From 06ec749e460370a43b5d42bad809797b8de14620 Mon Sep 17 00:00:00 2001 From: yoavg Date: Fri, 26 May 2017 02:08:20 +0300 Subject: [PATCH] introduce a version of matmul which does not depend on first arg Former-commit-id: 46c9a58f8b8ee3650bccdb31424a99f61994bdd6 --- dynet/dynet.h | 4 +++- dynet/expr.cc | 2 +- dynet/nodes-common.cc | 15 ++++++++++++--- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/dynet/dynet.h b/dynet/dynet.h index eb2ab8290..9e434af18 100644 --- a/dynet/dynet.h +++ b/dynet/dynet.h @@ -636,8 +636,10 @@ struct Node { Device* device; /**< pointer to the node, or null to inherit device from first input, or default when there is no input */ + unsigned matmul_count; // how many matmul nodes am I an arg of? + protected: - Node() : args(), device(default_device) {} + Node() : args(), device(default_device), matmul_count(0) {} explicit Node(const std::initializer_list& a) : args(a), device(default_device) {} template explicit Node(const T&c) : args(c.begin(), c.end()), device(default_device) {} diff --git a/dynet/expr.cc b/dynet/expr.cc index a73b1ad39..580fa3aa0 100644 --- a/dynet/expr.cc +++ b/dynet/expr.cc @@ -52,7 +52,7 @@ Expression operator+(const Expression& x, real y) { return y + x; } Expression operator-(const Expression& x, const Expression& y) { return x + (-y); } Expression operator-(real x, const Expression& y) { return Expression(y.pg, y.pg->add_function({y.i}, x)); } Expression operator-(const Expression& x, real y) { return -(y - x); } -Expression operator*(const Expression& x, const Expression& y) { return Expression(x.pg, x.pg->add_function({x.i, y.i})); } +Expression operator*(const Expression& x, const Expression& y) { x.pg->nodes[x.i]->matmul_count++; return Expression(x.pg, x.pg->add_function({x.i, y.i})); } Expression operator*(const Expression& x, float y) { return Expression(x.pg, x.pg->add_function({x.i}, y)); } Expression cmult(const Expression& x, const Expression& y) { if (x.dim().batch_size() == 1) diff --git a/dynet/nodes-common.cc b/dynet/nodes-common.cc index 174da835e..94b56c054 100644 --- a/dynet/nodes-common.cc +++ b/dynet/nodes-common.cc @@ -879,7 +879,13 @@ int MatrixMultiply::autobatch_sig(const ComputationGraph & cg, SigMap &sm) const // TODO do we want to treat different dimensions of first/second arg differently? if(dim.bd == 1) { Sig s(nt::matmul); - s.add_node(args[0]); + // if arg0 is likely to be shared, include it in the sig. + // otherwise, include both args dims in the sig. + if (cg.nodes[args[0]]->matmul_count > 2) { //TODO why 2? can we set a better number? + s.add_node(args[0]); s.add_dim(cg.nodes[args[1]]->dim); + } else { + s.add_dim(cg.nodes[args[0]]->dim); s.add_dim(cg.nodes[args[1]]->dim); + } return sm.get_idx(s); } else { return 0; // TODO handle the batched case as well? should it differ at all? @@ -887,8 +893,11 @@ int MatrixMultiply::autobatch_sig(const ComputationGraph & cg, SigMap &sm) const } std::vector MatrixMultiply::autobatch_concat(const ComputationGraph & cg) const { - vector ret(args.size(), 0); - if (dim.bd == 1) { ret[1] = 1; } + vector ret(2, 0); + if (dim.bd == 1) { + ret[1] = 1; + if (cg.nodes[args[0]]->matmul_count <= 2) { ret[0] = 1; } + } return ret; }