Skip to content
This repository has been archived by the owner on Oct 25, 2023. It is now read-only.

[TUZ-6] Add a direct Onnx to Relax Importer #14

Merged
merged 39 commits into from
Feb 17, 2023
Merged

[TUZ-6] Add a direct Onnx to Relax Importer #14

merged 39 commits into from
Feb 17, 2023

Conversation

jwfromm
Copy link
Contributor

@jwfromm jwfromm commented Feb 3, 2023

This PR is the culmination of work for the Onnx importer epic. It implements a converter similar in spirit to the Onnx -> Relay importer and even reuses many of the operator converters directly. Other operators instead are converted by directly emitting tensorir functions using topi.

What should we put here?

What are the right tests to run?

This PR depends on importing onnx, how should we update CI to support new dependencies?

python/tvm/relax/frontend/onnx_frontend.py Outdated Show resolved Hide resolved
python/tvm/relax/frontend/onnx_frontend.py Outdated Show resolved Hide resolved
python/tvm/relax/frontend/onnx_frontend.py Outdated Show resolved Hide resolved
python/tvm/relax/frontend/onnx_frontend.py Outdated Show resolved Hide resolved
python/tvm/relax/frontend/onnx_frontend.py Outdated Show resolved Hide resolved
tests/python/relax/frontend/test_onnx_frontend.py Outdated Show resolved Hide resolved
tests/python/relax/frontend/test_onnx_frontend.py Outdated Show resolved Hide resolved
python/tvm/relax/frontend/onnx_frontend.py Outdated Show resolved Hide resolved
python/tvm/relax/frontend/onnx_frontend.py Outdated Show resolved Hide resolved
python/tvm/relax/frontend/onnx_frontend.py Outdated Show resolved Hide resolved
python/tvm/relax/frontend/onnx_frontend.py Outdated Show resolved Hide resolved
name = dim.dim_param
value = dim.dim_value
if value is None or value == 0:
value = tvm.tir.Var("d", "int64")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why name the var "d"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we usually use 'd' for dynamic, although I'm not sure..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah its just a stand-in dynamic shape. I've updated it to be dyn for more clarity.

python/tvm/relax/frontend/onnx_frontend.py Outdated Show resolved Hide resolved
python/tvm/relax/frontend/onnx_frontend.py Outdated Show resolved Hide resolved
python/tvm/relax/frontend/onnx_frontend.py Outdated Show resolved Hide resolved
python/tvm/relax/frontend/onnx_frontend.py Outdated Show resolved Hide resolved
python/tvm/relax/frontend/onnx_frontend.py Outdated Show resolved Hide resolved
@slyubomirsky
Copy link
Contributor

slyubomirsky commented Feb 3, 2023

My overall high-level reactions (I will leave finer-grained comments in the code) from the perspective of documentation:

Gut reaction for what docs we would expect

  • A high-level comment outlining the overall algorithm for the conversion, explaining the relevant portions of an ONNX model and what each component means (at the very least, have links to ONNX docs)
  • We probably don't need a fully fledged tutorial explaining how to add new operators because the implementation of the importer might change over time, but a list of steps outlined in a comment near the implementation might be nice
  • A possible tutorial might be showing how to use the front-end interface for the importer and explaining, broadly, what happens and how the ONNX model is transformed into a Relax program

Specific reactions

  • The operator implementations are generally straightforward and so is the convert map, definitely good in terms of being "self-documenting" code
  • The graph conversion is an area where things can become very complicated and depend on details of ONNX
  • Name sanitization is a very potentially tricky issue that should be outlined in detail. This has been an enormous source of bugs and confusion in practice with Relay, so I think it should be explained as clearly as possible and made as idiot-proof as possible
  • The tests for individual operators seem to be lacking, as far as the goal of reaching branch coverage.

How the importer should be tested

  • The importer is really a compiler, so tests should be checking differential equality: Make sure the importer version gets the same result as the original
  • Fuzz random inputs to individual operators as a start, aim for branch coverage
  • Also have some integration tests making sure that the operators can be built up as expected
  • Less useful form of testing: Making sure that some specific AST is produced given an input program. That is just coupling too closely with the implementation. Better to ensure the equivalence of the results



@tvm.register_func("relax.run.broadcast_to")
def torch_broadcast_to(data: tvm.nd.array, shape: tvm.nd.array) -> tvm.nd.array:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likely this is no longer needed if you can lower bcast to via emit_te

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, i dont think that lowering exists yet but as soon as it does we'll remove this function.

Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After taking a detailed look at the implementation, I've highlighted some areas in which the tests can be improved (many test cases do check all the branches, which is good). Overall, I stand by my earlier comments that there should be some outline of the overall structure of the converter. There should also be some explanation of how the existing Relay converter's methods are used (is that intended to remain permanently? Not sure it's good to depend on the Relay converter at all if we can avoid it, just another massive complication)

@@ -88,21 +88,22 @@ StructInfo InferStructInfoBroadcastTo(const Call& call, const BlockBuilder& ctx)
const auto* old_len_int = old_len.as<IntImmNode>();
if (old_len_int != nullptr && old_len_int->value == 1) {
continue;
} else if (!analyzer->CanProveEqual(old_len, tgt_len)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why this was commented out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah this portion of code is specifically requiring shapes to be known and checking if those static shapes are compatible for broadcasting. My opinion is that we should not enforce that since dynamic broadcast_to shows up pretty frequently. This should be fully removed rather than commented out though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. The intent of the original code was to require using a MatchCast to check the shape first. However, I agree with you that the operator implementation can do the checking, so that's unnecessary work for the user

python/tvm/relax/frontend/onnx_frontend.py Show resolved Hide resolved
python/tvm/relax/frontend/onnx_frontend.py Show resolved Hide resolved
python/tvm/relax/frontend/onnx_frontend.py Outdated Show resolved Hide resolved
check_correctness(model)


def test_const():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it would be good to have tests for multiple types or for the edge case discussed in the implementation (not supporting strings)

converter, which should be `_impl_vx`. Number x is the biggest
number smaller than or equal to opset belongs to all support versions.
"""
versions = [int(d.replace("_impl_v", "")) for d in dir(cls) if "_impl_v" in d]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might there be a simpler way to accomplish the dispatching by opset? E.g., as a property in the converter classes. Encoding it in the method names seems a little unusual. Is split the only operator that has multiple versions in this manner?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle, additionally, we should test the error case for rejecting an op due to version (we could test this function separately)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment -- yes, split is the only operator that has multiple versions. Probably in the future, this will not hold true.

I am not sure about using a different kind of encoding -- what is nice with this implementation is that each converter does not have to contain a different mapping opset version -> function implementation

tests/python/relax/frontend/test_onnx_frontend.py Outdated Show resolved Hide resolved
@jwfromm
Copy link
Contributor Author

jwfromm commented Feb 13, 2023

@driazati, it looks like the build is failing due to the image not having onnx. How should we update that as a dependency?

Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my principal concerns have been addressed and I thank those who made the changes. We can improve the documentation as we go along.

Josh Fromm and others added 15 commits February 16, 2023 13:13
* [Relax][Onnx] Implement Div, Sigmoid, Softmax, Transpose and Unsqueeze ops

* skip test_reshape

* [Relax][ONNX] Implement BiasGelu and Gelu ops

* [Relax][ONNX] Implement Where op
…hape / Not / Tanh (#3)

* Rebase w/ Equal, Not, Tanh, Sqrt, Relu, Clip, Conv, Pow, Erf.

* Fix cumsum but still needs work.
* Add squeeze.

* Add Constant.

* Add sub.
…tend (#8)

* [WIP] Support using Relay ops in the Relax ONNX frontend

Co-authored-by: Matthew Barrett  <[email protected]>
Co-authored-by: Michalis Papadimitriou <[email protected]>

* [WIP] small fixes

Co-authored-by: Matthew Barrett  <[email protected]>
Co-authored-by: Michalis Papadimitriou <[email protected]>

* [WIP] Support dynamic matmul and reshape

Co-authored-by: Matthew Barrett  <[email protected]>
Co-authored-by: Michalis Papadimitriou <[email protected]>

* Address PR comments

---------

Co-authored-by: Matthew Barrett  <[email protected]>
Co-authored-by: Michalis Papadimitriou <[email protected]>
* [WIP] add more ops. Some fail at the moment

* skip some tests

* Remove duplicate tests for squeeze
* [Relax][ONNX] Add Split op

* Remove tmp
@jwfromm
Copy link
Contributor Author

jwfromm commented Feb 17, 2023

I'm going to merge this to help unblock other efforts. There are still a few CI related issues we should try to fix. Specifically, it seems like CI failed despite all tests passing because some of those tests triggered intentional errors.

@jwfromm jwfromm merged commit 62d5ad1 into relax Feb 17, 2023
jwfromm pushed a commit that referenced this pull request Feb 22, 2023
* Initial importer and testing scaffolding.

* Implement matmul operator and tests.

* Add a bunch of new operators.

* Add new ops 

* [Relax][Onnx] Implement Div, Sigmoid, Softmax, Transpose and Unsqueeze ops

* skip test_reshape

* [Relax][ONNX] Implement BiasGelu and Gelu ops

* [Relax][ONNX] Implement Where op

* [Relax][ONNX] Add Multiple ONNX Frontend Support for Clip / Equal / Shape / Not / Tanh (#3)

* Rebase w/ Equal, Not, Tanh, Sqrt, Relu, Clip, Conv, Pow, Erf.

* Fix cumsum but still needs work.

* Fix initializer for CumSum. (#9)

* Add Constant, Squeeze & Sub (#10)

* Add squeeze.

* Add Constant.

* Add sub.

* Support reusing Relay ONNX operator convertors in the Relax ONNX frontend (#8)

* [WIP] Support using Relay ops in the Relax ONNX frontend

Co-authored-by: Matthew Barrett  <[email protected]>
Co-authored-by: Michalis Papadimitriou <[email protected]>

* [WIP] small fixes

Co-authored-by: Matthew Barrett  <[email protected]>
Co-authored-by: Michalis Papadimitriou <[email protected]>

* [WIP] Support dynamic matmul and reshape

Co-authored-by: Matthew Barrett  <[email protected]>
Co-authored-by: Michalis Papadimitriou <[email protected]>

* Address PR comments

---------

Co-authored-by: Matthew Barrett  <[email protected]>
Co-authored-by: Michalis Papadimitriou <[email protected]>

* Add more ops (including all Reduce ops) using the relay frontend (#11)

* [WIP] add more ops. Some fail at the moment

* skip some tests

* Remove duplicate tests for squeeze

* Add Split op in the Relax ONNX frontend (#12)

* [Relax][ONNX] Add Split op

* Remove tmp

* Fix layer normalizations and Shape operator.

* Replace main loop with tvm testing.

* Simplify Slice for opset 13.

* [Relax][ONNX] Implement pad op

* Incorporate pad op, add static constantofshape op.

* Changes to shape to temporarily enable constantofshape in our models.

* Add initial tensor_to_shape implementation.

* Implemented dynamic broadcast_to to support expand and constantofshape.

* Changes sufficient for vortex end to end run.

* Formatting.

* Format tests.

* Re-add broadcast_to shape checking.

* Fix formatting.

* Remove overly strict manipulate check.

* Fix typing

* [Relax][Onnx] Implement Tile operator

* Switch to native relax attention importer.

* Address some of the PR comments

* Check for the imported model IR version

* switch from torch to numpy due to some incompatibility

* Fix make format.

* Clean up typing issues.

* Clarify variable name.

* Remove unneeded comprehension.

* Remove circular dependency.

* Add name sanitization for inputs

* Disable reshape rewrite pass until fixed.

* Fix long comment

* Update cpu image.

---------

Co-authored-by: Florin Blanaru <[email protected]>
Co-authored-by: Xiyou Zhou <[email protected]>
Co-authored-by: Matthew Barrett  <[email protected]>
Co-authored-by: Michalis Papadimitriou <[email protected]>
Co-authored-by: Florin Blanaru <[email protected]>
Co-authored-by: sung <[email protected]>
jwfromm pushed a commit that referenced this pull request Feb 25, 2023
* Initial importer and testing scaffolding.

* Implement matmul operator and tests.

* Add a bunch of new operators.

* Add new ops 

* [Relax][Onnx] Implement Div, Sigmoid, Softmax, Transpose and Unsqueeze ops

* skip test_reshape

* [Relax][ONNX] Implement BiasGelu and Gelu ops

* [Relax][ONNX] Implement Where op

* [Relax][ONNX] Add Multiple ONNX Frontend Support for Clip / Equal / Shape / Not / Tanh (#3)

* Rebase w/ Equal, Not, Tanh, Sqrt, Relu, Clip, Conv, Pow, Erf.

* Fix cumsum but still needs work.

* Fix initializer for CumSum. (#9)

* Add Constant, Squeeze & Sub (#10)

* Add squeeze.

* Add Constant.

* Add sub.

* Support reusing Relay ONNX operator convertors in the Relax ONNX frontend (#8)

* [WIP] Support using Relay ops in the Relax ONNX frontend

Co-authored-by: Matthew Barrett  <[email protected]>
Co-authored-by: Michalis Papadimitriou <[email protected]>

* [WIP] small fixes

Co-authored-by: Matthew Barrett  <[email protected]>
Co-authored-by: Michalis Papadimitriou <[email protected]>

* [WIP] Support dynamic matmul and reshape

Co-authored-by: Matthew Barrett  <[email protected]>
Co-authored-by: Michalis Papadimitriou <[email protected]>

* Address PR comments

---------

Co-authored-by: Matthew Barrett  <[email protected]>
Co-authored-by: Michalis Papadimitriou <[email protected]>

* Add more ops (including all Reduce ops) using the relay frontend (#11)

* [WIP] add more ops. Some fail at the moment

* skip some tests

* Remove duplicate tests for squeeze

* Add Split op in the Relax ONNX frontend (#12)

* [Relax][ONNX] Add Split op

* Remove tmp

* Fix layer normalizations and Shape operator.

* Replace main loop with tvm testing.

* Simplify Slice for opset 13.

* [Relax][ONNX] Implement pad op

* Incorporate pad op, add static constantofshape op.

* Changes to shape to temporarily enable constantofshape in our models.

* Add initial tensor_to_shape implementation.

* Implemented dynamic broadcast_to to support expand and constantofshape.

* Changes sufficient for vortex end to end run.

* Formatting.

* Format tests.

* Re-add broadcast_to shape checking.

* Fix formatting.

* Remove overly strict manipulate check.

* Fix typing

* [Relax][Onnx] Implement Tile operator

* Switch to native relax attention importer.

* Address some of the PR comments

* Check for the imported model IR version

* switch from torch to numpy due to some incompatibility

* Fix make format.

* Clean up typing issues.

* Clarify variable name.

* Remove unneeded comprehension.

* Remove circular dependency.

* Add name sanitization for inputs

* Disable reshape rewrite pass until fixed.

* Fix long comment

* Update cpu image.

---------

Co-authored-by: Florin Blanaru <[email protected]>
Co-authored-by: Xiyou Zhou <[email protected]>
Co-authored-by: Matthew Barrett  <[email protected]>
Co-authored-by: Michalis Papadimitriou <[email protected]>
Co-authored-by: Florin Blanaru <[email protected]>
Co-authored-by: sung <[email protected]>
jwfromm pushed a commit that referenced this pull request Feb 28, 2023
* Initial importer and testing scaffolding.

* Implement matmul operator and tests.

* Add a bunch of new operators.

* Add new ops

* [Relax][Onnx] Implement Div, Sigmoid, Softmax, Transpose and Unsqueeze ops

* skip test_reshape

* [Relax][ONNX] Implement BiasGelu and Gelu ops

* [Relax][ONNX] Implement Where op

* [Relax][ONNX] Add Multiple ONNX Frontend Support for Clip / Equal / Shape / Not / Tanh (#3)

* Rebase w/ Equal, Not, Tanh, Sqrt, Relu, Clip, Conv, Pow, Erf.

* Fix cumsum but still needs work.

* Fix initializer for CumSum. (#9)

* Add Constant, Squeeze & Sub (#10)

* Add squeeze.

* Add Constant.

* Add sub.

* Support reusing Relay ONNX operator convertors in the Relax ONNX frontend (#8)

* [WIP] Support using Relay ops in the Relax ONNX frontend

Co-authored-by: Matthew Barrett  <[email protected]>
Co-authored-by: Michalis Papadimitriou <[email protected]>

* [WIP] small fixes

Co-authored-by: Matthew Barrett  <[email protected]>
Co-authored-by: Michalis Papadimitriou <[email protected]>

* [WIP] Support dynamic matmul and reshape

Co-authored-by: Matthew Barrett  <[email protected]>
Co-authored-by: Michalis Papadimitriou <[email protected]>

* Address PR comments

---------

Co-authored-by: Matthew Barrett  <[email protected]>
Co-authored-by: Michalis Papadimitriou <[email protected]>

* Add more ops (including all Reduce ops) using the relay frontend (#11)

* [WIP] add more ops. Some fail at the moment

* skip some tests

* Remove duplicate tests for squeeze

* Add Split op in the Relax ONNX frontend (#12)

* [Relax][ONNX] Add Split op

* Remove tmp

* Fix layer normalizations and Shape operator.

* Replace main loop with tvm testing.

* Simplify Slice for opset 13.

* [Relax][ONNX] Implement pad op

* Incorporate pad op, add static constantofshape op.

* Changes to shape to temporarily enable constantofshape in our models.

* Add initial tensor_to_shape implementation.

* Implemented dynamic broadcast_to to support expand and constantofshape.

* Changes sufficient for vortex end to end run.

* Formatting.

* Format tests.

* Re-add broadcast_to shape checking.

* Fix formatting.

* Remove overly strict manipulate check.

* Fix typing

* [Relax][Onnx] Implement Tile operator

* Switch to native relax attention importer.

* Address some of the PR comments

* Check for the imported model IR version

* switch from torch to numpy due to some incompatibility

* Fix make format.

* Clean up typing issues.

* Clarify variable name.

* Remove unneeded comprehension.

* Remove circular dependency.

* Add name sanitization for inputs

* Disable reshape rewrite pass until fixed.

* Fix long comment

* Update cpu image.

---------

Co-authored-by: Florin Blanaru <[email protected]>
Co-authored-by: Xiyou Zhou <[email protected]>
Co-authored-by: Matthew Barrett  <[email protected]>
Co-authored-by: Michalis Papadimitriou <[email protected]>
Co-authored-by: Florin Blanaru <[email protected]>
Co-authored-by: sung <[email protected]>
@jwfromm jwfromm deleted the TUZ-6 branch March 7, 2023 21:30
vinx13 pushed a commit to vinx13/relax-octo that referenced this pull request Mar 29, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants