Skip to content

Commit

Permalink
[Adreno] Adapt reduction schedule for adreno (#13100)
Browse files Browse the repository at this point in the history
* [Adreno] Adapt reduction schedule for adreno

Origin cuda schedule uses rfactor that is 10x-50x slower on
Adreno than without barries

* Address PR comments

* Remove copy-paste, start reuse cuda impl

* Address pylint hits

* Extend comment for cuda schedule_reduce_impl
  • Loading branch information
elvin-n authored Oct 24, 2022
1 parent 3e02ac5 commit 03d989f
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 4 deletions.
7 changes: 7 additions & 0 deletions python/tvm/relay/op/strategy/adreno.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,13 @@ def schedule_injective_adreno(attrs, outs, target):
return topi.adreno.schedule_injective(outs)


@schedule_reduce.register(["adreno"])
def schedule_reduce_adreno(attrs, outs, target):
"""schedule reduction ops for adreno GPU"""
with target:
return topi.adreno.schedule_reduce(outs)


@concatenate_strategy.register(["adreno"])
def concatenate_strategy_adreno(attrs, inputs, out_type, target):
strategy = _op.OpStrategy()
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/adreno/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@
from .conv2d_nchw_winograd import *
from .conv2d_nhwc_winograd import *
from .injective import schedule_injective
from .reduction import *
69 changes: 69 additions & 0 deletions python/tvm/topi/adreno/reduction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable,too-many-locals,len-as-condition
"""Schedule for reduce operators"""
import numpy
from tvm import te
from ..utils import get_const_tuple
from .injective import schedule_injective_from_existing
from .utils import get_div
from ..cuda.reduction import schedule_reduce_impl


def _schedule_reduce_adreno(op, sch, is_idx_reduce=False):
if is_idx_reduce:
real_output = op.output(0)
temp_idx_input = op.input_tensors[0].op.output(0)
temp_val_input = op.input_tensors[0].op.output(1)
else:
real_output = op.output(0)
shape = get_const_tuple(real_output.shape)
latest4 = shape[-1] == 4
div4 = numpy.prod(shape) % 4 == 0

# Fuse and split the axis
if latest4:
fused_outer = sch[real_output].fuse(
*[sch[real_output].op.axis[i] for i in range(len(sch[real_output].op.axis) - 1)]
)
else:
fused_outer = sch[real_output].fuse(
*[sch[real_output].op.axis[i] for i in range(len(sch[real_output].op.axis))]
)

ftc = numpy.prod(shape)
a = fused_outer
if latest4:
sch[real_output].vectorize(sch[real_output].op.axis[-1])
elif div4 and not is_idx_reduce:
a, b = sch[real_output].split(fused_outer, factor=4)
sch[real_output].vectorize(b)
ftc = ftc / 4

num_thread = get_div(ftc, 128)

bx, outer_in = sch[real_output].split(a, factor=num_thread)

sch[real_output].bind(bx, te.thread_axis("blockIdx.x"))
sch[real_output].bind(outer_in, te.thread_axis("threadIdx.y"))
if is_idx_reduce:
sch[temp_idx_input].compute_at(sch[real_output], outer_in)
sch[temp_val_input].compute_at(sch[real_output], outer_in)


def schedule_reduce(outs):
return schedule_reduce_impl(outs, _schedule_reduce_adreno, schedule_injective_from_existing)
20 changes: 16 additions & 4 deletions python/tvm/topi/cuda/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,22 @@ def is_scheduled(stage):
return True


def schedule_reduce(outs):
def schedule_reduce_impl(outs, schedule_reduce_stage, schedule_injective_stage):
"""Schedule for inject->reduce->bcast ops.
Traverse over the stages in the schedule and schedule separate stages depending
on the position of the stage. Injecteve post-ops of reduction will be scheduled using
injection schedule, injective pre-ops of reduction will be inlined, reduction stage
will be scheduled using reduction schedule
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
schedule_reduce_stage: Function responsible for scheduling the reduction
stage
schedule_injective_stage: Function responsible for scheduling the
standalone injection stage
Returns
-------
Expand Down Expand Up @@ -153,7 +161,7 @@ def traverse_after_reduce(operator):
"""Internal traverse function"""
if tag.is_broadcast(operator.tag):
if operator not in scheduled_ops:
schedule_injective_from_existing(sch, operator.output(0))
schedule_injective_stage(sch, operator.output(0))
for tensor in operator.input_tensors:
if tensor.op not in scheduled_ops:
if enable_auto_inline:
Expand All @@ -162,13 +170,13 @@ def traverse_after_reduce(operator):
traverse_after_reduce(tensor.op)
elif operator.tag == "comm_reduce":
if operator not in scheduled_ops:
_schedule_reduce(operator, sch, is_idx_reduce=False)
schedule_reduce_stage(operator, sch, is_idx_reduce=False)
for tensor in operator.input_tensors:
if tensor.op not in scheduled_ops:
traverse_before_reduce(tensor.op)
elif operator.tag == "comm_reduce_idx":
if operator not in scheduled_ops:
_schedule_reduce(operator, sch, is_idx_reduce=True)
schedule_reduce_stage(operator, sch, is_idx_reduce=True)
input_tensors = operator.input_tensors[0].op.input_tensors
for tensor in input_tensors:
if tensor.op not in scheduled_ops:
Expand All @@ -183,3 +191,7 @@ def traverse_after_reduce(operator):
for out in outs:
traverse_after_reduce(out.op)
return sch


def schedule_reduce(outs):
return schedule_reduce_impl(outs, _schedule_reduce, schedule_injective_from_existing)
51 changes: 51 additions & 0 deletions tests/python/relay/opencl_texture/test_reduction_texture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import re
import tvm
import numpy as np
from tvm import relay
from tvm.relay import testing
from tvm.contrib import utils
from utils.adreno_utils import gpu_preprocess, build_run_compare


dtype = tvm.testing.parameter("float32")


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_mean(target, dtype):
# NCHW
input_shape = (1, 3, 720, 1280)
A = relay.var("data", shape=input_shape, dtype=dtype)
mean = relay.mean(A, axis=1, keepdims=True)
mod = relay.Function([A], mean)

build_run_compare(mod, {}, {"data": input_shape}, dtype, target)


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_argmax(target, dtype):
# NCHW
input_shape = (1, 3, 720, 1280)
A = relay.var("data", shape=input_shape, dtype=dtype)
argmax = relay.op.argmax(A, axis=[1])
mod = relay.Function([A], argmax)

build_run_compare(mod, {}, {"data": input_shape}, dtype, target)

0 comments on commit 03d989f

Please sign in to comment.