fix behaviour of random values on empty data sets, add tests

This commit is contained in:
evilchili 2024-02-17 16:50:06 -08:00
parent 4ed084965e
commit 4922eeb3e6
3 changed files with 38 additions and 4 deletions

View File

@ -122,9 +122,9 @@ class DataSource:
"""
# If there is no data for the specified option, stop now.
flattened = [option]
flattened = []
if not self.data[option]:
return random.choice(flattened) if rand else flattened
raise ValueError(f"There is no data for '{option}' in your data source.")
if hasattr(self.data[option], 'keys'):
# if the option is a dict, we assume the values are lists; we select a random item

View File

@ -22,7 +22,14 @@ class WeightedSet:
self.members, self.weights = list(zip(*weighted_members))
def random(self) -> str:
return random.choices(self.members, self.weights)[0]
nonzero_members = []
nonzero_weights = []
for i in range(self.weights):
if float(self.weights[i]) == 0.0:
continue
nozero_members.append(self.members[i])
nozero_weights.append(self.weights[i])
return random.choices(nonzero_members, nonzero_weights)[0]
def __add__(self, obj):
ws = WeightedSet()

View File

@ -1,7 +1,6 @@
from io import StringIO
from random_sets import datasources
from pprint import pprint as print
fixture_metadata = """
metadata:
@ -47,3 +46,31 @@ def test_datasource_random_values():
# each value has an "Option", a "choice", and a "description"
assert len(randvals[0]) == 3
def test_zero_frequency():
fixture = StringIO(fixture_metadata + fixture_source)
ds = datasources.DataSource(fixture)
ds.set_frequency('nondefault')
for val in ds.random_values(count=100):
assert 'Option 1' not in val
def test_distribution_accuracy_to_one_decimal_place():
fixture = StringIO(fixture_metadata + fixture_source)
ds = datasources.DataSource(fixture)
ds.set_frequency('nondefault')
counts = {
'Option 1': 0,
'Option 2': 0,
'Option 3': 0,
}
population = 10000
for val in ds.random_values(count=population):
counts[val[0]] += 1
for (option, count) in counts.items():
observed = count/population
assert round(observed, 1) == round(ds.frequencies[option], 1)