Skip to content

Commit

Permalink
this?
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Jan 24, 2025
1 parent 993c6a7 commit 99e78f4
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions tests/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,17 @@ def test_fp8_is_reasonable():
In, Out, key=jrandom.PRNGKey(0), dot_general=hax.quantization.Fp8DotGeneralOp.init(), init_scale=0.1
)

input = hax.random.normal(jrandom.PRNGKey(3), In)
input = hax.random.normal(jrandom.PRNGKey(3), In) * 0.1
input2 = hax.random.normal(jrandom.PRNGKey(2), In) * 0.1
output = linear(input)
fp8_output = fp8_linear(input)
fp8_output2 = fp8_linear(input2)

assert output.shape == fp8_output.shape
assert output.dtype == fp8_output.dtype

assert_trees_all_close(output.array, fp8_output.array, atol=1e-2, rtol=5e-2)
assert_trees_all_close(output.array, fp8_output.array, atol=2e-2, rtol=1e-2)
assert not jnp.allclose(fp8_output2.array, fp8_output.array, atol=2e-2, rtol=1e-2)


# https://github.com/google/flax/blob/6f2b08e024c2fd2f8cec42a6c82408cb35412319/tests/linen/linen_test.py#L1222
Expand Down

0 comments on commit 99e78f4

Please sign in to comment.