Skip to content

Commit

Permalink
[microNPU] Add support for conv2d running on two cores on U65
Browse files Browse the repository at this point in the history
The 512 mac variant has two cores that processes the weights in
parallel, so we need to split the weights and biases into two
and encode them separately.

Change-Id: I53791f614288ac4df181b9462fc632d35b934a86
  • Loading branch information
ekalda committed Feb 16, 2022
1 parent 64e94ab commit 5f302f3
Show file tree
Hide file tree
Showing 14 changed files with 293 additions and 142 deletions.
4 changes: 2 additions & 2 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,7 +1151,7 @@ def callback(

if axis == [1, 2] and params.keepdims:
weight_scale = 1
weight_values = np.ones([out_channels, filter_height, filter_width, in_channels])
weight_values = np.ones([out_channels, filter_height, filter_width, 1])
scale_bias = vela_api.pack_biases(
biases=np.zeros(ifm_shape[-1]),
ifm_scale=params.ifm.q_params.scale_f32,
Expand Down Expand Up @@ -1216,7 +1216,7 @@ def callback(
)
else:
weight_scale = 1 / (filter_height * filter_width)
weight_values = np.ones([out_channels, filter_height, filter_width, in_channels])
weight_values = np.ones([out_channels, filter_height, filter_width, 1])
bias = -1 * int(params.ifm.q_params.zero_point) * filter_height * filter_width

scale_bias = vela_api.pack_biases(
Expand Down
71 changes: 62 additions & 9 deletions python/tvm/relay/backend/contrib/ethosu/tir/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
# under the License.
# pylint: disable=invalid-name, unused-argument
"""Extract parameters from the convolution operators in TIR."""
import math
import tvm
from ..vela_api import SCALE_BIAS_LENGTH
from ethosu.vela import api as vapi
from ..vela_api import SCALE_BIAS_LENGTH, get_accelerator_config
from .utils import get_outer_loops, get_op_attrs, get_base_address, get_loads, get_stores
from .dma import get_ifm_params, get_ofm_params
from .spec import SerialKernel, SerialAddressRange, SerialActivation, Serial2DConvolution
Expand Down Expand Up @@ -50,6 +52,8 @@ def get_conv2d_params(stmt, producers, consumers):
Whether this operator allocates its output.
"""
accel_config = get_accelerator_config()

attrs, body = get_op_attrs(stmt)
_, _, _, _, _, inner = get_outer_loops(body, "NHWC")
rh = inner
Expand All @@ -76,17 +80,64 @@ def get_conv2d_params(stmt, producers, consumers):
# Get scale_bias info
scale_bias_load = loads[3]
scale_bias_base = get_base_address(scale_bias_load.index)
serial_scale_bias = SerialAddressRange(
address=tvm.tir.Load("uint8", scale_bias_load.buffer_var, scale_bias_base),
length=SCALE_BIAS_LENGTH * serial_ofm[3],
)
# Get weight info
weight_load = loads[2]
weight_base = get_base_address(weight_load.index)
serial_weight = SerialAddressRange(
address=tvm.tir.Load("uint8", weight_load.buffer_var, weight_base),
length=serial_ofm[3] * serial_kernel[0] * serial_kernel[1] * rc.extent,
)
channels = serial_ofm[3] if isinstance(serial_ofm[3], int) else serial_ofm[3].value

if accel_config == vapi.NpuAccelerator.Ethos_U65_512:
scale_bias_length = SCALE_BIAS_LENGTH * math.ceil(channels / 2)
scale_bias2_length = SCALE_BIAS_LENGTH * math.floor(channels / 2)

serial_scale_bias = SerialAddressRange(
address=tvm.tir.Load("uint8", scale_bias_load.buffer_var, scale_bias_base),
length=scale_bias_length,
)
serial_scale_bias2 = SerialAddressRange(
address=tvm.tir.Load(
"uint8", scale_bias_load.buffer_var, scale_bias_base + scale_bias_length
),
length=scale_bias2_length,
)

weight_length = (
channels * serial_kernel[0] * serial_kernel[1] * math.ceil(rc.extent.value / 2)
)
weight2_length = (
channels * serial_kernel[0] * serial_kernel[1] * math.floor(rc.extent.value / 2)
)

serial_weight = SerialAddressRange(
address=tvm.tir.Load("uint8", weight_load.buffer_var, weight_base),
length=weight_length,
)
serial_weight2 = SerialAddressRange(
address=tvm.tir.Load("uint8", weight_load.buffer_var, weight_base + weight_length),
length=weight2_length,
)
else:
scale_bias_length = SCALE_BIAS_LENGTH * channels

serial_scale_bias = SerialAddressRange(
address=tvm.tir.Load("uint8", scale_bias_load.buffer_var, scale_bias_base),
length=scale_bias_length,
)
# Insert -1s into the spec to denote the absence of the other pointer
serial_scale_bias2 = SerialAddressRange(
address=tvm.tir.IntImm("int8", -1),
length=tvm.tir.IntImm("int8", -1),
)

weight_length = channels * serial_kernel[0] * serial_kernel[1] * rc.extent.value

serial_weight = SerialAddressRange(
address=tvm.tir.Load("uint8", weight_load.buffer_var, weight_base),
length=weight_length,
)
serial_weight2 = SerialAddressRange(
address=tvm.tir.IntImm("int8", -1),
length=tvm.tir.IntImm("int8", -1),
)
# Get activation info
serial_activation = SerialActivation(
op=attrs["activation"], clip_min=attrs["clip_min"], clip_max=attrs["clip_max"]
Expand All @@ -97,8 +148,10 @@ def get_conv2d_params(stmt, producers, consumers):
ofm=serial_ofm,
kernel=serial_kernel,
weight=serial_weight,
weight2=serial_weight2,
weight_zero_point=attrs["weight_zero_point"],
scale_bias=serial_scale_bias,
scale_bias2=serial_scale_bias2,
padding=serial_padding,
activation=serial_activation,
rounding_mode=attrs["rounding_mode"],
Expand Down
115 changes: 91 additions & 24 deletions python/tvm/relay/backend/contrib/ethosu/tir/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,22 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument, no-else-return, inconsistent-return-statements
# pylint: disable=invalid-name, unused-argument, no-else-return, inconsistent-return-statements, too-many-nested-blocks
"""The TIR passes to be run on Arm(R) Ethos(TM)-U NPU TIR Compiler."""
from collections import namedtuple
import numpy as np # type: ignore

import tvm
from tvm.relay.backend.contrib.ethosu import vela_api
from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator as tirtocs
from ethosu.vela import api as vapi
from .convolution import get_conv2d_params
from .depthwise import get_depthwise_conv2d_params
from .pooling import get_pooling_params
from .binary_elementwise import get_binary_elementwise_params
from .identity import get_identity_params
from .unary_elementwise import get_unary_elementwise_params
from .transform import get_copy_params
from .utils import get_weights_pointer, get_scale_bias_pointer


def RemoveZeroStores():
Expand Down Expand Up @@ -306,6 +307,7 @@ def EncodeConstants(const_dict):
pointer_to_buffer = {}
rewrite_buffer = {}
rewrite_pointer = {}
pointer_to_offset = {}
accel_config = vela_api.get_accelerator_config()

def _align_scale_bias(tir_extern_call, bias):
Expand Down Expand Up @@ -338,35 +340,91 @@ def _new_buffer(old_buffer, new_value):
rewrite_buffer[old_buffer] = new_buffer
rewrite_pointer[old_buffer.data] = new_buffer.data

def _encode_weights_or_bias(ptr1, ptr2, stmt, encode_func):
"""Encode the weights or align the bias either for one or two cores,
depending on the variant."""
assert ptr1 in pointer_to_buffer
buffer = pointer_to_buffer[ptr1]
constant = buffer_to_const[buffer]

# If we have just one core, encode the whole constant
if ptr2 is None:
new_const = encode_func(stmt, constant)
return new_const, len(new_const)

# Assume OHWI
channels = constant.shape[0]
split_const = np.split(constant, channels, axis=0)

const_list = [split_const[i] for i in range(channels) if i % 2 == 0]
const_to_encode = np.concatenate(const_list, axis=0)

new_const = encode_func(stmt, const_to_encode)
new_const_length = len(new_const)

# Encode half of the constant separately for the other core if it exists
assert ptr1.same_as(ptr2)
const2_list = [split_const[i] for i in range(channels) if i % 2 == 1]
const2_to_encode = np.concatenate(const2_list, axis=0)

new_const2 = encode_func(stmt, const2_to_encode)
new_const = np.append(new_const, new_const2).astype("uint8")

return new_const, new_const_length

def _visit_encode_pre(stmt):
if isinstance(stmt, tvm.tir.Call):
op = str(stmt.args[0].value)
# Handle copies as a special-case by propagating the buffer information
# from the read to the write pointer.
if stmt.args[0] == "ethosu_copy":
if op == "ethosu_copy":
read_pointer = stmt.args[1].buffer_var
if read_pointer in pointer_to_buffer:
write_pointer = stmt.args[3].buffer_var
# Assert writing to the base of the write_var (pre-StorageRewrite)
assert stmt.args[3].index == 0
assert stmt.args[1].index == 0
pointer_to_buffer[write_pointer] = pointer_to_buffer[read_pointer]
else:

ops_with_weights = {
"ethosu_conv2d": tirtocs.translate_ethosu_conv2d,
"ethosu_depthwise_conv2d": tirtocs.translate_ethosu_depthwise_conv2d,
}
if op in ops_with_weights.keys():
npu_op, _ = ops_with_weights[op](stmt)

# Encode the weights
weights_pointer = get_weights_pointer(stmt)
if weights_pointer is not None:
assert weights_pointer in pointer_to_buffer
weights_buffer = pointer_to_buffer[weights_pointer]
weights_value = buffer_to_const[weights_buffer]
new_weights_value = _encode_weights(stmt, weights_value)
_new_buffer(weights_buffer, new_weights_value)
# Align the scale_bias to 16 bytes
scale_bias_pointer = get_scale_bias_pointer(stmt)
if scale_bias_pointer is not None:
assert scale_bias_pointer in pointer_to_buffer
scale_bias_buffer = pointer_to_buffer[scale_bias_pointer]
scale_bias_value = buffer_to_const[scale_bias_buffer]
new_scale_bias_value = _align_scale_bias(stmt, scale_bias_value)
_new_buffer(scale_bias_buffer, new_scale_bias_value)
weights_pointer = npu_op.weights[0].address.buffer_var
weights2_pointer = (
npu_op.weights[1].address.buffer_var
if accel_config == vapi.NpuAccelerator.Ethos_U65_512
else None
)

new_weights, new_weights_length = _encode_weights_or_bias(
weights_pointer, weights2_pointer, stmt, _encode_weights
)

weights_buffer = pointer_to_buffer[weights_pointer]
_new_buffer(weights_buffer, new_weights)
pointer_to_offset[weights_pointer] = new_weights_length

# Align the bias(es) to 16 bit
scale_bias_pointer = npu_op.biases[0].address.buffer_var
scale_bias2_pointer = (
npu_op.biases[1].address.buffer_var
if accel_config == vapi.NpuAccelerator.Ethos_U65_512
else None
)

new_scale_bias, new_scale_bias_length = _encode_weights_or_bias(
scale_bias_pointer, scale_bias2_pointer, stmt, _align_scale_bias
)

scale_bias_buffer = pointer_to_buffer[scale_bias_pointer]

_new_buffer(scale_bias_buffer, new_scale_bias)
pointer_to_offset[scale_bias_pointer] = new_scale_bias_length

def _visit_encode_post(stmt):
# Because encoding may change the data type (e.g. bias to uint8) and type information
Expand Down Expand Up @@ -406,6 +464,14 @@ def _visit_rewrite(stmt):
# Only rewrite the arguments of buffers that have been encoded
if buffer in new_buffers:
new_arg = np.prod(list(pointer_to_buffer[pointer].shape))
if isinstance(stmt.args[i + 1], tvm.tir.Load):
if pointer.same_as(stmt.args[i + 1].buffer_var):
# we've got a pair of loads form the same buffer
new_arg = stmt.args[i + 1].index.value
elif isinstance(stmt.args[i - 3], tvm.tir.Load):
if pointer.same_as(stmt.args[i - 3].buffer_var):
new_arg = new_arg - load.index.value

new_args.append(new_arg)
continue
new_args.append(stmt.args[i])
Expand Down Expand Up @@ -433,10 +499,11 @@ def _visit_rewrite(stmt):
load_pointer = stmt.buffer_var
if load_pointer in rewrite_pointer:
new_pointer = rewrite_pointer[load_pointer]
offset = stmt.index
if offset != 0:
offset = pointer_to_offset[load_pointer]
element_type = new_pointer.type_annotation.element_type.dtype
return tvm.tir.Load(
element_type, new_pointer, stmt.index, stmt.predicate, stmt.span
)
return tvm.tir.Load(element_type, new_pointer, offset, stmt.predicate, stmt.span)
if isinstance(stmt, tvm.tir.AttrStmt):
node_pointer = stmt.node
if node_pointer in rewrite_pointer:
Expand All @@ -448,7 +515,7 @@ def _visit_rewrite(stmt):
def _ftransform(f, mod, ctx):
for i, param in enumerate(f.params):
if i in const_dict:
buffer_to_const[f.buffer_map[param]] = const_dict[i].flatten()
buffer_to_const[f.buffer_map[param]] = const_dict[i]
pointer_to_buffer[f.buffer_map[param].data] = f.buffer_map[param]

# First analyse what needs to be rewritten
Expand All @@ -469,7 +536,7 @@ def _ftransform(f, mod, ctx):
new_value = buffer_to_const[new_buffer]
new_const_dict[i] = new_value
elif buffer in buffer_to_const:
new_const_dict[i] = buffer_to_const[buffer]
new_const_dict[i] = buffer_to_const[buffer].flatten()
new_buffer_map[param] = buffer
else:
new_buffer_map[param] = buffer
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/tir/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,10 @@ def __init__(
ofm: SerialFeatureMap,
kernel: SerialKernel,
weight: SerialAddressRange,
weight2: SerialAddressRange,
weight_zero_point: int,
scale_bias: SerialAddressRange,
scale_bias2: SerialAddressRange,
padding: SerialPadding,
activation: SerialActivation,
rounding_mode: str,
Expand All @@ -195,8 +197,10 @@ def __init__(
self.ofm = ofm
self.kernel = kernel
self.weight = weight
self.weight2 = weight2
self.weight_zero_point = weight_zero_point
self.scale_bias = scale_bias
self.scale_bias2 = scale_bias2
self.padding = padding
self.activation = activation
self.rounding_mode = rounding_mode
Expand Down
18 changes: 0 additions & 18 deletions python/tvm/relay/backend/contrib/ethosu/tir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,6 @@
from tvm import arith


# TODO(@mbaret): Formalise this with a specification
def get_weights_pointer(tir_extern_call):
"""Get the weights pointer from a NPU extern call if it exists"""
supported_ops = ["ethosu_conv2d", "ethosu_depthwise_conv2d"]
if tir_extern_call.args[0] in supported_ops:
return tir_extern_call.args[41].buffer_var
return None


# TODO(@mbaret): Formalise this with a specification
def get_scale_bias_pointer(tir_extern_call):
"""Get the scale_bias pointer from a NPU extern call if it exists"""
supported_ops = ["ethosu_conv2d", "ethosu_depthwise_conv2d"]
if tir_extern_call.args[0] in supported_ops:
return tir_extern_call.args[44].buffer_var
return None


def get_op_attrs(stmt):
"""Iterate through nested attribute statements accumulating their values
in an attribute dictionary.
Expand Down
Loading

0 comments on commit 5f302f3

Please sign in to comment.