From 113e9940cc2ec965437d734a93a4ecf0401d4f85 Mon Sep 17 00:00:00 2001 From: Albert Julius Liu Date: Wed, 11 Sep 2024 23:36:43 -0700 Subject: [PATCH] zero-quantity outcomes are pruned from Die/Deck construction --- src/icepool/creation_args.py | 10 +++++++++- src/icepool/function.py | 5 +++++ src/icepool/population/deck.py | 3 +++ src/icepool/population/die.py | 3 ++- tests/comparator_test.py | 2 +- tests/evaluator_test.py | 2 +- tests/from_cumulative_test.py | 4 ++-- 7 files changed, 23 insertions(+), 6 deletions(-) diff --git a/src/icepool/creation_args.py b/src/icepool/creation_args.py index a3ecd4e8..49d127c2 100644 --- a/src/icepool/creation_args.py +++ b/src/icepool/creation_args.py @@ -104,8 +104,12 @@ def merge_weights_lcm(subdatas: Sequence[Mapping[T, int]], data: MutableMapping[Any, int] = defaultdict(int) for subdata, subdata_denominator, w in zip(subdatas, subdata_denominators, weights): - factor = denominator_lcm * w // subdata_denominator if subdata_denominator else 0 + if subdata_denominator == 0 or w == 0: + continue + factor = denominator_lcm * w // subdata_denominator for outcome, weight in subdata.items(): + if weight == 0: + continue data[outcome] += weight * factor return data @@ -122,7 +126,11 @@ def merge_duplicates(subdatas: Sequence[Mapping[T, int]], data: MutableMapping[Any, int] = defaultdict(int) for subdata, subdup in zip(subdatas, duplicates): + if subdup == 0: + continue for outcome, dup in subdata.items(): + if dup == 0: + continue data[outcome] += dup * subdup return data diff --git a/src/icepool/function.py b/src/icepool/function.py index 8e208ac4..b07df3d3 100644 --- a/src/icepool/function.py +++ b/src/icepool/function.py @@ -198,11 +198,16 @@ def from_rv(rv, outcomes: Sequence[float], denominator: int, def from_rv(rv, outcomes: Sequence[int] | Sequence[float], denominator: int, **kwargs) -> 'icepool.Die[int] | icepool.Die[float]': """Constructs a `Die` from a rv object (as `scipy.stats`). + + This is done using the CDF. + Args: rv: A rv object (as `scipy.stats`). outcomes: An iterable of `int`s or `float`s that will be the outcomes of the resulting `Die`. If the distribution is discrete, outcomes must be `int`s. + Some outcomes may be omitted if their probability is too small + compared to the denominator. denominator: The denominator of the resulting `Die` will be set to this. **kwargs: These will be forwarded to `rv.cdf()`. """ diff --git a/src/icepool/population/deck.py b/src/icepool/population/deck.py index 3b80a87b..805fef9e 100644 --- a/src/icepool/population/deck.py +++ b/src/icepool/population/deck.py @@ -34,6 +34,9 @@ def __new__(cls, times: Sequence[int] | int = 1) -> 'Deck[T_co]': """Constructor for a `Deck`. + All quantities must be non-negative. Outcomes with zero quantity will be + omitted. + Args: outcomes: The cards of the `Deck`. This can be one of the following: * A `Sequence` of outcomes. Duplicates will contribute diff --git a/src/icepool/population/die.py b/src/icepool/population/die.py index fe3fdbc2..54cbfb6b 100644 --- a/src/icepool/population/die.py +++ b/src/icepool/population/die.py @@ -87,7 +87,8 @@ def __new__( * Use a dict: `Die({1:1, 2:1, 3:1, 4:1, 5:1, 6:1})` * Give the faces as a sequence: `Die([1, 2, 3, 4, 5, 6])` - All quantities must be non-negative, though they can be zero. + All quantities must be non-negative. Outcomes with zero quantity will be + omitted. Several methods and functions foward **kwargs to this constructor. However, these only affect the construction of the returned or yielded diff --git a/tests/comparator_test.py b/tests/comparator_test.py index e1b01a45..f9e44bb9 100644 --- a/tests/comparator_test.py +++ b/tests/comparator_test.py @@ -106,7 +106,7 @@ def test_cmp_len(): assert len(icepool.d6.cmp(0)) == 1 assert len(icepool.d6.cmp(7)) == 1 assert len(icepool.Die([1]).cmp(1)) == 1 - assert len(icepool.Die({-1: 0, 0: 0, 1: 0}).cmp(0)) == 3 + assert len(icepool.Die({-1: 0, 0: 0, 1: 0}).cmp(0)) == 0 def test_quantity_le(): diff --git a/tests/evaluator_test.py b/tests/evaluator_test.py index 0a926e33..bd51ec8f 100644 --- a/tests/evaluator_test.py +++ b/tests/evaluator_test.py @@ -51,7 +51,7 @@ def test_sum_descending_keep_highest(): def test_zero_weight_outcomes(): result = icepool.Die(range(5), times=[0, 1, 0, 1, 0]).highest(3, 2) - assert len(result) == 9 + assert len(result) == 3 def sum_dice_func(state, outcome, count): diff --git a/tests/from_cumulative_test.py b/tests/from_cumulative_test.py index f8839262..02f96bb9 100644 --- a/tests/from_cumulative_test.py +++ b/tests/from_cumulative_test.py @@ -27,5 +27,5 @@ def test_from_rv_norm(): 1000000, loc=die.mean(), scale=die.standard_deviation()) - assert die.probabilities('<=') == pytest.approx( - norm_die.probabilities('<='), abs=1e-3) + assert [die.probability('<=', x) for x in range(600)] == pytest.approx( + [norm_die.probability('<=', x) for x in range(600)], abs=1e-3)