117 lines
3.3 KiB
Python
117 lines
3.3 KiB
Python
import base64
|
|
import hashlib
|
|
import json
|
|
import os
|
|
from contextlib import contextmanager
|
|
from functools import cached_property
|
|
from sqlite3 import IntegrityError
|
|
|
|
import transaction
|
|
from pyramid_sqlalchemy.meta import Session
|
|
from sqlalchemy import create_engine, event, insert
|
|
|
|
from ttfrog.db import schema
|
|
from ttfrog.path import database
|
|
|
|
|
|
class AlchemyEncoder(json.JSONEncoder):
|
|
def default(self, obj):
|
|
try:
|
|
return getattr(obj, "__json__")()
|
|
except (AttributeError, NotImplementedError): # pragma: no cover
|
|
return super().default(obj)
|
|
|
|
|
|
class SQLDatabaseManager:
|
|
"""
|
|
A context manager for working with sqllite database.
|
|
"""
|
|
|
|
@cached_property
|
|
def url(self):
|
|
return os.environ.get("DATABASE_URL", f"sqlite:///{database()}")
|
|
|
|
@cached_property
|
|
def engine(self):
|
|
return create_engine(self.url)
|
|
|
|
@cached_property
|
|
def session(self):
|
|
return Session
|
|
|
|
@cached_property
|
|
def metadata(self):
|
|
return schema.BaseObject.metadata
|
|
|
|
@cached_property
|
|
def tables(self):
|
|
return dict((t.name, t) for t in self.metadata.sorted_tables)
|
|
|
|
@contextmanager
|
|
def transaction(self):
|
|
with transaction.manager as tm:
|
|
yield tm
|
|
try:
|
|
tm.commit()
|
|
except Exception: # pragam: no cover
|
|
tm.abort()
|
|
raise
|
|
|
|
def add_or_update(self, record, *args, **kwargs):
|
|
if not isinstance(record, list):
|
|
record = [record]
|
|
for rec in record:
|
|
self.session.add(rec, *args, **kwargs)
|
|
self.session.flush()
|
|
|
|
def query(self, *args, **kwargs):
|
|
return self.session.query(*args, **kwargs)
|
|
|
|
def slugify(self, rec: dict) -> str:
|
|
"""
|
|
Create a uniquish slug from a dictionary.
|
|
"""
|
|
sha1bytes = hashlib.sha1(str(rec["id"]).encode())
|
|
return base64.urlsafe_b64encode(sha1bytes.digest()).decode("ascii")[:10]
|
|
|
|
def init(self):
|
|
self.metadata.bind = self.engine
|
|
self.session.remove()
|
|
self.session.configure(bind=self.engine)
|
|
self.metadata.create_all(self.engine)
|
|
|
|
def dump(self, names: list = []):
|
|
results = {}
|
|
for table_name, table in self.tables.items():
|
|
if not names or table_name in names:
|
|
results[table_name] = [dict(row._mapping) for row in self.query(table).all()]
|
|
return json.dumps(results, indent=2, cls=AlchemyEncoder)
|
|
|
|
def load(self, data: dict):
|
|
for table_name, rows in data.items():
|
|
table = self.tables.get(table_name, None)
|
|
if table is None:
|
|
raise IntegrityError(f"Table {table_name} not found in database.")
|
|
if not rows:
|
|
continue
|
|
query = insert(table), rows
|
|
self.session.execute(*query)
|
|
|
|
def __getattr__(self, name: str):
|
|
return self.query(getattr(schema, name))
|
|
|
|
|
|
db = SQLDatabaseManager()
|
|
|
|
|
|
@event.listens_for(db.session, "after_flush")
|
|
def session_after_flush(session, flush_context):
|
|
"""
|
|
Listen to flush events looking for newly-created objects. For each one, if the
|
|
obj has a __after_insert__ method, call it.
|
|
"""
|
|
for obj in session.new:
|
|
callback = getattr(obj, "__after_insert__", None)
|
|
if callback:
|
|
callback(session)
|