import base64 import hashlib import json import os from contextlib import contextmanager from functools import cached_property from sqlite3 import IntegrityError import transaction from pyramid_sqlalchemy.meta import Session from sqlalchemy import create_engine, event, insert from ttfrog.db import schema from ttfrog.path import database class AlchemyEncoder(json.JSONEncoder): def default(self, obj): try: return getattr(obj, "__json__")() except (AttributeError, NotImplementedError): # pragma: no cover return super().default(obj) class SQLDatabaseManager: """ A context manager for working with sqllite database. """ @cached_property def url(self): return os.environ.get("DATABASE_URL", f"sqlite:///{database()}") @cached_property def engine(self): return create_engine(self.url) @cached_property def session(self): return Session @cached_property def metadata(self): return schema.BaseObject.metadata @cached_property def tables(self): return dict((t.name, t) for t in self.metadata.sorted_tables) @contextmanager def transaction(self): with transaction.manager as tm: yield tm try: tm.commit() except Exception: # pragam: no cover tm.abort() raise def add_or_update(self, record, *args, **kwargs): if not isinstance(record, list): record = [record] for rec in record: self.session.add(rec, *args, **kwargs) self.session.flush() def query(self, *args, **kwargs): return self.session.query(*args, **kwargs) def slugify(self, rec: dict) -> str: """ Create a uniquish slug from a dictionary. """ sha1bytes = hashlib.sha1(str(rec["id"]).encode()) return base64.urlsafe_b64encode(sha1bytes.digest()).decode("ascii")[:10] def init(self): self.metadata.bind = self.engine self.session.remove() self.session.configure(bind=self.engine) self.metadata.create_all(self.engine) def dump(self, names: list = []): results = {} for table_name, table in self.tables.items(): if not names or table_name in names: results[table_name] = [dict(row._mapping) for row in self.query(table).all()] return json.dumps(results, indent=2, cls=AlchemyEncoder) def load(self, data: dict): for table_name, rows in data.items(): table = self.tables.get(table_name, None) if table is None: raise IntegrityError(f"Table {table_name} not found in database.") if not rows: continue query = insert(table), rows self.session.execute(*query) def __getattr__(self, name: str): return self.query(getattr(schema, name)) db = SQLDatabaseManager() @event.listens_for(db.session, "after_flush") def session_after_flush(session, flush_context): """ Listen to flush events looking for newly-created objects. For each one, if the obj has a __after_insert__ method, call it. """ for obj in session.new: callback = getattr(obj, "__after_insert__", None) if callback: callback(session)