Adding Dict, List, and RecordDict types

This commit is contained in:
evilchili 2025-10-07 20:57:50 -07:00
parent 44ee664a77
commit ef9a3053d1
3 changed files with 123 additions and 11 deletions

View File

@ -1,4 +1,16 @@
from grung.types import BackReference, Collection, DateTime, Field, Integer, Password, Record, Timestamp from grung.types import (
BackReference,
Collection,
DateTime,
Dict,
Field,
Integer,
List,
Password,
Record,
RecordDict,
Timestamp,
)
class User(Record): class User(Record):
@ -26,3 +38,25 @@ class Group(Record):
Collection("groups", Group), Collection("groups", Group),
BackReference("parent", Group), BackReference("parent", Group),
] ]
class Album(Record):
@classmethod
def fields(cls):
return [
*super().fields(),
Field("name"),
Dict("credits"),
List("tracks"),
BackReference("artist", Artist),
]
class Artist(User):
@classmethod
def fields(cls):
return [
*super().fields(),
Field("name"),
RecordDict("albums", Album),
]

View File

@ -4,10 +4,10 @@ import hashlib
import hmac import hmac
import os import os
import re import re
import typing
from collections import namedtuple from collections import namedtuple
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from typing import Dict, List
import nanoid import nanoid
from tinydb import TinyDB, where from tinydb import TinyDB, where
@ -51,6 +51,30 @@ class Integer(Field):
return int(value) return int(value)
@dataclass
class Dict(Field):
value_type: type = str
default: dict = field(default_factory=lambda: {})
def serialize(self, values: dict) -> Dict[(str, str)]:
return dict((key, str(value)) for key, value in values.items())
def deserialize(self, values: dict, db: TinyDB, recurse: bool = False) -> Dict[(str, str)]:
return values
@dataclass
class List(Field):
value_type: type = list
default: list = field(default_factory=lambda: [])
def serialize(self, values: list) -> Dict[(str, str)]:
return values
def deserialize(self, values: lsit, db: TinyDB, recurse: bool = False) -> typing.List[str]:
return values
@dataclass @dataclass
class DateTime(Field): class DateTime(Field):
value_type: datetime value_type: datetime
@ -92,7 +116,7 @@ class Password(Field):
try: try:
if passwd[offset] != ":": if passwd[offset] != ":":
return False return False
digest = passwd[(offset + 1) :] digest = passwd[(offset + 1) :] # noqa
if len(digest) != cls.digest_size * 2: if len(digest) != cls.digest_size * 2:
return False return False
return re.match(r"^[0-9a-f]+$", digest) return re.match(r"^[0-9a-f]+$", digest)
@ -118,7 +142,7 @@ class Password(Field):
record[self.name] = f"{salt}:{digest}" record[self.name] = f"{salt}:{digest}"
class Record(Dict[(str, Field)]): class Record(typing.Dict[(str, Field)]):
""" """
Base type for a single database record. Base type for a single database record.
""" """
@ -148,7 +172,7 @@ class Record(Dict[(str, Field)]):
""" """
rec = {} rec = {}
for name, _field in self._metadata.fields.items(): for name, _field in self._metadata.fields.items():
rec[name] = _field.serialize(self[name]) rec[name] = _field.serialize(self[name]) if isinstance(_field, Field) else _field
return self.__class__(rec, doc_id=self.doc_id) return self.__class__(rec, doc_id=self.doc_id)
def deserialize(self, db, recurse: bool = True): def deserialize(self, db, recurse: bool = True):
@ -228,10 +252,10 @@ class Pointer(Field):
elif type(value) == str: elif type(value) == str:
pt, puid = value.split("::") pt, puid = value.split("::")
if puid: if puid:
try: rec = db.table(pt).get(where("uid") == puid, recurse=recurse)
return db.table(pt).search(where("uid") == puid, recurse=recurse)[0] if not rec:
except IndexError:
raise PointerReferenceError(f"Expected a {pt} with uid=={puid} but did not find one!") raise PointerReferenceError(f"Expected a {pt} with uid=={puid} but did not find one!")
return rec
return value return value
@ -247,12 +271,12 @@ class Collection(Field):
""" """
value_type: type = Record value_type: type = Record
default: List[value_type] = field(default_factory=lambda: []) default: typing.List[value_type] = field(default_factory=lambda: [])
def serialize(self, values: List[value_type]) -> List[str]: def serialize(self, values: typing.List[value_type]) -> typing.List[str]:
return [Pointer.reference(val) for val in values] return [Pointer.reference(val) for val in values]
def deserialize(self, values: List[str], db: TinyDB, recurse: bool = False) -> List[value_type]: def deserialize(self, values: typing.List[str], db: TinyDB, recurse: bool = False) -> typing.List[value_type]:
""" """
Recursively deserialize the objects in this collection Recursively deserialize the objects in this collection
""" """
@ -275,3 +299,32 @@ class Collection(Field):
for backref in target._metadata.backrefs(type(record)): for backref in target._metadata.backrefs(type(record)):
target[backref.name] = record target[backref.name] = record
db.save(target) db.save(target)
@dataclass
class RecordDict(Field):
value_type: type = Record
default: typing.Dict[(str, Record)] = field(default_factory=lambda: {})
def serialize(self, values: typing.Dict[(str, value_type)]) -> typing.Dict[(str, str)]:
return dict((key, Pointer.reference(val)) for (key, val) in values.items())
def deserialize(
self, values: typing.Dict[(str, str)], db: TinyDB, recurse: bool = False
) -> typing.Dict[(str, value_type)]:
if not recurse:
return values
return dict((key, Pointer.dereference(val, db=db, recurse=False)) for (key, val) in values.items())
def after_insert(self, db: TinyDB, record: Record) -> None:
"""
Populate any backreferences in the members of this mapping with the parent record's uid.
"""
if not record[self.name]:
return
for key, pointer in record[self.name].items():
target = Pointer.dereference(pointer, db=db, recurse=False)
for backref in target._metadata.backrefs(type(record)):
target[backref.name] = record
db.save(target)

View File

@ -146,3 +146,28 @@ def test_datetime(db):
sleep(1) sleep(1)
user = db.save(user) user = db.save(user)
assert user.last_updated >= user.created assert user.last_updated >= user.created
def test_mapping(db):
album = db.save(
examples.Album(
name="The Impossible Kid",
credits={"Produced By": "Aesop Rock", "Lyrics By": "Aesop Rock", "Puke in the MeowMix By": "Kirby"},
tracks=["Mystery Fish", "Rings", "Lotta Years", "Dorks"],
)
)
assert album.credits["Produced By"] == "Aesop Rock"
assert album.tracks[0] == "Mystery Fish"
aes = db.save(
examples.Artist(
name="Aesop Rock",
albums={"The Impossible Kid": album},
)
)
album = db.Album.get(doc_id=album.doc_id)
assert album.artist.uid == aes.uid
assert album.name in aes.albums
assert aes.albums[album.name].uid == album.uid
assert "Kirby" in aes.albums[album.name].credits.values()