Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][VectorOps] Support string literals in vector.print #68695

Merged
merged 6 commits into from
Oct 24, 2023

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Oct 10, 2023

Printing strings within integration tests is currently quite annoyingly
verbose, and can't be tucked into shared helpers as the types depend on
the length of the string:

llvm.mlir.global internal constant @hello_world("Hello, World!\0")

func.func @entry() {
  %0 = llvm.mlir.addressof @hello_world : !llvm.ptr<array<14 x i8>>
  %1 = llvm.mlir.constant(0 : index) : i64
  %2 = llvm.getelementptr %0[%1, %1]
    : (!llvm.ptr<array<14 x i8>>, i64, i64) -> !llvm.ptr<i8>
  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
  return
}

So this patch adds a simple extension to vector.print to simplify
this:

func.func @entry() {
   // Print a vector of characters ;)
   vector.print str "Hello, World!"
   return
}

Most of the logic for this is now shared with cf.assert which already
does something similar.

Depends on #68694

@MacDue
Copy link
Member Author

MacDue commented Oct 10, 2023

cc @c-rhodes

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the patch Ben! I think this is generally a really useful feature and will clean up some of the tests nicely

mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h Outdated Show resolved Hide resolved
mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td Outdated Show resolved Hide resolved
mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp Outdated Show resolved Hide resolved
@llvmbot
Copy link
Member

llvmbot commented Oct 13, 2023

@llvm/pr-subscribers-mlir-execution-engine
@llvm/pr-subscribers-mlir-cf
@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

Changes

Printing strings within integration tests is currently quite annoyingly
verbose, and can't be tucked into shared helpers as the types depend on
the length of the string:

llvm.mlir.global internal constant @<!-- -->hello_world("Hello, World!\0")

func.func @<!-- -->entry() {
  %0 = llvm.mlir.addressof @<!-- -->hello_world : !llvm.ptr&lt;array&lt;14 x i8&gt;&gt;
  %1 = llvm.mlir.constant(0 : index) : i64
  %2 = llvm.getelementptr %0[%1, %1]
    : (!llvm.ptr&lt;array&lt;14 x i8&gt;&gt;, i64, i64) -&gt; !llvm.ptr&lt;i8&gt;
  llvm.call @<!-- -->printCString(%2) : (!llvm.ptr&lt;i8&gt;) -&gt; ()
  return
}

So this patch adds a simple extension to vector.print to simplify
this:

func.func @<!-- -->entry() {
   // Print a vector of characters ;)
   vector.print str "Hello, World!"
   return
}

Most of the logic for this is now shared with cf.assert which already
does something similar.

Depends on #68694


Full diff: https://github.com/llvm/llvm-project/pull/68695.diff

9 Files Affected:

  • (added) mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h (+30)
  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+33-4)
  • (modified) mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp (+3-46)
  • (modified) mlir/lib/Conversion/LLVMCommon/CMakeLists.txt (+1)
  • (added) mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp (+64)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+5-1)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+14)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+16)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir (+14)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
new file mode 100644
index 000000000000000..457cd98ca3dc2c8
--- /dev/null
+++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
@@ -0,0 +1,30 @@
+//===- PrintCallHelper.h - Helper to emit runtime print calls ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_PRINTCALLHELPER_H_
+#define MLIR_DIALECT_LLVMIR_PRINTCALLHELPER_H_
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace mlir {
+
+class OpBuilder;
+class LLVMTypeConverter;
+
+namespace LLVM {
+
+/// Generate IR that prints the given string to stdout.
+void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp,
+                        StringRef symbolName, StringRef string,
+                        const LLVMTypeConverter &typeConverter);
+} // namespace LLVM
+
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 2df2fe4c5ce8e9c..f946d124fb2fa5e 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -26,6 +26,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/VectorInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
+include "mlir/IR/BuiltinAttributes.td"
 
 // TODO: Add an attribute to specify a different algebra with operators other
 // than the current set: {*, +}.
@@ -2476,12 +2477,18 @@ def Vector_TransposeOp :
 }
 
 def Vector_PrintOp :
-  Vector_Op<"print", []>,
+  Vector_Op<"print", [
+    PredOpTrait<
+      "`source` or `punctuation` are not set when printing strings",
+      CPred<"!getStringLiteral() || (!getSource() && getPunctuation() == PrintPunctuation::NewLine)">
+    >,
+  ]>,
   Arguments<(ins Optional<Type<Or<[
     AnyVectorOfAnyRank.predicate,
     AnyInteger.predicate, Index.predicate, AnyFloat.predicate
   ]>>>:$source, DefaultValuedAttr<Vector_PrintPunctuation,
-                      "::mlir::vector::PrintPunctuation::NewLine">:$punctuation)
+                      "::mlir::vector::PrintPunctuation::NewLine">:$punctuation,
+                OptionalAttr<Builtin_StringAttr>:$stringLiteral)
   > {
   let summary = "print operation (for testing and debugging)";
   let description = [{
@@ -2520,6 +2527,13 @@ def Vector_PrintOp :
     ```mlir
     vector.print punctuation <newline>
     ```
+
+    Additionally, to aid with debugging and testing `vector.print` can also
+    print constant strings:
+
+    ```mlir
+    vector.print str "Hello, World!"
+    ```
   }];
   let extraClassDeclaration = [{
     Type getPrintType() {
@@ -2528,11 +2542,26 @@ def Vector_PrintOp :
   }];
   let builders = [
     OpBuilder<(ins "PrintPunctuation":$punctuation), [{
-      build($_builder, $_state, {}, punctuation);
+      build($_builder, $_state, {}, punctuation, {});
+    }]>,
+    OpBuilder<(ins "::mlir::Value":$source), [{
+      build($_builder, $_state, source, PrintPunctuation::NewLine);
+    }]>,
+    OpBuilder<(ins "::mlir::Value":$source, "PrintPunctuation":$punctuation), [{
+      build($_builder, $_state, source, punctuation, {});
+    }]>,
+    OpBuilder<(ins "::llvm::StringRef":$string), [{
+      build($_builder, $_state, {}, PrintPunctuation::NewLine, $_builder.getStringAttr(string));
     }]>,
   ];
 
-  let assemblyFormat = "($source^ `:` type($source))? (`punctuation` $punctuation^)? attr-dict";
+  let assemblyFormat = [{
+      ($source^ `:` type($source))?
+        oilist(
+            `str` $stringLiteral
+          | `punctuation` $punctuation)
+        attr-dict
+    }];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index a4f146bbe475cc6..6b7647b038f1d94 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
@@ -36,51 +37,6 @@ using namespace mlir;
 
 #define PASS_NAME "convert-cf-to-llvm"
 
-static std::string generateGlobalMsgSymbolName(ModuleOp moduleOp) {
-  std::string prefix = "assert_msg_";
-  int counter = 0;
-  while (moduleOp.lookupSymbol(prefix + std::to_string(counter)))
-    ++counter;
-  return prefix + std::to_string(counter);
-}
-
-/// Generate IR that prints the given string to stderr.
-static void createPrintMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp,
-                           StringRef msg,
-                           const LLVMTypeConverter &typeConverter) {
-  auto ip = builder.saveInsertionPoint();
-  builder.setInsertionPointToStart(moduleOp.getBody());
-  MLIRContext *ctx = builder.getContext();
-
-  // Create a zero-terminated byte representation and allocate global symbol.
-  SmallVector<uint8_t> elementVals;
-  elementVals.append(msg.begin(), msg.end());
-  elementVals.push_back(0);
-  auto dataAttrType = RankedTensorType::get(
-      {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
-  auto dataAttr =
-      DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
-  auto arrayTy =
-      LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
-  std::string symbolName = generateGlobalMsgSymbolName(moduleOp);
-  auto globalOp = builder.create<LLVM::GlobalOp>(
-      loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, symbolName,
-      dataAttr);
-
-  // Emit call to `printStr` in runtime library.
-  builder.restoreInsertionPoint(ip);
-  auto msgAddr = builder.create<LLVM::AddressOfOp>(
-      loc, typeConverter.getPointerType(arrayTy), globalOp.getName());
-  SmallVector<LLVM::GEPArg> indices(1, 0);
-  Value gep = builder.create<LLVM::GEPOp>(
-      loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
-      indices);
-  Operation *printer = LLVM::lookupOrCreatePrintStrFn(
-      moduleOp, typeConverter.useOpaquePointers());
-  builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
-                               gep);
-}
-
 namespace {
 /// Lower `cf.assert`. The default lowering calls the `abort` function if the
 /// assertion is violated and has no effect otherwise. The failure message is
@@ -105,7 +61,8 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
 
     // Failed block: Generate IR to print the message and call `abort`.
     Block *failureBlock = rewriter.createBlock(opBlock->getParent());
-    createPrintMsg(rewriter, loc, module, op.getMsg(), *getTypeConverter());
+    LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
+                             *getTypeConverter());
     if (abortOnFailedAssert) {
       // Insert the `abort` declaration if necessary.
       auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
diff --git a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
index 091cd539f0ae014..568d9339aaabcb4 100644
--- a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
+++ b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_conversion_library(MLIRLLVMCommonConversion
   LoweringOptions.cpp
   MemRefBuilder.cpp
   Pattern.cpp
+  PrintCallHelper.cpp
   StructBuilder.cpp
   TypeConverter.cpp
   VectorPattern.cpp
diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
new file mode 100644
index 000000000000000..40b9382452fbb45
--- /dev/null
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -0,0 +1,64 @@
+//===- PrintCallHelper.cpp - Helper to emit runtime print calls -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "llvm/ADT/ArrayRef.h"
+
+using namespace mlir;
+using namespace llvm;
+
+static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
+                                            StringRef symbolName) {
+  static int counter = 0;
+  std::string uniqueName = std::string(symbolName);
+  while (moduleOp.lookupSymbol(uniqueName)) {
+    uniqueName = std::string(symbolName) + "_" + std::to_string(counter++);
+  }
+  return uniqueName;
+}
+
+void mlir::LLVM::createPrintStrCall(OpBuilder &builder, Location loc,
+                                    ModuleOp moduleOp, StringRef symbolName,
+                                    StringRef string,
+                                    const LLVMTypeConverter &typeConverter) {
+  auto ip = builder.saveInsertionPoint();
+  builder.setInsertionPointToStart(moduleOp.getBody());
+  MLIRContext *ctx = builder.getContext();
+
+  // Create a zero-terminated byte representation and allocate global symbol.
+  SmallVector<uint8_t> elementVals;
+  elementVals.append(string.begin(), string.end());
+  elementVals.push_back(0);
+  auto dataAttrType = RankedTensorType::get(
+      {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
+  auto dataAttr =
+      DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
+  auto arrayTy =
+      LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
+  auto globalOp = builder.create<LLVM::GlobalOp>(
+      loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private,
+      ensureSymbolNameIsUnique(moduleOp, symbolName), dataAttr);
+
+  // Emit call to `printStr` in runtime library.
+  builder.restoreInsertionPoint(ip);
+  auto msgAddr = builder.create<LLVM::AddressOfOp>(
+      loc, typeConverter.getPointerType(arrayTy), globalOp.getName());
+  SmallVector<LLVM::GEPArg> indices(1, 0);
+  Value gep = builder.create<LLVM::GEPOp>(
+      loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
+      indices);
+  Operation *printer = LLVM::lookupOrCreatePrintStrFn(
+      moduleOp, typeConverter.useOpaquePointers());
+  builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
+                               gep);
+}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 8427d60f14c0bcc..4af58653c8227ae 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 
 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
+#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -1548,7 +1549,10 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
     }
 
     auto punct = printOp.getPunctuation();
-    if (punct != PrintPunctuation::NoPunctuation) {
+    if (auto stringLiteral = printOp.getStringLiteral()) {
+      LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
+                               *stringLiteral, *getTypeConverter());
+    } else if (punct != PrintPunctuation::NoPunctuation) {
       emitCall(rewriter, printOp->getLoc(), [&] {
         switch (punct) {
         case PrintPunctuation::Close:
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 9aa4d735681f576..65b3a78e295f0c4 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1068,6 +1068,20 @@ func.func @vector_print_scalar_f64(%arg0: f64) {
 
 // -----
 
+// CHECK-LABEL: module {
+// CHECK: llvm.func @puts(!llvm.ptr)
+// CHECK: llvm.mlir.global private constant @[[GLOBAL_STR:.*]](dense<[72, 101, 108, 108, 111, 44, 32, 87, 111, 114, 108, 100, 33, 0]> : tensor<14xi8>) {addr_space = 0 : i32} : !llvm.array<14 x i8>
+// CHECK: @vector_print_string
+//       CHECK-NEXT: %[[GLOBAL_ADDR:.*]] = llvm.mlir.addressof @[[GLOBAL_STR]] : !llvm.ptr
+//       CHECK-NEXT: %[[STR_PTR:.*]] = llvm.getelementptr %[[GLOBAL_ADDR]][0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<14 x i8>
+//       CHECK-NEXT: llvm.call @puts(%[[STR_PTR]]) : (!llvm.ptr) -> ()
+func.func @vector_print_string() {
+  vector.print str "Hello, World!"
+  return
+}
+
+// -----
+
 func.func @extract_strided_slice1(%arg0: vector<4xf32>) -> vector<2xf32> {
   %0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
   return %0 : vector<2xf32>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 5967a8d69bbfcc0..1664ddde7e48d76 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1016,6 +1016,22 @@ func.func private @print_needs_vector(%arg0: tensor<8xf32>) {
 
 // -----
 
+func.func @cannot_print_string_with_punctuation_set() {
+  // expected-error@+1 {{`source` or `punctuation` are not set when printing strings}}
+  vector.print str "Whoops!" punctuation <comma>
+  return
+}
+
+// -----
+
+func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) {
+  // expected-error@+1 {{`source` or `punctuation` are not set when printing strings}}
+  vector.print %vec: vector<[4]xf32> str "Yay!"
+  return
+}
+
+// -----
+
 func.func @reshape_bad_input_shape(%arg0 : vector<3x2x4xf32>) {
   %c2 = arith.constant 2 : index
   %c3 = arith.constant 3 : index
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir
new file mode 100644
index 000000000000000..4a11987121b3308
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s -test-lower-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
+// RUN:   -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+/// This tests printing (multiple) string literals works.
+
+func.func @entry() {
+   // CHECK: Hello, World!
+   vector.print str "Hello, World!"
+   // CHECK-NEXT: Bye!
+   vector.print str "Bye!"
+   return
+}

@MacDue MacDue requested a review from banach-space October 13, 2023 10:43
@MacDue
Copy link
Member Author

MacDue commented Oct 13, 2023

Split the test updates to: #68973

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks Ben. Would be good to have some input from others so please allow time for that before landing, cheers.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Copy link
Contributor

@aartbik aartbik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this very useful feature!

Note that calling "puts" puts a direct dependence on the std lib for this to work, while originally the idea of the vector print support in the CRunnerlibrary was that users could provide implementation of "just" a few basic print operations (printI64, printU64, etc.) and then provide their own implementation, usually indeed using a std lib, as in

extern "C" void printNewline() { fputc('\n', stdout); }

So from that sense, would it make sense to provide

extern "C" void printString(char *) { fputs(s, stdout); }

and use that extra indirection to stay along the original philosophy I had for this library?

@MacDue MacDue force-pushed the vector_print_vector_of_characters branch from 24d9b25 to cfd10aa Compare October 19, 2023 08:05
@MacDue
Copy link
Member Author

MacDue commented Oct 19, 2023

Thanks for adding this very useful feature!

Note that calling "puts" puts a direct dependence on the std lib for this to work, while originally the idea of the vector print support in the CRunnerlibrary was that users could provide implementation of "just" a few basic print operations (printI64, printU64, etc.) and then provide their own implementation, usually indeed using a std lib, as in

extern "C" void printNewline() { fputc('\n', stdout); }

So from that sense, would it make sense to provide

extern "C" void printString(char *) { fputs(s, stdout); }

and use that extra indirection to stay along the original philosophy I had for this library?

I've switched this to use printCStr() rather than puts() directly now, as that already exists in the runner library 👍

@aartbik
Copy link
Contributor

aartbik commented Oct 19, 2023

Apologies for nitpicking... but, printCString belongs to the group of memref prints, defined in RunnerUtils.h. The group of methods that constitute the "small set of primitives" that are easy to port to other platforms live in CRunnerUtils.h, and are grouped together here:

//===----------------------------------------------------------------------===//
// Small runtime support library for vector.print lowering during codegen.
//===----------------------------------------------------------------------===//
extern "C" MLIR_CRUNNERUTILS_EXPORT void printI64(int64_t i);
extern "C" MLIR_CRUNNERUTILS_EXPORT void printU64(uint64_t u);
extern "C" MLIR_CRUNNERUTILS_EXPORT void printF32(float f);
extern "C" MLIR_CRUNNERUTILS_EXPORT void printF64(double d);
extern "C" MLIR_CRUNNERUTILS_EXPORT void printOpen();
extern "C" MLIR_CRUNNERUTILS_EXPORT void printClose();
extern "C" MLIR_CRUNNERUTILS_EXPORT void printComma();
extern "C" MLIR_CRUNNERUTILS_EXPORT void printNewline();

It feels like we need to add a printString(char*s) to this set (and then also the implementation at the top of CRunnerUtils.cpp).

In the big picture, I realize this may seem like too much nitpicking, but it it is much closer to the design philosophy of the vector.print support when I first added it (enabling the first actual running end-to-end MLIR tests ;-)

@MacDue
Copy link
Member Author

MacDue commented Oct 19, 2023

Apologies for nitpicking... but, printCString belongs to the group of memref prints, defined in RunnerUtils.h. The group of methods that constitute the "small set of primitives" that are easy to port to other platforms live in CRunnerUtils.h, and are grouped together here:

//===----------------------------------------------------------------------===// // Small runtime support library for vector.print lowering during codegen. //===----------------------------------------------------------------------===// extern "C" MLIR_CRUNNERUTILS_EXPORT void printI64(int64_t i); extern "C" MLIR_CRUNNERUTILS_EXPORT void printU64(uint64_t u); extern "C" MLIR_CRUNNERUTILS_EXPORT void printF32(float f); extern "C" MLIR_CRUNNERUTILS_EXPORT void printF64(double d); extern "C" MLIR_CRUNNERUTILS_EXPORT void printOpen(); extern "C" MLIR_CRUNNERUTILS_EXPORT void printClose(); extern "C" MLIR_CRUNNERUTILS_EXPORT void printComma(); extern "C" MLIR_CRUNNERUTILS_EXPORT void printNewline();

It feels like we need to add a printString(char*s) to this set (and then also the implementation at the top of CRunnerUtils.cpp).

In the big picture, I realize this may seem like too much nitpicking, but it it is much closer to the design philosophy of the vector.print support when I first added it (enabling the first actual running end-to-end MLIR tests ;-)

No worries! I've added printString to CRunnerUtils and switched to that. I've also defined an implementation in RunnerUtils since this is shared with cf.assert, and several tests there link RunnerUtils, but not CRunnerUtils, so it's easier if linking either gives you printString().

@MacDue MacDue force-pushed the vector_print_vector_of_characters branch 2 times, most recently from f0c9887 to 2c370be Compare October 19, 2023 17:56
@banach-space
Copy link
Contributor

I've also defined an implementation in RunnerUtils since this is shared with `cf.assert

Hi Ben, I think that we should avoid defining the same symbol both in RunnerUtils and CRunnerUtils. I know that the linker won't complain (IIRC, I will just use the first definition that it encounters). However, once the definitions in two libraries diverge, people might see different behavior depending on the order in which they list these runtime libraries. That would be undesirable.

Perhaps we should just focus on vector.print and skip cf.assert for now? Would that simplify things?

Printing strings within integration tests is currently quite annoyingly
verbose, and can't be tucked into shared helpers as the types depend on
the length of the string:

```
llvm.mlir.global internal constant @hello_world("Hello, World!\0")

func.func @entry() {
  %0 = llvm.mlir.addressof @hello_world : !llvm.ptr<array<14 x i8>>
  %1 = llvm.mlir.constant(0 : index) : i64
  %2 = llvm.getelementptr %0[%1, %1]
    : (!llvm.ptr<array<14 x i8>>, i64, i64) -> !llvm.ptr<i8>
  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
  return
}
``

So this patch adds a simple extension to `vector.print` to simplify
this:
```
func.func @entry() {
   // Print a vector of characters ;)
   vector.print str "Hello, World!"
   return
}
```

Most of the logic for this is now shared with `cf.assert` which already
does something similar.
PrintCallHelper is the only use of this, so we can safely switch.
I've also defined this in RunnerUtils so linking either or both gives
you printString, this avoids the need to update a bunch of tests that
use cf.assert.
@MacDue MacDue force-pushed the vector_print_vector_of_characters branch from 2c370be to 2b8b1c9 Compare October 23, 2023 11:02
@MacDue
Copy link
Member Author

MacDue commented Oct 23, 2023

Perhaps we should just focus on vector.print and skip cf.assert for now? Would that simplify things?

We can't skip cf.assert as this is shared code, but I've now reworked things so you can specify the runtime function to use. cf.assert will now keep using puts() and vector.print will use a new printString() function (which is defined only in CRunnerUtils).


printCString() should be removed in a later patch (and uses updated to use vector.print str or printString() directly)

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Ben, I believe that this addresses Aart's concerns about not depending on stdlib and also for vector.print to be fully customizable via hooks in CRunnerUtils.cpp.

Please wait for @aartbik to confirm before landing :)

Comment on lines -161 to +162
extern "C" void printCString(char *str) { printf("%s", str); }
/// Deprecated. This should be unified with printString from CRunnerUtils.
extern "C" void printCString(char *str) { fputs(str, stdout); }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] I would just skip these changes altoghether.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It follows on for #68973 (which removes all uses of printCString()), so I'll follow this up after that.

Copy link
Contributor

@aartbik aartbik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very kindly for taking my suggestion very seriously. Looks great!

@MacDue MacDue merged commit 3be3883 into llvm:main Oct 24, 2023
2 checks passed
MacDue added a commit to MacDue/llvm-project that referenced this pull request Oct 24, 2023
This cuts down on a fair amount of boilerplate.

Depends on: llvm#68695
MacDue added a commit that referenced this pull request Oct 25, 2023
This cuts down on a fair amount of boilerplate.

Depends on: #68695
MacDue added a commit that referenced this pull request Oct 26, 2023
This patch adds a pass that ensures that loads, stores, and allocations
of SVE vector types will be legal in the LLVM backend. It does this at
the memref level, so this pass must be applied before lowering all the
way to LLVM.

This pass currently fixes two issues.

## Loading and storing predicate types

It is only legal to load/store predicate types equal to (or greater
than) a full predicate register, which in MLIR is `vector<[16]xi1>`.
Smaller predicate types (`vector<[1|2|4|8]xi1>`) must be converted
to/from a full predicate type (referred to as a `svbool`) before and
after storing and loading respectively. This pass does this by widening
allocations and inserting conversion intrinsics.

For example:


```mlir
%alloca = memref.alloca() : memref<vector<[4]xi1>>
%mask = vector.constant_mask [4] : vector<[4]xi1>
memref.store %mask, %alloca[] : memref<vector<[4]xi1>>
%reload = memref.load %alloca[] : memref<vector<[4]xi1>>
```
Becomes:
```mlir
%alloca = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
%mask = vector.constant_mask [4] : vector<[4]xi1>
%svbool = arm_sve.convert_to_svbool %mask : vector<[4]xi1>
memref.store %svbool, %alloca[] : memref<vector<[16]xi1>>
%reload_svbool = memref.load %alloca[] : memref<vector<[16]xi1>>
%reload = arm_sve.convert_from_svbool %reload_svbool : vector<[4]xi1>
```

## Relax alignments for SVE vector allocas

The storage for SVE vector types only needs to have an alignment that
matches the element type (for example 4 byte alignment for `f32`s).
However, the LLVM backend currently defaults to aligning to `base size x
element size` bytes. For non-legal vector types like `vector<[8]xf32>`
this results in 8 x 4 = 32-byte alignment, but the backend only supports
up to 16-byte alignment for SVE vectors on the stack. Explicitly setting
a smaller alignment prevents this issue.

Depends on: #68586 and #68695 (for testing)
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Oct 26, 2023
This patch adds a pass that ensures that loads, stores, and allocations
of SVE vector types will be legal in the LLVM backend. It does this at
the memref level, so this pass must be applied before lowering all the
way to LLVM.

This pass currently fixes two issues.

## Loading and storing predicate types

It is only legal to load/store predicate types equal to (or greater
than) a full predicate register, which in MLIR is `vector<[16]xi1>`.
Smaller predicate types (`vector<[1|2|4|8]xi1>`) must be converted
to/from a full predicate type (referred to as a `svbool`) before and
after storing and loading respectively. This pass does this by widening
allocations and inserting conversion intrinsics.

For example:


```mlir
%alloca = memref.alloca() : memref<vector<[4]xi1>>
%mask = vector.constant_mask [4] : vector<[4]xi1>
memref.store %mask, %alloca[] : memref<vector<[4]xi1>>
%reload = memref.load %alloca[] : memref<vector<[4]xi1>>
```
Becomes:
```mlir
%alloca = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
%mask = vector.constant_mask [4] : vector<[4]xi1>
%svbool = arm_sve.convert_to_svbool %mask : vector<[4]xi1>
memref.store %svbool, %alloca[] : memref<vector<[16]xi1>>
%reload_svbool = memref.load %alloca[] : memref<vector<[16]xi1>>
%reload = arm_sve.convert_from_svbool %reload_svbool : vector<[4]xi1>
```

## Relax alignments for SVE vector allocas

The storage for SVE vector types only needs to have an alignment that
matches the element type (for example 4 byte alignment for `f32`s).
However, the LLVM backend currently defaults to aligning to `base size x
element size` bytes. For non-legal vector types like `vector<[8]xf32>`
this results in 8 x 4 = 32-byte alignment, but the backend only supports
up to 16-byte alignment for SVE vectors on the stack. Explicitly setting
a smaller alignment prevents this issue.

Depends on: llvm#68586 and llvm#68695 (for testing)
Vector_Op<"print", [
PredOpTrait<
"`source` or `punctuation` are not set when printing strings",
CPred<"!getStringLiteral() || (!getSource() && getPunctuation() == PrintPunctuation::NewLine)">
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why isn't other punctuation allowed? It would be very useful to be able to print strings without the trailing newline.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strings don't have a trailing newline.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lowering for them in LLVM::createPrintStrCall adds one unconditionally.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll post a complete patch for this in a bit, but I was thinking something along the lines of:

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index b66b55ae8d57..fd5aec1982eb 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1531,8 +1531,28 @@ public:
 
     auto punct = printOp.getPunctuation();
     if (auto stringLiteral = printOp.getStringLiteral()) {
+      std::string str;
+      llvm::raw_string_ostream punctuatedLiteral(str);
+      if (punct == PrintPunctuation::Open)
+        punctuatedLiteral << "( ";
+      punctuatedLiteral << stringLiteral->str();
+      switch (punct) {
+      case PrintPunctuation::Close:
+        punctuatedLiteral << " )";
+        break;
+      case PrintPunctuation::Comma:
+        punctuatedLiteral << ", ";
+        break;
+      case PrintPunctuation::NewLine:
+        punctuatedLiteral << '\n';
+        break;
+      case PrintPunctuation::Open:
+      case PrintPunctuation::NoPunctuation:
+        break;
+      }
       LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
-                               *stringLiteral, *getTypeConverter());
+                               punctuatedLiteral.str(), *getTypeConverter(),
+                               /*addNewLine=*/false);
     } else if (punct != PrintPunctuation::NoPunctuation) {
       emitCall(rewriter, printOp->getLoc(), [&] {
         switch (punct) {

so that:

diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir
index 78d6609ccaf9..b47c5b38f783 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir
@@ -8,6 +8,16 @@
 func.func @entry() {
    // CHECK: Hello, World!
    vector.print str "Hello, World!"
+
+   // CHECK-NEXT: Nice to meet you ( finally, today, and in this place )
+   vector.print str "Nice to meet you " punctuation <no_punctuation>
+   vector.print str "finally" punctuation <open>
+   vector.print punctuation <comma>
+   vector.print str "today" punctuation <comma>
+   vector.print str "and in " punctuation <no_punctuation>
+   vector.print str "this place" punctuation <close>
+   vector.print punctuation <newline>
+
    // CHECK-NEXT: Bye!
    vector.print str "Bye!"
    return

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm that's odd, I verified no newline is printed for strings yesterday and that seems to be the case, are we seeing different behaviour? The createPrintStrCall for vector.print <str> already has addNewLine=false

if (auto stringLiteral = printOp.getStringLiteral()) {
LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
*stringLiteral, *getTypeConverter(),
/*addNewline=*/false);

although for empty string it will print newline as that wont be entered.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, my branch is a bit behind. This has been fixed already: c1b8c6c

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants