Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance Robustness in Reverse Pass #442

Merged
merged 18 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Mooncake"
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.4.75"
version = "0.4.76"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
11 changes: 4 additions & 7 deletions src/fwds_rvs_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -662,14 +662,11 @@ with.
if P isa DataType
names = fieldnames(P)
types = fieldtypes(P)
wrapped_field_zeros = map(enumerate(tangent_field_types(P))) do (n, tt)
wrapped_field_zeros = map(enumerate(always_initialised(P))) do (n, init)
fzero = :(zero_rdata_from_type($(types[n])))
if tt <: PossiblyUninitTangent
Q = :(rdata_type(tangent_type($(fieldtype(P, n)))))
return :(PossiblyUninitTangent{$Q}($fzero))
else
return fzero
end
init && return fzero
Q = :(rdata_type(tangent_type($(fieldtype(P, n)))))
return :(PossiblyUninitTangent{$Q}($fzero))
end
wrapped_field_zeros_tuple = Expr(:call, :tuple, wrapped_field_zeros...)
wrapped_expr = :(R(NamedTuple{$names}($wrapped_field_zeros_tuple)))
Expand Down
167 changes: 109 additions & 58 deletions src/interpreter/s2s_reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ function make_ad_stmts!(stmt::ReturnNode, line::ID, info::ADInfo)
end
if is_active(stmt.val)
rdata_id = get_rev_data_id(info, stmt.val)
rvs = new_inst(Expr(:call, increment_ref!, rdata_id, Argument(2)))
rvs = increment_ref_stmts(rdata_id, Argument(2))
assert_id = ID()
val = __inc(stmt.val)
fwds = [
Expand Down Expand Up @@ -479,7 +479,13 @@ function make_ad_stmts!(stmt::PiNode, line::ID, info::ADInfo)
val_rdata_ref_id = get_rev_data_id(info, stmt.val)
output_rdata_ref_id = get_rev_data_id(info, line)
fwds = PiNode(__inc(stmt.val), fcodual_type(CC.widenconst(stmt.typ)))
rvs = Expr(:call, __pi_rvs!, P, val_rdata_ref_id, output_rdata_ref_id)

# Get the rdata from the output_rdata_ref, and set its new value to zero, and
# increment the output ref.
output_rdata_id = ID()
deref_stmts = deref_and_zero_stmts(P, output_rdata_ref_id, output_rdata_id)
inc_exprs = increment_ref_stmts(val_rdata_ref_id, output_rdata_id)
rvs = vcat(deref_stmts, inc_exprs)
else
# If the value of the PiNode is a constant / QuoteNode etc, then there is nothing to
# do on the reverse-pass.
Expand All @@ -494,11 +500,6 @@ function make_ad_stmts!(stmt::PiNode, line::ID, info::ADInfo)
return ad_stmt_info(line, nothing, fwds, rvs)
end

@inline function __pi_rvs!(::Type{P}, val_rdata_ref::Ref, output_rdata_ref::Ref) where {P}
increment_ref!(val_rdata_ref, __deref_and_zero(P, output_rdata_ref))
return nothing
end

# Constant GlobalRefs are handled. See const_codual. Non-constant GlobalRefs are handled by
# assuming that they are constant, and creating a CoDual with the value. We then check at
# run-time that the value has not changed.
Expand Down Expand Up @@ -723,17 +724,53 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo)
rvs_pass = if T_pb!! <: NoPullback
nothing
else
Expr(
:call,
__run_rvs_pass!,
get_primal_type(info, line),
sig,
pb,
get_rev_data_id(info, line),
map(Base.Fix1(get_rev_data_id, info), args)...,
# Get the rdata which we pass into the pullback from its rdata ref.
rdata_ref_id = get_rev_data_id(info, line)
rdata_output_id = ID()
rdata_output_expr = Expr(:call, getfield, rdata_ref_id, QuoteNode(:x))
rdata_output = (rdata_output_id, new_inst(rdata_output_expr))

# Zero out the value stored in this rdata ref now that we have its current
# value. The new value is rdata, so must be an instance of a bits type, so is
# safe to interpolate straight into instruction.
zero_val = zero_like_rdata_from_type(get_primal_type(info, line))
zero_rdata_expr = Expr(:call, setfield!, rdata_ref_id, QuoteNode(:x), zero_val)
zero_rdata_ref = (ID(), new_inst(zero_rdata_expr))

# Run the pullback. The result is a tuple comprising `length(args)` elements.
call_pullback_id = ID()
call_pullback = (call_pullback_id, new_inst(Expr(:call, pb, rdata_output_id)))

# For each element of the tuple returned by call_pullback, if the corresponding
# value in the primal IR is an Argument / SSA (if `get_rev_data_id` does not
# return nothing), increment the value in its rdata ref. This is equivalent to
# rdata_ref[] = increment!!(rdata_ref[], rdata_inc_resulting_from_pullback),
# but written out manually to ensure nothing fails to inline.
# If the corresponding value in the primal IR is not an Argument / SSA (e.g. it
# is a literal, a `QuoteNode`, or a `GlobalRef`), do nothing as we do not track
# gradients w.r.t. it.
tmp = map(enumerate(args)) do (n, arg)
rev_data_id = get_rev_data_id(info, arg)

# If arg is not an SSA / Argument, then no rdata ref to inc.
rev_data_id === nothing && return nothing

# Extract rdata from result of calling pullback.
rdata_inc_id = ID()
rdata_inc_expr = Expr(:call, getfield, call_pullback_id, n)
rdata_inc = (rdata_inc_id, new_inst(rdata_inc_expr))

# Construct statments to increment ref.
return vcat(rdata_inc, increment_ref_stmts(rev_data_id, rdata_inc_id))
end

# Concatenate all statements, and return them.
vcat(
IDInstPair[rdata_output, zero_rdata_ref, call_pullback],
reduce(vcat, filter(x -> !(x === nothing), tmp); init=IDInstPair[]),
)
end
return ad_stmt_info(line, comms_id, fwds, new_inst(rvs_pass))
return ad_stmt_info(line, comms_id, fwds, rvs_pass)

elseif Meta.isexpr(stmt, :boundscheck)
# For some reason the compiler cannot handle boundscheck statements when we run it
Expand Down Expand Up @@ -782,6 +819,29 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo)
end
end

"""
increment_ref_stmts(ref_id::ID, inc_data)::Vector{IDInstPair}

Equivalent to `ref[] = increment!!(ref[], inc_data)`, where `ref` and `inc_data` are the
values associated to `ref_id` and `inc_data` respectively.
"""
function increment_ref_stmts(ref_id::ID, inc_data)::Vector{IDInstPair}

# Get the value stored in the `Base.RefValue`.
ref_val_id = ID()
ref_val = (ref_val_id, new_inst(Expr(:call, getfield, ref_id, QuoteNode(:x))))

# Increment the value by inc_data.
new_val_id = ID()
new_val = (new_val_id, new_inst(Expr(:call, increment!!, ref_val_id, inc_data)))

# Update the value stored in the rdata reference.
set_ref_expr = Expr(:call, setfield!, ref_id, QuoteNode(:x), new_val_id)
set_ref = (ID(), new_inst(set_ref_expr))

return IDInstPair[ref_val, new_val, set_ref]
end

is_active(::Union{Argument,ID}) = true
is_active(::Any) = false

Expand All @@ -807,33 +867,6 @@ end
__get_primal(x::CoDual) = primal(x)
__get_primal(x) = x

"""
__run_rvs_pass!(
P::Type, ::Type{sig}, pb!!, ret_rev_data_ref::Ref, arg_rev_data_refs...
) where {sig}

Used in `make_ad_stmts!` method for `Expr(:call, ...)` and `Expr(:invoke, ...)`.
"""
@inline function __run_rvs_pass!(
P::Type, ::Type{sig}, pb!!, ret_rev_data_ref::Ref, arg_rev_data_refs...
) where {sig}
tuple_map(increment_if_ref!, arg_rev_data_refs, pb!!(ret_rev_data_ref[]))
set_ret_ref_to_zero!!(P, ret_rev_data_ref)
return nothing
end

@inline increment_if_ref!(ref::Ref, rvs_data) = increment_ref!(ref, rvs_data)
@inline increment_if_ref!(::Ref, ::ZeroRData) = nothing
@inline increment_if_ref!(::Nothing, ::Any) = nothing

@inline increment_ref!(x::Ref, t) = setindex!(x, increment!!(x[], t))
@inline increment_ref!(::Base.RefValue{NoRData}, t) = nothing

@inline function set_ret_ref_to_zero!!(::Type{P}, r::Ref{R}) where {P,R}
return r[] = zero_like_rdata_from_type(P)
end
@inline set_ret_ref_to_zero!!(::Type{P}, r::Base.RefValue{NoRData}) where {P} = nothing

const RuleMC{A,R} = MistyClosure{OpaqueClosure{A,R}}

#
Expand Down Expand Up @@ -1437,7 +1470,7 @@ function pullback_ir(

# De-reference the nth rdata.
rdata_id = ID()
rdata = new_inst(Expr(:call, getindex, arg_rdata_ref_ids[n]))
rdata = new_inst(Expr(:call, getfield, arg_rdata_ref_ids[n], QuoteNode(:x)))

# Get the nth lazy zero rdata.
lazy_zero_rdata_id = ID()
Expand Down Expand Up @@ -1511,11 +1544,12 @@ function conclude_rvs_block(

# Create statements which extract + zero the rdata refs associated to them.
rdata_ids = map(_ -> ID(), phi_ids)
deref_stmts = map(phi_ids, rdata_ids) do phi_id, deref_id
tmp = map(phi_ids, rdata_ids) do phi_id, deref_id
P = get_primal_type(info, phi_id)
r = get_rev_data_id(info, phi_id)
return (deref_id, new_inst(Expr(:call, __deref_and_zero, P, r)))
return deref_and_zero_stmts(P, r, deref_id)
end
deref_stmts = reduce(vcat, tmp; init=IDInstPair[])

# For each predecessor, create a `BBlock` which processes its corresponding edge in
# each of the `PhiNode`s.
Expand All @@ -1540,14 +1574,19 @@ function __get_value(edge::ID, x::IDPhiNode)
end

"""
__deref_and_zero(::Type{P}, x::Ref) where {P}
deref_and_zero_stmts(P, ref_id, val_id)

Helper, used in conclude_rvs_block.
Equivalent to something like
```julia
val = ref[]
ref[] = zero_rdata_from_type(P)
```
"""
@inline function __deref_and_zero(::Type{P}, x::Ref) where {P}
t = x[]
x[] = Mooncake.zero_like_rdata_from_type(P)
return t
function deref_and_zero_stmts(P, ref_id, val_id)
val = (val_id, new_inst(Expr(:call, getfield, ref_id, QuoteNode(:x))))
r = Mooncake.zero_like_rdata_from_type(P)
set_ref = (ID(), new_inst(Expr(:call, setfield!, ref_id, QuoteNode(:x), r)))
return IDInstPair[val, set_ref]
end

"""
Expand All @@ -1562,10 +1601,14 @@ of some block:
%6 = φ (#2 => _1, #3 => %5)
%7 = φ (#2 => 5., #3 => _2)
```
Let the tangent refs associated to `%6`, `%7`, and `_1`` be denoted `t%6`, `t%7`, and `t_1`
resp., and let `pred_id` be `#2`, then this function will produce a basic block of the form
Let the rdata refs associated to `%6`, `%7`, and `_1`` be denoted `r%6`, `r%7`, and `r_1`
resp., and let `pred_id` be `#2`, and `increment_ref!` be the following function,
```julia
increment_ref!(t_1, t%6)
increment_ref!(ref, x) = ref[] = increment!!(ref[], x)
```
then this `rvs_phi_block` will produce a basic block of the form
```julia
increment_ref!(r_1, r%6)
nothing
goto #2
```
Expand All @@ -1577,15 +1620,23 @@ on.

The same ideas apply if `pred_id` were `#3`. The block would end with `#3`, and there would
be two `increment_ref!` calls because both `%5` and `_2` are not constants.

In practice, code which is equivalent to `increment_ref!` is created directly, rather than
inserting a call to a generic Julia function. This is because we need to be certain that
the getfield and setfield! calls applied to any references are visible to the SROA
optimisation pass. If we insert a call to a function like `increment_ref!`, it might not be
inlined away, making such references opaque.
"""
function rvs_phi_block(
pred_id::ID, rdata_ids::Vector{ID}, values::Vector{Any}, info::ADInfo
)
@assert length(rdata_ids) == length(values)
inc_stmts = map(rdata_ids, values) do id, val
stmt = Expr(:call, increment_if_ref!, get_rev_data_id(info, val), id)
return (ID(), new_inst(stmt))
tmp = map(rdata_ids, values) do id, val
rev_data_id = get_rev_data_id(info, val)
rev_data_id === nothing && return nothing
return increment_ref_stmts(rev_data_id, id)
end
inc_stmts = reduce(vcat, filter(x -> !(x === nothing), tmp); init=IDInstPair[])
goto_stmt = (ID(), new_inst(IDGotoNode(pred_id)))
return BBlock(ID(), vcat(inc_stmts, goto_stmt))
end
Expand Down
2 changes: 2 additions & 0 deletions src/interpreter/zero_like_rdata.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ error -- please open an issue in such a situation.
struct ZeroRData end

@inline increment!!(::ZeroRData, r::R) where {R} = r
@inline increment!!(r::R, ::ZeroRData) where {R} = r
@inline increment!!(::ZeroRData, ::ZeroRData) = ZeroRData()

"""
zero_like_rdata_type(::Type{P}) where {P}
Expand Down
38 changes: 12 additions & 26 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,43 +24,29 @@ the same length, while `map` will just produce a new tuple whose length is equal
shorter of `x` and `y`.
"""
@inline @generated function tuple_map(f::F, x::Tuple) where {F}
return Expr(:call, :tuple, map(n -> :(f(getfield(x, $n))), eachindex(x.parameters))...)
return Expr(:call, :tuple, map(n -> :(f(getfield(x, $n))), 1:fieldcount(x))...)
end

@inline @generated function tuple_map(f::F, x::Tuple, y::Tuple) where {F}
if length(x.parameters) != length(y.parameters)
return :(throw(ArgumentError("length(x) != length(y)")))
else
stmts = map(n -> :(f(getfield(x, $n), getfield(y, $n))), eachindex(x.parameters))
stmts = map(n -> :(f(getfield(x, $n), getfield(y, $n))), 1:fieldcount(x))
return Expr(:call, :tuple, stmts...)
end
end

for N in 1:128
@eval @inline function tuple_map(f::F, x::Tuple{Vararg{Any,$N}}) where {F}
return $(Expr(:call, :tuple, map(n -> :(f(getfield(x, $n))), 1:N)...))
end
@eval @inline function tuple_map(
f::F, x::NamedTuple{names,<:Tuple{Vararg{Any,$N}}}
) where {F,names}
return NamedTuple{names}(
$(Expr(:call, :tuple, map(n -> :(f(getfield(x, $n))), 1:N)...))
)
end
@eval @inline function tuple_map(f, x::Tuple{Vararg{Any,$N}}, y::Tuple{Vararg{Any,$N}})
return $(Expr(
:call, :tuple, map(n -> :(f(getfield(x, $n), getfield(y, $n))), 1:N)...
))
end
@eval @inline function tuple_map(
f::F,
x::NamedTuple{names,<:Tuple{Vararg{Any,$N}}},
y::NamedTuple{names,<:Tuple{Vararg{Any,$N}}},
) where {F,names}
return NamedTuple{names}(
$(Expr(:call, :tuple, map(n -> :(f(getfield(x, $n), getfield(y, $n))), 1:N)...))
)
@generated function tuple_map(f, x::NamedTuple{names}) where {names}
getfield_exprs = map(n -> :(f(getfield(x, $n))), 1:fieldcount(x))
return :(NamedTuple{names}($(Expr(:call, :tuple, getfield_exprs...))))
end

@generated function tuple_map(f, x::NamedTuple{names}, y::NamedTuple{names}) where {names}
if fieldcount(x) != fieldcount(y)
return :(throw(ArgumentError("length(x) != length(y)")))
end
getfield_exprs = map(n -> :(f(getfield(x, $n), getfield(y, $n))), 1:fieldcount(x))
return :(NamedTuple{names}($(Expr(:call, :tuple, getfield_exprs...))))
end

for N in 1:256
Expand Down
4 changes: 2 additions & 2 deletions test/ext/special_functions/special_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using Mooncake.TestUtils: test_rule

# Rules in this file are only lightly tester, because they are all just @from_rrule rules.
@testset "special_functions" begin
@testset for (perf_flag, f, x...) in vcat(
@testset "$perf_flag, $(typeof((f, x...)))" for (perf_flag, f, x...) in vcat(
map([Float64, Float32]) do P
return Any[
(:stability, airyai, P(0.1)),
Expand Down Expand Up @@ -51,7 +51,7 @@ using Mooncake.TestUtils: test_rule
)
test_rule(StableRNG(123456), f, x...; perf_flag)
end
@testset for (perf_flag, f, x...) in vcat(
@testset "$perf_flag, $(typeof((f, x...)))" for (perf_flag, f, x...) in vcat(
map([Float64, Float32]) do P
return Any[
(:none, logerf, P(0.3), P(0.5)), # first branch
Expand Down
9 changes: 7 additions & 2 deletions test/interpreter/s2s_reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ struct A
end
f(a, x) = dot(a.data, x)

unstable_tester(x::Ref{Any}) = sin(x[])

end

@testset "s2s_reverse_mode_ad" begin
Expand Down Expand Up @@ -106,8 +108,6 @@ end
@test length(stmts.fwds) == 2
@test stmts.fwds[1][2].stmt isa Expr
@test stmts.fwds[2][2].stmt isa ReturnNode
@test Meta.isexpr(only(stmts.rvs)[2].stmt, :call)
@test only(stmts.rvs)[2].stmt.args[1] == Mooncake.increment_ref!
end
@testset "literal" begin
stmt_info = make_ad_stmts!(ReturnNode(5.0), line, info)
Expand Down Expand Up @@ -344,4 +344,9 @@ end
f() = Float64
@test length(build_rrule(Tuple{typeof(f)}).fwds_oc.oc.captures) == 2
end
@testset "all `Ref`s for rdata are eliminated in type unstable code" begin
ir = Mooncake.rvs_ir(Tuple{typeof(S2SGlobals.unstable_tester),Ref{Any}})
stmts = Mooncake.stmt(ir.stmts)
@test !any(x -> Meta.isexpr(x, :new) && x.args[1] <: Base.RefValue, stmts)
end
end
Loading