diff --git a/gym_go/gogame.py b/gym_go/gogame.py index 4d68339..5cc3c71 100644 --- a/gym_go/gogame.py +++ b/gym_go/gogame.py @@ -153,7 +153,7 @@ def batch_next_states(batch_states, batch_action1d, canonical=False): def invalid_moves(state): # return a fixed size binary vector if game_ended(state): - return np.zeros(action_size(state)) + return np.ones(action_size(state)) return np.append(state[govars.INVD_CHNL].flatten(), 0) diff --git a/gym_go/tests/test_invalid_moves.py b/gym_go/tests/test_invalid_moves.py index 3693bbe..e94147d 100644 --- a/gym_go/tests/test_invalid_moves.py +++ b/gym_go/tests/test_invalid_moves.py @@ -4,7 +4,7 @@ import gym import numpy as np -from gym_go import govars +from gym_go import govars, gogame class TestGoEnvInvalidMoves(unittest.TestCase): @@ -175,6 +175,8 @@ def test_invalid_game_already_over_move(self): with self.assertRaises(Exception): self.env.step((0, 0)) + self.assertTrue((gogame.invalid_moves(self.env.state()) == 1).all()) + def test_small_suicide(self): """ 7, 8, 0,