Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt to GPUCompiler 0.18 #673

Merged
merged 10 commits into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
CEnum = "0.4"
EnzymeCore = "0.2.1"
Enzyme_jll = "0.0.51"
GPUCompiler = "0.16.7, 0.17"
GPUCompiler = "0.18"
LLVM = "4.14"
ObjectFile = "0.3"
julia = "1.6"
58 changes: 38 additions & 20 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,14 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0))

ModifiedBetween = Val(falses_from_args(Val(1), args...))

tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}
world = GPUCompiler.get_world(Core.Typeof(f.val), tt)

if A <: Active
tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}
rt = Core.Compiler.return_type(f.val, tt)
if !allocatedinline(rt) || rt isa Union
forward, adjoint = Enzyme.Compiler.thunk(FA, Duplicated{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(true))
forward, adjoint = Enzyme.Compiler.thunk(Val(world), FA, Duplicated{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(true))
res = forward(f, args′...)
tape = res[1]
if ReturnPrimal
Expand All @@ -193,10 +196,10 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0))
elseif A <: Duplicated || A<: DuplicatedNoNeed || A <: BatchDuplicated || A<: BatchDuplicatedNoNeed
throw(ErrorException("Duplicated Returns not yet handled"))
end
thunk = Enzyme.Compiler.thunk(FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal))
thunk = Enzyme.Compiler.thunk(Val(world), FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal))
if A <: Active
tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}
rt = eltype(Compiler.return_type(thunk))
rt = Core.Compiler.return_type(f.val, tt)
args′ = (args′..., one(rt))
end
thunk(f, args′...)
Expand Down Expand Up @@ -305,18 +308,15 @@ f(x) = x*x
end

ModifiedBetween = Val(falses_from_args(Val(1), args...))

tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}
world = GPUCompiler.get_world(Core.Typeof(f.val), tt)

thunk = Enzyme.Compiler.thunk(FA, RT, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width),
thunk = Enzyme.Compiler.thunk(Val(world), FA, RT, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width),
ModifiedBetween, ReturnPrimal)
thunk(f, args′...)
end

# F as first arg for `do` syntax
@inline autodiff(dupf::Duplicated{F}, mode::Mode, ::Type{A}, args...) where {F,A<:Annotation} = autodiff(mode, dupf, A, args...)
@inline autodiff(f::F, mode::Mode, ::Type{A}, args...) where {F,A<:Annotation} = autodiff(mode, f, A, args...)
@inline autodiff(dupf::Duplicated{F}, mode::Mode, args...) where {F} = autodiff(mode, dupf, args...)
@inline autodiff(f::F, mode::Mode, args...) where {F} = autodiff(mode, f, args...)

"""
autodiff_deferred(::ReverseMode, f, Activity, args...)

Expand All @@ -330,8 +330,11 @@ code, as well as high-order differentiation.
if width == 0
throw(ErrorException("Cannot differentiate with a batch size of 0"))
end
tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}

world = GPUCompiler.get_world(Core.Typeof(f.val), tt)

if A isa UnionAll
tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}
rt = Core.Compiler.return_type(f.val, tt)
rt = A{rt}
else
Expand All @@ -345,7 +348,7 @@ code, as well as high-order differentiation.

ModifiedBetween = Val(falses_from_args(Val(1), args...))

adjoint_ptr, primal_ptr = Compiler.deferred_codegen(FA, Val(tt′), Val(rt), Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal))
adjoint_ptr, primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal))
@assert primal_ptr === nothing
thunk = Compiler.CombinedAdjointThunk{FA, rt, tt′, typeof(Val(width)), Val(ReturnPrimal)}(adjoint_ptr)
if rt <: Active
Expand Down Expand Up @@ -387,8 +390,11 @@ code, as well as high-order differentiation.
else
A
end
tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}

world = GPUCompiler.get_world(Core.Typeof(f.val), tt)

if RT isa UnionAll
tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}
rt = Core.Compiler.return_type(f.val, tt)
rt = RT{rt}
else
Expand All @@ -407,7 +413,8 @@ code, as well as high-order differentiation.
ReturnPrimal = Val(RT <: Duplicated || RT <: BatchDuplicated)
ModifiedBetween = Val(falses_from_args(Val(1), args...))

adjoint_ptr, primal_ptr = Compiler.deferred_codegen(FA, Val(tt′), Val(rt), Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal)

adjoint_ptr, primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal)
@assert primal_ptr === nothing
thunk = Compiler.ForwardModeThunk{FA, rt, tt′, typeof(Val(width)), ReturnPrimal}(adjoint_ptr)
thunk(f, args′...)
Expand All @@ -430,6 +437,7 @@ Like [`autodiff_deferred`](@ref) but will try to guess the activity of the retur
@inline function autodiff_deferred(mode::M, f::FA, args...) where {FA<:Annotation, M<:Mode}
args′ = annotate(args...)
tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}
world = GPUCompiler.get_world(Core.Typeof(f.val), tt)
rt = Core.Compiler.return_type(f.val, tt)
if rt === Union{}
error("return type is Union{}, giving up.")
Expand Down Expand Up @@ -499,8 +507,12 @@ result, ∂v, ∂A
ModifiedBetween = Val(ModifiedBetweenT)
end

tt = Tuple{map(eltype, args)...}

world = GPUCompiler.get_world(eltype(FA), tt)

@assert ReturnShadow
Enzyme.Compiler.thunk(FA, A, Tuple{args...}, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), parent_job)
Enzyme.Compiler.thunk(Val(world), FA, A, Tuple{args...}, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), parent_job)
end

"""
Expand Down Expand Up @@ -567,12 +579,15 @@ result, ∂v, ∂A
@assert ReturnShadow
TT = Tuple{args...}

primal_tt = Tuple{map(eltype, args)...}
world = GPUCompiler.get_world(eltype(FA), primal_tt)

# TODO this assumes that the thunk here has the correct parent/etc things for getting the right cuda instructions -> same caching behavior
nondef = Enzyme.Compiler.thunk(FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal))
nondef = Enzyme.Compiler.thunk(Val(world), FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal))
TapeType = Compiler.get_tape_type(typeof(nondef[1]))
A2 = Compiler.return_type(typeof(nondef[1]))

adjoint_ptr, primal_ptr = Compiler.deferred_codegen(FA, Val(TT), Val(A2), Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType)
adjoint_ptr, primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(A2), Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType)
AugT = Compiler.AugmentedForwardThunk{FA, A2, TT, Val{width}, Val(ReturnPrimal), TapeType}
@assert AugT == typeof(nondef[1])
AdjT = Compiler.AdjointThunk{FA, A2, TT, Val{width}, TapeType}
Expand Down Expand Up @@ -842,18 +857,20 @@ grad = jacobian(Reverse, f, [2.0, 3.0], Val(2))

tt′ = Tuple{BatchDuplicated{Core.Typeof(x), chunk}}
tt = Tuple{Core.Typeof(x)}
world = GPUCompiler.get_world(Core.Typeof(f), tt)
rt = Core.Compiler.return_type(f, tt)
ModifiedBetween = Val((false, false))
FA = Const{Core.Typeof(f)}
primal, adjoint = Enzyme.Compiler.thunk(FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(chunk), ModifiedBetween)
World = Val(nothing)
primal, adjoint = Enzyme.Compiler.thunk(Val(world), FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(chunk), ModifiedBetween)

if num * chunk == n_out_val
last_size = chunk
primal2, adjoint2 = primal, adjoint
else
last_size = n_out_val - (num-1)*chunk
tt′ = Tuple{BatchDuplicated{Core.Typeof(x), last_size}}
primal2, adjoint2 = Enzyme.Compiler.thunk(FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(last_size), ModifiedBetween)
primal2, adjoint2 = Enzyme.Compiler.thunk(Val(world), FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(last_size), ModifiedBetween)
end

tmp = ntuple(num) do i
Expand All @@ -879,10 +896,11 @@ end
@inline function jacobian(::ReverseMode, f::F, x::X, n_outs::Val{n_out_val}, ::Val{1} = Val(1)) where {F, X, n_out_val}
tt′ = Tuple{Duplicated{Core.Typeof(x)}}
tt = Tuple{Core.Typeof(x)}
world = GPUCompiler.get_world(Core.Typeof(f), tt)
rt = Core.Compiler.return_type(f, tt)
ModifiedBetween = Val((false, false))
FA = Const{Core.Typeof(f)}
primal, adjoint = Enzyme.Compiler.thunk(FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween)
primal, adjoint = Enzyme.Compiler.thunk(Val(world), FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween)
rows = ntuple(n_outs) do i
Base.@_inline_meta
dx = zero(x)
Expand Down
Loading