Skip to content

Commit

Permalink
Revert "[WIP] Do Not Merge. Static memory allocation for cached_op (a…
Browse files Browse the repository at this point in the history
…pache#10817)" (apache#11311)

This reverts commit 2dbd143.
marcoabreu authored Jun 15, 2018
1 parent 258e96d commit e48a8fd
Showing 18 changed files with 523 additions and 1,388 deletions.
5 changes: 5 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
@@ -987,6 +987,11 @@ MXNET_DLL int MXCreateCachedOpEx(SymbolHandle handle,
int num_flags,
const char** keys,
const char** vals,
int num_inputs,
const char** input_names,
int num_params,
const char** param_names,
NDArrayHandle* params,
CachedOpHandle *out);
/*!
* \brief free cached operator
89 changes: 89 additions & 0 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
@@ -35,6 +35,23 @@
#include "./ndarray.h"

namespace mxnet {
/*! \brief CachedOp Parameters */
struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
uint32_t inline_limit;
uint32_t forward_bulk_size;
uint32_t backward_bulk_size;
DMLC_DECLARE_PARAMETER(CachedOpConfig) {
DMLC_DECLARE_FIELD(inline_limit)
.set_default(2)
.describe("Maximum number of operators that can be inlined.");
DMLC_DECLARE_FIELD(forward_bulk_size)
.set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
.describe("Segment size of bulk execution during forward pass.");
DMLC_DECLARE_FIELD(backward_bulk_size)
.set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
.describe("Segment size of bulk execution during backward pass.");
}
};
/*! \brief runtime functions for NDArray */
class Imperative {
public:
@@ -77,6 +94,67 @@ class Imperative {
&& info.out_grads.size() == 1;
}
};
class CachedOp {
public:
CachedOp(
const nnvm::Symbol& sym,
const std::vector<std::pair<std::string, std::string> >& flags,
const std::vector<std::string> arg_names,
const std::unordered_map<std::string, std::vector<NDArray> >& params);
uint32_t num_inputs() {
return fwd_graph_.indexed_graph().input_nodes().size();
}
uint32_t num_outputs() {
return fwd_graph_.outputs.size();
}
uint32_t num_backward_inputs() {
return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
}
std::vector<bool>& save_inputs() {
return save_inputs_;
}
std::vector<bool>& save_outputs() {
return save_outputs_;
}
const std::unordered_set<uint32_t>& mutable_input_nodes() {
return fwd_graph_.indexed_graph().mutable_input_nodes();
}
nnvm::Graph GetForwardGraph(const bool recording,
const std::vector<NDArray*>& inputs);
nnvm::Graph GetBackwardGraph(const OpStatePtr& state,
const std::vector<OpReqType>& reqs,
const std::vector<NDArray*>& inputs);
std::vector<nnvm::NodeEntry> Gradient(const nnvm::NodePtr& node,
const std::vector<nnvm::NodeEntry>& ograds);
void Forward(const std::shared_ptr<CachedOp>& op_ptr,
const std::vector<NDArray*>& args,
const std::vector<NDArray*>& outputs);
void Backward(const bool retain_graph,
const OpStatePtr& state,
const std::vector<NDArray*>& inputs,
const std::vector<OpReqType>& reqs,
const std::vector<NDArray*>& outputs);

private:
struct CachedOpState {
std::vector<NDArray> buff;
std::vector<OpStatePtr> states;
};
std::mutex mutex_;
CachedOpConfig config_;
nnvm::Graph fwd_graph_;
nnvm::Graph grad_graph_;
nnvm::Graph full_graph_;
std::unordered_map<Context, std::vector<NDArray> > params_;
bool inlining_;
std::vector<nnvm::NodeEntry> ograd_entries_;
std::vector<bool> curr_grad_req_;
std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
std::vector<uint32_t> fwd_args_idx_;
std::vector<uint32_t> fwd_params_idx_;
std::vector<uint32_t> bwd_input_eid_;
std::vector<bool> save_inputs_, save_outputs_;
};
/*! \brief whether operator recording is on. */
bool is_training() const {
return is_train_;
@@ -144,6 +222,15 @@ class Imperative {
uint32_t num_inputs, uint32_t num_outputs,
std::vector<bool> *p_save_inputs,
std::vector<bool> *p_save_outputs);
void RunGraph(
const bool retain_graph,
const nnvm::IndexedGraph& idx,
const std::vector<NDArray*> arrays,
size_t node_start, size_t node_end,
std::vector<OpReqType>&& array_reqs,
std::vector<uint32_t>&& ref_count,
std::vector<OpStatePtr> *p_states,
const DispatchModeVector& dispatch_modes);
/*! \brief indicate whether is training. */
#if DMLC_CXX11_THREAD_LOCAL
static thread_local bool is_train_;
@@ -160,5 +247,7 @@ class Imperative {
int backward_bulk_size_{0};
};

using CachedOpPtr = std::shared_ptr<Imperative::CachedOp>;

} // namespace mxnet
#endif // MXNET_IMPERATIVE_H_
8 changes: 0 additions & 8 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
@@ -155,14 +155,6 @@ class NDArray {
return byte_offset_ > 0 || shape() != ptr_->storage_shape;
}

/* \brief Check whether the two arrays are the same array */
inline bool IsSame(const NDArray& other) {
return ptr_ == other.ptr_ &&
shape_ == other.shape_ &&
byte_offset_ == other.byte_offset_ &&
dtype_ == other.dtype_;
}

/*!
* \return the shape of current NDArray.
*/
33 changes: 13 additions & 20 deletions include/mxnet/op_attr_types.h
Original file line number Diff line number Diff line change
@@ -126,36 +126,25 @@ class OpStatePtr {
template<typename T, typename... Args>
static OpStatePtr Create(Args&&... args) {
OpStatePtr ret;
auto state = new T(std::forward<Args>(args)...);
auto var = Engine::Get()->NewVariable();
ret.ptr_.reset(
new OpState(var, state),
[](OpState* p) {
Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), p->var);
delete reinterpret_cast<T*>(p->state);
delete p;
});
ret.ptr_ = std::make_shared<OpState>();
ret.ptr_->var_ = Engine::Get()->NewVariable();
ret.ptr_->state_.construct<T>(std::forward<Args>(args)...);

return ret;
}
/* \brief Get engine variable associated with this state */
engine::VarHandle get_var() const {
return ptr_->var;
return ptr_->var_;
}
/* \brief Get state of type T */
template<typename T>
T& get_state() const {
return *reinterpret_cast<T*>(ptr_->state);
return dmlc::get<T>(ptr_->state_);
}
/* \brief clear state */
void reset() {
ptr_.reset();
}
/* \brief checks whether the managed object is managed only by the current
OpStatePtr instance */
bool unique() const {
return ptr_.unique();
}
/* \brief Whether state is empty */
explicit operator bool() const {
return ptr_ ? true : false;
@@ -164,12 +153,16 @@ class OpStatePtr {
private:
/* \brief state structure */
struct OpState {
engine::VarHandle var;
void* state;

OpState(engine::VarHandle var_, void* state_) : var(var_), state(state_) {}
OpState() {}
OpState(const OpState& other) = delete;
OpState& operator=(const OpState& other) = delete;

~OpState() {
Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), var_);
}

engine::VarHandle var_;
dmlc::any state_;
};
/* \brief shared pointer to state */
std::shared_ptr<OpState> ptr_;
16 changes: 15 additions & 1 deletion python/mxnet/_ctypes/ndarray.py
Original file line number Diff line number Diff line change
@@ -105,14 +105,28 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
class CachedOp(object):
"""Cached operator handle."""
__slots__ = ["handle"]
def __init__(self, sym, flags=()):
def __init__(self, sym, flags=(), inputs=None, params=None):
self.handle = CachedOpHandle()
param_names = []
param_arrays = []
if inputs is None:
assert params is None, "When inputs is None params must also be None."
inputs = sym.list_inputs()
elif params is not None:
for name, arrs in params.items():
param_arrays.extend(arrs)
param_names.extend([name] * len(arrs))

check_call(_LIB.MXCreateCachedOpEx(
sym.handle,
len(flags),
c_str_array([key for key, _ in flags]),
c_str_array([str(val) for _, val in flags]),
len(inputs),
c_str_array(inputs),
len(param_names),
c_str_array(param_names),
c_handle_array(param_arrays),
ctypes.byref(self.handle)))

def __del__(self):
74 changes: 25 additions & 49 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
@@ -502,16 +502,8 @@ def hybridize(self, active=True, **kwargs):
----------
active : bool, default True
Whether to turn hybrid on or off.
static_alloc : bool, default False
Statically allocate memory to improve speed. Memory usage may increase.
static_shape : bool, default False
Optimize for invariant input shapes between iterations. Must also
set static_alloc to True. Change of input shapes is still allowed
but slower.
forward_bulk_size : int, default 15
Segment size of bulk execution during forward pass.
backward_bulk_size : int, default 15
Segment size of bulk execution during backward pass.
**kwargs : string
Additional flags for hybridized operator.
"""
for cld in self._children.values():
cld.hybridize(active, **kwargs)
@@ -704,7 +696,7 @@ def __init__(self, prefix=None, params=None):
self._out_format = None
self._in_format = None
self._active = False
self._flags = []
self._flags = {}

def __setattr__(self, name, value):
"""Registers parameters."""
@@ -731,43 +723,39 @@ def _get_graph(self, *args):
return self._cached_graph

def _build_cache(self, *args):
data, out = self._get_graph(*args)
data_names = {data.name : i for i, data in enumerate(data)}
params = self.collect_params()
input_names = out.list_inputs()
inputs, out = self._get_graph(*args)
input_names = [i.name for i in inputs]

params = self.collect_params()
param_names = set(params.keys())
expected_names = set(input_names)
expected_names = set(out.list_inputs())
for name in expected_names:
assert name in param_names or name in data_names, \
assert name in param_names or name in input_names, \
"Unknown input to HybridBlock: %s"%name

used_data_names = [i for i in data_names if i in expected_names]
if len(used_data_names) != len(data_names):
unused = ', '.join(['%d-th'%i for name, i in data_names.items()
used_input_names = [i for i in input_names if i in expected_names]
if len(used_input_names) != len(input_names):
unused = ', '.join(['%d-th'%i for i, name in enumerate(input_names)
if name not in expected_names])
warnings.warn("The %s input to HybridBlock is not used by any "
"computation. Is this intended?"%unused, stacklevel=4)

used_param_names = [i for i in param_names if i in expected_names]
used_param_names = set(i for i in param_names if i in expected_names)
if len(used_param_names) != len(param_names):
unused = ', '.join(list(param_names - set(used_param_names)))
unused = ', '.join(list(param_names - used_param_names))
warnings.warn("Parameter %s is not used by any computation. "
"Is this intended?"%unused, stacklevel=4)

data_indices = []
param_indices = []
self._cached_op_args = []
for i, name in enumerate(input_names):
if name in data_names:
data_indices.append(i)
self._cached_op_args.append((True, data_names[name]))
else:
param_indices.append(i)
self._cached_op_args.append((False, params[name]))
flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \
self._flags
self._cached_op = ndarray.CachedOp(out, flags)
used_params = {k: params[k] for k in used_param_names}
try:
param_dict = {k: v.list_data() for k, v in used_params.items()}
except DeferredInitializationError:
self._deferred_infer_shape(*args)
for i in used_params.values():
i._finish_deferred_init()
param_dict = {k: v.list_data() for k, v in used_params.items()}

self._cached_op = ndarray.CachedOp(out, self._flags, input_names, param_dict)

def _deferred_infer_shape(self, *args):
try:
@@ -783,19 +771,7 @@ def _call_cached_op(self, *args):

args, fmt = _flatten(args, "input")
assert fmt == self._in_format, "Invalid input format"
try:
cargs = [args[i] if is_arg else i.data()
for is_arg, i in self._cached_op_args]
except DeferredInitializationError:
self._deferred_infer_shape(*args)
cargs = []
for is_arg, i in self._cached_op_args:
if is_arg:
cargs.append(args[i])
else:
i._finish_deferred_init()
cargs.append(i.data())
out = self._cached_op(*cargs)
out = self._cached_op(*args)
if isinstance(out, NDArray):
out = [out]
return _regroup(out, self._out_format)[0]
@@ -816,7 +792,7 @@ def register_child(self, block, name=None):

def hybridize(self, active=True, **kwargs):
self._active = active
self._flags = list(kwargs.items())
self._flags = kwargs.items()
self._clear_cached_op()
if active and self._forward_hooks or self._forward_pre_hooks:
warnings.warn('"{}" is being hybridized while still having forward hook/pre-hook. '
26 changes: 22 additions & 4 deletions src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
@@ -36,7 +36,6 @@
#include "../common/utils.h"
#include "../common/exec_utils.h"
#include "../imperative/imperative_utils.h"
#include "../imperative/cached_op.h"

using namespace mxnet;

@@ -161,15 +160,24 @@ int MXCreateCachedOp(SymbolHandle handle,
std::vector<std::string> input_names;
input_names.reserve(inputs.size());
for (const auto& i : inputs) input_names.push_back(i->attrs.name);
*out = new CachedOpPtr(new CachedOp(
*sym, std::vector<std::pair<std::string, std::string> >()));
*out = new std::shared_ptr<Imperative::CachedOp>(
new Imperative::CachedOp(
*sym,
std::vector<std::pair<std::string, std::string> >(),
input_names,
std::unordered_map<std::string, std::vector<NDArray> >()));
API_END();
}

int MXCreateCachedOpEx(SymbolHandle handle,
int num_flags,
const char** keys,
const char** vals,
int num_args,
const char** arg_names,
int num_params,
const char** param_names,
NDArrayHandle* params,
CachedOpHandle *out) {
nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(handle);

@@ -178,7 +186,17 @@ int MXCreateCachedOpEx(SymbolHandle handle,
for (int i = 0; i < num_flags; ++i) {
flags.push_back({keys[i], vals[i]});
}
*out = new CachedOpPtr(new CachedOp(*sym, flags));
std::vector<std::string> args;
for (int i = 0; i < num_args; ++i) {
args.push_back(arg_names[i]);
}
std::unordered_map<std::string, std::vector<NDArray> > param_dict;
for (int i = 0; i < num_params; ++i) {
param_dict[param_names[i]].emplace_back(
*reinterpret_cast<NDArray*>(params[i]));
}
*out = new std::shared_ptr<Imperative::CachedOp>(
new Imperative::CachedOp(*sym, flags, args, param_dict));
API_END();
}

3 changes: 1 addition & 2 deletions src/engine/threaded_engine.cc
Original file line number Diff line number Diff line change
@@ -278,8 +278,6 @@ void ThreadedEngine::DeleteOperator(OprHandle op) {
}

void ThreadedEngine::Push(OprHandle op, Context exec_ctx, int priority, bool profiling) {
BulkFlush();

ThreadedOpr* threaded_opr = ThreadedOpr::CastFromBase(op);
OprBlock* opr_block = OprBlock::New();
opr_block->opr = threaded_opr;
@@ -325,6 +323,7 @@ void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx,
<< device_count_;
}
#endif
BulkFlush();
ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, prop, opr_name, wait);
opr->temporary = true;
const bool profiling = profiler_->IsProfiling(profiler::Profiler::kImperative);
165 changes: 77 additions & 88 deletions src/executor/attach_op_execs_pass.cc
Original file line number Diff line number Diff line change
@@ -134,10 +134,6 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor {
return state_.get_var();
}

OpStatePtr state() const override {
return state_;
}

explicit StatefulComputeExecutor(const OpStatePtr& state,
const FStatefulCompute& fcompute,
ExecType exec_type,
@@ -146,6 +142,7 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor {
state_(state), fcompute_(fcompute), exec_type_(exec_type) {}

private:
friend Graph AttachOpExecs(Graph g);
OpStatePtr state_;
FStatefulCompute fcompute_;
ExecType exec_type_;
@@ -173,16 +170,13 @@ class StatefulComputeExExecutor : public OpExecutor {
return state_.get_var();
}

OpStatePtr state() const override {
return state_;
}

explicit StatefulComputeExExecutor(const OpStatePtr& state,
const FStatefulComputeEx& fcompute,
ExecType exec_type)
: state_(state), fcompute_(fcompute), exec_type_(exec_type) {}

private:
friend Graph AttachOpExecs(Graph g);
OpStatePtr state_;
FStatefulComputeEx fcompute_;
ExecType exec_type_;
@@ -247,15 +241,16 @@ class FComputeExExecutor : public OpExecutor {
ExecType exec_type_;
};

void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) {
// pass to attach operator executors
Graph AttachOpExecs(Graph g) {
using nnvm::DTypeVector;
using nnvm::ShapeVector;
using nnvm::FMutateInputs;

static auto& fcreate_op_state = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
static auto& fmutate_inputs = nnvm::Op::GetAttr<FMutateInputs>("FMutateInputs");
static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
static auto& is_layer_backward = nnvm::Op::GetAttr<bool>("TIsLayerOpBackward");
auto& fcreate_op_state = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
auto& fmutate_inputs = nnvm::Op::GetAttr<FMutateInputs>("FMutateInputs");
auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
auto& is_layer_backward = nnvm::Op::GetAttr<bool>("TIsLayerOpBackward");

const auto& vdtype = g.GetAttr<DTypeVector>("dtype");
const auto& vshape = g.GetAttr<ShapeVector>("shape");
@@ -264,88 +259,82 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) {

// get the graph
const auto& idx = g.indexed_graph();
OpExecVector& ret = *p_ret;
std::vector<std::shared_ptr<OpExecutor> > ret(idx.num_nodes());

// initialize the nodes
const auto& inode = idx[i];
if (inode.source->is_variable()) return;
const nnvm::Op *op = inode.source->op();
ExecType exec_type = ExecType::kSync;
std::vector<uint32_t> mutate_index;
if (fmutate_inputs.count(op)) {
mutate_index = fmutate_inputs[op](inode.source->attrs);
}
if (fexec_type.count(op)) {
exec_type = fexec_type[op](inode.source->attrs);
}
CHECK(dispatch_modes[i] != DispatchMode::kUndefined);
if (fcreate_op_state.count(op)) {
std::vector<TShape> ishape;
std::vector<int> itype;
for (const auto& e : inode.inputs) {
ishape.emplace_back(vshape[idx.entry_id(e)]);
itype.emplace_back(vdtype[idx.entry_id(e)]);
}

OpStatePtr state = fcreate_op_state[op](
inode.source->attrs, vctx[i], ishape, itype);
FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
op, "FStatefulComputeEx", vctx[i]);
// FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
ret[i] = std::make_shared<StatefulComputeExExecutor>(state, fcompute_ex, exec_type);
} else {
FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
op, "FStatefulCompute", vctx[i]);
CHECK(fcompute != nullptr)
<< "One of FStatefulCompute and FStatefulComputeEx must be registered "
<< "for stateful operator " << op->name;
ret[i] = std::make_shared<StatefulComputeExecutor>(state, fcompute,
exec_type, mutate_index);
for (size_t i = 0; i < idx.num_nodes(); ++i) {
const auto& inode = idx[i];
if (inode.source->is_variable()) continue;
const nnvm::Op *op = inode.source->op();
ExecType exec_type = ExecType::kSync;
std::vector<uint32_t> mutate_index;
if (fmutate_inputs.count(op)) {
mutate_index = fmutate_inputs[op](inode.source->attrs);
}
} else if (is_layer_backward.get(op, false)) {
CHECK_GE(inode.control_deps.size(), 1);
uint32_t fwd_id = inode.control_deps[0];
CHECK(vctx[fwd_id] == vctx[i]);
CHECK(ret[fwd_id] != nullptr);
FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
op, "FStatefulComputeEx", vctx[i]);
// FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
ret[i] = std::make_shared<StatefulComputeExExecutor>(
ret[fwd_id].get()->state(), fcompute_ex, exec_type);
} else {
FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
op, "FStatefulCompute", vctx[i]);
CHECK(fcompute != nullptr)
<< "One of FStatefulCompute and FStatefulComputeEx must be registered "
<< "for stateful operator " << op->name;
ret[i] = std::make_shared<StatefulComputeExecutor>(
ret[fwd_id].get()->state(), fcompute, exec_type, mutate_index);
if (fexec_type.count(op)) {
exec_type = fexec_type[op](inode.source->attrs);
}
} else {
FCompute fcompute = common::GetFCompute<FCompute>(op, "FCompute", vctx[i]);
FComputeEx fcomp_ex = common::GetFCompute<FComputeEx>(op, "FComputeEx", vctx[i]);
if (fcomp_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
ret[i] = std::make_shared<FComputeExExecutor>(
inode.source->attrs, fcomp_ex, exec_type);
} else if (fcompute != nullptr) {
ret[i] = std::make_shared<FComputeExecutor>(
inode.source->attrs, fcompute, exec_type, mutate_index);
CHECK(dispatch_modes[i] != DispatchMode::kUndefined);
if (fcreate_op_state.count(op)) {
std::vector<TShape> ishape;
std::vector<int> itype;
for (const auto& e : inode.inputs) {
ishape.emplace_back(vshape[idx.entry_id(e)]);
itype.emplace_back(vdtype[idx.entry_id(e)]);
}

OpStatePtr state = fcreate_op_state[op](
inode.source->attrs, vctx[i], ishape, itype);
FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
op, "FStatefulComputeEx", vctx[i]);
// FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
ret[i] = std::make_shared<StatefulComputeExExecutor>(state, fcompute_ex, exec_type);
} else {
FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
op, "FStatefulCompute", vctx[i]);
CHECK(fcompute != nullptr)
<< "One of FStatefulCompute and FStatefulComputeEx must be registered "
<< "for stateful operator " << op->name;
ret[i] = std::make_shared<StatefulComputeExecutor>(state, fcompute,
exec_type, mutate_index);
}
} else if (is_layer_backward.get(op, false)) {
CHECK_GE(inode.control_deps.size(), 1);
uint32_t fwd_id = inode.control_deps[0];
CHECK(vctx[fwd_id] == vctx[i]);
CHECK(ret[fwd_id] != nullptr);
FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
op, "FStatefulComputeEx", vctx[i]);
// FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
ret[i] = std::make_shared<StatefulComputeExExecutor>(
dynamic_cast<StatefulComputeExExecutor*>(ret[fwd_id].get())->state_,
fcompute_ex, exec_type);
} else {
FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
op, "FStatefulCompute", vctx[i]);
CHECK(fcompute != nullptr)
<< "One of FStatefulCompute and FStatefulComputeEx must be registered "
<< "for stateful operator " << op->name;
ret[i] = std::make_shared<StatefulComputeExecutor>(
dynamic_cast<StatefulComputeExecutor*>(ret[fwd_id].get())->state_,
fcompute, exec_type, mutate_index);
}
} else {
LOG(INFO) << "Neither FCompute nor FComputeEx registered " << op->name;
FCompute fcompute = common::GetFCompute<FCompute>(op, "FCompute", vctx[i]);
FComputeEx fcomp_ex = common::GetFCompute<FComputeEx>(op, "FComputeEx", vctx[i]);
if (fcomp_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
ret[i] = std::make_shared<FComputeExExecutor>(
inode.source->attrs, fcomp_ex, exec_type);
} else if (fcompute != nullptr) {
ret[i] = std::make_shared<FComputeExecutor>(
inode.source->attrs, fcompute, exec_type, mutate_index);
} else {
LOG(INFO) << "Neither FCompute nor FComputeEx registered " << op->name;
}
}
}
}


// pass to attach operator executors
Graph AttachOpExecs(Graph g) {
const auto& idx = g.indexed_graph();
OpExecVector ret(idx.num_nodes());
for (size_t i = 0; i < idx.num_nodes(); ++i) {
CreateOpExecs(g, &ret, i);
}
g.attrs["op_execs"] = std::make_shared<nnvm::any>(ret);
return g;
}
16 changes: 4 additions & 12 deletions src/executor/attach_op_resource_pass.cc
Original file line number Diff line number Diff line change
@@ -30,23 +30,20 @@
namespace mxnet {
namespace exec {

void AttachOpResources(
const Graph& g,
const OpExecVector& op_execs,
size_t start_nid,
size_t end_nid) {
Graph AttachOpResources(Graph g) {
static auto& fresource =
nnvm::Op::GetAttr<FResourceRequest>("FResourceRequest");
static auto& fresource_ex =
nnvm::Op::GetAttr<FResourceRequestEx>("FResourceRequestEx");
auto& op_execs = nnvm::get<OpExecVector>(*g.attrs.at("op_execs"));
const auto& vctx = g.GetAttr<ContextVector>("context");
const auto& vdispatch = g.GetAttr<DispatchModeVector>("dispatch_mode");
const auto& dev_masks = g.GetAttr<DevMaskVector>("dev_mask");
const auto& idx = g.indexed_graph();
// Use global resource pool for each executor for now.
std::map<Context, Resource> cached_temp;
// Resource allocation
for (uint32_t nid = start_nid; nid < end_nid; ++nid) {
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
const Context &ctx = vctx[nid];
@@ -87,12 +84,7 @@ void AttachOpResources(
requested.push_back(ResourceManager::Get()->Request(ctx, ResourceRequest::kTempSpace));
}
}
return g;
}

void AttachOpResources(const Graph& g) {
const auto& op_execs = g.GetAttr<OpExecVector>("op_execs");
AttachOpResources(g, op_execs, 0, g.indexed_graph().num_nodes());
}

} // namespace exec
} // namespace mxnet
28 changes: 4 additions & 24 deletions src/executor/exec_pass.h
Original file line number Diff line number Diff line change
@@ -82,10 +82,6 @@ class OpExecutor {
virtual engine::VarHandle var() const {
return nullptr;
}
/*! \return return operator state */
virtual OpStatePtr state() const {
return OpStatePtr();
}
};

/*!
@@ -106,14 +102,6 @@ using ContextVector = std::vector<Context>;
*/
using DevMaskVector = std::vector<int>;

/*!
* \brief create OpExecutor for a node in graph
*
* \param g input graph
* \param p_ret OpExecVector for input and output
* \param i the id of the node
*/
void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i);
/*!
* \brief Attach OpExecutor to the graph attributes.
*
@@ -127,20 +115,12 @@ Graph AttachOpExecs(Graph g);
* \brief Attach Resource to the OpExecVector of the graph.
*
* \param g input graph need to contain op_exec attribute.
*/
void AttachOpResources(const Graph& g);
/*!
* \brief Attach Resource to the OpExecVector
*
* \param g input graph
* \param op_execs OpExecutor vector
* \param start_nid starting node id
* \param end_nid end node id
* \return graph with new attribute "op_exec" of type OpExecVector
* The fields on the OpExecVector are not yet been setup.
*/
void AttachOpResources(const Graph& g,
const OpExecVector& op_execs,
size_t start_nid,
size_t end_nid);
Graph AttachOpResources(Graph g);

/*!
* \brief Discover chance of inplace addto operators.
* i.e. z = plus(z, source_op), and encourage it to become z += source_op.
2 changes: 1 addition & 1 deletion src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
@@ -912,7 +912,7 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol,
}

g = AttachOpExecs(g);
AttachOpResources(g);
g = AttachOpResources(g);
graph_ = std::move(g);

if (shared_exec != nullptr) {
750 changes: 144 additions & 606 deletions src/imperative/cached_op.cc

Large diffs are not rendered by default.

174 changes: 0 additions & 174 deletions src/imperative/cached_op.h

This file was deleted.

90 changes: 89 additions & 1 deletion src/imperative/imperative.cc
Original file line number Diff line number Diff line change
@@ -19,7 +19,6 @@
#include <unordered_set>
#include <iostream>
#include "./imperative_utils.h"
#include "./cached_op.h"

namespace mxnet {
#if DMLC_CXX11_THREAD_LOCAL
@@ -267,6 +266,95 @@ void Imperative::RecordOp(
}
}

void Imperative::RunGraph(
const bool retain_graph,
const nnvm::IndexedGraph& idx,
const std::vector<NDArray*> arrays,
size_t node_start, size_t node_end,
std::vector<OpReqType>&& array_reqs,
std::vector<uint32_t>&& ref_count,
std::vector<OpStatePtr> *p_states,
const DispatchModeVector &dispatch_modes) {
using namespace nnvm;
using namespace imperative;
static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward");
static const auto bwd_cached_op = Op::Get("_backward_CachedOp");

std::vector<OpStatePtr>& states = *p_states;
bool recording = is_recording();

std::vector<NDArray*> ndinputs, ndoutputs;
ShapeVector arg_shapes;
DTypeVector arg_dtypes;
std::vector<OpReqType> req;

for (size_t i = node_start; i < node_end; ++i) {
const nnvm::IndexedGraph::Node& node = idx[i];
if (node.source->op() == nullptr) continue;
auto num_outputs = node.source->num_outputs();
ndinputs.clear();
ndinputs.reserve(node.inputs.size());
for (const auto& j : node.inputs) {
ndinputs.emplace_back(arrays[idx.entry_id(j)]);
CHECK(!ndinputs.back()->is_none()) << idx[j.node_id].source->attrs.name << " " << j.index;
}
ndoutputs.clear();
ndoutputs.reserve(num_outputs);
req.clear();
req.reserve(num_outputs);
for (size_t j = 0; j < num_outputs; ++j) {
size_t eid = idx.entry_id(i, j);
ndoutputs.emplace_back(arrays[eid]);
req.push_back(array_reqs[eid]);
CHECK(!ndoutputs.back()->is_none());
}
const Context& ctx = ndoutputs[0]->ctx();
const DispatchMode dispatch_mode = dispatch_modes[i];
if (node.source->op() == bwd_cached_op) {
const auto& cached_op = dmlc::get<CachedOpPtr>(node.source->attrs.parsed);
nnvm::Node* fwd_node = node.source->control_deps[0].get();
auto fwd_node_id = idx.node_id(fwd_node);
cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req, ndoutputs);
} else if (createop.count(node.source->op())) {
arg_shapes.clear();
arg_dtypes.clear();
arg_shapes.reserve(ndinputs.size());
arg_dtypes.reserve(ndinputs.size());
for (size_t i = 0; i < ndinputs.size(); ++i) {
arg_shapes.emplace_back(ndinputs[i]->shape());
arg_dtypes.emplace_back(ndinputs[i]->dtype());
}
states[i] = createop[node.source->op()](
node.source->attrs, ctx, arg_shapes, arg_dtypes);
InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode, states[i]);
if (recording) RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[i]);
} else if (is_layer_backward.get(node.source->op(), false)) {
nnvm::Node* fwd_node = node.source->control_deps[0].get();
auto fwd_node_id = idx.node_id(fwd_node);
InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs,
req, dispatch_mode, states[fwd_node_id]);
if (recording) {
RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[fwd_node_id]);
}
} else {
InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode);
if (recording) RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs);
}

for (const auto& j : node.inputs) {
size_t eid = idx.entry_id(j);
--ref_count[eid];
if (ref_count[eid] == 0) arrays[eid]->ptr_.reset();
}
for (size_t j = 0; j < ndoutputs.size(); ++j) {
size_t eid = idx.entry_id(i, j);
if (ref_count[eid] == 0) arrays[eid]->ptr_.reset();
}
}
}


std::vector<NDArray*> Imperative::Backward(
const std::vector<NDArray*>& outputs,
const std::vector<NDArray*>& ograds,
120 changes: 0 additions & 120 deletions src/imperative/imperative_utils.cc

This file was deleted.

256 changes: 33 additions & 223 deletions src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
@@ -23,7 +23,6 @@
#include <utility>
#include <algorithm>
#include <vector>
#include <map>
#include <string>
#include "../executor/graph_executor.h"
#include "../executor/exec_pass.h"
@@ -39,24 +38,11 @@ namespace mxnet {
namespace imperative {

struct MemoryPlanInfo {
int storage_id;
uint32_t root;
uint32_t sid;
size_t size;
bool inplace;
};

struct EngineOprDeleter {
void operator()(engine::Opr* handle) {
Engine::Get()->DeleteOperator(handle);
}
};

struct EngineOprSeg {
bool skip;
size_t next_nid;
std::unique_ptr<engine::Opr, EngineOprDeleter> opr;
};

using MemoryPlanVector = std::vector<MemoryPlanInfo>;

inline Context GetContext(const nnvm::NodeAttrs& attrs,
@@ -729,12 +715,10 @@ inline std::vector<Context> PlaceDevice(const nnvm::IndexedGraph& idx) {


inline MemoryPlanVector PlanMemory(
nnvm::Graph* p_g,
nnvm::StorageVector&& storage,
nnvm::Graph* p_g, nnvm::StorageVector&& storage,
const std::vector<uint32_t>& ref_count,
const std::pair<uint32_t, uint32_t>& node_range = {0, 0},
const std::pair<uint32_t, uint32_t>& entry_range = {0, 0},
bool detect_inplace_addto = false) {
const std::pair<uint32_t, uint32_t>& entry_range = {0, 0}) {
using namespace nnvm;
nnvm::Graph& g = *p_g;
const auto& idx = g.indexed_graph();
@@ -744,31 +728,31 @@ inline MemoryPlanVector PlanMemory(
g.attrs["ref_count"] = std::make_shared<dmlc::any>(ref_count);
g.attrs["storage"] = std::make_shared<dmlc::any>(std::move(storage));
g = nnvm::ApplyPass(g, "PlanMemory");
if (detect_inplace_addto) g = exec::DetectInplaceAddTo(g);

const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
const auto& shapes = g.GetAttr<ShapeVector>("shape");
const auto& storage_inplace = g.GetAttr<std::vector<int> >("storage_inplace_index");
const auto& storage_ids = g.GetAttr<StorageVector>("storage_id");
const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
auto storage_ids = g.MoveCopyAttr<StorageVector>("storage_id");
auto storage_inplace = g.MoveCopyAttr<std::vector<int> >("storage_inplace_index");
uint32_t entry_start = entry_range.first;
uint32_t entry_end =
entry_range.second > entry_start ? entry_range.second : idx.num_node_entries();
MemoryPlanVector mem_plan(idx.num_node_entries());
std::unordered_map<int, uint32_t> sid_to_root;
std::unordered_map<int, uint32_t> sid_to_loc;

for (uint32_t i = entry_start; i < entry_end; ++i) {
if (stypes[i] != kDefaultStorage) continue;
if (storage_ids[i] < 0) {
mem_plan[i] = {storage_ids[i], i, 0, false};
} else if (!sid_to_root.count(storage_ids[i])) {
mem_plan[i] = {i, mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size(), false};
} else if (!sid_to_loc.count(storage_ids[i])) {
CHECK_LT(storage_inplace[i], 0);
sid_to_root[storage_ids[i]] = i;
mem_plan[i] = {storage_ids[i], i,
mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size(),
false};
sid_to_loc[storage_ids[i]] = i;
mem_plan[i].sid = i;
mem_plan[i].size = mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size();
} else {
uint32_t root = sid_to_root[storage_ids[i]];
mem_plan[i] = {storage_ids[i], root, 0, storage_inplace[i] >= 0};
mem_plan[root].size = std::max(mem_plan[root].size,
uint32_t loc = sid_to_loc[storage_ids[i]];
mem_plan[i] = {loc, 0, storage_inplace[i] >= 0};
mem_plan[loc].size = std::max(mem_plan[loc].size,
mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size());
}
}
@@ -777,213 +761,39 @@ inline MemoryPlanVector PlanMemory(
}


inline std::multimap<size_t, NDArray> AllocateMemory(
const nnvm::Graph& g,
const nnvm::IndexedGraph& idx,
const Context& default_ctx,
const uint32_t entry_start, const uint32_t entry_end,
const MemoryPlanVector& mem_plan,
const std::vector<NDArray*>& arrays,
std::vector<OpReqType> *array_reqs,
std::multimap<size_t, NDArray>&& pool = std::multimap<size_t, NDArray>()) {
inline void AllocateMemory(const nnvm::Graph& g,
const nnvm::IndexedGraph& idx,
const Context& default_ctx,
const uint32_t entry_start, const uint32_t entry_end,
const MemoryPlanVector& mem_plan,
const std::vector<NDArray*>& arrays,
std::vector<OpReqType> *array_reqs) {
using namespace nnvm;
const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
const auto& shapes = g.GetAttr<ShapeVector>("shape");
const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");

std::multimap<size_t, NDArray> new_pool;

for (uint32_t i = entry_start; i < entry_end; ++i) {
if (mem_plan[i].storage_id == exec::kExternalStorageID) continue;
CHECK(arrays[i]->is_none());
if (mem_plan[i].storage_id == exec::kDynamicStorageID) {
*arrays[i] = NDArray(static_cast<NDArrayStorageType>(stypes[i]),
shapes[i], default_ctx, true, dtypes[i]);
continue;
}
CHECK_EQ(stypes[i], kDefaultStorage);
if (mem_plan[i].root == i) {
CHECK_GT(mem_plan[i].size, 0);
auto iter = pool.lower_bound(mem_plan[i].size);
if (iter != pool.end()) {
*arrays[i] = iter->second.AsArray(shapes[i], dtypes[i]);
new_pool.insert(*iter);
pool.erase(iter);
} else {
if (!arrays[i]->is_none()) continue;
if (stypes[i] == kDefaultStorage) {
if (mem_plan[i].sid == i) {
CHECK_GT(mem_plan[i].size, 0);
NDArray buff(TShape({static_cast<nnvm::dim_t>(mem_plan[i].size)}),
default_ctx, true, mshadow::kUint8);
*arrays[i] = buff.AsArray(shapes[i], dtypes[i]);
new_pool.insert({mem_plan[i].size, buff});
}
} else {
CHECK_GE(mem_plan[mem_plan[i].root].storage_id, 0);
*arrays[i] = arrays[mem_plan[i].root]->AsArray(shapes[i], dtypes[i]);
if (mem_plan[i].inplace && array_reqs->at(i) == kWriteTo) {
array_reqs->at(i) = kWriteInplace;
}
}
}

return new_pool;
}

inline void SetupOpExec(
const nnvm::Graph& g,
size_t nid,
const std::shared_ptr<exec::OpExecutor>& exec,
const std::vector<NDArray*> arrays,
const std::vector<OpReqType> array_reqs) {
const auto& idx = g.indexed_graph();
const auto& inode = idx[nid];
CHECK_EQ(exec->in_array.size(), 0U);
CHECK_EQ(exec->out_array.size(), 0U);
for (const auto& e : inode.inputs) {
CHECK(!arrays[idx.entry_id(e)]->is_none()) << inode.source->attrs.name;
exec->in_array.push_back(*arrays[idx.entry_id(e)]);
}
for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
uint32_t eid = idx.entry_id(nid, index);
CHECK(!arrays[eid]->is_none()) << inode.source->attrs.name;
exec->out_array.push_back(*arrays[eid]);
exec->req.push_back(array_reqs[eid]);
}

exec->Setup();
}

inline Engine::OprHandle CreateEngineOp(
const Context& default_ctx,
const std::vector<std::shared_ptr<exec::OpExecutor> >& execs) {
CHECK_GT(execs.size(), 0);
std::vector<Engine::VarHandle> use_vars, mutate_vars;

for (const auto& exec : execs) {
CHECK_GT(exec->out_array.size(), 0);
CHECK(execs.size() == 1 || exec->exec_type() == ExecType::kSync);

// the variables
for (const auto& nd : exec->in_array) {
use_vars.push_back(nd.var());
}
for (auto& r : exec->op_ctx.requested) {
mutate_vars.push_back(r.var);
}
for (auto& nd : exec->out_array) {
mutate_vars.push_back(nd.var());
}
if (exec->var() != nullptr) {
mutate_vars.push_back(exec->var());
}
}

// dedup vars
Engine::Get()->DeduplicateVarHandle(&use_vars, &mutate_vars);
bool is_gpu = default_ctx.dev_mask() == gpu::kDevMask;
bool is_async = execs.size() > 1 ? false : execs[0]->exec_type() == ExecType::kAsync;

auto exec_fun = [execs, is_async, is_gpu] (
RunContext ctx, Engine::CallbackOnComplete on_complete) {
if (is_async) {
execs[0]->op_ctx.async_on_complete = on_complete;
}
for (const auto& exec : execs) exec->Run(ctx, is_gpu);
// call on complete only if it is async op
if (!is_async) {
if (is_gpu) {
#if MXNET_USE_CUDA
// Wait GPU kernel to finish.
ctx.get_stream<gpu>()->Wait();
#else
LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
#endif
}
on_complete();
}
};

return Engine::Get()->NewOperator(
exec_fun, use_vars, mutate_vars, FnProperty::kNormal);
}

inline void CreateEngineOpSeg(
const nnvm::IndexedGraph& idx,
const Context default_ctx,
const size_t start_nid,
const size_t end_nid,
const size_t bulk_size,
const std::unordered_set<uint32_t>& excludes,
const std::vector<std::shared_ptr<exec::OpExecutor> >& execs,
const std::vector<int> skip_plus_node,
std::vector<EngineOprSeg> *opr_segs) {
size_t seg_start = start_nid;
std::vector<std::shared_ptr<exec::OpExecutor> > seg_execs;
for (size_t nid = start_nid; nid < end_nid; ++nid) {
const auto& node = idx[nid];
if (node.source->is_variable()) continue;
if (skip_plus_node.size() && skip_plus_node[nid]) continue;
auto& exec = execs[nid];
bool is_async = exec->exec_type() != ExecType::kSync;
bool valid = exec->out_array.size() > 0;

// Stop at async nodes and invalid node (due to input/output is not allocated)
bool stop = is_async || !valid || seg_execs.size() >= bulk_size;
for (size_t i = 0; i < node.inputs.size() && !stop; ++i) {
if (excludes.count(idx.entry_id(node.inputs[i]))) stop = true;
}
auto num_outputs = node.source->num_outputs();
for (size_t i = 0; i < num_outputs && !stop; ++i) {
if (excludes.count(idx.entry_id(nid, i))) stop = true;
}

// Create opr segment for previous nodes.
if (stop && nid > seg_start) {
auto& seg = (*opr_segs)[seg_start];
if (seg_execs.size()) {
seg = EngineOprSeg{false, nid};
seg.opr.reset(CreateEngineOp(default_ctx, seg_execs));
} else {
seg = EngineOprSeg{true, nid, nullptr};
*arrays[i] = arrays[mem_plan[i].sid]->AsArray(shapes[i], dtypes[i]);
if (mem_plan[i].inplace && array_reqs->at(i) == kWriteTo) {
array_reqs->at(i) = kWriteInplace;
}
}
seg_start = nid;
seg_execs.clear();
}

seg_execs.push_back(exec);

auto& seg = (*opr_segs)[nid];
if (is_async) {
seg = EngineOprSeg{false, nid + 1};
seg.opr.reset(CreateEngineOp(default_ctx, seg_execs));
seg_execs.clear();
seg_start = nid + 1;
} else if (!valid) {
seg = EngineOprSeg{false, nid + 1, nullptr};
seg_execs.clear();
seg_start = nid + 1;
}
}
// The last segment
if (end_nid > seg_start) {
auto& seg = (*opr_segs)[seg_start];
if (seg_execs.size()) {
seg = EngineOprSeg{false, end_nid};
seg.opr.reset(CreateEngineOp(default_ctx, seg_execs));
} else {
seg = EngineOprSeg{true, end_nid, nullptr};
*arrays[i] = NDArray(static_cast<NDArrayStorageType>(stypes[i]),
shapes[i], default_ctx, true, dtypes[i]);
}
}
}


void RunGraph(const bool retain_graph,
const nnvm::IndexedGraph& idx,
const std::vector<NDArray*> arrays,
size_t node_start, size_t node_end,
std::vector<OpReqType>&& array_reqs,
std::vector<uint32_t>&& ref_count,
std::vector<OpStatePtr> *p_states,
const DispatchModeVector &dispatch_modes);

} // namespace imperative
} // namespace mxnet

56 changes: 1 addition & 55 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
@@ -22,7 +22,6 @@
from mxnet.ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
from common import setup_module, with_seed, assertRaises, teardown
import numpy as np
from numpy.testing import assert_array_equal
from nose.tools import raises, assert_raises
from copy import deepcopy
import warnings
@@ -1125,6 +1124,7 @@ def test_hybrid_multi_context():
net.hybridize()
net(mx.nd.zeros((1, 3, 32, 32), ctx=mx.cpu(0))).asnumpy()


@with_seed()
def test_zero_grad():
data = mx.nd.random.uniform(shape=(3,3))
@@ -1137,60 +1137,6 @@ def test_zero_grad():
grad = net.collect_params()['test_zero_grad_weight'].grad()
assert_almost_equal(grad.asnumpy(), grad.asnumpy() * 0)

def check_hybrid_static_memory(**kwargs):
x = mx.nd.random.uniform(shape=(2, 3, 32, 32))
x.attach_grad()

net1 = gluon.model_zoo.vision.get_resnet(
1, 18, pretrained=True, prefix='net_', ctx=mx.context.current_context())
net2 = gluon.model_zoo.vision.get_resnet(
1, 18, pretrained=True, prefix='net_', ctx=mx.context.current_context())
net2.hybridize(**kwargs)
net1(x)
net2(x)

def test(net, x):
with mx.autograd.record():
y = net(x) + net(x)
y.backward()

grads = {k: v.grad() for k, v in net.collect_params().items() if v.grad_req != 'null'}

return y, grads

y1, grads1 = test(net1, x)
y2, grads2 = test(net2, x)

assert_almost_equal(y1.asnumpy(), y2.asnumpy(), rtol=1e-3, atol=1e-5)
for key in grads1:
assert_almost_equal(grads1[key].asnumpy(), grads2[key].asnumpy(), rtol=1e-3, atol=1e-5)

def test_hybrid_static_memory():
check_hybrid_static_memory()
check_hybrid_static_memory(static_alloc=True)
check_hybrid_static_memory(static_alloc=True, static_shape=True)

def check_hybrid_static_memory_switching(**kwargs):
net = gluon.model_zoo.vision.get_resnet(
1, 18, pretrained=True, ctx=mx.context.current_context())
net.hybridize(**kwargs)

x = mx.nd.random.uniform(shape=(4, 3, 32, 32))
net(x)
with mx.autograd.record():
y = net(x)
y.backward()
x = mx.nd.random.uniform(shape=(2, 3, 32, 32))
net(x)
with mx.autograd.record():
y = net(x)
y.backward()
mx.nd.waitall()

def test_hybrid_static_memory_switching():
check_hybrid_static_memory_switching()
check_hybrid_static_memory_switching(static_alloc=True)
check_hybrid_static_memory_switching(static_alloc=True, static_shape=True)

@with_seed()
def test_hook():

0 comments on commit e48a8fd

Please sign in to comment.