Skip to content

Commit

Permalink
Address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
ghpvnist committed Jun 9, 2023
1 parent f88a6fd commit 0b21091
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 35 deletions.
15 changes: 10 additions & 5 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1891,6 +1891,7 @@ LogicalResult inferDotGeneralOp(
return emitOptionalError(location,
"lhs and rhs should have the same "
"number of batching dimensions");

// dot_general_c2
if (lhsContractingDimensions.size() != rhsContractingDimensions.size())
return emitOptionalError(location,
Expand All @@ -1911,6 +1912,7 @@ LogicalResult inferDotGeneralOp(
}
return success();
};

// dot_general_c3
if (failed(checkDimsDistinct(lhsBatchingDimensions, lhsContractingDimensions,
dimSet, "lhs_batching_dimensions",
Expand All @@ -1934,21 +1936,23 @@ LogicalResult inferDotGeneralOp(
" is out of range: ", "[0, ", rank, ")");
return success();
};

auto lhsRankedType = lhsType.dyn_cast<RankedTensorType>();
auto rhsRankedType = rhsType.dyn_cast<RankedTensorType>();
// dot_general_c5
// dot_general_c6
if (lhsRankedType) {
// dot_general_c5
// dot_general_c6
if (failed(checkDimsInRange(lhsRankedType.getRank(), lhsBatchingDimensions,
"lhs_batching_dimensions")) ||
failed(checkDimsInRange(lhsRankedType.getRank(),
lhsContractingDimensions,
"lhs_contracting_dimensions")))
return failure();
}
// dot_general_c7
// dot_general_c8

auto rhsRankedType = rhsType.dyn_cast<RankedTensorType>();
if (rhsRankedType) {
// dot_general_c7
// dot_general_c8
if (failed(checkDimsInRange(rhsRankedType.getRank(), rhsBatchingDimensions,
"rhs_batching_dimensions")) ||
failed(checkDimsInRange(rhsRankedType.getRank(),
Expand Down Expand Up @@ -1999,6 +2003,7 @@ LogicalResult inferDotGeneralOp(
if (!llvm::is_contained(rhsBatchingDimensions, i) &&
!llvm::is_contained(rhsContractingDimensions, i))
dimensions.push_back(rhsShape[i]);

// dot_general_c12
inferredReturnShapes.emplace_back(dimensions);
return success();
Expand Down
30 changes: 0 additions & 30 deletions stablehlo/tests/ops_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3006,36 +3006,6 @@ func.func @dot_general_c2(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) ->

// -----

func.func @dot_general_c3(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
// expected-error @+1 {{has duplicated dimension from lhs_batching_dimensions and lhs_contracting_dimensions: 0}}
%0 = "stablehlo.dot_general"(%arg0, %arg1) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0, 0],
rhs_batching_dimensions = [0, 0],
lhs_contracting_dimensions = [1],
rhs_contracting_dimensions = [1]
>
} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
func.return %0 : tensor<?x?x?xf32>
}

// -----

func.func @dot_general_c3(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
// expected-error @+1 {{has duplicated dimension from lhs_batching_dimensions and lhs_contracting_dimensions: 1}}
%0 = "stablehlo.dot_general"(%arg0, %arg1) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [1, 1],
rhs_contracting_dimensions = [1, 1]
>
} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
func.return %0 : tensor<?x?x?xf32>
}

// -----

func.func @dot_general_c3(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
// expected-error @+1 {{has duplicated dimension from lhs_batching_dimensions and lhs_contracting_dimensions: 0}}
%0 = "stablehlo.dot_general"(%arg0, %arg1) {
Expand Down

0 comments on commit 0b21091

Please sign in to comment.