Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add BatchNorm, change interface #36

Merged
merged 5 commits into from
Sep 2, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ endif

BIN = tests/test_simple_engine
OBJ = narray_function_cpu.o
OBJCXX11 = reshape_cpu.o dag_engine.o simple_engine.o narray.o c_api.o operator.o symbol.o storage.o fully_connected_cpu.o static_graph.o activation_cpu.o graph_executor.o softmax_cpu.o elementwise_sum_cpu.o pooling_cpu.o convolution_cpu.o io.o iter_mnist.o
OBJCXX11 = batch_norm_cpu.o reshape_cpu.o dag_engine.o simple_engine.o narray.o c_api.o operator.o symbol.o storage.o fully_connected_cpu.o static_graph.o activation_cpu.o graph_executor.o softmax_cpu.o elementwise_sum_cpu.o pooling_cpu.o convolution_cpu.o io.o iter_mnist.o
CUOBJ =
SLIB = lib/libmxnet.so
ALIB = lib/libmxnet.a
LIB_DEP = $(DMLC_CORE)/libdmlc.a

ifeq ($(USE_CUDA), 1)
CUOBJ += reshape_gpu.o narray_function_gpu.o fully_connected_gpu.o activation_gpu.o elementwise_sum_gpu.o pooling_gpu.o softmax_gpu.o convolution_gpu.o
CUOBJ += batch_norm_gpu.o reshape_gpu.o narray_function_gpu.o fully_connected_gpu.o activation_gpu.o elementwise_sum_gpu.o pooling_gpu.o softmax_gpu.o convolution_gpu.o
endif

.PHONY: clean all test lint doc
Expand Down Expand Up @@ -105,6 +105,8 @@ convolution_cpu.o: src/operator/convolution.cc
convolution_gpu.o: src/operator/convolution.cu
reshape_cpu.o: src/operator/reshape.cc
reshape_gpu.o: src/operator/reshape.cu
batch_norm_cpu.o: src/operator/batch_norm.cc
batch_norm_gpu.o: src/operator/batch_norm.cu
io.o: src/io/io.cc
iter_mnist.o: src/io/iter_mnist.cc

Expand Down
23 changes: 22 additions & 1 deletion include/mxnet/c_api.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,16 @@ MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol,
MXNET_DLL int MXSymbolListReturns(SymbolHandle symbol,
mx_uint *out_size,
const char ***out_str_array);
/*!
* \brief List auxiliary states in the symbol.
* \param symbol the symbol
* \param out_size output size
* \param out_str_array pointer to hold the output string array
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolListAuxiliaryStates(SymbolHandle symbol,
mx_uint *out_size,
const char ***out_str_array);
/*!
* \brief Compose the symbol on other symbols.
*
Expand Down Expand Up @@ -406,6 +416,9 @@ MXNET_DLL int MXSymbolGrad(SymbolHandle sym,
* \param out_shape_size sizeof the returning array of out_shapes
* \param out_shape_ndim returning array of shape dimensions of eachs input shape.
* \param out_shape_data returning array of pointers to head of the input shape.
* \param aux_shape_size sizeof the returning array of aux_shapes
* \param aux_shape_ndim returning array of shape dimensions of eachs auxiliary shape.
* \param aux_shape_data returning array of pointers to head of the auxiliary shape.
* \param complete whether infer shape completes or more information is needed.
* \return 0 when success, -1 when failure happens
*/
Expand All @@ -420,6 +433,9 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym,
mx_uint *out_shape_size,
const mx_uint **out_shape_ndim,
const mx_uint ***out_shape_data,
mx_uint *aux_shape_size,
const mx_uint **aux_shape_ndim,
const mx_uint ***aux_shape_data,
int *complete);
//--------------------------------------------
// Part 4: Executor interface
Expand All @@ -428,9 +444,10 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym,
* \brief Executor forward method
*
* \param handle executor handle
* \param is_train bool value to indicate whether the forward pass is for evaluation
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXExecutorForward(ExecutorHandle handle);
MXNET_DLL int MXExecutorForward(ExecutorHandle handle, bool is_train);
/*!
* \brief Excecutor run backward
*
Expand Down Expand Up @@ -466,6 +483,8 @@ MXNET_DLL int MXExecutorHeads(ExecutorHandle handle,
* \param in_args in args array
* \param arg_grad_store arg grads handle array
* \param grad_req_type grad req array
* \param aux_states_len length of auxiliary states
* \param aux_states auxiliary states array
* \param out output executor handle
* \return 0 when success, -1 when failure happens
*/
Expand All @@ -476,6 +495,8 @@ MXNET_DLL int MXExecutorBind(SymbolHandle symbol_handle,
NArrayHandle *in_args,
NArrayHandle *arg_grad_store,
mx_uint *grad_req_type,
mx_uint aux_states_len,
NArrayHandle *aux_states,
ExecutorHandle *out);

//--------------------------------------------
Expand Down
1 change: 1 addition & 0 deletions include/mxnet/narray.h
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ struct NArrayFunctionReg
} // namespace mxnet

namespace dmlc {
/*!\brief traits */
DMLC_DECLARE_TRAITS(has_saveload, mxnet::NArray, true);
} // namespace dmlc
#endif // MXNET_NARRAY_H_
21 changes: 18 additions & 3 deletions include/mxnet/operator.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,15 @@ class Operator {
* \param req the request types of saving operation, can only be kWriteTo or kWriteInplace.
* \param out_data array of output data, pointer is used to indicate that this is holder
* the space of TBlob in out_data must be pre-allocated with InferShape
* \param aux_states Auxiliary states of operator. Normally operator doesn't
* need, epecial case like Batch Norm requires.
* \sa OpReqType, OpContext
*/
virtual void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data) = 0;
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_states) = 0;
/*!
* \brief Perform a Backward Operation, write gradient to the in_grad.
*
Expand All @@ -111,14 +114,16 @@ class Operator {
* \param out_data the array of output data.
* \param req request types of the saving operation, can be all types.
* \param in_grad the array of gradient we need to write to.
* \param aux_states Auxiliary states of operator. Normally operator doesn't need
* \sa OperatorProperty, OpReqType, OpContext
*/
virtual void Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad) {
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &aux_states) {
LOG(FATAL) << "Backward is not implemented";
}
};
Expand Down Expand Up @@ -158,6 +163,13 @@ class OperatorProperty {
virtual std::vector<std::string> ListReturns() const {
return {"output"};
}
/*!
* \brief Get name of auxilary states of Operator
* \return name of return values.
*/
virtual std::vector<std::string> ListAuxiliaryStates() const {
return {};
}
/*! \return number of real return values of the Operator */
virtual int NumReturns() const {
return 1;
Expand Down Expand Up @@ -189,11 +201,14 @@ class OperatorProperty {
*
* \param out_shape the shape of outputs of the operator
* InferShape will modify the vector to fill output TShape
* \param aux_shape the shape of auxiliary states of the operator
* InferShape will modify the vector to fill output TShape
* \return true if the shape inference is successful, false if there is not enough information.
* \throws dmlc::Error if the known arg_shapes are inconsistent.
*/
virtual bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) const = 0;
std::vector<TShape> *out_shape,
std::vector<TShape> *aux_shape) const = 0;
/*!
* \brief Copy this OperatorProperty.
* \return a pointer to the copied OperatorProperty
Expand Down
25 changes: 19 additions & 6 deletions include/mxnet/symbolic.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,12 @@ class StaticGraph {
*
* \param topo_order The topological order of node index, as created by TopoSort.
* \param node_out_shapes The shapes of the each outputs of nodes in the graph.
* \param node_aux_shapes The shapes of the each auxiliary states of nodes in the graph.
* \return if the shape inference is successful, return true, else return false.
*/
bool InferNodeShapes(const std::vector<uint32_t> &topo_order,
std::vector<std::vector<TShape> > *node_out_shapes) const;
std::vector<std::vector<TShape> > *node_out_shapes,
std::vector<std::vector<TShape> > *node_aux_shapes) const;
/*!
* \brief infer the shapes of outputs and unknown input arguments
* \param in_shape the shape of input arguments of the operator
Expand All @@ -144,10 +146,13 @@ class StaticGraph {
*
* \param out_shape the shape of outputs of the operator
* InferShape will modify the vector to fill output TShape
* \param aux_shape the shape of auxiliary states of the operator
* InferShape will modify the vector to fill output TShape
* \return if the shape inference is successful, return true, else return false.
*/
bool InferShape(std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) const;
std::vector<TShape>* out_shape,
std::vector<TShape>* aux_shape) const;
/*!
* \brief Add a full backward pass in the static graph.
* This function will add gradient nodes for each heads,
Expand Down Expand Up @@ -204,6 +209,8 @@ class Symbol {
std::vector<std::string> ListArguments() const;
/*! \return get the descriptions of outputs for this symbol */
std::vector<std::string> ListReturns() const;
/*! \return get the descriptions of auxiliary data for this symbol */
std::vector<std::string> ListAuxiliaryStates() const;
/*!
* \brief get the index th element from the returned tuple.
* \param index index of multi output
Expand Down Expand Up @@ -272,22 +279,26 @@ class Symbol {
* common practice: set the shape of data input, and usually weight's shape can be infered
*
* \param out_shapes Use to store the infered shapes of outputs.
* \param aux_shapes Use to store the infered shapes of auxiliary states
* \return true if the shape inference is successful, false if there is not enough information.
* \throws dmlc::Error if the known arg_shapes are inconsistent.
*/
bool InferShape(std::vector<TShape> *arg_shapes,
std::vector<TShape> *out_shapes) const;
std::vector<TShape> *out_shapes,
std::vector<TShape> *aux_shapes) const;
/*!
* \brief infer the shapes by providing shapes of known arguments.
* \param known_arg_shapes map of argument name to shape of arguments with known shapes.
* \param arg_shapes used to store infered shapes of arguments.
* \param out_shapes used to store infered shapes of outputs.
* \param aux_shapes Use to store the infered shapes of auxiliary states
* \return true if the shape inference is successful, false if there is not enough information.
* \throws dmlc::Error if the known arg_shapes are inconsistent.
*/
bool InferShape(const std::unordered_map<std::string, TShape> &known_arg_shapes,
std::vector<TShape> *arg_shapes,
std::vector<TShape> *out_shapes) const;
std::vector<TShape> *out_shapes,
std::vector<TShape> *aux_shapes) const;
/*!
* \brief get number of outputs of this symbol
* \return number of outputs
Expand Down Expand Up @@ -378,7 +389,7 @@ class Executor {
* \brief Perform a Forward operation of Operator
* After this operation, user can get the result by using function head.
*/
virtual void Forward() = 0;
virtual void Forward(bool is_train) = 0;
/*!
* \brief Perform a Backward operation of the Operator.
* This must be called after Forward.
Expand All @@ -400,13 +411,15 @@ class Executor {
* \param in_args the NArray that stores the input arguments to the symbol.
* \param arg_grad_store NArray that is used to store the gradient output of the input arguments.
* \param grad_req_type requirment type of gradient saving. Can only be in {kNullOp, kAddTo, kWriteTo}.
* \param aux_states NArray that is used as internal state in op
* \return a new executor.
*/
static Executor *Bind(Symbol symbol,
Context ctx,
const std::vector<NArray> &in_args,
const std::vector<NArray> &arg_grad_store,
const std::vector<OpReqType> &grad_req_type);
const std::vector<OpReqType> &grad_req_type,
const std::vector<NArray> &aux_states);
}; // class operator
} // namespace mxnet
#endif // MXNET_SYMBOLIC_H_
13 changes: 10 additions & 3 deletions python/mxnet/executor.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,16 @@ def __init__(self, handle):
raise TypeError("Handle type error")
self.handle = handle

def forward(self):
"""Do forward."""
check_call(_LIB.MXExecutorForward(self.handle))
def forward(self, is_train=True):
"""Do forward.

Parameters
----------
is_train: bool
whether this forward is for evaluation purpose
Note: for test only network, please indicate in Bind (TODO)
"""
check_call(_LIB.MXExecutorForward(self.handle, is_train))

def backward(self, grads):
"""Do backward on heads' gradient.
Expand Down
38 changes: 35 additions & 3 deletions python/mxnet/symbol.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding: utf-8
# pylint: disable=invalid-name, protected-access, fixme
# pylint: disable=invalid-name, protected-access, fixme, too-many-arguments
"""Symbol support of mxnet"""
from __future__ import absolute_import

Expand Down Expand Up @@ -123,6 +123,20 @@ def list_returns(self):
self.handle, ctypes.byref(size), ctypes.byref(sarr)))
return [py_str(sarr[i]) for i in range(size.value)]

def list_auxiliary_states(self):
"""List all auxiliary states in the symbool.

Returns
-------
args: list of string
List of all the auxiliary
"""
size = ctypes.c_uint()
sarr = ctypes.POINTER(ctypes.c_char_p)()
check_call(_LIB.MXSymbolListAuxiliaryStates(
self.handle, ctypes.byref(size), ctypes.byref(sarr)))
return [py_str(sarr[i]) for i in range(size.value)]

def infer_shape(self, *args, **kwargs):
"""Infer the shape of outputs and arguments of given known shapes of arguments.

Expand All @@ -147,6 +161,9 @@ def infer_shape(self, *args, **kwargs):
out_shapes : list of tuple or None
List of shapes of outputs.
The order is in the same order as list_returns()
aux_shapes : list of tuple or None
List of shapes of outputs.
The order is in the same order as list_auxiliary()
"""
# pylint: disable=too-many-locals
if len(args) != 0 and len(kwargs) != 0:
Expand Down Expand Up @@ -176,6 +193,9 @@ def infer_shape(self, *args, **kwargs):
out_shape_size = mx_uint()
out_shape_ndim = ctypes.POINTER(mx_uint)()
out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_uint))()
aux_shape_size = mx_uint()
aux_shape_ndim = ctypes.POINTER(mx_uint)()
aux_shape_data = ctypes.POINTER(ctypes.POINTER(mx_uint))()
complete = ctypes.c_int()
check_call(_LIB.MXSymbolInferShape(
self.handle, len(indptr) - 1,
Expand All @@ -188,13 +208,18 @@ def infer_shape(self, *args, **kwargs):
ctypes.byref(out_shape_size),
ctypes.byref(out_shape_ndim),
ctypes.byref(out_shape_data),
ctypes.byref(aux_shape_size),
ctypes.byref(aux_shape_ndim),
ctypes.byref(aux_shape_data),
ctypes.byref(complete)))
if complete.value != 0:
arg_shapes = [
tuple(arg_shape_data[i][:arg_shape_ndim[i]]) for i in range(arg_shape_size.value)]
out_shapes = [
tuple(out_shape_data[i][:out_shape_ndim[i]]) for i in range(out_shape_size.value)]
return (arg_shapes, out_shapes)
aux_shapes = [
tuple(aux_shape_data[i][:aux_shape_ndim[i]]) for i in range(aux_shape_size.value)]
return (arg_shapes, out_shapes, aux_shapes)
else:
return (None, None)
# pylint: enable=too-many-locals
Expand All @@ -212,7 +237,7 @@ def debug_str(self):
self.handle, ctypes.byref(debug_str)))
return py_str(debug_str.value)

def bind(self, ctx, args, args_grad, reqs):
def bind(self, ctx, args, args_grad, reqs, aux_states=None):
"""bind current symbol to get an executor.

Parameters
Expand All @@ -225,15 +250,20 @@ def bind(self, ctx, args, args_grad, reqs):
input args' gradient
reqs: Array of enum
graident requirements
aux_states: Array of NArray
input auxiliary states to the symbol
"""
# TODO(bing): consider a more friendly interface
# For example, pass in args_grad by dict
enum = {"null" : 0, "write_to" : 1, "in_place":2, "add_to" : 3}
if not isinstance(ctx, Context):
raise TypeError("Context type error")
if aux_states == None:
aux_states = []
args_handle = c_array(NArrayHandle, [item.handle for item in args])
args_grad_handle = c_array(NArrayHandle, [item.handle for item in args_grad])
reqs_array = c_array(mx_uint, [mx_uint(enum[item]) for item in reqs])
aux_args_handle = c_array(NArrayHandle, [item.handle for item in aux_states])
handle = ExecutorHandle()
check_call(_LIB.MXExecutorBind(self.handle,
mx_uint(ctx.device_mask),
Expand All @@ -242,6 +272,8 @@ def bind(self, ctx, args, args_grad, reqs):
args_handle,
args_grad_handle,
reqs_array,
len(aux_states),
aux_args_handle,
ctypes.byref(handle)))
return Executor(handle)

Expand Down
Loading