Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Topi][Unittests] Parametrized tests in test_topi_dense.py, split out gpu-independent implementations #8336

Merged
merged 3 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,24 +705,29 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
data, weights = inputs
b, i = get_const_tuple(data.shape)
o, _ = get_const_tuple(weights.shape)
if data.dtype == "int8" and weights.dtype == "int8" and out_type.dtype == "int32":
if (
target.kind.name == "cuda"
and data.dtype == "int8"
and weights.dtype == "int8"
and out_type.dtype == "int32"
):
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_int8),
wrap_topi_schedule(topi.cuda.schedule_dense_int8),
name="dense_int8.cuda",
)
else:
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_small_batch),
wrap_topi_schedule(topi.cuda.schedule_dense_small_batch),
name="dense_small_batch.cuda",
wrap_compute_dense(topi.gpu.dense_small_batch),
wrap_topi_schedule(topi.gpu.schedule_dense_small_batch),
name="dense_small_batch.gpu",
)

with SpecializedCondition(b >= 32):
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_large_batch),
wrap_topi_schedule(topi.cuda.schedule_dense_large_batch),
name="dense_large_batch.cuda",
wrap_compute_dense(topi.gpu.dense_large_batch),
wrap_topi_schedule(topi.gpu.schedule_dense_large_batch),
name="dense_large_batch.gpu",
plevel=5,
)
if target.kind.name == "cuda":
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def _get_targets(target_str=None):


DEFAULT_TEST_TARGETS = (
"llvm;cuda;opencl;metal;rocm;vulkan;nvptx;"
"llvm;cuda;opencl;metal;rocm;vulkan -from_device=0;nvptx;"
"llvm -device=arm_cpu;opencl -device=mali,aocl_sw_emu"
)

Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from . import nn
from . import x86
from . import cuda
from . import gpu
from . import arm_cpu
from . import mali
from . import bifrost
Expand Down
191 changes: 0 additions & 191 deletions python/tvm/topi/cuda/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@
import logging
from tvm import te
import tvm.autotvm as autotvm
from tvm.autotvm.task.space import SplitEntity
from tvm.contrib import cublas
from .tensor_intrin import dp4a
from .. import nn
from .. import tag
from .. import generic
from ..utils import traverse_inline, get_const_tuple
Expand Down Expand Up @@ -57,195 +55,6 @@ def schedule_dense_cublas(_, outs):
return generic.schedule_extern(outs)


@autotvm.register_topi_compute("dense_small_batch.cuda")
def dense_small_batch(cfg, data, weight, bias=None, out_dtype=None):
"""Dense operator on CUDA"""
return nn.dense(data, weight, bias, out_dtype)


@autotvm.register_topi_schedule("dense_small_batch.cuda")
def schedule_dense_small_batch(cfg, outs):
"""Schedule float32/64 dense with small batch size"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])

def _callback(op):
if op.tag == "dense":
_schedule_dense_small_batch(cfg, s, op.output(0))

traverse_inline(s, outs[0].op, _callback)
return s


def _schedule_dense_small_batch(cfg, s, C):
A, weights = C.op.input_tensors
_, in_dim_weights = get_const_tuple(weights.shape)
_, in_dim_A = get_const_tuple(A.shape)

if isinstance(in_dim_A, int):
in_dim = in_dim_A
elif isinstance(in_dim_weights, int):
in_dim = in_dim_weights
else:
in_dim = None

if in_dim is not None:
cfg.define_split("tile_k", in_dim, num_outputs=2)
if cfg.is_fallback:
cfg["tile_k"] = SplitEntity([-1, 64] if in_dim > 64 else [1, 64])
_, kf = cfg["tile_k"].apply(s, C, C.op.reduce_axis[0])
else:
tile_k = 64
_, kf = s[C].split(C.op.reduce_axis[0], tile_k)

CF = s.rfactor(C, kf)

if C.op in s.outputs:
Out = C
else:
Out = s.outputs[0].output(0)
s[C].compute_at(s[Out], s[Out].op.axis[1])
s[Out].bind(s[Out].op.axis[0], te.thread_axis("blockIdx.y"))
s[Out].bind(s[Out].op.axis[1], te.thread_axis("blockIdx.x"))

tx = s[C].op.reduce_axis[0]
thread_x = te.thread_axis("threadIdx.x")
s[C].bind(tx, thread_x)
s[CF].compute_at(s[C], tx)
s[C].set_store_predicate(thread_x.var.equal(0))
s[Out].set_store_predicate(thread_x.var.equal(0))


@autotvm.register_topi_compute("dense_large_batch.cuda")
def dense_large_batch(cfg, data, weight, bias=None, out_dtype=None):
"""Dense operator on CUDA"""
return nn.dense(data, weight, bias, out_dtype)


@autotvm.register_topi_schedule("dense_large_batch.cuda")
def schedule_dense_large_batch(cfg, outs):
"""Schedule float32/64 dense with large batch size"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])

def _callback(op):
if op.tag == "dense":
_schedule_dense_large_batch(cfg, s, op.output(0))

traverse_inline(s, outs[0].op, _callback)
return s


def _schedule_dense_large_batch(cfg, s, C):
"""Schedule float32/64 dense with large batch size"""
A, B = C.op.input_tensors
batch, in_dim = get_const_tuple(A.shape)
out_dim, _ = get_const_tuple(B.shape)
k = C.op.reduce_axis[0]

# create tuning space
try:
block_cand = [64, 128]
vthread_cand = [2 ** x for x in range(1, 7)]
n_thread_cand = [2 ** x for x in range(3, 7)]
cfg.define_split(
"tile_x",
batch,
num_outputs=4,
filter=lambda x: (
x.size[1] in vthread_cand
and x.size[2] in n_thread_cand
and (x.size[1] * x.size[2] * x.size[3]) in block_cand
),
)
cfg.define_split(
"tile_y",
out_dim,
num_outputs=4,
filter=lambda x: (
x.size[1] in vthread_cand
and x.size[2] in n_thread_cand
and (x.size[1] * x.size[2] * x.size[3]) in block_cand
),
)
cfg.define_split("tile_k", in_dim, num_outputs=3, filter=lambda x: x.size[0] > 2)
except IndexError:
# Index error happens when no entities left after filtering, which was designed
# to prune tuning space for better search efficiency.
logger.debug("Tuning space was created without pruning due to unfit shapes")
cfg.define_split("tile_x", batch, num_outputs=4)
cfg.define_split("tile_y", out_dim, num_outputs=4)
cfg.define_split("tile_k", in_dim, num_outputs=3)

if cfg.is_fallback:
if batch > 1:
cfg["tile_x"] = SplitEntity([-1, 2, 16, 2])
else:
cfg["tile_x"] = SplitEntity([1, 1, 1, 1])
if out_dim > 1:
cfg["tile_y"] = SplitEntity([-1, 2, 16, 2])
else:
cfg["tile_y"] = SplitEntity([1, 1, 1, 1])
if in_dim > 8:
cfg["tile_k"] = SplitEntity([-1, 8, 1])
else:
cfg["tile_k"] = SplitEntity([-1, 1, 1])

# Explicit memory access
AA = s.cache_read(A, "shared", [C])
BB = s.cache_read(B, "shared", [C])
AL = s.cache_read(AA, "local", [C])
BL = s.cache_read(BB, "local", [C])
CC = s.cache_write(C, "local")

# Deal with op fusion
if C.op not in s.outputs:
s[C].compute_inline()
C = s.outputs[0].output(0)

# Split and reorder computation
bx, txz, tx, xi = cfg["tile_x"].apply(s, C, C.op.axis[0])
by, tyz, ty, yi = cfg["tile_y"].apply(s, C, C.op.axis[1])
s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi)
s[CC].compute_at(s[C], tx)

# Binding
s[C].bind(by, te.thread_axis("blockIdx.y"))
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tyz, te.thread_axis("vthread"))
s[C].bind(txz, te.thread_axis("vthread"))
s[C].bind(ty, te.thread_axis("threadIdx.y"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))

# Split reduction
yo, xo = CC.op.axis
ko, kt, ki = cfg["tile_k"].apply(s, CC, k)
s[CC].reorder(ko, kt, ki, yo, xo)
s[AA].compute_at(s[CC], ko)
s[BB].compute_at(s[CC], ko)
s[CC].unroll(kt)
s[AL].compute_at(s[CC], kt)
s[BL].compute_at(s[CC], kt)

# Schedule for A's shared memory load
num_thread_x = cfg["tile_x"].size[2]
ty, _ = s[AA].split(s[AA].op.axis[0], nparts=num_thread_x)
_, xi = s[AA].split(s[AA].op.axis[1], factor=num_thread_x * 4)
tx, xi = s[AA].split(xi, nparts=num_thread_x)
s[AA].bind(ty, te.thread_axis("threadIdx.y"))
s[AA].bind(tx, te.thread_axis("threadIdx.x"))
s[AA].double_buffer()

# Schedule for B' shared memory load
num_thread_y = cfg["tile_y"].size[2]
ty, _ = s[BB].split(s[BB].op.axis[0], nparts=num_thread_y)
_, xi = s[BB].split(s[BB].op.axis[1], factor=num_thread_y * 4)
tx, xi = s[BB].split(xi, nparts=num_thread_y)
s[BB].bind(ty, te.thread_axis("threadIdx.y"))
s[BB].bind(tx, te.thread_axis("threadIdx.x"))
s[BB].double_buffer()


@autotvm.register_topi_compute("dense_int8.cuda")
def dense_int8(cfg, data, weight, bias=None, out_dtype=None):
"""Dense operator for int8 on CUDA"""
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/topi/gpu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=redefined-builtin, wildcard-import
"""GPU specific declaration and schedules."""
from .dense import *
Loading