random-sets/random_sets/sets.py

53 lines
1.4 KiB
Python
Raw Normal View History

2023-12-23 15:34:32 -08:00
import random
from pathlib import Path
from random_sets.datasources import DataSource
class WeightedSet:
"""
A set in which members each have a weight, used for selecting at random.
Usage:
>>> ws = WeightedSet(('foo', 1.0), ('bar', 0.5))
>>> ws.random()
('foo', 1.0)
"""
def __init__(self, *weighted_members: tuple):
self.members = []
self.weights = []
if weighted_members:
self.members, self.weights = list(zip(*weighted_members))
def random(self) -> str:
return random.choices(self.members, self.weights)[0]
def __add__(self, obj):
ws = WeightedSet()
ws.members = self.members + obj.members
ws.weights = self.weights + obj.weights
return ws
def __str__(self):
return f"{self.members}\n{self.weights}"
class DataSourceSet(WeightedSet):
def __init__(self, source: Path):
self.source = DataSource(source.read_text())
super().__init__(*[(key, value) for key, value in self.source.frequencies.items()])
def random(self):
random_key = super().random()
return self.source.as_dict()[random_key]
def equal_weights(terms: list, weight: float = 1.0, blank: bool = True) -> WeightedSet:
ws = WeightedSet(*[(term, weight) for term in terms])
if blank:
ws = WeightedSet(("", 1.0)) + ws
return ws