2024-01-28 22:14:50 -08:00
import base64
import hashlib
2024-04-20 20:35:07 -07:00
import json
2024-03-26 00:53:21 -07:00
import os
2024-01-31 22:39:54 -08:00
from contextlib import contextmanager
2024-01-28 22:14:50 -08:00
from functools import cached_property
2024-06-30 23:21:23 -07:00
from sqlite3 import IntegrityError
2024-01-28 22:14:50 -08:00
2024-03-26 00:53:21 -07:00
import transaction
2024-04-28 14:30:47 -07:00
from pyramid_sqlalchemy.meta import Session
2024-06-30 23:21:23 -07:00
from sqlalchemy import create_engine, event, insert
2024-01-28 00:46:19 -08:00
2024-06-30 23:21:23 -07:00
from ttfrog.db import schema
2024-03-26 00:53:21 -07:00
from ttfrog.path import database
2024-01-30 01:25:02 -08:00
2024-04-20 20:35:07 -07:00
class AlchemyEncoder(json.JSONEncoder):
def default(self, obj):
2024-04-20 20:35:24 -07:00
return getattr(obj, "__json__")()
2024-04-20 20:35:07 -07:00
except (AttributeError, NotImplementedError): # pragma: no cover
return super().default(obj)
2024-01-28 00:46:19 -08:00
class SQLDatabaseManager:
A context manager for working with sqllite database.
2024-03-26 00:53:21 -07:00
2024-01-28 00:46:19 -08:00
def url(self):
2024-03-26 00:53:21 -07:00
return os.environ.get("DATABASE_URL", f"sqlite:///{database()}")
2024-01-28 00:46:19 -08:00
def engine(self):
2024-01-30 01:25:02 -08:00
return create_engine(self.url)
2024-02-02 15:40:45 -08:00
2024-01-30 01:25:02 -08:00
def session(self):
return Session
2024-01-28 00:46:19 -08:00
2024-01-30 01:25:02 -08:00
def metadata(self):
2024-06-30 23:21:23 -07:00
return schema.BaseObject.metadata
2024-01-28 00:46:19 -08:00
def tables(self):
2024-01-30 01:25:02 -08:00
return dict((t.name, t) for t in self.metadata.sorted_tables)
2024-01-28 00:46:19 -08:00
2024-01-31 22:39:54 -08:00
def transaction(self):
with transaction.manager as tm:
yield tm
2024-04-20 20:35:07 -07:00
except Exception: # pragam: no cover
2024-01-31 22:39:54 -08:00
2024-04-21 21:30:24 -07:00
def add_or_update(self, record, *args, **kwargs):
if not isinstance(record, list):
record = [record]
for rec in record:
self.session.add(rec, *args, **kwargs)
2024-02-04 15:22:54 -08:00
2024-01-28 00:46:19 -08:00
def query(self, *args, **kwargs):
2024-01-30 01:25:02 -08:00
return self.session.query(*args, **kwargs)
2024-01-28 00:46:19 -08:00
2024-01-28 22:14:50 -08:00
def slugify(self, rec: dict) -> str:
Create a uniquish slug from a dictionary.
2024-03-26 00:53:21 -07:00
sha1bytes = hashlib.sha1(str(rec["id"]).encode())
2024-01-28 22:14:50 -08:00
return base64.urlsafe_b64encode(sha1bytes.digest()).decode("ascii")[:10]
2024-01-30 01:25:02 -08:00
def init(self):
2024-04-29 01:09:58 -07:00
2024-04-28 14:30:47 -07:00
self.metadata.bind = self.engine
2024-01-30 01:25:02 -08:00
2024-01-28 22:14:50 -08:00
2024-04-20 20:35:07 -07:00
def dump(self, names: list = []):
2024-03-24 16:56:13 -07:00
results = {}
2024-03-26 00:53:21 -07:00
for table_name, table in self.tables.items():
2024-04-20 20:35:07 -07:00
if not names or table_name in names:
results[table_name] = [dict(row._mapping) for row in self.query(table).all()]
return json.dumps(results, indent=2, cls=AlchemyEncoder)
2024-03-24 16:56:13 -07:00
2024-06-30 23:21:23 -07:00
def load(self, data: dict):
for table_name, rows in data.items():
table = self.tables.get(table_name, None)
if table is None:
raise IntegrityError(f"Table {table_name} not found in database.")
if not rows:
query = insert(table), rows
2024-01-28 00:46:19 -08:00
def __getattr__(self, name: str):
2024-06-30 23:21:23 -07:00
return self.query(getattr(schema, name))
2024-01-28 00:46:19 -08:00
db = SQLDatabaseManager()
2024-05-06 00:13:52 -07:00
@event.listens_for(db.session, "after_flush")
def session_after_flush(session, flush_context):
Listen to flush events looking for newly-created objects. For each one, if the
obj has a __after_insert__ method, call it.
for obj in session.new:
callback = getattr(obj, "__after_insert__", None)
if callback: