Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[microNPU] Determine block configs using the cascader #10695

Merged
merged 1 commit into from
Mar 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions python/tvm/contrib/ethosu/cascader/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tvm import tir
from .cascader_options import CascaderOptions
from .graph import CascaderGraph, Part, Tensor, TESubgraph
from .parts import EthosuPart
from .tensor_config import MemoryRegion
from .proposal import Proposal
from .proposal_generator import generate_proposals
Expand Down Expand Up @@ -125,6 +126,23 @@ def apply_proposal(proposal: Proposal, sch: te.Schedule) -> None:

"""
for plan in proposal.plans:
for part in plan.part_group:
if isinstance(part, EthosuPart):
tensor_config = plan.tensor_configs[part.output_tensor]
stripe_config = tensor_config.stripe_configs[0]
block_config = part.get_block_config(stripe_config)
iv = part.subgraph.output_tensor.op.axis[0]
block_shape = block_config.output_shape
if len(block_shape) == 4:
height, width, depth = block_shape[1:]
else:
height = block_shape[1]
width = block_shape[3]
depth = block_shape[2] * block_shape[4]
sch[part.subgraph.output_tensor].pragma(iv, "block_config_height", height)
sch[part.subgraph.output_tensor].pragma(iv, "block_config_width", width)
sch[part.subgraph.output_tensor].pragma(iv, "block_config_depth", depth)

output_tensor_config = plan.output_config
output_tensor = output_tensor_config.tensor
output_part = output_tensor.producers[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def get_binary_elementwise_params(
# Get feature map info
serial_ifm, _ = get_ifm_params(input_pointer, producers)
serial_ifm2, _ = get_ifm_params(input_pointer1, producers)
serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers)
serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params(
output_pointer, consumers, producers
)
# Get activation info
serial_activation = SerialActivation(
op=attrs["activation"], clip_min=attrs["clip_min"], clip_max=attrs["clip_max"]
Expand All @@ -100,6 +102,7 @@ def get_binary_elementwise_params(
reversed_operands=reversed_operands,
activation=serial_activation,
rounding_mode=attrs["rounding_mode"],
block_config=serial_block_config,
),
output_pointer,
replace_pointer,
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/tir/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def get_conv2d_params(stmt, producers, consumers):
output_pointer = stores[0].buffer.data
# Get feature map info
serial_ifm, serial_padding = get_ifm_params(input_pointer, producers)
serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers)
serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params(
output_pointer, consumers, producers
)
# Get kernel info
serial_kernel = SerialKernel(
width=int(rw.extent),
Expand Down Expand Up @@ -103,6 +105,7 @@ def get_conv2d_params(stmt, producers, consumers):
activation=serial_activation,
rounding_mode=attrs["rounding_mode"],
upscale=attrs["upscale"],
block_config=serial_block_config,
),
output_pointer,
replace_pointer,
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def get_depthwise_conv2d_params(
output_pointer = stores[0].buffer.data
# Get feature map info
serial_ifm, serial_padding = get_ifm_params(input_pointer, producers)
serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers)
serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params(
output_pointer, consumers, producers
)
# Get kernel info
serial_kernel = SerialKernel(
width=int(rw.extent),
Expand Down Expand Up @@ -113,6 +115,7 @@ def get_depthwise_conv2d_params(
activation=serial_activation,
rounding_mode=attrs["rounding_mode"],
upscale="NONE",
block_config=serial_block_config,
),
output_pointer,
replace_pointer,
Expand Down
17 changes: 14 additions & 3 deletions python/tvm/relay/backend/contrib/ethosu/tir/dma.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""Extract parameters from the DMA operators in TIR."""
import tvm
from .utils import get_outer_loops, get_base_address, get_strides, get_op_attrs
from .spec import SerialFeatureMap, SerialPadding
from .spec import SerialBlockConfig, SerialFeatureMap, SerialPadding


def get_pad_params(stmt):
Expand Down Expand Up @@ -253,6 +253,14 @@ def get_write_params(stmt):

base_address = [get_base_address(index) for index in inner.indices]
data_type = inner.buffer.data.type_annotation.element_type.dtype
if "block_config_height" in attrs:
block_config = SerialBlockConfig(
height=int(attrs["block_config_height"]),
width=int(attrs["block_config_width"]),
depth=int(attrs["block_config_depth"]),
)
else:
block_config = SerialBlockConfig(0, 0, 0)
return (
SerialFeatureMap(
data_type=data_type,
Expand All @@ -273,6 +281,7 @@ def get_write_params(stmt):
stride_w=strides[1],
stride_c=strides[2],
),
block_config,
input_pointer,
output_pointer,
)
Expand Down Expand Up @@ -327,6 +336,8 @@ def get_ofm_params(pointer, consumers, producers):
-------
serial_ifm : SerialFeatureMap
The serializable OFM.
serial_block_config : SerialBlockConfig
The serializable block config.
output_pointer : tvm.tir.Var
The pointer that the OFM DMA pipeline produces.
is_allocator : bool
Expand All @@ -336,11 +347,11 @@ def get_ofm_params(pointer, consumers, producers):
convert_to_nhcwb16 = consumers[pointer]
out_channels, _, output_pointer = get_convert_to_nhcwb16_params(convert_to_nhcwb16)
write = consumers[output_pointer]
serial_ofm, _, output_pointer = get_write_params(write)
serial_ofm, serial_block_config, _, output_pointer = get_write_params(write)
is_allocator = True
if output_pointer not in producers:
is_allocator = False
elif producers[output_pointer] != write:
is_allocator = False
serial_ofm.channels = out_channels
return serial_ofm, output_pointer, is_allocator
return serial_ofm, serial_block_config, output_pointer, is_allocator
10 changes: 9 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/tir/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@
"""Extract information from the identity operator in TIR."""
from typing import Dict, Tuple
import tvm
from .spec import SerialKernel, SerialActivation, SerialPooling, SerialPadding, SerialFeatureMap
from .spec import (
SerialBlockConfig,
SerialKernel,
SerialActivation,
SerialPooling,
SerialPadding,
SerialFeatureMap,
)
from .utils import get_op_attrs, get_base_address, get_strides, get_loads


Expand Down Expand Up @@ -164,6 +171,7 @@ def get_identity_params(
activation=serial_activation,
upscale="NONE",
rounding_mode="TFL",
block_config=SerialBlockConfig(0, 0, 0),
),
output_pointer,
replace_pointer,
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/tir/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def get_pooling_params(
output_pointer = stores[0].buffer.data
# Get feature map info
serial_ifm, serial_padding = get_ifm_params(input_pointer, producers)
serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers)
serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params(
output_pointer, consumers, producers
)
# Get kernel info
serial_kernel = SerialKernel(
width=int(rw.extent),
Expand All @@ -90,6 +92,7 @@ def get_pooling_params(
activation=serial_activation,
rounding_mode=attrs["rounding_mode"],
upscale=attrs["upscale"],
block_config=serial_block_config,
),
output_pointer,
replace_pointer,
Expand Down
7 changes: 7 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,13 @@ def _add_pragmas(stage, ax):
for attr, val in stage.op.attrs.items():
if attr not in ("op", "lut") and not isinstance(val, Propagator):
stage.pragma(ax, str(attr), val)
if stage.op.axis[0] in stage.iter_var_attrs:
attrs = stage.iter_var_attrs[stage.op.axis[0]]
if "block_config_height" in attrs.pragma_keys:
pragmas = dict(zip([k.value for k in attrs.pragma_keys], attrs.pragma_values))
stage.pragma(ax, "block_config_height", pragmas["block_config_height"])
stage.pragma(ax, "block_config_width", pragmas["block_config_width"])
stage.pragma(ax, "block_config_depth", pragmas["block_config_depth"])

for stage in sch.stages:
if (
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/tir/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,16 @@ def __init__(self, op: str, clip_min: int, clip_max: int):
self.clip_max = clip_max


class SerialBlockConfig(SerializableFormat):
"""Specialization class to retrieve arguments of a BlockConfig
(similar to NpuBlockConfig of Vela) on a predefined ordering"""

def __init__(self, height: int, width: int, depth: int):
self.height = height
self.width = width
self.depth = depth


class Serial2DConvolution(SerializableFormat):
"""Specialization class to retrieve arguments of
a ethosu.conv2d tir extern call on a predefined ordering"""
Expand All @@ -190,6 +200,7 @@ def __init__(
activation: SerialActivation,
rounding_mode: str,
upscale: str,
block_config: SerialBlockConfig,
):
self.ifm = ifm
self.ofm = ofm
Expand All @@ -201,6 +212,7 @@ def __init__(
self.activation = activation
self.rounding_mode = rounding_mode
self.upscale = upscale
self.block_config = block_config


class Serial2DDepthwise(SerializableFormat):
Expand All @@ -219,6 +231,7 @@ def __init__(
activation: SerialActivation,
rounding_mode: str,
upscale: str,
block_config: SerialBlockConfig,
):
self.ifm = ifm
self.ofm = ofm
Expand All @@ -230,6 +243,7 @@ def __init__(
self.activation = activation
self.rounding_mode = rounding_mode
self.upscale = upscale
self.block_config = block_config


class SerialCopy(SerializableFormat):
Expand Down Expand Up @@ -261,6 +275,7 @@ def __init__(
activation: SerialActivation,
rounding_mode: str,
upscale: str,
block_config: SerialBlockConfig,
):
self.ifm = ifm
self.ofm = ofm
Expand All @@ -270,6 +285,7 @@ def __init__(
self.activation = activation
self.rounding_mode = rounding_mode
self.upscale = upscale
self.block_config = block_config


class SerialBinaryElementwise(SerializableFormat):
Expand All @@ -285,6 +301,7 @@ def __init__(
reversed_operands: bool,
activation: SerialActivation,
rounding_mode: str,
block_config: SerialBlockConfig,
):
self.ifm = ifm
self.ifm2 = ifm2
Expand All @@ -293,6 +310,7 @@ def __init__(
self.reversed_operands = reversed_operands
self.activation = activation
self.rounding_mode = rounding_mode
self.block_config = block_config


class SerialUnaryElementwise(SerializableFormat):
Expand All @@ -306,9 +324,11 @@ def __init__(
operator_type: str,
activation: SerialActivation,
rounding_mode: str,
block_config: SerialBlockConfig,
):
self.ifm = ifm
self.ofm = ofm
self.operator_type = operator_type
self.activation = activation
self.rounding_mode = rounding_mode
self.block_config = block_config
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def get_unary_elementwise_params(stmt, producers, consumers):
output_pointer = inner.buffer.data
# Get feature map info
serial_ifm, _ = get_ifm_params(input_pointer, producers)
serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers)
serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params(
output_pointer, consumers, producers
)
# Get activation info
serial_activation = SerialActivation(
op=attrs["activation"], clip_min=attrs["clip_min"], clip_max=attrs["clip_max"]
Expand All @@ -73,6 +75,7 @@ def get_unary_elementwise_params(stmt, producers, consumers):
operator_type=attrs["operator_type"],
activation=serial_activation,
rounding_mode=attrs["rounding_mode"],
block_config=serial_block_config,
),
output_pointer,
replace_pointer,
Expand Down
Loading