Make objects iterable by default, add tests, refactoring

This commit is contained in:
evilchili 2024-04-20 20:35:07 -07:00
parent 44cd8fe9c9
commit 412efe2aec
10 changed files with 104 additions and 75 deletions

View File

@ -36,11 +36,14 @@ HOST={default_host}
PORT={default_port} PORT={default_port}
""" """
db_app = typer.Typer()
app = typer.Typer() app = typer.Typer()
app.add_typer(db_app, name="db", help="Manage the database.")
app_state = dict() app_state = dict()
@app.callback() @app.callback()
@db_app.callback()
def main( def main(
context: typer.Context, context: typer.Context,
root: Optional[Path] = typer.Option( root: Optional[Path] = typer.Option(
@ -59,20 +62,6 @@ def main(
) )
@app.command()
def setup(context: typer.Context):
"""
(Re)Initialize TableTop Frog. Idempotent; will preserve any existing configuration.
"""
from ttfrog.db.bootstrap import bootstrap
if not os.path.exists(app_state["env"]):
app_state["env"].parent.mkdir(parents=True, exist_ok=True)
app_state["env"].write_text(dedent(SETUP_HELP))
print(f"Wrote defaults file {app_state['env']}.")
bootstrap()
@app.command() @app.command()
def serve( def serve(
context: typer.Context, context: typer.Context,
@ -99,5 +88,35 @@ def serve(
application.start(host=host, port=port, debug=debug) application.start(host=host, port=port, debug=debug)
@db_app.command()
def setup(context: typer.Context):
"""
(Re)Initialize TableTop Frog. Idempotent; will preserve any existing configuration.
"""
from ttfrog.db.bootstrap import bootstrap
if not os.path.exists(app_state["env"]):
app_state["env"].parent.mkdir(parents=True, exist_ok=True)
app_state["env"].write_text(dedent(SETUP_HELP))
print(f"Wrote defaults file {app_state['env']}.")
bootstrap()
@db_app.command()
def list(context: typer.Context):
from ttfrog.db.manager import db
print("\n".join(sorted(db.tables.keys())))
@db_app.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True})
def dump(context: typer.Context):
"""
Dump tables (or the entire database) as a JSON blob.
"""
from ttfrog.db.manager import db
db.init()
print(db.dump(context.args))
if __name__ == "__main__": if __name__ == "__main__":
app() app()

View File

@ -2,7 +2,7 @@ import enum
import nanoid import nanoid
from nanoid_dictionary import human_alphabet from nanoid_dictionary import human_alphabet
from pyramid_sqlalchemy import BaseObject from pyramid_sqlalchemy import BaseObject as _BaseObject
from slugify import slugify from slugify import slugify
from sqlalchemy import Column, String from sqlalchemy import Column, String
@ -19,10 +19,11 @@ class SlugMixin:
return "-".join([self.slug, slugify(self.name.title().replace(" ", ""), ok="", only_ascii=True, lower=False)]) return "-".join([self.slug, slugify(self.name.title().replace(" ", ""), ok="", only_ascii=True, lower=False)])
class IterableMixin: class BaseObject(_BaseObject):
""" """
Allows for iterating over Model objects' column names and values Allows for iterating over Model objects' column names and values
""" """
__abstract__ = True
def __iter__(self): def __iter__(self):
values = vars(self) values = vars(self)
@ -42,14 +43,11 @@ class IterableMixin:
relvals.append(rel) relvals.append(rel)
yield relname, relvals yield relname, relvals
def __json__(self, request): def __json__(self):
serialized = dict() """
for key, value in self: Provide a custom JSON encoder.
try: """
serialized[key] = getattr(self.value, "__json__")(request) raise NotImplementedError()
except AttributeError:
serialized[key] = value
return serialized
def __repr__(self): def __repr__(self):
return str(dict(self)) return str(dict(self))
@ -90,7 +88,7 @@ class EnumField(enum.Enum):
A serializable enum. A serializable enum.
""" """
def __json__(self, request): def __json__(self):
return self.value return self.value
@ -116,6 +114,3 @@ CREATURE_TYPES = [
] ]
CreatureTypesEnum = EnumField("CreatureTypesEnum", ((k, k) for k in CREATURE_TYPES)) CreatureTypesEnum = EnumField("CreatureTypesEnum", ((k, k) for k in CREATURE_TYPES))
StatsEnum = EnumField("StatsEnum", ((k, k) for k in STATS)) StatsEnum = EnumField("StatsEnum", ((k, k) for k in STATS))
# class Table(*Bases):
Bases = [BaseObject, IterableMixin, SlugMixin]

View File

@ -1,5 +1,6 @@
import base64 import base64
import hashlib import hashlib
import json
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from functools import cached_property from functools import cached_property
@ -9,13 +10,19 @@ from pyramid_sqlalchemy import Session, init_sqlalchemy
from pyramid_sqlalchemy import metadata as _metadata from pyramid_sqlalchemy import metadata as _metadata
from sqlalchemy import create_engine from sqlalchemy import create_engine
import ttfrog.db.schema import ttfrog.db.schema
from ttfrog.path import database from ttfrog.path import database
# from sqlalchemy.exc import IntegrityError assert ttfrog.db.schema
ttfrog.db.schema class AlchemyEncoder(json.JSONEncoder):
def default(self, obj):
try:
return getattr(obj, '__json__')()
except (AttributeError, NotImplementedError): # pragma: no cover
return super().default(obj)
class SQLDatabaseManager: class SQLDatabaseManager:
@ -49,11 +56,11 @@ class SQLDatabaseManager:
yield tm yield tm
try: try:
tm.commit() tm.commit()
except Exception: except Exception: # pragam: no cover
tm.abort() tm.abort()
raise raise
def add(self, *args, **kwargs): def add_or_update(self, *args, **kwargs):
self.session.add(*args, **kwargs) self.session.add(*args, **kwargs)
self.session.flush() self.session.flush()
@ -71,11 +78,12 @@ class SQLDatabaseManager:
init_sqlalchemy(self.engine) init_sqlalchemy(self.engine)
self.metadata.create_all(self.engine) self.metadata.create_all(self.engine)
def dump(self): def dump(self, names: list = []):
results = {} results = {}
for table_name, table in self.tables.items(): for table_name, table in self.tables.items():
results[table_name] = [row for row in self.query(table).all()] if not names or table_name in names:
return results results[table_name] = [dict(row._mapping) for row in self.query(table).all()]
return json.dumps(results, indent=2, cls=AlchemyEncoder)
def __getattr__(self, name: str): def __getattr__(self, name: str):
try: try:

View File

@ -1,4 +1,4 @@
from .character import * from .character import *
from .classes import * from .classes import *
from .property import * from .property import *
from .transaction import * from .log import *

View File

@ -2,7 +2,7 @@ from sqlalchemy import Column, Enum, ForeignKey, Integer, String, Text, UniqueCo
from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from ttfrog.db.base import BaseObject, Bases, CreatureTypesEnum, IterableMixin, SavingThrowsMixin, SkillsMixin from ttfrog.db.base import BaseObject, CreatureTypesEnum, SavingThrowsMixin, SkillsMixin, SlugMixin
__all__ = [ __all__ = [
"Ancestry", "Ancestry",
@ -28,6 +28,7 @@ def attr_map_creator(fields):
class AncestryTraitMap(BaseObject): class AncestryTraitMap(BaseObject):
__tablename__ = "trait_map" __tablename__ = "trait_map"
__table_args__ = (UniqueConstraint("ancestry_id", "ancestry_trait_id"), )
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
ancestry_id = Column(Integer, ForeignKey("ancestry.id")) ancestry_id = Column(Integer, ForeignKey("ancestry.id"))
ancestry_trait_id = Column(Integer, ForeignKey("ancestry_trait.id")) ancestry_trait_id = Column(Integer, ForeignKey("ancestry_trait.id"))
@ -35,7 +36,7 @@ class AncestryTraitMap(BaseObject):
level = Column(Integer, nullable=False, info={"min": 1, "max": 20}) level = Column(Integer, nullable=False, info={"min": 1, "max": 20})
class Ancestry(*Bases): class Ancestry(BaseObject):
""" """
A character ancestry ("race"), which has zero or more AncestryTraits. A character ancestry ("race"), which has zero or more AncestryTraits.
""" """
@ -44,13 +45,13 @@ class Ancestry(*Bases):
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String, index=True, unique=True) name = Column(String, index=True, unique=True)
creature_type = Column(Enum(CreatureTypesEnum)) creature_type = Column(Enum(CreatureTypesEnum))
traits = relationship("AncestryTraitMap", lazy="immediate") _traits = relationship("AncestryTraitMap", lazy="immediate")
def __repr__(self): def __repr__(self):
return self.name return self.name
class AncestryTrait(BaseObject, IterableMixin): class AncestryTrait(BaseObject):
""" """
A trait granted to a character via its Ancestry. A trait granted to a character via its Ancestry.
""" """
@ -64,12 +65,12 @@ class AncestryTrait(BaseObject, IterableMixin):
return self.name return self.name
class CharacterClassMap(BaseObject, IterableMixin): class CharacterClassMap(BaseObject):
__tablename__ = "class_map" __tablename__ = "class_map"
__table_args__ = (UniqueConstraint("character_id", "character_class_id"), )
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
character_id = Column(Integer, ForeignKey("character.id"), nullable=False) character_id = Column(Integer, ForeignKey("character.id"), nullable=False)
character_class_id = Column(Integer, ForeignKey("character_class.id"), nullable=False) character_class_id = Column(Integer, ForeignKey("character_class.id"), nullable=False)
mapping = UniqueConstraint(character_id, character_class_id)
level = Column(Integer, nullable=False, info={"min": 1, "max": 20}, default=1) level = Column(Integer, nullable=False, info={"min": 1, "max": 20}, default=1)
character_class = relationship("CharacterClass", lazy="immediate") character_class = relationship("CharacterClass", lazy="immediate")
@ -79,13 +80,13 @@ class CharacterClassMap(BaseObject, IterableMixin):
return "{self.character.name}, {self.character_class.name}, level {self.level}" return "{self.character.name}, {self.character_class.name}, level {self.level}"
class CharacterClassAttributeMap(BaseObject, IterableMixin): class CharacterClassAttributeMap(BaseObject):
__tablename__ = "character_class_attribute_map" __tablename__ = "character_class_attribute_map"
__table_args__ = (UniqueConstraint("character_id", "class_attribute_id"), )
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
character_id = Column(Integer, ForeignKey("character.id"), nullable=False) character_id = Column(Integer, ForeignKey("character.id"), nullable=False)
class_attribute_id = Column(Integer, ForeignKey("class_attribute.id"), nullable=False) class_attribute_id = Column(Integer, ForeignKey("class_attribute.id"), nullable=False)
option_id = Column(Integer, ForeignKey("class_attribute_option.id"), nullable=False) option_id = Column(Integer, ForeignKey("class_attribute_option.id"), nullable=False)
mapping = UniqueConstraint(character_id, class_attribute_id)
class_attribute = relationship("ClassAttribute", lazy="immediate") class_attribute = relationship("ClassAttribute", lazy="immediate")
option = relationship("ClassAttributeOption", lazy="immediate") option = relationship("ClassAttributeOption", lazy="immediate")
@ -100,7 +101,7 @@ class CharacterClassAttributeMap(BaseObject, IterableMixin):
) )
class Character(*Bases, SavingThrowsMixin, SkillsMixin): class Character(BaseObject, SlugMixin, SavingThrowsMixin, SkillsMixin):
__tablename__ = "character" __tablename__ = "character"
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String, default="New Character", nullable=False) name = Column(String, default="New Character", nullable=False)
@ -132,7 +133,7 @@ class Character(*Bases, SavingThrowsMixin, SkillsMixin):
@property @property
def traits(self): def traits(self):
return [mapping.trait for mapping in self.ancestry.traits] return [mapping.trait for mapping in self.ancestry._traits]
@property @property
def level(self): def level(self):
@ -172,8 +173,11 @@ class Character(*Bases, SavingThrowsMixin, SkillsMixin):
def add_class_attribute(self, attribute, option): def add_class_attribute(self, attribute, option):
for thisclass in self.classes.values(): for thisclass in self.classes.values():
# this test is failing? current_level = self.levels[thisclass.name]
if attribute.name in thisclass.attributes_by_level.get(self.levels[thisclass.name], {}): current_attributes = thisclass.attributes_by_level.get(current_level, {})
if attribute.name in current_attributes:
if attribute.name in self.class_attributes:
return True
self.attribute_list.append( self.attribute_list.append(
CharacterClassAttributeMap( CharacterClassAttributeMap(
character_id=self.id, class_attribute_id=attribute.id, option_id=option.id character_id=self.id, class_attribute_id=attribute.id, option_id=option.id

View File

@ -3,7 +3,7 @@ from collections import defaultdict
from sqlalchemy import Column, Enum, ForeignKey, Integer, String from sqlalchemy import Column, Enum, ForeignKey, Integer, String
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from ttfrog.db.base import BaseObject, Bases, IterableMixin, SavingThrowsMixin, SkillsMixin, StatsEnum from ttfrog.db.base import BaseObject, SavingThrowsMixin, SkillsMixin, StatsEnum
__all__ = [ __all__ = [
"ClassAttributeMap", "ClassAttributeMap",
@ -13,7 +13,7 @@ __all__ = [
] ]
class ClassAttributeMap(BaseObject, IterableMixin): class ClassAttributeMap(BaseObject):
__tablename__ = "class_attribute_map" __tablename__ = "class_attribute_map"
class_attribute_id = Column(Integer, ForeignKey("class_attribute.id"), primary_key=True) class_attribute_id = Column(Integer, ForeignKey("class_attribute.id"), primary_key=True)
character_class_id = Column(Integer, ForeignKey("character_class.id"), primary_key=True) character_class_id = Column(Integer, ForeignKey("character_class.id"), primary_key=True)
@ -21,7 +21,7 @@ class ClassAttributeMap(BaseObject, IterableMixin):
attribute = relationship("ClassAttribute", uselist=False, viewonly=True, lazy="immediate") attribute = relationship("ClassAttribute", uselist=False, viewonly=True, lazy="immediate")
class ClassAttribute(BaseObject, IterableMixin): class ClassAttribute(BaseObject):
__tablename__ = "class_attribute" __tablename__ = "class_attribute"
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String, nullable=False) name = Column(String, nullable=False)
@ -31,14 +31,14 @@ class ClassAttribute(BaseObject, IterableMixin):
return f"{self.id}: {self.name}" return f"{self.id}: {self.name}"
class ClassAttributeOption(BaseObject, IterableMixin): class ClassAttributeOption(BaseObject):
__tablename__ = "class_attribute_option" __tablename__ = "class_attribute_option"
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String, nullable=False) name = Column(String, nullable=False)
attribute_id = Column(Integer, ForeignKey("class_attribute.id"), nullable=False) attribute_id = Column(Integer, ForeignKey("class_attribute.id"), nullable=False)
class CharacterClass(*Bases, SavingThrowsMixin, SkillsMixin): class CharacterClass(BaseObject, SavingThrowsMixin, SkillsMixin):
__tablename__ = "character_class" __tablename__ = "character_class"
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String, index=True, unique=True) name = Column(String, index=True, unique=True)

View File

@ -1,11 +1,11 @@
from sqlalchemy import Column, Integer, String, Text from sqlalchemy import Column, Integer, String, Text
from ttfrog.db.base import BaseObject, IterableMixin from ttfrog.db.base import BaseObject
__all__ = ["TransactionLog"] __all__ = ["TransactionLog"]
class TransactionLog(BaseObject, IterableMixin): class TransactionLog(BaseObject):
__tablename__ = "transaction_log" __tablename__ = "transaction_log"
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
source_table_name = Column(String, index=True, nullable=False) source_table_name = Column(String, index=True, nullable=False)

View File

@ -1,6 +1,6 @@
from sqlalchemy import Column, Integer, String, Text, UniqueConstraint from sqlalchemy import Column, Integer, String, Text, UniqueConstraint
from ttfrog.db.base import BaseObject, Bases, IterableMixin from ttfrog.db.base import BaseObject
__all__ = [ __all__ = [
"Skill", "Skill",
@ -9,7 +9,7 @@ __all__ = [
] ]
class Skill(*Bases): class Skill(BaseObject):
__tablename__ = "skill" __tablename__ = "skill"
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String, index=True, unique=True) name = Column(String, index=True, unique=True)
@ -19,7 +19,7 @@ class Skill(*Bases):
return str(self.name) return str(self.name)
class Proficiency(*Bases): class Proficiency(BaseObject):
__tablename__ = "proficiency" __tablename__ = "proficiency"
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String, index=True, unique=True) name = Column(String, index=True, unique=True)
@ -28,7 +28,7 @@ class Proficiency(*Bases):
return str(self.name) return str(self.name)
class Modifier(BaseObject, IterableMixin): class Modifier(BaseObject):
__tablename__ = "modifier" __tablename__ = "modifier"
__table_args__ = (UniqueConstraint("source_table_name", "source_table_id", "value", "type", "target"),) __table_args__ = (UniqueConstraint("source_table_name", "source_table_id", "value", "type", "target"),)
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)

View File

@ -13,7 +13,7 @@
"AncestryTraitMap": [ "AncestryTraitMap": [
{"ancestry_id": 1, "ancestry_trait_id": 1, "level": 1}, {"ancestry_id": 1, "ancestry_trait_id": 1, "level": 1},
{"ancestry_id": 2, "ancestry_trait_id": 2, "level": 1}, {"ancestry_id": 2, "ancestry_trait_id": 2, "level": 1},
{"ancestry_id": 2, "ancestry_trait_id": 2, "level": 1}, {"ancestry_id": 2, "ancestry_trait_id": 3, "level": 1},
{"ancestry_id": 3, "ancestry_trait_id": 3, "level": 1} {"ancestry_id": 3, "ancestry_trait_id": 3, "level": 1}
] ]
} }

View File

@ -1,7 +1,8 @@
import json
from ttfrog.db import schema from ttfrog.db import schema
def test_create_character(db, classes_factory, ancestries_factory): def test_manage_character(db, classes_factory, ancestries_factory):
with db.transaction(): with db.transaction():
# load the fixtures so they are bound to the current session # load the fixtures so they are bound to the current session
classes = classes_factory() classes = classes_factory()
@ -10,7 +11,7 @@ def test_create_character(db, classes_factory, ancestries_factory):
# create a human character (the default) # create a human character (the default)
char = schema.Character(name="Test Character") char = schema.Character(name="Test Character")
db.add(char) db.add_or_update(char)
assert char.id == 1 assert char.id == 1
assert char.armor_class == 10 assert char.armor_class == 10
assert char.name == "Test Character" assert char.name == "Test Character"
@ -19,14 +20,14 @@ def test_create_character(db, classes_factory, ancestries_factory):
# switch ancestry to tiefling # switch ancestry to tiefling
char.ancestry = ancestries["tiefling"] char.ancestry = ancestries["tiefling"]
db.add(char) db.add_or_update(char)
char = db.session.get(schema.Character, 1) char = db.session.get(schema.Character, 1)
assert char.ancestry.name == "tiefling" assert char.ancestry.name == "tiefling"
assert darkvision in char.traits assert darkvision in char.traits
# assign a class and level # assign a class and level
char.add_class(classes["fighter"], level=1) char.add_class(classes["fighter"], level=1)
db.add(char) db.add_or_update(char)
assert char.levels == {"fighter": 1} assert char.levels == {"fighter": 1}
assert char.level == 1 assert char.level == 1
assert char.class_attributes == {} assert char.class_attributes == {}
@ -34,37 +35,39 @@ def test_create_character(db, classes_factory, ancestries_factory):
# 'fighting style' is available, but not at this level # 'fighting style' is available, but not at this level
fighting_style = char.classes["fighter"].attributes_by_level[2]["Fighting Style"] fighting_style = char.classes["fighter"].attributes_by_level[2]["Fighting Style"]
assert char.add_class_attribute(fighting_style, fighting_style.options[0]) is False assert char.add_class_attribute(fighting_style, fighting_style.options[0]) is False
db.add(char) db.add_or_update(char)
assert char.class_attributes == {} assert char.class_attributes == {}
# level up # level up
char.add_class(classes["fighter"], level=2) char.add_class(classes["fighter"], level=2)
db.add(char) db.add_or_update(char)
assert char.levels == {"fighter": 2} assert char.levels == {"fighter": 2}
assert char.level == 2 assert char.level == 2
# Assign the fighting style # Assert the fighting style is added automatically and idempotent...ly?
assert char.add_class_attribute(fighting_style, fighting_style.options[0])
db.add(char)
assert char.class_attributes[fighting_style.name] == fighting_style.options[0] assert char.class_attributes[fighting_style.name] == fighting_style.options[0]
assert char.add_class_attribute(fighting_style, fighting_style.options[0]) is True
db.add_or_update(char)
# classes # classes
char.add_class(classes["rogue"], level=1) char.add_class(classes["rogue"], level=1)
db.add(char) db.add_or_update(char)
assert char.level == 3 assert char.level == 3
assert char.levels == {"fighter": 2, "rogue": 1} assert char.levels == {"fighter": 2, "rogue": 1}
# remove a class # remove a class
char.remove_class(classes["rogue"]) char.remove_class(classes["rogue"])
db.add(char) db.add_or_update(char)
assert char.levels == {"fighter": 2} assert char.levels == {"fighter": 2}
assert char.level == 2 assert char.level == 2
# remove all remaining classes # remove remaining class by setting level to zero
char.remove_class(classes["fighter"]) char.add_class(classes["fighter"], level=0)
db.add(char) db.add_or_update(char)
assert char.levels == {}
# ensure we're not persisting any orphan records in the map tables # ensure we're not persisting any orphan records in the map tables
dump = db.dump() dump = json.loads(db.dump())
assert dump["class_map"] == []
assert dump["class_map"] == [] assert dump["class_map"] == []
assert dump["character_class_attribute_map"] == [] assert dump["character_class_attribute_map"] == []