Skip to content

Commit

Permalink
zeros/ones/fill may accept arbitrary axes that are supported by…
Browse files Browse the repository at this point in the history
… `similar` (#53965)

The idea is that functions like `zeros` are essentially constructing a
container and filling it with a value. `similar` seems perfectly placed
to construct such a container, so we may accept arbitrary axes in
`zeros` as long as there's a corresponding `similar` method that is
defined for the axes. Packages therefore would only need to define
`similar`, and would get `zeros`/`ones` and `fill` for free. For
example, the following will work after this:
```julia
julia> using StaticArrays

julia> zeros(SOneTo(2), 2)
2×2 Matrix{Float64}:
 0.0  0.0
 0.0  0.0

julia> zeros(SOneTo(2), Base.OneTo(2))
2×2 Matrix{Float64}:
 0.0  0.0
 0.0  0.0
```
Neither of these work on the current master, as `StaticArrays` doesn't
define `zeros` for these combinations, even though it does define
`similar`. One may argue for these methods to be added to
`StaticArrays`, but this seems to be adding redundancy.

The flip side is that `OffsetArrays` defines exactly these methods, so
adding them to `Base` would break precompilation for the package.
However, `OffsetArrays` really shouldn't be defining these methods, as
this is type-piracy. The methods may be version-limited in
`OffsetArrays` if this PR is merged.

On the face of it, `trues` and `falses` should also work similarly, but
currently these seem to be bypassing `similar` and constructing a
`BitArray` explicitly. I have not added the corresponding methods for
these functions, but they may be added as well.
  • Loading branch information
jishnub authored Apr 14, 2024
1 parent b9aeafa commit 8d577ab
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 0 deletions.
6 changes: 6 additions & 0 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ function fill end
fill(v, dims::DimOrInd...) = fill(v, dims)
fill(v, dims::NTuple{N, Union{Integer, OneTo}}) where {N} = fill(v, map(to_dim, dims))
fill(v, dims::NTuple{N, Integer}) where {N} = (a=Array{typeof(v),N}(undef, dims); fill!(a, v); a)
fill(v, dims::NTuple{N, DimOrInd}) where {N} = (a=similar(Array{typeof(v),N}, dims); fill!(a, v); a)
fill(v, dims::Tuple{}) = (a=Array{typeof(v),0}(undef, dims); fill!(a, v); a)

"""
Expand Down Expand Up @@ -589,6 +590,11 @@ for (fname, felt) in ((:zeros, :zero), (:ones, :one))
fill!(a, $felt(T))
return a
end
function $fname(::Type{T}, dims::NTuple{N, DimOrInd}) where {T,N}
a = similar(Array{T,N}, dims)
fill!(a, $felt(T))
return a
end
end
end

Expand Down
2 changes: 2 additions & 0 deletions base/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ falses(dims::DimOrInd...) = falses(dims)
falses(dims::NTuple{N, Union{Integer, OneTo}}) where {N} = falses(map(to_dim, dims))
falses(dims::NTuple{N, Integer}) where {N} = fill!(BitArray(undef, dims), false)
falses(dims::Tuple{}) = fill!(BitArray(undef, dims), false)
falses(dims::NTuple{N, DimOrInd}) where {N} = fill!(similar(BitArray, dims), false)

"""
trues(dims)
Expand All @@ -422,6 +423,7 @@ trues(dims::DimOrInd...) = trues(dims)
trues(dims::NTuple{N, Union{Integer, OneTo}}) where {N} = trues(map(to_dim, dims))
trues(dims::NTuple{N, Integer}) where {N} = fill!(BitArray(undef, dims), true)
trues(dims::Tuple{}) = fill!(BitArray(undef, dims), true)
trues(dims::NTuple{N, DimOrInd}) where {N} = fill!(similar(BitArray, dims), true)

function one(x::BitMatrix)
m, n = size(x)
Expand Down
22 changes: 22 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ using .Main.StructArrays
isdefined(Main, :FillArrays) || @eval Main include("testhelpers/FillArrays.jl")
using .Main.FillArrays

isdefined(Main, :SizedArrays) || @eval Main include("testhelpers/SizedArrays.jl")
using .Main.SizedArrays

A = rand(5,4,3)
@testset "Bounds checking" begin
@test checkbounds(Bool, A, 1, 1, 1) == true
Expand Down Expand Up @@ -2097,3 +2100,22 @@ end
@test r2[i] == z[j]
end
end

@testset "zero for arbitrary axes" begin
r = SizedArrays.SOneTo(2)
s = Base.OneTo(2)
_to_oneto(x::Integer) = Base.OneTo(2)
_to_oneto(x::Union{Base.OneTo, SizedArrays.SOneTo}) = x
for (f, v) in ((zeros, 0), (ones, 1), ((x...)->fill(3,x...),3))
for ax in ((r,r), (s, r), (2, r))
A = f(ax...)
@test axes(A) == map(_to_oneto, ax)
if all(x -> x isa SizedArrays.SOneTo, ax)
@test A isa SizedArrays.SizedArray && parent(A) isa Array
else
@test A isa Array
end
@test all(==(v), A)
end
end
end
22 changes: 22 additions & 0 deletions test/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
using Base: findprevnot, findnextnot
using Random, LinearAlgebra, Test

isdefined(Main, :SizedArrays) || @eval Main include("testhelpers/SizedArrays.jl")
using .Main.SizedArrays

tc(r1::NTuple{N,Any}, r2::NTuple{N,Any}) where {N} = all(x->tc(x...), [zip(r1,r2)...])
tc(r1::BitArray{N}, r2::Union{BitArray{N},Array{Bool,N}}) where {N} = true
tc(r1::SubArray{Bool,N1,BitArray{N2}}, r2::SubArray{Bool,N1,<:Union{BitArray{N2},Array{Bool,N2}}}) where {N1,N2} = true
Expand Down Expand Up @@ -82,6 +85,25 @@ allsizes = [((), BitArray{0}), ((v1,), BitVector),
@test !isassigned(b, length(b) + 1)
end

@testset "trues and falses with custom axes" begin
for ax in ((SizedArrays.SOneTo(2),), (SizedArrays.SOneTo(2), Base.OneTo(2)))
t = trues(ax)
if all(x -> x isa SizedArrays.SOneTo, ax)
@test t isa SizedArrays.SizedArray && parent(t) isa BitArray
else
@test t isa BitArray
end
@test all(t)

f = falses(ax)
if all(x -> x isa SizedArrays.SOneTo, ax)
@test t isa SizedArrays.SizedArray && parent(t) isa BitArray
else
@test t isa BitArray
end
@test !any(f)
end
end

@testset "Conversions for size $sz" for (sz, T) in allsizes
b1 = rand!(falses(sz...))
Expand Down
15 changes: 15 additions & 0 deletions test/testhelpers/SizedArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,25 @@ Base.size(a::SizedArray) = size(typeof(a))
Base.size(::Type{<:SizedArray{SZ}}) where {SZ} = SZ
Base.axes(a::SizedArray) = map(SOneTo, size(a))
Base.getindex(A::SizedArray, i...) = getindex(A.data, i...)
Base.setindex!(A::SizedArray, v, i...) = setindex!(A.data, v, i...)
Base.zero(::Type{T}) where T <: SizedArray = SizedArray{size(T)}(zeros(eltype(T), size(T)))
Base.parent(S::SizedArray) = S.data
+(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = SizedArray{SZ}(S1.data + S2.data)
==(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = S1.data == S2.data

homogenize_shape(t::Tuple) = (_homogenize_shape(first(t)), homogenize_shape(Base.tail(t))...)
homogenize_shape(::Tuple{}) = ()
_homogenize_shape(x::Integer) = x
_homogenize_shape(x::AbstractUnitRange) = length(x)
const Dims = Union{Integer, Base.OneTo, SOneTo}
function Base.similar(::Type{A}, shape::Tuple{Dims, Vararg{Dims}}) where {A<:AbstractArray}
similar(A, homogenize_shape(shape))
end
function Base.similar(::Type{A}, shape::Tuple{SOneTo, Vararg{SOneTo}}) where {A<:AbstractArray}
R = similar(A, length.(shape))
SizedArray{length.(shape)}(R)
end

const SizedMatrixLike = Union{SizedMatrix, Transpose{<:Any, <:SizedMatrix}, Adjoint{<:Any, <:SizedMatrix}}

_data(S::SizedArray) = S.data
Expand Down

0 comments on commit 8d577ab

Please sign in to comment.