fix behaviour of random values on empty data sets, add tests
This commit is contained in:
parent
4ed084965e
commit
4922eeb3e6
|
@ -122,9 +122,9 @@ class DataSource:
|
|||
"""
|
||||
|
||||
# If there is no data for the specified option, stop now.
|
||||
flattened = [option]
|
||||
flattened = []
|
||||
if not self.data[option]:
|
||||
return random.choice(flattened) if rand else flattened
|
||||
raise ValueError(f"There is no data for '{option}' in your data source.")
|
||||
|
||||
if hasattr(self.data[option], 'keys'):
|
||||
# if the option is a dict, we assume the values are lists; we select a random item
|
||||
|
|
|
@ -22,7 +22,14 @@ class WeightedSet:
|
|||
self.members, self.weights = list(zip(*weighted_members))
|
||||
|
||||
def random(self) -> str:
|
||||
return random.choices(self.members, self.weights)[0]
|
||||
nonzero_members = []
|
||||
nonzero_weights = []
|
||||
for i in range(self.weights):
|
||||
if float(self.weights[i]) == 0.0:
|
||||
continue
|
||||
nozero_members.append(self.members[i])
|
||||
nozero_weights.append(self.weights[i])
|
||||
return random.choices(nonzero_members, nonzero_weights)[0]
|
||||
|
||||
def __add__(self, obj):
|
||||
ws = WeightedSet()
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from io import StringIO
|
||||
|
||||
from random_sets import datasources
|
||||
from pprint import pprint as print
|
||||
|
||||
fixture_metadata = """
|
||||
metadata:
|
||||
|
@ -47,3 +46,31 @@ def test_datasource_random_values():
|
|||
|
||||
# each value has an "Option", a "choice", and a "description"
|
||||
assert len(randvals[0]) == 3
|
||||
|
||||
|
||||
def test_zero_frequency():
|
||||
fixture = StringIO(fixture_metadata + fixture_source)
|
||||
ds = datasources.DataSource(fixture)
|
||||
ds.set_frequency('nondefault')
|
||||
for val in ds.random_values(count=100):
|
||||
assert 'Option 1' not in val
|
||||
|
||||
|
||||
def test_distribution_accuracy_to_one_decimal_place():
|
||||
fixture = StringIO(fixture_metadata + fixture_source)
|
||||
ds = datasources.DataSource(fixture)
|
||||
ds.set_frequency('nondefault')
|
||||
counts = {
|
||||
'Option 1': 0,
|
||||
'Option 2': 0,
|
||||
'Option 3': 0,
|
||||
}
|
||||
|
||||
population = 10000
|
||||
|
||||
for val in ds.random_values(count=population):
|
||||
counts[val[0]] += 1
|
||||
|
||||
for (option, count) in counts.items():
|
||||
observed = count/population
|
||||
assert round(observed, 1) == round(ds.frequencies[option], 1)
|
||||
|
|
Loading…
Reference in New Issue
Block a user