From d9b3c4500e118cbf7f782faeca599ac704d34f45 Mon Sep 17 00:00:00 2001 From: evilchili Date: Mon, 2 Sep 2024 14:02:16 -0700 Subject: [PATCH] make inventory maps proxy items and containers propxy inventories --- src/ttfrog/db/schema/container.py | 16 +++- src/ttfrog/db/schema/inventory.py | 152 ++++++++++++++++-------------- test/test_inventories.py | 10 +- 3 files changed, 100 insertions(+), 78 deletions(-) diff --git a/src/ttfrog/db/schema/container.py b/src/ttfrog/db/schema/container.py index fdd93a3..0afab29 100644 --- a/src/ttfrog/db/schema/container.py +++ b/src/ttfrog/db/schema/container.py @@ -1,7 +1,10 @@ +from typing import Union + from sqlalchemy import ForeignKey from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.orm import base as sa_base -from ttfrog.db.schema.inventory import Inventory, InventoryType +from ttfrog.db.schema.inventory import Inventory, InventoryMap, InventoryType from ttfrog.db.schema.item import Item, ItemType __all__ = [ @@ -19,3 +22,14 @@ class Container(Item): lazy="immediate", default_factory=lambda: Inventory(inventory_type=InventoryType.EQUIPMENT), ) + + def __contains__(self, obj: Union[InventoryMap, Item]): + return obj in self.inventory + + def __iter__(self): + yield from self.inventory + + def __getattr__(self, name: str): + if name == sa_base.DEFAULT_STATE_ATTR: + raise AttributeError() + return getattr(self.inventory, name) diff --git a/src/ttfrog/db/schema/inventory.py b/src/ttfrog/db/schema/inventory.py index c942003..8111fcb 100644 --- a/src/ttfrog/db/schema/inventory.py +++ b/src/ttfrog/db/schema/inventory.py @@ -1,7 +1,8 @@ -from typing import List +from typing import List, Union from sqlalchemy import ForeignKey, UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.orm import base as sa_base from ttfrog.db.base import BaseObject, EnumField from ttfrog.db.schema.item import Item, ItemProperty, ItemType @@ -32,77 +33,6 @@ def inventory_map_creator(fields): return InventoryMap(**fields) -class Inventory(BaseObject): - __tablename__ = "inventory" - __table_args__ = (UniqueConstraint("character_id", "container_id", "inventory_type"),) - id: Mapped[int] = mapped_column(init=False, primary_key=True, autoincrement=True) - inventory_type: Mapped[InventoryType] = mapped_column(nullable=False) - - item_map: Mapped[List["InventoryMap"]] = relationship( - uselist=True, cascade="all,delete,delete-orphan", lazy="immediate", default_factory=lambda: [] - ) - - character_id: Mapped[int] = mapped_column(ForeignKey("character.id"), nullable=True, default=None) - container_id: Mapped[int] = mapped_column(ForeignKey("item.id"), nullable=True, default=None) - - character = relationship("Character", init=False, viewonly=True, lazy="immediate") - container = relationship("Item", init=False, viewonly=True, lazy="immediate") - - @property - def items(self): - return [mapping.item for mapping in self.item_map] - - @property - def all_items(self): - def inventory_contents(inventory): - for mapping in inventory.item_map: - yield mapping - if mapping.item.item_type == ItemType.CONTAINER: - yield from inventory_contents(mapping.item.inventory) - yield from inventory_contents(self) - - @property - def all_item_maps(self): - def inventory_map(inventory): - for mapping in inventory.item_map: - yield mapping - if mapping.item.item_type == ItemType.CONTAINER: - yield from inventory_map(mapping.item.inventory) - yield from inventory_map(self) - - def get(self, item): - return self.get_all(item)[0] - - def get_all(self, item): - return [mapping for mapping in self.all_item_maps if mapping.item == item] - - def add(self, item): - if item.item_type not in inventory_type_map[self.inventory_type]: - return False - mapping = InventoryMap(inventory_id=self.id, item_id=item.id) - if item.consumable: - mapping.count = item.count - if item.charges: - mapping.charges = [Charge(inventory_map_id=mapping.id) for i in range(item.charges)] - self.item_map.append(mapping) - return mapping - - def remove(self, mapping): - if mapping in self.item_map: - self.item_map.remove(mapping) - return True - return False - - def __contains__(self, obj): - for item in self.all_items: - if item == obj: - return True - return False - - def __iter__(self): - yield from self.all_items - - class InventoryMap(BaseObject): __tablename__ = "inventory_map" id: Mapped[int] = mapped_column(init=False, primary_key=True, autoincrement=True) @@ -215,6 +145,84 @@ class InventoryMap(BaseObject): self.attuned = False return True + def __getattr__(self, name: str): + if name == sa_base.DEFAULT_STATE_ATTR: + raise AttributeError() + return getattr(self.item, name) + +class Inventory(BaseObject): + __tablename__ = "inventory" + __table_args__ = (UniqueConstraint("character_id", "container_id", "inventory_type"),) + id: Mapped[int] = mapped_column(init=False, primary_key=True, autoincrement=True) + inventory_type: Mapped[InventoryType] = mapped_column(nullable=False) + + item_map: Mapped[List["InventoryMap"]] = relationship( + uselist=True, cascade="all,delete,delete-orphan", lazy="immediate", default_factory=lambda: [] + ) + + character_id: Mapped[int] = mapped_column(ForeignKey("character.id"), nullable=True, default=None) + container_id: Mapped[int] = mapped_column(ForeignKey("item.id"), nullable=True, default=None) + + character = relationship("Character", init=False, viewonly=True, lazy="immediate") + container = relationship("Item", init=False, viewonly=True, lazy="immediate") + + @property + def items(self): + return [mapping.item for mapping in self.item_map] + + @property + def all_items(self): + def inventory_contents(inventory): + for mapping in inventory.item_map: + yield mapping + if mapping.item.item_type == ItemType.CONTAINER: + yield from inventory_contents(mapping.item.inventory) + yield from inventory_contents(self) + + @property + def all_item_maps(self): + def inventory_map(inventory): + for mapping in inventory.item_map: + yield mapping + if mapping.item.item_type == ItemType.CONTAINER: + yield from inventory_map(mapping.item.inventory) + yield from inventory_map(self) + + def get(self, item): + return self.get_all(item)[0] + + def get_all(self, item): + return [mapping for mapping in self.all_item_maps if mapping.item == item] + + def add(self, item): + if item.item_type not in inventory_type_map[self.inventory_type]: + return False + mapping = InventoryMap(inventory_id=self.id, item_id=item.id) + if item.consumable: + mapping.count = item.count + if item.charges: + mapping.charges = [Charge(inventory_map_id=mapping.id) for i in range(item.charges)] + self.item_map.append(mapping) + return mapping + + def remove(self, mapping): + if mapping in self.item_map: + self.item_map.remove(mapping) + return True + return False + + def __contains__(self, obj: Union[InventoryMap, Item]): + if isinstance(obj, InventoryMap): + item = obj.item + else: + item = obj + return item in [mapping.item for mapping in self.all_items] + + def __iter__(self): + yield from self.all_items + + + class Charge(BaseObject): __tablename__ = "charge" diff --git a/test/test_inventories.py b/test/test_inventories.py index 94c2829..f3da768 100644 --- a/test/test_inventories.py +++ b/test/test_inventories.py @@ -165,11 +165,11 @@ def test_containers(db, carl): db.add_or_update([carl, ten_foot_pole, bag_of_holding]) # add the ten_foot_pole to the bag of holding - assert bag_of_holding.inventory.add(ten_foot_pole) + assert bag_of_holding.add(ten_foot_pole) db.add_or_update(bag_of_holding) - pole_from_bag = bag_of_holding.inventory.get(ten_foot_pole) - assert pole_from_bag - assert pole_from_bag in bag_of_holding.inventory + pole_from_bag = bag_of_holding.get(ten_foot_pole) + assert pole_from_bag.item == ten_foot_pole + assert pole_from_bag in bag_of_holding assert pole_from_bag not in carl.equipment # add the bag of holding to carl's equipment @@ -183,4 +183,4 @@ def test_containers(db, carl): assert carls_pole == pole_from_bag # remove the pole from the bag - assert carls_bag.item.inventory.remove(pole_from_bag) + assert carls_bag.remove(pole_from_bag)