Skip to content

Commit

Permalink
[microNPU] Some housekeeping in the test_ethosu folder (#10824)
Browse files Browse the repository at this point in the history
* [microNPU] Some housekeeping in the test_ethosu folder

* Move the utility functions from test_codegen.py into infra.py for
  wider accessibility
* Remove some unused code
* Make the conv2d codegen tests more general

* Update test_identity_optimizer.py

* Update test_lut_optimizer.py
ekalda authored Mar 31, 2022
1 parent 5629f8a commit 8226bd0
Showing 5 changed files with 199 additions and 359 deletions.
141 changes: 116 additions & 25 deletions tests/python/contrib/test_ethosu/infra.py
Original file line number Diff line number Diff line change
@@ -28,7 +28,8 @@

import os
import struct
import numpy
import numpy as np
import tflite.Model
import math
from enum import IntEnum
import tensorflow as tf
@@ -41,7 +42,11 @@
from tvm import relay
import tvm.relay.backend.contrib.ethosu.op as ethosu_ops
from tvm.topi.nn.utils import get_pad_tuple
from tvm.relay.expr_functor import ExprMutator
from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.relay.backend.contrib.ethosu import preprocess

from tvm.relay.op.contrib.ethosu import partition_for_ethosu
from tests.python.relay.aot.aot_test_utils import (
AOTCompiledTestModel,
AOTDataLinkage,
@@ -180,13 +185,13 @@ def __init__(self, random_state):
self._random_state = random_state

def generate(self, size, dtype):
if dtype == numpy.float32:
if dtype == np.float32:
print("random float32")
return self._random_state.uniform(-1, 1, size).astype(dtype)
else:
print("random (u)int min=%d max=%d", numpy.iinfo(dtype).min, numpy.iinfo(dtype).max)
low = numpy.iinfo(dtype).min
high = numpy.iinfo(dtype).max + 1
print("random (u)int min=%d max=%d", np.iinfo(dtype).min, np.iinfo(dtype).max)
low = np.iinfo(dtype).min
high = np.iinfo(dtype).max + 1
return self._random_state.randint(low, high, size, dtype)


@@ -213,7 +218,7 @@ def generate_ref_data_tflite(model):

# Initialize random generators with a fixed seed to get deterministic results
seed = 0
random_state = numpy.random.RandomState(seed)
random_state = np.random.RandomState(seed)

inputgen = InputGenerator(random_state)

@@ -237,39 +242,125 @@ def generate_ref_data_tflite(model):
return input_data, expected_output_data


def make_partitioned_function(relay_op):
def get_tflite_graph(tf_func, shapes, ranges=None):
tensor_specs = [tf.TensorSpec(shape, dtype=tf.float32) for shape in shapes]
if not ranges:
ranges = [(0, 1) for _ in shapes]
concrete_func = tf_func.get_concrete_function(*tensor_specs)

ifm0 = relay.analysis.free_vars(relay_op)
ifm_shape = ifm0[0].type_annotation.shape
ifm_dtype = ifm0[0].type_annotation.dtype
# Convert the model
def representative_dataset():
for _ in range(100):
inputs = []
for i, shape in enumerate(shapes):
data = np.random.uniform(
low=ranges[i][0], high=ranges[i][1], size=tuple(shape)
).astype("float32")
inputs.append(data)

ifm = relay.var("ifm", shape=ifm_shape, dtype=ifm_dtype)
yield inputs

glb_ethosu = relay.GlobalVar("tvmgen_default_ethosu_main_0")
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_graph = converter.convert()

func = (
relay.Function(ifm0, relay_op)
.with_attr("Inline", 1)
.with_attr("Compiler", "ethos-u")
.with_attr("global_symbol", "tvmgen_default_ethosu_main_0")
.with_attr("Primitive", 1)
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)

relay_module, params = relay.frontend.from_tflite(tflite_model)
mod = partition_for_ethosu(relay_module, params)
return mod, tflite_graph


def compare_ethosu_with_reference(
mod, input_data, output_data, accel_type, output_tolerance=0, print_cmm=False
):
compiled_models = build_source(
mod,
input_data,
output_data,
accel_type,
output_tolerance=output_tolerance,
)
mod = tvm.IRModule()
mod[glb_ethosu] = func
mod = relay.transform.InferType()(mod)

call = relay.Call(glb_ethosu, [ifm])
mod["main"] = relay.Function([ifm], call)
mod = relay.transform.InferType()(mod)
# Assumes only two runtime.Modules are created -- i.e. single offload module
ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]

# Verify generated C source
if print_cmm:
get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
compilation_artifacts = get_artifacts(ethosu_module)
cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
print_payload(cmms)

verify_source(compiled_models, accel_type)


def compare_tvm_with_tflite(
tf_func, shapes, accel_type, ranges=None, output_tolerance=0, print_cmm=False
):
mod, tflite_graph = get_tflite_graph(tf_func, shapes, ranges)

# Generate reference data
input_data, output_data = generate_ref_data_tflite(tflite_graph)

compare_ethosu_with_reference(
mod,
input_data,
output_data,
accel_type,
output_tolerance=output_tolerance,
print_cmm=print_cmm,
)


class EthosUAnnotator(ExprMutator):
"""Annotate entire graph for Ethos-U offload"""

def __init__(self):
super(EthosUAnnotator, self).__init__()
self.compiler = "ethos-u"
self.last_call = True

def visit_call(self, call):
curr_last = self.last_call
self.last_call = False

params = []
for arg in call.args:
param = super().visit(arg)
if isinstance(param, relay.expr.Var):
param = compiler_begin(param, self.compiler)
params.append(param)

new_call = relay.Call(call.op, params, call.attrs)
if curr_last:
new_call = compiler_end(new_call, self.compiler)
return new_call

def visit_constant(self, constant):
new_constant = compiler_begin(constant, self.compiler)
return new_constant


def create_ethosu_partition(mod):
mod["main"] = EthosUAnnotator().visit(mod["main"])
mod = relay.transform.MergeCompilerRegions()(mod)
mod = relay.transform.InferType()(mod)
mod = relay.transform.PartitionGraph()(mod)
mod = relay.transform.InferType()(mod)
mod = preprocess.preprocess_ext_io()(mod)
return mod


def generate_weights_data(shape, dtype):
size = 1
for dim in shape:
size *= dim
return (numpy.arange(size) % 255).reshape(shape).astype(dtype)
return (np.arange(size) % 255).reshape(shape).astype(dtype)


def get_convolutional_args(call, include_buffers=False, remove_constants=False):
Loading

0 comments on commit 8226bd0

Please sign in to comment.