Skip to content

Commit

Permalink
relax restriction of x::AbstractVector in threaded_gradient!
Browse files Browse the repository at this point in the history
  • Loading branch information
longemen3000 authored Apr 12, 2024
1 parent a72e010 commit 1c32181
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/PolyesterForwardDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ function evaluate_chunks!(f::F, (r,Δx,x), start, stop, ::ForwardDiff.Chunk{C},
end
end

function threaded_gradient!(f::F, Δx::AbstractVector, x::AbstractVector, ::ForwardDiff.Chunk{C}, check = Val{false}()) where {F,C}
function threaded_gradient!(f::F, Δx::AbstractArray, x::AbstractArray, ::ForwardDiff.Chunk{C}, check = Val{false}()) where {F,C}
N = length(x)
d = cld_fast(N, C)
r = Ref{eltype(Δx)}()
Expand Down
9 changes: 9 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,12 @@ ForwardDiff.jacobian!(dxref, g!, yref, x, ForwardDiff.JacobianConfig(g!, yref, x
PolyesterForwardDiff.threaded_jacobian!(g!, y, dx, x, ForwardDiff.Chunk(8),Val{true}());
@test dx dxref
@test y yref


X = randn(10,80);
dXref = similar(x);
dX = similar(x);
ForwardDiff.gradient!(dXref, f, X, ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk(8), nothing));
PolyesterForwardDiff.threaded_gradient!(f, dX, X, ForwardDiff.Chunk(8));

@test dX dXref

0 comments on commit 1c32181

Please sign in to comment.