Skip to content

Commit

Permalink
more progress on expression rework #203
Browse files Browse the repository at this point in the history
  • Loading branch information
HighDiceRoller committed Dec 25, 2024
1 parent 0327ef9 commit ee7c305
Show file tree
Hide file tree
Showing 24 changed files with 485 additions and 170 deletions.
6 changes: 3 additions & 3 deletions src/icepool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,22 +127,22 @@

from icepool.multiset_expression import MultisetExpression, implicit_convert_to_expression

from icepool.generator.multiset_generator import MultisetGenerator, InitialMultisetGenerator, NextMultisetGenerator
from icepool.generator.multiset_generator import MultisetGenerator
from icepool.generator.alignment import Alignment
from icepool.evaluator.multiset_evaluator import MultisetEvaluator

from icepool.population.deck import Deck
from icepool.generator.deal import Deal
from icepool.generator.multi_deal import MultiDeal

from icepool.expression.multiset_function import multiset_function
from icepool.multiset_function import multiset_function
from icepool.multiset_variable import MultisetVariable

from icepool.population.format import format_probability_inverse

import icepool.expression as expression
import icepool.generator as generator
import icepool.evaluator as evaluator
import icepool.transform as transform

import icepool.typing as typing

Expand Down
5 changes: 3 additions & 2 deletions src/icepool/evaluator/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from functools import cached_property
import itertools
import icepool
import icepool.expression
import icepool.multiset_expression

from icepool.evaluator.multiset_evaluator import MultisetEvaluator
from icepool.typing import Order, Outcome, T, U_co
Expand All @@ -15,7 +15,8 @@ class ExpressionEvaluator(MultisetEvaluator[T, U_co]):
"""Assigns an expression to be evaluated first to each input of an evaluator."""

def __init__(self,
*expressions: 'icepool.expression.MultisetExpression[T]',
*expressions:
'icepool.multiset_expression.MultisetExpression[T]',
evaluator: MultisetEvaluator[T, U_co],
truth_value: bool | None = None) -> None:
self._evaluator = evaluator
Expand Down
50 changes: 22 additions & 28 deletions src/icepool/evaluator/multiset_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def validate_arity(self, arity: int) -> None:
@cached_property
def _cache(
self
) -> 'MutableMapping[tuple[Order, Alignment, tuple[MultisetGenerator, ...], Hashable], Mapping[Any, int]]':
) -> 'MutableMapping[tuple[Order, Alignment, tuple[MultisetExpression, ...], Hashable], Mapping[Any, int]]':
"""Cached results.
The key is `(order, extra_outcomes, generators, state)`.
Expand Down Expand Up @@ -257,34 +257,26 @@ def evaluate(
from icepool.evaluator.expression import ExpressionEvaluator
return ExpressionEvaluator(*expressions, evaluator=self)

if not all(
isinstance(expression, icepool.MultisetGenerator)
for expression in expressions):
from icepool.evaluator.expression import ExpressionEvaluator
return ExpressionEvaluator(*expressions, evaluator=self).evaluate()

generators = cast(tuple[icepool.MultisetGenerator, ...], expressions)

self.validate_arity(
sum(generator.output_arity() for generator in generators))
sum(expression.output_arity() for expression in expressions))

generators = self.prefix_generators() + generators
expressions = self.prefix_generators() + expressions

if not all(generator._is_resolvable() for generator in generators):
if not all(expression._is_resolvable() for expression in expressions):
return icepool.Die([])

algorithm, order = self._select_algorithm(*generators)
algorithm, order = self._select_algorithm(*expressions)

outcomes = icepool.sorted_union(*(generator.outcomes()
for generator in generators))
outcomes = icepool.sorted_union(*(expression.outcomes()
for expression in expressions))
extra_outcomes = Alignment(self.extra_outcomes(outcomes))

dist: MutableMapping[Any, int] = defaultdict(int)
iterators = MultisetEvaluator._initialize_generators(generators)
iterators = MultisetEvaluator._initialize_expressions(expressions)
for p in itertools.product(*iterators):
sub_generators, sub_weights = zip(*p)
sub_expressions, sub_weights = zip(*p)
prod_weight = math.prod(sub_weights)
sub_result = algorithm(order, extra_outcomes, sub_generators)
sub_result = algorithm(order, extra_outcomes, sub_expressions)
for sub_state, sub_weight in sub_result.items():
dist[sub_state] += sub_weight * prod_weight

Expand All @@ -306,9 +298,9 @@ def evaluate(
__call__ = evaluate

def _select_algorithm(
self, *generators: 'icepool.MultisetGenerator[T, Any]'
self, *generators: 'icepool.MultisetExpression[T]'
) -> tuple[
'Callable[[Order, Alignment[T], tuple[icepool.MultisetGenerator[T, Any], ...]], Mapping[Any, int]]',
'Callable[[Order, Alignment[T], tuple[icepool.MultisetExpression[T], ...]], Mapping[Any, int]]',
Order]:
"""Selects an algorithm and iteration order.
Expand All @@ -324,7 +316,8 @@ def _select_algorithm(
return self._eval_internal, eval_order

preferred_pop_order, pop_order_reason = merge_pop_orders(
*(generator._preferred_pop_order() for generator in generators))
*(generator._local_preferred_pop_order()
for generator in generators))

if preferred_pop_order is None:
preferred_pop_order = Order.Any
Expand All @@ -346,7 +339,7 @@ def _select_algorithm(

def _eval_internal(
self, order: Order, extra_outcomes: 'Alignment[T]',
generators: 'tuple[icepool.MultisetGenerator[T, Any], ...]'
generators: 'tuple[icepool.MultisetExpression[T], ...]'
) -> Mapping[Any, int]:
"""Internal algorithm for iterating in the more-preferred order.
Expand Down Expand Up @@ -438,16 +431,17 @@ def _eval_internal_forward(
return result

@staticmethod
def _initialize_generators(
generators: 'tuple[icepool.MultisetGenerator[T, Any], ...]'
) -> 'tuple[icepool.InitialMultisetGenerator, ...]':
return tuple(generator._generate_initial() for generator in generators)
def _initialize_expressions(
expressions: 'tuple[icepool.MultisetExpression[T], ...]'
) -> 'tuple[icepool.InitialMultisetGeneration, ...]':
return tuple(expression._generate_initial()
for expression in expressions)

@staticmethod
def _pop_generators(
order: Order, extra_outcomes: 'Alignment[T]',
generators: 'tuple[icepool.MultisetGenerator[T, Any], ...]'
) -> 'tuple[T, Alignment[T], tuple[icepool.NextMultisetGenerator, ...]]':
generators: 'tuple[icepool.MultisetExpression[T], ...]'
) -> 'tuple[T, Alignment[T], tuple[icepool.PopMultisetGeneration, ...]]':
"""Pops a single outcome from the generators.
Args:
Expand Down
2 changes: 1 addition & 1 deletion src/icepool/expression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from icepool.expression.keep import KeepExpression
from icepool.expression.match import SortMatchExpression, MaximumMatchExpression

from icepool.expression.multiset_function import multiset_function
from icepool.multiset_function import multiset_function

__all__ = [
'multiset_function', 'MultisetExpression', 'MultisetVariable',
Expand Down
5 changes: 3 additions & 2 deletions src/icepool/generator/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,13 @@ def _generate_max(self, max_outcome) -> AlignmentGenerator:
else:
yield Alignment(self.outcomes()[:-1]), (), 1

def _preferred_pop_order(self) -> tuple[Order | None, PopOrderReason]:
def _local_preferred_pop_order(
self) -> tuple[Order | None, PopOrderReason]:
return Order.Any, PopOrderReason.NoPreference

def denominator(self) -> int:
return 1

@cached_property
def _hash_key(self) -> Hashable:
def _local_hash_key(self) -> Hashable:
return Alignment, self._outcomes
18 changes: 10 additions & 8 deletions src/icepool/generator/compound_keep.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
__docformat__ = 'google'

import icepool
from icepool.multiset_expression import InitialMultisetGeneration, PopMultisetGeneration
from icepool.generator.keep import KeepGenerator, pop_max_from_keep_tuple, pop_min_from_keep_tuple
from icepool.generator.multiset_generator import InitialMultisetGenerator, NextMultisetGenerator, MultisetGenerator
from icepool.generator.multiset_generator import MultisetGenerator
from icepool.generator.pop_order import PopOrderReason, merge_pop_orders

import itertools
Expand All @@ -22,18 +23,18 @@ def __init__(self, inners: Sequence[KeepGenerator[T]],

def outcomes(self) -> Sequence[T]:
return icepool.sorted_union(*(inner.outcomes()
for inner in self._inner_generators))
for inner in self._inner_generators))

def output_arity(self) -> int:
return 1

def _is_resolvable(self) -> bool:
return all(inner._is_resolvable() for inner in self._inner_generators)

def _generate_initial(self) -> InitialMultisetGenerator:
def _generate_initial(self) -> InitialMultisetGeneration:
yield self, 1

def _generate_min(self, min_outcome) -> NextMultisetGenerator:
def _generate_min(self, min_outcome) -> PopMultisetGeneration:
for t in itertools.product(*(inner._generate_min(min_outcome)
for inner in self._inner_generators)):
generators, counts, weights = zip(*t)
Expand All @@ -44,7 +45,7 @@ def _generate_min(self, min_outcome) -> NextMultisetGenerator:
yield CompoundKeepGenerator(
generators, popped_keep_tuple), (result_count, ), total_weight

def _generate_max(self, max_outcome) -> NextMultisetGenerator:
def _generate_max(self, max_outcome) -> PopMultisetGeneration:
for t in itertools.product(*(inner._generate_max(max_outcome)
for inner in self._inner_generators)):
generators, counts, weights = zip(*t)
Expand All @@ -55,8 +56,9 @@ def _generate_max(self, max_outcome) -> NextMultisetGenerator:
yield CompoundKeepGenerator(
generators, popped_keep_tuple), (result_count, ), total_weight

def _preferred_pop_order(self) -> tuple[Order | None, PopOrderReason]:
return merge_pop_orders(*(inner._preferred_pop_order()
def _local_preferred_pop_order(
self) -> tuple[Order | None, PopOrderReason]:
return merge_pop_orders(*(inner._local_preferred_pop_order()
for inner in self._inner_generators))

def denominator(self) -> int:
Expand All @@ -68,7 +70,7 @@ def _set_keep_tuple(self, keep_tuple: tuple[int,
return CompoundKeepGenerator(self._inner_generators, keep_tuple)

@property
def _hash_key(self) -> Hashable:
def _local_hash_key(self) -> Hashable:
return CompoundKeepGenerator, tuple(
inner._hash_key
for inner in self._inner_generators), self._keep_tuple
Expand Down
13 changes: 7 additions & 6 deletions src/icepool/generator/deal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import icepool
from icepool.generator.keep import KeepGenerator, pop_max_from_keep_tuple, pop_min_from_keep_tuple
from icepool.collection.counts import CountsKeysView
from icepool.generator.multiset_generator import InitialMultisetGenerator, NextMultisetGenerator
from icepool.multiset_expression import InitialMultisetGeneration, PopMultisetGeneration
import icepool.generator.pop_order
from icepool.generator.pop_order import PopOrderReason

Expand Down Expand Up @@ -85,10 +85,10 @@ def _is_resolvable(self) -> bool:
def denominator(self) -> int:
return icepool.math.comb(self.deck().size(), self._hand_size)

def _generate_initial(self) -> InitialMultisetGenerator:
def _generate_initial(self) -> InitialMultisetGeneration:
yield self, 1

def _generate_min(self, min_outcome) -> NextMultisetGenerator:
def _generate_min(self, min_outcome) -> PopMultisetGeneration:
if not self.outcomes() or min_outcome != self.min_outcome():
yield self, (0, ), 1
return
Expand All @@ -115,7 +115,7 @@ def _generate_min(self, min_outcome) -> NextMultisetGenerator:
popped_deal = Deal._new_raw(popped_deck, 0, ())
yield popped_deal, (sum(self.keep_tuple()), ), skip_weight

def _generate_max(self, max_outcome) -> NextMultisetGenerator:
def _generate_max(self, max_outcome) -> PopMultisetGeneration:
if not self.outcomes() or max_outcome != self.max_outcome():
yield self, (0, ), 1
return
Expand All @@ -142,7 +142,8 @@ def _generate_max(self, max_outcome) -> NextMultisetGenerator:
popped_deal = Deal._new_raw(popped_deck, 0, ())
yield popped_deal, (sum(self.keep_tuple()), ), skip_weight

def _preferred_pop_order(self) -> tuple[Order | None, PopOrderReason]:
def _local_preferred_pop_order(
self) -> tuple[Order | None, PopOrderReason]:
lo_skip, hi_skip = icepool.generator.pop_order.lo_hi_skip(
self.keep_tuple())
if lo_skip > hi_skip:
Expand All @@ -153,7 +154,7 @@ def _preferred_pop_order(self) -> tuple[Order | None, PopOrderReason]:
return Order.Any, PopOrderReason.NoPreference

@cached_property
def _hash_key(self) -> Hashable:
def _local_hash_key(self) -> Hashable:
return Deal, self.deck(), self._hand_size, self._keep_tuple

def __repr__(self) -> str:
Expand Down
8 changes: 5 additions & 3 deletions src/icepool/generator/keep.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
__docformat__ = 'google'

import icepool
from icepool.generator.multiset_generator import InitialMultisetGenerator, NextMultisetGenerator, MultisetGenerator
from icepool.multiset_expression import InitialMultisetGeneration, PopMultisetGeneration
from icepool.generator.multiset_generator import MultisetGenerator

import operator
from collections import defaultdict
Expand All @@ -10,10 +11,11 @@
from abc import ABC, abstractmethod
from types import EllipsisType
from typing import Hashable, Literal, Mapping, MutableMapping, Sequence, cast, overload, TYPE_CHECKING
import icepool.multiset_expression
from icepool.typing import ImplicitConversionError, Outcome, T

if TYPE_CHECKING:
from icepool.expression import MultisetExpression
from icepool.multiset_expression import MultisetExpression


class KeepGenerator(MultisetGenerator[T, tuple[int]]):
Expand Down Expand Up @@ -235,7 +237,7 @@ def additive_union(
*args: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]'
) -> 'MultisetExpression[T]':
args = tuple(
icepool.expression.implicit_convert_to_expression(arg)
icepool.multiset_expression.implicit_convert_to_expression(arg)
for arg in args)
if all(isinstance(arg, KeepGenerator) for arg in args):
generators = cast(tuple[KeepGenerator, ...], args)
Expand Down
20 changes: 11 additions & 9 deletions src/icepool/generator/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import icepool

from icepool.generator.multiset_generator import InitialMultisetGenerator, NextMultisetGenerator, MultisetGenerator
from icepool.multiset_expression import InitialMultisetGeneration, PopMultisetGeneration
from icepool.generator.multiset_generator import MultisetGenerator
from icepool.generator.pop_order import PopOrderReason, merge_pop_orders

import math
Expand All @@ -15,7 +16,7 @@
from typing import TYPE_CHECKING, Callable, Hashable, Literal, Mapping, MutableMapping, Sequence, overload

if TYPE_CHECKING:
from icepool.expression import MultisetExpression
from icepool.multiset_expression import MultisetExpression


class MixtureGenerator(MultisetGenerator[T, tuple[int]]):
Expand Down Expand Up @@ -63,7 +64,7 @@ def __init__(self,

def outcomes(self) -> Sequence[T]:
return icepool.sorted_union(*(inner.outcomes()
for inner in self._inner_generators))
for inner in self._inner_generators))

def output_arity(self) -> int:
result = None
Expand All @@ -79,21 +80,22 @@ def output_arity(self) -> int:
def _is_resolvable(self) -> bool:
return all(inner._is_resolvable() for inner in self._inner_generators)

def _generate_initial(self) -> InitialMultisetGenerator:
def _generate_initial(self) -> InitialMultisetGeneration:
yield from self._inner_generators.items()

def _generate_min(self, min_outcome) -> NextMultisetGenerator:
def _generate_min(self, min_outcome) -> PopMultisetGeneration:
raise RuntimeError(
'MixtureMultisetGenerator should have decayed to another generator type by this point.'
)

def _generate_max(self, max_outcome) -> NextMultisetGenerator:
def _generate_max(self, max_outcome) -> PopMultisetGeneration:
raise RuntimeError(
'MixtureMultisetGenerator should have decayed to another generator type by this point.'
)

def _preferred_pop_order(self) -> tuple[Order | None, PopOrderReason]:
return merge_pop_orders(*(inner._preferred_pop_order()
def _local_preferred_pop_order(
self) -> tuple[Order | None, PopOrderReason]:
return merge_pop_orders(*(inner._local_preferred_pop_order()
for inner in self._inner_generators))

@cached_property
Expand All @@ -107,7 +109,7 @@ def denominator(self) -> int:
return self._denominator

@property
def _hash_key(self) -> Hashable:
def _local_hash_key(self) -> Hashable:
# This is not intended to be cached directly, so we are a little loose here.
return MixtureGenerator, tuple(self._inner_generators.items())

Expand Down
Loading

0 comments on commit ee7c305

Please sign in to comment.