Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Preview] Custom ops #130

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions src/_custom_impl.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import RawMutex

immutable CustomOpInfo
forward :: Ptr{Void}
backward :: Ptr{Void}
delete :: Ptr{Void}
p_forward :: Ptr{Void}
p_backward :: Ptr{Void}
p_delete :: Ptr{Void}

function CustomOpInfo(op :: Operator)
c_wrapper_fb = cfunction(_wrapper_fb, Bool, (Cint, Ptr{Ptr{Void}}, Ptr{Ptr{Cint}}, Ptr{Void}))
p_f = _create_entry(op, _forward_entry)
p_b = _create_entry(op, _backward_entry)
new(c_wrapper_fb, c_wrapper_fb, C_NULL, p_f, p_b, C_NULL)
end
end

const __op_pinned_memory{Operator, Vector{Any}}()
function _pin!(op :: Operator, x :: ANY)
xs = get(__op_pinned_memory, op, Any[])
push!(xs, x)
__op_pinned_memory[op] = xs
end

function _finalizer(op :: Operator)
if haskey(__op_pinned_memory)
delete!(__op_pinned_memory, op)
end
end

##
# Forward and backward can be called from different threads in libmxnet and
# so we need to take special care in handling these callbacks correctly in
# the julia runtime.

immutable _FB
handle :: Ptr{Void}
m_entry :: RawMutex.Mutex
num_ndarray :: Cint
ndarries :: Ptr{Ptr{Void}}
tags :: Ptr{Cint}
reqs :: Ptr{Cint}
is_train :: Bool
end

_FB(handle :: Ptr{Void}, m_entry) = _FB(handle,m_entry, 0, 0, 0, 0, 0)
@assert isbits(_FB)

# This function is called async and because the Julia runtime is not thread safe, we are
# very limited in the things we can do. Using a immutable that is a bitstype we can pass,
# return values to the handling tasks.
function _wrapper_fb(num_ndarray :: Cint, ndarries :: Ptr{Ptr{Void}}, tags :: Ptr{Cint}, reqs :: Ptr{Cint}, is_train :: Bool, payload :: Ptr{Void})
# Load the libuv async handle
ptr = convert(Ptr{_FB}, payload)
handle = unsafe_load(ptr, 1).handle
m_entry = unsafe_load(ptr, 1).m_entry

# lock the hard part
RawMutex.lock(m_entry)

# Create result
val = _FB(handle, m_entry, b_exit, num_ndarray, ndarries, tags, reqs, is_train)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_FB does not have field b_exit?

unsafe_store!(ptr, val, 1)

ccall(:uv_async_send, Void, (Ptr{Void},), handle)

return true # Better solution?
end

function _req_enum(req)
if req == 0
return :null
elseif req == 1
return :write
elseif req == 2
return :inplace
elseif req == 3
return :add
else
error("Don't know req value $req")
end
end

function _forward_entry(op :: Operator, payload :: _FB)
info("Forward entry function")
num_ndarray = payload.num_ndarray
ndarries = unsafe_wrap(Array, payload.ndarries, num_ndarray)
tags = unsafe_wrap(Array, payload.tags, num_ndarray)
reqs = unsafe_wrap(Array, payload.reqs, num_ndarray)
is_train = payload.is_train

in_data = NDArray[]
out_data = NDArray[]
aux = NDArray[]
for (ndarray, tag) in zip(ndarries, tags)
handle = mx.MX_NDArrayHandle(ndarray)
if tag == 1 || tag == 4
tensors = tag == 1 ? out_data : aux
push!(tensors, NDArray(handle, true))
elseif tag == 0
push!(in_data, NDArray(handle, false))
else
error("Received incorrect tag: $tag for handle $handle")
end
end
req = map(_reg_enum, reqs)
forward(op, is_train, req, in_data, out_data, aux)
end

function _backward_entry(op :: Operator, payload :: _FB)
info("Backward entry function")
num_ndarray = payload.num_ndarray
ndarries = unsafe_wrap(Array, payload.ndarries, num_ndarray)
tags = unsafe_wrap(Array, payload.tags, num_ndarray)
reqs = unsafe_wrap(Array, payload.reqs, num_ndarray)

in_data = NDArray[]
out_data = NDArray[]
in_grad = NDArray[]
out_grad = NDArray[]
aux = NDArray[]
for (ndarray, tag) in zip(ndarries, tags)
handle = mx.MX_NDArrayHandle(ndarray)
if tag == 2 || tag == 4
tensors = tag == 2 ? in_grad : aux
push!(tensors, NDArray(handle, true))
elseif tag == 0 || tag == 1 || tag == 3
tensors = tag == 0 ? in_data :
tag == 1 ? out_data : out_grad
push!(tensors, NDArray(handle, false))
else
error("Received incorrect tag: $tag for handle $handle")
end
end
req = map(_reg_enum, reqs)
backward(op, req, in_data, out_data, in_grad, out_grad, aux)
end

function _create_entry(op:: Operator, _entry :: Function)
cond = Base.AsyncCondition()
m_entry = RawMutex.create_mutex()

ref = Ref(_FB(Base.unsafe_convert(Ptr{Void}, cond), m_entry))
ptr = Base.unsafe_convert(Ptr{Void}, ref)

task = @schedule begin
try
while true
wait(cond) # Do we need to replace the AsyncCondition?
_entry(op, ref[])
RawMutex.unlock(m_entry)
end
catch err
@show err
rethrow()
finally
Base.close(cond)
RawMutex.close_mutex(m_enrty)
end
end
_pin!(op, task)
return ptr
end
185 changes: 185 additions & 0 deletions src/_custom_prop.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
immutable CustomOpPropInfo
list_arguments :: Ptr{Void}
list_outputs :: Ptr{Void}
infer_shape :: Ptr{Void}
declare_backward_dependency :: Ptr{Void}
create_operator :: Ptr{Void}
list_auxiliary_states :: Ptr{Void}
delete :: Ptr{Void}
p_list_arguments :: Ptr{Void}
p_list_outputs :: Ptr{Void}
p_infer_shape :: Ptr{Void}
p_declare_backward_dependency :: Ptr{Void}
p_create_operator :: Ptr{Void}
p_list_auxiliary_states :: Ptr{Void}
p_delete :: Ptr{Void}
function CustomOpPropInfo(op :: CustomOpProp)
payload = pointer_from_objref(op)
c_infer_shape = cfunction(_infer_shape_entry, Bool, (Cint, Ptr{Ptr{Void}}, Ptr{Cint}, Ptr{Void}))
c_list_outputs = cfunction(_list_outputs_entry, Bool, (Ptr{Ptr{Ptr{Cchar}}}, Ptr{Void}))
c_list_arguments = cfunction(_list_arguments_entry, Bool, (Ptr{Ptr{Ptr{Cchar}}}, Ptr{Void}))
c_list_auxiliary_states = cfunction(_list_auxiliary_states_entry, Bool, (Ptr{Ptr{Ptr{Cchar}}}, Ptr{Void}))
c_declare_backward_dependency = cfunction(_declare_backward_dependency_entry, Bool, (Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Ptr{Cint}}, Ptr{Void}))
c_delete = cfunction(_delete_entry, Void, (Ptr{Void},))

new(c_list_arguments, c_list_outputs, c_infer_shape, c_declare_backwards_dependency, c_list_auxiliary_states, c_delete,
payload, payload, payload, payload, payload, payload)
end
end

const __prop_pinned_memory = WeakKeyDict{CustomOpProp, Vector{Any}}()
function _pin!(op :: CustomOpProp, x :: ANY)
xs = get(__prop_pinned_memory, op, Any[])
push!(xs, x)
__prop_pinned_memory[op] = xs
end

function _finalizer(op :: CustomOpProp)
if haskey(__prop_pinned_memory)
delete!(__prop_pinned_memory, op)
end
end

function _delete_entry(payload :: Ptr{Void})
# Figure out what to do here. This is to keep this part of the memory alive
end

function _infer_shape_entry(num_tensor, tensor_dims, tensor_shapes, payload)
try
op = unsafe_pointer_to_objref(payload) :: CustomOpProp
n_in = length(list_arguments(op))
n_out = length(list_outputs(op))
n_aux = length(list_auxiliary_states())

@assert num_tensor == n_in + n_out + n_aux

shapes = Vector{Cuint}[]
# copy and revert input shapes
for i in 1:n_in
# Get size of array and create julia arry
ndims = unsafe_load(tensor_dims, i)
shape = zeros(Cuint, ndims)
tshape = unsafe_load(tensor_shapes, i)
for j in 1:ndims
shape[j] = unsafe_load(tshapes, ndims-j + 1)
end
push!(shapes, shape)
end

ret = infer_shape(op, shapes)
if length(ret) == 2
ishapes, oshapes = ret
ashapes = Cuint[]
elseif lenght(ret) == 3
ishapes, oshapes, ashapes = ret
else
error("infer_shape must return 2 or 3 lists.")
end

@assert length(ishapes) == n_in
@assert length(oshapes) == n_out
@assert length(ashapes) == n_aux

# We now have to reverse the arrays again
# We can't perform a inplace operation in case the arrays share memory
rshapes = Vector{Cuint}
for shape in ishapes
push!(rshapes, reverse(shape))
end
for shape in oshapes
push!(rshapes, reverse(shape))
end
for shape in ashapes
push!(rshapes, reverse(shape))
end

_pin!(op, rshapes)

for i in 1:num_tensors
unsafe_store!(tensor_shapes, pointer(rshapes[i]), i)
unsafe_store!(tensor_dims, length(rshapes[i]), i)
end
catch err
println(STDERR, "Error in infer_shape: ")
showerror(STDERR, err)
return false
end
return true
end

function _list_arguments_entry(data :: Ptr{Ptr{Ptr{Cchar}}}, payload :: Ptr{Void})
try
op = unsafe_pointer_to_objref(payload) :: CustomOpProp
arguments = list_arguments(op)
_pin!(op, arguments)
ptrs = Ptr{Cchar}[Base.unsafe_convert(Ptr{Cchar}, s) for s in arguments]
_pin!(op, ptrs)
push!(ptrs, C_NULL)
unsafe_store!(data, pointer(ptrs), 1)
catch err
println(STDERR, "Error in list_arguments: ")
showerror(STDERR, err)
return false
end
return true
end

function _list_outputs_entry(data :: Ptr{Ptr{Ptr{Cchar}}}, payload :: Ptr{Void})
try
op = unsafe_pointer_to_objref(payload) :: CustomOpProp
outputs = list_outputs(op)
_pin!(op, outputs)
ptrs = Ptr{Cchar}[Base.unsafe_convert(Ptr{Cchar}, s) for s in outputs]
_pin!(op, ptrs)
push!(ptrs, C_NULL)
unsafe_store!(data, pointer(ptrs), 1)
catch err
println(STDERR, "Error in list_outputs: ")
showerror(STDERR, err)
return false
end
return true
end

function _list_auxiliary_states_entry(data :: Ptr{Ptr{Ptr{Cchar}}}, payload :: Ptr{Void})
try
op = unsafe_pointer_to_objref(payload) :: CustomOpProp
aux = list_auxiliary_states(op)
_pin!(op, aux)
ptrs = Ptr{Cchar}[Base.unsafe_convert(Ptr{Cchar}, s) for s in aux]
_pin!(op, ptrs)
push!(ptrs, C_NULL)
unsafe_store!(data, pointer(ptrs), 1)
catch err
println(STDERR, "Error in list_auxiliary_states: ")
showerror(STDERR, err)
return false
end
return true
end

function _declare_backward_dependency(_out_grad :: Ptr{Cint},
_in_data :: Ptr{Cint},
_out_data :: Ptr{Cint}
num_dep :: Ptr{Cint},
deps :: Ptr{Ptr{Cint}},
payload :: Ptr{Void})
try
op = unsafe_pointer_to_objref(payload) :: CustomOpProp
out_grad = unsafe_wrap(Array, _out_grad, length(list_outputs(op)))
in_data = unsafe_wrap(Array, _in_data, length(list_arguments(op)))
out_data = unsafe_wrap(Array, _out_data, length(list_outputs(op)))

rdeps = convert(Vector{Cint}, declare_backward_dependency(op, out_grad, in_data, out_data))
_pin!(op, rdeps)

unsafe_store!(num_dep, length(rdeps), 1)
unsafe_store!(deps, pointer(rdeps), 1)
catch err
println(STDERR, "Error in declare_backward_dependency: ")
showerror(STDERR, err)
return false
end
return true
end

Loading