diff --git a/src/lib/array.jl b/src/lib/array.jl index bdfb908a7..9fe2e14d8 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -487,7 +487,8 @@ end X = _pairdiffquotmat(exp, n, w, ew, ew, ew) V = E.vectors VF = factorize(V) - Ā = (V * ((VF \ F̄' * V) .* X) / VF)' + Āc = (V * ((VF \ F̄' * V) .* X) / VF)' + Ā = isreal(A) && isreal(F̄) ? real(Āc) : Āc return (Ā,) end diff --git a/test/gradcheck.jl b/test/gradcheck.jl index e408d1f12..55b27073a 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -603,6 +603,12 @@ end end end end + A = [ 0.0 1.0 0.0 + 0.0 0.0 1.0 + -4.34 -18.31 -0.43] + _,back = Zygote.pullback(exp,A) + Ȳ = rand(3,3) + @test isreal(back(Ȳ)[1]) end end