Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance pragma to support single point copy #863

Merged
merged 28 commits into from
Feb 4, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
f95ad29
modified schedule_dataflow_rewrite.cc to fix losing tensor problem
libing4752 Jan 2, 2018
dd15d65
modified schedule_dataflow_rewrite.cc for lint scan
libing4752 Jan 2, 2018
5276444
Merge branch 'master' into master
libing4752 Jan 2, 2018
c2747cb
modified schedule_dataflow_rewrite.cc for lint scan
libing4752 Jan 2, 2018
fd7599a
Merge branch 'master' of https://github.com/libing4752/tvm
libing4752 Jan 2, 2018
70bcd58
using tensor's value_index to index output of stage op
libing4752 Jan 3, 2018
4af36f0
Merge remote-tracking branch 'upstream/master'
libing4752 Jan 20, 2018
6d94540
repare address offset for different kinds of dtype
libing4752 Jan 22, 2018
370b588
bc
libing4752 Jan 22, 2018
4357ca8
aaa
libing4752 Jan 22, 2018
f0c79d6
aaaaa
libing4752 Jan 22, 2018
9135f4f
repare address for different dtypes
libing4752 Jan 22, 2018
a26cee5
remove nonsense files
libing4752 Jan 22, 2018
4270189
add whitespace of line 581
libing4752 Jan 22, 2018
e074dce
Merge branch 'master' into master
libing4752 Jan 22, 2018
d65953a
use base alloc elem_type
libing4752 Jan 23, 2018
111f16b
enhance the testcast of basic buffer is 64bits,32bits,16bits,8bits
libing4752 Jan 23, 2018
d9a2762
Merge branch 'master' of https://github.com/libing4752/tvm
libing4752 Jan 23, 2018
1f3b27e
Merge branch 'master' into master
libing4752 Jan 23, 2018
8436366
use extends[0]->type() as dtype of offset
libing4752 Jan 23, 2018
942fdc5
clear program writes
libing4752 Jan 23, 2018
05ac4fd
Merge branch 'master' of https://github.com/libing4752/tvm
libing4752 Jan 24, 2018
4b0adf9
Merge remote-tracking branch 'upstream/master'
libing4752 Feb 2, 2018
91c4769
enhance inject_copy_intin to support of pragma stmt with no loops
libing4752 Feb 2, 2018
5f1ae95
fix cpplint errors
libing4752 Feb 2, 2018
0c14bd4
fix cpplint error of !
libing4752 Feb 2, 2018
3d59239
enhance detectLinearEquation to support with no loop vars
libing4752 Feb 4, 2018
4a52de0
fix cpplint errors
libing4752 Feb 4, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions src/arithmetic/detect_linear_equation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,25 +123,28 @@ class LinearEqDetector
};

Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) {
CHECK_GE(vars.size(), 1U);
Expr base = e;
Array<Expr> coeff;

for (Var v : vars) {
LinearEqEntry ret;
if (!LinearEqDetector(v).Detect(base, &ret)) {
return Array<Expr>();
if (0 == vars.size()) {
coeff.push_back(make_const(Int(32), 1));
} else {
for (Var v : vars) {
LinearEqEntry ret;
if (!LinearEqDetector(v).Detect(base, &ret)) {
return Array<Expr>();
}
coeff.push_back(ret.coeff);
base = std::move(ret.base);
}
coeff.push_back(ret.coeff);
base = std::move(ret.base);
}

std::unordered_set<const Variable*> vset;
for (size_t i = vars.size(); i != 1; --i) {
vset.insert(vars[i - 1].get());
// The previous coeff contains the variable
if (ExprUseVar(coeff[i - 2], vset)) {
return Array<Expr>();
std::unordered_set<const Variable*> vset;
for (size_t i = vars.size(); i != 1; --i) {
vset.insert(vars[i - 1].get());
// The previous coeff contains the variable
if (ExprUseVar(coeff[i - 2], vset)) {
return Array<Expr>();
}
}
}
coeff.push_back(base);
Expand Down
26 changes: 18 additions & 8 deletions src/pass/inject_copy_intrin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class CopyIntrinInjector : public IRMutator {
private:
bool MatchCopyPattern(Stmt stmt, Stmt *out) {
Stmt body = stmt;
bool is_single_point_copy = false;

// strip the loops
std::vector<const For*> loops;
Expand All @@ -53,7 +54,10 @@ class CopyIntrinInjector : public IRMutator {
const Select* select = store->value.as<Select>();
const Cast* cast = store->value.as<Cast>();
const Load* load = store->value.as<Load>();

if (0 == loops.size()) {
is_single_point_copy = true;
CHECK(select == nullptr);
}
// for now only support true condition matching
if (select != nullptr) {
load = select->true_value.as<Load>();
Expand All @@ -74,13 +78,19 @@ class CopyIntrinInjector : public IRMutator {
arith::DetectLinearEquation(load->index, loop_vars);
if (load_strides.size() == 0 || store_strides.size() == 0) return false;
Array<Expr> dst_shape;
for (const For* op : loops) {
dst_shape.push_back(op->extent);
auto loop_var_size = loop_vars.size();
if (is_single_point_copy) {
loop_var_size = 1;
dst_shape.push_back(make_const(Int(32), 1));
} else {
for (const For* op : loops) {
dst_shape.push_back(op->extent);
}
}
Array<Expr> src_shape = dst_shape;
Array<Expr> pad_before, pad_after;
Expr pad_value;
Expr src_elem_offset = load_strides[loop_vars.size()];
Expr src_elem_offset = load_strides[loop_var_size];
if (select != nullptr) {
Array<Expr> clip_bound =
arith::DetectClipBound(select->condition, loop_vars);
Expand Down Expand Up @@ -114,15 +124,15 @@ class CopyIntrinInjector : public IRMutator {
src_elem_offset = Simplify(src_elem_offset);
}
CHECK_EQ(load_strides.size(), store_strides.size());
CHECK_EQ(load_strides.size(), loop_vars.size() + 1);
Array<Expr> src_strides(load_strides.begin(), load_strides.begin() + loop_vars.size());
Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_vars.size());
CHECK_EQ(load_strides.size(), loop_var_size + 1);
Array<Expr> src_strides(load_strides.begin(), load_strides.begin() + loop_var_size);
Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size);
Buffer dst = BufferNode::make(
Var(store->buffer_var.node_),
store->value.type(),
dst_shape,
dst_strides,
store_strides[loop_vars.size()],
store_strides[loop_var_size],
store->buffer_var->name_hint,
GetStorageScope(store->buffer_var.get()),
0, 0);
Expand Down
20 changes: 20 additions & 0 deletions tests/python/unittest/test_pass_inject_copy_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,25 @@ def cb(src, dst, pad_before, pad_after, pad_value):
return tvm.make.Evaluate(0)
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)

def test_single_point_test():
A = tvm.placeholder((1,), name='A')
B = tvm.compute((1,), lambda i:
A[i], name='B')
s = tvm.create_schedule(B.op)
s[B].pragma(B.op.axis[0], "memcpy")
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
def cb(src, dst, pad_before, pad_after, pad_value):
assert tvm.ir_pass.Simplify(src.elem_offset).value == 0
assert tvm.ir_pass.Simplify(dst.elem_offset).value == 0
assert tvm.ir_pass.Simplify(src.strides[0]).value == 1
assert tvm.ir_pass.Simplify(dst.strides[0]).value == 1
return tvm.make.Evaluate(0)
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)

def assert_expr_equal(a, b):
assert tvm.ir_pass.Simplify(a - b).value == 0

Expand Down Expand Up @@ -80,3 +99,4 @@ def cb(src, dst, pad_before, pad_after, pad_value):
test_copy2d()
test_copy_pad()
test_copy_pad_split()
test_single_point_test()