Skip to content

Commit

Permalink
wip #106
Browse files Browse the repository at this point in the history
  • Loading branch information
denizyuret committed Dec 25, 2018
1 parent 09d4dfd commit b067fd3
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 3 deletions.
5 changes: 4 additions & 1 deletion src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ mutable struct Result{T} <: Tracked{T}
args
kwargs
Result{T}(v,f,a,k) where {T} = new(v,f,a,k)
Result{T}(v,f,a,k) where {T<:Value} = error("Result cannot take $T as arg.")
# Result{T}(v,f,a,k) where {T<:Value} = error("Result cannot take $T as arg.") # See #106
end

# value() should give a regular (non-Value) result regardless of recursion
Expand Down Expand Up @@ -132,6 +132,9 @@ function track(f, args, kwargs, bcasted)
end
end

Result(v::T, f, args, kwargs) where {T<:Tracked} = v
Result(v::T, f, args, kwargs) where {T<:Bcasted} = error("Result cannot take $T as arg")

function Result(v::T, f, args, kwargs) where {T}
record!(t::Tape, v::Tracked) = (n = get(t.dict, v, nothing); n === nothing ? record!(t, Node(v)) : n)
record!(t::Tape, n::Node) = (t.dict[n.Value] = n; pushfirst!(t.list, n); n)
Expand Down
22 changes: 22 additions & 0 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,26 @@ using Statistics
# Double broadcasting
x = Param([1.,2.]); f3(x)=sin(x); f4(x)=sin.(x)
@test grad((@diff sum(f3.(x))), x) == grad((@diff sum(f4.(x))), x) == grad((@diff sum(f4(x))), x)

# Issue #106: Double Result
h(x) = exp(-x); h′(x,y) = -y
𝓁(x,y) = sum(abs2,x-y)/2
function neural_net(mparams, input; h=h, h′=h′, N=length(mparams))
δ = [];
X = Any[input];
for i=1:N
x = sum(mparams[i] .* [X[i],1])
y = h.(x)
push!(δ, h′.(x,y))
push!(X,y)
end
return X,δ
end
mparams =[[randn(),randn()] for i=1:3]
P = Param(mparams)
loss(P,x,y)= 𝓁(neural_net(P,x)[1][end],y)
x,y=randn(),randn()
J = @diff loss(P,x,y)
@test isa(J, AutoGrad.Tape)
@test_broken @gcheck loss(P,x,y)
end
2 changes: 1 addition & 1 deletion test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ function gcheck(f, x...; kw=(), nsample=10, verbose=1, rtol=0.05, atol=0.01, del
y = @diff gcsum(f, x...; kw...)
if !isa(y, Tape); @warn("Output independent of params"); return true; end
f0 = value(y)
ps = Param[ n.Value for n in y if isa(n.Value, Param) ]
ps = Param[ n.Value for n in y.list if isa(n.Value, Param) ]
if isempty(ps); @error("Cannot find any params"); end
vs = value.(ps)
gs = (p->grad(y,p)).(ps)
Expand Down
2 changes: 1 addition & 1 deletion test/header.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
using AutoGrad, Test
using AutoGrad: gradcheck, randcheck
using AutoGrad: gradcheck, randcheck, gcheck, @gcheck

0 comments on commit b067fd3

Please sign in to comment.