Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Heterogeneous][Bugfix] Fix bug of wrongly generated device_map #2990

Merged
merged 6 commits into from
Apr 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 38 additions & 33 deletions src/relay/pass/device_annotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,9 @@ class AnnotatationVisitor : private ExprVisitor {
* -Pass 1: Propagating the source device type to ops in a bottom-up way to the
* ancestors until encountering another copy op. For example, this way
* provides add, x, and y device types from the copy operator, `copy1`.
* -Pass 2: Propagating the destination device type of "the last" copy op in a
* top-down manner to the nodes on the output paths. For instance,
* this offers `subtract` and `exp` the same device type as `copy3`.
* -Pass 2: Propagating the destination device type of "the last" copy op to the
* remain nodes. For instance, this offers `subtract` and `exp` the
* same device type as `copy3`.
*/

class DeviceInfo {
Expand Down Expand Up @@ -371,17 +371,22 @@ class DeviceInfo {
}

void VisitExpr_(const ConstantNode* cn) final {
post_dfs_order_.push_back(cn);
post_dfs_order_.push_back(std::make_pair(cn, has_copy_));
}

void VisitExpr_(const CallNode* call) final {
// Skip annotation nodes.
if (!IsOnDeviceNode(call)) {
ExprVisitor::VisitExpr_(call);
post_dfs_order_.push_back(call);

if (GetDeviceCopyNode(call)) {
num_device_copy_ops_++;
bool has_copy_prev = has_copy_;
has_copy_ = true;
ExprVisitor::VisitExpr_(call);
post_dfs_order_.push_back(std::make_pair(call, has_copy_));
has_copy_ = has_copy_prev;
} else {
ExprVisitor::VisitExpr_(call);
post_dfs_order_.push_back(std::make_pair(call, has_copy_));
}
}
}
Expand All @@ -393,23 +398,27 @@ class DeviceInfo {

void VisitExpr_(const TupleGetItemNode* op) final {
ExprVisitor::VisitExpr_(op);
post_dfs_order_.push_back(op);
std::make_pair(op, has_copy_);
}

void VisitExpr_(const VarNode* vn) final { post_dfs_order_.push_back(vn); }
void VisitExpr_(const VarNode* vn) final {
post_dfs_order_.push_back(std::make_pair(vn, has_copy_));
}

void VisitExpr_(const LetNode* ln) final {
ExprVisitor::VisitExpr_(ln);
post_dfs_order_.push_back(ln);
post_dfs_order_.push_back(std::make_pair(ln, has_copy_));
}

void VisitExpr_(const IfNode* in) final {
ExprVisitor::VisitExpr_(in);
post_dfs_order_.push_back(in);
post_dfs_order_.push_back(std::make_pair(in, has_copy_));
}


int num_device_copy_ops_{0};
std::vector<const ExprNode*> post_dfs_order_;
bool has_copy_ = false;
std::vector<std::pair<const ExprNode*, bool>> post_dfs_order_;
friend DeviceInfo;
};

Expand All @@ -435,46 +444,41 @@ class DeviceInfo {

void PropagateDeviceId() {
// Bottom-up propagation.
BottomUpPropagation();
// Top-down propagation.
TopDownPropagation();
int out_dev_type = BottomUpPropagation();
// propagation for remained nodes.
FillPropagation(out_dev_type);
}

void BottomUpPropagation() {
int BottomUpPropagation() {
const CallNode* last_copy_node = nullptr;
int cur_dev_type = -1;
int out_dev_type = -1;
for (auto it = post_visitor_.post_dfs_order_.crbegin();
it != post_visitor_.post_dfs_order_.crend(); ++it) {
if (const auto* node = GetDeviceCopyNode(*it)) {
if (const auto* node = GetDeviceCopyNode(it->first)) {
last_copy_node = dynamic_cast<const CallNode*>(node);
const auto* attrs = last_copy_node->attrs.as<DeviceCopyAttrs>();
cur_dev_type = attrs->src_dev_type;
device_map_.Set(GetRef<Expr>(*it), attrs->dst_dev_type);
if (out_dev_type == -1) out_dev_type = attrs->dst_dev_type;
if (it->second) device_map_.Set(GetRef<Expr>(it->first),
attrs->dst_dev_type);
} else if (last_copy_node) {
Expr expr = GetRef<Expr>(*it);
Expr expr = GetRef<Expr>(it->first);
CHECK_EQ(device_map_.count(expr), 0U);
device_map_.Set(expr, cur_dev_type);
if (it->second) device_map_.Set(expr, cur_dev_type);
}
}
return out_dev_type;
}

void TopDownPropagation() {
const CallNode* last_copy_node = nullptr;
int cur_dev_type = -1;
void FillPropagation(int out_dev_type) {
for (const auto& it : post_visitor_.post_dfs_order_) {
if (const auto* node = GetDeviceCopyNode(it)) {
last_copy_node = dynamic_cast<const CallNode*>(node);
const auto* attrs = last_copy_node->attrs.as<DeviceCopyAttrs>();
cur_dev_type = attrs->dst_dev_type;
} else if (last_copy_node) {
Expr expr = GetRef<Expr>(it);
if (device_map_.count(expr) == 0) {
device_map_.Set(expr, cur_dev_type);
}
}
Expr expr = GetRef<Expr>(it.first);
if (!it.second) device_map_.Set(expr, out_dev_type);
}
}


PostDfsOrderVisitor post_visitor_;
Map<Expr, Integer> device_map_;
};
Expand Down Expand Up @@ -509,3 +513,4 @@ TVM_REGISTER_API("relay._ir_pass.CollectDeviceAnnotationOps")

} // namespace relay
} // namespace tvm

92 changes: 84 additions & 8 deletions tests/python/relay/test_pass_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def check_storage_and_device_types():
check_storage_and_device_types()


def test_fusible_network():
def run_fusible_network(dev, tgt):
R""" The network is as following:
x y
\ /
Expand Down Expand Up @@ -413,20 +413,96 @@ def test_fallback_all_operators(device, tgt):
check_annotated_graph(annotated_func, expected_func)
test_runtime(target, device, annotated_func)


test_fuse_log_add(dev, tgt)
test_fuse_all(dev, tgt)
test_fallback_exp(dev, tgt)
test_fallback_all_operators(dev, tgt)

def run_unpropagatable_graph(dev, tgt):
R""" The network is as following:
a b c d
\ / \ /
add mul
\ /
subtract
"""

a = relay.var("a", shape=(10, 10))
b = relay.var("b", shape=(10, 10))
c = relay.var("c", shape=(10, 10))
d = relay.var("d", shape=(10, 10))
a_data = np.random.rand(10, 10).astype('float32')
b_data = np.random.rand(10, 10).astype('float32')
c_data = np.random.rand(10, 10).astype('float32')
d_data = np.random.rand(10, 10).astype('float32')
tmp_add = a_data + b_data
tmp_mul = np.multiply(c_data, d_data)
ref_res = np.subtract(tmp_add, tmp_mul)

fallback_device = tvm.context("cpu")
target = {"cpu": "llvm", dev: tgt}
cpu_ctx = fallback_device
dev_ctx = tvm.context(dev)

def annotated():
add = relay.add(a, b)
_add = relay.annotation.on_device(add, dev_ctx)
mul = relay.multiply(c, d)
_mul = relay.annotation.on_device(mul, cpu_ctx)
sub = relay.subtract(add, mul)
_sub = relay.annotation.on_device(sub, dev_ctx)
func = relay.Function([a, b, c, d],
relay.Tuple(tvm.convert([_add, _mul,
_sub, sub])))
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
dev_ctx.device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[3]),
func.body[3])

def expected():
add = relay.add(a, b)
mul = relay.multiply(c, d)
copy_mul_sub = relay.device_copy(mul, cpu_ctx, dev_ctx)
sub = relay.subtract(add, copy_mul_sub)
func = relay.Function([a, b, c, d], sub)
return func

annotated_func = annotated()
expected_func = expected()
expected_index = [2, 2, 2, 1, 1, 1, 2, 2]
check_annotated_graph(annotated_func, expected_func)
params = {"a": a_data, "b": b_data, "c": c_data, "d": d_data}
config = {"opt_level": 0}
config["fallback_device"] = fallback_device
with relay.build_config(**config):
graph, lib, params = relay.build(annotated_func, target, params=params)
contexts = [tvm.cpu(0), tvm.context(dev)]
graph_json = json.loads(graph)
if "device_index" in graph_json["attrs"]:
device_index = graph_json["attrs"]["device_index"][1]
assert device_index == expected_index
mod = graph_runtime.create(graph, lib, contexts)
mod.set_input(**params)
mod.run()
res = mod.get_output(0).asnumpy()
tvm.testing.assert_allclose(res, ref_res, rtol=1e-5, atol=1e-5)

def test_check_run():
for dev, tgt in [("opencl", "opencl"), ("cuda", "cuda"),
("opencl", str(tvm.target.intel_graphics()))]:
("opencl", str(tvm.target.intel_graphics()))]:
if not tvm.module.enabled(dev):
print("Skip test because %s is not enabled." % dev)
continue
test_fuse_log_add(dev, tgt)
test_fuse_all(dev, tgt)
test_fallback_exp(dev, tgt)
test_fallback_all_operators(dev, tgt)

run_fusible_network(dev, tgt)
run_unpropagatable_graph(dev, tgt)


if __name__ == "__main__":
test_redundant_annotation()
test_annotate_all()
test_annotate_none()
test_conv_network()
test_fusible_network()
test_check_run()