From 2f8a794fa022849088850471a9adc7ed91f8aade Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 19 Jul 2018 08:54:57 -0700 Subject: [PATCH 1/7] Initial commit for `Ifelse` --- benchmark/python/control_flow/foreach_rnn.py | 195 ------- .../python/control_flow/while_loop_rnn.py | 213 ------- docs/api/python/ndarray/contrib.md | 1 + docs/api/python/symbol/contrib.md | 1 + python/mxnet/ndarray/contrib.py | 97 +++- python/mxnet/symbol/contrib.py | 153 ++++- src/operator/control_flow.cc | 537 ++++++++++++++---- src/operator/subgraph_op_common.cc | 28 + src/operator/subgraph_op_common.h | 49 ++ .../unittest/test_contrib_control_flow.py | 156 +++++ 10 files changed, 918 insertions(+), 512 deletions(-) delete mode 100644 benchmark/python/control_flow/foreach_rnn.py delete mode 100644 benchmark/python/control_flow/while_loop_rnn.py diff --git a/benchmark/python/control_flow/foreach_rnn.py b/benchmark/python/control_flow/foreach_rnn.py deleted file mode 100644 index 4ce7a429ee9d..000000000000 --- a/benchmark/python/control_flow/foreach_rnn.py +++ /dev/null @@ -1,195 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import subprocess -import mxnet as mx -from mxnet import gluon -import time -import copy - -def get_gpus(): - """ - return a list of GPUs - """ - try: - re = subprocess.check_output(["nvidia-smi", "-L"], universal_newlines=True) - except OSError: - return [] - return range(len([i for i in re.split('\n') if 'GPU' in i])) - -class TestRNNLayer(gluon.HybridBlock): - def __init__(self, cell, prefix=None, params=None): - super(TestRNNLayer, self).__init__(prefix=prefix, params=params) - self.cell = cell - - def hybrid_forward(self, F, inputs, states): - out, states = F.contrib.foreach(self.cell, inputs, states) - return out - -def benchmark_rnn(cell, rnn_data, states): - ctx = rnn_data.context - num_batches = 20 - - # Imperative - cell0 = copy.deepcopy(cell) - layer0 = TestRNNLayer(cell0) - layer0.initialize(ctx=ctx) - - # Hybridize - cell1 = copy.deepcopy(cell) - cell1.hybridize() - layer1 = TestRNNLayer(cell1) - layer1.initialize(ctx=ctx) - - # Hybridize - cell2 = copy.deepcopy(cell) - layer2 = TestRNNLayer(cell2) - layer2.initialize(ctx=ctx) - layer2.hybridize() - layer2(rnn_data, states) - - # Hybridize - cell3 = copy.deepcopy(cell) - cell3.hybridize(static_alloc=True) - layer3 = TestRNNLayer(cell3) - layer3.initialize(ctx=ctx) - - tic = time.time() - for i in range(num_batches): - res0 = layer0(rnn_data, states) - mx.nd.waitall() - print("Imperative inference takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - res1 = layer1(rnn_data, states) - mx.nd.waitall() - print("Hybrid-cell inference takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - res3 = layer3(rnn_data, states) - mx.nd.waitall() - print("Static-hybrid-cell inference takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - res2 = layer2(rnn_data, states) - mx.nd.waitall() - print("Hybrid inference takes " + str(time.time() - tic)) - - layer2.export("foreach_rnn") - symnet = mx.symbol.load('foreach_rnn-symbol.json') - args1 = {} - params = layer2.collect_params() - for key in params.keys(): - args1[key] = params[key].data() - args1['data0'] = rnn_data - for i in range(len(states)): - args1['data' + str(i + 1)] = states[i] - exe = symnet.bind(ctx=ctx, args=args1) - tic = time.time() - for i in range(num_batches): - exe.forward(is_train=False) - mx.nd.waitall() - print("Symbol inference takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - with mx.autograd.record(): - res0 = layer0(rnn_data, states) - res0.backward() - mx.nd.waitall() - print("Imperative training takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - with mx.autograd.record(): - res1 = layer1(rnn_data, states) - res1.backward() - mx.nd.waitall() - print("Hybrid-cell training takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - with mx.autograd.record(): - res3 = layer3(rnn_data, states) - res3.backward() - mx.nd.waitall() - print("Static-hybrid-cell training takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - with mx.autograd.record(): - res2 = layer2(rnn_data, states) - res2.backward() - mx.nd.waitall() - print("Hybrid training takes " + str(time.time() - tic)) - - # gradients for the backward of the foreach symbol - args_grad1 = {} - for key in args1.keys(): - args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx) - exe = symnet.bind(ctx=ctx, args=args1, args_grad=args_grad1) - tic = time.time() - for i in range(num_batches): - exe.forward(is_train=True) - exe.backward(res2) - mx.nd.waitall() - print("Symbol training takes " + str(time.time() - tic)) - print("") - -if __name__ == '__main__': - ndim = 512 - seq_len = 100 - batch_sizes = [1, 32] - cells = [gluon.rnn.RNNCell(ndim, prefix='rnn_'), - gluon.rnn.GRUCell(ndim, prefix='rnn_'), - gluon.rnn.LSTMCell(ndim, prefix='rnn_')] - ctxs = [mx.cpu(0), mx.gpu(0)] - for cell in cells: - for ctx in ctxs: - for batch_size in batch_sizes: - if len(get_gpus()) == 0 and ctx == mx.gpu(0): - continue - if isinstance(cell, gluon.rnn.RNNCell): - rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim), - ctx=mx.cpu(0)) - states = [] - states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim), - ctx=mx.cpu(0))) - elif isinstance(cell, gluon.rnn.GRUCell): - rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim), - ctx=mx.cpu(0)) - states = [] - states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim), - ctx=mx.cpu(0))) - elif isinstance(cell, gluon.rnn.LSTMCell): - rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim), - ctx=mx.cpu(0)) - states = [] - states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim), - ctx=mx.cpu(0))) - states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim), - ctx=mx.cpu(0))) - if ctx == mx.gpu(0): - dev = "GPU" - else: - dev = "CPU" - print("Benchmark {} in {} (batch size: {})".format(cell._alias(), dev, - batch_size)) - benchmark_rnn(cell, rnn_data, states) diff --git a/benchmark/python/control_flow/while_loop_rnn.py b/benchmark/python/control_flow/while_loop_rnn.py deleted file mode 100644 index 42aaee5840dd..000000000000 --- a/benchmark/python/control_flow/while_loop_rnn.py +++ /dev/null @@ -1,213 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# Code borrowed from ./benchmark/python/control_flow/foreach_rnn.py - -import subprocess -import mxnet as mx -from mxnet import gluon -import time -import copy - -def get_gpus(): - """ - return a list of GPUs - """ - try: - re = subprocess.check_output(["nvidia-smi", "-L"], universal_newlines=True) - except OSError: - return [] - return range(len([i for i in re.split('\n') if 'GPU' in i])) - -class TestRNNLayer(gluon.HybridBlock): - def __init__(self, cell, length, prefix=None, params=None): - super(TestRNNLayer, self).__init__(prefix=prefix, params=params) - self.length = length - self.cell = cell - - def hybrid_forward(self, F, inputs, states): - def _func(*states): - i = states[0] - s = states[1: ] - data = inputs.take(i).squeeze(axis=0) - out, new_s = self.cell(data, s) - new_s = [i + 1] + new_s - return out, new_s - out, states = F.contrib.while_loop( - cond=lambda i, *_: i < self.length, - func=_func, - loop_vars=states, - max_iterations=self.length, - ) - return out + states - -def benchmark_rnn(cell, rnn_data, states, length): - ctx = rnn_data.context - num_batches = 20 - - # Imperative - cell0 = copy.deepcopy(cell) - layer0 = TestRNNLayer(cell0, length) - layer0.initialize(ctx=ctx) - - # Hybrid-cell - cell1 = copy.deepcopy(cell) - cell1.hybridize() - layer1 = TestRNNLayer(cell1, length) - layer1.initialize(ctx=ctx) - - # Hybrid - cell2 = copy.deepcopy(cell) - layer2 = TestRNNLayer(cell2, length) - layer2.initialize(ctx=ctx) - layer2.hybridize() - layer2(rnn_data, states) - - # Static-hybrid-cell - cell3 = copy.deepcopy(cell) - cell3.hybridize(static_alloc=True) - layer3 = TestRNNLayer(cell3, length) - layer3.initialize(ctx=ctx) - - tic = time.time() - for i in range(num_batches): - res0 = layer0(rnn_data, states) - mx.nd.waitall() - print("Imperative inference takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - res1 = layer1(rnn_data, states) - mx.nd.waitall() - print("Hybrid-cell inference takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - res3 = layer3(rnn_data, states) - mx.nd.waitall() - print("Static-hybrid-cell inference takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - res2 = layer2(rnn_data, states) - mx.nd.waitall() - print("Hybrid inference takes " + str(time.time() - tic)) - - layer2.export("while_loop_rnn") - symnet = mx.symbol.load('while_loop_rnn-symbol.json') - args1 = {} - params = layer2.collect_params() - for key in params.keys(): - args1[key] = params[key].data() - args1['data0'] = rnn_data - for i in range(len(states)): - args1['data' + str(i + 1)] = states[i] - exe = symnet.bind(ctx=ctx, args=args1) - tic = time.time() - for i in range(num_batches): - exe.forward(is_train=False) - mx.nd.waitall() - print("Symbol inference takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - with mx.autograd.record(): - res0 = layer0(rnn_data, states) - res0[0].backward() - mx.nd.waitall() - print("Imperative training takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - with mx.autograd.record(): - res1 = layer1(rnn_data, states) - res1[0].backward() - mx.nd.waitall() - print("Hybrid-cell training takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - with mx.autograd.record(): - res3 = layer3(rnn_data, states) - res3[0].backward() - mx.nd.waitall() - print("Static-hybrid-cell training takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - with mx.autograd.record(): - res2 = layer2(rnn_data, states) - res2[0].backward() - mx.nd.waitall() - print("Hybrid training takes " + str(time.time() - tic)) - - # gradients for the backward of the while_loop symbol - args_grad1 = {} - for key in args1.keys(): - if key != "data1": - args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx) - exe = symnet.bind(ctx=ctx, args=args1, args_grad=args_grad1) - tic = time.time() - for i in range(num_batches): - exe.forward(is_train=True) - exe.backward(res2) - mx.nd.waitall() - print("Symbol training takes " + str(time.time() - tic)) - print("") - -if __name__ == '__main__': - def _zeros(shape): - return mx.nd.zeros(shape=shape, ctx=mx.cpu(0)) - def _array(shape): - return mx.nd.normal(loc=0.0, scale=1.0, shape=shape, ctx=mx.cpu(0)) - ndim = 512 - seq_len = 100 - batch_sizes = [1, 32] - cells = [gluon.rnn.RNNCell(ndim, prefix='rnn_'), - gluon.rnn.GRUCell(ndim, prefix='rnn_'), - gluon.rnn.LSTMCell(ndim, prefix='rnn_')] - ctxs = [mx.cpu(0), mx.gpu(0)] - for cell in cells: - for ctx in ctxs: - for batch_size in batch_sizes: - if len(get_gpus()) == 0 and ctx == mx.gpu(0): - continue - if isinstance(cell, gluon.rnn.RNNCell): - rnn_data = _array((seq_len, batch_size, ndim)) - states = [ - _zeros((1, )), - _array((batch_size, ndim)), - ] - if isinstance(cell, gluon.rnn.GRUCell): - rnn_data = _array((seq_len, batch_size, ndim)) - states = [ - _zeros((1, )), - _array((batch_size, ndim)), - ] - elif isinstance(cell, gluon.rnn.LSTMCell): - rnn_data = _array((seq_len, batch_size, ndim)) - states = [ - _zeros((1, )), - _array((batch_size, ndim)), - _array((batch_size, ndim)), - ] - if ctx == mx.gpu(0): - dev = "GPU" - else: - dev = "CPU" - print("Benchmark {} in {} (batch size: {})".format(cell._alias(), dev, batch_size)) - benchmark_rnn(cell, rnn_data, states, seq_len) diff --git a/docs/api/python/ndarray/contrib.md b/docs/api/python/ndarray/contrib.md index 0cf8724de301..80d8ef23b459 100644 --- a/docs/api/python/ndarray/contrib.md +++ b/docs/api/python/ndarray/contrib.md @@ -54,6 +54,7 @@ In the rest of this document, we list routines provided by the `ndarray.contrib` quantize foreach while_loop + ifelse ``` ## API Reference diff --git a/docs/api/python/symbol/contrib.md b/docs/api/python/symbol/contrib.md index ba43f2d6633c..96ce7987d800 100644 --- a/docs/api/python/symbol/contrib.md +++ b/docs/api/python/symbol/contrib.md @@ -54,6 +54,7 @@ In the rest of this document, we list routines provided by the `symbol.contrib` quantize foreach while_loop + ifelse ``` ## API Reference diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index b67cf5a55daf..b7b63c4e10e6 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -28,7 +28,7 @@ except ImportError: pass -__all__ = ["rand_zipfian", "foreach", "while_loop"] +__all__ = ["rand_zipfian", "foreach", "while_loop", "ifelse"] # pylint: disable=line-too-long def rand_zipfian(true_classes, num_sampled, range_max, ctx=None): @@ -192,7 +192,6 @@ def check_input(inputs, in_type, msg): outputs = outputs[0] return (outputs, states) - def while_loop(cond, func, loop_vars, max_iterations=None): """Run a while loop with user-defined computation and loop condition. @@ -363,3 +362,97 @@ def _func_wrapper(loop_vars): [" Step %d, shape is %s" % (i, str(x.shape)) for i, x in enumerate(items)] )) return stacked_outputs, list(loop_vars) + +def ifelse(cond, then_func, else_func, inputs): + """Run a if-then-else using user-defined condition and computation + + This operator simulates a if-like branch which chooses to do one of + the two customized computations according to the specified condition. + + `inputs` is a list of NDArrays on which the condition and computations reply on. + + `cond` is a user-defined function, used as the if condition. + It consumes `inputs`, and produces a scalar MXNet NDArray, + indicating which branch of computation should be used. + The `cond` is variadic, and its signature should be + `cond(*loop_vars) => NDArray`. + + `then_func` is a user-defined function, used as computation of the then branch. + It consumes `inputs`, and produces `outputs`. + The `then_func` is variadic, and its signature should be + `then_func(*loop_vars) => List[NDArray]`. + + `else_func` is a user-defined function, used as computation of the else branch. + It also consumes `inputs`, and produces `outputs`. + The `else_func` is variadic, and its signature should be + `else_func(*loop_vars) => List[NDArray]`. + + The `outputs` produces by `then_func` and `else_func` should have the same number + of elements, all of which should be in the same shape, of the same dtype and stype. + + This function returns a list of NDArrays, representing the computation result. + + Parameters + ---------- + cond: a Python function. + The branch condition. + then_func: a Python function. + The computation to be executed if `cond` is true. + else_func: a Python function. + The computation to be executed if `cond` is false. + inputs: list of NDArrays. + The variables fed to `cond`, `then_func` and `else_func`. + + Returns + ------- + outputs: a list of NDArrays, representing the result of computation. + + Examples + -------- + >>> cond = lambda a, b: a * b < 5 + >>> then_func = lambda a, b: (a + 5) * (b + 5) + >>> else_func = lambda a, b: (a - 5) * (b - 5) + >>> inputs = (mx.nd.array([1]), mx.nd.array([2])) + >>> outputs = mx.nd.contrib.ifelse(cond, then_func, else_func, inputs) + >>> outputs[0] + [42.] + + """ + def _to_python_scalar(inputs, type_, name): + """Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray, other python types, + to the given type + """ + if hasattr(inputs, "asscalar"): + inputs = inputs.asscalar() + try: + inputs = type_(inputs) + except: + raise ValueError("Cannot convert %s to python %s" % (name, type_.__name__)) + return inputs + + def _to_ndarray_tuple(inputs, name): + """Converts "inputs", possibly a single mxnet NDArray, a list of mxnet NDArray, + a tuple of mxnet NDArray, into a tuple of NDArray + """ + if isinstance(inputs, list): + inputs = tuple(inputs) + if isinstance(inputs, ndarray.NDArray): + inputs = (inputs, ) + if not isinstance(inputs, tuple): + raise ValueError("%s must be an NDArray, or a tuple or list of NDArrays" % (name, )) + for item in inputs: + if not isinstance(item, ndarray.NDArray): + raise ValueError("%s must be an NDArray, or a tuple or list of NDArrays" % (name, )) + return inputs + + inputs = _to_ndarray_tuple(inputs, "inputs") + if len(inputs) == 0: + raise ValueError("inputs should contain at least one element") + branch = _to_python_scalar(cond(*inputs), bool, "Return value of cond") + if branch: + outputs = then_func(*inputs) + outputs = _to_ndarray_tuple(outputs, "outputs of then_func") + else: + outputs = else_func(*inputs) + outputs = _to_ndarray_tuple(outputs, "outputs of else_func") + return list(outputs) diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index 2c11921383c8..33932ba5ad94 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -34,7 +34,7 @@ from ..base import SymbolHandle, _as_list from ..attribute import AttrScope -__all__ = ["rand_zipfian", "foreach", "while_loop"] +__all__ = ["rand_zipfian", "foreach", "while_loop", "ifelse"] def rand_zipfian(true_classes, num_sampled, range_max): """Draw random samples from an approximately log-uniform or Zipfian distribution. @@ -556,3 +556,154 @@ def _union_inputs(*graphs): outputs = [result[i] for i in range(num_out_data)] final_loop_vars = [result[i] for i in range(num_out_data, num_outputs)] return outputs, final_loop_vars + +def ifelse(cond, then_func, else_func, inputs, name="ifelse"): + """Run a if-then-else using user-defined condition and computation + + This operator simulates a if-like branch which chooses to do one of + the two customized computations according to the specified condition. + + `inputs` is a list of Symbols on which the condition and computations reply on. + + `cond` is a user-defined function, used as the if condition. + It consumes `inputs`, and produces a scalar MXNet symbol, + indicating which branch of computation should be used. + The `cond` is variadic, and its signature should be + `cond(*loop_vars) => Symbol`. + + `then_func` is a user-defined function, used as computation of the then branch. + It consumes `inputs`, and produces `outputs`. + The `then_func` is variadic, and its signature should be + `then_func(*loop_vars) => List[Symbol]`. + + `else_func` is a user-defined function, used as computation of the else branch. + It also consumes `inputs`, and produces `outputs`. + The `else_func` is variadic, and its signature should be + `else_func(*loop_vars) => List[Symbol]`. + + The `outputs` produces by `then_func` and `else_func` should have the same number + of elements, all of which should be in the same shape, of the same dtype and stype. + + This function returns a list of symbols, representing the computation result. + + Parameters + ---------- + cond: a Python function. + The branch condition. + then_func: a Python function. + The computation to be executed if `cond` is true. + else_func: a Python function. + The computation to be executed if `cond` is false. + inputs: list of Symbols. + The variables fed to `cond`, `then_func` and `else_func`. + + Returns + ------- + outputs: a list of Symbols, representing the result of computation. + + Examples + -------- + >>> cond = lambda a, b: a * b < 5 + >>> then_func = lambda a, b: (a + 5) * (b + 5) + >>> else_func = lambda a, b: (a - 5) * (b - 5) + >>> inputs = (mx.sym.var('a'), mx.sym.var('b')) + >>> outputs = mx.sym.contrib.ifelse(cond, then_func, else_func, inputs) + """ + def _to_symbol_tuple(inputs, name): + """Converts "inputs", possibly a single mxnet Symbol, a list of mxnet Symbol, + a tuple of mxnet Symbol, into a tuple of Symbol + """ + if isinstance(inputs, list): + inputs = tuple(inputs) + if isinstance(inputs, Symbol): + inputs = (inputs, ) + if not isinstance(inputs, tuple): + raise ValueError("%s must be a Symbol, or a tuple or list of Symbol" % (name, )) + for item in inputs: + if not isinstance(item, Symbol): + raise ValueError("%s must be a Symbol, or a tuple or list of Symbol" % (name, )) + return inputs + + def _create_subgraph(graph_vars, graph_func, subgraph_name): + with AttrScope(__subgraph_name__=subgraph_name): + # create new variables with the same name, + # them feed them to the given func + new_graph_vars = [symbol.var(sym.name) for sym in graph_vars] + outputs = graph_func(*new_graph_vars) + outputs = _to_symbol_tuple(outputs, "outputs") + num_outputs = len(outputs) + # nnvm cut-graph does not allow inputs and outputs overlap + # so we calculate the name of inputs, and copy outputs once it overlaps with inputs + all_input_names = symbol.Group(outputs).list_inputs() + make_identity = lambda x: symbol.op.identity(x) if x.name in all_input_names else x + # group all outputs of graph_func + graph = symbol.Group(list(map(make_identity, outputs))) + return graph, num_outputs + + def _union_inputs(*graphs): + # Given a list of graphs, each whose inputs are either from input_vars or other variables. + # 1) calculate a list `inputs`, the union of their inputs. + # 2) for each graph, determine in which indices their inputs reside in `inputs` + # 3) for each variable in the input of `graph`, find which index it is + inputs = [] # List[Symbol], result of 1) + locs = [] # List[Tuple(List[Int], List[Int])], a list of tuples, + # where tuples are results of 2) and 3) + input_id_to_loc = {} # Dict[int, int], given id(sym), input_id_to_loc maps it + # to a `loc`, where inputs[loc] = sym + for graph in graphs: + # input_syms: all inputs to the `graph` + name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)} + # some input_vars are inputs to `graph`, some are not + name_to_input_vars = {sym.name: sym for sym in inputs} + # other inputs to `graph` created by cut_graph + name_to_cut_g_syms = {sym.list_outputs()[0]: sym for sym in _cut_subgraph(graph)} + # collect arguments for each subgraph + input_locs = [] # results from the second step + for name in graph.list_inputs(): + assert name in name_to_input_syms # it should obviously hold + # name -> sym + if name in name_to_input_vars: + sym = name_to_input_vars[name] + elif name in name_to_cut_g_syms: + sym = name_to_cut_g_syms[name] + else: + sym = copy.deepcopy(name_to_input_syms[name]) + # do 2), and 1) is implicitly done + if id(sym) in input_id_to_loc: + loc = input_id_to_loc[id(sym)] + else: + loc = len(input_id_to_loc) + inputs.append(sym) + input_id_to_loc[id(sym)] = loc + input_locs.append(loc) + locs.append(input_locs) + return inputs, locs + inputs = _to_symbol_tuple(inputs, "inputs") + if len(inputs) == 0: + raise ValueError("loop_vars should contain at least one element") + # create graph for `cond' + cond_g, num_outputs = _create_subgraph(inputs, cond, name + "_cond") + if num_outputs != 1: + raise ValueError("cond should always produce a single output") + # create graph for `then` + then_g, then_num_outputs = _create_subgraph(inputs, then_func, name + "_then") + # create graph for `else` + else_g, else_num_outputs = _create_subgraph(inputs, else_func, name + "_else") + if then_num_outputs != else_num_outputs: + raise ValueError("Number of outputs differs between then-branch and else-branch") + # find symbols used in either cond_g or func_g + input_syms, (cond_input_locs, then_input_locs, else_input_locs) = \ + _union_inputs(cond_g, then_g, else_g) + result = symbol._internal._ifelse( + # [cond, then_g, else_g, *input_syms] + cond_g, + then_g, + else_g, + *input_syms, + cond_input_locs=cond_input_locs, + then_input_locs=then_input_locs, + else_input_locs=else_input_locs, + num_outputs=then_num_outputs + ) + result = _to_symbol_tuple(result, "result") + return list(result) diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc index b00ed9b19d8c..261bd5070f7d 100644 --- a/src/operator/control_flow.cc +++ b/src/operator/control_flow.cc @@ -508,6 +508,18 @@ struct WhileLoopParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(func_var_locs) .describe("The locations of loop_vars among func's inputs."); } + template + bool sync_in_out(std::vector *in, + std::vector *out, + std::function is_empty) const { + for (int i = this->num_out_data; i < this->num_outputs; ++i) { + // each out->at(i) is a params, loop_var + T &x = in->at(this->func_input_locs[this->func_var_locs[i - this->num_out_data]]); + T &y = out->at(i); + fill_value(&x, &y, is_empty(x), is_empty(y)); + } + return true; + } }; // struct WhileLoopParam DMLC_REGISTER_PARAMETER(WhileLoopParam); @@ -540,84 +552,8 @@ class WhileLoopState: public LoopState { } } } - template - static void extract_by_loc(const std::vector &array, - const nnvm::Tuple input_locs, - std::vector *out) { - out->clear(); - out->reserve(input_locs.ndim()); - for (dim_t i : input_locs) { - out->push_back(array[i]); - } - } - static bool is_shape_udf(const TShape &x) { - return x.ndim() == 0 || x.Size() == 0; - } - static bool is_stype_udf(const int &x) { - return x == exec::kBadStorageID; - } - static bool is_type_udf(const int &x) { - return x == -1; - } - template - static bool fill_value(T *x, T *y, bool x_empty, bool y_empty) { - if (*x == *y || (x_empty && y_empty)) { - return true; - } - if (!x_empty && !y_empty) { - return false; - } - if (x_empty) { - *x = *y; - } - if (y_empty) { - *y = *x; - } - return true; - } - template - static bool sync_in_in(const nnvm::Tuple &input_locs, - std::vector *in, - std::vector *subg_in, - std::function is_empty) { - for (size_t i = 0; i < input_locs.ndim(); ++i) { - T &x = in->at(input_locs[i]); - T &y = subg_in->at(i); - fill_value(&x, &y, is_empty(x), is_empty(y)); - } - return true; - } - template - static bool sync_in_out(const WhileLoopParam& params, - std::vector *in, - std::vector *out, - std::function is_empty) { - for (int i = params.num_out_data; i < params.num_outputs; ++i) { - // each out->at(i) is a params, loop_var - T &x = in->at(params.func_input_locs[params.func_var_locs[i - params.num_out_data]]); - T &y = out->at(i); - fill_value(&x, &y, is_empty(x), is_empty(y)); - } - return true; - } }; -template -T _asscalar(const NDArray &a) { - CHECK_EQ(a.shape().Size(), 1U); - T data; - a.SyncCopyToCPU(&data, 1U); - return data; -} - -bool as_bool_scalar(const NDArray &a) { - MSHADOW_TYPE_SWITCH(a.dtype(), DType, { - return static_cast(_asscalar(a)); - }); - LOG(FATAL) << "Unknown dtype"; - return false; -} - static void WhileLoopComputeExCPU(const OpStatePtr& state_ptr, const OpContext& ctx, const std::vector& inputs, @@ -648,13 +584,13 @@ static void WhileLoopComputeExCPU(const OpStatePtr& state_ptr, CHECK_EQ(params.max_iterations, outputs[i].shape()[0]); // construct inputs and outputs for cond std::vector cond_inputs, cond_outputs = {NDArray()}; - WhileLoopState::extract_by_loc(inputs, params.cond_input_locs, &cond_inputs); + extract_by_loc(inputs, params.cond_input_locs, &cond_inputs); std::vector cond_input_ptr, cond_output_ptr; to_ptr_vec(cond_inputs, &cond_input_ptr); to_ptr_vec(cond_outputs, &cond_output_ptr); // construct inputs and outputs for func std::vector func_inputs, func_outputs(outputs.size()); - WhileLoopState::extract_by_loc(inputs, params.func_input_locs, &func_inputs); + extract_by_loc(inputs, params.func_input_locs, &func_inputs); for (size_t &step = state.n_iterations = 0; step < (size_t) params.max_iterations; ++step) { state.cond_op->Forward(nullptr, cond_input_ptr, cond_output_ptr); if (!as_bool_scalar(*cond_output_ptr[0])) { @@ -716,8 +652,8 @@ static void WhileLoopGradComputeExCPU(const OpStatePtr& state_ptr, } std::vector outputs; std::vector req; - WhileLoopState::extract_by_loc(_outputs, params.func_input_locs, &outputs); - WhileLoopState::extract_by_loc(_req, params.func_input_locs, &req); + extract_by_loc(_outputs, params.func_input_locs, &outputs); + extract_by_loc(_req, params.func_input_locs, &req); if (state.n_iterations == 0) { for (int i = params.num_out_data; i < params.num_outputs; ++i) { int j = params.func_var_locs[i - params.num_out_data]; @@ -796,7 +732,7 @@ static bool WhileLoopShape(const nnvm::NodeAttrs& attrs, std::vector *out_shape) { using nnvm::ShapeVector; const WhileLoopParam& params = nnvm::get(attrs.parsed); - static const std::function is_udf = WhileLoopState::is_shape_udf; + static const std::function is_udf = is_shape_udf; // sanity checks CHECK_EQ(in_shape->size() + 2U, (size_t) params.num_args); CHECK_EQ(out_shape->size(), (size_t) params.num_outputs); @@ -811,7 +747,7 @@ static bool WhileLoopShape(const nnvm::NodeAttrs& attrs, // create subg_in ShapeVector subg_in; ShapeVector &subg_out = *_subg_out; - WhileLoopState::extract_by_loc(*in_shape, input_locs, &subg_in); + extract_by_loc(*in_shape, input_locs, &subg_in); // create an indexed graph nnvm::Graph g; g.outputs = subg->outputs; @@ -884,35 +820,35 @@ static bool WhileLoopShape(const nnvm::NodeAttrs& attrs, }; ShapeVector cond_out_shape{TShape(1U)}; // this means: [(1, )] ShapeVector func_out_shape(params.num_outputs); - CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf)); + CHECK(params.sync_in_out(in_shape, out_shape, is_udf)); bool succ_0 = infer_subg(attrs.subgraphs[0], &cond_out_shape, params.cond_input_locs, 0, false); - CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf)); + CHECK(params.sync_in_out(in_shape, out_shape, is_udf)); bool succ_1 = infer_subg(attrs.subgraphs[1], &func_out_shape, \ params.func_input_locs, params.num_out_data, true); - CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf)); + CHECK(params.sync_in_out(in_shape, out_shape, is_udf)); return succ_0 && succ_1; } static bool WhileLoopType(const nnvm::NodeAttrs& attrs, std::vector *in_type, std::vector *out_type) { const WhileLoopParam& params = nnvm::get(attrs.parsed); - static const std::function is_udf = WhileLoopState::is_type_udf; + static const std::function is_udf = is_type_udf; CHECK_EQ(in_type->size() + 2U, (size_t) params.num_args); CHECK_EQ(out_type->size(), (size_t) params.num_outputs); CHECK_EQ(attrs.subgraphs.size(), 2U); CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U); std::vector cond_in_type; std::vector func_in_type; - WhileLoopState::extract_by_loc(*in_type, params.cond_input_locs, &cond_in_type); - WhileLoopState::extract_by_loc(*in_type, params.func_input_locs, &func_in_type); + extract_by_loc(*in_type, params.cond_input_locs, &cond_in_type); + extract_by_loc(*in_type, params.func_input_locs, &func_in_type); std::vector cond_out_type = {0}; - CHECK(WhileLoopState::sync_in_out(params, in_type, out_type, is_udf)); + CHECK(params.sync_in_out(in_type, out_type, is_udf)); bool succ_0 = InferSubgraphDataType(*attrs.subgraphs[0], &cond_in_type, &cond_out_type); - CHECK(WhileLoopState::sync_in_out(params, in_type, out_type, is_udf)); - CHECK(WhileLoopState::sync_in_in(params.cond_input_locs, in_type, &cond_in_type, is_udf)); + CHECK(params.sync_in_out(in_type, out_type, is_udf)); + CHECK(sync_in_in(params.cond_input_locs, in_type, &cond_in_type, is_udf)); bool succ_1 = InferSubgraphDataType(*attrs.subgraphs[1], &func_in_type, out_type); - CHECK(WhileLoopState::sync_in_out(params, in_type, out_type, is_udf)); - CHECK(WhileLoopState::sync_in_in(params.func_input_locs, in_type, &func_in_type, is_udf)); + CHECK(params.sync_in_out(in_type, out_type, is_udf)); + CHECK(sync_in_in(params.func_input_locs, in_type, &func_in_type, is_udf)); return succ_0 && succ_1; } @@ -922,28 +858,28 @@ static bool WhileLoopStorageType(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { const WhileLoopParam& params = nnvm::get(attrs.parsed); - static const std::function is_udf = WhileLoopState::is_stype_udf; + static const std::function is_udf = is_stype_udf; CHECK_EQ(in_attrs->size() + 2U, (size_t) params.num_args); CHECK_EQ(out_attrs->size(), (size_t) params.num_outputs); CHECK_EQ(attrs.subgraphs.size(), 2U); CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U); std::vector cond_in_attrs; std::vector func_in_attrs; - WhileLoopState::extract_by_loc(*in_attrs, params.cond_input_locs, &cond_in_attrs); - WhileLoopState::extract_by_loc(*in_attrs, params.func_input_locs, &func_in_attrs); + extract_by_loc(*in_attrs, params.cond_input_locs, &cond_in_attrs); + extract_by_loc(*in_attrs, params.func_input_locs, &func_in_attrs); std::vector cond_out_attrs = {kDefaultStorage}; DispatchMode cond_mode = DispatchMode::kUndefined; DispatchMode func_mode = DispatchMode::kUndefined; *dispatch_mode = DispatchMode::kFComputeEx; - CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf)); + CHECK(params.sync_in_out(in_attrs, out_attrs, is_udf)); bool succ_0 = InferSubgraphStorage(*attrs.subgraphs[0], dev_mask, \ &cond_mode, &cond_in_attrs, &cond_out_attrs); - CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf)); - CHECK(WhileLoopState::sync_in_in(params.cond_input_locs, in_attrs, &cond_in_attrs, is_udf)); + CHECK(params.sync_in_out(in_attrs, out_attrs, is_udf)); + CHECK(sync_in_in(params.cond_input_locs, in_attrs, &cond_in_attrs, is_udf)); bool succ_1 = InferSubgraphStorage(*attrs.subgraphs[1], dev_mask, \ &func_mode, &func_in_attrs, out_attrs); - CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf)); - CHECK(WhileLoopState::sync_in_in(params.func_input_locs, in_attrs, &func_in_attrs, is_udf)); + CHECK(params.sync_in_out(in_attrs, out_attrs, is_udf)); + CHECK(sync_in_in(params.func_input_locs, in_attrs, &func_in_attrs, is_udf)); return succ_0 && succ_1; } @@ -977,6 +913,342 @@ WhileLoopGradient(const nnvm::NodePtr& n, const std::vector& og return entries; } +struct IfelseParam : public dmlc::Parameter { + int num_args; + int num_outputs; + nnvm::Tuple cond_input_locs; + nnvm::Tuple then_input_locs; + nnvm::Tuple else_input_locs; + DMLC_DECLARE_PARAMETER(IfelseParam) { + DMLC_DECLARE_FIELD(num_args).set_lower_bound(3) + .describe("Number of input arguments, including cond, then and else as three symbol inputs."); + DMLC_DECLARE_FIELD(num_outputs).set_lower_bound(1) + .describe("The number of outputs of the subgraph."); + DMLC_DECLARE_FIELD(cond_input_locs) + .describe("The locations of cond's inputs in the given inputs."); + DMLC_DECLARE_FIELD(then_input_locs) + .describe("The locations of then's inputs in the given inputs."); + DMLC_DECLARE_FIELD(else_input_locs) + .describe("The locations of else's inputs in the given inputs."); + } +}; // struct IfelseParam + +DMLC_REGISTER_PARAMETER(IfelseParam); + +class IfelseState { + public: + IfelseParam params; + CachedOpPtr cond_op; + LoopState then_branch; + LoopState else_branch; + int branch_selection; // 1 if then branch; 0 if else branch; -1 if undefined + + IfelseState(const IfelseParam ¶ms, + const Symbol &cond, + const Symbol &then_sym, + const Symbol &else_sym): + params(params), + cond_op(LoopState::MakeSharedOp(cond)), + then_branch(then_sym), + else_branch(else_sym), + branch_selection(-1) { + } +}; + +static void IfelseComputeExCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + // The argument `inputs' are loop_vars and other inputs + // loop_vars are stored in stored in `loop_vars_locs' + // The argument `outputs' are output and new_loop_vars + // [0: num_out_data) are outputs at each step. + // [num_out_data: ) are new_loop_vars + IfelseState &state = state_ptr.get_state(); + const IfelseParam& params = state.params; + // a helper function, converting std::vector to std::vector + const auto to_ptr_vec = [](std::vector &in, std::vector *out) { + out->clear(); + out->reserve(in.size()); + std::transform(std::begin(in), + std::end(in), + std::back_inserter(*out), + [](NDArray &a) {return &a;}); + }; + // sanity checks + CHECK_EQ(inputs.size() + 3U, (size_t) params.num_args); + CHECK_EQ(outputs.size(), (size_t) params.num_outputs); + CHECK_EQ(outputs.size(), req.size()); + // construct inputs and outputs for cond + std::vector cond_inputs; + std::vector cond_outputs = {NDArray()}; + std::vector cond_input_ptr; + std::vector cond_output_ptr; + extract_by_loc(inputs, params.cond_input_locs, &cond_inputs); + to_ptr_vec(cond_inputs, &cond_input_ptr); + to_ptr_vec(cond_outputs, &cond_output_ptr); + int &branch_selection = state.branch_selection; + // run cond + state.cond_op->Forward(nullptr, cond_input_ptr, cond_output_ptr); + branch_selection = as_bool_scalar(*cond_output_ptr[0]); + // select the right branch + const nnvm::Tuple &func_input_locs = branch_selection + ? params.then_input_locs + : params.else_input_locs; + LoopState &loop_state = branch_selection + ? state.then_branch + : state.else_branch; + // extract inputs for the branch + std::vector func_inputs; + extract_by_loc(inputs, func_input_locs, &func_inputs); + loop_state.Forward(0, func_inputs, req, outputs, ctx.need_grad); +} + +static void IfelseGradComputeExCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& _req, + const std::vector& outputs) { + IfelseState &state = state_ptr.get_state(); + const IfelseParam& params = state.params; + // sanity checks + CHECK_EQ(outputs.size() + 3U, (size_t) params.num_args); + CHECK_EQ(outputs.size(), _req.size()); + // select the right branch + int branch_selection = state.branch_selection; + CHECK_NE(branch_selection, -1); + const nnvm::Tuple &func_input_locs = branch_selection + ? params.then_input_locs + : params.else_input_locs; + LoopState &loop_state = branch_selection + ? state.then_branch + : state.else_branch; + // construct parameters + std::vector ograds(inputs.begin(), inputs.begin() + params.num_outputs); + std::vector req; + extract_by_loc(_req, func_input_locs, &req); + std::vector igrads; + extract_by_loc(outputs, func_input_locs, &igrads); + loop_state.Backward(0, ograds, req, igrads); + loop_state.Cleanup(); +} + +static bool IfelseShape(const nnvm::NodeAttrs& attrs, + std::vector *in_shape, + std::vector *out_shape) { + using nnvm::ShapeVector; + const IfelseParam& params = nnvm::get(attrs.parsed); + static const std::function is_udf = is_shape_udf; + // sanity checks + CHECK_EQ(in_shape->size() + 3U, (size_t) params.num_args); + CHECK_EQ(out_shape->size(), (size_t) params.num_outputs); + CHECK_EQ(attrs.subgraphs.size(), 3U); + CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U); + CHECK_EQ(attrs.subgraphs[1]->outputs.size(), attrs.subgraphs[2]->outputs.size()); + // infer shape for cond, then and else + auto infer_subg = [¶ms, in_shape, out_shape](std::shared_ptr subg, + ShapeVector *_subg_out, + const nnvm::Tuple &input_locs, + bool fill_out_shape) { + // create subg_in + ShapeVector subg_in; + ShapeVector &subg_out = *_subg_out; + extract_by_loc(*in_shape, input_locs, &subg_in); + // create an indexed graph + nnvm::Graph g; + g.outputs = subg->outputs; + const auto& idx = g.indexed_graph(); + // get input nodes + const auto &input_nids = idx.input_nodes(); + // sanity checks + CHECK_EQ(input_nids.size(), subg_in.size()); + CHECK_EQ(g.outputs.size(), subg_out.size()); + CHECK_EQ(idx.input_nodes().size(), subg_in.size()); + CHECK_EQ(idx.outputs().size(), subg_out.size()); + // create empty shapes for inference + ShapeVector shapes(idx.num_node_entries()); + // copy subg_in into shapes + for (size_t i = 0; i < subg_in.size(); ++i) { + auto eid = idx.entry_id(input_nids[i], 0); + shapes[eid] = subg_in[i]; + } + // copy subg_out into shapes + for (size_t i = 0; i < subg_out.size(); ++i) { + auto eid = idx.entry_id(g.outputs[i]); + shapes[eid] = subg_out[i]; + } + // copy done, call InferShape + g.attrs["shape"] = std::make_shared(std::move(shapes)); + g = exec::InferShape(std::move(g)); + // now `shapes' won't be used anymore, use new_shapes instead + const auto& new_shapes = g.GetAttr("shape"); + // copy subg_in back to in_shape + for (size_t i = 0; i < subg_in.size(); ++i) { + auto eid = idx.entry_id(input_nids[i], 0); + auto g_out_shape = new_shapes[eid]; + if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) { + // when the shape is not fully inferred + continue; + } + SHAPE_ASSIGN_CHECK(*in_shape, input_locs[i], g_out_shape); + } + if (!fill_out_shape) { + return true; + } + // copy subg_out back to out_shape + for (size_t i = 0; i < g.outputs.size(); ++i) { + auto eid = idx.entry_id(g.outputs[i]); + auto g_out_shape = new_shapes[eid]; + if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) { + // when the shape is not fully inferred + continue; + } + SHAPE_ASSIGN_CHECK(*out_shape, i, g_out_shape); + } + return g.GetAttr("shape_num_unknown_nodes") == 0; + }; + ShapeVector cond_out_shape{TShape(1U)}; // this means: [(1, )] + ShapeVector then_out_shape(params.num_outputs); + ShapeVector else_out_shape(params.num_outputs); + bool succ_0 = infer_subg(attrs.subgraphs[0], &cond_out_shape, \ + params.cond_input_locs, false); + bool succ_1 = infer_subg(attrs.subgraphs[1], &then_out_shape, \ + params.then_input_locs, true); + bool succ_2 = infer_subg(attrs.subgraphs[2], &else_out_shape, \ + params.else_input_locs, true); + return succ_0 && succ_1 && succ_2; +} + +static bool IfelseType(const nnvm::NodeAttrs& attrs, + std::vector *in_type, + std::vector *out_type) { + const IfelseParam& params = nnvm::get(attrs.parsed); + static const std::function is_udf = is_type_udf; + CHECK_EQ(in_type->size() + 3U, (size_t) params.num_args); + CHECK_EQ(out_type->size(), (size_t) params.num_outputs); + CHECK_EQ(attrs.subgraphs.size(), 3U); + CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U); + CHECK_EQ(attrs.subgraphs[1]->outputs.size(), attrs.subgraphs[2]->outputs.size()); + std::vector cond_in_type; + std::vector then_in_type; + std::vector else_in_type; + extract_by_loc(*in_type, params.cond_input_locs, &cond_in_type); + extract_by_loc(*in_type, params.then_input_locs, &then_in_type); + extract_by_loc(*in_type, params.else_input_locs, &else_in_type); + std::vector cond_out_type = {0}; + bool succ_0 = InferSubgraphDataType(*attrs.subgraphs[0], &cond_in_type, &cond_out_type); + CHECK(sync_in_in(params.cond_input_locs, in_type, &cond_in_type, is_udf)); + bool succ_1 = InferSubgraphDataType(*attrs.subgraphs[1], &then_in_type, out_type); + CHECK(sync_in_in(params.then_input_locs, in_type, &then_in_type, is_udf)); + bool succ_2 = InferSubgraphDataType(*attrs.subgraphs[2], &else_in_type, out_type); + CHECK(sync_in_in(params.else_input_locs, in_type, &else_in_type, is_udf)); + return succ_0 && succ_1 && succ_2; +} + +static bool IfelseStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const IfelseParam& params = nnvm::get(attrs.parsed); + static const std::function is_udf = is_stype_udf; + CHECK_EQ(in_attrs->size() + 3U, (size_t) params.num_args); + CHECK_EQ(out_attrs->size(), (size_t) params.num_outputs); + CHECK_EQ(attrs.subgraphs.size(), 3U); + CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U); + CHECK_EQ(attrs.subgraphs[1]->outputs.size(), attrs.subgraphs[2]->outputs.size()); + std::vector cond_in_attrs; + std::vector then_in_attrs; + std::vector else_in_attrs; + extract_by_loc(*in_attrs, params.cond_input_locs, &cond_in_attrs); + extract_by_loc(*in_attrs, params.then_input_locs, &then_in_attrs); + extract_by_loc(*in_attrs, params.else_input_locs, &else_in_attrs); + std::vector cond_out_attrs = {kDefaultStorage}; + DispatchMode cond_mode = DispatchMode::kUndefined; + DispatchMode then_mode = DispatchMode::kUndefined; + DispatchMode else_mode = DispatchMode::kUndefined; + *dispatch_mode = DispatchMode::kFComputeEx; + bool succ_0 = InferSubgraphStorage(*attrs.subgraphs[0], dev_mask, \ + &cond_mode, &cond_in_attrs, &cond_out_attrs); + CHECK(sync_in_in(params.cond_input_locs, in_attrs, &cond_in_attrs, is_udf)); + bool succ_1 = InferSubgraphStorage(*attrs.subgraphs[1], dev_mask, \ + &then_mode, &then_in_attrs, out_attrs); + CHECK(sync_in_in(params.then_input_locs, in_attrs, &then_in_attrs, is_udf)); + bool succ_2 = InferSubgraphStorage(*attrs.subgraphs[2], dev_mask, \ + &else_mode, &else_in_attrs, out_attrs); + CHECK(sync_in_in(params.else_input_locs, in_attrs, &else_in_attrs, is_udf)); + return succ_0 && succ_1 && succ_2; +} + +static bool BackwardIfelseStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const IfelseParam& params = nnvm::get(attrs.parsed); + CHECK_EQ(out_attrs->size() + 3U, (size_t) params.num_args); + CHECK_EQ(attrs.subgraphs.size(), 3U); + static const std::function is_udf = is_stype_udf; + auto sub_pass = [&](const std::shared_ptr &subg, const nnvm::Tuple &input_locs) { + // A. first construct subg_in_attrs + // need subg_in_attrs as subg_bwd_out (copy), subg_fwd_in (extract), subg_fwd_out (copy) + std::vector subg_in_attrs; + size_t num_elts = params.num_outputs * 2 + input_locs.ndim(); + subg_in_attrs.reserve(num_elts); + // part 1. subg_bwd_out (copy) + subg_in_attrs.insert(subg_in_attrs.end(), + in_attrs->begin(), + in_attrs->begin() + params.num_outputs); + // part 2. subg_fwd_in (extract) + std::vector fwd_in(in_attrs->begin() + params.num_outputs, + in_attrs->begin() + params.num_outputs + params.num_args - 3); + std::vector subg_fwd_in; + extract_by_loc(fwd_in, input_locs, &subg_fwd_in); + subg_in_attrs.insert(subg_in_attrs.end(), + subg_fwd_in.begin(), + subg_fwd_in.end()); + // part 3. subg_fwd_out (copy) + subg_in_attrs.insert(subg_in_attrs.end(), + in_attrs->begin() + params.num_outputs + params.num_args - 3, + in_attrs->end()); + // check correctness of the number of elements + CHECK_EQ(subg_in_attrs.size(), num_elts); + // B. then we construct subg_out_attrs by extracting from out_attrs + std::vector subg_out_attrs; + extract_by_loc(*out_attrs, input_locs, &subg_out_attrs); + // then we construct the subgraph and do inference + CachedOp op(*subg, {}); + bool ret = op.BackwardStorageType(attrs, dev_mask, dispatch_mode, \ + &subg_in_attrs, &subg_out_attrs); + CHECK(sync_in_in(input_locs, out_attrs, &subg_out_attrs, is_udf)); + return ret; + }; + bool succ_0 = sub_pass(attrs.subgraphs[1], params.then_input_locs); + bool succ_1 = sub_pass(attrs.subgraphs[2], params.else_input_locs); + return succ_0 && succ_1; +} + +static OpStatePtr CreateIfelseState(const NodeAttrs& attrs, + Context ctx, + const std::vector& ishape, + const std::vector& itype) { + const IfelseParam& params = nnvm::get(attrs.parsed); + return OpStatePtr::Create( + params, + *attrs.subgraphs[0], + *attrs.subgraphs[1], + *attrs.subgraphs[2]); +} + +static std::vector +IfelseGradient(const nnvm::NodePtr& n, const std::vector& ograds) { + ElemwiseGradUseInOut fgrad{"_backward_ifelse"}; + std::vector entries = fgrad(n, ograds); + entries[0].node->attrs.subgraphs = n->attrs.subgraphs; + return entries; +} + NNVM_REGISTER_OP(_foreach) .MXNET_DESCRIBE("Run a for loop over an NDArray with user-defined computation") .set_attr_parser(ParamParser) @@ -1100,5 +1372,68 @@ NNVM_REGISTER_OP(_backward_while_loop) .set_attr("FStatefulComputeEx", WhileLoopGradComputeExCPU) .set_attr("FStatefulComputeEx", WhileLoopGradComputeExCPU); +NNVM_REGISTER_OP(_ifelse) +.MXNET_DESCRIBE("Run a if-then-else using user-defined condition and computation") +.set_attr_parser(ParamParser) +.set_attr("FInferStorageType", IfelseStorageType) +.set_num_inputs([](const NodeAttrs& attrs) { + const IfelseParam& params = nnvm::get(attrs.parsed); + return params.num_args; +}) +.set_num_outputs([](const NodeAttrs& attrs) { + const IfelseParam& params = nnvm::get(attrs.parsed); + return params.num_outputs; +}) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const IfelseParam& params = nnvm::get(attrs.parsed); + std::vector names; + names.reserve(params.num_args); + names.push_back("cond"); + names.push_back("then_branch"); + names.push_back("else_branch"); + for (int i = 3; i < params.num_args; ++i) + names.push_back("data" + std::to_string(i - 3)); + return names; +}) +.set_attr("FInputGraph", + [](const NodeAttrs& attrs) { + return std::vector{0, 1, 2}; +}) +.set_attr("FGradient", IfelseGradient) +.set_attr("FCreateOpState", CreateIfelseState) +.set_attr("FInferShape", IfelseShape) +.set_attr("FInferType", IfelseType) +.set_attr("FStatefulComputeEx", IfelseComputeExCPU) +.set_attr("FExecType", [](const NodeAttrs& attrs) { + return ExecType::kSubgraphExec; +}) +.set_attr("FStatefulComputeEx", IfelseComputeExCPU) +.set_attr("key_var_num_args", "num_args") +.add_argument("cond", "Symbol", "Input graph for the condition.") +.add_argument("then_branch", "Symbol", "Input graph for the then branch.") +.add_argument("else_branch", "Symbol", "Input graph for the else branch.") +.add_argument("data", "NDArray-or-Symbol[]", + "The input arrays that include data arrays and states.") +.add_arguments(IfelseParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_ifelse) +.set_num_inputs([](const NodeAttrs& attrs){ + const IfelseParam& params = nnvm::get(attrs.parsed); + return params.num_outputs * 2 + params.num_args - 3; +}) +.set_num_outputs([](const NodeAttrs& attrs){ + const IfelseParam& params = nnvm::get(attrs.parsed); + return params.num_args - 3; +}) +.set_attr("FExecType", [](const NodeAttrs& attrs) { + return ExecType::kSubgraphExec; +}) +.set_attr("FInferStorageType", BackwardIfelseStorageType) +.set_attr_parser(ParamParser) +.set_attr("TIsLayerOpBackward", true) +.set_attr("TIsBackward", true) +.set_attr("FStatefulComputeEx", IfelseGradComputeExCPU) +.set_attr("FStatefulComputeEx", IfelseGradComputeExCPU); } // namespace op } // namespace mxnet diff --git a/src/operator/subgraph_op_common.cc b/src/operator/subgraph_op_common.cc index d845aa907d33..7a99aedb8602 100644 --- a/src/operator/subgraph_op_common.cc +++ b/src/operator/subgraph_op_common.cc @@ -161,6 +161,34 @@ bool InferSubgraphShape(const nnvm::Symbol &subgraph, return g.GetAttr("shape_num_unknown_nodes") == 0; } +template +T _asscalar(const NDArray &a) { + CHECK_EQ(a.shape().Size(), 1U); + T data; + a.SyncCopyToCPU(&data, 1U); + return data; +} + +bool as_bool_scalar(const NDArray &a) { + MSHADOW_TYPE_SWITCH(a.dtype(), DType, { + return static_cast(_asscalar(a)); + }); + LOG(FATAL) << "Unknown dtype"; + return false; +} + +bool is_shape_udf(const TShape &x) { + return x.ndim() == 0 || x.Size() == 0; +} + +bool is_stype_udf(const int &x) { + return x == exec::kBadStorageID; +} + +bool is_type_udf(const int &x) { + return x == -1; +} + LoopState::LoopState(const Symbol &g) { this->subgraph_sym = g; this->subgraph.outputs = g.outputs; diff --git a/src/operator/subgraph_op_common.h b/src/operator/subgraph_op_common.h index f73f09cd5c85..24983ae34632 100644 --- a/src/operator/subgraph_op_common.h +++ b/src/operator/subgraph_op_common.h @@ -57,6 +57,55 @@ bool InferSubgraphStorage(const nnvm::Symbol &subgraph, std::vector *in_attrs, std::vector *out_attrs); +bool as_bool_scalar(const NDArray &a); + +bool is_shape_udf(const TShape &x); + +bool is_stype_udf(const int &x); + +bool is_type_udf(const int &x); + +template +void extract_by_loc(const std::vector &array, + const nnvm::Tuple input_locs, + std::vector *out) { + out->clear(); + out->reserve(input_locs.ndim()); + for (dim_t i : input_locs) { + out->push_back(array[i]); + } +} + +template +bool fill_value(T *x, T *y, bool x_empty, bool y_empty) { + if (*x == *y || (x_empty && y_empty)) { + return true; + } + if (!x_empty && !y_empty) { + return false; + } + if (x_empty) { + *x = *y; + } + if (y_empty) { + *y = *x; + } + return true; +} + +template +bool sync_in_in(const nnvm::Tuple &input_locs, + std::vector *in, + std::vector *subg_in, + std::function is_empty) { + for (size_t i = 0; i < input_locs.ndim(); ++i) { + T &x = in->at(input_locs[i]); + T &y = subg_in->at(i); + fill_value(&x, &y, is_empty(x), is_empty(y)); + } + return true; +} + /* * This contains the states for running a loop and provides methods * of running the subgraph computation for an iteration. diff --git a/tests/python/unittest/test_contrib_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py index 1cc5b21ac86c..12694572bb7c 100644 --- a/tests/python/unittest/test_contrib_control_flow.py +++ b/tests/python/unittest/test_contrib_control_flow.py @@ -974,6 +974,162 @@ def _func(*states): y = y.asnumpy() assert_almost_equal(x, y, rtol=1e-4, atol=1e-4) +def _verify_ifelse(cond, then_func, else_func, input_var_shapes, free_var_shapes, is_train): + + def _create_symbol(prefix, i): + return mx.sym.var(prefix + str(i)) + + def _create_array(shape): + return mx.nd.random.uniform(-1.0, 1.0, shape=shape) + + def _to_numpy_list(arrays): + return [x.asnumpy() if x is not None else x for x in arrays] + + def _merge_dict(*dicts): + result = {} + for item in dicts: + result.update(item) + return result + + _input_syms = [_create_symbol("InputVar", i) for i, _ in enumerate(input_var_shapes)] + _free_syms = [_create_symbol("FreeVar", i) for i, _ in enumerate(free_var_shapes)] + _input_vars = [_create_array(x) for x in input_var_shapes] + _free_vars = [_create_array(x) for x in free_var_shapes] + _args_dict = _merge_dict( + {"InputVar" + str(i): x for i, x in enumerate(_input_vars)}, + {"FreeVar" + str(i): x for i, x in enumerate(_free_vars)}, + ) + + def _get_imperative_result(): + free_vars = [x.copy() for x in _free_vars] + input_vars = [x.copy() for x in _input_vars] + out_grads = [] + if is_train: + for var in free_vars + input_vars: + var.attach_grad() + with mx.autograd.record(train_mode=is_train): + outputs = mx.nd.contrib.ifelse( + cond=lambda *__input_vars: cond(__input_vars, free_vars), + then_func=lambda *__input_vars: then_func(__input_vars, free_vars), + else_func=lambda *__input_vars: else_func(__input_vars, free_vars), + inputs=input_vars, + ) + outputs = [x * 2 for x in outputs] + grads = [] + if is_train: + out_grads = [_create_array(x.shape) for x in outputs] + cat_out = mx.nd.concat(*[x.reshape(-1) for x in outputs], dim=0) + cat_out.backward(out_grad=mx.nd.concat(*[x.reshape(-1) for x in out_grads], dim=0)) + grads = [free_vars[i].grad for i, _ in enumerate(free_var_shapes)] \ + + [input_vars[i].grad for i, _ in enumerate(input_var_shapes)] + return _to_numpy_list(outputs), _to_numpy_list(grads), out_grads + + def _get_symbolic_result(out_grads): + outputs_sym = mx.sym.contrib.ifelse( + cond=lambda *__loop_vars: cond(__loop_vars, _free_syms), + then_func=lambda *__loop_vars: then_func(__loop_vars, _free_syms), + else_func=lambda *__loop_vars: else_func(__loop_vars, _free_syms), + inputs=_input_syms, + ) + outputs_sym = [x * 2 for x in outputs_sym] + outputs_sym = mx.sym.Group(outputs_sym) + executor = outputs_sym.bind( + ctx=default_context(), + args={name: _args_dict[name].copy() for name in outputs_sym.list_inputs()}, + args_grad=None if not is_train else _merge_dict( + {"InputVar" + str(i): mx.nd.zeros(s) for i, s in enumerate(input_var_shapes)}, + {"FreeVar" + str(i): mx.nd.zeros(s) for i, s in enumerate(free_var_shapes)}, + ), + ) + outputs = executor.forward(is_train=is_train) + grads = [] + if is_train: + executor.backward(out_grads=out_grads) + grads = [executor.grad_dict.get("FreeVar" + str(i), None) for i, _ in enumerate(free_var_shapes)] \ + + [executor.grad_dict.get("InputVar" + str(i), None) for i, _ in enumerate(input_var_shapes)] + return _to_numpy_list(outputs), _to_numpy_list(grads) + + imp_outs, imp_grads, out_grads = _get_imperative_result() + sym_outs, sym_grads = _get_symbolic_result(out_grads) + for imp_out, sym_out in zip(imp_outs, sym_outs): + if imp_out is None or sym_out is None: + continue + assert_almost_equal(imp_out, sym_out, rtol=1e-5, atol=1e-5) + for imp_grad, sym_grad in zip(imp_grads, sym_grads): + if imp_grad is None or sym_grad is None: + continue + assert_almost_equal(imp_grad, sym_grad, rtol=1e-5, atol=1e-5) + + +@with_seed() +def test_ifelse(): + # whether there are free variables in three graphs + # whether these three graphs contain input_vars + # whether to use all input_vars + # which branch to choose + def run_case(cond_func, then_func, else_func, **params): + def make_cond(is_inverse): + def cond(inputs, free): + x = cond_func(inputs, free) + if is_inverse: + if isinstance(x, mx.sym.Symbol): + return mx.sym.logical_not(x) + else: + return mx.nd.logical_not(x) + return x + return cond + for is_train in [True, False]: + for is_inverse in [False, True]: + _verify_ifelse( + cond=make_cond(is_inverse), + then_func=then_func, + else_func=else_func, + is_train=is_train, + **params + ) + # Each function can + # 1. use_free_vars or not: T/F + # 2. use_input_vars or not: T/F + # 3. use_all_input_vars or not: T/F + # (a, b, c) are inputs, (d, e, f) are free_vars + cond_funcs = [ + lambda a, b, c, d, e, f: (a * b).sum() < 0.5, # F, T, F + lambda a, b, c, d, e, f: (a + b + c).sum() < 0.5, # F, T, T + lambda a, b, c, d, e, f: (d + e).sum() < 0.5, # T, F, F + lambda a, b, c, d, e, f: (d + e * a).sum() < 0.5, # T, T, F + lambda a, b, c, d, e, f: (d + e * a + b * c).sum() < 0.5, # T, T, T + ] + body_funcs = [ + lambda a, b, c, d, e, f: a * b, # F, T, F + lambda a, b, c, d, e, f: a * b * c, # F, T, T + lambda a, b, c, d, e, f: d * e, # T, F, F + lambda a, b, c, d, e, f: d * e * a, # T, T, F + lambda a, b, c, d, e, f: d * e * a * b * c, # T, T, T + # some extra tests + lambda a, b, c, d, e, f: b * c, + lambda a, b, c, d, e, f: a * c, + lambda a, b, c, d, e, f: (a + b) * c, + lambda a, b, c, d, e, f: c * (b - a), + ] + # enumerate all kinds of possible combinations + for cond_func in cond_funcs: + for then_func in body_funcs: + for else_func in body_funcs: + run_case( + cond_func=lambda x, y: cond_func(x[0], x[1], x[2], y[0], y[1], y[2]), + then_func=lambda x, y: then_func(x[0], x[1], x[2], y[0], y[1], y[2]), + else_func=lambda x, y: else_func(x[0], x[1], x[2], y[0], y[1], y[2]), + input_var_shapes=[ + (2, 3), + (2, 3), + (2, 3), + ], + free_var_shapes=[ + (2, 3), + (2, 3), + (2, 3), + ] + ) if __name__ == '__main__': import nose From 2b46358310409d6121c5697c87f6861d293016ad Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 20 Jul 2018 11:29:35 -0700 Subject: [PATCH 2/7] Address comments --- python/mxnet/ndarray/contrib.py | 4 ++-- python/mxnet/symbol/contrib.py | 4 ++-- src/operator/control_flow.cc | 1 + src/operator/subgraph_op_common.h | 13 +++++++++++++ 4 files changed, 18 insertions(+), 4 deletions(-) diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index b7b63c4e10e6..12407cf4fe74 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -364,12 +364,12 @@ def _func_wrapper(loop_vars): return stacked_outputs, list(loop_vars) def ifelse(cond, then_func, else_func, inputs): - """Run a if-then-else using user-defined condition and computation + """Run an if-then-else using user-defined condition and computation This operator simulates a if-like branch which chooses to do one of the two customized computations according to the specified condition. - `inputs` is a list of NDArrays on which the condition and computations reply on. + `inputs` is a list of NDArrays on which the condition and computations rely on. `cond` is a user-defined function, used as the if condition. It consumes `inputs`, and produces a scalar MXNet NDArray, diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index 33932ba5ad94..13bb89e8d9f1 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -558,12 +558,12 @@ def _union_inputs(*graphs): return outputs, final_loop_vars def ifelse(cond, then_func, else_func, inputs, name="ifelse"): - """Run a if-then-else using user-defined condition and computation + """Run an if-then-else using user-defined condition and computation This operator simulates a if-like branch which chooses to do one of the two customized computations according to the specified condition. - `inputs` is a list of Symbols on which the condition and computations reply on. + `inputs` is a list of Symbols on which the condition and computations rely on. `cond` is a user-defined function, used as the if condition. It consumes `inputs`, and produces a scalar MXNet symbol, diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc index 261bd5070f7d..5159a27cb508 100644 --- a/src/operator/control_flow.cc +++ b/src/operator/control_flow.cc @@ -1117,6 +1117,7 @@ static bool IfelseShape(const nnvm::NodeAttrs& attrs, params.then_input_locs, true); bool succ_2 = infer_subg(attrs.subgraphs[2], &else_out_shape, \ params.else_input_locs, true); + sync_out_out(&then_out_shape, &else_out_shape, is_udf); return succ_0 && succ_1 && succ_2; } diff --git a/src/operator/subgraph_op_common.h b/src/operator/subgraph_op_common.h index 24983ae34632..ebf727f0f5a8 100644 --- a/src/operator/subgraph_op_common.h +++ b/src/operator/subgraph_op_common.h @@ -106,6 +106,19 @@ bool sync_in_in(const nnvm::Tuple &input_locs, return true; } +template +bool sync_out_out(std::vector *out_1, + std::vector *out_2, + std::function is_empty) { + CHECK_EQ(out_1->size(), out_2->size()); + for (size_t i = 0; i < out_1->size(); ++i) { + T &x = out_1->at(i); + T &y = out_2->at(i); + fill_value(&x, &y, is_empty(x), is_empty(y)); + } + return true; +} + /* * This contains the states for running a loop and provides methods * of running the subgraph computation for an iteration. From b28bacf575ca39e3d83118e9cf596eef3c1cdd69 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 20 Jul 2018 16:15:11 -0700 Subject: [PATCH 3/7] Rename ifelse to condition --- docs/api/python/ndarray/contrib.md | 2 +- docs/api/python/symbol/contrib.md | 2 +- python/mxnet/ndarray/contrib.py | 24 +-- python/mxnet/symbol/contrib.py | 32 ++-- src/operator/control_flow.cc | 154 +++++++++--------- .../unittest/test_contrib_control_flow.py | 16 +- 6 files changed, 115 insertions(+), 115 deletions(-) diff --git a/docs/api/python/ndarray/contrib.md b/docs/api/python/ndarray/contrib.md index 80d8ef23b459..f575bc8e7ce2 100644 --- a/docs/api/python/ndarray/contrib.md +++ b/docs/api/python/ndarray/contrib.md @@ -54,7 +54,7 @@ In the rest of this document, we list routines provided by the `ndarray.contrib` quantize foreach while_loop - ifelse + condition ``` ## API Reference diff --git a/docs/api/python/symbol/contrib.md b/docs/api/python/symbol/contrib.md index 96ce7987d800..69d38beffd1c 100644 --- a/docs/api/python/symbol/contrib.md +++ b/docs/api/python/symbol/contrib.md @@ -54,7 +54,7 @@ In the rest of this document, we list routines provided by the `symbol.contrib` quantize foreach while_loop - ifelse + condition ``` ## API Reference diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index 12407cf4fe74..4e30e5e16e53 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -28,7 +28,7 @@ except ImportError: pass -__all__ = ["rand_zipfian", "foreach", "while_loop", "ifelse"] +__all__ = ["rand_zipfian", "foreach", "while_loop", "condition"] # pylint: disable=line-too-long def rand_zipfian(true_classes, num_sampled, range_max, ctx=None): @@ -363,7 +363,7 @@ def _func_wrapper(loop_vars): )) return stacked_outputs, list(loop_vars) -def ifelse(cond, then_func, else_func, inputs): +def condition(cond_func, then_func, else_func, inputs): # pylint: disable=redefined-outer-name """Run an if-then-else using user-defined condition and computation This operator simulates a if-like branch which chooses to do one of @@ -371,11 +371,11 @@ def ifelse(cond, then_func, else_func, inputs): `inputs` is a list of NDArrays on which the condition and computations rely on. - `cond` is a user-defined function, used as the if condition. + `cond_func` is a user-defined function, used as the if condition. It consumes `inputs`, and produces a scalar MXNet NDArray, indicating which branch of computation should be used. - The `cond` is variadic, and its signature should be - `cond(*loop_vars) => NDArray`. + The `cond_func` is variadic, and its signature should be + `cond_func(*loop_vars) => NDArray`. `then_func` is a user-defined function, used as computation of the then branch. It consumes `inputs`, and produces `outputs`. @@ -394,14 +394,14 @@ def ifelse(cond, then_func, else_func, inputs): Parameters ---------- - cond: a Python function. + cond_func: a Python function. The branch condition. then_func: a Python function. - The computation to be executed if `cond` is true. + The computation to be executed if `cond_func` is true. else_func: a Python function. - The computation to be executed if `cond` is false. + The computation to be executed if `cond_func` is false. inputs: list of NDArrays. - The variables fed to `cond`, `then_func` and `else_func`. + The variables fed to `cond_func`, `then_func` and `else_func`. Returns ------- @@ -409,11 +409,11 @@ def ifelse(cond, then_func, else_func, inputs): Examples -------- - >>> cond = lambda a, b: a * b < 5 + >>> cond_func = lambda a, b: a * b < 5 >>> then_func = lambda a, b: (a + 5) * (b + 5) >>> else_func = lambda a, b: (a - 5) * (b - 5) >>> inputs = (mx.nd.array([1]), mx.nd.array([2])) - >>> outputs = mx.nd.contrib.ifelse(cond, then_func, else_func, inputs) + >>> outputs = mx.nd.contrib.cond(cond_func, then_func, else_func, inputs) >>> outputs[0] [42.] @@ -448,7 +448,7 @@ def _to_ndarray_tuple(inputs, name): inputs = _to_ndarray_tuple(inputs, "inputs") if len(inputs) == 0: raise ValueError("inputs should contain at least one element") - branch = _to_python_scalar(cond(*inputs), bool, "Return value of cond") + branch = _to_python_scalar(cond_func(*inputs), bool, "Return value of cond_func") if branch: outputs = then_func(*inputs) outputs = _to_ndarray_tuple(outputs, "outputs of then_func") diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index 13bb89e8d9f1..3274b7833c47 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -34,7 +34,7 @@ from ..base import SymbolHandle, _as_list from ..attribute import AttrScope -__all__ = ["rand_zipfian", "foreach", "while_loop", "ifelse"] +__all__ = ["rand_zipfian", "foreach", "while_loop", "condition"] def rand_zipfian(true_classes, num_sampled, range_max): """Draw random samples from an approximately log-uniform or Zipfian distribution. @@ -557,7 +557,7 @@ def _union_inputs(*graphs): final_loop_vars = [result[i] for i in range(num_out_data, num_outputs)] return outputs, final_loop_vars -def ifelse(cond, then_func, else_func, inputs, name="ifelse"): +def condition(cond_func, then_func, else_func, inputs, name="cond"): # pylint: disable=redefined-outer-name """Run an if-then-else using user-defined condition and computation This operator simulates a if-like branch which chooses to do one of @@ -565,11 +565,11 @@ def ifelse(cond, then_func, else_func, inputs, name="ifelse"): `inputs` is a list of Symbols on which the condition and computations rely on. - `cond` is a user-defined function, used as the if condition. + `cond_func` is a user-defined function, used as the if condition. It consumes `inputs`, and produces a scalar MXNet symbol, indicating which branch of computation should be used. - The `cond` is variadic, and its signature should be - `cond(*loop_vars) => Symbol`. + The `cond_func` is variadic, and its signature should be + `cond_func(*loop_vars) => Symbol`. `then_func` is a user-defined function, used as computation of the then branch. It consumes `inputs`, and produces `outputs`. @@ -588,14 +588,14 @@ def ifelse(cond, then_func, else_func, inputs, name="ifelse"): Parameters ---------- - cond: a Python function. + cond_func: a Python function. The branch condition. then_func: a Python function. - The computation to be executed if `cond` is true. + The computation to be executed if `cond_func` is true. else_func: a Python function. - The computation to be executed if `cond` is false. + The computation to be executed if `cond_func` is false. inputs: list of Symbols. - The variables fed to `cond`, `then_func` and `else_func`. + The variables fed to `cond_func`, `then_func` and `else_func`. Returns ------- @@ -603,11 +603,11 @@ def ifelse(cond, then_func, else_func, inputs, name="ifelse"): Examples -------- - >>> cond = lambda a, b: a * b < 5 + >>> cond_func = lambda a, b: a * b < 5 >>> then_func = lambda a, b: (a + 5) * (b + 5) >>> else_func = lambda a, b: (a - 5) * (b - 5) >>> inputs = (mx.sym.var('a'), mx.sym.var('b')) - >>> outputs = mx.sym.contrib.ifelse(cond, then_func, else_func, inputs) + >>> outputs = mx.sym.contrib.cond(cond_func, then_func, else_func, inputs) """ def _to_symbol_tuple(inputs, name): """Converts "inputs", possibly a single mxnet Symbol, a list of mxnet Symbol, @@ -681,10 +681,10 @@ def _union_inputs(*graphs): inputs = _to_symbol_tuple(inputs, "inputs") if len(inputs) == 0: raise ValueError("loop_vars should contain at least one element") - # create graph for `cond' - cond_g, num_outputs = _create_subgraph(inputs, cond, name + "_cond") - if num_outputs != 1: - raise ValueError("cond should always produce a single output") + # create graph for `cond_func' + cond_g, cond_num_outputs = _create_subgraph(inputs, cond_func, name + "_cond") + if cond_num_outputs != 1: + raise ValueError("cond_func should always produce a single output") # create graph for `then` then_g, then_num_outputs = _create_subgraph(inputs, then_func, name + "_then") # create graph for `else` @@ -694,7 +694,7 @@ def _union_inputs(*graphs): # find symbols used in either cond_g or func_g input_syms, (cond_input_locs, then_input_locs, else_input_locs) = \ _union_inputs(cond_g, then_g, else_g) - result = symbol._internal._ifelse( + result = symbol._internal._cond( # [cond, then_g, else_g, *input_syms] cond_g, then_g, diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc index 5159a27cb508..7c1beccb0288 100644 --- a/src/operator/control_flow.cc +++ b/src/operator/control_flow.cc @@ -913,13 +913,13 @@ WhileLoopGradient(const nnvm::NodePtr& n, const std::vector& og return entries; } -struct IfelseParam : public dmlc::Parameter { +struct CondParam : public dmlc::Parameter { int num_args; int num_outputs; nnvm::Tuple cond_input_locs; nnvm::Tuple then_input_locs; nnvm::Tuple else_input_locs; - DMLC_DECLARE_PARAMETER(IfelseParam) { + DMLC_DECLARE_PARAMETER(CondParam) { DMLC_DECLARE_FIELD(num_args).set_lower_bound(3) .describe("Number of input arguments, including cond, then and else as three symbol inputs."); DMLC_DECLARE_FIELD(num_outputs).set_lower_bound(1) @@ -931,42 +931,42 @@ struct IfelseParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(else_input_locs) .describe("The locations of else's inputs in the given inputs."); } -}; // struct IfelseParam +}; // struct CondParam -DMLC_REGISTER_PARAMETER(IfelseParam); +DMLC_REGISTER_PARAMETER(CondParam); -class IfelseState { +class CondState { public: - IfelseParam params; + CondParam params; CachedOpPtr cond_op; LoopState then_branch; LoopState else_branch; int branch_selection; // 1 if then branch; 0 if else branch; -1 if undefined - IfelseState(const IfelseParam ¶ms, - const Symbol &cond, - const Symbol &then_sym, - const Symbol &else_sym): - params(params), - cond_op(LoopState::MakeSharedOp(cond)), - then_branch(then_sym), - else_branch(else_sym), - branch_selection(-1) { + CondState(const CondParam ¶ms, + const Symbol &cond, + const Symbol &then_sym, + const Symbol &else_sym): + params(params), + cond_op(LoopState::MakeSharedOp(cond)), + then_branch(then_sym), + else_branch(else_sym), + branch_selection(-1) { } }; -static void IfelseComputeExCPU(const OpStatePtr& state_ptr, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { +static void CondComputeExCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { // The argument `inputs' are loop_vars and other inputs // loop_vars are stored in stored in `loop_vars_locs' // The argument `outputs' are output and new_loop_vars // [0: num_out_data) are outputs at each step. // [num_out_data: ) are new_loop_vars - IfelseState &state = state_ptr.get_state(); - const IfelseParam& params = state.params; + CondState &state = state_ptr.get_state(); + const CondParam& params = state.params; // a helper function, converting std::vector to std::vector const auto to_ptr_vec = [](std::vector &in, std::vector *out) { out->clear(); @@ -1005,13 +1005,13 @@ static void IfelseComputeExCPU(const OpStatePtr& state_ptr, loop_state.Forward(0, func_inputs, req, outputs, ctx.need_grad); } -static void IfelseGradComputeExCPU(const OpStatePtr& state_ptr, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& _req, - const std::vector& outputs) { - IfelseState &state = state_ptr.get_state(); - const IfelseParam& params = state.params; +static void CondGradComputeExCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& _req, + const std::vector& outputs) { + CondState &state = state_ptr.get_state(); + const CondParam& params = state.params; // sanity checks CHECK_EQ(outputs.size() + 3U, (size_t) params.num_args); CHECK_EQ(outputs.size(), _req.size()); @@ -1034,11 +1034,11 @@ static void IfelseGradComputeExCPU(const OpStatePtr& state_ptr, loop_state.Cleanup(); } -static bool IfelseShape(const nnvm::NodeAttrs& attrs, - std::vector *in_shape, - std::vector *out_shape) { +static bool CondShape(const nnvm::NodeAttrs& attrs, + std::vector *in_shape, + std::vector *out_shape) { using nnvm::ShapeVector; - const IfelseParam& params = nnvm::get(attrs.parsed); + const CondParam& params = nnvm::get(attrs.parsed); static const std::function is_udf = is_shape_udf; // sanity checks CHECK_EQ(in_shape->size() + 3U, (size_t) params.num_args); @@ -1121,10 +1121,10 @@ static bool IfelseShape(const nnvm::NodeAttrs& attrs, return succ_0 && succ_1 && succ_2; } -static bool IfelseType(const nnvm::NodeAttrs& attrs, - std::vector *in_type, - std::vector *out_type) { - const IfelseParam& params = nnvm::get(attrs.parsed); +static bool CondType(const nnvm::NodeAttrs& attrs, + std::vector *in_type, + std::vector *out_type) { + const CondParam& params = nnvm::get(attrs.parsed); static const std::function is_udf = is_type_udf; CHECK_EQ(in_type->size() + 3U, (size_t) params.num_args); CHECK_EQ(out_type->size(), (size_t) params.num_outputs); @@ -1147,12 +1147,12 @@ static bool IfelseType(const nnvm::NodeAttrs& attrs, return succ_0 && succ_1 && succ_2; } -static bool IfelseStorageType(const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { - const IfelseParam& params = nnvm::get(attrs.parsed); +static bool CondStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const CondParam& params = nnvm::get(attrs.parsed); static const std::function is_udf = is_stype_udf; CHECK_EQ(in_attrs->size() + 3U, (size_t) params.num_args); CHECK_EQ(out_attrs->size(), (size_t) params.num_outputs); @@ -1182,12 +1182,12 @@ static bool IfelseStorageType(const nnvm::NodeAttrs& attrs, return succ_0 && succ_1 && succ_2; } -static bool BackwardIfelseStorageType(const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { - const IfelseParam& params = nnvm::get(attrs.parsed); +static bool BackwardCondStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const CondParam& params = nnvm::get(attrs.parsed); CHECK_EQ(out_attrs->size() + 3U, (size_t) params.num_args); CHECK_EQ(attrs.subgraphs.size(), 3U); static const std::function is_udf = is_stype_udf; @@ -1230,12 +1230,12 @@ static bool BackwardIfelseStorageType(const nnvm::NodeAttrs& attrs, return succ_0 && succ_1; } -static OpStatePtr CreateIfelseState(const NodeAttrs& attrs, - Context ctx, - const std::vector& ishape, - const std::vector& itype) { - const IfelseParam& params = nnvm::get(attrs.parsed); - return OpStatePtr::Create( +static OpStatePtr CreateCondState(const NodeAttrs& attrs, + Context ctx, + const std::vector& ishape, + const std::vector& itype) { + const CondParam& params = nnvm::get(attrs.parsed); + return OpStatePtr::Create( params, *attrs.subgraphs[0], *attrs.subgraphs[1], @@ -1243,8 +1243,8 @@ static OpStatePtr CreateIfelseState(const NodeAttrs& attrs, } static std::vector -IfelseGradient(const nnvm::NodePtr& n, const std::vector& ograds) { - ElemwiseGradUseInOut fgrad{"_backward_ifelse"}; +CondGradient(const nnvm::NodePtr& n, const std::vector& ograds) { + ElemwiseGradUseInOut fgrad{"_backward_cond"}; std::vector entries = fgrad(n, ograds); entries[0].node->attrs.subgraphs = n->attrs.subgraphs; return entries; @@ -1373,21 +1373,21 @@ NNVM_REGISTER_OP(_backward_while_loop) .set_attr("FStatefulComputeEx", WhileLoopGradComputeExCPU) .set_attr("FStatefulComputeEx", WhileLoopGradComputeExCPU); -NNVM_REGISTER_OP(_ifelse) +NNVM_REGISTER_OP(_cond) .MXNET_DESCRIBE("Run a if-then-else using user-defined condition and computation") -.set_attr_parser(ParamParser) -.set_attr("FInferStorageType", IfelseStorageType) +.set_attr_parser(ParamParser) +.set_attr("FInferStorageType", CondStorageType) .set_num_inputs([](const NodeAttrs& attrs) { - const IfelseParam& params = nnvm::get(attrs.parsed); + const CondParam& params = nnvm::get(attrs.parsed); return params.num_args; }) .set_num_outputs([](const NodeAttrs& attrs) { - const IfelseParam& params = nnvm::get(attrs.parsed); + const CondParam& params = nnvm::get(attrs.parsed); return params.num_outputs; }) .set_attr("FListInputNames", [](const NodeAttrs& attrs) { - const IfelseParam& params = nnvm::get(attrs.parsed); + const CondParam& params = nnvm::get(attrs.parsed); std::vector names; names.reserve(params.num_args); names.push_back("cond"); @@ -1401,40 +1401,40 @@ NNVM_REGISTER_OP(_ifelse) [](const NodeAttrs& attrs) { return std::vector{0, 1, 2}; }) -.set_attr("FGradient", IfelseGradient) -.set_attr("FCreateOpState", CreateIfelseState) -.set_attr("FInferShape", IfelseShape) -.set_attr("FInferType", IfelseType) -.set_attr("FStatefulComputeEx", IfelseComputeExCPU) +.set_attr("FGradient", CondGradient) +.set_attr("FCreateOpState", CreateCondState) +.set_attr("FInferShape", CondShape) +.set_attr("FInferType", CondType) +.set_attr("FStatefulComputeEx", CondComputeExCPU) .set_attr("FExecType", [](const NodeAttrs& attrs) { return ExecType::kSubgraphExec; }) -.set_attr("FStatefulComputeEx", IfelseComputeExCPU) +.set_attr("FStatefulComputeEx", CondComputeExCPU) .set_attr("key_var_num_args", "num_args") .add_argument("cond", "Symbol", "Input graph for the condition.") .add_argument("then_branch", "Symbol", "Input graph for the then branch.") .add_argument("else_branch", "Symbol", "Input graph for the else branch.") .add_argument("data", "NDArray-or-Symbol[]", "The input arrays that include data arrays and states.") -.add_arguments(IfelseParam::__FIELDS__()); +.add_arguments(CondParam::__FIELDS__()); -NNVM_REGISTER_OP(_backward_ifelse) +NNVM_REGISTER_OP(_backward_cond) .set_num_inputs([](const NodeAttrs& attrs){ - const IfelseParam& params = nnvm::get(attrs.parsed); + const CondParam& params = nnvm::get(attrs.parsed); return params.num_outputs * 2 + params.num_args - 3; }) .set_num_outputs([](const NodeAttrs& attrs){ - const IfelseParam& params = nnvm::get(attrs.parsed); + const CondParam& params = nnvm::get(attrs.parsed); return params.num_args - 3; }) .set_attr("FExecType", [](const NodeAttrs& attrs) { return ExecType::kSubgraphExec; }) -.set_attr("FInferStorageType", BackwardIfelseStorageType) -.set_attr_parser(ParamParser) +.set_attr("FInferStorageType", BackwardCondStorageType) +.set_attr_parser(ParamParser) .set_attr("TIsLayerOpBackward", true) .set_attr("TIsBackward", true) -.set_attr("FStatefulComputeEx", IfelseGradComputeExCPU) -.set_attr("FStatefulComputeEx", IfelseGradComputeExCPU); +.set_attr("FStatefulComputeEx", CondGradComputeExCPU) +.set_attr("FStatefulComputeEx", CondGradComputeExCPU); } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_contrib_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py index 12694572bb7c..87eac8960339 100644 --- a/tests/python/unittest/test_contrib_control_flow.py +++ b/tests/python/unittest/test_contrib_control_flow.py @@ -974,7 +974,7 @@ def _func(*states): y = y.asnumpy() assert_almost_equal(x, y, rtol=1e-4, atol=1e-4) -def _verify_ifelse(cond, then_func, else_func, input_var_shapes, free_var_shapes, is_train): +def _verify_cond(cond_func, then_func, else_func, input_var_shapes, free_var_shapes, is_train): def _create_symbol(prefix, i): return mx.sym.var(prefix + str(i)) @@ -1008,8 +1008,8 @@ def _get_imperative_result(): for var in free_vars + input_vars: var.attach_grad() with mx.autograd.record(train_mode=is_train): - outputs = mx.nd.contrib.ifelse( - cond=lambda *__input_vars: cond(__input_vars, free_vars), + outputs = mx.nd.contrib.condition( + cond_func=lambda *__input_vars: cond_func(__input_vars, free_vars), then_func=lambda *__input_vars: then_func(__input_vars, free_vars), else_func=lambda *__input_vars: else_func(__input_vars, free_vars), inputs=input_vars, @@ -1025,8 +1025,8 @@ def _get_imperative_result(): return _to_numpy_list(outputs), _to_numpy_list(grads), out_grads def _get_symbolic_result(out_grads): - outputs_sym = mx.sym.contrib.ifelse( - cond=lambda *__loop_vars: cond(__loop_vars, _free_syms), + outputs_sym = mx.sym.contrib.condition( + cond_func=lambda *__loop_vars: cond_func(__loop_vars, _free_syms), then_func=lambda *__loop_vars: then_func(__loop_vars, _free_syms), else_func=lambda *__loop_vars: else_func(__loop_vars, _free_syms), inputs=_input_syms, @@ -1062,7 +1062,7 @@ def _get_symbolic_result(out_grads): @with_seed() -def test_ifelse(): +def test_cond(): # whether there are free variables in three graphs # whether these three graphs contain input_vars # whether to use all input_vars @@ -1080,8 +1080,8 @@ def cond(inputs, free): return cond for is_train in [True, False]: for is_inverse in [False, True]: - _verify_ifelse( - cond=make_cond(is_inverse), + _verify_cond( + cond_func=make_cond(is_inverse), then_func=then_func, else_func=else_func, is_train=is_train, From 8f2a8e09be756364af44c3154b69e99d04a4cd70 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 21 Jul 2018 00:10:12 -0700 Subject: [PATCH 4/7] API change --- python/mxnet/ndarray/contrib.py | 46 +++++++----------- python/mxnet/symbol/contrib.py | 47 ++++++++----------- .../unittest/test_contrib_control_flow.py | 14 +++--- 3 files changed, 43 insertions(+), 64 deletions(-) diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index 4e30e5e16e53..45f058336eb5 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -363,45 +363,38 @@ def _func_wrapper(loop_vars): )) return stacked_outputs, list(loop_vars) -def condition(cond_func, then_func, else_func, inputs): # pylint: disable=redefined-outer-name +def condition(cond, then_func, else_func): """Run an if-then-else using user-defined condition and computation This operator simulates a if-like branch which chooses to do one of the two customized computations according to the specified condition. - `inputs` is a list of NDArrays on which the condition and computations rely on. - - `cond_func` is a user-defined function, used as the if condition. - It consumes `inputs`, and produces a scalar MXNet NDArray, + `cond` is a scalar MXNet NDArray, indicating which branch of computation should be used. - The `cond_func` is variadic, and its signature should be - `cond_func(*loop_vars) => NDArray`. `then_func` is a user-defined function, used as computation of the then branch. - It consumes `inputs`, and produces `outputs`. - The `then_func` is variadic, and its signature should be - `then_func(*loop_vars) => List[NDArray]`. + It produces `outputs`, which is a list of NDArrays. + The signature of `then_func` should be + `then_func() => List[NDArray]`. `else_func` is a user-defined function, used as computation of the else branch. - It also consumes `inputs`, and produces `outputs`. - The `else_func` is variadic, and its signature should be - `else_func(*loop_vars) => List[NDArray]`. + It produces `outputs`, which is a list of NDArrays. + The signature of `else_func` should be + `else_func() => List[NDArray]`. The `outputs` produces by `then_func` and `else_func` should have the same number of elements, all of which should be in the same shape, of the same dtype and stype. - This function returns a list of NDArrays, representing the computation result. + This function returns a list of symbols, representing the computation result. Parameters ---------- - cond_func: a Python function. + cond: a MXNet NDArray representing a scalar. The branch condition. then_func: a Python function. - The computation to be executed if `cond_func` is true. + The computation to be executed if `cond` is true. else_func: a Python function. - The computation to be executed if `cond_func` is false. - inputs: list of NDArrays. - The variables fed to `cond_func`, `then_func` and `else_func`. + The computation to be executed if `cond` is false. Returns ------- @@ -409,11 +402,11 @@ def condition(cond_func, then_func, else_func, inputs): # pylint: disable=redef Examples -------- - >>> cond_func = lambda a, b: a * b < 5 + >>> a, b = mx.nd.array([1]), mx.nd.array([2]) + >>> cond = a * b < 5 >>> then_func = lambda a, b: (a + 5) * (b + 5) >>> else_func = lambda a, b: (a - 5) * (b - 5) - >>> inputs = (mx.nd.array([1]), mx.nd.array([2])) - >>> outputs = mx.nd.contrib.cond(cond_func, then_func, else_func, inputs) + >>> outputs = mx.nd.contrib.cond(cond, then_func, else_func) >>> outputs[0] [42.] @@ -445,14 +438,11 @@ def _to_ndarray_tuple(inputs, name): raise ValueError("%s must be an NDArray, or a tuple or list of NDArrays" % (name, )) return inputs - inputs = _to_ndarray_tuple(inputs, "inputs") - if len(inputs) == 0: - raise ValueError("inputs should contain at least one element") - branch = _to_python_scalar(cond_func(*inputs), bool, "Return value of cond_func") + branch = _to_python_scalar(cond, bool, "cond") if branch: - outputs = then_func(*inputs) + outputs = then_func() outputs = _to_ndarray_tuple(outputs, "outputs of then_func") else: - outputs = else_func(*inputs) + outputs = else_func() outputs = _to_ndarray_tuple(outputs, "outputs of else_func") return list(outputs) diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index 3274b7833c47..2d8b034533dc 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -557,29 +557,24 @@ def _union_inputs(*graphs): final_loop_vars = [result[i] for i in range(num_out_data, num_outputs)] return outputs, final_loop_vars -def condition(cond_func, then_func, else_func, inputs, name="cond"): # pylint: disable=redefined-outer-name +def condition(cond, then_func, else_func, name="cond"): """Run an if-then-else using user-defined condition and computation This operator simulates a if-like branch which chooses to do one of the two customized computations according to the specified condition. - `inputs` is a list of Symbols on which the condition and computations rely on. - - `cond_func` is a user-defined function, used as the if condition. - It consumes `inputs`, and produces a scalar MXNet symbol, + `cond` is a scalar MXNet Symbol, indicating which branch of computation should be used. - The `cond_func` is variadic, and its signature should be - `cond_func(*loop_vars) => Symbol`. `then_func` is a user-defined function, used as computation of the then branch. - It consumes `inputs`, and produces `outputs`. - The `then_func` is variadic, and its signature should be - `then_func(*loop_vars) => List[Symbol]`. + It produces `outputs`, which is a list of Symbols. + The signature of `then_func` should be + `then_func() => List[Symbol]`. `else_func` is a user-defined function, used as computation of the else branch. - It also consumes `inputs`, and produces `outputs`. - The `else_func` is variadic, and its signature should be - `else_func(*loop_vars) => List[Symbol]`. + It produces `outputs`, which is a list of Symbols. + The signature of `else_func` should be + `else_func() => List[Symbol]`. The `outputs` produces by `then_func` and `else_func` should have the same number of elements, all of which should be in the same shape, of the same dtype and stype. @@ -588,14 +583,12 @@ def condition(cond_func, then_func, else_func, inputs, name="cond"): # pylint: Parameters ---------- - cond_func: a Python function. + cond: a MXNet Symbol representing a scalar. The branch condition. then_func: a Python function. - The computation to be executed if `cond_func` is true. + The computation to be executed if `cond` is true. else_func: a Python function. - The computation to be executed if `cond_func` is false. - inputs: list of Symbols. - The variables fed to `cond_func`, `then_func` and `else_func`. + The computation to be executed if `cond` is false. Returns ------- @@ -603,11 +596,11 @@ def condition(cond_func, then_func, else_func, inputs, name="cond"): # pylint: Examples -------- - >>> cond_func = lambda a, b: a * b < 5 - >>> then_func = lambda a, b: (a + 5) * (b + 5) - >>> else_func = lambda a, b: (a - 5) * (b - 5) - >>> inputs = (mx.sym.var('a'), mx.sym.var('b')) - >>> outputs = mx.sym.contrib.cond(cond_func, then_func, else_func, inputs) + >>> a, b = mx.sym.var('a'), mx.sym.var('b') + >>> cond = a * b < 5 + >>> then_func = lambda: (a + 5) * (b + 5) + >>> else_func = lambda: (a - 5) * (b - 5) + >>> outputs = mx.sym.contrib.cond(cond, then_func, else_func) """ def _to_symbol_tuple(inputs, name): """Converts "inputs", possibly a single mxnet Symbol, a list of mxnet Symbol, @@ -678,13 +671,11 @@ def _union_inputs(*graphs): input_locs.append(loc) locs.append(input_locs) return inputs, locs - inputs = _to_symbol_tuple(inputs, "inputs") - if len(inputs) == 0: - raise ValueError("loop_vars should contain at least one element") + inputs = [] # create graph for `cond_func' - cond_g, cond_num_outputs = _create_subgraph(inputs, cond_func, name + "_cond") + cond_g, cond_num_outputs = _create_subgraph(inputs, lambda: cond, name + "_cond") if cond_num_outputs != 1: - raise ValueError("cond_func should always produce a single output") + raise ValueError("cond should always be a single output") # create graph for `then` then_g, then_num_outputs = _create_subgraph(inputs, then_func, name + "_then") # create graph for `else` diff --git a/tests/python/unittest/test_contrib_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py index 87eac8960339..5323f4ae75b4 100644 --- a/tests/python/unittest/test_contrib_control_flow.py +++ b/tests/python/unittest/test_contrib_control_flow.py @@ -1009,10 +1009,9 @@ def _get_imperative_result(): var.attach_grad() with mx.autograd.record(train_mode=is_train): outputs = mx.nd.contrib.condition( - cond_func=lambda *__input_vars: cond_func(__input_vars, free_vars), - then_func=lambda *__input_vars: then_func(__input_vars, free_vars), - else_func=lambda *__input_vars: else_func(__input_vars, free_vars), - inputs=input_vars, + cond=cond_func(input_vars, free_vars), + then_func=lambda: then_func(input_vars, free_vars), + else_func=lambda: else_func(input_vars, free_vars), ) outputs = [x * 2 for x in outputs] grads = [] @@ -1026,10 +1025,9 @@ def _get_imperative_result(): def _get_symbolic_result(out_grads): outputs_sym = mx.sym.contrib.condition( - cond_func=lambda *__loop_vars: cond_func(__loop_vars, _free_syms), - then_func=lambda *__loop_vars: then_func(__loop_vars, _free_syms), - else_func=lambda *__loop_vars: else_func(__loop_vars, _free_syms), - inputs=_input_syms, + cond=cond_func(_input_syms, _free_syms), + then_func=lambda: then_func(_input_syms, _free_syms), + else_func=lambda: else_func(_input_syms, _free_syms), ) outputs_sym = [x * 2 for x in outputs_sym] outputs_sym = mx.sym.Group(outputs_sym) From b2e70144741c4e5c75e466712f79b598ed7b5356 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 21 Jul 2018 10:03:51 -0700 Subject: [PATCH 5/7] Trigger CI From 5a5a4b4808b1507d452c1b0e02dccc5dfe9e92b6 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 23 Jul 2018 16:57:49 -0700 Subject: [PATCH 6/7] Rename condition to cond --- docs/api/python/ndarray/contrib.md | 2 +- docs/api/python/symbol/contrib.md | 2 +- python/mxnet/ndarray/contrib.py | 20 +++++++++---------- python/mxnet/symbol/contrib.py | 20 +++++++++---------- .../unittest/test_contrib_control_flow.py | 13 ++++++------ 5 files changed, 29 insertions(+), 28 deletions(-) diff --git a/docs/api/python/ndarray/contrib.md b/docs/api/python/ndarray/contrib.md index f575bc8e7ce2..97f38a53a366 100644 --- a/docs/api/python/ndarray/contrib.md +++ b/docs/api/python/ndarray/contrib.md @@ -54,7 +54,7 @@ In the rest of this document, we list routines provided by the `ndarray.contrib` quantize foreach while_loop - condition + cond ``` ## API Reference diff --git a/docs/api/python/symbol/contrib.md b/docs/api/python/symbol/contrib.md index 69d38beffd1c..c0a4da54cbde 100644 --- a/docs/api/python/symbol/contrib.md +++ b/docs/api/python/symbol/contrib.md @@ -54,7 +54,7 @@ In the rest of this document, we list routines provided by the `symbol.contrib` quantize foreach while_loop - condition + cond ``` ## API Reference diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index 45f058336eb5..aae898a3b7a2 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -16,7 +16,7 @@ # under the License. # coding: utf-8 -# pylint: disable=wildcard-import, unused-wildcard-import +# pylint: disable=wildcard-import, unused-wildcard-import,redefined-outer-name """Contrib NDArray API of MXNet.""" import math from ..context import current_context @@ -28,7 +28,7 @@ except ImportError: pass -__all__ = ["rand_zipfian", "foreach", "while_loop", "condition"] +__all__ = ["rand_zipfian", "foreach", "while_loop", "cond"] # pylint: disable=line-too-long def rand_zipfian(true_classes, num_sampled, range_max, ctx=None): @@ -363,13 +363,13 @@ def _func_wrapper(loop_vars): )) return stacked_outputs, list(loop_vars) -def condition(cond, then_func, else_func): +def cond(pred, then_func, else_func): """Run an if-then-else using user-defined condition and computation This operator simulates a if-like branch which chooses to do one of the two customized computations according to the specified condition. - `cond` is a scalar MXNet NDArray, + `pred` is a scalar MXNet NDArray, indicating which branch of computation should be used. `then_func` is a user-defined function, used as computation of the then branch. @@ -389,12 +389,12 @@ def condition(cond, then_func, else_func): Parameters ---------- - cond: a MXNet NDArray representing a scalar. + pred: a MXNet NDArray representing a scalar. The branch condition. then_func: a Python function. - The computation to be executed if `cond` is true. + The computation to be executed if `pred` is true. else_func: a Python function. - The computation to be executed if `cond` is false. + The computation to be executed if `pred` is false. Returns ------- @@ -403,10 +403,10 @@ def condition(cond, then_func, else_func): Examples -------- >>> a, b = mx.nd.array([1]), mx.nd.array([2]) - >>> cond = a * b < 5 + >>> pred = a * b < 5 >>> then_func = lambda a, b: (a + 5) * (b + 5) >>> else_func = lambda a, b: (a - 5) * (b - 5) - >>> outputs = mx.nd.contrib.cond(cond, then_func, else_func) + >>> outputs = mx.nd.contrib.cond(pred, then_func, else_func) >>> outputs[0] [42.] @@ -438,7 +438,7 @@ def _to_ndarray_tuple(inputs, name): raise ValueError("%s must be an NDArray, or a tuple or list of NDArrays" % (name, )) return inputs - branch = _to_python_scalar(cond, bool, "cond") + branch = _to_python_scalar(pred, bool, "pred") if branch: outputs = then_func() outputs = _to_ndarray_tuple(outputs, "outputs of then_func") diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index 2d8b034533dc..4ecd31456a69 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -34,7 +34,7 @@ from ..base import SymbolHandle, _as_list from ..attribute import AttrScope -__all__ = ["rand_zipfian", "foreach", "while_loop", "condition"] +__all__ = ["rand_zipfian", "foreach", "while_loop", "cond"] def rand_zipfian(true_classes, num_sampled, range_max): """Draw random samples from an approximately log-uniform or Zipfian distribution. @@ -557,13 +557,13 @@ def _union_inputs(*graphs): final_loop_vars = [result[i] for i in range(num_out_data, num_outputs)] return outputs, final_loop_vars -def condition(cond, then_func, else_func, name="cond"): +def cond(pred, then_func, else_func, name="cond"): """Run an if-then-else using user-defined condition and computation This operator simulates a if-like branch which chooses to do one of the two customized computations according to the specified condition. - `cond` is a scalar MXNet Symbol, + `pred` is a scalar MXNet Symbol, indicating which branch of computation should be used. `then_func` is a user-defined function, used as computation of the then branch. @@ -583,12 +583,12 @@ def condition(cond, then_func, else_func, name="cond"): Parameters ---------- - cond: a MXNet Symbol representing a scalar. + pred: a MXNet Symbol representing a scalar. The branch condition. then_func: a Python function. - The computation to be executed if `cond` is true. + The computation to be executed if `pred` is true. else_func: a Python function. - The computation to be executed if `cond` is false. + The computation to be executed if `pred` is false. Returns ------- @@ -597,10 +597,10 @@ def condition(cond, then_func, else_func, name="cond"): Examples -------- >>> a, b = mx.sym.var('a'), mx.sym.var('b') - >>> cond = a * b < 5 + >>> pred = a * b < 5 >>> then_func = lambda: (a + 5) * (b + 5) >>> else_func = lambda: (a - 5) * (b - 5) - >>> outputs = mx.sym.contrib.cond(cond, then_func, else_func) + >>> outputs = mx.sym.contrib.cond(pred, then_func, else_func) """ def _to_symbol_tuple(inputs, name): """Converts "inputs", possibly a single mxnet Symbol, a list of mxnet Symbol, @@ -673,9 +673,9 @@ def _union_inputs(*graphs): return inputs, locs inputs = [] # create graph for `cond_func' - cond_g, cond_num_outputs = _create_subgraph(inputs, lambda: cond, name + "_cond") + cond_g, cond_num_outputs = _create_subgraph(inputs, lambda: pred, name + "_pred") if cond_num_outputs != 1: - raise ValueError("cond should always be a single output") + raise ValueError("pred should always be a single output") # create graph for `then` then_g, then_num_outputs = _create_subgraph(inputs, then_func, name + "_then") # create graph for `else` diff --git a/tests/python/unittest/test_contrib_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py index 5323f4ae75b4..f758b3952c08 100644 --- a/tests/python/unittest/test_contrib_control_flow.py +++ b/tests/python/unittest/test_contrib_control_flow.py @@ -1008,8 +1008,8 @@ def _get_imperative_result(): for var in free_vars + input_vars: var.attach_grad() with mx.autograd.record(train_mode=is_train): - outputs = mx.nd.contrib.condition( - cond=cond_func(input_vars, free_vars), + outputs = mx.nd.contrib.cond( + pred=cond_func(input_vars, free_vars), then_func=lambda: then_func(input_vars, free_vars), else_func=lambda: else_func(input_vars, free_vars), ) @@ -1024,8 +1024,8 @@ def _get_imperative_result(): return _to_numpy_list(outputs), _to_numpy_list(grads), out_grads def _get_symbolic_result(out_grads): - outputs_sym = mx.sym.contrib.condition( - cond=cond_func(_input_syms, _free_syms), + outputs_sym = mx.sym.contrib.cond( + pred=cond_func(_input_syms, _free_syms), then_func=lambda: then_func(_input_syms, _free_syms), else_func=lambda: else_func(_input_syms, _free_syms), ) @@ -1130,5 +1130,6 @@ def cond(inputs, free): ) if __name__ == '__main__': - import nose - nose.runmodule() + # import nose + # nose.runmodule() + test_cond() From cc355e2b6db06d1b703070232054cd9a5bf4bdfb Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 23 Jul 2018 17:08:10 -0700 Subject: [PATCH 7/7] Fix lint --- python/mxnet/symbol/contrib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index 4ecd31456a69..884288364b3d 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -16,7 +16,7 @@ # under the License. # coding: utf-8 -# pylint: disable=wildcard-import, unused-wildcard-import +# pylint: disable=wildcard-import, unused-wildcard-import,redefined-outer-name """Contrib Symbol API of MXNet.""" import math import ctypes