Skip to content

Commit

Permalink
[MLIR][DLTI] Pretty parsing and printing for DLTI attrs (llvm#113365)
Browse files Browse the repository at this point in the history
Unifies parsing and printing for DLTI attributes. Introduces a format of
`#dlti.attr<key1 = val1, ..., keyN = valN>` syntax for all queryable
DLTI attributes similar to that of the DictionaryAttr, while retaining
support for specifying key-value pairs with `#dlti.dl_entry` (whether to
retain this is TBD).

As the new format does away with most of the boilerplate, it is much easier
to parse for humans. This makes an especially big difference for nested
attributes.

Updates the DLTI-using tests and includes fixes for misc error checking/
error messages.
  • Loading branch information
rolfmorel authored and smallp-o-p committed Nov 3, 2024
1 parent aacbb55 commit 90ff785
Show file tree
Hide file tree
Showing 14 changed files with 538 additions and 413 deletions.
2 changes: 1 addition & 1 deletion flang/test/Fir/tco-default-datalayout.fir
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ module {
// CHECK: module attributes {
// CHECK-SAME: dlti.dl_spec = #dlti.dl_spec<
// ...
// CHECK-SAME: #dlti.dl_entry<i64, dense<[32, 64]> : vector<2xi64>>,
// CHECK-SAME: i64 = dense<[32, 64]> : vector<2xi64>,
// ...
// CHECK-SAME: llvm.data_layout = ""
2 changes: 1 addition & 1 deletion flang/test/Fir/tco-explicit-datalayout.fir
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i6
// CHECK: module attributes {
// CHECK-SAME: dlti.dl_spec = #dlti.dl_spec<
// ...
// CHECK-SAME: #dlti.dl_entry<i64, dense<128> : vector<2xi64>>,
// CHECK-SAME: i64 = dense<128> : vector<2xi64>,
// ...
// CHECK-SAME: llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:128-i128:128-f80:128-n8:16:32:64-S128"
57 changes: 31 additions & 26 deletions mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,15 @@ def DLTI_DataLayoutSpecAttr :

/// Returns the attribute associated with the key.
FailureOr<Attribute> query(DataLayoutEntryKey key) {
return llvm::cast<mlir::DataLayoutSpecInterface>(*this).queryHelper(key);
return ::llvm::cast<mlir::DataLayoutSpecInterface>(*this).queryHelper(key);
}
}];
}

//===----------------------------------------------------------------------===//
// MapAttr
//===----------------------------------------------------------------------===//

def DLTI_MapAttr : DLTIAttr<"Map", [DLTIQueryInterface]> {
let summary = "A mapping of DLTI-information by way of key-value pairs";
let description = [{
Expand All @@ -106,18 +110,16 @@ def DLTI_MapAttr : DLTIAttr<"Map", [DLTIQueryInterface]> {

Consider the following flat encoding of a single-key dictionary
```
#dlti.map<#dlti.dl_entry<"CPU::cache::L1::size_in_bytes", 65536 : i32>>
#dlti.map<"CPU::cache::L1::size_in_bytes" = 65536 : i32>>
```
versus nested maps, which make it possible to obtain sub-dictionaries of
related information (with the following example making use of other
attributes that also implement the `DLTIQueryInterface`):
```
#dlti.target_system_spec<"CPU":
#dlti.target_device_spec<#dlti.dl_entry<"cache",
#dlti.map<#dlti.dl_entry<"L1",
#dlti.map<#dlti.dl_entry<"size_in_bytes", 65536 : i32>>>,
#dlti.dl_entry<"L1d",
#dlti.map<#dlti.dl_entry<"size_in_bytes", 32768 : i32>>> >>>>
#dlti.target_system_spec<"CPU" =
#dlti.target_device_spec<"cache" =
#dlti.map<"L1" = #dlti.map<"size_in_bytes" = 65536 : i32>,
"L1d" = #dlti.map<"size_in_bytes" = 32768 : i32> >>>
```

With the flat encoding, the implied structure of the key is ignored, that is
Expand All @@ -132,14 +134,13 @@ def DLTI_MapAttr : DLTIAttr<"Map", [DLTIQueryInterface]> {
`transform.dlti.query ["CPU","cache","L1","size_in_bytes"] at %op` gives
back the first leaf value contained. To access the other leaf, we need to do
`transform.dlti.query ["CPU","cache","L1d","size_in_bytes"] at %op`.
```
}];
let parameters = (ins
ArrayRefParameter<"DataLayoutEntryInterface", "">:$entries
);
let mnemonic = "map";
let genVerifyDecl = 1;
let assemblyFormat = "`<` $entries `>`";
let hasCustomAssemblyFormat = 1;
let extraClassDeclaration = [{
/// Returns the attribute associated with the key.
FailureOr<Attribute> query(DataLayoutEntryKey key) {
Expand Down Expand Up @@ -167,20 +168,23 @@ def DLTI_TargetSystemSpecAttr :
```
dlti.target_system_spec =
#dlti.target_system_spec<
"CPU": #dlti.target_device_spec<
#dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096: ui32>>,
"GPU": #dlti.target_device_spec<
#dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>,
"XPU": #dlti.target_device_spec<
#dlti.dl_entry<"dlti.max_vector_op_width", 4096 : ui32>>>
"CPU" = #dlti.target_device_spec<
"L1_cache_size_in_bytes" = 4096: ui32>,
"GPU" = #dlti.target_device_spec<
"max_vector_op_width" = 64 : ui32>,
"XPU" = #dlti.target_device_spec<
"max_vector_op_width" = 4096 : ui32>>
```

The verifier checks that keys are strings and pointed to values implement
DLTI's TargetDeviceSpecInterface.
}];
let parameters = (ins
ArrayRefParameter<"DeviceIDTargetDeviceSpecPair", "">:$entries
ArrayRefParameter<"DataLayoutEntryInterface">:$entries
);
let mnemonic = "target_system_spec";
let genVerifyDecl = 1;
let assemblyFormat = "`<` $entries `>`";
let hasCustomAssemblyFormat = 1;
let extraClassDeclaration = [{
/// Return the device specification that matches the given device ID
std::optional<TargetDeviceSpecInterface>
Expand All @@ -189,16 +193,18 @@ def DLTI_TargetSystemSpecAttr :

/// Returns the attribute associated with the key.
FailureOr<Attribute> query(DataLayoutEntryKey key) const {
return llvm::cast<mlir::TargetSystemSpecInterface>(*this).queryHelper(key);
return ::llvm::cast<mlir::TargetSystemSpecInterface>(*this).queryHelper(key);
}
}];
let extraClassDefinition = [{
std::optional<TargetDeviceSpecInterface>
$cppClass::getDeviceSpecForDeviceID(
TargetSystemSpecInterface::DeviceID deviceID) {
for (const auto& entry : getEntries()) {
if (entry.first == deviceID)
return entry.second;
if (entry.getKey() == DataLayoutEntryKey(deviceID))
if (auto deviceSpec =
::llvm::dyn_cast<TargetDeviceSpecInterface>(entry.getValue()))
return deviceSpec;
}
return std::nullopt;
}
Expand All @@ -219,21 +225,20 @@ def DLTI_TargetDeviceSpecAttr :

Example:
```
#dlti.target_device_spec<
#dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>
#dlti.target_device_spec<"max_vector_op_width" = 64 : ui32>
```
}];
let parameters = (ins
ArrayRefParameter<"DataLayoutEntryInterface", "">:$entries
ArrayRefParameter<"DataLayoutEntryInterface">:$entries
);
let mnemonic = "target_device_spec";
let genVerifyDecl = 1;
let assemblyFormat = "`<` $entries `>`";
let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = [{
/// Returns the attribute associated with the key.
FailureOr<Attribute> query(DataLayoutEntryKey key) const {
return llvm::cast<mlir::TargetDeviceSpecInterface>(*this).queryHelper(key);
return ::llvm::cast<mlir::TargetDeviceSpecInterface>(*this).queryHelper(key);
}
}];
}
Expand Down
6 changes: 2 additions & 4 deletions mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#ifndef MLIR_INTERFACES_DATALAYOUTINTERFACES_H
#define MLIR_INTERFACES_DATALAYOUTINTERFACES_H

#include "mlir/IR/Attributes.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/OpDefinition.h"
#include "llvm/ADT/DenseMap.h"
Expand All @@ -32,10 +33,7 @@ using DataLayoutEntryKey = llvm::PointerUnion<Type, StringAttr>;
using DataLayoutEntryList = llvm::SmallVector<DataLayoutEntryInterface, 4>;
using DataLayoutEntryListRef = llvm::ArrayRef<DataLayoutEntryInterface>;
using TargetDeviceSpecListRef = llvm::ArrayRef<TargetDeviceSpecInterface>;
using DeviceIDTargetDeviceSpecPair =
std::pair<StringAttr, TargetDeviceSpecInterface>;
using DeviceIDTargetDeviceSpecPairListRef =
llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>;
using TargetDeviceSpecEntry = std::pair<StringAttr, TargetDeviceSpecInterface>;
class DataLayoutOpInterface;
class DataLayoutSpecInterface;
class ModuleOp;
Expand Down
6 changes: 3 additions & 3 deletions mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def TargetDeviceSpecInterface : AttrInterface<"TargetDeviceSpecInterface", [DLTI
/// Helper for default implementation of `DLTIQueryInterface`'s `query`.
::mlir::FailureOr<::mlir::Attribute>
queryHelper(::mlir::DataLayoutEntryKey key) const {
if (auto strKey = llvm::dyn_cast<StringAttr>(key))
if (auto strKey = ::llvm::dyn_cast<StringAttr>(key))
if (DataLayoutEntryInterface spec = getSpecForIdentifier(strKey))
return spec.getValue();
return ::mlir::failure();
Expand Down Expand Up @@ -304,7 +304,7 @@ def TargetSystemSpecInterface : AttrInterface<"TargetSystemSpecInterface", [DLTI
let methods = [
InterfaceMethod<
/*description=*/"Returns the list of layout entries.",
/*retTy=*/"llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>",
/*retTy=*/"::llvm::ArrayRef<DataLayoutEntryInterface>",
/*methodName=*/"getEntries",
/*args=*/(ins)
>,
Expand Down Expand Up @@ -334,7 +334,7 @@ def TargetSystemSpecInterface : AttrInterface<"TargetSystemSpecInterface", [DLTI
/// Helper for default implementation of `DLTIQueryInterface`'s `query`.
::mlir::FailureOr<::mlir::Attribute>
queryHelper(::mlir::DataLayoutEntryKey key) const {
if (auto strKey = llvm::dyn_cast<::mlir::StringAttr>(key))
if (auto strKey = ::llvm::dyn_cast<::mlir::StringAttr>(key))
if (auto deviceSpec = getDeviceSpecForDeviceID(strKey))
return *deviceSpec;
return ::mlir::failure();
Expand Down
Loading

0 comments on commit 90ff785

Please sign in to comment.