Skip to content

Commit

Permalink
Enzyme: adapt to pending version breaking update
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 16, 2024
1 parent d7077da commit 656d61d
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 62 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ CUDA_Runtime_jll = "0.14"
ChainRulesCore = "1"
Crayons = "4"
DataFrames = "1"
EnzymeCore = "0.7.3"
EnzymeCore = "0.8"
ExprTools = "0.1"
GPUArrays = "10.0.1"
GPUCompiler = "0.24, 0.25, 0.26"
Expand Down
148 changes: 87 additions & 61 deletions ext/EnzymeCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ function metaf(fn, args::Vararg{Any, N}) where N
nothing
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cufunction)},
function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(cufunction)},
::Type{<:Duplicated}, f::Const{F},
tt::Const{TT}; kwargs...) where {F,TT}
res = ofn.val(f.val, tt.val; kwargs...)
return Duplicated(res, res)
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cufunction)},
function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(cufunction)},
::Type{BatchDuplicated{T,N}}, f::Const{F},
tt::Const{TT}; kwargs...) where {F,TT,T,N}
res = ofn.val(f.val, tt.val; kwargs...)
Expand All @@ -52,7 +52,7 @@ function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cufunction)},
end)
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cudaconvert)},
function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(cudaconvert)},
::Type{RT}, x::IT) where {RT, IT}
if RT <: Duplicated
RT(ofn.val(x.val), ofn.val(x.dval))
Expand All @@ -73,34 +73,45 @@ function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cudaconvert)},
end
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{Type{CT}},
function EnzymeCore.EnzymeRules.forward(config, ofn::Const{Type{CT}},
::Type{RT}, uval::EnzymeCore.Annotation{UndefInitializer}, args...) where {CT <: CuArray, RT}
primargs = ntuple(Val(length(args))) do i
Base.@_inline_meta
args[i].val
end
if RT <: Duplicated
shadow = ofn.val(uval.val, primargs...)::CT
fill!(shadow, 0)
Duplicated(ofn.val(uval.val, primargs...), shadow)
elseif RT <: Const
ofn.val(uval.val, primargs...)
elseif RT <: DuplicatedNoNeed
shadow = ofn.val(uval.val, primargs...)::CT
fill!(shadow, 0)
shadow::CT
else
tup = ntuple(Val(EnzymeCore.batch_size(RT))) do i
Base.@_inline_meta

if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
shadow = ofn.val(uval.val, primargs...)::CT
fill!(shadow, 0)
shadow::CT
Duplicated(ofn.val(uval.val, primargs...), shadow)
else
tup = ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
shadow = ofn.val(uval.val, primargs...)::CT
fill!(shadow, 0)
shadow::CT
end
BatchDuplicated(ofn.val(uval.val, primargs...), tup)
end
if RT <: BatchDuplicated
BatchDuplicated(ofv.val(uval.val), tup)
elseif EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
shadow = ofn.val(uval.val, primargs...)::CT
fill!(shadow, 0)
shadow
else
tup = ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
shadow = ofn.val(uval.val, primargs...)::CT
fill!(shadow, 0)
shadow::CT
end
tup
end
elseif EnzymeRules.needs_primal(config)
ofn.val(uval.val, primargs...)
else
nothing
end
end

Expand All @@ -110,54 +121,71 @@ function EnzymeCore.EnzymeRules.forward(ofn::Const{Type{CT}},
Base.@_inline_meta
args[i].val
end
if RT <: Duplicated
shadow = ofn.val(uval.val, primargs...; kwargs...)
Duplicated(ofn.val(uval.dval, primargs...; kwargs...), shadow)
elseif RT <: Const
ofn.val(uval.val, primargs...; kwargs...)
elseif RT <: DuplicatedNoNeed
ofn.val(uval.dval, primargs...; kwargs...)
else
tup = ntuple(Val(EnzymeCore.batch_size(RT))) do i
Base.@_inline_meta
shadow = ofn.val(uval.dval[i], primargs...; kwargs...)

if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
shadow = ofn.val(uval.val, primargs...; kwargs...)
Duplicated(ofn.val(uval.val, primargs...; kwargs...), shadow)
else
tup = ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
ofn.val(uval.val, primargs...; kwargs...)
end
BatchDuplicated(ofn.val(uval.val, primargs...; kwargs...), tup)
end
if RT <: BatchDuplicated
BatchDuplicated(ofv.val(uval.val), tup)
elseif EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
shadow = ofn.val(uval.val, primargs...; kwargs...)
shadow
else
tup = ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
ofn.val(uval.val, primargs...; kwargs...)
end
tup
end
elseif EnzymeRules.needs_primal(config)
ofn.val(uval.val, primargs...; kwargs...)
else
nothing
end
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(synchronize)},
function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(synchronize)},
::Type{RT}, args::Vararg{EnzymeCore.Annotation, N}; kwargs...) where {RT, N}
pargs = ntuple(Val(N)) do i
Base.@_inline_meta
args[i].val
end
res = ofn.val(pargs...; kwargs...)

if RT <: Duplicated
return Duplicated(res, res)
elseif RT <: Const
return res
elseif RT <: DuplicatedNoNeed
return res
else
tup = ntuple(Val(EnzymeCore.batch_size(RT))) do i
Base.@_inline_meta
res
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
Duplicated(res, res)
else
tup = ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
res
end
BatchDuplicated(ofn.val(uval.val, primargs...; kwargs...), tup)
end
if RT <: BatchDuplicated
return BatchDuplicated(res, tup)
elseif EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
res
else
return tup
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
res
end
end
elseif EnzymeRules.needs_primal(config)
res
else
nothing
end
end

function EnzymeCore.EnzymeRules.forward(ofn::EnzymeCore.Annotation{CUDA.HostKernel{F,TT}},
function EnzymeCore.EnzymeRules.forward(config, ofn::EnzymeCore.Annotation{CUDA.HostKernel{F,TT}},
::Type{Const{Nothing}}, args...;
kwargs...) where {F,TT}

Expand Down Expand Up @@ -195,7 +223,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{typeof(cufun
else
nothing
end
return EnzymeRules.AugmentedReturn{(EnzymeRules.needs_primal(config) ? CT : Nothing), (EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? CT : NTuple{EnzymeRules.width(config), CT}) : Nothing), Nothing}(primal, shadow, nothing)
return EnzymeRules.AugmentedReturn{EnzymeRules.primal_type(config, RT), EnzymeRules.shadow_type(config, RT), Nothing}(primal, shadow, nothing)
end

function EnzymeCore.EnzymeRules.reverse(config, ofn::EnzymeCore.Const{typeof(cufunction)},::Type{RT}, subtape, f, tt; kwargs...) where RT
Expand Down Expand Up @@ -322,7 +350,7 @@ function EnzymeCore.EnzymeRules.reverse(config, ofn::EnzymeCore.Annotation{CUDA.
end
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(Base.fill!)}, ::Type{RT}, A::EnzymeCore.Annotation{<:DenseCuArray{T}}, x) where {RT, T <: CUDA.MemsetCompatTypes}
function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(Base.fill!)}, ::Type{RT}, A::EnzymeCore.Annotation{<:DenseCuArray{T}}, x) where {RT, T <: CUDA.MemsetCompatTypes}
if A isa Const || A isa Duplicated || A isa BatchDuplicated
ofn.val(A.val, x.val)
end
Expand All @@ -337,16 +365,14 @@ function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(Base.fill!)}, ::Type{R
end
end

if RT <: Duplicated
return A
elseif RT <: Const
return A.val
elseif RT <: DuplicatedNoNeed
return A.dval
elseif RT <: BatchDuplicated
return A
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
A
elseif EnzymeRules.needs_shadow(config)
A.dval
elseif EnzymeRules.needs_primal(config)
A.val
else
return A.dval
nothing
end
end

Expand Down Expand Up @@ -441,7 +467,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{Type{CT}}, :
else
nothing
end
return EnzymeRules.AugmentedReturn{(EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing), (EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeRules.width(config), eltype(RT)}) : Nothing), Nothing}(primal, shadow, nothing)
return EnzymeRules.AugmentedReturn{EnzymeRules.primal_type(config, RT), EnzymeRules.shadow_type(config, RT), Nothing}(primal, shadow, nothing)
end

function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{Type{CT}}, ::Type{RT}, tape, A::EnzymeCore.Annotation{UndefInitializer}, args::Vararg{EnzymeCore.Annotation, N}) where {CT <: CuArray, RT, N}
Expand Down Expand Up @@ -475,7 +501,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{Type{CT}}, :
else
nothing
end
return EnzymeRules.AugmentedReturn{(EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing), (EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeRules.width(config), eltype(RT)}) : Nothing), Nothing}(primal, shadow, nothing)
return EnzymeRules.AugmentedReturn{EnzymeRules.primal_type(config, RT), EnzymeRules.shadow_type(config, RT), Nothing}(primal, shadow, nothing)
end

function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{Type{CT}}, ::Type{RT}, tape, A::EnzymeCore.Annotation{DR}, args::Vararg{EnzymeCore.Annotation, N}; kwargs...) where {CT <: CuArray, DR <: CUDA.DataRef, RT, N}
Expand Down

0 comments on commit 656d61d

Please sign in to comment.