Skip to content

Commit

Permalink
fix: tests
Browse files Browse the repository at this point in the history
  • Loading branch information
clement-bonnet committed Mar 13, 2024
1 parent fb58be4 commit b492701
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions jumanji/environments/logic/sliding_tile_puzzle/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_sliding_tile_puzzle_step_jit(
) -> None:
"""Confirm that the step is only compiled once when jitted."""
up_action = jnp.array(0)
down_action = jnp.array(1)
down_action = jnp.array(2)

chex.clear_trace_counter()
step_fn = jax.jit(chex.assert_max_traces(sliding_tile_puzzle.step, n=1))
Expand Down Expand Up @@ -78,7 +78,7 @@ def test_sliding_tile_puzzle_get_action_mask(
# Check that the action mask is a boolean array with the correct shape.
assert action_mask.dtype == jnp.bool_
assert action_mask.shape == (4,)
assert jnp.array_equal(action_mask, jnp.array([False, True, False, True]))
assert jnp.array_equal(action_mask, jnp.array([False, True, True, False]))


def test_sliding_tile_puzzle_does_not_smoke(
Expand All @@ -95,22 +95,22 @@ def test_env_one_move_to_solve(sliding_tile_puzzle: SlidingTilePuzzle) -> None:
# Set up a state that is one move away from being solved.
one_move_away = jnp.array(
[
[3, 1, 2],
[0, 4, 5],
[6, 7, 8],
[1, 2, 3],
[4, 5, 0],
[7, 8, 6],
]
)
empty_tile_position = jnp.array([1, 0])
empty_tile_position = jnp.array([1, 2])
state = State(
puzzle=one_move_away,
empty_tile_position=empty_tile_position,
key=jax.random.PRNGKey(0),
step_count=0,
)

# The correct action to solve the puzzle is to move the empty tile down (action=1).
up_action = jnp.array(0)
next_state, timestep = sliding_tile_puzzle.step(state, up_action)
# The correct action to solve the puzzle is to move the empty tile down (action=2).
down_action = jnp.array(2)
next_state, timestep = sliding_tile_puzzle.step(state, down_action)

assert jnp.array_equal(next_state.puzzle, sliding_tile_puzzle.solved_puzzle)
assert timestep.last()
Expand All @@ -137,8 +137,8 @@ def test_env_legal_move_changes_board_as_expected(
# [4, 2, 5],
# [7, 8, 6],
# ]
# A legal move is to move the empty tile down (action=1).
action = jnp.array(1)
# A legal move is to move the empty tile down (action=2).
action = jnp.array(2)
next_state, _ = sliding_tile_puzzle.step(state, action)
expected_puzzle = jnp.array(
[
Expand All @@ -149,7 +149,7 @@ def test_env_legal_move_changes_board_as_expected(
)
assert jnp.array_equal(next_state.puzzle, expected_puzzle)

action = jnp.array(3)
action = jnp.array(1)
next_state, _ = sliding_tile_puzzle.step(next_state, action)
expected_puzzle = jnp.array(
[
Expand Down

0 comments on commit b492701

Please sign in to comment.