Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Capture symbolic vars in struct info of weights #16834

Merged
merged 1 commit into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 33 additions & 15 deletions src/relax/transform/rewrite_cuda_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,13 +239,31 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
if (pair.second->IsInstance<FunctionNode>()) {
// If a function has the num_input attribute, the last func->params.size() - num_inputs
// inputs are assumed to be fixed and thus they can be captured into a cuda graph.
// The symbolic variables in the struct info of the fixed inputs (weights) are also allowed
// to be captured.
// If the hints for capturing symbolic variables via
// 'relax.rewrite_cuda_graph.capture_symbolic_vars' annotation, the actual variables with
// these names are extracted from the struct info for the capturing.
const auto& func = Downcast<Function>(pair.second);
if (auto num_input = func->attrs.GetAttr<Integer>(attr::kNumInput)) {
for (size_t i = num_input.value().IntValue(); i < func->params.size(); ++i) {
auto num_inputs =
func->attrs.GetAttr<Integer>(attr::kNumInput).value_or(Integer(func->params.size()));
auto capture_symbolic_var_name_hints = ExtractSymbolicVarHints(func);
for (int i = 0; i < static_cast<int>(func->params.size()); ++i) {
Array<tir::Var> symbolic_vars = DefinableTIRVarsInStructInfo(
Downcast<StructInfo>(func->params[i]->struct_info_.value()));
if (i < num_inputs.IntValue()) {
for (const auto& symbolic_var : symbolic_vars) {
if (capture_symbolic_var_name_hints.count(symbolic_var->name_hint)) {
capture_symbolic_vars_.insert(symbolic_var.get());
}
}
} else {
static_vars_.insert(func->params[i].get());
for (const auto& symbolic_var : symbolic_vars) {
capture_symbolic_vars_.insert(symbolic_var.get());
}
}
}
CollectSymbolicVarHints(func);
disabled_storage_vars_ = OutputStorageCollector::Collect(func);
VisitExpr(func);
}
Expand Down Expand Up @@ -284,17 +302,16 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
}

/*!
* \brief Collect the name hints of the symbolic variables that are allowed to be captured.
* \brief Extract the name hints of the symbolic variables that are allowed to be captured
* from the function attributes.
*/
void CollectSymbolicVarHints(const Function& func) {
capture_symbolic_vars_.clear();
if (auto symbolic_vars =
func->attrs.GetAttr<Array<String>>("relax.rewrite_cuda_graph.capture_symbolic_vars")) {
for (const auto& var : symbolic_vars.value()) {
capture_symbolic_vars_.insert(var);
}
}
std::unordered_set<String> ExtractSymbolicVarHints(const Function& func) {
auto symbolic_var_names =
func->attrs.GetAttr<Array<String>>("relax.rewrite_cuda_graph.capture_symbolic_vars")
.value_or(Array<String>());
return {symbolic_var_names.begin(), symbolic_var_names.end()};
}

/*!
*\brief Start a new static region. This method should be called when encountering a
* CUDA kernel launch (calls to PrimFunc or ExternFunc) that only depends on static parameters.
Expand Down Expand Up @@ -467,7 +484,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
bool is_static = true;
tir::PostOrderVisit(expr, [&](const ObjectRef& e) {
if (auto var = e.as<tir::VarNode>()) {
if (!capture_symbolic_vars_.count(var->name_hint)) {
if (!capture_symbolic_vars_.count(var)) {
is_static = false;
return;
}
Expand Down Expand Up @@ -596,8 +613,9 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
FunctionScope current_function_scope_;
// Variables whose buffer address is fixed
std::unordered_set<const VarNode*> static_vars_;
// The name of the variables that are allowed to be symbolic
std::unordered_set<String> capture_symbolic_vars_;
// Symbolic variables that are allowed to be captured. This can come from symbolic shapes of
// weights or hints in the function annotations.
std::unordered_set<const tir::VarNode*> capture_symbolic_vars_;
// Binding to the FuncBuilder if the binding is lifted. This is used to update the inputs/outputs
// of the lifted function when its binding is used outside.
std::unordered_map<const VarNode*, FuncBuilder*> binding_to_region_;
Expand Down
88 changes: 88 additions & 0 deletions tests/python/relax/test_transform_rewrite_cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,5 +1088,93 @@ def main(x: R.Tensor((8,), dtype="float32")) -> R.Tuple(R.Tensor((8,), dtype="fl
return gv


class TestStaticInputWithSymbolicShape(BaseCompare):
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor((8,), "float16"), w: R.Tensor(("m",))):
m = T.int64()
R.func_attr({"relax.force_pure": True, "num_input": 1})
storage1 = R.memory.alloc_storage(R.shape([8]), 0, "global", "float16")
alloc1 = R.memory.alloc_tensor(storage1, 0, R.shape([8]), "float16")
_ = R.call_packed("dummy", x, w, alloc1, sinfo_args=(R.Tuple,))
storage2 = R.memory.alloc_storage(R.shape([8]), 0, "global", "float16")
alloc2 = R.memory.alloc_tensor(storage2, 0, R.shape([8]), "float16")
_1 = R.call_packed("dummy", alloc1, w, alloc2, sinfo_args=(R.Tuple,))
storage3 = R.memory.alloc_storage(R.shape([8]), 0, "global", "float16")
alloc3 = R.memory.alloc_tensor(storage3, 0, R.shape([8]), "float16")
_2 = R.call_packed("dummy", alloc2, w, alloc3, sinfo_args=(R.Tuple,))
gv = (alloc3,)
return gv

@I.ir_module
class Expected:
@R.function(private=True)
def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object):
R.func_attr({"relax.force_pure": True})
storage1: R.Object = R.memory.alloc_storage(
R.shape([8]), R.prim_value(0), R.str("global"), R.dtype("float16")
)
storage2: R.Object = R.memory.alloc_storage(
R.shape([8]), R.prim_value(0), R.str("global"), R.dtype("float16")
)
gv: R.Tuple(R.Object, R.Object) = storage1, storage2
return gv

@R.function(private=True)
def main_cuda_graph_capture(
alloc1: R.Tensor((8,), dtype="float16"),
w: R.Tensor(("m",)),
alloc2: R.Tensor((8,), dtype="float16"),
shape_expr: R.Shape(["m"]),
) -> R.Tuple:
m = T.int64()
R.func_attr({"relax.force_pure": True})
R.call_packed("dummy", alloc1, w, alloc2, sinfo_args=(R.Tuple,))
R.tuple()
return R.tuple()

@R.function
def main(
x: R.Tensor((8,), dtype="float16"), w: R.Tensor(("m",))
) -> R.Tuple(R.Tensor((8,), dtype="float16")):
m = T.int64()
R.func_attr({"num_input": 1, "relax.force_pure": True})
cls = Expected
gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx(
"vm.builtin.cuda_graph.get_cached_alloc",
(cls.cuda_graph_alloc, R.prim_value(0)),
sinfo_args=(R.Tuple(R.Object, R.Object),),
)
storage1: R.Object = gv[0]
alloc1: R.Tensor((8,), dtype="float16") = R.memory.alloc_tensor(
storage1, R.prim_value(0), R.shape([8]), R.dtype("float16")
)
R.call_packed("dummy", x, w, alloc1, sinfo_args=(R.Tuple,))
storage2: R.Object = gv[1]
alloc2: R.Tensor((8,), dtype="float16") = R.memory.alloc_tensor(
storage2, R.prim_value(0), R.shape([8]), R.dtype("float16")
)
R.call_builtin_with_ctx(
"vm.builtin.cuda_graph.run_or_capture",
(
cls.main_cuda_graph_capture,
(alloc1, w, alloc2, R.shape([m])),
R.prim_value(0),
R.shape([m]),
),
sinfo_args=(R.Tuple,),
)
storage3: R.Object = R.memory.alloc_storage(
R.shape([8]), R.prim_value(0), R.str("global"), R.dtype("float16")
)
alloc3: R.Tensor((8,), dtype="float16") = R.memory.alloc_tensor(
storage3, R.prim_value(0), R.shape([8]), R.dtype("float16")
)
R.call_packed("dummy", alloc2, w, alloc3, sinfo_args=(R.Tuple,))
gv_1: R.Tuple(R.Tensor((8,), dtype="float16")) = (alloc3,)
return gv_1


if __name__ == "__main__":
tvm.testing.main()
Loading