Skip to content

Commit

Permalink
Fix cholesky part 2 (#1256)
Browse files Browse the repository at this point in the history
* Fix cholesky part 2

* Mark char as inactive
  • Loading branch information
wsmoses authored Jan 28, 2024
1 parent 4dfe204 commit 3095a4f
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 9 deletions.
18 changes: 9 additions & 9 deletions src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ end

@inline EnzymeRules.inactive_type(v::Type{Nothing}) = true
@inline EnzymeRules.inactive_type(v::Type{Union{}}) = true
@inline EnzymeRules.inactive_type(v::Type{Char}) = true
@inline EnzymeRules.inactive_type(v::Type{T}) where {T<:Integer} = true
@inline EnzymeRules.inactive_type(v::Type{Function}) = true
@inline EnzymeRules.inactive_type(v::Type{T}) where {T<:DataType} = true
Expand Down Expand Up @@ -138,7 +139,7 @@ function EnzymeRules.forward(::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDupli
end

# Deepcopy preserving the primal if runtime inactive
@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, shadow::RT) where {RT <: Integer}
@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, shadow::RT) where {RT <: Union{Integer, Char}}
return Base.deepcopy_internal(shadow, seen)
end
@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, shadow::RT) where {RT <: AbstractFloat}
Expand Down Expand Up @@ -684,7 +685,7 @@ function EnzymeRules.forward(
end
end

function EnzymeRules.augmented_primal(config, func::Const{typeof(cholesky)}, RT::Type, A; kwargs...)
function EnzymeRules.augmented_primal(config, func::Const{typeof(cholesky)}, RT::Type, A::Annotation{AT}; kwargs...) where {AT <: Array}
fact = if EnzymeRules.needs_primal(config)
cholesky(A.val; kwargs...)
else
Expand All @@ -696,11 +697,11 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(cholesky)}, RT:
nothing
else
if EnzymeRules.width(config) == 1
Cholesky(Matrix(fact), 'L', 0)
Enzyme.make_zero(fact)
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
Cholesky(Matrix(fact), 'L', 0)
Enzyme.make_zero(fact)
end
end
end
Expand All @@ -718,8 +719,8 @@ function EnzymeRules.reverse(
::Const{typeof(cholesky)},
RT::Type,
dfact,
A;
kwargs...)
A::Annotation{AT};
kwargs...) where {AT <: Array}

if !(RT <: Const) && !isa(A, Const)
dAs = EnzymeRules.width(config) == 1 ? (A.dval,) : A.dval
Expand Down Expand Up @@ -813,14 +814,13 @@ function EnzymeRules.reverse(
# dA −= z B(out)^T

func.val(cache_A, dB, kwargs...)

if !isa(A, Const)
dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[b]

if AType <: Array
mul!(dA, dB, transpose(cache_Bout), 1, -1)
mul!(dA, dB, transpose(cache_Bout), -1, 1)
else
mul!(dA.factors, dB, transpose(cache_Bout), 1, -1)
mul!(dA.factors, dB, transpose(cache_Bout), -1, 1)
end
end
end
Expand Down
4 changes: 4 additions & 0 deletions src/typetree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ function typetree(::Type{T}, ctx, dl, seen=nothing) where T <: Integer
return TypeTree(API.DT_Integer, -1, ctx)
end

function typetree(::Type{Char}, ctx, dl, seen=nothing)
return TypeTree(API.DT_Integer, -1, ctx)
end

function typetree(::Type{Float16}, ctx, dl, seen=nothing)
return TypeTree(API.DT_Half, -1, ctx)
end
Expand Down
16 changes: 16 additions & 0 deletions test/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,22 @@ end
@test isapprox(fwdJ, fdJ)
end
@test isapprox(fwdJ, revJ)

function h(A, b)
C = cholesky(A)
b2 = copy(b)
ldiv!(C, b2)
@inbounds b2[1]
end

A = [1.3 0.5; 0.5 1.5]
b = [1., 2.]
V = [1.0 0.0; 0.0 0.0]
dA = zero(A)
Enzyme.autodiff(Reverse, h, Active, Duplicated(A, dA), Const(b))

dA_sym = - (transpose(A) \ [1.0, 0.0]) * transpose(A \ b)
@test isapprox(dA, dA_sym)
end
end
end
Expand Down

0 comments on commit 3095a4f

Please sign in to comment.