Skip to content

Commit

Permalink
Adds SEScope (Storage/Execution Scope) for use as new unit of plannin…
Browse files Browse the repository at this point in the history
…g in 'device' planning

This is the first step in apache/tvm-rfcs#38 to bring devices
and targets together when doing device planning. I've gone ahead and also included a
memory scope in this object since we will also need to propagate memory scopes across
Relay expressions once this basic preparation is in place. In the meantime that field will be
left as "".

Once device planning works in units of SEScopes it will be possible to directly read off
the device and target for any Relay sub-expression without the need for TargetMaps ort
the construction of default Targets.

SEScopes also support 'Join' and 'Default' operations needed when constraint solving in
the device planner. You can see those in use in my scratchpad branch:
  https://github.com/mbs-octoml/mbs-tvm/tree/mbs-scopes

This PR also brings some duplicated and the ad-hoc 'default target' handling logic
together into a CompilationConfig class. (Again, see the scratchpad branch for how that
will end up being used). I've placed that next to SEScope since it's main purpose is to
  a) establish the default SEScope for primitive ops
  b) establish the SEScope for the 'host'
  c) feed a definitive vector of Targets into device planning so it can resolve all
     "on_device" and "device_copy" device references to their full SEScope form.
  • Loading branch information
mbs-octoml committed Oct 21, 2021
1 parent e62075d commit 909f99d
Show file tree
Hide file tree
Showing 20 changed files with 1,000 additions and 35 deletions.
4 changes: 2 additions & 2 deletions include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ struct AttrInitEntry {
~AttrInitEntry() DMLC_THROW_EXCEPTION {
if (value_missing_) {
std::ostringstream os;
os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization."
os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization. "
<< "If the key is defined check that its type matches the declared type.";
throw AttrError(os.str());
}
Expand Down Expand Up @@ -806,7 +806,7 @@ class AttrsNode : public BaseAttrsNode {
ICHECK_EQ(args.size() % 2, 0);
const int kLinearSearchBound = 16;
int hit_count = 0;
// applies two stratgies to lookup
// applies two strategies to lookup
if (args.size() < kLinearSearchBound) {
// linear search.
auto ffind = [&args](const char* key, runtime::TVMArgValue* val) {
Expand Down
384 changes: 384 additions & 0 deletions include/tvm/target/se_scope.h

Large diffs are not rendered by default.

13 changes: 9 additions & 4 deletions include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ class Target : public ObjectRef {
*/
TVM_DLL void ExitWithScope();
};

using TargetMap = Map<Integer, Target>;

/*!
* \brief Check and update host field of the given legacy target and target host pair.
* Note that this function is for legacy target api compatibility issue only, not
Expand All @@ -187,22 +190,24 @@ class Target : public ObjectRef {
* \param host The pointer to a Target typed object for target host to be updated
*/
void CheckAndUpdateHostConsistency(Target* target, Target* host);

/*!
* \brief Check and update host field of the given legacy heterogeneous targets and
* target host.Note that this function is for legacy target api compatibility issue only,
* not recommended for other use.
* \param target The pointer to a Map objects with values being Target objects
* \param target_map The pointer to a Map objects with values being Target objects
* \param host The Target typed object for target host to be updated
*/
void CheckAndUpdateHostConsistency(Map<Integer, Target>* target, Target* host);
void CheckAndUpdateHostConsistency(TargetMap* target_map, Target* host);

/*!
* \brief Check and update host field of the given legacy heterogeneous targets and
* target host.Note that this function is for legacy target api compatibility issue only,
* not recommended for other use.
* \param target The pointer to a Map objects with keys being Target objects
* \param ir_modules The pointer to a Map objects with keys being Target objects
* \param host The Target typed object for target host to be updated
*/
void CheckAndUpdateHostConsistency(Map<Target, IRModule>* target, Target* host);
void CheckAndUpdateHostConsistency(Map<Target, IRModule>* ir_modules, Target* host);

} // namespace tvm
#endif // TVM_TARGET_TARGET_H_
1 change: 1 addition & 0 deletions python/tvm/target/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
riscv_cpu,
hexagon,
)
from .se_scope import make_se_scope
from .tag import list_tags
from .generic_func import GenericFunc
from .generic_func import generic_func, get_native_generic_func, override_native_generic_func
Expand Down
22 changes: 22 additions & 0 deletions python/tvm/target/se_scope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Python bindings for creating SEScopes."""
from . import _ffi_api


def make_se_scope(device, target=None, memory_scope=""):
return _ffi_api.SEScope_ForDeviceTargetAndMemoryScope(device, target, memory_scope)
3 changes: 3 additions & 0 deletions src/ir/attr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#define TVM_IR_ATTR_FUNCTOR_H_

#include <tvm/node/functor.h>
#include <tvm/target/se_scope.h>
#include <tvm/tir/expr.h>

#include <utility>
Expand Down Expand Up @@ -105,6 +106,7 @@ class AttrFunctor<R(const ObjectRef& n, Args...)> {
virtual R VisitAttr_(const tir::CastNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const tir::CallNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const tir::SelectNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const SEScopeNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;

private:
// initialize the vtable.
Expand Down Expand Up @@ -139,6 +141,7 @@ class AttrFunctor<R(const ObjectRef& n, Args...)> {
ATTR_FUNCTOR_DISPATCH(CastNode);
ATTR_FUNCTOR_DISPATCH(CallNode);
ATTR_FUNCTOR_DISPATCH(SelectNode);
ATTR_FUNCTOR_DISPATCH(SEScopeNode);
return vtable;
}
};
Expand Down
4 changes: 2 additions & 2 deletions src/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1437,8 +1437,8 @@ class Parser {
String attr_key = Downcast<String>(raw_attrs["attrs_type_key"]);
if (attr_key.size()) {
raw_attrs.erase("attrs_type_key");
auto tbl = tvm::ReflectionVTable::Global();
auto attr_obj = tbl->CreateObject(attr_key, raw_attrs);
auto attr_obj =
tvm::ReflectionVTable::Global()->CreateObject(attr_key, raw_attrs);
ICHECK(attr_obj.defined());
attrs = Downcast<Attrs>(attr_obj);
}
Expand Down
15 changes: 12 additions & 3 deletions src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -773,8 +773,6 @@ Doc RelayTextPrinter::PrintAttr(const ObjectRef& value, bool meta) {
printed_attr << "?";
} else if (auto str_obj = value.as<tvm::StringObj>()) {
printed_attr << Doc::StrLiteral(GetRef<String>(str_obj));
} else if (const auto* on_device_attrs = value.as<OnDeviceAttrs>()) {
printed_attr << "device_type=" << on_device_attrs->device_type;
} else if (meta) {
printed_attr = meta_->GetMetaNode(Downcast<ObjectRef>(value));
} else {
Expand All @@ -787,7 +785,7 @@ Doc RelayTextPrinter::PrintAttr(const ObjectRef& value, bool meta) {
}

Doc RelayTextPrinter::VisitAttrDefault_(const Object* op) {
return PrintAttr(GetRef<ObjectRef>(op), true);
return PrintAttr(GetRef<ObjectRef>(op), /*meta=*/true);
}

Doc RelayTextPrinter::VisitAttr_(const ArrayNode* op) {
Expand All @@ -814,6 +812,17 @@ Doc RelayTextPrinter::VisitAttr_(const tir::StringImmNode* op) {
return Doc::StrLiteral(op->value);
}

Doc RelayTextPrinter::VisitAttr_(const SEScopeNode* op) {
if (show_meta_data_) {
return VisitAttrDefault_(op);
} else {
// TODO(mbs): Surely there's a better way?
std::ostringstream os;
os << GetRef<SEScope>(op);
return Doc::Text(os.str());
}
}

/*!
* \brief Attribute printer which prints the attributes in the call.
*/
Expand Down
1 change: 1 addition & 0 deletions src/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
Doc VisitAttr_(const tir::IntImmNode* op) final;
Doc VisitAttr_(const tir::FloatImmNode* op) final;
Doc VisitAttr_(const tir::StringImmNode* op) final;
Doc VisitAttr_(const SEScopeNode* op) final;

private:
/*! \brief Whether to print meta data. */
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {
ICHECK_EQ(args.num_args, 2) << "The expected of arguments are: "
<< "runtime::Module mod and Map<int, Target> targets";
void* mod = args[0];
Map<Integer, tvm::Target> targets = args[1];
TargetMap targets = args[1];
init(mod, targets);
});
} else if (name == "codegen") {
Expand Down Expand Up @@ -759,7 +759,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {
const char* type_key() const final { return "RelayGraphRuntimeCodegenModule"; }

private:
void init(void* mod, Map<Integer, tvm::Target> tmp) {
void init(void* mod, TargetMap tmp) {
tec::TargetMap targets;
Target target_host;
for (const auto& it : tmp) {
Expand Down
11 changes: 5 additions & 6 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ Pass LabelOps();
}
namespace backend {

using TargetsMap = Map<tvm::Integer, tvm::Target>;
using namespace tvm::relay::transform;

/*!
Expand All @@ -56,7 +55,7 @@ struct BuildOutput {
};

struct ExecutorCodegen {
void Init(runtime::Module* m, TargetsMap targets) { CallFunc("init", m, targets); }
void Init(runtime::Module* m, TargetMap targets) { CallFunc("init", m, targets); }

void Codegen(const Function& func, String mod_name) { CallFunc("codegen", func, mod_name); }

Expand Down Expand Up @@ -278,7 +277,7 @@ class RelayBuildModule : public runtime::ModuleNode {
* \param target Target device
* \param target_host Host target device
*/
void Build(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host,
void Build(IRModule mod, const TargetMap& targets, const tvm::Target& target_host,
const String executor, const String mod_name) {
for (const auto& pair : targets) {
VLOG(0) << "Build target " << pair.first << " = " << pair.second->str();
Expand Down Expand Up @@ -309,7 +308,7 @@ class RelayBuildModule : public runtime::ModuleNode {
*
* \return relay::IRModule The updated Relay IR module after optimization.
*/
IRModule Optimize(IRModule relay_module, const TargetsMap& targets,
IRModule Optimize(IRModule relay_module, const TargetMap& targets,
const std::unordered_map<std::string, runtime::NDArray>& params) {
targets_ = targets;
// No target_host setup it seems.
Expand Down Expand Up @@ -446,7 +445,7 @@ class RelayBuildModule : public runtime::ModuleNode {
const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate");
if (!target_host.defined()) target_host = (pf != nullptr) ? Target("llvm") : Target("stackvm");

// Update all the targets in the targets_ TargetsMap
// Update all the targets in the targets_ TargetMap
CheckAndUpdateHostConsistency(&targets_, &target_host);

// Relay IRModule -> IRModule optimizations.
Expand Down Expand Up @@ -542,7 +541,7 @@ class RelayBuildModule : public runtime::ModuleNode {
protected:
std::unique_ptr<ExecutorCodegen> executor_codegen_;
/*! \brief target device */
TargetsMap targets_;
TargetMap targets_;
/*! \brief target host device */
tvm::Target target_host_;
/*! \brief parameters */
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode {
ICHECK_EQ(args.num_args, 2) << "The expected of arguments are: "
<< "runtime::Module mod and Map<int, Target> targets";
void* mod = args[0];
Map<Integer, tvm::Target> tmp = args[1];
TargetMap tmp = args[1];
tec::TargetMap targets;
for (const auto& it : tmp) {
auto dev_type = it.first.as<tir::IntImmNode>();
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/te_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ void UpdateFunctionMetadata(Function relay_func,
* \param dev_type
* \return Target
*/
Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets);
Target GetTargetFromInteger(DLDeviceType dev_type, tec::TargetMap targets);

/*!
* \brief Update the "main" control function's metadata
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ inline bool IsCompileEngineCacheDisabled() {
* \param is_vm A boolean indicating if the passes are used for vm or graph runtime.
* \return An array of passes.
*/
Array<Pass> GetPassPrefix(const Map<tvm::Integer, tvm::Target>& targets, bool is_vm);
Array<Pass> GetPassPrefix(const TargetMap& targets, bool is_vm);

/*! \brief Target hash function */
struct TargetStrHash {
Expand Down
16 changes: 8 additions & 8 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ namespace vm {
using namespace tvm::runtime;
using namespace tvm::runtime::vm;
using namespace relay::transform;
using namespace tec;

// (@jroesch): VM passes, eventually declare as passes.
bool IsClosure(const Function& func);
Expand Down Expand Up @@ -251,7 +250,7 @@ int GetFallbackDevice() {

class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
public:
VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host)
VMFunctionCompiler(VMCompilerContext* context, TargetMap targets, Target target_host)
: DeviceAwareExprFunctor(context->module),
last_register_(0),
registers_num_(0),
Expand Down Expand Up @@ -458,7 +457,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {

void EmitShapeFunc(Function func, Array<Expr> inputs, Array<Expr> outputs) {
// Lower shape function
CCacheKey key(func, target_host_);
tec::CCacheKey key(func, target_host_);
auto cfunc = context_->compiler->LowerShapeFunc(key);
int op_index = -1;
// pick the only function inside the context
Expand Down Expand Up @@ -534,7 +533,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
}
}

CCacheKey key(func, target);
tec::CCacheKey key(func, target);
auto mangle_fn = [](String name) { return name; };
auto cfunc = context_->compiler->Lower(key, mangle_fn); // <<<< one-func-at-a-time lowering

Expand Down Expand Up @@ -903,7 +902,8 @@ void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) {
params_[name] = data_in;
}

void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) {
void VMCompiler::Lower(IRModule mod, const tvm::TargetMap& targets,
const tvm::Target& target_host) {
exec_ = make_object<Executable>();
targets_ = targets;
target_host_ = target_host;
Expand Down Expand Up @@ -969,7 +969,7 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe
backend::UpdateAutoSchedulerOpWeights(context_.compiler);
}

transform::Sequential MemoryOpt(tvm::Target host_target, TargetsMap targets) {
transform::Sequential MemoryOpt(tvm::Target host_target, tvm::TargetMap targets) {
Array<Pass> pass_seqs;
// Remove unused functions
Array<runtime::String> entry_functions{"main"};
Expand Down Expand Up @@ -1015,9 +1015,9 @@ transform::Sequential MemoryOpt(tvm::Target host_target, TargetsMap targets) {
return transform::Sequential(pass_seqs);
}

IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetsMap& targets_arg,
IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetMap& targets_arg,
const Target& target_host_arg) {
TargetsMap targets = targets_arg;
TargetMap targets = targets_arg;
Target target_host = target_host_arg;
CheckAndUpdateHostConsistency(&targets, &target_host);
if (params_.size()) {
Expand Down
7 changes: 3 additions & 4 deletions src/relay/backend/vm/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ using TagNameMap = std::unordered_map<size_t, tvm::relay::Constructor>;
using GlobalMap = NodeMap<GlobalVar, Index>;
using ConstMap = NodeMap<Constant, Index>;
using ConstTensorShapeMap = NodeMap<TensorType, std::pair<Index, NDArray>>;
using TargetsMap = Map<tvm::Integer, tvm::Target>;

struct VMCompilerContext {
// The module context for the compilation
Expand Down Expand Up @@ -111,7 +110,7 @@ class VMCompiler : public runtime::ModuleNode {
* to target mapping. For homogeneous compilation, it is a singleton build target.
* \param target_host Host compilation target, if target is device.
*/
void Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host);
void Lower(IRModule mod, const TargetMap& targets, const tvm::Target& target_host);

/*! \brief Generate the machine code for lowered functions. */
void Codegen();
Expand All @@ -127,7 +126,7 @@ class VMCompiler : public runtime::ModuleNode {
*
* \return The optimized IRModule.
*/
IRModule OptimizeModule(IRModule mod, const TargetsMap& targets, const Target& target_host);
IRModule OptimizeModule(IRModule mod, const TargetMap& targets, const Target& target_host);

/*!
* \brief Populate the global function names in a map where the value is used
Expand All @@ -137,7 +136,7 @@ class VMCompiler : public runtime::ModuleNode {

protected:
/*! \brief Target devices. */
TargetsMap targets_;
TargetMap targets_;
/*! \brief Target host device. */
tvm::Target target_host_;
/*! \brief Global shared meta data */
Expand Down
Loading

0 comments on commit 909f99d

Please sign in to comment.