add ability to move items between inventories

This commit is contained in:
evilchili 2024-09-02 14:53:49 -07:00
parent d9b3c4500e
commit 17a951b1b2
3 changed files with 56 additions and 8 deletions

View File

@ -1,8 +1,9 @@
from typing import Union from typing import Union
from sqlalchemy import ForeignKey from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped
from sqlalchemy.orm import base as sa_base from sqlalchemy.orm import base as sa_base
from sqlalchemy.orm import mapped_column, relationship
from ttfrog.db.schema.inventory import Inventory, InventoryMap, InventoryType from ttfrog.db.schema.inventory import Inventory, InventoryMap, InventoryType
from ttfrog.db.schema.item import Item, ItemType from ttfrog.db.schema.item import Item, ItemType

View File

@ -1,8 +1,9 @@
from typing import List, Union from typing import List, Union
from sqlalchemy import ForeignKey, UniqueConstraint from sqlalchemy import ForeignKey, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped
from sqlalchemy.orm import base as sa_base from sqlalchemy.orm import base as sa_base
from sqlalchemy.orm import mapped_column, relationship
from ttfrog.db.base import BaseObject, EnumField from ttfrog.db.base import BaseObject, EnumField
from ttfrog.db.schema.item import Item, ItemProperty, ItemType from ttfrog.db.schema.item import Item, ItemProperty, ItemType
@ -145,11 +146,25 @@ class InventoryMap(BaseObject):
self.attuned = False self.attuned = False
return True return True
def move_to(self, inventory):
if inventory == self.inventory:
return False
self.inventory.remove(self)
self.inventory = inventory
inventory.item_map.append(self)
return True
def __getattr__(self, name: str): def __getattr__(self, name: str):
if name == sa_base.DEFAULT_STATE_ATTR: if name == sa_base.DEFAULT_STATE_ATTR:
raise AttributeError() raise AttributeError()
return getattr(self.item, name) return getattr(self.item, name)
def __contains__(self, obj):
if self.item.item_type == ItemType.CONTAINER:
return obj in self.item.inventory
raise RuntimeException("Item {self.item.name} is not a container.")
class Inventory(BaseObject): class Inventory(BaseObject):
__tablename__ = "inventory" __tablename__ = "inventory"
__table_args__ = (UniqueConstraint("character_id", "container_id", "inventory_type"),) __table_args__ = (UniqueConstraint("character_id", "container_id", "inventory_type"),)
@ -177,6 +192,7 @@ class Inventory(BaseObject):
yield mapping yield mapping
if mapping.item.item_type == ItemType.CONTAINER: if mapping.item.item_type == ItemType.CONTAINER:
yield from inventory_contents(mapping.item.inventory) yield from inventory_contents(mapping.item.inventory)
yield from inventory_contents(self) yield from inventory_contents(self)
@property @property
@ -186,6 +202,7 @@ class Inventory(BaseObject):
yield mapping yield mapping
if mapping.item.item_type == ItemType.CONTAINER: if mapping.item.item_type == ItemType.CONTAINER:
yield from inventory_map(mapping.item.inventory) yield from inventory_map(mapping.item.inventory)
yield from inventory_map(self) yield from inventory_map(self)
def get(self, item): def get(self, item):
@ -222,8 +239,6 @@ class Inventory(BaseObject):
yield from self.all_items yield from self.all_items
class Charge(BaseObject): class Charge(BaseObject):
__tablename__ = "charge" __tablename__ = "charge"
id: Mapped[int] = mapped_column(init=False, primary_key=True, autoincrement=True) id: Mapped[int] = mapped_column(init=False, primary_key=True, autoincrement=True)

View File

@ -161,13 +161,18 @@ def test_spell_slots(db, carl, wizard):
def test_containers(db, carl): def test_containers(db, carl):
with db.transaction(): with db.transaction():
ten_foot_pole = Item(name="10ft. Pole", item_type=ItemType.ITEM, consumable=False) ten_foot_pole = Item(name="10ft. Pole", item_type=ItemType.ITEM, consumable=False)
rope = Item(name="50 ft. of Rope", item_type=ItemType.ITEM, consumable=True, count=50)
bag_of_holding = Container(name="Bag of Holding") bag_of_holding = Container(name="Bag of Holding")
db.add_or_update([carl, ten_foot_pole, bag_of_holding]) db.add_or_update([carl, ten_foot_pole, rope, bag_of_holding])
# add the ten_foot_pole to the bag of holding # add some items to the bag of holding
assert bag_of_holding.add(ten_foot_pole) assert bag_of_holding.add(ten_foot_pole)
assert bag_of_holding.add(rope)
db.add_or_update(bag_of_holding) db.add_or_update(bag_of_holding)
pole_from_bag = bag_of_holding.get(ten_foot_pole) pole_from_bag = bag_of_holding.get(ten_foot_pole)
rope_from_bag = bag_of_holding.get(rope)
assert pole_from_bag.item == ten_foot_pole assert pole_from_bag.item == ten_foot_pole
assert pole_from_bag in bag_of_holding assert pole_from_bag in bag_of_holding
assert pole_from_bag not in carl.equipment assert pole_from_bag not in carl.equipment
@ -176,11 +181,38 @@ def test_containers(db, carl):
assert carl.equipment.add(bag_of_holding) assert carl.equipment.add(bag_of_holding)
db.add_or_update(bag_of_holding) db.add_or_update(bag_of_holding)
assert pole_from_bag in carl.equipment assert pole_from_bag in carl.equipment
assert rope_from_bag in carl.equipment
# test equality of mappings # test equality of mappings
carls_bag = carl.equipment.get(bag_of_holding) carls_bag = carl.equipment.get(bag_of_holding)
carls_pole = carl.equipment.get(ten_foot_pole) carls_pole = carl.equipment.get(ten_foot_pole)
carls_rope = carl.equipment.get(rope)
assert carls_pole == pole_from_bag assert carls_pole == pole_from_bag
assert carls_rope == rope_from_bag
# remove the pole from the bag # use some rope
assert carls_bag.remove(pole_from_bag) carls_rope.consume(10)
assert carls_rope.count == 40
# move the rope out of the bag of holding, but not the pole
assert carls_rope.move_to(carl.equipment)
assert carls_rope not in carls_bag
assert carls_pole in carls_bag
db.add_or_update(carl)
# get the db record anew, in case the in-memory representation isn't
# what's recorded in the database. Then make sure we didn't break
# anything by asserting we still only have 40ft of rope.
carl = db.Character.filter_by(name="carl").one()
assert carls_rope in carl.equipment
assert carls_rope not in carl.equipment.get(bag_of_holding)
assert carls_rope.count == 40
# old references are still valid
assert rope_from_bag == carls_rope
# use the rest of the rope
assert carls_rope.consume(40) == 0
print(rope_from_bag.inventory)
assert rope_from_bag not in carl.equipment
assert rope_from_bag not in carl.equipment.get(bag_of_holding)