Skip to content

Commit

Permalink
remove MultisetEvaluator.validate_arity() since the number of argum…
Browse files Browse the repository at this point in the history
…ents to `next_state()` usually works well enough

courtesy check in `output_arity()` that each input has `output_arity` exactly 1
  • Loading branch information
HighDiceRoller committed Dec 27, 2024
1 parent f2c0079 commit a8fd875
Show file tree
Hide file tree
Showing 9 changed files with 30 additions and 27 deletions.
7 changes: 6 additions & 1 deletion src/icepool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,12 @@
from icepool.generator.compound_keep import CompoundKeepGenerator
from icepool.generator.mixture import MixtureGenerator

from icepool.multiset_expression import MultisetExpression, implicit_convert_to_expression, InitialMultisetGeneration, PopMultisetGeneration, MultisetBindingError
from icepool.multiset_expression import (MultisetExpression,
implicit_convert_to_expression,
InitialMultisetGeneration,
PopMultisetGeneration,
MultisetArityError,
MultisetBindingError)

from icepool.generator.multiset_generator import MultisetGenerator
from icepool.generator.alignment import Alignment
Expand Down
1 change: 1 addition & 0 deletions src/icepool/evaluator/argsort.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__docformat__ = 'google'

from icepool.evaluator.multiset_evaluator import MultisetEvaluator
from icepool.multiset_expression import MultisetArityError

from icepool.order import Order, OrderReason
from typing import Any
Expand Down
1 change: 1 addition & 0 deletions src/icepool/evaluator/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import icepool
from icepool.evaluator.multiset_evaluator import MultisetEvaluator
from icepool.multiset_expression import MultisetArityError
from icepool.order import Order, OrderReason

import operator
Expand Down
4 changes: 0 additions & 4 deletions src/icepool/evaluator/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,6 @@ def _bound_inputs(self) -> 'tuple[icepool.MultisetExpression, ...]':
def bound_inputs(self) -> 'tuple[icepool.MultisetExpression, ...]':
return self._bound_inputs

def validate_arity(self, arity: int) -> None:
for subeval in self._sub_evaluators:
subeval.validate_arity(arity)

@cached_property
def _extra_arity(self) -> int:
return sum(expression.output_arity()
Expand Down
19 changes: 3 additions & 16 deletions src/icepool/evaluator/multiset_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,19 +179,6 @@ def bound_inputs(self) -> 'tuple[icepool.MultisetExpression, ...]':
"""
return ()

def validate_arity(self, arity: int) -> None:
"""An optional method to verify the total input arity.
This is called after any implicit conversion to expressions, but does
not include any `bound_inputs()`.
Overriding `next_state` with a fixed number of counts will make this
check redundant.
Raises:
`ValueError` if the total input arity is not valid.
"""

@cached_property
def _cache(
self
Expand Down Expand Up @@ -254,11 +241,11 @@ def evaluate(
from icepool.evaluator.multiset_function import MultisetFunctionEvaluator
return MultisetFunctionEvaluator(*inputs, evaluator=self)

self.validate_arity(
sum(expression.output_arity() for expression in inputs))

inputs = self.bound_inputs() + inputs

# This is kept to verify inputs to operators each have arity exactly 1.
total_arity = sum(input.output_arity() for input in inputs)

if not all(expression._is_resolvable() for expression in inputs):
return icepool.Die([])

Expand Down
4 changes: 2 additions & 2 deletions src/icepool/generator/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import icepool

from icepool.multiset_expression import InitialMultisetGeneration, PopMultisetGeneration
from icepool.multiset_expression import MultisetArityError, InitialMultisetGeneration, PopMultisetGeneration
from icepool.generator.multiset_generator import MultisetGenerator
from icepool.order import Order, OrderReason, merge_order_preferences

Expand Down Expand Up @@ -72,7 +72,7 @@ def output_arity(self) -> int:
if result is None:
result = inner.output_arity()
elif result != inner.output_arity():
raise ValueError('Inconsistent output_arity.')
raise MultisetArityError('Inconsistent output_arity.')
if result is None:
raise ValueError('Empty MixtureMultisetGenerator.')
return result
Expand Down
4 changes: 4 additions & 0 deletions src/icepool/multiset_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
Sequence, int]]


class MultisetArityError(ValueError):
"""Indicates that an arity was not the same as required."""


class MultisetBindingError(TypeError):
"""Indicates a bound multiset variable was found where a free variable was expected, or vice versa."""

Expand Down
12 changes: 10 additions & 2 deletions src/icepool/operator/multiset_operator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__docformat__ = 'google'

import icepool
from icepool.multiset_expression import MultisetExpression, InitialMultisetGeneration, PopMultisetGeneration
from icepool.multiset_expression import MultisetExpression, InitialMultisetGeneration, PopMultisetGeneration, MultisetArityError

import itertools
import math
Expand Down Expand Up @@ -51,7 +51,15 @@ def outcomes(self) -> Sequence[T]:
for child in self._children))

def output_arity(self) -> int:
"""Transforms only output 1 count. For multiple outputs, use @multiset_function."""
"""Operators only output 1 count.
Each input to `MultisetOperator` must only output 1 count as well.
For multiple outputs, use @multiset_function.
"""
if any(child.output_arity() != 1 for child in self._children):
raise MultisetArityError(
'Each input to MultisetOperator must output exactly 1 count.')
return 1

def _is_resolvable(self) -> bool:
Expand Down
5 changes: 3 additions & 2 deletions tests/deck_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from icepool import MultisetEvaluator, Deck
from icepool.evaluator import LargestStraightEvaluator
from icepool.multiset_expression import MultisetArityError

# no wraparound
best_run_evaluator = LargestStraightEvaluator()
Expand Down Expand Up @@ -62,8 +63,8 @@ def test_two_hand_sum_diff_size():
def test_multiple_bind_error():
deck = icepool.Deck(range(4), times=4)
deal = deck.deal(2, 2)
with pytest.raises(ValueError):
deal.unique()
with pytest.raises(MultisetArityError):
deal.unique().count()


def test_add():
Expand Down

0 comments on commit a8fd875

Please sign in to comment.