Skip to content

Commit

Permalink
#16871: Update conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
mcw-anasuya committed Jan 31, 2025
1 parent 1b22adb commit 37a278c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
)
# No typecast on inputs and optional output
def test_opt_output_no_typecast(input_shapes, dtype, ttnn_fn, device):
torch.manual_seed(0)
a_shape, b_shape, out_shape = input_shapes
ttnn_op = getattr(ttnn.experimental, ttnn_fn)

Expand Down Expand Up @@ -80,7 +81,7 @@ def test_opt_output_no_typecast(input_shapes, dtype, ttnn_fn, device):
golden_fn = ttnn.get_golden_function(ttnn_op)
torch_output_tensor = golden_fn(torch_input_tensor_a, torch_input_tensor_b)
status = ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor)
assert status >= 0.99
assert status >= 0.999


@skip_for_grayskull("Requires wormhole_b0 to run")
Expand All @@ -103,6 +104,7 @@ def test_opt_output_no_typecast(input_shapes, dtype, ttnn_fn, device):
)
# Typecast on both inputs and optional output
def test_opt_output_bf8b(input_shapes, dtype, ttnn_fn, device):
torch.manual_seed(0)
a_shape, b_shape, out_shape = input_shapes
ttnn_op = getattr(ttnn.experimental, ttnn_fn)

Expand Down Expand Up @@ -629,6 +631,7 @@ def test_inplace_sub_typecast_b(input_shapes, device):
@pytest.mark.parametrize("scalar", [-0.25, -16.5, 0.0, 0.05, 1.7, 19.0])
# Typecast on both input and optional tensor
def test_opt_output_scalar(input_shapes, ttnn_fn, scalar, device):
torch.manual_seed(0)
a_shape, out_shape = input_shapes
ttnn_op = getattr(ttnn.experimental, ttnn_fn)

Expand Down
42 changes: 14 additions & 28 deletions ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ Tensor BinaryNg<binary_op_type>::invoke(
const ttnn::DataType b_dtype = input_tensor_b.get_dtype();
const bool output_preallocated = optional_output_tensor.has_value();
const ttnn::DataType out_dtype =
output_preallocated ? optional_output_tensor.value().get_dtype() : output_dtype.value_or(a_dtype);
output_preallocated ? optional_output_tensor->get_dtype() : output_dtype.value_or(a_dtype);

if (output_dtype.has_value() && output_preallocated) {
TT_FATAL(
output_dtype.value() == out_dtype,
*output_dtype == out_dtype,
"If both output dtype and output tensor are provided, their dtypes should match");
}

Expand All @@ -50,7 +50,7 @@ Tensor BinaryNg<binary_op_type>::invoke(
input_tensor_b,
binary_op_type,
out_dtype,
output_preallocated ? optional_output_tensor.value().memory_config()
output_preallocated ? optional_output_tensor->memory_config()
: memory_config.value_or(input_tensor_a.memory_config()),
optional_output_tensor,
lhs_activations,
Expand All @@ -59,11 +59,9 @@ Tensor BinaryNg<binary_op_type>::invoke(
} else {
Tensor input_a = typecast_to(DataType::BFLOAT16, input_tensor_a);
Tensor input_b = typecast_to(DataType::BFLOAT16, input_tensor_b);
std::optional<Tensor> output_tensor =
(output_preallocated && !typecast_out) ? optional_output_tensor
: (output_preallocated && typecast_out)
? std::make_optional(ttnn::typecast(optional_output_tensor.value(), DataType::BFLOAT16))
: std::nullopt;
const auto output_tensor = output_preallocated and typecast_out
? ttnn::typecast(*optional_output_tensor, DataType::BFLOAT16)
: optional_output_tensor;

Tensor result = ttnn::prim::binary_ng(
queue_id,
Expand All @@ -77,12 +75,7 @@ Tensor BinaryNg<binary_op_type>::invoke(
rhs_activations,
post_activations);

if (output_preallocated && typecast_out) {
return ttnn::typecast(result, out_dtype, std::nullopt, optional_output_tensor);
} else if (typecast_out) {
return ttnn::typecast(result, out_dtype);
}
return (output_preallocated && !typecast_out) ? optional_output_tensor.value() : result;
return typecast_out ? ttnn::typecast(result, out_dtype, std::nullopt, optional_output_tensor) : result;
}
}

Expand Down Expand Up @@ -122,11 +115,11 @@ Tensor BinaryNg<binary_op_type>::invoke(
const ttnn::DataType a_dtype = input_tensor_a.get_dtype();
const bool output_preallocated = optional_output_tensor.has_value();
const ttnn::DataType out_dtype =
output_preallocated ? optional_output_tensor.value().get_dtype() : output_dtype.value_or(a_dtype);
output_preallocated ? optional_output_tensor->get_dtype() : output_dtype.value_or(a_dtype);

if (output_dtype.has_value() && output_preallocated) {
TT_FATAL(
output_dtype.value() == out_dtype,
*output_dtype == out_dtype,
"If both output dtype and output tensor are provided, their dtypes should match");
}

Expand All @@ -140,19 +133,17 @@ Tensor BinaryNg<binary_op_type>::invoke(
scalar,
binary_op_type,
out_dtype,
output_preallocated ? optional_output_tensor.value().memory_config()
output_preallocated ? optional_output_tensor->memory_config()
: memory_config.value_or(input_tensor_a.memory_config()),
optional_output_tensor,
lhs_activations,
rhs_activations,
post_activations);
} else {
Tensor input_a = typecast_to(DataType::BFLOAT16, input_tensor_a);
std::optional<Tensor> output_tensor =
(output_preallocated && !typecast_out) ? optional_output_tensor
: (output_preallocated && typecast_out)
? std::make_optional(ttnn::typecast(optional_output_tensor.value(), DataType::BFLOAT16))
: std::nullopt;
const auto output_tensor = output_preallocated and typecast_out
? ttnn::typecast(*optional_output_tensor, DataType::BFLOAT16)
: optional_output_tensor;

Tensor result = ttnn::prim::binary_ng(
queue_id,
Expand All @@ -166,12 +157,7 @@ Tensor BinaryNg<binary_op_type>::invoke(
rhs_activations,
post_activations);

if (output_preallocated && typecast_out) {
return ttnn::typecast(result, out_dtype, std::nullopt, optional_output_tensor);
} else if (typecast_out) {
return ttnn::typecast(result, out_dtype);
}
return (output_preallocated && !typecast_out) ? optional_output_tensor.value() : result;
return typecast_out ? ttnn::typecast(result, out_dtype, std::nullopt, optional_output_tensor) : result;
}
}

Expand Down

0 comments on commit 37a278c

Please sign in to comment.