Skip to content

Commit

Permalink
remove assumption that points have to be real in chebeval
Browse files Browse the repository at this point in the history
  • Loading branch information
maltezfaria committed Feb 9, 2024
1 parent 411b93c commit 379bdde
Showing 1 changed file with 2 additions and 50 deletions.
52 changes: 2 additions & 50 deletions src/chebinterp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ Implementation of [`chebeval`](@ref) using a Clenshaw summation without vectoriz
"""
function chebeval_novec(
coefs,
x::SVector{N,<:Real},
x::SVector{N},
rec::HyperRectangle,
sz::Val{SZ},
) where {N,SZ}
Expand Down Expand Up @@ -217,55 +217,6 @@ statically which allows for various improvements by the compiler (like loop unro
end
end

@inline @fastmath function _evaluate_vec(
x,
c,
::Val{dim},
i1,
len,
sz::Val{SZ},
) where {dim,SZ}
@inbounds n = SZ[dim]
@inbounds xd = x[dim]
if dim == 1
@inbounds c₁ = c[i1]
if n 2
n == 1 && return c₁ + one(xd) * zero(c₁)
return muladd(xd, c[i1+1], c₁)
end
@inbounds bₖ = muladd(2xd, c[i1+(n-1)], c[i1+(n-2)])
@inbounds bₖ₊₁ = oftype(bₖ, c[i1+(n-1)])
for j in n-3:-1:1
@inbounds bⱼ = muladd(2xd, bₖ, c[i1+j]) - bₖ₊₁
bₖ, bₖ₊₁ = bⱼ, bₖ
end
return muladd(xd, bₖ, c₁) - bₖ₊₁
else
Δi = len ÷ n # column-major stride of current dimension

# we recurse downward on dim for cache locality,
# since earlier dimensions are contiguous
dim′ = Val{dim - 1}()

c₁ = _evaluate(x, c, dim′, i1, Δi, sz)
if n 2
n == 1 && return c₁ + one(xd) * zero(c₁)
c₂ = _evaluate(x, c, dim′, i1 + Δi, Δi, sz)
return c₁ + xd * c₂
end
cₙ₋₁ = _evaluate(x, c, dim′, i1 + (n - 2) * Δi, Δi, sz)
cₙ = _evaluate(x, c, dim′, i1 + (n - 1) * Δi, Δi, sz)
bₖ = muladd(2xd, cₙ, cₙ₋₁)
bₖ₊₁ = oftype(bₖ, cₙ)
for j in n-3:-1:1
cⱼ = _evaluate(x, c, dim′, i1 + j * Δi, Δi, sz)
bⱼ = muladd(2xd, bₖ, cⱼ) - bₖ₊₁
bₖ, bₖ₊₁ = bⱼ, bₖ
end
return muladd(xd, bₖ, c₁) - bₖ₊₁
end
end

"""
chebeval_vec(coefs,x,rec[,sz])
Expand All @@ -286,6 +237,7 @@ end
N = length(x)
T = eltype(coeffs)
V = SVector{SZ[1],T}
# V = Vec{SZ[1],T}
n = prod(SZ) ÷ SZ[1]
coeffs_ = reinterpret(V, coeffs)
# coeffs_ = unsafe_wrap(Array, reinterpret(Ptr{V}, pointer(coeffs)), n)
Expand Down

0 comments on commit 379bdde

Please sign in to comment.