From ee7c3056ba3b15e866e7ceb095393b1caf9a522b Mon Sep 17 00:00:00 2001 From: Albert Julius Liu Date: Wed, 25 Dec 2024 14:34:22 -0800 Subject: [PATCH] more progress on expression rework #203 --- src/icepool/__init__.py | 6 +- src/icepool/evaluator/expression.py | 5 +- src/icepool/evaluator/multiset_evaluator.py | 50 ++-- src/icepool/expression/__init__.py | 2 +- src/icepool/generator/alignment.py | 5 +- src/icepool/generator/compound_keep.py | 18 +- src/icepool/generator/deal.py | 13 +- src/icepool/generator/keep.py | 8 +- src/icepool/generator/mixture.py | 20 +- src/icepool/generator/multi_deal.py | 16 +- src/icepool/generator/multiset_generator.py | 2 +- src/icepool/generator/pool.py | 19 +- src/icepool/multiset_expression.py | 259 +++++++++++++++--- .../{expression => }/multiset_function.py | 2 +- src/icepool/multiset_variable.py | 19 +- src/icepool/transform/__init__.py | 5 + src/icepool/transform/binary_operator.py | 114 ++++++++ src/icepool/transform/multiset_transform.py | 36 +-- src/icepool/typing.py | 8 +- tests/evaluator_test.py | 2 +- tests/import_all_test.py | 2 +- tests/neon_city_overdrive_test.py | 2 +- tests/pop_order_test.py | 40 +-- tests/subset_target_test.py | 2 +- 24 files changed, 485 insertions(+), 170 deletions(-) rename src/icepool/{expression => }/multiset_function.py (98%) create mode 100644 src/icepool/transform/binary_operator.py diff --git a/src/icepool/__init__.py b/src/icepool/__init__.py index b0f5d901..805cbd3b 100644 --- a/src/icepool/__init__.py +++ b/src/icepool/__init__.py @@ -127,7 +127,7 @@ from icepool.multiset_expression import MultisetExpression, implicit_convert_to_expression -from icepool.generator.multiset_generator import MultisetGenerator, InitialMultisetGenerator, NextMultisetGenerator +from icepool.generator.multiset_generator import MultisetGenerator from icepool.generator.alignment import Alignment from icepool.evaluator.multiset_evaluator import MultisetEvaluator @@ -135,14 +135,14 @@ from icepool.generator.deal import Deal from icepool.generator.multi_deal import MultiDeal -from icepool.expression.multiset_function import multiset_function +from icepool.multiset_function import multiset_function from icepool.multiset_variable import MultisetVariable from icepool.population.format import format_probability_inverse -import icepool.expression as expression import icepool.generator as generator import icepool.evaluator as evaluator +import icepool.transform as transform import icepool.typing as typing diff --git a/src/icepool/evaluator/expression.py b/src/icepool/evaluator/expression.py index 8e0b04d9..3e91b723 100644 --- a/src/icepool/evaluator/expression.py +++ b/src/icepool/evaluator/expression.py @@ -3,7 +3,7 @@ from functools import cached_property import itertools import icepool -import icepool.expression +import icepool.multiset_expression from icepool.evaluator.multiset_evaluator import MultisetEvaluator from icepool.typing import Order, Outcome, T, U_co @@ -15,7 +15,8 @@ class ExpressionEvaluator(MultisetEvaluator[T, U_co]): """Assigns an expression to be evaluated first to each input of an evaluator.""" def __init__(self, - *expressions: 'icepool.expression.MultisetExpression[T]', + *expressions: + 'icepool.multiset_expression.MultisetExpression[T]', evaluator: MultisetEvaluator[T, U_co], truth_value: bool | None = None) -> None: self._evaluator = evaluator diff --git a/src/icepool/evaluator/multiset_evaluator.py b/src/icepool/evaluator/multiset_evaluator.py index 62de1354..69a78d51 100644 --- a/src/icepool/evaluator/multiset_evaluator.py +++ b/src/icepool/evaluator/multiset_evaluator.py @@ -198,7 +198,7 @@ def validate_arity(self, arity: int) -> None: @cached_property def _cache( self - ) -> 'MutableMapping[tuple[Order, Alignment, tuple[MultisetGenerator, ...], Hashable], Mapping[Any, int]]': + ) -> 'MutableMapping[tuple[Order, Alignment, tuple[MultisetExpression, ...], Hashable], Mapping[Any, int]]': """Cached results. The key is `(order, extra_outcomes, generators, state)`. @@ -257,34 +257,26 @@ def evaluate( from icepool.evaluator.expression import ExpressionEvaluator return ExpressionEvaluator(*expressions, evaluator=self) - if not all( - isinstance(expression, icepool.MultisetGenerator) - for expression in expressions): - from icepool.evaluator.expression import ExpressionEvaluator - return ExpressionEvaluator(*expressions, evaluator=self).evaluate() - - generators = cast(tuple[icepool.MultisetGenerator, ...], expressions) - self.validate_arity( - sum(generator.output_arity() for generator in generators)) + sum(expression.output_arity() for expression in expressions)) - generators = self.prefix_generators() + generators + expressions = self.prefix_generators() + expressions - if not all(generator._is_resolvable() for generator in generators): + if not all(expression._is_resolvable() for expression in expressions): return icepool.Die([]) - algorithm, order = self._select_algorithm(*generators) + algorithm, order = self._select_algorithm(*expressions) - outcomes = icepool.sorted_union(*(generator.outcomes() - for generator in generators)) + outcomes = icepool.sorted_union(*(expression.outcomes() + for expression in expressions)) extra_outcomes = Alignment(self.extra_outcomes(outcomes)) dist: MutableMapping[Any, int] = defaultdict(int) - iterators = MultisetEvaluator._initialize_generators(generators) + iterators = MultisetEvaluator._initialize_expressions(expressions) for p in itertools.product(*iterators): - sub_generators, sub_weights = zip(*p) + sub_expressions, sub_weights = zip(*p) prod_weight = math.prod(sub_weights) - sub_result = algorithm(order, extra_outcomes, sub_generators) + sub_result = algorithm(order, extra_outcomes, sub_expressions) for sub_state, sub_weight in sub_result.items(): dist[sub_state] += sub_weight * prod_weight @@ -306,9 +298,9 @@ def evaluate( __call__ = evaluate def _select_algorithm( - self, *generators: 'icepool.MultisetGenerator[T, Any]' + self, *generators: 'icepool.MultisetExpression[T]' ) -> tuple[ - 'Callable[[Order, Alignment[T], tuple[icepool.MultisetGenerator[T, Any], ...]], Mapping[Any, int]]', + 'Callable[[Order, Alignment[T], tuple[icepool.MultisetExpression[T], ...]], Mapping[Any, int]]', Order]: """Selects an algorithm and iteration order. @@ -324,7 +316,8 @@ def _select_algorithm( return self._eval_internal, eval_order preferred_pop_order, pop_order_reason = merge_pop_orders( - *(generator._preferred_pop_order() for generator in generators)) + *(generator._local_preferred_pop_order() + for generator in generators)) if preferred_pop_order is None: preferred_pop_order = Order.Any @@ -346,7 +339,7 @@ def _select_algorithm( def _eval_internal( self, order: Order, extra_outcomes: 'Alignment[T]', - generators: 'tuple[icepool.MultisetGenerator[T, Any], ...]' + generators: 'tuple[icepool.MultisetExpression[T], ...]' ) -> Mapping[Any, int]: """Internal algorithm for iterating in the more-preferred order. @@ -438,16 +431,17 @@ def _eval_internal_forward( return result @staticmethod - def _initialize_generators( - generators: 'tuple[icepool.MultisetGenerator[T, Any], ...]' - ) -> 'tuple[icepool.InitialMultisetGenerator, ...]': - return tuple(generator._generate_initial() for generator in generators) + def _initialize_expressions( + expressions: 'tuple[icepool.MultisetExpression[T], ...]' + ) -> 'tuple[icepool.InitialMultisetGeneration, ...]': + return tuple(expression._generate_initial() + for expression in expressions) @staticmethod def _pop_generators( order: Order, extra_outcomes: 'Alignment[T]', - generators: 'tuple[icepool.MultisetGenerator[T, Any], ...]' - ) -> 'tuple[T, Alignment[T], tuple[icepool.NextMultisetGenerator, ...]]': + generators: 'tuple[icepool.MultisetExpression[T], ...]' + ) -> 'tuple[T, Alignment[T], tuple[icepool.PopMultisetGeneration, ...]]': """Pops a single outcome from the generators. Args: diff --git a/src/icepool/expression/__init__.py b/src/icepool/expression/__init__.py index 102eba67..fa04301f 100644 --- a/src/icepool/expression/__init__.py +++ b/src/icepool/expression/__init__.py @@ -13,7 +13,7 @@ from icepool.expression.keep import KeepExpression from icepool.expression.match import SortMatchExpression, MaximumMatchExpression -from icepool.expression.multiset_function import multiset_function +from icepool.multiset_function import multiset_function __all__ = [ 'multiset_function', 'MultisetExpression', 'MultisetVariable', diff --git a/src/icepool/generator/alignment.py b/src/icepool/generator/alignment.py index 067b1859..bb67c789 100644 --- a/src/icepool/generator/alignment.py +++ b/src/icepool/generator/alignment.py @@ -50,12 +50,13 @@ def _generate_max(self, max_outcome) -> AlignmentGenerator: else: yield Alignment(self.outcomes()[:-1]), (), 1 - def _preferred_pop_order(self) -> tuple[Order | None, PopOrderReason]: + def _local_preferred_pop_order( + self) -> tuple[Order | None, PopOrderReason]: return Order.Any, PopOrderReason.NoPreference def denominator(self) -> int: return 1 @cached_property - def _hash_key(self) -> Hashable: + def _local_hash_key(self) -> Hashable: return Alignment, self._outcomes diff --git a/src/icepool/generator/compound_keep.py b/src/icepool/generator/compound_keep.py index 2f972ff1..692dba1e 100644 --- a/src/icepool/generator/compound_keep.py +++ b/src/icepool/generator/compound_keep.py @@ -1,8 +1,9 @@ __docformat__ = 'google' import icepool +from icepool.multiset_expression import InitialMultisetGeneration, PopMultisetGeneration from icepool.generator.keep import KeepGenerator, pop_max_from_keep_tuple, pop_min_from_keep_tuple -from icepool.generator.multiset_generator import InitialMultisetGenerator, NextMultisetGenerator, MultisetGenerator +from icepool.generator.multiset_generator import MultisetGenerator from icepool.generator.pop_order import PopOrderReason, merge_pop_orders import itertools @@ -22,7 +23,7 @@ def __init__(self, inners: Sequence[KeepGenerator[T]], def outcomes(self) -> Sequence[T]: return icepool.sorted_union(*(inner.outcomes() - for inner in self._inner_generators)) + for inner in self._inner_generators)) def output_arity(self) -> int: return 1 @@ -30,10 +31,10 @@ def output_arity(self) -> int: def _is_resolvable(self) -> bool: return all(inner._is_resolvable() for inner in self._inner_generators) - def _generate_initial(self) -> InitialMultisetGenerator: + def _generate_initial(self) -> InitialMultisetGeneration: yield self, 1 - def _generate_min(self, min_outcome) -> NextMultisetGenerator: + def _generate_min(self, min_outcome) -> PopMultisetGeneration: for t in itertools.product(*(inner._generate_min(min_outcome) for inner in self._inner_generators)): generators, counts, weights = zip(*t) @@ -44,7 +45,7 @@ def _generate_min(self, min_outcome) -> NextMultisetGenerator: yield CompoundKeepGenerator( generators, popped_keep_tuple), (result_count, ), total_weight - def _generate_max(self, max_outcome) -> NextMultisetGenerator: + def _generate_max(self, max_outcome) -> PopMultisetGeneration: for t in itertools.product(*(inner._generate_max(max_outcome) for inner in self._inner_generators)): generators, counts, weights = zip(*t) @@ -55,8 +56,9 @@ def _generate_max(self, max_outcome) -> NextMultisetGenerator: yield CompoundKeepGenerator( generators, popped_keep_tuple), (result_count, ), total_weight - def _preferred_pop_order(self) -> tuple[Order | None, PopOrderReason]: - return merge_pop_orders(*(inner._preferred_pop_order() + def _local_preferred_pop_order( + self) -> tuple[Order | None, PopOrderReason]: + return merge_pop_orders(*(inner._local_preferred_pop_order() for inner in self._inner_generators)) def denominator(self) -> int: @@ -68,7 +70,7 @@ def _set_keep_tuple(self, keep_tuple: tuple[int, return CompoundKeepGenerator(self._inner_generators, keep_tuple) @property - def _hash_key(self) -> Hashable: + def _local_hash_key(self) -> Hashable: return CompoundKeepGenerator, tuple( inner._hash_key for inner in self._inner_generators), self._keep_tuple diff --git a/src/icepool/generator/deal.py b/src/icepool/generator/deal.py index 9b9396c7..5f9b4fbc 100644 --- a/src/icepool/generator/deal.py +++ b/src/icepool/generator/deal.py @@ -3,7 +3,7 @@ import icepool from icepool.generator.keep import KeepGenerator, pop_max_from_keep_tuple, pop_min_from_keep_tuple from icepool.collection.counts import CountsKeysView -from icepool.generator.multiset_generator import InitialMultisetGenerator, NextMultisetGenerator +from icepool.multiset_expression import InitialMultisetGeneration, PopMultisetGeneration import icepool.generator.pop_order from icepool.generator.pop_order import PopOrderReason @@ -85,10 +85,10 @@ def _is_resolvable(self) -> bool: def denominator(self) -> int: return icepool.math.comb(self.deck().size(), self._hand_size) - def _generate_initial(self) -> InitialMultisetGenerator: + def _generate_initial(self) -> InitialMultisetGeneration: yield self, 1 - def _generate_min(self, min_outcome) -> NextMultisetGenerator: + def _generate_min(self, min_outcome) -> PopMultisetGeneration: if not self.outcomes() or min_outcome != self.min_outcome(): yield self, (0, ), 1 return @@ -115,7 +115,7 @@ def _generate_min(self, min_outcome) -> NextMultisetGenerator: popped_deal = Deal._new_raw(popped_deck, 0, ()) yield popped_deal, (sum(self.keep_tuple()), ), skip_weight - def _generate_max(self, max_outcome) -> NextMultisetGenerator: + def _generate_max(self, max_outcome) -> PopMultisetGeneration: if not self.outcomes() or max_outcome != self.max_outcome(): yield self, (0, ), 1 return @@ -142,7 +142,8 @@ def _generate_max(self, max_outcome) -> NextMultisetGenerator: popped_deal = Deal._new_raw(popped_deck, 0, ()) yield popped_deal, (sum(self.keep_tuple()), ), skip_weight - def _preferred_pop_order(self) -> tuple[Order | None, PopOrderReason]: + def _local_preferred_pop_order( + self) -> tuple[Order | None, PopOrderReason]: lo_skip, hi_skip = icepool.generator.pop_order.lo_hi_skip( self.keep_tuple()) if lo_skip > hi_skip: @@ -153,7 +154,7 @@ def _preferred_pop_order(self) -> tuple[Order | None, PopOrderReason]: return Order.Any, PopOrderReason.NoPreference @cached_property - def _hash_key(self) -> Hashable: + def _local_hash_key(self) -> Hashable: return Deal, self.deck(), self._hand_size, self._keep_tuple def __repr__(self) -> str: diff --git a/src/icepool/generator/keep.py b/src/icepool/generator/keep.py index e771ab72..ea0729b6 100644 --- a/src/icepool/generator/keep.py +++ b/src/icepool/generator/keep.py @@ -1,7 +1,8 @@ __docformat__ = 'google' import icepool -from icepool.generator.multiset_generator import InitialMultisetGenerator, NextMultisetGenerator, MultisetGenerator +from icepool.multiset_expression import InitialMultisetGeneration, PopMultisetGeneration +from icepool.generator.multiset_generator import MultisetGenerator import operator from collections import defaultdict @@ -10,10 +11,11 @@ from abc import ABC, abstractmethod from types import EllipsisType from typing import Hashable, Literal, Mapping, MutableMapping, Sequence, cast, overload, TYPE_CHECKING +import icepool.multiset_expression from icepool.typing import ImplicitConversionError, Outcome, T if TYPE_CHECKING: - from icepool.expression import MultisetExpression + from icepool.multiset_expression import MultisetExpression class KeepGenerator(MultisetGenerator[T, tuple[int]]): @@ -235,7 +237,7 @@ def additive_union( *args: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]' ) -> 'MultisetExpression[T]': args = tuple( - icepool.expression.implicit_convert_to_expression(arg) + icepool.multiset_expression.implicit_convert_to_expression(arg) for arg in args) if all(isinstance(arg, KeepGenerator) for arg in args): generators = cast(tuple[KeepGenerator, ...], args) diff --git a/src/icepool/generator/mixture.py b/src/icepool/generator/mixture.py index 711f89c2..5e024684 100644 --- a/src/icepool/generator/mixture.py +++ b/src/icepool/generator/mixture.py @@ -2,7 +2,8 @@ import icepool -from icepool.generator.multiset_generator import InitialMultisetGenerator, NextMultisetGenerator, MultisetGenerator +from icepool.multiset_expression import InitialMultisetGeneration, PopMultisetGeneration +from icepool.generator.multiset_generator import MultisetGenerator from icepool.generator.pop_order import PopOrderReason, merge_pop_orders import math @@ -15,7 +16,7 @@ from typing import TYPE_CHECKING, Callable, Hashable, Literal, Mapping, MutableMapping, Sequence, overload if TYPE_CHECKING: - from icepool.expression import MultisetExpression + from icepool.multiset_expression import MultisetExpression class MixtureGenerator(MultisetGenerator[T, tuple[int]]): @@ -63,7 +64,7 @@ def __init__(self, def outcomes(self) -> Sequence[T]: return icepool.sorted_union(*(inner.outcomes() - for inner in self._inner_generators)) + for inner in self._inner_generators)) def output_arity(self) -> int: result = None @@ -79,21 +80,22 @@ def output_arity(self) -> int: def _is_resolvable(self) -> bool: return all(inner._is_resolvable() for inner in self._inner_generators) - def _generate_initial(self) -> InitialMultisetGenerator: + def _generate_initial(self) -> InitialMultisetGeneration: yield from self._inner_generators.items() - def _generate_min(self, min_outcome) -> NextMultisetGenerator: + def _generate_min(self, min_outcome) -> PopMultisetGeneration: raise RuntimeError( 'MixtureMultisetGenerator should have decayed to another generator type by this point.' ) - def _generate_max(self, max_outcome) -> NextMultisetGenerator: + def _generate_max(self, max_outcome) -> PopMultisetGeneration: raise RuntimeError( 'MixtureMultisetGenerator should have decayed to another generator type by this point.' ) - def _preferred_pop_order(self) -> tuple[Order | None, PopOrderReason]: - return merge_pop_orders(*(inner._preferred_pop_order() + def _local_preferred_pop_order( + self) -> tuple[Order | None, PopOrderReason]: + return merge_pop_orders(*(inner._local_preferred_pop_order() for inner in self._inner_generators)) @cached_property @@ -107,7 +109,7 @@ def denominator(self) -> int: return self._denominator @property - def _hash_key(self) -> Hashable: + def _local_hash_key(self) -> Hashable: # This is not intended to be cached directly, so we are a little loose here. return MixtureGenerator, tuple(self._inner_generators.items()) diff --git a/src/icepool/generator/multi_deal.py b/src/icepool/generator/multi_deal.py index 89a3eb3d..e5565a01 100644 --- a/src/icepool/generator/multi_deal.py +++ b/src/icepool/generator/multi_deal.py @@ -5,7 +5,8 @@ from typing import Any, Hashable, cast import icepool from icepool.collection.counts import CountsKeysView -from icepool.generator.multiset_generator import InitialMultisetGenerator, NextMultisetGenerator, MultisetGenerator +from icepool.multiset_expression import InitialMultisetGeneration, PopMultisetGeneration +from icepool.generator.multiset_generator import MultisetGenerator from icepool.math import iter_hypergeom from icepool.generator.pop_order import PopOrderReason @@ -92,11 +93,11 @@ def _denomiator(self) -> int: def denominator(self) -> int: return self._denomiator - def _generate_initial(self) -> InitialMultisetGenerator: + def _generate_initial(self) -> InitialMultisetGeneration: yield self, 1 def _generate_common(self, popped_deck: 'icepool.Deck[T]', - deck_count: int) -> NextMultisetGenerator: + deck_count: int) -> PopMultisetGeneration: """Common implementation for _generate_min and _generate_max.""" min_count = max( 0, deck_count + self.total_cards_dealt() - self.deck().size()) @@ -112,7 +113,7 @@ def _generate_common(self, popped_deck: 'icepool.Deck[T]', weight = weight_total * weight_split yield popped_deal, counts, weight - def _generate_min(self, min_outcome) -> NextMultisetGenerator: + def _generate_min(self, min_outcome) -> PopMultisetGeneration: if not self.outcomes() or min_outcome != self.min_outcome(): yield self, (0, ), 1 return @@ -121,7 +122,7 @@ def _generate_min(self, min_outcome) -> NextMultisetGenerator: yield from self._generate_common(popped_deck, deck_count) - def _generate_max(self, max_outcome) -> NextMultisetGenerator: + def _generate_max(self, max_outcome) -> PopMultisetGeneration: if not self.outcomes() or max_outcome != self.max_outcome(): yield self, (0, ), 1 return @@ -130,11 +131,12 @@ def _generate_max(self, max_outcome) -> NextMultisetGenerator: yield from self._generate_common(popped_deck, deck_count) - def _preferred_pop_order(self) -> tuple[Order | None, PopOrderReason]: + def _local_preferred_pop_order( + self) -> tuple[Order | None, PopOrderReason]: return Order.Any, PopOrderReason.NoPreference @cached_property - def _hash_key(self) -> Hashable: + def _local_hash_key(self) -> Hashable: return MultiDeal, self.deck(), self.hand_sizes() def __repr__(self) -> str: diff --git a/src/icepool/generator/multiset_generator.py b/src/icepool/generator/multiset_generator.py index 748c83e2..d306154b 100644 --- a/src/icepool/generator/multiset_generator.py +++ b/src/icepool/generator/multiset_generator.py @@ -52,7 +52,7 @@ def _can_keep(self) -> bool: def _free_arity(self) -> int: return 0 - def order(self) -> Order: + def local_order(self) -> Order: return Order.Any # Overridden to switch bound generators with variables. diff --git a/src/icepool/generator/pool.py b/src/icepool/generator/pool.py index 6b4b8d32..cf5640cc 100644 --- a/src/icepool/generator/pool.py +++ b/src/icepool/generator/pool.py @@ -1,11 +1,11 @@ __docformat__ = 'google' import icepool -import icepool.expression +import icepool.multiset_expression import icepool.math import icepool.creation_args +from icepool.multiset_expression import InitialMultisetGeneration, PopMultisetGeneration from icepool.generator.keep import KeepGenerator, pop_max_from_keep_tuple, pop_min_from_keep_tuple -from icepool.generator.multiset_generator import InitialMultisetGenerator, NextMultisetGenerator import icepool.generator.pop_order from icepool.generator.pop_order import PopOrderReason @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any, Collection, Iterator, Mapping, MutableMapping, Sequence, cast if TYPE_CHECKING: - from icepool.expression import MultisetExpression + from icepool.multiset_expression import MultisetExpression class Pool(KeepGenerator[T]): @@ -170,7 +170,8 @@ def outcomes(self) -> Sequence[T]: def output_arity(self) -> int: return 1 - def _preferred_pop_order(self) -> tuple[Order | None, PopOrderReason]: + def _local_preferred_pop_order( + self) -> tuple[Order | None, PopOrderReason]: can_truncate_min, can_truncate_max = icepool.generator.pop_order.can_truncate( self.unique_dice()) if can_truncate_min and not can_truncate_max: @@ -195,10 +196,10 @@ def max_outcome(self) -> T: """The max outcome among all dice in this pool.""" return self._outcomes[-1] - def _generate_initial(self) -> InitialMultisetGenerator: + def _generate_initial(self) -> InitialMultisetGeneration: yield self, 1 - def _generate_min(self, min_outcome) -> NextMultisetGenerator: + def _generate_min(self, min_outcome) -> PopMultisetGeneration: """Pops the given outcome from this pool, if it is the min outcome. Yields: @@ -244,7 +245,7 @@ def _generate_min(self, min_outcome) -> NextMultisetGenerator: popped_pool = Pool._new_raw((), self._outcomes[1:], ()) yield popped_pool, (sum(self.keep_tuple()), ), skip_weight - def _generate_max(self, max_outcome) -> NextMultisetGenerator: + def _generate_max(self, max_outcome) -> PopMultisetGeneration: """Pops the given outcome from this pool, if it is the max outcome. Yields: @@ -298,7 +299,7 @@ def additive_union( *args: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]' ) -> 'MultisetExpression[T]': args = tuple( - icepool.expression.implicit_convert_to_expression(arg) + icepool.multiset_expression.implicit_convert_to_expression(arg) for arg in args) if all(isinstance(arg, Pool) for arg in args): pools = cast(tuple[Pool[T], ...], args) @@ -325,7 +326,7 @@ def __str__(self) -> str: for die, count in self._dice)) @cached_property - def _hash_key(self) -> tuple: + def _local_hash_key(self) -> tuple: return Pool, self._dice, self._outcomes, self._keep_tuple diff --git a/src/icepool/multiset_expression.py b/src/icepool/multiset_expression.py index 706a2206..5d3ac4d2 100644 --- a/src/icepool/multiset_expression.py +++ b/src/icepool/multiset_expression.py @@ -7,12 +7,17 @@ import icepool from icepool.collection.counts import Counts -from icepool.generator.pop_order import PopOrderReason +from icepool.generator.pop_order import PopOrderReason, merge_pop_orders from icepool.typing import T, U, ImplicitConversionError, Order, Outcome, T from typing import Any, Callable, Collection, Generic, Hashable, Iterable, Iterator, Literal, Mapping, Self, Sequence, Type, TypeAlias, TypeVar, cast, overload from abc import ABC, abstractmethod +InitialMultisetGeneration: TypeAlias = Iterator[tuple['MultisetExpression', + int]] +PopMultisetGeneration: TypeAlias = Iterator[tuple['MultisetExpression', + Sequence, int]] + def implicit_convert_to_expression( arg: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]' @@ -108,7 +113,7 @@ def _is_resolvable(self) -> bool: """ @abstractmethod - def _generate_initial(self) -> Iterator[tuple['MultisetExpression', int]]: + def _generate_initial(self) -> InitialMultisetGeneration: """Initialize the expression before any outcomes are emitted. Yields: @@ -119,9 +124,7 @@ def _generate_initial(self) -> Iterator[tuple['MultisetExpression', int]]: """ @abstractmethod - def _generate_min( - self, min_outcome: T - ) -> Iterator[tuple['MultisetExpression', Sequence, int]]: + def _generate_min(self, min_outcome: T) -> PopMultisetGeneration: """Pops the min outcome from this expression if it matches the argument. Yields: @@ -141,9 +144,7 @@ def _generate_min( """ @abstractmethod - def _generate_max( - self, max_outcome: T - ) -> Iterator[tuple['MultisetExpression', Sequence, int]]: + def _generate_max(self, max_outcome: T) -> PopMultisetGeneration: """Pops the max outcome from this expression if it matches the argument. Yields: @@ -163,8 +164,9 @@ def _generate_max( """ @abstractmethod - def _preferred_pop_order(self) -> tuple[Order | None, PopOrderReason]: - """Returns the preferred pop order of the expression, along with the priority of that pop order. + def _local_preferred_pop_order( + self) -> tuple[Order | None, PopOrderReason]: + """Returns the preferred pop order of this expression node, along with the priority of that pop order. Greater priorities strictly outrank lower priorities. An order of `None` represents conflicting orders and can occur in the @@ -172,15 +174,8 @@ def _preferred_pop_order(self) -> tuple[Order | None, PopOrderReason]: """ @abstractmethod - def order(self) -> Order: - """Any ordering that is required by this expression. - - This should check any previous expressions for their order, and - raise a ValueError for any incompatibilities. - - Returns: - The required order. - """ + def local_order(self) -> Order: + """Any ordering that is required by this expression node.""" @abstractmethod def _free_arity(self) -> int: @@ -254,6 +249,16 @@ def _hash(self) -> int: def __hash__(self) -> int: return self._hash + def _iter_nodes(self) -> 'Iterator[MultisetExpression]': + """Iterates over the nodes in this expression in post-order (leaves first).""" + for child in self._children: + yield from child._iter_nodes() + yield self + + def _preferred_pop_order(self) -> tuple[Order | None, PopOrderReason]: + return merge_pop_orders(*(node._local_preferred_pop_order() + for node in self._iter_nodes())) + # Sampling. def sample(self) -> tuple[tuple, ...]: @@ -268,7 +273,8 @@ def sample(self) -> tuple[tuple, ...]: if not self.outcomes(): raise ValueError('Cannot sample from an empty set of outcomes.') - preferred_pop_order, pop_order_reason = self._preferred_pop_order() + preferred_pop_order, pop_order_reason = self._local_preferred_pop_order( + ) if preferred_pop_order is not None and preferred_pop_order > 0: outcome = self.min_outcome() @@ -291,6 +297,199 @@ def sample(self) -> tuple[tuple, ...]: else: return head + # Binary operators. + + def __add__(self, + other: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]', + /) -> 'MultisetExpression[T]': + try: + return MultisetExpression.additive_union(self, other) + except ImplicitConversionError: + return NotImplemented + + def __radd__( + self, + other: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]', + /) -> 'MultisetExpression[T]': + try: + return MultisetExpression.additive_union(other, self) + except ImplicitConversionError: + return NotImplemented + + def additive_union( + *args: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]' + ) -> 'MultisetExpression[T]': + """The combined elements from all of the multisets. + + Same as `a + b + c + ...`. + + Any resulting counts that would be negative are set to zero. + + Example: + ```python + [1, 2, 2, 3] + [1, 2, 4] -> [1, 1, 2, 2, 2, 3, 4] + ``` + """ + expressions = tuple( + implicit_convert_to_expression(arg) for arg in args) + return icepool.transform.MultisetAdditiveUnion(*expressions) + + def __sub__(self, + other: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]', + /) -> 'MultisetExpression[T]': + try: + return MultisetExpression.difference(self, other) + except ImplicitConversionError: + return NotImplemented + + def __rsub__( + self, + other: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]', + /) -> 'MultisetExpression[T]': + try: + return MultisetExpression.difference(other, self) + except ImplicitConversionError: + return NotImplemented + + def difference( + *args: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]' + ) -> 'MultisetExpression[T]': + """The elements from the left multiset that are not in any of the others. + + Same as `a - b - c - ...`. + + Any resulting counts that would be negative are set to zero. + + Example: + ```python + [1, 2, 2, 3] - [1, 2, 4] -> [2, 3] + ``` + + If no arguments are given, the result will be an empty multiset, i.e. + all zero counts. + + Note that, as a multiset operation, this will only cancel elements 1:1. + If you want to drop all elements in a set of outcomes regardless of + count, either use `drop_outcomes()` instead, or use a large number of + counts on the right side. + """ + expressions = tuple( + implicit_convert_to_expression(arg) for arg in args) + return icepool.transform.MultisetDifference(*expressions) + + def __and__(self, + other: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]', + /) -> 'MultisetExpression[T]': + try: + return MultisetExpression.intersection(self, other) + except ImplicitConversionError: + return NotImplemented + + def __rand__( + self, + other: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]', + /) -> 'MultisetExpression[T]': + try: + return MultisetExpression.intersection(other, self) + except ImplicitConversionError: + return NotImplemented + + def intersection( + *args: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]' + ) -> 'MultisetExpression[T]': + """The elements that all the multisets have in common. + + Same as `a & b & c & ...`. + + Any resulting counts that would be negative are set to zero. + + Example: + ```python + [1, 2, 2, 3] & [1, 2, 4] -> [1, 2] + ``` + + Note that, as a multiset operation, this will only intersect elements + 1:1. + If you want to keep all elements in a set of outcomes regardless of + count, either use `keep_outcomes()` instead, or use a large number of + counts on the right side. + """ + expressions = tuple( + implicit_convert_to_expression(arg) for arg in args) + return icepool.transform.MultisetIntersection(*expressions) + + def __or__(self, + other: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]', + /) -> 'MultisetExpression[T]': + try: + return MultisetExpression.union(self, other) + except ImplicitConversionError: + return NotImplemented + + def __ror__(self, + other: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]', + /) -> 'MultisetExpression[T]': + try: + return MultisetExpression.union(other, self) + except ImplicitConversionError: + return NotImplemented + + def union( + *args: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]' + ) -> 'MultisetExpression[T]': + """The most of each outcome that appear in any of the multisets. + + Same as `a | b | c | ...`. + + Any resulting counts that would be negative are set to zero. + + Example: + ```python + [1, 2, 2, 3] | [1, 2, 4] -> [1, 2, 2, 3, 4] + ``` + """ + expressions = tuple( + implicit_convert_to_expression(arg) for arg in args) + return icepool.transform.MultisetUnion(*expressions) + + def __xor__(self, + other: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]', + /) -> 'MultisetExpression[T]': + try: + return MultisetExpression.symmetric_difference(self, other) + except ImplicitConversionError: + return NotImplemented + + def __rxor__( + self, + other: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]', + /) -> 'MultisetExpression[T]': + try: + # Symmetric. + return MultisetExpression.symmetric_difference(self, other) + except ImplicitConversionError: + return NotImplemented + + def symmetric_difference( + self, + other: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]', + /) -> 'MultisetExpression[T]': + """The elements that appear in the left or right multiset but not both. + + Same as `a ^ b`. + + Specifically, this produces the absolute difference between counts. + If you don't want negative counts to be used from the inputs, you can + do `left.keep_counts('>=', 0) ^ right.keep_counts('>=', 0)`. + + Example: + ```python + [1, 2, 2, 3] ^ [1, 2, 4] -> [2, 3, 4] + ``` + """ + other = implicit_convert_to_expression(other) + return icepool.transform.MultisetSymmetricDifference(self, other) + def _compare( self, right: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]', @@ -298,21 +497,14 @@ def _compare( *, truth_value_callback: 'Callable[[], bool] | None' = None ) -> 'icepool.Die[bool] | icepool.MultisetEvaluator[T, bool]': - if isinstance(right, MultisetExpression): - evaluator = icepool.evaluator.ExpressionEvaluator( - self, right, evaluator=operation_class()) - elif isinstance(right, (Mapping, Sequence)): - right_expression = icepool.implicit_convert_to_expression(right) - evaluator = icepool.evaluator.ExpressionEvaluator( - self, right_expression, evaluator=operation_class()) - else: - raise TypeError('Operand not comparable with expression.') + right = icepool.implicit_convert_to_expression(right) - if evaluator._free_arity == 0: + if self._free_arity() == 0 and right._free_arity() == 0: if truth_value_callback is not None: def data_callback() -> Counts[bool]: - die = cast('icepool.Die[bool]', evaluator.evaluate()) + die = cast('icepool.Die[bool]', + operation_class().evaluate(self, right)) if not isinstance(die, icepool.Die): raise TypeError('Did not resolve to a die.') return die._data @@ -320,9 +512,10 @@ def data_callback() -> Counts[bool]: return icepool.DieWithTruth(data_callback, truth_value_callback) else: - return evaluator.evaluate() + return operation_class().evaluate(self, right) else: - return evaluator + return icepool.evaluator.ExpressionEvaluator( + self, right, evaluator=operation_class()) def __eq__( # type: ignore self, diff --git a/src/icepool/expression/multiset_function.py b/src/icepool/multiset_function.py similarity index 98% rename from src/icepool/expression/multiset_function.py rename to src/icepool/multiset_function.py index 59567067..f0617487 100644 --- a/src/icepool/expression/multiset_function.py +++ b/src/icepool/multiset_function.py @@ -2,7 +2,7 @@ import icepool.evaluator from icepool.evaluator.multiset_evaluator import MultisetEvaluator -from icepool.expression.variable import MultisetVariable as MV +from icepool.multiset_variable import MultisetVariable as MV import inspect from functools import update_wrapper diff --git a/src/icepool/multiset_variable.py b/src/icepool/multiset_variable.py index 0413dd3f..da1b27cc 100644 --- a/src/icepool/multiset_variable.py +++ b/src/icepool/multiset_variable.py @@ -1,7 +1,7 @@ __docformat__ = 'google' from icepool.generator.pop_order import PopOrderReason -from icepool.multiset_expression import MultisetExpression +from icepool.multiset_expression import MultisetExpression, InitialMultisetGeneration, PopMultisetGeneration from typing import Any, Hashable, Iterator, Self, Sequence @@ -28,23 +28,20 @@ def output_arity(self) -> int: def _is_resolvable(self) -> bool: raise UnboundMultisetExpressionError() - def _generate_initial(self) -> Iterator[tuple['MultisetExpression', int]]: + def _generate_initial(self) -> InitialMultisetGeneration: raise UnboundMultisetExpressionError() - def _generate_min( - self, min_outcome - ) -> Iterator[tuple['MultisetExpression', Sequence, int]]: + def _generate_min(self, min_outcome) -> PopMultisetGeneration: raise UnboundMultisetExpressionError() - def _generate_max( - self, max_outcome - ) -> Iterator[tuple['MultisetExpression', Sequence, int]]: + def _generate_max(self, max_outcome) -> PopMultisetGeneration: raise UnboundMultisetExpressionError() - def _preferred_pop_order(self) -> tuple[Order | None, PopOrderReason]: + def _local_preferred_pop_order( + self) -> tuple[Order | None, PopOrderReason]: raise UnboundMultisetExpressionError() - def order(self) -> Order: + def local_order(self) -> Order: return Order.Any def _free_arity(self) -> int: @@ -57,7 +54,7 @@ def _unbind(self, next_index: int) -> 'tuple[MultisetExpression, int]': return self, next_index def _local_hash_key(self) -> Hashable: - raise UnboundMultisetExpressionError() + return (MultisetVariable, self._index) def __str__(self) -> str: return f'mv[{self._index}]' diff --git a/src/icepool/transform/__init__.py b/src/icepool/transform/__init__.py index e69de29b..4224bb64 100644 --- a/src/icepool/transform/__init__.py +++ b/src/icepool/transform/__init__.py @@ -0,0 +1,5 @@ +from icepool.transform.binary_operator import (MultisetIntersection, + MultisetDifference, + MultisetUnion, + MultisetAdditiveUnion, + MultisetSymmetricDifference) diff --git a/src/icepool/transform/binary_operator.py b/src/icepool/transform/binary_operator.py new file mode 100644 index 00000000..8819578c --- /dev/null +++ b/src/icepool/transform/binary_operator.py @@ -0,0 +1,114 @@ +__docformat__ = 'google' + +import icepool + +from icepool.multiset_expression import MultisetExpression +from icepool.transform.multiset_transform import MultisetTransform + +import operator +from abc import abstractmethod +from functools import cached_property, reduce + +from typing import Hashable, Iterable +from icepool.typing import Order, T + + +class MultisetBinaryOperator(MultisetTransform[T]): + + def __init__(self, *children: MultisetExpression[T]) -> None: + """Constructor. + + Args: + *children: Any number of expressions to feed into the operator. + If zero expressions are provided, the result will have all zero + counts. + If more than two expressions are provided, the counts will be + `reduce`d. + """ + self._children = children + + @staticmethod + @abstractmethod + def merge_counts(left: int, right: int) -> int: + """Merge counts produced by the left and right expression.""" + + @staticmethod + @abstractmethod + def symbol() -> str: + """A symbol representing this operation.""" + + def _copy( + self, copy_children: 'Iterable[MultisetExpression[T]]' + ) -> 'MultisetExpression[T]': + return type(self)(*copy_children) + + def _transform_next( + self, next_children: 'Iterable[MultisetExpression[T]]', outcome: T, + counts: 'Iterable[int]') -> 'tuple[MultisetExpression[T], int]': + count = reduce(self.merge_counts, counts) + return type(self)(*next_children), count + + def local_order(self) -> Order: + return Order.Any + + def _local_hash_key(self) -> Hashable: + return type(self) + + def __str__(self) -> str: + return '(' + (' ' + self.symbol() + ' ').join( + str(child) for child in self._children) + ')' + + +class MultisetIntersection(MultisetBinaryOperator): + + @staticmethod + def merge_counts(left: int, right: int) -> int: + return min(left, right) + + @staticmethod + def symbol() -> str: + return '&' + + +class MultisetDifference(MultisetBinaryOperator): + + @staticmethod + def merge_counts(left: int, right: int) -> int: + return left - right + + @staticmethod + def symbol() -> str: + return '-' + + +class MultisetUnion(MultisetBinaryOperator): + + @staticmethod + def merge_counts(left: int, right: int) -> int: + return max(left, right) + + @staticmethod + def symbol() -> str: + return '|' + + +class MultisetAdditiveUnion(MultisetBinaryOperator): + + @staticmethod + def merge_counts(left: int, right: int) -> int: + return left + right + + @staticmethod + def symbol() -> str: + return '+' + + +class MultisetSymmetricDifference(MultisetBinaryOperator): + + @staticmethod + def merge_counts(left: int, right: int) -> int: + return abs(left - right) + + @staticmethod + def symbol() -> str: + return '^' diff --git a/src/icepool/transform/multiset_transform.py b/src/icepool/transform/multiset_transform.py index 249a58f5..d6b4aa88 100644 --- a/src/icepool/transform/multiset_transform.py +++ b/src/icepool/transform/multiset_transform.py @@ -2,7 +2,7 @@ import icepool from icepool.generator.pop_order import PopOrderReason, merge_pop_orders -from icepool.multiset_expression import MultisetExpression +from icepool.multiset_expression import MultisetExpression, InitialMultisetGeneration, PopMultisetGeneration import itertools import math @@ -12,21 +12,24 @@ from abc import abstractmethod -C = TypeVar('C', bound='MultisetTransform') -"""Type variable representing a subclass of `MultisetTransform`.""" - class MultisetTransform(MultisetExpression[T]): """Internal node of an expression taking one or more counts and producing a single count.""" @abstractmethod - def _copy(self: C, children: 'Iterable[MultisetExpression[T]]') -> C: - """Creates a copy of self with the given children.""" + def _copy( + self, copy_children: 'Iterable[MultisetExpression[T]]' + ) -> 'MultisetExpression[T]': + """Creates a copy of self with the given children. + + I considered using `copy.copy` but this doesn't play well with + incidental members such as in `@cached_property`. + """ @abstractmethod - def _transform_next(self: C, - next_children: 'Iterable[MultisetExpression[T]]', - outcome: T, counts: 'Iterable[int]') -> tuple[C, int]: + def _transform_next( + self, next_children: 'Iterable[MultisetExpression[T]]', outcome: T, + counts: 'Iterable[int]') -> 'tuple[MultisetExpression[T], int]': """Produce the next state of this expression. Args: @@ -52,16 +55,14 @@ def output_arity(self) -> int: def _is_resolvable(self) -> bool: return all(child._is_resolvable() for child in self._children) - def _generate_initial(self) -> Iterator[tuple['MultisetExpression', int]]: + def _generate_initial(self) -> InitialMultisetGeneration: for t in itertools.product(*(child._generate_initial() for child in self._children)): next_children, weights = zip(*t) next_self = self._copy(next_children) yield next_self, math.prod(weights) - def _generate_min( - self, min_outcome: T - ) -> Iterator[tuple['MultisetExpression', Sequence, int]]: + def _generate_min(self, min_outcome: T) -> PopMultisetGeneration: for t in itertools.product(*(child._generate_min(min_outcome) for child in self._children)): next_children, counts, weights = zip(*t) @@ -69,9 +70,7 @@ def _generate_min( counts) yield next_self, (count, ), math.prod(weights) - def _generate_max( - self, max_outcome: T - ) -> Iterator[tuple['MultisetExpression', Sequence, int]]: + def _generate_max(self, max_outcome: T) -> PopMultisetGeneration: for t in itertools.product(*(child._generate_min(max_outcome) for child in self._children)): next_children, counts, weights = zip(*t) @@ -79,8 +78,9 @@ def _generate_max( counts) yield next_self, (count, ), math.prod(weights) - def _preferred_pop_order(self) -> tuple[Order | None, PopOrderReason]: - return merge_pop_orders(*(child._preferred_pop_order() + def _local_preferred_pop_order( + self) -> tuple[Order | None, PopOrderReason]: + return merge_pop_orders(*(child._local_preferred_pop_order() for child in self._children)) def _free_arity(self) -> int: diff --git a/src/icepool/typing.py b/src/icepool/typing.py index 6b4805e5..7aa4507e 100644 --- a/src/icepool/typing.py +++ b/src/icepool/typing.py @@ -6,7 +6,7 @@ from typing import Any, Callable, Hashable, Iterable, Literal, Mapping, Protocol, Sequence, Sized, TypeAlias, TypeGuard, TypeVar, TYPE_CHECKING if TYPE_CHECKING: - from icepool.expression.multiset_expression import MultisetExpression + from icepool.multiset_expression import MultisetExpression S = TypeVar('S', bound='Sequence') """A sequence type.""" @@ -17,7 +17,7 @@ T_co = TypeVar('T_co', bound='Outcome', covariant=True) """An outcome type.""" -T = TypeVar('T_contra', bound='Outcome', contravariant=True) +T_contra = TypeVar('T_contra', bound='Outcome', contravariant=True) """An outcome type.""" U = TypeVar('U', bound='Outcome') @@ -66,13 +66,13 @@ class RerollType(enum.Enum): """Indicates an outcome should be rerolled (with unlimited depth).""" -class Outcome(Hashable, Protocol[T]): +class Outcome(Hashable, Protocol[T_contra]): """Protocol to attempt to verify that outcome types are hashable and sortable. Far from foolproof, e.g. it cannot enforce total ordering. """ - def __lt__(self, other: T) -> bool: + def __lt__(self, other: T_contra) -> bool: ... diff --git a/tests/evaluator_test.py b/tests/evaluator_test.py index bd51ec8f..e080058b 100644 --- a/tests/evaluator_test.py +++ b/tests/evaluator_test.py @@ -2,7 +2,7 @@ import pytest from icepool import d4, d6, d8, d10, d12, Pool, Vector -from icepool.expression import multiset_function +from icepool.multiset_function import multiset_function class SumRerollIfAnyOnes(icepool.MultisetEvaluator): diff --git a/tests/import_all_test.py b/tests/import_all_test.py index 18f54664..104316f9 100644 --- a/tests/import_all_test.py +++ b/tests/import_all_test.py @@ -1,6 +1,6 @@ from icepool import * from icepool.evaluator import * -from icepool.expression import * +from icepool.generator import * import pytest diff --git a/tests/neon_city_overdrive_test.py b/tests/neon_city_overdrive_test.py index c7fd0bef..61532d24 100644 --- a/tests/neon_city_overdrive_test.py +++ b/tests/neon_city_overdrive_test.py @@ -2,7 +2,7 @@ import pytest from icepool import d6, Die -from icepool.expression import multiset_function +from icepool.multiset_function import multiset_function # See https://rpg.stackexchange.com/q/171498/72732 # for approaches by myself and others. diff --git a/tests/pop_order_test.py b/tests/pop_order_test.py index 02578959..4fb666ee 100644 --- a/tests/pop_order_test.py +++ b/tests/pop_order_test.py @@ -51,59 +51,59 @@ def test_pop_order_conflict_override(): def test_pool_single_type(): pool = icepool.Pool([d6, d6, d6]) - assert pool._preferred_pop_order() == (Order.Any, - PopOrderReason.NoPreference) + assert pool._local_preferred_pop_order() == (Order.Any, + PopOrderReason.NoPreference) def test_pool_standard(): pool = icepool.Pool([d8, d12, d6]) - assert pool._preferred_pop_order() == (Order.Descending, - PopOrderReason.PoolComposition) + assert pool._local_preferred_pop_order() == ( + Order.Descending, PopOrderReason.PoolComposition) def test_pool_standard_negative(): pool = icepool.Pool([-d8, -d12, -d6]) - assert pool._preferred_pop_order() == (Order.Ascending, - PopOrderReason.PoolComposition) + assert pool._local_preferred_pop_order() == ( + Order.Ascending, PopOrderReason.PoolComposition) def test_pool_non_truncate(): pool = icepool.Pool([-d8, d12, -d6]) - assert pool._preferred_pop_order() == (Order.Any, - PopOrderReason.NoPreference) + assert pool._local_preferred_pop_order() == (Order.Any, + PopOrderReason.NoPreference) def test_pool_skip_min(): pool = icepool.Pool([d6, d6, d6])[0, 1, 1] - assert pool._preferred_pop_order() == (Order.Descending, - PopOrderReason.KeepSkip) + assert pool._local_preferred_pop_order() == (Order.Descending, + PopOrderReason.KeepSkip) def test_pool_skip_max(): pool = icepool.Pool([d6, d6, d6])[1, 1, 0] - assert pool._preferred_pop_order() == (Order.Ascending, - PopOrderReason.KeepSkip) + assert pool._local_preferred_pop_order() == (Order.Ascending, + PopOrderReason.KeepSkip) def test_pool_skip_min_but_truncate(): pool = icepool.Pool([-d6, -d6, -d8])[0, 1, 1] - assert pool._preferred_pop_order() == (Order.Ascending, - PopOrderReason.PoolComposition) + assert pool._local_preferred_pop_order() == ( + Order.Ascending, PopOrderReason.PoolComposition) def test_pool_skip_max_but_truncate(): pool = icepool.Pool([d6, d6, d8])[1, 1, 0] - assert pool._preferred_pop_order() == (Order.Descending, - PopOrderReason.PoolComposition) + assert pool._local_preferred_pop_order() == ( + Order.Descending, PopOrderReason.PoolComposition) def test_deck_skip_min(): deck = icepool.Deck(range(10)).deal(4)[..., 1, 1] - assert deck._preferred_pop_order() == (Order.Descending, - PopOrderReason.KeepSkip) + assert deck._local_preferred_pop_order() == (Order.Descending, + PopOrderReason.KeepSkip) def test_deck_skip_max(): deck = icepool.Deck(range(10)).deal(4)[1, 1, ...] - assert deck._preferred_pop_order() == (Order.Ascending, - PopOrderReason.KeepSkip) + assert deck._local_preferred_pop_order() == (Order.Ascending, + PopOrderReason.KeepSkip) diff --git a/tests/subset_target_test.py b/tests/subset_target_test.py index f92c7c80..cd3c20ef 100644 --- a/tests/subset_target_test.py +++ b/tests/subset_target_test.py @@ -3,7 +3,7 @@ import pytest from icepool import d, Die, Order, Pool -from icepool.expression.multiset_expression import MultisetExpression +from icepool.multiset_expression import MultisetExpression targets_to_test = [ (),