From 023014bed4b242d9ec95cfea9e0ec2c46035efe0 Mon Sep 17 00:00:00 2001 From: mcw-anasuya Date: Sat, 11 Jan 2025 00:59:38 +0000 Subject: [PATCH] #16143: Add validation check and typecast bfloat8_b, bfloat4_b --- .../operations/eltwise/test_binary_bcast.py | 262 ++++++++++++++++-- .../eltwise/binary_ng/binary_ng.cpp | 146 ++++++---- .../eltwise/binary_ng/binary_ng.hpp | 19 +- .../eltwise/binary_ng/binary_ng_pybind.cpp | 15 +- .../device/binary_ng_device_operation.cpp | 12 + 5 files changed, 350 insertions(+), 104 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py b/tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py index cd869ae24e8..4e76ab1ea33 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py @@ -32,23 +32,6 @@ "mul", "div", "bias_gelu", - "add_", - "sub_", - "mul_", - "gt_", - "lt_", - "lte_", - "gte_", - "eq_", - "ne_", - "logical_and_", - "logical_or_", - "logical_xor_", - "ldexp_", - "logaddexp_", - "logaddexp2_", - "squared_difference_", - "bias_gelu_", } activation_fns = { "EXP": torch.exp, @@ -144,22 +127,16 @@ def rand_bf16_gen(shape, device, *, min=0, max=1): ), *parameters({"add"}, {((), (), (op,)) for op in activation_fns.keys()}), }.difference( - parameters({"eq", "ne", "ne_"}, {square_lhs, sin_rhs, exp_floor_lhs_exp_rhs, log_lhs_sqrt_abs_post}), - parameters({"eq_"}, {square_lhs, sin_rhs, exp_floor_lhs_exp_rhs}), + parameters({"eq", "ne"}, {square_lhs, sin_rhs, exp_floor_lhs_exp_rhs, log_lhs_sqrt_abs_post}), parameters({"logaddexp", "logaddexp2"}, {floor_lhs_ceil_rhs_cos_post}), - parameters({"gte", "lt", "lte", "lt_"}, {exp_floor_lhs_exp_rhs, log_lhs_sqrt_abs_post}), - parameters({"lte_"}, {sin_rhs, log_lhs_sqrt_abs_post}), - parameters({"gte_"}, {exp_floor_lhs_exp_rhs}), - parameters({"gt_"}, {sin_rhs}), - parameters({"logical_and", "logical_or", "logical_xor", "bias_gelu", "logical_or_"}, {log_lhs_sqrt_abs_post}), + parameters({"gte", "lt", "lte"}, {exp_floor_lhs_exp_rhs, log_lhs_sqrt_abs_post}), + parameters({"logical_and", "logical_or", "logical_xor", "bias_gelu"}, {log_lhs_sqrt_abs_post}), parameters({"div"}, {exp_post, tanh_post, exp2_post, expm1_post, i0_post, tan_post}), parameters({"sub"}, {log_post, log2_post, log10_post}), parameters({"ldexp"}, {erfinv_post, tan_post, floor_post, ceil_post}), parameters({"squared_difference"}, {erfinv_post, i0_post}), - parameters({"bias_gelu_"}, {square_lhs, log_lhs_sqrt_abs_post}), parameters({"add"}, {tan_post, tanh_post}), {("mul", log_lhs_sqrt_abs_post)}, - {("mul_", log_lhs_sqrt_abs_post)}, ), ) def test_binary_scalar_ops(a_shape, b_shape, ttnn_fn, activations, device): @@ -276,3 +253,236 @@ def test_01_volume_tensors(device, a, b, c_golden, memory_config): c = ttnn.to_torch(ttnn_c).reshape((-1)) assert c.tolist() == c_golden + + +binary_inplace_fns = { + "add_", + "sub_", + "mul_", + "div_", + "gt_", + "lt_", + "lte_", + "gte_", + "eq_", + "ne_", + "logical_and_", + "logical_or_", + "logical_xor_", + "ldexp_", + "logaddexp_", + "logaddexp2_", + "squared_difference_", + "bias_gelu_", +} + + +@pytest.mark.parametrize( + "a_shape, b_shape", + ( + (torch.Size([5, 3, 128, 64]), torch.Size([1, 3, 128, 1])), + (torch.Size([5, 3, 32, 32]), torch.Size([1, 1, 1, 1])), + (torch.Size([5, 1, 1, 128]), torch.Size([5, 1, 1, 1])), + ), +) +@pytest.mark.parametrize( + "ttnn_fn, activations", + { + *parameters( + binary_inplace_fns, + { + no_activations, + square_lhs, + sin_rhs, + floor_lhs_ceil_rhs_cos_post, + exp_floor_lhs_exp_rhs, + log_lhs_sqrt_abs_post, + }, + ) + }.difference( + parameters({"eq_", "ne_"}, {square_lhs, sin_rhs, exp_floor_lhs_exp_rhs}), + parameters({"lt_", "gte_"}, {exp_floor_lhs_exp_rhs}), + parameters({"lte_"}, {sin_rhs, log_lhs_sqrt_abs_post}), + parameters({"bias_gelu_"}, {log_lhs_sqrt_abs_post}), + parameters({"mul_"}, {log_lhs_sqrt_abs_post}), + ), +) +@skip_for_grayskull("Possible accuracy issues with grayskull") +def test_inplace_binary_ops_with_tensor(a_shape, b_shape, ttnn_fn, activations, device): + torch.manual_seed(0) + + ttnn_op = getattr(ttnn.experimental, ttnn_fn) + lhs, rhs, post = ([getattr(ttnn.UnaryOpType, op) for op in ops] for ops in activations) + golden_lhs, golden_rhs, golden_post = ((activation_fns[op] for op in ops) for ops in activations) + min, max = (1, 0) if ttnn_fn == "div_" else (0, 1) + + torch_input_tensor_a, input_tensor_a = rand_bf16_gen(a_shape, device) + torch_input_tensor_b, input_tensor_b = rand_bf16_gen(b_shape, device, min=min, max=max) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG + ) + input_tensor_b = ttnn.from_torch( + torch_input_tensor_b, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG + ) + + for golden_activation in golden_lhs: + torch_input_tensor_a = golden_activation(torch_input_tensor_a).bfloat16() + + for golden_activation in golden_rhs: + torch_input_tensor_b = golden_activation(torch_input_tensor_b).bfloat16() + + golden_fn = ttnn.get_golden_function(ttnn_op) + torch_output_tensor = golden_fn(torch_input_tensor_a, torch_input_tensor_b).bfloat16() + + for golden_activation in golden_post: + torch_output_tensor = golden_activation(torch_output_tensor).bfloat16() + + ttnn_op(input_tensor_a, input_tensor_b, lhs_activations=lhs, rhs_activations=rhs, post_activations=post) + output_tensor = ttnn.to_torch(input_tensor_a) + assert output_tensor.shape == torch_output_tensor.shape + + def compare(output_tensor, torch_output_tensor): + imprecise_cases = { + *parameters( + {"logaddexp2_"}, {exp_floor_lhs_exp_rhs, no_activations, sin_rhs, log_lhs_sqrt_abs_post, square_lhs} + ), + *parameters({"bias_gelu_"}, {no_activations, sin_rhs, square_lhs}), + *parameters({"gt_", "lte_", "gte_", "lt_"}, {sin_rhs, square_lhs}), + } + return ( + ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor) >= 0.98 + if (ttnn_fn, activations) in imprecise_cases + else ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor) >= 0.999 + ) + + assert compare(output_tensor, torch_output_tensor) + + +@pytest.mark.parametrize( + "a_shape, b_shape", + ( + (torch.Size([5, 3, 128, 64]), torch.Size([1, 3, 128, 1])), + (torch.Size([5, 3, 32, 32]), torch.Size([1, 1, 1, 1])), + (torch.Size([5, 1, 1, 128]), torch.Size([5, 1, 1, 1])), + (torch.Size([1, 71, 7, 7]), torch.Size([7, 7])), + (torch.Size([920, 1, 256]), torch.Size([256])), + (torch.Size([4, 12, 64, 64]), torch.Size([12, 1, 1])), + ), +) +@pytest.mark.parametrize("input_dtype", [ttnn.bfloat4_b, ttnn.bfloat8_b]) +@pytest.mark.parametrize("ttnn_fn", ["add_", "sub_", "mul_"]) +@skip_for_grayskull("Possible accuracy issues with grayskull") +def test_inplace_add_bf4b_bf8b(a_shape, b_shape, input_dtype, ttnn_fn, device): + torch.manual_seed(0) + + torch_input_tensor_a, input_tensor_a = rand_bf16_gen(a_shape, device, min=-1e3, max=1e3) + torch_input_tensor_b, input_tensor_b = rand_bf16_gen(b_shape, device, min=-1e3, max=1e3) + ttnn_op = getattr(ttnn.experimental, ttnn_fn) + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_dtype, + device=device, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + input_tensor_b = ttnn.from_torch( + torch_input_tensor_b, + dtype=input_dtype, + device=device, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + golden_function = ttnn.get_golden_function(ttnn_op) + torch_output_tensor = golden_function(torch_input_tensor_a, torch_input_tensor_b) + + ttnn_op(input_tensor_a, input_tensor_b) + output_tensor = ttnn.to_torch(input_tensor_a) + assert output_tensor.shape == torch_output_tensor.shape + + def compare(output_tensor, torch_output_tensor, ttnn_fn, input_dtype): + imprecise_cases = {"add_": {ttnn.bfloat4_b}, "sub_": {ttnn.bfloat4_b}, "mul_": {ttnn.bfloat4_b}} + if ttnn_fn in imprecise_cases and input_dtype in imprecise_cases[ttnn_fn]: + return ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor) >= 0.97 + else: + return ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor) >= 0.999 + + assert compare(output_tensor, torch_output_tensor, ttnn_fn, input_dtype) + + +@pytest.mark.parametrize( + "a_shape, b_shape", + ( + (torch.Size([1, 3, 128, 1]), torch.Size([5, 3, 128, 64])), + (torch.Size([1, 1, 1, 1]), torch.Size([5, 3, 32, 32])), + (torch.Size([5, 1, 1, 64]), torch.Size([1, 3, 128, 1])), + (torch.Size([16, 1]), torch.Size([1, 1, 32])), + ), +) +@pytest.mark.parametrize( + "ttnn_fn", + binary_inplace_fns, +) +def test_inplace_binary_scalar_ops_invalid_bcast(a_shape, b_shape, ttnn_fn, device): + torch.manual_seed(0) + ttnn_op = getattr(ttnn.experimental, ttnn_fn) + + _, input_tensor_a = rand_bf16_gen(a_shape, device) + _, input_tensor_b = rand_bf16_gen(b_shape, device) + + with pytest.raises(RuntimeError) as e: + cq_id = 0 + ttnn_op(input_tensor_a, input_tensor_b, queue_id=cq_id) + assert "In-place operation rule violation" in str(e.value) + + +@pytest.mark.parametrize( + "ttnn_fn", + [ + "add_", + "sub_", + "mul_", + "div_", + "gt_", + "lt_", + "lte_", + "gte_", + "eq_", + "ne_", + "squared_difference_", + ], +) +@pytest.mark.parametrize( + "a_shape", + ( + torch.Size([5, 3, 128, 64]), + torch.Size([1, 1, 1, 1]), + torch.Size([5, 3, 32, 32]), + torch.Size([16, 1]), + torch.Size([1, 1, 32]), + torch.Size([920, 1, 256]), + ), +) +@skip_for_grayskull("Possible accuracy issues with grayskull") +@pytest.mark.parametrize("scalar", [-0.25, -16.5, 0.0, 0.05, 1.7, 19.0]) +def test_inplace_binary_with_scalar(a_shape, scalar, ttnn_fn, device): + torch.manual_seed(0) + + ttnn_op = getattr(ttnn.experimental, ttnn_fn) + torch_input_tensor_a, input_tensor_a = rand_bf16_gen(a_shape, device) + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=ttnn.bfloat16, + device=device, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + golden_function = ttnn.get_golden_function(ttnn_op) + torch_output_tensor = golden_function(torch_input_tensor_a, scalar) + + ttnn_op(input_tensor_a, scalar) + output_tensor = ttnn.to_torch(input_tensor_a) + assert output_tensor.shape == torch_output_tensor.shape + assert ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor) >= 0.99 diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp index 93dee3e6cdd..b536f4d891e 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp @@ -5,6 +5,15 @@ #include "binary_ng.hpp" #include "device/binary_ng_device_operation.hpp" +#include "ttnn/operations/copy.hpp" + +inline bool needs_typecast_to_bfloat16(const Tensor& input) { + return (input.get_dtype() == DataType::BFLOAT8_B || input.get_dtype() == DataType::BFLOAT4_B); +} + +inline Tensor typecast_to(DataType dtype, const Tensor& input) { + return input.get_dtype() == dtype ? input : ttnn::typecast(input, dtype); +} namespace ttnn::operations::binary_ng { @@ -110,45 +119,65 @@ Tensor InplaceBinaryNg::invoke( uint8_t queue_id, const Tensor& input_tensor_a, const Tensor& input_tensor_b, - const std::optional& output_dtype, tt::stl::Span lhs_activations, tt::stl::Span rhs_activations, tt::stl::Span post_activations) { - auto input_a = typecast_to(DataType::BFLOAT16, input_tensor_a); - auto input_b = typecast_to(DataType::BFLOAT16, input_tensor_b); + bool typecast_a = needs_typecast_to_bfloat16(input_tensor_a); + bool typecast_b = needs_typecast_to_bfloat16(input_tensor_b); - return ttnn::prim::binary_ng( - queue_id, - input_a, - input_b, - binary_op_type, - output_dtype, - input_tensor_a.memory_config(), - input_tensor_a, - lhs_activations, - rhs_activations, - post_activations); + if (!typecast_a && !typecast_b) { + return ttnn::prim::binary_ng( + queue_id, + input_tensor_a, + input_tensor_b, + binary_op_type, + input_tensor_a.get_dtype(), + input_tensor_a.memory_config(), + input_tensor_a, + lhs_activations, + rhs_activations, + post_activations); + } else { + Tensor input_a = typecast_to(DataType::BFLOAT16, input_tensor_a); + Tensor input_b = typecast_to(DataType::BFLOAT16, input_tensor_b); + + ttnn::prim::binary_ng( + queue_id, + input_a, + input_b, + binary_op_type, + input_a.get_dtype(), + input_a.memory_config(), + input_a, + lhs_activations, + rhs_activations, + post_activations); + + if (typecast_a) { + copy::detail::copy_impl( + queue_id, + input_a, + {ttnn::operations::unary::UnaryWithParam( + ttnn::operations::unary::UnaryOpType::TYPECAST, + {static_cast(input_a.get_dtype()), static_cast(input_tensor_a.get_dtype())})}, + input_tensor_a.memory_config(), + input_tensor_a); + + return input_tensor_a; + } + return input_tensor_a; + } } template Tensor InplaceBinaryNg::invoke( const Tensor& input_tensor_a, const Tensor& input_tensor_b, - const std::optional& output_dtype, tt::stl::Span lhs_activations, tt::stl::Span rhs_activations, tt::stl::Span post_activations) { - return ttnn::prim::binary_ng( - DefaultQueueId, - input_tensor_a, - input_tensor_b, - binary_op_type, - output_dtype, - input_tensor_a.memory_config(), - input_tensor_a, - lhs_activations, - rhs_activations, - post_activations); + return InplaceBinaryNg::invoke( + DefaultQueueId, input_tensor_a, input_tensor_b, lhs_activations, rhs_activations, post_activations); } template @@ -156,44 +185,59 @@ Tensor InplaceBinaryNg::invoke( uint8_t queue_id, const Tensor& input_tensor_a, const float scalar, - const std::optional& output_dtype, tt::stl::Span lhs_activations, tt::stl::Span rhs_activations, tt::stl::Span post_activations) { - auto input_a = typecast_to(DataType::BFLOAT16, input_tensor_a); + bool typecast_a = needs_typecast_to_bfloat16(input_tensor_a); - return ttnn::prim::binary_ng( - queue_id, - input_a, - scalar, - binary_op_type, - output_dtype, - input_tensor_a.memory_config(), - input_tensor_a, - lhs_activations, - rhs_activations, - post_activations); + if (!typecast_a) { + return ttnn::prim::binary_ng( + queue_id, + input_tensor_a, + scalar, + binary_op_type, + input_tensor_a.get_dtype(), + input_tensor_a.memory_config(), + input_tensor_a, + lhs_activations, + rhs_activations, + post_activations); + } else { + Tensor input_a = typecast_to(DataType::BFLOAT16, input_tensor_a); + ttnn::prim::binary_ng( + queue_id, + input_a, + scalar, + binary_op_type, + input_a.get_dtype(), + input_a.memory_config(), + input_a, + lhs_activations, + rhs_activations, + post_activations); + + copy::detail::copy_impl( + queue_id, + input_a, + {ttnn::operations::unary::UnaryWithParam( + ttnn::operations::unary::UnaryOpType::TYPECAST, + {static_cast(input_a.get_dtype()), static_cast(input_tensor_a.get_dtype())})}, + input_tensor_a.memory_config(), + input_tensor_a); + + return input_tensor_a; + } } template Tensor InplaceBinaryNg::invoke( const Tensor& input_tensor_a, const float scalar, - const std::optional& output_dtype, tt::stl::Span lhs_activations, tt::stl::Span rhs_activations, tt::stl::Span post_activations) { - return ttnn::prim::binary_ng( - DefaultQueueId, - input_tensor_a, - scalar, - binary_op_type, - output_dtype, - input_tensor_a.memory_config(), - input_tensor_a, - lhs_activations, - rhs_activations, - post_activations); + return InplaceBinaryNg::invoke( + DefaultQueueId, input_tensor_a, scalar, lhs_activations, rhs_activations, post_activations); } template struct BinaryNg; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.hpp index 80751733f46..9c7351f802e 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.hpp @@ -7,7 +7,6 @@ #include "ttnn/decorators.hpp" #include "ttnn/operations/eltwise/binary_ng/types.hpp" -#include "ttnn/operations/copy.hpp" #include "ttnn/operations/eltwise/unary/common/unary_op_types.hpp" namespace ttnn::operations::binary_ng { @@ -63,24 +62,21 @@ struct InplaceBinaryNg { uint8_t queue_id, const Tensor& input_tensor_a, const Tensor& input_tensor_b, - const std::optional& output_dtype = std::nullopt, tt::stl::Span lhs_activations = {}, tt::stl::Span rhs_activations = {}, tt::stl::Span post_activations = {}); static Tensor invoke( - uint8_t queue_id, - const Tensor& input_tensor, - const float scalar, - const std::optional& output_dtype = std::nullopt, + const Tensor& input_tensor_a, + const Tensor& input_tensor_b, tt::stl::Span lhs_activations = {}, tt::stl::Span rhs_activations = {}, tt::stl::Span post_activations = {}); static Tensor invoke( - const Tensor& input_tensor_a, - const Tensor& input_tensor_b, - const std::optional& output_dtype = std::nullopt, + uint8_t queue_id, + const Tensor& input_tensor, + const float scalar, tt::stl::Span lhs_activations = {}, tt::stl::Span rhs_activations = {}, tt::stl::Span post_activations = {}); @@ -88,7 +84,6 @@ struct InplaceBinaryNg { static Tensor invoke( const Tensor& input_tensor, const float scalar, - const std::optional& output_dtype = std::nullopt, tt::stl::Span lhs_activations = {}, tt::stl::Span rhs_activations = {}, tt::stl::Span post_activations = {}); @@ -96,10 +91,6 @@ struct InplaceBinaryNg { } // namespace ttnn::operations::binary_ng -inline Tensor typecast_to(DataType dtype, const Tensor& input) { - return input.get_dtype() == dtype ? input : ttnn::typecast(input, dtype); -} - namespace ttnn::experimental { constexpr auto add = ttnn::register_operation_with_auto_launch_op< "ttnn::experimental::add", diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp index 92f6d52b3f2..fde01a1a215 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp @@ -97,18 +97,15 @@ void bind_inplace_binary_ng_operation(py::module& module, T op, const std::strin [](const T& self, const ttnn::Tensor& input_tensor_a, const float scalar, - const std::optional& dtype, const ttnn::SmallVector& lhs_activations, const ttnn::SmallVector& rhs_activations, const ttnn::SmallVector& post_activations, const uint8_t& queue_id) -> ttnn::Tensor { - return self( - queue_id, input_tensor_a, scalar, dtype, lhs_activations, rhs_activations, post_activations); + return self(queue_id, input_tensor_a, scalar, lhs_activations, rhs_activations, post_activations); }, py::arg("input_tensor_a"), py::arg("scalar"), py::kw_only(), - py::arg("dtype") = std::nullopt, py::arg("lhs_activations") = ttnn::SmallVector(), py::arg("rhs_activations") = ttnn::SmallVector(), py::arg("post_activations") = ttnn::SmallVector(), @@ -119,24 +116,16 @@ void bind_inplace_binary_ng_operation(py::module& module, T op, const std::strin [](const T& self, const ttnn::Tensor& input_tensor_a, const ttnn::Tensor& input_tensor_b, - const std::optional& dtype, const ttnn::SmallVector& lhs_activations, const ttnn::SmallVector& rhs_activations, const ttnn::SmallVector& post_activations, uint8_t queue_id) -> ttnn::Tensor { return self( - queue_id, - input_tensor_a, - input_tensor_b, - dtype, - lhs_activations, - rhs_activations, - post_activations); + queue_id, input_tensor_a, input_tensor_b, lhs_activations, rhs_activations, post_activations); }, py::arg("input_tensor_a"), py::arg("input_tensor_b"), py::kw_only(), - py::arg("dtype") = std::nullopt, py::arg("lhs_activations") = ttnn::SmallVector(), py::arg("rhs_activations") = ttnn::SmallVector(), py::arg("post_activations") = ttnn::SmallVector(), diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp index 4d0df8d6ca5..32f097e0993 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp @@ -103,6 +103,10 @@ void BinaryNgDeviceOperation::validate_on_program_cache_hit( const int rank_b = input_shape_b.rank(); const int larger_rank = std::max(rank_a, rank_b); + if (input_tensor_a.buffer()->address() == output_tensor->buffer()->address()) { + TT_FATAL(rank_a >= rank_b, "In-place operation rule violation for rank a: {}, rank b: {}", rank_a, rank_b); + } + for (int i = -1; i >= -larger_rank; --i) { auto a_dim = (i >= -rank_a) ? input_shape_a[i] : 1; auto b_dim = (i >= -rank_b) ? input_shape_b[i] : 1; @@ -112,6 +116,14 @@ void BinaryNgDeviceOperation::validate_on_program_cache_hit( i, a_dim, b_dim); + if (input_tensor_a.buffer()->address() == output_tensor->buffer()->address()) { + TT_FATAL( + a_dim >= b_dim, + "In-place operation rule violation for dimension {}, dim a: {}, dim b: {}", + i, + a_dim, + b_dim); + } } }