diff --git a/src/ttfrog/db/schema/container.py b/src/ttfrog/db/schema/container.py index 0afab29..78306b5 100644 --- a/src/ttfrog/db/schema/container.py +++ b/src/ttfrog/db/schema/container.py @@ -1,8 +1,9 @@ from typing import Union from sqlalchemy import ForeignKey -from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.orm import Mapped from sqlalchemy.orm import base as sa_base +from sqlalchemy.orm import mapped_column, relationship from ttfrog.db.schema.inventory import Inventory, InventoryMap, InventoryType from ttfrog.db.schema.item import Item, ItemType diff --git a/src/ttfrog/db/schema/inventory.py b/src/ttfrog/db/schema/inventory.py index 8111fcb..89f5a7e 100644 --- a/src/ttfrog/db/schema/inventory.py +++ b/src/ttfrog/db/schema/inventory.py @@ -1,8 +1,9 @@ from typing import List, Union from sqlalchemy import ForeignKey, UniqueConstraint -from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.orm import Mapped from sqlalchemy.orm import base as sa_base +from sqlalchemy.orm import mapped_column, relationship from ttfrog.db.base import BaseObject, EnumField from ttfrog.db.schema.item import Item, ItemProperty, ItemType @@ -145,11 +146,25 @@ class InventoryMap(BaseObject): self.attuned = False return True + def move_to(self, inventory): + if inventory == self.inventory: + return False + self.inventory.remove(self) + self.inventory = inventory + inventory.item_map.append(self) + return True + def __getattr__(self, name: str): if name == sa_base.DEFAULT_STATE_ATTR: raise AttributeError() return getattr(self.item, name) + def __contains__(self, obj): + if self.item.item_type == ItemType.CONTAINER: + return obj in self.item.inventory + raise RuntimeException("Item {self.item.name} is not a container.") + + class Inventory(BaseObject): __tablename__ = "inventory" __table_args__ = (UniqueConstraint("character_id", "container_id", "inventory_type"),) @@ -177,6 +192,7 @@ class Inventory(BaseObject): yield mapping if mapping.item.item_type == ItemType.CONTAINER: yield from inventory_contents(mapping.item.inventory) + yield from inventory_contents(self) @property @@ -186,6 +202,7 @@ class Inventory(BaseObject): yield mapping if mapping.item.item_type == ItemType.CONTAINER: yield from inventory_map(mapping.item.inventory) + yield from inventory_map(self) def get(self, item): @@ -222,8 +239,6 @@ class Inventory(BaseObject): yield from self.all_items - - class Charge(BaseObject): __tablename__ = "charge" id: Mapped[int] = mapped_column(init=False, primary_key=True, autoincrement=True) diff --git a/test/test_inventories.py b/test/test_inventories.py index f3da768..288e974 100644 --- a/test/test_inventories.py +++ b/test/test_inventories.py @@ -161,13 +161,18 @@ def test_spell_slots(db, carl, wizard): def test_containers(db, carl): with db.transaction(): ten_foot_pole = Item(name="10ft. Pole", item_type=ItemType.ITEM, consumable=False) + rope = Item(name="50 ft. of Rope", item_type=ItemType.ITEM, consumable=True, count=50) bag_of_holding = Container(name="Bag of Holding") - db.add_or_update([carl, ten_foot_pole, bag_of_holding]) + db.add_or_update([carl, ten_foot_pole, rope, bag_of_holding]) - # add the ten_foot_pole to the bag of holding + # add some items to the bag of holding assert bag_of_holding.add(ten_foot_pole) + assert bag_of_holding.add(rope) db.add_or_update(bag_of_holding) + pole_from_bag = bag_of_holding.get(ten_foot_pole) + rope_from_bag = bag_of_holding.get(rope) + assert pole_from_bag.item == ten_foot_pole assert pole_from_bag in bag_of_holding assert pole_from_bag not in carl.equipment @@ -176,11 +181,38 @@ def test_containers(db, carl): assert carl.equipment.add(bag_of_holding) db.add_or_update(bag_of_holding) assert pole_from_bag in carl.equipment + assert rope_from_bag in carl.equipment # test equality of mappings carls_bag = carl.equipment.get(bag_of_holding) carls_pole = carl.equipment.get(ten_foot_pole) + carls_rope = carl.equipment.get(rope) assert carls_pole == pole_from_bag + assert carls_rope == rope_from_bag - # remove the pole from the bag - assert carls_bag.remove(pole_from_bag) + # use some rope + carls_rope.consume(10) + assert carls_rope.count == 40 + + # move the rope out of the bag of holding, but not the pole + assert carls_rope.move_to(carl.equipment) + assert carls_rope not in carls_bag + assert carls_pole in carls_bag + db.add_or_update(carl) + + # get the db record anew, in case the in-memory representation isn't + # what's recorded in the database. Then make sure we didn't break + # anything by asserting we still only have 40ft of rope. + carl = db.Character.filter_by(name="carl").one() + assert carls_rope in carl.equipment + assert carls_rope not in carl.equipment.get(bag_of_holding) + assert carls_rope.count == 40 + + # old references are still valid + assert rope_from_bag == carls_rope + + # use the rest of the rope + assert carls_rope.consume(40) == 0 + print(rope_from_bag.inventory) + assert rope_from_bag not in carl.equipment + assert rope_from_bag not in carl.equipment.get(bag_of_holding)