From 5061751fccbed484c982a0e3b654744a12022e34 Mon Sep 17 00:00:00 2001 From: Lunderberg Date: Tue, 29 Jun 2021 22:59:16 -0700 Subject: [PATCH] [Topi][Unittests] Parametrized tests in `test_topi_dense.py`, split out gpu-independent implementations (#8336) * [Topi][UnitTests] Parametrized tests in test_topi_dense.py Now, tests run for multiple data types, can be extended with additional datatypes. * [Topi] Separated generic-gpu nn.dense implementations into topi.gpu.dense As a follow-up to the renaming of "gpu" to "cuda", separating implementations that require CUDA (e.g. dense_cublas.cuda) from implementations that require any GPU, but not necessarily a CUDA GPU (e.g. dense_small_batch.gpu). My intent is to pair this migration with the extension of unit tests to cover additional GPU runtimes, migrating only implementations that run correctly on non-CUDA GPU devices. * [Vulkan][Codegen] Updated storage sync to avoid incorrect matmul results on some GPUs - In ThreadAllreduceBuilder, separate out load/store so that they can have a memory barrier in-between. - In Vulkan codegen, added Workgroup memory sync for subgroup thread sync, since the different subgroup threads can still access workgroup memory. Longer-term, may need tir enhancements to separate out sync of control/memory. Co-authored-by: Eric Lunderberg --- python/tvm/relay/op/strategy/cuda.py | 19 +- python/tvm/testing.py | 2 +- python/tvm/topi/__init__.py | 1 + python/tvm/topi/cuda/dense.py | 191 --------------- python/tvm/topi/gpu/__init__.py | 20 ++ python/tvm/topi/gpu/dense.py | 218 ++++++++++++++++++ src/target/spirv/codegen_spirv.cc | 28 ++- src/tir/transforms/lower_thread_allreduce.cc | 57 ++++- .../relay/test_autotvm_task_extraction.py | 2 +- tests/python/topi/python/test_topi_dense.py | 212 +++++++++-------- 10 files changed, 435 insertions(+), 315 deletions(-) create mode 100644 python/tvm/topi/gpu/__init__.py create mode 100644 python/tvm/topi/gpu/dense.py diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 6418f1f96b3b..683f3ecdb22b 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -705,7 +705,12 @@ def dense_strategy_cuda(attrs, inputs, out_type, target): data, weights = inputs b, i = get_const_tuple(data.shape) o, _ = get_const_tuple(weights.shape) - if data.dtype == "int8" and weights.dtype == "int8" and out_type.dtype == "int32": + if ( + target.kind.name == "cuda" + and data.dtype == "int8" + and weights.dtype == "int8" + and out_type.dtype == "int32" + ): strategy.add_implementation( wrap_compute_dense(topi.cuda.dense_int8), wrap_topi_schedule(topi.cuda.schedule_dense_int8), @@ -713,16 +718,16 @@ def dense_strategy_cuda(attrs, inputs, out_type, target): ) else: strategy.add_implementation( - wrap_compute_dense(topi.cuda.dense_small_batch), - wrap_topi_schedule(topi.cuda.schedule_dense_small_batch), - name="dense_small_batch.cuda", + wrap_compute_dense(topi.gpu.dense_small_batch), + wrap_topi_schedule(topi.gpu.schedule_dense_small_batch), + name="dense_small_batch.gpu", ) with SpecializedCondition(b >= 32): strategy.add_implementation( - wrap_compute_dense(topi.cuda.dense_large_batch), - wrap_topi_schedule(topi.cuda.schedule_dense_large_batch), - name="dense_large_batch.cuda", + wrap_compute_dense(topi.gpu.dense_large_batch), + wrap_topi_schedule(topi.gpu.schedule_dense_large_batch), + name="dense_large_batch.gpu", plevel=5, ) if target.kind.name == "cuda": diff --git a/python/tvm/testing.py b/python/tvm/testing.py index 8178b0a14b29..4721c0050656 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -414,7 +414,7 @@ def _get_targets(target_str=None): DEFAULT_TEST_TARGETS = ( - "llvm;cuda;opencl;metal;rocm;vulkan;nvptx;" + "llvm;cuda;opencl;metal;rocm;vulkan -from_device=0;nvptx;" "llvm -device=arm_cpu;opencl -device=mali,aocl_sw_emu" ) diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py index 8028dc2c2186..c7197e9358ac 100644 --- a/python/tvm/topi/__init__.py +++ b/python/tvm/topi/__init__.py @@ -51,6 +51,7 @@ from . import nn from . import x86 from . import cuda +from . import gpu from . import arm_cpu from . import mali from . import bifrost diff --git a/python/tvm/topi/cuda/dense.py b/python/tvm/topi/cuda/dense.py index 8adc38b84b1b..0f410aef9afd 100644 --- a/python/tvm/topi/cuda/dense.py +++ b/python/tvm/topi/cuda/dense.py @@ -19,10 +19,8 @@ import logging from tvm import te import tvm.autotvm as autotvm -from tvm.autotvm.task.space import SplitEntity from tvm.contrib import cublas from .tensor_intrin import dp4a -from .. import nn from .. import tag from .. import generic from ..utils import traverse_inline, get_const_tuple @@ -57,195 +55,6 @@ def schedule_dense_cublas(_, outs): return generic.schedule_extern(outs) -@autotvm.register_topi_compute("dense_small_batch.cuda") -def dense_small_batch(cfg, data, weight, bias=None, out_dtype=None): - """Dense operator on CUDA""" - return nn.dense(data, weight, bias, out_dtype) - - -@autotvm.register_topi_schedule("dense_small_batch.cuda") -def schedule_dense_small_batch(cfg, outs): - """Schedule float32/64 dense with small batch size""" - outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs - s = te.create_schedule([x.op for x in outs]) - - def _callback(op): - if op.tag == "dense": - _schedule_dense_small_batch(cfg, s, op.output(0)) - - traverse_inline(s, outs[0].op, _callback) - return s - - -def _schedule_dense_small_batch(cfg, s, C): - A, weights = C.op.input_tensors - _, in_dim_weights = get_const_tuple(weights.shape) - _, in_dim_A = get_const_tuple(A.shape) - - if isinstance(in_dim_A, int): - in_dim = in_dim_A - elif isinstance(in_dim_weights, int): - in_dim = in_dim_weights - else: - in_dim = None - - if in_dim is not None: - cfg.define_split("tile_k", in_dim, num_outputs=2) - if cfg.is_fallback: - cfg["tile_k"] = SplitEntity([-1, 64] if in_dim > 64 else [1, 64]) - _, kf = cfg["tile_k"].apply(s, C, C.op.reduce_axis[0]) - else: - tile_k = 64 - _, kf = s[C].split(C.op.reduce_axis[0], tile_k) - - CF = s.rfactor(C, kf) - - if C.op in s.outputs: - Out = C - else: - Out = s.outputs[0].output(0) - s[C].compute_at(s[Out], s[Out].op.axis[1]) - s[Out].bind(s[Out].op.axis[0], te.thread_axis("blockIdx.y")) - s[Out].bind(s[Out].op.axis[1], te.thread_axis("blockIdx.x")) - - tx = s[C].op.reduce_axis[0] - thread_x = te.thread_axis("threadIdx.x") - s[C].bind(tx, thread_x) - s[CF].compute_at(s[C], tx) - s[C].set_store_predicate(thread_x.var.equal(0)) - s[Out].set_store_predicate(thread_x.var.equal(0)) - - -@autotvm.register_topi_compute("dense_large_batch.cuda") -def dense_large_batch(cfg, data, weight, bias=None, out_dtype=None): - """Dense operator on CUDA""" - return nn.dense(data, weight, bias, out_dtype) - - -@autotvm.register_topi_schedule("dense_large_batch.cuda") -def schedule_dense_large_batch(cfg, outs): - """Schedule float32/64 dense with large batch size""" - outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs - s = te.create_schedule([x.op for x in outs]) - - def _callback(op): - if op.tag == "dense": - _schedule_dense_large_batch(cfg, s, op.output(0)) - - traverse_inline(s, outs[0].op, _callback) - return s - - -def _schedule_dense_large_batch(cfg, s, C): - """Schedule float32/64 dense with large batch size""" - A, B = C.op.input_tensors - batch, in_dim = get_const_tuple(A.shape) - out_dim, _ = get_const_tuple(B.shape) - k = C.op.reduce_axis[0] - - # create tuning space - try: - block_cand = [64, 128] - vthread_cand = [2 ** x for x in range(1, 7)] - n_thread_cand = [2 ** x for x in range(3, 7)] - cfg.define_split( - "tile_x", - batch, - num_outputs=4, - filter=lambda x: ( - x.size[1] in vthread_cand - and x.size[2] in n_thread_cand - and (x.size[1] * x.size[2] * x.size[3]) in block_cand - ), - ) - cfg.define_split( - "tile_y", - out_dim, - num_outputs=4, - filter=lambda x: ( - x.size[1] in vthread_cand - and x.size[2] in n_thread_cand - and (x.size[1] * x.size[2] * x.size[3]) in block_cand - ), - ) - cfg.define_split("tile_k", in_dim, num_outputs=3, filter=lambda x: x.size[0] > 2) - except IndexError: - # Index error happens when no entities left after filtering, which was designed - # to prune tuning space for better search efficiency. - logger.debug("Tuning space was created without pruning due to unfit shapes") - cfg.define_split("tile_x", batch, num_outputs=4) - cfg.define_split("tile_y", out_dim, num_outputs=4) - cfg.define_split("tile_k", in_dim, num_outputs=3) - - if cfg.is_fallback: - if batch > 1: - cfg["tile_x"] = SplitEntity([-1, 2, 16, 2]) - else: - cfg["tile_x"] = SplitEntity([1, 1, 1, 1]) - if out_dim > 1: - cfg["tile_y"] = SplitEntity([-1, 2, 16, 2]) - else: - cfg["tile_y"] = SplitEntity([1, 1, 1, 1]) - if in_dim > 8: - cfg["tile_k"] = SplitEntity([-1, 8, 1]) - else: - cfg["tile_k"] = SplitEntity([-1, 1, 1]) - - # Explicit memory access - AA = s.cache_read(A, "shared", [C]) - BB = s.cache_read(B, "shared", [C]) - AL = s.cache_read(AA, "local", [C]) - BL = s.cache_read(BB, "local", [C]) - CC = s.cache_write(C, "local") - - # Deal with op fusion - if C.op not in s.outputs: - s[C].compute_inline() - C = s.outputs[0].output(0) - - # Split and reorder computation - bx, txz, tx, xi = cfg["tile_x"].apply(s, C, C.op.axis[0]) - by, tyz, ty, yi = cfg["tile_y"].apply(s, C, C.op.axis[1]) - s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi) - s[CC].compute_at(s[C], tx) - - # Binding - s[C].bind(by, te.thread_axis("blockIdx.y")) - s[C].bind(bx, te.thread_axis("blockIdx.x")) - s[C].bind(tyz, te.thread_axis("vthread")) - s[C].bind(txz, te.thread_axis("vthread")) - s[C].bind(ty, te.thread_axis("threadIdx.y")) - s[C].bind(tx, te.thread_axis("threadIdx.x")) - - # Split reduction - yo, xo = CC.op.axis - ko, kt, ki = cfg["tile_k"].apply(s, CC, k) - s[CC].reorder(ko, kt, ki, yo, xo) - s[AA].compute_at(s[CC], ko) - s[BB].compute_at(s[CC], ko) - s[CC].unroll(kt) - s[AL].compute_at(s[CC], kt) - s[BL].compute_at(s[CC], kt) - - # Schedule for A's shared memory load - num_thread_x = cfg["tile_x"].size[2] - ty, _ = s[AA].split(s[AA].op.axis[0], nparts=num_thread_x) - _, xi = s[AA].split(s[AA].op.axis[1], factor=num_thread_x * 4) - tx, xi = s[AA].split(xi, nparts=num_thread_x) - s[AA].bind(ty, te.thread_axis("threadIdx.y")) - s[AA].bind(tx, te.thread_axis("threadIdx.x")) - s[AA].double_buffer() - - # Schedule for B' shared memory load - num_thread_y = cfg["tile_y"].size[2] - ty, _ = s[BB].split(s[BB].op.axis[0], nparts=num_thread_y) - _, xi = s[BB].split(s[BB].op.axis[1], factor=num_thread_y * 4) - tx, xi = s[BB].split(xi, nparts=num_thread_y) - s[BB].bind(ty, te.thread_axis("threadIdx.y")) - s[BB].bind(tx, te.thread_axis("threadIdx.x")) - s[BB].double_buffer() - - @autotvm.register_topi_compute("dense_int8.cuda") def dense_int8(cfg, data, weight, bias=None, out_dtype=None): """Dense operator for int8 on CUDA""" diff --git a/python/tvm/topi/gpu/__init__.py b/python/tvm/topi/gpu/__init__.py new file mode 100644 index 000000000000..6d9fd39e16b8 --- /dev/null +++ b/python/tvm/topi/gpu/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=redefined-builtin, wildcard-import +"""GPU specific declaration and schedules.""" +from .dense import * diff --git a/python/tvm/topi/gpu/dense.py b/python/tvm/topi/gpu/dense.py new file mode 100644 index 000000000000..806aa9f5ca44 --- /dev/null +++ b/python/tvm/topi/gpu/dense.py @@ -0,0 +1,218 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=invalid-name, unused-argument +"""Schedule for dense operator""" + +import logging + +from tvm import autotvm, te +from tvm.autotvm.task.space import SplitEntity + +from .. import nn +from ..utils import traverse_inline, get_const_tuple + +logger = logging.getLogger("topi") + + +@autotvm.register_topi_compute("dense_small_batch.gpu") +def dense_small_batch(cfg, data, weight, bias=None, out_dtype=None): + """Dense operator on GPU""" + return nn.dense(data, weight, bias, out_dtype) + + +@autotvm.register_topi_schedule("dense_small_batch.gpu") +def schedule_dense_small_batch(cfg, outs): + """Schedule float32/64 dense with small batch size""" + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == "dense": + _schedule_dense_small_batch(cfg, s, op.output(0)) + + traverse_inline(s, outs[0].op, _callback) + return s + + +def _schedule_dense_small_batch(cfg, s, C): + A, weights = C.op.input_tensors + _, in_dim_weights = get_const_tuple(weights.shape) + _, in_dim_A = get_const_tuple(A.shape) + + if isinstance(in_dim_A, int): + in_dim = in_dim_A + elif isinstance(in_dim_weights, int): + in_dim = in_dim_weights + else: + in_dim = None + + if in_dim is not None: + cfg.define_split("tile_k", in_dim, num_outputs=2) + if cfg.is_fallback: + cfg["tile_k"] = SplitEntity([-1, 64] if in_dim > 64 else [1, 64]) + _, kf = cfg["tile_k"].apply(s, C, C.op.reduce_axis[0]) + else: + tile_k = 64 + _, kf = s[C].split(C.op.reduce_axis[0], tile_k) + + CF = s.rfactor(C, kf) + + if C.op in s.outputs: + Out = C + else: + Out = s.outputs[0].output(0) + s[C].compute_at(s[Out], s[Out].op.axis[1]) + s[Out].bind(s[Out].op.axis[0], te.thread_axis("blockIdx.y")) + s[Out].bind(s[Out].op.axis[1], te.thread_axis("blockIdx.x")) + + tx = s[C].op.reduce_axis[0] + thread_x = te.thread_axis("threadIdx.x") + s[C].bind(tx, thread_x) + s[CF].compute_at(s[C], tx) + s[C].set_store_predicate(thread_x.var.equal(0)) + s[Out].set_store_predicate(thread_x.var.equal(0)) + + +@autotvm.register_topi_compute("dense_large_batch.gpu") +def dense_large_batch(cfg, data, weight, bias=None, out_dtype=None): + """Dense operator on GPU""" + return nn.dense(data, weight, bias, out_dtype) + + +@autotvm.register_topi_schedule("dense_large_batch.gpu") +def schedule_dense_large_batch(cfg, outs): + """Schedule float32/64 dense with large batch size""" + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == "dense": + _schedule_dense_large_batch(cfg, s, op.output(0)) + + traverse_inline(s, outs[0].op, _callback) + return s + + +def _schedule_dense_large_batch(cfg, s, C): + """Schedule float32/64 dense with large batch size""" + A, B = C.op.input_tensors + batch, in_dim = get_const_tuple(A.shape) + out_dim, _ = get_const_tuple(B.shape) + k = C.op.reduce_axis[0] + + # create tuning space + try: + block_cand = [64, 128] + vthread_cand = [2 ** x for x in range(1, 7)] + n_thread_cand = [2 ** x for x in range(3, 7)] + cfg.define_split( + "tile_x", + batch, + num_outputs=4, + filter=lambda x: ( + x.size[1] in vthread_cand + and x.size[2] in n_thread_cand + and (x.size[1] * x.size[2] * x.size[3]) in block_cand + ), + ) + cfg.define_split( + "tile_y", + out_dim, + num_outputs=4, + filter=lambda x: ( + x.size[1] in vthread_cand + and x.size[2] in n_thread_cand + and (x.size[1] * x.size[2] * x.size[3]) in block_cand + ), + ) + cfg.define_split("tile_k", in_dim, num_outputs=3, filter=lambda x: x.size[0] > 2) + except IndexError: + # Index error happens when no entities left after filtering, which was designed + # to prune tuning space for better search efficiency. + logger.debug("Tuning space was created without pruning due to unfit shapes") + cfg.define_split("tile_x", batch, num_outputs=4) + cfg.define_split("tile_y", out_dim, num_outputs=4) + cfg.define_split("tile_k", in_dim, num_outputs=3) + + if cfg.is_fallback: + if batch > 1: + cfg["tile_x"] = SplitEntity([-1, 2, 16, 2]) + else: + cfg["tile_x"] = SplitEntity([1, 1, 1, 1]) + if out_dim > 1: + cfg["tile_y"] = SplitEntity([-1, 2, 16, 2]) + else: + cfg["tile_y"] = SplitEntity([1, 1, 1, 1]) + if in_dim > 8: + cfg["tile_k"] = SplitEntity([-1, 8, 1]) + else: + cfg["tile_k"] = SplitEntity([-1, 1, 1]) + + # Explicit memory access + AA = s.cache_read(A, "shared", [C]) + BB = s.cache_read(B, "shared", [C]) + AL = s.cache_read(AA, "local", [C]) + BL = s.cache_read(BB, "local", [C]) + CC = s.cache_write(C, "local") + + # Deal with op fusion + if C.op not in s.outputs: + s[C].compute_inline() + C = s.outputs[0].output(0) + + # Split and reorder computation + bx, txz, tx, xi = cfg["tile_x"].apply(s, C, C.op.axis[0]) + by, tyz, ty, yi = cfg["tile_y"].apply(s, C, C.op.axis[1]) + s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi) + s[CC].compute_at(s[C], tx) + + # Binding + s[C].bind(by, te.thread_axis("blockIdx.y")) + s[C].bind(bx, te.thread_axis("blockIdx.x")) + s[C].bind(tyz, te.thread_axis("vthread")) + s[C].bind(txz, te.thread_axis("vthread")) + s[C].bind(ty, te.thread_axis("threadIdx.y")) + s[C].bind(tx, te.thread_axis("threadIdx.x")) + + # Split reduction + yo, xo = CC.op.axis + ko, kt, ki = cfg["tile_k"].apply(s, CC, k) + s[CC].reorder(ko, kt, ki, yo, xo) + s[AA].compute_at(s[CC], ko) + s[BB].compute_at(s[CC], ko) + s[CC].unroll(kt) + s[AL].compute_at(s[CC], kt) + s[BL].compute_at(s[CC], kt) + + # Schedule for A's shared memory load + num_thread_x = cfg["tile_x"].size[2] + ty, _ = s[AA].split(s[AA].op.axis[0], nparts=num_thread_x) + _, xi = s[AA].split(s[AA].op.axis[1], factor=num_thread_x * 4) + tx, xi = s[AA].split(xi, nparts=num_thread_x) + s[AA].bind(ty, te.thread_axis("threadIdx.y")) + s[AA].bind(tx, te.thread_axis("threadIdx.x")) + s[AA].double_buffer() + + # Schedule for B' shared memory load + num_thread_y = cfg["tile_y"].size[2] + ty, _ = s[BB].split(s[BB].op.axis[0], nparts=num_thread_y) + _, xi = s[BB].split(s[BB].op.axis[1], factor=num_thread_y * 4) + tx, xi = s[BB].split(xi, nparts=num_thread_y) + s[BB].bind(ty, te.thread_axis("threadIdx.y")) + s[BB].bind(tx, te.thread_axis("threadIdx.x")) + s[BB].double_buffer() diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 2628406f6f49..5d52bee44e98 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -63,6 +63,7 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std:: } spirv::Value arg_value = builder_->BufferArgument(builder_->GetSType(value_storage_type), descriptor_set, num_buffer); + builder_->SetName(arg_value, arg->name_hint); storage_info_[arg.get()].UpdateContentType(value_storage_type); var_map_[arg.get()] = arg_value; } else { @@ -144,15 +145,21 @@ spirv::Value CodeGenSPIRV::CreateStorageSync(const CallNode* op) { uint32_t vulkan_api_version = spirv_support_.vulkan_api_version; int64_t sync_scope; - int64_t memory_semantics; + int64_t memory_semantics = spv::MemorySemanticsSequentiallyConsistentMask; if ((sync == "warp") && (vulkan_api_version >= VK_API_VERSION_1_1)) { + // Synchronize control at the Subgroup level, but memory at the + // Workgroup level. This is because different invocations in a + // subgroup may have each modified memory that exists at the + // workgroup scope. This should be changed if/when tir exposes + // more information as to which memory access needs to be + // synchronized. sync_scope = spv::ScopeSubgroup; - memory_semantics = - spv::MemorySemanticsSequentiallyConsistentMask | spv::MemorySemanticsSubgroupMemoryMask; + memory_semantics |= + spv::MemorySemanticsSubgroupMemoryMask | spv::MemorySemanticsWorkgroupMemoryMask; + } else if ((sync == "shared") || (sync == "warp")) { sync_scope = spv::ScopeWorkgroup; - memory_semantics = - spv::MemorySemanticsSequentiallyConsistentMask | spv::MemorySemanticsWorkgroupMemoryMask; + memory_semantics |= spv::MemorySemanticsWorkgroupMemoryMask; } else { LOG(FATAL) << "Do not support sync " << sync; } @@ -161,6 +168,7 @@ spirv::Value CodeGenSPIRV::CreateStorageSync(const CallNode* op) { builder_->MakeInst(spv::OpControlBarrier, builder_->IntImm(type_int, sync_scope), builder_->IntImm(type_int, sync_scope), builder_->IntImm(type_int, memory_semantics)); + return value; } @@ -642,14 +650,16 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { if (info.scope.rank == runtime::StorageRank::kLocal) { buf = builder_->Allocate(etype, static_cast(constant_size), spv::StorageClassFunction); - } else { - // shared memory - ICHECK(info.scope.rank == runtime::StorageRank::kShared) - << "Can only allocate shared or local memory inside kernel"; + } else if (info.scope.rank == runtime::StorageRank::kShared) { // Shared memory buf = builder_->Allocate(etype, static_cast(constant_size), spv::StorageClassWorkgroup); + } else { + LOG(FATAL) << "Can only allocate shared or local memory inside kernel"; } + + builder_->SetName(buf, op->buffer_var->name_hint); + ICHECK(!info.content_fixed); info.UpdateContentType(op->dtype); ICHECK(!var_map_.count(op->buffer_var.get())); diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 9598f07e365e..9e536814fa12 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -388,7 +388,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { size_t size = shared_bufs.size(); PrimExpr buf_index = BufIndex(reduce_index, group_index, reduce_extent); // make reduction - auto freduce = [&](int offset) { + auto fload = [&](int offset) { Array a, b; for (size_t i = 0; i < size; ++i) { b.push_back(Load(types[i], shared_bufs[i], @@ -397,12 +397,19 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { a.push_back(Load(types[i], shared_bufs[i], buf_index, const_true())); } Array ret = (*combiner)(a, b); + return ret; + }; + auto fstore = [&](const Array& ret) { std::vector stores(size); for (size_t i = 0; i < size; ++i) { stores[i] = Store(shared_bufs[i], ret[i], buf_index, const_true()); } return SeqStmt::Flatten(stores); }; + auto freduce = [&](int offset) { + auto ret = fload(offset); + return fstore(ret); + }; // Step one, check for if (reduce_align > reduce_extent) { // reduction with the boundary condition @@ -420,15 +427,47 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { seq.emplace_back(SyncThread("shared")); } // in warp synchronization. - std::vector in_warp_seq; - PrimExpr in_warp_cond = reduce_index < (reduce_align >> 1); - while (reduce_align > 1) { - reduce_align = reduce_align >> 1; - in_warp_seq.emplace_back(freduce(reduce_align)); - in_warp_seq.emplace_back(SyncThread("warp")); - } - if (in_warp_seq.size() != 0) { + if (reduce_align > 1) { + PrimExpr in_warp_cond = reduce_index < (reduce_align >> 1); + + std::vector in_warp_seq; + + while (reduce_align > 1) { + reduce_align = reduce_align >> 1; + + // freduce can read/write to the same memory location. For + // example, with reduce_align of 4, threadIdx 3 reads from + // memory location 7 as threadIdx 7 is writing to it. + // Therefore, we need to separate out the load from the store + // with a memory barrier in-between. This isn't necessary for + // the earlier normal synchronization, because those are each + // protected by an if-statement. The if-statement is avoided + // here to reduce thread divergence. + auto loads = fload(reduce_align); + + Array in_warp_local_vars; + for (auto expr : loads) { + Var var( + "w_" + std::to_string(reduce_align) + "_" + std::to_string(in_warp_local_vars.size()), + expr->dtype); + in_warp_local_vars.push_back(var); + } + + std::vector in_let_statement; + in_let_statement.emplace_back(SyncThread("warp")); + in_let_statement.emplace_back( + fstore({in_warp_local_vars.begin(), in_warp_local_vars.end()})); + in_let_statement.emplace_back(SyncThread("warp")); + + Stmt body = SeqStmt::Flatten(in_let_statement); + for (size_t i = 0; i < size; i++) { + body = LetStmt(in_warp_local_vars[i], loads[i], body); + } + in_warp_seq.push_back(body); + } + Stmt warp_body = SeqStmt::Flatten(in_warp_seq); + seq.emplace_back(IfThenElse(in_warp_cond, warp_body)); seq.emplace_back(SyncThread("shared")); } diff --git a/tests/python/relay/test_autotvm_task_extraction.py b/tests/python/relay/test_autotvm_task_extraction.py index b3f1868969cc..83480a044f45 100644 --- a/tests/python/relay/test_autotvm_task_extraction.py +++ b/tests/python/relay/test_autotvm_task_extraction.py @@ -115,7 +115,7 @@ def get_net(batch, in_dim, out_dim, dtype, out_dtype): mod, params = get_net(1, 16, 32, "float32", "float32") tasks = autotvm.task.extract_from_program(mod, target=target, params=params, ops=(dense,)) - assert len(tasks) == 1 and tasks[0].name == "dense_small_batch.cuda" + assert len(tasks) == 1 and tasks[0].name == "dense_small_batch.gpu" mod, params = get_net(1, 16, 32, "int8", "int32") tasks = autotvm.task.extract_from_program(mod, target=target, params=params, ops=(dense,)) diff --git a/tests/python/topi/python/test_topi_dense.py b/tests/python/topi/python/test_topi_dense.py index 07301fad822c..235a09400387 100644 --- a/tests/python/topi/python/test_topi_dense.py +++ b/tests/python/topi/python/test_topi_dense.py @@ -15,26 +15,37 @@ # specific language governing permissions and limitations # under the License. """Test code for dense operator""" +import contextlib import numpy as np +import pytest +import sys + import tvm -from tvm import te -from tvm import topi +import tvm.testing import tvm.topi.testing +from tvm import te, topi from tvm.topi.utils import get_const_tuple -from tvm.contrib.pickle_memoize import memoize from common import Int8Fallback -import tvm.testing -_dense_implement = { +use_bias = tvm.testing.parameter(True, False) +batch_size = tvm.testing.parameter(1, 2, 128) +in_dim, out_dim = tvm.testing.parameters((1024, 1000)) +in_dtype, out_dtype = tvm.testing.parameters( + ("float32", "float32"), + ("int8", "int32"), +) + + +_dense_implementations = { "generic": [(topi.nn.dense, topi.generic.schedule_dense)], "cpu": [ (topi.x86.dense_nopack, topi.x86.schedule_dense_nopack), (topi.x86.dense_pack, topi.x86.schedule_dense_pack), ], "gpu": [ - (topi.cuda.dense_small_batch, topi.cuda.schedule_dense_small_batch), - (topi.cuda.dense_large_batch, topi.cuda.schedule_dense_large_batch), + (topi.gpu.dense_small_batch, topi.gpu.schedule_dense_small_batch), + (topi.gpu.dense_large_batch, topi.gpu.schedule_dense_large_batch), ], "mali": [(topi.mali.dense, topi.mali.schedule_dense)], "bifrost": [(topi.bifrost.dense, topi.bifrost.schedule_dense)], @@ -43,108 +54,115 @@ } -def verify_dense(batch, in_dim, out_dim, use_bias=True): - A = te.placeholder((batch, in_dim), name="A") - B = te.placeholder((out_dim, in_dim), name="B") - C = te.placeholder((out_dim,), name="C") - dtype = A.dtype - - # use memoize to pickle the test data for next time use - @memoize("topi.tests.test_topi_dense") - def get_ref_data(): - a_np = np.random.uniform(size=(batch, in_dim)).astype(dtype) - b_np = np.random.uniform(size=(out_dim, in_dim)).astype(dtype) - c_np = np.random.uniform(size=(out_dim,)).astype(dtype) - if use_bias: - d_np = np.maximum(np.dot(a_np, b_np.T) + c_np, 0.0) - else: - d_np = np.maximum(np.dot(a_np, b_np.T), 0.0) - return (a_np, b_np, c_np, d_np) - - # get the test data - a_np, b_np, c_np, d_np = get_ref_data() - - def check_device(device, dev): - print("Running on target: %s" % device) - for fcompute, fschedule in tvm.topi.testing.dispatch(device, _dense_implement): - with tvm.target.Target(device): - D = fcompute(A, B, C if use_bias else None) - D = topi.nn.relu(D) - s = fschedule([D]) - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(b_np, dev) - c = tvm.nd.array(c_np, dev) - d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=dtype), dev) - f = tvm.build(s, [A, B, C, D], device, name="dense") - f(a, b, c, d) - tvm.testing.assert_allclose(d.numpy(), d_np, rtol=1e-5) - - for device, dev in tvm.testing.enabled_targets(): - check_device(device, dev) - - -def verify_dense_int8(batch, in_dim, out_dim, use_bias=True): - dtype = "int8" - out_dtype = "int32" - A = te.placeholder((batch, in_dim), name="A", dtype=dtype) - B = te.placeholder((out_dim, in_dim), name="B", dtype=dtype) +@tvm.testing.fixture(cache_return_value=True) +def dense_ref_data(batch_size, in_dim, out_dim, use_bias, in_dtype, out_dtype): + if "float" in in_dtype: + a_np = np.random.uniform(size=(batch_size, in_dim)).astype(in_dtype) + b_np = np.random.uniform(size=(out_dim, in_dim)).astype(in_dtype) + c_np = np.random.uniform(size=(out_dim,)).astype(out_dtype) + elif in_dtype == "int8": + a_np = np.random.randint(low=-128, high=127, size=(batch_size, in_dim)).astype(in_dtype) + b_np = np.random.randint(low=-128, high=127, size=(out_dim, in_dim)).astype(in_dtype) + c_np = np.random.randint(low=-128, high=127, size=(out_dim,)).astype(out_dtype) + else: + raise ValueError("No method to generate test data for data type '{}'".format(in_dtype)) + + matmul = np.dot(a_np.astype(out_dtype), b_np.T.astype(out_dtype)) + + if use_bias: + matmul += c_np + + d_np = np.maximum(matmul, 0) + return (a_np, b_np, c_np, d_np) + + +def test_dense( + target, + dev, + batch_size, + in_dim, + out_dim, + use_bias, + dense_ref_data, + in_dtype, + out_dtype, + implementations=None, +): + target = tvm.target.Target(target) + + if ( + in_dtype == "int8" + and target.kind.name == "cuda" + and not tvm.contrib.nvcc.have_int8(dev.compute_version) + ): + pytest.xfail("CUDA int8 intrinsics not available") + + if ( + in_dtype == "int8" + and target.kind.name == "vulkan" + and not target.attrs.get("supports_int8", False) + ): + pytest.xfail("Vulkan int8 driver support not available") + + if ( + target.kind.name not in ["llvm", "c"] + and len(set(target.keys) & set(_dense_implementations)) == 0 + ): + pytest.xfail("No implementation for tvm.topi.testing.dispatch to find") + + A = te.placeholder((batch_size, in_dim), name="A", dtype=in_dtype) + B = te.placeholder((out_dim, in_dim), name="B", dtype=in_dtype) C = te.placeholder((out_dim,), name="C", dtype=out_dtype) - # use memoize to pickle the test data for next time use - @memoize("topi.tests.test_topi_dense_int8") - def get_ref_data(): - a_np = np.random.randint(low=-128, high=127, size=(batch, in_dim)).astype(dtype) - b_np = np.random.randint(low=-128, high=127, size=(out_dim, in_dim)).astype(dtype) - c_np = np.random.randint(low=-128, high=127, size=(out_dim,)).astype(out_dtype) - d_np = np.dot(a_np.astype(out_dtype), b_np.T.astype(out_dtype)) - if use_bias: - d_np += c_np - d_np = np.maximum(d_np, 0.0) - return (a_np, b_np, c_np, d_np) - - # get the test data - a_np, b_np, c_np, d_np = get_ref_data() - - def check_device(device): - dev = tvm.device(device, 0) - if device == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version): - print("Skip because int8 intrinsics are not available") - return - - print("Running on target: %s" % device) - with tvm.target.Target(device): - D = topi.cuda.dense_int8(A, B, C if use_bias else None, out_dtype) + a_np, b_np, c_np, d_np = dense_ref_data + + if implementations is None: + implementations = tvm.topi.testing.dispatch(target, _dense_implementations) + + for fcompute, fschedule in implementations: + with tvm.target.Target(target): + D = fcompute(A, B, C if use_bias else None, out_dtype) D = topi.nn.relu(D) - s = topi.cuda.schedule_dense_int8([D]) + s = fschedule([D]) + a = tvm.nd.array(a_np, dev) b = tvm.nd.array(b_np, dev) c = tvm.nd.array(c_np, dev) d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=out_dtype), dev) - f = tvm.build(s, [A, B, C, D], device, name="dense") + f = tvm.build(s, [A, B, C, D], target, name="dense") f(a, b, c, d) tvm.testing.assert_allclose(d.numpy(), d_np, rtol=1e-5) - for device in ["cuda"]: - check_device(device) - - -@tvm.testing.uses_gpu -def test_dense(): - verify_dense(1, 1024, 1000, use_bias=True) - verify_dense(1, 1024, 1000, use_bias=False) - verify_dense(2, 1024, 1000, use_bias=True) - verify_dense(128, 1024, 1000, use_bias=False) - verify_dense(128, 1024, 1000, use_bias=True) - -@tvm.testing.requires_cuda -@tvm.testing.requires_gpu -def test_dense_int8(): +@pytest.mark.parametrize("target,in_dtype,out_dtype", [("cuda", "int8", "int32")]) +def test_dense_cuda_int8( + target, + dev, + batch_size, + in_dim, + out_dim, + use_bias, + dense_ref_data, + in_dtype, + out_dtype, +): + implementations = [ + (topi.cuda.dense_int8, topi.cuda.schedule_dense_int8), + ] with Int8Fallback(): - verify_dense_int8(2, 1024, 1000, use_bias=True) - verify_dense_int8(2, 1024, 1000, use_bias=False) + test_dense( + target, + dev, + batch_size, + in_dim, + out_dim, + use_bias, + dense_ref_data, + in_dtype, + out_dtype, + implementations=implementations, + ) if __name__ == "__main__": - test_dense() - test_dense_int8() + sys.exit(pytest.main(sys.argv))