Skip to content

Commit

Permalink
Add get_num_threads (#171)
Browse files Browse the repository at this point in the history
* Add get_num_threads

This commit adds `get_num_threads`, which returns the number of threads used by
the planner, and is the complement to `set_num_threads`. This simply wraps the
function `fftw_planner_nthreads`, which was
[newly added to fftw in version 3.3.9](https://github.com/FFTW/fftw3/blob/34082eb5d6ed7dc9436915df69f376c06fc39762/NEWS#L3).

* Set FFTW_jll compat to 3.3.9

`get_num_threads` requires FFTW_jll v3.3.9+7, but it doesn't seem possible to
specify a particular build in the compat section of Project.toml files. However,
this should work in most cases, as the most recent build of `FFTW_jll` should be
downloaded upon updating.

* bump to 1.3 for the new function

* Make test for get_num_threads fftw specific

No equivalent function for mkl

* Typo...

* another typo

* Add vendor check to `get_num_threads`

* Add a method of `set_num_threads` that restores the original nthreads

Additionally, separate previous `set_num_threads` method into a base function,
`_set_num_threads`, that wraps the `ccalls`, and `set_num_threads`, which will
acquire the `fftwlock`.

* Provide support for `get_num_threads` with MKL's FFTW

While MKL's FFTW does not provide access to the number of threads available to
the planner, this can be simulated by caching the value last passed to
`set_num_threads` and returning it with `get_num_threads` if
`fftw_vendor == :mkl`.

* Implement suggestions of @stevengj

* Fix typo in set_num_threads

* Add test for set_num_threads method that restores original num_threads

* Rename `nthreads` variable to `num_threads` to avoid shadowing Threads.nthreads

Since FFTW uses `Base.Threads`, and `nthreads` is a function defined in
`Base.Threads`, then the function argument `nthreads` shadows a function already
in the namespace of every function. While there is no inherent issue with this,
it can make debugging this code more confusing.

* Make one-line method of `set_num_threads` one line.

* First attempt at adding `num_threads` to `plan_...` functions

As suggested by @stevengj, I have add a `num_threads` keyword to the `plan_...`
functions. My approach here is fairly naive, and adds a bunch of redundant
boiler plate code to every `plan_` function.

Co-authored-by: Steven G. Johnson <[email protected]>
Co-authored-by: Mosè Giordano <[email protected]>
  • Loading branch information
3 people authored Mar 14, 2022
1 parent 683a6e8 commit 17bc81a
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 22 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FFTW"
uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
version = "1.4.6"
version = "1.5.0"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand All @@ -12,7 +12,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"

[compat]
AbstractFFTs = "1.0"
FFTW_jll = "3.3"
FFTW_jll = "3.3.9"
MKL_jll = "2019.0.117, 2020, 2021, 2022"
Preferences = "1.2"
Reexport = "0.2, 1.0"
Expand Down
8 changes: 4 additions & 4 deletions src/dct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
# (This is part of the FFTW module.)

"""
plan_dct!(A [, dims [, flags [, timelimit]]])
plan_dct!(A [, dims [, flags [, timelimit [, num_threads]]]])
Same as [`plan_dct`](@ref), but operates in-place on `A`.
"""
function plan_dct! end

"""
plan_idct(A [, dims [, flags [, timelimit]]])
plan_idct(A [, dims [, flags [, timelimit [, num_threads]]]])
Pre-plan an optimized inverse discrete cosine transform (DCT), similar to
[`plan_fft`](@ref) except producing a function that computes
Expand All @@ -20,7 +20,7 @@ Pre-plan an optimized inverse discrete cosine transform (DCT), similar to
function plan_idct end

"""
plan_dct(A [, dims [, flags [, timelimit]]])
plan_dct(A [, dims [, flags [, timelimit [, num_threads]]]])
Pre-plan an optimized discrete cosine transform (DCT), similar to
[`plan_fft`](@ref) except producing a function that computes
Expand All @@ -30,7 +30,7 @@ Pre-plan an optimized discrete cosine transform (DCT), similar to
function plan_dct end

"""
plan_idct!(A [, dims [, flags [, timelimit]]])
plan_idct!(A [, dims [, flags [, timelimit [, num_threads]]]])
Same as [`plan_idct`](@ref), but operates in-place on `A`.
"""
Expand Down
126 changes: 110 additions & 16 deletions src/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ an array of real or complex floating-point numbers.
function r2r! end

"""
plan_r2r!(A, kind [, dims [, flags [, timelimit]]])
plan_r2r!(A, kind [, dims [, flags [, timelimit [, num_threads]]]])
Similar to [`plan_fft`](@ref), but corresponds to [`r2r!`](@ref).
"""
function plan_r2r! end

"""
plan_r2r(A, kind [, dims [, flags [, timelimit]]])
plan_r2r(A, kind [, dims [, flags [, timelimit [, num_threads]]]])
Pre-plan an optimized r2r transform, similar to [`plan_fft`](@ref)
except that the transforms (and the first three arguments)
Expand Down Expand Up @@ -171,9 +171,33 @@ end

# Threads

@exclusive function set_num_threads(nthreads::Integer)
ccall((:fftw_plan_with_nthreads,libfftw3[]), Cvoid, (Int32,), nthreads)
ccall((:fftwf_plan_with_nthreads,libfftw3f[]), Cvoid, (Int32,), nthreads)
# Must only be called after acquiring fftwlock
function _set_num_threads(num_threads::Integer)
@static if fftw_provider == "mkl"
_last_num_threads[] = num_threads
end
ccall((:fftw_plan_with_nthreads,libfftw3[]), Cvoid, (Int32,), num_threads)
ccall((:fftwf_plan_with_nthreads,libfftw3f[]), Cvoid, (Int32,), num_threads)
end

@exclusive set_num_threads(num_threads::Integer) = _set_num_threads(num_threads)

function get_num_threads()
@static if fftw_provider == "fftw"
ccall((:fftw_planner_nthreads,libfftw3[]), Cint, ())
else
_last_num_threads[]
end
end

@exclusive function set_num_threads(f::Function, num_threads::Integer)
orig_num_threads = get_num_threads()
_set_num_threads(num_threads)
try
f()
finally
_set_num_threads(orig_num_threads)
end
end

# pointer type for fftw_plan (opaque pointer)
Expand Down Expand Up @@ -684,22 +708,43 @@ for (f,direction) in ((:fft,FORWARD), (:bfft,BACKWARD))
@eval begin
function $plan_f(X::StridedArray{T,N}, region;
flags::Integer=ESTIMATE,
timelimit::Real=NO_TIMELIMIT) where {T<:fftwComplex,N}
timelimit::Real=NO_TIMELIMIT,
num_threads::Union{Nothing, Integer} = nothing) where {T<:fftwComplex,N}
if num_threads !== nothing
plan = set_num_threads(num_threads) do
$plan_f(X, region; flags = flags, timelimit = timelimit)
end
return plan
end
cFFTWPlan{T,$direction,false,N}(X, fakesimilar(flags, X, T),
region, flags, timelimit)
end

function $plan_f!(X::StridedArray{T,N}, region;
flags::Integer=ESTIMATE,
timelimit::Real=NO_TIMELIMIT) where {T<:fftwComplex,N}
flags::Integer=ESTIMATE,
timelimit::Real=NO_TIMELIMIT,
num_threads::Union{Nothing, Integer} = nothing ) where {T<:fftwComplex,N}
if num_threads !== nothing
plan = set_num_threads(num_threads) do
$plan_f!(X, region; flags = flags, timelimit = timelimit)
end
return plan
end
cFFTWPlan{T,$direction,true,N}(X, X, region, flags, timelimit)
end
$plan_f(X::StridedArray{<:fftwComplex}; kws...) =
$plan_f(X, 1:ndims(X); kws...)
$plan_f!(X::StridedArray{<:fftwComplex}; kws...) =
$plan_f!(X, 1:ndims(X); kws...)

function plan_inv(p::cFFTWPlan{T,$direction,inplace,N}) where {T<:fftwComplex,N,inplace}
function plan_inv(p::cFFTWPlan{T,$direction,inplace,N};
num_threads::Union{Nothing, Integer} = nothing) where {T<:fftwComplex,N,inplace}
if num_threads !== nothing
plan = set_num_threads(num_threads) do
plan_inv(p)
end
return plan
end
X = Array{T}(undef, p.sz)
Y = inplace ? X : fakesimilar(p.flags, X, T)
ScaledPlan(cFFTWPlan{T,$idirection,inplace,N}(X, Y, p.region,
Expand Down Expand Up @@ -735,15 +780,29 @@ for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64})))
@eval begin
function plan_rfft(X::StridedArray{$Tr,N}, region;
flags::Integer=ESTIMATE,
timelimit::Real=NO_TIMELIMIT) where N
timelimit::Real=NO_TIMELIMIT,
num_threads::Union{Nothing, Integer} = nothing) where N
if num_threads !== nothing
plan = set_num_threads(num_threads) do
plan_rfft(X, region; flags = flags, timelimit = timelimit)
end
return plan
end
osize = rfft_output_size(X, region)
Y = flags&ESTIMATE != 0 ? FakeArray{$Tc}(osize) : Array{$Tc}(undef, osize)
rFFTWPlan{$Tr,$FORWARD,false,N}(X, Y, region, flags, timelimit)
end

function plan_brfft(X::StridedArray{$Tc,N}, d::Integer, region;
flags::Integer=ESTIMATE,
timelimit::Real=NO_TIMELIMIT) where N
timelimit::Real=NO_TIMELIMIT,
num_threads::Union{Nothing, Integer} = nothing) where N
if num_threads !== nothing
plan = set_num_threads(num_threads) do
plan_brfft(X, d, region; flags = flags, timelimit = timelimit)
end
return plan
end
osize = brfft_output_size(X, d, region)
Y = flags&ESTIMATE != 0 ? FakeArray{$Tr}(osize) : Array{$Tr}(undef, osize)

Expand All @@ -763,7 +822,14 @@ for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64})))
plan_rfft(X::StridedArray{$Tr};kws...)=plan_rfft(X,1:ndims(X);kws...)
plan_brfft(X::StridedArray{$Tr};kws...)=plan_brfft(X,1:ndims(X);kws...)

function plan_inv(p::rFFTWPlan{$Tr,$FORWARD,false,N}) where N
function plan_inv(p::rFFTWPlan{$Tr,$FORWARD,false,N},
num_threads::Union{Nothing, Integer} = nothing) where N
if num_threads !== nothing
plan = set_num_threads(num_threads) do
plan_inv(p)
end
return plan
end
X = Array{$Tr}(undef, p.sz)
Y = p.flags&ESTIMATE != 0 ? FakeArray{$Tc}(p.osz) : Array{$Tc}(undef, p.osz)
ScaledPlan(rFFTWPlan{$Tc,$BACKWARD,false,N}(Y, X, p.region,
Expand All @@ -773,7 +839,14 @@ for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64})))
normalization(X, p.region))
end

function plan_inv(p::rFFTWPlan{$Tc,$BACKWARD,false,N}) where N
function plan_inv(p::rFFTWPlan{$Tc,$BACKWARD,false,N};
num_threads::Union{Nothing, Integer} = nothing) where N
if num_threads !== nothing
plan = set_num_threads(num_threads) do
plan_inv(p)
end
return plan
end
X = Array{$Tc}(undef, p.sz)
Y = p.flags&ESTIMATE != 0 ? FakeArray{$Tr}(p.osz) : Array{$Tr}(undef, p.osz)
ScaledPlan(rFFTWPlan{$Tr,$FORWARD,false,N}(Y, X, p.region,
Expand Down Expand Up @@ -832,14 +905,28 @@ end

function plan_r2r(X::StridedArray{T,N}, kinds, region;
flags::Integer=ESTIMATE,
timelimit::Real=NO_TIMELIMIT) where {T<:fftwNumber,N}
timelimit::Real=NO_TIMELIMIT,
num_threads::Union{Nothing, Integer} = nothing) where {T<:fftwNumber,N}
if num_threads !== nothing
plan = set_num_threads(num_threads) do
plan_r2r(X, kinds, region; flags = flags, timelimit = timelimit)
end
return plan
end
r2rFFTWPlan{T,Any,false,N}(X, fakesimilar(flags, X, T), region, kinds,
flags, timelimit)
end

function plan_r2r!(X::StridedArray{T,N}, kinds, region;
flags::Integer=ESTIMATE,
timelimit::Real=NO_TIMELIMIT) where {T<:fftwNumber,N}
timelimit::Real=NO_TIMELIMIT,
num_threads::Union{Nothing, Integer} = nothing) where {T<:fftwNumber,N}
if num_threads !== nothing
plan = set_num_threads(num_threads) do
plan_r2r(X, kinds, region; flags = flags, timelimit = timelimit)
end
return plan
end
r2rFFTWPlan{T,Any,true,N}(X, X, region, kinds, flags, timelimit)
end

Expand All @@ -861,7 +948,14 @@ function logical_size(n::Integer, k::Integer)
return 2n
end

function plan_inv(p::r2rFFTWPlan{T,K,inplace,N}) where {T<:fftwNumber,K,inplace,N}
function plan_inv(p::r2rFFTWPlan{T,K,inplace,N};
num_threads::Union{Nothing, Integer} = nothing) where {T<:fftwNumber,K,inplace,N}
if num_threads !== nothing
set_num_threads(num_threads) do
plan = plan_inv(p)
end
return plan
end
X = Array{T}(undef, p.sz)
iK = fix_kinds(p.region, [inv_kind[k] for k in K])
Y = inplace ? X : fakesimilar(p.flags, X, T)
Expand Down
1 change: 1 addition & 0 deletions src/providers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,5 @@ end
import MKL_jll
libfftw3[] = MKL_jll.libmkl_rt_path
libfftw3f[] = MKL_jll.libmkl_rt_path
const _last_num_threads = Ref(Cint(1))
end
11 changes: 11 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -528,3 +528,14 @@ end
@test occursin("dft-thr", string(p2))
end
end

@testset "Setting and getting planner nthreads" begin
FFTW.set_num_threads(1)
@test FFTW.get_num_threads() == 1
FFTW.set_num_threads(2)
@test FFTW.get_num_threads() == 2
plan = FFTW.set_num_threads(1) do # Should leave get_num_threads unchanged
plan_rfft(m4, 1)
end
@test FFTW.get_num_threads() == 2 # Unchanged
end

2 comments on commit 17bc81a

@ararslan
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/62829

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.5.0 -m "<description of version>" 17bc81a0fcf9875d777ea4bee2fca70fc23c8a0c
git push origin v1.5.0

Please sign in to comment.