From ed9b2c82aabd1650296540d5fa59ca4999bd4277 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 31 Aug 2016 01:08:46 +0900 Subject: [PATCH 1/8] add initial draft of the custom Operator interface --- src/_custom_impl.jl | 0 src/custom.jl | 88 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 src/_custom_impl.jl create mode 100644 src/custom.jl diff --git a/src/_custom_impl.jl b/src/_custom_impl.jl new file mode 100644 index 000000000..e69de29bb diff --git a/src/custom.jl b/src/custom.jl new file mode 100644 index 000000000..78da4c808 --- /dev/null +++ b/src/custom.jl @@ -0,0 +1,88 @@ +""" +Custom and native operators written in Julia. +The interface is given by the abstract type [`Operactor`](@ref). +""" +module Custom + +using Compat +import Compat: String + +import ..mx +import ..mx: NDArray + +export assign + +""" + Operator +""" +abstract Operator + +""" + forward(op, is_train, req, in_data, out_data, aux) + +Forward interface. Custom operators must override this. + +# Arguments: +* `is_train::Bool`: Whether we are in training +* `req::Vector{Symbol}`: How to assign to out_data. Can be :null, :write, :inplace, or :add. You can use `assign(dst, req, src)` to handle this +* `in_data::Vector{NDArray}` +* `out_data::Vector{NDArray}` +* `aux::Vector{NDArray}` +""" +function forward(op::Operator, is_train, req, in_data, out_data, aux) + throw(MethodError(forward, (op, is_train, req, in_data, out_data))) +end + +""" +Backwards interface. Custom operators must override this. +""" +function backward(op::Operator, req, out_grad, in_data, out_data, in_grad, aux) + throw(MethodError(backward, (op, req, out_grad, in_data, out_data, in_grad, aux))) +end + +function assign(dst, req, src) + if req == :null + return nothing + elseif req == :write || req == :inplace + dst[:] = src + elseif req == :add + dst[:] += src + else + error("Unable to handle $req in assign.") + end + return nothing +end + +abstract CustomOpProp + +function needs_top_grad(self :: CustomOpProp) + return false +end + +function infer_shape(self :: CustomOpProp, in_shape) + return in_shape, [in_shape[1]], [] +end + +function list_outputs(self :: CustomOpProp) + return String["output"] +end + +function list_arguments(self :: CustomOpProp) + return String["data"] +end + +function list_auxiliary_states(self :: CustomOpProp) + return String[] +end + +function declare_backward_dependency(self :: CustomOpProp, out_grad, in_data, out_data) + deps = Int[] + if needs_top_grad(self) + append!(deps, out_grad) + end + append!(deps, in_data) + append!(deps, out_data) +end + +include("_impl_custom.jl") +end From 36b637db0ec9e27786c964e05ee32c5d115adfa6 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 31 Aug 2016 23:16:02 +0900 Subject: [PATCH 2/8] add infer_shape_entry --- src/_custom_impl.jl | 92 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/src/_custom_impl.jl b/src/_custom_impl.jl index e69de29bb..ec9ed18e8 100644 --- a/src/_custom_impl.jl +++ b/src/_custom_impl.jl @@ -0,0 +1,92 @@ +immutable CustomOpInfo + forward :: Ptr{Void} + backward :: Ptr{Void} + delete :: Ptr{Void} + p_forward :: Ptr{Void} + p_backward :: Ptr{Void} + p_delete :: Ptr{Void} +end + +immutable CustomOpPropInfo + list_arguments :: Ptr{Void} + list_outputs :: Ptr{Void} + infer_shape :: Ptr{Void} + declare_backward_dependency :: Ptr{Void} + create_operator :: Ptr{Void} + list_auxiliary_states :: Ptr{Void} + delete :: Ptr{Void} + p_list_arguments :: Ptr{Void} + p_list_outputs :: Ptr{Void} + p_infer_shape :: Ptr{Void} + p_declare_backward_dependency :: Ptr{Void} + p_create_operator :: Ptr{Void} + p_list_auxiliary_states :: Ptr{Void} + p_delete :: Ptr{Void} +end + +function infer_shape_entry(num_tensor, tensor_dims, tensor_shapes, payload) + try + op = unsafe_pointer_to_objref(payload) :: Operator n_in = length(list_arguments(op)) + n_out = length(list_outputs(op)) + n_aux = length(list_auxiliary_states()) + + @assert num_tensor == n_in + n_out + n_aux + + shapes = Vector{Cuint}[] + # copy and revert input shapes + for i in 1:n_in + # Get size of array and create julia arry + ndims = unsafe_load(tensor_dims, i) + shape = zeros(Cuint, ndims) + tshape = unsafe_load(tensor_shapes, i) + for j in 1:ndims + shape[j] = unsafe_load(tshapes, ndims-j + 1) + end + push!(shapes, shape) + end + + ret = infer_shape(op, shapes) + if length(ret) == 2 + ishapes, oshapes = ret + ashapes = Cuint[] + elseif lenght(ret) == 3 + ishapes, oshapes, ashapes = ret + else + error("infer_shape must return 2 or 3 lists.") + end + + @assert length(ishapes) == n_in + @assert length(oshapes) == n_out + @assert length(ashapes) == n_aux + + # We now have to reverse the arrays again + # We can't perform a inplace operation in case the arrays share memory + rshapes = Vector{Cuint} + for shape in ishapes + push!(rshapes, reverse(shape)) + end + for shape in oshapes + push!(rshapes, reverse(shape)) + end + for shape in ashapes + push!(rshapes, reverse(shape)) + end + + # link memory lifetime of rshapes and op + # TODO + + for i in 1:num_tensors + unsafe_store!(tensor_shapes, pointer(rshapes[i]), i) + unsafe_store!(tensor_dims, length(rshapes[i]), i) + end + catch error + println(STDERR, "Error in infer_shape: ") + showerror(STDERR, error) + return false + end + return true +end + +end + + From 220baf61253ede13a5dea6eccb66e25226497a34 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Thu, 1 Sep 2016 04:55:17 +0900 Subject: [PATCH 3/8] add more entry functions --- src/_custom_impl.jl | 114 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 108 insertions(+), 6 deletions(-) diff --git a/src/_custom_impl.jl b/src/_custom_impl.jl index ec9ed18e8..bde63c940 100644 --- a/src/_custom_impl.jl +++ b/src/_custom_impl.jl @@ -22,11 +22,41 @@ immutable CustomOpPropInfo p_create_operator :: Ptr{Void} p_list_auxiliary_states :: Ptr{Void} p_delete :: Ptr{Void} + function CustomOpPropInfo(op :: CustomOpProp) + payload = pointer_from_objref(op) + c_infer_shape = cfunction(_infer_shape_entry, Bool, (Cint, Ptr{Ptr{Void}}, Ptr{Cint}, Ptr{Void})) + c_list_outputs = cfunction(_list_outputs_entry, Bool, (Ptr{Ptr{Ptr{Cchar}}}, Ptr{Void})) + c_list_arguments = cfunction(_list_arguments_entry, Bool, (Ptr{Ptr{Ptr{Cchar}}}, Ptr{Void})) + c_list_auxiliary_states = cfunction(_list_auxiliary_states_entry, Bool, (Ptr{Ptr{Ptr{Cchar}}}, Ptr{Void})) + c_declare_backward_dependency = cfunction(_declare_backward_dependency_entry, Bool, (Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Ptr{Cint}}, Ptr{Void})) + c_delete = cfunction(_delete_entry, Void, (Ptr{Void},)) + + new(c_list_arguments, c_list_outputs, c_infer_shape, c_declare_backwards_dependency, c_list_auxiliary_states, c_delete, + payload, payload, payload, payload, payload, payload) + end +end + +const __prop_pinned_memory = WeakKeyDict{CustomOpProp, Vector{Any}}() +function _pin!(op :: CustomOpProp, x :: ANY) + xs = get(__prop_pinned_memory, op, Any[]) + push!(xs, x) + __prop_pinned_memory[op] = xs +end + +function _finalizer(op :: CustomOpProp) + if haskey(__prop_pinned_memory) + delete!(__prop_pinned_memory, op) + else +end + +function _delete_entry(payload :: Ptr{Void}) + # Figure out what to do here. This is to keep this part of the memory alive end -function infer_shape_entry(num_tensor, tensor_dims, tensor_shapes, payload) +function _infer_shape_entry(num_tensor, tensor_dims, tensor_shapes, payload) try - op = unsafe_pointer_to_objref(payload) :: Operator n_in = length(list_arguments(op)) + op = unsafe_pointer_to_objref(payload) :: CustomOpProp + n_in = length(list_arguments(op)) n_out = length(list_outputs(op)) n_aux = length(list_auxiliary_states()) @@ -72,21 +102,93 @@ function infer_shape_entry(num_tensor, tensor_dims, tensor_shapes, payload) push!(rshapes, reverse(shape)) end - # link memory lifetime of rshapes and op - # TODO + _pin!(op, rshapes) for i in 1:num_tensors unsafe_store!(tensor_shapes, pointer(rshapes[i]), i) unsafe_store!(tensor_dims, length(rshapes[i]), i) end - catch error + catch err println(STDERR, "Error in infer_shape: ") - showerror(STDERR, error) + showerror(STDERR, err) return false end return true end +function _list_arguments_entry(data :: Ptr{Ptr{Ptr{Cchar}}}, payload :: Ptr{Void}) + try + op = unsafe_pointer_to_objref(payload) :: CustomOpProp + arguments = list_arguments(op) + _pin!(op, arguments) + ptrs = Ptr{Cchar}[Base.unsafe_convert(Ptr{Cchar}, s) for s in arguments] + _pin!(op, ptrs) + push!(ptrs, C_NULL) + unsafe_store!(data, pointer(ptrs), 1) + catch err + println(STDERR, "Error in list_arguments: ") + showerror(STDERR, err) + return false + end + return true end +function _list_outputs_entry(data :: Ptr{Ptr{Ptr{Cchar}}}, payload :: Ptr{Void}) + try + op = unsafe_pointer_to_objref(payload) :: CustomOpProp + outputs = list_outputs(op) + _pin!(op, outputs) + ptrs = Ptr{Cchar}[Base.unsafe_convert(Ptr{Cchar}, s) for s in outputs] + _pin!(op, ptrs) + push!(ptrs, C_NULL) + unsafe_store!(data, pointer(ptrs), 1) + catch err + println(STDERR, "Error in list_outputs: ") + showerror(STDERR, err) + return false + end + return true +end + +function _list_auxiliary_states_entry(data :: Ptr{Ptr{Ptr{Cchar}}}, payload :: Ptr{Void}) + try + op = unsafe_pointer_to_objref(payload) :: CustomOpProp + aux = list_auxiliary_states(op) + _pin!(op, aux) + ptrs = Ptr{Cchar}[Base.unsafe_convert(Ptr{Cchar}, s) for s in aux] + _pin!(op, ptrs) + push!(ptrs, C_NULL) + unsafe_store!(data, pointer(ptrs), 1) + catch err + println(STDERR, "Error in list_auxiliary_states: ") + showerror(STDERR, err) + return false + end + return true +end + +function _declare_backward_dependency(_out_grad :: Ptr{Cint}, + _in_data :: Ptr{Cint}, + _out_data :: Ptr{Cint} + num_dep :: Ptr{Cint}, + deps :: Ptr{Ptr{Cint}}, + payload :: Ptr{Void}) + try + op = unsafe_pointer_to_objref(payload) :: CustomOpProp + out_grad = unsafe_wrap(Array, _out_grad, length(list_outputs(op))) + in_data = unsafe_wrap(Array, _in_data, length(list_arguments(op))) + out_data = unsafe_wrap(Array, _out_data, length(list_outputs(op))) + + rdeps = convert(Vector{Cint}, declare_backward_dependency(op, out_grad, in_data, out_data)) + _pin!(op, rdeps) + + unsafe_store!(num_dep, length(rdeps), 1) + unsafe_store!(deps, pointer(rdeps), 1) + catch err + println(STDERR, "Error in declare_backward_dependency: ") + showerror(STDERR, err) + return false + end + return true +end From d264a9096dc71ab0dfafb053f997beb6650f89a8 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Thu, 1 Sep 2016 04:58:25 +0900 Subject: [PATCH 4/8] reorganize --- src/_custom_impl.jl | 186 -------------------------------------------- src/_custom_prop.jl | 185 +++++++++++++++++++++++++++++++++++++++++++ src/custom.jl | 3 +- 3 files changed, 187 insertions(+), 187 deletions(-) create mode 100644 src/_custom_prop.jl diff --git a/src/_custom_impl.jl b/src/_custom_impl.jl index bde63c940..7c6779cef 100644 --- a/src/_custom_impl.jl +++ b/src/_custom_impl.jl @@ -6,189 +6,3 @@ immutable CustomOpInfo p_backward :: Ptr{Void} p_delete :: Ptr{Void} end - -immutable CustomOpPropInfo - list_arguments :: Ptr{Void} - list_outputs :: Ptr{Void} - infer_shape :: Ptr{Void} - declare_backward_dependency :: Ptr{Void} - create_operator :: Ptr{Void} - list_auxiliary_states :: Ptr{Void} - delete :: Ptr{Void} - p_list_arguments :: Ptr{Void} - p_list_outputs :: Ptr{Void} - p_infer_shape :: Ptr{Void} - p_declare_backward_dependency :: Ptr{Void} - p_create_operator :: Ptr{Void} - p_list_auxiliary_states :: Ptr{Void} - p_delete :: Ptr{Void} - function CustomOpPropInfo(op :: CustomOpProp) - payload = pointer_from_objref(op) - c_infer_shape = cfunction(_infer_shape_entry, Bool, (Cint, Ptr{Ptr{Void}}, Ptr{Cint}, Ptr{Void})) - c_list_outputs = cfunction(_list_outputs_entry, Bool, (Ptr{Ptr{Ptr{Cchar}}}, Ptr{Void})) - c_list_arguments = cfunction(_list_arguments_entry, Bool, (Ptr{Ptr{Ptr{Cchar}}}, Ptr{Void})) - c_list_auxiliary_states = cfunction(_list_auxiliary_states_entry, Bool, (Ptr{Ptr{Ptr{Cchar}}}, Ptr{Void})) - c_declare_backward_dependency = cfunction(_declare_backward_dependency_entry, Bool, (Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Ptr{Cint}}, Ptr{Void})) - c_delete = cfunction(_delete_entry, Void, (Ptr{Void},)) - - new(c_list_arguments, c_list_outputs, c_infer_shape, c_declare_backwards_dependency, c_list_auxiliary_states, c_delete, - payload, payload, payload, payload, payload, payload) - end -end - -const __prop_pinned_memory = WeakKeyDict{CustomOpProp, Vector{Any}}() -function _pin!(op :: CustomOpProp, x :: ANY) - xs = get(__prop_pinned_memory, op, Any[]) - push!(xs, x) - __prop_pinned_memory[op] = xs -end - -function _finalizer(op :: CustomOpProp) - if haskey(__prop_pinned_memory) - delete!(__prop_pinned_memory, op) - else -end - -function _delete_entry(payload :: Ptr{Void}) - # Figure out what to do here. This is to keep this part of the memory alive -end - -function _infer_shape_entry(num_tensor, tensor_dims, tensor_shapes, payload) - try - op = unsafe_pointer_to_objref(payload) :: CustomOpProp - n_in = length(list_arguments(op)) - n_out = length(list_outputs(op)) - n_aux = length(list_auxiliary_states()) - - @assert num_tensor == n_in + n_out + n_aux - - shapes = Vector{Cuint}[] - # copy and revert input shapes - for i in 1:n_in - # Get size of array and create julia arry - ndims = unsafe_load(tensor_dims, i) - shape = zeros(Cuint, ndims) - tshape = unsafe_load(tensor_shapes, i) - for j in 1:ndims - shape[j] = unsafe_load(tshapes, ndims-j + 1) - end - push!(shapes, shape) - end - - ret = infer_shape(op, shapes) - if length(ret) == 2 - ishapes, oshapes = ret - ashapes = Cuint[] - elseif lenght(ret) == 3 - ishapes, oshapes, ashapes = ret - else - error("infer_shape must return 2 or 3 lists.") - end - - @assert length(ishapes) == n_in - @assert length(oshapes) == n_out - @assert length(ashapes) == n_aux - - # We now have to reverse the arrays again - # We can't perform a inplace operation in case the arrays share memory - rshapes = Vector{Cuint} - for shape in ishapes - push!(rshapes, reverse(shape)) - end - for shape in oshapes - push!(rshapes, reverse(shape)) - end - for shape in ashapes - push!(rshapes, reverse(shape)) - end - - _pin!(op, rshapes) - - for i in 1:num_tensors - unsafe_store!(tensor_shapes, pointer(rshapes[i]), i) - unsafe_store!(tensor_dims, length(rshapes[i]), i) - end - catch err - println(STDERR, "Error in infer_shape: ") - showerror(STDERR, err) - return false - end - return true -end - -function _list_arguments_entry(data :: Ptr{Ptr{Ptr{Cchar}}}, payload :: Ptr{Void}) - try - op = unsafe_pointer_to_objref(payload) :: CustomOpProp - arguments = list_arguments(op) - _pin!(op, arguments) - ptrs = Ptr{Cchar}[Base.unsafe_convert(Ptr{Cchar}, s) for s in arguments] - _pin!(op, ptrs) - push!(ptrs, C_NULL) - unsafe_store!(data, pointer(ptrs), 1) - catch err - println(STDERR, "Error in list_arguments: ") - showerror(STDERR, err) - return false - end - return true -end - -function _list_outputs_entry(data :: Ptr{Ptr{Ptr{Cchar}}}, payload :: Ptr{Void}) - try - op = unsafe_pointer_to_objref(payload) :: CustomOpProp - outputs = list_outputs(op) - _pin!(op, outputs) - ptrs = Ptr{Cchar}[Base.unsafe_convert(Ptr{Cchar}, s) for s in outputs] - _pin!(op, ptrs) - push!(ptrs, C_NULL) - unsafe_store!(data, pointer(ptrs), 1) - catch err - println(STDERR, "Error in list_outputs: ") - showerror(STDERR, err) - return false - end - return true -end - -function _list_auxiliary_states_entry(data :: Ptr{Ptr{Ptr{Cchar}}}, payload :: Ptr{Void}) - try - op = unsafe_pointer_to_objref(payload) :: CustomOpProp - aux = list_auxiliary_states(op) - _pin!(op, aux) - ptrs = Ptr{Cchar}[Base.unsafe_convert(Ptr{Cchar}, s) for s in aux] - _pin!(op, ptrs) - push!(ptrs, C_NULL) - unsafe_store!(data, pointer(ptrs), 1) - catch err - println(STDERR, "Error in list_auxiliary_states: ") - showerror(STDERR, err) - return false - end - return true -end - -function _declare_backward_dependency(_out_grad :: Ptr{Cint}, - _in_data :: Ptr{Cint}, - _out_data :: Ptr{Cint} - num_dep :: Ptr{Cint}, - deps :: Ptr{Ptr{Cint}}, - payload :: Ptr{Void}) - try - op = unsafe_pointer_to_objref(payload) :: CustomOpProp - out_grad = unsafe_wrap(Array, _out_grad, length(list_outputs(op))) - in_data = unsafe_wrap(Array, _in_data, length(list_arguments(op))) - out_data = unsafe_wrap(Array, _out_data, length(list_outputs(op))) - - rdeps = convert(Vector{Cint}, declare_backward_dependency(op, out_grad, in_data, out_data)) - _pin!(op, rdeps) - - unsafe_store!(num_dep, length(rdeps), 1) - unsafe_store!(deps, pointer(rdeps), 1) - catch err - println(STDERR, "Error in declare_backward_dependency: ") - showerror(STDERR, err) - return false - end - return true -end - diff --git a/src/_custom_prop.jl b/src/_custom_prop.jl new file mode 100644 index 000000000..b31bcb3f3 --- /dev/null +++ b/src/_custom_prop.jl @@ -0,0 +1,185 @@ +immutable CustomOpPropInfo + list_arguments :: Ptr{Void} + list_outputs :: Ptr{Void} + infer_shape :: Ptr{Void} + declare_backward_dependency :: Ptr{Void} + create_operator :: Ptr{Void} + list_auxiliary_states :: Ptr{Void} + delete :: Ptr{Void} + p_list_arguments :: Ptr{Void} + p_list_outputs :: Ptr{Void} + p_infer_shape :: Ptr{Void} + p_declare_backward_dependency :: Ptr{Void} + p_create_operator :: Ptr{Void} + p_list_auxiliary_states :: Ptr{Void} + p_delete :: Ptr{Void} + function CustomOpPropInfo(op :: CustomOpProp) + payload = pointer_from_objref(op) + c_infer_shape = cfunction(_infer_shape_entry, Bool, (Cint, Ptr{Ptr{Void}}, Ptr{Cint}, Ptr{Void})) + c_list_outputs = cfunction(_list_outputs_entry, Bool, (Ptr{Ptr{Ptr{Cchar}}}, Ptr{Void})) + c_list_arguments = cfunction(_list_arguments_entry, Bool, (Ptr{Ptr{Ptr{Cchar}}}, Ptr{Void})) + c_list_auxiliary_states = cfunction(_list_auxiliary_states_entry, Bool, (Ptr{Ptr{Ptr{Cchar}}}, Ptr{Void})) + c_declare_backward_dependency = cfunction(_declare_backward_dependency_entry, Bool, (Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Ptr{Cint}}, Ptr{Void})) + c_delete = cfunction(_delete_entry, Void, (Ptr{Void},)) + + new(c_list_arguments, c_list_outputs, c_infer_shape, c_declare_backwards_dependency, c_list_auxiliary_states, c_delete, + payload, payload, payload, payload, payload, payload) + end +end + +const __prop_pinned_memory = WeakKeyDict{CustomOpProp, Vector{Any}}() +function _pin!(op :: CustomOpProp, x :: ANY) + xs = get(__prop_pinned_memory, op, Any[]) + push!(xs, x) + __prop_pinned_memory[op] = xs +end + +function _finalizer(op :: CustomOpProp) + if haskey(__prop_pinned_memory) + delete!(__prop_pinned_memory, op) + else +end + +function _delete_entry(payload :: Ptr{Void}) + # Figure out what to do here. This is to keep this part of the memory alive +end + +function _infer_shape_entry(num_tensor, tensor_dims, tensor_shapes, payload) + try + op = unsafe_pointer_to_objref(payload) :: CustomOpProp + n_in = length(list_arguments(op)) + n_out = length(list_outputs(op)) + n_aux = length(list_auxiliary_states()) + + @assert num_tensor == n_in + n_out + n_aux + + shapes = Vector{Cuint}[] + # copy and revert input shapes + for i in 1:n_in + # Get size of array and create julia arry + ndims = unsafe_load(tensor_dims, i) + shape = zeros(Cuint, ndims) + tshape = unsafe_load(tensor_shapes, i) + for j in 1:ndims + shape[j] = unsafe_load(tshapes, ndims-j + 1) + end + push!(shapes, shape) + end + + ret = infer_shape(op, shapes) + if length(ret) == 2 + ishapes, oshapes = ret + ashapes = Cuint[] + elseif lenght(ret) == 3 + ishapes, oshapes, ashapes = ret + else + error("infer_shape must return 2 or 3 lists.") + end + + @assert length(ishapes) == n_in + @assert length(oshapes) == n_out + @assert length(ashapes) == n_aux + + # We now have to reverse the arrays again + # We can't perform a inplace operation in case the arrays share memory + rshapes = Vector{Cuint} + for shape in ishapes + push!(rshapes, reverse(shape)) + end + for shape in oshapes + push!(rshapes, reverse(shape)) + end + for shape in ashapes + push!(rshapes, reverse(shape)) + end + + _pin!(op, rshapes) + + for i in 1:num_tensors + unsafe_store!(tensor_shapes, pointer(rshapes[i]), i) + unsafe_store!(tensor_dims, length(rshapes[i]), i) + end + catch err + println(STDERR, "Error in infer_shape: ") + showerror(STDERR, err) + return false + end + return true +end + +function _list_arguments_entry(data :: Ptr{Ptr{Ptr{Cchar}}}, payload :: Ptr{Void}) + try + op = unsafe_pointer_to_objref(payload) :: CustomOpProp + arguments = list_arguments(op) + _pin!(op, arguments) + ptrs = Ptr{Cchar}[Base.unsafe_convert(Ptr{Cchar}, s) for s in arguments] + _pin!(op, ptrs) + push!(ptrs, C_NULL) + unsafe_store!(data, pointer(ptrs), 1) + catch err + println(STDERR, "Error in list_arguments: ") + showerror(STDERR, err) + return false + end + return true +end + +function _list_outputs_entry(data :: Ptr{Ptr{Ptr{Cchar}}}, payload :: Ptr{Void}) + try + op = unsafe_pointer_to_objref(payload) :: CustomOpProp + outputs = list_outputs(op) + _pin!(op, outputs) + ptrs = Ptr{Cchar}[Base.unsafe_convert(Ptr{Cchar}, s) for s in outputs] + _pin!(op, ptrs) + push!(ptrs, C_NULL) + unsafe_store!(data, pointer(ptrs), 1) + catch err + println(STDERR, "Error in list_outputs: ") + showerror(STDERR, err) + return false + end + return true +end + +function _list_auxiliary_states_entry(data :: Ptr{Ptr{Ptr{Cchar}}}, payload :: Ptr{Void}) + try + op = unsafe_pointer_to_objref(payload) :: CustomOpProp + aux = list_auxiliary_states(op) + _pin!(op, aux) + ptrs = Ptr{Cchar}[Base.unsafe_convert(Ptr{Cchar}, s) for s in aux] + _pin!(op, ptrs) + push!(ptrs, C_NULL) + unsafe_store!(data, pointer(ptrs), 1) + catch err + println(STDERR, "Error in list_auxiliary_states: ") + showerror(STDERR, err) + return false + end + return true +end + +function _declare_backward_dependency(_out_grad :: Ptr{Cint}, + _in_data :: Ptr{Cint}, + _out_data :: Ptr{Cint} + num_dep :: Ptr{Cint}, + deps :: Ptr{Ptr{Cint}}, + payload :: Ptr{Void}) + try + op = unsafe_pointer_to_objref(payload) :: CustomOpProp + out_grad = unsafe_wrap(Array, _out_grad, length(list_outputs(op))) + in_data = unsafe_wrap(Array, _in_data, length(list_arguments(op))) + out_data = unsafe_wrap(Array, _out_data, length(list_outputs(op))) + + rdeps = convert(Vector{Cint}, declare_backward_dependency(op, out_grad, in_data, out_data)) + _pin!(op, rdeps) + + unsafe_store!(num_dep, length(rdeps), 1) + unsafe_store!(deps, pointer(rdeps), 1) + catch err + println(STDERR, "Error in declare_backward_dependency: ") + showerror(STDERR, err) + return false + end + return true +end + diff --git a/src/custom.jl b/src/custom.jl index 78da4c808..5384d78d4 100644 --- a/src/custom.jl +++ b/src/custom.jl @@ -84,5 +84,6 @@ function declare_backward_dependency(self :: CustomOpProp, out_grad, in_data, ou append!(deps, out_data) end -include("_impl_custom.jl") +include("_custom_prop.jl") +include("_custom_impl.jl") end From 6d593addd227e9b51fff219d29ec655ccc8a1c62 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 2 Sep 2016 05:19:36 +0900 Subject: [PATCH 5/8] add entry function for fb --- src/_custom_impl.jl | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/src/_custom_impl.jl b/src/_custom_impl.jl index 7c6779cef..af1a49a36 100644 --- a/src/_custom_impl.jl +++ b/src/_custom_impl.jl @@ -1,3 +1,5 @@ +import RawMutex + immutable CustomOpInfo forward :: Ptr{Void} backward :: Ptr{Void} @@ -6,3 +8,40 @@ immutable CustomOpInfo p_backward :: Ptr{Void} p_delete :: Ptr{Void} end + +## +# Forward and backward can be called from different threads in libmxnet and +# so we need to take special care in handling these callbacks correctly in +# the julia runtime. + +immutable _FB + handle :: Ptr{Void} + m_entry :: RawMutex.Mutex + size :: Cint + data :: Ptr{Ptr{Void}} + tags :: Ptr{Cint} +end + +_FB(handle :: Ptr{Void}, m_entry) = _FB(handle,m_entry, 0, 0, 0) +@assert isbits(_FB) + +# This function is called async and because the Julia runtime is not thread safe, we are +# very limited in the things we can do. Using a immutable that is a bitstype we can pass, +# return values to the handling tasks. +function _wrapper_fb(size :: Cint, data :: Ptr{Ptr{Void}}, tags :: Ptr{Cint}, payload :: Ptr{Void}) + # Load the libuv async handle + ptr = convert(Ptr{_FB}, payload) + handle = unsafe_load(ptr, 1).handle + m_entry = unsafe_load(ptr, 1).m_entry + + # lock the hard part + RawMutex.lock(m_entry) + + # Create result + val = _FB(handle, m_entry, b_exit, size, data, tags) + unsafe_store!(ptr, val, 1) + + ccall(:uv_async_send, Void, (Ptr{Void},), handle) + + return true # Better solution? +end From 1b0639a40db31ac54d9525be40b9e3bf30153ffe Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Mon, 5 Sep 2016 06:54:15 +0900 Subject: [PATCH 6/8] add infrastructure to create tasks to handle the entry functions --- src/_custom_impl.jl | 54 +++++++++++++++++++++++++++++++++++++++++++++ src/_custom_prop.jl | 2 +- 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/src/_custom_impl.jl b/src/_custom_impl.jl index af1a49a36..9994b6e94 100644 --- a/src/_custom_impl.jl +++ b/src/_custom_impl.jl @@ -7,6 +7,26 @@ immutable CustomOpInfo p_forward :: Ptr{Void} p_backward :: Ptr{Void} p_delete :: Ptr{Void} + + function CustomOpInfo(op :: Operator) + c_wrapper_fb = cfunction(_wrapper_fb, Bool, (Cint, Ptr{Ptr{Void}}, Ptr{Ptr{Cint}}, Ptr{Void})) + p_f = _create_entry(op, _forward_entry) + p_b = _create_entry(op, _backward_entry) + new(c_wrapper_fb, c_wrapper_fb, C_NULL, p_f, p_b, C_NULL) + end +end + +const __op_pinned_memory{Operator, Vector{Any}}() +function _pin!(op :: Operator, x :: ANY) + xs = get(__op_pinned_memory, op, Any[]) + push!(xs, x) + __op_pinned_memory[op] = xs +end + +function _finalizer(op :: Operator) + if haskey(__op_pinned_memory) + delete!(__op_pinned_memory, op) + end end ## @@ -45,3 +65,37 @@ function _wrapper_fb(size :: Cint, data :: Ptr{Ptr{Void}}, tags :: Ptr{Cint}, pa return true # Better solution? end + +function _forward_entry(op :: Operator, payload :: _FB) + info("Forward entry function") +end + +function _backward_entry(op :: Operator, payload :: _FB) + info("Backward entry function") +end + +function _create_entry(op:: Operator, _entry :: Function) + cond = Base.AsyncCondition() + m_entry = RawMutex.create_mutex() + + ref = Ref(_FB(Base.unsafe_convert(Ptr{Void}, cond), m_entry)) + ptr = Base.unsafe_convert(Ptr{Void}, ref) + + task = @schedule begin + try + while true + wait(cond) # Do we need to replace the AsyncCondition? + _entry(op, ref[]) + RawMutex.unlock(m_entry) + end + catch err + @show err + rethrow() + finally + Base.close(cond) + RawMutex.close_mutex(m_enrty) + end + end + _pin!(op, task) + return ptr +end diff --git a/src/_custom_prop.jl b/src/_custom_prop.jl index b31bcb3f3..a71d03492 100644 --- a/src/_custom_prop.jl +++ b/src/_custom_prop.jl @@ -37,7 +37,7 @@ end function _finalizer(op :: CustomOpProp) if haskey(__prop_pinned_memory) delete!(__prop_pinned_memory, op) - else + end end function _delete_entry(payload :: Ptr{Void}) From db480f119151b082a33cd877724aee00406b86fc Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Tue, 6 Sep 2016 02:10:09 +0900 Subject: [PATCH 7/8] handle entry functions for backwards and forwards --- src/_custom_impl.jl | 60 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 55 insertions(+), 5 deletions(-) diff --git a/src/_custom_impl.jl b/src/_custom_impl.jl index 9994b6e94..0cc7559de 100644 --- a/src/_custom_impl.jl +++ b/src/_custom_impl.jl @@ -37,18 +37,20 @@ end immutable _FB handle :: Ptr{Void} m_entry :: RawMutex.Mutex - size :: Cint - data :: Ptr{Ptr{Void}} + num_ndarray :: Cint + ndarries :: Ptr{Ptr{Void}} tags :: Ptr{Cint} + reqs :: Ptr{Cint} + is_train :: Bool end -_FB(handle :: Ptr{Void}, m_entry) = _FB(handle,m_entry, 0, 0, 0) +_FB(handle :: Ptr{Void}, m_entry) = _FB(handle,m_entry, 0, 0, 0, 0, 0) @assert isbits(_FB) # This function is called async and because the Julia runtime is not thread safe, we are # very limited in the things we can do. Using a immutable that is a bitstype we can pass, # return values to the handling tasks. -function _wrapper_fb(size :: Cint, data :: Ptr{Ptr{Void}}, tags :: Ptr{Cint}, payload :: Ptr{Void}) +function _wrapper_fb(num_ndarray :: Cint, ndarries :: Ptr{Ptr{Void}}, tags :: Ptr{Cint}, reqs :: Ptr{Cint}, is_train :: Bool, payload :: Ptr{Void}) # Load the libuv async handle ptr = convert(Ptr{_FB}, payload) handle = unsafe_load(ptr, 1).handle @@ -58,7 +60,7 @@ function _wrapper_fb(size :: Cint, data :: Ptr{Ptr{Void}}, tags :: Ptr{Cint}, pa RawMutex.lock(m_entry) # Create result - val = _FB(handle, m_entry, b_exit, size, data, tags) + val = _FB(handle, m_entry, b_exit, num_ndarray, ndarries, tags, reqs, is_train) unsafe_store!(ptr, val, 1) ccall(:uv_async_send, Void, (Ptr{Void},), handle) @@ -68,10 +70,58 @@ end function _forward_entry(op :: Operator, payload :: _FB) info("Forward entry function") + num_ndarray = payload.num_ndarray + ndarries = unsafe_wrap(Array, payload.ndarries, num_ndarray) + tags = unsafe_wrap(Array, payload.tags, num_ndarray) + reqs = unsafe_wrap(Array, payload.reqs, num_ndarray) + is_train = payload.is_train + + in_data = NDArray[] + out_data = NDArray[] + aux = NDArray[] + for (ndarray, tag) in zip(ndarries, tags) + handle = mx.MX_NDArrayHandle(ndarray) + if tag == 1 || tag == 4 + tensors = tag == 1 ? out_data : aux + push!(tensors, NDArray(handle, true)) + elseif tag == 0 + push!(in_data, NDArray(handle, false)) + else + error("Received incorrect tag: $tag for handle $handle") + end + end + @show reqs + req = reqs # TODO: map to symbols + forward(op, is_train, req, in_data, out_data, aux) end function _backward_entry(op :: Operator, payload :: _FB) info("Backward entry function") + num_ndarray = payload.num_ndarray + ndarries = unsafe_wrap(Array, payload.ndarries, num_ndarray) + tags = unsafe_wrap(Array, payload.tags, num_ndarray) + reqs = unsafe_wrap(Array, payload.reqs, num_ndarray) + + in_data = NDArray[] + out_data = NDArray[] + in_grad = NDArray[] + out_grad = NDArray[] + aux = NDArray[] + for (ndarray, tag) in zip(ndarries, tags) + handle = mx.MX_NDArrayHandle(ndarray) + if tag == 2 || tag == 4 + tensors = tag == 2 ? in_grad : aux + push!(tensors, NDArray(handle, true)) + elseif tag == 0 || tag == 1 || tag == 3 + tensors = tag == 0 ? in_data : + tag == 1 ? out_data : out_grad + push!(tensors, NDArray(handle, false)) + else + error("Received incorrect tag: $tag for handle $handle") + end + end + req = copy(req) # TODO: convert + backward(op, req, in_data, out_data, in_grad, out_grad, aux) end function _create_entry(op:: Operator, _entry :: Function) From ec21986bc89e750712f1e43934ac662f7f49d45b Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Tue, 6 Sep 2016 04:54:55 +0900 Subject: [PATCH 8/8] map values req to symbol --- src/_custom_impl.jl | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/_custom_impl.jl b/src/_custom_impl.jl index 0cc7559de..e6fc7dd34 100644 --- a/src/_custom_impl.jl +++ b/src/_custom_impl.jl @@ -68,6 +68,20 @@ function _wrapper_fb(num_ndarray :: Cint, ndarries :: Ptr{Ptr{Void}}, tags :: Pt return true # Better solution? end +function _req_enum(req) + if req == 0 + return :null + elseif req == 1 + return :write + elseif req == 2 + return :inplace + elseif req == 3 + return :add + else + error("Don't know req value $req") + end +end + function _forward_entry(op :: Operator, payload :: _FB) info("Forward entry function") num_ndarray = payload.num_ndarray @@ -90,8 +104,7 @@ function _forward_entry(op :: Operator, payload :: _FB) error("Received incorrect tag: $tag for handle $handle") end end - @show reqs - req = reqs # TODO: map to symbols + req = map(_reg_enum, reqs) forward(op, is_train, req, in_data, out_data, aux) end @@ -120,7 +133,7 @@ function _backward_entry(op :: Operator, payload :: _FB) error("Received incorrect tag: $tag for handle $handle") end end - req = copy(req) # TODO: convert + req = map(_reg_enum, reqs) backward(op, req, in_data, out_data, in_grad, out_grad, aux) end