Skip to content

Commit

Permalink
[Vulkan] Rewrote PointerValueTypeRewrite transform (apache#8528)
Browse files Browse the repository at this point in the history
* [Vulkan] Rewrote PointerValueTypeRewrite transform

In C-style codegen, pointer types can be freely cast between scalar
and vectorized types (e.g. `float16x4* <-> float16*`).  In SPIR-V,
these are separate types, and no such casting is allowed.  This was
previously handled by having a special-case for `Ramp(base, stride=1,
lanes)` in the codegen.  That method didn't cover all possible cases,
including Broadcast nodes used as indices.

PointerValueTypeRewrite previously re-wrote the AllocateNode and
parameter pointer types, but didn't update the Load/Store node.  This
change tracks which variables can be updated to a vectorized type, and
then updates all references to those.  This includes removing the
`RampNode`, as the vectorization is then included as part of the
variable type.

* [StorageRewrite] Updates as recommended in review.

- Added explicit TODO(Lunderberg) for follow-ups

- Pass `checker.info_map_` instead of `checker` to
  `VectorTypeRewriter`

* [Vulkan] Allow for pointer rewrites that change base type.

A single memory allocation may have more than one type of data stored
within it.  This allows the PointerTypeRewrite pass to recognize if a
function only uses the pointer as a particular base type.  This wasn't
an issue in C-based codegen, but is required for Vulkan.  Since Vulkan
shaders do not permit type-casting, the cast must be done when passing
the pointer argument into the shader.

Co-authored-by: Eric Lunderberg <[email protected]>
  • Loading branch information
2 people authored and ylc committed Sep 29, 2021
1 parent a7b0770 commit 951c1eb
Show file tree
Hide file tree
Showing 7 changed files with 885 additions and 184 deletions.
170 changes: 78 additions & 92 deletions src/target/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::
this->InitFuncState();
ICHECK(f->HasNonzeroAttr(tir::attr::kNoAlias)) << "SPIRV only takes restricted memory model";
std::vector<Var> pod_args;
uint32_t num_buffer = 0;
uint32_t i_buffer = 0;

// Currently, all storage and uniform buffer arguments are passed as
// a single descriptor set at index 0. If ever non-zero, must
Expand All @@ -53,24 +53,25 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::
for (Var arg : f->params) {
DataType t = arg.dtype();
if (t.is_handle()) {
if (auto* ptr = arg->type_annotation.as<PointerTypeNode>()) {
auto* prim = ptr->element_type.as<PrimTypeNode>();
ICHECK(prim);
DataType value_storage_type = prim->dtype;
if (value_storage_type == DataType::UInt(1)) {
// We need a physically addressable buffer type to support boolean tensors.
// The loaded byte is cast to bool inside the LoadNode visitor below.
value_storage_type = DataType::UInt(8);
}
spirv::Value arg_value = builder_->BufferArgument(builder_->GetSType(value_storage_type),
descriptor_set, num_buffer);
builder_->SetName(arg_value, arg->name_hint);
storage_info_[arg.get()].UpdateContentType(value_storage_type);
var_map_[arg.get()] = arg_value;
} else {
LOG(FATAL) << "require all handles to be typed";
auto* ptr = arg->type_annotation.as<PointerTypeNode>();
ICHECK(ptr) << "All handles passed to the Vulkan codegen must have a type_annotation as a "
"PointerType, "
<< "and must point to a PrimType";
auto* prim = ptr->element_type.as<PrimTypeNode>();
ICHECK(prim) << "All handles passed to the Vulkan codegen must have a type_annotation as a "
"PointerType, "
<< "and must point to a PrimType";
DataType value_storage_type = prim->dtype;
if (value_storage_type == DataType::Bool()) {
// We need a physically addressable buffer type to support boolean tensors.
// The loaded byte is cast to bool inside the LoadNode visitor below.
value_storage_type = boolean_storage_type_.with_lanes(value_storage_type.lanes());
}
++num_buffer;
spirv::Value arg_value = builder_->BufferArgument(builder_->GetSType(value_storage_type),
descriptor_set, i_buffer++);
builder_->SetName(arg_value, arg->name_hint);
storage_info_[arg.get()].SetContentType(value_storage_type, arg->name_hint);
var_map_[arg.get()] = arg_value;
} else {
pod_args.push_back(arg);
}
Expand All @@ -95,7 +96,7 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::
} else {
shader.flag |= 1 << runtime::vulkan::ShaderMetaDataFlagMask::kUseUBO;
// If we need to pass more arguments than push constants could handle, we use UBO.
spirv::Value ptr = builder_->DeclareUniformBuffer(value_types, descriptor_set, num_buffer);
spirv::Value ptr = builder_->DeclareUniformBuffer(value_types, descriptor_set, i_buffer++);
for (size_t i = 0; i < pod_args.size(); ++i) {
spirv::Value value = builder_->GetUniform(ptr, value_types[i], static_cast<uint32_t>(i));
var_map_[pod_args[i].get()] = value;
Expand Down Expand Up @@ -404,62 +405,58 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const BroadcastNode* op) {

spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) {
ICHECK(is_one(op->predicate));
auto it = storage_info_.find(op->buffer_var.get());

DataType desired_read_type = op->dtype;
if (desired_read_type == DataType::Bool()) {
desired_read_type = boolean_storage_type_.with_lanes(desired_read_type.lanes());
}

const VarNode* buffer_var = op->buffer_var.get();
auto it = storage_info_.find(buffer_var);
ICHECK(it != storage_info_.end());
StorageInfo& info = it->second;
if (!info.content_fixed) {
info.UpdateContentType(op->dtype);
}
info.CheckContentType(desired_read_type, op->index.dtype().lanes());

spirv::SType content_type = builder_->GetSType(info.content_type);
spirv::SType content_type = builder_->GetSType(info.element_type);
spirv::Value buffer = MakeValue(op->buffer_var);
spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class);

uint32_t mask = spv::MemoryAccessMaskNone;
if (info.is_volatile) {
mask |= spv::MemoryAccessVolatileMask;
}
if (op->dtype.lanes() == 1) {

if (desired_read_type == info.element_type) {
// Requested a single value from an array. This may be a scalar load
// or a vectorized load, based on the array element type.
spirv::Value index = MakeValue(op->index);
spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
spirv::Value loaded = builder_->MakeValue(spv::OpLoad, content_type, ptr, mask);
if (op->dtype == DataType::UInt(1)) {
// A bool tensor is backed by a byte buffer, we cast to bool here.
auto bool_ty = builder_->GetSType(DataType::UInt(1));
return builder_->Cast(bool_ty, loaded);
} else {
ICHECK_EQ(info.content_type, op->dtype)
<< "Vulkan only allow one type access to the same buffer";
return loaded;
// OpTypeBool have no physical address/storage. Here, cast from
// the storage type to an OpTypeBool.
if (op->dtype == DataType::Bool()) {
auto spirv_bool = builder_->GetSType(DataType::Bool());
loaded = builder_->Cast(spirv_bool, loaded);
}
return loaded;

} else if (desired_read_type.element_of() == info.element_type) {
// Requested several elements returned as an array. Read out each
// element and concatenate into the result.
std::vector<spirv::Value> values;
auto f = [&](int i, spirv::Value index) {
spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
values.emplace_back(builder_->MakeValue(spv::OpLoad, content_type, ptr, mask));
};
this->Scalarize(op->index, f);
return builder_->Concat(values);

} else {
if (op->dtype.element_of() == info.content_type) {
// because content type is element type, we can only do scalarize load.
std::vector<spirv::Value> values;
auto f = [&](int i, spirv::Value index) {
spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
values.emplace_back(builder_->MakeValue(spv::OpLoad, content_type, ptr, mask));
};
this->Scalarize(op->index, f);
return builder_->Concat(values);
} else {
if (const RampNode* ramp = op->index.as<RampNode>()) {
if (is_one(ramp->stride)) {
ICHECK_EQ(ramp->lanes, op->dtype.lanes());
arith::ModularSet me = analyzer_->modular_set(ramp->base);
ICHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0)
<< "Only aligned vector access is allowed in SPIRV";
PrimExpr vec_index =
analyzer_->Simplify(ramp->base / make_const(ramp->base.dtype(), ramp->lanes));
spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, MakeValue(vec_index));
return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask);
}
}
}
LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV";
LOG(FATAL) << "Cannot perform buffer access of buffer variable '" << buffer_var->name_hint
<< "' with element type " << info.element_type << " using index of type "
<< op->index->dtype << " to produce output of type " << op->dtype;
return spirv::Value();
}
LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV";
return spirv::Value();
}

void CodeGenSPIRV::Scalarize(const PrimExpr& e, std::function<void(int i, spirv::Value v)> f) {
Expand All @@ -482,12 +479,9 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) {
auto it = storage_info_.find(op->buffer_var.get());
ICHECK(it != storage_info_.end());
StorageInfo& info = it->second;
info.CheckContentType(op->value.dtype(), op->index.dtype().lanes());

if (!info.content_fixed) {
info.UpdateContentType(op->value.dtype());
}

spirv::SType content_type = builder_->GetSType(info.content_type);
spirv::SType content_type = builder_->GetSType(info.element_type);
spirv::Value buffer = MakeValue(op->buffer_var);
spirv::Value value = MakeValue(op->value);
spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class);
Expand All @@ -497,37 +491,29 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) {
mask |= spv::MemoryAccessVolatileMask;
}

if (op->value.dtype().lanes() == 1) {
ICHECK_EQ(info.content_type, op->value.dtype())
if (op->value.dtype() == info.element_type) {
// Requested store of a single value. This may be a scalar store
// or a vectorized store, based on the array element type.
ICHECK_EQ(info.element_type, op->value.dtype())
<< "Vulkan only allow one type access to the same buffer";
spirv::Value index = MakeValue(op->index);
spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
builder_->MakeInst(spv::OpStore, ptr, value, mask);

} else if (op->value.dtype().element_of() == info.element_type) {
// Requested store of several arbitrarily located values. Extract
// each value from the composite, then assign to the buffer.
auto f = [&](int i, spirv::Value index) {
spirv::Value elem = builder_->MakeValue(spv::OpCompositeExtract, content_type, value, i);
spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
builder_->MakeInst(spv::OpStore, ptr, elem, mask);
};
this->Scalarize(op->index, f);

} else {
if (op->value.dtype().element_of() == info.content_type) {
// because content type is element type, we can only do scalarize load.
auto f = [&](int i, spirv::Value index) {
spirv::Value elem = builder_->MakeValue(spv::OpCompositeExtract, content_type, value, i);
spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
builder_->MakeInst(spv::OpStore, ptr, elem, mask);
};
this->Scalarize(op->index, f);
} else {
if (const RampNode* ramp = op->index.as<RampNode>()) {
if (is_one(ramp->stride)) {
ICHECK_EQ(ramp->lanes, op->value.dtype().lanes());
arith::ModularSet me = analyzer_->modular_set(ramp->base);
ICHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0)
<< "Only aligned vector access is allowed in SPIRV";
PrimExpr vec_index =
analyzer_->Simplify(ramp->base / make_const(ramp->base.dtype(), ramp->lanes));
spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, MakeValue(vec_index));
builder_->MakeInst(spv::OpStore, ptr, value, mask);
return;
}
}
LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV";
}
LOG(FATAL) << "Cannot store value of type " << op->value.dtype() << " into buffer variable '"
<< op->buffer_var->name_hint << "' with element type " << info.element_type
<< " using index of type " << op->index->dtype;
}
}

Expand Down Expand Up @@ -663,8 +649,8 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) {
builder_->SetName(buf, op->buffer_var->name_hint);

StorageInfo& info = storage_info_[op->buffer_var.get()];
ICHECK(!info.content_fixed);
info.UpdateContentType(op->dtype);
ICHECK(!info.element_type_known);
info.SetContentType(op->dtype, op->buffer_var->name_hint);

ICHECK(!var_map_.count(op->buffer_var.get()));
var_map_[op->buffer_var.get()] = buf;
Expand Down
85 changes: 71 additions & 14 deletions src/target/spirv/codegen_spirv.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,47 +114,104 @@ class CodeGenSPIRV : public ExprFunctor<spirv::Value(const PrimExpr&)>,
void VisitStmt_(const EvaluateNode* op) override;

protected:
/*! \brief The storage information */
/*! \brief Storage information for a buffer */
struct StorageInfo {
/*! \brief The name of the tir::Var for the buffer
*
* Used for error messages.
*/
std::string name_hint;

/*! \brief Whether it is volatile */
bool is_volatile{false};
/*! \brief Whether it is volatile */
bool content_fixed{false};
/*! \brief Current content type */
DataType content_type{DataType::Handle()};

// Update content type if it hasn't beenupdated.
void UpdateContentType(DataType type) {
if (content_fixed) {
ICHECK_EQ(type, content_type) << "Cannot use two different content type in GLSL model";
} else {
this->content_type = type;
content_fixed = true;
}

/*! \brief Whether the element type of the buffer is known.
*
* This value is determined based on the type_annotation of the
* buffer variable (AllocateNode) or of the parameter (shader
* arguments).
*/
bool element_type_known{false};

/*! \brief The known element type of the buffer.
*
* This value is determined based on the type_annotation of the
* buffer variable (AllocateNode) or of the parameter (shader
* arguments).
*/
DataType element_type{DataType()};

/* \brief Check that the access type matches the known type
*
* Asserts that the type given is the same as the type previously
* stored in this array.
*
* @param type The data type being stored/loaded in the buffer
*
* @param index_lanes The number of lanes of the index. The
* number of lanes in the value being stored/loaded should be the
* product of the number of lanes of the buffer element type and
* the number of lanes of the index.
*/
void CheckContentType(DataType type, int index_lanes = 1) {
ICHECK(element_type_known) << "Cannot check element type of buffer " << name_hint
<< " no previous element type defined";
DataType expected_type = element_type.with_lanes(index_lanes * element_type.lanes());
ICHECK_EQ(type, expected_type) << "Attempted to access buffer " << name_hint
<< " as element type " << type << " using an index of size "
<< index_lanes << " when the element type is " << element_type;
}

// Update content type if it hasn't been updated.
void SetContentType(DataType type, std::string name_hint) {
ICHECK(!element_type_known) << "Cannot set element type of buffer " << name_hint
<< " a second time.";
this->element_type = type;
this->name_hint = name_hint;
element_type_known = true;
}
};
// Reset the state so it works for a new function.
void InitFuncState();
// Get the thread index
spirv::Value GetThreadIndex(const IterVar& iv, const PrimExpr& extent);

spirv::Value CreateStorageSync(const CallNode* op);
void Scalarize(const PrimExpr& e, std::function<void(int i, spirv::Value v)> f);

// SPIRV-related capabilities of the target
SPIRVSupport spirv_support_;

// The builder
std::unique_ptr<spirv::IRBuilder> builder_;

// Work group size of three
uint32_t workgroup_size_[3];

// Likely branch
uint32_t weight_likely_branch_{128};

/* The data type used for the backing array for booleans.
*
* Currently matched to the data type used in Buffer::vstore and
* Buffer::vload. In the future, this should be the smallest
* integer type supported by the device, as not all Vulkan
* implementations support int8.
*/
DataType boolean_storage_type_{DataType::Int(8)};

// the storage scope of allocation
std::unordered_map<const VarNode*, StorageInfo> storage_info_;

// The definition of local variable.
std::unordered_map<const VarNode*, spirv::Value> var_map_;

// The analyzer.
std::unique_ptr<arith::Analyzer> analyzer_;

// deep comparison of PrimExpr
ExprDeepEqual deep_equal_;

// binding of let variables. Enables duplicate var defs that map to same value
std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
};
Expand Down
Loading

0 comments on commit 951c1eb

Please sign in to comment.