Skip to content

Commit

Permalink
add custom handler for ptr_to_array runtime call
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Jan 8, 2025
1 parent 2309abd commit 79c8beb
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/compiler/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,13 @@ function __init__()
"jl_get_keyword_sorter",
"ijl_get_keyword_sorter",
"jl_ptr_to_array",
"ijl_ptr_to_array",
"jl_box_float32",
"ijl_box_float32",
"jl_box_float64",
"ijl_box_float64",
"jl_ptr_to_array_1d",
"ijl_ptr_to_array_1d",
"jl_eqtable_get",
"ijl_eqtable_get",
"memcmp",
Expand Down
62 changes: 62 additions & 0 deletions src/rules/llvmrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1779,6 +1779,62 @@ end
return nothing
end

@register_fwd function jl_ptr_to_array_fwd(B, orig, gutils, normalR, shadowR)
if is_constant_inst(gutils, orig)
return true
end
origops = collect(operands(orig))
width = get_width(gutils)
origops = collect(operands(orig))
width = get_width(gutils)

args = LLVM.Value[
new_from_original(gutils, origops[1]),
invert_pointer(gutils, origops[2], B), # data
new_from_original(gutils, origops[3]),
new_from_original(gutils, origops[4]),
]
valTys = API.CValueType[
API.VT_Primal,
API.VT_Shadow,
API.VT_Primal,
API.VT_Primal,
]

if width == 1
vargs = args
cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, vargs, valTys, false) #=lookup=#
debug_from_orig!(gutils, cal, orig)
callconv!(cal, callconv(orig))
shadowres = cal
else
shadowres =
UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))))
for idx = 1:width
vargs = LLVM.Value[
args[1],
extract_value!(B, args[2], idx - 1), # data
args[3],
args[4],
]
cal =
call_samefunc_with_inverted_bundles!(B, gutils, orig, vargs, valTys, false) #=lookup=#
debug_from_orig!(gutils, cal, orig)
callconv!(cal, callconv(orig))
shadowres = insert_value!(B, shadowres, call, idx - 1)
end
end
unsafe_store!(shadowR, shadowres.ref)

return false
end
@register_aug function jl_ptr_to_array_augfwd(B, orig, gutils, normalR, shadowR, tapeR)
jl_ptr_to_array_fwd(B, orig, gutils, normalR, shadowR)
end
@register_rev function jl_ptr_to_array_rev(B, orig, gutils, tape)
return nothing
end

@register_fwd function genericmemory_copyto_fwd(B, orig, gutils, normalR, shadowR)
if is_constant_inst(gutils, orig)
return true
Expand Down Expand Up @@ -2400,6 +2456,12 @@ end
@revfunc(jl_array_ptr_copy_rev),
@fwdfunc(jl_array_ptr_copy_fwd),
)
register_handler!(
("jl_ptr_to_array_1d", "ijl_ptr_to_array_1d", "jl_ptr_to_array", "ijl_ptr_to_array"),
@augfunc(jl_ptr_to_array_augfwd),
@revfunc(jl_ptr_to_array_rev),
@fwdfunc(jl_ptr_to_array_fwd),
)
register_handler!(
(),
@augfunc(jl_unhandled_augfwd),
Expand Down
15 changes: 15 additions & 0 deletions test/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,18 @@ end
@test dB[1] === dA1
@test dB[2] === dA2
end

function unsafe_wrap_test(a, i, x)
GC.@preserve a begin
ptr = pointer(a)
b = Base.unsafe_wrap(Array, ptr, length(a))
b[i] = x
end
a[i]
end

@testset "Unsafe wrap" begin
autodiff(Forward, f, Duplicated(zeros(1), zeros(1)), Const(1), Duplicated(1.0, 2.0))

# TODO test for batch and reverse
end

0 comments on commit 79c8beb

Please sign in to comment.