Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[spirv] Fix bug of CTBuffer DX memory layout with matrix #3672

Merged
merged 17 commits into from
Apr 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tools/clang/include/clang/SPIRV/SpirvBasicBlock.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ class SpirvBasicBlock {
/// block.
void addInstruction(SpirvInstruction *inst) { instructions.push_back(inst); }

/// Adds the given instruction as the first instruction of this SPIR-V basic
/// block.
void addFirstInstruction(SpirvInstruction *inst) {
instructions.push_front(inst);
}

/// Return true if instructions is empty. Otherwise, return false.
bool empty() { return instructions.empty(); }

Expand Down
61 changes: 61 additions & 0 deletions tools/clang/include/clang/SPIRV/SpirvBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ class SpirvBuilder {
createAccessChain(QualType resultType, SpirvInstruction *base,
llvm::ArrayRef<SpirvInstruction *> indexes,
SourceLocation loc);
SpirvAccessChain *
createAccessChain(const SpirvType *resultType, SpirvInstruction *base,
llvm::ArrayRef<SpirvInstruction *> indexes,
SourceLocation loc);

/// \brief Creates a unary operation with the given SPIR-V opcode. Returns
/// the instruction pointer for the result.
Expand Down Expand Up @@ -503,6 +507,22 @@ class SpirvBuilder {
/// OpIgnoreIntersectionKHR/OpTerminateIntersectionKHR
void createRaytracingTerminateKHR(spv::Op opcode, SourceLocation loc);

/// \brief Returns a clone SPIR-V variable for CTBuffer with FXC memory layout
/// and creates copy instructions from the CTBuffer to the clone variable in
/// module.init if it contains HLSL matrix 1xN. Otherwise, returns nullptr.
///
/// Motivation for this clone variable:
/// We translate a matrix type1xN as a vector typeN in all code generation,
/// but type1xN in CTBuffer with FXC memory layout rule must have a stride 16
/// bytes between elements. Since we cannot set a stride for a SPIR-V vector,
/// we must use a SPIR-V array type[N] with stride 16 bytes for it. Since we
/// translate it into a vector typeN for all places, it has side effects. We
/// use a clone variable to fix this issue i.e.,
/// 1. Use the CTBuffer to receive the data from CPU
/// 2. Copy it to the clone variable
/// 3. Use the clone variable in all the places
SpirvInstruction *initializeCloneVarForFxcCTBuffer(SpirvInstruction *instr);

// === SPIR-V Module Structure ===
inline void setMemoryModel(spv::AddressingModel, spv::MemoryModel);

Expand Down Expand Up @@ -666,6 +686,37 @@ class SpirvBuilder {
SpirvInstruction *constOffsets, SpirvInstruction *sample,
SpirvInstruction *minLod);

/// \brief Creates instructions to copy sub-components of CTBuffer src to its
/// clone dst. This method assumes
/// 1. src has a pointer type to a type with FXC memory layout rule
/// 2. dst has a pointer type to a type with void memory layout rule
void
createCopyInstructionsFromFxcCTBufferToClone(SpirvInstruction *fxcCTBuffer,
SpirvInstruction *clone);
void createCopyArrayInFxcCTBufferToClone(const ArrayType *fxcCTBufferArrTy,
SpirvInstruction *fxcCTBuffer,
const SpirvType *cloneType,
SpirvInstruction *clone,
SourceLocation loc);
void createCopyStructInFxcCTBufferToClone(
const StructType *fxcCTBufferStructTy, SpirvInstruction *fxcCTBuffer,
const SpirvType *cloneType, SpirvInstruction *clone, SourceLocation loc);

/// \brief Sets moduleInitInsertPoint as insertPoint.
void switchInsertPointToModuleInit();

/// \brief Adds OpFunctionCall instructions for ModuleInit to all entry
/// points.
void addModuleInitCallToEntryPoints();

/// \brief Ends building of the module initialization function.
void endModuleInitFunction();

/// \brief Creates a clone SPIR-V variable for CTBuffer.
SpirvVariable *createCloneVarForFxcCTBuffer(QualType astType,
const SpirvType *spvType,
SpirvInstruction *var);

private:
ASTContext &astContext;
SpirvContext &context; ///< From which we allocate various SPIR-V object
Expand All @@ -674,6 +725,11 @@ class SpirvBuilder {
SpirvFunction *function; ///< The current function being built
SpirvBasicBlock *insertPoint; ///< The current basic block being built

SpirvFunction *moduleInit; ///< The module initialization
///< function
SpirvBasicBlock *moduleInitInsertPoint; ///< The basic block of the module
///< initialization function

const SpirvCodeGenOptions &spirvOptions; ///< Command line options.

/// A struct containing information regarding a builtin variable.
Expand All @@ -695,6 +751,11 @@ class SpirvBuilder {
// To avoid generating multiple OpStrings for the same string literal
// the SpirvBuilder will generate and reuse them.
llvm::DenseMap<std::string, SpirvString *, StringMapInfo> stringLiterals;

/// Mapping of CTBuffers including matrix 1xN with FXC memory layout to their
/// clone variables. We need it to avoid multiple clone variables for the same
/// CTBuffer.
llvm::DenseMap<SpirvVariable *, SpirvVariable *> fxcCTBufferToClone;
};

void SpirvBuilder::requireCapability(spv::Capability cap, SourceLocation loc) {
Expand Down
14 changes: 14 additions & 0 deletions tools/clang/include/clang/SPIRV/SpirvContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,17 @@ class SpirvContext {
return declToDebugFunction[decl];
}

/// Adds inst to instructionsWithLoweredType.
void addToInstructionsWithLoweredType(const SpirvInstruction *inst) {
instructionsWithLoweredType.insert(inst);
}

/// Returns whether inst is in instructionsWithLoweredType or not.
bool hasLoweredType(const SpirvInstruction *inst) {
return instructionsWithLoweredType.find(inst) !=
instructionsWithLoweredType.end();
}

private:
/// \brief The allocator used to create SPIR-V entity objects.
///
Expand Down Expand Up @@ -463,6 +474,9 @@ class SpirvContext {

// Mapping from SPIR-V OpVariable to SPIR-V image format.
llvm::DenseMap<const SpirvVariable *, spv::ImageFormat> spvVarToImageFormat;

// Set of instructions that already have lowered SPIR-V types.
llvm::DenseSet<const SpirvInstruction *> instructionsWithLoweredType;
};

} // end namespace spirv
Expand Down
9 changes: 8 additions & 1 deletion tools/clang/include/clang/SPIRV/SpirvFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@

#include <vector>

#include "clang/SPIRV/SpirvBasicBlock.h"
#include "clang/SPIRV/SpirvInstruction.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"

namespace clang {
namespace spirv {

class SpirvBasicBlock;
class SpirvVisitor;

/// The class representing a SPIR-V function in memory.
Expand Down Expand Up @@ -91,6 +91,13 @@ class SpirvFunction {
void addVariable(SpirvVariable *);
void addBasicBlock(SpirvBasicBlock *);

/// Adds the given instruction as the first instruction of this SPIR-V
/// function body.
void addFirstInstruction(SpirvInstruction *inst) {
assert(basicBlocks.size() != 0);
basicBlocks[0]->addFirstInstruction(inst);
}

/// Legalization-specific code
///
/// Note: the following methods are used for properly handling aliasing.
Expand Down
4 changes: 4 additions & 0 deletions tools/clang/include/clang/SPIRV/SpirvModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ class SpirvModule {

llvm::ArrayRef<SpirvVariable *> getVariables() const { return variables; }

llvm::ArrayRef<SpirvEntryPoint *> getEntryPoints() const {
return entryPoints;
}

private:
// Use a set for storing capabilities. This will ensure there are no duplicate
// capabilities. Although the set stores pointers, the provided
Expand Down
30 changes: 26 additions & 4 deletions tools/clang/lib/SPIRV/AlignmentSizeCalculator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

#include "AlignmentSizeCalculator.h"
#include "clang/AST/Attr.h"
#include "clang/SPIRV/AstTypeProbe.h"

namespace {

Expand Down Expand Up @@ -129,6 +128,10 @@ std::pair<uint32_t, uint32_t> AlignmentSizeCalculator::getAlignmentAndSize(
// - Vector base alignment is set as its element type's base alignment.
// - Arrays/structs do not need to have padding at the end; arrays/structs do
// not affect the base offset of the member following them.
// - For typeNxM matrix, if M > 1,
// - It must be alinged to 16 bytes.
// - Its size must be (16 * (M - 1)) + N * sizeof(type).
// - We have the same rule for column_major typeNxM and row_major typeMxN.
//
// FxcSBuffer:
// - Vector/matrix/array base alignment is set as its element type's base
Expand Down Expand Up @@ -186,6 +189,27 @@ std::pair<uint32_t, uint32_t> AlignmentSizeCalculator::getAlignmentAndSize(
}
}

// FxcCTBuffer for typeNxM matrix where M > 1,
// - It must be alinged to 16 bytes.
// - Its size must be (16 * (M - 1)) + N * sizeof(type).
// - We have the same rule for column_major typeNxM and row_major typeMxN.
if (rule == SpirvLayoutRule::FxcCTBuffer && hlsl::IsHLSLMatType(type)) {
uint32_t rowCount = 0, colCount = 0;
hlsl::GetHLSLMatRowColCount(type, rowCount, colCount);
if (!useRowMajor(isRowMajor, type))
std::swap(rowCount, colCount);
if (colCount > 1) {
auto elemType = hlsl::GetHLSLMatElementType(type);
uint32_t alignment = 0, size = 0;
std::tie(alignment, size) =
getAlignmentAndSize(elemType, rule, isRowMajor, stride);
alignment = roundToPow2(alignment * (rowCount == 3 ? 4 : rowCount),
kStd140Vec4Alignment);
*stride = alignment;
return {alignment, 16 * (colCount - 1) + rowCount * size};
}
}

{ // Rule 2 and 3
QualType elemType = {};
uint32_t elemCount = {};
Expand Down Expand Up @@ -215,9 +239,7 @@ std::pair<uint32_t, uint32_t> AlignmentSizeCalculator::getAlignmentAndSize(
// The base alignment and array stride are set to match the base alignment
// of a single array element, according to rules 1, 2, and 3, and rounded
// up to the base alignment of a vec4.
bool rowMajor = isRowMajor.hasValue()
? isRowMajor.getValue()
: isRowMajorMatrix(spvOptions, type);
bool rowMajor = useRowMajor(isRowMajor, type);

const uint32_t vecStorageSize = rowMajor ? rowCount : colCount;

Expand Down
8 changes: 8 additions & 0 deletions tools/clang/lib/SPIRV/AlignmentSizeCalculator.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "dxc/Support/SPIRVOptions.h"
#include "clang/AST/ASTContext.h"
#include "clang/SPIRV/AstTypeProbe.h"

namespace clang {
namespace spirv {
Expand Down Expand Up @@ -48,6 +49,13 @@ class AlignmentSizeCalculator {
uint32_t fieldAlignment,
uint32_t *currentOffset);

/// \brief Returns true if we use row-major matrix for type. Otherwise,
/// returns false.
bool useRowMajor(llvm::Optional<bool> isRowMajor, clang::QualType type) {
return isRowMajor.hasValue() ? isRowMajor.getValue()
: isRowMajorMatrix(spvOptions, type);
}

private:
/// Emits error to the diagnostic engine associated with this visitor.
template <unsigned N>
Expand Down
22 changes: 10 additions & 12 deletions tools/clang/lib/SPIRV/DeclResultIdMapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -947,8 +947,7 @@ SpirvVariable *DeclResultIdMapper::createExternVar(const VarDecl *var) {
if (spvImageFormat != spv::ImageFormat::Unknown)
spvContext.registerImageFormatForSpirvVariable(varInstr, spvImageFormat);

DeclSpirvInfo info(varInstr);
astDecls[var] = info;
astDecls[var] = createDeclSpirvInfo(varInstr);

createDebugGlobalVariable(varInstr, type, loc, name);

Expand Down Expand Up @@ -991,8 +990,7 @@ DeclResultIdMapper::createOrUpdateStringVar(const VarDecl *var) {
const StringLiteral *stringLiteral =
dyn_cast<StringLiteral>(var->getInit()->IgnoreParenCasts());
SpirvString *init = spvBuilder.getString(stringLiteral->getString());
DeclSpirvInfo info(init);
astDecls[var] = info;
astDecls[var] = createDeclSpirvInfo(init);
return init;
}

Expand Down Expand Up @@ -1089,7 +1087,7 @@ void DeclResultIdMapper::createEnumConstant(const EnumConstantDecl *decl) {
SpirvVariable *varInstr = spvBuilder.addModuleVar(
astContext.IntTy, spv::StorageClass::Private, /*isPrecise*/ false,
decl->getName(), enumConstant, decl->getLocation());
astDecls[valueDecl] = DeclSpirvInfo(varInstr);
astDecls[valueDecl] = createDeclSpirvInfo(varInstr);
}

SpirvVariable *DeclResultIdMapper::createCTBuffer(const HLSLBufferDecl *decl) {
Expand All @@ -1111,7 +1109,7 @@ SpirvVariable *DeclResultIdMapper::createCTBuffer(const HLSLBufferDecl *decl) {
continue;

const auto *varDecl = cast<VarDecl>(subDecl);
astDecls[varDecl] = DeclSpirvInfo(bufferVar, index++);
astDecls[varDecl] = createDeclSpirvInfo(bufferVar, index++);
}
resourceVars.emplace_back(
bufferVar, decl, decl->getLocation(), getResourceBinding(decl),
Expand Down Expand Up @@ -1185,7 +1183,7 @@ SpirvVariable *DeclResultIdMapper::createCTBuffer(const VarDecl *decl) {
decl->getName());

// We register the VarDecl here.
astDecls[decl] = DeclSpirvInfo(bufferVar);
astDecls[decl] = createDeclSpirvInfo(bufferVar);
resourceVars.emplace_back(
bufferVar, decl, decl->getLocation(), getResourceBinding(decl),
decl->getAttr<VKBindingAttr>(), decl->getAttr<VKCounterBindingAttr>());
Expand All @@ -1212,7 +1210,7 @@ SpirvVariable *DeclResultIdMapper::createPushConstant(const VarDecl *decl) {
structName, decl->getName());

// Register the VarDecl
astDecls[decl] = DeclSpirvInfo(var);
astDecls[decl] = createDeclSpirvInfo(var);

// Do not push this variable into resourceVars since it does not need
// descriptor set.
Expand Down Expand Up @@ -1241,7 +1239,7 @@ DeclResultIdMapper::createShaderRecordBuffer(const VarDecl *decl,
kind, structName, decl->getName());

// Register the VarDecl
astDecls[decl] = DeclSpirvInfo(var);
astDecls[decl] = createDeclSpirvInfo(var);

// Do not push this variable into resourceVars since it does not need
// descriptor set.
Expand Down Expand Up @@ -1276,7 +1274,7 @@ DeclResultIdMapper::createShaderRecordBuffer(const HLSLBufferDecl *decl,
continue;

const auto *varDecl = cast<VarDecl>(subDecl);
astDecls[varDecl] = DeclSpirvInfo(bufferVar, index++);
astDecls[varDecl] = createDeclSpirvInfo(bufferVar, index++);
}
return bufferVar;
}
Expand Down Expand Up @@ -1312,7 +1310,7 @@ void DeclResultIdMapper::createGlobalsCBuffer(const VarDecl *var) {
return;
}

astDecls[varDecl] = DeclSpirvInfo(globals, index++);
astDecls[varDecl] = createDeclSpirvInfo(globals, index++);
}
}
}
Expand Down Expand Up @@ -1385,7 +1383,7 @@ DeclResultIdMapper::getCounterVarFields(const DeclaratorDecl *decl) {
void DeclResultIdMapper::registerSpecConstant(const VarDecl *decl,
SpirvInstruction *specConstant) {
specConstant->setRValue();
astDecls[decl] = DeclSpirvInfo(specConstant);
astDecls[decl] = createDeclSpirvInfo(specConstant);
}

void DeclResultIdMapper::createCounterVar(
Expand Down
11 changes: 11 additions & 0 deletions tools/clang/lib/SPIRV/DeclResultIdMapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,16 @@ class DeclResultIdMapper {
/// Returns nullptr if no such decl was previously registered.
const DeclSpirvInfo *getDeclSpirvInfo(const ValueDecl *decl) const;

/// \brief Creates DeclSpirvInfo using the given instr and index. It creates a
/// clone variable if it is CTBuffer including matrix 1xN with FXC memory
/// layout.
DeclSpirvInfo createDeclSpirvInfo(SpirvInstruction *instr,
int index = -1) const {
if (auto *clone = spvBuilder.initializeCloneVarForFxcCTBuffer(instr))
instr = clone;
return DeclSpirvInfo(instr, index);
}

public:
/// \brief Returns the information for the given decl.
///
Expand Down Expand Up @@ -786,6 +796,7 @@ class DeclResultIdMapper {
/// Mapping of all Clang AST decls to their instruction pointers.
llvm::DenseMap<const ValueDecl *, DeclSpirvInfo> astDecls;
llvm::DenseMap<const ValueDecl *, SpirvFunction *> astFunctionDecls;

/// Vector of all defined stage variables.
llvm::SmallVector<StageVar, 8> stageVars;
/// Mapping from Clang AST decls to the corresponding stage variables.
Expand Down
Loading