Adding Dict, List, and RecordDict types

This commit is contained in:
evilchili 2025-10-07 20:57:50 -07:00
parent 44ee664a77
commit ef9a3053d1
3 changed files with 123 additions and 11 deletions

View File

@ -1,4 +1,16 @@
from grung.types import BackReference, Collection, DateTime, Field, Integer, Password, Record, Timestamp
from grung.types import (
BackReference,
Collection,
DateTime,
Dict,
Field,
Integer,
List,
Password,
Record,
RecordDict,
Timestamp,
)
class User(Record):
@ -26,3 +38,25 @@ class Group(Record):
Collection("groups", Group),
BackReference("parent", Group),
]
class Album(Record):
@classmethod
def fields(cls):
return [
*super().fields(),
Field("name"),
Dict("credits"),
List("tracks"),
BackReference("artist", Artist),
]
class Artist(User):
@classmethod
def fields(cls):
return [
*super().fields(),
Field("name"),
RecordDict("albums", Album),
]

View File

@ -4,10 +4,10 @@ import hashlib
import hmac
import os
import re
import typing
from collections import namedtuple
from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict, List
import nanoid
from tinydb import TinyDB, where
@ -51,6 +51,30 @@ class Integer(Field):
return int(value)
@dataclass
class Dict(Field):
value_type: type = str
default: dict = field(default_factory=lambda: {})
def serialize(self, values: dict) -> Dict[(str, str)]:
return dict((key, str(value)) for key, value in values.items())
def deserialize(self, values: dict, db: TinyDB, recurse: bool = False) -> Dict[(str, str)]:
return values
@dataclass
class List(Field):
value_type: type = list
default: list = field(default_factory=lambda: [])
def serialize(self, values: list) -> Dict[(str, str)]:
return values
def deserialize(self, values: lsit, db: TinyDB, recurse: bool = False) -> typing.List[str]:
return values
@dataclass
class DateTime(Field):
value_type: datetime
@ -92,7 +116,7 @@ class Password(Field):
try:
if passwd[offset] != ":":
return False
digest = passwd[(offset + 1) :]
digest = passwd[(offset + 1) :] # noqa
if len(digest) != cls.digest_size * 2:
return False
return re.match(r"^[0-9a-f]+$", digest)
@ -118,7 +142,7 @@ class Password(Field):
record[self.name] = f"{salt}:{digest}"
class Record(Dict[(str, Field)]):
class Record(typing.Dict[(str, Field)]):
"""
Base type for a single database record.
"""
@ -148,7 +172,7 @@ class Record(Dict[(str, Field)]):
"""
rec = {}
for name, _field in self._metadata.fields.items():
rec[name] = _field.serialize(self[name])
rec[name] = _field.serialize(self[name]) if isinstance(_field, Field) else _field
return self.__class__(rec, doc_id=self.doc_id)
def deserialize(self, db, recurse: bool = True):
@ -228,10 +252,10 @@ class Pointer(Field):
elif type(value) == str:
pt, puid = value.split("::")
if puid:
try:
return db.table(pt).search(where("uid") == puid, recurse=recurse)[0]
except IndexError:
rec = db.table(pt).get(where("uid") == puid, recurse=recurse)
if not rec:
raise PointerReferenceError(f"Expected a {pt} with uid=={puid} but did not find one!")
return rec
return value
@ -247,12 +271,12 @@ class Collection(Field):
"""
value_type: type = Record
default: List[value_type] = field(default_factory=lambda: [])
default: typing.List[value_type] = field(default_factory=lambda: [])
def serialize(self, values: List[value_type]) -> List[str]:
def serialize(self, values: typing.List[value_type]) -> typing.List[str]:
return [Pointer.reference(val) for val in values]
def deserialize(self, values: List[str], db: TinyDB, recurse: bool = False) -> List[value_type]:
def deserialize(self, values: typing.List[str], db: TinyDB, recurse: bool = False) -> typing.List[value_type]:
"""
Recursively deserialize the objects in this collection
"""
@ -275,3 +299,32 @@ class Collection(Field):
for backref in target._metadata.backrefs(type(record)):
target[backref.name] = record
db.save(target)
@dataclass
class RecordDict(Field):
value_type: type = Record
default: typing.Dict[(str, Record)] = field(default_factory=lambda: {})
def serialize(self, values: typing.Dict[(str, value_type)]) -> typing.Dict[(str, str)]:
return dict((key, Pointer.reference(val)) for (key, val) in values.items())
def deserialize(
self, values: typing.Dict[(str, str)], db: TinyDB, recurse: bool = False
) -> typing.Dict[(str, value_type)]:
if not recurse:
return values
return dict((key, Pointer.dereference(val, db=db, recurse=False)) for (key, val) in values.items())
def after_insert(self, db: TinyDB, record: Record) -> None:
"""
Populate any backreferences in the members of this mapping with the parent record's uid.
"""
if not record[self.name]:
return
for key, pointer in record[self.name].items():
target = Pointer.dereference(pointer, db=db, recurse=False)
for backref in target._metadata.backrefs(type(record)):
target[backref.name] = record
db.save(target)

View File

@ -146,3 +146,28 @@ def test_datetime(db):
sleep(1)
user = db.save(user)
assert user.last_updated >= user.created
def test_mapping(db):
album = db.save(
examples.Album(
name="The Impossible Kid",
credits={"Produced By": "Aesop Rock", "Lyrics By": "Aesop Rock", "Puke in the MeowMix By": "Kirby"},
tracks=["Mystery Fish", "Rings", "Lotta Years", "Dorks"],
)
)
assert album.credits["Produced By"] == "Aesop Rock"
assert album.tracks[0] == "Mystery Fish"
aes = db.save(
examples.Artist(
name="Aesop Rock",
albums={"The Impossible Kid": album},
)
)
album = db.Album.get(doc_id=album.doc_id)
assert album.artist.uid == aes.uid
assert album.name in aes.albums
assert aes.albums[album.name].uid == album.uid
assert "Kirby" in aes.albums[album.name].credits.values()