From 2be67c385dc6381ba5681371bb2d40f93e4cf609 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 16 Jan 2025 22:49:52 +0200 Subject: [PATCH 1/6] Do not at-thunk with mixed-type accum --- src/lib/lib.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 179951033..0b9046925 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -40,8 +40,8 @@ end accum(x::NamedTuple, y::ChainRulesCore.Tangent) = accum(x, wrap_chainrules_output(y)) accum(x::ChainRulesCore.Tangent, y::NamedTuple) = accum(wrap_chainrules_output(x), y) -accum(x, y::AbstractThunk) = @thunk(accum(x, unthunk(y))) -accum(x::AbstractThunk, y) = @thunk(accum(unthunk(x), y)) +accum(x, y::AbstractThunk) = accum(x, unthunk(y)) +accum(x::AbstractThunk, y) = accum(unthunk(x), y) accum(x::AbstractThunk, y::AbstractThunk) = @thunk(accum(unthunk(x), unthunk(y))) # Core functions From 9bf08fb01ecc0325a873e13919bc37663270e2f8 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Sat, 18 Jan 2025 00:27:52 +0200 Subject: [PATCH 2/6] Unthunk gradient before returning --- src/compiler/chainrules.jl | 2 +- src/compiler/interface.jl | 2 +- src/lib/lib.jl | 7 +++++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index c3bb9e208..7ee5e3ab7 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -1,6 +1,6 @@ # ToDo: Move some of this to ZygoteRules, or move unthunk_tangent for Tuple and NamedTuple from # Zygote rules here? -function unthunk_tangent end +# function unthunk_tangent end @inline unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x)) @inline unthunk_tangent(x::NTuple{N,<:Number}) where N = x @inline unthunk_tangent(x::AbstractArray{<:Number,N}) where N = x diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 8f251d761..bfc00ca46 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -151,7 +151,7 @@ julia> gradient([7, 11], 0, 1) do x, y, d """ function gradient(f, args...) y, back = pullback(f, args...) - grad = back(sensitivity(y)) + grad = unthunk_tangent(back(sensitivity(y))) return _project_all(args, grad) end diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 0b9046925..90e596d95 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -40,8 +40,11 @@ end accum(x::NamedTuple, y::ChainRulesCore.Tangent) = accum(x, wrap_chainrules_output(y)) accum(x::ChainRulesCore.Tangent, y::NamedTuple) = accum(wrap_chainrules_output(x), y) -accum(x, y::AbstractThunk) = accum(x, unthunk(y)) -accum(x::AbstractThunk, y) = accum(unthunk(x), y) +accum(x::Nothing, y::AbstractThunk) = y +accum(x::AbstractThunk, y::Nothing) = x + +accum(x, y::AbstractThunk) = @thunk(accum(x, unthunk(y))) +accum(x::AbstractThunk, y) = @thunk(accum(unthunk(x), y)) accum(x::AbstractThunk, y::AbstractThunk) = @thunk(accum(unthunk(x), unthunk(y))) # Core functions From 8baa2b2d29305d5897525ab9d5c4d8380d7a0238 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Sat, 18 Jan 2025 01:34:25 +0200 Subject: [PATCH 3/6] Add tests --- src/compiler/chainrules.jl | 11 +++++------ src/compiler/interface.jl | 6 +++--- test/chainrules.jl | 28 ++++++++++++++++++++++++++++ 3 files changed, 36 insertions(+), 9 deletions(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 7ee5e3ab7..e0e09a63b 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -1,11 +1,10 @@ # ToDo: Move some of this to ZygoteRules, or move unthunk_tangent for Tuple and NamedTuple from # Zygote rules here? -# function unthunk_tangent end -@inline unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x)) -@inline unthunk_tangent(x::NTuple{N,<:Number}) where N = x -@inline unthunk_tangent(x::AbstractArray{<:Number,N}) where N = x -@inline unthunk_tangent(x::AbstractArray) = map(unthunk_tangent, x) -unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d]) +@inline ZygoteRules.unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x)) +@inline ZygoteRules.unthunk_tangent(x::NTuple{N,<:Number}) where N = x +@inline ZygoteRules.unthunk_tangent(x::AbstractArray{<:Number,N}) where N = x +@inline ZygoteRules.unthunk_tangent(x::AbstractArray) = map(unthunk_tangent, x) +ZygoteRules.unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d]) @non_differentiable unthunk_tangent(::IdDict) diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index bfc00ca46..a5da774a8 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -151,8 +151,8 @@ julia> gradient([7, 11], 0, 1) do x, y, d """ function gradient(f, args...) y, back = pullback(f, args...) - grad = unthunk_tangent(back(sensitivity(y))) - return _project_all(args, grad) + grad = back(sensitivity(y)) + return _project_all(args, unthunk_tangent(grad)) end # Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy! @@ -218,7 +218,7 @@ function withgradient(f, args...) else back(sensitivity(y)) end - results = _project_all(args, grad) + results = _project_all(args, unthunk_tangent(grad)) (val=y, grad=results) end diff --git a/test/chainrules.jl b/test/chainrules.jl index 3017a9e18..99d81fff8 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -428,3 +428,31 @@ end @test Zygote.wrap_chainrules_input([[2.0; 4.0], [1.0; 3.0]]) == [[2.0; 4.0], [1.0; 3.0]] @test Zygote.wrap_chainrules_input([nothing; 4.0]) == [0.0; 4.0] # ChainRules uses the numeric zero where possible end + +@testset "Lazy" begin + custom_add(x, y) = x + y + function ChainRulesCore.rrule(::typeof(custom_add), x, y) + function pullback(Δ) + return NoTangent(), unthunk(Δ), @thunk(error("Should not compute.")) + end + custom_add(x, y), pullback + end + + x, y = 1f0, 1f0 + Zygote.gradient(x) do x + sum(custom_add(x, y)) + end +end + +@testset "No thunks in the gradient" begin + struct Dense + w::Matrix{Float32} + end + (d::Dense)(x) = d.w * x + + layers = [Dense(rand(Float32, 3, 3))] + x = ones(Float32, 3) + g = gradient(layers -> sum(layers[1](x)), layers)[1] + @test g[1] isa NamedTuple + @test g[1].w isa Array +end From e68a4ad6f247ef8ba6061c438658e0a25b4d1116 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Sat, 18 Jan 2025 14:13:59 +0200 Subject: [PATCH 4/6] Fix test --- test/chainrules.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/chainrules.jl b/test/chainrules.jl index 99d81fff8..ed8e98b94 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -445,12 +445,12 @@ end end @testset "No thunks in the gradient" begin - struct Dense + struct CustomDense w::Matrix{Float32} end - (d::Dense)(x) = d.w * x + (d::CustomDense)(x) = d.w * x - layers = [Dense(rand(Float32, 3, 3))] + layers = [CustomDense(rand(Float32, 3, 3))] x = ones(Float32, 3) g = gradient(layers -> sum(layers[1](x)), layers)[1] @test g[1] isa NamedTuple From 2a83a66e0ec141c9cae747bfd77a62f08ca9ecdd Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Mon, 20 Jan 2025 14:50:57 +0200 Subject: [PATCH 5/6] Fix version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 725d9a1f7..ddd14436e 100644 --- a/Project.toml +++ b/Project.toml @@ -57,7 +57,7 @@ Requires = "1.1" SpecialFunctions = "1.6, 2" Statistics = "1" Tracker = "0.2" -ZygoteRules = "0.2.5" +ZygoteRules = "=0.2.5" julia = "1.6" [extras] From f297436d83c6a77518ecc0ecd84038659570600c Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Tue, 21 Jan 2025 02:01:19 +0200 Subject: [PATCH 6/6] Update deps --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ddd14436e..d453525d0 100644 --- a/Project.toml +++ b/Project.toml @@ -57,7 +57,7 @@ Requires = "1.1" SpecialFunctions = "1.6, 2" Statistics = "1" Tracker = "0.2" -ZygoteRules = "=0.2.5" +ZygoteRules = "0.2.7" julia = "1.6" [extras]