Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SVE] Add codegen support for vscale_range() function attribute #16962

Merged
merged 2 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
#include <vector>

#include "../../arith/pattern_match.h"
#include "../../arith/scalable_expression.h"
#include "../build_common.h"
#include "../func_registry_generator.h"
#include "codegen_params.h"
Expand Down Expand Up @@ -1127,6 +1128,13 @@ void CodeGenLLVM::SetTargetAttributes(llvm::Function* func) {
if (!features.empty()) {
func->addFnAttr("target-features", features);
}
#if TVM_LLVM_VERSION >= 130
// Add vscale_range() function attribute when appropriate.
if (llvm_target_->TargetHasCPUFeature("sve") || llvm_target_->TargetHasCPUFeature("sme")) {
lhutton1 marked this conversation as resolved.
Show resolved Hide resolved
Anndrey24 marked this conversation as resolved.
Show resolved Hide resolved
func->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs(
*llvm_target_->GetContext(), 1, tvm::arith::kAArch64VScaleValues.size()));
}
#endif
}

void CodeGenLLVM::EmitFloat16ConversionBuiltins(bool use_float16_abi) {
Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
llvm::ArrayRef<llvm::Type*> arg_types);
/*!
* \brief Set target-related attributes on the LLVM function \p func. This
* includes "target-cpu" and "target-features" if present.
* includes "target-cpu", "target-features" and "vscale_range()" if present.
*
* \param func The function to set attributes on.
*/
Expand Down
38 changes: 38 additions & 0 deletions tests/python/codegen/test_target_codegen_aarch64.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,44 @@ def my_func(a: T.handle):
assert re.findall(r" store <vscale x 4 x float>", llvm), "No scalable store in generated LLVM."


@pytest.mark.skipif(
llvm_version_major() < 13,
reason="Function attribute vscale_range() is not supported in earlier versions of LLVM",
)
@pytest.mark.parametrize(
"mattr,expect_attr",
[
("+neon", False),
("+sve", True),
("+v9a", True),
("+sme", True),
],
)
def test_vscale_range_function_attribute(mattr, expect_attr):
target = f"llvm -mtriple=aarch64-linux-gnu -mattr={mattr}"

m = te.var("m")
A = te.placeholder(m, dtype="float32", name="A")
C = te.compute((m), lambda i: A[i] + 1, name="C")
s = te.create_schedule([C.op])

with tvm.target.Target(target) as target:
f = tvm.build(s, [A, C], target)

# Check if the vscale_range() attribute exists
ll = f.get_source("ll")
attr = re.findall(rf".*vscale_range\(\d+,\d+\)*.", ll)

if expect_attr:
assert (
len(attr) > 0
), f"Function attribute vscale_range() was not found in generated LLVM IR"
else:
assert (
len(attr) == 0
), f"Unexpected function attribute vscale_range() was found in generated LLVM IR"


@pytest.mark.skipif(
llvm_version_major() < 16, reason="Test requires an LLVM version of at least 16 to target SME"
)
Expand Down
Loading