Skip to content

Commit

Permalink
[Relax][Pass] Lowering passes for GPU IPC memory and allreduce
Browse files Browse the repository at this point in the history
This PR introduces the lowering passes for GPU IPC memory and
all-reduce. It contains the following changes:

1. a pass `IPCAllreduceRewrite` which rewrites `"runtime.disco.allreduce"`
to `"runtime.disco.cuda_ipc.custom_allreduce"`, and rewrites
the storage scopes of the all-reduce inputs's from "global" to
"ipc_memory" accordingly.

2. memory planning enhancement, making the planning be aware of
storage scopes. So each storage scope will be planned independently.

3. a pass `LowerGPUIPCAllocStorage` that rewrites the storage allocation
of IPC memory from builtin ops to calls to function `"runtime.disco.cuda_ipc.alloc_storage"`.

4. supports the op `relax.builtin.alloc_tensor` with storage scope.
The default storage scope is `"global"`.

We write the new passes in Python for experiment and fast development.
These are good demos showing we can have efficient development
with the architecture enabled by TVM.
  • Loading branch information
MasterJH5574 committed Mar 20, 2024
1 parent e257fb8 commit 0af4ac4
Show file tree
Hide file tree
Showing 10 changed files with 554 additions and 36 deletions.
20 changes: 17 additions & 3 deletions python/tvm/relax/op/builtin/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,18 @@
"""The builtin Relax operators."""

from typing import Union
from ...expr import Call, Expr, PrimValue, DataTypeImm

from ...expr import Call, DataTypeImm, Expr, PrimValue, StringImm
from ...utils import args_converter
from . import _ffi_api


@args_converter.auto
def alloc_tensor(
shape: Expr, dtype: Union[str, Expr], runtime_device_index: Union[int, Expr]
shape: Expr,
dtype: Union[str, Expr],
runtime_device_index: Union[int, Expr],
storage_scope: Union[str, Expr] = "global",
) -> Call:
"""Construct a Call to allocate a tensor with specific shape, dtype, runtime_device_index.
Expand All @@ -39,6 +43,9 @@ def alloc_tensor(
The device index indicating on which device the tensor is to be allocated at runtime.
Index -1 is reserved for the host device.
storage_scope : Union[str, Expr]
The storage scope to allocate the storage to.
Returns
-------
result : Call
Expand All @@ -48,8 +55,15 @@ def alloc_tensor(
dtype = DataTypeImm(dtype)
if isinstance(runtime_device_index, int):
runtime_device_index = PrimValue(runtime_device_index)
if isinstance(storage_scope, str):
storage_scope = StringImm(storage_scope)
if not isinstance(storage_scope, StringImm):
raise ValueError(
"relax.builtin.alloc_tensor expects string as the storage scope, "
f"but {storage_scope} is got."
)

return _ffi_api.alloc_tensor(shape, dtype, runtime_device_index) # type: ignore
return _ffi_api.alloc_tensor(shape, dtype, runtime_device_index, storage_scope) # type: ignore


def stop_lift_params(x: Expr) -> Expr:
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relax/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@
function_pass,
)

from .ipc_allreduce_rewrite import IPCAllReduceRewrite
from .lazy_transform_params import LazyTransformParams
from .lower_gpu_ipc_alloc_storage import LowerGPUIPCAllocStorage
from .optimize_layout_transform import OptimizeLayoutTransform
from .remove_redundant_reshape import RemoveRedundantReshape
from .fast_math import FastMathTransform
Expand Down
150 changes: 150 additions & 0 deletions python/tvm/relax/transform/ipc_allreduce_rewrite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Rewrite all-reduce operation to customized all-reduce impl with IPC memory.
The pass is written in Python for experiment, fast development.
"""

from typing import Dict

import tvm
from tvm import relax
from tvm.ir.module import IRModule
from tvm.relax.analysis import remove_all_unused
from tvm.relax.expr import Expr, Var
from tvm.relax.expr_functor import PyExprMutator, PyExprVisitor, mutator, visitor


@tvm.transform.module_pass(opt_level=0, name="IPCAllReduceRewrite")
class IPCAllReduceRewrite:
"""Rewrite all-reduce operation to customized all-reduce impl with IPC memory."""

def __init__(self, allreduce_strategy: int) -> None:
"""Constructor
Parameters
----------
allreduce_strategy : int
The all-reduce strategy. Only "1" and "2" are supported.
"1" stands for one-shot, and "2" stands for two-shot.
"""
if allreduce_strategy not in [1, 2]:
raise ValueError(f"All-reduce strategy {allreduce_strategy} is not supported.")
self.allreduce_strategy = allreduce_strategy

def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
"""IRModule-level transformation"""
fcustom_allreduce = tvm.get_global_func(
"runtime.disco.cuda_ipc.custom_allreduce", allow_missing=True
)
if fcustom_allreduce is None:
# Customized allreduce is not available.
return mod

binding_replacement_map = _Visitor(self.allreduce_strategy).visit(mod)
return _Rewriter(mod, binding_replacement_map).transform()


@visitor
class _Visitor(PyExprVisitor): # pylint: disable=abstract-method
def __init__(self, allreduce_strategy: int) -> None:
self.allreduce_strategy = allreduce_strategy
self.alloc_map: Dict[Var, relax.Call] = {}
self.binding_replacement_map: Dict[relax.Expr, relax.Expr] = {}
self.builtin_alloc_tensor_op = tvm.ir.Op.get("relax.builtin.alloc_tensor")
self.reshape_op = tvm.ir.Op.get("relax.reshape")

def visit(self, mod: IRModule) -> Dict[relax.Expr, relax.Expr]:
"""Entry point"""
for _, func in mod.functions_items():
if isinstance(func, relax.Function):
self.alloc_map.clear()
self.visit_expr(func)
return self.binding_replacement_map

def visit_var_binding_(self, binding: relax.VarBinding):
super().visit_var_binding_(binding)
if (
isinstance(binding.value, relax.Call)
and binding.value.op == self.builtin_alloc_tensor_op
):
self.alloc_map[binding.var] = binding.value
elif isinstance(binding.value, relax.Var) and binding.value in self.alloc_map:
self.alloc_map[binding.var] = self.alloc_map[binding.value]
elif (
isinstance(binding.value, relax.Call)
and binding.value.op == self.reshape_op
and binding.value.args[0] in self.alloc_map
):
self.alloc_map[binding.var] = self.alloc_map[binding.value.args[0]]

def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-renamed
if (
not isinstance(call.op, relax.ExternFunc)
or call.op.global_symbol != "runtime.disco.allreduce"
or call.args[1].values[0] != 0
):
# Return if the call is not a summation all-reduce.
return call

assert len(call.args) == 3
allreduce_input = call.args[0]
alloc_tensor = self.alloc_map.get(allreduce_input, None)
if alloc_tensor is None or alloc_tensor.args[3].value != "global":
# Return if the allocation of all-reduce input is not recorded,
# or the scope of the allocation is not global.
return call

# Set the scope of the alloc_tensor to IPC memory.
alloc_tensor = self.alloc_map[allreduce_input]
self.binding_replacement_map[alloc_tensor] = relax.op.builtin.alloc_tensor(
alloc_tensor.args[0],
alloc_tensor.args[1],
alloc_tensor.args[2],
relax.StringImm("ipc_memory"),
)
self.binding_replacement_map[call] = relax.Call(
relax.ExternFunc("runtime.disco.cuda_ipc.custom_allreduce"),
args=[call.args[0], relax.PrimValue(self.allreduce_strategy), call.args[2]],
)


@mutator
class _Rewriter(PyExprMutator):
"""Rewrite the IRModule according to the binding replacement provided by the visitor."""

def __init__(
self, mod: IRModule, binding_replacement_map: Dict[relax.Expr, relax.Expr]
) -> None:
super().__init__(mod)
self.mod = mod
self.binding_replacement_map = binding_replacement_map

def transform(self) -> IRModule:
"""Entry point"""
for g_var, func in self.mod.functions_items():
if isinstance(func, relax.Function):
updated_func = self.visit_expr(func)
updated_func = remove_all_unused(updated_func)
self.builder_.update_func(g_var, updated_func)
return self.builder_.get()

def visit_call_(self, call: relax.Call) -> Expr: # pylint: disable=arguments-renamed
return (
super().visit_call_(self.binding_replacement_map[call])
if call in self.binding_replacement_map
else super().visit_call_(call)
)
85 changes: 85 additions & 0 deletions python/tvm/relax/transform/lower_gpu_ipc_alloc_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Lower the storage/tensor allocation on IPC memory.
The pass is written in Python for experiment, fast development.
"""

import tvm
from tvm import relax
from tvm.ir.module import IRModule
from tvm.relax.analysis import remove_all_unused
from tvm.relax.expr import Expr
from tvm.relax.expr_functor import PyExprMutator, mutator


@tvm.transform.module_pass(opt_level=0, name="LowerGPUIPCAllocStorage")
class LowerGPUIPCAllocStorage:
"""Lower the storage/tensor allocation on IPC memory."""

def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
"""IRModule-level transformation"""
return _Rewriter(mod).transform()


@mutator
class _Rewriter(PyExprMutator):
def __init__(self, mod: IRModule) -> None:
super().__init__(mod)
self.mod = mod
self.memory_alloc_storage_op = tvm.ir.Op.get("relax.memory.alloc_storage")
self.memory_alloc_tensor_op = tvm.ir.Op.get("relax.memory.alloc_tensor")
self.builtin_alloc_tensor_op = tvm.ir.Op.get("relax.builtin.alloc_tensor")

def transform(self) -> IRModule:
"""Entry point"""
for g_var, func in self.mod.functions_items():
if isinstance(func, relax.Function):
updated_func = self.visit_expr(func)
updated_func = remove_all_unused(updated_func)
self.builder_.update_func(g_var, updated_func)
return self.builder_.get()

def visit_call_(self, call: relax.Call) -> Expr: # pylint: disable=arguments-renamed
if call.op == self.memory_alloc_storage_op and call.args[2].value == "ipc_memory":
return self.rewrite_alloc_storage(call)
elif call.op == self.builtin_alloc_tensor_op and call.args[3].value == "ipc_memory":
return self.rewrite_alloc_tensor(call)
else:
return call

def rewrite_alloc_storage(self, call: relax.Call) -> relax.Call:
shape = call.args[0]
dtype = call.args[3]
return relax.Call(
relax.ExternFunc("runtime.disco.cuda_ipc.alloc_storage"),
args=[shape, dtype],
sinfo_args=[call.struct_info],
)

def rewrite_alloc_tensor(self, call: relax.Call) -> relax.Call:
shape = call.args[0]
dtype = call.args[1]
ipc_alloc_storage = relax.Call(
relax.ExternFunc("runtime.disco.cuda_ipc.alloc_storage"),
args=[shape, dtype],
sinfo_args=[relax.ObjectStructInfo()],
)
return relax.Call(
self.memory_alloc_tensor_op,
args=[ipc_alloc_storage, call.args[2], shape, dtype],
sinfo_args=call.sinfo_args,
)
9 changes: 6 additions & 3 deletions src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -841,19 +841,22 @@ StructInfo InferStructInfoAllocateTensor(const Call& call, const BlockBuilder& c
}

RELAY_REGISTER_OP("relax.builtin.alloc_tensor")
.set_num_inputs(3)
.set_num_inputs(4)
.add_argument("shape", "Expr", "The shape of the tensor to allocate.")
.add_argument("dtype", "DataTypeImm", "The dtype of the tensor to allocate.")
.add_argument("runtime_device_index", "PrimValue",
"The device index indicating on which device the tensor is to be "
"allocated at runtime. Index -1 is reserved for the host device.")
.add_argument("storage_scope", "StringImm",
"The storage scope of the storage to allocate. Default is global.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAllocateTensor)
// memory allocation isn't considered a "visible effect" as far as purity is concerned
.set_attr<Bool>("FPurity", Bool(true));

Expr MakeAllocTensor(Expr shape, DataTypeImm dtype, PrimValue runtime_device_index) {
Expr MakeAllocTensor(Expr shape, DataTypeImm dtype, PrimValue runtime_device_index,
StringImm storage_scope) {
static const Op& op = Op::Get("relax.builtin.alloc_tensor");
return Call(op, {shape, dtype, runtime_device_index}, Attrs(), {});
return Call(op, {shape, dtype, runtime_device_index, storage_scope}, Attrs(), {});
}

TVM_REGISTER_GLOBAL("relax.op.builtin.alloc_tensor").set_body_typed(MakeAllocTensor);
Expand Down
24 changes: 12 additions & 12 deletions src/relax/transform/call_tir_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@ class CallTIRMutator : public ExprMutator {
dev_index = GetDeviceIndex(mod_, tensor_sinfo->vdevice.value());
}
if (!is_inplace) {
outs.push_back(
builder_->Emit(Call(alloc_tensor_op,
{Downcast<ShapeExpr>(tensor_sinfo->shape.value()),
DataTypeImm(tensor_sinfo->dtype), PrimValue::Int64(dev_index)},
Attrs()),
"alloc"));
outs.push_back(builder_->Emit(Call(alloc_tensor_op,
{Downcast<ShapeExpr>(tensor_sinfo->shape.value()),
DataTypeImm(tensor_sinfo->dtype),
PrimValue::Int64(dev_index), StringImm("global")},
Attrs()),
"alloc"));
} else {
// if there is only one output, it must be an in-place argument, but check anyway
ICHECK(inplace_attrs->inplace_indices[0].IntValue() != -1)
Expand All @@ -113,12 +113,12 @@ class CallTIRMutator : public ExprMutator {
<< "call_tir expects all TensorStructInfo has shape, but got " << field_tensor
<< " as an element of TupleStructInfo";
if (!is_inplace || inplace_attrs->inplace_indices[i].IntValue() == -1) {
outs.push_back(
builder_->Emit(Call(alloc_tensor_op,
{Downcast<ShapeExpr>(field_tensor->shape.value()),
DataTypeImm(field_tensor->dtype), PrimValue::Int64(0)},
Attrs()),
"alloc"));
outs.push_back(builder_->Emit(
Call(alloc_tensor_op,
{Downcast<ShapeExpr>(field_tensor->shape.value()),
DataTypeImm(field_tensor->dtype), PrimValue::Int64(0), StringImm("global")},
Attrs()),
"alloc"));
} else {
outs.push_back(Downcast<Tuple>(call->args[1])
->fields[inplace_attrs->inplace_indices[i].IntValue()]);
Expand Down
12 changes: 6 additions & 6 deletions src/relax/transform/lower_alloc_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ class Mutator : public ExprMutator {
static const Op& mem_alloc_tensor_op = Op::Get("relax.memory.alloc_tensor");

if (op->op.same_as(alloc_tensor_op)) {
CHECK_EQ(op->args.size(), 3) << "Op " << op->op << " should have three arguments, "
<< "[shape, dtype, runtime_device_index]. "
CHECK_EQ(op->args.size(), 4) << "Op " << op->op << " should have three arguments, "
<< "[shape, dtype, runtime_device_index, storage_scope]. "
<< "However, received " << GetRef<Call>(op);

auto shape_arg = op->args[0];
auto dtype = Downcast<DataTypeImm>(op->args[1]);
PrimValue runtime_device_index = Downcast<PrimValue>(op->args[2]);
std::string storage_scope = "global";
StringImm storage_scope = Downcast<StringImm>(op->args[3]);

auto shape = [&]() -> Array<PrimExpr> {
if (auto ptr = shape_arg.as<ShapeExprNode>()) {
Expand Down Expand Up @@ -71,9 +71,9 @@ class Mutator : public ExprMutator {

auto offset = PrimValue::Int64(0);

Expr storage = relax::Call(mem_alloc_storage_op,
{ShapeExpr({nbytes}), runtime_device_index,
StringImm(storage_scope), DataTypeImm(DataType::UInt(8))});
Expr storage =
relax::Call(mem_alloc_storage_op, {ShapeExpr({nbytes}), runtime_device_index,
storage_scope, DataTypeImm(DataType::UInt(8))});
storage = builder_->Emit(storage, "storage");
Expr tensor = relax::Call(mem_alloc_tensor_op, {storage, offset, shape_arg, dtype});
return tensor;
Expand Down
Loading

0 comments on commit 0af4ac4

Please sign in to comment.