diff --git a/ttfrog/webserver/controllers/base.py b/ttfrog/webserver/controllers/base.py index da6f5c5..e74e7a4 100644 --- a/ttfrog/webserver/controllers/base.py +++ b/ttfrog/webserver/controllers/base.py @@ -5,8 +5,7 @@ from collections import defaultdict from pyramid.httpexceptions import HTTPFound from pyramid.interfaces import IRoutesMapper - -from sqlalchemy.inspection import inspect +from wtforms.fields import SelectField from ttfrog.attribute_map import AttributeMap from ttfrog.db.manager import db @@ -27,8 +26,10 @@ def get_all_routes(request): return routes -def query_factory(model): - return lambda: db.query(model).all() +class DeferredSelectField(SelectField): + def __init__(self, *args, model=None, **kwargs): + super().__init__(*args, **kwargs) + self.choices = db.query(model).all() class BaseController: @@ -80,17 +81,6 @@ class BaseController: if 'all_records' not in self.attrs: self.attrs['all_records'] = db.query(self.model).all() - def coerce_foreign_keys(self): - inspector = inspect(db.engine) - foreign_keys = inspector.get_foreign_keys(table_name=self.record.__class__.__tablename__) - for foreign_key in foreign_keys: - for col in inspector.get_columns(foreign_key['referred_table']): - if col['name'] == foreign_key['referred_columns'][0]: - col_name = foreign_key['constrained_columns'][0] - col_type = col['type'].python_type - col_value = col_type(getattr(self.record, col_name)) - setattr(self.record, col_name, col_value) - def template_context(self, **kwargs) -> dict: return AttributeMap.from_dict({ 'c': dict( @@ -110,7 +100,6 @@ class BaseController: if not self.form.validate(): return self.form.populate_obj(self.record) - self.coerce_foreign_keys() if self.record.id: return with db.transaction(): diff --git a/ttfrog/webserver/controllers/character_sheet.py b/ttfrog/webserver/controllers/character_sheet.py index 2f4a31b..57380cc 100644 --- a/ttfrog/webserver/controllers/character_sheet.py +++ b/ttfrog/webserver/controllers/character_sheet.py @@ -1,8 +1,6 @@ -from ttfrog.webserver.controllers.base import BaseController, query_factory +from ttfrog.webserver.controllers.base import BaseController, DeferredSelectField from ttfrog.db.schema import Character, Ancestry -from ttfrog.db.manager import db -from wtforms_alchemy import ModelForm, QuerySelectField -from wtforms.validators import InputRequired +from wtforms_alchemy import ModelForm from wtforms.fields import SubmitField @@ -11,14 +9,9 @@ class CharacterForm(ModelForm): model = Character exclude = ['slug'] - def get_session(): - return db.session - save = SubmitField() delete = SubmitField() - - ancestry = QuerySelectField('Ancestry', validators=[InputRequired()], - query_factory=query_factory(Ancestry), get_label='name') + ancestry = DeferredSelectField('Ancestry', model=Ancestry, coerce=str, validate_choice=True) class CharacterSheet(BaseController):