-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Relax][Pass] Lowering passes for GPU IPC memory and allreduce
This PR introduces the lowering passes for GPU IPC memory and all-reduce. It contains the following changes: 1. a pass `IPCAllreduceRewrite` which rewrites `"runtime.disco.allreduce"` to `"runtime.disco.cuda_ipc.custom_allreduce"`, and rewrites the storage scopes of the all-reduce inputs's from "global" to "ipc_memory" accordingly. 2. memory planning enhancement, making the planning be aware of storage scopes. So each storage scope will be planned independently. 3. a pass `LowerGPUIPCAllocStorage` that rewrites the storage allocation of IPC memory from builtin ops to calls to function `"runtime.disco.cuda_ipc.alloc_storage"`. 4. supports the op `relax.builtin.alloc_tensor` with storage scope. The default storage scope is `"global"`. We write the new passes in Python for experiment and fast development. These are good demos showing we can have efficient development with the architecture enabled by TVM.
- Loading branch information
1 parent
e257fb8
commit 0af4ac4
Showing
10 changed files
with
554 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Rewrite all-reduce operation to customized all-reduce impl with IPC memory. | ||
The pass is written in Python for experiment, fast development. | ||
""" | ||
|
||
from typing import Dict | ||
|
||
import tvm | ||
from tvm import relax | ||
from tvm.ir.module import IRModule | ||
from tvm.relax.analysis import remove_all_unused | ||
from tvm.relax.expr import Expr, Var | ||
from tvm.relax.expr_functor import PyExprMutator, PyExprVisitor, mutator, visitor | ||
|
||
|
||
@tvm.transform.module_pass(opt_level=0, name="IPCAllReduceRewrite") | ||
class IPCAllReduceRewrite: | ||
"""Rewrite all-reduce operation to customized all-reduce impl with IPC memory.""" | ||
|
||
def __init__(self, allreduce_strategy: int) -> None: | ||
"""Constructor | ||
Parameters | ||
---------- | ||
allreduce_strategy : int | ||
The all-reduce strategy. Only "1" and "2" are supported. | ||
"1" stands for one-shot, and "2" stands for two-shot. | ||
""" | ||
if allreduce_strategy not in [1, 2]: | ||
raise ValueError(f"All-reduce strategy {allreduce_strategy} is not supported.") | ||
self.allreduce_strategy = allreduce_strategy | ||
|
||
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: | ||
"""IRModule-level transformation""" | ||
fcustom_allreduce = tvm.get_global_func( | ||
"runtime.disco.cuda_ipc.custom_allreduce", allow_missing=True | ||
) | ||
if fcustom_allreduce is None: | ||
# Customized allreduce is not available. | ||
return mod | ||
|
||
binding_replacement_map = _Visitor(self.allreduce_strategy).visit(mod) | ||
return _Rewriter(mod, binding_replacement_map).transform() | ||
|
||
|
||
@visitor | ||
class _Visitor(PyExprVisitor): # pylint: disable=abstract-method | ||
def __init__(self, allreduce_strategy: int) -> None: | ||
self.allreduce_strategy = allreduce_strategy | ||
self.alloc_map: Dict[Var, relax.Call] = {} | ||
self.binding_replacement_map: Dict[relax.Expr, relax.Expr] = {} | ||
self.builtin_alloc_tensor_op = tvm.ir.Op.get("relax.builtin.alloc_tensor") | ||
self.reshape_op = tvm.ir.Op.get("relax.reshape") | ||
|
||
def visit(self, mod: IRModule) -> Dict[relax.Expr, relax.Expr]: | ||
"""Entry point""" | ||
for _, func in mod.functions_items(): | ||
if isinstance(func, relax.Function): | ||
self.alloc_map.clear() | ||
self.visit_expr(func) | ||
return self.binding_replacement_map | ||
|
||
def visit_var_binding_(self, binding: relax.VarBinding): | ||
super().visit_var_binding_(binding) | ||
if ( | ||
isinstance(binding.value, relax.Call) | ||
and binding.value.op == self.builtin_alloc_tensor_op | ||
): | ||
self.alloc_map[binding.var] = binding.value | ||
elif isinstance(binding.value, relax.Var) and binding.value in self.alloc_map: | ||
self.alloc_map[binding.var] = self.alloc_map[binding.value] | ||
elif ( | ||
isinstance(binding.value, relax.Call) | ||
and binding.value.op == self.reshape_op | ||
and binding.value.args[0] in self.alloc_map | ||
): | ||
self.alloc_map[binding.var] = self.alloc_map[binding.value.args[0]] | ||
|
||
def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-renamed | ||
if ( | ||
not isinstance(call.op, relax.ExternFunc) | ||
or call.op.global_symbol != "runtime.disco.allreduce" | ||
or call.args[1].values[0] != 0 | ||
): | ||
# Return if the call is not a summation all-reduce. | ||
return call | ||
|
||
assert len(call.args) == 3 | ||
allreduce_input = call.args[0] | ||
alloc_tensor = self.alloc_map.get(allreduce_input, None) | ||
if alloc_tensor is None or alloc_tensor.args[3].value != "global": | ||
# Return if the allocation of all-reduce input is not recorded, | ||
# or the scope of the allocation is not global. | ||
return call | ||
|
||
# Set the scope of the alloc_tensor to IPC memory. | ||
alloc_tensor = self.alloc_map[allreduce_input] | ||
self.binding_replacement_map[alloc_tensor] = relax.op.builtin.alloc_tensor( | ||
alloc_tensor.args[0], | ||
alloc_tensor.args[1], | ||
alloc_tensor.args[2], | ||
relax.StringImm("ipc_memory"), | ||
) | ||
self.binding_replacement_map[call] = relax.Call( | ||
relax.ExternFunc("runtime.disco.cuda_ipc.custom_allreduce"), | ||
args=[call.args[0], relax.PrimValue(self.allreduce_strategy), call.args[2]], | ||
) | ||
|
||
|
||
@mutator | ||
class _Rewriter(PyExprMutator): | ||
"""Rewrite the IRModule according to the binding replacement provided by the visitor.""" | ||
|
||
def __init__( | ||
self, mod: IRModule, binding_replacement_map: Dict[relax.Expr, relax.Expr] | ||
) -> None: | ||
super().__init__(mod) | ||
self.mod = mod | ||
self.binding_replacement_map = binding_replacement_map | ||
|
||
def transform(self) -> IRModule: | ||
"""Entry point""" | ||
for g_var, func in self.mod.functions_items(): | ||
if isinstance(func, relax.Function): | ||
updated_func = self.visit_expr(func) | ||
updated_func = remove_all_unused(updated_func) | ||
self.builder_.update_func(g_var, updated_func) | ||
return self.builder_.get() | ||
|
||
def visit_call_(self, call: relax.Call) -> Expr: # pylint: disable=arguments-renamed | ||
return ( | ||
super().visit_call_(self.binding_replacement_map[call]) | ||
if call in self.binding_replacement_map | ||
else super().visit_call_(call) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Lower the storage/tensor allocation on IPC memory. | ||
The pass is written in Python for experiment, fast development. | ||
""" | ||
|
||
import tvm | ||
from tvm import relax | ||
from tvm.ir.module import IRModule | ||
from tvm.relax.analysis import remove_all_unused | ||
from tvm.relax.expr import Expr | ||
from tvm.relax.expr_functor import PyExprMutator, mutator | ||
|
||
|
||
@tvm.transform.module_pass(opt_level=0, name="LowerGPUIPCAllocStorage") | ||
class LowerGPUIPCAllocStorage: | ||
"""Lower the storage/tensor allocation on IPC memory.""" | ||
|
||
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: | ||
"""IRModule-level transformation""" | ||
return _Rewriter(mod).transform() | ||
|
||
|
||
@mutator | ||
class _Rewriter(PyExprMutator): | ||
def __init__(self, mod: IRModule) -> None: | ||
super().__init__(mod) | ||
self.mod = mod | ||
self.memory_alloc_storage_op = tvm.ir.Op.get("relax.memory.alloc_storage") | ||
self.memory_alloc_tensor_op = tvm.ir.Op.get("relax.memory.alloc_tensor") | ||
self.builtin_alloc_tensor_op = tvm.ir.Op.get("relax.builtin.alloc_tensor") | ||
|
||
def transform(self) -> IRModule: | ||
"""Entry point""" | ||
for g_var, func in self.mod.functions_items(): | ||
if isinstance(func, relax.Function): | ||
updated_func = self.visit_expr(func) | ||
updated_func = remove_all_unused(updated_func) | ||
self.builder_.update_func(g_var, updated_func) | ||
return self.builder_.get() | ||
|
||
def visit_call_(self, call: relax.Call) -> Expr: # pylint: disable=arguments-renamed | ||
if call.op == self.memory_alloc_storage_op and call.args[2].value == "ipc_memory": | ||
return self.rewrite_alloc_storage(call) | ||
elif call.op == self.builtin_alloc_tensor_op and call.args[3].value == "ipc_memory": | ||
return self.rewrite_alloc_tensor(call) | ||
else: | ||
return call | ||
|
||
def rewrite_alloc_storage(self, call: relax.Call) -> relax.Call: | ||
shape = call.args[0] | ||
dtype = call.args[3] | ||
return relax.Call( | ||
relax.ExternFunc("runtime.disco.cuda_ipc.alloc_storage"), | ||
args=[shape, dtype], | ||
sinfo_args=[call.struct_info], | ||
) | ||
|
||
def rewrite_alloc_tensor(self, call: relax.Call) -> relax.Call: | ||
shape = call.args[0] | ||
dtype = call.args[1] | ||
ipc_alloc_storage = relax.Call( | ||
relax.ExternFunc("runtime.disco.cuda_ipc.alloc_storage"), | ||
args=[shape, dtype], | ||
sinfo_args=[relax.ObjectStructInfo()], | ||
) | ||
return relax.Call( | ||
self.memory_alloc_tensor_op, | ||
args=[ipc_alloc_storage, call.args[2], shape, dtype], | ||
sinfo_args=call.sinfo_args, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.