Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
denizyuret committed Dec 25, 2018
1 parent 3f503ab commit 09d4dfd
Showing 1 changed file with 55 additions and 52 deletions.
107 changes: 55 additions & 52 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,72 +63,72 @@ broadcastable(x::Value) = x # This is necessary, default is collect(x) which
# This should only catch non-primitive functions
broadcasted(::Style{Value}, f, args...) = recording() ? f(Bcasted.(args)...).value : broadcasted(f, value.(args)...)

function forw(f, args...; kwargs...)
if recording()
bcast(f, args, kwargs)
else
fforw(f, args, kwargs)
end
end
# This should only catch primitive functions
forw(f, args...; kwargs...) = recording() ? bcast(f, args, kwargs) : fforw(f, args, kwargs)

function fforw(f, args, kwargs)
@assert !recording()
aval = args
@inbounds for i in 1:length(aval)
if isa(aval[i], Value)
if aval === args
aval = Any[args...]
@timer "fforw" begin
@assert !recording()
aval = args
@inbounds for i in 1:length(aval)
if isa(aval[i], Value)
if aval === args
aval = Any[args...]
end
# @assert !isa(aval[i], Result) # This can happen during back
@assert !isa(aval[i], Bcasted)
aval[i] = aval[i].value
@assert !isa(aval[i], Value) "Illegal value recursion: $(typeof(args[i]))"
end
@assert isa(aval[i], Param) "$(typeof(aval[i])) while not recording."
aval[i] = aval[i].value
@assert !isa(aval[i], Value) "Illegal value recursion: $(typeof(args[i]))"
end
@assert aval !== args "forw called without Value args"
end
if aval === args
error("forw called without Value args")
else
f(aval...; kwargs...)
end
f(aval...; kwargs...)
end

function bcast(f, args, kwargs)
@assert recording()
aval = args
@inbounds for i in 1:length(aval)
if isa(aval[i], Bcasted)
if aval === args
aval = Any[args...]
@timer "bcast" begin
@assert recording()
aval = args
@inbounds for i in 1:length(aval)
if isa(aval[i], Bcasted)
if aval === args
aval = Any[args...]
end
aval[i] = aval[i].value
@assert !isa(aval[i], Bcasted)
end
aval[i] = aval[i].value
@assert !isa(aval[i], Bcasted)
end
bcasted = (aval !== args)
if bcasted && f !== broadcast
aval = pushfirst!(aval, f)
f = broadcast
end
end
if aval === args
track(f, aval, kwargs, false)
else
aval = pushfirst!(aval, f)
track(broadcast, aval, kwargs, true) |> Bcasted
end
v = track(f, aval, kwargs, bcasted)
bcasted ? Bcasted(v) : v
end

function track(f, args, kwargs, bcasted)
@assert recording()
aval = args
@inbounds for i in 1:length(aval)
if isa(aval[i], Tracked)
if aval === args
aval = isa(args, Array) ? copy(args) : Any[args...]
@timer "track" begin
@assert recording()
aval = args
@inbounds for i in 1:length(aval)
if isa(aval[i], Tracked)
if aval === args
aval = isa(args, Array) ? copy(args) : Any[args...]
end
aval[i] = aval[i].value
@assert !isa(aval[i], Value)
end
aval[i] = aval[i].value
@assert !isa(aval[i], Value)
end
end
if aval === args
@assert bcasted
return f(args...; kwargs...)
@assert bcasted "Tracking function without Value args."
f(args...; kwargs...)
else
v = f(aval...; kwargs...)
return Result(v, f, args, kwargs)
@timer ftimer(f,aval) (v = f(aval...; kwargs...))
@timer "record" Result(v, f, args, kwargs)
end
end

Expand Down Expand Up @@ -156,7 +156,6 @@ back(x...; o...) = throw(ArgumentError("AutoGrad does not yet support back"*stri
abstract type Arg{N} end

function differentiate(f, x...; o...)
global _tapes
duplicate(x)=(isa(x,Value) ? identity(x) : x)
if !isempty(_tapes) # PR#75: to avoid tape confusion
x = map(duplicate,x) # duplicate tracked function arguments.
Expand All @@ -178,29 +177,32 @@ function differentiate(f, x...; o...)
n1 = first(tape.list)
if result !== n1.Value; error("Result not on tape"); end
n1.outgrad = one(value(result))
tm(r::Result,i::Int)=(r.func==broadcast ? "$(r.args[1]).[$(i-1)]" : "$(r.func)[$i]")
for n in tape.list
if n.outgrad == nothing; continue; end
r = n.Value
@inbounds for i in 1:length(n.parents)
if !isassigned(n.parents, i); continue; end
p = n.parents[i]
@timer tm(r,i) (g = back(r.func, Arg{i}, n.outgrad, r, r.args...; r.kwargs...))
@timer btimer(r,i) (g = back(r.func, Arg{i}, n.outgrad, r, r.args...; r.kwargs...))
@timer "sum_outgrads" (p.outgrad = sum_outgrads(p.outgrad, g))
end
if isempty(_tapes) && isa(r,Result) && n !== n1; gcnode(n); end # save memory
end
return tape
end

default_gc(n::Node) = (n.outgrad=nothing; n.Value.value=nothing)
default_gc(n::Node) = nothing # (n.outgrad=nothing; n.Value.value=nothing)
gcnode = default_gc
set_gc_function(f::Function) = (global gcnode = f)

# This allows argument expressions like @diff sin(sqrt(x)) which fail with differentiate
# because arguments get evaluated before the tape gets created.
macro diff(fx); :(differentiate(()->$(esc(fx)))); end

# Used by @timer
btimer(r::Result,i::Int)=(r.func===broadcast ? "$(r.args[1]).[$(i-1)]" : "$(r.func)[$i]")
ftimer(f::Function,a::Array{Any})=(f===broadcast ? "$(a[1])." : "$f")

# Old style grad and gradloss
function grad(fun::Function, argnum::Int=1, loss=false)
function gradfun(args...; kwargs...)
Expand All @@ -218,6 +220,7 @@ end
gradloss(f,a=1)=grad(f,a,true)



### DEPRECATED:

# # Fix iterate, first, last, cons!, collect, get/grad?
Expand Down Expand Up @@ -336,7 +339,7 @@ gradloss(f,a=1)=grad(f,a,true)



# ftimer(f,a)=(f===broadcast ? "$(a[1])." : "$f") # used by @timer
#

# function record(t::Tape, r::Result)
# nargs = length(r.args)
Expand Down

0 comments on commit 09d4dfd

Please sign in to comment.