From 846a5bb8d08bd3459062441a379de8763623b97c Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Mon, 6 Mar 2023 16:26:53 -0500 Subject: [PATCH 1/4] Ensure all fft-like functions fallback to version with region when region not provided --- src/definitions.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/definitions.jl b/src/definitions.jl index 4532650..1cf542b 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -59,7 +59,7 @@ _to1(::Tuple, x) = copy1(eltype(x), x) for f in (:fft, :bfft, :ifft, :fft!, :bfft!, :ifft!, :rfft) pf = Symbol("plan_", f) @eval begin - $f(x::AbstractArray) = (y = to1(x); $pf(y) * y) + $f(x::AbstractArray) = $f(x, 1:ndims(x)) $f(x::AbstractArray, region) = (y = to1(x); $pf(y, region) * y) $pf(x::AbstractArray; kws...) = (y = to1(x); $pf(y, 1:ndims(y); kws...)) end @@ -207,9 +207,9 @@ bfft! for f in (:fft, :bfft, :ifft) pf = Symbol("plan_", f) @eval begin - $f(x::AbstractArray{<:Real}, region=1:ndims(x)) = $f(complexfloat(x), region) + $f(x::AbstractArray{<:Real}, region) = $f(complexfloat(x), region) $pf(x::AbstractArray{<:Real}, region; kws...) = $pf(complexfloat(x), region; kws...) - $f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region=1:ndims(x)) = $f(complexfloat(x), region) + $f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region) = $f(complexfloat(x), region) $pf(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region; kws...) = $pf(complexfloat(x), region; kws...) end end @@ -297,7 +297,7 @@ LinearAlgebra.mul!(y::AbstractArray, p::ScaledPlan, x::AbstractArray) = for f in (:brfft, :irfft) pf = Symbol("plan_", f) @eval begin - $f(x::AbstractArray, d::Integer) = $pf(x, d) * x + $f(x::AbstractArray, d::Integer) = $f(x, d, 1:ndims(x)) $f(x::AbstractArray, d::Integer, region) = $pf(x, d, region) * x $pf(x::AbstractArray, d::Integer;kws...) = $pf(x, d, 1:ndims(x);kws...) end @@ -305,8 +305,8 @@ end for f in (:brfft, :irfft) @eval begin - $f(x::AbstractArray{<:Real}, d::Integer, region=1:ndims(x)) = $f(complexfloat(x), d, region) - $f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, d::Integer, region=1:ndims(x)) = $f(complexfloat(x), d, region) + $f(x::AbstractArray{<:Real}, d::Integer, region) = $f(complexfloat(x), d, region) + $f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, d::Integer, region) = $f(complexfloat(x), d, region) end end From 20de5e72a51ed97a7ae1fb9f6e1e358bedc302c5 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Mon, 6 Mar 2023 16:57:58 -0500 Subject: [PATCH 2/4] Add testset for default dims --- test/runtests.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 4d402c5..c2ba8d1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -213,6 +213,21 @@ end @test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10 end +# Test that dims defaults to 1:ndims for fft-like functions +@testset "Default dims" begin + for x in (randn(3), randn(3, 4), randn(3, 4, 5)) + N = ndims(x) + complex_x = complex.(x) + @test fft(x) ≈ fft(x, 1:N) + @test ifft(x) ≈ ifft(x, 1:N) + @test bfft(x) ≈ bfft(x, 1:N) + @test rfft(x) ≈ rfft(x, 1:N) + d = 2 * size(x, 1) - 1 + @test irfft(x, d) ≈ irfft(x, d, 1:N) + @test brfft(x, d) ≈ brfft(x, d, 1:N) + end +end + @testset "ChainRules" begin @testset "shift functions" begin for x in (randn(3), randn(3, 4), randn(3, 4, 5)) From 3cbb2bb9990311b04378ae040fbc6e327794fcdd Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Mon, 6 Mar 2023 17:16:19 -0500 Subject: [PATCH 3/4] Add tests for complex float promotion --- test/runtests.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index c2ba8d1..c48516c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -228,6 +228,16 @@ end end end +@testset "Complex float promotion for backwards real FFTs" begin + for x in (rand(-5:5, 3), rand(-5:5, 3, 4), rand(-5:5, 3, 4, 5)) + N = ndims(x) + complex_x = complex.(x) + d = 2 * size(x, 1) - 1 + @test irfft(x, d) ≈ irfft(complex.(x), d) ≈ irfft(complex.(float.(x)), d) + @test brfft(x, d) ≈ brfft(complex.(x), d) ≈ brfft(complex.(float.(x)), d) + end +end + @testset "ChainRules" begin @testset "shift functions" begin for x in (randn(3), randn(3, 4), randn(3, 4, 5)) From bb29f03370c9e70557f9b845706c05703ad90c3d Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Mon, 6 Mar 2023 17:35:22 -0500 Subject: [PATCH 4/4] Test complex float promotion for fft,ifft,bfft too --- test/runtests.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index c48516c..9cb528a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -228,10 +228,12 @@ end end end -@testset "Complex float promotion for backwards real FFTs" begin +@testset "Complex float promotion" begin for x in (rand(-5:5, 3), rand(-5:5, 3, 4), rand(-5:5, 3, 4, 5)) N = ndims(x) - complex_x = complex.(x) + @test fft(x) ≈ fft(complex.(x)) ≈ fft(complex.(float.(x))) + @test ifft(x) ≈ ifft(complex.(x)) ≈ ifft(complex.(float.(x))) + @test bfft(x) ≈ bfft(complex.(x)) ≈ bfft(complex.(float.(x))) d = 2 * size(x, 1) - 1 @test irfft(x, d) ≈ irfft(complex.(x), d) ≈ irfft(complex.(float.(x)), d) @test brfft(x, d) ≈ brfft(complex.(x), d) ≈ brfft(complex.(float.(x)), d)