From bece73757b49caee97ab688e80aaa12bd2cd9c01 Mon Sep 17 00:00:00 2001
From: Xingyu Zhou <zhoxingy@amazon.com>
Date: Mon, 25 Nov 2019 15:07:39 -0800
Subject: [PATCH] merge from apach/tvm to neo ai tvm (#59)

* [TOPI][OP] Support Faster-RCNN Proposal OP on CPU (#4297)

* Support Proposal operator on CPU.

* PyLint space issue

* PyLint space issue

* Pylint singleton-comparison issue

* [QNN][Legalize] Specialize for Platforms without any fast Int8 arithmetic units. (#4307)

* fix error when memory_id is VTA_MEM_ID_OUT (#4330)

* [CI][DOCKER] Add ONNX runtime dep (#4314)

* [DOCKER] Add ONNX runtime dep

* Improve ci script

* [QNN] Quantize - Fixing the sequence of lowering. (#4316)

* [QNN] Use Int16 upcast in Fallback Conv2D. Fix test names. (#4329)

* [doc][fix] fix sphinx parsing for pass infra tutorial (#4337)

* change ci image version (#4313)

* [Codegen] remove fp16 function override for cuda  (#4331)

* add volatile override back

* [codegen] remove fp16 function override for cuda

* [CI] Set workspace to be per executor (#4336)

* [Build][Windows] Fix Windows build by including cctype (#4319)

* Fix build

* dummy change to retrigger CI

* dummy change to retrigger ci

* dummy change to retrigger ci

* Enable hipModuleGetGlobal() (#4321)

* [Relay][Pass] Add pass to remove unused functions in relay module (#4334)

* [Relay][Pass] Add pass to remove unused functions in relay module

* Add tests

* Fix lint

* Fix visit order

* Add pass argument

* Fix

* Add support for quant. mul operator in tflite frontend (#4283)

A test for qnn_mul has to be added when the qnn elemwise tests (#4282) get merged.

* Add topi.nn.fifo_buffer to TVM doc (#4343)

* Solve custom model of prelu (#4326)

* Deprecate NNVM warning msg (#4333)

* [Contrib] Add MKL DNN option (#4323)

* [Contrib] Add MKL DNN

* update

* update

* [Relay][Frontend][TF] Fix transpose when axes is not a param (#4327)

* [Relay][Frontend][TF] Use _infer_value_simulated when axes is not a const to Transpose

* uncomment tests

* dummy change to retrigger ci

* [RUNTIME] Add device query for AMD GcnArch (#4341)

* add gcnArch query

* kGcnArch query for cuda is a no-op

* [Test][Relay][Pass] Add test case for lambda lift (#4317)

* [Relay][Frontend][ONNX] operator support: DepthToSpace, SpaceToDepth (#4271)

* imp module is deprecated (#4275)

* [VTA] Bug fix for padded load with large inputs (#4293)

* bug fix for padded load with large inputs

* Update TensorLoad.scala

* Update test_vta_insn.py

* fix inconsistent tag name (#4134)

* [CodeGen] Add build config option disable_assert to control whether to generate assert (#4340)

* Bump up CUDA log version in tophub.py (#4347)

* Add check to ensure input file was successfully opened in NNVM deploy code demo (#4315)

* [COMMUNITY] Add DISCLAIMER, KEYS for ASF release (#4345)

* [COMMUNITY] Add DISCLAIMER, KEYS for ASF release

* Add file name spec

* [Relay][VM][Interpreter] Enable first-class constructors in VM and interpreter via eta expansion (#4218)

* Fix constructor pretty printing

* Make Module::HasDef name consistent with API

* Add VM constructor compilation via eta expansion

* Lint

* Fix CI

* Fix failing test

* Address comment

* Retrigger CI

* Retrigger CI

* Update dmlc_tvm_commit_id.txt
---
 CMakeLists.txt                                |   1 +
 DISCLAIMER                                    |  12 +
 Jenkinsfile                                   |  24 +-
 KEYS                                          |  74 +++
 NOTICE                                        |   7 +-
 cmake/config.cmake                            |   3 +
 cmake/modules/contrib/BLAS.cmake              |   7 +
 dmlc_tvm_commit_id.txt                        |   2 +-
 docker/install/ubuntu_install_onnx.sh         |   1 +
 docs/api/python/topi.rst                      |   2 +
 docs/deploy/nnvm.md                           |   8 +-
 include/tvm/build_module.h                    |   4 +
 include/tvm/ir_pass.h                         |   7 +
 include/tvm/relay/module.h                    |  14 +-
 include/tvm/relay/transform.h                 |   9 +-
 include/tvm/runtime/device_api.h              |   3 +-
 nnvm/python/nnvm/__init__.py                  |   4 +
 python/tvm/autotvm/tophub.py                  |   2 +-
 python/tvm/build_module.py                    |   3 +-
 python/tvm/hybrid/module.py                   |   9 +-
 python/tvm/relay/frontend/onnx.py             |  72 +++
 python/tvm/relay/frontend/tensorflow.py       |   4 +-
 python/tvm/relay/frontend/tflite.py           |   6 +-
 python/tvm/relay/qnn/op/legalizations.py      | 194 ++++++--
 python/tvm/relay/std/prelude.rly              |  15 +-
 python/tvm/relay/transform.py                 |  31 +-
 src/codegen/build_module.cc                   |   1 +
 src/codegen/codegen.cc                        |  12 +-
 src/codegen/codegen_cuda.cc                   |  22 +-
 src/codegen/literal/cuda_half_t.h             |   3 +-
 src/codegen/llvm/codegen_amdgpu.cc            |   2 +-
 src/common/util.h                             |   1 +
 src/pass/skip_assert.cc                       |  47 ++
 src/relay/backend/interpreter.cc              |  11 +
 src/relay/backend/vm/compiler.cc              |   7 +
 src/relay/backend/vm/lambda_lift.cc           |  22 +-
 src/relay/backend/vm/removed_unused_funcs.cc  | 134 ++++++
 src/relay/ir/alpha_equal.cc                   |   2 +-
 src/relay/ir/module.cc                        |  12 +-
 src/relay/ir/pretty_printer.cc                |   6 +-
 src/relay/pass/eta_expand.cc                  | 159 +++++--
 src/relay/pass/type_infer.cc                  |   2 +-
 src/relay/qnn/op/convolution.cc               |  41 +-
 src/relay/qnn/op/quantize.cc                  |  10 +-
 src/runtime/contrib/cblas/cblas.cc            |  10 +
 src/runtime/cuda/cuda_device_api.cc           |   1 +
 src/runtime/metal/metal_device_api.mm         |   1 +
 src/runtime/opencl/opencl_device_api.cc       |   1 +
 src/runtime/opengl/opengl_device_api.cc       |   1 +
 src/runtime/rocm/rocm_device_api.cc           |  14 +-
 src/runtime/rocm/rocm_module.cc               |  11 +-
 src/runtime/vulkan/vulkan.cc                  |   2 +
 tests/lint/check_file_type.py                 |   2 +
 tests/python/frontend/onnx/test_forward.py    |  71 ++-
 .../frontend/tensorflow/test_forward.py       |  18 +
 tests/python/frontend/tflite/test_forward.py  |   6 +-
 tests/python/relay/test_ir_text_printer.py    |  22 +
 tests/python/relay/test_op_level2.py          |  48 ++
 tests/python/relay/test_op_level5.py          |   2 +-
 tests/python/relay/test_op_qnn_conv2d.py      |  52 +-
 .../{test_qnn_mul.py => test_op_qnn_mul.py}   |   0
 tests/python/relay/test_op_qnn_requantize.py  | 445 +++++++++---------
 tests/python/relay/test_pass_eta_expand.py    |  75 ++-
 tests/python/relay/test_pass_lambda_lift.py   |  40 ++
 tests/python/relay/test_pass_qnn_legalize.py  | 153 +++++-
 .../test_pass_remove_unused_functions.py      |  75 +++
 tests/scripts/task_python_frontend.sh         |   1 +
 tests/scripts/task_python_integration.sh      |   2 +
 topi/python/topi/arm_cpu/depthwise_conv2d.py  |   4 +-
 topi/python/topi/cpp/__init__.py              |   9 +
 topi/python/topi/cpp/cuda.py                  |  21 +
 topi/python/topi/cpp/generic.py               |  21 +
 topi/python/topi/cpp/image.py                 |  21 +
 topi/python/topi/{cpp.py => cpp/impl.py}      |  28 +-
 topi/python/topi/cpp/nn.py                    |  21 +
 topi/python/topi/cpp/rocm.py                  |  21 +
 topi/python/topi/cpp/vision/__init__.py       |   7 +
 topi/python/topi/cpp/vision/yolo.py           |  21 +
 topi/python/topi/cpp/x86.py                   |  21 +
 topi/python/topi/vision/rcnn/proposal.py      | 283 ++++++++++-
 topi/python/topi/x86/dense.py                 |   2 +-
 topi/tests/python/test_topi_vision.py         |   2 +-
 tutorials/dev/relay_pass_infra.py             |   2 +-
 .../src/main/scala/core/TensorLoad.scala      |  29 +-
 vta/src/runtime.cc                            |   2 +-
 vta/tests/python/unittest/test_vta_insn.py    | 116 ++---
 86 files changed, 2120 insertions(+), 576 deletions(-)
 create mode 100644 DISCLAIMER
 create mode 100644 KEYS
 create mode 100644 src/pass/skip_assert.cc
 create mode 100644 src/relay/backend/vm/removed_unused_funcs.cc
 rename tests/python/relay/{test_qnn_mul.py => test_op_qnn_mul.py} (100%)
 create mode 100644 tests/python/relay/test_pass_lambda_lift.py
 create mode 100644 tests/python/relay/test_pass_remove_unused_functions.py
 create mode 100644 topi/python/topi/cpp/__init__.py
 create mode 100644 topi/python/topi/cpp/cuda.py
 create mode 100644 topi/python/topi/cpp/generic.py
 create mode 100644 topi/python/topi/cpp/image.py
 rename topi/python/topi/{cpp.py => cpp/impl.py} (64%)
 create mode 100644 topi/python/topi/cpp/nn.py
 create mode 100644 topi/python/topi/cpp/rocm.py
 create mode 100644 topi/python/topi/cpp/vision/__init__.py
 create mode 100644 topi/python/topi/cpp/vision/yolo.py
 create mode 100644 topi/python/topi/cpp/x86.py

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 25489d07fd68..9c94f29ac4f8 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -53,6 +53,7 @@ tvm_option(PICOJSON_PATH "Path to PicoJSON" "3rdparty/picojson")
 # Contrib library options
 tvm_option(USE_BLAS "The blas library to be linked" none)
 tvm_option(USE_MKL_PATH "MKL root path when use MKL blas" none)
+tvm_option(USE_MKLDNN "Build with MKLDNN" OFF)
 tvm_option(USE_CUDNN "Build with cuDNN" OFF)
 tvm_option(USE_CUBLAS "Build with cuBLAS" OFF)
 tvm_option(USE_MIOPEN "Build with ROCM:MIOpen" OFF)
diff --git a/DISCLAIMER b/DISCLAIMER
new file mode 100644
index 000000000000..986b2c84f6b4
--- /dev/null
+++ b/DISCLAIMER
@@ -0,0 +1,12 @@
+Apache TVM (incubating) is an effort undergoing incubation at The
+Apache Software Foundation (ASF), sponsored by the Apache Incubator PMC.
+
+Incubation is required of all newly accepted
+projects until a further review indicates that the
+infrastructure, communications, and decision making process have
+stabilized in a manner consistent with other successful ASF
+projects.
+
+While incubation status is not necessarily a reflection
+of the completeness or stability of the code, it does indicate
+that the project has yet to be fully endorsed by the ASF.
diff --git a/Jenkinsfile b/Jenkinsfile
index 17f9a5669d03..10073278ded0 100644
--- a/Jenkinsfile
+++ b/Jenkinsfile
@@ -45,7 +45,7 @@
 //
 
 ci_lint = "tvmai/ci-lint:v0.51"
-ci_gpu = "tvmai/ci-gpu:v0.55"
+ci_gpu = "tvmai/ci-gpu:v0.56"
 ci_cpu = "tvmai/ci-cpu:v0.54"
 ci_i386 = "tvmai/ci-i386:v0.52"
 
@@ -64,6 +64,8 @@ docker_run = 'docker/bash.sh'
 // timeout in minutes
 max_time = 120
 
+workspace = "workspace/exec_${env.EXECUTOR_NUMBER}"
+
 // initialize source codes
 def init_git() {
   checkout scm
@@ -86,7 +88,7 @@ def init_git_win() {
 stage("Sanity Check") {
   timeout(time: max_time, unit: 'MINUTES') {
     node('CPU') {
-      ws('workspace/tvm/sanity') {
+      ws("${workspace}/tvm/sanity") {
         init_git()
         sh "${docker_run} ${ci_lint}  ./tests/scripts/task_lint.sh"
       }
@@ -134,7 +136,7 @@ def unpack_lib(name, libs) {
 stage('Build') {
   parallel 'BUILD: GPU': {
     node('GPUBUILD') {
-      ws('workspace/tvm/build-gpu') {
+      ws("${workspace}/tvm/build-gpu") {
         init_git()
         sh """
            mkdir -p build
@@ -182,7 +184,7 @@ stage('Build') {
   },
   'BUILD: CPU': {
     node('CPU') {
-      ws('workspace/tvm/build-cpu') {
+      ws("${workspace}/tvm/build-cpu") {
         init_git()
         sh """
            mkdir -p build
@@ -213,7 +215,7 @@ stage('Build') {
   },
   'BUILD : i386': {
     node('CPU') {
-      ws('workspace/tvm/build-i386') {
+      ws("${workspace}/tvm/build-i386") {
         init_git()
         sh """
            mkdir -p build
@@ -238,7 +240,7 @@ stage('Build') {
 stage('Unit Test') {
   parallel 'python3: GPU': {
     node('TensorCore') {
-      ws('workspace/tvm/ut-python-gpu') {
+      ws("${workspace}/tvm/ut-python-gpu") {
         init_git()
         unpack_lib('gpu', tvm_multilib)
         timeout(time: max_time, unit: 'MINUTES') {
@@ -250,7 +252,7 @@ stage('Unit Test') {
   },
   'python3: i386': {
     node('CPU') {
-      ws('workspace/tvm/ut-python-i386') {
+      ws("${workspace}/tvm/ut-python-i386") {
         init_git()
         unpack_lib('i386', tvm_multilib)
         timeout(time: max_time, unit: 'MINUTES') {
@@ -263,7 +265,7 @@ stage('Unit Test') {
   },
   'java: GPU': {
     node('GPU') {
-      ws('workspace/tvm/ut-java') {
+      ws("${workspace}/tvm/ut-java") {
         init_git()
         unpack_lib('gpu', tvm_multilib)
         timeout(time: max_time, unit: 'MINUTES') {
@@ -277,7 +279,7 @@ stage('Unit Test') {
 stage('Integration Test') {
   parallel 'topi: GPU': {
     node('GPU') {
-      ws('workspace/tvm/topi-python-gpu') {
+      ws("${workspace}/tvm/topi-python-gpu") {
         init_git()
         unpack_lib('gpu', tvm_multilib)
         timeout(time: max_time, unit: 'MINUTES') {
@@ -288,7 +290,7 @@ stage('Integration Test') {
   },
   'frontend: GPU': {
     node('GPU') {
-      ws('workspace/tvm/frontend-python-gpu') {
+      ws("${workspace}/tvm/frontend-python-gpu") {
         init_git()
         unpack_lib('gpu', tvm_multilib)
         timeout(time: max_time, unit: 'MINUTES') {
@@ -299,7 +301,7 @@ stage('Integration Test') {
   },
   'legacy: GPU': {
     node('GPU') {
-      ws('workspace/tvm/legacy-python-gpu') {
+      ws("${workspace}/tvm/legacy-python-gpu") {
         init_git()
         unpack_lib('gpu', tvm_multilib)
         timeout(time: max_time, unit: 'MINUTES') {
diff --git a/KEYS b/KEYS
new file mode 100644
index 000000000000..5395d5eef6c9
--- /dev/null
+++ b/KEYS
@@ -0,0 +1,74 @@
+This file contains the PGP keys of various developers.
+Please don't use them for email unless you have to. Their main
+purpose is code signing.
+
+Examples of importing this file in your keystore:
+ gpg --import KEYS.txt
+ (need pgp and other examples here)
+
+Examples of adding your key to this file:
+ pgp -kxa <your name> and append it to this file.
+ (pgpk -ll <your name> && pgpk -xa <your name>) >> this file.
+ (gpg --list-sigs <your name>
+     && gpg --armor --export <your name>) >> this file.
+
+-----------------------------------------------------------------------------------
+pub   rsa4096 2019-11-15 [SC]
+      EF52D68AD5276994249816836754EA97C55E3DEB
+uid           [ultimate] Tianqi Chen (CODE SIGNING KEY) <tqchen@apache.org>
+sig 3        6754EA97C55E3DEB 2019-11-15  Tianqi Chen (CODE SIGNING KEY) <tqchen@apache.org>
+sub   rsa4096 2019-11-15 [E]
+sig          6754EA97C55E3DEB 2019-11-15  Tianqi Chen (CODE SIGNING KEY) <tqchen@apache.org>
+
+-----BEGIN PGP PUBLIC KEY BLOCK-----
+
+mQINBF3OK24BEADD4hxjrsgb4jIDIACHS15X+5YP/YaUF5UDDQs/bNn/xGJGVl4/
+4sJ6qKZcvMDrWTmnNItYBuaHi1qhGvlcASBekm/9PU2U8lZmAF1lZkKIIYZkX+If
+s8PEYurE8cDr65orrdsFF8Zwb+u6x+gMsHNivsU2Kn3xbQjGmeW44UA+aaXzcJp6
+sVk3aX5DypoYJNBmbASyOjZVWkcrJ+NKEfJ1dKtka5/siqOjuvCd8NT5dJVhZbm3
+Sf8iclEMqog1LhdI/FhE2fB3C5hJkzcinq2v55qDaGqsL+qgT7agf9b4t0EgjbVh
+cs6jlCglad+Oz27BQIjt06HE1OB5T/Gxa080FK4JZMpxZJ5tDA2/7DQM2MyN84z/
+s62JuBJnsrzr4w8D/QcAyzAmyzAqvxLR/aqLgJTIcQiw6AenHovKkNbEQOBYE2T5
+ms7uVO2E2Tv42J4Te4OKhpId9mK+7elCLvOb2DfAJDdYxDN9c8dJTls+G6xmv0h9
+bb2+QRjkpDiFeu1hKNEe0/ST/YXDfRYpKl+1t/QZ+JccLgEdEwuo/IQ1e4POH2h0
+Zqvy7TR5obeTf0TvmLzW+i3s1oUkmSAnQEncSGnGnlugYk0BLuMMi9Fhx6qcC5pC
+cA3nsRqFKebtnpop+m+psFkmd//xKSXJt9IYVEbQVNiUKm9uYq6RxZEAmQARAQAB
+tDJUaWFucWkgQ2hlbiAoQ09ERSBTSUdOSU5HIEtFWSkgPHRxY2hlbkBhcGFjaGUu
+b3JnPokCTgQTAQgAOBYhBO9S1orVJ2mUJJgWg2dU6pfFXj3rBQJdzituAhsDBQsJ
+CAcCBhUKCQgLAgQWAgMBAh4BAheAAAoJEGdU6pfFXj3rVJIQALBArXEaFDdTw8wl
+65nPLU6+QPc6eMn7mz6BDp1V7xL6Lq1GbArLpmQHIFhfQ/5Qmg80wuFBU1CNSRHd
+tdZq3v8tB9Txvhy6bLQ+IijWH/TxSEPqnrkNsWBQLqAygDC5O3Ook/T6B5kuc176
+Kz+w+YhzPS5hoPfJK6xGoKDNlkhmI/EnUjAq459VNpXeoeemiydzvApiCHH0VfOj
+XnmgAJsAJA21EfT5Wuh/WODsf0HkaXB0xoWZfE/ugIQBLhZi9nUTYgwU2r4a+v4A
+4C2T1OyJ3mDU+Oi/z6d0WJvsIrLCFcF4Q7b/6+MGkgLDGlsEKK2LZMrulGzQ1QY/
+O4ck3dVDseqT2urplrTamDIh1IQmOt1FqMFwugdjfQwJ5HQeX6IeUGZei2Av/IZR
+8Vw5Wxtm1Aksz3Js6iP3QmAh7txDUKO+eT5zLSXBoPmkleLnvCdtlvwaSNCAudHw
+12h10IV286OetJvyyjmh/q/30sKNGiuucLMzPMwtLNW/j3cts3fqRHIHxepT6m94
+FoYIlwVu4afiGgSi/7cN4p9GgfwnFGeETd25pgNG0KdXbVWniO1dTEKzOtvtuPYK
+Y88ZAfdOgj4dyeI9ZnJV8RaZvpImDPVHGQm69/071jBxyWZnVi/YtOm+DjHfw0Vi
+uiUdzoIb54oWW8tbiNg/nfiLUaJBuQINBF3OK24BEAC9W8Cwubu4Dpr4m0IIrLF5
+zRRqQm9QIcEC0QHf6w1c2NWQTJP+MQY/jZLjtKw5yCQDghT+qsil2p8xCM0EqRd6
+6NqxsAoweTCoV0MwolQv5T3KuP54SlNWjO+6gT73LkKuOHoIyy5cS9pIITlExHy+
+XHtfQi1keDpWUEyvSRG9slu1DcxAeo6nFEpCuoQ+xx/lrCMxDlyZJCDhj2fXs2hK
+8oKLV5NbIuifbXbCiOvZUdBHk0yLCEc6wNsVR30yLijSiPCKsAPcsG0PjQnz3eTb
+0czq+6g50zUVOTioUghIlZ1DhCsxQGnlxoLY71pnmc7qVszdXPV2Mp7/KSIhDJFQ
+LN0enDVz9aRXfpEK3SifxaPVNd61O/BGziza+XCK5qpEQL95UM2NdQCWixYmIOJE
+k95tpnagtNupMkrY6WEa0CjVBzF1kdr5WpeUd6w85rA/opcqpQ8yLmvpyJ4tXZhN
+7oAWZSUzyB904FMswUEhaS7pEJIlACeFcPwm31Jv/637gw1CopZpDxDUaW5/boG5
+9Gp9D/GV2gyMrHAcwA1gZSbmolv5ZYcnUmwTPijVNZ+o70HBbvbNZqziPgy9G+L/
+oGBkY/fpg7qfaGtAbOUbx1ck04CbafSUQIxpCG8in6zwrIRnn4uj6q4wIZ8SnvQ0
+h3Ug0DmdsxvB/xdfillH/QARAQABiQI2BBgBCAAgFiEE71LWitUnaZQkmBaDZ1Tq
+l8VePesFAl3OK24CGwwACgkQZ1Tql8VePeuZ1Q//csRsGDKNrW5e0EitEcfPZ0PC
+teEw7A16dniXiCQF39KxxLzjCjUq7U8iWNm7bn1zdXcSVYZow+i5hFWXgZLKTKep
+tQoocJmQ7kPV5oiTBewFy9T4BICUekj/EhXhSz1wxb3GSc+uHL2IUlFkixTY4k4B
+9zq49gkNkTM02Or3quu1ZWAgeol1BSyV0tcI1h3M0OXtrN6idLyzQJFRyMYtzfwp
+Pd2+hdaKAl8mKANs/GMJni3QvyVXzuJxMP6SNOFx4mWj0UVFVZvosv1lLXDesvwY
+sNZmz5IkfuU4DHz1ZzZc3sThkpBdBiadvyKtNsenNh5nEXtwVhpiFf3IdZAvG7Ks
+7i3Fx1/ObbvxMCWeFoB6oP/swHr9i6dqntiJoB6Gl5y1ye3qte8PiNuwRVhz+YOK
+58Ga3wWMvODpi2AgSFv7cd1OFXXsoonORfmpcfAp+h6dIr/ttQMP2929/NoX3Cs4
+/pXoG9L5EOpMfj0Q24sAGW8VzuCAHL3e7QSijFuSHZxz9oe4C28/mAY+KP0dif0Q
+O3rq4kpqlhseyzcRyE1LWBvzuCeSTui2OPmyivFY57TOPnMHm5sXVby1VUiwm0B0
+RgBtZDRLv765lAFGtp43sccZ7zfRaKhkVmzh3bAZ62nJyQNGw0TWg96Pf7Kjb0Bv
+ha8fS9ysWDy/Ye65MP4=
+=MSiP
+-----END PGP PUBLIC KEY BLOCK-----
diff --git a/NOTICE b/NOTICE
index 45468c50ba1b..4f6447124e97 100644
--- a/NOTICE
+++ b/NOTICE
@@ -1 +1,6 @@
-TVM End to End Deep Learning Compiler Stack: https://tvm.ai/
+Apache TVM (incubating)
+Copyright 2017 and onwards The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
+
diff --git a/cmake/config.cmake b/cmake/config.cmake
index 8ca6d6f58bd5..71666541eec5 100644
--- a/cmake/config.cmake
+++ b/cmake/config.cmake
@@ -115,6 +115,9 @@ set(USE_BLAS none)
 # set(USE_MKL_PATH <path to venv or site-packages directory>) if using `pip install mkl`
 set(USE_MKL_PATH none)
 
+# Whether use MKLDNN library
+set(USE_MKLDNN OFF)
+
 # Whether use OpenMP thread pool, choices: gnu, intel
 # Note: "gnu" uses gomp library, "intel" uses iomp5 library
 set(USE_OPENMP none)
diff --git a/cmake/modules/contrib/BLAS.cmake b/cmake/modules/contrib/BLAS.cmake
index 6a5828749762..bd8c0d0c445f 100644
--- a/cmake/modules/contrib/BLAS.cmake
+++ b/cmake/modules/contrib/BLAS.cmake
@@ -55,3 +55,10 @@ elseif(USE_BLAS STREQUAL "none")
 else()
   message(FATAL_ERROR "Invalid option: USE_BLAS=" ${USE_BLAS})
 endif()
+
+if(USE_MKLDNN STREQUAL "ON")
+  find_library(BLAS_LIBRARY_MKLDNN dnnl)
+  list(APPEND TVM_RUNTIME_LINKER_LIBS ${BLAS_LIBRARY_MKLDNN})
+  add_definitions(-DUSE_DNNL=1)
+  message(STATUS "Use MKLDNN library " ${BLAS_LIBRARY_MKLDNN})
+endif()
diff --git a/dmlc_tvm_commit_id.txt b/dmlc_tvm_commit_id.txt
index 979933ddb488..2bd2b4b9a515 100644
--- a/dmlc_tvm_commit_id.txt
+++ b/dmlc_tvm_commit_id.txt
@@ -1 +1 @@
-e541c75863775f9011a658a36b86f084133bfbb7
+2c5c4da697753ca79ea1551cc91c3072cecbbbb1
diff --git a/docker/install/ubuntu_install_onnx.sh b/docker/install/ubuntu_install_onnx.sh
index 54210b83f4d6..a915ca02c05a 100755
--- a/docker/install/ubuntu_install_onnx.sh
+++ b/docker/install/ubuntu_install_onnx.sh
@@ -22,6 +22,7 @@ set -o pipefail
 
 # fix to certain version for now
 pip3 install onnx==1.5.0
+pip3 install onnxruntime==1.0.0
 
 # torch depends on a number of other packages, but unhelpfully, does
 # not expose that in the wheel!!!
diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst
index 0e203c176711..3cfebb8503e4 100644
--- a/docs/api/python/topi.rst
+++ b/docs/api/python/topi.rst
@@ -69,6 +69,7 @@ List of operators
    topi.nn.conv2d_hwcn
    topi.nn.depthwise_conv2d_nchw
    topi.nn.depthwise_conv2d_nhwc
+   topi.nn.fifo_buffer
    topi.max
    topi.sum
    topi.min
@@ -199,6 +200,7 @@ topi.nn
 .. autofunction:: topi.nn.conv2d_hwcn
 .. autofunction:: topi.nn.depthwise_conv2d_nchw
 .. autofunction:: topi.nn.depthwise_conv2d_nhwc
+.. autofunction:: topi.nn.fifo_buffer
 
 topi.image
 ~~~~~~~~~~
diff --git a/docs/deploy/nnvm.md b/docs/deploy/nnvm.md
index 4040de35ea54..650912231b12 100644
--- a/docs/deploy/nnvm.md
+++ b/docs/deploy/nnvm.md
@@ -59,9 +59,11 @@ An example in c++.
 #include <tvm/runtime/registry.h>
 #include <tvm/runtime/packed_func.h>
 
+#include <algorithm>
 #include <fstream>
 #include <iterator>
-#include <algorithm>
+#include <stdexcept>
+#include <string>
 
 int main()
 {
@@ -97,7 +99,9 @@ int main()
     int64_t in_shape[4] = {1, 3, 224, 224};
     TVMArrayAlloc(in_shape, in_ndim, dtype_code, dtype_bits, dtype_lanes, device_type, device_id, &x);
     // load image data saved in binary
-    std::ifstream data_fin("cat.bin", std::ios::binary);
+    const std::string data_filename = "cat.bin";
+    std::ifstream data_fin(data_filename, std::ios::binary);
+    if(!data_fin) throw std::runtime_error("Could not open: " + data_filename);
     data_fin.read(static_cast<char*>(x->data), 3 * 224 * 224 * 4);
 
     // get the function from the module(set input data)
diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h
index 7114a4550331..a83288ce3662 100644
--- a/include/tvm/build_module.h
+++ b/include/tvm/build_module.h
@@ -229,6 +229,9 @@ class BuildConfigNode : public Node {
   /*! \brief Whether to disable loop vectorization. */
   bool disable_vectorize = false;
 
+  /*! \brief Whether to disable assert stmt generation. */
+  bool disable_assert = false;
+
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("data_alignment", &data_alignment);
     v->Visit("offset_factor", &offset_factor);
@@ -244,6 +247,7 @@ class BuildConfigNode : public Node {
     v->Visit("instrument_bound_checkers", &instrument_bound_checkers);
     v->Visit("disable_select_rewriting", &disable_select_rewriting);
     v->Visit("disable_vectorize", &disable_vectorize);
+    v->Visit("disable_assert", &disable_assert);
   }
 
   static constexpr const char* _type_key = "BuildConfig";
diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h
index 76d7d61f1e3d..5c5c4bb2f452 100644
--- a/include/tvm/ir_pass.h
+++ b/include/tvm/ir_pass.h
@@ -563,6 +563,13 @@ LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target);
  */
 LoweredFunc InferFragment(LoweredFunc f);
 
+/*!
+ * \brief skip assert stmt generation
+ * \param f The function to be transformed.
+ * \return Transformed function.
+ */
+LoweredFunc SkipAssert(LoweredFunc f);
+
 /*!
  * \brief Verify if memory accesses are legal for a specific target device type.
  *
diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h
index 1ef7ca88280e..0d3f46cd3cc0 100644
--- a/include/tvm/relay/module.h
+++ b/include/tvm/relay/module.h
@@ -144,6 +144,13 @@ class ModuleNode : public RelayNode {
    */
   TVM_DLL bool ContainGlobalVar(const std::string& name) const;
 
+  /*!
+   * \brief Check if the global_type_var_map_ contains a global type variable.
+   * \param name The variable name.
+   * \returns true if contains, otherise false.
+   */
+  TVM_DLL bool ContainGlobalTypeVar(const std::string& name) const;
+
   /*!
    * \brief Lookup a global function by its variable.
    * \param str The unique string specifying the global variable.
@@ -198,13 +205,6 @@ class ModuleNode : public RelayNode {
    */
   TVM_DLL TypeData LookupDef(const std::string& var) const;
 
-  /*!
-   * \brief Check if a global type definition exists
-   * \param var The name of the global type definition.
-   * \return Whether the definition exists.
-   */
-  TVM_DLL bool HasDef(const std::string& var) const;
-
   /*!
    * \brief Look up a constructor by its tag.
    * \param tag The tag for the constructor.
diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h
index 10de08710fbe..ddadbe4fc31d 100644
--- a/include/tvm/relay/transform.h
+++ b/include/tvm/relay/transform.h
@@ -552,17 +552,20 @@ TVM_DLL Pass Legalize(const std::string& legalize_map_attr_name = "FTVMLegalize"
 TVM_DLL Pass CanonicalizeCast();
 
 /*!
- * \brief Add abstraction over a function
+ * \brief Add abstraction over a constructor or global variable bound to a function.
  *
  * For example: `square` is transformed to
- * `fun x -> square x`.
+ * `fn (%x: int32) -> int32 { square(x) }`.
  *
  * See https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B7-conversion
  * for more details.
  *
+ * \param expand_constructor Whether to expand constructors.
+ * \param expand_global_var Whether to expand global variables.
+ *
  * \return The pass.
  */
-TVM_DLL Pass EtaExpand();
+TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var);
 
 /*!
  * \brief Print the IR for a module to help debugging.
diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h
index bb362dcdec66..4b0fcd3159ab 100644
--- a/include/tvm/runtime/device_api.h
+++ b/include/tvm/runtime/device_api.h
@@ -42,7 +42,8 @@ enum DeviceAttrKind : int {
   kDeviceName = 5,
   kMaxClockRate = 6,
   kMultiProcessorCount = 7,
-  kMaxThreadDimensions = 8
+  kMaxThreadDimensions = 8,
+  kGcnArch = 9
 };
 
 /*! \brief Number of bytes each allocation must align to */
diff --git a/nnvm/python/nnvm/__init__.py b/nnvm/python/nnvm/__init__.py
index 31b88587764d..aaaa8b18a2d2 100644
--- a/nnvm/python/nnvm/__init__.py
+++ b/nnvm/python/nnvm/__init__.py
@@ -2,6 +2,7 @@
 # coding: utf-8
 """NNVM python API for ease of use and help new framework establish python API. """
 from __future__ import absolute_import as _abs
+import warnings
 
 from . import _base
 from . import symbol as sym
@@ -10,3 +11,6 @@
 from . import frontend
 
 __version__ = _base.__version__
+
+warnings.warn("NNVM is deprecated and will be removed in a future version. Use Relay instead.",
+              FutureWarning)
diff --git a/python/tvm/autotvm/tophub.py b/python/tvm/autotvm/tophub.py
index c290063f1d7f..12f2b803184d 100644
--- a/python/tvm/autotvm/tophub.py
+++ b/python/tvm/autotvm/tophub.py
@@ -50,7 +50,7 @@
     'arm_cpu':          "v0.04",
     'llvm':             "v0.03",
 
-    'cuda':             "v0.05",
+    'cuda':             "v0.06",
     'rocm':             "v0.03",
     'opencl':           "v0.03",
     'mali':             "v0.05",
diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py
index 217318ebfa84..f96e28323595 100644
--- a/python/tvm/build_module.py
+++ b/python/tvm/build_module.py
@@ -144,7 +144,8 @@ class BuildConfig(NodeBase):
         "dump_pass_ir": False,
         "instrument_bound_checkers": False,
         "disable_select_rewriting": False,
-        "disable_vectorize": False
+        "disable_vectorize": False,
+        "disable_assert": False
     }
     _dump_ir = DumpIR()
 
diff --git a/python/tvm/hybrid/module.py b/python/tvm/hybrid/module.py
index 13e45a7516fa..9811ae1bd4d6 100644
--- a/python/tvm/hybrid/module.py
+++ b/python/tvm/hybrid/module.py
@@ -22,7 +22,6 @@
 """
 
 import ast
-import imp
 
 from ..contrib import util
 from .util import _internal_assert
@@ -112,5 +111,9 @@ def visit_FunctionDef(self, node):
         if self.name is None:
             self.name = finder.name
         self.root_ = finder.root
-        py_module = imp.load_source(self.name, path)
-        self.func_ = getattr(py_module, self.name)
+
+        _, local_ = {}, {}
+        exec(self.src_, _, local_) #pylint: disable=exec-used
+        local_.pop('tvm')
+        assert len(local_) == 1
+        self.func_ = list(local_.values())[0]
diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index 0c581c96d4e5..3d90d15e1916 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -472,6 +472,76 @@ def _impl_v5(cls, inputs, attr, params):
                 static_shape.asnumpy().astype('int32')))
         return out
 
+
+class DepthToSpace(OnnxOpConverter):
+    """ Operator converter for DepthToSpace.
+    """
+
+    @classmethod
+    def _impl_v11(cls, inputs, attr, params):
+
+        block_size = int(attr['blocksize'])
+        mode = attr.get("mode", "DCR")
+
+        # handle NCHW layout
+        indata = infer_value_simulated(inputs[0], params)
+        in_n, in_c, in_h, in_w = indata.shape
+
+        # reshape to proper output
+        new_c = int(in_c / (block_size * block_size))
+        new_h = in_h * block_size
+        new_w = in_w * block_size
+        newshape = (in_n, new_c, new_h, new_w)
+
+        if mode == "DCR":
+            # expand input to larger dimension.
+            expanded = _op.reshape(inputs[0],
+                                   newshape=(in_n, block_size, block_size, new_c, in_h, in_w))
+            # reorder to expand spatial blocks.
+            transposed = _op.transpose(expanded, axes=(0, 3, 4, 1, 5, 2))
+
+        else:  # CRD mode
+            # expand input to larger dimension.
+            expanded = _op.reshape(inputs[0],
+                                   newshape=(in_n, new_c, block_size, block_size, in_h, in_w))
+            # reorder to expand spatial blocks.
+            transposed = _op.transpose(expanded, axes=(0, 1, 4, 2, 5, 3))
+
+        return AttrCvt(op_name="reshape",
+                       extras={'newshape': newshape},
+                       ignores=['mode', 'blocksize'])([transposed], attr)
+
+
+class SpaceToDepth(OnnxOpConverter):
+    """ Operator converter for SpaceToDepth.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+
+        block_size = int(attr['blocksize'])
+
+        # handle NCHW layout
+        indata = infer_value_simulated(inputs[0], params)
+        in_n, in_c, in_h, in_w = indata.shape
+
+        # reshape to proper output
+        new_c = in_c * (block_size * block_size)
+        new_h = int(in_h / block_size)
+        new_w = int(in_w / block_size)
+        newshape = (in_n, new_c, new_h, new_w)
+
+        # expand input to larger dimension.
+        expanded = _op.reshape(inputs[0],
+                               newshape=(in_n, in_c, new_h, block_size, new_w, block_size))
+        # reorder to expand spatial blocks.
+        transposed = _op.transpose(expanded, axes=(0, 3, 5, 1, 2, 4))
+
+        return AttrCvt(op_name="reshape",
+                       extras={'newshape': newshape},
+                       ignores=['blocksize'])([transposed], attr)
+
+
 class Concat(OnnxOpConverter):
     """ Operator converter for Concat.
     """
@@ -1121,6 +1191,8 @@ def _get_convert_map(opset):
         'Split': Split.get_converter(opset),
         'Slice': Slice.get_converter(opset),
         'Transpose': AttrCvt('transpose', {'perm': 'axes'}),
+        'DepthToSpace': DepthToSpace.get_converter(opset),
+        'SpaceToDepth': SpaceToDepth.get_converter(opset),
         'Gather': Gather.get_converter(opset),
         'Squeeze': AttrCvt('squeeze', {'axes': 'axis'}),
         'Unsqueeze': Unsqueeze.get_converter(opset),
diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py
index 35c857a3d77f..e44653ff1ba9 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -1039,8 +1039,8 @@ def _impl(inputs, attr, params):
         # otherwise its value is get from params
         try:
             axes = _get_list_param(params, inputs[1])
-        except (IndexError, KeyError):
-            axes = None
+        except (IndexError, KeyError, AttributeError):
+            axes = _infer_value_simulated(inputs[1], params).asnumpy()
         return _op.transpose(inputs[0], axes=axes)
     return _impl
 
diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py
index 8966aa6b389e..415f04eff52c 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -575,8 +575,7 @@ def convert_mul(self, op):
         """Convert TFLite MUL"""
         # Check if the input tensor is quantized, call QNN op
         if self.is_quantized(op):
-            raise tvm.error.OpNotImplemented(
-                'TFlite quantized mul operator is not supported yet.')
+            return self._convert_elemwise(_qnn.op.mul, op)
         return self._convert_elemwise(_op.multiply, op)
 
     def convert_div(self, op):
@@ -1341,14 +1340,13 @@ def convert_prelu(self, op):
         alpha_tensor = input_tensors[1]
         alpha_tensor_type = alpha_tensor.tensor.Type()
         alpha_tensor_type_str = self.get_tensor_type_str(alpha_tensor_type)
-        alpha_expr = self.exp_tab.new_const(self.get_tensor_value(alpha_tensor),
+        alpha_expr = self.exp_tab.new_const(self.get_tensor_value(alpha_tensor).flatten(),
                                             dtype=alpha_tensor_type_str)
         in_expr = self.get_expr(input_tensor.tensor_idx)
         out = _op.nn.prelu(in_expr, alpha_expr, axis=3)
 
         return out
 
-
     def get_expr(self, input_tensor_idx):
         return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))
 
diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py
index 6b2e073822f1..3f94f98a03a5 100644
--- a/python/tvm/relay/qnn/op/legalizations.py
+++ b/python/tvm/relay/qnn/op/legalizations.py
@@ -22,10 +22,43 @@
 from tvm import relay
 from .. import op as reg
 
+#################################################
+# Register the functions for different operators.
+#################################################
+
 # Registering QNN Conv2D legalization function.
 @reg.register_qnn_legalize("qnn.conv2d")
 def legalize_qnn_conv2d(attrs, inputs, types):
-    """Legalizes QNN conv2d op.
+    return qnn_conv2d_legalize(attrs, inputs, types)
+
+# Registering QNN dense legalization function.
+@reg.register_qnn_legalize("qnn.dense")
+def legalize_qnn_dense(attrs, inputs, types):
+    return qnn_dense_legalize(attrs, inputs, types)
+
+# Default to None. If overridden by target, this will not be run.
+# Generic QNN Conv2D legalization function.
+@tvm.target.generic_func
+def qnn_conv2d_legalize(attrs, inputs, types):
+    """Default legalization is None."""
+    return None
+
+# Generic QNN Conv2D legalization function.
+@tvm.target.generic_func
+def qnn_dense_legalize(attrs, inputs, types):
+    """Default legalization is None."""
+    return None
+
+###################
+# Helper functions.
+###################
+
+# Helper function for lowering in the abscence of fast Int8 arithmetic units.
+def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op):
+    """ Converts QNN operators into a sequence of Relay operators that are friendly to HW that do
+    not have fast Int8 arithmetic. For example, for ARM, LLVM utilizes the assembly instructions
+    much more efficiently if the convolution or dense operator input datatypes are int16 instead of
+    int8. More details are present at https://github.com/apache/incubator-tvm/pull/4277.
 
     Parameters
     ----------
@@ -41,19 +74,27 @@ def legalize_qnn_conv2d(attrs, inputs, types):
     result : tvm.relay.Expr
         The legalized expr
     """
-    return qnn_conv2d_legalize(attrs, inputs, types)
 
-# Generic QNN Conv2D legalization function.
-@tvm.target.generic_func
-def qnn_conv2d_legalize(attrs, inputs, types):
-    """Default legalization is None."""
-    return None
+    # Collect the input exprs.
+    data, kernel = inputs
 
-# Intel x86 QNN Conv2D legalization function.
-@qnn_conv2d_legalize.register('cpu')
-def _qnn_conv2d_legalize(attrs, inputs, types):
-    """Legalizes QNN conv2d op. VNNI supports u8 x i8 fast conv/MM. If the dtypes are already good,
-    we dont transform. Else, we shift the tensor values and zero points to change the dtype.
+    input_zp = attrs['input_zero_point']
+    kernel_zp = attrs['kernel_zero_point']
+
+    shift_data = relay.subtract(relay.cast(data, dtype='int16'),
+                                relay.const(input_zp, 'int16'))
+    shift_kernel = relay.subtract(relay.cast(kernel, dtype='int16'),
+                                  relay.const(kernel_zp, 'int16'))
+    new_attrs = {k : attrs[k] for k in attrs.keys()}
+    del new_attrs['kernel_zero_point']
+    del new_attrs['input_zero_point']
+    return relay_op(shift_data, shift_kernel, **new_attrs)
+
+# Helper function to change dtypes to uint8 x int8. Intel VNNI instructions prefer this setting.
+def helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay_op):
+    """Legalizes QNN conv2d/dense op for Intel HW. VNNI supports u8 x i8 fast conv/MM. If the dtypes
+    are already good, we dont transform. Else, we shift the tensor values and zero points to change
+    the dtype.
 
     Converting from int8 to uint8 can be done in following manner.
 
@@ -82,26 +123,18 @@ def _qnn_conv2d_legalize(attrs, inputs, types):
         The legalized expr
     """
 
-    def _shift(data, out_dtype):
+    def _shift(data, zero_point, out_dtype):
         """Shifts (add/subtracts) the qnn tensor with +/-128)"""
         if out_dtype == 'uint8':
             shift = 128
         elif out_dtype == 'int8':
             shift = -128
         else:
-            raise ValueError("Unsupport out dtype.")
+            raise ValueError("Unsupported out dtype.")
         data_modified = relay.cast(data, 'int32')
         data_modified = relay.add(data_modified, relay.const(shift, 'int32'))
         data_modified = relay.cast(data_modified, out_dtype)
-        return data_modified
-
-    def _is_int8_hw_support(target):
-        """
-        Checks to ensure that we can use Intel DLBoost instructions - Check if the target is skylake
-        and above.
-        """
-        supported_arches = {'-mcpu=skylake-avx512', '-mcpu=cascadelake'}
-        return supported_arches.intersection(set(target.options))
+        return (data_modified, zero_point + shift)
 
     # Collect the dtypes.
     data_dtype = types[0].dtype
@@ -110,11 +143,6 @@ def _is_int8_hw_support(target):
     # Collect the input exprs.
     data, kernel = inputs
 
-    # The VNNI transformations are applicable only Skylake and above.g
-    target = tvm.target.current_target(allow_none=False)
-    if not _is_int8_hw_support(target):
-        return None
-
     # VNNI supports u8 x i8 fast conv/MM. Don't do anything if it is already satisfied.
     if data_dtype == 'uint8' and kernel_dtype == 'int8':
         return None
@@ -123,18 +151,118 @@ def _is_int8_hw_support(target):
     input_zp = attrs['input_zero_point']
     if data_dtype == 'int8':
         # Compute (QA + 128) and (zp_a + 128)
-        data = _shift(data, 'uint8')
-        input_zp = input_zp + 128
+        data, input_zp = _shift(data, input_zp, 'uint8')
 
     # Shift kernel if necessary.
     kernel_zp = attrs['kernel_zero_point']
     if kernel_dtype == 'uint8':
         # Compute (QA - 128) and (zp_a - 128)
-        kernel = _shift(kernel, 'int8')
-        kernel_zp = kernel_zp - 128
+        kernel, kernel_zp = _shift(kernel, kernel_zp, 'int8')
 
     # Call qnn.conv2d with modified inputs and zero points.
     new_attrs = {k : attrs[k] for k in attrs.keys()}
     new_attrs['input_zero_point'] = input_zp
     new_attrs['kernel_zero_point'] = kernel_zp
-    return relay.qnn.op.conv2d(data, kernel, **new_attrs)
+    return relay_op(data, kernel, **new_attrs)
+
+# Helper function to change dtypes to be same. ARM dotprod instructions prefer this setting.
+def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op):
+    """ Sometimes MxNet + MLDNN can lead to uint8 x int8 datatypes for the conv inputs. However,
+    many devices like ARM prefer the datatypes to be same for the HW units. This helper transforms
+    conv2d/dense such that both the dtypes are same.
+
+    Parameters
+    ----------
+    attrs : tvm.attrs.Attrs
+        Attributes of current convolution
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    types : list of types
+        List of input and output types
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The legalized expr
+    """
+
+    def _shift(data, zero_point, out_dtype):
+        """Shifts (adds/subtracts) the qnn tensor by 128)"""
+        if out_dtype == 'uint8':
+            shift = 128
+        elif out_dtype == 'int8':
+            shift = -128
+        else:
+            raise ValueError("Unsupported out dtype.")
+        data_modified = relay.cast(data, 'int32')
+        data_modified = relay.add(data_modified, relay.const(shift, 'int32'))
+        data_modified = relay.cast(data_modified, out_dtype)
+        return (data_modified, zero_point + shift)
+
+    # Collect the dtypes.
+    data_dtype = types[0].dtype
+    kernel_dtype = types[1].dtype
+
+    if data_dtype == kernel_dtype:
+        return None
+
+    # Collect the input exprs.
+    data, kernel = inputs
+
+    assert 'int8' in data_dtype and 'int8' in kernel_dtype, \
+            "Qnn Conv2D/Dense only accepts uint8 or int8 inputs"
+
+    # Shift input if necessary.
+    input_zp = attrs['input_zero_point']
+    data, input_zp = _shift(data, input_zp, kernel_dtype)
+
+    new_attrs = {k : attrs[k] for k in attrs.keys()}
+    new_attrs['input_zero_point'] = input_zp
+    return relay_op(data, kernel, **new_attrs)
+
+def is_fast_int8_on_intel():
+    """ Checks whether the hardware has support for fast Int8 arithmetic operations. """
+    target = tvm.target.current_target(allow_none=False)
+    intel_supported_arches = {'-mcpu=skylake-avx512', '-mcpu=cascadelake'}
+    return intel_supported_arches.intersection(set(target.options))
+
+def is_fast_int8_on_arm():
+    """ Checks whether the hardware has support for fast Int8 arithmetic operations. """
+    target = tvm.target.current_target(allow_none=False)
+    return '+v8.2a,+dotprod' in ' '.join(target.options)
+
+########################
+# ARM CPU legalizations.
+########################
+
+@qnn_conv2d_legalize.register('arm_cpu')
+def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types):
+    # ARM prefers the dtypes to be same.
+    if is_fast_int8_on_arm():
+        return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
+    return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d)
+
+@qnn_dense_legalize.register('arm_cpu')
+def _qnn_dense_legalize_arm_cpu(attrs, inputs, types):
+    # ARM prefers the dtypes to be same.
+    if is_fast_int8_on_arm():
+        return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)
+    return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense)
+
+##########################
+# Intel CPU legalizations.
+##########################
+
+@qnn_conv2d_legalize.register('cpu')
+def _qnn_conv2d_legalize_intel_cpu(attrs, inputs, types):
+    # The VNNI transformations prefer uint8 x int8 datatypes.
+    if is_fast_int8_on_intel():
+        return helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay.qnn.op.conv2d)
+    return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d)
+
+@qnn_dense_legalize.register('cpu')
+def _qnn_dense_legalize_intel_cpu(attrs, inputs, types):
+    # The VNNI transformations prefer uint8 x int8 datatypes.
+    if is_fast_int8_on_intel():
+        return helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay.qnn.op.dense)
+    return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense)
diff --git a/python/tvm/relay/std/prelude.rly b/python/tvm/relay/std/prelude.rly
index a5c2c9f8a9cb..fa05d1a7bd98 100644
--- a/python/tvm/relay/std/prelude.rly
+++ b/python/tvm/relay/std/prelude.rly
@@ -158,13 +158,9 @@ def @sum(%xs: List[Tensor[(), int32]]) {
 /*
  * Concatenates two lists.
  */
+
 def @concat[A](%xs: List[A], %ys: List[A]) -> List[A] {
-  let %updater = fn(%x: A, %xss: List[A]) -> List[A] {
-    Cons(%x, %xss)
-  };
-  @foldr(%updater, %ys, %xs)
-  // TODO(weberlo): write it like below, once VM constructor compilation is fixed
-  // @foldr(Cons, %ys, %xs)
+  @foldr(Cons, %ys, %xs)
 }
 
 /*
@@ -199,12 +195,7 @@ def @zip[A, B](%xs: List[A], %ys: List[B]) -> List[(A, B)] {
  * Reverses a list.
  */
 def @rev[A](%xs: List[A]) -> List[A] {
-  let %updater = fn(%xss: List[A], %x: A) -> List[A] {
-    Cons(%x, %xss)
-  };
-  @foldl(%updater, Nil, %xs)
-  // TODO(weberlo): write it like below, once VM constructor compilation is fixed
-  // @foldl(@flip(Cons), Nil, %xs)
+  @foldl(@flip(Cons), Nil, %xs)
 }
 
 /*
diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py
index d3509dabddf9..540c1f5b79cd 100644
--- a/python/tvm/relay/transform.py
+++ b/python/tvm/relay/transform.py
@@ -297,6 +297,22 @@ def BackwardFoldScaleAxis():
     """
     return _transform.BackwardFoldScaleAxis()
 
+def RemoveUnusedFunctions(entry_functions=None):
+    """Remove unused global relay functions in a relay module.
+
+    Parameters
+    ----------
+    entry_functions: list[string]
+        The set of entry functions to start from.
+
+    Returns
+    -------
+    ret : tvm.relay.Pass
+        The registered pass to remove unused functions.
+    """
+    if entry_functions is None:
+        entry_functions = ['main']
+    return _transform.RemoveUnusedFunctions(entry_functions)
 
 def ForwardFoldScaleAxis():
     """Fold the scaling of axis into weights of conv2d/dense.
@@ -513,15 +529,23 @@ def ToCPS(expr, mod=None):
     return _transform.to_cps(expr, mod)
 
 
-def EtaExpand():
-    """Add abstraction over a function
+def EtaExpand(expand_constructor=False, expand_global_var=False):
+    """Add abstraction over a constructor or global variable bound to a function
+
+    Parameters
+    ----------
+    expand_constructor: bool
+        Whether to expand constructors.
+
+    expand_global_var: bool
+        Whether to expand global variables.
 
     Returns
     -------
     ret: tvm.relay.Pass
         The registered pass that eta expands an expression.
     """
-    return _transform.EtaExpand()
+    return _transform.EtaExpand(expand_constructor, expand_global_var)
 
 
 def ToGraphNormalForm():
@@ -943,6 +967,7 @@ def create_function_pass(pass_arg):
         return create_function_pass(pass_func)
     return create_function_pass
 
+
 @function_pass(opt_level=1)
 class ChangeBatch:
     """
diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc
index 3f279f8772df..ac991d4bfea3 100644
--- a/src/codegen/build_module.cc
+++ b/src/codegen/build_module.cc
@@ -672,6 +672,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", ";
   p->stream << "disable_select_rewriting=" << op->disable_select_rewriting;
   p->stream << "disable_vectorize=" << op->disable_vectorize;
+  p->stream << "disable_assert=" << op->disable_assert;
   p->stream << ")";
 });
 
diff --git a/src/codegen/codegen.cc b/src/codegen/codegen.cc
index ed9484b211b0..4ea37ba7317b 100644
--- a/src/codegen/codegen.cc
+++ b/src/codegen/codegen.cc
@@ -26,6 +26,7 @@
 #include <tvm/ir_pass.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/runtime/module.h>
+#include <tvm/build_module.h>
 #include <dmlc/memory_io.h>
 #include <sstream>
 #include <iostream>
@@ -40,12 +41,21 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
   if (pos != std::string::npos) {
     mode = mode.substr(0, pos);
   }
+  Array<LoweredFunc> transformed_funcs;
+  for (const auto& x : funcs) {
+    if (BuildConfig::Current()->disable_assert) {
+      auto func = ir::SkipAssert(x);
+      transformed_funcs.push_back(func);
+    }
+  }
   std::string build_f_name = "codegen.build_" + mode;
   // the build function.
   const PackedFunc* bf = runtime::Registry::Get(build_f_name);
   CHECK(bf != nullptr)
       << "Target " << target << " is not enabled";
-  runtime::Module m = (*bf)(funcs, target);
+  runtime::Module m = transformed_funcs.empty() ?
+                      (*bf)(funcs, target) :
+                      (*bf)(transformed_funcs, target);
   return m;
 }
 
diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc
index 22e8d842e424..2a412823d6ef 100644
--- a/src/codegen/codegen_cuda.cc
+++ b/src/codegen/codegen_cuda.cc
@@ -58,15 +58,19 @@ std::string CodeGenCUDA::Finish() {
                 << "{\n  return __hgt(__half(a), __half(b)) ? a : b;\n}\n";
     decl_stream << "__device__ half min(half a, half b)\n"
                 << "{\n  return __hlt(__half(a), __half(b)) ? a : b;\n}\n";
-    decl_stream << "__device__ half operator<="
-                << "(__half a,  __half b)\n"
-                << "{\n  return __hlt(a, b);\n}\n";
-    decl_stream << "__device__ half operator+"
-                << "(__half a,  __half &b)\n"
-                <<"{\n  return __hadd(a, b);\n}\n";
-    decl_stream << "__device__ half operator*"
-                << "(__half a, __half b)\n"
-                <<   "{\n  return __hmul(a, b);\n}\n";
+    // FIXME(tvm-team): "volatile" is used to enable cross thread reduction,
+    // which is needed by operations such as softmax.
+    // However, volatile overloading is not supported in NVRTC and CUDA < 9.2.
+    // We need to figure out a solution which can satisfy both scenario.
+    // decl_stream << "__device__ half operator<="
+    //             << "(const volatile __half &a,  const volatile __half &b)\n"
+    //             << "{\n  return __hlt(a, b);\n}\n";
+    // decl_stream << "__device__ half operator+"
+    //             << "(const volatile __half &a,  const volatile __half &b)\n"
+    //             <<"{\n  return __hadd(a, b);\n}\n";
+    // decl_stream << "__device__ half operator*"
+    //             << "(const volatile __half &a, const volatile __half &b)\n"
+    //             <<   "{\n  return __hmul(a, b);\n}\n";
     // otherwise simulate computation via float32
     decl_stream << "#else\n";
     decl_stream << _cuda_half_t_def;
diff --git a/src/codegen/literal/cuda_half_t.h b/src/codegen/literal/cuda_half_t.h
index 23075b0b6e76..0889032aadd4 100644
--- a/src/codegen/literal/cuda_half_t.h
+++ b/src/codegen/literal/cuda_half_t.h
@@ -28,6 +28,7 @@
 static constexpr const char* _cuda_half_t_def = R"(
 typedef unsigned short uint16_t;
 typedef unsigned char uint8_t;
+typedef signed char int8_t;
 typedef int int32_t;
 typedef unsigned long long uint64_t;
 typedef unsigned int uint32_t;
@@ -76,7 +77,7 @@ class TVM_ALIGNED(2) half {
   TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); }
   TVM_XINLINE explicit half(const int32_t& value) { constructor(value); }
   TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); }
-  TVM_XINLINE explicit half(const int64_t& value) { constructor(value); }
+  TVM_XINLINE explicit half(const long long& value) { constructor(value); }
   TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); }
 
   TVM_XINLINE operator float() const {                          \
diff --git a/src/codegen/llvm/codegen_amdgpu.cc b/src/codegen/llvm/codegen_amdgpu.cc
index 28b2deb3f2b7..19179c3e6064 100644
--- a/src/codegen/llvm/codegen_amdgpu.cc
+++ b/src/codegen/llvm/codegen_amdgpu.cc
@@ -174,7 +174,7 @@ inline int DetectROCMComputeVersion(const std::string& target) {
     TVMRetValue val;
     api->GetAttr(tvm_ctx, tvm::runtime::kExist, &val);
     if (val.operator int() == 1) {
-      tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kComputeVersion, &val);
+      tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kGcnArch, &val);
       return val.operator int();
     }
   }
diff --git a/src/common/util.h b/src/common/util.h
index 93f32f48a2a6..85db7f387093 100644
--- a/src/common/util.h
+++ b/src/common/util.h
@@ -35,6 +35,7 @@
 #include <sstream>
 #include <algorithm>
 #include <array>
+#include <cctype>
 #include <memory>
 
 namespace tvm {
diff --git a/src/pass/skip_assert.cc b/src/pass/skip_assert.cc
new file mode 100644
index 000000000000..5f310a61dfe3
--- /dev/null
+++ b/src/pass/skip_assert.cc
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/ir.h>
+#include <tvm/ir_pass.h>
+#include <tvm/ir_mutator.h>
+
+namespace tvm {
+namespace ir {
+
+class AssertSkipper : public IRMutator {
+ public:
+  Stmt Mutate_(const AssertStmt* op, const Stmt& s) final {
+    Stmt stmt = IRMutator::Mutate_(op, s);
+    op = stmt.as<AssertStmt>();
+    return op->body;
+  }
+};
+
+Stmt SkipAssert(Stmt stmt) {
+  return AssertSkipper().Mutate(stmt);
+}
+
+LoweredFunc SkipAssert(LoweredFunc f) {
+  auto n = make_node<LoweredFuncNode>(*f.operator->());
+  n->body = SkipAssert(f->body);
+  return LoweredFunc(n);
+}
+
+}  // namespace ir
+}  // namespace tvm
diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc
index 01693e5b3673..45283582bf05 100644
--- a/src/relay/backend/interpreter.cc
+++ b/src/relay/backend/interpreter.cc
@@ -26,6 +26,7 @@
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/pattern_functor.h>
 #include <tvm/relay/interpreter.h>
+#include <tvm/relay/transform.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/attrs/debug.h>
 #include <tvm/relay/feature.h>
@@ -789,6 +790,16 @@ CreateInterpreter(
     Module mod,
     DLContext context,
     Target target) {
+  if (mod.defined()) {
+    // eta expand to support constructors in argument position
+    transform::Sequential seq({
+        transform::EtaExpand(
+            /* expand_constructor */ true, /* expand_global_var */ false)});
+    transform::PassContext pass_ctx = transform::PassContext::Current();
+    tvm::With<transform::PassContext> ctx(pass_ctx);
+    mod = seq(mod);
+  }
+
   auto intrp = std::make_shared<Interpreter>(mod, context, target);
   auto packed = [intrp](Expr expr) {
     auto f = DetectFeature(expr);
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index 7f828c473bbe..c38ca1ae0469 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -54,6 +54,7 @@ namespace transform {
 
 Pass LambdaLift();
 Pass InlinePrimitives();
+Pass RemoveUnusedFunctions(Array<tvm::Expr> entry_functions);
 
 Pass ManifestAlloc(Target target_host) {
   auto f = tvm::runtime::Registry::Get("relay.transform.ManifestAlloc");
@@ -863,6 +864,8 @@ void VMCompiler::Compile(Module mod,
 
 Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) {
   Array<Pass> pass_seqs;
+  Array<tvm::Expr> entry_functions{tvm::Expr{"main"}};
+  pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
   // Run all dialect legalization passes.
   pass_seqs.push_back(relay::qnn::transform::Legalize());
 
@@ -871,6 +874,10 @@ Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets)
     pass_seqs.push_back(transform::Legalize());
   }
 
+  // eta expand to support constructors in argument position
+  pass_seqs.push_back(transform::EtaExpand(
+    /* expand_constructor */ true, /* expand_global_var */ false));
+
   pass_seqs.push_back(transform::SimplifyInference());
   PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
     Expr expr = args[0];
diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc
index 6290ef7c6e93..6ef31e626dbb 100644
--- a/src/relay/backend/vm/lambda_lift.cc
+++ b/src/relay/backend/vm/lambda_lift.cc
@@ -61,8 +61,8 @@ Function MarkClosure(const Function& func) {
  * We will lift a function out into a global which takes the set of the free
  * vars and then return the new created function.
  */
-struct LambdaLifter : ExprMutator {
-  Module module_;
+class LambdaLifter : public ExprMutator {
+ public:
   explicit LambdaLifter(const Module& module) : module_(module) {}
 
   Expr VisitExpr_(const FunctionNode* func_node) final {
@@ -100,8 +100,8 @@ struct LambdaLifter : ExprMutator {
     // The "inner" function should be used to generate the
     // code for the closure.
     Function lifted_func;
-    if (free_vars.size() == 0) {
-      lifted_func = FunctionNode::make(body->params, body->body, body->ret_type, free_type_vars);
+    if (free_vars.size() == 0 && free_type_vars.size() == 0) {
+      lifted_func = FunctionNode::make(body->params, body->body, body->ret_type, body->type_params);
     } else {
       lifted_func =
           FunctionNode::make(free_vars, body, func->func_type_annotation(), free_type_vars);
@@ -114,8 +114,15 @@ struct LambdaLifter : ExprMutator {
     auto name = GenerateName(lifted_func);
     auto global = GlobalVarNode::make(name);
 
-    // Add the lifted function to the module.
-    module_->Add(global, lifted_func);
+    if (module_->ContainGlobalVar(name)) {
+      const auto existing_func = module_->Lookup(name);
+      CHECK(AlphaEqual(lifted_func, existing_func)) << "lifted function hash collision";
+      // If an identical function already exists, use its global var.
+      global = module_->GetGlobalVar(name);
+    } else {
+      // Add the lifted function to the module.
+      module_->Add(global, lifted_func);
+    }
 
     if (free_vars.size() == 0) {
       return std::move(global);
@@ -145,6 +152,9 @@ struct LambdaLifter : ExprMutator {
     }
     return module_;
   }
+
+ private:
+  Module module_;
 };
 
 }  // namespace vm
diff --git a/src/relay/backend/vm/removed_unused_funcs.cc b/src/relay/backend/vm/removed_unused_funcs.cc
new file mode 100644
index 000000000000..a01204077c55
--- /dev/null
+++ b/src/relay/backend/vm/removed_unused_funcs.cc
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file tvm/relay/backend/vm/remove_unused_funcs.cc
+ * \brief Remove unused global relay functions in a relay module.
+ */
+
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/logging.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/vm.h>
+#include <iostream>
+#include <unordered_set>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+namespace vm {
+
+/**
+ * \brief Detects all the functions that can be possibly called by entry function.
+ */
+struct CallTracer : ExprVisitor {
+  Module module_;
+
+  // Record the names of all encountered functions
+  std::unordered_set<std::string> called_funcs_;
+
+  // Record the expressions that are being visited
+  std::unordered_set<Expr, NodeHash, NodeEqual> visiting_;
+
+  explicit CallTracer(const Module& module)
+    : module_{module},
+      called_funcs_{},
+      visiting_{} {}
+
+  void VisitExpr_(const CallNode* call_node) final {
+    Expr op = call_node->op;
+    for (auto param : call_node->args) {
+      VisitExpr(param);
+    }
+    if (auto func_node = op.as<FunctionNode>()) {
+      auto func = GetRef<Function>(func_node);
+      auto it = visiting_.find(func);
+      if (it != visiting_.end()) {
+        return;
+      }
+      visiting_.insert(func);
+      VisitExpr(func);
+    } else if (auto global = op.as<GlobalVarNode>()) {
+      called_funcs_.insert(global->name_hint);
+      auto func = module_->Lookup(global->name_hint);
+      auto it = visiting_.find(func);
+      if (it != visiting_.end()) {
+        return;
+      }
+      visiting_.insert(func);
+      VisitExpr(func);
+    }
+  }
+
+  std::unordered_set<std::string> Trace(const std::string& entry) {
+    called_funcs_.insert(entry);
+    auto main_func = module_->Lookup(entry);
+    VisitExpr(main_func);
+    return called_funcs_;
+  }
+};
+
+/*!
+ * \brief Remove functions that are not used.
+ *
+ * \param module The Relay module.
+ * \param entry_funcs The set of functions that can be entry function.
+ * 
+ * \return The module with dead functions removed.
+ */
+Module RemoveUnusedFunctions(const Module& module,
+                             Array<tvm::Expr> entry_funcs) {
+  std::unordered_set<std::string> called_funcs{};
+  for (auto entry : entry_funcs) {
+    auto* str_name = entry.as<ir::StringImm>();
+    auto funcs = CallTracer(module).Trace(str_name->value);
+    called_funcs.insert(funcs.cbegin(), funcs.cend());
+  }
+  auto existing_functions = module->functions;
+  for (auto f : existing_functions) {
+    auto it = called_funcs.find(f.first->name_hint);
+    if (it == called_funcs.end()) {
+      module->Remove(f.first);
+    }
+  }
+  return module;
+}
+
+}  // namespace vm
+
+namespace transform {
+
+Pass RemoveUnusedFunctions(Array<tvm::Expr> entry_functions) {
+  runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
+    [=](Module m, PassContext pc) {
+    return relay::vm::RemoveUnusedFunctions(m, entry_functions);
+  };
+  return CreateModulePass(pass_func, 1, "RemoveUnusedFunctions", {});
+}
+
+TVM_REGISTER_API("relay._transform.RemoveUnusedFunctions")
+.set_body_typed(RemoveUnusedFunctions);
+
+}  // namespace transform
+
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc
index 0dbcf992e028..df91f794f6d1 100644
--- a/src/relay/ir/alpha_equal.cc
+++ b/src/relay/ir/alpha_equal.cc
@@ -69,7 +69,7 @@ class AlphaEqualHandler:
       }
       if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false;
       for (const auto& p : lhsm->type_definitions) {
-        if (!rhsm->HasDef(p.first->var->name_hint) ||
+        if (!rhsm->ContainGlobalTypeVar(p.first->var->name_hint) ||
             !Equal(p.second, rhsm->LookupDef(p.first->var->name_hint))) {
           return false;
         }
diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc
index 960c28f94c76..3bd8d59aaf49 100644
--- a/src/relay/ir/module.cc
+++ b/src/relay/ir/module.cc
@@ -68,6 +68,10 @@ bool ModuleNode::ContainGlobalVar(const std::string& name) const {
   return global_var_map_.find(name) != global_var_map_.end();
 }
 
+bool ModuleNode::ContainGlobalTypeVar(const std::string& name) const {
+  return global_type_var_map_.find(name) != global_type_var_map_.end();
+}
+
 GlobalVar ModuleNode::GetGlobalVar(const std::string& name) const {
   auto it = global_var_map_.find(name);
   CHECK(it != global_var_map_.end())
@@ -239,11 +243,6 @@ TypeData ModuleNode::LookupDef(const std::string& name) const {
   return this->LookupDef(id);
 }
 
-bool ModuleNode::HasDef(const std::string& name) const {
-  auto it = global_type_var_map_.find(name);
-  return it != global_type_var_map_.end();
-}
-
 Constructor ModuleNode::LookupTag(const int32_t tag) {
   auto it = constructor_tag_map_.find(tag);
   CHECK(it != constructor_tag_map_.end())
@@ -336,7 +335,8 @@ TVM_REGISTER_API("relay._module.Module_Add")
   } else if (val->IsInstance<GlobalVarNode>()) {
     GlobalVar gv = Downcast<GlobalVar>(val);
     auto mod_copy = Module(make_node<ModuleNode>(*mod.operator->()));
-    mod_copy = transform::EtaExpand()(mod_copy);
+    mod_copy = transform::EtaExpand(
+      /* expand_constructor */ false, /* expand_global_var */ true)(mod_copy);
     auto func = mod_copy->Lookup(gv->name_hint);
     mod->Add(var, Downcast<Function>(func), update);
   } else {
diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc
index b2a8396706f2..f42069b99603 100644
--- a/src/relay/ir/pretty_printer.cc
+++ b/src/relay/ir/pretty_printer.cc
@@ -669,7 +669,7 @@ class PrettyPrinter :
   Doc VisitExpr_(const ConstructorNode* n) final {
     Doc doc;
     doc << n->name_hint;
-    if (n->inputs.size() != 0) {
+    if (in_adt_def_ && n->inputs.size() != 0) {
       doc << "(";
       std::vector<Doc> inputs;
       for (Type input : n->inputs) {
@@ -775,6 +775,7 @@ class PrettyPrinter :
   }
 
   Doc VisitType_(const TypeDataNode* node) final {
+    in_adt_def_ = true;
     Doc doc;
     doc << "type " << Print(node->header);
 
@@ -802,6 +803,7 @@ class PrettyPrinter :
       adt_body << ",";
     }
     doc << Brace(adt_body);
+    in_adt_def_ = false;
     return doc;
   }
 
@@ -876,6 +878,8 @@ class PrettyPrinter :
   TextMetaDataContext meta_;
   /*! \brief counter of temporary variable */
   size_t temp_var_counter_{0};
+  /*! \brief whether the printer is currently in an ADT definition */
+  bool in_adt_def_;
   /*! \brief arena for dependency graph */
   common::Arena arena_;
   /*! \brief dependency graph of the expr */
diff --git a/src/relay/pass/eta_expand.cc b/src/relay/pass/eta_expand.cc
index a5d04871ba95..dca08cc834d1 100644
--- a/src/relay/pass/eta_expand.cc
+++ b/src/relay/pass/eta_expand.cc
@@ -20,57 +20,144 @@
 /*!
  * \file eta_expand.cc
  *
- * \brief Add abstraction over a function. For example, abs will become (fun x -> abs x).
+ * \brief Add an abstraction over constructors and/or global variables bound to a function.
  *
  */
-#include <tvm/relay/type.h>
 #include <tvm/relay/transform.h>
+#include <tvm/relay/type.h>
+#include <tvm/relay/expr_functor.h>
+#include "../ir/type_functor.h"
 
 namespace tvm {
 namespace relay {
+namespace eta_expand {
+
+/*!
+ * \brief mutator to replace type variables with fresh ones, while maintaining alpha equality
+ */
+class TypeVarReplacer : public TypeMutator {
+ public:
+  TypeVarReplacer() : replace_map_({}) {}
 
-Expr EtaExpand(const Expr& e, const Module& mod) {
-  tvm::Array<Var> original_params;
-  tvm::Array<Expr> params;
-  tvm::Array<Var> args;
-  tvm::Array<TypeVar> original_type_params;
-  Type ret_type;
-
-  if (e->IsInstance<GlobalVarNode>()) {
-    auto gvar_node = e.as<GlobalVarNode>();
-    auto func = mod->Lookup(GetRef<GlobalVar>(gvar_node));
-    original_params = func->params;
-    original_type_params = func->type_params;
-    ret_type = func->ret_type;
-  } else {
-    CHECK(e->IsInstance<FunctionNode>());
-    auto func = GetRef<Function>(e.as<FunctionNode>());
-    original_params = func->params;
-    original_type_params = func->type_params;
-    ret_type = func->ret_type;
+  Type VisitType_(const TypeVarNode* type_var_node) final {
+    const auto type_var = GetRef<TypeVar>(type_var_node);
+    if (replace_map_.find(type_var) == replace_map_.end()) {
+      replace_map_[type_var] = TypeVarNode::make("A", Kind::kType);
+    }
+    return replace_map_[type_var];
   }
 
-  for (size_t i = 0; i < original_params.size(); ++i) {
-    auto var = VarNode::make("a", original_params[i]->type_annotation);
-    params.push_back(var);
-    args.push_back(var);
+ private:
+  /*! \brief variable replacement map to remap old type vars to fresh ones */
+  std::unordered_map<TypeVar, TypeVar, NodeHash, NodeEqual> replace_map_;
+};
+
+/*!
+ * \brief mutator to perform eta expansion on all functions in a module
+ */
+class EtaExpander : public ExprMutator {
+ public:
+  explicit EtaExpander(const Module& mod, bool expand_constructor, bool expand_global_var)
+      : mod_(mod),
+        type_var_replacer_(TypeVarReplacer()),
+        expand_constructor_(expand_constructor),
+        expand_global_var_(expand_global_var) {
+    CHECK(expand_constructor || expand_global_var)
+      << "must expand at least one language feature";
   }
 
-  auto new_func =
-      FunctionNode::make(args, CallNode::make(e, params), ret_type, original_type_params);
+  Module Expand() {
+    for (GlobalVar global_var : mod_->GetGlobalVars()) {
+      const Function func = mod_->Lookup(global_var);
+      const Function new_func = Downcast<Function>(VisitExpr(func));
+      mod_->Update(global_var, new_func);
+    }
+    return mod_;
+  }
 
-  return std::move(new_func);
-}
+  Expr VisitExpr_(const CallNode* call) final {
+    // we don't need to expand constructors when they are being called, so we
+    // prevent them being visited here
+    Expr new_op = call->op;
+    if (!call->op.as<ConstructorNode>()) {
+      new_op = VisitExpr(new_op);
+    }
+    tvm::Array<Expr> new_args;
+    for (const auto& arg : call->args) {
+      new_args.push_back(VisitExpr(arg));
+    }
+    return CallNode::make(new_op, new_args, call->attrs, call->type_args);
+  }
+
+  Expr VisitExpr_(const ConstructorNode* cons_node) final {
+    Constructor cons = GetRef<Constructor>(cons_node);
+    if (!expand_constructor_) {
+      return std::move(cons);
+    }
+    // NOTE: we only reach this case if the constructor is not being applied to any arguments
+    tvm::Array<Expr> params;
+    for (const auto& type : cons->inputs) {
+      Type param_type = type_var_replacer_.VisitType(type);
+      params.push_back(VarNode::make("eta_expand_param", param_type));
+    }
+    tvm::Array<Type> type_params;
+    TypeData adt_def = mod_->LookupDef(cons->belong_to);
+    for (const auto& type_var : adt_def->type_vars) {
+      type_params.push_back(type_var_replacer_.VisitType(type_var));
+    }
+    Expr body = CallNode::make(cons, params, Attrs());
+    Type ret_type = TypeCallNode::make(cons->belong_to, type_params);
+
+    return FunctionNode::make(
+      Downcast<tvm::Array<Var>>(params),
+      body,
+      ret_type,
+      Downcast<tvm::Array<TypeVar>>(type_params));
+  }
+
+  Expr VisitExpr_(const GlobalVarNode* gvar_node) final {
+    GlobalVar gvar = GetRef<GlobalVar>(gvar_node);
+    if (!expand_global_var_) {
+      return std::move(gvar);
+    }
+
+    const auto func = mod_->Lookup(gvar);
+    tvm::Array<Expr> params;
+    tvm::Array<Var> args;
+    for (size_t i = 0; i < func->params.size(); ++i) {
+      auto var = VarNode::make("eta_expand_param", func->params[i]->type_annotation);
+      params.push_back(var);
+      args.push_back(var);
+    }
+
+    return FunctionNode::make(
+      args,
+      CallNode::make(gvar, params),
+      func->ret_type,
+      func->type_params);
+  }
+
+ private:
+  /*! \brief reference to module being expanded */
+  const Module mod_;
+  /*! \brief type variable replacer */
+  TypeVarReplacer type_var_replacer_;
+  /*! \brief whether to expand constructor nodes */
+  bool expand_constructor_;
+  /*! \brief whether to expand global variable nodes */
+  bool expand_global_var_;
+};
+
+}  // namespace eta_expand
 
 namespace transform {
 
-Pass EtaExpand() {
-  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
-    [=](Function f, Module m, PassContext pc) {
-      return Downcast<Function>(EtaExpand(f, m));
-    };
-  Pass expanded = CreateFunctionPass(pass_func, 1, "EtaExpand", {});
-  return Sequential({expanded, InferType()});
+Pass EtaExpand(bool expand_constructor, bool expand_global_var) {
+  runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
+    [=](Module mod, PassContext pc) {
+    return eta_expand::EtaExpander(mod, expand_constructor, expand_global_var).Expand();
+  };
+  return CreateModulePass(pass_func, 1, "EtaExpand", {});
 }
 
 TVM_REGISTER_API("relay._transform.EtaExpand")
diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc
index bc84bddaad79..9d6878170bb5 100644
--- a/src/relay/pass/type_infer.cc
+++ b/src/relay/pass/type_infer.cc
@@ -653,7 +653,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator {
   }
 
   Expr VisitExpr_(const ConstructorNode* op) final {
-    return GetRef<Constructor>(op);
+    return AttachCheckedType(op);
   }
 
   Expr VisitExpr_(const MatchNode* op) final {
diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc
index 1103620829d1..75cac5d0f120 100644
--- a/src/relay/qnn/op/convolution.cc
+++ b/src/relay/qnn/op/convolution.cc
@@ -106,8 +106,6 @@ WorkloadType GetWorkload(const Array<tvm::relay::Type>& arg_types, const QnnConv
  * \brief Fallback to simpler lowering for dilation or depthwise conv.
  * \param data The input expr.
  * \param weight The weight expr.
- * \param zp_data The data zero point expr.
- * \param zp_kernel The kernel zero point expr.
  * \param param The qnn conv2d attributes.
  * \return The fallback lowered sequence of Relay expr.
  * \note In case of dilation, normal lowering would require a dilated pool.
@@ -115,16 +113,19 @@ WorkloadType GetWorkload(const Array<tvm::relay::Type>& arg_types, const QnnConv
  *       Relay operations. This will potentially lead to performance degradation
  *       as the convolution is called on int32 tensors instead of int8 tensors.
  */
-Expr Conv2DFallBack(const Expr& data, const Expr& weight, const Expr& zp_data,
-                    const Expr& zp_kernel, const QnnConv2DAttrs* param) {
-  auto shifted_data = data;
+Expr Conv2DFallBack(const Expr& data, const Expr& weight, const QnnConv2DAttrs* param) {
+  // Upcast the zero point to Int16.
+  auto zp_data = MakeConstantScalar(Int(16), param->input_zero_point);
+  auto zp_kernel = MakeConstantScalar(Int(16), param->kernel_zero_point);
+
+  auto shifted_data = Cast(data, Int(16));
   if (param->input_zero_point != 0) {
-    shifted_data = Subtract(Cast(data, Int(32)), zp_data);
+    shifted_data = Subtract(Cast(data, Int(16)), zp_data);
   }
 
-  auto shifted_kernel = weight;
+  auto shifted_kernel = Cast(weight, Int(16));
   if (param->kernel_zero_point != 0) {
-    shifted_kernel = Subtract(Cast(weight, Int(32)), zp_kernel);
+    shifted_kernel = Subtract(Cast(weight, Int(16)), zp_kernel);
   }
 
   return Conv2D(shifted_data, shifted_kernel, param->strides, param->padding, param->dilation,
@@ -186,7 +187,6 @@ Expr Conv2DFirstTerm(const Expr& padded_data, const Expr& weight, const QnnConv2
 /*
  * \brief Calculates the second term in the qnn.conv2d lowering sequence.
  * \param padded_data The padded data expr.
- * \param zp_kernel The kernel zero point expr.
  * \param param The qnn conv2d attributes.
  * \param kernel_h The height of kernel.
  * \param kernel_w The width of kernel.
@@ -200,8 +200,11 @@ Expr Conv2DFirstTerm(const Expr& padded_data, const Expr& weight, const QnnConv2
  *       followed by a reduce on the C axis. Using avg_pool2d also gives an
  *       opportunity to reuse alter_op_layout infrastructure.
  */
-Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& zp_kernel, const QnnConv2DAttrs* param,
-                      int kernel_h, int kernel_w, int out_channels) {
+Expr Conv2DSecondTerm(const Expr& padded_data, const QnnConv2DAttrs* param, int kernel_h,
+                      int kernel_w, int out_channels) {
+  // Constant Expr for the kernel zero point.
+  auto zp_kernel = MakeConstantScalar(Int(32), param->kernel_zero_point);
+
   auto casted_t2 = Cast(padded_data, Int(32));
 
   // We can reduce the H and W axis by using avg_pool2d. However, avg_pool2d averages the sum.
@@ -241,7 +244,6 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& zp_kernel, const QnnC
 /*
  * \brief Calculates the third term in the qnn.conv2d lowering sequence.
  * \param weight The weight expr.
- * \param zp_data The data zero point expr.
  * \param param The qnn conv2d attributes.
  * \param batch_size The batch size.
  * \param out_channels The number of output channels.
@@ -254,8 +256,11 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& zp_kernel, const QnnC
  *       a 1D tensor. The tensor is then reshaped to conform to NHWC/NCHW
  *       format.
  */
-Expr Conv2DThirdTerm(const Expr& weight, const Expr& zp_data, const QnnConv2DAttrs* param,
-                     int batch_size, int out_channels) {
+Expr Conv2DThirdTerm(const Expr& weight, const QnnConv2DAttrs* param, int batch_size,
+                     int out_channels) {
+  // Constant expr for input zero point.
+  auto zp_data = MakeConstantScalar(Int(32), param->input_zero_point);
+
   // Find which dimensions are C, R, S.
   Array<Integer> axes_t3;
   if (param->kernel_layout == "OIHW") {
@@ -415,21 +420,19 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
   int batch_size, in_channels, out_channels, kernel_h, kernel_w;
   std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w) =
       GetWorkload(arg_types, param);
-  auto zp_data = MakeConstantScalar(Int(32), param->input_zero_point);
-  auto zp_kernel = MakeConstantScalar(Int(32), param->kernel_zero_point);
 
   // Fallback to int32 conv if there is dilation or depthwise conv2d
   CHECK_EQ(param->dilation.size(), 2) << "qnn.conv2d only supports 2D dilation";
   auto dilation_h = get_const_int(param->dilation[0]);
   auto dilation_w = get_const_int(param->dilation[1]);
   if (dilation_h != 1 || dilation_w != 1 || param->groups != 1) {
-    return Conv2DFallBack(data, weight, zp_data, zp_kernel, param);
+    return Conv2DFallBack(data, weight, param);
   }
 
   auto padded_data = Conv2DPadInput(data, param);
   auto term1 = Conv2DFirstTerm(padded_data, weight, param);
-  auto term2 = Conv2DSecondTerm(padded_data, zp_kernel, param, kernel_h, kernel_w, out_channels);
-  auto term3 = Conv2DThirdTerm(weight, zp_data, param, batch_size, out_channels);
+  auto term2 = Conv2DSecondTerm(padded_data, param, kernel_h, kernel_w, out_channels);
+  auto term3 = Conv2DThirdTerm(weight, param, batch_size, out_channels);
   auto term4 = Conv2DFourthTerm(param, batch_size, in_channels, kernel_h, kernel_w);
   return Conv2DCombineTerms(term1, term2, term3, term4, param);
 }
diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc
index 1f7dbc1b6bb6..6df9b433560a 100644
--- a/src/relay/qnn/op/quantize.cc
+++ b/src/relay/qnn/op/quantize.cc
@@ -48,8 +48,8 @@ bool QuantizeRel(const Array<Type>& types,
   const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
   const Array<tvm::Expr> oshape = data->shape;
   const DataType out_dtype = quantize_attrs->out_dtype;
-  CHECK(out_dtype == Int(8) || out_dtype == UInt(8))
-    << "Output type should be one of [int8, unit8 ] but was " << out_dtype;
+  CHECK(out_dtype == Int(8) || out_dtype == UInt(8) || out_dtype == Int(32))
+    << "Output type should be one of [int8, unit8, int32] but was " << out_dtype;
   // assign output type
   reporter->Assign(types[1], TensorTypeNode::make(oshape, out_dtype));
   return true;
@@ -72,12 +72,12 @@ Expr MakeQuantize(Expr data,
 Expr QuantizeLower(const Expr& input_tensor,
                    const QuantizeAttrs* attrs) {
   const auto out_dtype = attrs->out_dtype;
-  const auto output_zero_point = MakeConstantScalar(Int(32), attrs->output_zero_point);
+  const auto output_zero_point = MakeConstantScalar(Float(32), attrs->output_zero_point);
   const auto scale = MakeConstantScalar(Float(32), attrs->output_scale);
   const int32_t min_val = GetQmin(out_dtype);
   const int32_t max_val = GetQmax(out_dtype);
-  auto scale_data = Cast(Round(Divide(input_tensor, scale)), Int(32));
-  auto add_zero_point = Add(scale_data, output_zero_point);
+  auto scale_data = Divide(input_tensor, scale);
+  auto add_zero_point = Cast(Round(Add(scale_data, output_zero_point)), Int(32));
   auto clamped_output = Clip(add_zero_point, min_val, max_val);
   auto clamp_out_dtype = Cast(clamped_output, out_dtype);
   return clamp_out_dtype;
diff --git a/src/runtime/contrib/cblas/cblas.cc b/src/runtime/contrib/cblas/cblas.cc
index c6558673c108..ef9f5d6564de 100644
--- a/src/runtime/contrib/cblas/cblas.cc
+++ b/src/runtime/contrib/cblas/cblas.cc
@@ -31,6 +31,9 @@ extern "C" {
 #else
 #include <cblas.h>
 #endif
+#if USE_DNNL == 1
+#include <dnnl.h>
+#endif
 }
 
 namespace tvm {
@@ -40,12 +43,19 @@ using namespace runtime;
 
 inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) { return trans ? CblasTrans : CblasNoTrans; }
 
+inline char BooleanToTransposeChar(bool trans) { return trans ? 'T' : 'N'; }
+
 struct CblasSgemmOp {
   typedef float TDatatype;
   void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B,
                   int ldb, float beta, float* C, int ldc) {
+#if USE_DNNL == 1
+    dnnl_sgemm(BooleanToTransposeChar(tb), BooleanToTransposeChar(ta), N, M, K, alpha, B,
+               ldb, A, lda, beta, C, ldc);
+#else
     cblas_sgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A,
                 lda, B, ldb, beta, C, ldc);
+#endif
   }
 };
 
diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc
index 87ee4041efaf..a504edae7fb2 100644
--- a/src/runtime/cuda/cuda_device_api.cc
+++ b/src/runtime/cuda/cuda_device_api.cc
@@ -105,6 +105,7 @@ class CUDADeviceAPI final : public DeviceAPI {
         *rv = ss.str();
         return;
       }
+      case kGcnArch: return;
     }
     *rv = value;
   }
diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm
index e38329add367..d319e5094c0e 100644
--- a/src/runtime/metal/metal_device_api.mm
+++ b/src/runtime/metal/metal_device_api.mm
@@ -63,6 +63,7 @@
     case kMultiProcessorCount: return;
     case kMaxThreadDimensions: return;
     case kExist: break;
+    case kGcnArch: return; 
   }
 }
 
diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc
index 1e6af53f3860..882ee8310f2a 100644
--- a/src/runtime/opencl/opencl_device_api.cc
+++ b/src/runtime/opencl/opencl_device_api.cc
@@ -114,6 +114,7 @@ void OpenCLWorkspace::GetAttr(
       *rv = ss.str();
       break;
     }
+    case kGcnArch: return;
   }
 }
 
diff --git a/src/runtime/opengl/opengl_device_api.cc b/src/runtime/opengl/opengl_device_api.cc
index db065522c607..1b1487e4ceec 100644
--- a/src/runtime/opengl/opengl_device_api.cc
+++ b/src/runtime/opengl/opengl_device_api.cc
@@ -117,6 +117,7 @@ void OpenGLWorkspace::GetAttr(
     case kMaxClockRate: return;
     case kMultiProcessorCount: return;
     case kMaxThreadDimensions: return;
+    case kGcnArch: return;
   }
 }
 
diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc
index cff72f58f69a..c49af895480a 100644
--- a/src/runtime/rocm/rocm_device_api.cc
+++ b/src/runtime/rocm/rocm_device_api.cc
@@ -26,9 +26,10 @@
 
 #include <dmlc/logging.h>
 #include <dmlc/thread_local.h>
-#include <tvm/runtime/registry.h>
 #include <hip/hip_runtime_api.h>
 #include <hsa/hsa.h>
+#include <tvm/runtime/registry.h>
+#include "../../../include/tvm/runtime/device_api.h"
 #include "rocm_common.h"
 
 namespace tvm {
@@ -62,16 +63,17 @@ class ROCMDeviceAPI final : public DeviceAPI {
         break;
       }
       case kMaxSharedMemoryPerBlock: return;
-      case kComputeVersion: {
+      case kComputeVersion:
+      case kDeviceName: return;
+      case kMaxClockRate: return;
+      case kMultiProcessorCount: return;
+      case kMaxThreadDimensions: return;
+      case kGcnArch: {
         hipDeviceProp_t prop;
         ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id));
         *rv = prop.gcnArch;
         return;
       }
-      case kDeviceName: return;
-      case kMaxClockRate: return;
-      case kMultiProcessorCount: return;
-      case kMaxThreadDimensions: return;
     }
     *rv = value;
   }
diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc
index c2bea8a8d745..5bc92f5b244a 100644
--- a/src/runtime/rocm/rocm_module.cc
+++ b/src/runtime/rocm/rocm_module.cc
@@ -122,16 +122,9 @@ class ROCMModuleNode : public runtime::ModuleNode {
     hipDeviceptr_t global = nullptr;
     size_t nbytes = 0;
 
-    hipError_t result = hipSuccess;
-    // ROCM doesn't support hipModuleGetGlobal yet.
-    // hipError_t result = hipModuleGetGlobal(&global, &nbytes,
-    //                                    module_[device_id], global_name.c_str());
+    ROCM_DRIVER_CALL(hipModuleGetGlobal(&global, &nbytes,
+                                        module_[device_id], global_name.c_str()));
     CHECK_EQ(nbytes, expect_nbytes);
-    if (result != hipSuccess) {
-      LOG(FATAL)
-          << "ROCMError: hipModuleGetGlobal " << global_name
-          << " failed with error: " << hipGetErrorString(result);
-    }
     return global;
   }
 
diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc
index daf4ae7c55f7..b14260e07816 100644
--- a/src/runtime/vulkan/vulkan.cc
+++ b/src/runtime/vulkan/vulkan.cc
@@ -398,6 +398,8 @@ void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue*
       break;
     case kMaxThreadDimensions:
       break;
+    case kGcnArch:
+      return;
   }
 }
 
diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py
index e5f2dc7e2aa6..d21b9488daff 100644
--- a/tests/lint/check_file_type.py
+++ b/tests/lint/check_file_type.py
@@ -95,6 +95,8 @@
     "docker/with_the_same_user",
     "LICENSE",
     "NOTICE",
+    "KEYS",
+    "DISCLAIMER",
     "Jenkinsfile",
     # sgx file
     "apps/sgx/enclave/sgx-deps.diff",
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index 6391a1a9504d..e074bac90f2a 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -77,19 +77,19 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output
         return tvm_output.asnumpy()
 
 
-def get_caffe2_output(model, x, dtype='float32'):
-    import caffe2.python.onnx.backend
-    prepared_backend = caffe2.python.onnx.backend.prepare(model)
-    W = {model.graph.input[0].name: x.astype(dtype)}
-    c2_out = prepared_backend.run(W)[0]
-    return c2_out
+def get_onnxruntime_output(model, x, dtype='float32'):
+    import onnxruntime.backend
+    rep = onnxruntime.backend.prepare(model, 'CPU')
+    x = x.astype(dtype)
+    ort_out = rep.run(x)[0]
+    return ort_out
 
 
 def verify_onnx_forward_impl(graph_file, data_shape, out_shape):
     dtype = 'float32'
     x = np.random.uniform(size=data_shape)
     model = onnx.load_model(graph_file)
-    c2_out = get_caffe2_output(model, x, dtype)
+    c2_out = get_onnxruntime_output(model, x, dtype)
     for target, ctx in ctx_list():
         tvm_out = get_tvm_output(model, x, target, ctx, out_shape, dtype)
         tvm.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5)
@@ -142,6 +142,57 @@ def test_reshape():
     tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
 
 
+def verify_depth_to_space(inshape, outshape, mode, blockSize):
+    node = onnx.helper.make_node('DepthToSpace',
+                                 inputs=['x'],
+                                 outputs=['y'],
+                                 blocksize=blockSize)
+
+    graph = helper.make_graph([node],
+                              "depth_to_space_test",
+                              inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(inshape))],
+                              outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(outshape))])
+
+    model = helper.make_model(graph, producer_name='depth_to_space_test')
+
+    for target, ctx in ctx_list():
+        x = np.random.uniform(size=inshape).astype('float32')
+        tvm_out = get_tvm_output(model, x, target, ctx, outshape, 'float32')
+        onnx_out = get_onnxruntime_output(model, x, 'float32')
+        tvm.testing.assert_allclose(onnx_out, tvm_out)
+
+
+def test_depth_to_space():
+    # current onnx.checker use OpSet-1 version of DepthToSpace, which doesn't have a mode argument.
+    # TO-DO, we can add mode arguement to test CRD mode and DCR mode
+    # in the future when we update to a newer onnx version.
+    verify_depth_to_space((1, 8, 2, 3), (1, 2, 4, 6), mode="CRD", blockSize=2)
+
+
+def verify_space_to_depth(inshape, outshape, blockSize):
+    node = onnx.helper.make_node('SpaceToDepth',
+                                 inputs=['x'],
+                                 outputs=['y'],
+                                 blocksize=blockSize)
+
+    graph = helper.make_graph([node],
+                              "space_to_depth_test",
+                              inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(inshape))],
+                              outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(outshape))])
+
+    model = helper.make_model(graph, producer_name='space_to_depth_test')
+
+    for target, ctx in ctx_list():
+        x = np.random.uniform(size=inshape).astype('float32')
+        tvm_out = get_tvm_output(model, x, target, ctx, outshape, 'float32')
+        onnx_out = get_onnxruntime_output(model, x, 'float32')
+        tvm.testing.assert_allclose(onnx_out, tvm_out)
+
+
+def test_space_to_depth():
+    verify_space_to_depth((1, 1, 4, 6), (1, 4, 2, 3), 2)
+
+
 def test_shape():
     in_shape = (4, 3, 3, 4)
     ref_shape = (6, 2, 4, 3)
@@ -1372,7 +1423,7 @@ def check_torch_conversion(model, input_size):
     onnx_model = onnx.load(file_name)
     for target, ctx in ctx_list():
         input_data = np.random.uniform(size=input_size).astype('int32')
-        c2_out = get_caffe2_output(onnx_model, input_data)
+        c2_out = get_onnxruntime_output(onnx_model, input_data)
         tvm_out = get_tvm_output(onnx_model, input_data, target, ctx)
         tvm.testing.assert_allclose(c2_out, tvm_out)
 
@@ -1574,6 +1625,7 @@ def test_erf():
     z = scipy.special.erf(x)
     verify_erf(x, z)
 
+
 def verify_where(condition, x, y, dtype, outdata):
     node = helper.make_node('Where', inputs=['condition', 'x', 'y'], outputs=['out'])
     graph = helper.make_graph([node],
@@ -1588,6 +1640,7 @@ def verify_where(condition, x, y, dtype, outdata):
         tvm_out = get_tvm_output(model, [condition, x, y], target, ctx, outdata.shape)
         tvm.testing.assert_allclose(outdata, tvm_out)
 
+
 def test_where():
     condition = np.array([[1, 0], [1, 1]], dtype=np.bool)
     x = np.array([[1, 2], [3, 4]], dtype=np.int64)
@@ -1704,3 +1757,5 @@ def test_or():
     test_erf()
     test_where()
     test_or()
+    test_depth_to_space()
+    test_space_to_depth()
diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py
index 17db2f5cc9a8..e02532fa748b 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -2114,6 +2114,22 @@ def _test_forward_transpose(ishape, axes=None):
 
         compare_tf_with_tvm(data, 'transpose_data:0', 'transpose:0')
 
+def _test_forward_tranapose_axes_input(ishape, axes):
+    data = np.random.uniform(size=ishape).astype(np.float32)
+    axes_np = np.array(axes).astype(np.int32)
+
+    with tf.Graph().as_default():
+        in1 = tf.placeholder(
+            shape=data.shape, dtype=data.dtype, name="transpose_data")
+
+        const1 = tf.constant(axes_np, dtype=tf.int32)
+
+        # make axes an input to tf.transpose, but not an input to the graph,
+        # so it can be extracted with infer_value_simulated
+        axes = tf.reverse(const1, axis=[-1])
+        tf.transpose(in1, axes)
+
+        compare_tf_with_tvm([data], ['transpose_data:0'], 'transpose:0')
 
 def test_forward_transpose():
     _test_forward_transpose((2, 3, 4), (1, 2, 0))
@@ -2122,6 +2138,8 @@ def test_forward_transpose():
     _test_forward_transpose((2, 3, 4), (1, 2, 0))
     _test_forward_transpose((2, 3, 4), (0, 1, 2))
     _test_forward_transpose((2, 3, 4, 5), (3, 0, 1, 2))
+    _test_forward_tranapose_axes_input((2, 3, 4), (1, 2, 0))
+    _test_forward_tranapose_axes_input((2, 3, 4, 5), (3, 0, 1, 2))
 
 
 def test_forward_ceil():
diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py
index 83a0730d74f8..8d1902694ee4 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -934,18 +934,18 @@ def test_forward_relu():
     """ ReLU """
     _test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
 
-def _test_prelu(data):
+def _test_prelu(data, alpha):
     """ One iteration of PReLU """
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
-        alpha = np.full((data.shape[-1],), 0.2, dtype=data.dtype)
         # This specific pattern will be replaced into PRelu by tflite
         out = nn_ops.relu(in_data) + (-alpha * nn_ops.relu(-in_data))
         compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
 
 def test_forward_prelu():
     """ PReLU """
-    _test_prelu(np.random.uniform(-5, 5, size=(1, 32, 32, 3)).astype("float32"))
+    _test_prelu(np.random.uniform(-5, 5, size=(1, 32, 32, 3)).astype("float32"), np.full((3,), 0.2, dtype="float32"))
+    _test_prelu(np.random.uniform(-5, 5, size=(1, 32, 32, 3)).astype("float32"), np.full((1, 1, 3), 0.2, dtype="float32"))
 
 #######################################################################
 # Fully Connected
diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py
index 0d6a02e6c8e4..6426bf3410c8 100644
--- a/tests/python/relay/test_ir_text_printer.py
+++ b/tests/python/relay/test_ir_text_printer.py
@@ -218,6 +218,27 @@ def test_zeros():
     x = relay.op.zeros([], "float32")
     astext(x)
 
+
+def test_unapplied_constructor():
+    type_def_str = r"""
+type List[A] {
+  Cons(A, List[A]),
+  Nil,
+}
+    """
+    main_def_str = r"""
+def @main[A]() -> fn (A, List[A]) -> List[A] {
+  Cons
+}
+    """
+    mod = relay.fromtext(SEMVER + type_def_str + main_def_str)
+    mod_str = str(mod)
+    # ensure constructors are printed correctly in type definitions (with their
+    # signature) and as exprs (without their signature)
+    assert type_def_str.strip() in mod_str
+    assert main_def_str.strip() in mod_str
+
+
 if __name__ == "__main__":
     do_print[0] = True
     test_lstm()
@@ -239,3 +260,4 @@ def test_zeros():
     test_let_if_scope()
     test_variable_name()
     test_call_node_order()
+    test_unapplied_constructor()
diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py
index eee5e4ffa52b..08c5eb0d5cfc 100644
--- a/tests/python/relay/test_op_level2.py
+++ b/tests/python/relay/test_op_level2.py
@@ -22,6 +22,7 @@
 from tvm import relay
 from tvm.relay import transform
 from tvm.relay.testing import ctx_list
+from tvm.contrib import util
 import topi.testing
 
 def run_infer_type(expr):
@@ -134,6 +135,46 @@ def run_test_conv2d(dtype, out_dtype, scale, dshape, kshape,
             op_res1 = intrp1.evaluate(func)(data, kernel)
             tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
 
+    def compile_test_conv2d_arm_cpu(dtype, out_dtype, scale, dshape, kshape,
+                        padding=(1, 1),
+                        groups=1,
+                        dilation=(1, 1),
+                        **attrs):
+        x = relay.var("x", shape=dshape, dtype=dtype)
+        w = relay.var("w", dtype=dtype)
+        y = relay.nn.conv2d(x, w,
+                            padding=padding,
+                            dilation=dilation,
+                            groups=groups,
+                            **attrs)
+        func = relay.Function([x, w], y)
+        mod = tvm.relay.Module()
+        mod["main"] = func
+
+        test_schedule='{"i": ["llvm -device=arm_cpu", "topi_nn_depthwise_conv2d_nchw", \
+                        [["TENSOR", [1, 512, 32, 32], "float32"], \
+                        ["TENSOR", [512, 1, 3, 3], "float32"], \
+                        [1, 1], [1, 1], [1, 1], "float32"], {}, \
+                        ["depthwise_conv2d_nchw", [1, 512, 32, 32, "float32"], \
+                        [512, 1, 3, 3, "float32"], [1, 1], [1, 1], [1, 1], "float32"], \
+                        {"i": 743640, "t": "contrib_spatial_pack", "c": null, \
+                        "e": [["tile_co", "sp", [512, 1]], ["tile_oh", "sp", [8, 1]], \
+                        ["tile_ow", "sp", [1, 8]], \
+                        ["reorder_0", "re", [0, 1, 2, 3, 4, 5, 8, 6, 7]], \
+                        ["reorder_1", "re", [0, 1, 2, 3, 6, 4, 5]], \
+                        ["ann_reduce", "an", ["unroll", "none"]], \
+                        ["ann_spatial", "an", ["unroll", "unroll", "vec"]], \
+                        ["data_pad_inline", "ot", 4], ["data_vec_inline", "ot", 1], \
+                        ["conv_inline", "ot", 0]]}], "r": [[0.0002933163], \
+                        0, 3.1976189613342285, 1570811630.6058347], "v": 0.1}'
+        temp = util.tempdir()
+        with open(temp.relpath("temp.log"), "w") as log_file:
+            log_file.write(test_schedule)
+        with autotvm.apply_history_best(temp.relpath("temp.log")):
+            with relay.build_config(opt_level=3):
+                print('Compiling...')
+                graph_json, mod, params = tvm.relay.build(mod, target="llvm -device=arm_cpu")
+
     # depthwise conv2d
     dshape = (1, 32, 18, 18)
     kshape = (32, 1, 3, 3)
@@ -142,6 +183,13 @@ def run_test_conv2d(dtype, out_dtype, scale, dshape, kshape,
                     fref=lambda x, w: topi.testing.depthwise_conv2d_python_nchw(
                         x, w, (1, 1), "SAME"))
 
+    # depthwise conv2d for arm_cpu
+    dshape = (1, 512, 32, 32)
+    kshape = (512, 1, 3, 3)
+    compile_test_conv2d_arm_cpu("float32", "float32", 1, dshape, kshape,
+                                padding=(1, 1), channels=512, 
+                                groups=512, kernel_size=(3 ,3))
+
     # CUDA is disabled for 'direct' schedule:
     # https://github.com/apache/incubator-tvm/pull/3070#issuecomment-486597553
     # group conv2d
diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py
index fb5dbcc8b558..f7447465c3ac 100644
--- a/tests/python/relay/test_op_level5.py
+++ b/tests/python/relay/test_op_level5.py
@@ -424,7 +424,7 @@ def verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs):
 
         func = relay.Function([cls_prob, bbox_pred, im_info], z)
         func = run_infer_type(func)
-        for target in ['cuda']:
+        for target in ['llvm', 'cuda']:
             if not tvm.module.enabled(target):
                 print("Skip test because %s is not enabled." % target)
                 continue
diff --git a/tests/python/relay/test_op_qnn_conv2d.py b/tests/python/relay/test_op_qnn_conv2d.py
index b4e8bfd71b62..71368f84d023 100644
--- a/tests/python/relay/test_op_qnn_conv2d.py
+++ b/tests/python/relay/test_op_qnn_conv2d.py
@@ -160,7 +160,7 @@ def get_output(func, golden_inputs):
     qnn_output = get_output(qnn_func, golden_inputs)
     np.testing.assert_equal(qnn_output, golden_output)
 
-def no_zero_point_test():
+def test_no_zero_point():
     # uint8 input
     data_shape = (2, 1, 2, 4)
     data_dtype = 'uint8'
@@ -203,7 +203,7 @@ def no_zero_point_test():
     verify(ref_func, qnn_func, data_shape, data_dtype,
             kernel_shape, kernel_dtype)
 
-def kernel_zero_point_test():
+def test_kernel_zero_point():
     # uint8 input
     data_shape = (2, 4, 2, 4)
     data_dtype = 'uint8'
@@ -247,7 +247,7 @@ def kernel_zero_point_test():
             kernel_shape, kernel_dtype)
 
 
-def input_zero_point_test():
+def test_input_zero_point():
     # uint8 input
     data_shape = (2, 4, 2, 4)
     data_dtype = 'uint8'
@@ -290,7 +290,7 @@ def input_zero_point_test():
     verify(ref_func, qnn_func, data_shape, data_dtype,
             kernel_shape, kernel_dtype)
 
-def both_zero_point_test():
+def test_both_zero_point():
     # uint8 input
     data_shape = (2, 4, 2, 4)
     data_dtype = 'uint8'
@@ -333,7 +333,7 @@ def both_zero_point_test():
     verify(ref_func, qnn_func, data_shape, data_dtype,
             kernel_shape, kernel_dtype)
 
-def layout_test():
+def test_layout():
     # uint8 input
     data_shape = (2, 2, 4, 4) # NHWC
     data_dtype = 'uint8'
@@ -378,7 +378,7 @@ def layout_test():
 
 
 
-def padding_test():
+def test_padding():
     # uint8 input
     data_shape = (1, 4, 2, 2)
     data_dtype = 'uint8'
@@ -421,7 +421,7 @@ def padding_test():
     verify(ref_func, qnn_func, data_shape, data_dtype,
             kernel_shape, kernel_dtype)
 
-def dilation_test():
+def test_dilation():
     # uint8 input
     data_shape = (2, 4, 4, 4)
     data_dtype = 'uint8'
@@ -444,7 +444,7 @@ def dilation_test():
             kernel_shape, kernel_dtype)
 
 
-def const_folding_test():
+def test_const_folding():
     data_shape = (2, 4, 2, 4)
     data_dtype = 'uint8'
     kernel_shape = (3, 4, 2, 2)
@@ -470,7 +470,7 @@ def const_folding_test():
     folded_func = folded_mod["main"]
     assert "reshape" not in folded_func.astext()
 
-def kernel_size_1x1_test():
+def test_kernel_size_1x1():
     # uint8 input
     data_shape = (2, 4, 2, 4)
     data_dtype = 'uint8'
@@ -493,7 +493,7 @@ def kernel_size_1x1_test():
     verify(ref_func, qnn_func, data_shape, data_dtype,
             kernel_shape, kernel_dtype)
 
-def tflite_large_irregular_test():
+def test_tflite_large_irregular():
     # uint8 input
     data_shape = (1, 1024, 1, 1)
     data_dtype = 'uint8'
@@ -526,7 +526,7 @@ def tflite_large_irregular_test():
     golden_output = np.full((1, 1001, 1, 1), 0).astype('uint8')
     np.testing.assert_equal(qnn_output, golden_output)
 
-def tflite_output_multiplier_greater_than_one():
+def test_tflite_output_multiplier_greater_than_one():
     # uint8 input
     data_shape = (2, 1, 2, 4)
     data_dtype = 'uint8'
@@ -570,7 +570,7 @@ def tflite_output_multiplier_greater_than_one():
                               0, 0)).reshape(2, 3, 1, 2)
     np.testing.assert_equal(qnn_output, golden_output)
 
-def tflite_anistropic_strides():
+def test_tflite_anistropic_strides():
     # uint8 input
     data_shape = (1, 1, 3, 6)
     data_dtype = 'uint8'
@@ -607,7 +607,7 @@ def tflite_anistropic_strides():
     golden_output = np.array((124, -92, 164, -132)).reshape(1, 1, 2, 2)
     np.testing.assert_equal(qnn_output, golden_output)
 
-def broadcast_layout_test():
+def test_broadcast_layout():
     # Test broadcast support for NHWC layout.
     data_shape = (1, 229, 229, 3) # NHWC
     data_dtype = 'uint8'
@@ -641,16 +641,16 @@ def broadcast_layout_test():
         graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512")
 
 if __name__ == "__main__":
-    no_zero_point_test()
-    input_zero_point_test()
-    kernel_zero_point_test()
-    both_zero_point_test()
-    layout_test()
-    padding_test()
-    dilation_test()
-    const_folding_test()
-    kernel_size_1x1_test()
-    tflite_large_irregular_test()
-    tflite_output_multiplier_greater_than_one()
-    tflite_anistropic_strides()
-    broadcast_layout_test()
+    test_no_zero_point()
+    test_input_zero_point()
+    test_kernel_zero_point()
+    test_both_zero_point()
+    test_layout()
+    test_padding()
+    test_dilation()
+    test_const_folding()
+    test_kernel_size_1x1()
+    test_tflite_large_irregular()
+    test_broadcast_layout()
+    test_tflite_output_multiplier_greater_than_one()
+    test_tflite_anistropic_strides()
diff --git a/tests/python/relay/test_qnn_mul.py b/tests/python/relay/test_op_qnn_mul.py
similarity index 100%
rename from tests/python/relay/test_qnn_mul.py
rename to tests/python/relay/test_op_qnn_mul.py
diff --git a/tests/python/relay/test_op_qnn_requantize.py b/tests/python/relay/test_op_qnn_requantize.py
index 18e2f308969b..3818135ecda7 100644
--- a/tests/python/relay/test_op_qnn_requantize.py
+++ b/tests/python/relay/test_op_qnn_requantize.py
@@ -22,230 +22,227 @@
 
 roundings = ["UPWARD", "TONEAREST"]
 
-def test_requantize():
-    def verify(mod, goldens):
-        with relay.build_config(opt_level=3):
-            graph, lib, params = relay.build(mod, "llvm", params=None)
-            golden_data, golden_output = goldens
-            rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
-            rt_mod.set_input("quantized_data",golden_data)
-            rt_mod.set_input(**params)
-            rt_mod.run()
-            res = rt_mod.get_output(0).asnumpy()
-            np.testing.assert_equal(res, golden_output)
-
-    def get_mod(data_shape, data_dtype, out_dtype, input_scale, output_scale,
-            input_zero_point=0, output_zero_point=0, rounding="TONEAREST"):
-        quantized_data = relay.var("quantized_data", shape=data_shape,
-                dtype=data_dtype)
-        mod = relay.qnn.op.requantize(
-                quantized_data,
-                input_scale=input_scale,
-                input_zero_point=input_zero_point,
-                output_scale=output_scale,
-                output_zero_point=output_zero_point,
-                rounding=rounding,
-                out_dtype=out_dtype)
-
-        mod = relay.Function(relay.analysis.free_vars(mod), mod)
-        mod = relay.Module.from_expr(mod)
-        return mod
-
-    def same_scale_test():
-        # Have same scales, everything within range
-        golden_data = np.arange(-100, 100, 1).astype('int32')
-        golden_output = golden_data
-
-        for rounding in roundings:
-            mod = get_mod(data_shape=(200, ),
-                          data_dtype='int32',
-                          out_dtype="int8",
-                          input_scale=0.5,
-                          output_scale=0.5,
-                          rounding=rounding)
-            assert 'right_shift' not in mod.astext()
-            verify(mod, (golden_data, golden_output))
-
-    def downscale_test():
-        for rounding in roundings:
-            mod = get_mod(data_shape=(32, ),
-                          data_dtype='int32',
-                          out_dtype='int8',
-                          input_scale=1,
-                          output_scale=16,
-                          rounding=rounding)
-
-            # Try positive values
-            # 8 corresponds to 0.5, resulting in 1
-            golden_data = np.arange(0, 32, 1).astype('int32')
-            golden_output = np.repeat([0, 1, 2], [8, 16, 8])
-            verify(mod, (golden_data, golden_output))
-
-            # Try negative values
-            # -8 corresponds to -0.5. For UPWARD, this is 0
-            golden_data = np.arange(0, -32, -1).astype('int32')
-            if rounding == "UPWARD":
-                golden_output = np.repeat([0, -1, -2], [9, 16, 7])
-            else:
-                golden_output = np.repeat([0, -1, -2], [8, 16, 8])
-            verify(mod, (golden_data, golden_output))
-
-            # Try a different scale
-            mod = get_mod(data_shape=(32, ),
-                          data_dtype='int32',
-                          out_dtype="int8",
-                          input_scale=1,
-                          output_scale=4,
-                          rounding=rounding)
-
-            # Try positive values
-            # 2I corresponds to 0.5, resulting in 1
-            golden_data = np.arange(0, 32, 1).astype('int32')
-            golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8],
+def verify(mod, goldens):
+    with relay.build_config(opt_level=3):
+        graph, lib, params = relay.build(mod, "llvm", params=None)
+        golden_data, golden_output = goldens
+        rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
+        rt_mod.set_input("quantized_data",golden_data)
+        rt_mod.set_input(**params)
+        rt_mod.run()
+        res = rt_mod.get_output(0).asnumpy()
+        np.testing.assert_equal(res, golden_output)
+
+def get_mod(data_shape, data_dtype, out_dtype, input_scale, output_scale,
+        input_zero_point=0, output_zero_point=0, rounding="TONEAREST"):
+    quantized_data = relay.var("quantized_data", shape=data_shape,
+            dtype=data_dtype)
+    mod = relay.qnn.op.requantize(
+            quantized_data,
+            input_scale=input_scale,
+            input_zero_point=input_zero_point,
+            output_scale=output_scale,
+            output_zero_point=output_zero_point,
+            rounding=rounding,
+            out_dtype=out_dtype)
+
+    mod = relay.Function(relay.analysis.free_vars(mod), mod)
+    mod = relay.Module.from_expr(mod)
+    return mod
+
+def test_same_scale():
+    # Have same scales, everything within range
+    golden_data = np.arange(-100, 100, 1).astype('int32')
+    golden_output = golden_data
+
+    for rounding in roundings:
+        mod = get_mod(data_shape=(200, ),
+                      data_dtype='int32',
+                      out_dtype="int8",
+                      input_scale=0.5,
+                      output_scale=0.5,
+                      rounding=rounding)
+        assert 'right_shift' not in mod.astext()
+        verify(mod, (golden_data, golden_output))
+
+def test_downscale():
+    for rounding in roundings:
+        mod = get_mod(data_shape=(32, ),
+                      data_dtype='int32',
+                      out_dtype='int8',
+                      input_scale=1,
+                      output_scale=16,
+                      rounding=rounding)
+
+        # Try positive values
+        # 8 corresponds to 0.5, resulting in 1
+        golden_data = np.arange(0, 32, 1).astype('int32')
+        golden_output = np.repeat([0, 1, 2], [8, 16, 8])
+        verify(mod, (golden_data, golden_output))
+
+        # Try negative values
+        # -8 corresponds to -0.5. For UPWARD, this is 0
+        golden_data = np.arange(0, -32, -1).astype('int32')
+        if rounding == "UPWARD":
+            golden_output = np.repeat([0, -1, -2], [9, 16, 7])
+        else:
+            golden_output = np.repeat([0, -1, -2], [8, 16, 8])
+        verify(mod, (golden_data, golden_output))
+
+        # Try a different scale
+        mod = get_mod(data_shape=(32, ),
+                      data_dtype='int32',
+                      out_dtype="int8",
+                      input_scale=1,
+                      output_scale=4,
+                      rounding=rounding)
+
+        # Try positive values
+        # 2I corresponds to 0.5, resulting in 1
+        golden_data = np.arange(0, 32, 1).astype('int32')
+        golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8],
+                                  [2, 4, 4, 4, 4, 4, 4, 4, 2])
+        verify(mod, (golden_data, golden_output))
+
+        # Try negative values
+        # -8 corresponds to -0.5. For UPWARD, this is 0
+        golden_data = np.arange(0, -32, -1).astype('int32')
+        if rounding == "UPWARD":
+            golden_output = np.repeat([0, -1, -2, -3, -4, -5, -6, -7, -8],
+                                      [3, 4, 4, 4, 4, 4, 4, 4, 1])
+        else:
+            golden_output = np.repeat([0, -1, -2, -3, -4, -5, -6, -7, -8],
                                       [2, 4, 4, 4, 4, 4, 4, 4, 2])
-            verify(mod, (golden_data, golden_output))
-
-            # Try negative values
-            # -8 corresponds to -0.5. For UPWARD, this is 0
-            golden_data = np.arange(0, -32, -1).astype('int32')
-            if rounding == "UPWARD":
-                golden_output = np.repeat([0, -1, -2, -3, -4, -5, -6, -7, -8],
-                                          [3, 4, 4, 4, 4, 4, 4, 4, 1])
-            else:
-                golden_output = np.repeat([0, -1, -2, -3, -4, -5, -6, -7, -8],
-                                          [2, 4, 4, 4, 4, 4, 4, 4, 2])
-            verify(mod, (golden_data, golden_output))
-
-            # Try uint8 out_dtype
-            mod = get_mod(data_shape=(32, ),
-                          data_dtype='int32',
-                          out_dtype='uint8',
-                          input_scale=1,
-                          output_scale=16,
-                          rounding=rounding)
-
-            # Try positive values
-            # 8 corresponds to 0.5, resulting in 1
-            golden_data = np.arange(0, 32, 1).astype('int32')
-            golden_output = np.repeat([0, 1, 2], [8, 16, 8])
-            verify(mod, (golden_data, golden_output))
-
-            # Try uint8 in_dtyope and uint8 out_dtype
-            mod = get_mod(data_shape=(32, ),
-                          data_dtype='uint8',
-                          out_dtype='uint8',
-                          input_scale=1,
-                          output_scale=16,
-                          rounding=rounding)
-
-            # Try positive values
-            # 8 corresponds to 0.5, resulting in 1
-            golden_data = np.arange(0, 32, 1).astype('int32')
-            golden_output = np.repeat([0, 1, 2], [8, 16, 8])
-            verify(mod, (golden_data, golden_output))
-
-    def upscale_test():
-        for rounding in roundings:
-            mod = get_mod(data_shape=(32, ),
-                          data_dtype='int32',
-                          out_dtype="int8",
-                          input_scale=2,
-                          output_scale=1,
-                          rounding=rounding)
-
-            # Try positive values
-            # 8 corresponds to 0.5, resulting in 1
-            golden_data = np.arange(0, 32, 1).astype('int32')
-            golden_output = np.multiply(2, golden_data)
-            verify(mod, (golden_data, golden_output))
-
-            # Try negative values
-            # -8 corresponds to -0.5. For UPWARD, this is 0
-            golden_data = np.arange(0, -32, -1).astype('int32')
-            golden_output = np.multiply(2, golden_data)
-            verify(mod, (golden_data, golden_output))
-
-    def saturation_test():
-        for rounding in roundings:
-            mod = get_mod(data_shape=(16, ),
-                          data_dtype='int32',
-                          out_dtype="int8",
-                          input_scale=0.5,
-                          output_scale=0.5,
-                          rounding=rounding)
-            golden_data = np.arange(0, 16, 1).astype('int32')
-            golden_data = np.add(120, golden_data)
-            output = np.array([120, 121, 122, 123, 124, 125, 126, 127,
-                               127, 127, 127, 127, 127, 127, 127, 127])
-            golden_output = output
-            verify(mod, (golden_data, golden_output))
-
-            # Try negative numbers
-            golden_data = np.arange(0, -16, -1).astype('int32')
-            golden_data = np.add(-120, golden_data)
-            output = np.array([-120, -121, -122, -123, -124, -125, -126, -127,
-                               -128, -128, -128, -128, -128, -128, -128, -128])
-            golden_output = output
-            verify(mod, (golden_data, golden_output))
-
-    def zero_point_test():
-        # Output zero point
-        for rounding in roundings:
-            mod = get_mod(data_shape=(32, ),
-                          data_dtype='int32',
-                          out_dtype='int8',
-                          input_scale=1,
-                          output_scale=16,
-                          output_zero_point=1,
-                          rounding=rounding)
-
-            # Try positive values
-            # 8 corresponds to 0.5, resulting in 1
-            golden_data = np.arange(0, 32, 1).astype('int32')
-            golden_output = np.repeat([0, 1, 2], [8, 16, 8])
-            golden_output = np.add(1, golden_output)
-            verify(mod, (golden_data, golden_output))
-
-            # Try negative values
-            # -8 corresponds to -0.5. For UPWARD, this is 0
-            golden_data = np.arange(-32, -64, -1).astype('int32')
-            if rounding == "UPWARD":
-                golden_output = np.repeat([-2, -3, -4], [9, 16, 7])
-            else:
-                golden_output = np.repeat([-2, -3, -4], [8, 16, 8])
-            golden_output = np.add(1, golden_output)
-            verify(mod, (golden_data, golden_output))
-
-        # Input zero point
-        for rounding in roundings:
-            mod = get_mod(data_shape=(32, ),
-                          data_dtype='int32',
-                          out_dtype='int8',
-                          input_scale=1,
-                          output_scale=16,
-                          input_zero_point=16,
-                          rounding=rounding)
-
-            # Try positive values
-            golden_data = np.arange(32, 64, 1).astype('int32')
-            golden_output = np.repeat([2, 3, 4], [8, 16, 8])
-            golden_output = np.subtract(golden_output, 1)
-            verify(mod, (golden_data, golden_output))
-
-            # Try negative values
-            golden_data = np.arange(-32, -64, -1).astype('int32')
-            if rounding == "UPWARD":
-                golden_output = np.repeat([-2, -3, -4], [9, 16, 7])
-            else:
-                golden_output = np.repeat([-2, -3, -4], [8, 16, 8])
-            golden_output = np.subtract(golden_output, 1)
-            verify(mod, (golden_data, golden_output))
-
-    same_scale_test()
-    downscale_test()
-    upscale_test()
-    saturation_test()
-    zero_point_test()
+        verify(mod, (golden_data, golden_output))
+
+        # Try uint8 out_dtype
+        mod = get_mod(data_shape=(32, ),
+                      data_dtype='int32',
+                      out_dtype='uint8',
+                      input_scale=1,
+                      output_scale=16,
+                      rounding=rounding)
+
+        # Try positive values
+        # 8 corresponds to 0.5, resulting in 1
+        golden_data = np.arange(0, 32, 1).astype('int32')
+        golden_output = np.repeat([0, 1, 2], [8, 16, 8])
+        verify(mod, (golden_data, golden_output))
+
+        # Try uint8 in_dtyope and uint8 out_dtype
+        mod = get_mod(data_shape=(32, ),
+                      data_dtype='uint8',
+                      out_dtype='uint8',
+                      input_scale=1,
+                      output_scale=16,
+                      rounding=rounding)
+
+        # Try positive values
+        # 8 corresponds to 0.5, resulting in 1
+        golden_data = np.arange(0, 32, 1).astype('int32')
+        golden_output = np.repeat([0, 1, 2], [8, 16, 8])
+        verify(mod, (golden_data, golden_output))
+
+def test_upscale():
+    for rounding in roundings:
+        mod = get_mod(data_shape=(32, ),
+                      data_dtype='int32',
+                      out_dtype="int8",
+                      input_scale=2,
+                      output_scale=1,
+                      rounding=rounding)
+
+        # Try positive values
+        # 8 corresponds to 0.5, resulting in 1
+        golden_data = np.arange(0, 32, 1).astype('int32')
+        golden_output = np.multiply(2, golden_data)
+        verify(mod, (golden_data, golden_output))
+
+        # Try negative values
+        # -8 corresponds to -0.5. For UPWARD, this is 0
+        golden_data = np.arange(0, -32, -1).astype('int32')
+        golden_output = np.multiply(2, golden_data)
+        verify(mod, (golden_data, golden_output))
+
+def test_saturation():
+    for rounding in roundings:
+        mod = get_mod(data_shape=(16, ),
+                      data_dtype='int32',
+                      out_dtype="int8",
+                      input_scale=0.5,
+                      output_scale=0.5,
+                      rounding=rounding)
+        golden_data = np.arange(0, 16, 1).astype('int32')
+        golden_data = np.add(120, golden_data)
+        output = np.array([120, 121, 122, 123, 124, 125, 126, 127,
+                           127, 127, 127, 127, 127, 127, 127, 127])
+        golden_output = output
+        verify(mod, (golden_data, golden_output))
+
+        # Try negative numbers
+        golden_data = np.arange(0, -16, -1).astype('int32')
+        golden_data = np.add(-120, golden_data)
+        output = np.array([-120, -121, -122, -123, -124, -125, -126, -127,
+                           -128, -128, -128, -128, -128, -128, -128, -128])
+        golden_output = output
+        verify(mod, (golden_data, golden_output))
+
+def test_zero_point():
+    # Output zero point
+    for rounding in roundings:
+        mod = get_mod(data_shape=(32, ),
+                      data_dtype='int32',
+                      out_dtype='int8',
+                      input_scale=1,
+                      output_scale=16,
+                      output_zero_point=1,
+                      rounding=rounding)
+
+        # Try positive values
+        # 8 corresponds to 0.5, resulting in 1
+        golden_data = np.arange(0, 32, 1).astype('int32')
+        golden_output = np.repeat([0, 1, 2], [8, 16, 8])
+        golden_output = np.add(1, golden_output)
+        verify(mod, (golden_data, golden_output))
+
+        # Try negative values
+        # -8 corresponds to -0.5. For UPWARD, this is 0
+        golden_data = np.arange(-32, -64, -1).astype('int32')
+        if rounding == "UPWARD":
+            golden_output = np.repeat([-2, -3, -4], [9, 16, 7])
+        else:
+            golden_output = np.repeat([-2, -3, -4], [8, 16, 8])
+        golden_output = np.add(1, golden_output)
+        verify(mod, (golden_data, golden_output))
+
+    # Input zero point
+    for rounding in roundings:
+        mod = get_mod(data_shape=(32, ),
+                      data_dtype='int32',
+                      out_dtype='int8',
+                      input_scale=1,
+                      output_scale=16,
+                      input_zero_point=16,
+                      rounding=rounding)
+
+        # Try positive values
+        golden_data = np.arange(32, 64, 1).astype('int32')
+        golden_output = np.repeat([2, 3, 4], [8, 16, 8])
+        golden_output = np.subtract(golden_output, 1)
+        verify(mod, (golden_data, golden_output))
+
+        # Try negative values
+        golden_data = np.arange(-32, -64, -1).astype('int32')
+        if rounding == "UPWARD":
+            golden_output = np.repeat([-2, -3, -4], [9, 16, 7])
+        else:
+            golden_output = np.repeat([-2, -3, -4], [8, 16, 8])
+        golden_output = np.subtract(golden_output, 1)
+        verify(mod, (golden_data, golden_output))
 
 if __name__ == "__main__":
-    test_requantize()
+    test_same_scale()
+    test_downscale()
+    test_upscale()
+    test_saturation()
+    test_zero_point()
diff --git a/tests/python/relay/test_pass_eta_expand.py b/tests/python/relay/test_pass_eta_expand.py
index 73c3a4eb4073..b9eb2a1e692d 100644
--- a/tests/python/relay/test_pass_eta_expand.py
+++ b/tests/python/relay/test_pass_eta_expand.py
@@ -14,27 +14,70 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import os
+
+import numpy as np
+
+import tvm
 from tvm import relay
-import tvm.relay.module as _module
 import tvm.relay.transform as _transform
 
-def test_eta_expand_basic():
-    x = relay.var('x', 'int32')
-    orig = relay.Function([x], x)
-    mod = _module.Module.from_expr(orig)
-    seq = _transform.Sequential([_transform.EtaExpand()])
+def test_eta_expand_global_var():
+    mod = relay.fromtext(r"""
+        v0.0.4
+        def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] {
+            %x
+        }
+        def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) {
+            @aux
+        }
+    """)
+    seq = _transform.Sequential([_transform.EtaExpand(expand_global_var=True)])
     with _transform.PassContext(opt_level=3):
         mod = seq(mod)
+    expected = relay.fromtext(r"""
+        v0.0.4
+        def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] {
+            %x
+        }
+        def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) {
+            fn (%x: Tensor[(), int32]) -> Tensor[(), int32] {
+                @aux(%x)
+            }
+        }
+    """)
+    relay.analysis.assert_graph_equal(mod['main'], expected['main'])
+
 
-    got = mod["main"]
+def test_eta_expand_constructor():
+    mod = relay.fromtext(r"""
+        v0.0.4
+        type List[A] {
+            Cons(A, List[A]),
+            Nil,
+        }
+        def @main[A]() -> (fn(A, List[A]) -> List[A]) {
+            Cons
+        }
+    """)
+    seq = _transform.Sequential([_transform.EtaExpand(expand_constructor=True)])
+    with _transform.PassContext(opt_level=3):
+        mod = seq(mod)
+    expected = relay.fromtext(r"""
+        v0.0.4
+        type List[A] {
+            Cons(A, List[A]),
+            Nil,
+        }
+        def @main[A]() -> (fn(A, List[A]) -> List[A]) {
+            fn [A](%x: A, %xs: List[A]) -> List[A] {
+                Cons(%x, %xs)
+            }
+        }
+    """)
+    relay.analysis.assert_graph_equal(mod['main'], expected['main'])
 
-    y = relay.var('y', 'int32')
-    expected = relay.Function([y], orig(y))
-    gv = relay.GlobalVar("gv")
-    mod[gv] = expected
-    mod = _transform.InferType()(mod)
-    expected = mod["gv"]
-    assert(relay.analysis.alpha_equal(got, expected))
 
-if __name__ == "__main__":
-    test_eta_expand_basic()
+if __name__ == '__main__':
+    test_eta_expand_global_var()
+    test_eta_expand_constructor()
diff --git a/tests/python/relay/test_pass_lambda_lift.py b/tests/python/relay/test_pass_lambda_lift.py
new file mode 100644
index 000000000000..ffcdb5e3ea9c
--- /dev/null
+++ b/tests/python/relay/test_pass_lambda_lift.py
@@ -0,0 +1,40 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import pytest
+
+import tvm
+from tvm import relay
+from tvm.relay import transform
+
+def test_basic():
+    mod = relay.Module()
+    x2 = relay.var('x2', shape=(10, 5))
+    y2 = relay.var('y2', shape=(1, 5))
+    level2_func = relay.Function([x2, y2], relay.op.add(x2, y2))
+
+    x1 = relay.var('x1', shape=(10, 5))
+    y1 = relay.var('y1', shape=(1, 5))
+    level1_func = relay.Function([x1, y1], level2_func(x1, y1))
+
+    mod["main"] = level1_func
+    new_mod = transform.LambdaLift()(mod)
+    assert len(new_mod.functions) == 2
+
+if __name__ == "__main__":
+    pytest.main()
+
diff --git a/tests/python/relay/test_pass_qnn_legalize.py b/tests/python/relay/test_pass_qnn_legalize.py
index 55c1fa6d3187..8ace7bc745ea 100644
--- a/tests/python/relay/test_pass_qnn_legalize.py
+++ b/tests/python/relay/test_pass_qnn_legalize.py
@@ -23,6 +23,14 @@
 from tvm.relay.qnn.op import register_qnn_legalize
 from tvm.relay import transform, analysis
 
+def alpha_equal(x, y):
+    """
+    Wrapper around alpha equality which ensures that
+    the hash function respects equality.
+    """
+    x = x['main']
+    y = y['main']
+    return analysis.alpha_equal(x, y) and analysis.structural_hash(x) == analysis.structural_hash(y)
 
 def run_opt_pass(expr, passes):
     passes = passes if isinstance(passes, list) else [passes]
@@ -82,11 +90,11 @@ def expected():
     b = run_opt_pass(expected(), transform.InferType())
     assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
 
+
 def test_qnn_legalize_qnn_conv2d():
-    data_shape = (1, 64, 256, 256)
-    kernel_shape = (128, 64, 3, 3)
-    for dtype in ['uint8', 'int8']:
-        data_dtype =  kernel_dtype = dtype
+    def _get_mod(data_dtype, kernel_dtype):
+        data_shape = (1, 64, 256, 256)
+        kernel_shape = (128, 64, 3, 3)
         data = relay.var("data", shape=data_shape,
                 dtype=data_dtype)
         kernel = relay.var("kernel", shape=kernel_shape,
@@ -104,12 +112,145 @@ def test_qnn_legalize_qnn_conv2d():
 
         mod = relay.Function(relay.analysis.free_vars(func), func)
         mod = relay.Module.from_expr(mod)
+        return mod
+
+    # Check uint8 x uint8 and int8 x int8 transformation
+    for dtype in ('uint8', 'int8'):
+        mod = _get_mod(dtype, dtype)
 
+        #############################################################
+        # Check transformations for platforms with fast Int8 support.
+        #############################################################
+        # Check that Intel VNNI gets picked up.
         with tvm.target.create('llvm -mcpu=skylake-avx512'):
-            mod = relay.qnn.transform.Legalize()(mod)
+            legalized_mod = relay.qnn.transform.Legalize()(mod)
+            assert 'cast' in legalized_mod.astext() and "qnn.conv2d" in legalized_mod.astext()
+
+        # Since same dtype, there should not be any transformation
+        with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
+            legalized_mod = relay.qnn.transform.Legalize()(mod)
+            assert alpha_equal(mod, legalized_mod)
+
+        ################################################################
+        # Check transformations for platforms without fast Int8 support.
+        ################################################################
+        # Older Intel versions.
+        with tvm.target.create('llvm'):
+            legalized_mod = relay.qnn.transform.Legalize()(mod)
+            assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+
+        # Older ARM vesions.
+        with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu'):
+            legalized_mod = relay.qnn.transform.Legalize()(mod)
+            assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+
+    # Check uint8 x int8 transformation
+    mod = _get_mod('uint8', 'int8')
+    #############################################################
+    # Check transformations for platforms with fast Int8 support.
+    #############################################################
+    # Check no transformation for Intel VNNI.
+    with tvm.target.create('llvm -mcpu=skylake-avx512'):
+        legalized_mod = relay.qnn.transform.Legalize()(mod)
+        assert alpha_equal(mod, legalized_mod)
+
+    # ARM - so check that transformation has happened.
+    with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
+        legalized_mod = relay.qnn.transform.Legalize()(mod)
+        assert 'cast' in legalized_mod.astext() and "qnn.conv2d" in legalized_mod.astext()
+
+    ################################################################
+    # Check transformations for platforms without fast Int8 support.
+    ################################################################
+    # Older Intel versions.
+    with tvm.target.create('llvm'):
+        legalized_mod = relay.qnn.transform.Legalize()(mod)
+        assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+
+    # Older ARM vesions.
+    with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu'):
+        legalized_mod = relay.qnn.transform.Legalize()(mod)
+        assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+
+
+def test_qnn_legalize_qnn_dense():
+    def _get_mod(data_dtype, kernel_dtype):
+        data_shape = (10, 3)
+        kernel_shape = (20, 3)
+        data = relay.var("data", shape=data_shape,
+                dtype=data_dtype)
+        kernel = relay.var("kernel", shape=kernel_shape,
+                dtype=kernel_dtype)
+        func = relay.qnn.op.dense(
+                data, kernel,
+                input_zero_point=1,
+                kernel_zero_point=1,
+                out_dtype='int32')
+
+        mod = relay.Function(relay.analysis.free_vars(func), func)
+        mod = relay.Module.from_expr(mod)
+        return mod
+
+    # Check uint8 x uint8 and int8 x int8 transformation
+    for dtype in ('uint8', 'int8'):
+        mod = _get_mod(dtype, dtype)
+
+        #############################################################
+        # Check transformations for platforms with fast Int8 support.
+        #############################################################
+        # Check that Intel VNNI gets picked up.
+        with tvm.target.create('llvm -mcpu=skylake-avx512'):
+            legalized_mod = relay.qnn.transform.Legalize()(mod)
+            assert 'cast' in legalized_mod.astext() and "qnn.dense" in legalized_mod.astext()
+
+        # Since same dtype, there should not be any transformation
+        with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
+            legalized_mod = relay.qnn.transform.Legalize()(mod)
+            assert alpha_equal(mod, legalized_mod)
+
+        ################################################################
+        # Check transformations for platforms without fast Int8 support.
+        ################################################################
+        # Older Intel versions.
+        with tvm.target.create('llvm'):
+            legalized_mod = relay.qnn.transform.Legalize()(mod)
+            assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+
+        # Older ARM vesions.
+        with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu'):
+            legalized_mod = relay.qnn.transform.Legalize()(mod)
+            assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+
+    # Check uint8 x int8 transformation
+    mod = _get_mod('uint8', 'int8')
+    #############################################################
+    # Check transformations for platforms with fast Int8 support.
+    #############################################################
+    # Check no transformation for Intel VNNI.
+    with tvm.target.create('llvm -mcpu=skylake-avx512'):
+        legalized_mod = relay.qnn.transform.Legalize()(mod)
+        assert alpha_equal(mod, legalized_mod)
+
+    # ARM - so check that transformation has happened.
+    with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
+        legalized_mod = relay.qnn.transform.Legalize()(mod)
+        assert 'cast' in legalized_mod.astext() and "qnn.dense" in legalized_mod.astext()
+
+    ################################################################
+    # Check transformations for platforms without fast Int8 support.
+    ################################################################
+    # Older Intel versions.
+    with tvm.target.create('llvm'):
+        legalized_mod = relay.qnn.transform.Legalize()(mod)
+        assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+
+    # Older ARM vesions.
+    with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu'):
+        legalized_mod = relay.qnn.transform.Legalize()(mod)
+        assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
 
-        assert 'cast' in mod.astext()
 
 if __name__ == "__main__":
     test_qnn_legalize()
     test_qnn_legalize_qnn_conv2d()
+    test_qnn_legalize_qnn_dense()
diff --git a/tests/python/relay/test_pass_remove_unused_functions.py b/tests/python/relay/test_pass_remove_unused_functions.py
new file mode 100644
index 000000000000..c4a0c41bfdd1
--- /dev/null
+++ b/tests/python/relay/test_pass_remove_unused_functions.py
@@ -0,0 +1,75 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import relay
+from tvm.relay import transform
+from tvm.relay.prelude import Prelude
+
+def test_remove_all_prelude_functions():
+    mod = relay.Module()
+    p = Prelude(mod)
+    x = relay.var("x", shape=(1, 16))
+    mod["main"] = relay.Function([x], x)
+    mod = relay.transform.RemoveUnusedFunctions()(mod)
+    l = set([x[0].name_hint for x in mod.functions.items()])
+    assert l == set(['main'])
+
+def test_remove_all_prelude_functions_but_referenced_functions():
+    mod = relay.Module()
+    p = Prelude(mod)
+    x = relay.var("x", shape=(1, 16))
+    id_func = relay.Function([x], x)
+    id_name = relay.GlobalVar('id_func')
+    mod[id_name] = id_func
+
+    mod["main"] = relay.Function([x], id_name(x))
+    mod = relay.transform.RemoveUnusedFunctions()(mod)
+    l = set([x[0].name_hint for x in mod.functions.items()])
+    assert l == set(['id_func', 'main'])
+
+def test_keep_only_referenced_prelude_functions():
+    mod = relay.Module()
+    p = Prelude(mod)
+    l = p.nil()
+    for i in [4, 3, 2, 1, 0]:
+        l = p.cons(relay.const(i), l)
+    body = p.hd(p.tl(p.tl(l)))
+    mod["main"] = relay.Function([], body)
+    mod = relay.transform.RemoveUnusedFunctions()(mod)
+    l = set([x[0].name_hint for x in mod.functions.items()])
+    assert l == set(['tl', 'hd', 'main'])
+
+def test_multiple_entry_functions():
+    mod = relay.Module()
+    p = Prelude(mod)
+    l = p.nil()
+    for i in [4, 3, 2, 1, 0]:
+        l = p.cons(relay.const(i), l)
+    body = p.hd(p.tl(p.tl(l)))
+    mod["main1"] = relay.Function([], body)
+
+    x = relay.var("x", shape=(1, 16))
+    id_func = relay.Function([x], x)
+    id_name = relay.GlobalVar('id_func')
+    mod[id_name] = id_func
+    mod["main2"] = relay.Function([x], id_name(x))
+    mod = relay.transform.RemoveUnusedFunctions(['main1', 'main2'])(mod)
+    l = set([x[0].name_hint for x in mod.functions.items()])
+    assert l == set(['tl', 'hd', 'main2', 'id_func', 'main1'])
+
+if __name__ == '__main__':
+    pytest.main()
diff --git a/tests/scripts/task_python_frontend.sh b/tests/scripts/task_python_frontend.sh
index 78ad8e8ca73a..7a7bcacf6140 100755
--- a/tests/scripts/task_python_frontend.sh
+++ b/tests/scripts/task_python_frontend.sh
@@ -21,6 +21,7 @@ set -u
 
 export PYTHONPATH=nnvm/python:python:topi/python
 # to avoid openblas threading error
+export TVM_BIND_THREADS=0
 export OMP_NUM_THREADS=1
 
 # Rebuild cython
diff --git a/tests/scripts/task_python_integration.sh b/tests/scripts/task_python_integration.sh
index ab0f3a852307..ebbcf2106617 100755
--- a/tests/scripts/task_python_integration.sh
+++ b/tests/scripts/task_python_integration.sh
@@ -21,6 +21,8 @@ set -u
 
 export PYTHONPATH=python:topi/python:apps/extension/python
 export LD_LIBRARY_PATH="build:${LD_LIBRARY_PATH:-}"
+export TVM_BIND_THREADS=0
+export TVM_NUM_THREADS=2
 
 rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc
 
diff --git a/topi/python/topi/arm_cpu/depthwise_conv2d.py b/topi/python/topi/arm_cpu/depthwise_conv2d.py
index a5dc1e94a9fb..207fc712c450 100644
--- a/topi/python/topi/arm_cpu/depthwise_conv2d.py
+++ b/topi/python/topi/arm_cpu/depthwise_conv2d.py
@@ -134,7 +134,7 @@ def _callback(op):
                 data = data_pad.op.input_tensors[0]
             _schedule(cfg, s, data, data_pad, kernel, output)
 
-        if op.tag == 'spatial_depthwise_conv_nchw_output':
+        if op.tag == 'spatial_depthwise_conv2d_nchw_output':
             output = op.output(0)
             conv = op.input_tensors[0]
             data_vec = conv.op.input_tensors[0]
@@ -316,7 +316,7 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype,
                          conv[n,
                               idxdiv(co, VC), idxdiv(h, VH), idxdiv(w, VW),
                               idxmod(h, VH), idxmod(w, VW), idxmod(co, VC)],
-                         name='output_unpack', tag='spatial_depthwise_conv_nchw_output')
+                         name='output_unpack', tag='spatial_depthwise_conv2d_nchw_output')
     return output
 
 def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec,
diff --git a/topi/python/topi/cpp/__init__.py b/topi/python/topi/cpp/__init__.py
new file mode 100644
index 000000000000..c52a819a274a
--- /dev/null
+++ b/topi/python/topi/cpp/__init__.py
@@ -0,0 +1,9 @@
+"""FFI for C++ TOPI ops and schedules"""
+from .impl import * #pylint: disable=wildcard-import
+from . import cuda
+from . import nn
+from . import vision
+from . import x86
+from . import generic
+from . import rocm
+from . import image
diff --git a/topi/python/topi/cpp/cuda.py b/topi/python/topi/cpp/cuda.py
new file mode 100644
index 000000000000..920b2717437d
--- /dev/null
+++ b/topi/python/topi/cpp/cuda.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""FFI for CUDA TOPI ops and schedules"""
+
+from tvm._ffi.function import _init_api_prefix
+
+_init_api_prefix("topi.cpp.cuda", "topi.cuda")
diff --git a/topi/python/topi/cpp/generic.py b/topi/python/topi/cpp/generic.py
new file mode 100644
index 000000000000..a8a71656c1aa
--- /dev/null
+++ b/topi/python/topi/cpp/generic.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""FFI for generic TOPI ops and schedules"""
+
+from tvm._ffi.function import _init_api_prefix
+
+_init_api_prefix("topi.cpp.generic", "topi.generic")
diff --git a/topi/python/topi/cpp/image.py b/topi/python/topi/cpp/image.py
new file mode 100644
index 000000000000..c6a8f2c9db60
--- /dev/null
+++ b/topi/python/topi/cpp/image.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""FFI for image TOPI ops and schedules"""
+
+from tvm._ffi.function import _init_api_prefix
+
+_init_api_prefix("topi.cpp.image", "topi.image")
diff --git a/topi/python/topi/cpp.py b/topi/python/topi/cpp/impl.py
similarity index 64%
rename from topi/python/topi/cpp.py
rename to topi/python/topi/cpp/impl.py
index a1c1c8e94a84..5eff6040a66e 100644
--- a/topi/python/topi/cpp.py
+++ b/topi/python/topi/cpp/impl.py
@@ -14,11 +14,11 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""FFI for C++ TOPI ops and schedules"""
+"""Load Lib for C++ TOPI ops and schedules"""
 import sys
 import os
 import ctypes
-from imp import new_module as _new_module
+
 from tvm._ffi.function import _init_api_prefix
 from tvm._ffi import libinfo
 
@@ -42,27 +42,3 @@ def _load_lib():
 _LIB, _LIB_NAME = _load_lib()
 
 _init_api_prefix("topi.cpp", "topi")
-
-def _create_module(name):
-    fullname = __name__ + "." + name
-    mod = _new_module(fullname)
-    sys.modules[fullname] = mod
-    return mod
-
-# pylint: disable-msg=C0103
-nn = _create_module("nn")
-_init_api_prefix("topi.cpp.nn", "topi.nn")
-generic = _create_module("generic")
-_init_api_prefix("topi.cpp.generic", "topi.generic")
-cuda = _create_module("cuda")
-_init_api_prefix("topi.cpp.cuda", "topi.cuda")
-rocm = _create_module("rocm")
-_init_api_prefix("topi.cpp.rocm", "topi.rocm")
-x86 = _create_module("x86")
-_init_api_prefix("topi.cpp.x86", "topi.x86")
-vision = _create_module("vision")
-_init_api_prefix("topi.cpp.vision", "topi.vision")
-yolo = _create_module("vision.yolo")
-_init_api_prefix("topi.cpp.vision.yolo", "topi.vision.yolo")
-image = _create_module("image")
-_init_api_prefix("topi.cpp.image", "topi.image")
diff --git a/topi/python/topi/cpp/nn.py b/topi/python/topi/cpp/nn.py
new file mode 100644
index 000000000000..59bf1477501d
--- /dev/null
+++ b/topi/python/topi/cpp/nn.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""FFI for NN TOPI ops and schedules"""
+
+from tvm._ffi.function import _init_api_prefix
+
+_init_api_prefix("topi.cpp.nn", "topi.nn")
diff --git a/topi/python/topi/cpp/rocm.py b/topi/python/topi/cpp/rocm.py
new file mode 100644
index 000000000000..d57ce3e3cae1
--- /dev/null
+++ b/topi/python/topi/cpp/rocm.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""FFI for Rocm TOPI ops and schedules"""
+
+from tvm._ffi.function import _init_api_prefix
+
+_init_api_prefix("topi.cpp.rocm", "topi.rocm")
diff --git a/topi/python/topi/cpp/vision/__init__.py b/topi/python/topi/cpp/vision/__init__.py
new file mode 100644
index 000000000000..b965a0314592
--- /dev/null
+++ b/topi/python/topi/cpp/vision/__init__.py
@@ -0,0 +1,7 @@
+"""FFI for vision TOPI ops and schedules"""
+
+from tvm._ffi.function import _init_api_prefix
+
+from . import yolo
+
+_init_api_prefix("topi.cpp.vision", "topi.vision")
diff --git a/topi/python/topi/cpp/vision/yolo.py b/topi/python/topi/cpp/vision/yolo.py
new file mode 100644
index 000000000000..072ab29ff524
--- /dev/null
+++ b/topi/python/topi/cpp/vision/yolo.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""FFI for Yolo TOPI ops and schedules"""
+
+from tvm._ffi.function import _init_api_prefix
+
+_init_api_prefix("topi.cpp.vision.yolo", "topi.vision.yolo")
diff --git a/topi/python/topi/cpp/x86.py b/topi/python/topi/cpp/x86.py
new file mode 100644
index 000000000000..a6db26e336bb
--- /dev/null
+++ b/topi/python/topi/cpp/x86.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""FFI for x86 TOPI ops and schedules"""
+
+from tvm._ffi.function import _init_api_prefix
+
+_init_api_prefix("topi.cpp.x86", "topi.x86")
diff --git a/topi/python/topi/vision/rcnn/proposal.py b/topi/python/topi/vision/rcnn/proposal.py
index 1df25a06566d..507d464e081b 100644
--- a/topi/python/topi/vision/rcnn/proposal.py
+++ b/topi/python/topi/vision/rcnn/proposal.py
@@ -14,11 +14,12 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name
+# pylint: disable=invalid-name, singleton-comparison
 """Proposal operator"""
 import math
 import tvm
-
+from ...util import get_const_tuple, get_const_int
+from ...sort import argsort
 
 def generate_anchor(ratio, scale, base_size):
     """Generate anchor"""
@@ -60,6 +61,261 @@ def reg_iou(x1, y1, x2, y2, dx1, dy1, dx2, dy2):
     pred_y2 = y2 + dy2
     return pred_x1, pred_y1, pred_x2, pred_y2
 
+def predict_bbox_ir(cls_prob_buf, bbox_pred_buf, im_info_buf, out_buf, scales, ratios,
+                    feature_stride, rpn_min_size, iou_loss):
+    """Predict bounding boxes based on anchors, scores and deltas.
+
+    Parameters
+    ----------
+    cls_prob_buf : tvm.schedule.Buffer
+        4-D with shape [batch, 2 * num_anchors, height, width]
+
+    bbox_pred_buf : tvm.schedule.Buffer
+        4-D with shape [batch, 4 * num_anchors, height, width]
+
+    im_info_buf : tvm.schedule.Buffer
+        2-D with shape [batch, 3]
+
+    out_buf : tvm.schedule.Buffer
+        3-D with shape [batch, num_bbox, 5]
+        The last dimension is in format of [w_start, h_start, w_end, h_end, score]
+
+    scales : list/tuple of float
+        Scales of anchor windoes.
+
+    ratios : list/tuple of float
+        Ratios of anchor windoes.
+
+    feature_stride : int
+        The size of the receptive field each unit in the convolution layer of the rpn, for example
+        the product of all stride's prior to this layer.
+
+    rpn_min_size : int
+        Minimum height or width in proposal.
+
+    iou_loss : bool
+        Usage of IoU loss.
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    batch, num_anchors, height, width = get_const_tuple(cls_prob_buf.shape)
+    num_anchors //= 2
+    ib = tvm.ir_builder.create()
+
+    p_score = ib.buffer_ptr(cls_prob_buf)
+    p_delta = ib.buffer_ptr(bbox_pred_buf)
+    p_im_info = ib.buffer_ptr(im_info_buf)
+    p_out = ib.buffer_ptr(out_buf)
+
+    idxm = tvm.indexmod
+    idxd = tvm.indexdiv
+
+    with ib.for_range(0, batch * height * width) as tid:
+        w = idxm(tid, width)
+        h = idxm(idxd(tid, width), height)
+        b = idxd(idxd(tid, width), height)
+
+        for k in range(num_anchors):
+            out_index = tid * num_anchors + k
+            ratio = ratios[k // len(scales)]
+            scale = scales[k % len(scales)]
+            anchor = generate_anchor(ratio, scale, feature_stride)
+            im_height = p_im_info[b * 3]
+            im_width = p_im_info[b * 3 + 1]
+            x1 = anchor[0] + w * feature_stride
+            y1 = anchor[1] + h * feature_stride
+            x2 = anchor[2] + w * feature_stride
+            y2 = anchor[3] + h * feature_stride
+
+            delta = [p_delta[((((b * num_anchors + k) * 4 + i) * height + h) * width + w)]
+                     for i in range(4)]
+            regression_func = reg_iou if iou_loss else reg_bbox
+            pred_x1, pred_y1, pred_x2, pred_y2 = regression_func(x1, y1, x2, y2, *delta)
+
+            pred_x1 = tvm.max(tvm.min(pred_x1, im_width - 1.0), 0.0)
+            pred_y1 = tvm.max(tvm.min(pred_y1, im_height - 1.0), 0.0)
+            pred_x2 = tvm.max(tvm.min(pred_x2, im_width - 1.0), 0.0)
+            pred_y2 = tvm.max(tvm.min(pred_y2, im_height - 1.0), 0.0)
+
+            real_height = (im_height / feature_stride).astype('int32')
+            real_width = (im_width / feature_stride).astype('int32')
+
+            bbox_w = pred_x2 - pred_x1 + 1.0
+            bbox_h = pred_y2 - pred_y1 + 1.0
+            min_size = p_im_info[b * 3 + 2] * rpn_min_size
+
+            pred_score = p_score[((b * num_anchors * 2 + num_anchors + k) * height + h) * width + w]
+            pred_score = tvm.expr.Select(tvm.any(h >= real_height, w >= real_width),
+                                         -1.0, pred_score)
+            p_out[out_index * 5 + 0] = pred_x1
+            p_out[out_index * 5 + 1] = pred_y1
+            p_out[out_index * 5 + 2] = pred_x2
+            p_out[out_index * 5 + 3] = pred_y2
+            p_out[out_index * 5 + 4] = pred_score
+
+            with ib.if_scope(tvm.any(bbox_w < min_size, bbox_h < min_size)):
+                p_out[out_index * 5 + 0] -= min_size / 2.0
+                p_out[out_index * 5 + 1] -= min_size / 2.0
+                p_out[out_index * 5 + 2] += min_size / 2.0
+                p_out[out_index * 5 + 3] += min_size / 2.0
+                p_out[out_index * 5 + 4] = -1.0
+
+    return ib.get()
+
+
+def argsort_ir(data_buf, out_index_buf):
+    """Batched odd-even transposition sort.
+
+    Parameters
+    ----------
+    data_buf : tvm.schedule.Buffer
+        2-D with shape [batch, num_bbox]
+
+    out_index_buf : tvm.schedule.Buffer
+        2-D with shape [batch, num_bbox]. Indices of data in sorted order.
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    batch, num_bbox = get_const_tuple(data_buf.shape)
+    ib = tvm.ir_builder.create()
+    p_data = ib.buffer_ptr(data_buf)
+    index_out = ib.buffer_ptr(out_index_buf)
+    temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local")
+    temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local")
+    idxm = tvm.indexmod
+    with ib.for_range(0, batch, for_type="unroll") as b:
+        start = b * num_bbox
+        for i in range(2):
+            with ib.for_range(0, (num_bbox + 1) // 2) as tid:
+                bbox_id = tid * 2 + i
+                with ib.if_scope(bbox_id < num_bbox):
+                    index_out[start + bbox_id] = bbox_id
+        with ib.for_range(0, num_bbox) as k:
+            with ib.for_range(0, (num_bbox + 1) // 2) as tid:
+                offset = start + 2 * tid + idxm(k, 2)
+                with ib.if_scope(tvm.all(offset + 1 < num_bbox,
+                                         p_data[offset] < p_data[offset + 1])):
+                    temp_data[0] = p_data[offset]
+                    p_data[offset] = p_data[offset + 1]
+                    p_data[offset + 1] = temp_data[0]
+                    temp_index[0] = index_out[offset]
+                    index_out[offset] = index_out[offset + 1]
+                    index_out[offset + 1] = temp_index[0]
+    return ib.get()
+
+
+def nms_ir(sorted_bbox_buf, out_buf, nms_threshold):
+    """Non-maximum supression.
+
+    Parameters
+    ----------
+    sorted_bbox_buf : tvm.schedule.Buffer
+        3-D with shape [batch, num_bbox, 5]. The last dimension is in format of
+        [w_start, h_start, w_end, h_end, score].
+
+    out_buf : tvm.schedule.Buffer
+        2-D with shape [batch, num_bbox]. Boolean mask of whether a bounding box should be removed.
+
+    nms_threshold : float
+        Non-maximum suppression threshold.
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
+        """Calculate overlap of two boxes.
+        """
+        w = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2])
+                    - tvm.max(out_tensor[box_a_idx], out_tensor[box_b_idx]) + 1.0)
+        h = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3])
+                    - tvm.max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1]) + 1.0)
+        i = w * h
+        u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx] + 1.0) * \
+            (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1] + 1.0) + \
+            (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx] + 1.0) * \
+            (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1] + 1.0) - i
+        return i / u
+
+    batch, num_bbox = get_const_tuple(out_buf.shape)
+    ib = tvm.ir_builder.create()
+    p_data = ib.buffer_ptr(sorted_bbox_buf)
+    p_out = ib.buffer_ptr(out_buf)
+    with ib.for_range(0, batch, for_type="unroll", name="n") as b:
+        base_idx = b * num_bbox
+        for i in range(num_bbox):
+            p_out[base_idx + i] = False
+        with ib.for_range(0, num_bbox - 1) as l:
+            with ib.for_range(0, num_bbox) as i:
+                with ib.if_scope(tvm.all(i < num_bbox, i > l, p_out[base_idx + l] == False)):
+                    iou = calculate_overlap(p_data, (base_idx + l) * 5, (base_idx + i) * 5)
+                    with ib.if_scope(iou > nms_threshold):
+                        p_out[base_idx + i] = True
+    return ib.get()
+
+
+def prepare_output_ir(sorted_bbox_buf, remove_mask_buf, out_buf):
+    """Copy output after applying nms to continuous memory.
+
+    Parameters
+    ----------
+    sorted_bbox_buf : tvm.schedule.Buffer
+        3-D with shape [batch, num_bbox, 5]. The last dimension is in format of
+        [w_start, h_start, w_end, h_end, score].
+
+    remove_mask_buf : tvm.schedule.Buffer
+        2-D with shape [batch, num_bbox]. Boolean mask of whether a bounding box should be removed.
+
+    out_buf : tvm.schedule.Buffer
+        2-D with shape [batch * rpn_post_nms_top_n, 5]. The last dimension is in format of
+        [batch_index, w_start, h_start, w_end, h_end].
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    batch, num_bbox, _ = get_const_tuple(sorted_bbox_buf.shape)
+    rpn_post_nms_top_n = get_const_int(out_buf.shape[0]) // batch
+    ib = tvm.ir_builder.create()
+    i = ib.allocate('int32', (batch,), 'i', scope='local')
+    p_sorted_bbox = ib.buffer_ptr(sorted_bbox_buf)
+    p_remove = ib.buffer_ptr(remove_mask_buf)
+    p_out = ib.buffer_ptr(out_buf)
+
+    nkeep = ib.allocate('int32', (batch,), 'nkeep', scope='local')
+
+    with ib.for_range(0, batch) as b:
+        nkeep[b] = 0
+        i[b] = 0
+
+    with ib.for_range(0, num_bbox) as j:
+        with ib.for_range(0, batch) as b:
+            with ib.if_scope(p_remove[b * num_bbox + j] == False):
+                nkeep[b] += 1
+    with ib.for_range(0, batch) as b:
+        with ib.if_scope(nkeep[b] > 0):
+            with ib.for_range(0, tvm.ceil(
+                tvm.const(rpn_post_nms_top_n, 'float32') / nkeep[b]).astype('int32')):
+                with ib.for_range(0, num_bbox) as j:
+                    offset_j = (b * num_bbox + j) * 5
+                    offset_i = (b * rpn_post_nms_top_n + i[b]) * 5
+                    with ib.if_scope(tvm.all(i[b] < rpn_post_nms_top_n,
+                                             p_remove[(b*num_bbox+j)] == False)):
+                        p_out[offset_i] = tvm.expr.Cast('float32', b)
+                        with ib.for_range(0, 4, for_type='unroll') as k:
+                            p_out[offset_i + k + 1] = p_sorted_bbox[offset_j + k]
+                        i[b] = i[b] + 1
+
+    body = ib.get()
+    return body
 
 @tvm.target.generic_func
 def proposal(cls_prob, bbox_pred, im_info, scales, ratios, feature_stride, threshold,
@@ -109,4 +365,25 @@ def proposal(cls_prob, bbox_pred, im_info, scales, ratios, feature_stride, thres
         [batch_index, w_start, h_start, w_end, h_end].
     """
     # pylint: disable=unused-argument
-    raise ValueError("missing register for topi.vision.rcnn.proposal")
+    batch, _, height, width = get_const_tuple(cls_prob.shape)
+    num_anchors = len(scales) * len(ratios)
+    num_bbox = height * width * num_anchors
+    rpn_pre_nms_top_n = min(rpn_pre_nms_top_n, num_bbox) if rpn_pre_nms_top_n > 0 else num_bbox
+
+    bbox = tvm.extern((batch, num_bbox, 5), [cls_prob, bbox_pred, im_info], lambda ins, outs:
+                      predict_bbox_ir(ins[0], ins[1], ins[2], outs[0], scales, ratios,
+                                      feature_stride, rpn_min_size, iou_loss),
+                      dtype=bbox_pred.dtype)
+    score = tvm.compute((batch, num_bbox), lambda b, i: bbox[b, i, 4], tag='bbox_score')
+    valid_count_shape = (1,)
+    valid_count = tvm.compute(valid_count_shape, lambda i: num_bbox)
+    sorted_index = argsort(score, valid_count=valid_count, axis=1, is_ascend=False)
+    sorted_bbox = tvm.compute((batch, rpn_pre_nms_top_n, 5),
+                              lambda b, i, j: bbox[b, sorted_index[b, i], j], tag='sorted_bbox')
+    nms_remove_mask = tvm.extern((batch, rpn_pre_nms_top_n), [sorted_bbox],
+                                 lambda ins, outs: nms_ir(ins[0], outs[0], threshold),
+                                 dtype='bool')
+    nms_out = tvm.extern((batch * rpn_post_nms_top_n, 5), [sorted_bbox, nms_remove_mask],
+                         lambda ins, outs: prepare_output_ir(ins[0], ins[1], outs[0]),
+                         dtype=sorted_bbox.dtype)
+    return nms_out
diff --git a/topi/python/topi/x86/dense.py b/topi/python/topi/x86/dense.py
index 2a739d5c5d8f..605a1754c846 100644
--- a/topi/python/topi/x86/dense.py
+++ b/topi/python/topi/x86/dense.py
@@ -32,7 +32,7 @@ def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None):
     if "cblas" in target.libs:
         C = cblas.matmul(data, weight, False, True)
         if bias is not None:
-            C = tvm.compute(C.shape, lambda i, j: C[i, j] + bias[j].astype(out_dtype),
+            C = tvm.compute(C.shape, lambda i, j: C[i, j] + bias[j],
                             tag=tag.BROADCAST)
         return C
 
diff --git a/topi/tests/python/test_topi_vision.py b/topi/tests/python/test_topi_vision.py
index 08b6d2e7d414..a081f0797dad 100644
--- a/topi/tests/python/test_topi_vision.py
+++ b/topi/tests/python/test_topi_vision.py
@@ -378,7 +378,7 @@ def check_device(device):
             f(tvm_cls_prob, tvm_bbox_pred, tvm_im_info, tvm_out)
             tvm.testing.assert_allclose(tvm_out.asnumpy(), np_out, rtol=1e-4)
 
-    for device in ['cuda']:
+    for device in ['llvm', 'cuda']:
         check_device(device)
 
 
diff --git a/tutorials/dev/relay_pass_infra.py b/tutorials/dev/relay_pass_infra.py
index 2a2d1f50eb88..87a3bf1c3ca7 100644
--- a/tutorials/dev/relay_pass_infra.py
+++ b/tutorials/dev/relay_pass_infra.py
@@ -16,7 +16,7 @@
 # under the License.
 # pylint: disable=line-too-long
 """
-.. _tutorial-relay-pass-infra
+.. _tutorial-relay-pass-infra:
 
 How to Use Relay Pass Infra
 ===========================
diff --git a/vta/hardware/chisel/src/main/scala/core/TensorLoad.scala b/vta/hardware/chisel/src/main/scala/core/TensorLoad.scala
index ca6803c15aba..d184cd2c286a 100644
--- a/vta/hardware/chisel/src/main/scala/core/TensorLoad.scala
+++ b/vta/hardware/chisel/src/main/scala/core/TensorLoad.scala
@@ -103,20 +103,21 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
           when(dec.xpad_1 =/= 0.U) {
             state := sXPad1
           }.elsewhen(dec.ypad_1 =/= 0.U) {
-              state := sYPad1
-            }
-            .otherwise {
-              state := sIdle
-            }
-        }.elsewhen(dataCtrl.io.stride || dataCtrl.io.split) {
+            state := sYPad1
+          }
+          .otherwise {
+            state := sIdle
+          }
+        }.elsewhen(dataCtrl.io.stride) {
           when(dec.xpad_1 =/= 0.U) {
             state := sXPad1
           }.elsewhen(dec.xpad_0 =/= 0.U) {
-              state := sXPad0
-            }
-            .otherwise {
-              state := sReadCmd
-            }
+            state := sXPad0
+          }.otherwise {
+            state := sReadCmd
+          }
+        }.elsewhen(dataCtrl.io.split) {
+          state := sReadCmd
         }
       }
     }
@@ -168,13 +169,11 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
   xPadCtrl0.io.start := dec.xpad_0 =/= 0.U &
     ((state === sIdle & io.start) |
       (state === sYPad0 & yPadCtrl0.io.done) |
-      (io.vme_rd.data
-        .fire() & ~dataCtrlDone & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 === 0.U) |
+      (io.vme_rd.data.fire() & ~dataCtrlDone & dataCtrl.io.stride & dec.xpad_1 === 0.U) |
       (state === sXPad1 & xPadCtrl1.io.done & ~dataCtrlDone))
 
   xPadCtrl1.io.start := dec.xpad_1 =/= 0.U & io.vme_rd.data.fire() &
-    ((dataCtrl.io.done) |
-      (~dataCtrl.io.done & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 =/= 0.U))
+    ((dataCtrl.io.done) | (~dataCtrl.io.done & dataCtrl.io.stride & dec.xpad_1 =/= 0.U))
 
   yPadCtrl0.io.inst := io.inst
   yPadCtrl1.io.inst := io.inst
diff --git a/vta/src/runtime.cc b/vta/src/runtime.cc
index 79fc0e293a85..b657d63a95f5 100644
--- a/vta/src/runtime.cc
+++ b/vta/src/runtime.cc
@@ -1016,7 +1016,7 @@ class CommandQueue {
           elem_bytes = VTA_ACC_ELEM_BYTES;
           break;
       case VTA_MEM_ID_OUT:
-          elem_bytes = VTA_INP_ELEM_BYTES;
+          elem_bytes = VTA_OUT_ELEM_BYTES;
           break;
       default:
           LOG(FATAL) << "Memory id not recognized:" << memory_id;
diff --git a/vta/tests/python/unittest/test_vta_insn.py b/vta/tests/python/unittest/test_vta_insn.py
index 574273f274f4..ef3c45ce58d6 100644
--- a/vta/tests/python/unittest/test_vta_insn.py
+++ b/vta/tests/python/unittest/test_vta_insn.py
@@ -24,6 +24,7 @@
 import vta.testing
 from vta.testing import simulator
 
+np.random.seed(0xdeadb)
 
 def test_save_load_out():
     """Test save/store output command"""
@@ -88,68 +89,73 @@ def _run(env, remote):
 def test_padded_load():
     """Test padded load."""
     def _run(env, remote):
-        # declare
-        n = 3
-        m = 5
-        pad_before = [2, 1, 0, 0]
-        pad_after = [1, 2, 0, 0]
-        x = tvm.placeholder(
-            (n, m, env.BATCH, env.BLOCK_OUT),
-            name="x",
-            dtype=env.acc_dtype)
-        x_buf = topi.nn.pad(x, pad_before, pad_after, name="y")
-        # insert no-op that won't be optimized away
-        y_buf = tvm.compute((n + pad_before[0] + pad_after[0],
+        def check_padded_load(pad_before, pad_after, test_name=None):
+            # declare
+            n = 3
+            m = 5
+            x = tvm.placeholder(
+                (n, m, env.BATCH, env.BLOCK_OUT),
+                name="x",
+                dtype=env.acc_dtype)
+            x_buf = topi.nn.pad(x, pad_before, pad_after, name="y")
+            # insert no-op that won't be optimized away
+            y_buf = tvm.compute((n + pad_before[0] + pad_after[0],
+                                 m + pad_before[1] + pad_after[1],
+                                 env.BATCH,
+                                 env.BLOCK_OUT), lambda *i: x_buf(*i)>>0, "y_buf")
+            y = tvm.compute((n + pad_before[0] + pad_after[0],
                              m + pad_before[1] + pad_after[1],
                              env.BATCH,
-                             env.BLOCK_OUT), lambda *i: x_buf(*i)>>0, "y_buf")
-        y = tvm.compute((n + pad_before[0] + pad_after[0],
-                         m + pad_before[1] + pad_after[1],
-                         env.BATCH,
-                         env.BLOCK_OUT), lambda *i: y_buf(*i).astype(env.inp_dtype), "y")
-        # schedule
-        s = tvm.create_schedule(y.op)
-        s[x_buf].set_scope(env.acc_scope)
-        s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy)
-        s[y_buf].set_scope(env.acc_scope)
-        s[y_buf].pragma(y_buf.op.axis[0], env.alu)
-        s[y].pragma(y.op.axis[0], env.dma_copy)
-        # build
-        with vta.build_config():
-            mod = vta.build(s, [x, y], "ext_dev", env.target_host)
+                             env.BLOCK_OUT), lambda *i: y_buf(*i).astype(env.inp_dtype), "y")
+            # schedule
+            s = tvm.create_schedule(y.op)
+            s[x_buf].set_scope(env.acc_scope)
+            s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy)
+            s[y_buf].set_scope(env.acc_scope)
+            s[y_buf].pragma(y_buf.op.axis[0], env.alu)
+            s[y].pragma(y.op.axis[0], env.dma_copy)
+            # build
+            with vta.build_config():
+                mod = vta.build(s, [x, y], "ext_dev", env.target_host)
 
-        if not remote:
-            return
-        temp = util.tempdir()
-        mod.save(temp.relpath("padded_load.o"))
-        remote.upload(temp.relpath("padded_load.o"))
-        f = remote.load_module("padded_load.o")
-        # verify
-        ctx = remote.ext_dev(0)
-        x_np = np.random.randint(-10, 10, size=(
-            n, m, env.BATCH, env.BLOCK_OUT)).astype(x.dtype)
-        y_np = np.zeros((n + pad_before[0] + pad_after[0],
-                         m + pad_before[1] + pad_after[1],
-                         env.BATCH,
-                         env.BLOCK_OUT)).astype(y.dtype)
-        y_np[pad_before[0]:pad_before[0] + n,
-             pad_before[1]:pad_before[1] + m,
-             :] = x_np
-        x_nd = tvm.nd.array(x_np, ctx)
-        y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
+            if not remote:
+                return
+            temp = util.tempdir()
+            mod.save(temp.relpath("padded_load.o"))
+            remote.upload(temp.relpath("padded_load.o"))
+            f = remote.load_module("padded_load.o")
+            # verify
+            ctx = remote.ext_dev(0)
+            x_np = np.random.randint(0, 10, size=(
+                n, m, env.BATCH, env.BLOCK_OUT)).astype(x.dtype)
+            y_np = np.zeros((n + pad_before[0] + pad_after[0],
+                             m + pad_before[1] + pad_after[1],
+                             env.BATCH,
+                             env.BLOCK_OUT)).astype(y.dtype)
+            y_np[pad_before[0]:pad_before[0] + n,
+                 pad_before[1]:pad_before[1] + m,
+                 :] = x_np
+            x_nd = tvm.nd.array(x_np, ctx)
+            y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
 
-        if env.TARGET in ["sim", "tsim"]:
-            simulator.clear_stats()
+            if env.TARGET in ["sim", "tsim"]:
+                simulator.clear_stats()
 
-        f(x_nd, y_nd)
+            f(x_nd, y_nd)
 
-        np.testing.assert_equal(y_np, y_nd.asnumpy())
+            np.testing.assert_equal(y_np, y_nd.asnumpy())
 
-        if env.TARGET in ["sim", "tsim"]:
-            sim_stats = simulator.stats()
-            print("Padded load execution statistics:")
-            for k, v in sim_stats.items():
-                print("\t{:<16}: {:>16}".format(k, v))
+            if env.TARGET in ["sim", "tsim"]:
+                sim_stats = simulator.stats()
+                print("Padded {} load execution statistics:".format(test_name))
+                for k, v in sim_stats.items():
+                    print("\t{:<16}: {:>16}".format(k, v))
+
+        check_padded_load([2, 0, 0, 0], [0, 0, 0, 0], test_name="Y0")
+        check_padded_load([0, 2, 0, 0], [0, 0, 0, 0], test_name="Y1")
+        check_padded_load([0, 0, 0, 0], [2, 0, 0, 0], test_name="X0")
+        check_padded_load([0, 0, 0, 0], [0, 2, 0, 0], test_name="X1")
+        check_padded_load([1, 1, 0, 0], [1, 1, 0, 0], test_name="all")
 
     vta.testing.run(_run)