Skip to content

Commit

Permalink
Remove unthunk_tangent methods, leave only function def
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th committed Jan 20, 2025
1 parent a4d77ab commit 62fc864
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 13 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ZygoteRules"
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
version = "0.2.6"
version = "0.3.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
12 changes: 0 additions & 12 deletions src/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,7 @@ abstract type AContext end
function adjoint end
function _pullback end
function pullback end


function unthunk_tangent end
@inline unthunk_tangent(x) = x
@inline unthunk_tangent(x::Tuple) = map(unthunk_tangent, x)
@inline unthunk_tangent(x::NamedTuple) = map(unthunk_tangent, x)
@inline unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x))
@inline unthunk_tangent(x::NTuple{N,<:Number}) where N = x
@inline unthunk_tangent(x::AbstractArray{<:Number,N}) where N = x
@inline unthunk_tangent(x::AbstractArray) = map(unthunk_tangent, x)
unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d])
@non_differentiable unthunk_tangent(::IdDict)


function gradm(ex, mut = false, keepthunks = false)
@capture(shortdef(ex), (name_(args__) = body_) |
Expand Down

0 comments on commit 62fc864

Please sign in to comment.