Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add doc and remove example code.
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da committed May 19, 2018
1 parent dacd6a0 commit b7339c2
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 148 deletions.
146 changes: 1 addition & 145 deletions python/mxnet/gluon/contrib/rnn/rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,10 @@

# coding: utf-8
"""Definition of various recurrent neural network cells."""
__all__ = ['VariationalDropoutCell', 'LSTMPCell', 'SymHybridRNNCell', 'RNNCell']
__all__ = ['VariationalDropoutCell', 'LSTMPCell']

import inspect

from .... import symbol, ndarray
from ....base import _as_list
from ...rnn import BidirectionalCell, SequentialRNNCell, ModifierCell, HybridRecurrentCell
from ...rnn.rnn_cell import _format_sequence, _get_begin_state, _mask_sequence_variable_length
from ...rnn.rnn_cell import RNNCell as GluonRNNCell
from ... import tensor_types

class VariationalDropoutCell(ModifierCell):
Expand Down Expand Up @@ -320,142 +315,3 @@ def hybrid_forward(self, F, inputs, states, i2h_weight,

return next_r, [next_r, next_c]
# pylint: enable= arguments-differ

class SymHybridRNNCell(HybridRecurrentCell):
def __init__(self, prefix=None, params=None):
super(SymHybridRNNCell, self).__init__(prefix=prefix, params=params)

def unroll(self, inputs, begin_state=None, layout='TNC',
merge_outputs=None, valid_length=None):
# if this is a list, we can have unroll in the parent class to handle it.
if (isinstance(inputs, list)):
return super(SymHybridRNNCell, self).unroll(self, len(inputs), inputs, begin_state,
layout, merge_outputs, valid_length)

self.reset()
batch_axis = layout.find('N')
axis = layout.find('T')
batch_size = 0
if isinstance(inputs, symbol.Symbol):
F = symbol
else:
batch_size = inputs.shape[batch_axis]
F = ndarray
begin_state = _get_begin_state(self, F, begin_state, inputs, batch_size)

states = begin_state
outputs = []
all_states = []
def iter_func(input, states):
return self(input, states)
outputs, last_states = F.contrib.foreach(iter_func, inputs, begin_state)
#if valid_length is not None:
# states = [F.SequenceLast(ele_list,
# sequence_length=valid_length,
# use_sequence_length=True,
# axis=0)
# for ele_list in all_states]
# outputs = F.SequenceMask(outputs, sequence_length=valid_length, use_sequence_length=True,
# axis=axis)
#outputs, _, _, _ = _format_sequence(length, outputs, layout, merge_outputs)

return outputs, last_states

class RNNCell(SymHybridRNNCell):
r"""Elman RNN recurrent neural network cell.
Each call computes the following function:
.. math::
h_t = \tanh(w_{ih} * x_t + b_{ih} + w_{hh} * h_{(t-1)} + b_{hh})
where :math:`h_t` is the hidden state at time `t`, and :math:`x_t` is the hidden
state of the previous layer at time `t` or :math:`input_t` for the first layer.
If nonlinearity='relu', then `ReLU` is used instead of `tanh`.
Parameters
----------
hidden_size : int
Number of units in output symbol
activation : str or Symbol, default 'tanh'
Type of activation function.
i2h_weight_initializer : str or Initializer
Initializer for the input weights matrix, used for the linear
transformation of the inputs.
h2h_weight_initializer : str or Initializer
Initializer for the recurrent weights matrix, used for the linear
transformation of the recurrent state.
i2h_bias_initializer : str or Initializer
Initializer for the bias vector.
h2h_bias_initializer : str or Initializer
Initializer for the bias vector.
prefix : str, default 'rnn_'
Prefix for name of `Block`s
(and name of weight if params is `None`).
params : Parameter or None
Container for weight sharing between cells.
Created if `None`.
Inputs:
- **data**: input tensor with shape `(batch_size, input_size)`.
- **states**: a list of one initial recurrent state tensor with shape
`(batch_size, num_hidden)`.
Outputs:
- **out**: output tensor with shape `(batch_size, num_hidden)`.
- **next_states**: a list of one output recurrent state tensor with the
same shape as `states`.
"""
def __init__(self, hidden_size, activation='tanh',
i2h_weight_initializer=None, h2h_weight_initializer=None,
i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
input_size=0, prefix=None, params=None):
super(RNNCell, self).__init__(prefix=prefix, params=params)
self._hidden_size = hidden_size
self._activation = activation
self._input_size = input_size
self.i2h_weight = self.params.get('i2h_weight', shape=(hidden_size, input_size),
init=i2h_weight_initializer,
allow_deferred_init=True)
self.h2h_weight = self.params.get('h2h_weight', shape=(hidden_size, hidden_size),
init=h2h_weight_initializer,
allow_deferred_init=True)
self.i2h_bias = self.params.get('i2h_bias', shape=(hidden_size,),
init=i2h_bias_initializer,
allow_deferred_init=True)
self.h2h_bias = self.params.get('h2h_bias', shape=(hidden_size,),
init=h2h_bias_initializer,
allow_deferred_init=True)

def state_info(self, batch_size=0):
return [{'shape': (batch_size, self._hidden_size), '__layout__': 'NC'}]

def _alias(self):
return 'rnn'

def __repr__(self):
s = '{name}({mapping}'
if hasattr(self, '_activation'):
s += ', {_activation}'
s += ')'
shape = self.i2h_weight.shape
mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0])
return s.format(name=self.__class__.__name__,
mapping=mapping,
**self.__dict__)

def hybrid_forward(self, F, inputs, states, i2h_weight,
h2h_weight, i2h_bias, h2h_bias):
prefix = 't%d_'%self._counter
i2h = F.FullyConnected(data=inputs, weight=i2h_weight, bias=i2h_bias,
num_hidden=self._hidden_size,
name=prefix+'i2h')
h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias,
num_hidden=self._hidden_size,
name=prefix+'h2h')
output = self._get_activation(F, i2h + h2h, self._activation,
name=prefix+'out')

return output, [output]
62 changes: 59 additions & 3 deletions python/mxnet/ndarray/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,68 @@ def rand_zipfian(true_classes, num_sampled, range_max, ctx=None):
return sampled_classes, expected_count_true, expected_count_sampled
# pylint: enable=line-too-long

def foreach(func, input, init_states, back_prop=False, name="foreach"):
def foreach(func, data, init_states, name="foreach"):
"""Run a for loop with user-defined computation over NDArrays on dimension 0.
This operator simulates a for loop and func has the computation for an iteration
of the for loop. It runs the computation in func on each slice from the input
NDArrays.
func takes two arguments as input and outputs a tuple of two elements,
as illustrated below:
out, states = func(data1, states)
data1 can be either a symbol or a list of symbols. If data is a symbol,
data1 is a symbol. Otherwise, data1 is a list of symbols and has the same
size as data. states is a list of symbols and have the same size as init_states.
Similarly, out can be either a symbol or a list of symbols, which are concatenated
as the first output of foreach; states from the last execution of func
are the second output of foreach.
The computation done by this operator is equivalent to the pseudo code below
when the input data is NDArray:
states = init_states
outs = []
for i in data.shape[0]:
s = data[i]
out, states = func(s, states)
outs.append(out)
outs = stack(*outs)
Parameters
----------
func : a Python function.
Define computation in an iteration.
data: a symbol or a list of symbols.
The input data.
init_states: a list of symbols.
The initial values of the loop states.
name: string.
The name of the operator.
Returns
-------
outputs: a Symbol or a list of Symbols.
The output data concatenated from the output of all iterations.
states: a list of Symbols.
The loop states in the last iteration.
Examples
--------
>>> step = lambda data, states: (data + states[0], [states[0] * 2])
>>> data = mx.nd.random.uniform(shape=(2, 10))
>>> states = [mx.nd.random.uniform(shape=(10))]
>>> outs, states = mx.nd.contrib.foreach(step, data, states)
"""

assert isinstance(init_states, list), "init_states should be a list"
states = init_states
outputs = []
for i in range(input.shape[0]):
ele = input[i]
for i in range(data.shape[0]):
ele = data[i]
outs, states = func(ele, states)
outs = _as_list(outs)
if (i == 0):
Expand Down

0 comments on commit b7339c2

Please sign in to comment.