Skip to content

Commit

Permalink
Handle arith.const in dispatchIndexOpFoldResult func
Browse files Browse the repository at this point in the history
Change-Id: I15280932f88d8ff638f5d0f964a1c03ce7a7881a
  • Loading branch information
rutkoor committed Jan 10, 2025
1 parent eeac0ff commit 533c396
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
8 changes: 8 additions & 0 deletions mlir/lib/Dialect/Utils/StaticValueUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ void dispatchIndexOpFoldResult(OpFoldResult ofr,
staticVec.push_back(apInt.getSExtValue());
return;
}

OpFoldResult result = getAsOpFoldResult(v);
if (auto attr = result.dyn_cast<Attribute>()) {
APInt apInt = cast<IntegerAttr>(attr).getValue();
staticVec.push_back(apInt.getSExtValue());
return;
}

dynamicVec.push_back(v);
staticVec.push_back(ShapedType::kDynamic);
}
Expand Down
20 changes: 20 additions & 0 deletions mlir/test/Dialect/Tensor/bubble-reshapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,26 @@ func.func @bubble_parallel_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1:

// -----

func.func @bubble_parallel_reshapes2(%arg0: tensor<?x2x2x6xf32>, %s0: index, %s1: index) -> tensor<?x4x2x3xf32> {
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x2x2x6xf32> into tensor<?x4x6xf32>
%expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
output_shape [%s0, %s1, %c2, %c3] : tensor<?x4x6xf32> into tensor<?x4x2x3xf32>
return %expand : tensor<?x4x2x3xf32>
}
// CHECK: func @bubble_parallel_reshapes2
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x2x2x6xf32>
// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]]
// CHECK-SAME: output_shape [%[[S0]], 2, 2, %[[C2]], %[[C3]]] : tensor<?x2x2x6xf32> into tensor<?x2x2x2x3xf32>
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] {{\[}}[0], [1, 2], [3], [4]] : tensor<?x2x2x2x3xf32> into tensor<?x4x2x3xf32>
// CHECK: return %[[COLLAPSE]]

// -----

func.func @no_bubble_full_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%expand = tensor.expand_shape %collapse [[0], [1, 2], [3]]
Expand Down

0 comments on commit 533c396

Please sign in to comment.