-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[runtime] AOTExecutor implementation and c target code-generator #10283
Changes from 16 commits
8745a3c
81a4352
3ea361b
2e4c277
e5571e2
9e81293
66c8ef5
b49bea6
033885c
3717f0c
b78b6c4
2384fcc
5ec990f
aef76ca
84e5b92
ccc6c9c
c0be55c
fda6d53
f32ba4b
4f658a6
2d42f22
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -0,0 +1,26 @@ | ||||
# Licensed to the Apache Software Foundation (ASF) under one | ||||
# or more contributor license agreements. See the NOTICE file | ||||
# distributed with this work for additional information | ||||
# regarding copyright ownership. The ASF licenses this file | ||||
# to you under the Apache License, Version 2.0 (the | ||||
# "License"); you may not use this file except in compliance | ||||
# with the License. You may obtain a copy of the License at | ||||
# | ||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||
# | ||||
# Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||||
# KIND, either express or implied. See the License for the | ||||
# specific language governing permissions and limitations | ||||
# under the License. | ||||
|
||||
"""This module contains Python wrappers for the TVM C++ Executor implementations. | ||||
NOTE: at present, only AOT Executor is contained here. The others are: | ||||
- GraphExecutor, in python/tvm/contrib/graph_executor.py | ||||
- VM Executor, in python/tvm/runtime/vm.py | ||||
TODO(areusch): Consolidate these into this module. | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just realized that, we have two notions of tvm/python/tvm/relay/build_module.py Line 647 in 55849e6
which is used a lot in the test cases. Do we intend to support There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah that is a good point. added support here. |
||||
""" | ||||
from .aot_executor import AotModule |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""A Python wrapper for the Module-based Model Runtime Interface for Ahead-of-Time compilation.""" | ||
|
||
import numpy as np | ||
|
||
|
||
class AotModule(object): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (stylistic) : AotModule --> AOTModule or AoTModule ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i kind of find it easier to read the acronyms if we use CapWords, this also follows https://www.python.org/dev/peps/pep-0008/#class-names which is linked from numpydoc |
||
"""Wraps the AOT executor runtime.Module. | ||
|
||
This is a thin wrapper of the underlying TVM module. | ||
you can also directly call set_input, run, and get_output | ||
of underlying module functions | ||
|
||
Parameters | ||
---------- | ||
module : tvm.runtime.Module | ||
The internal tvm module that holds the actual graph functions. | ||
|
||
Attributes | ||
---------- | ||
module : tvm.runtime.Module | ||
The internal tvm module that holds the actual graph functions. | ||
|
||
Examples | ||
-------- | ||
|
||
.. code-block:: python | ||
|
||
import tvm | ||
from tvm import relay | ||
from tvm.contrib import graph_executor | ||
|
||
# build the library using graph executor | ||
lib = relay.build(...) | ||
lib.export_library("compiled_lib.so") | ||
# load it back as a runtime | ||
lib: tvm.runtime.Module = tvm.runtime.load_module("compiled_lib.so") | ||
# Call the library factory function for default and create | ||
# a new runtime.Module, wrap with graph module. | ||
gmod = graph_executor.GraphModule(lib["default"](dev)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need update? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed, thanks! |
||
# use the graph module. | ||
gmod.set_input("x", data) | ||
gmod.run() | ||
""" | ||
|
||
def __init__(self, module): | ||
self.module = module | ||
self._set_input = module["set_input"] | ||
self._run = module["run"] | ||
self._get_output = module["get_output"] | ||
self._get_input = module["get_input"] | ||
self._get_num_outputs = module["get_num_outputs"] | ||
self._get_input_index = module["get_input_index"] | ||
self._get_num_inputs = module["get_num_inputs"] | ||
|
||
def set_input(self, key=None, value=None, **params): | ||
"""Set inputs to the module via kwargs | ||
|
||
Parameters | ||
---------- | ||
key : int or str | ||
The input key | ||
|
||
value : the input value. | ||
The input key | ||
|
||
params : dict of str to NDArray | ||
Additional arguments | ||
""" | ||
if key is not None: | ||
v = self._get_input(key) | ||
if v is None: | ||
raise RuntimeError("Could not find '%s' in graph's inputs" % key) | ||
v.copyfrom(value) | ||
|
||
if params: | ||
# upload big arrays first to avoid memory issue in rpc mode | ||
keys = list(params.keys()) | ||
keys.sort(key=lambda x: -np.prod(params[x].shape)) | ||
for k in keys: | ||
# TODO(zhiics) Skip the weights for submodule in a better way. | ||
# We should use MetadataModule for initialization and remove | ||
# params from set_input | ||
val = self._get_input(k) | ||
if val: | ||
self._get_input(k).copyfrom(params[k]) | ||
|
||
def run(self, **input_dict): | ||
"""Run forward execution of the graph | ||
|
||
Parameters | ||
---------- | ||
input_dict: dict of str to NDArray | ||
List of input values to be feed to | ||
""" | ||
if input_dict: | ||
self.set_input(**input_dict) | ||
self._run() | ||
|
||
def get_num_outputs(self): | ||
"""Get the number of outputs from the graph | ||
|
||
Returns | ||
------- | ||
count : int | ||
The number of outputs. | ||
""" | ||
return self._get_num_outputs() | ||
|
||
def get_num_inputs(self): | ||
"""Get the number of inputs to the graph | ||
|
||
Returns | ||
------- | ||
count : int | ||
The number of inputs. | ||
""" | ||
return self._get_num_inputs() | ||
|
||
def get_input(self, index, out=None): | ||
"""Get index-th input to out | ||
|
||
Parameters | ||
---------- | ||
index : int | ||
The input index | ||
|
||
out : NDArray | ||
The output array container | ||
""" | ||
if out: | ||
self._get_input(index).copyto(out) | ||
return out | ||
|
||
return self._get_input(index) | ||
|
||
def get_input_index(self, name): | ||
"""Get inputs index via input name. | ||
|
||
Parameters | ||
---------- | ||
name : str | ||
The input key name | ||
|
||
Returns | ||
------- | ||
index: int | ||
The input index. -1 will be returned if the given input name is not found. | ||
""" | ||
return self._get_input_index(name) | ||
|
||
def get_output(self, index, out=None): | ||
"""Get index-th output to out | ||
|
||
Parameters | ||
---------- | ||
index : int | ||
The output index | ||
|
||
out : NDArray | ||
The output array container | ||
""" | ||
if out: | ||
self._get_output(index, out) | ||
return out | ||
|
||
return self._get_output(index) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are not using
run_model
anymore, are you? In your previous branch indeed I'm seeingtvmgen_default_run_model
generated, but after rebase it is replaced withtvmgen_default___tvm_main__
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is right!.
We had run (entry_point), run_model and tvm_main.
The entry_point is supposed to call run_model, however, run_model is identical to tvm_main -- therefore, it was removed due to no need of maintaining two symbols for relay and tir versions of main.
So I think it needs to be tvm_main now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed to just "AOT main" since the function should probably eventually be renamed based on
mod_name
.