diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index a7133ae2139..31af9fe3d68 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -1891,6 +1891,7 @@ LogicalResult inferDotGeneralOp( return emitOptionalError(location, "lhs and rhs should have the same " "number of batching dimensions"); + // dot_general_c2 if (lhsContractingDimensions.size() != rhsContractingDimensions.size()) return emitOptionalError(location, @@ -1911,6 +1912,7 @@ LogicalResult inferDotGeneralOp( } return success(); }; + // dot_general_c3 if (failed(checkDimsDistinct(lhsBatchingDimensions, lhsContractingDimensions, dimSet, "lhs_batching_dimensions", @@ -1934,11 +1936,11 @@ LogicalResult inferDotGeneralOp( " is out of range: ", "[0, ", rank, ")"); return success(); }; + auto lhsRankedType = lhsType.dyn_cast(); - auto rhsRankedType = rhsType.dyn_cast(); - // dot_general_c5 - // dot_general_c6 if (lhsRankedType) { + // dot_general_c5 + // dot_general_c6 if (failed(checkDimsInRange(lhsRankedType.getRank(), lhsBatchingDimensions, "lhs_batching_dimensions")) || failed(checkDimsInRange(lhsRankedType.getRank(), @@ -1946,9 +1948,11 @@ LogicalResult inferDotGeneralOp( "lhs_contracting_dimensions"))) return failure(); } - // dot_general_c7 - // dot_general_c8 + + auto rhsRankedType = rhsType.dyn_cast(); if (rhsRankedType) { + // dot_general_c7 + // dot_general_c8 if (failed(checkDimsInRange(rhsRankedType.getRank(), rhsBatchingDimensions, "rhs_batching_dimensions")) || failed(checkDimsInRange(rhsRankedType.getRank(), @@ -1999,6 +2003,7 @@ LogicalResult inferDotGeneralOp( if (!llvm::is_contained(rhsBatchingDimensions, i) && !llvm::is_contained(rhsContractingDimensions, i)) dimensions.push_back(rhsShape[i]); + // dot_general_c12 inferredReturnShapes.emplace_back(dimensions); return success(); diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir index 8d743120301..a6ebfd781b7 100644 --- a/stablehlo/tests/ops_stablehlo.mlir +++ b/stablehlo/tests/ops_stablehlo.mlir @@ -3006,36 +3006,6 @@ func.func @dot_general_c2(%arg0: tensor, %arg1: tensor) -> // ----- -func.func @dot_general_c3(%arg0: tensor, %arg1: tensor) -> tensor { - // expected-error @+1 {{has duplicated dimension from lhs_batching_dimensions and lhs_contracting_dimensions: 0}} - %0 = "stablehlo.dot_general"(%arg0, %arg1) { - dot_dimension_numbers = #stablehlo.dot< - lhs_batching_dimensions = [0, 0], - rhs_batching_dimensions = [0, 0], - lhs_contracting_dimensions = [1], - rhs_contracting_dimensions = [1] - > - } : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -func.func @dot_general_c3(%arg0: tensor, %arg1: tensor) -> tensor { - // expected-error @+1 {{has duplicated dimension from lhs_batching_dimensions and lhs_contracting_dimensions: 1}} - %0 = "stablehlo.dot_general"(%arg0, %arg1) { - dot_dimension_numbers = #stablehlo.dot< - lhs_batching_dimensions = [0], - rhs_batching_dimensions = [0], - lhs_contracting_dimensions = [1, 1], - rhs_contracting_dimensions = [1, 1] - > - } : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - func.func @dot_general_c3(%arg0: tensor, %arg1: tensor) -> tensor { // expected-error @+1 {{has duplicated dimension from lhs_batching_dimensions and lhs_contracting_dimensions: 0}} %0 = "stablehlo.dot_general"(%arg0, %arg1) {