From 820c08b896e408d9ed1e064ada8f9138dd6a3a6b Mon Sep 17 00:00:00 2001 From: Jeff Bezanson Date: Mon, 18 Jul 2022 15:02:46 -0400 Subject: [PATCH] fix #45825, BitArray methods assuming 1-indexing of AbstractArray (#45835) --- base/abstractarray.jl | 10 +++++--- base/bitarray.jl | 53 +++++++++++++++++++++++-------------------- test/bitarray.jl | 35 ++++++++++++++++++++++++++++ 3 files changed, 70 insertions(+), 28 deletions(-) diff --git a/base/abstractarray.jl b/base/abstractarray.jl index b5b74bd8446c0..e97359cb87fcf 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -1035,6 +1035,10 @@ julia> y """ function copyto!(dest::AbstractArray, src::AbstractArray) isempty(src) && return dest + if dest isa BitArray + # avoid ambiguities with other copyto!(::AbstractArray, ::SourceArray) methods + return _copyto_bitarray!(dest, src) + end src′ = unalias(dest, src) copyto_unaliased!(IndexStyle(dest), dest, IndexStyle(src′), src′) end @@ -1139,10 +1143,10 @@ function copyto!(B::AbstractVecOrMat{R}, ir_dest::AbstractRange{Int}, jr_dest::A return B end -function copyto_axcheck!(dest, src) - @noinline checkaxs(axd, axs) = axd == axs || throw(DimensionMismatch("axes must agree, got $axd and $axs")) +@noinline _checkaxs(axd, axs) = axd == axs || throw(DimensionMismatch("axes must agree, got $axd and $axs")) - checkaxs(axes(dest), axes(src)) +function copyto_axcheck!(dest, src) + _checkaxs(axes(dest), axes(src)) copyto!(dest, src) end diff --git a/base/bitarray.jl b/base/bitarray.jl index 73f274df44a85..71d83b5b58f56 100644 --- a/base/bitarray.jl +++ b/base/bitarray.jl @@ -501,40 +501,42 @@ function Array{T,N}(B::BitArray{N}) where {T,N} end BitArray(A::AbstractArray{<:Any,N}) where {N} = BitArray{N}(A) + function BitArray{N}(A::AbstractArray{T,N}) where N where T B = BitArray(undef, convert(Dims{N}, size(A)::Dims{N})) - Bc = B.chunks - l = length(B) + _checkaxs(axes(B), axes(A)) + _copyto_bitarray!(B, A) + return B::BitArray{N} +end + +function _copyto_bitarray!(B::BitArray, A::AbstractArray) + l = length(A) l == 0 && return B - ind = 1 + l > length(B) && throw(BoundsError(B, length(B)+1)) + Bc = B.chunks + nc = num_bit_chunks(l) + Ai = first(eachindex(A)) @inbounds begin - for i = 1:length(Bc)-1 + for i = 1:nc-1 c = UInt64(0) for j = 0:63 - c |= (UInt64(convert(Bool, A[ind])::Bool) << j) - ind += 1 + c |= (UInt64(convert(Bool, A[Ai])::Bool) << j) + Ai = nextind(A, Ai) end Bc[i] = c end c = UInt64(0) - for j = 0:_mod64(l-1) - c |= (UInt64(convert(Bool, A[ind])::Bool) << j) - ind += 1 + tail = _mod64(l - 1) + 1 + for j = 0:tail-1 + c |= (UInt64(convert(Bool, A[Ai])::Bool) << j) + Ai = nextind(A, Ai) end - Bc[end] = c + msk = _msk_end(tail) + Bc[nc] = (c & msk) | (Bc[nc] & ~msk) end return B end -function BitArray{N}(A::Array{Bool,N}) where N - B = BitArray(undef, size(A)) - Bc = B.chunks - l = length(B) - l == 0 && return B - copy_to_bitarray_chunks!(Bc, 1, A, 1, l) - return B::BitArray{N} -end - reinterpret(::Type{Bool}, B::BitArray, dims::NTuple{N,Int}) where {N} = reinterpret(B, dims) reinterpret(B::BitArray, dims::NTuple{N,Int}) where {N} = reshape(B, dims) @@ -721,24 +723,25 @@ function _unsafe_setindex!(B::BitArray, X::AbstractArray, I::BitArray) lx = length(X) last_chunk_len = _mod64(length(B)-1)+1 - c = 1 + Xi = first(eachindex(X)) + lastXi = last(eachindex(X)) for i = 1:lc @inbounds Imsk = Ic[i] @inbounds C = Bc[i] u = UInt64(1) for j = 1:(i < lc ? 64 : last_chunk_len) if Imsk & u != 0 - lx < c && throw_setindex_mismatch(X, c) - @inbounds x = convert(Bool, X[c]) + Xi > lastXi && throw_setindex_mismatch(X, count(I)) + @inbounds x = convert(Bool, X[Xi]) C = ifelse(x, C | u, C & ~u) - c += 1 + Xi = nextind(X, Xi) end u <<= 1 end @inbounds Bc[i] = C end - if length(X) != c-1 - throw_setindex_mismatch(X, c-1) + if Xi != nextind(X, lastXi) + throw_setindex_mismatch(X, count(I)) end return B end diff --git a/test/bitarray.jl b/test/bitarray.jl index d17a9856596a4..3a528d4391a82 100644 --- a/test/bitarray.jl +++ b/test/bitarray.jl @@ -1787,3 +1787,38 @@ end @test all(bitarray[rangein, rangeout] .== true) end end + +# issue #45825 + +isdefined(Main, :OffsetArrays) || @eval Main include("testhelpers/OffsetArrays.jl") +using .Main.OffsetArrays + +let all_false = OffsetArray(falses(2001), -1000:1000) + @test !any(==(true), all_false) + # should be run with --check-bounds=yes + @test_throws DimensionMismatch BitArray(all_false) + all_false = OffsetArray(falses(2001), 1:2001) + @test !any(==(true), BitArray(all_false)) + all_false = OffsetArray(falses(100, 100), 0:99, -1:98) + @test !any(==(true), all_false) + @test_throws DimensionMismatch BitArray(all_false) + all_false = OffsetArray(falses(100, 100), 1:100, 1:100) + @test !any(==(true), all_false) +end +let a = falses(1000), + msk = BitArray(rand(Bool, 1000)), + n = count(msk), + b = OffsetArray(rand(Bool, n), (-n÷2):(n÷2)-iseven(n)) + a[msk] = b + @test a[msk] == collect(b) + a = falses(100, 100) + msk = BitArray(rand(Bool, 100, 100)) + n = count(msk) + b = OffsetArray(rand(Bool, 1, n), 1:1, (-n÷2):(n÷2)-iseven(n)) + a[msk] = b + @test a[msk] == vec(collect(b)) +end +let b = trues(10) + copyto!(b, view([0,0,0], :)) + @test b == [0,0,0,1,1,1,1,1,1,1] +end