Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYOC][VitisAI] Fix issue in Vitis AI codegen out tensor names matching & update docs and docker #7350

Merged
merged 10 commits into from
Feb 24, 2021
1 change: 0 additions & 1 deletion docker/Dockerfile.demo_vitis_ai
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ FROM xilinx/vitis-ai:latest

RUN apt-get update --fix-missing


COPY install/ubuntu_install_core.sh /install/ubuntu_install_core.sh
RUN bash /install/ubuntu_install_core.sh

Expand Down
12 changes: 6 additions & 6 deletions docker/install/ubuntu_install_vitis_ai_core.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ set -u
set -o pipefail

# install libraries for building Vitis-AI on ubuntu
apt-get update && apt-get install -y --no-install-recommends \
graphviz\
gnupg2

apt-get update && apt-get install -y gcc-aarch64-linux-gnu

apt-get update && apt-get install -y \
graphviz \
gnupg2 \
gpg-agent \
gcc-aarch64-linux-gnu \
&& rm -rf /var/lib/apt/lists/*
68 changes: 56 additions & 12 deletions docs/deploy/vitis_ai.rst
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ Edge hardware setup
https://github.com/Xilinx/PYNQ/releases/tag/v2.5
2. Follow Pynq instructions for setting up the board: `pynq
setup <https://pynq.readthedocs.io/en/latest/getting_started.html>`__
3. After connecting to the board, make sure to run as root. Execute
3. After connecting to the board, make sure to run as root. **Execute**
``su``
4. Set up DPU on Pynq by following the steps here: `DPU Pynq
setup <https://github.com/Xilinx/DPU-PYNQ>`__
Expand Down Expand Up @@ -441,7 +441,7 @@ TVM.
import tvm
import tvm.relay as relay
from tvm.contrib.target import vitis_ai
from tvm.contrib import util, graph_runtime
from tvm.contrib import utils, graph_runtime
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.op.contrib.vitis_ai import annotation

Expand Down Expand Up @@ -524,6 +524,8 @@ model in TVM with Vitis-AI at the edge. The first couple of steps will
have to be run on the host machine and take care of quantization and
compilation for deployment at the edge.

A complete ResNet 18 example can be found `here <https://github.com/Xilinx/pyxir/tree/master/examples/tvm>`__.

Host steps
^^^^^^^^^^

Expand All @@ -541,20 +543,50 @@ TVM.
import tvm
import tvm.relay as relay
from tvm.contrib.target import vitis_ai
from tvm.contrib import util, graph_runtime
from tvm.contrib import utils, graph_runtime
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.op.contrib.vitis_ai import annotation

After importing a convolutional neural network model using the usual
Relay API's, annotate the Relay expression for the given Vitis-AI DPU
target and partition the graph.

.. note::

We recommend switching DPU convolutions' data layouts to NHWC and CPU comvolutions'
jtuyls marked this conversation as resolved.
Show resolved Hide resolved
data layouts to NCHW for best DPU and CPU performance. You can use the ConvertLayout
transformation pass two times to achieve this as demonstrated in the code block
underneath.

.. code:: python

mod["main"] = bind_params_by_name(mod["main"], params)

# For edge DPU we recommend switching the convolutions'data layout
jtuyls marked this conversation as resolved.
Show resolved Hide resolved
# to NHWC for best performance. Therefore, we first convert the layouts
# of all convolutions to NHWC before partitioning. Afterwards, we can
# convert any remaining convolutions (to be executed on CPU) back to NCHW.
desired_layouts = {'nn.conv2d': ['NHWC', 'default']}
seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
relay.transform.ConvertLayout(desired_layouts),
relay.transform.FoldConstant()])
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)

# Annotate and partition the Relay expression for the given target
mod = annotation(mod, params, target)
mod = relay.transform.MergeCompilerRegions()(mod)
mod = relay.transform.PartitionGraph()(mod)

# After partitioning we recommend transforming the remaining convolutions
# (that will be executed on CPU, if any) back to NCHW data layout
# for best CPU performance
desired_layouts = {'nn.conv2d': ['NCHW', 'default']}
seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
jtuyls marked this conversation as resolved.
Show resolved Hide resolved
relay.transform.ConvertLayout(desired_layouts),
relay.transform.FoldConstant()])
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)

Now, we can build the TVM runtime library for executing the model. The
TVM target is 'llvm' as the operations that can't be handled by the DPU
Expand All @@ -572,13 +604,9 @@ can be included.

.. code:: python

from tvm.contrib import util

temp = util.tempdir()

tvm_target = 'llvm'
target='DPUCZDX8G-zcu104'
export_rt_mod_file = temp.relpath("vitis_ai.rtmod")
export_rt_mod_file = "vitis_ai.rtmod"

with tvm.transform.PassContext(opt_level=3, config= {'relay.ext.vitis_ai.options.target': target,
'relay.ext.vitis_ai.options.export_runtime_module': export_rt_mod_file}):
Expand All @@ -604,9 +632,9 @@ Save the TVM lib module so that the Vitis-AI runtime module will also be exporte

.. code:: python

from tvm.contrib import util
from tvm.contrib import utils

temp = util.tempdir()
temp = utils.tempdir()
lib.export_library(temp.relpath("tvm_lib.so"))

After quantizing and compiling the model for Vitis-AI acceleration using the
Expand Down Expand Up @@ -638,15 +666,31 @@ Edge steps
^^^^^^^^^^

After setting up TVM with Vitis-AI on the edge device, you can now load
the TVM runtime module into memory and feed inputs for inference.
the TVM runtime module into memory and feed inputs for inference. A nearly
complete runtiem script can be found underneath. Make sure to run the script
as root (execute ``su`` in terminal to log into root).


.. note::

You will see a warning about the 'cpu-tf' runtime not being found. This warning is
expected on the board and can be ignored. Note also that you **shouldn't** import the
PyXIR targets in the run script (``import pyxir.contrib.target.DPUCZDX8G``).

.. code:: python

import pyxir
import tvm
from tvm.contrib import graph_runtime

ctx = tvm.cpu()

# input_name = ...
# input_data = ...

# load the module into memory
lib = tvm.runtime.load_module("tvm_dpu_arm.so")

module = graph_runtime.GraphModule(lib["default"](tvm.cpu()))
module.set_input(name, data)
module.set_input(input_name, input_data)
module.run()
4 changes: 2 additions & 2 deletions python/tvm/contrib/target/vitis_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,12 @@ def vitis_ai_compiler(ref):
layers = xgraph.get_layers()

# Get the output tensor names using XGraph and output Relay ids
out_tensor_names = []
out_tensor_names = [1] * len(output_relay_ids)
jtuyls marked this conversation as resolved.
Show resolved Hide resolved
for layer in layers:
if not layer.internal:
for relay_id in layer.attrs["relay_id"]:
if relay_id in output_relay_ids:
out_tensor_names.append(layer.name)
out_tensor_names[output_relay_ids.index(relay_id)] = layer.name
break
if not out_tensor_names:
jtuyls marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relay/op/contrib/vitis_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def visit_call(self, call):

def annotation(mod, params, target):
"""Annotate Relay expression for Vitis-AI DPU accelerators"""
# We need type information for supporting models that contain operations that don't
# have a Relay to XLayer translation
mod = relay.transform.InferType()(mod)

xgraph = pyxir.frontend.tvm.from_relay(mod, params, postprocessing=None)
xgraph = pyxir.partition(xgraph, targets=[target])

Expand Down