Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix indexing for chains in different threads #154

Merged
merged 3 commits into from
Jan 10, 2025

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Jan 8, 2025

Closes #153

If there were 6 threads and 9 chains, the previous code (

chainidxs = if i == nchunks
((i - 1) * chunksize + 1):nchains
else
((i - 1) * chunksize + 1):(i * chunksize)
end
) would generate the following chainidxs for each thread

1:2
3:4
5:6
7:8
9:10
11:10

Somewhere later on (probably in the fifth thread with chainidxs = 9:10) this causes indexing errors as we only have a vector of size 9.

This PR changes it such that for the same inputs (6 threads and 9 chains) it generates the following chainidxs

1:2
3:4
5:6
7:7
8:8
9:9

The error in #153 is fixed

julia> samples = [sample_chains(chain_count) for chain_count in 1:30]
30-element Vector{Chains{Float64, AxisArrays.AxisArray{Float64, 3, Array{Float64, 3}, Tuple{AxisArrays.Axis{:iter, StepRange{Int64, Int64}}, AxisArrays.Axis{:var, Vector{Symbol}}, AxisArrays.Axis{:chain, UnitRange{Int64}}}}, Missing, @NamedTuple{parameters::Vector{Symbol}, internals::Vector{Symbol}}}}:
 MCMC chain (1000×14×1 Array{Float64, 3})
 MCMC chain (1000×14×2 Array{Float64, 3})
 MCMC chain (1000×14×3 Array{Float64, 3})
 MCMC chain (1000×14×4 Array{Float64, 3})
 MCMC chain (1000×14×5 Array{Float64, 3})
 MCMC chain (1000×14×6 Array{Float64, 3})
 MCMC chain (1000×14×7 Array{Float64, 3})
 MCMC chain (1000×14×8 Array{Float64, 3})
 MCMC chain (1000×14×9 Array{Float64, 3})
 MCMC chain (1000×14×10 Array{Float64, 3})
 MCMC chain (1000×14×11 Array{Float64, 3})
 MCMC chain (1000×14×12 Array{Float64, 3})
 MCMC chain (1000×14×13 Array{Float64, 3})
 MCMC chain (1000×14×14 Array{Float64, 3})
 MCMC chain (1000×14×15 Array{Float64, 3})
 MCMC chain (1000×14×16 Array{Float64, 3})
 MCMC chain (1000×14×17 Array{Float64, 3})
 MCMC chain (1000×14×18 Array{Float64, 3})
 MCMC chain (1000×14×19 Array{Float64, 3})
 MCMC chain (1000×14×20 Array{Float64, 3})
 MCMC chain (1000×14×21 Array{Float64, 3})
 MCMC chain (1000×14×22 Array{Float64, 3})
 MCMC chain (1000×14×23 Array{Float64, 3})
 MCMC chain (1000×14×24 Array{Float64, 3})
 MCMC chain (1000×14×25 Array{Float64, 3})
 MCMC chain (1000×14×26 Array{Float64, 3})
 MCMC chain (1000×14×27 Array{Float64, 3})
 MCMC chain (1000×14×28 Array{Float64, 3})
 MCMC chain (1000×14×29 Array{Float64, 3})
 MCMC chain (1000×14×30 Array{Float64, 3})

julia> println("Thread count: $(Threads.nthreads())")
Thread count: 6

julia> for i in eachindex(samples)
           isnothing(samples[i]) && println("BoundsError with $i chains")
       end

@penelopeysm penelopeysm force-pushed the py/fix-multithreaded-chain-indexing branch from 23ba9c0 to d194f86 Compare January 8, 2025 15:15
@TuringLang TuringLang deleted a comment from github-actions bot Jan 8, 2025
@penelopeysm penelopeysm requested a review from mhauru January 8, 2025 15:16
@penelopeysm penelopeysm force-pushed the py/fix-multithreaded-chain-indexing branch from d194f86 to ec621b7 Compare January 8, 2025 15:19
src/sample.jl Outdated Show resolved Hide resolved
Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy with the code, just wondering about whether this could be tested somehow. End-to-end is hard, but might it make sense to put the chainidxs logic in a function that could be tested with given values for nchains and number of threads?

src/sample.jl Outdated Show resolved Hide resolved
@penelopeysm
Copy link
Member Author

@mhauru the logic itself feels too basic to be worth testing to me 😅. I checked again locally with the most recent commit and it still behaves correctly, so it has been end-to-end tested, albeit not in CI.

Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get paranoid about off-by-one mistakes with all index arithmetic, and testing can also guard against botched future changes. If the person who wrote the previous version would have added tests we wouldn't be here. But I do see that this is very minimal, and also it's not introducing untested functionality, it's rather fixing untested functionality, so I'm okay with merging as-is. Thanks for fixing!

@penelopeysm penelopeysm merged commit 82f02f1 into master Jan 10, 2025
28 checks passed
@penelopeysm penelopeysm deleted the py/fix-multithreaded-chain-indexing branch January 10, 2025 15:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BoundsError when sampling with MCMCThreads
3 participants