From c6377abc4164603eedcea64c9c214c518ba21012 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 18 Mar 2023 22:05:02 -0400 Subject: [PATCH] [CODEGEN][METAL] Fix ramp codegen Fix ramp node codegen for the metal backend. The default C codegen can cause problem in vector indices assignment. Confirmed on apple M2. --- src/target/source/codegen_metal.cc | 11 +++++++++++ src/target/source/codegen_metal.h | 1 + 2 files changed, 12 insertions(+) diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 928d961d50ee..ad9560eef214 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -299,6 +299,17 @@ void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // N os << ')'; } +void CodeGenMetal::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*) + PrintType(op->dtype, os); + os << "("; + for (int i = 0; i < op->lanes; ++i) { + if (i != 0) os << ", "; + os << "(" << PrintExpr(op->base) << ")" + << "+(" << PrintExpr(op->stride) << "*" << i << ")"; + } + os << ')'; +} + void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) if (op->op.same_as(builtin::reinterpret())) { // generate as_type(ARG) diff --git a/src/target/source/codegen_metal.h b/src/target/source/codegen_metal.h index 9fb8f80303f9..4e464c6636a8 100644 --- a/src/target/source/codegen_metal.h +++ b/src/target/source/codegen_metal.h @@ -51,6 +51,7 @@ class CodeGenMetal final : public CodeGenC { void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final; // overload visitor void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // reuse parent's function.