implement primary keys

This commit is contained in:
evilchili 2025-10-08 00:25:02 -07:00
parent ef9a3053d1
commit 7e05915540
3 changed files with 38 additions and 26 deletions

View File

@ -18,7 +18,7 @@ class User(Record):
def fields(cls): def fields(cls):
return [ return [
*super().fields(), *super().fields(),
Field("name"), Field("name", primary_key=True),
Integer("number", default=0), Integer("number", default=0),
Field("email", unique=True), Field("email", unique=True),
Password("password"), Password("password"),
@ -33,7 +33,7 @@ class Group(Record):
def fields(cls): def fields(cls):
return [ return [
*super().fields(), *super().fields(),
Field("name", unique=True), Field("name", primary_key=True),
Collection("members", User), Collection("members", User),
Collection("groups", Group), Collection("groups", Group),
BackReference("parent", Group), BackReference("parent", Group),
@ -43,8 +43,8 @@ class Group(Record):
class Album(Record): class Album(Record):
@classmethod @classmethod
def fields(cls): def fields(cls):
return [ inherited = [f for f in super().fields() if f.name != "name"]
*super().fields(), return inherited + [
Field("name"), Field("name"),
Dict("credits"), Dict("credits"),
List("tracks"), List("tracks"),
@ -55,8 +55,8 @@ class Album(Record):
class Artist(User): class Artist(User):
@classmethod @classmethod
def fields(cls): def fields(cls):
return [ inherited = [f for f in super().fields() if f.name != "name"]
*super().fields(), return inherited + [
Field("name"), Field("name"),
RecordDict("albums", Album), RecordDict("albums", Album),
] ]

View File

@ -14,7 +14,7 @@ from tinydb import TinyDB, where
from grung.exceptions import PointerReferenceError from grung.exceptions import PointerReferenceError
Metadata = namedtuple("Metadata", ["table", "fields", "backrefs"]) Metadata = namedtuple("Metadata", ["table", "fields", "backrefs", "primary_key"])
@dataclass @dataclass
@ -27,6 +27,7 @@ class Field:
value_type: type = str value_type: type = str
default: str = None default: str = None
unique: bool = False unique: bool = False
primary_key: bool = False
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None: def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
pass pass
@ -71,7 +72,7 @@ class List(Field):
def serialize(self, values: list) -> Dict[(str, str)]: def serialize(self, values: list) -> Dict[(str, str)]:
return values return values
def deserialize(self, values: lsit, db: TinyDB, recurse: bool = False) -> typing.List[str]: def deserialize(self, values: list, db: TinyDB, recurse: bool = False) -> typing.List[str]:
return values return values
@ -150,8 +151,22 @@ class Record(typing.Dict[(str, Field)]):
def __init__(self, raw_doc: dict = {}, doc_id: int = None, **params): def __init__(self, raw_doc: dict = {}, doc_id: int = None, **params):
self.doc_id = doc_id self.doc_id = doc_id
fields = self.__class__.fields() fields = self.__class__.fields()
pkey = [field for field in fields if field.primary_key]
if len(pkey) > 1:
raise Exception(f"Cannnot have more than one primary key: {pkey}")
elif pkey:
pkey = pkey[0]
else:
# 1% collision rate at ~2M records
pkey = Field("uid", default=nanoid.generate(size=8), primary_key=True)
fields.append(pkey)
pkey.unique = True
self._metadata = Metadata( self._metadata = Metadata(
table=self.__class__.__name__, table=self.__class__.__name__,
primary_key=pkey.name,
fields={f.name: f for f in fields}, fields={f.name: f for f in fields},
backrefs=lambda value_type: ( backrefs=lambda value_type: (
field for field in fields if type(field) == BackReference and field.value_type == value_type field for field in fields if type(field) == BackReference and field.value_type == value_type
@ -160,11 +175,8 @@ class Record(typing.Dict[(str, Field)]):
super().__init__(dict({field.name: field.default for field in fields}, **raw_doc, **params)) super().__init__(dict({field.name: field.default for field in fields}, **raw_doc, **params))
@classmethod @classmethod
def fields(self): def fields(cls):
return [ return []
# 1% collision rate at ~2M records
Field("uid", default=nanoid.generate(size=8), unique=True)
]
def serialize(self): def serialize(self):
""" """
@ -204,7 +216,7 @@ class Record(typing.Dict[(str, Field)]):
def __getattr__(self, attr_name): def __getattr__(self, attr_name):
if attr_name in self: if attr_name in self:
return self.get(attr_name) return self.get(attr_name)
return super().__getattr__(attr_name) raise AttributeError(f"No such attribute: {attr_name}")
def __hash__(self): def __hash__(self):
return hash(str(dict(self))) return hash(str(dict(self)))
@ -242,7 +254,7 @@ class Pointer(Field):
if value: if value:
if not value.doc_id: if not value.doc_id:
raise PointerReferenceError(value) raise PointerReferenceError(value)
return f"{value._metadata.table}::{value.uid}" return f"{value._metadata.table}::{value._metadata.primary_key}::{value[value._metadata.primary_key]}"
return None return None
@classmethod @classmethod
@ -250,11 +262,12 @@ class Pointer(Field):
if not value: if not value:
return return
elif type(value) == str: elif type(value) == str:
pt, puid = value.split("::") table_name, pkey, pval = value.split("::")
if puid: if pval:
rec = db.table(pt).get(where("uid") == puid, recurse=recurse) table = db.table(table_name)
rec = table.get(where(pkey) == pval, recurse=recurse)
if not rec: if not rec:
raise PointerReferenceError(f"Expected a {pt} with uid=={puid} but did not find one!") raise PointerReferenceError(f"Expected a {table_name} with {pkey}=={pval} but did not find one!")
return rec return rec
return value return value

View File

@ -20,8 +20,8 @@ def db():
def test_crud(db): def test_crud(db):
user = examples.User(name="john", number=23, email="john@foo") user = examples.User(name="john", number=23, email="john@foo")
assert user.uid assert user._metadata.fields[user._metadata.primary_key].unique
assert user._metadata.fields["uid"].unique assert user._metadata.fields[user._metadata.primary_key].primary_key
# insert # insert
john_something = db.save(user) john_something = db.save(user)
@ -32,7 +32,6 @@ def test_crud(db):
assert john_something.name == user.name assert john_something.name == user.name
assert john_something.number == 23 assert john_something.number == 23
assert john_something.email == user.email assert john_something.email == user.email
assert john_something.uid == user.uid
# update # update
john_something.name = "james?" john_something.name = "james?"
@ -57,9 +56,9 @@ def test_pointers(db):
players = db.save(examples.Group(name="players", members=[user])) players = db.save(examples.Group(name="players", members=[user]))
user = db.table("User").get(doc_id=user.doc_id) user = db.table("User").get(doc_id=user.doc_id)
assert user.groups.uid == players.uid assert user.groups.name == players.name
assert players.members[0].groups.uid == players.uid assert players.members[0].groups.name == players.name
def test_subgroups(db): def test_subgroups(db):
@ -76,8 +75,8 @@ def test_subgroups(db):
assert snw in trek.groups assert snw in trek.groups
assert trek.parent is None assert trek.parent is None
assert tos.parent.uid == trek.uid assert tos.parent.name == trek.name
assert snw.parent.uid == trek.uid assert snw.parent.name == trek.name
unique_users = set([user for group in trek.groups for user in group.members]) unique_users = set([user for group in trek.groups for user in group.members])