diff --git a/random_sets/datasources.py b/random_sets/datasources.py index 18b1c30..442e82f 100644 --- a/random_sets/datasources.py +++ b/random_sets/datasources.py @@ -122,9 +122,9 @@ class DataSource: """ # If there is no data for the specified option, stop now. - flattened = [option] + flattened = [] if not self.data[option]: - return random.choice(flattened) if rand else flattened + raise ValueError(f"There is no data for '{option}' in your data source.") if hasattr(self.data[option], 'keys'): # if the option is a dict, we assume the values are lists; we select a random item diff --git a/random_sets/sets.py b/random_sets/sets.py index d17bcfa..5c58da0 100644 --- a/random_sets/sets.py +++ b/random_sets/sets.py @@ -22,7 +22,14 @@ class WeightedSet: self.members, self.weights = list(zip(*weighted_members)) def random(self) -> str: - return random.choices(self.members, self.weights)[0] + nonzero_members = [] + nonzero_weights = [] + for i in range(self.weights): + if float(self.weights[i]) == 0.0: + continue + nozero_members.append(self.members[i]) + nozero_weights.append(self.weights[i]) + return random.choices(nonzero_members, nonzero_weights)[0] def __add__(self, obj): ws = WeightedSet() diff --git a/test/test_datasources.py b/test/test_datasources.py index a17f5a1..a6ec357 100644 --- a/test/test_datasources.py +++ b/test/test_datasources.py @@ -1,7 +1,6 @@ from io import StringIO from random_sets import datasources -from pprint import pprint as print fixture_metadata = """ metadata: @@ -47,3 +46,31 @@ def test_datasource_random_values(): # each value has an "Option", a "choice", and a "description" assert len(randvals[0]) == 3 + + +def test_zero_frequency(): + fixture = StringIO(fixture_metadata + fixture_source) + ds = datasources.DataSource(fixture) + ds.set_frequency('nondefault') + for val in ds.random_values(count=100): + assert 'Option 1' not in val + + +def test_distribution_accuracy_to_one_decimal_place(): + fixture = StringIO(fixture_metadata + fixture_source) + ds = datasources.DataSource(fixture) + ds.set_frequency('nondefault') + counts = { + 'Option 1': 0, + 'Option 2': 0, + 'Option 3': 0, + } + + population = 10000 + + for val in ds.random_values(count=population): + counts[val[0]] += 1 + + for (option, count) in counts.items(): + observed = count/population + assert round(observed, 1) == round(ds.frequencies[option], 1)