Skip to content

Commit

Permalink
[Topi][Unittests] Parametrized tests in test_topi_dense.py, split o…
Browse files Browse the repository at this point in the history
…ut gpu-independent implementations (apache#8336)

* [Topi][UnitTests] Parametrized tests in test_topi_dense.py

Now, tests run for multiple data types, can be extended with
additional datatypes.

* [Topi] Separated generic-gpu nn.dense implementations into topi.gpu.dense

As a follow-up to the renaming of "gpu" to "cuda", separating
implementations that require CUDA (e.g. dense_cublas.cuda) from
implementations that require any GPU, but not necessarily a CUDA GPU
(e.g. dense_small_batch.gpu).

My intent is to pair this migration with the extension of unit tests
to cover additional GPU runtimes, migrating only implementations that
run correctly on non-CUDA GPU devices.

* [Vulkan][Codegen] Updated storage sync to avoid incorrect matmul results on some GPUs

- In ThreadAllreduceBuilder, separate out load/store so that they can
  have a memory barrier in-between.

- In Vulkan codegen, added Workgroup memory sync for subgroup thread
  sync, since the different subgroup threads can still access
  workgroup memory.  Longer-term, may need tir enhancements to
  separate out sync of control/memory.

Co-authored-by: Eric Lunderberg <[email protected]>
2 people authored and ylc committed Sep 29, 2021

Unverified

This commit is not signed, but one or more authors requires that any commit attributed to them is signed.
1 parent b613a87 commit 5061751
Showing 10 changed files with 435 additions and 315 deletions.
19 changes: 12 additions & 7 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
@@ -705,24 +705,29 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
data, weights = inputs
b, i = get_const_tuple(data.shape)
o, _ = get_const_tuple(weights.shape)
if data.dtype == "int8" and weights.dtype == "int8" and out_type.dtype == "int32":
if (
target.kind.name == "cuda"
and data.dtype == "int8"
and weights.dtype == "int8"
and out_type.dtype == "int32"
):
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_int8),
wrap_topi_schedule(topi.cuda.schedule_dense_int8),
name="dense_int8.cuda",
)
else:
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_small_batch),
wrap_topi_schedule(topi.cuda.schedule_dense_small_batch),
name="dense_small_batch.cuda",
wrap_compute_dense(topi.gpu.dense_small_batch),
wrap_topi_schedule(topi.gpu.schedule_dense_small_batch),
name="dense_small_batch.gpu",
)

with SpecializedCondition(b >= 32):
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_large_batch),
wrap_topi_schedule(topi.cuda.schedule_dense_large_batch),
name="dense_large_batch.cuda",
wrap_compute_dense(topi.gpu.dense_large_batch),
wrap_topi_schedule(topi.gpu.schedule_dense_large_batch),
name="dense_large_batch.gpu",
plevel=5,
)
if target.kind.name == "cuda":
2 changes: 1 addition & 1 deletion python/tvm/testing.py
Original file line number Diff line number Diff line change
@@ -414,7 +414,7 @@ def _get_targets(target_str=None):


DEFAULT_TEST_TARGETS = (
"llvm;cuda;opencl;metal;rocm;vulkan;nvptx;"
"llvm;cuda;opencl;metal;rocm;vulkan -from_device=0;nvptx;"
"llvm -device=arm_cpu;opencl -device=mali,aocl_sw_emu"
)

1 change: 1 addition & 0 deletions python/tvm/topi/__init__.py
Original file line number Diff line number Diff line change
@@ -51,6 +51,7 @@
from . import nn
from . import x86
from . import cuda
from . import gpu
from . import arm_cpu
from . import mali
from . import bifrost
191 changes: 0 additions & 191 deletions python/tvm/topi/cuda/dense.py
Original file line number Diff line number Diff line change
@@ -19,10 +19,8 @@
import logging
from tvm import te
import tvm.autotvm as autotvm
from tvm.autotvm.task.space import SplitEntity
from tvm.contrib import cublas
from .tensor_intrin import dp4a
from .. import nn
from .. import tag
from .. import generic
from ..utils import traverse_inline, get_const_tuple
@@ -57,195 +55,6 @@ def schedule_dense_cublas(_, outs):
return generic.schedule_extern(outs)


@autotvm.register_topi_compute("dense_small_batch.cuda")
def dense_small_batch(cfg, data, weight, bias=None, out_dtype=None):
"""Dense operator on CUDA"""
return nn.dense(data, weight, bias, out_dtype)


@autotvm.register_topi_schedule("dense_small_batch.cuda")
def schedule_dense_small_batch(cfg, outs):
"""Schedule float32/64 dense with small batch size"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])

def _callback(op):
if op.tag == "dense":
_schedule_dense_small_batch(cfg, s, op.output(0))

traverse_inline(s, outs[0].op, _callback)
return s


def _schedule_dense_small_batch(cfg, s, C):
A, weights = C.op.input_tensors
_, in_dim_weights = get_const_tuple(weights.shape)
_, in_dim_A = get_const_tuple(A.shape)

if isinstance(in_dim_A, int):
in_dim = in_dim_A
elif isinstance(in_dim_weights, int):
in_dim = in_dim_weights
else:
in_dim = None

if in_dim is not None:
cfg.define_split("tile_k", in_dim, num_outputs=2)
if cfg.is_fallback:
cfg["tile_k"] = SplitEntity([-1, 64] if in_dim > 64 else [1, 64])
_, kf = cfg["tile_k"].apply(s, C, C.op.reduce_axis[0])
else:
tile_k = 64
_, kf = s[C].split(C.op.reduce_axis[0], tile_k)

CF = s.rfactor(C, kf)

if C.op in s.outputs:
Out = C
else:
Out = s.outputs[0].output(0)
s[C].compute_at(s[Out], s[Out].op.axis[1])
s[Out].bind(s[Out].op.axis[0], te.thread_axis("blockIdx.y"))
s[Out].bind(s[Out].op.axis[1], te.thread_axis("blockIdx.x"))

tx = s[C].op.reduce_axis[0]
thread_x = te.thread_axis("threadIdx.x")
s[C].bind(tx, thread_x)
s[CF].compute_at(s[C], tx)
s[C].set_store_predicate(thread_x.var.equal(0))
s[Out].set_store_predicate(thread_x.var.equal(0))


@autotvm.register_topi_compute("dense_large_batch.cuda")
def dense_large_batch(cfg, data, weight, bias=None, out_dtype=None):
"""Dense operator on CUDA"""
return nn.dense(data, weight, bias, out_dtype)


@autotvm.register_topi_schedule("dense_large_batch.cuda")
def schedule_dense_large_batch(cfg, outs):
"""Schedule float32/64 dense with large batch size"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])

def _callback(op):
if op.tag == "dense":
_schedule_dense_large_batch(cfg, s, op.output(0))

traverse_inline(s, outs[0].op, _callback)
return s


def _schedule_dense_large_batch(cfg, s, C):
"""Schedule float32/64 dense with large batch size"""
A, B = C.op.input_tensors
batch, in_dim = get_const_tuple(A.shape)
out_dim, _ = get_const_tuple(B.shape)
k = C.op.reduce_axis[0]

# create tuning space
try:
block_cand = [64, 128]
vthread_cand = [2 ** x for x in range(1, 7)]
n_thread_cand = [2 ** x for x in range(3, 7)]
cfg.define_split(
"tile_x",
batch,
num_outputs=4,
filter=lambda x: (
x.size[1] in vthread_cand
and x.size[2] in n_thread_cand
and (x.size[1] * x.size[2] * x.size[3]) in block_cand
),
)
cfg.define_split(
"tile_y",
out_dim,
num_outputs=4,
filter=lambda x: (
x.size[1] in vthread_cand
and x.size[2] in n_thread_cand
and (x.size[1] * x.size[2] * x.size[3]) in block_cand
),
)
cfg.define_split("tile_k", in_dim, num_outputs=3, filter=lambda x: x.size[0] > 2)
except IndexError:
# Index error happens when no entities left after filtering, which was designed
# to prune tuning space for better search efficiency.
logger.debug("Tuning space was created without pruning due to unfit shapes")
cfg.define_split("tile_x", batch, num_outputs=4)
cfg.define_split("tile_y", out_dim, num_outputs=4)
cfg.define_split("tile_k", in_dim, num_outputs=3)

if cfg.is_fallback:
if batch > 1:
cfg["tile_x"] = SplitEntity([-1, 2, 16, 2])
else:
cfg["tile_x"] = SplitEntity([1, 1, 1, 1])
if out_dim > 1:
cfg["tile_y"] = SplitEntity([-1, 2, 16, 2])
else:
cfg["tile_y"] = SplitEntity([1, 1, 1, 1])
if in_dim > 8:
cfg["tile_k"] = SplitEntity([-1, 8, 1])
else:
cfg["tile_k"] = SplitEntity([-1, 1, 1])

# Explicit memory access
AA = s.cache_read(A, "shared", [C])
BB = s.cache_read(B, "shared", [C])
AL = s.cache_read(AA, "local", [C])
BL = s.cache_read(BB, "local", [C])
CC = s.cache_write(C, "local")

# Deal with op fusion
if C.op not in s.outputs:
s[C].compute_inline()
C = s.outputs[0].output(0)

# Split and reorder computation
bx, txz, tx, xi = cfg["tile_x"].apply(s, C, C.op.axis[0])
by, tyz, ty, yi = cfg["tile_y"].apply(s, C, C.op.axis[1])
s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi)
s[CC].compute_at(s[C], tx)

# Binding
s[C].bind(by, te.thread_axis("blockIdx.y"))
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tyz, te.thread_axis("vthread"))
s[C].bind(txz, te.thread_axis("vthread"))
s[C].bind(ty, te.thread_axis("threadIdx.y"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))

# Split reduction
yo, xo = CC.op.axis
ko, kt, ki = cfg["tile_k"].apply(s, CC, k)
s[CC].reorder(ko, kt, ki, yo, xo)
s[AA].compute_at(s[CC], ko)
s[BB].compute_at(s[CC], ko)
s[CC].unroll(kt)
s[AL].compute_at(s[CC], kt)
s[BL].compute_at(s[CC], kt)

# Schedule for A's shared memory load
num_thread_x = cfg["tile_x"].size[2]
ty, _ = s[AA].split(s[AA].op.axis[0], nparts=num_thread_x)
_, xi = s[AA].split(s[AA].op.axis[1], factor=num_thread_x * 4)
tx, xi = s[AA].split(xi, nparts=num_thread_x)
s[AA].bind(ty, te.thread_axis("threadIdx.y"))
s[AA].bind(tx, te.thread_axis("threadIdx.x"))
s[AA].double_buffer()

# Schedule for B' shared memory load
num_thread_y = cfg["tile_y"].size[2]
ty, _ = s[BB].split(s[BB].op.axis[0], nparts=num_thread_y)
_, xi = s[BB].split(s[BB].op.axis[1], factor=num_thread_y * 4)
tx, xi = s[BB].split(xi, nparts=num_thread_y)
s[BB].bind(ty, te.thread_axis("threadIdx.y"))
s[BB].bind(tx, te.thread_axis("threadIdx.x"))
s[BB].double_buffer()


@autotvm.register_topi_compute("dense_int8.cuda")
def dense_int8(cfg, data, weight, bias=None, out_dtype=None):
"""Dense operator for int8 on CUDA"""
20 changes: 20 additions & 0 deletions python/tvm/topi/gpu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=redefined-builtin, wildcard-import
"""GPU specific declaration and schedules."""
from .dense import *
218 changes: 218 additions & 0 deletions python/tvm/topi/gpu/dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=invalid-name, unused-argument
"""Schedule for dense operator"""

import logging

from tvm import autotvm, te
from tvm.autotvm.task.space import SplitEntity

from .. import nn
from ..utils import traverse_inline, get_const_tuple

logger = logging.getLogger("topi")


@autotvm.register_topi_compute("dense_small_batch.gpu")
def dense_small_batch(cfg, data, weight, bias=None, out_dtype=None):
"""Dense operator on GPU"""
return nn.dense(data, weight, bias, out_dtype)


@autotvm.register_topi_schedule("dense_small_batch.gpu")
def schedule_dense_small_batch(cfg, outs):
"""Schedule float32/64 dense with small batch size"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])

def _callback(op):
if op.tag == "dense":
_schedule_dense_small_batch(cfg, s, op.output(0))

traverse_inline(s, outs[0].op, _callback)
return s


def _schedule_dense_small_batch(cfg, s, C):
A, weights = C.op.input_tensors
_, in_dim_weights = get_const_tuple(weights.shape)
_, in_dim_A = get_const_tuple(A.shape)

if isinstance(in_dim_A, int):
in_dim = in_dim_A
elif isinstance(in_dim_weights, int):
in_dim = in_dim_weights
else:
in_dim = None

if in_dim is not None:
cfg.define_split("tile_k", in_dim, num_outputs=2)
if cfg.is_fallback:
cfg["tile_k"] = SplitEntity([-1, 64] if in_dim > 64 else [1, 64])
_, kf = cfg["tile_k"].apply(s, C, C.op.reduce_axis[0])
else:
tile_k = 64
_, kf = s[C].split(C.op.reduce_axis[0], tile_k)

CF = s.rfactor(C, kf)

if C.op in s.outputs:
Out = C
else:
Out = s.outputs[0].output(0)
s[C].compute_at(s[Out], s[Out].op.axis[1])
s[Out].bind(s[Out].op.axis[0], te.thread_axis("blockIdx.y"))
s[Out].bind(s[Out].op.axis[1], te.thread_axis("blockIdx.x"))

tx = s[C].op.reduce_axis[0]
thread_x = te.thread_axis("threadIdx.x")
s[C].bind(tx, thread_x)
s[CF].compute_at(s[C], tx)
s[C].set_store_predicate(thread_x.var.equal(0))
s[Out].set_store_predicate(thread_x.var.equal(0))


@autotvm.register_topi_compute("dense_large_batch.gpu")
def dense_large_batch(cfg, data, weight, bias=None, out_dtype=None):
"""Dense operator on GPU"""
return nn.dense(data, weight, bias, out_dtype)


@autotvm.register_topi_schedule("dense_large_batch.gpu")
def schedule_dense_large_batch(cfg, outs):
"""Schedule float32/64 dense with large batch size"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])

def _callback(op):
if op.tag == "dense":
_schedule_dense_large_batch(cfg, s, op.output(0))

traverse_inline(s, outs[0].op, _callback)
return s


def _schedule_dense_large_batch(cfg, s, C):
"""Schedule float32/64 dense with large batch size"""
A, B = C.op.input_tensors
batch, in_dim = get_const_tuple(A.shape)
out_dim, _ = get_const_tuple(B.shape)
k = C.op.reduce_axis[0]

# create tuning space
try:
block_cand = [64, 128]
vthread_cand = [2 ** x for x in range(1, 7)]
n_thread_cand = [2 ** x for x in range(3, 7)]
cfg.define_split(
"tile_x",
batch,
num_outputs=4,
filter=lambda x: (
x.size[1] in vthread_cand
and x.size[2] in n_thread_cand
and (x.size[1] * x.size[2] * x.size[3]) in block_cand
),
)
cfg.define_split(
"tile_y",
out_dim,
num_outputs=4,
filter=lambda x: (
x.size[1] in vthread_cand
and x.size[2] in n_thread_cand
and (x.size[1] * x.size[2] * x.size[3]) in block_cand
),
)
cfg.define_split("tile_k", in_dim, num_outputs=3, filter=lambda x: x.size[0] > 2)
except IndexError:
# Index error happens when no entities left after filtering, which was designed
# to prune tuning space for better search efficiency.
logger.debug("Tuning space was created without pruning due to unfit shapes")
cfg.define_split("tile_x", batch, num_outputs=4)
cfg.define_split("tile_y", out_dim, num_outputs=4)
cfg.define_split("tile_k", in_dim, num_outputs=3)

if cfg.is_fallback:
if batch > 1:
cfg["tile_x"] = SplitEntity([-1, 2, 16, 2])
else:
cfg["tile_x"] = SplitEntity([1, 1, 1, 1])
if out_dim > 1:
cfg["tile_y"] = SplitEntity([-1, 2, 16, 2])
else:
cfg["tile_y"] = SplitEntity([1, 1, 1, 1])
if in_dim > 8:
cfg["tile_k"] = SplitEntity([-1, 8, 1])
else:
cfg["tile_k"] = SplitEntity([-1, 1, 1])

# Explicit memory access
AA = s.cache_read(A, "shared", [C])
BB = s.cache_read(B, "shared", [C])
AL = s.cache_read(AA, "local", [C])
BL = s.cache_read(BB, "local", [C])
CC = s.cache_write(C, "local")

# Deal with op fusion
if C.op not in s.outputs:
s[C].compute_inline()
C = s.outputs[0].output(0)

# Split and reorder computation
bx, txz, tx, xi = cfg["tile_x"].apply(s, C, C.op.axis[0])
by, tyz, ty, yi = cfg["tile_y"].apply(s, C, C.op.axis[1])
s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi)
s[CC].compute_at(s[C], tx)

# Binding
s[C].bind(by, te.thread_axis("blockIdx.y"))
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tyz, te.thread_axis("vthread"))
s[C].bind(txz, te.thread_axis("vthread"))
s[C].bind(ty, te.thread_axis("threadIdx.y"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))

# Split reduction
yo, xo = CC.op.axis
ko, kt, ki = cfg["tile_k"].apply(s, CC, k)
s[CC].reorder(ko, kt, ki, yo, xo)
s[AA].compute_at(s[CC], ko)
s[BB].compute_at(s[CC], ko)
s[CC].unroll(kt)
s[AL].compute_at(s[CC], kt)
s[BL].compute_at(s[CC], kt)

# Schedule for A's shared memory load
num_thread_x = cfg["tile_x"].size[2]
ty, _ = s[AA].split(s[AA].op.axis[0], nparts=num_thread_x)
_, xi = s[AA].split(s[AA].op.axis[1], factor=num_thread_x * 4)
tx, xi = s[AA].split(xi, nparts=num_thread_x)
s[AA].bind(ty, te.thread_axis("threadIdx.y"))
s[AA].bind(tx, te.thread_axis("threadIdx.x"))
s[AA].double_buffer()

# Schedule for B' shared memory load
num_thread_y = cfg["tile_y"].size[2]
ty, _ = s[BB].split(s[BB].op.axis[0], nparts=num_thread_y)
_, xi = s[BB].split(s[BB].op.axis[1], factor=num_thread_y * 4)
tx, xi = s[BB].split(xi, nparts=num_thread_y)
s[BB].bind(ty, te.thread_axis("threadIdx.y"))
s[BB].bind(tx, te.thread_axis("threadIdx.x"))
s[BB].double_buffer()
28 changes: 19 additions & 9 deletions src/target/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
@@ -63,6 +63,7 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::
}
spirv::Value arg_value = builder_->BufferArgument(builder_->GetSType(value_storage_type),
descriptor_set, num_buffer);
builder_->SetName(arg_value, arg->name_hint);
storage_info_[arg.get()].UpdateContentType(value_storage_type);
var_map_[arg.get()] = arg_value;
} else {
@@ -144,15 +145,21 @@ spirv::Value CodeGenSPIRV::CreateStorageSync(const CallNode* op) {
uint32_t vulkan_api_version = spirv_support_.vulkan_api_version;

int64_t sync_scope;
int64_t memory_semantics;
int64_t memory_semantics = spv::MemorySemanticsSequentiallyConsistentMask;
if ((sync == "warp") && (vulkan_api_version >= VK_API_VERSION_1_1)) {
// Synchronize control at the Subgroup level, but memory at the
// Workgroup level. This is because different invocations in a
// subgroup may have each modified memory that exists at the
// workgroup scope. This should be changed if/when tir exposes
// more information as to which memory access needs to be
// synchronized.
sync_scope = spv::ScopeSubgroup;
memory_semantics =
spv::MemorySemanticsSequentiallyConsistentMask | spv::MemorySemanticsSubgroupMemoryMask;
memory_semantics |=
spv::MemorySemanticsSubgroupMemoryMask | spv::MemorySemanticsWorkgroupMemoryMask;

} else if ((sync == "shared") || (sync == "warp")) {
sync_scope = spv::ScopeWorkgroup;
memory_semantics =
spv::MemorySemanticsSequentiallyConsistentMask | spv::MemorySemanticsWorkgroupMemoryMask;
memory_semantics |= spv::MemorySemanticsWorkgroupMemoryMask;
} else {
LOG(FATAL) << "Do not support sync " << sync;
}
@@ -161,6 +168,7 @@ spirv::Value CodeGenSPIRV::CreateStorageSync(const CallNode* op) {
builder_->MakeInst(spv::OpControlBarrier, builder_->IntImm(type_int, sync_scope),
builder_->IntImm(type_int, sync_scope),
builder_->IntImm(type_int, memory_semantics));

return value;
}

@@ -642,14 +650,16 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) {
if (info.scope.rank == runtime::StorageRank::kLocal) {
buf =
builder_->Allocate(etype, static_cast<uint32_t>(constant_size), spv::StorageClassFunction);
} else {
// shared memory
ICHECK(info.scope.rank == runtime::StorageRank::kShared)
<< "Can only allocate shared or local memory inside kernel";
} else if (info.scope.rank == runtime::StorageRank::kShared) {
// Shared memory
buf =
builder_->Allocate(etype, static_cast<uint32_t>(constant_size), spv::StorageClassWorkgroup);
} else {
LOG(FATAL) << "Can only allocate shared or local memory inside kernel";
}

builder_->SetName(buf, op->buffer_var->name_hint);

ICHECK(!info.content_fixed);
info.UpdateContentType(op->dtype);
ICHECK(!var_map_.count(op->buffer_var.get()));
57 changes: 48 additions & 9 deletions src/tir/transforms/lower_thread_allreduce.cc
Original file line number Diff line number Diff line change
@@ -388,7 +388,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
size_t size = shared_bufs.size();
PrimExpr buf_index = BufIndex(reduce_index, group_index, reduce_extent);
// make reduction
auto freduce = [&](int offset) {
auto fload = [&](int offset) {
Array<PrimExpr> a, b;
for (size_t i = 0; i < size; ++i) {
b.push_back(Load(types[i], shared_bufs[i],
@@ -397,12 +397,19 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
a.push_back(Load(types[i], shared_bufs[i], buf_index, const_true()));
}
Array<PrimExpr> ret = (*combiner)(a, b);
return ret;
};
auto fstore = [&](const Array<PrimExpr>& ret) {
std::vector<Stmt> stores(size);
for (size_t i = 0; i < size; ++i) {
stores[i] = Store(shared_bufs[i], ret[i], buf_index, const_true());
}
return SeqStmt::Flatten(stores);
};
auto freduce = [&](int offset) {
auto ret = fload(offset);
return fstore(ret);
};
// Step one, check for
if (reduce_align > reduce_extent) {
// reduction with the boundary condition
@@ -420,15 +427,47 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
seq.emplace_back(SyncThread("shared"));
}
// in warp synchronization.
std::vector<Stmt> in_warp_seq;
PrimExpr in_warp_cond = reduce_index < (reduce_align >> 1);
while (reduce_align > 1) {
reduce_align = reduce_align >> 1;
in_warp_seq.emplace_back(freduce(reduce_align));
in_warp_seq.emplace_back(SyncThread("warp"));
}
if (in_warp_seq.size() != 0) {
if (reduce_align > 1) {
PrimExpr in_warp_cond = reduce_index < (reduce_align >> 1);

std::vector<Stmt> in_warp_seq;

while (reduce_align > 1) {
reduce_align = reduce_align >> 1;

// freduce can read/write to the same memory location. For
// example, with reduce_align of 4, threadIdx 3 reads from
// memory location 7 as threadIdx 7 is writing to it.
// Therefore, we need to separate out the load from the store
// with a memory barrier in-between. This isn't necessary for
// the earlier normal synchronization, because those are each
// protected by an if-statement. The if-statement is avoided
// here to reduce thread divergence.
auto loads = fload(reduce_align);

Array<Var> in_warp_local_vars;
for (auto expr : loads) {
Var var(
"w_" + std::to_string(reduce_align) + "_" + std::to_string(in_warp_local_vars.size()),
expr->dtype);
in_warp_local_vars.push_back(var);
}

std::vector<Stmt> in_let_statement;
in_let_statement.emplace_back(SyncThread("warp"));
in_let_statement.emplace_back(
fstore({in_warp_local_vars.begin(), in_warp_local_vars.end()}));
in_let_statement.emplace_back(SyncThread("warp"));

Stmt body = SeqStmt::Flatten(in_let_statement);
for (size_t i = 0; i < size; i++) {
body = LetStmt(in_warp_local_vars[i], loads[i], body);
}
in_warp_seq.push_back(body);
}

Stmt warp_body = SeqStmt::Flatten(in_warp_seq);

seq.emplace_back(IfThenElse(in_warp_cond, warp_body));
seq.emplace_back(SyncThread("shared"));
}
2 changes: 1 addition & 1 deletion tests/python/relay/test_autotvm_task_extraction.py
Original file line number Diff line number Diff line change
@@ -115,7 +115,7 @@ def get_net(batch, in_dim, out_dim, dtype, out_dtype):

mod, params = get_net(1, 16, 32, "float32", "float32")
tasks = autotvm.task.extract_from_program(mod, target=target, params=params, ops=(dense,))
assert len(tasks) == 1 and tasks[0].name == "dense_small_batch.cuda"
assert len(tasks) == 1 and tasks[0].name == "dense_small_batch.gpu"

mod, params = get_net(1, 16, 32, "int8", "int32")
tasks = autotvm.task.extract_from_program(mod, target=target, params=params, ops=(dense,))
212 changes: 115 additions & 97 deletions tests/python/topi/python/test_topi_dense.py
Original file line number Diff line number Diff line change
@@ -15,26 +15,37 @@
# specific language governing permissions and limitations
# under the License.
"""Test code for dense operator"""
import contextlib
import numpy as np
import pytest
import sys

import tvm
from tvm import te
from tvm import topi
import tvm.testing
import tvm.topi.testing
from tvm import te, topi
from tvm.topi.utils import get_const_tuple
from tvm.contrib.pickle_memoize import memoize

from common import Int8Fallback
import tvm.testing

_dense_implement = {
use_bias = tvm.testing.parameter(True, False)
batch_size = tvm.testing.parameter(1, 2, 128)
in_dim, out_dim = tvm.testing.parameters((1024, 1000))
in_dtype, out_dtype = tvm.testing.parameters(
("float32", "float32"),
("int8", "int32"),
)


_dense_implementations = {
"generic": [(topi.nn.dense, topi.generic.schedule_dense)],
"cpu": [
(topi.x86.dense_nopack, topi.x86.schedule_dense_nopack),
(topi.x86.dense_pack, topi.x86.schedule_dense_pack),
],
"gpu": [
(topi.cuda.dense_small_batch, topi.cuda.schedule_dense_small_batch),
(topi.cuda.dense_large_batch, topi.cuda.schedule_dense_large_batch),
(topi.gpu.dense_small_batch, topi.gpu.schedule_dense_small_batch),
(topi.gpu.dense_large_batch, topi.gpu.schedule_dense_large_batch),
],
"mali": [(topi.mali.dense, topi.mali.schedule_dense)],
"bifrost": [(topi.bifrost.dense, topi.bifrost.schedule_dense)],
@@ -43,108 +54,115 @@
}


def verify_dense(batch, in_dim, out_dim, use_bias=True):
A = te.placeholder((batch, in_dim), name="A")
B = te.placeholder((out_dim, in_dim), name="B")
C = te.placeholder((out_dim,), name="C")
dtype = A.dtype

# use memoize to pickle the test data for next time use
@memoize("topi.tests.test_topi_dense")
def get_ref_data():
a_np = np.random.uniform(size=(batch, in_dim)).astype(dtype)
b_np = np.random.uniform(size=(out_dim, in_dim)).astype(dtype)
c_np = np.random.uniform(size=(out_dim,)).astype(dtype)
if use_bias:
d_np = np.maximum(np.dot(a_np, b_np.T) + c_np, 0.0)
else:
d_np = np.maximum(np.dot(a_np, b_np.T), 0.0)
return (a_np, b_np, c_np, d_np)

# get the test data
a_np, b_np, c_np, d_np = get_ref_data()

def check_device(device, dev):
print("Running on target: %s" % device)
for fcompute, fschedule in tvm.topi.testing.dispatch(device, _dense_implement):
with tvm.target.Target(device):
D = fcompute(A, B, C if use_bias else None)
D = topi.nn.relu(D)
s = fschedule([D])
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(b_np, dev)
c = tvm.nd.array(c_np, dev)
d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=dtype), dev)
f = tvm.build(s, [A, B, C, D], device, name="dense")
f(a, b, c, d)
tvm.testing.assert_allclose(d.numpy(), d_np, rtol=1e-5)

for device, dev in tvm.testing.enabled_targets():
check_device(device, dev)


def verify_dense_int8(batch, in_dim, out_dim, use_bias=True):
dtype = "int8"
out_dtype = "int32"
A = te.placeholder((batch, in_dim), name="A", dtype=dtype)
B = te.placeholder((out_dim, in_dim), name="B", dtype=dtype)
@tvm.testing.fixture(cache_return_value=True)
def dense_ref_data(batch_size, in_dim, out_dim, use_bias, in_dtype, out_dtype):
if "float" in in_dtype:
a_np = np.random.uniform(size=(batch_size, in_dim)).astype(in_dtype)
b_np = np.random.uniform(size=(out_dim, in_dim)).astype(in_dtype)
c_np = np.random.uniform(size=(out_dim,)).astype(out_dtype)
elif in_dtype == "int8":
a_np = np.random.randint(low=-128, high=127, size=(batch_size, in_dim)).astype(in_dtype)
b_np = np.random.randint(low=-128, high=127, size=(out_dim, in_dim)).astype(in_dtype)
c_np = np.random.randint(low=-128, high=127, size=(out_dim,)).astype(out_dtype)
else:
raise ValueError("No method to generate test data for data type '{}'".format(in_dtype))

matmul = np.dot(a_np.astype(out_dtype), b_np.T.astype(out_dtype))

if use_bias:
matmul += c_np

d_np = np.maximum(matmul, 0)
return (a_np, b_np, c_np, d_np)


def test_dense(
target,
dev,
batch_size,
in_dim,
out_dim,
use_bias,
dense_ref_data,
in_dtype,
out_dtype,
implementations=None,
):
target = tvm.target.Target(target)

if (
in_dtype == "int8"
and target.kind.name == "cuda"
and not tvm.contrib.nvcc.have_int8(dev.compute_version)
):
pytest.xfail("CUDA int8 intrinsics not available")

if (
in_dtype == "int8"
and target.kind.name == "vulkan"
and not target.attrs.get("supports_int8", False)
):
pytest.xfail("Vulkan int8 driver support not available")

if (
target.kind.name not in ["llvm", "c"]
and len(set(target.keys) & set(_dense_implementations)) == 0
):
pytest.xfail("No implementation for tvm.topi.testing.dispatch to find")

A = te.placeholder((batch_size, in_dim), name="A", dtype=in_dtype)
B = te.placeholder((out_dim, in_dim), name="B", dtype=in_dtype)
C = te.placeholder((out_dim,), name="C", dtype=out_dtype)

# use memoize to pickle the test data for next time use
@memoize("topi.tests.test_topi_dense_int8")
def get_ref_data():
a_np = np.random.randint(low=-128, high=127, size=(batch, in_dim)).astype(dtype)
b_np = np.random.randint(low=-128, high=127, size=(out_dim, in_dim)).astype(dtype)
c_np = np.random.randint(low=-128, high=127, size=(out_dim,)).astype(out_dtype)
d_np = np.dot(a_np.astype(out_dtype), b_np.T.astype(out_dtype))
if use_bias:
d_np += c_np
d_np = np.maximum(d_np, 0.0)
return (a_np, b_np, c_np, d_np)

# get the test data
a_np, b_np, c_np, d_np = get_ref_data()

def check_device(device):
dev = tvm.device(device, 0)
if device == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
print("Skip because int8 intrinsics are not available")
return

print("Running on target: %s" % device)
with tvm.target.Target(device):
D = topi.cuda.dense_int8(A, B, C if use_bias else None, out_dtype)
a_np, b_np, c_np, d_np = dense_ref_data

if implementations is None:
implementations = tvm.topi.testing.dispatch(target, _dense_implementations)

for fcompute, fschedule in implementations:
with tvm.target.Target(target):
D = fcompute(A, B, C if use_bias else None, out_dtype)
D = topi.nn.relu(D)
s = topi.cuda.schedule_dense_int8([D])
s = fschedule([D])

a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(b_np, dev)
c = tvm.nd.array(c_np, dev)
d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=out_dtype), dev)
f = tvm.build(s, [A, B, C, D], device, name="dense")
f = tvm.build(s, [A, B, C, D], target, name="dense")
f(a, b, c, d)
tvm.testing.assert_allclose(d.numpy(), d_np, rtol=1e-5)

for device in ["cuda"]:
check_device(device)


@tvm.testing.uses_gpu
def test_dense():
verify_dense(1, 1024, 1000, use_bias=True)
verify_dense(1, 1024, 1000, use_bias=False)
verify_dense(2, 1024, 1000, use_bias=True)
verify_dense(128, 1024, 1000, use_bias=False)
verify_dense(128, 1024, 1000, use_bias=True)


@tvm.testing.requires_cuda
@tvm.testing.requires_gpu
def test_dense_int8():
@pytest.mark.parametrize("target,in_dtype,out_dtype", [("cuda", "int8", "int32")])
def test_dense_cuda_int8(
target,
dev,
batch_size,
in_dim,
out_dim,
use_bias,
dense_ref_data,
in_dtype,
out_dtype,
):
implementations = [
(topi.cuda.dense_int8, topi.cuda.schedule_dense_int8),
]
with Int8Fallback():
verify_dense_int8(2, 1024, 1000, use_bias=True)
verify_dense_int8(2, 1024, 1000, use_bias=False)
test_dense(
target,
dev,
batch_size,
in_dim,
out_dim,
use_bias,
dense_ref_data,
in_dtype,
out_dtype,
implementations=implementations,
)


if __name__ == "__main__":
test_dense()
test_dense_int8()
sys.exit(pytest.main(sys.argv))

0 comments on commit 5061751

Please sign in to comment.