-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Preview] Custom ops #130
Closed
Closed
[Preview] Custom ops #130
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
ed9b2c8
add initial draft of the custom Operator interface
vchuravy 36b637d
add infer_shape_entry
vchuravy 220baf6
add more entry functions
vchuravy d264a90
reorganize
vchuravy 6d593ad
add entry function for fb
vchuravy 1b0639a
add infrastructure to create tasks to handle the entry functions
vchuravy db480f1
handle entry functions for backwards and forwards
vchuravy ec21986
map values req to symbol
vchuravy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
import RawMutex | ||
|
||
immutable CustomOpInfo | ||
forward :: Ptr{Void} | ||
backward :: Ptr{Void} | ||
delete :: Ptr{Void} | ||
p_forward :: Ptr{Void} | ||
p_backward :: Ptr{Void} | ||
p_delete :: Ptr{Void} | ||
|
||
function CustomOpInfo(op :: Operator) | ||
c_wrapper_fb = cfunction(_wrapper_fb, Bool, (Cint, Ptr{Ptr{Void}}, Ptr{Ptr{Cint}}, Ptr{Void})) | ||
p_f = _create_entry(op, _forward_entry) | ||
p_b = _create_entry(op, _backward_entry) | ||
new(c_wrapper_fb, c_wrapper_fb, C_NULL, p_f, p_b, C_NULL) | ||
end | ||
end | ||
|
||
const __op_pinned_memory{Operator, Vector{Any}}() | ||
function _pin!(op :: Operator, x :: ANY) | ||
xs = get(__op_pinned_memory, op, Any[]) | ||
push!(xs, x) | ||
__op_pinned_memory[op] = xs | ||
end | ||
|
||
function _finalizer(op :: Operator) | ||
if haskey(__op_pinned_memory) | ||
delete!(__op_pinned_memory, op) | ||
end | ||
end | ||
|
||
## | ||
# Forward and backward can be called from different threads in libmxnet and | ||
# so we need to take special care in handling these callbacks correctly in | ||
# the julia runtime. | ||
|
||
immutable _FB | ||
handle :: Ptr{Void} | ||
m_entry :: RawMutex.Mutex | ||
num_ndarray :: Cint | ||
ndarries :: Ptr{Ptr{Void}} | ||
tags :: Ptr{Cint} | ||
reqs :: Ptr{Cint} | ||
is_train :: Bool | ||
end | ||
|
||
_FB(handle :: Ptr{Void}, m_entry) = _FB(handle,m_entry, 0, 0, 0, 0, 0) | ||
@assert isbits(_FB) | ||
|
||
# This function is called async and because the Julia runtime is not thread safe, we are | ||
# very limited in the things we can do. Using a immutable that is a bitstype we can pass, | ||
# return values to the handling tasks. | ||
function _wrapper_fb(num_ndarray :: Cint, ndarries :: Ptr{Ptr{Void}}, tags :: Ptr{Cint}, reqs :: Ptr{Cint}, is_train :: Bool, payload :: Ptr{Void}) | ||
# Load the libuv async handle | ||
ptr = convert(Ptr{_FB}, payload) | ||
handle = unsafe_load(ptr, 1).handle | ||
m_entry = unsafe_load(ptr, 1).m_entry | ||
|
||
# lock the hard part | ||
RawMutex.lock(m_entry) | ||
|
||
# Create result | ||
val = _FB(handle, m_entry, b_exit, num_ndarray, ndarries, tags, reqs, is_train) | ||
unsafe_store!(ptr, val, 1) | ||
|
||
ccall(:uv_async_send, Void, (Ptr{Void},), handle) | ||
|
||
return true # Better solution? | ||
end | ||
|
||
function _req_enum(req) | ||
if req == 0 | ||
return :null | ||
elseif req == 1 | ||
return :write | ||
elseif req == 2 | ||
return :inplace | ||
elseif req == 3 | ||
return :add | ||
else | ||
error("Don't know req value $req") | ||
end | ||
end | ||
|
||
function _forward_entry(op :: Operator, payload :: _FB) | ||
info("Forward entry function") | ||
num_ndarray = payload.num_ndarray | ||
ndarries = unsafe_wrap(Array, payload.ndarries, num_ndarray) | ||
tags = unsafe_wrap(Array, payload.tags, num_ndarray) | ||
reqs = unsafe_wrap(Array, payload.reqs, num_ndarray) | ||
is_train = payload.is_train | ||
|
||
in_data = NDArray[] | ||
out_data = NDArray[] | ||
aux = NDArray[] | ||
for (ndarray, tag) in zip(ndarries, tags) | ||
handle = mx.MX_NDArrayHandle(ndarray) | ||
if tag == 1 || tag == 4 | ||
tensors = tag == 1 ? out_data : aux | ||
push!(tensors, NDArray(handle, true)) | ||
elseif tag == 0 | ||
push!(in_data, NDArray(handle, false)) | ||
else | ||
error("Received incorrect tag: $tag for handle $handle") | ||
end | ||
end | ||
req = map(_reg_enum, reqs) | ||
forward(op, is_train, req, in_data, out_data, aux) | ||
end | ||
|
||
function _backward_entry(op :: Operator, payload :: _FB) | ||
info("Backward entry function") | ||
num_ndarray = payload.num_ndarray | ||
ndarries = unsafe_wrap(Array, payload.ndarries, num_ndarray) | ||
tags = unsafe_wrap(Array, payload.tags, num_ndarray) | ||
reqs = unsafe_wrap(Array, payload.reqs, num_ndarray) | ||
|
||
in_data = NDArray[] | ||
out_data = NDArray[] | ||
in_grad = NDArray[] | ||
out_grad = NDArray[] | ||
aux = NDArray[] | ||
for (ndarray, tag) in zip(ndarries, tags) | ||
handle = mx.MX_NDArrayHandle(ndarray) | ||
if tag == 2 || tag == 4 | ||
tensors = tag == 2 ? in_grad : aux | ||
push!(tensors, NDArray(handle, true)) | ||
elseif tag == 0 || tag == 1 || tag == 3 | ||
tensors = tag == 0 ? in_data : | ||
tag == 1 ? out_data : out_grad | ||
push!(tensors, NDArray(handle, false)) | ||
else | ||
error("Received incorrect tag: $tag for handle $handle") | ||
end | ||
end | ||
req = map(_reg_enum, reqs) | ||
backward(op, req, in_data, out_data, in_grad, out_grad, aux) | ||
end | ||
|
||
function _create_entry(op:: Operator, _entry :: Function) | ||
cond = Base.AsyncCondition() | ||
m_entry = RawMutex.create_mutex() | ||
|
||
ref = Ref(_FB(Base.unsafe_convert(Ptr{Void}, cond), m_entry)) | ||
ptr = Base.unsafe_convert(Ptr{Void}, ref) | ||
|
||
task = @schedule begin | ||
try | ||
while true | ||
wait(cond) # Do we need to replace the AsyncCondition? | ||
_entry(op, ref[]) | ||
RawMutex.unlock(m_entry) | ||
end | ||
catch err | ||
@show err | ||
rethrow() | ||
finally | ||
Base.close(cond) | ||
RawMutex.close_mutex(m_enrty) | ||
end | ||
end | ||
_pin!(op, task) | ||
return ptr | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
immutable CustomOpPropInfo | ||
list_arguments :: Ptr{Void} | ||
list_outputs :: Ptr{Void} | ||
infer_shape :: Ptr{Void} | ||
declare_backward_dependency :: Ptr{Void} | ||
create_operator :: Ptr{Void} | ||
list_auxiliary_states :: Ptr{Void} | ||
delete :: Ptr{Void} | ||
p_list_arguments :: Ptr{Void} | ||
p_list_outputs :: Ptr{Void} | ||
p_infer_shape :: Ptr{Void} | ||
p_declare_backward_dependency :: Ptr{Void} | ||
p_create_operator :: Ptr{Void} | ||
p_list_auxiliary_states :: Ptr{Void} | ||
p_delete :: Ptr{Void} | ||
function CustomOpPropInfo(op :: CustomOpProp) | ||
payload = pointer_from_objref(op) | ||
c_infer_shape = cfunction(_infer_shape_entry, Bool, (Cint, Ptr{Ptr{Void}}, Ptr{Cint}, Ptr{Void})) | ||
c_list_outputs = cfunction(_list_outputs_entry, Bool, (Ptr{Ptr{Ptr{Cchar}}}, Ptr{Void})) | ||
c_list_arguments = cfunction(_list_arguments_entry, Bool, (Ptr{Ptr{Ptr{Cchar}}}, Ptr{Void})) | ||
c_list_auxiliary_states = cfunction(_list_auxiliary_states_entry, Bool, (Ptr{Ptr{Ptr{Cchar}}}, Ptr{Void})) | ||
c_declare_backward_dependency = cfunction(_declare_backward_dependency_entry, Bool, (Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Ptr{Cint}}, Ptr{Void})) | ||
c_delete = cfunction(_delete_entry, Void, (Ptr{Void},)) | ||
|
||
new(c_list_arguments, c_list_outputs, c_infer_shape, c_declare_backwards_dependency, c_list_auxiliary_states, c_delete, | ||
payload, payload, payload, payload, payload, payload) | ||
end | ||
end | ||
|
||
const __prop_pinned_memory = WeakKeyDict{CustomOpProp, Vector{Any}}() | ||
function _pin!(op :: CustomOpProp, x :: ANY) | ||
xs = get(__prop_pinned_memory, op, Any[]) | ||
push!(xs, x) | ||
__prop_pinned_memory[op] = xs | ||
end | ||
|
||
function _finalizer(op :: CustomOpProp) | ||
if haskey(__prop_pinned_memory) | ||
delete!(__prop_pinned_memory, op) | ||
end | ||
end | ||
|
||
function _delete_entry(payload :: Ptr{Void}) | ||
# Figure out what to do here. This is to keep this part of the memory alive | ||
end | ||
|
||
function _infer_shape_entry(num_tensor, tensor_dims, tensor_shapes, payload) | ||
try | ||
op = unsafe_pointer_to_objref(payload) :: CustomOpProp | ||
n_in = length(list_arguments(op)) | ||
n_out = length(list_outputs(op)) | ||
n_aux = length(list_auxiliary_states()) | ||
|
||
@assert num_tensor == n_in + n_out + n_aux | ||
|
||
shapes = Vector{Cuint}[] | ||
# copy and revert input shapes | ||
for i in 1:n_in | ||
# Get size of array and create julia arry | ||
ndims = unsafe_load(tensor_dims, i) | ||
shape = zeros(Cuint, ndims) | ||
tshape = unsafe_load(tensor_shapes, i) | ||
for j in 1:ndims | ||
shape[j] = unsafe_load(tshapes, ndims-j + 1) | ||
end | ||
push!(shapes, shape) | ||
end | ||
|
||
ret = infer_shape(op, shapes) | ||
if length(ret) == 2 | ||
ishapes, oshapes = ret | ||
ashapes = Cuint[] | ||
elseif lenght(ret) == 3 | ||
ishapes, oshapes, ashapes = ret | ||
else | ||
error("infer_shape must return 2 or 3 lists.") | ||
end | ||
|
||
@assert length(ishapes) == n_in | ||
@assert length(oshapes) == n_out | ||
@assert length(ashapes) == n_aux | ||
|
||
# We now have to reverse the arrays again | ||
# We can't perform a inplace operation in case the arrays share memory | ||
rshapes = Vector{Cuint} | ||
for shape in ishapes | ||
push!(rshapes, reverse(shape)) | ||
end | ||
for shape in oshapes | ||
push!(rshapes, reverse(shape)) | ||
end | ||
for shape in ashapes | ||
push!(rshapes, reverse(shape)) | ||
end | ||
|
||
_pin!(op, rshapes) | ||
|
||
for i in 1:num_tensors | ||
unsafe_store!(tensor_shapes, pointer(rshapes[i]), i) | ||
unsafe_store!(tensor_dims, length(rshapes[i]), i) | ||
end | ||
catch err | ||
println(STDERR, "Error in infer_shape: ") | ||
showerror(STDERR, err) | ||
return false | ||
end | ||
return true | ||
end | ||
|
||
function _list_arguments_entry(data :: Ptr{Ptr{Ptr{Cchar}}}, payload :: Ptr{Void}) | ||
try | ||
op = unsafe_pointer_to_objref(payload) :: CustomOpProp | ||
arguments = list_arguments(op) | ||
_pin!(op, arguments) | ||
ptrs = Ptr{Cchar}[Base.unsafe_convert(Ptr{Cchar}, s) for s in arguments] | ||
_pin!(op, ptrs) | ||
push!(ptrs, C_NULL) | ||
unsafe_store!(data, pointer(ptrs), 1) | ||
catch err | ||
println(STDERR, "Error in list_arguments: ") | ||
showerror(STDERR, err) | ||
return false | ||
end | ||
return true | ||
end | ||
|
||
function _list_outputs_entry(data :: Ptr{Ptr{Ptr{Cchar}}}, payload :: Ptr{Void}) | ||
try | ||
op = unsafe_pointer_to_objref(payload) :: CustomOpProp | ||
outputs = list_outputs(op) | ||
_pin!(op, outputs) | ||
ptrs = Ptr{Cchar}[Base.unsafe_convert(Ptr{Cchar}, s) for s in outputs] | ||
_pin!(op, ptrs) | ||
push!(ptrs, C_NULL) | ||
unsafe_store!(data, pointer(ptrs), 1) | ||
catch err | ||
println(STDERR, "Error in list_outputs: ") | ||
showerror(STDERR, err) | ||
return false | ||
end | ||
return true | ||
end | ||
|
||
function _list_auxiliary_states_entry(data :: Ptr{Ptr{Ptr{Cchar}}}, payload :: Ptr{Void}) | ||
try | ||
op = unsafe_pointer_to_objref(payload) :: CustomOpProp | ||
aux = list_auxiliary_states(op) | ||
_pin!(op, aux) | ||
ptrs = Ptr{Cchar}[Base.unsafe_convert(Ptr{Cchar}, s) for s in aux] | ||
_pin!(op, ptrs) | ||
push!(ptrs, C_NULL) | ||
unsafe_store!(data, pointer(ptrs), 1) | ||
catch err | ||
println(STDERR, "Error in list_auxiliary_states: ") | ||
showerror(STDERR, err) | ||
return false | ||
end | ||
return true | ||
end | ||
|
||
function _declare_backward_dependency(_out_grad :: Ptr{Cint}, | ||
_in_data :: Ptr{Cint}, | ||
_out_data :: Ptr{Cint} | ||
num_dep :: Ptr{Cint}, | ||
deps :: Ptr{Ptr{Cint}}, | ||
payload :: Ptr{Void}) | ||
try | ||
op = unsafe_pointer_to_objref(payload) :: CustomOpProp | ||
out_grad = unsafe_wrap(Array, _out_grad, length(list_outputs(op))) | ||
in_data = unsafe_wrap(Array, _in_data, length(list_arguments(op))) | ||
out_data = unsafe_wrap(Array, _out_data, length(list_outputs(op))) | ||
|
||
rdeps = convert(Vector{Cint}, declare_backward_dependency(op, out_grad, in_data, out_data)) | ||
_pin!(op, rdeps) | ||
|
||
unsafe_store!(num_dep, length(rdeps), 1) | ||
unsafe_store!(deps, pointer(rdeps), 1) | ||
catch err | ||
println(STDERR, "Error in declare_backward_dependency: ") | ||
showerror(STDERR, err) | ||
return false | ||
end | ||
return true | ||
end | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_FB does not have field
b_exit
?