Skip to content

Commit

Permalink
[TIR] LowerTVMBuiltin may use device_type from PrimFunc annotation
Browse files Browse the repository at this point in the history
If an allocation occurs within a host function, it may not have a
device/host split.
  • Loading branch information
Lunderberg committed Mar 14, 2024
1 parent 33b9ea4 commit c3a6aba
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ namespace tir {
// These information are needed during codegen.
class BuiltinLower : public StmtExprMutator {
public:
static PrimFunc Build(PrimFunc func) {
Optional<PrimExpr> device_type = NullOpt;
if (auto target = func->GetAttr<Target>(tvm::attr::kTarget)) {
device_type = Integer(target.value()->kind->default_device_type);
}

BuiltinLower mutator(device_type);
func.CopyOnWrite()->body = mutator.VisitBodyAndRealizeAlloca(func->body);
return func;
}

BuiltinLower(Optional<PrimExpr> device_type = NullOpt) : device_type_(device_type) {}

// NOTE: Right now, we make the following scoping requirement
// for memory allocated by the following primitives
// - tvm_stack_make_array
Expand Down Expand Up @@ -656,13 +669,12 @@ class BuiltinLower : public StmtExprMutator {
namespace transform {

Pass LowerTVMBuiltin() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
if (IsHostFunc(f).value_or(false)) {
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
f.CopyOnWrite()->body = BuiltinLower().Build(f->body);
VLOG(2) << "LowerTVMBuiltin: " << f;
auto pass_func = [](PrimFunc func, IRModule m, PassContext ctx) {
if (IsHostFunc(func).value_or(false)) {
func = BuiltinLower::Build(func);
VLOG(2) << "LowerTVMBuiltin: " << func;
}
return f;
return func;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerTVMBuiltin", {});
}
Expand Down

0 comments on commit c3a6aba

Please sign in to comment.