Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

migrate simplified dimnames #411

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ makedocs(;
"sparsearrays.md",
"tuples.md",
"wrapping.md",
"dimnames.md",
]
)

Expand Down
9 changes: 9 additions & 0 deletions docs/src/dimnames.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Named Dimensions Interface

The following functions provide a common interface for interacting with named dimensions.

```@docs
ArrayInterface.has_dimnames
ArrayInterface.dimnames
ArrayInterface.to_dims
```
2 changes: 1 addition & 1 deletion docs/src/indexing.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ and index translations.
ArrayInterface.ArrayIndex
ArrayInterface.GetIndex
ArrayInterface.SetIndex!
```
```
131 changes: 131 additions & 0 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ else
end
end
end

@assume_effects :total function _find_first_egal(v::T, vals::NTuple{N, T}) where {N, T}
for i in 1:N
getfield(vals, i, false) === v && return i
end
return 0
end

@assume_effects :total __parameterless_type(T)=Base.typename(T).wrapper
parameterless_type(x) = parameterless_type(typeof(x))
parameterless_type(x::Type) = __parameterless_type(x)
Expand Down Expand Up @@ -1030,6 +1038,129 @@ ensures_sorted(@nospecialize( T::Type{<:AbstractRange})) = true
ensures_sorted(T::Type) = is_forwarding_wrapper(T) ? ensures_sorted(parent_type(T)) : false
ensures_sorted(@nospecialize(x)) = ensures_sorted(typeof(x))

DIMNAMES_EXTENDED_HELP = """
## Extended help

Structures that explicitly provide named dimensions must define both `has_dimnames` and
`dimnames`. Wrappers that don't change the layout of their parent data and define
`is_forwarding_wrapper` will propagate these methods freely. All other wrappers must
define `has_dimnames` and `dimnames`. For example:

```julia
function ArrayInterface.has_dimnames(T::Type{<:Wrapper})
has_dimnames(ArrayInterface.parent_type(T))
end

function ArrayInterface.dimnames(x::Wrapper)
if has_dimnames(x)
# appropriately modify wrapped dimension names to reflect changes lazy changes
# in the parent data layout
modify_wrapped_dimnames(dimnames(parent(x)))::NTuple{ndims(x), Symbol}
else # need to return "blank" dimension name :_ when names aren't defined
ntuple(_ -> :_, ndims(x))
end
end
```

In some cases `Wrapper` may modify some aspect of its parent data's layout that has no
impact on the dimension names (e.g., mapping offset indices to a parent array). In such
cases there may be no need to modify dimension names and simply defining
`ArrayInterface.dimnames(x::Wrapper) = dimnames(parent(x))` may be sufficient.

Since the utlity of dimension names is highly specific to the domain they are used in,
there are very few explicit guidelines how they should be modified by wrappers. The most
important guideline is that `dimnames(x)` returns an instance of type
`NTuple{ndims(x), Symbol}`.
"""

"""
has_dimnames(T::Type) -> Bool

Returns `true` if instances of `T` have named dimensions. Structures overloading this
method are also responsible for defining [`ArrayInterface.dimnames`](@ref).

See also: [`ArrayInterface.to_dims`](@ref)

$(DIMNAMES_EXTENDED_HELP)
"""
has_dimnames(T::Type) = is_forwarding_wrapper(T) ? has_dimnames(parent_type(T)) : false

"""
dimnames(x) -> NTuple{ndims(x), Symbol}
dimnames(x, dim::Integer) -> Symbol
dimnames(x, dim::Tuple{Vararg{Integer, N}}) -> NTuple{N, Symbol}

Return the names of the dimensions for `x`. `:_` is used to indicate a dimension does not
have a name. Structures overloading this method are also responsible for defining
[`ArrayInterface.has_dimnames`](@ref).

See also: [`ArrayInterface.to_dims`](@ref)

$(DIMNAMES_EXTENDED_HELP)
"""
@inline function dimnames(x::X) where {X}
if is_forwarding_wrapper(X)
return dimnames(buffer(x))
elseif isa(Base.IteratorSize(X), Base.HasShape)
return ntuple(_ -> :_, Val(ndims(X)))
else
return (:_,)
end
end
@inline function dimnames(x::X, dim::Tuple{Vararg{Integer, N}}) where {X, N}
has_dimnames(X) || return ntuple(_ -> :_, Val{N}())
dnames = dimnames(x)
nd = nfields(dnames)
ntuple(Val{N}()) do i
dim_i = Int(getfield(dim, i))
in(dim_i, 1:nd) ? getfield(dnames, dim_i, false) : :_
end
end
@inline function dimnames(x::X, dim::Integer) where {X}
if dim in 1:(isa(Base.IteratorSize(X), Base.HasShape) ? ndims(X) : 1)
return getfield(dimnames(x), Int(dim), false) # already know is inbounds
else # trailing dim is unnamed
return :_
end
end

@noinline function _throw_dimname(s::Symbol)
throw(DimensionMismatch("dimension name $(s) not found"))
end

"""
to_dims(x, d::Integer) -> Int
to_dims(x, d::Symbol) -> Int
to_dims(x, d::NTuple{N}) -> NTuple{N, Int}

Return the dimension(s) of `x` corresponding to `d`. Symbols are converted to dimensions
by searching through dimension names (see [`dimnames`](@ref)). Integers may be converted
to `Int` but are otherwise returned as is.

"""
to_dims(x, dim::Colon) = dim
to_dims(x, dim::Integer) = Int(dim)
function to_dims(x::X, s::Symbol) where {X}
dim = _find_first_egal(s, dimnames(x))
dim === 0 && _throw_dimname(s)
return dim
end
to_dims(x, dims::Tuple{Vararg{Int}}) = dims
function to_dims(x::X, dims::Tuple{Vararg{Union{Symbol, Integer}, N}}) where {X, N}
dnames = dimnames(x)
ntuple(Val{N}()) do i
dim = getfield(dims, i, false)
if dim isa Symbol
dim_i = _find_first_egal(dim, dnames)
dim_i === 0 && _throw_dimname(dim)
dim_i
else
dim_i = to_dims(x, dim)
end
dim_i
end
end

## Extensions

import Requires
Expand Down
44 changes: 43 additions & 1 deletion test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,24 @@ using Random
using SparseArrays
using Test

struct NamedDimsWrapper{D,T,N,P<:AbstractArray{T,N}} <: AbstractArray{T,N}
parent::P

NamedDimsWrapper{D}(p::P) where {D,P} = new{D,eltype(P),ndims(p),P}(p)
end

ArrayInterface.has_dimnames(T::Type{<:NamedDimsWrapper}) = true
ArrayInterface.is_forwarding_wrapper(::Type{<:NamedDimsWrapper}) = true
ArrayInterface.parent_type(::Type{T}) where {P,T<:NamedDimsWrapper{<:Any,<:Any,<:Any,P}} = P
ArrayInterface.dimnames(::NamedDimsWrapper{D}) where {D} = D
Base.parent(x::NamedDimsWrapper) = getfield(x, :parent)
Base.size(x::NamedDimsWrapper) = size(parent(x))
Base.IndexStyle(T::Type{<:NamedDimsWrapper}) = IndexStyle(parent_type(T))
Base.@propagate_inbounds Base.getindex(x::NamedDimsWrapper, inds...) = parent(x)[inds...]
Base.@propagate_inbounds function Base.setindex!(x::NamedDimsWrapper, v, inds...)
setindex!(parent(x), v, inds...)
end

# ensure we are correctly parsing these
ArrayInterface.@assume_effects :total foo(x::Bool) = x
ArrayInterface.@assume_effects bar(x::Bool) = x
Expand Down Expand Up @@ -282,4 +300,28 @@ end
end
@test ArrayInterface.ldlt_instance(SymTridiagonal(A' * A)) isa typeof(ldlt(SymTridiagonal(A' * A)))
end
end
end

@testset "dimnames interface" begin
a = zeros(3, 4, 5);
nda = NamedDimsWrapper{(:x, :y, :z)}(a)

@test !@inferred(ArrayInterface.has_dimnames(typeof(a)))
@test @inferred(ArrayInterface.has_dimnames(typeof(nda)))

@test @inferred(ArrayInterface.dimnames(a)) === (:_, :_, :_)
@test @inferred(ArrayInterface.dimnames(nda)) === (:x, :y, :z)
@test @inferred(ArrayInterface.dimnames(nda, 1)) === :x
@test @inferred(ArrayInterface.dimnames(nda, (1, 2))) === (:x, :y)
@test @inferred(ArrayInterface.dimnames((1,))) === (:_,)

@test @inferred(ArrayInterface.to_dims(nda, (:))) === Colon()
@test @inferred(ArrayInterface.to_dims(nda, 1)) === 1
@test @inferred(ArrayInterface.to_dims(nda, :x)) === 1
@test @inferred(ArrayInterface.to_dims(nda, (1, 2))) === (1, 2)
@test @inferred(ArrayInterface.to_dims(nda, (:x, :y))) === (1, 2)
@test @inferred(ArrayInterface.to_dims(nda, (:y, :x))) === (2, 1)
@test @inferred(ArrayInterface.to_dims(nda, (:y, 1))) === (2, 1)

@test_throws DimensionMismatch ArrayInterface.to_dims(a, :x)
end