Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure int masks and ranks #110

Merged
merged 1 commit into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion flowjax/bijections/masked_autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def __init__(
else:
self.cond_shape = (cond_dim,)
# we give conditioning variables rank -1 (no masking of edges to output)
in_ranks = jnp.hstack((jnp.arange(dim), -jnp.ones(cond_dim)))
in_ranks = jnp.hstack(
(jnp.arange(dim), -jnp.ones(cond_dim, dtype=jnp.int32))
)

hidden_ranks = jnp.arange(nn_width) % dim
out_ranks = jnp.repeat(jnp.arange(dim), transformer_init_params.size)
Expand Down
2 changes: 1 addition & 1 deletion flowjax/nn/masked_autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
key (KeyArray): Jax PRNGKey.
"""
in_ranks, hidden_ranks, out_ranks = (
jnp.asarray(a) for a in (in_ranks, hidden_ranks, out_ranks)
jnp.asarray(a, jnp.int32) for a in (in_ranks, hidden_ranks, out_ranks)
)
masks = []
if depth == 0:
Expand Down
12 changes: 8 additions & 4 deletions tests/test_masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,29 @@ def test_rank_based_mask():
expected_mask = jnp.array([[0, 0], [1, 0], [1, 0], [1, 1]], dtype=jnp.int32)

mask = rank_based_mask(in_ranks, out_ranks)
assert mask.dtype == jnp.int32
assert jnp.all(expected_mask == mask)

in_ranks = jnp.array([0, 0, 1, 1])
out_ranks = jnp.array([0, 1])

expected_mask = jnp.array([[0, 0, 0, 0], [1, 1, 0, 0]], dtype=jnp.int32)
mask = rank_based_mask(in_ranks, out_ranks)
assert mask.dtype == jnp.int32
assert jnp.all(expected_mask == mask)


def test_block_tril_mask():
args = [(1, 2), 3]
expected = jnp.array([[0, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0]])
result = block_tril_mask(*args)
assert jnp.all(expected == result)
mask = block_tril_mask(*args)
assert mask.dtype == jnp.int32
assert jnp.all(expected == mask)


def test_block_diag_mask():
args = [(1, 2), 3]
expected = jnp.array([[1, 1, 0, 0, 0, 0], [0, 0, 1, 1, 0, 0], [0, 0, 0, 0, 1, 1]])
result = block_diag_mask(*args)
assert jnp.all(expected == result)
mask = block_diag_mask(*args)
assert mask.dtype == jnp.int32
assert jnp.all(expected == mask)