diff --git a/src/target/llvm/codegen_aarch64.cc b/src/target/llvm/codegen_aarch64.cc index 94ad34bbcff2..785c45457e60 100644 --- a/src/target/llvm/codegen_aarch64.cc +++ b/src/target/llvm/codegen_aarch64.cc @@ -27,6 +27,7 @@ #include #include +#include "../../arith/scalable_expression.h" #include "codegen_cpu.h" #include "llvm_instance.h" @@ -40,6 +41,7 @@ class CodeGenAArch64 final : public CodeGenCPU { void VisitStmt_(const AttrStmtNode* op); void AddFunction(const GlobalVar& gvar, const PrimFunc& f); + void SetTargetAttributes(llvm::Function* func); bool func_has_pstate_sm = false; bool func_has_pstate_za = false; @@ -51,6 +53,17 @@ void CodeGenAArch64::AddFunction(const GlobalVar& gvar, const PrimFunc& f) { CodeGenCPU::AddFunction(gvar, f); } +void CodeGenAArch64::SetTargetAttributes(llvm::Function* func) { +#if TVM_LLVM_VERSION >= 130 + // Add vscale_range() function attribute when appropriate. + if (llvm_target_->TargetHasCPUFeature("sve") || llvm_target_->TargetHasCPUFeature("sme")) { + func->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs( + *llvm_target_->GetContext(), 1, tvm::arith::kAArch64VScaleValues.size())); + } +#endif + CodeGenCPU::SetTargetAttributes(func); +} + /*! * \brief Visit and handle AArch64 specific pragmas. To be AArch64 specific, * the expectation is that they are prepended with "pragma_aarch64". diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 0f7aa847ecb8..d46ab7320bf1 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -431,7 +431,7 @@ class CodeGenLLVM : public ExprFunctor, * * \param func The function to set attributes on. */ - void SetTargetAttributes(llvm::Function* func); + virtual void SetTargetAttributes(llvm::Function* func); /*! * \brief Emit LLVM IR for conversion functions __extendhfsf2 and __truncsfhf2 * into the current llvm::Module. diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index 8f22ba5b73ed..6b8bb2d9d58b 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -537,6 +537,44 @@ def my_func(a: T.handle): assert re.findall(r" store ", llvm), "No scalable store in generated LLVM." +@pytest.mark.skipif( + llvm_version_major() < 13, + reason="Function attribute vscale_range() is not supported in earlier versions of LLVM", +) +@pytest.mark.parametrize( + "mattr,expect_attr", + [ + ("+neon", False), + ("+sve", True), + ("+v9a", True), + ("+sme", True), + ], +) +def test_vscale_range_function_attribute(mattr, expect_attr): + target = f"llvm -mtriple=aarch64-linux-gnu -mattr={mattr}" + + m = te.var("m") + A = te.placeholder(m, dtype="float32", name="A") + C = te.compute((m), lambda i: A[i] + 1, name="C") + s = te.create_schedule([C.op]) + + with tvm.target.Target(target) as target: + f = tvm.build(s, [A, C], target) + + # Check if the vscale_range() attribute exists + ll = f.get_source("ll") + attr = re.findall(rf".*vscale_range\(\d+,\d+\)*.", ll) + + if expect_attr: + assert ( + len(attr) > 0 + ), f"Function attribute vscale_range() was not found in generated LLVM IR" + else: + assert ( + len(attr) == 0 + ), f"Unexpected function attribute vscale_range() was found in generated LLVM IR" + + @pytest.mark.skipif( llvm_version_major() < 16, reason="Test requires an LLVM version of at least 16 to target SME" )