Skip to content

Commit

Permalink
[Topi] [Hexagon] Conv2d slice op initial version
Browse files Browse the repository at this point in the history
  • Loading branch information
quic-sanirudh committed May 27, 2022
1 parent aaee8aa commit 913fca3
Show file tree
Hide file tree
Showing 4 changed files with 558 additions and 1 deletion.
22 changes: 22 additions & 0 deletions python/tvm/topi/hexagon/slice_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

""" Computes and Schedules for Hexagon slice ops. """

# pylint: disable=wildcard-import

from .conv2d import *
200 changes: 200 additions & 0 deletions python/tvm/topi/hexagon/slice_ops/conv2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=line-too-long, invalid-name

"""Hexagon slice conv2d compute and schedule"""
import typing

import tvm
from tvm import te


def conv2d_compute(
Input: te.Tensor,
Weights: te.Tensor,
out_shape: typing.Tuple,
stride: typing.Tuple,
dilation: typing.Tuple,
dtype: str,
) -> te.Tensor:
"""Compute for slice conv2d op for hexagon.
This op makes the following assumptions:
1. This op is written for a sliced convolution with 2d physical buffers
2. The input is assumed to be in NHWC layout and filter is in HWIO layout
3. Grouped convolutions are not supported. and there will be a separate compute definition for depthwise convolution
4. In order to get grouped convolutions, it is assumed that the op will be sliced according to the groups and multiple calls to this compute would be placed.
Parameters
----------
Input : te.Tensor
Input activations padded for inner dimension size
Weights : te.Tensor
Weights without dilation
out_shape : typing.Tuple
The logical output shape without considering input padding
stride : typing.Tuple
stride
dilation : typing.Tuple
dilation
dtype : str
dtype
Returns
-------
Output : te.Tensor
Output of applying 2D convolution of Weights on Input
"""

filt_shape = Weights.shape

reduce_channel = tvm.te.reduce_axis((0, filt_shape[2]), name="reduce_channel")
reduce_height = tvm.te.reduce_axis((0, filt_shape[0]), name="reduce_height")
reduce_width = tvm.te.reduce_axis((0, filt_shape[1]), name="reduce_width")
stride_height, stride_width = stride
dilation_height, dilation_width = dilation
Output = tvm.te.compute(
out_shape,
lambda n, h, w, c: tvm.te.sum(
(
Input[
n,
h * stride_height + reduce_height * dilation_height,
w * stride_width + reduce_width * dilation_width,
reduce_channel,
]
* Weights[reduce_height, reduce_width, reduce_channel, c]
).astype(dtype),
axis=[reduce_channel, reduce_height, reduce_width],
),
name="Output",
)
return Output


def conv2d_schedule(
out: te.Tensor,
ins: typing.List[te.Tensor],
transform_activation: typing.Callable,
transform_weights: typing.Callable,
) -> te.Schedule:
"""TE Schedule for the sliced conv2d op
This schedule makes the following assumptions:
1. There is only one output tensor
2. The activations and weights have specific layouts defined by the last 2 arguments
3. All transformation functions are expected to be a bijection for now
Parameters
----------
out : te.Tensor
The output tensor returned by a call to conv2d_compute
ins : typing.List[te.Tensor]
The list of 2 Tensors which would be the input activations and weights
transform_activation : typing.Callable
The transformation function definition for the expected activations layout
transform_weights : typing.Callable
The transformation function definition for the expected weights layout
Returns
-------
sch : te.Schedule
The TE schedule for slice conv2d
"""
Input, Weights = ins
Output = out
sch = tvm.te.create_schedule(Output.op)
rc, rh, rw = sch[Output].op.reduce_axis
sch[Input].transform_layout(transform_activation)
sch[Weights].transform_layout(transform_weights)
transformed_axis = sch[Output].transform_layout(transform_activation)
fused_out_axis = sch[Output].fuse(transformed_axis[-1], transformed_axis[-2])
sch[Output].reorder(*[*transformed_axis[:-2], rh, rw, rc, fused_out_axis])
# The below code doesn't work yet as vectorization across 2D boundary is not yet supported
# s[Output].vectorize(fused_out_axis)
return sch


def conv2d_STIR_schedule(
outs: te.Tensor,
ins: typing.List[te.Tensor],
transform_activation: typing.Callable,
transform_weights: typing.Callable,
) -> tvm.tir.Schedule:
"""STIR schedule definition for the compute defined above by conv2d_compute.
- Auto-generated prim_func before applying schedule primitives for reference
- The below TVMScript code is for conv2d with padded input dimensions and a stride of 1x1
# from tvm.script import tir as T
@T.prim_func
def func(InputTensor: T.Buffer[(1, 24, 12, 32), "float16"], Weights: T.Buffer[(3, 3, 32, 32), "float16"], compute: T.Buffer[(1, 16, 8, 32), "float16"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 16, 8, 32, 32, 3, 3):
with T.block("compute"):
n, h, w, c, rc, rh, rw = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6])
T.reads(InputTensor[n, h + rh, w + rw, rc], Weights[rh, rw, rc, c])
T.writes(compute[n, h, w, c])
with T.init():
compute[n, h, w, c] = T.float16(0)
compute[n, h, w, c] = compute[n, h, w, c] + InputTensor[n, h + rh, w + rw, rc] * Weights[rh, rw, rc, c]
Parameters
----------
outs : te.Tensor
The output Tensor as returned by a call to conv2d_compute
ins : typing.List[te.Tensor]
This is a list of 2 tensors - Input activations and Weights
transform_activation : typing.Callable
The transformation function definition for the expected activations layout
transform_weights : typing.Callable
The transformation function definition for the expected weights layout
Returns
-------
sch : tvm.tir.Schedule
The STIR schedule for slice conv2d compute
"""

assert len(ins) == 2, "This schedule expects only 2 inputs - Activations and Weights"
source_expr = ins + [outs]
prim_func = tvm.te.create_prim_func(source_expr)
sch = tvm.tir.Schedule(prim_func)

compute = sch.get_block("Output")
# Apply layout_transform for activation
sch.transform_layout(compute, ins[0].name, transform_activation)

# Apply layout_transform for weights
sch.transform_layout(compute, ins[1].name, transform_weights)

# Apply layout_transform for output
sch.transform_layout(compute, outs.name, transform_activation)

n, h, w, c, rc, rh, rw = sch.get_loops(compute) # This still returns the original 7d loop
ho, hi = sch.split(h, [None, 8])
wo, wi = sch.split(w, [None, 4])
wio, wii = sch.split(wi, [2, 2])
co, ci = sch.split(c, [None, 32])
sch.reorder(n, ho, wo, co, hi, wio, rh, rw, rc, ci, wii)
sch.decompose_reduction(compute, rh)
# ci_wii = s.fuse(ci, wii)
# s.vectorize(ci_wii)
return sch
2 changes: 1 addition & 1 deletion tests/python/contrib/test_hexagon/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def allocate_hexagon_array(
for dim_i, dim_f in zip(boundaries[:-1], boundaries[1:])
]

arr = tvm.nd.empty(physical_shape, dtype=dtype, device=dev)
arr = tvm.nd.empty(physical_shape, dtype=dtype, device=dev, mem_scope=mem_scope)

if data is not None:
arr.copyfrom(data.reshape(physical_shape))
Expand Down
Loading

0 comments on commit 913fca3

Please sign in to comment.