Skip to content

Commit

Permalink
chore: sacrifice the required goats to the flake8 gods
Browse files Browse the repository at this point in the history
 - I'm not sure why flake8 is so aggro all of the sudden. Freaking out about basic type hints and whatnot. 😭
  • Loading branch information
justindujardin committed Nov 29, 2023
1 parent be5b329 commit 021f7ec
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 17 deletions.
2 changes: 0 additions & 2 deletions .flake8

This file was deleted.

6 changes: 3 additions & 3 deletions mathy_envs/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def get_state_transition(self, env_state: MathyEnvState) -> time_step.TimeStep:
)

def get_next_state(
self, env_state: MathyEnvState, action: Union[int, ActionType]
self, env_state: MathyEnvState, action: Union[int, np.int64, ActionType]
) -> Tuple[MathyEnvState, time_step.TimeStep, ExpressionChangeRule]:
"""
# Parameters
Expand Down Expand Up @@ -597,7 +597,7 @@ def to_hash_key(self, env_state: MathyEnvState) -> str:
"""Convert env_state to a string for MCTS cache"""
return env_state.agent.problem

def to_action(self, action: Union[int, ActionType]) -> ActionType:
def to_action(self, action: Union[int, np.int64, ActionType]) -> ActionType:
"""Resolve a given action input to a tuple of (rule_index, node_index).
When given an int, it is treated as an index into the flattened 2d action
Expand All @@ -606,4 +606,4 @@ def to_action(self, action: Union[int, ActionType]) -> ActionType:
return action
token_index = action % self.max_seq_len
action_index = int((action - token_index) / self.max_seq_len)
return action_index, token_index
return action_index, int(token_index)
2 changes: 1 addition & 1 deletion mathy_envs/gym/masked_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, n: int, mask: np.ndarray):
super(MaskedDiscrete, self).__init__(n) # type:ignore
self.update_mask(mask)

def sample(self, mask: Optional[np.ndarray] = None) -> int:
def sample(self, mask: Optional[np.ndarray] = None) -> np.int64:
mask = self.mask if mask is None else mask
probability = self.mask / np.sum(self.mask)
return self.np_random.choice(self.n, p=probability)
9 changes: 5 additions & 4 deletions mathy_envs/gym/mathy_gym_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,16 @@ def __init__(
time = 1
obs_size = mask + values + nodes + type + time
self.observation_space = spaces.Box(
low=0, high=1, shape=(obs_size,), dtype=float
low=0, high=1, shape=(obs_size,), dtype=np.float32
)

@property
def action_size(self) -> int:
return self.mathy.action_size

def step(
self, action: Union[int, ActionType]
) -> Tuple[np.ndarray, Any, bool, Dict[str, object]]:
self, action: Union[int, np.int64, ActionType]
) -> Tuple[np.ndarray, Any, bool, bool, Dict[str, object]]:
assert self.state is not None, "call reset() before stepping the environment"
self.state, transition, change = self.mathy.get_next_state(self.state, action)
done = is_terminal_transition(transition)
Expand All @@ -79,7 +79,8 @@ def step(
}
if done:
info["win"] = transition.reward > 0.0
return self._observe(self.state), transition.reward, done, info
# TODO: What is the second done here, need to update it.
return self._observe(self.state), transition.reward, done, done, info

def _observe(self, state: MathyEnvState) -> np.ndarray:
"""Observe the environment at the given state, updating the observation
Expand Down
14 changes: 7 additions & 7 deletions mathy_envs/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ def empty(cls, template: "MathyObservation") -> "MathyObservation":


# fmt: off
MathyObservation.nodes.__doc__ = "tree node types in the current environment state shape=[n,]" # noqa
MathyObservation.mask.__doc__ = "0/1 mask where 0 indicates an invalid action shape=[n,]" # noqa
MathyObservation.values.__doc__ = "tree node value sequences, with non number indices set to 0.0 shape=[n,]" # noqa
MathyObservation.type.__doc__ = "two column hash of problem environment type shape=[2,]" # noqa
MathyObservation.time.__doc__ = "float value between 0.0 and 1.0 indicating the time elapsed shape=[1,]" # noqa
MathyObservation.nodes.__doc__ = "tree node types in the current environment state shape=[n,]" # noqa
MathyObservation.mask.__doc__ = "0/1 mask where 0 indicates an invalid action shape=[n,]" # noqa
MathyObservation.values.__doc__ = "tree node value sequences, with non number indices set to 0.0 shape=[n,]" # noqa
MathyObservation.type.__doc__ = "two column hash of problem environment type shape=[2,]" # noqa
MathyObservation.time.__doc__ = "float value between 0.0 and 1.0 indicating the time elapsed shape=[1,]" # noqa
# fmt: on


Expand All @@ -75,8 +75,8 @@ class MathyEnvStateStep(NamedTuple):


# fmt: off
MathyEnvStateStep.raw.__doc__ = "the input text at the timestep" # noqa
MathyEnvStateStep.action.__doc__ = "a tuple indicating the chosen action and the node it was applied to" # noqa
MathyEnvStateStep.raw.__doc__ = "the input text at the timestep" # noqa
MathyEnvStateStep.action.__doc__ = "a tuple indicating the chosen action and the node it was applied to" # noqa
# fmt: on


Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@ ignore_missing_imports = True
[mypy-wasabi.*]
ignore_missing_imports = True

[flake8]
max-line-length = 88
ignore = E701,E251

0 comments on commit 021f7ec

Please sign in to comment.