From b067fd397ee8646855bde3d51047eb62e8155017 Mon Sep 17 00:00:00 2001 From: Deniz Yuret Date: Tue, 25 Dec 2018 05:43:34 -0500 Subject: [PATCH] wip #106 --- src/core.jl | 5 ++++- test/core.jl | 22 ++++++++++++++++++++++ test/gradcheck.jl | 2 +- test/header.jl | 2 +- 4 files changed, 28 insertions(+), 3 deletions(-) diff --git a/src/core.jl b/src/core.jl index 7320591..b8adad7 100644 --- a/src/core.jl +++ b/src/core.jl @@ -22,7 +22,7 @@ mutable struct Result{T} <: Tracked{T} args kwargs Result{T}(v,f,a,k) where {T} = new(v,f,a,k) - Result{T}(v,f,a,k) where {T<:Value} = error("Result cannot take $T as arg.") + # Result{T}(v,f,a,k) where {T<:Value} = error("Result cannot take $T as arg.") # See #106 end # value() should give a regular (non-Value) result regardless of recursion @@ -132,6 +132,9 @@ function track(f, args, kwargs, bcasted) end end +Result(v::T, f, args, kwargs) where {T<:Tracked} = v +Result(v::T, f, args, kwargs) where {T<:Bcasted} = error("Result cannot take $T as arg") + function Result(v::T, f, args, kwargs) where {T} record!(t::Tape, v::Tracked) = (n = get(t.dict, v, nothing); n === nothing ? record!(t, Node(v)) : n) record!(t::Tape, n::Node) = (t.dict[n.Value] = n; pushfirst!(t.list, n); n) diff --git a/test/core.jl b/test/core.jl index 10066c9..fb60a20 100644 --- a/test/core.jl +++ b/test/core.jl @@ -46,4 +46,26 @@ using Statistics # Double broadcasting x = Param([1.,2.]); f3(x)=sin(x); f4(x)=sin.(x) @test grad((@diff sum(f3.(x))), x) == grad((@diff sum(f4.(x))), x) == grad((@diff sum(f4(x))), x) + + # Issue #106: Double Result + h(x) = exp(-x); h′(x,y) = -y + 𝓁(x,y) = sum(abs2,x-y)/2 + function neural_net(mparams, input; h=h, h′=h′, N=length(mparams)) + δ = []; + X = Any[input]; + for i=1:N + x = sum(mparams[i] .* [X[i],1]) + y = h.(x) + push!(δ, h′.(x,y)) + push!(X,y) + end + return X,δ + end + mparams =[[randn(),randn()] for i=1:3] + P = Param(mparams) + loss(P,x,y)= 𝓁(neural_net(P,x)[1][end],y) + x,y=randn(),randn() + J = @diff loss(P,x,y) + @test isa(J, AutoGrad.Tape) + @test_broken @gcheck loss(P,x,y) end diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 53cc08e..8f736f2 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -151,7 +151,7 @@ function gcheck(f, x...; kw=(), nsample=10, verbose=1, rtol=0.05, atol=0.01, del y = @diff gcsum(f, x...; kw...) if !isa(y, Tape); @warn("Output independent of params"); return true; end f0 = value(y) - ps = Param[ n.Value for n in y if isa(n.Value, Param) ] + ps = Param[ n.Value for n in y.list if isa(n.Value, Param) ] if isempty(ps); @error("Cannot find any params"); end vs = value.(ps) gs = (p->grad(y,p)).(ps) diff --git a/test/header.jl b/test/header.jl index b2dd4f5..492d954 100644 --- a/test/header.jl +++ b/test/header.jl @@ -1,2 +1,2 @@ using AutoGrad, Test -using AutoGrad: gradcheck, randcheck +using AutoGrad: gradcheck, randcheck, gcheck, @gcheck