From 7556cd09db08a4daa602a787e84539386e04bd47 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Fri, 8 Mar 2024 18:33:19 +0000 Subject: [PATCH] [SVE] Add codegen support for scalable buffer accesses This commit adds support for generating code for scalable loads and stores. It also adds support for the creation of scalable broadcast operations. Co-authored-by: Elen Kalda Co-authored-by: Neil Hickey Change-Id: Id4600a2d4537f5260f4a7dc7ed430df6b8e53eb3 --- include/tvm/runtime/data_type.h | 16 ++- python/tvm/testing/utils.py | 7 + src/target/llvm/codegen_llvm.cc | 66 ++++----- src/target/llvm/codegen_llvm.h | 1 - src/tir/ir/data_type_rewriter.cc | 2 +- src/tir/ir/expr.cc | 7 +- src/tir/transforms/storage_rewrite.cc | 7 + tests/cpp/tir_scalable_datatype.cc | 16 +++ .../codegen/test_target_codegen_aarch64.py | 41 ++++++ tests/python/target/test_arm_target.py | 125 ++++++++++++++++++ 10 files changed, 249 insertions(+), 39 deletions(-) diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index f6a7d424ed7d..8f3ae9b42460 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -110,6 +110,8 @@ class DataType { } return -lanes_as_int; } + /*! \return get vscale factor or lanes depending on scalability of the vector. */ + int get_lanes_or_vscale_factor() { return is_scalable_vector() ? vscale_factor() : lanes(); } /*! \return whether type is a scalar type. */ bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; } /*! \return whether type is a scalar type. */ @@ -211,10 +213,13 @@ class DataType { /*! * \brief Construct an uint type. * \param bits The number of bits in the type. - * \param lanes The number of lanes + * \param lanes The number of lanes. + * \param is_scalable Whether the data type is scalable. * \return The constructed data type. */ - static DataType UInt(int bits, int lanes = 1) { return DataType(kDLUInt, bits, lanes); } + static DataType UInt(int bits, int lanes = 1, bool is_scalable = false) { + return DataType(kDLUInt, bits, lanes, is_scalable); + } /*! * \brief Construct an float type. * \param bits The number of bits in the type. @@ -243,10 +248,13 @@ class DataType { static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kE5M2Float, 8, lanes); } /*! * \brief Construct a bool type. - * \param lanes The number of lanes + * \param lanes The number of lanes. + * \param is_scalable Whether the data type is scalable. * \return The constructed data type. */ - static DataType Bool(int lanes = 1) { return DataType::UInt(1, lanes); } + static DataType Bool(int lanes = 1, bool is_scalable = false) { + return DataType::UInt(1, lanes, is_scalable); + } /*! * \brief Construct a handle type. * \param bits The number of bits in the type. diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 6e23a84bc290..e1b1c654570a 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -1045,6 +1045,13 @@ def _has_cpu_feat(features): ) +requires_aarch64_sve = Feature( + "arm_sve", + "AArch64 SVE", + run_time_check=lambda: _has_cpu_feat("sve"), +) + + requires_x86_vnni = Feature( "x86_vnni", "x86 VNNI Extensions", diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index eae26e5cac5b..bba1488274e2 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -587,10 +587,17 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { LOG(FATAL) << "do not support " << dtype; } } - if (dtype.lanes() != 1) { + if (!dtype.is_scalar()) { #if TVM_LLVM_VERSION >= 110 - return llvm::FixedVectorType::get(etype, dtype.lanes()); + if (dtype.is_scalable_vector()) { + return llvm::VectorType::get(etype, dtype.vscale_factor(), true); + } else { + return llvm::FixedVectorType::get(etype, dtype.lanes()); + } #else + ICHECK(!dtype.is_scalable_vector()) + << "Versions of LLVM < 11 do not support scalable vectors. Please upgrade to a later " + "version."; return llvm::VectorType::get(etype, dtype.lanes()); #endif } else { @@ -749,26 +756,6 @@ std::unique_ptr CodeGenLLVM::CreateDebugInfo(llvm::Modul return debug_info; } -llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) { -#if TVM_LLVM_VERSION >= 110 - llvm::Type* type = llvm::FixedVectorType::get(value->getType(), lanes); -#else - llvm::Type* type = llvm::VectorType::get(value->getType(), lanes); -#endif - llvm::Constant* undef = llvm::UndefValue::get(type); - llvm::Constant* zero = ConstInt32(0); - value = builder_->CreateInsertElement(undef, value, zero); -#if TVM_LLVM_VERSION >= 120 - llvm::Constant* mask = llvm::ConstantVector::getSplat(llvm::ElementCount::getFixed(lanes), zero); -#elif TVM_LLVM_VERSION >= 110 - llvm::Constant* mask = - llvm::ConstantVector::getSplat(llvm::ElementCount(lanes, /*Scalable=*/false), zero); -#else - llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero); -#endif - return builder_->CreateShuffleVector(value, undef, mask); -} - llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) { int num_elems = GetVectorNumElements(vec); if (extent == num_elems && begin == 0) return vec; @@ -1693,7 +1680,8 @@ void CodeGenLLVM::BufferAccessHelper( } PrimExpr last_index = indices[indices.size() - 1]; - ICHECK_EQ(value_dtype.lanes(), last_index.dtype().lanes() * buffer_element_dtype.lanes()); + ICHECK_EQ(value_dtype.get_lanes_or_vscale_factor(), + last_index.dtype().get_lanes_or_vscale_factor() * buffer_element_dtype.lanes()); // Record index and elemtype in original form used for alias info PrimExpr last_index_origin = last_index; @@ -1736,8 +1724,6 @@ void CodeGenLLVM::BufferAccessHelper( llvm::Value* last_index_value; int subelement_i = i; if (const RampNode* ramp = last_index.as()) { - // TODO(ekalda): P4 in https://github.com/apache/tvm/issues/16455 - ICHECK(!last_index.dtype().is_scalable_vector()); PrimExpr offset = ramp->base + (ramp->stride * i); last_index_value = MakeValue(offset); } else if (last_index.dtype().lanes() > 1) { @@ -1754,8 +1740,13 @@ void CodeGenLLVM::BufferAccessHelper( all_index_values.push_back(last_index_value); TypedPointer buffer_ptr = - CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, all_index_values, - value_dtype.with_lanes(value_dtype.lanes() / last_index.dtype().lanes())); + value_dtype.is_scalable_vector() + ? CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, all_index_values, + value_dtype.with_scalable_vscale_factor(value_dtype.vscale_factor() / + last_index.dtype().lanes())) + : CreateBufferPtr( + MakeValue(buffer->data), buffer_element_dtype, all_index_values, + value_dtype.with_lanes(value_dtype.lanes() / last_index.dtype().lanes())); auto instruction = make_instruction(buffer_ptr, subelement_i, alignment, is_volatile); AddAliasInfo(instruction, buffer->data.get(), last_index_origin, buffer_element_dtype_origin); } @@ -1870,10 +1861,23 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const BroadcastNode* op) { - // TODO(ekalda): P4 in https://github.com/apache/tvm/issues/16455 - ICHECK(!op->dtype.is_scalable_vector()); - int lanes = op->dtype.lanes(); - return CreateBroadcast(MakeValue(op->value), lanes); + DataType dtype = op->dtype; + llvm::Value* value = MakeValue(op->value); + llvm::Type* type = DTypeToLLVMType(dtype); + llvm::Constant* undef = llvm::UndefValue::get(type); + llvm::Constant* zero = ConstInt32(0); + value = builder_->CreateInsertElement(undef, value, zero); +#if TVM_LLVM_VERSION >= 110 + llvm::ElementCount ec = + llvm::ElementCount::get(dtype.get_lanes_or_vscale_factor(), dtype.is_scalable_vector()); + llvm::Constant* mask = llvm::ConstantVector::getSplat(ec, zero); +#else + ICHECK(!dtype.is_scalable_vector()) + << "Versions of LLVM < 11 do not support scalable vectors. Please upgrade to a later " + "version."; + llvm::Constant* mask = llvm::ConstantVector::getSplat(dtype.lanes(), zero); +#endif + return builder_->CreateShuffleVector(value, undef, mask); } void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 2efac0307345..0f7aa847ecb8 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -468,7 +468,6 @@ class CodeGenLLVM : public ExprFunctor, llvm::Value* CreateAdd(DataType t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateSub(DataType t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateMul(DataType t, llvm::Value* a, llvm::Value* b); - llvm::Value* CreateBroadcast(llvm::Value* value, int lanes); virtual TypedPointer CreateBufferPtr(llvm::Value* buffer_ptr, DataType buffer_element_dtype, llvm::ArrayRef indices, DataType value_dtype); // Vector concatenation. diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index 2bd1e0608374..2d2c097be494 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -451,7 +451,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BufferStoreNode* op) { Buffer new_buffer = GetRemappedBuffer(op->buffer); auto value = this->VisitExpr(op->value); - if (new_buffer->dtype != value->dtype && value->dtype.lanes() == 1) { + if (new_buffer->dtype != value->dtype && value->dtype.is_scalar()) { value = cast(new_buffer->dtype, value); } auto indices = VisitIndices(op->indices); diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 1b611d453418..c2baad209624 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -58,7 +58,9 @@ namespace tir { CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types. " << a.dtype() << " vs. " \ << b.dtype() << "\n"; \ ObjectPtr node = make_object(); \ - node->dtype = DataType::Bool(a.dtype().lanes()); \ + DataType a_dtype = a.dtype(); \ + node->dtype = \ + DataType::Bool(a_dtype.get_lanes_or_vscale_factor(), a_dtype.is_scalable_vector()); \ node->a = std::move(a); \ node->b = std::move(b); \ node->span = std::move(span); \ @@ -393,7 +395,8 @@ Not::Not(PrimExpr a, Span span) { ICHECK(a.dtype().is_bool()); ObjectPtr node = make_object(); - node->dtype = DataType::Bool(a.dtype().lanes()); + DataType a_dtype = a.dtype(); + node->dtype = DataType::Bool(a_dtype.get_lanes_or_vscale_factor(), a_dtype.is_scalable_vector()); node->a = std::move(a); node->span = std::move(span); data_ = std::move(node); diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index e40f683e21f8..3f34f2e870fd 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -1275,6 +1275,13 @@ class VectorTypeAccessChecker : public StmtExprVisitor { auto it = info_map_.find(buffer); ICHECK(it != info_map_.end()) << "Load/Store of buffer " << buffer->name_hint << " (" << buffer << ") occurred before its declaration."; + + if (value_dtype.is_scalable_vector()) { + // Scalable types are not currently supported in storage_rewrite. Scalable buffer + // accesses are not currently checked and therefore are not rewritten. + return; + } + BufferVarInfo& var_info = it->second; if (value_dtype.element_of() == DataType::Bool()) { diff --git a/tests/cpp/tir_scalable_datatype.cc b/tests/cpp/tir_scalable_datatype.cc index 23decef69e5a..4b4764555f7b 100644 --- a/tests/cpp/tir_scalable_datatype.cc +++ b/tests/cpp/tir_scalable_datatype.cc @@ -162,6 +162,22 @@ TEST(ScalableDataType, TestScalableDataTypeInvalidLanesAccess) { tvm::InternalError); } +TEST(ScalableDataType, TestScalableBool) { + tvm::DataType scalable_type = tvm::DataType::Bool(4, true); + ASSERT_EQ(scalable_type.code(), kDLUInt); + ASSERT_EQ(scalable_type.bits(), 1); + ASSERT_EQ(scalable_type.vscale_factor(), 4); + ASSERT_TRUE(scalable_type.is_scalable_vector()); +} + +TEST(ScalableDataType, TestScalableUInt) { + tvm::DataType scalable_type = tvm::DataType::UInt(1, 4, true); + ASSERT_EQ(scalable_type.code(), kDLUInt); + ASSERT_EQ(scalable_type.bits(), 1); + ASSERT_EQ(scalable_type.vscale_factor(), 4); + ASSERT_TRUE(scalable_type.is_scalable_vector()); +} + // ----------- // Integration // ----------- diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index 4e75f916d9b2..773c113f4a42 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -492,5 +492,46 @@ def main(A: T.Buffer((5,), "int32")): assert re.findall(r"llvm.vscale.i32", llvm), "No vscale in generated LLVM." +@pytest.mark.skipif( + llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM" +) +def test_scalable_buffer_load_store(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + + @T.prim_func + def my_func(a: T.handle, b: T.handle): + A = T.match_buffer(a, (128,), "float32") + B = T.match_buffer(b, (128,), "float32") + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + B[T.ramp(0, 1, 4 * T.vscale())] = A[T.ramp(0, 1, 4 * T.vscale())] + + mod = tvm.build(my_func, target=target) + llvm = mod.get_source("ll") + + assert re.findall(r"load ", llvm), "No scalable load in generated LLVM." + assert re.findall(r" store ", llvm), "No scalable store in generated LLVM." + + +@pytest.mark.skipif( + llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM" +) +def test_scalable_broadcast(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + + @T.prim_func + def my_func(a: T.handle): + A = T.match_buffer(a, (128,), "float32") + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + A[T.ramp(0, 1, 4 * T.vscale())] = T.broadcast(1, 4 * T.vscale()) + + mod = tvm.build(my_func, target=target) + llvm = mod.get_source("ll") + + assert re.findall( + r"shufflevector \( insertelement \(", llvm + ), "No scalable broadcast in generated LLVM." + assert re.findall(r" store ", llvm), "No scalable store in generated LLVM." + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/target/test_arm_target.py b/tests/python/target/test_arm_target.py index dc8452710a8a..158d941073c6 100644 --- a/tests/python/target/test_arm_target.py +++ b/tests/python/target/test_arm_target.py @@ -14,9 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +import subprocess +import tempfile +import re + import pytest +import numpy as np import tvm +from tvm.script import tir as T from tvm.topi.arm_cpu.conv2d_int8 import is_int8_hw_support from tvm.target import codegen @@ -61,3 +68,121 @@ def test_arm_conv2d_int8_support( with tvm.target.Target(arm_target): monkeypatch.setattr(codegen, "llvm_version_major", lambda: llvm_version) assert is_int8_hw_support(input_dtype, kernel_dtype) == is_supported + + +@pytest.fixture(scope="session") +def sve_device_vector_length(): + c_code = r""" + #include + #include + + int main() { + printf("%ld\n", svcntb() * 8); + } + """ + + with tempfile.TemporaryDirectory() as tmp_dir: + c_path = f"{tmp_dir}/vl.c" + o_path = f"{tmp_dir}/out.o" + with open(c_path, "w") as f: + f.write(c_code) + tvm.contrib.cc.create_executable(o_path, c_path, ["-march=native"]) + out = subprocess.check_output(o_path, shell=True).strip().decode() + + return int(out) + + +@tvm.testing.requires_aarch64_sve +def test_scalable_div(sve_device_vector_length): + np.random.seed(0) + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + dev = tvm.cpu(0) + + @T.prim_func + def my_func(a: T.handle): + A = T.match_buffer(a, (1,), "int32") + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + A[0] = T.Div(10000, 4 * T.vscale()) + + mod = tvm.build(my_func, target=target) + + A_nd = tvm.nd.array(np.empty((1,), dtype="int32"), device=dev) + mod(A_nd) + + ref = 10000 // (sve_device_vector_length // 32) + tvm.testing.assert_allclose(A_nd.numpy()[0], ref) + + +@tvm.testing.requires_aarch64_sve +def test_scalable_buffer_load_store(sve_device_vector_length): + np.random.seed(0) + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + num_elements = sve_device_vector_length // 32 + dev = tvm.cpu(0) + + @T.prim_func + def my_func(a: T.handle, b: T.handle): + A = T.match_buffer(a, (num_elements,), "float32") + B = T.match_buffer(b, (num_elements,), "float32") + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + B[T.ramp(0, 1, 4 * T.vscale())] = A[T.ramp(0, 1, 4 * T.vscale())] + + mod = tvm.build(my_func, target=target) + + A_np = np.random.uniform(size=(num_elements,)).astype("float32") + B_np = np.zeros((num_elements,)).astype("float32") + A_nd = tvm.nd.array(A_np, device=dev) + B_nd = tvm.nd.array(B_np, device=dev) + mod(A_nd, B_nd) + + tvm.testing.assert_allclose(B_nd.numpy(), A_np) + + +@tvm.testing.requires_aarch64_sve +def test_scalable_loop_bound(sve_device_vector_length): + np.random.seed(0) + + dtype = "float32" + num_elements = sve_device_vector_length // 32 + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + dev = tvm.cpu(0) + + @T.prim_func + def my_func(a: T.handle, b: T.handle): + A = T.match_buffer(a, (num_elements,), "float32") + B = T.match_buffer(b, (num_elements,), "float32") + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + for i in T.serial(0, 4 * T.vscale()): + B[i] = A[i] + + mod = tvm.build(my_func, target=target) + + A_np = np.random.uniform(size=(num_elements,)).astype(dtype) + B_np = np.zeros((num_elements,)).astype(dtype) + A_nd = tvm.nd.array(A_np, device=dev) + B_nd = tvm.nd.array(B_np, device=dev) + mod(A_nd, B_nd) + + tvm.testing.assert_allclose(B_nd.numpy(), A_np) + + +@tvm.testing.requires_aarch64_sve +def test_scalable_broadcast(sve_device_vector_length): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + num_elements = sve_device_vector_length // 32 + dev = tvm.cpu(0) + + @T.prim_func + def my_func(a: T.handle): + A = T.match_buffer(a, (num_elements,), "float32") + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + A[T.ramp(0, 1, 4 * T.vscale())] = T.broadcast(1, 4 * T.vscale()) + + mod = tvm.build(my_func, target=target) + + A_np = np.zeros((num_elements,)).astype("float32") + A_nd = tvm.nd.array(A_np, device=dev) + mod(A_nd) + + ref = np.ones((num_elements,)) + tvm.testing.assert_allclose(A_nd.numpy(), ref)