Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][VectorOps] Support string literals in vector.print #68695

Merged
merged 6 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//===- PrintCallHelper.h - Helper to emit runtime print calls ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_LLVMIR_PRINTCALLHELPER_H_
#define MLIR_DIALECT_LLVMIR_PRINTCALLHELPER_H_

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "llvm/ADT/StringRef.h"
#include <optional>

namespace mlir {

class OpBuilder;
class LLVMTypeConverter;

namespace LLVM {

/// Generate IR that prints the given string to stdout.
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
/// have the signature void(char const*). The default function is `printString`.
void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp,
StringRef symbolName, StringRef string,
const LLVMTypeConverter &typeConverter,
bool addNewline = true,
std::optional<StringRef> runtimeFunctionName = {});
} // namespace LLVM

} // namespace mlir

#endif
9 changes: 7 additions & 2 deletions mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "mlir/IR/Operation.h"
#include "mlir/Support/LLVM.h"
#include <optional>

namespace mlir {
class Location;
Expand All @@ -38,8 +39,12 @@ LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintStrFn(ModuleOp moduleOp,
bool opaquePointers);
/// Declares a function to print a C-string.
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
/// have the signature void(char const*). The default function is `printString`.
LLVM::LLVMFuncOp
lookupOrCreatePrintStringFn(ModuleOp moduleOp, bool opaquePointers,
std::optional<StringRef> runtimeFunctionName = {});
LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp);
Expand Down
37 changes: 33 additions & 4 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/BuiltinAttributes.td"

// TODO: Add an attribute to specify a different algebra with operators other
// than the current set: {*, +}.
Expand Down Expand Up @@ -2477,12 +2478,18 @@ def Vector_TransposeOp :
}

def Vector_PrintOp :
Vector_Op<"print", []>,
Vector_Op<"print", [
PredOpTrait<
"`source` or `punctuation` are not set when printing strings",
MacDue marked this conversation as resolved.
Show resolved Hide resolved
CPred<"!getStringLiteral() || (!getSource() && getPunctuation() == PrintPunctuation::NewLine)">
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why isn't other punctuation allowed? It would be very useful to be able to print strings without the trailing newline.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strings don't have a trailing newline.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lowering for them in LLVM::createPrintStrCall adds one unconditionally.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll post a complete patch for this in a bit, but I was thinking something along the lines of:

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index b66b55ae8d57..fd5aec1982eb 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1531,8 +1531,28 @@ public:
 
     auto punct = printOp.getPunctuation();
     if (auto stringLiteral = printOp.getStringLiteral()) {
+      std::string str;
+      llvm::raw_string_ostream punctuatedLiteral(str);
+      if (punct == PrintPunctuation::Open)
+        punctuatedLiteral << "( ";
+      punctuatedLiteral << stringLiteral->str();
+      switch (punct) {
+      case PrintPunctuation::Close:
+        punctuatedLiteral << " )";
+        break;
+      case PrintPunctuation::Comma:
+        punctuatedLiteral << ", ";
+        break;
+      case PrintPunctuation::NewLine:
+        punctuatedLiteral << '\n';
+        break;
+      case PrintPunctuation::Open:
+      case PrintPunctuation::NoPunctuation:
+        break;
+      }
       LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
-                               *stringLiteral, *getTypeConverter());
+                               punctuatedLiteral.str(), *getTypeConverter(),
+                               /*addNewLine=*/false);
     } else if (punct != PrintPunctuation::NoPunctuation) {
       emitCall(rewriter, printOp->getLoc(), [&] {
         switch (punct) {

so that:

diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir
index 78d6609ccaf9..b47c5b38f783 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir
@@ -8,6 +8,16 @@
 func.func @entry() {
    // CHECK: Hello, World!
    vector.print str "Hello, World!"
+
+   // CHECK-NEXT: Nice to meet you ( finally, today, and in this place )
+   vector.print str "Nice to meet you " punctuation <no_punctuation>
+   vector.print str "finally" punctuation <open>
+   vector.print punctuation <comma>
+   vector.print str "today" punctuation <comma>
+   vector.print str "and in " punctuation <no_punctuation>
+   vector.print str "this place" punctuation <close>
+   vector.print punctuation <newline>
+
    // CHECK-NEXT: Bye!
    vector.print str "Bye!"
    return

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm that's odd, I verified no newline is printed for strings yesterday and that seems to be the case, are we seeing different behaviour? The createPrintStrCall for vector.print <str> already has addNewLine=false

if (auto stringLiteral = printOp.getStringLiteral()) {
LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
*stringLiteral, *getTypeConverter(),
/*addNewline=*/false);

although for empty string it will print newline as that wont be entered.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, my branch is a bit behind. This has been fixed already: c1b8c6c

>,
]>,
Arguments<(ins Optional<Type<Or<[
AnyVectorOfAnyRank.predicate,
AnyInteger.predicate, Index.predicate, AnyFloat.predicate
]>>>:$source, DefaultValuedAttr<Vector_PrintPunctuation,
"::mlir::vector::PrintPunctuation::NewLine">:$punctuation)
"::mlir::vector::PrintPunctuation::NewLine">:$punctuation,
OptionalAttr<Builtin_StringAttr>:$stringLiteral)
> {
let summary = "print operation (for testing and debugging)";
let description = [{
Expand Down Expand Up @@ -2521,6 +2528,13 @@ def Vector_PrintOp :
```mlir
vector.print punctuation <newline>
```

Additionally, to aid with debugging and testing `vector.print` can also
print constant strings:

```mlir
vector.print str "Hello, World!"
```
}];
let extraClassDeclaration = [{
Type getPrintType() {
Expand All @@ -2529,11 +2543,26 @@ def Vector_PrintOp :
}];
let builders = [
OpBuilder<(ins "PrintPunctuation":$punctuation), [{
build($_builder, $_state, {}, punctuation);
build($_builder, $_state, {}, punctuation, {});
}]>,
OpBuilder<(ins "::mlir::Value":$source), [{
build($_builder, $_state, source, PrintPunctuation::NewLine);
}]>,
OpBuilder<(ins "::mlir::Value":$source, "PrintPunctuation":$punctuation), [{
build($_builder, $_state, source, punctuation, {});
}]>,
OpBuilder<(ins "::llvm::StringRef":$string), [{
build($_builder, $_state, {}, PrintPunctuation::NewLine, $_builder.getStringAttr(string));
}]>,
];

let assemblyFormat = "($source^ `:` type($source))? (`punctuation` $punctuation^)? attr-dict";
let assemblyFormat = [{
($source^ `:` type($source))?
oilist(
`str` $stringLiteral
| `punctuation` $punctuation)
attr-dict
}];
}

//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ extern "C" MLIR_CRUNNERUTILS_EXPORT void printI64(int64_t i);
extern "C" MLIR_CRUNNERUTILS_EXPORT void printU64(uint64_t u);
extern "C" MLIR_CRUNNERUTILS_EXPORT void printF32(float f);
extern "C" MLIR_CRUNNERUTILS_EXPORT void printF64(double d);
extern "C" MLIR_CRUNNERUTILS_EXPORT void printString(char const *s);
extern "C" MLIR_CRUNNERUTILS_EXPORT void printOpen();
extern "C" MLIR_CRUNNERUTILS_EXPORT void printClose();
extern "C" MLIR_CRUNNERUTILS_EXPORT void printComma();
Expand Down
50 changes: 4 additions & 46 deletions mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
Expand All @@ -36,51 +37,6 @@ using namespace mlir;

#define PASS_NAME "convert-cf-to-llvm"

static std::string generateGlobalMsgSymbolName(ModuleOp moduleOp) {
std::string prefix = "assert_msg_";
int counter = 0;
while (moduleOp.lookupSymbol(prefix + std::to_string(counter)))
++counter;
return prefix + std::to_string(counter);
}

/// Generate IR that prints the given string to stderr.
static void createPrintMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp,
StringRef msg,
const LLVMTypeConverter &typeConverter) {
auto ip = builder.saveInsertionPoint();
builder.setInsertionPointToStart(moduleOp.getBody());
MLIRContext *ctx = builder.getContext();

// Create a zero-terminated byte representation and allocate global symbol.
SmallVector<uint8_t> elementVals;
elementVals.append(msg.begin(), msg.end());
elementVals.push_back(0);
auto dataAttrType = RankedTensorType::get(
{static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
auto dataAttr =
DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
auto arrayTy =
LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
std::string symbolName = generateGlobalMsgSymbolName(moduleOp);
auto globalOp = builder.create<LLVM::GlobalOp>(
loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, symbolName,
dataAttr);

// Emit call to `printStr` in runtime library.
builder.restoreInsertionPoint(ip);
auto msgAddr = builder.create<LLVM::AddressOfOp>(
loc, typeConverter.getPointerType(arrayTy), globalOp.getName());
SmallVector<LLVM::GEPArg> indices(1, 0);
Value gep = builder.create<LLVM::GEPOp>(
loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
indices);
Operation *printer = LLVM::lookupOrCreatePrintStrFn(
moduleOp, typeConverter.useOpaquePointers());
builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
gep);
}

namespace {
/// Lower `cf.assert`. The default lowering calls the `abort` function if the
/// assertion is violated and has no effect otherwise. The failure message is
Expand All @@ -105,7 +61,9 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {

// Failed block: Generate IR to print the message and call `abort`.
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
createPrintMsg(rewriter, loc, module, op.getMsg(), *getTypeConverter());
LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
*getTypeConverter(), /*addNewLine=*/false,
/*runtimeFunctionName=*/"puts");
if (abortOnFailedAssert) {
// Insert the `abort` declaration if necessary.
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_mlir_conversion_library(MLIRLLVMCommonConversion
LoweringOptions.cpp
MemRefBuilder.cpp
Pattern.cpp
PrintCallHelper.cpp
StructBuilder.cpp
TypeConverter.cpp
VectorPattern.cpp
Expand Down
66 changes: 66 additions & 0 deletions mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
//===- PrintCallHelper.cpp - Helper to emit runtime print calls -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "llvm/ADT/ArrayRef.h"

using namespace mlir;
using namespace llvm;

static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
StringRef symbolName) {
static int counter = 0;
std::string uniqueName = std::string(symbolName);
while (moduleOp.lookupSymbol(uniqueName)) {
uniqueName = std::string(symbolName) + "_" + std::to_string(counter++);
}
return uniqueName;
}

void mlir::LLVM::createPrintStrCall(
OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline,
std::optional<StringRef> runtimeFunctionName) {
auto ip = builder.saveInsertionPoint();
builder.setInsertionPointToStart(moduleOp.getBody());
MLIRContext *ctx = builder.getContext();

// Create a zero-terminated byte representation and allocate global symbol.
SmallVector<uint8_t> elementVals;
elementVals.append(string.begin(), string.end());
if (addNewline)
elementVals.push_back('\n');
elementVals.push_back('\0');
auto dataAttrType = RankedTensorType::get(
{static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
auto dataAttr =
DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
auto arrayTy =
LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
auto globalOp = builder.create<LLVM::GlobalOp>(
loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private,
ensureSymbolNameIsUnique(moduleOp, symbolName), dataAttr);

// Emit call to `printStr` in runtime library.
builder.restoreInsertionPoint(ip);
auto msgAddr = builder.create<LLVM::AddressOfOp>(
loc, typeConverter.getPointerType(arrayTy), globalOp.getName());
SmallVector<LLVM::GEPArg> indices(1, 0);
Value gep = builder.create<LLVM::GEPOp>(
loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
indices);
Operation *printer = LLVM::lookupOrCreatePrintStringFn(
moduleOp, typeConverter.useOpaquePointers(), runtimeFunctionName);
builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
gep);
}
6 changes: 5 additions & 1 deletion mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"

#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Expand Down Expand Up @@ -1548,7 +1549,10 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
}

auto punct = printOp.getPunctuation();
if (punct != PrintPunctuation::NoPunctuation) {
if (auto stringLiteral = printOp.getStringLiteral()) {
LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
*stringLiteral, *getTypeConverter());
} else if (punct != PrintPunctuation::NoPunctuation) {
emitCall(rewriter, printOp->getLoc(), [&] {
switch (punct) {
case PrintPunctuation::Close:
Expand Down
9 changes: 5 additions & 4 deletions mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ static constexpr llvm::StringRef kPrintF16 = "printF16";
static constexpr llvm::StringRef kPrintBF16 = "printBF16";
static constexpr llvm::StringRef kPrintF32 = "printF32";
static constexpr llvm::StringRef kPrintF64 = "printF64";
static constexpr llvm::StringRef kPrintStr = "puts";
static constexpr llvm::StringRef kPrintString = "printString";
static constexpr llvm::StringRef kPrintOpen = "printOpen";
static constexpr llvm::StringRef kPrintClose = "printClose";
static constexpr llvm::StringRef kPrintComma = "printComma";
Expand Down Expand Up @@ -107,9 +107,10 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context,
return getCharPtr(context, opaquePointers);
}

LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStrFn(ModuleOp moduleOp,
bool opaquePointers) {
return lookupOrCreateFn(moduleOp, kPrintStr,
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
ModuleOp moduleOp, bool opaquePointers,
std::optional<StringRef> runtimeFunctionName) {
return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
getCharPtr(moduleOp->getContext(), opaquePointers),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/ExecutionEngine/CRunnerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ extern "C" void printI64(int64_t i) { fprintf(stdout, "%" PRId64, i); }
extern "C" void printU64(uint64_t u) { fprintf(stdout, "%" PRIu64, u); }
extern "C" void printF32(float f) { fprintf(stdout, "%g", f); }
extern "C" void printF64(double d) { fprintf(stdout, "%lg", d); }
extern "C" void printString(char const *s) { fputs(s, stdout); }
extern "C" void printOpen() { fputs("( ", stdout); }
extern "C" void printClose() { fputs(" )", stdout); }
extern "C" void printComma() { fputs(", ", stdout); }
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/ExecutionEngine/RunnerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ extern "C" void printMemrefC64(int64_t rank, void *ptr) {
_mlir_ciface_printMemrefC64(&descriptor);
}

extern "C" void printCString(char *str) { printf("%s", str); }
/// Deprecated. This should be unified with printString from CRunnerUtils.
extern "C" void printCString(char *str) { fputs(str, stdout); }
Comment on lines -161 to +162
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] I would just skip these changes altoghether.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It follows on for #68973 (which removes all uses of printCString()), so I'll follow this up after that.


extern "C" void _mlir_ciface_printMemref0dF32(StridedMemRefType<float, 0> *M) {
impl::printMemRef(*M);
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,20 @@ func.func @vector_print_scalar_f64(%arg0: f64) {

// -----

// CHECK-LABEL: module {
// CHECK: llvm.func @printString(!llvm.ptr)
// CHECK: llvm.mlir.global private constant @[[GLOBAL_STR:.*]]({{.*}})
// CHECK: @vector_print_string
// CHECK-NEXT: %[[GLOBAL_ADDR:.*]] = llvm.mlir.addressof @[[GLOBAL_STR]] : !llvm.ptr
// CHECK-NEXT: %[[STR_PTR:.*]] = llvm.getelementptr %[[GLOBAL_ADDR]][0] : (!llvm.ptr) -> !llvm.ptr
// CHECK-NEXT: llvm.call @printString(%[[STR_PTR]]) : (!llvm.ptr) -> ()
func.func @vector_print_string() {
vector.print str "Hello, World!"
return
}

// -----

func.func @extract_strided_slice1(%arg0: vector<4xf32>) -> vector<2xf32> {
%0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
return %0 : vector<2xf32>
Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Dialect/Vector/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,22 @@ func.func private @print_needs_vector(%arg0: tensor<8xf32>) {

// -----

func.func @cannot_print_string_with_punctuation_set() {
// expected-error@+1 {{`source` or `punctuation` are not set when printing strings}}
vector.print str "Whoops!" punctuation <comma>
return
}

// -----

func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) {
// expected-error@+1 {{`source` or `punctuation` are not set when printing strings}}
vector.print %vec: vector<[4]xf32> str "Yay!"
return
}

// -----

func.func @reshape_bad_input_shape(%arg0 : vector<3x2x4xf32>) {
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
Expand Down
Loading