From b7415b2229e5f358ddf96c967e920a383a948413 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mustafa=20Ebrar=20Akta=C5=9F?= Date: Mon, 25 Mar 2024 10:37:47 +0200 Subject: [PATCH] feat: add ruy sgemm implementation (#1598) * feat: add ruy sgemm implementation * fix: move RUY sgemm implementation below BLAS --- src/cpu/backend.cc | 2 +- src/cpu/primitives.cc | 54 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/src/cpu/backend.cc b/src/cpu/backend.cc index b82c62edb..a2e4e8bac 100644 --- a/src/cpu/backend.cc +++ b/src/cpu/backend.cc @@ -81,7 +81,7 @@ namespace ctranslate2 { #endif #ifdef CT2_WITH_RUY - if (is_int8) { + if (is_int8 || compute_type == ComputeType::FLOAT32) { return GemmBackend::RUY; } #endif diff --git a/src/cpu/primitives.cc b/src/cpu/primitives.cc index 49bee5b7b..0c6377bbb 100644 --- a/src/cpu/primitives.cc +++ b/src/cpu/primitives.cc @@ -707,6 +707,60 @@ namespace ctranslate2 { } #endif +#ifdef CT2_WITH_RUY + case cpu::GemmBackend::RUY: { + if (lda != (transpose_a ? m : k) + || ldb != (transpose_b ? k : n) + || ldc != n) + throw std::invalid_argument("Ruy GEMM does not support custom leading dimensions"); + + ruy::Context *context = cpu::get_ruy_context(); + + const ruy::Order order_a = transpose_a ? ruy::Order::kColMajor : ruy::Order::kRowMajor; + const ruy::Order order_b = transpose_b ? ruy::Order::kColMajor : ruy::Order::kRowMajor; + + ruy::Matrix lhs; + ruy::MakeSimpleLayout(m, k, order_a, lhs.mutable_layout()); + lhs.set_data(a); + + ruy::Matrix rhs; + ruy::MakeSimpleLayout(k, n, order_b, rhs.mutable_layout()); + rhs.set_data(b); + + ruy::Matrix dst; + ruy::MakeSimpleLayout(m, n, ruy::Order::kRowMajor, dst.mutable_layout()); + dst.set_data(c); + + float *tmp_c = nullptr; + + ruy::MulParams mul_params; + + if (beta != 0.0f) { + // this block sets `(beta / alpha) * c` as bias + // and multiplication by `alpha` below generates correct value: + // C <- alpha * (AB + (beta/alpha) * C) + // <- alpha * AB + beta * C + // there is no guard for alpha = 0.0 case, as it is unlikely to + // call this function with that value. + auto beta_prime = beta / alpha; + tmp_c = static_cast(allocator.allocate(m * n * sizeof (float))); + mul(beta_prime, c, tmp_c, m * n); + mul_params.set_bias(tmp_c); + } + + ruy::Mul(lhs, rhs, mul_params, context, &dst); + + if (alpha != 1.0f) { + mul(alpha, c, m * n); + } + + if (tmp_c) { + allocator.free(tmp_c); + } + break; + } +#endif + default: throw std::runtime_error("No SGEMM backend on CPU"); }