implement primary keys
This commit is contained in:
parent
ef9a3053d1
commit
7e05915540
|
@ -18,7 +18,7 @@ class User(Record):
|
|||
def fields(cls):
|
||||
return [
|
||||
*super().fields(),
|
||||
Field("name"),
|
||||
Field("name", primary_key=True),
|
||||
Integer("number", default=0),
|
||||
Field("email", unique=True),
|
||||
Password("password"),
|
||||
|
@ -33,7 +33,7 @@ class Group(Record):
|
|||
def fields(cls):
|
||||
return [
|
||||
*super().fields(),
|
||||
Field("name", unique=True),
|
||||
Field("name", primary_key=True),
|
||||
Collection("members", User),
|
||||
Collection("groups", Group),
|
||||
BackReference("parent", Group),
|
||||
|
@ -43,8 +43,8 @@ class Group(Record):
|
|||
class Album(Record):
|
||||
@classmethod
|
||||
def fields(cls):
|
||||
return [
|
||||
*super().fields(),
|
||||
inherited = [f for f in super().fields() if f.name != "name"]
|
||||
return inherited + [
|
||||
Field("name"),
|
||||
Dict("credits"),
|
||||
List("tracks"),
|
||||
|
@ -55,8 +55,8 @@ class Album(Record):
|
|||
class Artist(User):
|
||||
@classmethod
|
||||
def fields(cls):
|
||||
return [
|
||||
*super().fields(),
|
||||
inherited = [f for f in super().fields() if f.name != "name"]
|
||||
return inherited + [
|
||||
Field("name"),
|
||||
RecordDict("albums", Album),
|
||||
]
|
||||
|
|
|
@ -14,7 +14,7 @@ from tinydb import TinyDB, where
|
|||
|
||||
from grung.exceptions import PointerReferenceError
|
||||
|
||||
Metadata = namedtuple("Metadata", ["table", "fields", "backrefs"])
|
||||
Metadata = namedtuple("Metadata", ["table", "fields", "backrefs", "primary_key"])
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -27,6 +27,7 @@ class Field:
|
|||
value_type: type = str
|
||||
default: str = None
|
||||
unique: bool = False
|
||||
primary_key: bool = False
|
||||
|
||||
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
|
||||
pass
|
||||
|
@ -71,7 +72,7 @@ class List(Field):
|
|||
def serialize(self, values: list) -> Dict[(str, str)]:
|
||||
return values
|
||||
|
||||
def deserialize(self, values: lsit, db: TinyDB, recurse: bool = False) -> typing.List[str]:
|
||||
def deserialize(self, values: list, db: TinyDB, recurse: bool = False) -> typing.List[str]:
|
||||
return values
|
||||
|
||||
|
||||
|
@ -150,8 +151,22 @@ class Record(typing.Dict[(str, Field)]):
|
|||
def __init__(self, raw_doc: dict = {}, doc_id: int = None, **params):
|
||||
self.doc_id = doc_id
|
||||
fields = self.__class__.fields()
|
||||
|
||||
pkey = [field for field in fields if field.primary_key]
|
||||
if len(pkey) > 1:
|
||||
raise Exception(f"Cannnot have more than one primary key: {pkey}")
|
||||
elif pkey:
|
||||
pkey = pkey[0]
|
||||
else:
|
||||
# 1% collision rate at ~2M records
|
||||
pkey = Field("uid", default=nanoid.generate(size=8), primary_key=True)
|
||||
fields.append(pkey)
|
||||
|
||||
pkey.unique = True
|
||||
|
||||
self._metadata = Metadata(
|
||||
table=self.__class__.__name__,
|
||||
primary_key=pkey.name,
|
||||
fields={f.name: f for f in fields},
|
||||
backrefs=lambda value_type: (
|
||||
field for field in fields if type(field) == BackReference and field.value_type == value_type
|
||||
|
@ -160,11 +175,8 @@ class Record(typing.Dict[(str, Field)]):
|
|||
super().__init__(dict({field.name: field.default for field in fields}, **raw_doc, **params))
|
||||
|
||||
@classmethod
|
||||
def fields(self):
|
||||
return [
|
||||
# 1% collision rate at ~2M records
|
||||
Field("uid", default=nanoid.generate(size=8), unique=True)
|
||||
]
|
||||
def fields(cls):
|
||||
return []
|
||||
|
||||
def serialize(self):
|
||||
"""
|
||||
|
@ -204,7 +216,7 @@ class Record(typing.Dict[(str, Field)]):
|
|||
def __getattr__(self, attr_name):
|
||||
if attr_name in self:
|
||||
return self.get(attr_name)
|
||||
return super().__getattr__(attr_name)
|
||||
raise AttributeError(f"No such attribute: {attr_name}")
|
||||
|
||||
def __hash__(self):
|
||||
return hash(str(dict(self)))
|
||||
|
@ -242,7 +254,7 @@ class Pointer(Field):
|
|||
if value:
|
||||
if not value.doc_id:
|
||||
raise PointerReferenceError(value)
|
||||
return f"{value._metadata.table}::{value.uid}"
|
||||
return f"{value._metadata.table}::{value._metadata.primary_key}::{value[value._metadata.primary_key]}"
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
|
@ -250,11 +262,12 @@ class Pointer(Field):
|
|||
if not value:
|
||||
return
|
||||
elif type(value) == str:
|
||||
pt, puid = value.split("::")
|
||||
if puid:
|
||||
rec = db.table(pt).get(where("uid") == puid, recurse=recurse)
|
||||
table_name, pkey, pval = value.split("::")
|
||||
if pval:
|
||||
table = db.table(table_name)
|
||||
rec = table.get(where(pkey) == pval, recurse=recurse)
|
||||
if not rec:
|
||||
raise PointerReferenceError(f"Expected a {pt} with uid=={puid} but did not find one!")
|
||||
raise PointerReferenceError(f"Expected a {table_name} with {pkey}=={pval} but did not find one!")
|
||||
return rec
|
||||
return value
|
||||
|
||||
|
|
|
@ -20,8 +20,8 @@ def db():
|
|||
|
||||
def test_crud(db):
|
||||
user = examples.User(name="john", number=23, email="john@foo")
|
||||
assert user.uid
|
||||
assert user._metadata.fields["uid"].unique
|
||||
assert user._metadata.fields[user._metadata.primary_key].unique
|
||||
assert user._metadata.fields[user._metadata.primary_key].primary_key
|
||||
|
||||
# insert
|
||||
john_something = db.save(user)
|
||||
|
@ -32,7 +32,6 @@ def test_crud(db):
|
|||
assert john_something.name == user.name
|
||||
assert john_something.number == 23
|
||||
assert john_something.email == user.email
|
||||
assert john_something.uid == user.uid
|
||||
|
||||
# update
|
||||
john_something.name = "james?"
|
||||
|
@ -57,9 +56,9 @@ def test_pointers(db):
|
|||
players = db.save(examples.Group(name="players", members=[user]))
|
||||
|
||||
user = db.table("User").get(doc_id=user.doc_id)
|
||||
assert user.groups.uid == players.uid
|
||||
assert user.groups.name == players.name
|
||||
|
||||
assert players.members[0].groups.uid == players.uid
|
||||
assert players.members[0].groups.name == players.name
|
||||
|
||||
|
||||
def test_subgroups(db):
|
||||
|
@ -76,8 +75,8 @@ def test_subgroups(db):
|
|||
assert snw in trek.groups
|
||||
|
||||
assert trek.parent is None
|
||||
assert tos.parent.uid == trek.uid
|
||||
assert snw.parent.uid == trek.uid
|
||||
assert tos.parent.name == trek.name
|
||||
assert snw.parent.name == trek.name
|
||||
|
||||
unique_users = set([user for group in trek.groups for user in group.members])
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user