Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][DLTI] Pretty parsing and printing for DLTI attrs #113365

Merged
merged 5 commits into from
Oct 31, 2024

Conversation

rolfmorel
Copy link
Contributor

Unifies parsing and printing for DLTI attributes. Introduces a format of #dlti.attr<key1 = val1, ..., keyN = valN> syntax for all queryable DLTI attributes similar to that of the DictionaryAttr, while retaining support for specifying key-value pairs with #dlti.dl_entry (whether to retain this is TBD).

As the new format does away with all the boilerplate, it is much easier to parse for humans. This makes an especially big difference for nested attributes.

Updates the DLTI-using tests and includes fixes for misc error checking/ error messages.

@llvmbot
Copy link
Member

llvmbot commented Oct 22, 2024

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir-gpu

Author: Rolf Morel (rolfmorel)

Changes

Unifies parsing and printing for DLTI attributes. Introduces a format of #dlti.attr&lt;key1 = val1, ..., keyN = valN&gt; syntax for all queryable DLTI attributes similar to that of the DictionaryAttr, while retaining support for specifying key-value pairs with #dlti.dl_entry (whether to retain this is TBD).

As the new format does away with all the boilerplate, it is much easier to parse for humans. This makes an especially big difference for nested attributes.

Updates the DLTI-using tests and includes fixes for misc error checking/ error messages.


Patch is 73.01 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/113365.diff

12 Files Affected:

  • (modified) mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td (+28-22)
  • (modified) mlir/include/mlir/Interfaces/DataLayoutInterfaces.h (+2-3)
  • (modified) mlir/include/mlir/Interfaces/DataLayoutInterfaces.td (+1-1)
  • (modified) mlir/lib/Dialect/DLTI/DLTI.cpp (+191-129)
  • (modified) mlir/lib/Interfaces/DataLayoutInterfaces.cpp (+11-2)
  • (modified) mlir/test/Dialect/DLTI/invalid.mlir (+14-23)
  • (modified) mlir/test/Dialect/DLTI/query.mlir (+51-52)
  • (modified) mlir/test/Dialect/DLTI/roundtrip.mlir (+29-21)
  • (modified) mlir/test/Dialect/DLTI/valid.mlir (+124-93)
  • (modified) mlir/test/Dialect/GPU/outlining.mlir (+13-13)
  • (modified) mlir/test/Target/LLVMIR/Import/data-layout.ll (+24-24)
  • (modified) mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp (+22-16)
diff --git a/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td b/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
index 53d38407608bed..1caf5fd8787c7b 100644
--- a/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
+++ b/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
@@ -93,6 +93,10 @@ def DLTI_DataLayoutSpecAttr :
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// MapAttr
+//===----------------------------------------------------------------------===//
+
 def DLTI_MapAttr : DLTIAttr<"Map", [DLTIQueryInterface]> {
   let summary = "A mapping of DLTI-information by way of key-value pairs";
   let description = [{
@@ -106,18 +110,16 @@ def DLTI_MapAttr : DLTIAttr<"Map", [DLTIQueryInterface]> {
 
     Consider the following flat encoding of a single-key dictionary
     ```
-    #dlti.map<#dlti.dl_entry<"CPU::cache::L1::size_in_bytes", 65536 : i32>>
+    #dlti.map<"CPU::cache::L1::size_in_bytes" = 65536 : i32>>
     ```
     versus nested maps, which make it possible to obtain sub-dictionaries of
     related information (with the following example making use of other
     attributes that also implement the `DLTIQueryInterface`):
     ```
-    #dlti.target_system_spec<"CPU":
-      #dlti.target_device_spec<#dlti.dl_entry<"cache",
-        #dlti.map<#dlti.dl_entry<"L1",
-          #dlti.map<#dlti.dl_entry<"size_in_bytes", 65536 : i32>>>,
-                  #dlti.dl_entry<"L1d",
-          #dlti.map<#dlti.dl_entry<"size_in_bytes", 32768 : i32>>> >>>>
+    #dlti.target_system_spec<"CPU" =
+      #dlti.target_device_spec<"cache" =
+        #dlti.map<"L1" = #dlti.map<"size_in_bytes" = 65536 : i32>,
+                  "L1d" = #dlti.map<"size_in_bytes" = 32768 : i32> >>>
     ```
 
     With the flat encoding, the implied structure of the key is ignored, that is
@@ -139,7 +141,7 @@ def DLTI_MapAttr : DLTIAttr<"Map", [DLTIQueryInterface]> {
   );
   let mnemonic = "map";
   let genVerifyDecl = 1;
-  let assemblyFormat = "`<` $entries `>`";
+  let hasCustomAssemblyFormat = 1;
   let extraClassDeclaration = [{
     /// Returns the attribute associated with the key.
     FailureOr<Attribute> query(DataLayoutEntryKey key) {
@@ -167,20 +169,23 @@ def DLTI_TargetSystemSpecAttr :
     ```
     dlti.target_system_spec =
      #dlti.target_system_spec<
-      "CPU": #dlti.target_device_spec<
-              #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096: ui32>>,
-      "GPU": #dlti.target_device_spec<
-              #dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>,
-      "XPU": #dlti.target_device_spec<
-              #dlti.dl_entry<"dlti.max_vector_op_width", 4096 : ui32>>>
+      "CPU" = #dlti.target_device_spec<
+        "L1_cache_size_in_bytes" = 4096: ui32>,
+      "GPU" = #dlti.target_device_spec<
+        "max_vector_op_width" = 64 : ui32>,
+      "XPU" = #dlti.target_device_spec<
+        "max_vector_op_width" = 4096 : ui32>>
     ```
+
+    The verifier checks that keys are strings and pointed to values implement
+    DLTI's TargetDeviceSpecInterface.
   }];
   let parameters = (ins
-    ArrayRefParameter<"DeviceIDTargetDeviceSpecPair", "">:$entries
+    ArrayRefParameter<"DataLayoutEntryInterface">:$entries
   );
   let mnemonic = "target_system_spec";
   let genVerifyDecl = 1;
-  let assemblyFormat = "`<` $entries `>`";
+  let hasCustomAssemblyFormat = 1;
   let extraClassDeclaration = [{
     /// Return the device specification that matches the given device ID
     std::optional<TargetDeviceSpecInterface>
@@ -197,8 +202,10 @@ def DLTI_TargetSystemSpecAttr :
     $cppClass::getDeviceSpecForDeviceID(
         TargetSystemSpecInterface::DeviceID deviceID) {
       for (const auto& entry : getEntries()) {
-        if (entry.first == deviceID)
-          return entry.second;
+        if (entry.getKey() == DataLayoutEntryKey(deviceID))
+          if (auto deviceSpec =
+              llvm::dyn_cast<TargetDeviceSpecInterface>(entry.getValue()))
+            return deviceSpec;
       }
       return std::nullopt;
     }
@@ -219,16 +226,15 @@ def DLTI_TargetDeviceSpecAttr :
 
     Example:
     ```
-    #dlti.target_device_spec<
-      #dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>
+    #dlti.target_device_spec<"max_vector_op_width" = 64 : ui32>
     ```
   }];
   let parameters = (ins
-    ArrayRefParameter<"DataLayoutEntryInterface", "">:$entries
+    ArrayRefParameter<"DataLayoutEntryInterface">:$entries
   );
   let mnemonic = "target_device_spec";
   let genVerifyDecl = 1;
-  let assemblyFormat = "`<` $entries `>`";
+  let hasCustomAssemblyFormat = 1;
 
   let extraClassDeclaration = [{
     /// Returns the attribute associated with the key.
diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
index 848d2dee4a6309..e3fdf85b15ea59 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
@@ -15,6 +15,7 @@
 #ifndef MLIR_INTERFACES_DATALAYOUTINTERFACES_H
 #define MLIR_INTERFACES_DATALAYOUTINTERFACES_H
 
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/DialectInterface.h"
 #include "mlir/IR/OpDefinition.h"
 #include "llvm/ADT/DenseMap.h"
@@ -32,10 +33,8 @@ using DataLayoutEntryKey = llvm::PointerUnion<Type, StringAttr>;
 using DataLayoutEntryList = llvm::SmallVector<DataLayoutEntryInterface, 4>;
 using DataLayoutEntryListRef = llvm::ArrayRef<DataLayoutEntryInterface>;
 using TargetDeviceSpecListRef = llvm::ArrayRef<TargetDeviceSpecInterface>;
-using DeviceIDTargetDeviceSpecPair =
+using TargetDeviceSpecEntry =
     std::pair<StringAttr, TargetDeviceSpecInterface>;
-using DeviceIDTargetDeviceSpecPairListRef =
-    llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>;
 class DataLayoutOpInterface;
 class DataLayoutSpecInterface;
 class ModuleOp;
diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
index d6e955be4291a3..061dee2399d9ad 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
@@ -304,7 +304,7 @@ def TargetSystemSpecInterface : AttrInterface<"TargetSystemSpecInterface", [DLTI
   let methods = [
     InterfaceMethod<
       /*description=*/"Returns the list of layout entries.",
-      /*retTy=*/"llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>",
+      /*retTy=*/"llvm::ArrayRef<DataLayoutEntryInterface>",
       /*methodName=*/"getEntries",
       /*args=*/(ins)
     >,
diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp
index 85ec9fc93248a1..d8946d865a1836 100644
--- a/mlir/lib/Dialect/DLTI/DLTI.cpp
+++ b/mlir/lib/Dialect/DLTI/DLTI.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -28,6 +29,123 @@ using namespace mlir;
 
 #define DEBUG_TYPE "dlti"
 
+//===----------------------------------------------------------------------===//
+// parsing
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseKeyValuePair(AsmParser &parser,
+                                     DataLayoutEntryInterface &entry,
+                                     bool tryType = false) {
+  Attribute value;
+
+  if (tryType) {
+    Type type;
+    OptionalParseResult parsedType = parser.parseOptionalType(type);
+    if (parsedType.has_value()) {
+      if (failed(parsedType.value()))
+        return parser.emitError(parser.getCurrentLocation())
+               << "error while parsing type DLTI key";
+
+      if (failed(parser.parseEqual()) || failed(parser.parseAttribute(value)))
+        return failure();
+
+      entry = DataLayoutEntryAttr::get(type, value);
+      return ParseResult::success();
+    }
+  }
+
+  std::string ident;
+  OptionalParseResult parsedStr = parser.parseOptionalString(&ident);
+  if (parsedStr.has_value() && !ident.empty()) {
+    if (failed(parsedStr.value()))
+      return parser.emitError(parser.getCurrentLocation())
+             << "error while parsing string DLTI key";
+
+    if (failed(parser.parseEqual()) || failed(parser.parseAttribute(value)))
+      return failure(); // Assume that an error has already been emitted.
+
+    entry = DataLayoutEntryAttr::get(
+        StringAttr::get(parser.getContext(), ident), value);
+    return ParseResult::success();
+  }
+
+  OptionalParseResult parsedEntry = parser.parseAttribute(entry);
+  if (parsedEntry.has_value()) {
+    if (succeeded(parsedEntry.value()))
+      return parsedEntry.value();
+    return failure(); // Assume that an error has already been emitted.
+  }
+  return parser.emitError(parser.getCurrentLocation())
+         << "failed to parse DLTI entry";
+}
+
+template <class Attr>
+static Attribute parseAngleBracketedEntries(AsmParser &parser, Type ty,
+                                            bool tryType = false,
+                                            bool allowEmpty = false) {
+  SmallVector<DataLayoutEntryInterface> entries;
+  if (failed(parser.parseCommaSeparatedList(
+          AsmParser::Delimiter::LessGreater, [&]() {
+            return parseKeyValuePair(parser, entries.emplace_back(), tryType);
+          })))
+    return {};
+
+  if (entries.empty() && !allowEmpty) {
+    parser.emitError(parser.getNameLoc()) << "no DLTI entries provided";
+    return {};
+  }
+
+  return Attr::getChecked([&] { return parser.emitError(parser.getNameLoc()); },
+                          parser.getContext(), ArrayRef(entries));
+}
+
+//===----------------------------------------------------------------------===//
+// printing
+//===----------------------------------------------------------------------===//
+
+static inline std::string keyToStr(DataLayoutEntryKey key) {
+  std::string buf;
+  llvm::TypeSwitch<DataLayoutEntryKey>(key)
+      .Case<StringAttr, Type>( // The only two kinds of key we know of.
+          [&](auto key) { llvm::raw_string_ostream(buf) << key; })
+      .Default([](auto) { llvm_unreachable("unexpected entry key kind"); });
+  return buf;
+}
+
+template <class T>
+static void printAngleBracketedEntries(AsmPrinter &os, T &&entries) {
+  os << "<";
+  llvm::interleaveComma(std::forward<T>(entries), os, [&](auto entry) {
+    os << keyToStr(entry.getKey()) << " = " << entry.getValue();
+  });
+  os << ">";
+}
+
+//===----------------------------------------------------------------------===//
+// verifying
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyEntries(function_ref<InFlightDiagnostic()> emitError,
+                                   ArrayRef<DataLayoutEntryInterface> entries,
+                                   bool allowTypes = true) {
+  DenseSet<DataLayoutEntryKey> keys;
+  for (DataLayoutEntryInterface entry : entries) {
+    if (!entry)
+      return emitError() << "contained invalid DLTI entry";
+    DataLayoutEntryKey key = entry.getKey();
+    if (key.isNull())
+      return emitError() << "contained invalid DLTI key";
+    if (!allowTypes && llvm::dyn_cast<Type>(key))
+      return emitError() << "type as DLIT key is not allowed";
+    if (!keys.insert(key).second)
+      return emitError() << "repeated DLTI key: " << keyToStr(key);
+    if (!entry.getValue())
+      return emitError() << "value associated to DLTI key " << keyToStr(key)
+                         << " is invalid";
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // DataLayoutEntryAttr
 //===----------------------------------------------------------------------===//
@@ -71,15 +189,16 @@ DataLayoutEntryKey DataLayoutEntryAttr::getKey() const {
 Attribute DataLayoutEntryAttr::getValue() const { return getImpl()->value; }
 
 /// Parses an attribute with syntax:
-///   attr ::= `#target.` `dl_entry` `<` (type | quoted-string) `,` attr `>`
-Attribute DataLayoutEntryAttr::parse(AsmParser &parser, Type ty) {
+///   dl-entry-attr ::= `#dlti.` `dl_entry` `<` (type | quoted-string) `,`
+///     attr `>`
+Attribute DataLayoutEntryAttr::parse(AsmParser &parser, Type type) {
   if (failed(parser.parseLess()))
     return {};
 
-  Type type = nullptr;
+  Type typeKey = nullptr;
   std::string identifier;
   SMLoc idLoc = parser.getCurrentLocation();
-  OptionalParseResult parsedType = parser.parseOptionalType(type);
+  OptionalParseResult parsedType = parser.parseOptionalType(typeKey);
   if (parsedType.has_value() && failed(parsedType.value()))
     return {};
   if (!parsedType.has_value()) {
@@ -95,38 +214,29 @@ Attribute DataLayoutEntryAttr::parse(AsmParser &parser, Type ty) {
       failed(parser.parseGreater()))
     return {};
 
-  return type ? get(type, value)
-              : get(parser.getBuilder().getStringAttr(identifier), value);
+  return typeKey ? get(typeKey, value)
+                 : get(parser.getBuilder().getStringAttr(identifier), value);
 }
 
-void DataLayoutEntryAttr::print(AsmPrinter &os) const {
-  os << "<";
-  if (auto type = llvm::dyn_cast_if_present<Type>(getKey()))
-    os << type;
-  else
-    os << "\"" << getKey().get<StringAttr>().strref() << "\"";
-  os << ", " << getValue() << ">";
+void DataLayoutEntryAttr::print(AsmPrinter &printer) const {
+  printer << "<" << keyToStr(getKey()) << ", " << getValue() << ">";
 }
 
 //===----------------------------------------------------------------------===//
 // DLTIMapAttr
 //===----------------------------------------------------------------------===//
 
-static LogicalResult verifyEntries(function_ref<InFlightDiagnostic()> emitError,
-                                   ArrayRef<DataLayoutEntryInterface> entries) {
-  DenseSet<Type> types;
-  DenseSet<StringAttr> ids;
-  for (DataLayoutEntryInterface entry : entries) {
-    if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
-      if (!types.insert(type).second)
-        return emitError() << "repeated layout entry key: " << type;
-    } else {
-      auto id = entry.getKey().get<StringAttr>();
-      if (!ids.insert(id).second)
-        return emitError() << "repeated layout entry key: " << id.getValue();
-    }
-  }
-  return success();
+/// Parses an attribute with syntax:
+///   map-attr ::= `#dlti.` `map` `<` entry-list `>`
+///   entry-list ::= entry | entry `,` entry-list
+///   entry ::= ((type | quoted-string) `=` attr) | dl-entry-attr
+Attribute MapAttr::parse(AsmParser &parser, Type type) {
+  return parseAngleBracketedEntries<MapAttr>(parser, type, /*tryType=*/true,
+                                             /*allowEmpty=*/true);
+}
+
+void MapAttr::print(AsmPrinter &printer) const {
+  printAngleBracketedEntries(printer, getEntries());
 }
 
 LogicalResult MapAttr::verify(function_ref<InFlightDiagnostic()> emitError,
@@ -282,98 +392,40 @@ DataLayoutSpecAttr::getStackAlignmentIdentifier(MLIRContext *context) const {
       DLTIDialect::kDataLayoutStackAlignmentKey);
 }
 
-/// Parses an attribute with syntax
-///   attr ::= `#target.` `dl_spec` `<` attr-list? `>`
-///   attr-list ::= attr
-///               | attr `,` attr-list
+/// Parses an attribute with syntax:
+///   dl-spec-attr ::= `#dlti.` `dl_spec` `<` entry-list `>`
+///   entry-list ::= | entry | entry `,` entry-list
+///   entry ::= ((type | quoted-string) = attr) | dl-entry-attr
 Attribute DataLayoutSpecAttr::parse(AsmParser &parser, Type type) {
-  if (failed(parser.parseLess()))
-    return {};
-
-  // Empty spec.
-  if (succeeded(parser.parseOptionalGreater()))
-    return get(parser.getContext(), {});
-
-  SmallVector<DataLayoutEntryInterface> entries;
-  if (parser.parseCommaSeparatedList(
-          [&]() { return parser.parseAttribute(entries.emplace_back()); }) ||
-      parser.parseGreater())
-    return {};
-
-  return getChecked([&] { return parser.emitError(parser.getNameLoc()); },
-                    parser.getContext(), entries);
+  return parseAngleBracketedEntries<DataLayoutSpecAttr>(parser, type,
+                                                        /*tryType=*/true,
+                                                        /*allowEmpty=*/true);
 }
 
-void DataLayoutSpecAttr::print(AsmPrinter &os) const {
-  os << "<";
-  llvm::interleaveComma(getEntries(), os);
-  os << ">";
+void DataLayoutSpecAttr::print(AsmPrinter &printer) const {
+  printAngleBracketedEntries(printer, getEntries());
 }
 
 //===----------------------------------------------------------------------===//
 // TargetDeviceSpecAttr
 //===----------------------------------------------------------------------===//
 
-namespace mlir {
-/// A FieldParser for key-value pairs of DeviceID-target device spec pairs that
-/// make up a target system spec.
-template <>
-struct FieldParser<DeviceIDTargetDeviceSpecPair> {
-  static FailureOr<DeviceIDTargetDeviceSpecPair> parse(AsmParser &parser) {
-    std::string deviceID;
-
-    if (failed(parser.parseString(&deviceID))) {
-      parser.emitError(parser.getCurrentLocation())
-          << "DeviceID is missing, or is not of string type";
-      return failure();
-    }
-
-    if (failed(parser.parseColon())) {
-      parser.emitError(parser.getCurrentLocation()) << "Missing colon";
-      return failure();
-    }
-
-    auto target_device_spec =
-        FieldParser<TargetDeviceSpecInterface>::parse(parser);
-    if (failed(target_device_spec)) {
-      parser.emitError(parser.getCurrentLocation())
-          << "Error in parsing target device spec";
-      return failure();
-    }
-
-    return std::make_pair(parser.getBuilder().getStringAttr(deviceID),
-                          *target_device_spec);
-  }
-};
-
-inline AsmPrinter &operator<<(AsmPrinter &printer,
-                              DeviceIDTargetDeviceSpecPair param) {
-  return printer << param.first << " : " << param.second;
-}
-
-} // namespace mlir
-
 LogicalResult
 TargetDeviceSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
                              ArrayRef<DataLayoutEntryInterface> entries) {
-  // Entries in a target device spec can only have StringAttr as key. It does
-  // not support type as a key. Hence not reusing
-  // DataLayoutEntryInterface::verify.
-  DenseSet<StringAttr> ids;
-  for (DataLayoutEntryInterface entry : entries) {
-    if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
-      return emitError()
-             << "dlti.target_device_spec does not allow type as a key: "
-             << type;
-    } else {
-      // Check that keys in a target device spec are unique.
-      auto id = entry.getKey().get<StringAttr>();
-      if (!ids.insert(id).second)
-        return emitError() << "repeated layout entry key: " << id.getValue();
-    }
-  }
+  return verifyEntries(emitError, entries, /*allowTypes=*/false);
+}
 
-  return success();
+/// Parses an attribute with syntax:
+///   dev-spec-attr ::= `#dlti.` `target_device_spec` `<` entry-list `>`
+///   entry-list ::= entry | entry `,` entry-list
+///   entry ::= (quoted-string `=` attr) | dl-entry-attr
+Attribute TargetDeviceSpecAttr::parse(AsmParser &parser, Type type) {
+  return parseAngleBracketedEntries<TargetDeviceSpecAttr>(parser, type);
+}
+
+void TargetDeviceSpecAttr::print(AsmPrinter &printer) const {
+  printAngleBracketedEntries(printer, getEntries());
 }
 
 //===----------------------------------------------------------------------===//
@@ -382,27 +434,46 @@ TargetDeviceSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
 
 LogicalResult
 TargetSystemSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
-                             ArrayRef<DeviceIDTargetDeviceSpecPair> entries) {
-  DenseSet<TargetSystemSpecInterface::DeviceID> device_ids;
+                             ArrayRef<DataLayoutEntryInterface> entries) {
+  DenseSet<TargetSystemSpecInterface::DeviceID> deviceIds;
 
   for (const auto &entry : entries) {
-    TargetDeviceSpecInterface target_device_spec = entry.second;
-
-    // First verify that a target device spec is valid.
-    if (failed(TargetDeviceSpecAttr::verify(emitError,
-                                            target_device_spec.getEntries())))
-      return failure();
+    auto deviceId =
+        llvm::dyn_cast<TargetSystemSpecInterface::DeviceID>(entry.getKey());
+    if (!deviceId)
+      return emitError() << "non-string key of DLTI system spec";
+
+    if (auto targetDeviceSpec...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Oct 22, 2024

@llvm/pr-subscribers-mlir-dlti

Author: Rolf Morel (rolfmorel)

Changes

Unifies parsing and printing for DLTI attributes. Introduces a format of #dlti.attr&lt;key1 = val1, ..., keyN = valN&gt; syntax for all queryable DLTI attributes similar to that of the DictionaryAttr, while retaining support for specifying key-value pairs with #dlti.dl_entry (whether to retain this is TBD).

As the new format does away with all the boilerplate, it is much easier to parse for humans. This makes an especially big difference for nested attributes.

Updates the DLTI-using tests and includes fixes for misc error checking/ error messages.


Patch is 73.01 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/113365.diff

12 Files Affected:

  • (modified) mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td (+28-22)
  • (modified) mlir/include/mlir/Interfaces/DataLayoutInterfaces.h (+2-3)
  • (modified) mlir/include/mlir/Interfaces/DataLayoutInterfaces.td (+1-1)
  • (modified) mlir/lib/Dialect/DLTI/DLTI.cpp (+191-129)
  • (modified) mlir/lib/Interfaces/DataLayoutInterfaces.cpp (+11-2)
  • (modified) mlir/test/Dialect/DLTI/invalid.mlir (+14-23)
  • (modified) mlir/test/Dialect/DLTI/query.mlir (+51-52)
  • (modified) mlir/test/Dialect/DLTI/roundtrip.mlir (+29-21)
  • (modified) mlir/test/Dialect/DLTI/valid.mlir (+124-93)
  • (modified) mlir/test/Dialect/GPU/outlining.mlir (+13-13)
  • (modified) mlir/test/Target/LLVMIR/Import/data-layout.ll (+24-24)
  • (modified) mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp (+22-16)
diff --git a/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td b/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
index 53d38407608bed..1caf5fd8787c7b 100644
--- a/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
+++ b/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
@@ -93,6 +93,10 @@ def DLTI_DataLayoutSpecAttr :
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// MapAttr
+//===----------------------------------------------------------------------===//
+
 def DLTI_MapAttr : DLTIAttr<"Map", [DLTIQueryInterface]> {
   let summary = "A mapping of DLTI-information by way of key-value pairs";
   let description = [{
@@ -106,18 +110,16 @@ def DLTI_MapAttr : DLTIAttr<"Map", [DLTIQueryInterface]> {
 
     Consider the following flat encoding of a single-key dictionary
     ```
-    #dlti.map<#dlti.dl_entry<"CPU::cache::L1::size_in_bytes", 65536 : i32>>
+    #dlti.map<"CPU::cache::L1::size_in_bytes" = 65536 : i32>>
     ```
     versus nested maps, which make it possible to obtain sub-dictionaries of
     related information (with the following example making use of other
     attributes that also implement the `DLTIQueryInterface`):
     ```
-    #dlti.target_system_spec<"CPU":
-      #dlti.target_device_spec<#dlti.dl_entry<"cache",
-        #dlti.map<#dlti.dl_entry<"L1",
-          #dlti.map<#dlti.dl_entry<"size_in_bytes", 65536 : i32>>>,
-                  #dlti.dl_entry<"L1d",
-          #dlti.map<#dlti.dl_entry<"size_in_bytes", 32768 : i32>>> >>>>
+    #dlti.target_system_spec<"CPU" =
+      #dlti.target_device_spec<"cache" =
+        #dlti.map<"L1" = #dlti.map<"size_in_bytes" = 65536 : i32>,
+                  "L1d" = #dlti.map<"size_in_bytes" = 32768 : i32> >>>
     ```
 
     With the flat encoding, the implied structure of the key is ignored, that is
@@ -139,7 +141,7 @@ def DLTI_MapAttr : DLTIAttr<"Map", [DLTIQueryInterface]> {
   );
   let mnemonic = "map";
   let genVerifyDecl = 1;
-  let assemblyFormat = "`<` $entries `>`";
+  let hasCustomAssemblyFormat = 1;
   let extraClassDeclaration = [{
     /// Returns the attribute associated with the key.
     FailureOr<Attribute> query(DataLayoutEntryKey key) {
@@ -167,20 +169,23 @@ def DLTI_TargetSystemSpecAttr :
     ```
     dlti.target_system_spec =
      #dlti.target_system_spec<
-      "CPU": #dlti.target_device_spec<
-              #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096: ui32>>,
-      "GPU": #dlti.target_device_spec<
-              #dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>,
-      "XPU": #dlti.target_device_spec<
-              #dlti.dl_entry<"dlti.max_vector_op_width", 4096 : ui32>>>
+      "CPU" = #dlti.target_device_spec<
+        "L1_cache_size_in_bytes" = 4096: ui32>,
+      "GPU" = #dlti.target_device_spec<
+        "max_vector_op_width" = 64 : ui32>,
+      "XPU" = #dlti.target_device_spec<
+        "max_vector_op_width" = 4096 : ui32>>
     ```
+
+    The verifier checks that keys are strings and pointed to values implement
+    DLTI's TargetDeviceSpecInterface.
   }];
   let parameters = (ins
-    ArrayRefParameter<"DeviceIDTargetDeviceSpecPair", "">:$entries
+    ArrayRefParameter<"DataLayoutEntryInterface">:$entries
   );
   let mnemonic = "target_system_spec";
   let genVerifyDecl = 1;
-  let assemblyFormat = "`<` $entries `>`";
+  let hasCustomAssemblyFormat = 1;
   let extraClassDeclaration = [{
     /// Return the device specification that matches the given device ID
     std::optional<TargetDeviceSpecInterface>
@@ -197,8 +202,10 @@ def DLTI_TargetSystemSpecAttr :
     $cppClass::getDeviceSpecForDeviceID(
         TargetSystemSpecInterface::DeviceID deviceID) {
       for (const auto& entry : getEntries()) {
-        if (entry.first == deviceID)
-          return entry.second;
+        if (entry.getKey() == DataLayoutEntryKey(deviceID))
+          if (auto deviceSpec =
+              llvm::dyn_cast<TargetDeviceSpecInterface>(entry.getValue()))
+            return deviceSpec;
       }
       return std::nullopt;
     }
@@ -219,16 +226,15 @@ def DLTI_TargetDeviceSpecAttr :
 
     Example:
     ```
-    #dlti.target_device_spec<
-      #dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>
+    #dlti.target_device_spec<"max_vector_op_width" = 64 : ui32>
     ```
   }];
   let parameters = (ins
-    ArrayRefParameter<"DataLayoutEntryInterface", "">:$entries
+    ArrayRefParameter<"DataLayoutEntryInterface">:$entries
   );
   let mnemonic = "target_device_spec";
   let genVerifyDecl = 1;
-  let assemblyFormat = "`<` $entries `>`";
+  let hasCustomAssemblyFormat = 1;
 
   let extraClassDeclaration = [{
     /// Returns the attribute associated with the key.
diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
index 848d2dee4a6309..e3fdf85b15ea59 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
@@ -15,6 +15,7 @@
 #ifndef MLIR_INTERFACES_DATALAYOUTINTERFACES_H
 #define MLIR_INTERFACES_DATALAYOUTINTERFACES_H
 
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/DialectInterface.h"
 #include "mlir/IR/OpDefinition.h"
 #include "llvm/ADT/DenseMap.h"
@@ -32,10 +33,8 @@ using DataLayoutEntryKey = llvm::PointerUnion<Type, StringAttr>;
 using DataLayoutEntryList = llvm::SmallVector<DataLayoutEntryInterface, 4>;
 using DataLayoutEntryListRef = llvm::ArrayRef<DataLayoutEntryInterface>;
 using TargetDeviceSpecListRef = llvm::ArrayRef<TargetDeviceSpecInterface>;
-using DeviceIDTargetDeviceSpecPair =
+using TargetDeviceSpecEntry =
     std::pair<StringAttr, TargetDeviceSpecInterface>;
-using DeviceIDTargetDeviceSpecPairListRef =
-    llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>;
 class DataLayoutOpInterface;
 class DataLayoutSpecInterface;
 class ModuleOp;
diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
index d6e955be4291a3..061dee2399d9ad 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
@@ -304,7 +304,7 @@ def TargetSystemSpecInterface : AttrInterface<"TargetSystemSpecInterface", [DLTI
   let methods = [
     InterfaceMethod<
       /*description=*/"Returns the list of layout entries.",
-      /*retTy=*/"llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>",
+      /*retTy=*/"llvm::ArrayRef<DataLayoutEntryInterface>",
       /*methodName=*/"getEntries",
       /*args=*/(ins)
     >,
diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp
index 85ec9fc93248a1..d8946d865a1836 100644
--- a/mlir/lib/Dialect/DLTI/DLTI.cpp
+++ b/mlir/lib/Dialect/DLTI/DLTI.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -28,6 +29,123 @@ using namespace mlir;
 
 #define DEBUG_TYPE "dlti"
 
+//===----------------------------------------------------------------------===//
+// parsing
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseKeyValuePair(AsmParser &parser,
+                                     DataLayoutEntryInterface &entry,
+                                     bool tryType = false) {
+  Attribute value;
+
+  if (tryType) {
+    Type type;
+    OptionalParseResult parsedType = parser.parseOptionalType(type);
+    if (parsedType.has_value()) {
+      if (failed(parsedType.value()))
+        return parser.emitError(parser.getCurrentLocation())
+               << "error while parsing type DLTI key";
+
+      if (failed(parser.parseEqual()) || failed(parser.parseAttribute(value)))
+        return failure();
+
+      entry = DataLayoutEntryAttr::get(type, value);
+      return ParseResult::success();
+    }
+  }
+
+  std::string ident;
+  OptionalParseResult parsedStr = parser.parseOptionalString(&ident);
+  if (parsedStr.has_value() && !ident.empty()) {
+    if (failed(parsedStr.value()))
+      return parser.emitError(parser.getCurrentLocation())
+             << "error while parsing string DLTI key";
+
+    if (failed(parser.parseEqual()) || failed(parser.parseAttribute(value)))
+      return failure(); // Assume that an error has already been emitted.
+
+    entry = DataLayoutEntryAttr::get(
+        StringAttr::get(parser.getContext(), ident), value);
+    return ParseResult::success();
+  }
+
+  OptionalParseResult parsedEntry = parser.parseAttribute(entry);
+  if (parsedEntry.has_value()) {
+    if (succeeded(parsedEntry.value()))
+      return parsedEntry.value();
+    return failure(); // Assume that an error has already been emitted.
+  }
+  return parser.emitError(parser.getCurrentLocation())
+         << "failed to parse DLTI entry";
+}
+
+template <class Attr>
+static Attribute parseAngleBracketedEntries(AsmParser &parser, Type ty,
+                                            bool tryType = false,
+                                            bool allowEmpty = false) {
+  SmallVector<DataLayoutEntryInterface> entries;
+  if (failed(parser.parseCommaSeparatedList(
+          AsmParser::Delimiter::LessGreater, [&]() {
+            return parseKeyValuePair(parser, entries.emplace_back(), tryType);
+          })))
+    return {};
+
+  if (entries.empty() && !allowEmpty) {
+    parser.emitError(parser.getNameLoc()) << "no DLTI entries provided";
+    return {};
+  }
+
+  return Attr::getChecked([&] { return parser.emitError(parser.getNameLoc()); },
+                          parser.getContext(), ArrayRef(entries));
+}
+
+//===----------------------------------------------------------------------===//
+// printing
+//===----------------------------------------------------------------------===//
+
+static inline std::string keyToStr(DataLayoutEntryKey key) {
+  std::string buf;
+  llvm::TypeSwitch<DataLayoutEntryKey>(key)
+      .Case<StringAttr, Type>( // The only two kinds of key we know of.
+          [&](auto key) { llvm::raw_string_ostream(buf) << key; })
+      .Default([](auto) { llvm_unreachable("unexpected entry key kind"); });
+  return buf;
+}
+
+template <class T>
+static void printAngleBracketedEntries(AsmPrinter &os, T &&entries) {
+  os << "<";
+  llvm::interleaveComma(std::forward<T>(entries), os, [&](auto entry) {
+    os << keyToStr(entry.getKey()) << " = " << entry.getValue();
+  });
+  os << ">";
+}
+
+//===----------------------------------------------------------------------===//
+// verifying
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyEntries(function_ref<InFlightDiagnostic()> emitError,
+                                   ArrayRef<DataLayoutEntryInterface> entries,
+                                   bool allowTypes = true) {
+  DenseSet<DataLayoutEntryKey> keys;
+  for (DataLayoutEntryInterface entry : entries) {
+    if (!entry)
+      return emitError() << "contained invalid DLTI entry";
+    DataLayoutEntryKey key = entry.getKey();
+    if (key.isNull())
+      return emitError() << "contained invalid DLTI key";
+    if (!allowTypes && llvm::dyn_cast<Type>(key))
+      return emitError() << "type as DLIT key is not allowed";
+    if (!keys.insert(key).second)
+      return emitError() << "repeated DLTI key: " << keyToStr(key);
+    if (!entry.getValue())
+      return emitError() << "value associated to DLTI key " << keyToStr(key)
+                         << " is invalid";
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // DataLayoutEntryAttr
 //===----------------------------------------------------------------------===//
@@ -71,15 +189,16 @@ DataLayoutEntryKey DataLayoutEntryAttr::getKey() const {
 Attribute DataLayoutEntryAttr::getValue() const { return getImpl()->value; }
 
 /// Parses an attribute with syntax:
-///   attr ::= `#target.` `dl_entry` `<` (type | quoted-string) `,` attr `>`
-Attribute DataLayoutEntryAttr::parse(AsmParser &parser, Type ty) {
+///   dl-entry-attr ::= `#dlti.` `dl_entry` `<` (type | quoted-string) `,`
+///     attr `>`
+Attribute DataLayoutEntryAttr::parse(AsmParser &parser, Type type) {
   if (failed(parser.parseLess()))
     return {};
 
-  Type type = nullptr;
+  Type typeKey = nullptr;
   std::string identifier;
   SMLoc idLoc = parser.getCurrentLocation();
-  OptionalParseResult parsedType = parser.parseOptionalType(type);
+  OptionalParseResult parsedType = parser.parseOptionalType(typeKey);
   if (parsedType.has_value() && failed(parsedType.value()))
     return {};
   if (!parsedType.has_value()) {
@@ -95,38 +214,29 @@ Attribute DataLayoutEntryAttr::parse(AsmParser &parser, Type ty) {
       failed(parser.parseGreater()))
     return {};
 
-  return type ? get(type, value)
-              : get(parser.getBuilder().getStringAttr(identifier), value);
+  return typeKey ? get(typeKey, value)
+                 : get(parser.getBuilder().getStringAttr(identifier), value);
 }
 
-void DataLayoutEntryAttr::print(AsmPrinter &os) const {
-  os << "<";
-  if (auto type = llvm::dyn_cast_if_present<Type>(getKey()))
-    os << type;
-  else
-    os << "\"" << getKey().get<StringAttr>().strref() << "\"";
-  os << ", " << getValue() << ">";
+void DataLayoutEntryAttr::print(AsmPrinter &printer) const {
+  printer << "<" << keyToStr(getKey()) << ", " << getValue() << ">";
 }
 
 //===----------------------------------------------------------------------===//
 // DLTIMapAttr
 //===----------------------------------------------------------------------===//
 
-static LogicalResult verifyEntries(function_ref<InFlightDiagnostic()> emitError,
-                                   ArrayRef<DataLayoutEntryInterface> entries) {
-  DenseSet<Type> types;
-  DenseSet<StringAttr> ids;
-  for (DataLayoutEntryInterface entry : entries) {
-    if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
-      if (!types.insert(type).second)
-        return emitError() << "repeated layout entry key: " << type;
-    } else {
-      auto id = entry.getKey().get<StringAttr>();
-      if (!ids.insert(id).second)
-        return emitError() << "repeated layout entry key: " << id.getValue();
-    }
-  }
-  return success();
+/// Parses an attribute with syntax:
+///   map-attr ::= `#dlti.` `map` `<` entry-list `>`
+///   entry-list ::= entry | entry `,` entry-list
+///   entry ::= ((type | quoted-string) `=` attr) | dl-entry-attr
+Attribute MapAttr::parse(AsmParser &parser, Type type) {
+  return parseAngleBracketedEntries<MapAttr>(parser, type, /*tryType=*/true,
+                                             /*allowEmpty=*/true);
+}
+
+void MapAttr::print(AsmPrinter &printer) const {
+  printAngleBracketedEntries(printer, getEntries());
 }
 
 LogicalResult MapAttr::verify(function_ref<InFlightDiagnostic()> emitError,
@@ -282,98 +392,40 @@ DataLayoutSpecAttr::getStackAlignmentIdentifier(MLIRContext *context) const {
       DLTIDialect::kDataLayoutStackAlignmentKey);
 }
 
-/// Parses an attribute with syntax
-///   attr ::= `#target.` `dl_spec` `<` attr-list? `>`
-///   attr-list ::= attr
-///               | attr `,` attr-list
+/// Parses an attribute with syntax:
+///   dl-spec-attr ::= `#dlti.` `dl_spec` `<` entry-list `>`
+///   entry-list ::= | entry | entry `,` entry-list
+///   entry ::= ((type | quoted-string) = attr) | dl-entry-attr
 Attribute DataLayoutSpecAttr::parse(AsmParser &parser, Type type) {
-  if (failed(parser.parseLess()))
-    return {};
-
-  // Empty spec.
-  if (succeeded(parser.parseOptionalGreater()))
-    return get(parser.getContext(), {});
-
-  SmallVector<DataLayoutEntryInterface> entries;
-  if (parser.parseCommaSeparatedList(
-          [&]() { return parser.parseAttribute(entries.emplace_back()); }) ||
-      parser.parseGreater())
-    return {};
-
-  return getChecked([&] { return parser.emitError(parser.getNameLoc()); },
-                    parser.getContext(), entries);
+  return parseAngleBracketedEntries<DataLayoutSpecAttr>(parser, type,
+                                                        /*tryType=*/true,
+                                                        /*allowEmpty=*/true);
 }
 
-void DataLayoutSpecAttr::print(AsmPrinter &os) const {
-  os << "<";
-  llvm::interleaveComma(getEntries(), os);
-  os << ">";
+void DataLayoutSpecAttr::print(AsmPrinter &printer) const {
+  printAngleBracketedEntries(printer, getEntries());
 }
 
 //===----------------------------------------------------------------------===//
 // TargetDeviceSpecAttr
 //===----------------------------------------------------------------------===//
 
-namespace mlir {
-/// A FieldParser for key-value pairs of DeviceID-target device spec pairs that
-/// make up a target system spec.
-template <>
-struct FieldParser<DeviceIDTargetDeviceSpecPair> {
-  static FailureOr<DeviceIDTargetDeviceSpecPair> parse(AsmParser &parser) {
-    std::string deviceID;
-
-    if (failed(parser.parseString(&deviceID))) {
-      parser.emitError(parser.getCurrentLocation())
-          << "DeviceID is missing, or is not of string type";
-      return failure();
-    }
-
-    if (failed(parser.parseColon())) {
-      parser.emitError(parser.getCurrentLocation()) << "Missing colon";
-      return failure();
-    }
-
-    auto target_device_spec =
-        FieldParser<TargetDeviceSpecInterface>::parse(parser);
-    if (failed(target_device_spec)) {
-      parser.emitError(parser.getCurrentLocation())
-          << "Error in parsing target device spec";
-      return failure();
-    }
-
-    return std::make_pair(parser.getBuilder().getStringAttr(deviceID),
-                          *target_device_spec);
-  }
-};
-
-inline AsmPrinter &operator<<(AsmPrinter &printer,
-                              DeviceIDTargetDeviceSpecPair param) {
-  return printer << param.first << " : " << param.second;
-}
-
-} // namespace mlir
-
 LogicalResult
 TargetDeviceSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
                              ArrayRef<DataLayoutEntryInterface> entries) {
-  // Entries in a target device spec can only have StringAttr as key. It does
-  // not support type as a key. Hence not reusing
-  // DataLayoutEntryInterface::verify.
-  DenseSet<StringAttr> ids;
-  for (DataLayoutEntryInterface entry : entries) {
-    if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
-      return emitError()
-             << "dlti.target_device_spec does not allow type as a key: "
-             << type;
-    } else {
-      // Check that keys in a target device spec are unique.
-      auto id = entry.getKey().get<StringAttr>();
-      if (!ids.insert(id).second)
-        return emitError() << "repeated layout entry key: " << id.getValue();
-    }
-  }
+  return verifyEntries(emitError, entries, /*allowTypes=*/false);
+}
 
-  return success();
+/// Parses an attribute with syntax:
+///   dev-spec-attr ::= `#dlti.` `target_device_spec` `<` entry-list `>`
+///   entry-list ::= entry | entry `,` entry-list
+///   entry ::= (quoted-string `=` attr) | dl-entry-attr
+Attribute TargetDeviceSpecAttr::parse(AsmParser &parser, Type type) {
+  return parseAngleBracketedEntries<TargetDeviceSpecAttr>(parser, type);
+}
+
+void TargetDeviceSpecAttr::print(AsmPrinter &printer) const {
+  printAngleBracketedEntries(printer, getEntries());
 }
 
 //===----------------------------------------------------------------------===//
@@ -382,27 +434,46 @@ TargetDeviceSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
 
 LogicalResult
 TargetSystemSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
-                             ArrayRef<DeviceIDTargetDeviceSpecPair> entries) {
-  DenseSet<TargetSystemSpecInterface::DeviceID> device_ids;
+                             ArrayRef<DataLayoutEntryInterface> entries) {
+  DenseSet<TargetSystemSpecInterface::DeviceID> deviceIds;
 
   for (const auto &entry : entries) {
-    TargetDeviceSpecInterface target_device_spec = entry.second;
-
-    // First verify that a target device spec is valid.
-    if (failed(TargetDeviceSpecAttr::verify(emitError,
-                                            target_device_spec.getEntries())))
-      return failure();
+    auto deviceId =
+        llvm::dyn_cast<TargetSystemSpecInterface::DeviceID>(entry.getKey());
+    if (!deviceId)
+      return emitError() << "non-string key of DLTI system spec";
+
+    if (auto targetDeviceSpec...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Oct 22, 2024

@llvm/pr-subscribers-mlir

Author: Rolf Morel (rolfmorel)

Changes

Unifies parsing and printing for DLTI attributes. Introduces a format of #dlti.attr&lt;key1 = val1, ..., keyN = valN&gt; syntax for all queryable DLTI attributes similar to that of the DictionaryAttr, while retaining support for specifying key-value pairs with #dlti.dl_entry (whether to retain this is TBD).

As the new format does away with all the boilerplate, it is much easier to parse for humans. This makes an especially big difference for nested attributes.

Updates the DLTI-using tests and includes fixes for misc error checking/ error messages.


Patch is 73.01 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/113365.diff

12 Files Affected:

  • (modified) mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td (+28-22)
  • (modified) mlir/include/mlir/Interfaces/DataLayoutInterfaces.h (+2-3)
  • (modified) mlir/include/mlir/Interfaces/DataLayoutInterfaces.td (+1-1)
  • (modified) mlir/lib/Dialect/DLTI/DLTI.cpp (+191-129)
  • (modified) mlir/lib/Interfaces/DataLayoutInterfaces.cpp (+11-2)
  • (modified) mlir/test/Dialect/DLTI/invalid.mlir (+14-23)
  • (modified) mlir/test/Dialect/DLTI/query.mlir (+51-52)
  • (modified) mlir/test/Dialect/DLTI/roundtrip.mlir (+29-21)
  • (modified) mlir/test/Dialect/DLTI/valid.mlir (+124-93)
  • (modified) mlir/test/Dialect/GPU/outlining.mlir (+13-13)
  • (modified) mlir/test/Target/LLVMIR/Import/data-layout.ll (+24-24)
  • (modified) mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp (+22-16)
diff --git a/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td b/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
index 53d38407608bed..1caf5fd8787c7b 100644
--- a/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
+++ b/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
@@ -93,6 +93,10 @@ def DLTI_DataLayoutSpecAttr :
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// MapAttr
+//===----------------------------------------------------------------------===//
+
 def DLTI_MapAttr : DLTIAttr<"Map", [DLTIQueryInterface]> {
   let summary = "A mapping of DLTI-information by way of key-value pairs";
   let description = [{
@@ -106,18 +110,16 @@ def DLTI_MapAttr : DLTIAttr<"Map", [DLTIQueryInterface]> {
 
     Consider the following flat encoding of a single-key dictionary
     ```
-    #dlti.map<#dlti.dl_entry<"CPU::cache::L1::size_in_bytes", 65536 : i32>>
+    #dlti.map<"CPU::cache::L1::size_in_bytes" = 65536 : i32>>
     ```
     versus nested maps, which make it possible to obtain sub-dictionaries of
     related information (with the following example making use of other
     attributes that also implement the `DLTIQueryInterface`):
     ```
-    #dlti.target_system_spec<"CPU":
-      #dlti.target_device_spec<#dlti.dl_entry<"cache",
-        #dlti.map<#dlti.dl_entry<"L1",
-          #dlti.map<#dlti.dl_entry<"size_in_bytes", 65536 : i32>>>,
-                  #dlti.dl_entry<"L1d",
-          #dlti.map<#dlti.dl_entry<"size_in_bytes", 32768 : i32>>> >>>>
+    #dlti.target_system_spec<"CPU" =
+      #dlti.target_device_spec<"cache" =
+        #dlti.map<"L1" = #dlti.map<"size_in_bytes" = 65536 : i32>,
+                  "L1d" = #dlti.map<"size_in_bytes" = 32768 : i32> >>>
     ```
 
     With the flat encoding, the implied structure of the key is ignored, that is
@@ -139,7 +141,7 @@ def DLTI_MapAttr : DLTIAttr<"Map", [DLTIQueryInterface]> {
   );
   let mnemonic = "map";
   let genVerifyDecl = 1;
-  let assemblyFormat = "`<` $entries `>`";
+  let hasCustomAssemblyFormat = 1;
   let extraClassDeclaration = [{
     /// Returns the attribute associated with the key.
     FailureOr<Attribute> query(DataLayoutEntryKey key) {
@@ -167,20 +169,23 @@ def DLTI_TargetSystemSpecAttr :
     ```
     dlti.target_system_spec =
      #dlti.target_system_spec<
-      "CPU": #dlti.target_device_spec<
-              #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096: ui32>>,
-      "GPU": #dlti.target_device_spec<
-              #dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>,
-      "XPU": #dlti.target_device_spec<
-              #dlti.dl_entry<"dlti.max_vector_op_width", 4096 : ui32>>>
+      "CPU" = #dlti.target_device_spec<
+        "L1_cache_size_in_bytes" = 4096: ui32>,
+      "GPU" = #dlti.target_device_spec<
+        "max_vector_op_width" = 64 : ui32>,
+      "XPU" = #dlti.target_device_spec<
+        "max_vector_op_width" = 4096 : ui32>>
     ```
+
+    The verifier checks that keys are strings and pointed to values implement
+    DLTI's TargetDeviceSpecInterface.
   }];
   let parameters = (ins
-    ArrayRefParameter<"DeviceIDTargetDeviceSpecPair", "">:$entries
+    ArrayRefParameter<"DataLayoutEntryInterface">:$entries
   );
   let mnemonic = "target_system_spec";
   let genVerifyDecl = 1;
-  let assemblyFormat = "`<` $entries `>`";
+  let hasCustomAssemblyFormat = 1;
   let extraClassDeclaration = [{
     /// Return the device specification that matches the given device ID
     std::optional<TargetDeviceSpecInterface>
@@ -197,8 +202,10 @@ def DLTI_TargetSystemSpecAttr :
     $cppClass::getDeviceSpecForDeviceID(
         TargetSystemSpecInterface::DeviceID deviceID) {
       for (const auto& entry : getEntries()) {
-        if (entry.first == deviceID)
-          return entry.second;
+        if (entry.getKey() == DataLayoutEntryKey(deviceID))
+          if (auto deviceSpec =
+              llvm::dyn_cast<TargetDeviceSpecInterface>(entry.getValue()))
+            return deviceSpec;
       }
       return std::nullopt;
     }
@@ -219,16 +226,15 @@ def DLTI_TargetDeviceSpecAttr :
 
     Example:
     ```
-    #dlti.target_device_spec<
-      #dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>
+    #dlti.target_device_spec<"max_vector_op_width" = 64 : ui32>
     ```
   }];
   let parameters = (ins
-    ArrayRefParameter<"DataLayoutEntryInterface", "">:$entries
+    ArrayRefParameter<"DataLayoutEntryInterface">:$entries
   );
   let mnemonic = "target_device_spec";
   let genVerifyDecl = 1;
-  let assemblyFormat = "`<` $entries `>`";
+  let hasCustomAssemblyFormat = 1;
 
   let extraClassDeclaration = [{
     /// Returns the attribute associated with the key.
diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
index 848d2dee4a6309..e3fdf85b15ea59 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
@@ -15,6 +15,7 @@
 #ifndef MLIR_INTERFACES_DATALAYOUTINTERFACES_H
 #define MLIR_INTERFACES_DATALAYOUTINTERFACES_H
 
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/DialectInterface.h"
 #include "mlir/IR/OpDefinition.h"
 #include "llvm/ADT/DenseMap.h"
@@ -32,10 +33,8 @@ using DataLayoutEntryKey = llvm::PointerUnion<Type, StringAttr>;
 using DataLayoutEntryList = llvm::SmallVector<DataLayoutEntryInterface, 4>;
 using DataLayoutEntryListRef = llvm::ArrayRef<DataLayoutEntryInterface>;
 using TargetDeviceSpecListRef = llvm::ArrayRef<TargetDeviceSpecInterface>;
-using DeviceIDTargetDeviceSpecPair =
+using TargetDeviceSpecEntry =
     std::pair<StringAttr, TargetDeviceSpecInterface>;
-using DeviceIDTargetDeviceSpecPairListRef =
-    llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>;
 class DataLayoutOpInterface;
 class DataLayoutSpecInterface;
 class ModuleOp;
diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
index d6e955be4291a3..061dee2399d9ad 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
@@ -304,7 +304,7 @@ def TargetSystemSpecInterface : AttrInterface<"TargetSystemSpecInterface", [DLTI
   let methods = [
     InterfaceMethod<
       /*description=*/"Returns the list of layout entries.",
-      /*retTy=*/"llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>",
+      /*retTy=*/"llvm::ArrayRef<DataLayoutEntryInterface>",
       /*methodName=*/"getEntries",
       /*args=*/(ins)
     >,
diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp
index 85ec9fc93248a1..d8946d865a1836 100644
--- a/mlir/lib/Dialect/DLTI/DLTI.cpp
+++ b/mlir/lib/Dialect/DLTI/DLTI.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -28,6 +29,123 @@ using namespace mlir;
 
 #define DEBUG_TYPE "dlti"
 
+//===----------------------------------------------------------------------===//
+// parsing
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseKeyValuePair(AsmParser &parser,
+                                     DataLayoutEntryInterface &entry,
+                                     bool tryType = false) {
+  Attribute value;
+
+  if (tryType) {
+    Type type;
+    OptionalParseResult parsedType = parser.parseOptionalType(type);
+    if (parsedType.has_value()) {
+      if (failed(parsedType.value()))
+        return parser.emitError(parser.getCurrentLocation())
+               << "error while parsing type DLTI key";
+
+      if (failed(parser.parseEqual()) || failed(parser.parseAttribute(value)))
+        return failure();
+
+      entry = DataLayoutEntryAttr::get(type, value);
+      return ParseResult::success();
+    }
+  }
+
+  std::string ident;
+  OptionalParseResult parsedStr = parser.parseOptionalString(&ident);
+  if (parsedStr.has_value() && !ident.empty()) {
+    if (failed(parsedStr.value()))
+      return parser.emitError(parser.getCurrentLocation())
+             << "error while parsing string DLTI key";
+
+    if (failed(parser.parseEqual()) || failed(parser.parseAttribute(value)))
+      return failure(); // Assume that an error has already been emitted.
+
+    entry = DataLayoutEntryAttr::get(
+        StringAttr::get(parser.getContext(), ident), value);
+    return ParseResult::success();
+  }
+
+  OptionalParseResult parsedEntry = parser.parseAttribute(entry);
+  if (parsedEntry.has_value()) {
+    if (succeeded(parsedEntry.value()))
+      return parsedEntry.value();
+    return failure(); // Assume that an error has already been emitted.
+  }
+  return parser.emitError(parser.getCurrentLocation())
+         << "failed to parse DLTI entry";
+}
+
+template <class Attr>
+static Attribute parseAngleBracketedEntries(AsmParser &parser, Type ty,
+                                            bool tryType = false,
+                                            bool allowEmpty = false) {
+  SmallVector<DataLayoutEntryInterface> entries;
+  if (failed(parser.parseCommaSeparatedList(
+          AsmParser::Delimiter::LessGreater, [&]() {
+            return parseKeyValuePair(parser, entries.emplace_back(), tryType);
+          })))
+    return {};
+
+  if (entries.empty() && !allowEmpty) {
+    parser.emitError(parser.getNameLoc()) << "no DLTI entries provided";
+    return {};
+  }
+
+  return Attr::getChecked([&] { return parser.emitError(parser.getNameLoc()); },
+                          parser.getContext(), ArrayRef(entries));
+}
+
+//===----------------------------------------------------------------------===//
+// printing
+//===----------------------------------------------------------------------===//
+
+static inline std::string keyToStr(DataLayoutEntryKey key) {
+  std::string buf;
+  llvm::TypeSwitch<DataLayoutEntryKey>(key)
+      .Case<StringAttr, Type>( // The only two kinds of key we know of.
+          [&](auto key) { llvm::raw_string_ostream(buf) << key; })
+      .Default([](auto) { llvm_unreachable("unexpected entry key kind"); });
+  return buf;
+}
+
+template <class T>
+static void printAngleBracketedEntries(AsmPrinter &os, T &&entries) {
+  os << "<";
+  llvm::interleaveComma(std::forward<T>(entries), os, [&](auto entry) {
+    os << keyToStr(entry.getKey()) << " = " << entry.getValue();
+  });
+  os << ">";
+}
+
+//===----------------------------------------------------------------------===//
+// verifying
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyEntries(function_ref<InFlightDiagnostic()> emitError,
+                                   ArrayRef<DataLayoutEntryInterface> entries,
+                                   bool allowTypes = true) {
+  DenseSet<DataLayoutEntryKey> keys;
+  for (DataLayoutEntryInterface entry : entries) {
+    if (!entry)
+      return emitError() << "contained invalid DLTI entry";
+    DataLayoutEntryKey key = entry.getKey();
+    if (key.isNull())
+      return emitError() << "contained invalid DLTI key";
+    if (!allowTypes && llvm::dyn_cast<Type>(key))
+      return emitError() << "type as DLIT key is not allowed";
+    if (!keys.insert(key).second)
+      return emitError() << "repeated DLTI key: " << keyToStr(key);
+    if (!entry.getValue())
+      return emitError() << "value associated to DLTI key " << keyToStr(key)
+                         << " is invalid";
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // DataLayoutEntryAttr
 //===----------------------------------------------------------------------===//
@@ -71,15 +189,16 @@ DataLayoutEntryKey DataLayoutEntryAttr::getKey() const {
 Attribute DataLayoutEntryAttr::getValue() const { return getImpl()->value; }
 
 /// Parses an attribute with syntax:
-///   attr ::= `#target.` `dl_entry` `<` (type | quoted-string) `,` attr `>`
-Attribute DataLayoutEntryAttr::parse(AsmParser &parser, Type ty) {
+///   dl-entry-attr ::= `#dlti.` `dl_entry` `<` (type | quoted-string) `,`
+///     attr `>`
+Attribute DataLayoutEntryAttr::parse(AsmParser &parser, Type type) {
   if (failed(parser.parseLess()))
     return {};
 
-  Type type = nullptr;
+  Type typeKey = nullptr;
   std::string identifier;
   SMLoc idLoc = parser.getCurrentLocation();
-  OptionalParseResult parsedType = parser.parseOptionalType(type);
+  OptionalParseResult parsedType = parser.parseOptionalType(typeKey);
   if (parsedType.has_value() && failed(parsedType.value()))
     return {};
   if (!parsedType.has_value()) {
@@ -95,38 +214,29 @@ Attribute DataLayoutEntryAttr::parse(AsmParser &parser, Type ty) {
       failed(parser.parseGreater()))
     return {};
 
-  return type ? get(type, value)
-              : get(parser.getBuilder().getStringAttr(identifier), value);
+  return typeKey ? get(typeKey, value)
+                 : get(parser.getBuilder().getStringAttr(identifier), value);
 }
 
-void DataLayoutEntryAttr::print(AsmPrinter &os) const {
-  os << "<";
-  if (auto type = llvm::dyn_cast_if_present<Type>(getKey()))
-    os << type;
-  else
-    os << "\"" << getKey().get<StringAttr>().strref() << "\"";
-  os << ", " << getValue() << ">";
+void DataLayoutEntryAttr::print(AsmPrinter &printer) const {
+  printer << "<" << keyToStr(getKey()) << ", " << getValue() << ">";
 }
 
 //===----------------------------------------------------------------------===//
 // DLTIMapAttr
 //===----------------------------------------------------------------------===//
 
-static LogicalResult verifyEntries(function_ref<InFlightDiagnostic()> emitError,
-                                   ArrayRef<DataLayoutEntryInterface> entries) {
-  DenseSet<Type> types;
-  DenseSet<StringAttr> ids;
-  for (DataLayoutEntryInterface entry : entries) {
-    if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
-      if (!types.insert(type).second)
-        return emitError() << "repeated layout entry key: " << type;
-    } else {
-      auto id = entry.getKey().get<StringAttr>();
-      if (!ids.insert(id).second)
-        return emitError() << "repeated layout entry key: " << id.getValue();
-    }
-  }
-  return success();
+/// Parses an attribute with syntax:
+///   map-attr ::= `#dlti.` `map` `<` entry-list `>`
+///   entry-list ::= entry | entry `,` entry-list
+///   entry ::= ((type | quoted-string) `=` attr) | dl-entry-attr
+Attribute MapAttr::parse(AsmParser &parser, Type type) {
+  return parseAngleBracketedEntries<MapAttr>(parser, type, /*tryType=*/true,
+                                             /*allowEmpty=*/true);
+}
+
+void MapAttr::print(AsmPrinter &printer) const {
+  printAngleBracketedEntries(printer, getEntries());
 }
 
 LogicalResult MapAttr::verify(function_ref<InFlightDiagnostic()> emitError,
@@ -282,98 +392,40 @@ DataLayoutSpecAttr::getStackAlignmentIdentifier(MLIRContext *context) const {
       DLTIDialect::kDataLayoutStackAlignmentKey);
 }
 
-/// Parses an attribute with syntax
-///   attr ::= `#target.` `dl_spec` `<` attr-list? `>`
-///   attr-list ::= attr
-///               | attr `,` attr-list
+/// Parses an attribute with syntax:
+///   dl-spec-attr ::= `#dlti.` `dl_spec` `<` entry-list `>`
+///   entry-list ::= | entry | entry `,` entry-list
+///   entry ::= ((type | quoted-string) = attr) | dl-entry-attr
 Attribute DataLayoutSpecAttr::parse(AsmParser &parser, Type type) {
-  if (failed(parser.parseLess()))
-    return {};
-
-  // Empty spec.
-  if (succeeded(parser.parseOptionalGreater()))
-    return get(parser.getContext(), {});
-
-  SmallVector<DataLayoutEntryInterface> entries;
-  if (parser.parseCommaSeparatedList(
-          [&]() { return parser.parseAttribute(entries.emplace_back()); }) ||
-      parser.parseGreater())
-    return {};
-
-  return getChecked([&] { return parser.emitError(parser.getNameLoc()); },
-                    parser.getContext(), entries);
+  return parseAngleBracketedEntries<DataLayoutSpecAttr>(parser, type,
+                                                        /*tryType=*/true,
+                                                        /*allowEmpty=*/true);
 }
 
-void DataLayoutSpecAttr::print(AsmPrinter &os) const {
-  os << "<";
-  llvm::interleaveComma(getEntries(), os);
-  os << ">";
+void DataLayoutSpecAttr::print(AsmPrinter &printer) const {
+  printAngleBracketedEntries(printer, getEntries());
 }
 
 //===----------------------------------------------------------------------===//
 // TargetDeviceSpecAttr
 //===----------------------------------------------------------------------===//
 
-namespace mlir {
-/// A FieldParser for key-value pairs of DeviceID-target device spec pairs that
-/// make up a target system spec.
-template <>
-struct FieldParser<DeviceIDTargetDeviceSpecPair> {
-  static FailureOr<DeviceIDTargetDeviceSpecPair> parse(AsmParser &parser) {
-    std::string deviceID;
-
-    if (failed(parser.parseString(&deviceID))) {
-      parser.emitError(parser.getCurrentLocation())
-          << "DeviceID is missing, or is not of string type";
-      return failure();
-    }
-
-    if (failed(parser.parseColon())) {
-      parser.emitError(parser.getCurrentLocation()) << "Missing colon";
-      return failure();
-    }
-
-    auto target_device_spec =
-        FieldParser<TargetDeviceSpecInterface>::parse(parser);
-    if (failed(target_device_spec)) {
-      parser.emitError(parser.getCurrentLocation())
-          << "Error in parsing target device spec";
-      return failure();
-    }
-
-    return std::make_pair(parser.getBuilder().getStringAttr(deviceID),
-                          *target_device_spec);
-  }
-};
-
-inline AsmPrinter &operator<<(AsmPrinter &printer,
-                              DeviceIDTargetDeviceSpecPair param) {
-  return printer << param.first << " : " << param.second;
-}
-
-} // namespace mlir
-
 LogicalResult
 TargetDeviceSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
                              ArrayRef<DataLayoutEntryInterface> entries) {
-  // Entries in a target device spec can only have StringAttr as key. It does
-  // not support type as a key. Hence not reusing
-  // DataLayoutEntryInterface::verify.
-  DenseSet<StringAttr> ids;
-  for (DataLayoutEntryInterface entry : entries) {
-    if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
-      return emitError()
-             << "dlti.target_device_spec does not allow type as a key: "
-             << type;
-    } else {
-      // Check that keys in a target device spec are unique.
-      auto id = entry.getKey().get<StringAttr>();
-      if (!ids.insert(id).second)
-        return emitError() << "repeated layout entry key: " << id.getValue();
-    }
-  }
+  return verifyEntries(emitError, entries, /*allowTypes=*/false);
+}
 
-  return success();
+/// Parses an attribute with syntax:
+///   dev-spec-attr ::= `#dlti.` `target_device_spec` `<` entry-list `>`
+///   entry-list ::= entry | entry `,` entry-list
+///   entry ::= (quoted-string `=` attr) | dl-entry-attr
+Attribute TargetDeviceSpecAttr::parse(AsmParser &parser, Type type) {
+  return parseAngleBracketedEntries<TargetDeviceSpecAttr>(parser, type);
+}
+
+void TargetDeviceSpecAttr::print(AsmPrinter &printer) const {
+  printAngleBracketedEntries(printer, getEntries());
 }
 
 //===----------------------------------------------------------------------===//
@@ -382,27 +434,46 @@ TargetDeviceSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
 
 LogicalResult
 TargetSystemSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
-                             ArrayRef<DeviceIDTargetDeviceSpecPair> entries) {
-  DenseSet<TargetSystemSpecInterface::DeviceID> device_ids;
+                             ArrayRef<DataLayoutEntryInterface> entries) {
+  DenseSet<TargetSystemSpecInterface::DeviceID> deviceIds;
 
   for (const auto &entry : entries) {
-    TargetDeviceSpecInterface target_device_spec = entry.second;
-
-    // First verify that a target device spec is valid.
-    if (failed(TargetDeviceSpecAttr::verify(emitError,
-                                            target_device_spec.getEntries())))
-      return failure();
+    auto deviceId =
+        llvm::dyn_cast<TargetSystemSpecInterface::DeviceID>(entry.getKey());
+    if (!deviceId)
+      return emitError() << "non-string key of DLTI system spec";
+
+    if (auto targetDeviceSpec...
[truncated]

Copy link

github-actions bot commented Oct 22, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@AntonLydike AntonLydike left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really cool stuff! Great upgrade to usability!

Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dropped a few nit comments, but this should definitely be reviewed by @ftynse

mlir/lib/Interfaces/DataLayoutInterfaces.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/DLTI/DLTI.cpp Show resolved Hide resolved
mlir/lib/Dialect/DLTI/DLTI.cpp Show resolved Hide resolved
mlir/lib/Dialect/DLTI/DLTI.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/DLTI/DLTI.cpp Show resolved Hide resolved
mlir/lib/Dialect/DLTI/DLTI.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/DLTI/DLTI.cpp Show resolved Hide resolved
mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td Outdated Show resolved Hide resolved
mlir/include/mlir/Interfaces/DataLayoutInterfaces.td Outdated Show resolved Hide resolved
mlir/lib/Dialect/DLTI/DLTI.cpp Show resolved Hide resolved
mlir/lib/Dialect/DLTI/DLTI.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/DLTI/DLTI.cpp Show resolved Hide resolved
mlir/lib/Dialect/DLTI/DLTI.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/DLTI/DLTI.cpp Show resolved Hide resolved
@ftynse
Copy link
Member

ftynse commented Oct 28, 2024

LGTM with comments addressed.

@llvmbot llvmbot added the flang Flang issues not falling into any other category label Oct 30, 2024
@rolfmorel
Copy link
Contributor Author

Thank you, @Dinistro and @ftynse for the review (and @AntonLydike for the approval)!

Everything should be addressed now.

I will wait a day or two to see if there are any further remarks and otherwise proceed with the merge.

Thanks again!

Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for addressing the comments.

mlir/lib/Dialect/DLTI/DLTI.cpp Outdated Show resolved Hide resolved
Unifies parsing and printing for DLTI attributes. Introduces syntax of
`#dlti.attr<key1 = val1, ..., keyN = valN>` for all queryable DLTI
 attributes, while retaining support for specifying key-value entry
pairs with `#dlti.dl_entry` (whether to retain this is TBD).

As the new format does away with much of the boilerplate, it is much
easier on the eye. This makes an especially big difference for nested
attributes.

Updates the DLTI tests and includes fixes for misc error checking/
error messages.
@rolfmorel rolfmorel merged commit 5c1752e into llvm:main Oct 31, 2024
5 of 7 checks passed
smallp-o-p pushed a commit to smallp-o-p/llvm-project that referenced this pull request Nov 3, 2024
Unifies parsing and printing for DLTI attributes. Introduces a format of
`#dlti.attr<key1 = val1, ..., keyN = valN>` syntax for all queryable
DLTI attributes similar to that of the DictionaryAttr, while retaining
support for specifying key-value pairs with `#dlti.dl_entry` (whether to
retain this is TBD).

As the new format does away with most of the boilerplate, it is much easier
to parse for humans. This makes an especially big difference for nested
attributes.

Updates the DLTI-using tests and includes fixes for misc error checking/
error messages.
smallp-o-p pushed a commit to smallp-o-p/llvm-project that referenced this pull request Nov 3, 2024
Unifies parsing and printing for DLTI attributes. Introduces a format of
`#dlti.attr<key1 = val1, ..., keyN = valN>` syntax for all queryable
DLTI attributes similar to that of the DictionaryAttr, while retaining
support for specifying key-value pairs with `#dlti.dl_entry` (whether to
retain this is TBD).

As the new format does away with most of the boilerplate, it is much easier
to parse for humans. This makes an especially big difference for nested
attributes.

Updates the DLTI-using tests and includes fixes for misc error checking/
error messages.
NoumanAmir657 pushed a commit to NoumanAmir657/llvm-project that referenced this pull request Nov 4, 2024
Unifies parsing and printing for DLTI attributes. Introduces a format of
`#dlti.attr<key1 = val1, ..., keyN = valN>` syntax for all queryable
DLTI attributes similar to that of the DictionaryAttr, while retaining
support for specifying key-value pairs with `#dlti.dl_entry` (whether to
retain this is TBD).

As the new format does away with most of the boilerplate, it is much easier
to parse for humans. This makes an especially big difference for nested
attributes.

Updates the DLTI-using tests and includes fixes for misc error checking/
error messages.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang Flang issues not falling into any other category mlir:dlti mlir:gpu mlir:llvm mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants