Skip to content

Commit

Permalink
Extend LLVM IR type mangling to handle scalars. (halide#7212)
Browse files Browse the repository at this point in the history
Extend LLVM IR type mangling to handle scalars and use this in vector predication intrinsic codegen.

Fixes an error denerating vector predicated strided stores.
  • Loading branch information
Zalman Stern authored and ardier committed Mar 3, 2024
1 parent c3c4085 commit 0e98086
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 17 deletions.
31 changes: 17 additions & 14 deletions src/CodeGen_LLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5146,22 +5146,33 @@ llvm::Constant *CodeGen_LLVM::get_splat(int lanes, llvm::Constant *value,
return ConstantVector::getSplat(ec, value);
}

std::string CodeGen_LLVM::mangle_llvm_vector_type(llvm::Type *type) {
std::string CodeGen_LLVM::mangle_llvm_type(llvm::Type *type) {
std::string type_string = ".";
bool is_scalable = isa<llvm::ScalableVectorType>(type);
llvm::ElementCount llvm_vector_ec;
if (is_scalable) {
if (isa<PointerType>(type)) {
const auto *vt = cast<llvm::PointerType>(type);
type_string = ".p" + std::to_string(vt->getAddressSpace());
} else if (isa<llvm::ScalableVectorType>(type)) {
const auto *vt = cast<llvm::ScalableVectorType>(type);
const char *type_designator = vt->getElementType()->isIntegerTy() ? "i" : "f";
std::string bits_designator = std::to_string(vt->getScalarSizeInBits());
llvm_vector_ec = vt->getElementCount();
type_string = ".nxv" + std::to_string(vt->getMinNumElements()) + type_designator + bits_designator;
} else {
} else if (isa<llvm::FixedVectorType>(type)) {
const auto *vt = cast<llvm::FixedVectorType>(type);
const char *type_designator = vt->getElementType()->isIntegerTy() ? "i" : "f";
std::string bits_designator = std::to_string(vt->getScalarSizeInBits());
llvm_vector_ec = vt->getElementCount();
type_string = ".v" + std::to_string(vt->getNumElements()) + type_designator + bits_designator;
} else if (type->isIntegerTy()) {
type_string = ".i" + std::to_string(type->getScalarSizeInBits());
} else if (type->isFloatTy()) {
type_string = ".f" + std::to_string(type->getScalarSizeInBits());
} else {
std::string type_name;
llvm::raw_string_ostream type_name_stream(type_name);
type->print(type_name_stream, true);
internal_error << "Attempt to mangle unknown LLVM type " << type_name << "\n";
}
return type_string;
}
Expand Down Expand Up @@ -5205,19 +5216,11 @@ bool CodeGen_LLVM::try_vector_predication_intrinsic(const std::string &name, VPR
args.push_back(arg.value);
if (arg.mangle_index) {
llvm::Type *llvm_type = arg.value->getType();
if (isa<PointerType>(llvm_type)) {
mangled_types[arg.mangle_index.value()] = ".p0";
} else {
mangled_types[arg.mangle_index.value()] = mangle_llvm_vector_type(llvm_type);
}
mangled_types[arg.mangle_index.value()] = mangle_llvm_type(llvm_type);
}
}
if (result_type.mangle_index) {
if (isa<PointerType>(llvm_result_type)) {
mangled_types[result_type.mangle_index.value()] = ".p0";
} else {
mangled_types[result_type.mangle_index.value()] = mangle_llvm_vector_type(llvm_result_type);
}
mangled_types[result_type.mangle_index.value()] = mangle_llvm_type(llvm_result_type);
}

std::string full_name = name;
Expand Down
7 changes: 4 additions & 3 deletions src/CodeGen_LLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -304,12 +304,13 @@ class CodeGen_LLVM : public IRVisitor {
llvm::Value *codegen_buffer_pointer(llvm::Value *base_address, Type type, llvm::Value *index);
// @}

/** Return type string for LLVM vector type using LLVM IR intrinsic type mangling.
* E.g. ".nxv4i32" for a scalable vector of four 32-bit integers,
/** Return type string for LLVM type using LLVM IR intrinsic type mangling.
* E.g. ".i32 or ".f32" for scalars, ".p0" for pointers,
* ".nxv4i32" for a scalable vector of four 32-bit integers,
* or ".v4f32" for a fixed vector of four 32-bit floats.
* The dot is included in the result.
*/
std::string mangle_llvm_vector_type(llvm::Type *type);
std::string mangle_llvm_type(llvm::Type *type);

/** Turn a Halide Type into an llvm::Value representing a constant halide_type_t */
llvm::Value *make_halide_type_t(const Type &);
Expand Down

0 comments on commit 0e98086

Please sign in to comment.