From 7e05915540f6867d9b54924c0573a16f76a0a917 Mon Sep 17 00:00:00 2001 From: evilchili Date: Wed, 8 Oct 2025 00:25:02 -0700 Subject: [PATCH] implement primary keys --- src/grung/examples.py | 12 ++++++------ src/grung/types.py | 39 ++++++++++++++++++++++++++------------- test/test_db.py | 13 ++++++------- 3 files changed, 38 insertions(+), 26 deletions(-) diff --git a/src/grung/examples.py b/src/grung/examples.py index 4e8af47..904e392 100644 --- a/src/grung/examples.py +++ b/src/grung/examples.py @@ -18,7 +18,7 @@ class User(Record): def fields(cls): return [ *super().fields(), - Field("name"), + Field("name", primary_key=True), Integer("number", default=0), Field("email", unique=True), Password("password"), @@ -33,7 +33,7 @@ class Group(Record): def fields(cls): return [ *super().fields(), - Field("name", unique=True), + Field("name", primary_key=True), Collection("members", User), Collection("groups", Group), BackReference("parent", Group), @@ -43,8 +43,8 @@ class Group(Record): class Album(Record): @classmethod def fields(cls): - return [ - *super().fields(), + inherited = [f for f in super().fields() if f.name != "name"] + return inherited + [ Field("name"), Dict("credits"), List("tracks"), @@ -55,8 +55,8 @@ class Album(Record): class Artist(User): @classmethod def fields(cls): - return [ - *super().fields(), + inherited = [f for f in super().fields() if f.name != "name"] + return inherited + [ Field("name"), RecordDict("albums", Album), ] diff --git a/src/grung/types.py b/src/grung/types.py index 4ebdcad..0af2fcd 100644 --- a/src/grung/types.py +++ b/src/grung/types.py @@ -14,7 +14,7 @@ from tinydb import TinyDB, where from grung.exceptions import PointerReferenceError -Metadata = namedtuple("Metadata", ["table", "fields", "backrefs"]) +Metadata = namedtuple("Metadata", ["table", "fields", "backrefs", "primary_key"]) @dataclass @@ -27,6 +27,7 @@ class Field: value_type: type = str default: str = None unique: bool = False + primary_key: bool = False def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None: pass @@ -71,7 +72,7 @@ class List(Field): def serialize(self, values: list) -> Dict[(str, str)]: return values - def deserialize(self, values: lsit, db: TinyDB, recurse: bool = False) -> typing.List[str]: + def deserialize(self, values: list, db: TinyDB, recurse: bool = False) -> typing.List[str]: return values @@ -150,8 +151,22 @@ class Record(typing.Dict[(str, Field)]): def __init__(self, raw_doc: dict = {}, doc_id: int = None, **params): self.doc_id = doc_id fields = self.__class__.fields() + + pkey = [field for field in fields if field.primary_key] + if len(pkey) > 1: + raise Exception(f"Cannnot have more than one primary key: {pkey}") + elif pkey: + pkey = pkey[0] + else: + # 1% collision rate at ~2M records + pkey = Field("uid", default=nanoid.generate(size=8), primary_key=True) + fields.append(pkey) + + pkey.unique = True + self._metadata = Metadata( table=self.__class__.__name__, + primary_key=pkey.name, fields={f.name: f for f in fields}, backrefs=lambda value_type: ( field for field in fields if type(field) == BackReference and field.value_type == value_type @@ -160,11 +175,8 @@ class Record(typing.Dict[(str, Field)]): super().__init__(dict({field.name: field.default for field in fields}, **raw_doc, **params)) @classmethod - def fields(self): - return [ - # 1% collision rate at ~2M records - Field("uid", default=nanoid.generate(size=8), unique=True) - ] + def fields(cls): + return [] def serialize(self): """ @@ -204,7 +216,7 @@ class Record(typing.Dict[(str, Field)]): def __getattr__(self, attr_name): if attr_name in self: return self.get(attr_name) - return super().__getattr__(attr_name) + raise AttributeError(f"No such attribute: {attr_name}") def __hash__(self): return hash(str(dict(self))) @@ -242,7 +254,7 @@ class Pointer(Field): if value: if not value.doc_id: raise PointerReferenceError(value) - return f"{value._metadata.table}::{value.uid}" + return f"{value._metadata.table}::{value._metadata.primary_key}::{value[value._metadata.primary_key]}" return None @classmethod @@ -250,11 +262,12 @@ class Pointer(Field): if not value: return elif type(value) == str: - pt, puid = value.split("::") - if puid: - rec = db.table(pt).get(where("uid") == puid, recurse=recurse) + table_name, pkey, pval = value.split("::") + if pval: + table = db.table(table_name) + rec = table.get(where(pkey) == pval, recurse=recurse) if not rec: - raise PointerReferenceError(f"Expected a {pt} with uid=={puid} but did not find one!") + raise PointerReferenceError(f"Expected a {table_name} with {pkey}=={pval} but did not find one!") return rec return value diff --git a/test/test_db.py b/test/test_db.py index 3b931c9..054ee8b 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -20,8 +20,8 @@ def db(): def test_crud(db): user = examples.User(name="john", number=23, email="john@foo") - assert user.uid - assert user._metadata.fields["uid"].unique + assert user._metadata.fields[user._metadata.primary_key].unique + assert user._metadata.fields[user._metadata.primary_key].primary_key # insert john_something = db.save(user) @@ -32,7 +32,6 @@ def test_crud(db): assert john_something.name == user.name assert john_something.number == 23 assert john_something.email == user.email - assert john_something.uid == user.uid # update john_something.name = "james?" @@ -57,9 +56,9 @@ def test_pointers(db): players = db.save(examples.Group(name="players", members=[user])) user = db.table("User").get(doc_id=user.doc_id) - assert user.groups.uid == players.uid + assert user.groups.name == players.name - assert players.members[0].groups.uid == players.uid + assert players.members[0].groups.name == players.name def test_subgroups(db): @@ -76,8 +75,8 @@ def test_subgroups(db): assert snw in trek.groups assert trek.parent is None - assert tos.parent.uid == trek.uid - assert snw.parent.uid == trek.uid + assert tos.parent.name == trek.name + assert snw.parent.name == trek.name unique_users = set([user for group in trek.groups for user in group.members])