Skip to content

Commit

Permalink
compute Matrix length from dims
Browse files Browse the repository at this point in the history
Since changing Array to use Memory as the backing, we had the option of
making non-Vector arrays more flexible, but had instead preserved the
restriction that they must be zero offset and equal in length to the
Memory. This results in extra complexity, restrictions, and allocations
however, but doesn't gain any known benefits. This PR aims to test if
nanosoldier detects any benefit, or whether this restriction has
outlived its usefulness.
  • Loading branch information
vtjnash committed Oct 18, 2024
1 parent ca3713e commit c578faa
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 53 deletions.
4 changes: 0 additions & 4 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3130,10 +3130,6 @@ function _wrap(ref::MemoryRef{T}, dims::NTuple{N, Int}) where {T, N}
mem_len = length(mem) + 1 - memoryrefoffset(ref)
len = Core.checked_dims(dims...)
@boundscheck mem_len >= len || invalid_wrap_err(mem_len, dims, len)
if N != 1 && !(ref === GenericMemoryRef(mem) && len === mem_len)
mem = ccall(:jl_genericmemory_slice, Memory{T}, (Any, Ptr{Cvoid}, Int), mem, ref.ptr_or_offset, len)
ref = memoryref(mem)
end
return ref
end

Expand Down
3 changes: 2 additions & 1 deletion base/essentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ const Bottom = Union{}
# Define minimal array interface here to help code used in macros:
length(a::Array{T, 0}) where {T} = 1
length(a::Array{T, 1}) where {T} = getfield(a, :size)[1]
length(a::Array) = getfield(getfield(getfield(a, :ref), :mem), :length)
length(a::Array{T, 2}) where {T} = (sz = getfield(a, :size); sz[1] * sz[2])
# other sizes are handled by generic prod definition for AbstractArray
length(a::GenericMemory) = getfield(a, :length)
throw_boundserror(A, I) = (@noinline; throw(BoundsError(A, I)))

Expand Down
32 changes: 17 additions & 15 deletions base/reshapedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,31 +35,33 @@ end
length(R::ReshapedArrayIterator) = length(R.iter)
eltype(::Type{<:ReshapedArrayIterator{I}}) where {I} = @isdefined(I) ? ReshapedIndex{eltype(I)} : Any

## reshape(::Array, ::Dims) returns an Array, except for isbitsunion eltypes (issue #28611)
@noinline throw_dmrsa(dims, len) =
throw(DimensionMismatch("new dimensions $(dims) must be consistent with array length $len"))

## reshape(::Array, ::Dims) returns a new Array (to avoid conditionally aliasing the structure, only the data)
# reshaping to same # of dimensions
@eval function reshape(a::Array{T,M}, dims::NTuple{N,Int}) where {T,N,M}
throw_dmrsa(dims, len) =
throw(DimensionMismatch("new dimensions $(dims) must be consistent with array length $len"))
len = Core.checked_dims(dims...) # make sure prod(dims) doesn't overflow (and because of the comparison to length(a))
if len != length(a)
throw_dmrsa(dims, length(a))
end
isbitsunion(T) && return ReshapedArray(a, dims, ())
if N == M && dims == size(a)
return a
end
ref = a.ref
if M == 1 && N !== 1
mem = ref.mem::Memory{T}
if !(ref === memoryref(mem) && len === mem.length)
mem = ccall(:jl_genericmemory_slice, Memory{T}, (Any, Ptr{Cvoid}, Int), mem, ref.ptr_or_offset, len)
ref = memoryref(mem)::typeof(ref)
end
end
# or we could use `a = Array{T,N}(undef, ntuple(0, Val(N))); a.ref = ref; a.size = dims; return a` here
# or we could use `a = Array{T,N}(undef, ntuple(i->0, Val(N))); a.ref = ref; a.size = dims; return a` here to avoid the eval
return $(Expr(:new, :(Array{T,N}), :ref, :dims))
end

## reshape!(::Array, ::Dims) returns the original array, but must have the same dimensions and length as the original
# see also resize! for a similar operation that can change the length
function reshape!(a::Array{T,N}, dims::NTuple{N,Int}) where {T,N}
len = Core.checked_dims(dims...) # make sure prod(dims) doesn't overflow (and because of the comparison to length(a))
if len != length(a)
throw_dmrsa(dims, length(a))
end
setfield!(a, :dims, dims)
return a
end



"""
reshape(A, dims...) -> AbstractArray
Expand Down
33 changes: 0 additions & 33 deletions src/genericmemory.c
Original file line number Diff line number Diff line change
Expand Up @@ -221,39 +221,6 @@ JL_DLLEXPORT jl_genericmemory_t *jl_alloc_memory_any(size_t n)
return jl_alloc_genericmemory(jl_memory_any_type, n);
}

JL_DLLEXPORT jl_genericmemory_t *jl_genericmemory_slice(jl_genericmemory_t *mem, void *data, size_t len)
{
// Given a GenericMemoryRef represented as `jl_genericmemory_ref ref = {data, mem}`,
// return a new GenericMemory that only accesses the slice from the given GenericMemoryRef to
// the given length if this is possible to return. This allows us to make
// `length(Array)==length(Array.ref.mem)`, for simplification of this.
jl_datatype_t *dt = (jl_datatype_t*)jl_typetagof(mem);
const jl_datatype_layout_t *layout = dt->layout;
// repeated checks here ensure the values cannot overflow, since we know mem->length is a reasonable value
if (len > mem->length)
jl_exceptionf(jl_argumenterror_type, "invalid GenericMemory slice"); // TODO: make a BoundsError
if (layout->flags.arrayelem_isunion) {
if (!((size_t)data == 0 && mem->length == len))
jl_exceptionf(jl_argumenterror_type, "invalid GenericMemory slice"); // only exact slices are supported
data = mem->ptr;
}
else if (layout->size == 0) {
if ((size_t)data > mem->length || (size_t)data + len > mem->length)
jl_exceptionf(jl_argumenterror_type, "invalid GenericMemory slice"); // TODO: make a BoundsError
data = mem->ptr;
}
else {
if (data < mem->ptr || (char*)data > (char*)mem->ptr + mem->length * layout->size || (char*)data + len * layout->size > (char*)mem->ptr + mem->length * layout->size)
jl_exceptionf(jl_argumenterror_type, "invalid GenericMemory slice"); // TODO: make a BoundsError
}
jl_task_t *ct = jl_current_task;
jl_genericmemory_t *newmem = (jl_genericmemory_t*)jl_gc_alloc(ct->ptls, sizeof(jl_genericmemory_t) + sizeof(void*), dt);
newmem->length = len;
newmem->ptr = data;
jl_genericmemory_data_owner_field(newmem) = jl_genericmemory_owner(mem);
return newmem;
}

JL_DLLEXPORT void jl_genericmemory_copyto(jl_genericmemory_t *dest, char* destdata,
jl_genericmemory_t *src, char* srcdata,
size_t n) JL_NOTSAFEPOINT
Expand Down

0 comments on commit c578faa

Please sign in to comment.