diff --git a/language/types.py b/language/types.py index d8dc22d..5e4b033 100644 --- a/language/types.py +++ b/language/types.py @@ -2,6 +2,7 @@ import inspect import random from collections import defaultdict from typing import Union +from random_sets.sets import WeightedSet, equal_weights class LanguageError(Exception): @@ -17,35 +18,6 @@ class ImprobableTemplateError(Exception): """ -class WeightedSet: - """ - A set in which members each have a weight, used for selecting at random. - - Usage: - >>> ws = WeightedSet(('foo', 1.0), ('bar', 0.5)) - >>> ws.random() - ('foo', 1.0) - """ - - def __init__(self, *weighted_members: tuple): - self.members = [] - self.weights = [] - if weighted_members: - self.members, self.weights = list(zip(*weighted_members)) - - def random(self) -> str: - return random.choices(self.members, self.weights)[0] - - def __add__(self, obj): - ws = WeightedSet() - ws.members = self.members + obj.members - ws.weights = self.weights + obj.weights - return ws - - def __str__(self): - return f"{self.members}\n{self.weights}" - - class Syllable: """ One syllable of a word. Used to populate a SyllableSet. @@ -402,9 +374,3 @@ class NameGenerator: def __str__(self) -> str: return self.name()[0]["fullname"] - -def equal_weights(terms: list, weight: float = 1.0, blank: bool = True) -> WeightedSet: - ws = WeightedSet(*[(term, weight) for term in terms]) - if blank: - ws = WeightedSet(("", 1.0)) + ws - return ws diff --git a/pyproject.toml b/pyproject.toml index a2c2f33..3419662 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ python = "^3.10" rich = "^13.7.0" typer = "^0.9.0" dice = "^4.0.0" +random_sets = { git = "https://github.com/evilchili/random-sets", branch="main" } [tool.poetry.group.dev.dependencies] pytest = "^7.4.3"