Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] reverse int8 type's printing logic #69361

Merged
merged 2 commits into from
Oct 18, 2023

Conversation

yaochengji
Copy link
Member

@yaochengji yaochengji commented Oct 17, 2023

Specializing for 8-bit integers to ensure values are printed as integers in a generic way will cause a bug, see #69310

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Oct 17, 2023
@llvmbot
Copy link
Member

llvmbot commented Oct 17, 2023

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Chengji Yao (yaochengji)

Changes

Specializing for 8-bit integers to ensure values are printed as integers in a generic way will cause a bug., see #69310


Full diff: https://github.com/llvm/llvm-project/pull/69361.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td (+2-2)
  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+1-1)
  • (modified) mlir/include/mlir/IR/OpImplementation.h (+1-13)
  • (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+8-8)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index d761743a82bf86b..867c98078ae5171 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -58,8 +58,8 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
 
   let parameters = (ins
     AttrParameter<"::mlir::SymbolRefAttr", "cluster placed">:$cluster,
-    ArrayRefParameter<"::mlir::DenseI8ArrayAttr">:$split_axes,
-    OptionalArrayRefParameter<"int8_t">:$partial_axes,
+    ArrayRefParameter<"::mlir::DenseI64ArrayAttr">:$split_axes,
+    OptionalArrayRefParameter<"int64_t">:$partial_axes,
     OptionalParameter<"::mlir::mesh::Partial">:$partial_type
   );
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8ca4b6653104221..a8aa0a694bee29f 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -70,7 +70,7 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
   }];
   let arguments = (ins
     SymbolNameAttr:$sym_name,
-    I8Attr:$rank,
+    I64Attr:$rank,
     DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$dim_sizes
   );
   let assemblyFormat = [{
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 379392ace46961a..f1fabf95a68b7ad 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -350,8 +350,7 @@ template <typename AsmPrinterT, typename T,
                                !std::is_convertible<T &, Attribute &>::value &&
                                !std::is_convertible<T &, ValueRange>::value &&
                                !std::is_convertible<T &, APFloat &>::value &&
-                               !llvm::is_one_of<T, bool, int8_t, uint8_t, float,
-                                                double>::value,
+                               !llvm::is_one_of<T, bool, float, double>::value,
                            T> * = nullptr>
 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
                         AsmPrinterT &>
@@ -367,17 +366,6 @@ operator<<(AsmPrinterT &p, bool value) {
   return p << (value ? StringRef("true") : "false");
 }
 
-/// Specialization for 8-bit integers to ensure values are printed as integers
-// and not characters.
-template <
-    typename AsmPrinterT, typename T,
-    std::enable_if_t<llvm::is_one_of<T, int8_t, uint8_t>::value, T> * = nullptr>
-inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
-                        AsmPrinterT &>
-operator<<(AsmPrinterT &p, T value) {
-  return p << static_cast<int16_t>(value);
-}
-
 template <typename AsmPrinterT, typename ValueRangeT>
 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
                         AsmPrinterT &>
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index b2a47102528758c..e8dc14cf0fa9c04 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -47,7 +47,7 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
 
 LogicalResult ClusterOp::verify() {
   ArrayRef<int64_t> dimSizes = getDimSizes();
-  uint8_t rank = getRank();
+  uint64_t rank = getRank();
 
   if (rank == 0)
     return emitOpError("rank of cluster is expected to be a positive integer");
@@ -71,15 +71,15 @@ LogicalResult ClusterOp::verify() {
 
 LogicalResult
 MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
-                         SymbolRefAttr, ArrayRef<DenseI8ArrayAttr> splitAxes,
-                         ArrayRef<int8_t> partialAxes, Partial) {
+                         SymbolRefAttr, ArrayRef<DenseI64ArrayAttr> splitAxes,
+                         ArrayRef<int64_t> partialAxes, Partial) {
   // TODO: At present cluster symbol ref is not verified. This is due to the
   // difficulty in fetching the corresponding symbol op based on an attribute.
 
-  llvm::SmallSet<int8_t, 4> visitedAxes;
+  llvm::SmallSet<int64_t, 4> visitedAxes;
 
-  auto checkMeshAxis = [&](ArrayRef<int8_t> axesArray) -> LogicalResult {
-    for (int8_t axis : axesArray) {
+  auto checkMeshAxis = [&](ArrayRef<int64_t> axesArray) -> LogicalResult {
+    for (int64_t axis : axesArray) {
       if (axis < 0)
         return emitError() << "mesh axis is expected to be non-negative";
       if (!visitedAxes.insert(axis).second)
@@ -88,8 +88,8 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
     return success();
   };
 
-  for (DenseI8ArrayAttr subAxes : splitAxes) {
-    ArrayRef<int8_t> subAxesArray = subAxes.asArrayRef();
+  for (DenseI64ArrayAttr subAxes : splitAxes) {
+    ArrayRef<int64_t> subAxesArray = subAxes.asArrayRef();
     if (failed(checkMeshAxis(subAxesArray)))
       return failure();
   }

@sogartar
Copy link
Contributor

LGTM

@@ -58,8 +58,8 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {

let parameters = (ins
AttrParameter<"::mlir::SymbolRefAttr", "cluster placed">:$cluster,
ArrayRefParameter<"::mlir::DenseI8ArrayAttr">:$split_axes,
OptionalArrayRefParameter<"int8_t">:$partial_axes,
ArrayRefParameter<"::mlir::DenseI64ArrayAttr">:$split_axes,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about DenseI32ArrayAttr?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified, both split_axes and partial_axes to int32

@rorth
Copy link
Collaborator

rorth commented Oct 18, 2023

FWIW, I've tested this patch on both amd64-pc-solaris2.11 and sparcv9-sun-solaris2.11 and MLIR/Flang test results are back to normal. Thanks for the quick fix.

@joker-eph
Copy link
Collaborator

@yaochengji you most recent commit is authored from a @hotmial.com email, there is a typo there.
(Merging with the right email now).

@joker-eph joker-eph merged commit 1a21196 into llvm:main Oct 18, 2023
2 checks passed
@yaochengji
Copy link
Member Author

@yaochengji you most recent commit is authored from a @hotmial.com email, there is a typo there. (Merging with the right email now).

Oh, Thanks! It is now corrected on my dev machine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants