fix behaviour of random values on empty data sets, add tests
This commit is contained in:
parent
4ed084965e
commit
4922eeb3e6
|
@ -122,9 +122,9 @@ class DataSource:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# If there is no data for the specified option, stop now.
|
# If there is no data for the specified option, stop now.
|
||||||
flattened = [option]
|
flattened = []
|
||||||
if not self.data[option]:
|
if not self.data[option]:
|
||||||
return random.choice(flattened) if rand else flattened
|
raise ValueError(f"There is no data for '{option}' in your data source.")
|
||||||
|
|
||||||
if hasattr(self.data[option], 'keys'):
|
if hasattr(self.data[option], 'keys'):
|
||||||
# if the option is a dict, we assume the values are lists; we select a random item
|
# if the option is a dict, we assume the values are lists; we select a random item
|
||||||
|
|
|
@ -22,7 +22,14 @@ class WeightedSet:
|
||||||
self.members, self.weights = list(zip(*weighted_members))
|
self.members, self.weights = list(zip(*weighted_members))
|
||||||
|
|
||||||
def random(self) -> str:
|
def random(self) -> str:
|
||||||
return random.choices(self.members, self.weights)[0]
|
nonzero_members = []
|
||||||
|
nonzero_weights = []
|
||||||
|
for i in range(self.weights):
|
||||||
|
if float(self.weights[i]) == 0.0:
|
||||||
|
continue
|
||||||
|
nozero_members.append(self.members[i])
|
||||||
|
nozero_weights.append(self.weights[i])
|
||||||
|
return random.choices(nonzero_members, nonzero_weights)[0]
|
||||||
|
|
||||||
def __add__(self, obj):
|
def __add__(self, obj):
|
||||||
ws = WeightedSet()
|
ws = WeightedSet()
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
|
||||||
from random_sets import datasources
|
from random_sets import datasources
|
||||||
from pprint import pprint as print
|
|
||||||
|
|
||||||
fixture_metadata = """
|
fixture_metadata = """
|
||||||
metadata:
|
metadata:
|
||||||
|
@ -47,3 +46,31 @@ def test_datasource_random_values():
|
||||||
|
|
||||||
# each value has an "Option", a "choice", and a "description"
|
# each value has an "Option", a "choice", and a "description"
|
||||||
assert len(randvals[0]) == 3
|
assert len(randvals[0]) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_zero_frequency():
|
||||||
|
fixture = StringIO(fixture_metadata + fixture_source)
|
||||||
|
ds = datasources.DataSource(fixture)
|
||||||
|
ds.set_frequency('nondefault')
|
||||||
|
for val in ds.random_values(count=100):
|
||||||
|
assert 'Option 1' not in val
|
||||||
|
|
||||||
|
|
||||||
|
def test_distribution_accuracy_to_one_decimal_place():
|
||||||
|
fixture = StringIO(fixture_metadata + fixture_source)
|
||||||
|
ds = datasources.DataSource(fixture)
|
||||||
|
ds.set_frequency('nondefault')
|
||||||
|
counts = {
|
||||||
|
'Option 1': 0,
|
||||||
|
'Option 2': 0,
|
||||||
|
'Option 3': 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
population = 10000
|
||||||
|
|
||||||
|
for val in ds.random_values(count=population):
|
||||||
|
counts[val[0]] += 1
|
||||||
|
|
||||||
|
for (option, count) in counts.items():
|
||||||
|
observed = count/population
|
||||||
|
assert round(observed, 1) == round(ds.frequencies[option], 1)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user