Skip to content

Commit

Permalink
#16502: Add Unary with params support to BinaryNg
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickroberts committed Jan 24, 2025
1 parent 080f063 commit a568827
Show file tree
Hide file tree
Showing 12 changed files with 198 additions and 297 deletions.
54 changes: 52 additions & 2 deletions tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,7 @@ def test_binary_scalar_ops(a_shape, b_shape, ttnn_fn, activations, device):
a_pt, a_tt = rand_bf16_gen(a_shape, device)
b_pt, b_tt = rand_bf16_gen(b_shape, device, min=min, max=max)

cq_id = 0
out_tt = ttnn_op(a_tt, b_tt, queue_id=cq_id, lhs_activations=lhs, rhs_activations=rhs, post_activations=post)
out_tt = ttnn_op(a_tt, b_tt, lhs_activations=lhs, rhs_activations=rhs, post_activations=post)

for golden_activation in golden_lhs:
a_pt = golden_activation(a_pt).bfloat16()
Expand All @@ -179,6 +178,57 @@ def compare(tt, pt):
assert compare([out_tt], [out_pt])


activation_with_param_fns = {
"ADD_UNARY_SFPU": torch.add,
"SUB_UNARY_SFPU": torch.sub,
"MUL_UNARY_SFPU": torch.mul,
"DIV_UNARY_SFPU": torch.div,
"POWER": torch.pow,
}


@pytest.mark.parametrize(
"a_shape, b_shape",
(
(torch.Size([1, 1, 1, 1]), torch.Size([5, 3, 32, 32])),
(torch.Size([5, 1, 64, 1]), torch.Size([1, 3, 1, 128])),
(torch.Size([5, 1, 1, 64]), torch.Size([1, 3, 128, 1])),
),
)
@pytest.mark.parametrize("ttnn_fn", ("add", "sub", "mul", "div"))
@pytest.mark.parametrize(
"post_activations",
(
(),
(("ADD_UNARY_SFPU", 7),),
(("SUB_UNARY_SFPU", 6),),
(("MUL_UNARY_SFPU", 5),),
(("DIV_UNARY_SFPU", 4),),
(("POWER", 3),),
),
)
def test_binary_scalar_ops_with_unary_param(a_shape, b_shape, ttnn_fn, post_activations, device):
torch.manual_seed(0)
ttnn_op = getattr(ttnn.experimental, ttnn_fn)
post = [(getattr(ttnn.UnaryOpType, op), param) for op, param in post_activations]
golden_post = ((lambda x: activation_with_param_fns[op](x, param)) for op, param in post_activations)
# make 0 exclusive for rhs of div
min, max = (1, 0) if ttnn_fn == "div" else (0, 1)

a_pt, a_tt = rand_bf16_gen(a_shape, device)
b_pt, b_tt = rand_bf16_gen(b_shape, device, min=min, max=max)

out_tt = ttnn_op(a_tt, b_tt, post_activations=post)

golden_fn = ttnn.get_golden_function(ttnn_op)
out_pt = golden_fn(a_pt, b_pt).bfloat16()

for golden_activation in golden_post:
out_pt = golden_activation(out_pt).bfloat16()

assert compare_pcc([out_tt], [out_pt])


@pytest.mark.parametrize(
"a_shape, b_shape",
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ std::map<std::string, std::string> get_defines(
op_binary_type = "EltwiseBinaryType::ELWADD";
defines.merge(get_defines(UnaryOpType::LOG2, std::nullopt, "0", idst));
break;
default: TT_ASSERT(false && "Undefined op type");
default: TT_THROW("Undefined op type {}", op_type);
}

using DataType = tt::tt_metal::DataType;
Expand Down Expand Up @@ -138,10 +138,10 @@ std::map<std::string, std::string> get_defines(
(input_dtype.value() == DataType::UINT16 && output_dtype.value() == DataType::BFLOAT4_B) ||
(input_dtype.value() == DataType::BFLOAT4_B && output_dtype.value() == DataType::INT32) ||
(input_dtype.value() == DataType::INT32 && output_dtype.value() == DataType::BFLOAT4_B))) {
TT_ASSERT(defines.count("SFPU_OP_CHAIN_0") == 0 && "SFPU_OP_CHAIN_0 already defined");
TT_ASSERT(defines.count("SFPU_OP_CHAIN_0") == 0, "SFPU_OP_CHAIN_0 already defined");

auto in_dataformat = std::to_string((uint32_t)datatype_to_dataformat_converter(input_dtype.value()));
auto out_dataformat = std::to_string((uint32_t)datatype_to_dataformat_converter(output_dtype.value()));
auto in_dataformat = (uint32_t)datatype_to_dataformat_converter(input_dtype.value());
auto out_dataformat = (uint32_t)datatype_to_dataformat_converter(output_dtype.value());
defines.insert(
{"SFPU_OP_CHAIN_0",
fmt::format("typecast_tile_init(); typecast_tile<{0}u, {1}u>(i);", in_dataformat, out_dataformat)});
Expand Down
36 changes: 18 additions & 18 deletions ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ Tensor BinaryNg<binary_op_type>::invoke(
const std::optional<const DataType>& output_dtype,
const std::optional<MemoryConfig>& memory_config,
std::optional<Tensor> optional_output_tensor,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> lhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> rhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> post_activations) {
tt::stl::Span<const ttnn::operations::unary::UnaryWithParam> lhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryWithParam> rhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryWithParam> post_activations) {
bool typecast_a = needs_typecast_to_bfloat16(input_tensor_a);
bool typecast_b = needs_typecast_to_bfloat16(input_tensor_b);
Tensor input_a = typecast_a ? typecast_to(DataType::BFLOAT16, input_tensor_a) : input_tensor_a;
Expand All @@ -52,9 +52,9 @@ Tensor BinaryNg<binary_op_type>::invoke(
const std::optional<const DataType>& output_dtype,
const std::optional<MemoryConfig>& memory_config,
std::optional<Tensor> optional_output_tensor,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> lhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> rhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> post_activations) {
tt::stl::Span<const ttnn::operations::unary::UnaryWithParam> lhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryWithParam> rhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryWithParam> post_activations) {
return invoke(
DefaultQueueId,
input_tensor_a,
Expand All @@ -75,9 +75,9 @@ Tensor BinaryNg<binary_op_type>::invoke(
const std::optional<const DataType>& output_dtype,
const std::optional<MemoryConfig>& memory_config,
std::optional<Tensor> optional_output_tensor,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> lhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> rhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> post_activations) {
tt::stl::Span<const ttnn::operations::unary::UnaryWithParam> lhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryWithParam> rhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryWithParam> post_activations) {
bool typecast_a = needs_typecast_to_bfloat16(input_tensor_a);
Tensor input_a = typecast_a ? typecast_to(DataType::BFLOAT16, input_tensor_a) : input_tensor_a;

Expand All @@ -101,9 +101,9 @@ Tensor BinaryNg<binary_op_type>::invoke(
const std::optional<const DataType>& output_dtype,
const std::optional<MemoryConfig>& memory_config,
std::optional<Tensor> optional_output_tensor,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> lhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> rhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> post_activations) {
tt::stl::Span<const ttnn::operations::unary::UnaryWithParam> lhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryWithParam> rhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryWithParam> post_activations) {
return invoke(
DefaultQueueId,
input_tensor_a,
Expand All @@ -127,9 +127,9 @@ Tensor BinaryNgBitwise<binary_op_type>::invoke(
input_tensor_a.get_dtype() == DataType::INT32 && input_tensor_b.get_dtype() == DataType::INT32,
"Bitwise ops require input tensors to be of INT32 datatype ");

tt::stl::Span<const unary::UnaryOpType> lhs_activations = {};
tt::stl::Span<const unary::UnaryOpType> rhs_activations = {};
tt::stl::Span<const unary::UnaryOpType> post_activations = {};
tt::stl::Span<const unary::UnaryWithParam> lhs_activations = {};
tt::stl::Span<const unary::UnaryWithParam> rhs_activations = {};
tt::stl::Span<const unary::UnaryWithParam> post_activations = {};

return ttnn::prim::binary_ng(
queue_id,
Expand Down Expand Up @@ -164,9 +164,9 @@ Tensor BinaryNgBitwise<binary_op_type>::invoke(
TT_FATAL(
input_tensor_a.get_dtype() == DataType::INT32, "Bitwise ops require input tensor to be of INT32 datatype ");

tt::stl::Span<const unary::UnaryOpType> lhs_activations = {};
tt::stl::Span<const unary::UnaryOpType> rhs_activations = {};
tt::stl::Span<const unary::UnaryOpType> post_activations = {};
tt::stl::Span<const unary::UnaryWithParam> lhs_activations = {};
tt::stl::Span<const unary::UnaryWithParam> rhs_activations = {};
tt::stl::Span<const unary::UnaryWithParam> post_activations = {};

return ttnn::prim::binary_ng(
queue_id,
Expand Down
24 changes: 12 additions & 12 deletions ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,19 @@ struct BinaryNg {
const std::optional<const DataType>& output_dtype = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt,
tt::stl::Span<const unary::UnaryOpType> lhs_activations = {},
tt::stl::Span<const unary::UnaryOpType> rhs_activations = {},
tt::stl::Span<const unary::UnaryOpType> post_activations = {});
tt::stl::Span<const unary::UnaryWithParam> lhs_activations = {},
tt::stl::Span<const unary::UnaryWithParam> rhs_activations = {},
tt::stl::Span<const unary::UnaryWithParam> post_activations = {});

static Tensor invoke(
const Tensor& input_tensor_a,
const Tensor& input_tensor_b,
const std::optional<const DataType>& output_dtype = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt,
tt::stl::Span<const unary::UnaryOpType> lhs_activations = {},
tt::stl::Span<const unary::UnaryOpType> rhs_activations = {},
tt::stl::Span<const unary::UnaryOpType> post_activations = {});
tt::stl::Span<const unary::UnaryWithParam> lhs_activations = {},
tt::stl::Span<const unary::UnaryWithParam> rhs_activations = {},
tt::stl::Span<const unary::UnaryWithParam> post_activations = {});

static Tensor invoke(
uint8_t queue_id,
Expand All @@ -42,19 +42,19 @@ struct BinaryNg {
const std::optional<const DataType>& output_dtype = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt,
tt::stl::Span<const unary::UnaryOpType> lhs_activations = {},
tt::stl::Span<const unary::UnaryOpType> rhs_activations = {},
tt::stl::Span<const unary::UnaryOpType> post_activations = {});
tt::stl::Span<const unary::UnaryWithParam> lhs_activations = {},
tt::stl::Span<const unary::UnaryWithParam> rhs_activations = {},
tt::stl::Span<const unary::UnaryWithParam> post_activations = {});

static Tensor invoke(
const Tensor& input_tensor_a,
float scalar,
const std::optional<const DataType>& output_dtype = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt,
tt::stl::Span<const unary::UnaryOpType> lhs_activations = {},
tt::stl::Span<const unary::UnaryOpType> rhs_activations = {},
tt::stl::Span<const unary::UnaryOpType> post_activations = {});
tt::stl::Span<const unary::UnaryWithParam> lhs_activations = {},
tt::stl::Span<const unary::UnaryWithParam> rhs_activations = {},
tt::stl::Span<const unary::UnaryWithParam> post_activations = {});
};

template <BinaryOpType binary_op_type>
Expand Down
24 changes: 12 additions & 12 deletions ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ void bind_binary_ng_operation(py::module& module, T op, const std::string& docst
const std::optional<const DataType>& dtype,
const std::optional<ttnn::MemoryConfig>& memory_config,
const std::optional<ttnn::Tensor>& output_tensor,
const ttnn::SmallVector<unary::UnaryOpType>& lhs_activations,
const ttnn::SmallVector<unary::UnaryOpType>& rhs_activations,
const ttnn::SmallVector<unary::UnaryOpType>& post_activations,
const ttnn::SmallVector<unary::UnaryWithParam>& lhs_activations,
const ttnn::SmallVector<unary::UnaryWithParam>& rhs_activations,
const ttnn::SmallVector<unary::UnaryWithParam>& post_activations,
const uint8_t& queue_id) -> ttnn::Tensor {
return self(
queue_id,
Expand All @@ -45,9 +45,9 @@ void bind_binary_ng_operation(py::module& module, T op, const std::string& docst
py::arg("dtype") = std::nullopt,
py::arg("memory_config") = std::nullopt,
py::arg("output_tensor") = std::nullopt,
py::arg("lhs_activations") = ttnn::SmallVector<unary::UnaryOpType>(),
py::arg("rhs_activations") = ttnn::SmallVector<unary::UnaryOpType>(),
py::arg("post_activations") = ttnn::SmallVector<unary::UnaryOpType>(),
py::arg("lhs_activations") = ttnn::SmallVector<unary::UnaryWithParam>(),
py::arg("rhs_activations") = ttnn::SmallVector<unary::UnaryWithParam>(),
py::arg("post_activations") = ttnn::SmallVector<unary::UnaryWithParam>(),
py::arg("queue_id") = 0},

// tensor and tensor
Expand All @@ -58,9 +58,9 @@ void bind_binary_ng_operation(py::module& module, T op, const std::string& docst
const std::optional<const DataType>& dtype,
const std::optional<ttnn::MemoryConfig>& memory_config,
const std::optional<ttnn::Tensor>& output_tensor,
const ttnn::SmallVector<unary::UnaryOpType>& lhs_activations,
const ttnn::SmallVector<unary::UnaryOpType>& rhs_activations,
const ttnn::SmallVector<unary::UnaryOpType>& post_activations,
const ttnn::SmallVector<unary::UnaryWithParam>& lhs_activations,
const ttnn::SmallVector<unary::UnaryWithParam>& rhs_activations,
const ttnn::SmallVector<unary::UnaryWithParam>& post_activations,
uint8_t queue_id) -> ttnn::Tensor {
return self(
queue_id,
Expand All @@ -79,9 +79,9 @@ void bind_binary_ng_operation(py::module& module, T op, const std::string& docst
py::arg("dtype") = std::nullopt,
py::arg("memory_config") = std::nullopt,
py::arg("output_tensor") = std::nullopt,
py::arg("lhs_activations") = ttnn::SmallVector<unary::UnaryOpType>(),
py::arg("rhs_activations") = ttnn::SmallVector<unary::UnaryOpType>(),
py::arg("post_activations") = ttnn::SmallVector<unary::UnaryOpType>(),
py::arg("lhs_activations") = ttnn::SmallVector<unary::UnaryWithParam>(),
py::arg("rhs_activations") = ttnn::SmallVector<unary::UnaryWithParam>(),
py::arg("post_activations") = ttnn::SmallVector<unary::UnaryWithParam>(),
py::arg("queue_id") = 0});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,9 @@ BinaryNgDeviceOperation::invoke(
const std::optional<const DataType>& output_dtype,
const std::optional<MemoryConfig>& memory_config,
std::optional<Tensor> output_tensor,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> lhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> rhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryOpType> post_activations) {
tt::stl::Span<const ttnn::operations::unary::UnaryWithParam> lhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryWithParam> rhs_activations,
tt::stl::Span<const ttnn::operations::unary::UnaryWithParam> post_activations) {
auto subtile_broadcast_type = get_subtile_broadcast_type(
input_tensor_a.get_logical_shape()[-2],
input_tensor_a.get_logical_shape()[-1],
Expand Down Expand Up @@ -400,9 +400,9 @@ BinaryNgDeviceOperation::invoke(
const std::optional<const DataType>& output_dtype,
const std::optional<MemoryConfig>& memory_config,
std::optional<Tensor> output_tensor,
tt::stl::Span<const unary::UnaryOpType> lhs_activations,
tt::stl::Span<const unary::UnaryOpType> rhs_activations,
tt::stl::Span<const unary::UnaryOpType> post_activations) {
tt::stl::Span<const unary::UnaryWithParam> lhs_activations,
tt::stl::Span<const unary::UnaryWithParam> rhs_activations,
tt::stl::Span<const unary::UnaryWithParam> post_activations) {
DataType dtype1 = input_tensor_a.get_dtype();
bool device_check = input_tensor_a.device()->arch() != tt::ARCH::GRAYSKULL;
bool is_sfpu_op = (utils::is_binary_sfpu_op(binary_op_type, dtype1, dtype1) && device_check);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ struct BinaryNgDeviceOperation {

struct operation_attributes_t {
BinaryOpType binary_op_type;
ttnn::SmallVector<unary::UnaryOpType> lhs_activations;
ttnn::SmallVector<unary::UnaryOpType> rhs_activations;
ttnn::SmallVector<unary::UnaryOpType> post_activations;
ttnn::SmallVector<unary::UnaryWithParam> lhs_activations;
ttnn::SmallVector<unary::UnaryWithParam> rhs_activations;
ttnn::SmallVector<unary::UnaryWithParam> post_activations;
std::optional<float> scalar;
tt::tt_metal::MemoryConfig memory_config;
DataType input_dtype;
Expand Down Expand Up @@ -92,9 +92,9 @@ struct BinaryNgDeviceOperation {
const std::optional<const DataType>& output_dtype,
const std::optional<MemoryConfig>& memory_config,
std::optional<Tensor> optional_output_tensor,
tt::stl::Span<const unary::UnaryOpType> lhs_activations,
tt::stl::Span<const unary::UnaryOpType> rhs_activations,
tt::stl::Span<const unary::UnaryOpType> post_activations);
tt::stl::Span<const unary::UnaryWithParam> lhs_activations,
tt::stl::Span<const unary::UnaryWithParam> rhs_activations,
tt::stl::Span<const unary::UnaryWithParam> post_activations);

// tensor-scalar invocation
static std::tuple<operation_attributes_t, tensor_args_t> invoke(
Expand All @@ -104,9 +104,9 @@ struct BinaryNgDeviceOperation {
const std::optional<const DataType>& output_dtype,
const std::optional<MemoryConfig>& memory_config,
std::optional<Tensor> optional_output_tensor,
tt::stl::Span<const unary::UnaryOpType> lhs_activations,
tt::stl::Span<const unary::UnaryOpType> rhs_activations,
tt::stl::Span<const unary::UnaryOpType> post_activations);
tt::stl::Span<const unary::UnaryWithParam> lhs_activations,
tt::stl::Span<const unary::UnaryWithParam> rhs_activations,
tt::stl::Span<const unary::UnaryWithParam> post_activations);
};

} // namespace ttnn::operations::binary_ng
Expand Down
Loading

0 comments on commit a568827

Please sign in to comment.