Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Rename ifelse to condition
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Jul 20, 2018
1 parent 2b46358 commit b28bacf
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 115 deletions.
2 changes: 1 addition & 1 deletion docs/api/python/ndarray/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ In the rest of this document, we list routines provided by the `ndarray.contrib`
quantize
foreach
while_loop
ifelse
condition
```

## API Reference
Expand Down
2 changes: 1 addition & 1 deletion docs/api/python/symbol/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ In the rest of this document, we list routines provided by the `symbol.contrib`
quantize
foreach
while_loop
ifelse
condition
```

## API Reference
Expand Down
24 changes: 12 additions & 12 deletions python/mxnet/ndarray/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
except ImportError:
pass

__all__ = ["rand_zipfian", "foreach", "while_loop", "ifelse"]
__all__ = ["rand_zipfian", "foreach", "while_loop", "condition"]

# pylint: disable=line-too-long
def rand_zipfian(true_classes, num_sampled, range_max, ctx=None):
Expand Down Expand Up @@ -363,19 +363,19 @@ def _func_wrapper(loop_vars):
))
return stacked_outputs, list(loop_vars)

def ifelse(cond, then_func, else_func, inputs):
def condition(cond_func, then_func, else_func, inputs): # pylint: disable=redefined-outer-name
"""Run an if-then-else using user-defined condition and computation
This operator simulates a if-like branch which chooses to do one of
the two customized computations according to the specified condition.
`inputs` is a list of NDArrays on which the condition and computations rely on.
`cond` is a user-defined function, used as the if condition.
`cond_func` is a user-defined function, used as the if condition.
It consumes `inputs`, and produces a scalar MXNet NDArray,
indicating which branch of computation should be used.
The `cond` is variadic, and its signature should be
`cond(*loop_vars) => NDArray`.
The `cond_func` is variadic, and its signature should be
`cond_func(*loop_vars) => NDArray`.
`then_func` is a user-defined function, used as computation of the then branch.
It consumes `inputs`, and produces `outputs`.
Expand All @@ -394,26 +394,26 @@ def ifelse(cond, then_func, else_func, inputs):
Parameters
----------
cond: a Python function.
cond_func: a Python function.
The branch condition.
then_func: a Python function.
The computation to be executed if `cond` is true.
The computation to be executed if `cond_func` is true.
else_func: a Python function.
The computation to be executed if `cond` is false.
The computation to be executed if `cond_func` is false.
inputs: list of NDArrays.
The variables fed to `cond`, `then_func` and `else_func`.
The variables fed to `cond_func`, `then_func` and `else_func`.
Returns
-------
outputs: a list of NDArrays, representing the result of computation.
Examples
--------
>>> cond = lambda a, b: a * b < 5
>>> cond_func = lambda a, b: a * b < 5
>>> then_func = lambda a, b: (a + 5) * (b + 5)
>>> else_func = lambda a, b: (a - 5) * (b - 5)
>>> inputs = (mx.nd.array([1]), mx.nd.array([2]))
>>> outputs = mx.nd.contrib.ifelse(cond, then_func, else_func, inputs)
>>> outputs = mx.nd.contrib.cond(cond_func, then_func, else_func, inputs)
>>> outputs[0]
[42.]
<NDArray 1 @cpu(0)>
Expand Down Expand Up @@ -448,7 +448,7 @@ def _to_ndarray_tuple(inputs, name):
inputs = _to_ndarray_tuple(inputs, "inputs")
if len(inputs) == 0:
raise ValueError("inputs should contain at least one element")
branch = _to_python_scalar(cond(*inputs), bool, "Return value of cond")
branch = _to_python_scalar(cond_func(*inputs), bool, "Return value of cond_func")
if branch:
outputs = then_func(*inputs)
outputs = _to_ndarray_tuple(outputs, "outputs of then_func")
Expand Down
32 changes: 16 additions & 16 deletions python/mxnet/symbol/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from ..base import SymbolHandle, _as_list
from ..attribute import AttrScope

__all__ = ["rand_zipfian", "foreach", "while_loop", "ifelse"]
__all__ = ["rand_zipfian", "foreach", "while_loop", "condition"]

def rand_zipfian(true_classes, num_sampled, range_max):
"""Draw random samples from an approximately log-uniform or Zipfian distribution.
Expand Down Expand Up @@ -557,19 +557,19 @@ def _union_inputs(*graphs):
final_loop_vars = [result[i] for i in range(num_out_data, num_outputs)]
return outputs, final_loop_vars

def ifelse(cond, then_func, else_func, inputs, name="ifelse"):
def condition(cond_func, then_func, else_func, inputs, name="cond"): # pylint: disable=redefined-outer-name
"""Run an if-then-else using user-defined condition and computation
This operator simulates a if-like branch which chooses to do one of
the two customized computations according to the specified condition.
`inputs` is a list of Symbols on which the condition and computations rely on.
`cond` is a user-defined function, used as the if condition.
`cond_func` is a user-defined function, used as the if condition.
It consumes `inputs`, and produces a scalar MXNet symbol,
indicating which branch of computation should be used.
The `cond` is variadic, and its signature should be
`cond(*loop_vars) => Symbol`.
The `cond_func` is variadic, and its signature should be
`cond_func(*loop_vars) => Symbol`.
`then_func` is a user-defined function, used as computation of the then branch.
It consumes `inputs`, and produces `outputs`.
Expand All @@ -588,26 +588,26 @@ def ifelse(cond, then_func, else_func, inputs, name="ifelse"):
Parameters
----------
cond: a Python function.
cond_func: a Python function.
The branch condition.
then_func: a Python function.
The computation to be executed if `cond` is true.
The computation to be executed if `cond_func` is true.
else_func: a Python function.
The computation to be executed if `cond` is false.
The computation to be executed if `cond_func` is false.
inputs: list of Symbols.
The variables fed to `cond`, `then_func` and `else_func`.
The variables fed to `cond_func`, `then_func` and `else_func`.
Returns
-------
outputs: a list of Symbols, representing the result of computation.
Examples
--------
>>> cond = lambda a, b: a * b < 5
>>> cond_func = lambda a, b: a * b < 5
>>> then_func = lambda a, b: (a + 5) * (b + 5)
>>> else_func = lambda a, b: (a - 5) * (b - 5)
>>> inputs = (mx.sym.var('a'), mx.sym.var('b'))
>>> outputs = mx.sym.contrib.ifelse(cond, then_func, else_func, inputs)
>>> outputs = mx.sym.contrib.cond(cond_func, then_func, else_func, inputs)
"""
def _to_symbol_tuple(inputs, name):
"""Converts "inputs", possibly a single mxnet Symbol, a list of mxnet Symbol,
Expand Down Expand Up @@ -681,10 +681,10 @@ def _union_inputs(*graphs):
inputs = _to_symbol_tuple(inputs, "inputs")
if len(inputs) == 0:
raise ValueError("loop_vars should contain at least one element")
# create graph for `cond'
cond_g, num_outputs = _create_subgraph(inputs, cond, name + "_cond")
if num_outputs != 1:
raise ValueError("cond should always produce a single output")
# create graph for `cond_func'
cond_g, cond_num_outputs = _create_subgraph(inputs, cond_func, name + "_cond")
if cond_num_outputs != 1:
raise ValueError("cond_func should always produce a single output")
# create graph for `then`
then_g, then_num_outputs = _create_subgraph(inputs, then_func, name + "_then")
# create graph for `else`
Expand All @@ -694,7 +694,7 @@ def _union_inputs(*graphs):
# find symbols used in either cond_g or func_g
input_syms, (cond_input_locs, then_input_locs, else_input_locs) = \
_union_inputs(cond_g, then_g, else_g)
result = symbol._internal._ifelse(
result = symbol._internal._cond(
# [cond, then_g, else_g, *input_syms]
cond_g,
then_g,
Expand Down
Loading

0 comments on commit b28bacf

Please sign in to comment.