diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index c1fa921d4507..42d0027a326f 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -43,7 +43,7 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std:: this->InitFuncState(); ICHECK(f->HasNonzeroAttr(tir::attr::kNoAlias)) << "SPIRV only takes restricted memory model"; std::vector pod_args; - uint32_t num_buffer = 0; + uint32_t i_buffer = 0; // Currently, all storage and uniform buffer arguments are passed as // a single descriptor set at index 0. If ever non-zero, must @@ -53,24 +53,25 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std:: for (Var arg : f->params) { DataType t = arg.dtype(); if (t.is_handle()) { - if (auto* ptr = arg->type_annotation.as()) { - auto* prim = ptr->element_type.as(); - ICHECK(prim); - DataType value_storage_type = prim->dtype; - if (value_storage_type == DataType::UInt(1)) { - // We need a physically addressable buffer type to support boolean tensors. - // The loaded byte is cast to bool inside the LoadNode visitor below. - value_storage_type = DataType::UInt(8); - } - spirv::Value arg_value = builder_->BufferArgument(builder_->GetSType(value_storage_type), - descriptor_set, num_buffer); - builder_->SetName(arg_value, arg->name_hint); - storage_info_[arg.get()].UpdateContentType(value_storage_type); - var_map_[arg.get()] = arg_value; - } else { - LOG(FATAL) << "require all handles to be typed"; + auto* ptr = arg->type_annotation.as(); + ICHECK(ptr) << "All handles passed to the Vulkan codegen must have a type_annotation as a " + "PointerType, " + << "and must point to a PrimType"; + auto* prim = ptr->element_type.as(); + ICHECK(prim) << "All handles passed to the Vulkan codegen must have a type_annotation as a " + "PointerType, " + << "and must point to a PrimType"; + DataType value_storage_type = prim->dtype; + if (value_storage_type == DataType::Bool()) { + // We need a physically addressable buffer type to support boolean tensors. + // The loaded byte is cast to bool inside the LoadNode visitor below. + value_storage_type = boolean_storage_type_.with_lanes(value_storage_type.lanes()); } - ++num_buffer; + spirv::Value arg_value = builder_->BufferArgument(builder_->GetSType(value_storage_type), + descriptor_set, i_buffer++); + builder_->SetName(arg_value, arg->name_hint); + storage_info_[arg.get()].SetContentType(value_storage_type, arg->name_hint); + var_map_[arg.get()] = arg_value; } else { pod_args.push_back(arg); } @@ -95,7 +96,7 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std:: } else { shader.flag |= 1 << runtime::vulkan::ShaderMetaDataFlagMask::kUseUBO; // If we need to pass more arguments than push constants could handle, we use UBO. - spirv::Value ptr = builder_->DeclareUniformBuffer(value_types, descriptor_set, num_buffer); + spirv::Value ptr = builder_->DeclareUniformBuffer(value_types, descriptor_set, i_buffer++); for (size_t i = 0; i < pod_args.size(); ++i) { spirv::Value value = builder_->GetUniform(ptr, value_types[i], static_cast(i)); var_map_[pod_args[i].get()] = value; @@ -404,14 +405,19 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const BroadcastNode* op) { spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { ICHECK(is_one(op->predicate)); - auto it = storage_info_.find(op->buffer_var.get()); + + DataType desired_read_type = op->dtype; + if (desired_read_type == DataType::Bool()) { + desired_read_type = boolean_storage_type_.with_lanes(desired_read_type.lanes()); + } + + const VarNode* buffer_var = op->buffer_var.get(); + auto it = storage_info_.find(buffer_var); ICHECK(it != storage_info_.end()); StorageInfo& info = it->second; - if (!info.content_fixed) { - info.UpdateContentType(op->dtype); - } + info.CheckContentType(desired_read_type, op->index.dtype().lanes()); - spirv::SType content_type = builder_->GetSType(info.content_type); + spirv::SType content_type = builder_->GetSType(info.element_type); spirv::Value buffer = MakeValue(op->buffer_var); spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class); @@ -419,47 +425,38 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { if (info.is_volatile) { mask |= spv::MemoryAccessVolatileMask; } - if (op->dtype.lanes() == 1) { + + if (desired_read_type == info.element_type) { + // Requested a single value from an array. This may be a scalar load + // or a vectorized load, based on the array element type. spirv::Value index = MakeValue(op->index); spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); spirv::Value loaded = builder_->MakeValue(spv::OpLoad, content_type, ptr, mask); - if (op->dtype == DataType::UInt(1)) { - // A bool tensor is backed by a byte buffer, we cast to bool here. - auto bool_ty = builder_->GetSType(DataType::UInt(1)); - return builder_->Cast(bool_ty, loaded); - } else { - ICHECK_EQ(info.content_type, op->dtype) - << "Vulkan only allow one type access to the same buffer"; - return loaded; + // OpTypeBool have no physical address/storage. Here, cast from + // the storage type to an OpTypeBool. + if (op->dtype == DataType::Bool()) { + auto spirv_bool = builder_->GetSType(DataType::Bool()); + loaded = builder_->Cast(spirv_bool, loaded); } + return loaded; + + } else if (desired_read_type.element_of() == info.element_type) { + // Requested several elements returned as an array. Read out each + // element and concatenate into the result. + std::vector values; + auto f = [&](int i, spirv::Value index) { + spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); + values.emplace_back(builder_->MakeValue(spv::OpLoad, content_type, ptr, mask)); + }; + this->Scalarize(op->index, f); + return builder_->Concat(values); + } else { - if (op->dtype.element_of() == info.content_type) { - // because content type is element type, we can only do scalarize load. - std::vector values; - auto f = [&](int i, spirv::Value index) { - spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); - values.emplace_back(builder_->MakeValue(spv::OpLoad, content_type, ptr, mask)); - }; - this->Scalarize(op->index, f); - return builder_->Concat(values); - } else { - if (const RampNode* ramp = op->index.as()) { - if (is_one(ramp->stride)) { - ICHECK_EQ(ramp->lanes, op->dtype.lanes()); - arith::ModularSet me = analyzer_->modular_set(ramp->base); - ICHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0) - << "Only aligned vector access is allowed in SPIRV"; - PrimExpr vec_index = - analyzer_->Simplify(ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); - spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, MakeValue(vec_index)); - return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask); - } - } - } - LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV"; + LOG(FATAL) << "Cannot perform buffer access of buffer variable '" << buffer_var->name_hint + << "' with element type " << info.element_type << " using index of type " + << op->index->dtype << " to produce output of type " << op->dtype; + return spirv::Value(); } - LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV"; - return spirv::Value(); } void CodeGenSPIRV::Scalarize(const PrimExpr& e, std::function f) { @@ -482,12 +479,9 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) { auto it = storage_info_.find(op->buffer_var.get()); ICHECK(it != storage_info_.end()); StorageInfo& info = it->second; + info.CheckContentType(op->value.dtype(), op->index.dtype().lanes()); - if (!info.content_fixed) { - info.UpdateContentType(op->value.dtype()); - } - - spirv::SType content_type = builder_->GetSType(info.content_type); + spirv::SType content_type = builder_->GetSType(info.element_type); spirv::Value buffer = MakeValue(op->buffer_var); spirv::Value value = MakeValue(op->value); spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class); @@ -497,37 +491,29 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) { mask |= spv::MemoryAccessVolatileMask; } - if (op->value.dtype().lanes() == 1) { - ICHECK_EQ(info.content_type, op->value.dtype()) + if (op->value.dtype() == info.element_type) { + // Requested store of a single value. This may be a scalar store + // or a vectorized store, based on the array element type. + ICHECK_EQ(info.element_type, op->value.dtype()) << "Vulkan only allow one type access to the same buffer"; spirv::Value index = MakeValue(op->index); spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); builder_->MakeInst(spv::OpStore, ptr, value, mask); + + } else if (op->value.dtype().element_of() == info.element_type) { + // Requested store of several arbitrarily located values. Extract + // each value from the composite, then assign to the buffer. + auto f = [&](int i, spirv::Value index) { + spirv::Value elem = builder_->MakeValue(spv::OpCompositeExtract, content_type, value, i); + spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); + builder_->MakeInst(spv::OpStore, ptr, elem, mask); + }; + this->Scalarize(op->index, f); + } else { - if (op->value.dtype().element_of() == info.content_type) { - // because content type is element type, we can only do scalarize load. - auto f = [&](int i, spirv::Value index) { - spirv::Value elem = builder_->MakeValue(spv::OpCompositeExtract, content_type, value, i); - spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); - builder_->MakeInst(spv::OpStore, ptr, elem, mask); - }; - this->Scalarize(op->index, f); - } else { - if (const RampNode* ramp = op->index.as()) { - if (is_one(ramp->stride)) { - ICHECK_EQ(ramp->lanes, op->value.dtype().lanes()); - arith::ModularSet me = analyzer_->modular_set(ramp->base); - ICHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0) - << "Only aligned vector access is allowed in SPIRV"; - PrimExpr vec_index = - analyzer_->Simplify(ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); - spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, MakeValue(vec_index)); - builder_->MakeInst(spv::OpStore, ptr, value, mask); - return; - } - } - LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV"; - } + LOG(FATAL) << "Cannot store value of type " << op->value.dtype() << " into buffer variable '" + << op->buffer_var->name_hint << "' with element type " << info.element_type + << " using index of type " << op->index->dtype; } } @@ -663,8 +649,8 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { builder_->SetName(buf, op->buffer_var->name_hint); StorageInfo& info = storage_info_[op->buffer_var.get()]; - ICHECK(!info.content_fixed); - info.UpdateContentType(op->dtype); + ICHECK(!info.element_type_known); + info.SetContentType(op->dtype, op->buffer_var->name_hint); ICHECK(!var_map_.count(op->buffer_var.get())); var_map_[op->buffer_var.get()] = buf; diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index a44dc5fd3d34..8b14754f617f 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -114,47 +114,104 @@ class CodeGenSPIRV : public ExprFunctor, void VisitStmt_(const EvaluateNode* op) override; protected: - /*! \brief The storage information */ + /*! \brief Storage information for a buffer */ struct StorageInfo { + /*! \brief The name of the tir::Var for the buffer + * + * Used for error messages. + */ + std::string name_hint; + /*! \brief Whether it is volatile */ bool is_volatile{false}; - /*! \brief Whether it is volatile */ - bool content_fixed{false}; - /*! \brief Current content type */ - DataType content_type{DataType::Handle()}; - - // Update content type if it hasn't beenupdated. - void UpdateContentType(DataType type) { - if (content_fixed) { - ICHECK_EQ(type, content_type) << "Cannot use two different content type in GLSL model"; - } else { - this->content_type = type; - content_fixed = true; - } + + /*! \brief Whether the element type of the buffer is known. + * + * This value is determined based on the type_annotation of the + * buffer variable (AllocateNode) or of the parameter (shader + * arguments). + */ + bool element_type_known{false}; + + /*! \brief The known element type of the buffer. + * + * This value is determined based on the type_annotation of the + * buffer variable (AllocateNode) or of the parameter (shader + * arguments). + */ + DataType element_type{DataType()}; + + /* \brief Check that the access type matches the known type + * + * Asserts that the type given is the same as the type previously + * stored in this array. + * + * @param type The data type being stored/loaded in the buffer + * + * @param index_lanes The number of lanes of the index. The + * number of lanes in the value being stored/loaded should be the + * product of the number of lanes of the buffer element type and + * the number of lanes of the index. + */ + void CheckContentType(DataType type, int index_lanes = 1) { + ICHECK(element_type_known) << "Cannot check element type of buffer " << name_hint + << " no previous element type defined"; + DataType expected_type = element_type.with_lanes(index_lanes * element_type.lanes()); + ICHECK_EQ(type, expected_type) << "Attempted to access buffer " << name_hint + << " as element type " << type << " using an index of size " + << index_lanes << " when the element type is " << element_type; + } + + // Update content type if it hasn't been updated. + void SetContentType(DataType type, std::string name_hint) { + ICHECK(!element_type_known) << "Cannot set element type of buffer " << name_hint + << " a second time."; + this->element_type = type; + this->name_hint = name_hint; + element_type_known = true; } }; // Reset the state so it works for a new function. void InitFuncState(); // Get the thread index spirv::Value GetThreadIndex(const IterVar& iv, const PrimExpr& extent); + spirv::Value CreateStorageSync(const CallNode* op); void Scalarize(const PrimExpr& e, std::function f); + // SPIRV-related capabilities of the target SPIRVSupport spirv_support_; + // The builder std::unique_ptr builder_; + // Work group size of three uint32_t workgroup_size_[3]; + // Likely branch uint32_t weight_likely_branch_{128}; + + /* The data type used for the backing array for booleans. + * + * Currently matched to the data type used in Buffer::vstore and + * Buffer::vload. In the future, this should be the smallest + * integer type supported by the device, as not all Vulkan + * implementations support int8. + */ + DataType boolean_storage_type_{DataType::Int(8)}; + // the storage scope of allocation std::unordered_map storage_info_; + // The definition of local variable. std::unordered_map var_map_; + // The analyzer. std::unique_ptr analyzer_; + // deep comparison of PrimExpr ExprDeepEqual deep_equal_; + // binding of let variables. Enables duplicate var defs that map to same value std::unordered_map let_binding_; }; diff --git a/src/tir/ir/buffer_common.h b/src/tir/ir/buffer_common.h new file mode 100644 index 000000000000..8dac41a02e57 --- /dev/null +++ b/src/tir/ir/buffer_common.h @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tir/ir/buffer_common.h + * \brief Common utils for buffer access + */ +#ifndef TVM_TIR_IR_BUFFER_COMMON_H_ +#define TVM_TIR_IR_BUFFER_COMMON_H_ + +#include +#include + +#include + +namespace tvm { +namespace tir { + +/*! + * \brief Returns the type of object pointed to. + * + * \param type The type to be checked. + * + * \return A (bool, DataType) pair. If the type is a pointer to a + * primitive, the boolean is true and the DataType is the pointed-to + * type. Otherwise, the boolean is false and the DataType is + * default-constructed. This can be replaced with std::optional with + * C++17 if/when C++17 is required. + */ +inline std::pair GetPointerType(const Type& type) { + if (type.defined()) { + if (auto* ptr_type = type.as()) { + if (auto* prim_type = ptr_type->element_type.as()) { + return {true, prim_type->dtype}; + } + } + } + + return {false, DataType()}; +} + +} // namespace tir +} // namespace tvm +#endif // TVM_TIR_IR_BUFFER_COMMON_H_ diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 352f75abdf5e..c8be38fd8d29 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -29,6 +29,7 @@ #include #include "../../support/str_escape.h" +#include "buffer_common.h" namespace tvm { namespace tir { @@ -618,8 +619,42 @@ Load::Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate, S ICHECK(buffer_var.defined()); ICHECK(predicate.defined()); ICHECK(index.defined()); - ICHECK_EQ(dtype.lanes(), index.dtype().lanes()); - ICHECK_EQ(dtype.lanes(), predicate.dtype().lanes()); + + // Assume that the array elements have 1 lane, unless a type + // annotation tells us otherwise. + int element_lanes = 1; + auto pointer_type = tir::GetPointerType(buffer_var->type_annotation); + if (pointer_type.first) { + // Cannot check element type of array, as it may be different than + // the loaded type in some cases. + // + // 1. Booleans use DataType::Int(8) while stored, and the codegens + // handle cast to boolean. + // + // 2. The StorageRewrite pass can merge multiple allocations at + // the same scope, regardless of element type. The codegen is + // then responsible for casting to the output type. + + // TODO(Lunderberg): Uncomment this check once it can be applied. + // See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615 + // for discussion. + + // ICHECK(dtype.element_of() == pointer_type.second.element_of()) + // << "Type mismatch, cannot load type " << dtype << " from buffer " << + // buffer_var->name_hint + // << " of type " << pointer_type.second; + element_lanes = pointer_type.second.lanes(); + } + + // The C-based codegens assume that all loads occur on a array with + // non-vectorized elements, and cast between + // vectorized/non-vectorized arrays as needed. Ideally, these + // should be changed to explicit casts in the TIR graph, rather than + // being handled at the code-gen level. + ICHECK((dtype.lanes() == element_lanes * index.dtype().lanes()) || + (dtype.lanes() == index.dtype().lanes())); + ICHECK((dtype.lanes() == element_lanes * predicate.dtype().lanes()) || + (dtype.lanes() == index.dtype().lanes())); ObjectPtr node = make_object(); node->dtype = dtype; diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 9a20f3ec9358..cb06df8b7655 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -25,6 +25,8 @@ #include #include +#include "buffer_common.h" + namespace tvm { namespace tir { @@ -234,8 +236,29 @@ Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate, ICHECK(value.defined()); ICHECK(index.defined()); ICHECK(predicate.defined()); - ICHECK_EQ(value.dtype().lanes(), index.dtype().lanes()); - ICHECK_EQ(value.dtype().lanes(), predicate.dtype().lanes()); + + // Assume that the array elements have 1 lane, unless a type + // annotation tells us otherwise. + int element_lanes = 1; + auto pointer_type = tir::GetPointerType(buffer_var->type_annotation); + if (pointer_type.first) { + // Currently cannot check element type of array, see Load::Load + // for details. + + // TODO(Lunderberg): Uncomment this check once it can be applied. + // See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615 + // for discussion. + + // ICHECK_EQ(value.dtype().element_of(), pointer_type.second.element_of()) + // << "Type mismatch, cannot store type " << value.dtype() << " into buffer " + // << buffer_var->name_hint << " of type " << pointer_type.second; + element_lanes = pointer_type.second.lanes(); + } + + ICHECK((value.dtype().lanes() == element_lanes * index.dtype().lanes()) || + (value.dtype().lanes() == index.dtype().lanes())); + ICHECK((value.dtype().lanes() == element_lanes * predicate.dtype().lanes()) || + (value.dtype().lanes() == index.dtype().lanes())); ObjectPtr node = make_object(); node->buffer_var = std::move(buffer_var); diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 3a2990c928c7..592a6a33375e 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -37,6 +37,7 @@ #include #include "../../runtime/thread_storage_scope.h" +#include "../ir/buffer_common.h" #include "ir_utils.h" namespace tvm { @@ -326,7 +327,14 @@ class InplaceOpVerifier : public StmtExprVisitor { const StoreNode* store_{nullptr}; }; -// Planner to plan and rewrite memory allocation. +/* \brief Rewrite and merge memory allocation. + * + * Using LinearAccessPatternFinder, determines which buffers could share an + * allocation. This includes both sequential usage of the same buffer and + * merging small allocations at the same scope into a single larger allocation. + * The merging of small allocations requires the codegen to cast the resulting + * value from the storage type to the output type after access. + */ class StoragePlanRewriter : public StmtExprMutator { public: using StmtEntry = LinearAccessPatternFinder::StmtEntry; @@ -881,108 +889,547 @@ class StoragePlanRewriter : public StmtExprMutator { arith::Analyzer analyzer_; }; -// Turn alloc into vector alloc -// if all its access is the same vector type. -class VectorAllocRewriter : public StmtExprMutator { +/* Helper struct containing information on how a buffer is declared and used + * + */ +struct BufferVarInfo { + enum DeclarationLocation { + kPrimFuncParam = (1 << 0), + kPrimFuncBufferMap = (1 << 1), + kAllocateNode = (1 << 2), + kLetNode = (1 << 3), + }; + + // The tir::Var that represents this buffer. + Var var; + + // The data type of an element of the buffer. + DataType element_dtype; + + /* The extent of the buffer. + * + * If multidimensional, the extent of the last dimension of the buffer. If the + * size is unknown (e.g. pointer arguments to PrimFunc with no corresponding + * entry in buffer_map), then extent is zero. + */ + PrimExpr extent; + + // Where the buffer was declared + DeclarationLocation declaration_location; + + // When accessed, which element type is it accessed as. This may + // differ both in base type (e.g. int32* cast to float32* after + // packing in StorageRewrite) or in number of lanes (e.g. float16* + // cast to float16x4*). + std::unordered_set access_dtype; + + DataType get_preferred_dtype() const { + std::unordered_set base_access_dtype; + for (auto dtype : access_dtype) { + base_access_dtype.insert(dtype.element_of()); + } + // If the array is accessed as multiple base types within a + // function, no point in changing the declared type. CodeGenC can + // handle this with a type-cast prior to indexing. Vulkan will + // raise an error at code-gen time, if a later pass doesn't split + // it out. + if (base_access_dtype.size() != 1) { + return element_dtype; + } + + DataType preferred_base_type = *base_access_dtype.begin(); + + // If there is only one vectorizable size used to access the + // buffer, and if that access size is compatible with the array + // size, then the buffer is vectorizable. In the future, this + // could be improved to allow vectorized buffer access of size + // GCD(*lanes_used), if necessary. + int preferred_lanes = element_dtype.lanes(); + if ((element_dtype.lanes() == 1) && (access_dtype.size() == 1)) { + arith::Analyzer analyzer_; + arith::ModularSet me = analyzer_.modular_set(extent); + + int lanes = access_dtype.begin()->lanes(); + if ((me->coeff % lanes == 0) && (me->base % lanes == 0)) { + preferred_lanes = lanes; + } + } + + return preferred_base_type.with_lanes(preferred_lanes); + } +}; + +/* Checks whether buffers are accessed as scalar or vector parameters in a + * function. + * + */ +class VectorTypeAccessChecker : public StmtExprVisitor { public: - PrimExpr VisitExpr_(const LoadNode* op) final { - UpdateTypeMap(op->buffer_var.get(), op->dtype); - return StmtExprMutator::VisitExpr_(op); + /* Constructor + * + * @param params The parameters passed to a PrimFunc + * + * @param buffer_map The buffer_map associated with a PrimFunc + * + * @param allow_untyped_handles If a buffer or pointer variable is + * missing a type annotation, assume that it has the same underlying + * type as it is later accessed, with scalar element types. + */ + VectorTypeAccessChecker(const Array& params, const Map& buffer_map, + bool allow_untyped_pointers = false) + : allow_untyped_pointers_(allow_untyped_pointers) { + // If a parameter is in the buffer map, we want to track the + // version in the map. + for (auto it : buffer_map) { + Buffer& buffer = it.second; + Var buffer_var = buffer->data; + DataType dtype = buffer->dtype; + PrimExpr extent = buffer->shape.size() ? buffer->shape[buffer->shape.size() - 1] : 0; + OnArrayDeclaration(buffer_var, dtype, extent, BufferVarInfo::kPrimFuncParam); + } + + // If a pointer parameter isn't in the buffer map, then we want to + // track the parameter itself. + for (Var buffer_var : params) { + auto pointer_type = GetPointerType(buffer_var->type_annotation); + if (pointer_type.first && (buffer_map.count(buffer_var) == 0)) { + DataType dtype = pointer_type.second; + PrimExpr extent = 0; + OnArrayDeclaration(buffer_var, dtype, extent, BufferVarInfo::kPrimFuncBufferMap); + } + } } - Stmt VisitStmt_(const StoreNode* op) final { - UpdateTypeMap(op->buffer_var.get(), op->value.dtype()); - return StmtExprMutator::VisitStmt_(op); + void VisitExpr_(const LoadNode* op) final { + OnArrayAccess(op->dtype, op->buffer_var.get(), op->index, op->predicate); + StmtExprVisitor::VisitExpr_(op); } - PrimExpr VisitExpr_(const CallNode* op) final { + + void VisitStmt_(const StoreNode* op) final { + OnArrayAccess(op->value.dtype(), op->buffer_var.get(), op->index, op->predicate); + StmtExprVisitor::VisitStmt_(op); + } + void VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::tvm_access_ptr())) { DataType dtype = op->args[0].dtype(); const VarNode* buffer = op->args[1].as(); - UpdateTypeMap(buffer, dtype); + PrimExpr index = op->args[2]; + OnArrayAccess(dtype, buffer, index, const_true(dtype.lanes())); } - return StmtExprMutator::VisitExpr_(op); + StmtExprVisitor::VisitExpr_(op); } - Stmt VisitStmt_(const AllocateNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - const auto& tvec = acc_map_[op->buffer_var.get()]; - - if (tvec.size() == 1 && tvec[0].element_of() == op->dtype.element_of() && - tvec[0].lanes() % op->dtype.lanes() == 0 && tvec[0].lanes() != op->dtype.lanes()) { - int factor = tvec[0].lanes() / op->dtype.lanes(); - Array extents = op->extents; - arith::ModularSet me = analyzer_.modular_set(extents[extents.size() - 1]); - if (me->base % factor == 0 && me->coeff % factor == 0) { - extents.Set(extents.size() - 1, - extents[extents.size() - 1] / make_const(extents[0].dtype(), factor)); - // create a new buffer var - DataType new_dtype = tvec[0]; - Var new_buffer_var(op->buffer_var->name_hint, - PointerType(PrimType(new_dtype), GetPtrStorageScope(op->buffer_var))); - // update the remap req. - var_remap_.Set(op->buffer_var, new_buffer_var); - return Allocate(new_buffer_var, new_dtype, extents, op->condition, op->body); + void VisitStmt_(const AllocateNode* op) final { + const Array& extents = op->extents; + PrimExpr extent = extents[extents.size() - 1]; + OnArrayDeclaration(op->buffer_var, op->dtype, extent, BufferVarInfo::kAllocateNode); + + StmtExprVisitor::VisitStmt_(op); + } + + void VisitExpr_(const LetNode* op) final { + HandleLetNode(op->var); + StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const LetStmtNode* op) final { + HandleLetNode(op->var); + StmtExprVisitor::VisitStmt_(op); + } + + void HandleLetNode(Var let_var) { + if (let_var->dtype.is_handle()) { + auto pointer_type = GetPointerType(let_var->type_annotation); + if (pointer_type.first) { + OnArrayDeclaration(let_var, pointer_type.second, 0, BufferVarInfo::kLetNode); + } else if (allow_untyped_pointers_) { + OnArrayDeclaration(let_var, let_var->dtype, 0, BufferVarInfo::kLetNode); + } else { + LOG(FATAL) << "Let statement of variable " << let_var->name_hint + << " is missing a type annotation, " + << "or type annotation is not a pointer to primitive"; } } - return stmt; } - void UpdateTypeMap(const VarNode* buffer, DataType t) { - auto& tvec = acc_map_[buffer]; - if (std::find(tvec.begin(), tvec.end(), t) == tvec.end()) { - tvec.push_back(t); + /* Update the type map for a buffer based on its declaration + * + * @param buffer The VarNode representing the buffer. + * + * @param element_dtype The dtype of a single element of the buffer. + * If unknown, when used with the allow_untyped_handles option, + * should be a handle dtype. + * + * @param extent The extent of the buffer. Zero if size is unknown. + * + * @param declaration_location How the buffer was allocated, so that + * some locations can be rewritten without others. + */ + void OnArrayDeclaration(Var buffer, DataType element_dtype, PrimExpr extent, + BufferVarInfo::DeclarationLocation declaration_location) { + ICHECK(info_map_.find(buffer.get()) == info_map_.end()) + << "Array declaration of " << buffer->name_hint << " occurred multiple times."; + + if (element_dtype == DataType::Bool()) { + element_dtype = DataType::Int(8).with_lanes(element_dtype.lanes()); } + + info_map_[buffer.get()] = {buffer, element_dtype, extent, declaration_location}; } - // Internal access map - std::unordered_map > acc_map_; - // Variables to remap - Map var_remap_; + /* Update the type map for a buffer based on its usage + * + * @param value_dtype The dtype of the value being stored to or + * loaded from the buffer. + * + * @param buffer The VarNode representing the buffer. + * + * @param index The index at which the value is being stored/loaded. + * + * @param predicate The predicate used for the store/load. + */ + void OnArrayAccess(DataType value_dtype, const VarNode* buffer, const PrimExpr& index, + const PrimExpr& predicate) { + auto it = info_map_.find(buffer); + ICHECK(it != info_map_.end()) << "Load/Store of buffer " << buffer->name_hint << " (" << buffer + << ") occurred before its declaration."; + BufferVarInfo& var_info = it->second; + + if (value_dtype.element_of() == DataType::Bool()) { + value_dtype = DataType::Int(8).with_lanes(value_dtype.lanes()); + } + + if (var_info.element_dtype.is_handle()) { + ICHECK(allow_untyped_pointers_) << "Variable " << buffer->name_hint + << " was missing a type annotation in its declaration"; + var_info.element_dtype = value_dtype.element_of(); + } + + DataType access_dtype = value_dtype; + + int lanes_used = var_info.element_dtype.lanes(); + + // This can happen due to a previous pass that had rewrite_store_load = + // false. This occurs from the StorageRewrite in tvm::lower, followed by the + // PointerValueTypeRewrite in BuildSPIRV. The rewrite_store_load = false is + // necessary because the C-based codegens do not yet support vectorized + // pointer types (e.g. float16x4*). Once they do, this if statement should + // instead be replaced by the below ICHECK_EQ. + if (index.dtype().lanes() * var_info.element_dtype.lanes() != value_dtype.lanes()) { + ICHECK_EQ(index.dtype().lanes(), value_dtype.lanes()); + lanes_used = 1; + var_info.element_dtype = var_info.element_dtype.with_lanes(1); + } + + // TODO(Lunderberg): Uncomment this check once it can be applied. + // See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615 + // for discussion. + + // ICHECK_EQ(index.dtype().lanes() * var_info.element_dtype.lanes(), value_dtype.lanes()) + // << "Attempting to retrieve " << value_dtype.lanes() << " lanes of data with " + // << index.dtype().lanes() << " indices into an array whose elements have " + // << var_info.element_dtype.lanes() << " lanes. " + // << "Expected output with " << index.dtype().lanes() * var_info.element_dtype.lanes() + // << " lanes."; + + // If the index is a RampNode with stride of 1 and offset + // divisible by the number of number of lanes, and the predicate + // does not apply any masking, then this array access could be + // vectorized. + const RampNode* ramp_index = index.as(); + if (ramp_index && is_one(ramp_index->stride) && is_one(predicate)) { + arith::ModularSet me = analyzer_.modular_set(ramp_index->base); + if ((me->coeff % ramp_index->lanes == 0) && (me->base % ramp_index->lanes == 0)) { + lanes_used = ramp_index->lanes; + } + } + + var_info.access_dtype.insert(access_dtype.with_lanes(lanes_used)); + } + + // Map of buffer variable information determined + std::unordered_map info_map_; + + // + bool allow_untyped_pointers_{false}; + // internal analyzer arith::Analyzer analyzer_; }; -PrimFunc PointerValueTypeRewrite(PrimFunc f) { - auto* n = f.CopyOnWrite(); - VectorAllocRewriter rewriter; - n->body = rewriter(std::move(n->body)); +/* \brief Rewrites buffer/pointer variables from scalar types to vectorized + * types. + * + * Some runtimes do not allow casting between composite types and the underlying + * base type (e.g. Vulkan, casting from 1-lane float16* to 4-lane float16x4*). + * In these cases, in order to have vectorized load/store on an array, the + * element type of that array must be vectorized. This is in contrast to C-style + * runtimes, in which `float16x4* vec = *(float16x4*)(float_arr + offset)` is + * valid. + * + * By default, VectorTypeRewriter will attempt to rewrite all buffer variables to + * vectorized access, if the load/store occurring in the PrimFunc are all + * vectorized. This includes adjusting the indices being used to access the + * array. (e.g. If `float16* scalar_arr` is being converted to `float16x4* + * vec_arr`, then `scalar_arr[Ramp(offset, 1, 4)]` will be converted to + * `vec_arr[offset/4]`.) + * + * Currently, several of the C-style runtimes do not support buffers whose + * elements are vectorized types, or rely on the presence of the Ramp nodes to + * identify vectorized loads. The boolean parameters in the constructor are to + * mimic the previous behavior of VectorTypeRewriter, to avoid breaking these + * runtimes. Once all runtimes support vectorized buffer elements, these + * parameters can be removed. + */ +class VectorTypeRewriter : public StmtExprMutator { + public: + /* Constructor + * + * @param checker The VectorTypeAccessChecker that has previously read out + * information from the PrimFunc + * + * @param rewrite_params Whether pointer-type parameters passed into the + * function should be rewritten from scalar types to vectorized types. + * + * @param rewrite_buffer_map Whether buffers present in the buffer_map should + * have their data variable be rewritten from scalar types to vectorized types. + * + * @param rewrite_allocate_node Whether the buffer variable associated with + * AllocateNodes should be rewritten from scalar types to vectorized types. + * + * @param rewrite_indices Whether the indices to the Load and Store nodes + * should be rewritten to correspond to the new buffer_var type. + * + * @param rewrite_let_node Whether pointer declarations in let nodes + * should be re-written. + */ + VectorTypeRewriter(const std::unordered_map& info_map, + bool rewrite_params = true, bool rewrite_buffer_map = true, + bool rewrite_allocate_node = true, bool rewrite_indices = true, + bool rewrite_let_node = true) + : rewrite_indices_(rewrite_indices) { + int rewrite_mask = 0; + if (rewrite_params) { + rewrite_mask |= BufferVarInfo::kPrimFuncParam; + } + if (rewrite_buffer_map) { + rewrite_mask |= BufferVarInfo::kPrimFuncBufferMap; + } + if (rewrite_allocate_node) { + rewrite_mask |= BufferVarInfo::kAllocateNode; + } + if (rewrite_let_node) { + rewrite_mask |= BufferVarInfo::kLetNode; + } - Map var_remap = std::move(rewriter.var_remap_); - Array args; + // Rewrite any buffer variables whose preferred type isn't their current type. + for (const auto& pair : info_map) { + const auto& var_info = pair.second; + DataType preferred = var_info.get_preferred_dtype(); + if (preferred != var_info.element_dtype && (rewrite_mask & var_info.declaration_location)) { + Var old_buffer_var = var_info.var; + Var new_buffer_var(old_buffer_var->name_hint, + PointerType(PrimType(preferred), GetPtrStorageScope(old_buffer_var)), + old_buffer_var->span); - // rewrite paramters if needed. - for (Var var : f->params) { - if (var.dtype().is_handle()) { - const auto& tvec = rewriter.acc_map_[var.get()]; + rewrite_map_[var_info.var.get()] = {var_info.var, new_buffer_var, var_info.element_dtype, + preferred}; + } + } + } - if (tvec.size() == 1) { - tir::Var new_var(var->name_hint, PointerType(PrimType(tvec[0]))); - args.push_back(new_var); - var_remap.Set(var, new_var); - } else { - // always set data type to be non vectorized so - // load/store can still work via scalarization - if (tvec.size() != 0 && !var->type_annotation.defined()) { - tir::Var new_var(var->name_hint, PointerType(PrimType(tvec[0].with_lanes(1)))); - args.push_back(new_var); - var_remap.Set(var, new_var); - } else { - args.push_back(var); - } + PrimExpr VisitExpr_(const LoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + + if (!rewrite_indices_) { + return expr; + } + + auto it = rewrite_map_.find(op->buffer_var.get()); + if (it == rewrite_map_.end()) { + return expr; + } + const auto& info = it->second; + + DataType out_dtype_base = info.new_element_dtype.element_of(); + + const RampNode* ramp_index = op->index.as(); + if (ramp_index && is_one(ramp_index->stride)) { + PrimExpr new_index = + ramp_index->base / make_const(ramp_index->base.dtype(), ramp_index->lanes); + return Load(out_dtype_base.with_lanes(op->dtype.lanes()), info.new_buffer_var, new_index, + const_true(new_index.dtype().lanes()), op->span); + } else { + return Load(out_dtype_base, info.new_buffer_var, op->index, op->predicate); + } + } + + Stmt VisitStmt_(const StoreNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + + if (!rewrite_indices_) { + return stmt; + } + + auto it = rewrite_map_.find(op->buffer_var.get()); + if (it == rewrite_map_.end()) { + return stmt; + } + const auto& info = it->second; + + const RampNode* ramp_index = op->index.as(); + if (ramp_index && is_one(ramp_index->stride)) { + PrimExpr new_index = + ramp_index->base / make_const(ramp_index->base.dtype(), ramp_index->lanes); + return Store(info.new_buffer_var, op->value, new_index, const_true(new_index.dtype().lanes()), + op->span); + } else { + return Store(info.new_buffer_var, op->value, op->index, op->predicate, op->span); + } + } + + PrimExpr VisitExpr_(const CallNode* op) final { + if (op->op.same_as(builtin::tvm_access_ptr())) { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + + if (!rewrite_indices_) { + return expr; } + + const VarNode* buffer_var = op->args[1].as(); + auto it = rewrite_map_.find(buffer_var); + if (it == rewrite_map_.end()) { + return expr; + } + const auto& info = it->second; + + PrimExpr index = op->args[2]; + PrimExpr extent = op->args[3]; + PrimExpr flag = op->args[4]; + + PrimExpr e_dtype = tir::TypeAnnotation(info.new_element_dtype); + PrimExpr factor = make_const(extent.dtype(), info.new_element_dtype.lanes()); + extent = extent / factor; + index = index / factor; + Array acc_args{e_dtype, info.new_buffer_var, index, extent, flag}; + return Call(info.new_element_dtype, builtin::tvm_access_ptr(), acc_args); + } else { - args.push_back(var); + return StmtExprMutator::VisitExpr_(op); + } + } + + Stmt VisitStmt_(const AllocateNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + + auto it = rewrite_map_.find(op->buffer_var.get()); + if (it == rewrite_map_.end()) { + return stmt; } + + const auto& info = it->second; + + Var new_buffer_var = info.new_buffer_var; + + int factor = info.new_element_dtype.lanes() / op->dtype.lanes(); + + Array extents = op->extents; + extents.Set(extents.size() - 1, + extents[extents.size() - 1] / make_const(extents[0].dtype(), factor)); + return Allocate(new_buffer_var, info.new_element_dtype, extents, op->condition, op->body); } - // no variable remap is needed. - if (var_remap.size() == 0) return f; + /* Update the parameters and all remaining variable references + * + * Should be called after calling operator() on the body of the + * function. + * + * @param func A pointer to the PrimFunc being modified. + */ + void Finalize(PrimFunc* func_ptr) const { + ICHECK(func_ptr) << "Finalize expects a non-null pointer"; + auto& func = *func_ptr; + auto* n = func.CopyOnWrite(); + + // Remap any remaining references to the old buffer variables + Map var_remap; + for (const auto& pair : rewrite_map_) { + const auto& info = pair.second; + var_remap.Set(info.old_buffer_var, info.new_buffer_var); + } + n->body = Substitute(n->body, var_remap); + + // Remap the argument list to use the new buffer variables. + Array new_params; + for (const auto& old_param : n->params) { + auto it = rewrite_map_.find(old_param.get()); + if (it == rewrite_map_.end()) { + new_params.push_back(old_param); + } else { + const auto& info = it->second; + new_params.push_back(info.new_buffer_var); + } + } + n->params = new_params; + + // Remap the Buffer objects in so that the buffers use the new buffer variables + Map new_buffer_map; + for (const auto& pair : n->buffer_map) { + Var key = pair.first; + Buffer old_buffer = pair.second; + Var old_var = old_buffer->data; + + auto it = rewrite_map_.find(old_var.get()); + if (it == rewrite_map_.end()) { + new_buffer_map.Set(key, old_buffer); + } else { + auto& info = it->second; + int factor = info.new_element_dtype.lanes() / info.old_element_dtype.lanes(); + ICHECK_EQ(factor * info.new_element_dtype.lanes(), info.old_element_dtype.lanes()); + + auto* buffer_cow = old_buffer.CopyOnWrite(); + buffer_cow->data = info.new_buffer_var; + buffer_cow->dtype = info.new_element_dtype; + size_t ndim = buffer_cow->shape.size(); + const auto& last_dim = buffer_cow->shape[ndim - 1]; + buffer_cow->shape.Set(ndim - 1, last_dim / make_const(last_dim.dtype(), factor)); + new_buffer_map.Set(key, old_buffer); + } + } + n->buffer_map = new_buffer_map; + } + + private: + struct RewriteInfo { + Var old_buffer_var; + Var new_buffer_var; + DataType old_element_dtype; + DataType new_element_dtype; + }; + + bool rewrite_indices_{true}; + std::unordered_map rewrite_map_; +}; + +// Rewrite allocates, pointer parameters, and buffer map into vectorized versions +// if each access into a buffer is the same vector type. +PrimFunc PointerValueTypeRewrite(PrimFunc f, bool allow_untyped_pointers = false, + bool rewrite_params = true, bool rewrite_buffer_map = true, + bool rewrite_allocate_node = true, bool rewrite_indices = true, + bool rewrite_let_node = true) { + VectorTypeAccessChecker checker(f->params, f->buffer_map, allow_untyped_pointers); + checker(f->body); + + VectorTypeRewriter rewriter(checker.info_map_, rewrite_params, rewrite_buffer_map, + rewrite_allocate_node, rewrite_indices, rewrite_let_node); + PrimFuncNode* n = f.CopyOnWrite(); + n->body = rewriter(std::move(n->body)); + rewriter.Finalize(&f); - // remap the variables. - ICHECK_EQ(args.size(), n->params.size()); - n->params = args; - n->body = Substitute(n->body, var_remap); return f; } @@ -992,7 +1439,7 @@ Pass StorageRewrite() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true); - return PointerValueTypeRewrite(std::move(f)); + return PointerValueTypeRewrite(std::move(f), true, false, false, true, false, true); }; return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {}); } diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py index 0551fcd54855..85e9cb12d8d2 100644 --- a/tests/python/unittest/test_target_codegen_vulkan.py +++ b/tests/python/unittest/test_target_codegen_vulkan.py @@ -433,5 +433,99 @@ def do_compute(A, B, n): tvm.testing.assert_allclose(b.numpy(), a_np) +class TestVectorizedIndices: + load_type, store_type = tvm.testing.parameters( + # Load N values, write to N locations. + # Vectorized copy. + ("ramp", "ramp"), + # Load 1 value, write to N locations. + # Scalar load, vectorized store. + # + # Most TVM operations (e.g. schedule[tensor].vectorize(axis)) have + # the broadcast outside of the index, but it is semantically okay + # for the broadcast to be inside the index, and it shows up with + # some optimizations. + ("broadcast", "ramp"), + # Load 1 values, write to 1 location. + # Broadcasting on both sides should be equivalent to a scalar copy. + ("broadcast", "broadcast"), + # Loads N values, write to 1 location. + # Disabled as it would have unclear semantics. + # ("ramp","broadcoast"), + ) + indirect_indices = tvm.testing.parameter(True, False, ids=["reorder", "no_reorder"]) + + @tvm.testing.fixture + def ref_data(self, load_type, store_type, indirect_indices): + n = 4 + + index_map = { + "ramp": np.arange(n), + "broadcast": np.zeros(n, dtype="int32"), + } + + a_np = np.random.randint(np.iinfo("int32").max, size=n).astype("int32") + b_np = np.zeros(shape=n, dtype=a_np.dtype) + reorder_np = np.arange(n, dtype="int32")[::-1] + + load_index = index_map[load_type] + store_index = index_map[store_type] + + if indirect_indices: + load_index = reorder_np[load_index] + + b_np[store_index] = a_np[load_index] + + return a_np, reorder_np, b_np + + @tvm.testing.fixture + def mod(self, target, load_type, store_type, indirect_indices): + target = tvm.target.Target(target) + + n = 4 + dtype = "int32" + A = te.placeholder((n,), dtype=dtype, name="A") + R = te.placeholder((n,), dtype=dtype, name="R") + + def do_compute(ins, outs): + ib = tvm.tir.ir_builder.create() + A, R = map(ib.buffer_ptr, ins) + B = ib.buffer_ptr(outs[0]) + + if "gpu" in target.keys: + ib.scope_attr(te.thread_axis("blockIdx.x"), "thread_extent", 0) + + index_map = { + "ramp": tvm.tir.Ramp(0, 1, 4), + "broadcast": tvm.tir.Broadcast(0, 4), + } + + load_index = index_map[load_type] + store_index = index_map[store_type] + + if indirect_indices: + load_index = tvm.tir.expr.Load("int32x4", R, load_index) + + transfer = tvm.tir.expr.Load("int32x4", A, load_index) + ib.emit(tvm.tir.stmt.Store(B, transfer, store_index)) + + return ib.get() + + B = te.extern(A.shape, [A, R], do_compute, dtype="int32") + s = te.create_schedule(B.op) + + return tvm.lower(s, [A, R, B]) + + def test_ramp_broadcast_index(self, target, dev, mod, ref_data): + f = tvm.build(mod, target=target) + + a_np, reorder_np, b_np = ref_data + a = tvm.nd.array(a_np, dev) + r = tvm.nd.array(reorder_np, dev) + b = tvm.nd.array(np.zeros(shape=b_np.shape, dtype="int32"), dev) + f(a, r, b) + tvm.testing.assert_allclose(b.numpy(), b_np) + + if __name__ == "__main__": sys.exit(pytest.main(sys.argv))