Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] make autobatching matrix multiplies more flexible #566

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion dynet/dynet.h
Original file line number Diff line number Diff line change
Expand Up @@ -636,8 +636,10 @@ struct Node {

Device* device; /**< pointer to the node, or null to inherit device from first input, or default when there is no input */

unsigned matmul_count; // how many matmul nodes am I an arg of?

protected:
Node() : args(), device(default_device) {}
Node() : args(), device(default_device), matmul_count(0) {}
explicit Node(const std::initializer_list<VariableIndex>& a) : args(a), device(default_device) {}
template <typename T>
explicit Node(const T&c) : args(c.begin(), c.end()), device(default_device) {}
Expand Down
2 changes: 1 addition & 1 deletion dynet/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Expression operator+(const Expression& x, real y) { return y + x; }
Expression operator-(const Expression& x, const Expression& y) { return x + (-y); }
Expression operator-(real x, const Expression& y) { return Expression(y.pg, y.pg->add_function<ConstantMinusX>({y.i}, x)); }
Expression operator-(const Expression& x, real y) { return -(y - x); }
Expression operator*(const Expression& x, const Expression& y) { return Expression(x.pg, x.pg->add_function<MatrixMultiply>({x.i, y.i})); }
Expression operator*(const Expression& x, const Expression& y) { x.pg->nodes[x.i]->matmul_count++; return Expression(x.pg, x.pg->add_function<MatrixMultiply>({x.i, y.i})); }
Expression operator*(const Expression& x, float y) { return Expression(x.pg, x.pg->add_function<ConstScalarMultiply>({x.i}, y)); }
Expression cmult(const Expression& x, const Expression& y) {
if (x.dim().batch_size() == 1)
Expand Down
15 changes: 12 additions & 3 deletions dynet/nodes-common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -879,16 +879,25 @@ int MatrixMultiply::autobatch_sig(const ComputationGraph & cg, SigMap &sm) const
// TODO do we want to treat different dimensions of first/second arg differently?
if(dim.bd == 1) {
Sig s(nt::matmul);
s.add_node(args[0]);
// if arg0 is likely to be shared, include it in the sig.
// otherwise, include both args dims in the sig.
if (cg.nodes[args[0]]->matmul_count > 2) { //TODO why 2? can we set a better number?
s.add_node(args[0]); s.add_dim(cg.nodes[args[1]]->dim);
} else {
s.add_dim(cg.nodes[args[0]]->dim); s.add_dim(cg.nodes[args[1]]->dim);
}
return sm.get_idx(s);
} else {
return 0; // TODO handle the batched case as well? should it differ at all?
}
}

std::vector<int> MatrixMultiply::autobatch_concat(const ComputationGraph & cg) const {
vector<int> ret(args.size(), 0);
if (dim.bd == 1) { ret[1] = 1; }
vector<int> ret(2, 0);
if (dim.bd == 1) {
ret[1] = 1;
if (cg.nodes[args[0]]->matmul_count <= 2) { ret[0] = 1; }
}
return ret;
}

Expand Down