Skip to content

Commit

Permalink
#16144: migrate float32 support and bitwise ops
Browse files Browse the repository at this point in the history
  • Loading branch information
KalaivaniMCW committed Jan 17, 2025
1 parent 8982a37 commit fbeabc4
Show file tree
Hide file tree
Showing 31 changed files with 1,796 additions and 261 deletions.
110 changes: 109 additions & 1 deletion tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import (
compare_pcc,
)
from models.utility_functions import skip_for_grayskull
from models.utility_functions import skip_for_grayskull, torch_random
from itertools import product as parameters
from functools import partial
from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt


binary_fns = {
Expand Down Expand Up @@ -255,3 +257,109 @@ def test_01_volume_tensors(device, a, b, c_golden, memory_config):

assert c.tolist() == c_golden


@skip_for_grayskull("Requires wormhole_b0 to run")
@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 1, 1]), torch.Size([5, 3, 32, 32])),
(torch.Size([5, 1, 64, 1]), torch.Size([1, 3, 1, 128])),
(torch.Size([5, 1, 1, 64]), torch.Size([1, 3, 128, 1])),
(torch.Size([5, 1, 1]), torch.Size([1, 32, 128])),
),
)
@pytest.mark.parametrize(
"ttnn_fn",
[
ttnn.experimental.add,
ttnn.experimental.sub,
ttnn.experimental.mul,
ttnn.experimental.div,
ttnn.experimental.rsub,
ttnn.experimental.eq,
ttnn.experimental.ne,
ttnn.experimental.gt,
ttnn.experimental.gte,
ttnn.experimental.lt,
ttnn.experimental.lte,
ttnn.experimental.logical_or,
ttnn.experimental.logical_xor,
ttnn.experimental.logical_and,
ttnn.experimental.ldexp,
ttnn.experimental.logaddexp,
ttnn.experimental.logaddexp2,
ttnn.experimental.squared_difference,
ttnn.experimental.bias_gelu,
],
)
@pytest.mark.parametrize(
"dtype",
([ttnn.float32]),
)
def test_binary_sfpu_ops(input_shapes, dtype, ttnn_fn, device):
a_shape, b_shape = input_shapes

a_pt = gen_func_with_cast_tt(partial(torch_random, low=-50, high=50, dtype=torch.float32), dtype)(a_shape)
b_pt = gen_func_with_cast_tt(partial(torch_random, low=-50, high=50, dtype=torch.float32), dtype)(b_shape)

a_tt = ttnn.from_torch(
a_pt, dtype=dtype, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG
)
b_tt = ttnn.from_torch(
b_pt, dtype=dtype, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG
)
cq_id = 0
out_tt = ttnn_fn(a_tt, b_tt, queue_id=cq_id)
tt_out = ttnn.to_torch(out_tt)

golden_fn = ttnn.get_golden_function(ttnn_fn)
out_pt = golden_fn(a_pt, b_pt)
status = ttnn.pearson_correlation_coefficient(out_pt, tt_out)
assert status >= 0.999


@skip_for_grayskull("Requires wormhole_b0 to run")
@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 1, 1]), torch.Size([5, 3, 32, 32])),
(torch.Size([5, 1, 64, 1]), torch.Size([1, 3, 1, 128])),
(torch.Size([5, 1, 1, 64]), torch.Size([1, 3, 128, 1])),
(torch.Size([5, 1, 1]), torch.Size([1, 32, 128])),
),
)
@pytest.mark.parametrize(
"ttnn_fn",
[
ttnn.experimental.bitwise_and,
ttnn.experimental.bitwise_or,
ttnn.experimental.bitwise_xor,
ttnn.experimental.bitwise_left_shift,
ttnn.experimental.bitwise_right_shift,
],
)
@pytest.mark.parametrize(
"dtype",
([ttnn.int32]),
)
def test_binary_sfpu_bitwise_ops(input_shapes, dtype, ttnn_fn, device):
a_shape, b_shape = input_shapes

a_pt = gen_func_with_cast_tt(partial(torch_random, low=-100, high=100, dtype=torch.int32), dtype)(a_shape)
b_pt = gen_func_with_cast_tt(partial(torch_random, low=0, high=31, dtype=torch.int32), dtype)(b_shape)

a_tt = ttnn.from_torch(
a_pt, dtype=dtype, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG
)
b_tt = ttnn.from_torch(
b_pt, dtype=dtype, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG
)
cq_id = 0
out_tt = ttnn_fn(a_tt, b_tt, queue_id=cq_id)
tt_out = ttnn.to_torch(out_tt)

golden_fn = ttnn.get_golden_function(ttnn_fn)
out_pt = golden_fn(a_pt, b_pt)

status = ttnn.pearson_correlation_coefficient(out_pt, tt_out)
assert status >= 0.999
23 changes: 0 additions & 23 deletions tests/ttnn/unit_tests/operations/eltwise/test_binary_fp32.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def test_sub_fp32(device, ttnn_function):
z_torch = golden_fn(x_torch, y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_sub = ttnn.subtract(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_sub)

Expand All @@ -45,7 +44,6 @@ def test_rsub_fp32(device, ttnn_function):
z_torch = golden_fn(x_torch, y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_sub = ttnn.rsub(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_sub)

Expand All @@ -67,7 +65,6 @@ def test_add_fp32(device, ttnn_function):
z_torch = golden_fn(x_torch, y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_add = ttnn.add(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_add)

Expand All @@ -89,7 +86,6 @@ def test_add_int32(device, ttnn_function):
z_torch = golden_fn(x_torch, y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_add = ttnn.add(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_add)

Expand All @@ -111,7 +107,6 @@ def test_mul_fp32(device, ttnn_function):
z_torch = golden_fn(x_torch, y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_out = ttnn.mul(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_out)

Expand Down Expand Up @@ -139,7 +134,6 @@ def test_div_fp32(device, ttnn_function):
z_torch = golden_fn(x_torch, y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_div = ttnn.divide(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_div)

Expand Down Expand Up @@ -220,7 +214,6 @@ def test_pow_fp32(device, ttnn_function):
z_torch = golden_fn(x_torch, y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_pow = ttnn.pow(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_pow)

Expand All @@ -241,7 +234,6 @@ def test_add_fp32_activ(device, ttnn_function):
z_torch = torch.square(x_torch + y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_add = ttnn.add(x_tt, y_tt, activations=[ttnn.UnaryWithParam(ttnn.UnaryOpType.POWER, 2)])
tt_out = ttnn.to_torch(z_tt_add)

Expand Down Expand Up @@ -271,7 +263,6 @@ def test_add_fp32_input_activ(device, ttnn_function, shape):
z_torch = torch.square(torch.nn.functional.silu(x_torch) + y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_add = ttnn.add(
x_tt,
y_tt,
Expand All @@ -298,7 +289,6 @@ def test_logaddexp_fp32(device, ttnn_function):
z_torch = golden_fn(x_torch, y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_out = ttnn.logaddexp(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_out)

Expand All @@ -320,7 +310,6 @@ def test_logaddexp2_fp32(device, ttnn_function):
z_torch = golden_fn(x_torch, y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_out = ttnn.logaddexp2(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_out)

Expand All @@ -342,7 +331,6 @@ def test_ldexp_fp32(device, ttnn_function):
z_torch = golden_fn(x_torch, y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_out = ttnn.ldexp(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_out)

Expand All @@ -364,7 +352,6 @@ def test_bias_gelu_fp32(device, ttnn_function):
z_torch = golden_fn(x_torch, y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_out = ttnn.bias_gelu(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_out)

Expand All @@ -386,7 +373,6 @@ def test_squared_difference_fp32(device, ttnn_function):
z_torch = golden_fn(x_torch, y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_out = ttnn.squared_difference(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_out)

Expand All @@ -408,7 +394,6 @@ def test_logical_or_fp32(device, ttnn_function):
z_torch = golden_fn(x_torch, y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_out = ttnn.logical_or(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_out)

Expand All @@ -430,7 +415,6 @@ def test_logical_xor_fp32(device, ttnn_function):
z_torch = golden_fn(x_torch, y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_out = ttnn.logical_xor(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_out)

Expand All @@ -452,7 +436,6 @@ def test_logical_and_fp32(device, ttnn_function):
z_torch = golden_fn(x_torch, y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_out = ttnn.logical_and(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_out)

Expand All @@ -479,7 +462,6 @@ def test_relational_fp32(device, ttnn_function):
z_torch = golden_fn(x_torch, y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_out = ttnn_function(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_out)

Expand All @@ -501,7 +483,6 @@ def test_bitwise_and(device, ttnn_function):
z_torch = golden_fn(x_torch, y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_out = ttnn.bitwise_and(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_out)

Expand All @@ -523,7 +504,6 @@ def test_bitwise_or(device, ttnn_function):
z_torch = golden_fn(x_torch, y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_out = ttnn.bitwise_or(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_out)

Expand All @@ -545,7 +525,6 @@ def test_bitwise_xor(device, ttnn_function):
z_torch = golden_fn(x_torch, y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_out = ttnn.bitwise_xor(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_out)

Expand All @@ -567,7 +546,6 @@ def test_bitwise_left_shift(device, ttnn_function):
z_torch = golden_fn(x_torch, y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_out = ttnn.bitwise_left_shift(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_out)

Expand All @@ -589,7 +567,6 @@ def test_bitwise_right_shift(device, ttnn_function):
z_torch = golden_fn(x_torch, y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_out = ttnn.bitwise_right_shift(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_out)

Expand Down
Loading

0 comments on commit fbeabc4

Please sign in to comment.