diff --git a/src/grung/examples.py b/src/grung/examples.py index 6b4a364..4e8af47 100644 --- a/src/grung/examples.py +++ b/src/grung/examples.py @@ -1,4 +1,16 @@ -from grung.types import BackReference, Collection, DateTime, Field, Integer, Password, Record, Timestamp +from grung.types import ( + BackReference, + Collection, + DateTime, + Dict, + Field, + Integer, + List, + Password, + Record, + RecordDict, + Timestamp, +) class User(Record): @@ -26,3 +38,25 @@ class Group(Record): Collection("groups", Group), BackReference("parent", Group), ] + + +class Album(Record): + @classmethod + def fields(cls): + return [ + *super().fields(), + Field("name"), + Dict("credits"), + List("tracks"), + BackReference("artist", Artist), + ] + + +class Artist(User): + @classmethod + def fields(cls): + return [ + *super().fields(), + Field("name"), + RecordDict("albums", Album), + ] diff --git a/src/grung/types.py b/src/grung/types.py index c6ea2a9..4ebdcad 100644 --- a/src/grung/types.py +++ b/src/grung/types.py @@ -4,10 +4,10 @@ import hashlib import hmac import os import re +import typing from collections import namedtuple from dataclasses import dataclass, field from datetime import datetime -from typing import Dict, List import nanoid from tinydb import TinyDB, where @@ -51,6 +51,30 @@ class Integer(Field): return int(value) +@dataclass +class Dict(Field): + value_type: type = str + default: dict = field(default_factory=lambda: {}) + + def serialize(self, values: dict) -> Dict[(str, str)]: + return dict((key, str(value)) for key, value in values.items()) + + def deserialize(self, values: dict, db: TinyDB, recurse: bool = False) -> Dict[(str, str)]: + return values + + +@dataclass +class List(Field): + value_type: type = list + default: list = field(default_factory=lambda: []) + + def serialize(self, values: list) -> Dict[(str, str)]: + return values + + def deserialize(self, values: lsit, db: TinyDB, recurse: bool = False) -> typing.List[str]: + return values + + @dataclass class DateTime(Field): value_type: datetime @@ -92,7 +116,7 @@ class Password(Field): try: if passwd[offset] != ":": return False - digest = passwd[(offset + 1) :] + digest = passwd[(offset + 1) :] # noqa if len(digest) != cls.digest_size * 2: return False return re.match(r"^[0-9a-f]+$", digest) @@ -118,7 +142,7 @@ class Password(Field): record[self.name] = f"{salt}:{digest}" -class Record(Dict[(str, Field)]): +class Record(typing.Dict[(str, Field)]): """ Base type for a single database record. """ @@ -148,7 +172,7 @@ class Record(Dict[(str, Field)]): """ rec = {} for name, _field in self._metadata.fields.items(): - rec[name] = _field.serialize(self[name]) + rec[name] = _field.serialize(self[name]) if isinstance(_field, Field) else _field return self.__class__(rec, doc_id=self.doc_id) def deserialize(self, db, recurse: bool = True): @@ -228,10 +252,10 @@ class Pointer(Field): elif type(value) == str: pt, puid = value.split("::") if puid: - try: - return db.table(pt).search(where("uid") == puid, recurse=recurse)[0] - except IndexError: + rec = db.table(pt).get(where("uid") == puid, recurse=recurse) + if not rec: raise PointerReferenceError(f"Expected a {pt} with uid=={puid} but did not find one!") + return rec return value @@ -247,12 +271,12 @@ class Collection(Field): """ value_type: type = Record - default: List[value_type] = field(default_factory=lambda: []) + default: typing.List[value_type] = field(default_factory=lambda: []) - def serialize(self, values: List[value_type]) -> List[str]: + def serialize(self, values: typing.List[value_type]) -> typing.List[str]: return [Pointer.reference(val) for val in values] - def deserialize(self, values: List[str], db: TinyDB, recurse: bool = False) -> List[value_type]: + def deserialize(self, values: typing.List[str], db: TinyDB, recurse: bool = False) -> typing.List[value_type]: """ Recursively deserialize the objects in this collection """ @@ -275,3 +299,32 @@ class Collection(Field): for backref in target._metadata.backrefs(type(record)): target[backref.name] = record db.save(target) + + +@dataclass +class RecordDict(Field): + value_type: type = Record + default: typing.Dict[(str, Record)] = field(default_factory=lambda: {}) + + def serialize(self, values: typing.Dict[(str, value_type)]) -> typing.Dict[(str, str)]: + return dict((key, Pointer.reference(val)) for (key, val) in values.items()) + + def deserialize( + self, values: typing.Dict[(str, str)], db: TinyDB, recurse: bool = False + ) -> typing.Dict[(str, value_type)]: + if not recurse: + return values + return dict((key, Pointer.dereference(val, db=db, recurse=False)) for (key, val) in values.items()) + + def after_insert(self, db: TinyDB, record: Record) -> None: + """ + Populate any backreferences in the members of this mapping with the parent record's uid. + """ + if not record[self.name]: + return + + for key, pointer in record[self.name].items(): + target = Pointer.dereference(pointer, db=db, recurse=False) + for backref in target._metadata.backrefs(type(record)): + target[backref.name] = record + db.save(target) diff --git a/test/test_db.py b/test/test_db.py index 68d549c..3b931c9 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -146,3 +146,28 @@ def test_datetime(db): sleep(1) user = db.save(user) assert user.last_updated >= user.created + + +def test_mapping(db): + album = db.save( + examples.Album( + name="The Impossible Kid", + credits={"Produced By": "Aesop Rock", "Lyrics By": "Aesop Rock", "Puke in the MeowMix By": "Kirby"}, + tracks=["Mystery Fish", "Rings", "Lotta Years", "Dorks"], + ) + ) + assert album.credits["Produced By"] == "Aesop Rock" + assert album.tracks[0] == "Mystery Fish" + + aes = db.save( + examples.Artist( + name="Aesop Rock", + albums={"The Impossible Kid": album}, + ) + ) + + album = db.Album.get(doc_id=album.doc_id) + assert album.artist.uid == aes.uid + assert album.name in aes.albums + assert aes.albums[album.name].uid == album.uid + assert "Kirby" in aes.albums[album.name].credits.values()