Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] SplitHostDevice, handle subroutines #14918

Merged
merged 17 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ class Target : public ObjectRef {
*/
static Target WithHost(const Target& target, const Target& host);

/*! \return The target with the host stripped out */
Target WithoutHost() const;

/*!
* \brief Returns true if \p this target represents an external codegen. If so,
* \p this->kind->name can be used as the "Compiler" attribute on partitioned functions,
Expand Down
38 changes: 38 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,51 @@ TVM_DLL Pass LowerCustomDatatypes();
*/
TVM_DLL Pass DecorateDeviceScope();

/*!
* \brief Annotate locations that should be run on the device
*
* Insert `AttrStmt` nodes specifying a target on which regions within
* the PrimFunc should be executed. Only modifies functions that have
* a `tvm::attr::kTarget` attribute, and where that target defines a
* host.
*
* \return The pass.
*/
TVM_DLL Pass AnnotateDeviceRegions();

/*!
* \brief Split the function into a host function and device functions.
*
* The resulting host-side function will keep the same
* `tvm::attr::kTarget` attribute (e.g. `T.target("cuda",
* host=T.target("llvm"))`). This ensures that `MakePackedAPI` knows
* which device type should be used for the input buffers.
*
* The resulting device-side function will
* have the host stripped from its target attribute
* (e.g. `T.target("cuda")`).
*
* \return The pass.
*/
TVM_DLL Pass SplitHostDevice();

/*!
* \brief Lower cross-device function calls.
*
* Prior to this pass, host to device calls are represented as
* subroutine calls, with environment parameters (e.g. env_thread)
* specified internally. The device function is an internal function,
* without a `tvm::attr::kGlobalSymbol` attribute.
*
* After this pass, host to device calls are represented as
* tvm_call_packed built-in. The device function is an
* externally-exposed function, with a non-empty
* `tvm::attr::kGlobalSymbol` attribute.
*
* \return The pass.
*/
TVM_DLL Pass LowerDeviceKernelLaunch();

/*!
* \brief skip assert stmt.
*
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def call_tir(global_var: tvm.ir.GlobalVar, *args):
The call expression.
"""
assert isinstance(global_var, tvm.ir.GlobalVar)
return Call(dtype="handle", op=global_var, args=args)
return Call(dtype="void", op=global_var, args=args)


def start_profile_intrinsic(id):
Expand Down
38 changes: 38 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,22 @@ def MakeUnpackedAPI():
return _ffi_api.MakeUnpackedAPI() # type: ignore


def AnnotateDeviceRegions():
"""Annotate locations that should be run on the device

Insert `AttrStmt` nodes specifying a target on which regions
within the PrimFunc should be executed. Only modifies functions
that have a `tvm::attr::kTarget` attribute, and where that target
defines a host.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.AnnotateDeviceRegions() # type: ignore


def SplitHostDevice():
"""Split the function into a host function and device functions.

Expand All @@ -446,6 +462,28 @@ def SplitHostDevice():
return _ffi_api.SplitHostDevice() # type: ignore


def LowerDeviceKernelLaunch():
"""Lower cross-device function calls.

Prior to this pass, host to device calls are represented as
subroutine calls, with environment parameters (e.g. env_thread)
specified internally. The device function is an internal
function, without a `tvm::attr::kGlobalSymbol` attribute.

After this pass, host to device calls are represented as
tvm_call_packed built-in. The device function is an
externally-exposed function, with a non-empty
`tvm::attr::kGlobalSymbol` attribute.


Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerDeviceKernelLaunch() # type: ignore


def DecorateDeviceScope():
"""Decorate all the function's body as device function.

Expand Down
3 changes: 3 additions & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,10 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
mixed_pass_list.push_back(tir::transform::MakePackedAPI());
}
mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());

mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions());
mixed_pass_list.push_back(tir::transform::SplitHostDevice());
mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch());

return transform::Sequential(mixed_pass_list);
}
Expand Down
10 changes: 10 additions & 0 deletions src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,16 @@ Map<String, ObjectRef> TargetNode::Export() const {

Optional<Target> TargetNode::GetHost() const { return this->host.as<Target>(); }

Target Target::WithoutHost() const {
if ((*this)->GetHost()) {
auto output = make_object<TargetNode>(*get());
output->host = NullOpt;
return Target(output);
} else {
return *this;
}
}

int TargetNode::GetTargetDeviceType() const {
if (Optional<Integer> device_type = GetAttr<Integer>("target_device_type")) {
return Downcast<Integer>(device_type)->value;
Expand Down
81 changes: 81 additions & 0 deletions src/tir/transforms/annotate_device_regions.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file annotate_device_regions.cc
* \brief Split device function from host.
*/
#include <tvm/ir/transform.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

namespace tvm {
namespace tir {

class DeviceRegionAnnotater : public StmtMutator {
public:
explicit DeviceRegionAnnotater(Target device_target) : device_target_(device_target) {}

Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == tvm::attr::kTarget) {
// If a target attribute already exists, use it as-is.
return GetRef<Stmt>(op);
} else if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope ||
op->attr_key == attr::device_scope) {
// These attributes are only allowed in device-side code, so
// they should be annotated with the function's default target.
Stmt body = GetRef<Stmt>(op);
return AttrStmt(device_target_, tvm::attr::kTarget, 0, body);
} else {
// All other annotations are ignored
return StmtMutator::VisitStmt_(op);
}
}

private:
Target device_target_;
};

namespace transform {

Pass AnnotateDeviceRegions() {
auto pass_func = [](PrimFunc func, IRModule mod, PassContext ctx) -> PrimFunc {
auto opt_target = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(opt_target) << "AnnotateDeviceRegions: Require the target attribute";
Target target = opt_target.value();

if (target->GetHost()) {
DeviceRegionAnnotater mutator(target.WithoutHost());
func.CopyOnWrite()->body = mutator(func->body);
}
return func;
};

return CreatePrimFuncPass(pass_func, 0, "tir.AnnotateDeviceRegions", {});
}

TVM_REGISTER_GLOBAL("tir.transform.AnnotateDeviceRegions").set_body_typed(AnnotateDeviceRegions);

} // namespace transform
} // namespace tir
} // namespace tvm
Loading