diff --git a/Project.toml b/Project.toml index c20837c..0388f6f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ZygoteRules" uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.6" +version = "0.3.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/adjoint.jl b/src/adjoint.jl index a5883dc..7ab0c59 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -31,19 +31,7 @@ abstract type AContext end function adjoint end function _pullback end function pullback end - - function unthunk_tangent end -@inline unthunk_tangent(x) = x -@inline unthunk_tangent(x::Tuple) = map(unthunk_tangent, x) -@inline unthunk_tangent(x::NamedTuple) = map(unthunk_tangent, x) -@inline unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x)) -@inline unthunk_tangent(x::NTuple{N,<:Number}) where N = x -@inline unthunk_tangent(x::AbstractArray{<:Number,N}) where N = x -@inline unthunk_tangent(x::AbstractArray) = map(unthunk_tangent, x) -unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d]) -@non_differentiable unthunk_tangent(::IdDict) - function gradm(ex, mut = false, keepthunks = false) @capture(shortdef(ex), (name_(args__) = body_) |