Skip to content

Commit

Permalink
Add interpreter for DotGeneralOp
Browse files Browse the repository at this point in the history
  • Loading branch information
ghpvnist committed Feb 7, 2023
1 parent 34914e2 commit 76185f8
Show file tree
Hide file tree
Showing 9 changed files with 333 additions and 78 deletions.
2 changes: 1 addition & 1 deletion docs/status.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ one of the following tracking labels.
| custom_call | yes | yes | infeasible | yes | no |
| divide | yes | yes | yes | yes | no |
| dot | no | revisit | infeasible | yes | no |
| dot_general | yes | revisit | infeasible | no | no |
| dot_general | yes | revisit | infeasible | no | yes |
| dynamic_broadcast_in_dim | no | revisit | infeasible | no | no |
| dynamic_conv | no | revisit | no | no | no |
| dynamic_gather | no | revisit | revisit | no | no |
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2200,7 +2200,7 @@ def StableHLO_DotOp: StableHLO_Op<"dot", [Pure]> {
}

def StableHLO_DotGeneralOp: StableHLO_ShapedInterfaceOp<"dot_general", [Pure]> {
let summary = "General Dot operator";
let summary = "DotGeneral operation";
let description = [{
Computes dot products between slices of `lhs` and slices of `rhs` and
produces a `result` tensor.
Expand Down
15 changes: 11 additions & 4 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1900,10 +1900,12 @@ LogicalResult inferDotGeneralOp(
if (failed(verifyPrecisionConfig(location, precisionConfig)))
return failure();

// dot_general_c2
if (lhsBatchingDimensions.size() != rhsBatchingDimensions.size())
return emitOptionalError(location,
"lhs and rhs should have the same "
"number of batching dimensions");
// dot_general_c3
if (lhsContractingDimensions.size() != rhsContractingDimensions.size())
return emitOptionalError(location,
"lhs and rhs should have the same "
Expand All @@ -1923,14 +1925,14 @@ LogicalResult inferDotGeneralOp(
}
return success();
};

// dot_general_c4
if (failed(checkDimsDistinct(lhsBatchingDimensions, lhsContractingDimensions,
dimSet, "lhs_batching_dimensions",
"lhs_contracting_dimensions")))
return failure();

dimSet.clear();

// dot_general_c5
if (failed(checkDimsDistinct(rhsBatchingDimensions, rhsContractingDimensions,
dimSet, "rhs_batching_dimensions",
"rhs_contracting_dimensions")))
Expand All @@ -1948,7 +1950,8 @@ LogicalResult inferDotGeneralOp(
};
auto lhsRankedType = lhsType.dyn_cast<RankedTensorType>();
auto rhsRankedType = rhsType.dyn_cast<RankedTensorType>();

// dot_general_c6
// dot_general_c7
if (lhsRankedType) {
if (failed(checkDimsInRange(lhsRankedType.getRank(), lhsBatchingDimensions,
"lhs_batching_dimensions")) ||
Expand All @@ -1957,6 +1960,8 @@ LogicalResult inferDotGeneralOp(
"lhs_contracting_dimensions")))
return failure();
}
// dot_general_c8
// dot_general_c9
if (rhsRankedType) {
if (failed(checkDimsInRange(rhsRankedType.getRank(), rhsBatchingDimensions,
"rhs_batching_dimensions")) ||
Expand All @@ -1972,6 +1977,7 @@ LogicalResult inferDotGeneralOp(

for (auto [lhs, rhs] :
llvm::zip(lhsBatchingDimensions, rhsBatchingDimensions)) {
// dot_general_c10
if (!verifyCompatibleDims(lhsShape[lhs], rhsShape[rhs]))
return emitOptionalError(location,
"batching dimension sizes must "
Expand All @@ -1980,6 +1986,7 @@ LogicalResult inferDotGeneralOp(

for (auto [lhs, rhs] :
llvm::zip(lhsContractingDimensions, rhsContractingDimensions)) {
// dot_general_c11
if (!verifyCompatibleDims(lhsShape[lhs], rhsShape[rhs]))
return emitOptionalError(location,
"contracting dimension sizes must "
Expand All @@ -2006,7 +2013,7 @@ LogicalResult inferDotGeneralOp(
if (!llvm::is_contained(rhsBatchingDimensions, i) &&
!llvm::is_contained(rhsContractingDimensions, i))
dimensions.push_back(rhsShape[i]);

// dot_general_c13
inferredReturnShapes.emplace_back(dimensions);
return success();
}
Expand Down
1 change: 1 addition & 0 deletions stablehlo/reference/Element.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Element {
: type_(type), value_(std::make_pair(value.real(), value.imag())) {}

Element(const Element &other) = default;
Element() = default;
/// @}

/// Assignment operator.
Expand Down
16 changes: 16 additions & 0 deletions stablehlo/reference/Interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,22 @@ llvm::Expected<SmallVector<Tensor>> eval(func::FuncOp func,
Tensor runtimeOperand = fetchOperand(cosineOp.getOperand());
Tensor runtimeResult = evalCosineOp(runtimeOperand, cosineOp.getType());
populateResults({runtimeResult});
} else if (auto dotGeneralOp = dyn_cast<DotGeneralOp>(op)) {
Tensor runtimeLhs = fetchOperand(dotGeneralOp.getLhs());
Tensor runtimeRhs = fetchOperand(dotGeneralOp.getRhs());
auto lhsBatchingDimensions =
dotGeneralOp.getDotDimensionNumbers().getLhsBatchingDimensions();
auto rhsBatchingDimensions =
dotGeneralOp.getDotDimensionNumbers().getRhsBatchingDimensions();
auto lhsContractingDimensions =
dotGeneralOp.getDotDimensionNumbers().getLhsContractingDimensions();
auto rhsContractingDimensions =
dotGeneralOp.getDotDimensionNumbers().getRhsContractingDimensions();
Tensor runtimeResult =
evalDotGeneralOp(runtimeLhs, runtimeRhs, lhsBatchingDimensions,
rhsBatchingDimensions, lhsContractingDimensions,
rhsContractingDimensions, dotGeneralOp.getType());
populateResults({runtimeResult});
} else if (auto floorOp = dyn_cast<FloorOp>(op)) {
Tensor runtimeOperand = fetchOperand(floorOp.getOperand());
Tensor runtimeResult = evalFloorOp(runtimeOperand, floorOp.getType());
Expand Down
92 changes: 92 additions & 0 deletions stablehlo/reference/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,98 @@ Tensor evalCosineOp(const Tensor &operand, Type resultType) {
return result;
}

Tensor evalDotGeneralOp(const Tensor &lhs, const Tensor &rhs,
ArrayRef<int64_t> lhsBatchingDimensions,
ArrayRef<int64_t> rhsBatchingDimensions,
ArrayRef<int64_t> lhsContractingDimensions,
ArrayRef<int64_t> rhsContractingDimensions,
Type resultType) {
Tensor result(resultType);
SmallVector<int64_t> lhsNonBatchingNonContractingDims;
SmallVector<int64_t> rhsNonBatchingNonContractingDims;
for (auto i = 0; i < lhs.getType().getRank(); ++i)
if (!llvm::is_contained(lhsBatchingDimensions, i) &&
!llvm::is_contained(lhsContractingDimensions, i))
lhsNonBatchingNonContractingDims.push_back(i);
for (auto i = 0; i < rhs.getType().getRank(); ++i)
if (!llvm::is_contained(rhsBatchingDimensions, i) &&
!llvm::is_contained(rhsContractingDimensions, i))
rhsNonBatchingNonContractingDims.push_back(i);

SmallVector<int64_t> contractingDimSizes;
auto totalContractingSize = 1;
for (uint64_t i = 0; i < lhsContractingDimensions.size(); ++i) {
contractingDimSizes.push_back(
lhs.getType().getShape()[lhsContractingDimensions[i]]);
totalContractingSize *=
lhs.getType().getShape()[lhsContractingDimensions[i]];
}
auto initializeElementToZero = [&]() {
auto resultElTy = resultType.dyn_cast<ShapedType>().getElementType();
Element sum;
if (isSupportedSignedIntegerType(resultElTy)) {
sum = Element(resultElTy, APInt(resultElTy.getIntOrFloatBitWidth(), 0,
/*isSigned=*/true));
} else if (isSupportedUnsignedIntegerType(resultElTy)) {
sum = Element(resultElTy, APInt(resultElTy.getIntOrFloatBitWidth(), 0,
/*isSigned=*/false));
} else if (isSupportedBooleanType(resultElTy)) {
sum = Element(resultElTy, false);
} else if (isSupportedFloatType(resultElTy)) {
APFloat val((double)0.0);
bool roundingErr;
val.convert(resultElTy.cast<FloatType>().getFloatSemantics(),
APFloat::rmNearestTiesToEven, &roundingErr);
sum = Element(resultElTy, val);
} else if (isSupportedComplexType(resultElTy)) {
APFloat real((double)0.0);
APFloat imag((double)0.0);
auto flType =
resultElTy.cast<ComplexType>().getElementType().cast<FloatType>();
bool roundingErr;
real.convert(flType.getFloatSemantics(), APFloat::rmNearestTiesToEven,
&roundingErr);
imag.convert(flType.getFloatSemantics(), APFloat::rmNearestTiesToEven,
&roundingErr);
sum = Element(resultElTy, std::complex<APFloat>(real, imag));
}
return sum;
};
for (auto resultItr = result.index_begin(); resultItr != result.index_end();
++resultItr) {
SmallVector<int64_t> lhsIdx(lhs.getType().getRank());
SmallVector<int64_t> rhsIdx(rhs.getType().getRank());
int64_t resultDim = 0;
// Indices do not change for batching dimensions.
for (uint64_t i = 0; i < lhsBatchingDimensions.size(); ++i, ++resultDim) {
lhsIdx[lhsBatchingDimensions[i]] = (*resultItr)[resultDim];
rhsIdx[rhsBatchingDimensions[i]] = (*resultItr)[resultDim];
}
// The non-batching, non-contracting dimensions of the operands are copied
// over to the result dimensions following the batching dimensions.
for (uint64_t i = 0; i < lhsNonBatchingNonContractingDims.size(); i++)
lhsIdx[lhsNonBatchingNonContractingDims[i]] = (*resultItr)[resultDim++];
for (uint64_t i = 0; i < rhsNonBatchingNonContractingDims.size(); i++)
rhsIdx[rhsNonBatchingNonContractingDims[i]] = (*resultItr)[resultDim++];
Element sum = initializeElementToZero();
// All pair-wise combination of contracting dimensions needs to be summed.
for (auto i = 0; i < totalContractingSize; ++i) {
sum = sum + lhs.get(lhsIdx) * rhs.get(rhsIdx);
if (contractingDimSizes.empty()) continue;
for (int64_t j = contractingDimSizes.size() - 1; j >= 0; --j) {
lhsIdx[lhsContractingDimensions[j]]++;
rhsIdx[rhsContractingDimensions[j]]++;
if (lhsIdx[lhsContractingDimensions[j]] != contractingDimSizes[j])
break;
lhsIdx[lhsContractingDimensions[j]] = 0;
rhsIdx[rhsContractingDimensions[j]] = 0;
}
}
result.set((*resultItr), sum);
}
return result;
}

Tensor evalFloorOp(const Tensor &operand, Type resultType) {
Tensor result(resultType);
for (auto it = result.index_begin(); it != result.index_end(); ++it)
Expand Down
6 changes: 6 additions & 0 deletions stablehlo/reference/Ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ Tensor evalAndOp(const Tensor &lhs, const Tensor &rhs, Type resultType);
Tensor evalCeilOp(const Tensor &operand, Type resultType);
Tensor evalConstantOp(ElementsAttr value);
Tensor evalCosineOp(const Tensor &operand, Type resultType);
Tensor evalDotGeneralOp(const Tensor &lhs, const Tensor &rhs,
ArrayRef<int64_t> lhsBatchingDimensions,
ArrayRef<int64_t> rhsBatchingDimensions,
ArrayRef<int64_t> lhsContractingDimensions,
ArrayRef<int64_t> rhsContractingDimensions,
Type resultType);
Tensor evalFloorOp(const Tensor &operand, Type resultType);
Tensor evalIotaOp(int64_t iotaDimension, Type resultType);
Tensor evalMaxOp(const Tensor &lhs, const Tensor &rhs, Type resultType);
Expand Down
130 changes: 130 additions & 0 deletions stablehlo/tests/interpret_dot_general.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
// RUN: stablehlo-interpreter --interpret -split-input-file %s | FileCheck %s

// CHECK-LABEL: Evaluated results of function: dot_general_op_test_si64
func.func @dot_general_op_test_si64() -> tensor<2x2xi64> {
%0 = stablehlo.constant dense<[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]> : tensor<2x2x2xi64>
%1 = stablehlo.constant dense<[[[1, 0], [0, 1]], [[1, 0], [0, 1]]]> : tensor<2x2x2xi64>
%2 = "stablehlo.dot_general"(%0, %1) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0, 2],
rhs_batching_dimensions = [2, 1],
lhs_contracting_dimensions = [1],
rhs_contracting_dimensions = [0]
>,
precision_config = [#stablehlo<precision DEFAULT>]
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2xi64>
func.return %2 : tensor<2x2xi64>
// CHECK-NEXT: tensor<2x2xi64>
// CHECK-NEXT: 4 : i64
// CHECK-NEXT: 0 : i64
// CHECK-NEXT: 0 : i64
// CHECK-NEXT: 14 : i64
}

// -----

// CHECK-LABEL: Evaluated results of function: dot_general_op_test_ui64
func.func @dot_general_op_test_ui64() -> tensor<2x2x2xui64> {
%0 = stablehlo.constant dense<[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]> : tensor<2x2x2xui64>
%1 = stablehlo.constant dense<[[[1, 0], [0, 1]], [[1, 0], [0, 1]]]> : tensor<2x2x2xui64>
%2 = "stablehlo.dot_general"(%0, %1) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1]
>,
precision_config = [#stablehlo<precision DEFAULT>]
} : (tensor<2x2x2xui64>, tensor<2x2x2xui64>) -> tensor<2x2x2xui64>
func.return %2 : tensor<2x2x2xui64>
// CHECK-NEXT: tensor<2x2x2xui64>
// CHECK-NEXT: 1 : ui64
// CHECK-NEXT: 2 : ui64
// CHECK-NEXT: 3 : ui64
// CHECK-NEXT: 4 : ui64
// CHECK-NEXT: 5 : ui64
// CHECK-NEXT: 6 : ui64
// CHECK-NEXT: 7 : ui64
// CHECK-NEXT: 8 : ui64
}

// -----

// CHECK-LABEL: Evaluated results of function: dot_general_op_test_i1
func.func @dot_general_op_test_i1() -> tensor<2x2x2xi1> {
%0 = stablehlo.constant dense<[[[true, true], [true, true]], [[false, false], [false, false]]]> : tensor<2x2x2xi1>
%1 = stablehlo.constant dense<[[[true, false], [false, true]], [[true, false], [false, true]]]> : tensor<2x2x2xi1>
%2 = "stablehlo.dot_general"(%0, %1) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1]
>,
precision_config = [#stablehlo<precision DEFAULT>]
} : (tensor<2x2x2xi1>, tensor<2x2x2xi1>) -> tensor<2x2x2xi1>
func.return %2 : tensor<2x2x2xi1>
// CHECK-NEXT: tensor<2x2x2xi1>
// CHECK-NEXT: true
// CHECK-NEXT: true
// CHECK-NEXT: true
// CHECK-NEXT: true
// CHECK-NEXT: false
// CHECK-NEXT: false
// CHECK-NEXT: false
// CHECK-NEXT: false
}

// -----

// CHECK-LABEL: Evaluated results of function: dot_general_op_test_f64
func.func @dot_general_op_test_f64() -> tensor<2x2x2xf64> {
%0 = stablehlo.constant dense<[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]> : tensor<2x2x2xf64>
%1 = stablehlo.constant dense<[[[1.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [0.0, 1.0]]]> : tensor<2x2x2xf64>
%2 = "stablehlo.dot_general"(%0, %1) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1]
>,
precision_config = [#stablehlo<precision DEFAULT>]
} : (tensor<2x2x2xf64>, tensor<2x2x2xf64>) -> tensor<2x2x2xf64>
func.return %2 : tensor<2x2x2xf64>
// CHECK-NEXT: tensor<2x2x2xf64>
// CHECK-NEXT: 1.000000e+00 : f64
// CHECK-NEXT: 2.000000e+00 : f64
// CHECK-NEXT: 3.000000e+00 : f64
// CHECK-NEXT: 4.000000e+00 : f64
// CHECK-NEXT: 5.000000e+00 : f64
// CHECK-NEXT: 6.000000e+00 : f64
// CHECK-NEXT: 7.000000e+00 : f64
// CHECK-NEXT: 8.000000e+00 : f64
}

// -----

// CHECK-LABEL: Evaluated results of function: dot_general_op_test_c128
func.func @dot_general_op_test_c128() -> tensor<2x2x2xcomplex<f64>> {
%0 = stablehlo.constant dense<[[[(1.0, 0.0), (2.0, 0.0)], [(3.0, 0.0), (4.0, 0.0)]], [[(5.0, 0.0), (6.0, 0.0)], [(7.0, 0.0), (8.0, 0.0)]]]> : tensor<2x2x2xcomplex<f64>>
%1 = stablehlo.constant dense<[[[(1.0, 0.0), (0.0, 0.0)], [(0.0, 0.0), (1.0, 0.0)]], [[(1.0, 0.0), (0.0, 0.0)], [(0.0, 0.0), (1.0, 0.0)]]]> : tensor<2x2x2xcomplex<f64>>
%2 = "stablehlo.dot_general"(%0, %1) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1]
>,
precision_config = [#stablehlo<precision DEFAULT>]
} : (tensor<2x2x2xcomplex<f64>>, tensor<2x2x2xcomplex<f64>>) -> tensor<2x2x2xcomplex<f64>>
func.return %2 : tensor<2x2x2xcomplex<f64>>
// CHECK-NEXT: tensor<2x2x2xcomplex<f64>>
// CHECK-NEXT: [1.000000e+00 : f64, 0.000000e+00 : f64]
// CHECK-NEXT: [2.000000e+00 : f64, 0.000000e+00 : f64]
// CHECK-NEXT: [3.000000e+00 : f64, 0.000000e+00 : f64]
// CHECK-NEXT: [4.000000e+00 : f64, 0.000000e+00 : f64]
// CHECK-NEXT: [5.000000e+00 : f64, 0.000000e+00 : f64]
// CHECK-NEXT: [6.000000e+00 : f64, 0.000000e+00 : f64]
// CHECK-NEXT: [7.000000e+00 : f64, 0.000000e+00 : f64]
// CHECK-NEXT: [8.000000e+00 : f64, 0.000000e+00 : f64]
}
Loading

0 comments on commit 76185f8

Please sign in to comment.