diff --git a/src/_custom_impl.jl b/src/_custom_impl.jl new file mode 100644 index 000000000..e6fc7dd34 --- /dev/null +++ b/src/_custom_impl.jl @@ -0,0 +1,164 @@ +import RawMutex + +immutable CustomOpInfo + forward :: Ptr{Void} + backward :: Ptr{Void} + delete :: Ptr{Void} + p_forward :: Ptr{Void} + p_backward :: Ptr{Void} + p_delete :: Ptr{Void} + + function CustomOpInfo(op :: Operator) + c_wrapper_fb = cfunction(_wrapper_fb, Bool, (Cint, Ptr{Ptr{Void}}, Ptr{Ptr{Cint}}, Ptr{Void})) + p_f = _create_entry(op, _forward_entry) + p_b = _create_entry(op, _backward_entry) + new(c_wrapper_fb, c_wrapper_fb, C_NULL, p_f, p_b, C_NULL) + end +end + +const __op_pinned_memory{Operator, Vector{Any}}() +function _pin!(op :: Operator, x :: ANY) + xs = get(__op_pinned_memory, op, Any[]) + push!(xs, x) + __op_pinned_memory[op] = xs +end + +function _finalizer(op :: Operator) + if haskey(__op_pinned_memory) + delete!(__op_pinned_memory, op) + end +end + +## +# Forward and backward can be called from different threads in libmxnet and +# so we need to take special care in handling these callbacks correctly in +# the julia runtime. + +immutable _FB + handle :: Ptr{Void} + m_entry :: RawMutex.Mutex + num_ndarray :: Cint + ndarries :: Ptr{Ptr{Void}} + tags :: Ptr{Cint} + reqs :: Ptr{Cint} + is_train :: Bool +end + +_FB(handle :: Ptr{Void}, m_entry) = _FB(handle,m_entry, 0, 0, 0, 0, 0) +@assert isbits(_FB) + +# This function is called async and because the Julia runtime is not thread safe, we are +# very limited in the things we can do. Using a immutable that is a bitstype we can pass, +# return values to the handling tasks. +function _wrapper_fb(num_ndarray :: Cint, ndarries :: Ptr{Ptr{Void}}, tags :: Ptr{Cint}, reqs :: Ptr{Cint}, is_train :: Bool, payload :: Ptr{Void}) + # Load the libuv async handle + ptr = convert(Ptr{_FB}, payload) + handle = unsafe_load(ptr, 1).handle + m_entry = unsafe_load(ptr, 1).m_entry + + # lock the hard part + RawMutex.lock(m_entry) + + # Create result + val = _FB(handle, m_entry, b_exit, num_ndarray, ndarries, tags, reqs, is_train) + unsafe_store!(ptr, val, 1) + + ccall(:uv_async_send, Void, (Ptr{Void},), handle) + + return true # Better solution? +end + +function _req_enum(req) + if req == 0 + return :null + elseif req == 1 + return :write + elseif req == 2 + return :inplace + elseif req == 3 + return :add + else + error("Don't know req value $req") + end +end + +function _forward_entry(op :: Operator, payload :: _FB) + info("Forward entry function") + num_ndarray = payload.num_ndarray + ndarries = unsafe_wrap(Array, payload.ndarries, num_ndarray) + tags = unsafe_wrap(Array, payload.tags, num_ndarray) + reqs = unsafe_wrap(Array, payload.reqs, num_ndarray) + is_train = payload.is_train + + in_data = NDArray[] + out_data = NDArray[] + aux = NDArray[] + for (ndarray, tag) in zip(ndarries, tags) + handle = mx.MX_NDArrayHandle(ndarray) + if tag == 1 || tag == 4 + tensors = tag == 1 ? out_data : aux + push!(tensors, NDArray(handle, true)) + elseif tag == 0 + push!(in_data, NDArray(handle, false)) + else + error("Received incorrect tag: $tag for handle $handle") + end + end + req = map(_reg_enum, reqs) + forward(op, is_train, req, in_data, out_data, aux) +end + +function _backward_entry(op :: Operator, payload :: _FB) + info("Backward entry function") + num_ndarray = payload.num_ndarray + ndarries = unsafe_wrap(Array, payload.ndarries, num_ndarray) + tags = unsafe_wrap(Array, payload.tags, num_ndarray) + reqs = unsafe_wrap(Array, payload.reqs, num_ndarray) + + in_data = NDArray[] + out_data = NDArray[] + in_grad = NDArray[] + out_grad = NDArray[] + aux = NDArray[] + for (ndarray, tag) in zip(ndarries, tags) + handle = mx.MX_NDArrayHandle(ndarray) + if tag == 2 || tag == 4 + tensors = tag == 2 ? in_grad : aux + push!(tensors, NDArray(handle, true)) + elseif tag == 0 || tag == 1 || tag == 3 + tensors = tag == 0 ? in_data : + tag == 1 ? out_data : out_grad + push!(tensors, NDArray(handle, false)) + else + error("Received incorrect tag: $tag for handle $handle") + end + end + req = map(_reg_enum, reqs) + backward(op, req, in_data, out_data, in_grad, out_grad, aux) +end + +function _create_entry(op:: Operator, _entry :: Function) + cond = Base.AsyncCondition() + m_entry = RawMutex.create_mutex() + + ref = Ref(_FB(Base.unsafe_convert(Ptr{Void}, cond), m_entry)) + ptr = Base.unsafe_convert(Ptr{Void}, ref) + + task = @schedule begin + try + while true + wait(cond) # Do we need to replace the AsyncCondition? + _entry(op, ref[]) + RawMutex.unlock(m_entry) + end + catch err + @show err + rethrow() + finally + Base.close(cond) + RawMutex.close_mutex(m_enrty) + end + end + _pin!(op, task) + return ptr +end diff --git a/src/_custom_prop.jl b/src/_custom_prop.jl new file mode 100644 index 000000000..a71d03492 --- /dev/null +++ b/src/_custom_prop.jl @@ -0,0 +1,185 @@ +immutable CustomOpPropInfo + list_arguments :: Ptr{Void} + list_outputs :: Ptr{Void} + infer_shape :: Ptr{Void} + declare_backward_dependency :: Ptr{Void} + create_operator :: Ptr{Void} + list_auxiliary_states :: Ptr{Void} + delete :: Ptr{Void} + p_list_arguments :: Ptr{Void} + p_list_outputs :: Ptr{Void} + p_infer_shape :: Ptr{Void} + p_declare_backward_dependency :: Ptr{Void} + p_create_operator :: Ptr{Void} + p_list_auxiliary_states :: Ptr{Void} + p_delete :: Ptr{Void} + function CustomOpPropInfo(op :: CustomOpProp) + payload = pointer_from_objref(op) + c_infer_shape = cfunction(_infer_shape_entry, Bool, (Cint, Ptr{Ptr{Void}}, Ptr{Cint}, Ptr{Void})) + c_list_outputs = cfunction(_list_outputs_entry, Bool, (Ptr{Ptr{Ptr{Cchar}}}, Ptr{Void})) + c_list_arguments = cfunction(_list_arguments_entry, Bool, (Ptr{Ptr{Ptr{Cchar}}}, Ptr{Void})) + c_list_auxiliary_states = cfunction(_list_auxiliary_states_entry, Bool, (Ptr{Ptr{Ptr{Cchar}}}, Ptr{Void})) + c_declare_backward_dependency = cfunction(_declare_backward_dependency_entry, Bool, (Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Ptr{Cint}}, Ptr{Void})) + c_delete = cfunction(_delete_entry, Void, (Ptr{Void},)) + + new(c_list_arguments, c_list_outputs, c_infer_shape, c_declare_backwards_dependency, c_list_auxiliary_states, c_delete, + payload, payload, payload, payload, payload, payload) + end +end + +const __prop_pinned_memory = WeakKeyDict{CustomOpProp, Vector{Any}}() +function _pin!(op :: CustomOpProp, x :: ANY) + xs = get(__prop_pinned_memory, op, Any[]) + push!(xs, x) + __prop_pinned_memory[op] = xs +end + +function _finalizer(op :: CustomOpProp) + if haskey(__prop_pinned_memory) + delete!(__prop_pinned_memory, op) + end +end + +function _delete_entry(payload :: Ptr{Void}) + # Figure out what to do here. This is to keep this part of the memory alive +end + +function _infer_shape_entry(num_tensor, tensor_dims, tensor_shapes, payload) + try + op = unsafe_pointer_to_objref(payload) :: CustomOpProp + n_in = length(list_arguments(op)) + n_out = length(list_outputs(op)) + n_aux = length(list_auxiliary_states()) + + @assert num_tensor == n_in + n_out + n_aux + + shapes = Vector{Cuint}[] + # copy and revert input shapes + for i in 1:n_in + # Get size of array and create julia arry + ndims = unsafe_load(tensor_dims, i) + shape = zeros(Cuint, ndims) + tshape = unsafe_load(tensor_shapes, i) + for j in 1:ndims + shape[j] = unsafe_load(tshapes, ndims-j + 1) + end + push!(shapes, shape) + end + + ret = infer_shape(op, shapes) + if length(ret) == 2 + ishapes, oshapes = ret + ashapes = Cuint[] + elseif lenght(ret) == 3 + ishapes, oshapes, ashapes = ret + else + error("infer_shape must return 2 or 3 lists.") + end + + @assert length(ishapes) == n_in + @assert length(oshapes) == n_out + @assert length(ashapes) == n_aux + + # We now have to reverse the arrays again + # We can't perform a inplace operation in case the arrays share memory + rshapes = Vector{Cuint} + for shape in ishapes + push!(rshapes, reverse(shape)) + end + for shape in oshapes + push!(rshapes, reverse(shape)) + end + for shape in ashapes + push!(rshapes, reverse(shape)) + end + + _pin!(op, rshapes) + + for i in 1:num_tensors + unsafe_store!(tensor_shapes, pointer(rshapes[i]), i) + unsafe_store!(tensor_dims, length(rshapes[i]), i) + end + catch err + println(STDERR, "Error in infer_shape: ") + showerror(STDERR, err) + return false + end + return true +end + +function _list_arguments_entry(data :: Ptr{Ptr{Ptr{Cchar}}}, payload :: Ptr{Void}) + try + op = unsafe_pointer_to_objref(payload) :: CustomOpProp + arguments = list_arguments(op) + _pin!(op, arguments) + ptrs = Ptr{Cchar}[Base.unsafe_convert(Ptr{Cchar}, s) for s in arguments] + _pin!(op, ptrs) + push!(ptrs, C_NULL) + unsafe_store!(data, pointer(ptrs), 1) + catch err + println(STDERR, "Error in list_arguments: ") + showerror(STDERR, err) + return false + end + return true +end + +function _list_outputs_entry(data :: Ptr{Ptr{Ptr{Cchar}}}, payload :: Ptr{Void}) + try + op = unsafe_pointer_to_objref(payload) :: CustomOpProp + outputs = list_outputs(op) + _pin!(op, outputs) + ptrs = Ptr{Cchar}[Base.unsafe_convert(Ptr{Cchar}, s) for s in outputs] + _pin!(op, ptrs) + push!(ptrs, C_NULL) + unsafe_store!(data, pointer(ptrs), 1) + catch err + println(STDERR, "Error in list_outputs: ") + showerror(STDERR, err) + return false + end + return true +end + +function _list_auxiliary_states_entry(data :: Ptr{Ptr{Ptr{Cchar}}}, payload :: Ptr{Void}) + try + op = unsafe_pointer_to_objref(payload) :: CustomOpProp + aux = list_auxiliary_states(op) + _pin!(op, aux) + ptrs = Ptr{Cchar}[Base.unsafe_convert(Ptr{Cchar}, s) for s in aux] + _pin!(op, ptrs) + push!(ptrs, C_NULL) + unsafe_store!(data, pointer(ptrs), 1) + catch err + println(STDERR, "Error in list_auxiliary_states: ") + showerror(STDERR, err) + return false + end + return true +end + +function _declare_backward_dependency(_out_grad :: Ptr{Cint}, + _in_data :: Ptr{Cint}, + _out_data :: Ptr{Cint} + num_dep :: Ptr{Cint}, + deps :: Ptr{Ptr{Cint}}, + payload :: Ptr{Void}) + try + op = unsafe_pointer_to_objref(payload) :: CustomOpProp + out_grad = unsafe_wrap(Array, _out_grad, length(list_outputs(op))) + in_data = unsafe_wrap(Array, _in_data, length(list_arguments(op))) + out_data = unsafe_wrap(Array, _out_data, length(list_outputs(op))) + + rdeps = convert(Vector{Cint}, declare_backward_dependency(op, out_grad, in_data, out_data)) + _pin!(op, rdeps) + + unsafe_store!(num_dep, length(rdeps), 1) + unsafe_store!(deps, pointer(rdeps), 1) + catch err + println(STDERR, "Error in declare_backward_dependency: ") + showerror(STDERR, err) + return false + end + return true +end + diff --git a/src/custom.jl b/src/custom.jl new file mode 100644 index 000000000..5384d78d4 --- /dev/null +++ b/src/custom.jl @@ -0,0 +1,89 @@ +""" +Custom and native operators written in Julia. +The interface is given by the abstract type [`Operactor`](@ref). +""" +module Custom + +using Compat +import Compat: String + +import ..mx +import ..mx: NDArray + +export assign + +""" + Operator +""" +abstract Operator + +""" + forward(op, is_train, req, in_data, out_data, aux) + +Forward interface. Custom operators must override this. + +# Arguments: +* `is_train::Bool`: Whether we are in training +* `req::Vector{Symbol}`: How to assign to out_data. Can be :null, :write, :inplace, or :add. You can use `assign(dst, req, src)` to handle this +* `in_data::Vector{NDArray}` +* `out_data::Vector{NDArray}` +* `aux::Vector{NDArray}` +""" +function forward(op::Operator, is_train, req, in_data, out_data, aux) + throw(MethodError(forward, (op, is_train, req, in_data, out_data))) +end + +""" +Backwards interface. Custom operators must override this. +""" +function backward(op::Operator, req, out_grad, in_data, out_data, in_grad, aux) + throw(MethodError(backward, (op, req, out_grad, in_data, out_data, in_grad, aux))) +end + +function assign(dst, req, src) + if req == :null + return nothing + elseif req == :write || req == :inplace + dst[:] = src + elseif req == :add + dst[:] += src + else + error("Unable to handle $req in assign.") + end + return nothing +end + +abstract CustomOpProp + +function needs_top_grad(self :: CustomOpProp) + return false +end + +function infer_shape(self :: CustomOpProp, in_shape) + return in_shape, [in_shape[1]], [] +end + +function list_outputs(self :: CustomOpProp) + return String["output"] +end + +function list_arguments(self :: CustomOpProp) + return String["data"] +end + +function list_auxiliary_states(self :: CustomOpProp) + return String[] +end + +function declare_backward_dependency(self :: CustomOpProp, out_grad, in_data, out_data) + deps = Int[] + if needs_top_grad(self) + append!(deps, out_grad) + end + append!(deps, in_data) + append!(deps, out_data) +end + +include("_custom_prop.jl") +include("_custom_impl.jl") +end