From a568827e6a08784e7f56afb7b69bd8af02b9b85f Mon Sep 17 00:00:00 2001 From: Patrick Roberts Date: Thu, 23 Jan 2025 23:22:33 +0000 Subject: [PATCH] #16502: Add Unary with params support to BinaryNg --- .../operations/eltwise/test_binary_bcast.py | 54 ++++++- .../eltwise/binary/common/binary_op_utils.cpp | 8 +- .../eltwise/binary_ng/binary_ng.cpp | 36 ++--- .../eltwise/binary_ng/binary_ng.hpp | 24 +-- .../eltwise/binary_ng/binary_ng_pybind.cpp | 24 +-- .../device/binary_ng_device_operation.cpp | 12 +- .../device/binary_ng_device_operation.hpp | 18 +-- .../device/binary_ng_program_factory.cpp | 8 +- .../binary_ng/device/binary_ng_utils.cpp | 34 ++-- .../binary_ng/device/binary_ng_utils.hpp | 6 +- .../kernels/compute/eltwise_utils_common.hpp | 145 ------------------ .../eltwise/unary/common/unary_op_utils.cpp | 126 ++++++++------- 12 files changed, 198 insertions(+), 297 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py b/tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py index aa0b6c3e4de..ad01e9796ee 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py @@ -154,8 +154,7 @@ def test_binary_scalar_ops(a_shape, b_shape, ttnn_fn, activations, device): a_pt, a_tt = rand_bf16_gen(a_shape, device) b_pt, b_tt = rand_bf16_gen(b_shape, device, min=min, max=max) - cq_id = 0 - out_tt = ttnn_op(a_tt, b_tt, queue_id=cq_id, lhs_activations=lhs, rhs_activations=rhs, post_activations=post) + out_tt = ttnn_op(a_tt, b_tt, lhs_activations=lhs, rhs_activations=rhs, post_activations=post) for golden_activation in golden_lhs: a_pt = golden_activation(a_pt).bfloat16() @@ -179,6 +178,57 @@ def compare(tt, pt): assert compare([out_tt], [out_pt]) +activation_with_param_fns = { + "ADD_UNARY_SFPU": torch.add, + "SUB_UNARY_SFPU": torch.sub, + "MUL_UNARY_SFPU": torch.mul, + "DIV_UNARY_SFPU": torch.div, + "POWER": torch.pow, +} + + +@pytest.mark.parametrize( + "a_shape, b_shape", + ( + (torch.Size([1, 1, 1, 1]), torch.Size([5, 3, 32, 32])), + (torch.Size([5, 1, 64, 1]), torch.Size([1, 3, 1, 128])), + (torch.Size([5, 1, 1, 64]), torch.Size([1, 3, 128, 1])), + ), +) +@pytest.mark.parametrize("ttnn_fn", ("add", "sub", "mul", "div")) +@pytest.mark.parametrize( + "post_activations", + ( + (), + (("ADD_UNARY_SFPU", 7),), + (("SUB_UNARY_SFPU", 6),), + (("MUL_UNARY_SFPU", 5),), + (("DIV_UNARY_SFPU", 4),), + (("POWER", 3),), + ), +) +def test_binary_scalar_ops_with_unary_param(a_shape, b_shape, ttnn_fn, post_activations, device): + torch.manual_seed(0) + ttnn_op = getattr(ttnn.experimental, ttnn_fn) + post = [(getattr(ttnn.UnaryOpType, op), param) for op, param in post_activations] + golden_post = ((lambda x: activation_with_param_fns[op](x, param)) for op, param in post_activations) + # make 0 exclusive for rhs of div + min, max = (1, 0) if ttnn_fn == "div" else (0, 1) + + a_pt, a_tt = rand_bf16_gen(a_shape, device) + b_pt, b_tt = rand_bf16_gen(b_shape, device, min=min, max=max) + + out_tt = ttnn_op(a_tt, b_tt, post_activations=post) + + golden_fn = ttnn.get_golden_function(ttnn_op) + out_pt = golden_fn(a_pt, b_pt).bfloat16() + + for golden_activation in golden_post: + out_pt = golden_activation(out_pt).bfloat16() + + assert compare_pcc([out_tt], [out_pt]) + + @pytest.mark.parametrize( "a_shape, b_shape", ( diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_utils.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_utils.cpp index 8bb60a72a3e..153c99488ba 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_utils.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_utils.cpp @@ -107,7 +107,7 @@ std::map get_defines( op_binary_type = "EltwiseBinaryType::ELWADD"; defines.merge(get_defines(UnaryOpType::LOG2, std::nullopt, "0", idst)); break; - default: TT_ASSERT(false && "Undefined op type"); + default: TT_THROW("Undefined op type {}", op_type); } using DataType = tt::tt_metal::DataType; @@ -138,10 +138,10 @@ std::map get_defines( (input_dtype.value() == DataType::UINT16 && output_dtype.value() == DataType::BFLOAT4_B) || (input_dtype.value() == DataType::BFLOAT4_B && output_dtype.value() == DataType::INT32) || (input_dtype.value() == DataType::INT32 && output_dtype.value() == DataType::BFLOAT4_B))) { - TT_ASSERT(defines.count("SFPU_OP_CHAIN_0") == 0 && "SFPU_OP_CHAIN_0 already defined"); + TT_ASSERT(defines.count("SFPU_OP_CHAIN_0") == 0, "SFPU_OP_CHAIN_0 already defined"); - auto in_dataformat = std::to_string((uint32_t)datatype_to_dataformat_converter(input_dtype.value())); - auto out_dataformat = std::to_string((uint32_t)datatype_to_dataformat_converter(output_dtype.value())); + auto in_dataformat = (uint32_t)datatype_to_dataformat_converter(input_dtype.value()); + auto out_dataformat = (uint32_t)datatype_to_dataformat_converter(output_dtype.value()); defines.insert( {"SFPU_OP_CHAIN_0", fmt::format("typecast_tile_init(); typecast_tile<{0}u, {1}u>(i);", in_dataformat, out_dataformat)}); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp index 1231a7c9254..9305592ec66 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp @@ -24,9 +24,9 @@ Tensor BinaryNg::invoke( const std::optional& output_dtype, const std::optional& memory_config, std::optional optional_output_tensor, - tt::stl::Span lhs_activations, - tt::stl::Span rhs_activations, - tt::stl::Span post_activations) { + tt::stl::Span lhs_activations, + tt::stl::Span rhs_activations, + tt::stl::Span post_activations) { bool typecast_a = needs_typecast_to_bfloat16(input_tensor_a); bool typecast_b = needs_typecast_to_bfloat16(input_tensor_b); Tensor input_a = typecast_a ? typecast_to(DataType::BFLOAT16, input_tensor_a) : input_tensor_a; @@ -52,9 +52,9 @@ Tensor BinaryNg::invoke( const std::optional& output_dtype, const std::optional& memory_config, std::optional optional_output_tensor, - tt::stl::Span lhs_activations, - tt::stl::Span rhs_activations, - tt::stl::Span post_activations) { + tt::stl::Span lhs_activations, + tt::stl::Span rhs_activations, + tt::stl::Span post_activations) { return invoke( DefaultQueueId, input_tensor_a, @@ -75,9 +75,9 @@ Tensor BinaryNg::invoke( const std::optional& output_dtype, const std::optional& memory_config, std::optional optional_output_tensor, - tt::stl::Span lhs_activations, - tt::stl::Span rhs_activations, - tt::stl::Span post_activations) { + tt::stl::Span lhs_activations, + tt::stl::Span rhs_activations, + tt::stl::Span post_activations) { bool typecast_a = needs_typecast_to_bfloat16(input_tensor_a); Tensor input_a = typecast_a ? typecast_to(DataType::BFLOAT16, input_tensor_a) : input_tensor_a; @@ -101,9 +101,9 @@ Tensor BinaryNg::invoke( const std::optional& output_dtype, const std::optional& memory_config, std::optional optional_output_tensor, - tt::stl::Span lhs_activations, - tt::stl::Span rhs_activations, - tt::stl::Span post_activations) { + tt::stl::Span lhs_activations, + tt::stl::Span rhs_activations, + tt::stl::Span post_activations) { return invoke( DefaultQueueId, input_tensor_a, @@ -127,9 +127,9 @@ Tensor BinaryNgBitwise::invoke( input_tensor_a.get_dtype() == DataType::INT32 && input_tensor_b.get_dtype() == DataType::INT32, "Bitwise ops require input tensors to be of INT32 datatype "); - tt::stl::Span lhs_activations = {}; - tt::stl::Span rhs_activations = {}; - tt::stl::Span post_activations = {}; + tt::stl::Span lhs_activations = {}; + tt::stl::Span rhs_activations = {}; + tt::stl::Span post_activations = {}; return ttnn::prim::binary_ng( queue_id, @@ -164,9 +164,9 @@ Tensor BinaryNgBitwise::invoke( TT_FATAL( input_tensor_a.get_dtype() == DataType::INT32, "Bitwise ops require input tensor to be of INT32 datatype "); - tt::stl::Span lhs_activations = {}; - tt::stl::Span rhs_activations = {}; - tt::stl::Span post_activations = {}; + tt::stl::Span lhs_activations = {}; + tt::stl::Span rhs_activations = {}; + tt::stl::Span post_activations = {}; return ttnn::prim::binary_ng( queue_id, diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.hpp index c3f1727bf0e..e3bab50d7a9 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.hpp @@ -21,9 +21,9 @@ struct BinaryNg { const std::optional& output_dtype = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt, - tt::stl::Span lhs_activations = {}, - tt::stl::Span rhs_activations = {}, - tt::stl::Span post_activations = {}); + tt::stl::Span lhs_activations = {}, + tt::stl::Span rhs_activations = {}, + tt::stl::Span post_activations = {}); static Tensor invoke( const Tensor& input_tensor_a, @@ -31,9 +31,9 @@ struct BinaryNg { const std::optional& output_dtype = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt, - tt::stl::Span lhs_activations = {}, - tt::stl::Span rhs_activations = {}, - tt::stl::Span post_activations = {}); + tt::stl::Span lhs_activations = {}, + tt::stl::Span rhs_activations = {}, + tt::stl::Span post_activations = {}); static Tensor invoke( uint8_t queue_id, @@ -42,9 +42,9 @@ struct BinaryNg { const std::optional& output_dtype = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt, - tt::stl::Span lhs_activations = {}, - tt::stl::Span rhs_activations = {}, - tt::stl::Span post_activations = {}); + tt::stl::Span lhs_activations = {}, + tt::stl::Span rhs_activations = {}, + tt::stl::Span post_activations = {}); static Tensor invoke( const Tensor& input_tensor_a, @@ -52,9 +52,9 @@ struct BinaryNg { const std::optional& output_dtype = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt, - tt::stl::Span lhs_activations = {}, - tt::stl::Span rhs_activations = {}, - tt::stl::Span post_activations = {}); + tt::stl::Span lhs_activations = {}, + tt::stl::Span rhs_activations = {}, + tt::stl::Span post_activations = {}); }; template diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp index fd76ab0dd97..d3f3e82b8ef 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp @@ -24,9 +24,9 @@ void bind_binary_ng_operation(py::module& module, T op, const std::string& docst const std::optional& dtype, const std::optional& memory_config, const std::optional& output_tensor, - const ttnn::SmallVector& lhs_activations, - const ttnn::SmallVector& rhs_activations, - const ttnn::SmallVector& post_activations, + const ttnn::SmallVector& lhs_activations, + const ttnn::SmallVector& rhs_activations, + const ttnn::SmallVector& post_activations, const uint8_t& queue_id) -> ttnn::Tensor { return self( queue_id, @@ -45,9 +45,9 @@ void bind_binary_ng_operation(py::module& module, T op, const std::string& docst py::arg("dtype") = std::nullopt, py::arg("memory_config") = std::nullopt, py::arg("output_tensor") = std::nullopt, - py::arg("lhs_activations") = ttnn::SmallVector(), - py::arg("rhs_activations") = ttnn::SmallVector(), - py::arg("post_activations") = ttnn::SmallVector(), + py::arg("lhs_activations") = ttnn::SmallVector(), + py::arg("rhs_activations") = ttnn::SmallVector(), + py::arg("post_activations") = ttnn::SmallVector(), py::arg("queue_id") = 0}, // tensor and tensor @@ -58,9 +58,9 @@ void bind_binary_ng_operation(py::module& module, T op, const std::string& docst const std::optional& dtype, const std::optional& memory_config, const std::optional& output_tensor, - const ttnn::SmallVector& lhs_activations, - const ttnn::SmallVector& rhs_activations, - const ttnn::SmallVector& post_activations, + const ttnn::SmallVector& lhs_activations, + const ttnn::SmallVector& rhs_activations, + const ttnn::SmallVector& post_activations, uint8_t queue_id) -> ttnn::Tensor { return self( queue_id, @@ -79,9 +79,9 @@ void bind_binary_ng_operation(py::module& module, T op, const std::string& docst py::arg("dtype") = std::nullopt, py::arg("memory_config") = std::nullopt, py::arg("output_tensor") = std::nullopt, - py::arg("lhs_activations") = ttnn::SmallVector(), - py::arg("rhs_activations") = ttnn::SmallVector(), - py::arg("post_activations") = ttnn::SmallVector(), + py::arg("lhs_activations") = ttnn::SmallVector(), + py::arg("rhs_activations") = ttnn::SmallVector(), + py::arg("post_activations") = ttnn::SmallVector(), py::arg("queue_id") = 0}); } diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp index 80a4d881c29..6a5908af087 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp @@ -360,9 +360,9 @@ BinaryNgDeviceOperation::invoke( const std::optional& output_dtype, const std::optional& memory_config, std::optional output_tensor, - tt::stl::Span lhs_activations, - tt::stl::Span rhs_activations, - tt::stl::Span post_activations) { + tt::stl::Span lhs_activations, + tt::stl::Span rhs_activations, + tt::stl::Span post_activations) { auto subtile_broadcast_type = get_subtile_broadcast_type( input_tensor_a.get_logical_shape()[-2], input_tensor_a.get_logical_shape()[-1], @@ -400,9 +400,9 @@ BinaryNgDeviceOperation::invoke( const std::optional& output_dtype, const std::optional& memory_config, std::optional output_tensor, - tt::stl::Span lhs_activations, - tt::stl::Span rhs_activations, - tt::stl::Span post_activations) { + tt::stl::Span lhs_activations, + tt::stl::Span rhs_activations, + tt::stl::Span post_activations) { DataType dtype1 = input_tensor_a.get_dtype(); bool device_check = input_tensor_a.device()->arch() != tt::ARCH::GRAYSKULL; bool is_sfpu_op = (utils::is_binary_sfpu_op(binary_op_type, dtype1, dtype1) && device_check); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.hpp index 0985c20d5e9..598af525c8f 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.hpp @@ -31,9 +31,9 @@ struct BinaryNgDeviceOperation { struct operation_attributes_t { BinaryOpType binary_op_type; - ttnn::SmallVector lhs_activations; - ttnn::SmallVector rhs_activations; - ttnn::SmallVector post_activations; + ttnn::SmallVector lhs_activations; + ttnn::SmallVector rhs_activations; + ttnn::SmallVector post_activations; std::optional scalar; tt::tt_metal::MemoryConfig memory_config; DataType input_dtype; @@ -92,9 +92,9 @@ struct BinaryNgDeviceOperation { const std::optional& output_dtype, const std::optional& memory_config, std::optional optional_output_tensor, - tt::stl::Span lhs_activations, - tt::stl::Span rhs_activations, - tt::stl::Span post_activations); + tt::stl::Span lhs_activations, + tt::stl::Span rhs_activations, + tt::stl::Span post_activations); // tensor-scalar invocation static std::tuple invoke( @@ -104,9 +104,9 @@ struct BinaryNgDeviceOperation { const std::optional& output_dtype, const std::optional& memory_config, std::optional optional_output_tensor, - tt::stl::Span lhs_activations, - tt::stl::Span rhs_activations, - tt::stl::Span post_activations); + tt::stl::Span lhs_activations, + tt::stl::Span rhs_activations, + tt::stl::Span post_activations); }; } // namespace ttnn::operations::binary_ng diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_program_factory.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_program_factory.cpp index 1022d38dc03..82535321957 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_program_factory.cpp @@ -322,9 +322,9 @@ BinaryNgDeviceOperation::ProgramFactory::cached_program_t BinaryNgDeviceOperatio auto compute_kernel_defines = op_config.as_defines(a.get_dtype()); { - ttnn::SmallVector lhs_activations = operation_attributes.lhs_activations; - ttnn::SmallVector rhs_activations = operation_attributes.rhs_activations; - ttnn::SmallVector post_activations = operation_attributes.post_activations; + ttnn::SmallVector lhs_activations = operation_attributes.lhs_activations; + ttnn::SmallVector rhs_activations = operation_attributes.rhs_activations; + ttnn::SmallVector post_activations = operation_attributes.post_activations; if (op_config.process_lhs.has_value()) { lhs_activations.push_back(*op_config.process_lhs); @@ -342,7 +342,7 @@ BinaryNgDeviceOperation::ProgramFactory::cached_program_t BinaryNgDeviceOperatio add_activation_defines(compute_kernel_defines, rhs_activations, "RHS"); if (lhs_activations.empty() and rhs_activations.empty() and post_activations.size() == 1 and - post_activations[0] == unary::UnaryOpType::RELU) { + post_activations[0].op_type == unary::UnaryOpType::RELU) { compute_kernel_defines["PACK_RELU"] = "1"; compute_kernel_defines["PROCESS_POST_ACTIVATIONS(i)"] = ""; unary::utils::update_macro_defines(unary::UnaryOpType::RELU, compute_kernel_defines); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.cpp index e1d2fff727b..9308e02c82e 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.cpp @@ -10,6 +10,14 @@ #include #include +namespace ttnn::operations::binary_ng { + +struct Lowercase { + std::string_view view; +}; + +} // namespace ttnn::operations::binary_ng + template <> struct fmt::formatter : fmt::formatter { auto format(const ttnn::operations::binary_ng::Lowercase& value, fmt::format_context& ctx) const { @@ -282,21 +290,19 @@ std::map OpConfig::as_defines(DataType dtype) const { void add_activation_defines( std::map& defines, - tt::stl::Span activations, + tt::stl::Span activations, std::string_view operand) { - auto prepend_separator = false; - std::string process = ""; - - for (auto& a : activations) { - if (prepend_separator) { - process += ';'; - } - prepend_separator = true; - process += fmt::format("PROCESS_ACTIVATION({}, i)", magic_enum::enum_name(a)); - unary::utils::update_macro_defines(a, defines); - } - - defines[fmt::format("PROCESS_{}_ACTIVATIONS(i)", operand)] = process; + defines[fmt::format("PROCESS_{}_ACTIVATIONS(i)", operand)] = std::accumulate( + activations.begin(), + activations.end(), + std::string{}, + [&](std::string&& process, const unary::UnaryWithParam& a) { + const auto& [op_init, op_func] = unary::utils::get_op_init_and_func(a.op_type, a.params, "i"); + process += op_init; + process += op_func; + unary::utils::update_macro_defines(a.op_type, defines); + return std::move(process); + }); } bool OpConfig::is_sfpu_op() const { return std::holds_alternative(binary_op); } diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.hpp index a10e277152f..34bec43fc1c 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.hpp @@ -71,11 +71,7 @@ struct OpConfig { void add_activation_defines( std::map& defines, - tt::stl::Span activations, + tt::stl::Span activations, std::string_view operand); -struct Lowercase { - std::string_view view; -}; - } // namespace ttnn::operations::binary_ng diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils_common.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils_common.hpp index 0580376dffd..efeb1ab832f 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils_common.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils_common.hpp @@ -4,144 +4,6 @@ #pragma once -#define ACTIVATION_INIT_RELU relu_tile_init -#define ACTIVATION_APPLY_RELU relu_tile - -#define ACTIVATION_INIT_SQUARE square_tile_init -#define ACTIVATION_APPLY_SQUARE square_tile - -#define ACTIVATION_INIT_GTZ gtz_tile_init -#define ACTIVATION_APPLY_GTZ gtz_tile - -#define ACTIVATION_INIT_LTZ ltz_tile_init -#define ACTIVATION_APPLY_LTZ ltz_tile - -#define ACTIVATION_INIT_GEZ gez_tile_init -#define ACTIVATION_APPLY_GEZ gez_tile - -#define ACTIVATION_INIT_LEZ lez_tile_init -#define ACTIVATION_APPLY_LEZ lez_tile - -#define ACTIVATION_INIT_EQZ eqz_tile_init -#define ACTIVATION_APPLY_EQZ eqz_tile - -#define ACTIVATION_INIT_NEZ nez_tile_init -#define ACTIVATION_APPLY_NEZ nez_tile - -#define ACTIVATION_INIT_LOG log_tile_init -#define ACTIVATION_APPLY_LOG log_tile - -#define ACTIVATION_INIT_TANH tanh_tile_init -#define ACTIVATION_APPLY_TANH tanh_tile - -#define ACTIVATION_INIT_LOG2 log_with_base_tile_init -#define ACTIVATION_APPLY_LOG2(i) log_with_base_tile(i, 0x3dc5u) - -#define ACTIVATION_INIT_LOG10 log_with_base_tile_init -#define ACTIVATION_APPLY_LOG10(i) log_with_base_tile(i, 0x36f3u) - -#define ACTIVATION_INIT_EXP exp_tile_init -#define ACTIVATION_APPLY_EXP exp_tile - -#define ACTIVATION_INIT_EXP2 exp2_tile_init -#define ACTIVATION_APPLY_EXP2 exp2_tile - -#define ACTIVATION_INIT_EXPM1 expm1_tile_init -#define ACTIVATION_APPLY_EXPM1 expm1_tile - -#define ACTIVATION_INIT_RECIP recip_tile_init -#define ACTIVATION_APPLY_RECIP recip_tile - -#define ACTIVATION_INIT_GELU gelu_tile_init -#define ACTIVATION_APPLY_GELU gelu_tile - -#define ACTIVATION_INIT_SQRT sqrt_tile_init -#define ACTIVATION_APPLY_SQRT sqrt_tile - -#define ACTIVATION_INIT_SIGMOID sigmoid_tile_init -#define ACTIVATION_APPLY_SIGMOID sigmoid_tile - -#define ACTIVATION_INIT_SIN sin_tile_init -#define ACTIVATION_APPLY_SIN sin_tile - -#define ACTIVATION_INIT_COS cos_tile_init -#define ACTIVATION_APPLY_COS cos_tile - -#define ACTIVATION_INIT_TAN tan_tile_init -#define ACTIVATION_APPLY_TAN tan_tile - -#define ACTIVATION_INIT_ASIN asin_tile_init -#define ACTIVATION_APPLY_ASIN asin_tile - -#define ACTIVATION_INIT_ACOS acos_tile_init -#define ACTIVATION_APPLY_ACOS acos_tile - -#define ACTIVATION_INIT_ATAN atan_tile_init -#define ACTIVATION_APPLY_ATAN atan_tile - -#define ACTIVATION_INIT_ABS abs_tile_init -#define ACTIVATION_APPLY_ABS abs_tile - -#define ACTIVATION_INIT_SIGN sign_tile_init -#define ACTIVATION_APPLY_SIGN sign_tile - -#define ACTIVATION_INIT_SIGNBIT signbit_tile_init -#define ACTIVATION_APPLY_SIGNBIT signbit_tile - -#define ACTIVATION_INIT_RSQRT rsqrt_tile_init -#define ACTIVATION_APPLY_RSQRT rsqrt_tile - -#define ACTIVATION_INIT_RELU6 relu_max_tile_init -#define ACTIVATION_APPLY_RELU6(i) relu_max_tile(i, 0x40c00000u) - -#define ACTIVATION_INIT_ERF erf_tile_init -#define ACTIVATION_APPLY_ERF erf_tile - -#define ACTIVATION_INIT_ERFC erfc_tile_init -#define ACTIVATION_APPLY_ERFC erfc_tile - -#define ACTIVATION_INIT_ISINF isinf_tile_init -#define ACTIVATION_APPLY_ISINF isinf_tile - -#define ACTIVATION_INIT_ISPOSINF isposinf_tile_init -#define ACTIVATION_APPLY_ISPOSINF isposinf_tile - -#define ACTIVATION_INIT_ISNEGINF isneginf_tile_init -#define ACTIVATION_APPLY_ISNEGINF isneginf_tile - -#define ACTIVATION_INIT_ISNAN isnan_tile_init -#define ACTIVATION_APPLY_ISNAN isnan_tile - -#define ACTIVATION_INIT_ISFINITE isfinite_tile_init -#define ACTIVATION_APPLY_ISFINITE isfinite_tile - -#define ACTIVATION_INIT_LOGICAL_NOT_UNARY logical_not_unary_tile_init -#define ACTIVATION_APPLY_LOGICAL_NOT_UNARY logical_not_unary_tile - -#define ACTIVATION_INIT_ERFINV erfinv_tile_init -#define ACTIVATION_APPLY_ERFINV erfinv_tile - -#define ACTIVATION_INIT_I0 i0_tile_init -#define ACTIVATION_APPLY_I0 i0_tile - -#define ACTIVATION_INIT_I1 i1_tile_init -#define ACTIVATION_APPLY_I1 i1_tile - -#define ACTIVATION_INIT_SILU silu_tile_init -#define ACTIVATION_APPLY_SILU silu_tile - -#define ACTIVATION_INIT_NEG negative_tile_init -#define ACTIVATION_APPLY_NEG negative_tile - -#define ACTIVATION_INIT_BITWISE_NOT bitwise_not_tile_init -#define ACTIVATION_APPLY_BITWISE_NOT bitwise_not_tile - -#define ACTIVATION_INIT_FLOOR floor_tile_init -#define ACTIVATION_APPLY_FLOOR floor_tile - -#define ACTIVATION_INIT_CEIL ceil_tile_init -#define ACTIVATION_APPLY_CEIL ceil_tile - #define IS_EMPTY(...) P_CAT(IS_EMPTY_, IS_BEGIN_PARENS(__VA_ARGS__))(__VA_ARGS__) #define IS_EMPTY_0(...) IS_BEGIN_PARENS(IS_EMPTY_NON_FUNCTION_C __VA_ARGS__()) #define IS_EMPTY_1(...) 0 @@ -163,13 +25,6 @@ #define P_COMPL_0 1 #define P_COMPL_1 0 -#define ACTIVATION_INIT(elem) ACTIVATION_INIT_##elem() -#define ACTIVATION_APPLY(elem, i) ACTIVATION_APPLY_##elem(i) - -#define PROCESS_ACTIVATION(elem, i) \ - ACTIVATION_INIT(elem); \ - ACTIVATION_APPLY(elem, i) - #define PROCESS_ACTIVATIONS(op, i) PROCESS_ACTIVATIONS_(op)(i) #define PROCESS_ACTIVATIONS_(op) PROCESS_##op##_ACTIVATIONS #define HAS_ACTIVATIONS(op) P_COMPL(IS_EMPTY(PROCESS_ACTIVATIONS(op, 0))) diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp index 1e084f04b34..d0e0b2bf222 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp @@ -12,20 +12,6 @@ using namespace tt::tt_metal; namespace ttnn::operations::unary::utils { namespace { -union Converter { -public: - float f; - uint32_t u; - - Converter(float f_) : f(f_) {}; - - static std::string to_hex(float f_) { - Converter obj(f_); - std::stringstream ss; - ss << "0x" << std::hex << obj.u; - return ss.str(); - } -}; std::string get_macro_definition(UnaryOpType op_type) { switch (op_type) { @@ -90,132 +76,137 @@ std::pair get_op_init_and_func_parameterized( switch (op_type) { case UnaryOpType::FILL: op_init_and_name = { - "fill_tile_init();", fmt::format("fill_tile({}, {}u);", idst, Converter::to_hex(param0))}; + "fill_tile_init();", fmt::format("fill_tile({}, {:#x}u);", idst, std::bit_cast(param0))}; break; case UnaryOpType::RELU_MAX: op_init_and_name = { - "relu_max_tile_init();", fmt::format("relu_max_tile({}, {}u);", idst, Converter::to_hex(param0))}; + "relu_max_tile_init();", + fmt::format("relu_max_tile({}, {:#x}u);", idst, std::bit_cast(param0))}; break; case UnaryOpType::RELU_MIN: op_init_and_name = { - "relu_min_tile_init();", fmt::format("relu_min_tile({}, {}u);", idst, Converter::to_hex(param0))}; + "relu_min_tile_init();", + fmt::format("relu_min_tile({}, {:#x}u);", idst, std::bit_cast(param0))}; break; case UnaryOpType::POWER: - op_init_and_name = { - "power_tile_init();", fmt::format("power_tile({}, {}u);", idst, std::to_string((uint32_t)param0))}; + op_init_and_name = {"power_tile_init();", fmt::format("power_tile({}, {}u);", idst, (uint32_t)param0)}; break; case UnaryOpType::LEAKY_RELU: op_init_and_name = { - "leaky_relu_tile_init();", fmt::format("leaky_relu_tile({}, {}u);", idst, Converter::to_hex(param0))}; + "leaky_relu_tile_init();", + fmt::format("leaky_relu_tile({}, {:#x}u);", idst, std::bit_cast(param0))}; break; case UnaryOpType::ELU: - op_init_and_name = {"elu_tile_init();", fmt::format("elu_tile({}, {}u);", idst, Converter::to_hex(param0))}; + op_init_and_name = { + "elu_tile_init();", fmt::format("elu_tile({}, {:#x}u);", idst, std::bit_cast(param0))}; break; case UnaryOpType::GELU: op_init_and_name = { - fmt::format("gelu_tile_init<{}u>();", std::to_string((uint32_t)param0)), - fmt::format("gelu_tile<{1}u>({0});", idst, std::to_string((uint32_t)param0))}; + fmt::format("gelu_tile_init<{}u>();", (uint32_t)param0), + fmt::format("gelu_tile<{1}u>({0});", idst, (uint32_t)param0)}; break; case UnaryOpType::RSQRT: op_init_and_name = { - fmt::format("rsqrt_tile_init<{}u>();", std::to_string((uint32_t)param0)), - fmt::format("rsqrt_tile<{1}u>({0});", idst, std::to_string((uint32_t)param0))}; + fmt::format("rsqrt_tile_init<{}u>();", (uint32_t)param0), + fmt::format("rsqrt_tile<{1}u>({0});", idst, (uint32_t)param0)}; break; case UnaryOpType::HEAVISIDE: op_init_and_name = { - "heaviside_tile_init();", fmt::format("heaviside_tile({}, {}u);", idst, Converter::to_hex(param0))}; + "heaviside_tile_init();", + fmt::format("heaviside_tile({}, {:#x}u);", idst, std::bit_cast(param0))}; break; case UnaryOpType::BITWISE_XOR: op_init_and_name = { - "bitwise_xor_tile_init();", - fmt::format("bitwise_xor_tile({}, {}u);", idst, std::to_string((uint)param0))}; + "bitwise_xor_tile_init();", fmt::format("bitwise_xor_tile({}, {}u);", idst, (uint)param0)}; break; case UnaryOpType::BITWISE_AND: op_init_and_name = { - "bitwise_and_tile_init();", - fmt::format("bitwise_and_tile({}, {}u);", idst, std::to_string((uint)param0))}; + "bitwise_and_tile_init();", fmt::format("bitwise_and_tile({}, {}u);", idst, (uint)param0)}; break; case UnaryOpType::BITWISE_OR: op_init_and_name = { - "bitwise_or_tile_init();", - fmt::format("bitwise_or_tile({}, {}u);", idst, std::to_string((uint)param0))}; + "bitwise_or_tile_init();", fmt::format("bitwise_or_tile({}, {}u);", idst, (uint)param0)}; break; case UnaryOpType::RIGHT_SHIFT: op_init_and_name = { - "right_shift_tile_init();", - fmt::format("right_shift_tile({}, {}u);", idst, std::to_string((uint)param0))}; + "right_shift_tile_init();", fmt::format("right_shift_tile({}, {}u);", idst, (uint)param0)}; break; case UnaryOpType::LEFT_SHIFT: op_init_and_name = { - "left_shift_tile_init();", - fmt::format("left_shift_tile({}, {}u);", idst, std::to_string((uint)param0))}; + "left_shift_tile_init();", fmt::format("left_shift_tile({}, {}u);", idst, (uint)param0)}; break; case UnaryOpType::REMAINDER: op_init_and_name = { "remainder_tile_init();", fmt::format( - "remainder_tile({}, {}u, {}u);", + "remainder_tile({}, {:#x}u, {:#x}u);", idst, - Converter::to_hex(param0), - Converter::to_hex(1.0f / param0))}; + std::bit_cast(param0), + std::bit_cast(1.0f / param0))}; break; case UnaryOpType::FMOD: op_init_and_name = { "fmod_tile_init();", fmt::format( - "fmod_tile({}, {}u, {}u);", idst, Converter::to_hex(param0), Converter::to_hex(1.0f / param0))}; + "fmod_tile({}, {:#x}u, {:#x}u);", + idst, + std::bit_cast(param0), + std::bit_cast(1.0f / param0))}; break; case UnaryOpType::EXP: op_init_and_name = { - fmt::format("exp_tile_init<{}u>();", std::to_string((uint32_t)param0)), - fmt::format("exp_tile<{1}u>({0});", idst, std::to_string((uint32_t)param0))}; + fmt::format("exp_tile_init<{}u>();", (uint32_t)param0), + fmt::format("exp_tile<{1}u>({0});", idst, (uint32_t)param0)}; break; case UnaryOpType::ERF: op_init_and_name = { - fmt::format("erf_tile_init<{}u>();", std::to_string((uint32_t)param0)), - fmt::format("erf_tile<{1}u>({0});", idst, std::to_string((uint32_t)param0))}; + fmt::format("erf_tile_init<{}u>();", (uint32_t)param0), + fmt::format("erf_tile<{1}u>({0});", idst, (uint32_t)param0)}; break; case UnaryOpType::ERFC: op_init_and_name = { - fmt::format("erfc_tile_init<{}u>();", std::to_string((uint32_t)param0)), - fmt::format("erfc_tile<{1}u>({0});", idst, std::to_string((uint32_t)param0))}; + fmt::format("erfc_tile_init<{}u>();", (uint32_t)param0), + fmt::format("erfc_tile<{1}u>({0});", idst, (uint32_t)param0)}; break; case UnaryOpType::RDIV: op_init_and_name = {}; break; case UnaryOpType::RSUB: op_init_and_name = { - "rsub_tile_init();", fmt::format("rsub_tile({}, {}u);", idst, Converter::to_hex(param0))}; + "rsub_tile_init();", fmt::format("rsub_tile({}, {:#x}u);", idst, std::bit_cast(param0))}; break; case UnaryOpType::SUB_UNARY_SFPU: op_init_and_name = { "binop_with_scalar_tile_init();", - fmt::format("sub_unary_tile({}, {}u);", idst, Converter::to_hex(param0))}; + fmt::format("sub_unary_tile({}, {:#x}u);", idst, std::bit_cast(param0))}; break; case UnaryOpType::ADD_UNARY_SFPU: op_init_and_name = { "binop_with_scalar_tile_init();", - fmt::format("add_unary_tile({}, {}u);", idst, Converter::to_hex(param0))}; + fmt::format("add_unary_tile({}, {:#x}u);", idst, std::bit_cast(param0))}; break; case UnaryOpType::MUL_UNARY_SFPU: op_init_and_name = { "binop_with_scalar_tile_init();", - fmt::format("mul_unary_tile({}, {}u);", idst, Converter::to_hex(param0))}; + fmt::format("mul_unary_tile({}, {:#x}u);", idst, std::bit_cast(param0))}; break; case UnaryOpType::DIV_UNARY_SFPU: op_init_and_name = { "binop_with_scalar_tile_init();", - fmt::format("div_unary_tile({}, {}u);", idst, Converter::to_hex(1.0f / param0))}; + fmt::format("div_unary_tile({}, {:#x}u);", idst, std::bit_cast(1.0f / param0))}; break; case UnaryOpType::UNARY_NE: op_init_and_name = { - "unary_ne_tile_init();", fmt::format("unary_ne_tile({}, {}u);", idst, Converter::to_hex(param0))}; + "unary_ne_tile_init();", + fmt::format("unary_ne_tile({}, {:#x}u);", idst, std::bit_cast(param0))}; break; case UnaryOpType::UNARY_GT: op_init_and_name = { - "unary_gt_tile_init();", fmt::format("unary_gt_tile({}, {}u);", idst, Converter::to_hex(param0))}; + "unary_gt_tile_init();", + fmt::format("unary_gt_tile({}, {:#x}u);", idst, std::bit_cast(param0))}; break; case UnaryOpType::UNARY_LT: op_init_and_name = { - "unary_lt_tile_init();", fmt::format("unary_lt_tile({}, {}u);", idst, Converter::to_hex(param0))}; + "unary_lt_tile_init();", + fmt::format("unary_lt_tile({}, {:#x}u);", idst, std::bit_cast(param0))}; break; case UnaryOpType::SOFTPLUS: { TT_ASSERT(params.size() == 2, "Expected softplus to take 2 parameters"); @@ -223,16 +214,16 @@ std::pair get_op_init_and_func_parameterized( op_init_and_name = { "softplus_tile_init();", fmt::format( - "softplus_tile({}, {}u, {}u, {}u);", + "softplus_tile({}, {:#x}u, {:#x}u, {:#x}u);", idst, - Converter::to_hex(param0), - Converter::to_hex(1.0f / param0), // Pass reciprocal to avoid doing it on device - Converter::to_hex(param1))}; + std::bit_cast(param0), + std::bit_cast(1.0f / param0), // Pass reciprocal to avoid doing it on device + std::bit_cast(param1))}; break; } case UnaryOpType::PRELU_SFPU: { op_init_and_name = { - "prelu_tile_init();", fmt::format("prelu_tile({}, {}u);", idst, Converter::to_hex(param0))}; + "prelu_tile_init();", fmt::format("prelu_tile({}, {:#x}u);", idst, std::bit_cast(param0))}; break; } case UnaryOpType::TYPECAST: @@ -242,10 +233,10 @@ std::pair get_op_init_and_func_parameterized( fmt::format( "typecast_tile<{1}u, {2}u>({0});", idst, - std::to_string((uint32_t)datatype_to_dataformat_converter((DataType)params[0])), - std::to_string((uint32_t)datatype_to_dataformat_converter((DataType)params[1])))}; + (uint32_t)datatype_to_dataformat_converter((DataType)params[0]), + (uint32_t)datatype_to_dataformat_converter((DataType)params[1]))}; break; - default: TT_ASSERT(false && "unexpected parameterized type"); + default: TT_THROW("unexpected parameterized op type {}", op_type); }; return op_init_and_name; } @@ -257,6 +248,8 @@ std::pair get_op_init_and_func_default(UnaryOpType op_type, std: op_init_and_name = {"bitwise_not_tile_init();", fmt::format("bitwise_not_tile({});", idst)}; break; case UnaryOpType::RECIP: op_init_and_name = {"recip_tile_init();", fmt::format("recip_tile({});", idst)}; break; + case UnaryOpType::GELU: op_init_and_name = {"gelu_tile_init();", fmt::format("gelu_tile({});", idst)}; break; + case UnaryOpType::RSQRT: op_init_and_name = {"rsqrt_tile_init();", fmt::format("rsqrt_tile({});", idst)}; break; case UnaryOpType::RELU: op_init_and_name = {"relu_tile_init();", fmt::format("relu_tile({});", idst)}; break; case UnaryOpType::SQRT: op_init_and_name = {"sqrt_tile_init();", fmt::format("sqrt_tile({});", idst)}; break; case UnaryOpType::SIGMOID: @@ -285,6 +278,9 @@ std::pair get_op_init_and_func_default(UnaryOpType op_type, std: break; case UnaryOpType::I0: op_init_and_name = {"i0_tile_init();", fmt::format("i0_tile({});", idst)}; break; case UnaryOpType::I1: op_init_and_name = {"i1_tile_init();", fmt::format("i1_tile({});", idst)}; break; + case UnaryOpType::EXP: op_init_and_name = {"exp_tile_init();", fmt::format("exp_tile({});", idst)}; break; + case UnaryOpType::ERF: op_init_and_name = {"erf_tile_init();", fmt::format("erf_tile({0});", idst)}; break; + case UnaryOpType::ERFC: op_init_and_name = {"erfc_tile_init();", fmt::format("erfc_tile({});", idst)}; break; case UnaryOpType::ERFINV: op_init_and_name = {"erfinv_tile_init();", fmt::format("erfinv_tile({});", idst)}; break; @@ -325,9 +321,7 @@ std::pair get_op_init_and_func_default(UnaryOpType op_type, std: case UnaryOpType::IDENTITY_UINT32: op_init_and_name = {"identity_tile_init();", fmt::format("identity_tile_uint32({});", idst)}; break; - case UnaryOpType::FLOOR: - op_init_and_name = {"floor_tile_init();", fmt::format("floor_tile({});", idst)}; - break; + case UnaryOpType::FLOOR: op_init_and_name = {"floor_tile_init();", fmt::format("floor_tile({});", idst)}; break; case UnaryOpType::FLOOR_FLOAT32: op_init_and_name = {"floor_tile_init();", fmt::format("floor_tile_float32({});", idst)}; break; @@ -341,7 +335,7 @@ std::pair get_op_init_and_func_default(UnaryOpType op_type, std: case UnaryOpType::NEG: op_init_and_name = {"negative_tile_init();", fmt::format("negative_tile({});", idst)}; break; - default: TT_ASSERT(false && "Undefined non-parametrized op type"); + default: TT_THROW("Undefined non-parametrized op type {}", op_type); } return op_init_and_name; }