Adding Dict, List, and RecordDict types
This commit is contained in:
parent
44ee664a77
commit
ef9a3053d1
|
@ -1,4 +1,16 @@
|
|||
from grung.types import BackReference, Collection, DateTime, Field, Integer, Password, Record, Timestamp
|
||||
from grung.types import (
|
||||
BackReference,
|
||||
Collection,
|
||||
DateTime,
|
||||
Dict,
|
||||
Field,
|
||||
Integer,
|
||||
List,
|
||||
Password,
|
||||
Record,
|
||||
RecordDict,
|
||||
Timestamp,
|
||||
)
|
||||
|
||||
|
||||
class User(Record):
|
||||
|
@ -26,3 +38,25 @@ class Group(Record):
|
|||
Collection("groups", Group),
|
||||
BackReference("parent", Group),
|
||||
]
|
||||
|
||||
|
||||
class Album(Record):
|
||||
@classmethod
|
||||
def fields(cls):
|
||||
return [
|
||||
*super().fields(),
|
||||
Field("name"),
|
||||
Dict("credits"),
|
||||
List("tracks"),
|
||||
BackReference("artist", Artist),
|
||||
]
|
||||
|
||||
|
||||
class Artist(User):
|
||||
@classmethod
|
||||
def fields(cls):
|
||||
return [
|
||||
*super().fields(),
|
||||
Field("name"),
|
||||
RecordDict("albums", Album),
|
||||
]
|
||||
|
|
|
@ -4,10 +4,10 @@ import hashlib
|
|||
import hmac
|
||||
import os
|
||||
import re
|
||||
import typing
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Dict, List
|
||||
|
||||
import nanoid
|
||||
from tinydb import TinyDB, where
|
||||
|
@ -51,6 +51,30 @@ class Integer(Field):
|
|||
return int(value)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Dict(Field):
|
||||
value_type: type = str
|
||||
default: dict = field(default_factory=lambda: {})
|
||||
|
||||
def serialize(self, values: dict) -> Dict[(str, str)]:
|
||||
return dict((key, str(value)) for key, value in values.items())
|
||||
|
||||
def deserialize(self, values: dict, db: TinyDB, recurse: bool = False) -> Dict[(str, str)]:
|
||||
return values
|
||||
|
||||
|
||||
@dataclass
|
||||
class List(Field):
|
||||
value_type: type = list
|
||||
default: list = field(default_factory=lambda: [])
|
||||
|
||||
def serialize(self, values: list) -> Dict[(str, str)]:
|
||||
return values
|
||||
|
||||
def deserialize(self, values: lsit, db: TinyDB, recurse: bool = False) -> typing.List[str]:
|
||||
return values
|
||||
|
||||
|
||||
@dataclass
|
||||
class DateTime(Field):
|
||||
value_type: datetime
|
||||
|
@ -92,7 +116,7 @@ class Password(Field):
|
|||
try:
|
||||
if passwd[offset] != ":":
|
||||
return False
|
||||
digest = passwd[(offset + 1) :]
|
||||
digest = passwd[(offset + 1) :] # noqa
|
||||
if len(digest) != cls.digest_size * 2:
|
||||
return False
|
||||
return re.match(r"^[0-9a-f]+$", digest)
|
||||
|
@ -118,7 +142,7 @@ class Password(Field):
|
|||
record[self.name] = f"{salt}:{digest}"
|
||||
|
||||
|
||||
class Record(Dict[(str, Field)]):
|
||||
class Record(typing.Dict[(str, Field)]):
|
||||
"""
|
||||
Base type for a single database record.
|
||||
"""
|
||||
|
@ -148,7 +172,7 @@ class Record(Dict[(str, Field)]):
|
|||
"""
|
||||
rec = {}
|
||||
for name, _field in self._metadata.fields.items():
|
||||
rec[name] = _field.serialize(self[name])
|
||||
rec[name] = _field.serialize(self[name]) if isinstance(_field, Field) else _field
|
||||
return self.__class__(rec, doc_id=self.doc_id)
|
||||
|
||||
def deserialize(self, db, recurse: bool = True):
|
||||
|
@ -228,10 +252,10 @@ class Pointer(Field):
|
|||
elif type(value) == str:
|
||||
pt, puid = value.split("::")
|
||||
if puid:
|
||||
try:
|
||||
return db.table(pt).search(where("uid") == puid, recurse=recurse)[0]
|
||||
except IndexError:
|
||||
rec = db.table(pt).get(where("uid") == puid, recurse=recurse)
|
||||
if not rec:
|
||||
raise PointerReferenceError(f"Expected a {pt} with uid=={puid} but did not find one!")
|
||||
return rec
|
||||
return value
|
||||
|
||||
|
||||
|
@ -247,12 +271,12 @@ class Collection(Field):
|
|||
"""
|
||||
|
||||
value_type: type = Record
|
||||
default: List[value_type] = field(default_factory=lambda: [])
|
||||
default: typing.List[value_type] = field(default_factory=lambda: [])
|
||||
|
||||
def serialize(self, values: List[value_type]) -> List[str]:
|
||||
def serialize(self, values: typing.List[value_type]) -> typing.List[str]:
|
||||
return [Pointer.reference(val) for val in values]
|
||||
|
||||
def deserialize(self, values: List[str], db: TinyDB, recurse: bool = False) -> List[value_type]:
|
||||
def deserialize(self, values: typing.List[str], db: TinyDB, recurse: bool = False) -> typing.List[value_type]:
|
||||
"""
|
||||
Recursively deserialize the objects in this collection
|
||||
"""
|
||||
|
@ -275,3 +299,32 @@ class Collection(Field):
|
|||
for backref in target._metadata.backrefs(type(record)):
|
||||
target[backref.name] = record
|
||||
db.save(target)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecordDict(Field):
|
||||
value_type: type = Record
|
||||
default: typing.Dict[(str, Record)] = field(default_factory=lambda: {})
|
||||
|
||||
def serialize(self, values: typing.Dict[(str, value_type)]) -> typing.Dict[(str, str)]:
|
||||
return dict((key, Pointer.reference(val)) for (key, val) in values.items())
|
||||
|
||||
def deserialize(
|
||||
self, values: typing.Dict[(str, str)], db: TinyDB, recurse: bool = False
|
||||
) -> typing.Dict[(str, value_type)]:
|
||||
if not recurse:
|
||||
return values
|
||||
return dict((key, Pointer.dereference(val, db=db, recurse=False)) for (key, val) in values.items())
|
||||
|
||||
def after_insert(self, db: TinyDB, record: Record) -> None:
|
||||
"""
|
||||
Populate any backreferences in the members of this mapping with the parent record's uid.
|
||||
"""
|
||||
if not record[self.name]:
|
||||
return
|
||||
|
||||
for key, pointer in record[self.name].items():
|
||||
target = Pointer.dereference(pointer, db=db, recurse=False)
|
||||
for backref in target._metadata.backrefs(type(record)):
|
||||
target[backref.name] = record
|
||||
db.save(target)
|
||||
|
|
|
@ -146,3 +146,28 @@ def test_datetime(db):
|
|||
sleep(1)
|
||||
user = db.save(user)
|
||||
assert user.last_updated >= user.created
|
||||
|
||||
|
||||
def test_mapping(db):
|
||||
album = db.save(
|
||||
examples.Album(
|
||||
name="The Impossible Kid",
|
||||
credits={"Produced By": "Aesop Rock", "Lyrics By": "Aesop Rock", "Puke in the MeowMix By": "Kirby"},
|
||||
tracks=["Mystery Fish", "Rings", "Lotta Years", "Dorks"],
|
||||
)
|
||||
)
|
||||
assert album.credits["Produced By"] == "Aesop Rock"
|
||||
assert album.tracks[0] == "Mystery Fish"
|
||||
|
||||
aes = db.save(
|
||||
examples.Artist(
|
||||
name="Aesop Rock",
|
||||
albums={"The Impossible Kid": album},
|
||||
)
|
||||
)
|
||||
|
||||
album = db.Album.get(doc_id=album.doc_id)
|
||||
assert album.artist.uid == aes.uid
|
||||
assert album.name in aes.albums
|
||||
assert aes.albums[album.name].uid == album.uid
|
||||
assert "Kirby" in aes.albums[album.name].credits.values()
|
||||
|
|
Loading…
Reference in New Issue
Block a user