Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
trahman-quic1 committed Jun 2, 2022
1 parent 037b7e2 commit 009e536
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 40 deletions.
37 changes: 32 additions & 5 deletions python/tvm/topi/hexagon/slice_ops/add_subtract_multiply.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,26 @@
# pylint: disable=invalid-name
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-variable, unused-argument, too-many-locals

"""Compute and schedule for add, multiply, subtract slice op
Please note the following assumptions made by the implementation:
1) The inputs will be multiple of crouton layout except for the axis that needs broadcasting."""

from tvm import te
from tvm import tir
Expand All @@ -7,25 +29,30 @@


def add_broadcast_compute(A, B):
"""Call the add op from topi"""
return topi.add(A, B)


def subtract_broadcast_compute(A, B):
"""Call the subtract op from topi"""
return topi.subtract(A, B)


def multiply_broadcast_compute(A, B):
"""Call the multiply op from topi"""
return topi.multiply(A, B)


def get_layout(layout):
def get_2d_layout(layout):
"""Get the 2d layout for transformation"""
layout += "-2d"
return get_layout_transform_fn(layout)


def STIR_broadcast_schedule(
M, A, B, output_layout: str, input_A_layout: str, input_B_layout: str, op_name: str
):
"""Schedule for input and output layout nhwc-8h2w32c2w considering broadcast"""
func = te.create_prim_func([A, B, M])

s = tir.Schedule(func)
Expand All @@ -35,14 +62,14 @@ def STIR_broadcast_schedule(
block = s.get_block(block_dict[op_name])

if input_A_layout == "nhwc-8h2w32c2w":
input_A_transformed_layout = get_layout(input_A_layout)
input_A_transformed_layout = get_2d_layout(input_A_layout)
s.transform_layout(block, buffer=("read", 0), index_map=input_A_transformed_layout)

if input_B_layout == "nhwc-8h2w32c2w":
input_B_transformed_layout = get_layout(input_B_layout)
input_B_transformed_layout = get_2d_layout(input_B_layout)
s.transform_layout(block, buffer=("read", 1), index_map=input_B_transformed_layout)

output_transformed_layout = get_layout(output_layout)
output_transformed_layout = get_2d_layout(output_layout)
s.transform_layout(block, buffer=("write", 0), index_map=output_transformed_layout)

n, h, w, c = s.get_loops(block)
Expand Down
42 changes: 9 additions & 33 deletions python/tvm/topi/hexagon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,62 +15,38 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name

"""Common hexagon specific utilities"""
from tvm import te


def n11c_1024c_2d(n, h, w, c):
"""Return index map for n11c_1024 2d layout"""
return [n, h, w, c // 1024, te.AXIS_SEPARATOR, c % 1024]


def n11c_1024c_1d(n, h, w, c):
"""Return index map for n11c_1024 1d layout"""
return [n, h, w, c // 1024, c % 1024]


def nhwc_8h2w32c2w_2d(n, h, w, c):
"""Return index map for nhwc_8h2w32c2w 2d layout"""
return [n, h // 8, w // 4, c // 32, te.AXIS_SEPARATOR, h % 8, (w % 4) // 2, c % 32, w % 2]


def nhwc_8h2w32c2w_1d(n, h, w, c):
"""Return index map for nhwc_8h2w32c2w 1d layout"""
return [n, h // 8, w // 4, c // 32, h % 8, (w % 4) // 2, c % 32, w % 2]


def get_layout_transform_fn(layout):
"""Return index map function as per the layout string"""
if layout == "nhwc-8h2w32c2w-2d":
return nhwc_8h2w32c2w_2d
if layout == "nhwc-8h2w32c2w-1d":
return nhwc_8h2w32c2w_1d
elif layout == "n11c-1024c-2d":
if layout == "n11c-1024c-2d":
return n11c_1024c_2d
elif layout == "n11c-1024c-1d":
if layout == "n11c-1024c-1d":
return n11c_1024c_1d
else:
raise RuntimeError(f"Unexpected layout '{layout}'")


def apply_transform(s, block, block_index: int, buffer_type: str, layout: str):
"""Apply transform layout on a buffer
Parameters
----------
s: Schedule
block : BlockRV
The block that accesses the target buffer
buffer_index: int
The index of the buffer in block's read or write region
buffer_type : str
Type of the buffer index, "read" or "write"
layout : str
Layout of the buffer
"""
transform_fn = get_layout_transform_fn(layout)
if layout == "nhwc-8h2w32c2w-1d":
axis_separators = [4]
elif layout == "n11c-1024c-1d":
axis_separators = [2]
else:
raise RuntimeError(f"Unexpected layout '{layout}'")

s.transform_layout(block, block_index, buffer_type, transform_fn)
if axis_separators:
s.set_axis_separator(block, block_index, buffer_type, axis_separators)
raise RuntimeError(f"Unexpected layout '{layout}'")
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import pytest
import numpy as np

np.set_printoptions(threshold=np.inf)
from tvm import te, topi

import tvm.testing
Expand Down Expand Up @@ -75,6 +74,7 @@ def transformed_expected_output_np(expected_output_np, output_layout):
def hexagon_wrapper_allocation(
device, layout, axis_separators, tensor_shape=None, data=None, transformed_data=None, dtype=None
):
"""Input layout can either be nhwc-8h2w32c2w or nhwc"""
if layout == "nhwc-8h2w32c2w":
data_nd = allocate_hexagon_array(
device,
Expand Down Expand Up @@ -233,4 +233,4 @@ def test_transform(


if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))
tvm.testing.main()

0 comments on commit 009e536

Please sign in to comment.