Skip to content

Commit

Permalink
Enzyme: adapt to pending version breaking update
Browse files Browse the repository at this point in the history
[only downstream]
  • Loading branch information
wsmoses committed Sep 18, 2024
1 parent cab8f2d commit bf47fa4
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 109 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ CUDA_Runtime_jll = "0.15"
ChainRulesCore = "1"
Crayons = "4"
DataFrames = "1"
EnzymeCore = "0.7.3"
EnzymeCore = "0.8"
ExprTools = "0.1"
GPUArrays = "10.0.1"
GPUCompiler = "0.24, 0.25, 0.26, 0.27"
Expand Down
250 changes: 142 additions & 108 deletions ext/EnzymeCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ function metaf(fn, args::Vararg{Any, N}) where N
nothing
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cufunction)},
function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(cufunction)},
::Type{<:Duplicated}, f::Const{F},
tt::Const{TT}; kwargs...) where {F,TT}
res = ofn.val(f.val, tt.val; kwargs...)
return Duplicated(res, res)
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cufunction)},
function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(cufunction)},
::Type{BatchDuplicated{T,N}}, f::Const{F},
tt::Const{TT}; kwargs...) where {F,TT,T,N}
res = ofn.val(f.val, tt.val; kwargs...)
Expand All @@ -53,24 +53,32 @@ function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cufunction)},
end)
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cudaconvert)},
function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(cudaconvert)},
::Type{RT}, x::IT) where {RT, IT}
if RT <: Duplicated
Duplicated(ofn.val(x.val), ofn.val(x.dval))
elseif RT <: Const
ofn.val(x.val)::eltype(RT)
elseif RT <: DuplicatedNoNeed
ofn.val(x.val)::eltype(RT)
else
tup = ntuple(Val(EnzymeCore.batch_size(RT))) do i
Base.@_inline_meta
ofn.val(x.dval[i])::eltype(RT)

if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
Duplicated(ofn.val(x.val), ofn.val(x.dval))
else
tup = ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
ofn.val(x.dval[i])::eltype(RT)
end
BatchDuplicated(ofn.val(x.val), tup)
end
if RT <: BatchDuplicated
BatchDuplicated(ofv.val(x.val), tup)
elseif EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
ofn.val(x.dval)::eltype(RT)
else
tup
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
ofn.val(x.dval[i])::eltype(RT)
end
end
elseif EnzymeRules.needs_primal(config)
ofn.val(uval.val)::eltype(RT)
else
nothing
end
end

Expand All @@ -93,99 +101,127 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{typeof(cudac
else
nothing
end
return EnzymeRules.AugmentedReturn{(EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing), (EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeRules.width(config), eltype(RT)}) : Nothing), Nothing}(primal, shadow, nothing)
return EnzymeRules.AugmentedReturn{EnzymeRules.primal_type(config, RT), EnzymeRules.shadow_type(config, RT), Nothing}(primal, shadow, nothing)
end

function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{typeof(cudaconvert)}, ::Type{RT}, tape, x::IT) where {RT, IT}
(nothing,)
end


function EnzymeCore.EnzymeRules.forward(ofn::Const{Type{CT}},
function EnzymeCore.EnzymeRules.forward(config, ofn::Const{Type{CT}},
::Type{RT}, uval::EnzymeCore.Annotation{UndefInitializer}, args...) where {CT <: CuArray, RT}
primargs = ntuple(Val(length(args))) do i
Base.@_inline_meta
args[i].val
end
if RT <: Duplicated
shadow = ofn.val(uval.val, primargs...)::CT
fill!(shadow, 0)
Duplicated(ofn.val(uval.val, primargs...), shadow)
elseif RT <: Const
ofn.val(uval.val, primargs...)
elseif RT <: DuplicatedNoNeed
shadow = ofn.val(uval.val, primargs...)::CT
fill!(shadow, 0)
shadow::CT
else
tup = ntuple(Val(EnzymeCore.batch_size(RT))) do i
Base.@_inline_meta

if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
shadow = ofn.val(uval.val, primargs...)::CT
fill!(shadow, 0)
shadow::CT
Duplicated(ofn.val(uval.val, primargs...), shadow)
else
tup = ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
shadow = ofn.val(uval.val, primargs...)::CT
fill!(shadow, 0)
shadow::CT
end
BatchDuplicated(ofn.val(uval.val, primargs...), tup)
end
if RT <: BatchDuplicated
BatchDuplicated(ofv.val(uval.val), tup)
elseif EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
shadow = ofn.val(uval.val, primargs...)::CT
fill!(shadow, 0)
shadow
else
tup = ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
shadow = ofn.val(uval.val, primargs...)::CT
fill!(shadow, 0)
shadow::CT
end
tup
end
elseif EnzymeRules.needs_primal(config)
ofn.val(uval.val, primargs...)
else
nothing
end
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{Type{CT}},
function EnzymeCore.EnzymeRules.forward(config, ofn::Const{Type{CT}},
::Type{RT}, uval::EnzymeCore.Annotation{DR}, args...; kwargs...) where {CT <: CuArray, DR <: CUDA.DataRef, RT}
primargs = ntuple(Val(length(args))) do i
Base.@_inline_meta
args[i].val
end
if RT <: Duplicated
shadow = ofn.val(uval.val, primargs...; kwargs...)
Duplicated(ofn.val(uval.dval, primargs...; kwargs...), shadow)
elseif RT <: Const
ofn.val(uval.val, primargs...; kwargs...)
elseif RT <: DuplicatedNoNeed
ofn.val(uval.dval, primargs...; kwargs...)
else
tup = ntuple(Val(EnzymeCore.batch_size(RT))) do i
Base.@_inline_meta
shadow = ofn.val(uval.dval[i], primargs...; kwargs...)

if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
shadow = ofn.val(uval.val, primargs...; kwargs...)
Duplicated(ofn.val(uval.val, primargs...; kwargs...), shadow)
else
tup = ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
ofn.val(uval.val, primargs...; kwargs...)
end
BatchDuplicated(ofn.val(uval.val, primargs...; kwargs...), tup)
end
if RT <: BatchDuplicated
BatchDuplicated(ofv.val(uval.val), tup)
elseif EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
shadow = ofn.val(uval.val, primargs...; kwargs...)
shadow
else
tup = ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
ofn.val(uval.val, primargs...; kwargs...)
end
tup
end
elseif EnzymeRules.needs_primal(config)
ofn.val(uval.val, primargs...; kwargs...)
else
nothing
end
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(synchronize)},
function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(synchronize)},
::Type{RT}, args::Vararg{EnzymeCore.Annotation, N}; kwargs...) where {RT, N}
pargs = ntuple(Val(N)) do i
Base.@_inline_meta
args[i].val
end
res = ofn.val(pargs...; kwargs...)

if RT <: Duplicated
return Duplicated(res, res)
elseif RT <: Const
return res
elseif RT <: DuplicatedNoNeed
return res
else
tup = ntuple(Val(EnzymeCore.batch_size(RT))) do i
Base.@_inline_meta
res
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
Duplicated(res, res)
else
tup = ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
res
end
BatchDuplicated(ofn.val(uval.val, primargs...; kwargs...), tup)
end
if RT <: BatchDuplicated
return BatchDuplicated(res, tup)
elseif EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
res
else
return tup
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
res
end
end
elseif EnzymeRules.needs_primal(config)
res
else
nothing
end
end

function EnzymeCore.EnzymeRules.forward(ofn::EnzymeCore.Annotation{CUDA.HostKernel{F,TT}},
function EnzymeCore.EnzymeRules.forward(config, ofn::EnzymeCore.Annotation{CUDA.HostKernel{F,TT}},
::Type{Const{Nothing}}, args...;
kwargs...) where {F,TT}

Expand Down Expand Up @@ -223,7 +259,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{typeof(cufun
else
nothing
end
return EnzymeRules.AugmentedReturn{(EnzymeRules.needs_primal(config) ? CT : Nothing), (EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? CT : NTuple{EnzymeRules.width(config), CT}) : Nothing), Nothing}(primal, shadow, nothing)
return EnzymeRules.AugmentedReturn{EnzymeRules.primal_type(config, RT), EnzymeRules.shadow_type(config, RT), Nothing}(primal, shadow, nothing)
end

function EnzymeCore.EnzymeRules.reverse(config, ofn::EnzymeCore.Const{typeof(cufunction)},::Type{RT}, subtape, f, tt; kwargs...) where RT
Expand Down Expand Up @@ -350,7 +386,7 @@ function EnzymeCore.EnzymeRules.reverse(config, ofn::EnzymeCore.Annotation{CUDA.
end
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(Base.fill!)}, ::Type{RT}, A::EnzymeCore.Annotation{<:DenseCuArray{T}}, x) where {RT, T <: CUDA.MemsetCompatTypes}
function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(Base.fill!)}, ::Type{RT}, A::EnzymeCore.Annotation{<:DenseCuArray{T}}, x) where {RT, T <: CUDA.MemsetCompatTypes}
if A isa Const || A isa Duplicated || A isa BatchDuplicated
ofn.val(A.val, x.val)
end
Expand All @@ -365,16 +401,14 @@ function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(Base.fill!)}, ::Type{R
end
end

if RT <: Duplicated
return A
elseif RT <: Const
return A.val
elseif RT <: DuplicatedNoNeed
return A.dval
elseif RT <: BatchDuplicated
return A
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
A
elseif EnzymeRules.needs_shadow(config)
A.dval
elseif EnzymeRules.needs_primal(config)
A.val
else
return A.dval
nothing
end
end

Expand Down Expand Up @@ -469,7 +503,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{Type{CT}}, :
else
nothing
end
return EnzymeRules.AugmentedReturn{(EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing), (EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeRules.width(config), eltype(RT)}) : Nothing), Nothing}(primal, shadow, nothing)
return EnzymeRules.AugmentedReturn{EnzymeRules.primal_type(config, RT), EnzymeRules.shadow_type(config, RT), Nothing}(primal, shadow, nothing)
end

function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{Type{CT}}, ::Type{RT}, tape, A::EnzymeCore.Annotation{UndefInitializer}, args::Vararg{EnzymeCore.Annotation, N}) where {CT <: CuArray, RT, N}
Expand Down Expand Up @@ -503,7 +537,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{Type{CT}}, :
else
nothing
end
return EnzymeRules.AugmentedReturn{(EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing), (EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeRules.width(config), eltype(RT)}) : Nothing), Nothing}(primal, shadow, nothing)
return EnzymeRules.AugmentedReturn{EnzymeRules.primal_type(config, RT), EnzymeRules.shadow_type(config, RT), Nothing}(primal, shadow, nothing)
end

function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{Type{CT}}, ::Type{RT}, tape, A::EnzymeCore.Annotation{DR}, args::Vararg{EnzymeCore.Annotation, N}; kwargs...) where {CT <: CuArray, DR <: CUDA.DataRef, RT, N}
Expand All @@ -517,7 +551,7 @@ function EnzymeCore.EnzymeRules.noalias(::Type{CT}, ::UndefInitializer, args...)
return nothing
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(GPUArrays.mapreducedim!)},
function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(GPUArrays.mapreducedim!)},
::Type{RT},
f::EnzymeCore.Const{typeof(Base.identity)},
op::EnzymeCore.Const{typeof(Base.add_sum)},
Expand All @@ -544,16 +578,14 @@ function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(GPUArrays.mapreducedim
end
end

if RT <: Duplicated
return R
elseif RT <: Const
return R.val
elseif RT <: DuplicatedNoNeed
return R.dval
elseif RT <: BatchDuplicated
return R
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
R
elseif EnzymeRules.needs_shadow(config)
R.dval
elseif EnzymeRules.needs_primal(config)
R.val
else
return R.dval
nothing
end
end

Expand Down Expand Up @@ -605,34 +637,36 @@ function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{typeof(GPUArrays.mapr
return (nothing, nothing, nothing, nothing)
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(GPUArrays._mapreduce)},
function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(GPUArrays._mapreduce)},
::Type{RT},
f::EnzymeCore.Const{typeof(Base.identity)},
op::EnzymeCore.Const{typeof(Base.add_sum)},
A::EnzymeCore.Annotation{<:AnyCuArray{T}}; dims::D, init) where {RT, T, D}
if RT <: Const

if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
shadow = ofn.val(f.val, op.val, A.dval; dims, init)
Duplicated(ofn.val(f.val, op.val, A.val; dims, init), shadow)
else
tup = ntuple(Val(EnzymeRules.batch_width(RT))) do i
Base.@_inline_meta
ofn.val(f.val, op.val, A.dval[i]; dims, init)
end
BatchDuplicated(ofn.val(f.val, op.val, A.val; dims, init), tup)
end
elseif EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
ofn.val(f.val, op.val, A.dval; dims, init)
else
ntuple(Val(EnzymeRules.batch_width(RT))) do i
Base.@_inline_meta
ofn.val(f.val, op.val, A.dval[i]; dims, init)
end
end
elseif EnzymeRules.needs_primal(config)
ofn.val(f.val, op.val, A.val; dims, init)
elseif RT <: Duplicated
(
ofn.val(f.val, op.val, A.val; dims, init),
ofn.val(f.val, op.val, A.dval; dims, init)
)
elseif RT <: DuplicatedNoNeed
ofn.val(f.val, op.val, A.dval; dims, init)
elseif RT <: BatchDuplicated
(
ofn.val(f.val, op.val, A.val; dims, init),
ntuple(Val(EnzymeRules.batch_width(RT))) do i
Base.@_inline_meta
ofn.val(f.val, op.val, A.dval[i]; dims, init)
end
)
else
@assert RT <: BatchDuplicatedNoNeed
ntuple(Val(EnzymeRules.batch_width(RT))) do i
Base.@_inline_meta
ofn.val(f.val, op.val, A.dval[i]; dims, init)
end
nothing
end
end

Expand Down

0 comments on commit bf47fa4

Please sign in to comment.