make inventories recursive

This commit is contained in:
evilchili 2024-09-02 12:39:59 -07:00
parent 01a4360dca
commit b09b07d172
4 changed files with 121 additions and 18 deletions

View File

@ -19,14 +19,16 @@ inventory_type_map = {
ItemType.SHIELD, ItemType.SHIELD,
ItemType.ITEM, ItemType.ITEM,
ItemType.SCROLL, ItemType.SCROLL,
ItemType.CONTAINER,
], ],
InventoryType.SPELL: [ItemType.SPELL], InventoryType.SPELL: [ItemType.SPELL],
} }
def inventory_map_creator(fields): def inventory_map_creator(fields):
if isinstance(fields, InventoryMap): # if isinstance(fields, InventoryMap):
return fields # return fields
# return InventoryMap(**fields)
return InventoryMap(**fields) return InventoryMap(**fields)
@ -36,7 +38,7 @@ class Inventory(BaseObject):
id: Mapped[int] = mapped_column(init=False, primary_key=True, autoincrement=True) id: Mapped[int] = mapped_column(init=False, primary_key=True, autoincrement=True)
inventory_type: Mapped[InventoryType] = mapped_column(nullable=False) inventory_type: Mapped[InventoryType] = mapped_column(nullable=False)
items: Mapped[List["InventoryMap"]] = relationship( item_map: Mapped[List["InventoryMap"]] = relationship(
uselist=True, cascade="all,delete,delete-orphan", lazy="immediate", default_factory=lambda: [] uselist=True, cascade="all,delete,delete-orphan", lazy="immediate", default_factory=lambda: []
) )
@ -46,11 +48,33 @@ class Inventory(BaseObject):
character = relationship("Character", init=False, viewonly=True, lazy="immediate") character = relationship("Character", init=False, viewonly=True, lazy="immediate")
container = relationship("Item", init=False, viewonly=True, lazy="immediate") container = relationship("Item", init=False, viewonly=True, lazy="immediate")
@property
def items(self):
return [mapping.item for mapping in self.item_map]
@property
def all_items(self):
def inventory_contents(inventory):
for mapping in inventory.item_map:
yield mapping
if mapping.item.item_type == ItemType.CONTAINER:
yield from inventory_contents(mapping.item.inventory)
yield from inventory_contents(self)
@property
def all_item_maps(self):
def inventory_map(inventory):
for mapping in inventory.item_map:
yield mapping
if mapping.item.item_type == ItemType.CONTAINER:
yield from inventory_map(mapping.item.inventory)
yield from inventory_map(self)
def get(self, item): def get(self, item):
return self.get_all(item)[0] return self.get_all(item)[0]
def get_all(self, item): def get_all(self, item):
return [mapping for mapping in self.items if mapping.item == item] return [mapping for mapping in self.all_item_maps if mapping.item == item]
def add(self, item): def add(self, item):
if item.item_type not in inventory_type_map[self.inventory_type]: if item.item_type not in inventory_type_map[self.inventory_type]:
@ -60,23 +84,23 @@ class Inventory(BaseObject):
mapping.count = item.count mapping.count = item.count
if item.charges: if item.charges:
mapping.charges = [Charge(inventory_map_id=mapping.id) for i in range(item.charges)] mapping.charges = [Charge(inventory_map_id=mapping.id) for i in range(item.charges)]
self.items.append(mapping) self.item_map.append(mapping)
return mapping return mapping
def remove(self, mapping): def remove(self, mapping):
if mapping not in self.items: if mapping in self.item_map:
return False self.item_map.remove(mapping)
self.items.remove(mapping) return True
return True return False
def __contains__(self, obj): def __contains__(self, obj):
for mapping in self.items: for item in self.all_items:
if mapping.item == obj: if item == obj:
return True return True
return False return False
def __iter__(self): def __iter__(self):
yield from self.items yield from self.all_items
class InventoryMap(BaseObject): class InventoryMap(BaseObject):
@ -138,7 +162,7 @@ class InventoryMap(BaseObject):
charges = item_property.charge_cost charges = item_property.charge_cost
if len(avail) < charges: if len(avail) < charges:
return False return False
for charge in avail: for charge in avail[:charges]:
charge.expended = True charge.expended = True
return True return True

View File

@ -144,7 +144,7 @@ class ItemProperty(BaseObject):
id: Mapped[int] = mapped_column(init=False, primary_key=True, autoincrement=True) id: Mapped[int] = mapped_column(init=False, primary_key=True, autoincrement=True)
name: Mapped[str] = mapped_column(String(collation="NOCASE"), nullable=False, unique=True) name: Mapped[str] = mapped_column(String(collation="NOCASE"), nullable=False, unique=True)
description: Mapped[str] = mapped_column(String, nullable=True, default=None) description: Mapped[str] = mapped_column(String, nullable=True, default=None)
charge_cost: Mapped[int] = mapped_column(nullable=True, info={"min": 0}, default=None) charge_cost: Mapped[int] = mapped_column(nullable=True, info={"min": 1}, default=None)
item_id: Mapped[int] = mapped_column(ForeignKey("item.id"), default=0) item_id: Mapped[int] = mapped_column(ForeignKey("item.id"), default=0)
# action/reaction/bonus # action/reaction/bonus

View File

@ -1,3 +1,4 @@
from ttfrog.db.schema.container import Container
from ttfrog.db.schema.item import Item, ItemType, Spell from ttfrog.db.schema.item import Item, ItemType, Spell
@ -41,7 +42,15 @@ def test_equipment_inventory(db, carl):
# can't equip it twice # can't equip it twice
assert not pole_one.equip() assert not pole_one.equip()
# unequip it # can't prepare or cast an item
assert not pole_one.prepare()
assert not pole_one.unprepare()
assert not pole_one.cast()
# not consumable or attunable
assert not pole_one.consume()
assert not pole_one.attune()
assert pole_one.unequip() assert pole_one.unequip()
# can't unequip the unequipped ones # can't unequip the unequipped ones
@ -123,6 +132,7 @@ def test_spell_slots(db, carl, wizard):
assert carl.spellcaster_level == 3 assert carl.spellcaster_level == 3
# cast fireball until he's out of 3rd level slots # cast fireball until he's out of 3rd level slots
assert not carl.spells.get(fireball).cast()
assert carl.spells.get(fireball).prepare() assert carl.spells.get(fireball).prepare()
assert carl.spells.get(fireball).cast() assert carl.spells.get(fireball).cast()
assert carl.spells.get(fireball).cast() assert carl.spells.get(fireball).cast()
@ -142,3 +152,35 @@ def test_spell_slots(db, carl, wizard):
# use the last 3rd level slot # use the last 3rd level slot
assert carl.spells.get(fireball).cast() assert carl.spells.get(fireball).cast()
assert not carl.spells.get(fireball).cast() assert not carl.spells.get(fireball).cast()
# unprepare it
assert carl.spells.get(fireball).unprepare()
assert not carl.spells.get(fireball).unprepare()
def test_containers(db, carl):
with db.transaction():
ten_foot_pole = Item(name="10ft. Pole", item_type=ItemType.ITEM, consumable=False)
bag_of_holding = Container(name="Bag of Holding")
db.add_or_update([carl, ten_foot_pole, bag_of_holding])
# add the ten_foot_pole to the bag of holding
assert bag_of_holding.inventory.add(ten_foot_pole)
db.add_or_update(bag_of_holding)
pole_from_bag = bag_of_holding.inventory.get(ten_foot_pole)
assert pole_from_bag
assert pole_from_bag in bag_of_holding.inventory
assert pole_from_bag not in carl.equipment
# add the bag of holding to carl's equipment
assert carl.equipment.add(bag_of_holding)
db.add_or_update(bag_of_holding)
assert pole_from_bag in carl.equipment
# test equality of mappings
carls_bag = carl.equipment.get(bag_of_holding)
carls_pole = carl.equipment.get(ten_foot_pole)
assert carls_pole == pole_from_bag
# remove the pole from the bag
assert carls_bag.item.inventory.remove(pole_from_bag)

View File

@ -1,5 +1,5 @@
from ttfrog.db.schema.constants import DamageType, Defenses from ttfrog.db.schema.constants import DamageType, Defenses
from ttfrog.db.schema.item import Armor, ItemProperty, Rarity, RechargeTime, Shield, Weapon from ttfrog.db.schema.item import Armor, Item, ItemProperty, Rarity, RechargeTime, Shield, Weapon
from ttfrog.db.schema.modifiers import Modifier from ttfrog.db.schema.modifiers import Modifier
@ -49,7 +49,7 @@ def test_charges(db, carl):
saving throw. On a failure, the target is forced to grin for one minute. While grinning, the target saving throw. On a failure, the target is forced to grin for one minute. While grinning, the target
cannot speak. The target can repeat the saving throw at the start of their turn." cannot speak. The target can repeat the saving throw at the start of their turn."
""", """,
charge_cost=1, charge_cost=2,
) )
# from sqlalchemy.orm import relationship # from sqlalchemy.orm import relationship
@ -88,6 +88,29 @@ def test_charges(db, carl):
assert len(carls_dagger.charges) == dagger_of_lulz.charges == 6 assert len(carls_dagger.charges) == dagger_of_lulz.charges == 6
assert len(carls_dagger.charges_available) == dagger_of_lulz.charges == 6 assert len(carls_dagger.charges_available) == dagger_of_lulz.charges == 6
assert carls_dagger.use(for_the_lulz) assert carls_dagger.use(for_the_lulz)
assert len(carls_dagger.charges_available) == 4
# use the remaining charges
assert carls_dagger.use(for_the_lulz)
assert carls_dagger.use(for_the_lulz)
# all out of charges
assert len(carls_dagger.charges_available) == 0
assert not carls_dagger.use(for_the_lulz)
def test_nocharges(db, carl):
smiles = ItemProperty(name="Smile!", description="The target grins for one minute.", charge_cost=None)
wand_of_unlimited_smiles = Item(name="Wand of Unlimited Smiles", description="description", properties=[smiles])
db.add_or_update(wand_of_unlimited_smiles)
carl.equipment.add(wand_of_unlimited_smiles)
db.add_or_update(carl)
# no charges means you can use it at will
assert carl.equipment.get(wand_of_unlimited_smiles).use(smiles)
assert carl.equipment.get(wand_of_unlimited_smiles).use(smiles)
assert carl.equipment.get(wand_of_unlimited_smiles).use(smiles)
def test_attunement(db, carl): def test_attunement(db, carl):
@ -139,7 +162,8 @@ def test_attunement(db, carl):
assert carl.armor_class == 12 assert carl.armor_class == 12
assert carls_shield not in carl.attuned_items assert carls_shield not in carl.attuned_items
carls_shield.attune() assert carls_shield.attune()
assert not carls_shield.attune()
assert carl.armor_class == 12 assert carl.armor_class == 12
assert plus_two_ac in carl.modifiers["armor_class"] assert plus_two_ac in carl.modifiers["armor_class"]
assert ranged_resistance in carl.modifiers[DamageType.ranged_weapon_attacks] assert ranged_resistance in carl.modifiers[DamageType.ranged_weapon_attacks]
@ -149,7 +173,20 @@ def test_attunement(db, carl):
assert carl.armor_class == 13 assert carl.armor_class == 13
assert carls_shield.unattune() assert carls_shield.unattune()
assert not carls_shield.unattune()
assert carl.armor_class == 13 assert carl.armor_class == 13
assert ranged_resistance not in carl.modifiers[DamageType.ranged_weapon_attacks] assert ranged_resistance not in carl.modifiers[DamageType.ranged_weapon_attacks]
assert carls_shield.unequip() assert carls_shield.unequip()
assert carl.armor_class == 11 assert carl.armor_class == 11
# can only attune 3 items
assert carl.equipment.add(shield)
assert carl.equipment.add(shield)
assert carl.equipment.add(shield)
db.add_or_update(carl)
assert carl.equipment.get_all(shield)[0].attune()
assert carl.equipment.get_all(shield)[1].attune()
assert carl.equipment.get_all(shield)[2].attune()
assert len(carl.attuned_items) == 3
assert not carl.equipment.get_all(shield)[3].attune()