Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse] Add more error messages and avoid crashing in new parser #67034

Merged
merged 2 commits into from
Sep 22, 2023

Conversation

yinying-lisa-li
Copy link
Contributor

Updates:

  1. Added more invalid encodings to test the robustness of the new syntax
  2. Changed the asserts that caused crashing into returning booleans
  3. Modified some error messages to make them clearer and handled failures in parsing quotes as keyword for level properties.

…parser

Updates:
1. Added more invalid encodings to test the robustness of the new syntax
2. Changed the asserts that caused crashing into returning booleans
3. Modified some error messages to make them clearer and handled failures in parsing quotes as keyword for level properties.
@llvmbot llvmbot added mlir:sparse Sparse compiler in MLIR mlir labels Sep 21, 2023
@llvmbot
Copy link
Member

llvmbot commented Sep 21, 2023

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Changes

Updates:

  1. Added more invalid encodings to test the robustness of the new syntax
  2. Changed the asserts that caused crashing into returning booleans
  3. Modified some error messages to make them clearer and handled failures in parsing quotes as keyword for level properties.

Full diff: https://github.com/llvm/llvm-project/pull/67034.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp (+9-6)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp (+19-31)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/invalid_encoding.mlir (+158-5)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index b8483f5db130dcf..020e0640d988cfc 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -49,9 +49,10 @@ using namespace mlir::sparse_tensor::ir_detail;
 
 FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
   StringRef base;
-  FAILURE_IF_FAILED(parser.parseOptionalKeyword(&base));
-  uint8_t properties = 0;
   const auto loc = parser.getCurrentLocation();
+  ERROR_IF(failed(parser.parseOptionalKeyword(&base)),
+           "expected valid keyword, such as compressed without quotes")
+  uint8_t properties = 0;
 
   ParseResult res = parser.parseCommaSeparatedList(
       mlir::OpAsmParser::Delimiter::OptionalParen,
@@ -73,19 +74,21 @@ FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
   } else if (base.compare("singleton") == 0) {
     properties |= static_cast<uint8_t>(LevelFormat::Singleton);
   } else {
-    parser.emitError(loc, "unknown level format");
+    parser.emitError(loc, "unknown level format: ") << base;
     return failure();
   }
 
   ERROR_IF(!isValidDLT(static_cast<DimLevelType>(properties)),
-           "invalid level type");
+           "invalid level type: level format doesn't support the properties");
   return properties;
 }
 
 ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
                                          uint8_t *properties) const {
   StringRef strVal;
-  FAILURE_IF_FAILED(parser.parseOptionalKeyword(&strVal));
+  auto loc = parser.getCurrentLocation();
+  ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)),
+           "expected valid keyword, such as nonordered without quotes.")
   if (strVal.compare("nonunique") == 0) {
     *properties |= static_cast<uint8_t>(LevelNondefaultProperty::Nonunique);
   } else if (strVal.compare("nonordered") == 0) {
@@ -95,7 +98,7 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
   } else if (strVal.compare("block2_4") == 0) {
     *properties |= static_cast<uint8_t>(LevelNondefaultProperty::Block2_4);
   } else {
-    parser.emitError(parser.getCurrentLocation(), "unknown level property");
+    parser.emitError(loc, "unknown level property: ") << strVal;
     return failure();
   }
   return success();
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
index 3b00e17657f1f97..44eba668021ba79 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
@@ -196,26 +196,17 @@ minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) {
   return pair1 <= pair2 ? sm1 : sm2;
 }
 
-LLVM_ATTRIBUTE_UNUSED static void
-assertInternalConsistency(VarEnv const &env, VarInfo::ID id, StringRef name) {
-#ifndef NDEBUG
+bool isInternalConsistent(VarEnv const &env, VarInfo::ID id, StringRef name) {
   const auto &var = env.access(id);
-  assert(var.getName() == name && "found inconsistent name");
-  assert(var.getID() == id && "found inconsistent VarInfo::ID");
-#endif // NDEBUG
+  return (var.getName() == name && var.getID() == id);
 }
 
 // NOTE(wrengr): if we can actually obtain an `AsmParser` for `minSMLoc`
 // (or find some other way to convert SMLoc to FileLineColLoc), then this
 // would no longer be `const VarEnv` (and couldn't be a free-function either).
-LLVM_ATTRIBUTE_UNUSED static void assertUsageConsistency(VarEnv const &env,
-                                                         VarInfo::ID id,
-                                                         llvm::SMLoc loc,
-                                                         VarKind vk) {
-#ifndef NDEBUG
+bool isUsageConsistent(VarEnv const &env, VarInfo::ID id, llvm::SMLoc loc,
+                       VarKind vk) {
   const auto &var = env.access(id);
-  assert(var.getKind() == vk &&
-         "a variable of that name already exists with a different VarKind");
   // Since the same variable can occur at several locations,
   // it would not be appropriate to do `assert(var.getLoc() == loc)`.
   /* TODO(wrengr):
@@ -223,7 +214,7 @@ LLVM_ATTRIBUTE_UNUSED static void assertUsageConsistency(VarEnv const &env,
   assert(minLoc && "Location mismatch/incompatibility");
   var.loc = minLoc;
   // */
-#endif // NDEBUG
+  return var.getKind() == vk;
 }
 
 std::optional<VarInfo::ID> VarEnv::lookup(StringRef name) const {
@@ -236,24 +227,23 @@ std::optional<VarInfo::ID> VarEnv::lookup(StringRef name) const {
   if (iter == ids.end())
     return std::nullopt;
   const auto id = iter->second;
-#ifndef NDEBUG
-  assertInternalConsistency(*this, id, name);
-#endif // NDEBUG
+  if (!isInternalConsistent(*this, id, name))
+    return std::nullopt;
   return id;
 }
 
-std::pair<VarInfo::ID, bool> VarEnv::create(StringRef name, llvm::SMLoc loc,
-                                            VarKind vk, bool verifyUsage) {
+std::optional<std::pair<VarInfo::ID, bool>>
+VarEnv::create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage) {
   const auto &[iter, didInsert] = ids.try_emplace(name, nextID());
   const auto id = iter->second;
   if (didInsert) {
     vars.emplace_back(id, name, loc, vk);
   } else {
-#ifndef NDEBUG
-    assertInternalConsistency(*this, id, name);
-    if (verifyUsage)
-      assertUsageConsistency(*this, id, loc, vk);
-#endif // NDEBUG
+  if (!isInternalConsistent(*this, id, name))
+    return std::nullopt;
+  if (verifyUsage)
+    if (!isUsageConsistent(*this, id, loc, vk))
+      return std::nullopt;
   }
   return std::make_pair(id, didInsert);
 }
@@ -265,20 +255,18 @@ VarEnv::lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc,
   case Policy::MustNot: {
     const auto oid = lookup(name);
     if (!oid)
-      return std::nullopt; // Doesn't exist, but must not create.
-#ifndef NDEBUG
-    assertUsageConsistency(*this, *oid, loc, vk);
-#endif // NDEBUG
+      return std::nullopt;  // Doesn't exist, but must not create.
+    if (!isUsageConsistent(*this, *oid, loc, vk))
+      return std::nullopt;
     return std::make_pair(*oid, false);
   }
   case Policy::May:
     return create(name, loc, vk, /*verifyUsage=*/true);
   case Policy::Must: {
     const auto res = create(name, loc, vk, /*verifyUsage=*/false);
-    // const auto id = res.first;
-    const auto didCreate = res.second;
+    const auto didCreate = res->second;
     if (!didCreate)
-      return std::nullopt; // Already exists, but must create.
+      return std::nullopt;  // Already exists, but must create.
     return res;
   }
   }
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
index a488b3ea2d56ba4..145586a83a2528c 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
@@ -453,8 +453,8 @@ class VarEnv final {
   /// for the variable with the given name (i.e., either the newly created
   /// variable, or the pre-existing variable), and a bool indicating whether
   /// a new variable was created.
-  std::pair<VarInfo::ID, bool> create(StringRef name, llvm::SMLoc loc,
-                                      VarKind vk, bool verifyUsage = false);
+  std::optional<std::pair<VarInfo::ID, bool>>
+  create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage = false);
 
   /// Attempts to lookup or create a variable according to the given
   /// `Policy`.  Returns nullopt in one of two circumstances:
diff --git a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
index 42eb4e0a46182e7..883ba9cc81fd8f0 100644
--- a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
@@ -1,7 +1,49 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics
 
-// expected-error@+1 {{expected a non-empty array for lvlTypes}}
-#a = #sparse_tensor.encoding<{lvlTypes = []}>
+// expected-error@+1 {{expected '(' in dimension-specifier list}}
+#a = #sparse_tensor.encoding<{map = []}>
+func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
+
+// -----
+
+// expected-error@+1 {{expected '->'}}
+#a = #sparse_tensor.encoding<{map = ()}>
+func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
+
+// -----
+
+// expected-error@+1 {{expected ')' in dimension-specifier list}}
+#a = #sparse_tensor.encoding<{map = (d0 -> d0)}>
+func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
+
+// -----
+
+// expected-error@+1 {{expected '(' in dimension-specifier list}}
+#a = #sparse_tensor.encoding<{map = d0 -> d0}>
+func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
+
+// -----
+
+// expected-error@+1 {{expected '(' in level-specifier list}}
+#a = #sparse_tensor.encoding<{map = (d0) -> d0}>
+func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
+
+// -----
+
+// expected-error@+1 {{expected ':'}}
+#a = #sparse_tensor.encoding<{map = (d0) -> (d0)}>
+func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
+
+// -----
+
+// expected-error@+1 {{expected valid keyword}}
+#a = #sparse_tensor.encoding<{map = (d0) -> (d0:)}>
+func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
+
+// -----
+
+// expected-error@+1 {{expected valid keyword}}
+#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : (compressed))}>
 func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
 
 // -----
@@ -18,17 +60,61 @@ func.func private @tensor_sizes_mismatch(%arg0: tensor<8xi32, #a>) -> ()
 
 // -----
 
-#a = #sparse_tensor.encoding<{lvlTypes = [1]}> // expected-error {{expected a string value in lvlTypes}}
+// expected-error@+1 {{unexpected dimToLvl mapping from 2 to 1}}
+#a = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : dense)}>
+func.func private @tensor_sizes_mismatch(%arg0: tensor<8xi32, #a>) -> ()
+
+// -----
+
+// expected-error@+1 {{expected bare identifier}}
+#a = #sparse_tensor.encoding<{map = (1)}>
+func.func private @tensor_type_mismatch(%arg0: tensor<8xi32, #a>) -> ()
+
+// -----
+
+// expected-error@+1 {{unexpected key: nap}}
+#a = #sparse_tensor.encoding<{nap = (d0) -> (d0 : dense)}>
+func.func private @tensor_type_mismatch(%arg0: tensor<8xi32, #a>) -> ()
+
+// -----
+
+// expected-error@+1 {{expected '(' in dimension-specifier list}}
+#a = #sparse_tensor.encoding<{map =  -> (d0 : dense)}>
 func.func private @tensor_type_mismatch(%arg0: tensor<8xi32, #a>) -> ()
 
 // -----
 
-#a = #sparse_tensor.encoding<{lvlTypes = ["strange"]}> // expected-error {{unexpected level-type: strange}}
+// expected-error@+1 {{unknown level format: strange}}
+#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : strange)}>
 func.func private @tensor_value_mismatch(%arg0: tensor<8xi32, #a>) -> ()
 
 // -----
 
-#a = #sparse_tensor.encoding<{dimToLvl = "wrong"}> // expected-error {{expected an affine map for dimToLvl}}
+// expected-error@+1 {{expected valid keyword, such as compressed without quotes}}
+#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : "wrong")}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<8xi32, #a>) -> ()
+
+// -----
+
+// expected-error@+1 {{expected valid keyword, such as nonordered without quotes}}
+#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : compressed("wrong"))}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<8xi32, #a>) -> ()
+
+// -----
+// expected-error@+1 {{expected ')' in level-specifier list}}
+#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : compressed[high])}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<8xi32, #a>) -> ()
+
+// -----
+
+// expected-error@+1 {{unknown level property: wrong}}
+#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : compressed(wrong))}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<8xi32, #a>) -> ()
+
+// -----
+
+// expected-error@+1 {{use of undeclared identifier}}
+#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : compressed, dense)}>
 func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<8xi32, #a>) -> ()
 
 // -----
@@ -39,6 +125,73 @@ func.func private @tensor_no_permutation(%arg0: tensor<16x32xf32, #a>) -> ()
 
 // -----
 
+// expected-error@+1 {{unexpected character}}
+#a = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : compressed; d1 : dense)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+
+// expected-error@+1 {{expected attribute value}}
+#a = #sparse_tensor.encoding<{map = (d0: d1) -> (d0 : compressed, d1 : dense)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+
+// expected-error@+1 {{expected ':'}}
+#a = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 = compressed, d1 = dense)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+
+// expected-error@+1 {{expected attribute value}}
+#a = #sparse_tensor.encoding<{map = (d0 : compressed, d1 : compressed)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+
+// expected-error@+1 {{use of undeclared identifier}}
+#a = #sparse_tensor.encoding<{map = (d0 = compressed, d1 = compressed)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+
+// expected-error@+1 {{use of undeclared identifier}}
+#a = #sparse_tensor.encoding<{map = (d0 = l0, d1 = l1) {l0, l1} -> (l0 = d0 : dense, l1 = d1 : compressed)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+
+// expected-error@+1 {{expected '='}}
+#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (l0 : d0 = dense, l1 : d1 = compressed)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+// expected-error@+1 {{use of undeclared identifier 'd0'}}
+#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (d0 : l0 = dense, d1 : l1 = compressed)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+// expected-error@+1 {{use of undeclared identifier 'd0'}}
+#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (d0 : dense, d1 : compressed)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+// expected-error@+1 {{expected '='}}
+#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (l0 : dense, l1 : compressed)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+// expected-error@+1 {{use of undeclared identifier}}
+#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (l0 = dense, l1 = compressed)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+// expected-error@+1 {{use of undeclared identifier 'd0'}}
+#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (d0 = l0 : dense, d1 = l1 : compressed)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+
 #a = #sparse_tensor.encoding<{posWidth = "x"}> // expected-error {{expected an integral position bitwidth}}
 func.func private @tensor_no_int_ptr(%arg0: tensor<16x32xf32, #a>) -> ()
 

Copy link
Contributor

@aartbik aartbik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love all the new tests!


// -----

// expected-error@+1 {{expected valid keyword}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"valid keyword' still seems a bit abstract; can we say level format?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! Done.

@yinying-lisa-li yinying-lisa-li merged commit 8466eb7 into llvm:main Sep 22, 2023
@yinying-lisa-li yinying-lisa-li deleted the grammar_check branch September 22, 2023 16:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants