Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interpreter for DotGeneralOp #748

Merged
merged 5 commits into from
Jun 9, 2023

Conversation

ghpvnist
Copy link
Member

@ghpvnist ghpvnist commented Dec 13, 2022

We have the following non-quantization-related constraints (excluding C13, C15-C20) in the spec:

(I1) lhs tensor.
(I2) rhs tensor.
(I3) lhs_batching_dimensions 1-dimensional tensor constant of type `si64`.
(I4) rhs_batching_dimensions 1-dimensional tensor constant of type `si64`.
(I5) lhs_contracting_dimensions 1-dimensional tensor constant of type `si64`.
(I6) rhs_contracting_dimensions 1-dimensional tensor constant of type `si64`.
(I7) precision_config variadic number of enum of `DEFAULT`, `HIGH`, and `HIGHEST`.
(C1) size(`lhs_batching_dimensions`) = size(`rhs_batching_dimensions`).
(C2) size(`lhs_contracting_dimensions`) =
size(`rhs_contracting_dimensions`).
(C3) `lhs_batching_dimensions` and `lhs_contracting_dimensions` combined are
unique.
(C4) `rhs_batching_dimensions` and `rhs_contracting_dimensions` combined are
unique.
(C5) 0 <= `lhs_batching_dimensions[i]` < rank(`lhs`) for all `i`
in [0, size(`lhs_batching_dimensions`)).
(C6) 0 <= `lhs_contracting_dimensions[i]` < rank(`lhs`) for all `i`
in [0, size(`lhs_contracting_dimensions`)).
(C7) 0 <= `rhs_batching_dimensions[i]` < rank(`rhs`) for all `i`
in [0, size(`rhs_batching_dimensions`)).
(C8) 0 <= `rhs_contracting_dimensions[i]` < rank(`rhs`) for all `i`
in [0, size(`rhs_contracting_dimensions`)).
(C9) dim(`lhs`, `lhs_batching_dimensions[i]`) =
dim(`rhs`, `rhs_batching_dimensions[i]`) for all `i` in [0,
size(`lhs_batching_dimensions`)).
(C10) dim(`lhs`, `lhs_contracting_dimensions[i]`) =
dim(`rhs`, `rhs_contracting_dimensions[i]`) for all `i` in [0,
size(`lhs_contracting_dimensions`)).
(C11) size(`precision_config`) = 2.
(C12) shape(`result`) = dim(`lhs`, `lhs_batching_dimensions`) +
dim(`lhs`, `lhs_result_dimensions`) + dim(`rhs`, `rhs_result_dimensions`).
(C14) element_type(`lhs`) = element_type(`rhs`).

These constraints will be comprehensively covered by the following tests:

I1: a) lhs is not a tensor. (Covered by ODS).
I2: a) rhs is not a tensor. (Covered by ODS).
I3: a) lhs_batching_dimensions is not a 1-dimensional tensor. (Covered by ODS).
    b) element_type(lhs_batching_dimesnions) != `si64`. (Covered by ODS).
I4: a) rhs_batching_dimensions is not a 1-dimensional tensor. (Covered by ODS).
    b) element_type(rhs_batching_dimesnions) != `si64`. (Covered by ODS).
I5: a) lhs_contracting_dimensions is not a 1-dimensional tensor. (Covered by ODS).
    b) element_type(lhs_contracting_dimensions) != `si64`. (Covered by ODS).
I6: a) rhs_contracting_dimensions is not a 1-dimensional tensor. (Covered by ODS).
    b) element_type(rhs_contracting_dimensions) != `si64`. (Covered by ODS).
I7: a) precision_config does not have variadic number of enum of `DEFAULT`, `HIGH`, and `HIGHEST`. (Covered by ODS).
C1: a) size(lhs_batching_dimensions) != size(rhs_batching_dimensions).
C2: a) size(lhs_contracting_dimensions) != size(rhs_contracting_dimensions).
C3: a) lhs_batching_dimensions and lhs_contracting_dimensions combined are not unique.
C4: a) rhs_batching_dimensions and rhs_contracting_dimensions combined are not unique.
C5: a) lhs_batching_dimensions[i] < 0 for any i.
    b) lhs_batching_dimensions[i] >= rank(lhs) for any i.
C6: a) lhs_contracting_dimensions[i] < 0 for any i.
    b) lhs_contracting_dimensions[i] >= rank(lhs) for any i.
C7: a) rhs_batching_dimensions[i] < 0 for any i.
    b) rhs_batching_dimensions[i] >= rank(rhs) for any i.
C8: a) rhs_contracting_dimensions[i] < 0 for any i.
    b) rhs_contracting_dimensions[i] >= rank(rhs) for any i.
C9: a) dim(lhs, lhs_batching_dimensions[i]) != dim(rhs, rhs_batching_dimensions[i]) for any i.
C10: a) dim(lhs, lhs_contracting_dimensions[i]) != dim(rhs, rhs_contracting_dimensions[i]) for any i.
C11: a) size(precision_config) != 2.
C12: no negative test needed since it's just inferring the shape.
C14: a) element_type(lhs) != element_type(rhs).

If we drop the "Covered by ODS" pieces, this will leave us with the following test cases:

C1a: size(lhs_batching_dimensions) != size(rhs_batching_dimensions).
C2a: size(lhs_contracting_dimensions) != size(rhs_contracting_dimensions).
C3a: lhs_batching_dimensions and lhs_contracting_dimensions combined are not unique.
C4a: rhs_batching_dimensions and rhs_contracting_dimensions combined are not unique.
C5a: lhs_batching_dimensions[i] < 0 for any i.
C5b: lhs_batching_dimensions[i] >= rank(lhs) for any i.
C6a: lhs_contracting_dimensions[i] < 0 for any i.
C6b: lhs_contracting_dimensions[i] >= rank(lhs) for any i.
C7a: rhs_batching_dimensions[i] < 0 for any i.
C7b: rhs_batching_dimensions[i] >= rank(rhs) for any i.
C8a: rhs_contracting_dimensions[i] < 0 for any i.
C8b: rhs_contracting_dimensions[i] >= rank(rhs) for any i.
C9a: dim(lhs, lhs_batching_dimensions[i]) != dim(rhs, rhs_batching_dimensions[i]) for any i.
C10a: dim(lhs, lhs_contracting_dimensions[i]) != dim(rhs, rhs_contracting_dimensions[i]) for any i.
C11a: size(precision_config) != 2.
C14a: element_type(lhs) != element_type(rhs).

Notes:

closes #336

@ghpvnist ghpvnist self-assigned this Dec 13, 2022
@ghpvnist ghpvnist force-pushed the dot_general_interpreter branch from 12a9f55 to 3b009b8 Compare January 8, 2023 08:29
@ghpvnist ghpvnist force-pushed the dot_general_interpreter branch from 3b009b8 to 1d58458 Compare February 7, 2023 06:39
@ghpvnist
Copy link
Member Author

ghpvnist commented Feb 7, 2023

C11 currently does not have a test due to #755 and #879 and similarly with C14 due to #1354

@ghpvnist ghpvnist force-pushed the dot_general_interpreter branch from 1d58458 to 1fbd315 Compare February 7, 2023 06:43
@ghpvnist ghpvnist requested a review from sdasgup3 February 7, 2023 06:44
@ghpvnist ghpvnist assigned sdasgup3 and unassigned ghpvnist Feb 7, 2023
@ghpvnist ghpvnist force-pushed the dot_general_interpreter branch from 1fbd315 to 76185f8 Compare February 7, 2023 19:40
@ghpvnist ghpvnist marked this pull request as draft February 8, 2023 00:17
@ghpvnist ghpvnist assigned ghpvnist and unassigned sdasgup3 Feb 8, 2023
@ghpvnist ghpvnist force-pushed the dot_general_interpreter branch from 76185f8 to 901de0f Compare February 15, 2023 00:43
@ghpvnist ghpvnist added the Migrate to MHLO PR that needs to be migrated to MLIR-HLO label Feb 16, 2023
@ghpvnist ghpvnist force-pushed the dot_general_interpreter branch 2 times, most recently from 406b8f3 to 99522d1 Compare March 24, 2023 23:30
@ghpvnist ghpvnist force-pushed the dot_general_interpreter branch from 99522d1 to f81b418 Compare April 5, 2023 17:32
@ghpvnist
Copy link
Member Author

ghpvnist commented Apr 5, 2023

This is not a reference implementation to the spec, but decided to keep it as the reference implementation requires reliance on MLIR builders (to create ReduceOp) which affects readability. Thoughts? @burmako

@ghpvnist ghpvnist marked this pull request as ready for review April 5, 2023 17:39
@ghpvnist ghpvnist assigned sdasgup3 and unassigned ghpvnist Apr 5, 2023
@ghpvnist ghpvnist requested a review from sdasgup3 April 5, 2023 17:40
@burmako
Copy link
Contributor

burmako commented Apr 5, 2023

Happy to review as is, and then let's discuss that as part of the review?

stablehlo/reference/Element.cpp Outdated Show resolved Hide resolved
stablehlo/tests/interpret_dot_general.mlir Show resolved Hide resolved
stablehlo/tests/interpret_dot_general.mlir Show resolved Hide resolved
stablehlo/tests/interpret_dot_general.mlir Show resolved Hide resolved
stablehlo/reference/Ops.cpp Show resolved Hide resolved
stablehlo/reference/Ops.cpp Show resolved Hide resolved
@ghpvnist ghpvnist force-pushed the dot_general_interpreter branch 2 times, most recently from 1dfff45 to 072ea54 Compare May 31, 2023 22:03
@ghpvnist ghpvnist requested a review from burmako May 31, 2023 22:21
@ghpvnist ghpvnist assigned burmako and unassigned ghpvnist May 31, 2023
stablehlo/dialect/TypeInference.cpp Show resolved Hide resolved
stablehlo/dialect/StablehloOps.td Show resolved Hide resolved
stablehlo/dialect/StablehloOps.td Show resolved Hide resolved
stablehlo/tests/infer_stablehlo.mlir Show resolved Hide resolved
stablehlo/tests/interpret_dot_general.mlir Show resolved Hide resolved
stablehlo/tests/ops_stablehlo.mlir Outdated Show resolved Hide resolved
stablehlo/tests/ops_stablehlo.mlir Outdated Show resolved Hide resolved
@burmako burmako assigned ghpvnist and unassigned burmako Jun 3, 2023
@ghpvnist ghpvnist force-pushed the dot_general_interpreter branch from 072ea54 to 361eb59 Compare June 7, 2023 22:58
@ghpvnist ghpvnist assigned burmako and unassigned ghpvnist Jun 7, 2023
@ghpvnist ghpvnist requested a review from burmako June 7, 2023 23:29
stablehlo/dialect/TypeInference.cpp Show resolved Hide resolved
stablehlo/dialect/StablehloOps.td Show resolved Hide resolved
stablehlo/tests/ops_stablehlo.mlir Outdated Show resolved Hide resolved
@burmako burmako assigned ghpvnist and unassigned burmako Jun 8, 2023
@ghpvnist ghpvnist requested a review from burmako June 8, 2023 17:36
@ghpvnist ghpvnist assigned burmako and unassigned ghpvnist Jun 8, 2023
@burmako burmako assigned ghpvnist and unassigned burmako Jun 9, 2023
@ghpvnist ghpvnist force-pushed the dot_general_interpreter branch from b99c8bc to 0b21091 Compare June 9, 2023 17:42
@ghpvnist ghpvnist merged commit 7410641 into openxla:main Jun 9, 2023
@ghpvnist ghpvnist deleted the dot_general_interpreter branch June 9, 2023 18:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Interpreter Migrate to MHLO PR that needs to be migrated to MLIR-HLO
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

Add interpreter for dot_general
3 participants