Skip to content

Commit

Permalink
Added vmapnt! and vmapntt!
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Mar 3, 2020
1 parent d38ea34 commit 0b7aaa5
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/LoopVectorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ const SUPPORTED_TYPES = Union{Float16,Float32,Float64,Integer}

export LowDimArray, stridedpointer, vectorizable,
@avx, @_avx, *ˡ, _avx_!,
vmap, vmap!,
vmap, vmap!, vmapnt!, vmapntt!,
vfilter, vfilter!


Expand Down
50 changes: 50 additions & 0 deletions src/map.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,56 @@ function vmap(f::F, args...) where {F}
vmap!(f, dest, args...)
end

function vmapnt!(f::F, y::AbstractVector{T}, args::Vararg{<:Any,A}) where {F,T,A}
ptry = pointer(y)
@assert reinterpret(UInt, ptry) & (VectorizationBase.REGISTER_SIZE - 1) == 0
W, Wshift = VectorizationBase.pick_vector_width_shift(T)
ptrargs = pointer.(args)
V = VectorizationBase.pick_vector_width_val(T)
N = length(y)
i = 0
while i < N - ((W << 2) - 1)
vstorent!(ptry, extract_data(f(vload.(V, ptrargs, i)...)), i); i += W
vstorent!(ptry, extract_data(f(vload.(V, ptrargs, i)...)), i); i += W
vstorent!(ptry, extract_data(f(vload.(V, ptrargs, i)...)), i); i += W
vstorent!(ptry, extract_data(f(vload.(V, ptrargs, i)...)), i); i += W
end
while i < N - (W - 1) # stops at 16 when
vstorent!(ptry, extract_data(f(vload.(V, ptrargs, i)...)), i); i += W
end
if i < N
m = mask(T, N & (W - 1))
vstore!(ptry, extract_data(f(vload.(V, ptrargs, i, m)...)), i, m)
end
y
end
function vmapntt!(f::F, y::AbstractVector{T}, args::Vararg{<:Any,A}) where {F,T,A}
ptry = pointer(y)
@assert reinterpret(UInt, ptry) & (VectorizationBase.REGISTER_SIZE - 1) == 0
W, Wshift = VectorizationBase.pick_vector_width_shift(T)
ptrargs = pointer.(args)
V = VectorizationBase.pick_vector_width_val(T)
N = length(y)
Wsh = Wshift + 2
Niter = N >>> Wsh
Base.Threads.@threads for j 0:Niter-1
i = j << Wsh
vstorent!(ptry, extract_data(f(vload.(V, ptrargs, i)...)), i); i += W
vstorent!(ptry, extract_data(f(vload.(V, ptrargs, i)...)), i); i += W
vstorent!(ptry, extract_data(f(vload.(V, ptrargs, i)...)), i); i += W
vstorent!(ptry, extract_data(f(vload.(V, ptrargs, i)...)), i)
end
ii = Niter << Wsh
while ii < N - (W - 1) # stops at 16 when
vstorent!(ptry, extract_data(f(vload.(V, ptrargs, ii)...)), ii); ii += W
end
if ii < N
m = mask(T, N & (W - 1))
vstore!(ptry, extract_data(f(vload.(V, ptrargs, ii, m)...)), ii, m)
end
y
end

# @inline vmap!(f, y, x...) = @avx y .= f.(x...)
# @inline vmap(f, x...) = @avx f.(x...)

Expand Down
9 changes: 8 additions & 1 deletion test/map.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
@testset "map" begin
@inline foo(x, y) = exp(x) - sin(y)
N = 37
N = 3781
for T (Float32,Float64)
@show T, @__LINE__
a = rand(T, N); b = rand(T, N);
c1 = map(foo, a, b);
c2 = vmap(foo, a, b);
@test c1 c2
fill!(c2, NaN); vmapnt!(foo, c2, a, b);
@test c1 c2
fill!(c2, NaN); vmapntt!(foo, c2, a, b);
@test c1 c2
@test_throws AssertionError @views vmapnt!(c2[2:end], a[2:end], b[2:end])
@test_throws AssertionError @views vmapntt!(c2[2:end], a[2:end], b[2:end])

c = rand(T,100); x = rand(T,10^4); y1 = similar(x); y2 = similar(x);
map!(xᵢ -> clenshaw(xᵢ, c), y1, x)
vmap!(xᵢ -> clenshaw(xᵢ, c), y2, x)
@test y1 y2

end
end

0 comments on commit 0b7aaa5

Please sign in to comment.