Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #36 from antinucleon/op
Browse files Browse the repository at this point in the history
Add BatchNorm, change interface
  • Loading branch information
antinucleon committed Sep 2, 2015
2 parents 61aa571 + a498f6e commit 4cbf9de
Show file tree
Hide file tree
Showing 29 changed files with 773 additions and 200 deletions.
6 changes: 4 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ endif

BIN = tests/test_simple_engine
OBJ = narray_function_cpu.o
OBJCXX11 = reshape_cpu.o dag_engine.o simple_engine.o narray.o c_api.o operator.o symbol.o storage.o fully_connected_cpu.o static_graph.o activation_cpu.o graph_executor.o softmax_cpu.o elementwise_sum_cpu.o pooling_cpu.o convolution_cpu.o io.o iter_mnist.o
OBJCXX11 = batch_norm_cpu.o reshape_cpu.o dag_engine.o simple_engine.o narray.o c_api.o operator.o symbol.o storage.o fully_connected_cpu.o static_graph.o activation_cpu.o graph_executor.o softmax_cpu.o elementwise_sum_cpu.o pooling_cpu.o convolution_cpu.o io.o iter_mnist.o
CUOBJ =
SLIB = lib/libmxnet.so
ALIB = lib/libmxnet.a
LIB_DEP = $(DMLC_CORE)/libdmlc.a

ifeq ($(USE_CUDA), 1)
CUOBJ += reshape_gpu.o narray_function_gpu.o fully_connected_gpu.o activation_gpu.o elementwise_sum_gpu.o pooling_gpu.o softmax_gpu.o convolution_gpu.o
CUOBJ += batch_norm_gpu.o reshape_gpu.o narray_function_gpu.o fully_connected_gpu.o activation_gpu.o elementwise_sum_gpu.o pooling_gpu.o softmax_gpu.o convolution_gpu.o
endif

.PHONY: clean all test lint doc
Expand Down Expand Up @@ -105,6 +105,8 @@ convolution_cpu.o: src/operator/convolution.cc
convolution_gpu.o: src/operator/convolution.cu
reshape_cpu.o: src/operator/reshape.cc
reshape_gpu.o: src/operator/reshape.cu
batch_norm_cpu.o: src/operator/batch_norm.cc
batch_norm_gpu.o: src/operator/batch_norm.cu
io.o: src/io/io.cc
iter_mnist.o: src/io/iter_mnist.cc

Expand Down
23 changes: 22 additions & 1 deletion include/mxnet/c_api.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,16 @@ MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol,
MXNET_DLL int MXSymbolListReturns(SymbolHandle symbol,
mx_uint *out_size,
const char ***out_str_array);
/*!
* \brief List auxiliary states in the symbol.
* \param symbol the symbol
* \param out_size output size
* \param out_str_array pointer to hold the output string array
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolListAuxiliaryStates(SymbolHandle symbol,
mx_uint *out_size,
const char ***out_str_array);
/*!
* \brief Compose the symbol on other symbols.
*
Expand Down Expand Up @@ -406,6 +416,9 @@ MXNET_DLL int MXSymbolGrad(SymbolHandle sym,
* \param out_shape_size sizeof the returning array of out_shapes
* \param out_shape_ndim returning array of shape dimensions of eachs input shape.
* \param out_shape_data returning array of pointers to head of the input shape.
* \param aux_shape_size sizeof the returning array of aux_shapes
* \param aux_shape_ndim returning array of shape dimensions of eachs auxiliary shape.
* \param aux_shape_data returning array of pointers to head of the auxiliary shape.
* \param complete whether infer shape completes or more information is needed.
* \return 0 when success, -1 when failure happens
*/
Expand All @@ -420,6 +433,9 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym,
mx_uint *out_shape_size,
const mx_uint **out_shape_ndim,
const mx_uint ***out_shape_data,
mx_uint *aux_shape_size,
const mx_uint **aux_shape_ndim,
const mx_uint ***aux_shape_data,
int *complete);
//--------------------------------------------
// Part 4: Executor interface
Expand All @@ -428,9 +444,10 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym,
* \brief Executor forward method
*
* \param handle executor handle
* \param is_train bool value to indicate whether the forward pass is for evaluation
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXExecutorForward(ExecutorHandle handle);
MXNET_DLL int MXExecutorForward(ExecutorHandle handle, bool is_train);
/*!
* \brief Excecutor run backward
*
Expand Down Expand Up @@ -466,6 +483,8 @@ MXNET_DLL int MXExecutorHeads(ExecutorHandle handle,
* \param in_args in args array
* \param arg_grad_store arg grads handle array
* \param grad_req_type grad req array
* \param aux_states_len length of auxiliary states
* \param aux_states auxiliary states array
* \param out output executor handle
* \return 0 when success, -1 when failure happens
*/
Expand All @@ -476,6 +495,8 @@ MXNET_DLL int MXExecutorBind(SymbolHandle symbol_handle,
NArrayHandle *in_args,
NArrayHandle *arg_grad_store,
mx_uint *grad_req_type,
mx_uint aux_states_len,
NArrayHandle *aux_states,
ExecutorHandle *out);

//--------------------------------------------
Expand Down
1 change: 1 addition & 0 deletions include/mxnet/narray.h
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ struct NArrayFunctionReg
} // namespace mxnet

namespace dmlc {
/*!\brief traits */
DMLC_DECLARE_TRAITS(has_saveload, mxnet::NArray, true);
} // namespace dmlc
#endif // MXNET_NARRAY_H_
21 changes: 18 additions & 3 deletions include/mxnet/operator.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,15 @@ class Operator {
* \param req the request types of saving operation, can only be kWriteTo or kWriteInplace.
* \param out_data array of output data, pointer is used to indicate that this is holder
* the space of TBlob in out_data must be pre-allocated with InferShape
* \param aux_states Auxiliary states of operator. Normally operator doesn't
* need, epecial case like Batch Norm requires.
* \sa OpReqType, OpContext
*/
virtual void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data) = 0;
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_states) = 0;
/*!
* \brief Perform a Backward Operation, write gradient to the in_grad.
*
Expand All @@ -111,14 +114,16 @@ class Operator {
* \param out_data the array of output data.
* \param req request types of the saving operation, can be all types.
* \param in_grad the array of gradient we need to write to.
* \param aux_states Auxiliary states of operator. Normally operator doesn't need
* \sa OperatorProperty, OpReqType, OpContext
*/
virtual void Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad) {
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &aux_states) {
LOG(FATAL) << "Backward is not implemented";
}
};
Expand Down Expand Up @@ -158,6 +163,13 @@ class OperatorProperty {
virtual std::vector<std::string> ListReturns() const {
return {"output"};
}
/*!
* \brief Get name of auxilary states of Operator
* \return name of return values.
*/
virtual std::vector<std::string> ListAuxiliaryStates() const {
return {};
}
/*! \return number of real return values of the Operator */
virtual int NumReturns() const {
return 1;
Expand Down Expand Up @@ -189,11 +201,14 @@ class OperatorProperty {
*
* \param out_shape the shape of outputs of the operator
* InferShape will modify the vector to fill output TShape
* \param aux_shape the shape of auxiliary states of the operator
* InferShape will modify the vector to fill output TShape
* \return true if the shape inference is successful, false if there is not enough information.
* \throws dmlc::Error if the known arg_shapes are inconsistent.
*/
virtual bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) const = 0;
std::vector<TShape> *out_shape,
std::vector<TShape> *aux_shape) const = 0;
/*!
* \brief Copy this OperatorProperty.
* \return a pointer to the copied OperatorProperty
Expand Down
25 changes: 19 additions & 6 deletions include/mxnet/symbolic.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,12 @@ class StaticGraph {
*
* \param topo_order The topological order of node index, as created by TopoSort.
* \param node_out_shapes The shapes of the each outputs of nodes in the graph.
* \param node_aux_shapes The shapes of the each auxiliary states of nodes in the graph.
* \return if the shape inference is successful, return true, else return false.
*/
bool InferNodeShapes(const std::vector<uint32_t> &topo_order,
std::vector<std::vector<TShape> > *node_out_shapes) const;
std::vector<std::vector<TShape> > *node_out_shapes,
std::vector<std::vector<TShape> > *node_aux_shapes) const;
/*!
* \brief infer the shapes of outputs and unknown input arguments
* \param in_shape the shape of input arguments of the operator
Expand All @@ -144,10 +146,13 @@ class StaticGraph {
*
* \param out_shape the shape of outputs of the operator
* InferShape will modify the vector to fill output TShape
* \param aux_shape the shape of auxiliary states of the operator
* InferShape will modify the vector to fill output TShape
* \return if the shape inference is successful, return true, else return false.
*/
bool InferShape(std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) const;
std::vector<TShape>* out_shape,
std::vector<TShape>* aux_shape) const;
/*!
* \brief Add a full backward pass in the static graph.
* This function will add gradient nodes for each heads,
Expand Down Expand Up @@ -204,6 +209,8 @@ class Symbol {
std::vector<std::string> ListArguments() const;
/*! \return get the descriptions of outputs for this symbol */
std::vector<std::string> ListReturns() const;
/*! \return get the descriptions of auxiliary data for this symbol */
std::vector<std::string> ListAuxiliaryStates() const;
/*!
* \brief get the index th element from the returned tuple.
* \param index index of multi output
Expand Down Expand Up @@ -272,22 +279,26 @@ class Symbol {
* common practice: set the shape of data input, and usually weight's shape can be infered
*
* \param out_shapes Use to store the infered shapes of outputs.
* \param aux_shapes Use to store the infered shapes of auxiliary states
* \return true if the shape inference is successful, false if there is not enough information.
* \throws dmlc::Error if the known arg_shapes are inconsistent.
*/
bool InferShape(std::vector<TShape> *arg_shapes,
std::vector<TShape> *out_shapes) const;
std::vector<TShape> *out_shapes,
std::vector<TShape> *aux_shapes) const;
/*!
* \brief infer the shapes by providing shapes of known arguments.
* \param known_arg_shapes map of argument name to shape of arguments with known shapes.
* \param arg_shapes used to store infered shapes of arguments.
* \param out_shapes used to store infered shapes of outputs.
* \param aux_shapes Use to store the infered shapes of auxiliary states
* \return true if the shape inference is successful, false if there is not enough information.
* \throws dmlc::Error if the known arg_shapes are inconsistent.
*/
bool InferShape(const std::unordered_map<std::string, TShape> &known_arg_shapes,
std::vector<TShape> *arg_shapes,
std::vector<TShape> *out_shapes) const;
std::vector<TShape> *out_shapes,
std::vector<TShape> *aux_shapes) const;
/*!
* \brief get number of outputs of this symbol
* \return number of outputs
Expand Down Expand Up @@ -378,7 +389,7 @@ class Executor {
* \brief Perform a Forward operation of Operator
* After this operation, user can get the result by using function head.
*/
virtual void Forward() = 0;
virtual void Forward(bool is_train) = 0;
/*!
* \brief Perform a Backward operation of the Operator.
* This must be called after Forward.
Expand All @@ -400,13 +411,15 @@ class Executor {
* \param in_args the NArray that stores the input arguments to the symbol.
* \param arg_grad_store NArray that is used to store the gradient output of the input arguments.
* \param grad_req_type requirment type of gradient saving. Can only be in {kNullOp, kAddTo, kWriteTo}.
* \param aux_states NArray that is used as internal state in op
* \return a new executor.
*/
static Executor *Bind(Symbol symbol,
Context ctx,
const std::vector<NArray> &in_args,
const std::vector<NArray> &arg_grad_store,
const std::vector<OpReqType> &grad_req_type);
const std::vector<OpReqType> &grad_req_type,
const std::vector<NArray> &aux_states);
}; // class operator
} // namespace mxnet
#endif // MXNET_SYMBOLIC_H_
13 changes: 10 additions & 3 deletions python/mxnet/executor.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,16 @@ def __init__(self, handle):
raise TypeError("Handle type error")
self.handle = handle

def forward(self):
"""Do forward."""
check_call(_LIB.MXExecutorForward(self.handle))
def forward(self, is_train=True):
"""Do forward.
Parameters
----------
is_train: bool
whether this forward is for evaluation purpose
Note: for test only network, please indicate in Bind (TODO)
"""
check_call(_LIB.MXExecutorForward(self.handle, is_train))

def backward(self, grads):
"""Do backward on heads' gradient.
Expand Down
38 changes: 35 additions & 3 deletions python/mxnet/symbol.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding: utf-8
# pylint: disable=invalid-name, protected-access, fixme
# pylint: disable=invalid-name, protected-access, fixme, too-many-arguments
"""Symbol support of mxnet"""
from __future__ import absolute_import

Expand Down Expand Up @@ -123,6 +123,20 @@ def list_returns(self):
self.handle, ctypes.byref(size), ctypes.byref(sarr)))
return [py_str(sarr[i]) for i in range(size.value)]

def list_auxiliary_states(self):
"""List all auxiliary states in the symbool.
Returns
-------
args: list of string
List of all the auxiliary
"""
size = ctypes.c_uint()
sarr = ctypes.POINTER(ctypes.c_char_p)()
check_call(_LIB.MXSymbolListAuxiliaryStates(
self.handle, ctypes.byref(size), ctypes.byref(sarr)))
return [py_str(sarr[i]) for i in range(size.value)]

def infer_shape(self, *args, **kwargs):
"""Infer the shape of outputs and arguments of given known shapes of arguments.
Expand All @@ -147,6 +161,9 @@ def infer_shape(self, *args, **kwargs):
out_shapes : list of tuple or None
List of shapes of outputs.
The order is in the same order as list_returns()
aux_shapes : list of tuple or None
List of shapes of outputs.
The order is in the same order as list_auxiliary()
"""
# pylint: disable=too-many-locals
if len(args) != 0 and len(kwargs) != 0:
Expand Down Expand Up @@ -176,6 +193,9 @@ def infer_shape(self, *args, **kwargs):
out_shape_size = mx_uint()
out_shape_ndim = ctypes.POINTER(mx_uint)()
out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_uint))()
aux_shape_size = mx_uint()
aux_shape_ndim = ctypes.POINTER(mx_uint)()
aux_shape_data = ctypes.POINTER(ctypes.POINTER(mx_uint))()
complete = ctypes.c_int()
check_call(_LIB.MXSymbolInferShape(
self.handle, len(indptr) - 1,
Expand All @@ -188,13 +208,18 @@ def infer_shape(self, *args, **kwargs):
ctypes.byref(out_shape_size),
ctypes.byref(out_shape_ndim),
ctypes.byref(out_shape_data),
ctypes.byref(aux_shape_size),
ctypes.byref(aux_shape_ndim),
ctypes.byref(aux_shape_data),
ctypes.byref(complete)))
if complete.value != 0:
arg_shapes = [
tuple(arg_shape_data[i][:arg_shape_ndim[i]]) for i in range(arg_shape_size.value)]
out_shapes = [
tuple(out_shape_data[i][:out_shape_ndim[i]]) for i in range(out_shape_size.value)]
return (arg_shapes, out_shapes)
aux_shapes = [
tuple(aux_shape_data[i][:aux_shape_ndim[i]]) for i in range(aux_shape_size.value)]
return (arg_shapes, out_shapes, aux_shapes)
else:
return (None, None)
# pylint: enable=too-many-locals
Expand All @@ -212,7 +237,7 @@ def debug_str(self):
self.handle, ctypes.byref(debug_str)))
return py_str(debug_str.value)

def bind(self, ctx, args, args_grad, reqs):
def bind(self, ctx, args, args_grad, reqs, aux_states=None):
"""bind current symbol to get an executor.
Parameters
Expand All @@ -225,15 +250,20 @@ def bind(self, ctx, args, args_grad, reqs):
input args' gradient
reqs: Array of enum
graident requirements
aux_states: Array of NArray
input auxiliary states to the symbol
"""
# TODO(bing): consider a more friendly interface
# For example, pass in args_grad by dict
enum = {"null" : 0, "write_to" : 1, "in_place":2, "add_to" : 3}
if not isinstance(ctx, Context):
raise TypeError("Context type error")
if aux_states == None:
aux_states = []
args_handle = c_array(NArrayHandle, [item.handle for item in args])
args_grad_handle = c_array(NArrayHandle, [item.handle for item in args_grad])
reqs_array = c_array(mx_uint, [mx_uint(enum[item]) for item in reqs])
aux_args_handle = c_array(NArrayHandle, [item.handle for item in aux_states])
handle = ExecutorHandle()
check_call(_LIB.MXExecutorBind(self.handle,
mx_uint(ctx.device_mask),
Expand All @@ -242,6 +272,8 @@ def bind(self, ctx, args, args_grad, reqs):
args_handle,
args_grad_handle,
reqs_array,
len(aux_states),
aux_args_handle,
ctypes.byref(handle)))
return Executor(handle)

Expand Down
Loading

0 comments on commit 4cbf9de

Please sign in to comment.