Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zygote won't work with FFT Plans #11

Closed
pzimbrod opened this issue Sep 1, 2021 · 3 comments
Closed

Zygote won't work with FFT Plans #11

pzimbrod opened this issue Sep 1, 2021 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@pzimbrod
Copy link
Contributor

pzimbrod commented Sep 1, 2021

When taking the jacobian of a sample FFT-Pipeline, Zygote complains about missing fields in the corresponding FFT Plan:

using Zygote, FFTW

n = rand(100);
f = plan_rfft(n);
fi = plan_irfft(rfft(n), length(n));
fi * (f * n);

jacobian(x -> fi * (f * x), n)
ERROR: type ScaledPlan has no field region
Stacktrace:
  [1] getproperty(x::AbstractFFTs.ScaledPlan{ComplexF64, FFTW.rFFTWPlan{ComplexF64, 1, false, 1, UnitRange{Int64}}, Float64}, f::Symbol)
    @ Base ./Base.jl:33
  [2] (::Zygote.var"#931#932"{AbstractFFTs.ScaledPlan{ComplexF64, FFTW.rFFTWPlan{ComplexF64, 1, false, 1, UnitRange{Int64}}, Float64}, Vector{ComplexF64}})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/lib/array.jl:788
  [3] (::Zygote.var"#3479#back#933"{Zygote.var"#931#932"{AbstractFFTs.ScaledPlan{ComplexF64, FFTW.rFFTWPlan{ComplexF64, 1, false, 1, UnitRange{Int64}}, Float64}, Vector{ComplexF64}}})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [4] Pullback
    @ ./REPL[6]:1 [inlined]
  [5] (::typeof((#1)))(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
  [6] (::Zygote.var"#209#210"{Tuple{Tuple{Nothing}}, typeof((#1))})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/lib/lib.jl:203
  [7] (::Zygote.var"#1746#back#211"{Zygote.var"#209#210"{Tuple{Tuple{Nothing}}, typeof((#1))}})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [8] Pullback
    @ ./operators.jl:938 [inlined]
  [9] (::typeof((ComposedFunction{typeof(Zygote._jvec), var"#1#2"}(Zygote._jvec, var"#1#2"()))))(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [10] (::Zygote.var"#46#47"{typeof((ComposedFunction{typeof(Zygote._jvec), var"#1#2"}(Zygote._jvec, var"#1#2"())))})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:41
 [11] withjacobian(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/lib/grad.jl:162
 [12] jacobian(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/lib/grad.jl:140
 [13] top-level scope
    @ REPL[6]:1

However, doing the FFT in-place, the error vanishes:

jacobian(x -> irfft(rfft(x), length(x)), n)
([1.0 0.0  0.0 0.0; 0.0 1.0  0.0 6.187926798518962e-19;  ; 0.0 0.0  1.0 0.0; 0.0 5.204170427930421e-18  0.0 1.0],)

This is an issue since we're doing the FFT and its inverse a lot of times, so the speed gain from using pre-planned FFTs shoud be quite considerable.

@pzimbrod pzimbrod self-assigned this Sep 1, 2021
@pzimbrod pzimbrod added the bug Something isn't working label Sep 1, 2021
@pzimbrod
Copy link
Contributor Author

pzimbrod commented Sep 1, 2021

This problem is mentioned in FluxML/Zygote.jl#899 and subsequently JuliaMath/FFTW.jl#182. In some implementations of FFTW.jl the region field of the struct plan is missing.

It would be plausible that the problem can be fixed by switching to the unscaled inverse transform and do the scaling manually afterwards.

@pzimbrod
Copy link
Contributor Author

pzimbrod commented Sep 1, 2021

Apparently, you can use plan_brfft instead for the inverse and scale afterwards:

fib = plan_brfft(rfft(n), length(n));

# Check for same results
irfft(f * n, length(n))  fib * (f * n) ./ length(n)
true

# Does Zygote run now?
jacobian(x -> fib * (f * x) ./ length(x), n)
([0.51 0.0  0.0 0.0; 0.0 0.51  0.0 3.996891658508158e-19;  ; 0.0 0.0  0.51 0.0; 0.0 8.673617379884035e-19  0.0 0.51],)

@pzimbrod
Copy link
Contributor Author

Another workaround could also be to just define custom adjoints for Zygote, as partially discussed here. That will probably require some serious fiddling, though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant