Skip to content

Commit

Permalink
#16143: Add validation check and typecast bfloat8_b, bfloat4_b
Browse files Browse the repository at this point in the history
  • Loading branch information
mcw-anasuya committed Jan 11, 2025
1 parent c06a1a4 commit 4b1f750
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 29 deletions.
130 changes: 104 additions & 26 deletions tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,8 @@
"mul",
"div",
"bias_gelu",
"add_",
"sub_",
"mul_",
"gt_",
"lt_",
"lte_",
"gte_",
"eq_",
"ne_",
"logical_and_",
"logical_or_",
"logical_xor_",
"ldexp_",
"logaddexp_",
"logaddexp2_",
"squared_difference_",
"bias_gelu_",
}

activation_fns = {
"EXP": torch.exp,
"GELU": torch.nn.functional.gelu,
Expand Down Expand Up @@ -144,22 +128,16 @@ def rand_bf16_gen(shape, device, *, min=0, max=1):
),
*parameters({"add"}, {((), (), (op,)) for op in activation_fns.keys()}),
}.difference(
parameters({"eq", "ne", "ne_"}, {square_lhs, sin_rhs, exp_floor_lhs_exp_rhs, log_lhs_sqrt_abs_post}),
parameters({"eq_"}, {square_lhs, sin_rhs, exp_floor_lhs_exp_rhs}),
parameters({"eq", "ne"}, {square_lhs, sin_rhs, exp_floor_lhs_exp_rhs, log_lhs_sqrt_abs_post}),
parameters({"logaddexp", "logaddexp2"}, {floor_lhs_ceil_rhs_cos_post}),
parameters({"gte", "lt", "lte", "lt_"}, {exp_floor_lhs_exp_rhs, log_lhs_sqrt_abs_post}),
parameters({"lte_"}, {sin_rhs, log_lhs_sqrt_abs_post}),
parameters({"gte_"}, {exp_floor_lhs_exp_rhs}),
parameters({"gt_"}, {sin_rhs}),
parameters({"logical_and", "logical_or", "logical_xor", "bias_gelu", "logical_or_"}, {log_lhs_sqrt_abs_post}),
parameters({"gte", "lt", "lte"}, {exp_floor_lhs_exp_rhs, log_lhs_sqrt_abs_post}),
parameters({"logical_and", "logical_or", "logical_xor", "bias_gelu"}, {log_lhs_sqrt_abs_post}),
parameters({"div"}, {exp_post, tanh_post, exp2_post, expm1_post, i0_post, tan_post}),
parameters({"sub"}, {log_post, log2_post, log10_post}),
parameters({"ldexp"}, {erfinv_post, tan_post, floor_post, ceil_post}),
parameters({"squared_difference"}, {erfinv_post, i0_post}),
parameters({"bias_gelu_"}, {square_lhs, log_lhs_sqrt_abs_post}),
parameters({"add"}, {tan_post, tanh_post}),
{("mul", log_lhs_sqrt_abs_post)},
{("mul_", log_lhs_sqrt_abs_post)},
),
)
def test_binary_scalar_ops(a_shape, b_shape, ttnn_fn, activations, device):
Expand Down Expand Up @@ -198,6 +176,106 @@ def compare(tt, pt):
assert compare([out_tt], [out_pt])


binary_inplace_fns = {
"add_",
"sub_",
"mul_",
"gt_",
"lt_",
"lte_",
"gte_",
"eq_",
"ne_",
"logical_and_",
"logical_or_",
"logical_xor_",
"ldexp_",
"logaddexp_",
"logaddexp2_",
"squared_difference_",
"bias_gelu_",
}


@pytest.mark.parametrize(
"a_shape, b_shape",
(
(torch.Size([5, 3, 128, 64]), torch.Size([1, 3, 128, 1])),
(torch.Size([5, 3, 32, 32]), torch.Size([1, 1, 1, 1])),
(torch.Size([5, 1, 1, 128]), torch.Size([5, 1, 1, 1])),
),
)
@pytest.mark.parametrize(
"ttnn_fn, activations",
{
*parameters(
binary_inplace_fns,
{
no_activations,
square_lhs,
sin_rhs,
floor_lhs_ceil_rhs_cos_post,
exp_floor_lhs_exp_rhs,
log_lhs_sqrt_abs_post,
},
)
}.difference(
parameters({"eq_", "ne_"}, {square_lhs, sin_rhs, exp_floor_lhs_exp_rhs}),
parameters({"lt_", "gte_"}, {exp_floor_lhs_exp_rhs}),
parameters({"lte_"}, {sin_rhs, log_lhs_sqrt_abs_post}),
parameters({"bias_gelu_"}, {log_lhs_sqrt_abs_post}),
parameters({"mul_"}, {log_lhs_sqrt_abs_post}),
),
)
def test_inplace_binary_scalar_ops(a_shape, b_shape, ttnn_fn, activations, device):
torch.manual_seed(0)

torch_input_tensor_a, input_tensor_a = rand_bf16_gen(a_shape, device)
torch_input_tensor_b, input_tensor_b = rand_bf16_gen(b_shape, device)
ttnn_op = getattr(ttnn.experimental, ttnn_fn)
lhs, rhs, post = ([getattr(ttnn.UnaryOpType, op) for op in ops] for ops in activations)
golden_lhs, golden_rhs, golden_post = ((activation_fns[op] for op in ops) for ops in activations)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG
)
input_tensor_b = ttnn.from_torch(
torch_input_tensor_b, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG
)

for golden_activation in golden_lhs:
torch_input_tensor_a = golden_activation(torch_input_tensor_a).bfloat16()

for golden_activation in golden_rhs:
torch_input_tensor_b = golden_activation(torch_input_tensor_b).bfloat16()

golden_fn = ttnn.get_golden_function(ttnn_op)
torch_output_tensor = golden_fn(torch_input_tensor_a, torch_input_tensor_b).bfloat16()

for golden_activation in golden_post:
torch_output_tensor = golden_activation(torch_output_tensor).bfloat16()

ttnn_op(input_tensor_a, input_tensor_b, lhs_activations=lhs, rhs_activations=rhs, post_activations=post)
output_tensor = ttnn.to_torch(input_tensor_a)
assert output_tensor.shape == torch_output_tensor.shape

def compare(output_tensor, torch_output_tensor):
imprecise_cases = {
*parameters(
{"logaddexp2_"}, {exp_floor_lhs_exp_rhs, no_activations, sin_rhs, log_lhs_sqrt_abs_post, square_lhs}
),
*parameters({"bias_gelu_"}, {no_activations, sin_rhs, square_lhs}),
*parameters({"gt_", "lte_", "gte_", "lt_"}, {sin_rhs, square_lhs}),
}
return (
ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor) >= 0.98
if (ttnn_fn, activations) in imprecise_cases
else ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor) >= 0.999
)

assert compare(output_tensor, torch_output_tensor)


@pytest.mark.parametrize(
"a_shape, b_shape",
(
Expand Down
15 changes: 12 additions & 3 deletions ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,14 @@ Tensor InplaceBinaryNg<binary_op_type>::invoke(
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> lhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> rhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> post_activations) {
auto input_a = typecast_to(DataType::BFLOAT16, input_tensor_a);
auto input_b = typecast_to(DataType::BFLOAT16, input_tensor_b);
auto input_a =
(input_tensor_a.get_dtype() == DataType::BFLOAT8_B || input_tensor_a.get_dtype() == DataType::BFLOAT4_B)
? typecast_to(DataType::BFLOAT16, input_tensor_a)
: input_tensor_a;
auto input_b =
(input_tensor_b.get_dtype() == DataType::BFLOAT8_B || input_tensor_b.get_dtype() == DataType::BFLOAT4_B)
? typecast_to(DataType::BFLOAT16, input_tensor_b)
: input_tensor_b;

return ttnn::prim::binary_ng(
queue_id,
Expand Down Expand Up @@ -160,7 +166,10 @@ Tensor InplaceBinaryNg<binary_op_type>::invoke(
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> lhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> rhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> post_activations) {
auto input_a = typecast_to(DataType::BFLOAT16, input_tensor_a);
auto input_a =
(input_tensor_a.get_dtype() == DataType::BFLOAT8_B || input_tensor_a.get_dtype() == DataType::BFLOAT4_B)
? typecast_to(DataType::BFLOAT16, input_tensor_a)
: input_tensor_a;

return ttnn::prim::binary_ng(
queue_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,16 @@ BinaryNgDeviceOperation::program_factory_t BinaryNgDeviceOperation::select_progr
BinaryNgDeviceOperation::tensor_return_value_t BinaryNgDeviceOperation::create_output_tensors(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
const auto& output_tensor = tensor_args.output_tensor;
const auto& input_tensor_a = tensor_args.input_tensor_a;
if (output_tensor.has_value()) {
if (input_tensor_a.buffer()->address() == output_tensor->buffer()->address()) {
const auto& output_shape = compute_output_specs(operation_attributes, tensor_args);
TT_FATAL(
input_tensor_a.logical_shape() == output_shape.logical_shape(),
"In-place operation rule violation: Input tensor does not support output tensor shape.");
}
}

if (output_tensor.has_value()) {
return output_tensor.value();
}
Expand Down

0 comments on commit 4b1f750

Please sign in to comment.