From 0e980868f4e20dbd9b4e9602780fe3cb3f5f8b81 Mon Sep 17 00:00:00 2001 From: Zalman Stern Date: Tue, 6 Dec 2022 23:15:46 -0800 Subject: [PATCH] Extend LLVM IR type mangling to handle scalars. (#7212) Extend LLVM IR type mangling to handle scalars and use this in vector predication intrinsic codegen. Fixes an error denerating vector predicated strided stores. --- src/CodeGen_LLVM.cpp | 31 +++++++++++++++++-------------- src/CodeGen_LLVM.h | 7 ++++--- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index 225e19ee4513..200857ed9716 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -5146,22 +5146,33 @@ llvm::Constant *CodeGen_LLVM::get_splat(int lanes, llvm::Constant *value, return ConstantVector::getSplat(ec, value); } -std::string CodeGen_LLVM::mangle_llvm_vector_type(llvm::Type *type) { +std::string CodeGen_LLVM::mangle_llvm_type(llvm::Type *type) { std::string type_string = "."; - bool is_scalable = isa(type); llvm::ElementCount llvm_vector_ec; - if (is_scalable) { + if (isa(type)) { + const auto *vt = cast(type); + type_string = ".p" + std::to_string(vt->getAddressSpace()); + } else if (isa(type)) { const auto *vt = cast(type); const char *type_designator = vt->getElementType()->isIntegerTy() ? "i" : "f"; std::string bits_designator = std::to_string(vt->getScalarSizeInBits()); llvm_vector_ec = vt->getElementCount(); type_string = ".nxv" + std::to_string(vt->getMinNumElements()) + type_designator + bits_designator; - } else { + } else if (isa(type)) { const auto *vt = cast(type); const char *type_designator = vt->getElementType()->isIntegerTy() ? "i" : "f"; std::string bits_designator = std::to_string(vt->getScalarSizeInBits()); llvm_vector_ec = vt->getElementCount(); type_string = ".v" + std::to_string(vt->getNumElements()) + type_designator + bits_designator; + } else if (type->isIntegerTy()) { + type_string = ".i" + std::to_string(type->getScalarSizeInBits()); + } else if (type->isFloatTy()) { + type_string = ".f" + std::to_string(type->getScalarSizeInBits()); + } else { + std::string type_name; + llvm::raw_string_ostream type_name_stream(type_name); + type->print(type_name_stream, true); + internal_error << "Attempt to mangle unknown LLVM type " << type_name << "\n"; } return type_string; } @@ -5205,19 +5216,11 @@ bool CodeGen_LLVM::try_vector_predication_intrinsic(const std::string &name, VPR args.push_back(arg.value); if (arg.mangle_index) { llvm::Type *llvm_type = arg.value->getType(); - if (isa(llvm_type)) { - mangled_types[arg.mangle_index.value()] = ".p0"; - } else { - mangled_types[arg.mangle_index.value()] = mangle_llvm_vector_type(llvm_type); - } + mangled_types[arg.mangle_index.value()] = mangle_llvm_type(llvm_type); } } if (result_type.mangle_index) { - if (isa(llvm_result_type)) { - mangled_types[result_type.mangle_index.value()] = ".p0"; - } else { - mangled_types[result_type.mangle_index.value()] = mangle_llvm_vector_type(llvm_result_type); - } + mangled_types[result_type.mangle_index.value()] = mangle_llvm_type(llvm_result_type); } std::string full_name = name; diff --git a/src/CodeGen_LLVM.h b/src/CodeGen_LLVM.h index b132dac1d314..ed6ee0d1ee4c 100644 --- a/src/CodeGen_LLVM.h +++ b/src/CodeGen_LLVM.h @@ -304,12 +304,13 @@ class CodeGen_LLVM : public IRVisitor { llvm::Value *codegen_buffer_pointer(llvm::Value *base_address, Type type, llvm::Value *index); // @} - /** Return type string for LLVM vector type using LLVM IR intrinsic type mangling. - * E.g. ".nxv4i32" for a scalable vector of four 32-bit integers, + /** Return type string for LLVM type using LLVM IR intrinsic type mangling. + * E.g. ".i32 or ".f32" for scalars, ".p0" for pointers, + * ".nxv4i32" for a scalable vector of four 32-bit integers, * or ".v4f32" for a fixed vector of four 32-bit floats. * The dot is included in the result. */ - std::string mangle_llvm_vector_type(llvm::Type *type); + std::string mangle_llvm_type(llvm::Type *type); /** Turn a Halide Type into an llvm::Value representing a constant halide_type_t */ llvm::Value *make_halide_type_t(const Type &);