diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index 5c8f6ded39ba4ed..163481069be42db 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -54,6 +54,14 @@ void dispatchIndexOpFoldResult(OpFoldResult ofr, staticVec.push_back(apInt.getSExtValue()); return; } + + OpFoldResult result = getAsOpFoldResult(v); + if (auto attr = result.dyn_cast()) { + APInt apInt = cast(attr).getValue(); + staticVec.push_back(apInt.getSExtValue()); + return; + } + dynamicVec.push_back(v); staticVec.push_back(ShapedType::kDynamic); } diff --git a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir index cf6b12852bcd39c..15bc9b0435f6e6c 100644 --- a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir +++ b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir @@ -20,6 +20,26 @@ func.func @bubble_parallel_reshapes(%arg0: tensor, %s0: index, %s1: // ----- +func.func @bubble_parallel_reshapes2(%arg0: tensor, %s0: index, %s1: index) -> tensor { + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor into tensor + %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]] + output_shape [%s0, %s1, %c2, %c3] : tensor into tensor + return %expand : tensor +} +// CHECK: func @bubble_parallel_reshapes2 +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]] +// CHECK-SAME: output_shape [%[[S0]], 2, 2, %[[C2]], %[[C3]]] : tensor into tensor +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] {{\[}}[0], [1, 2], [3], [4]] : tensor into tensor +// CHECK: return %[[COLLAPSE]] + +// ----- + func.func @no_bubble_full_intersecting_reshapes(%arg0: tensor, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor { %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor into tensor %expand = tensor.expand_shape %collapse [[0], [1, 2], [3]]