implement primary keys
This commit is contained in:
parent
ef9a3053d1
commit
7e05915540
|
@ -18,7 +18,7 @@ class User(Record):
|
||||||
def fields(cls):
|
def fields(cls):
|
||||||
return [
|
return [
|
||||||
*super().fields(),
|
*super().fields(),
|
||||||
Field("name"),
|
Field("name", primary_key=True),
|
||||||
Integer("number", default=0),
|
Integer("number", default=0),
|
||||||
Field("email", unique=True),
|
Field("email", unique=True),
|
||||||
Password("password"),
|
Password("password"),
|
||||||
|
@ -33,7 +33,7 @@ class Group(Record):
|
||||||
def fields(cls):
|
def fields(cls):
|
||||||
return [
|
return [
|
||||||
*super().fields(),
|
*super().fields(),
|
||||||
Field("name", unique=True),
|
Field("name", primary_key=True),
|
||||||
Collection("members", User),
|
Collection("members", User),
|
||||||
Collection("groups", Group),
|
Collection("groups", Group),
|
||||||
BackReference("parent", Group),
|
BackReference("parent", Group),
|
||||||
|
@ -43,8 +43,8 @@ class Group(Record):
|
||||||
class Album(Record):
|
class Album(Record):
|
||||||
@classmethod
|
@classmethod
|
||||||
def fields(cls):
|
def fields(cls):
|
||||||
return [
|
inherited = [f for f in super().fields() if f.name != "name"]
|
||||||
*super().fields(),
|
return inherited + [
|
||||||
Field("name"),
|
Field("name"),
|
||||||
Dict("credits"),
|
Dict("credits"),
|
||||||
List("tracks"),
|
List("tracks"),
|
||||||
|
@ -55,8 +55,8 @@ class Album(Record):
|
||||||
class Artist(User):
|
class Artist(User):
|
||||||
@classmethod
|
@classmethod
|
||||||
def fields(cls):
|
def fields(cls):
|
||||||
return [
|
inherited = [f for f in super().fields() if f.name != "name"]
|
||||||
*super().fields(),
|
return inherited + [
|
||||||
Field("name"),
|
Field("name"),
|
||||||
RecordDict("albums", Album),
|
RecordDict("albums", Album),
|
||||||
]
|
]
|
||||||
|
|
|
@ -14,7 +14,7 @@ from tinydb import TinyDB, where
|
||||||
|
|
||||||
from grung.exceptions import PointerReferenceError
|
from grung.exceptions import PointerReferenceError
|
||||||
|
|
||||||
Metadata = namedtuple("Metadata", ["table", "fields", "backrefs"])
|
Metadata = namedtuple("Metadata", ["table", "fields", "backrefs", "primary_key"])
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -27,6 +27,7 @@ class Field:
|
||||||
value_type: type = str
|
value_type: type = str
|
||||||
default: str = None
|
default: str = None
|
||||||
unique: bool = False
|
unique: bool = False
|
||||||
|
primary_key: bool = False
|
||||||
|
|
||||||
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
|
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -71,7 +72,7 @@ class List(Field):
|
||||||
def serialize(self, values: list) -> Dict[(str, str)]:
|
def serialize(self, values: list) -> Dict[(str, str)]:
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def deserialize(self, values: lsit, db: TinyDB, recurse: bool = False) -> typing.List[str]:
|
def deserialize(self, values: list, db: TinyDB, recurse: bool = False) -> typing.List[str]:
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
@ -150,8 +151,22 @@ class Record(typing.Dict[(str, Field)]):
|
||||||
def __init__(self, raw_doc: dict = {}, doc_id: int = None, **params):
|
def __init__(self, raw_doc: dict = {}, doc_id: int = None, **params):
|
||||||
self.doc_id = doc_id
|
self.doc_id = doc_id
|
||||||
fields = self.__class__.fields()
|
fields = self.__class__.fields()
|
||||||
|
|
||||||
|
pkey = [field for field in fields if field.primary_key]
|
||||||
|
if len(pkey) > 1:
|
||||||
|
raise Exception(f"Cannnot have more than one primary key: {pkey}")
|
||||||
|
elif pkey:
|
||||||
|
pkey = pkey[0]
|
||||||
|
else:
|
||||||
|
# 1% collision rate at ~2M records
|
||||||
|
pkey = Field("uid", default=nanoid.generate(size=8), primary_key=True)
|
||||||
|
fields.append(pkey)
|
||||||
|
|
||||||
|
pkey.unique = True
|
||||||
|
|
||||||
self._metadata = Metadata(
|
self._metadata = Metadata(
|
||||||
table=self.__class__.__name__,
|
table=self.__class__.__name__,
|
||||||
|
primary_key=pkey.name,
|
||||||
fields={f.name: f for f in fields},
|
fields={f.name: f for f in fields},
|
||||||
backrefs=lambda value_type: (
|
backrefs=lambda value_type: (
|
||||||
field for field in fields if type(field) == BackReference and field.value_type == value_type
|
field for field in fields if type(field) == BackReference and field.value_type == value_type
|
||||||
|
@ -160,11 +175,8 @@ class Record(typing.Dict[(str, Field)]):
|
||||||
super().__init__(dict({field.name: field.default for field in fields}, **raw_doc, **params))
|
super().__init__(dict({field.name: field.default for field in fields}, **raw_doc, **params))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def fields(self):
|
def fields(cls):
|
||||||
return [
|
return []
|
||||||
# 1% collision rate at ~2M records
|
|
||||||
Field("uid", default=nanoid.generate(size=8), unique=True)
|
|
||||||
]
|
|
||||||
|
|
||||||
def serialize(self):
|
def serialize(self):
|
||||||
"""
|
"""
|
||||||
|
@ -204,7 +216,7 @@ class Record(typing.Dict[(str, Field)]):
|
||||||
def __getattr__(self, attr_name):
|
def __getattr__(self, attr_name):
|
||||||
if attr_name in self:
|
if attr_name in self:
|
||||||
return self.get(attr_name)
|
return self.get(attr_name)
|
||||||
return super().__getattr__(attr_name)
|
raise AttributeError(f"No such attribute: {attr_name}")
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash(str(dict(self)))
|
return hash(str(dict(self)))
|
||||||
|
@ -242,7 +254,7 @@ class Pointer(Field):
|
||||||
if value:
|
if value:
|
||||||
if not value.doc_id:
|
if not value.doc_id:
|
||||||
raise PointerReferenceError(value)
|
raise PointerReferenceError(value)
|
||||||
return f"{value._metadata.table}::{value.uid}"
|
return f"{value._metadata.table}::{value._metadata.primary_key}::{value[value._metadata.primary_key]}"
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -250,11 +262,12 @@ class Pointer(Field):
|
||||||
if not value:
|
if not value:
|
||||||
return
|
return
|
||||||
elif type(value) == str:
|
elif type(value) == str:
|
||||||
pt, puid = value.split("::")
|
table_name, pkey, pval = value.split("::")
|
||||||
if puid:
|
if pval:
|
||||||
rec = db.table(pt).get(where("uid") == puid, recurse=recurse)
|
table = db.table(table_name)
|
||||||
|
rec = table.get(where(pkey) == pval, recurse=recurse)
|
||||||
if not rec:
|
if not rec:
|
||||||
raise PointerReferenceError(f"Expected a {pt} with uid=={puid} but did not find one!")
|
raise PointerReferenceError(f"Expected a {table_name} with {pkey}=={pval} but did not find one!")
|
||||||
return rec
|
return rec
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
|
@ -20,8 +20,8 @@ def db():
|
||||||
|
|
||||||
def test_crud(db):
|
def test_crud(db):
|
||||||
user = examples.User(name="john", number=23, email="john@foo")
|
user = examples.User(name="john", number=23, email="john@foo")
|
||||||
assert user.uid
|
assert user._metadata.fields[user._metadata.primary_key].unique
|
||||||
assert user._metadata.fields["uid"].unique
|
assert user._metadata.fields[user._metadata.primary_key].primary_key
|
||||||
|
|
||||||
# insert
|
# insert
|
||||||
john_something = db.save(user)
|
john_something = db.save(user)
|
||||||
|
@ -32,7 +32,6 @@ def test_crud(db):
|
||||||
assert john_something.name == user.name
|
assert john_something.name == user.name
|
||||||
assert john_something.number == 23
|
assert john_something.number == 23
|
||||||
assert john_something.email == user.email
|
assert john_something.email == user.email
|
||||||
assert john_something.uid == user.uid
|
|
||||||
|
|
||||||
# update
|
# update
|
||||||
john_something.name = "james?"
|
john_something.name = "james?"
|
||||||
|
@ -57,9 +56,9 @@ def test_pointers(db):
|
||||||
players = db.save(examples.Group(name="players", members=[user]))
|
players = db.save(examples.Group(name="players", members=[user]))
|
||||||
|
|
||||||
user = db.table("User").get(doc_id=user.doc_id)
|
user = db.table("User").get(doc_id=user.doc_id)
|
||||||
assert user.groups.uid == players.uid
|
assert user.groups.name == players.name
|
||||||
|
|
||||||
assert players.members[0].groups.uid == players.uid
|
assert players.members[0].groups.name == players.name
|
||||||
|
|
||||||
|
|
||||||
def test_subgroups(db):
|
def test_subgroups(db):
|
||||||
|
@ -76,8 +75,8 @@ def test_subgroups(db):
|
||||||
assert snw in trek.groups
|
assert snw in trek.groups
|
||||||
|
|
||||||
assert trek.parent is None
|
assert trek.parent is None
|
||||||
assert tos.parent.uid == trek.uid
|
assert tos.parent.name == trek.name
|
||||||
assert snw.parent.uid == trek.uid
|
assert snw.parent.name == trek.name
|
||||||
|
|
||||||
unique_users = set([user for group in trek.groups for user in group.members])
|
unique_users = set([user for group in trek.groups for user in group.members])
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user