Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate mulithreading in bootstrap #674

Merged
merged 8 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
MixedModels v4.10.0 Release Notes
==============================
* Multithreading in `parametricbootstrap` with `use_threads` is now deprecated and a noop. With improvements in BLAS threading, multithreading at the Julia level did not help performance and sometimes hurt it. [#674]

MixedModels v4.9.0 Release Notes
==============================
* Support `StatsModels` 0.7, drop support for `StatsModels` 0.6. [#664]
Expand Down Expand Up @@ -400,3 +404,4 @@ Package dependencies
[#664]: https://github.com/JuliaStats/MixedModels.jl/issues/664
[#665]: https://github.com/JuliaStats/MixedModels.jl/issues/665
[#667]: https://github.com/JuliaStats/MixedModels.jl/issues/667
[#674]: https://github.com/JuliaStats/MixedModels.jl/issues/674
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MixedModels"
uuid = "ff71e718-51f3-5ec2-a782-8ffcbfa3c316"
author = ["Phillip Alday <[email protected]>", "Douglas Bates <[email protected]>", "Jose Bayoan Santiago Calderon <[email protected]>"]
version = "4.9.0"
version = "4.10.0"

[deps]
Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/bootstrap.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ m2 = fit(
DisplayAs.Text(ans) # hide
```
```@example Main
samp2 = parametricbootstrap(rng, 10_000, m2, use_threads=true);
samp2 = parametricbootstrap(rng, 10_000, m2);
df2 = DataFrame(samp2.allpars);
first(df2, 10)
```
Expand Down
54 changes: 13 additions & 41 deletions src/bootstrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ end

"""
parametricbootstrap([rng::AbstractRNG], nsamp::Integer, m::MixedModel{T}, ftype=T;
β = coef(m), σ = m.σ, θ = m.θ, use_threads=false, hide_progress=false)
β = coef(m), σ = m.σ, θ = m.θ, hide_progress=false)

Perform `nsamp` parametric bootstrap replication fits of `m`, returning a `MixedModelBootstrap`.

Expand All @@ -52,20 +52,8 @@ performance benefits.
- `β`, `σ`, and `θ` are the values of `m`'s parameters for simulating the responses.
- `σ` is only valid for `LinearMixedModel` and `GeneralizedLinearMixedModel` for
families with a dispersion parameter.
- `use_threads` determines whether or not to use thread-based parallelism.
- `hide_progress` can be used to disable the progress bar. Note that the progress
bar is automatically disabled for non-interactive (i.e. logging) contexts.

!!! note
Note that `use_threads=true` may not offer a performance boost and may even
decrease performance if multithreaded linear algebra (BLAS) routines are available.
In this case, threads at the level of the linear algebra may already occupy all
processors/processor cores. There are plans to provide better support in coordinating
Julia- and BLAS-level threads in the future.

!!! warning
The PRNG shared between threads is locked using `Threads.SpinLock`, which
should not be used recursively. Do not wrap `parametricbootstrap` in an outer `SpinLock`.
"""
function parametricbootstrap(
rng::AbstractRNG,
Expand All @@ -88,35 +76,19 @@ function parametricbootstrap(

β_names = (Symbol.(fixefnames(morig))...,)

# we need arrays of these for in-place operations to work across threads
m_threads = [m]
βsc_threads = [βsc]
θsc_threads = [θsc]

if use_threads
Threads.resize_nthreads!(m_threads)
Threads.resize_nthreads!(βsc_threads)
Threads.resize_nthreads!(θsc_threads)
end
# we use locks to guarantee thread-safety, but there might be better ways to do this for some RNGs
# see https://docs.julialang.org/en/v1.3/manual/parallel-computing/#Side-effects-and-mutable-function-arguments-1
# see https://docs.julialang.org/en/v1/stdlib/Future/index.html
rnglock = Threads.SpinLock()
samp = replicate(n; use_threads=use_threads, hide_progress=hide_progress) do
tidx = use_threads ? Threads.threadid() : 1
mod = m_threads[tidx]
local βsc = βsc_threads[tidx]
local θsc = θsc_threads[tidx]
lock(rnglock)
mod = simulate!(rng, mod; β=β, σ=σ, θ=θ)
unlock(rnglock)
refit!(mod; progress=false)
use_threads && Base.depwarn(
"use_threads is deprecated and will be removed in a future release",
:parametricbootstrap,
)
samp = replicate(n; hide_progress=hide_progress) do
simulate!(rng, m; β, σ, θ)
refit!(m; progress=false)
(
objective=ftype.(mod.objective),
σ=ismissing(mod.σ) ? missing : ftype(mod.σ),
β=NamedTuple{β_names}(fixef!(βsc, mod)),
se=SVector{p,ftype}(stderror!(βsc, mod)),
θ=SVector{k,ftype}(getθ!(θsc, mod)),
objective=ftype.(m.objective),
σ=ismissing(m.σ) ? missing : ftype(m.σ),
β=NamedTuple{β_names}(fixef!(βsc, m)),
se=SVector{p,ftype}(stderror!(βsc, m)),
θ=SVector{k,ftype}(getθ!(θsc, m)),
)
end
return MixedModelBootstrap{ftype}(
Expand Down
27 changes: 9 additions & 18 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,39 +124,30 @@ end
_is_logging(io) = isa(io, Base.TTY) == false || (get(ENV, "CI", nothing) == "true")

"""
replicate(f::Function, n::Integer; use_threads=false)
replicate(f::Function, n::Integer; hide_progress=false)

Return a vector of the values of `n` calls to `f()` - used in simulations where the value of `f` is stochastic.

`hide_progress` can be used to disable the progress bar. Note that the progress
bar is automatically disabled for non-interactive (i.e. logging) contexts.

!!! warning
If `f()` is not thread-safe or depends on a non thread-safe RNG,
then you must set `use_threads=false`. Also note that ordering of replications
is not guaranteed when `use_threads=true`, although the replications are not
otherwise affected for thread-safe `f()`.
"""
function replicate(f::Function, n::Integer; use_threads=false, hide_progress=false)
# no macro version yet: https://github.com/timholy/ProgressMeter.jl/issues/143
use_threads && Base.depwarn(
"use_threads is deprecated and will be removed in a future release",
:replicate,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
:replicate,
:replicate

)
# and we want some advanced options
p = Progress(n; output=Base.stderr, enabled=!hide_progress && !_is_logging(stderr))
# get the type
rr = f()
next!(p)
# pre-allocate
results = [rr for _ in Base.OneTo(n)]
if use_threads
Threads.@threads for idx in 2:n
results[idx] = f()
next!(p)
end
else
for idx in 2:n
results[idx] = f()
next!(p)
end
for idx in 2:n
results[idx] = f()
next!(p)
end
finish!(p)
return results
end

Expand Down
12 changes: 2 additions & 10 deletions test/bootstrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,8 @@ end
@test propertynames(coefp) == [:iter, :coefname, :β, :se, :z, :p]

@testset "threaded bootstrap" begin
bsamp_threaded = parametricbootstrap(MersenneTwister(1234321), 100, fm;
use_threads=true, hide_progress=true)
# even though it's bad practice with floating point, exact equality should
# be a valid test here -- if everything is working right, then it's the exact
# same operations occurring within each bootstrap sample, which IEEE 754
# guarantees will yield the same result
@test sort(bsamp_threaded.σ) == sort(bsamp.σ)
@test sort(bsamp_threaded.θ) == sort(bsamp.θ)
@test sort(columntable(bsamp_threaded.β).β) == sort(columntable(bsamp.β).β)
@test sum(issingular(bsamp)) == sum(issingular(bsamp_threaded))
@test_logs (:warn, r"use_threads is deprecated") parametricbootstrap(MersenneTwister(1234321), 1, fm;
use_threads=true, hide_progress=true)
end

@testset "zerocorr + Base.length + ftype" begin
Expand Down
12 changes: 1 addition & 11 deletions test/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,7 @@ end
end

@testset "threaded_replicate" begin
rng = StableRNG(42);
single_thread = replicate(10;use_threads=false) do; only(randn(rng, 1)) ; end
rng = StableRNG(42);
multi_thread = replicate(10;use_threads=true) do
if Threads.threadid() % 2 == 0
sleep(0.001)
end
r = only(randn(rng, 1));
end

@test all(sort!(single_thread) .≈ sort!(multi_thread))
@test_logs (:warn, r"use_threads is deprecated") replicate(string, 1; use_threads=true)
end

@testset "datasets" begin
Expand Down