Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start forward mode AD #389

Draft
wants to merge 34 commits into
base: main
Choose a base branch
from
Draft

Start forward mode AD #389

wants to merge 34 commits into from

Conversation

gdalle
Copy link
Collaborator

@gdalle gdalle commented Nov 24, 2024

This is a very rough backbone of forward mode AD, based on #386 and the existing reverse mode implementation.

Copy link

codecov bot commented Nov 24, 2024

Codecov Report

Attention: Patch coverage is 54.95495% with 150 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/interpreter/diffractor_compiler_utils.jl 13.63% 57 Missing ⚠️
src/interpreter/s2s_forward_mode_ad.jl 57.85% 51 Missing ⚠️
src/test_utils.jl 88.63% 10 Missing ⚠️
src/dual.jl 55.00% 9 Missing ⚠️
src/interpreter/bbcode.jl 0.00% 9 Missing ⚠️
src/rrules/misc.jl 0.00% 7 Missing ⚠️
src/rrules/builtins.jl 66.66% 4 Missing ⚠️
src/rrules/low_level_maths.jl 66.66% 2 Missing ⚠️
src/debug_mode.jl 0.00% 1 Missing ⚠️
Files with missing lines Coverage Δ
src/Mooncake.jl 100.00% <ø> (ø)
src/interpreter/ir_normalisation.jl 84.39% <ø> (-5.68%) ⬇️
src/interpreter/ir_utils.jl 82.50% <100.00%> (-5.00%) ⬇️
src/tools_for_rules.jl 94.73% <100.00%> (-4.22%) ⬇️
src/debug_mode.jl 97.22% <0.00%> (-2.78%) ⬇️
src/rrules/low_level_maths.jl 30.30% <66.66%> (-69.70%) ⬇️
src/rrules/builtins.jl 38.33% <66.66%> (-59.60%) ⬇️
src/rrules/misc.jl 54.11% <0.00%> (-43.32%) ⬇️
src/dual.jl 55.00% <55.00%> (ø)
src/interpreter/bbcode.jl 92.06% <0.00%> (-4.02%) ⬇️
... and 3 more

... and 22 files with indirect coverage changes

Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great. I've left a few comments, but if you're planning to do a bunch of additional stuff, then maybe they're redundant. Either way, don't feel the need to respond to them.

src/interpreter/s2s_forward_mode_ad.jl Outdated Show resolved Hide resolved
test/forward.jl Outdated Show resolved Hide resolved
src/frules/basic.jl Outdated Show resolved Hide resolved
src/frules/basic.jl Outdated Show resolved Hide resolved
src/interpreter/s2s_forward_mode_ad.jl Outdated Show resolved Hide resolved
src/interpreter/s2s_forward_mode_ad.jl Outdated Show resolved Hide resolved
src/interpreter/s2s_forward_mode_ad.jl Outdated Show resolved Hide resolved
src/interpreter/s2s_forward_mode_ad.jl Outdated Show resolved Hide resolved
@gdalle
Copy link
Collaborator Author

gdalle commented Nov 26, 2024

@willtebbutt following our discussion yesterday I scratched my head some more, and I decided that it would be infinitely simpler to enforce the invariant that one line of primal IR maps to one line of dual IR. While this may require additional fallbacks in the Julia code itself, I hope it will make our lives much easier on the IR side. What do you think?

@willtebbutt
Copy link
Member

I think this could work.

You could just replace the frule!! calls with a call to a function call_frule!! which would be something like

@inline function call_frule!!(rule::R, fargs::Vararg{Any, N}) where {N}
    return rule(map(x -> x isa Dual ? x : zero_dual(x), fargs)...)
end

The optimisation pass will lower this to the what we were thinking about writing out in the IR anyway.

I think the other important kinds of nodes would be largely straightforward to handle.

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 26, 2024

I think we might need to be slightly more subtle. If an argument to the :call or :invoke expression is a CC.Argument or a CC.SSAValue, we don't wrap it in a Dual because we assume it will already be one, right?

@willtebbutt
Copy link
Member

willtebbutt commented Nov 26, 2024

Yes. I think my propose code handles this though, or am I missing something?

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 26, 2024

In the spirit of higher-order AD, we may encounter Dual inputs that we want to wrap with a second Dual, and Dual inputs that we want to leave as-is. So I think this wrapping needs to be decided from the type of each argument in the IR?

@willtebbutt
Copy link
Member

Very good point.

So I think this wrapping needs to be decided from the type of each argument in the IR?

Agreed. Specifically, I think we need to distinguish between literals / QuoteNodes / GlobalRefs, and Argument / SSAValues?

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 26, 2024

I still need to dig into the different node types we might encounter (and I still don't understand QuoteNodes) but yeah, Argument and SSAValue don't need to be wrapped.

@gdalle gdalle mentioned this pull request Nov 27, 2024
@willtebbutt
Copy link
Member

I was reviewing the design docs and realised that, sadly, the "one line of primal IR maps to one line of dual IR" won't work for Core.GotoIfNot nodes. See https://compintell.github.io/Mooncake.jl/previews/PR386/developer_documentation/forwards_mode_design/#Statement-Transformation .

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

I think that's okay, the main trouble is adding new lines which insert new variables because it requires manual renumbering. A GoTo should be much simpler.

@willtebbutt
Copy link
Member

Were the difficulties around renumbering etc not resolved by not compact!ing until the end? I feel like I might be missing something.

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

No they weren't. I experimented with compact! in various places and I was struggling a lot, so I asked Frames for advice. She agreed that insertion should usually be avoided.
If we have to insert something for GoTo, I think it will still be easier because we're not defining a new SSAValue so we don't have to adapt future statements that refer to it.

@willtebbutt
Copy link
Member

willtebbutt commented Nov 27, 2024

Ah, right, but we do need to insert a new SSAValue. Suppose that the GotoIfNot of interest is

GotoIfNot(%5, #3)

i.e. jump to block 3 if not %5. In the forwards-mode IR this would become

%new_ssa = Expr(:call, primal, %5)
GotoIfNot(%new_ssa, #3)

Does this not cause the same kind of problems?

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

Oh yes you're probably right. Although it might be slightly less of a hassle because the new SSA is only used in one spot, right after. I'll take a look

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

Do you know what I should do about expressions of type :code_coverage_effect? I assume they're inserted automatically and they're alone on their lines?

@willtebbutt
Copy link
Member

willtebbutt commented Nov 27, 2024

Yup -- I just strip them out of the IR entirely in reverse-mode. See

elseif Meta.isexpr(stmt, :code_coverage_effect)

The way to remove an instruction from an IRCode is just to replace the instruction with nothing.

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

I think this works for GotoIfNot:

  1. make all the insertions necessary
  2. compact! once to make sure they applied
  3. shift the conditions of all GotoIfNot nodes to refer to the node right before them (where we get the primal value of the condition)

MWE (requires this branch of Mooncake):

const CC = Core.Compiler
using Mooncake
using MistyClosures

f(x) = x > 1 ? 2x : 3 + x
ir = Base.code_ircode(f, (Float64,))[1][1]
initial_ir = copy(ir)
get_primal_inst = CC.NewInstruction(Expr(:call, +, 1, 2), Any)  # placeholder for get_primal
CC.insert_node!(ir, CC.SSAValue(3), get_primal_inst, false)
ir = CC.compact!(ir)
for k in 1:length(ir.stmts)
    inst = ir[CC.SSAValue(k)][:stmt]
    if inst isa Core.GotoIfNot
        Mooncake.replace_call!(ir,CC.SSAValue(k), Core.GotoIfNot(CC.SSAValue(k-1), inst.dest))
    end
end
ir
julia> initial_ir
5 1%1 = Base.lt_float(1.0, _2)::Bool                                                                                 │╻╷╷ >%2 = Base.or_int(%1, false)::Bool                                                                                 ││╻   <
  └──      goto #3 if not %2                                                                                            │   
  2%4 = Base.mul_float(2.0, _2)::Float64                                                                             ││╻   *
  └──      return %43%6 = Base.add_float(3.0, _2)::Float64                                                                             ││╻   +
  └──      return %6                                                                                                    │   
                                                                                                                            

julia> ir
5 1%1 = Base.lt_float(1.0, _2)::Bool                                                                                 │╻╷╷ >
  │        Base.or_int(%1, false)::Bool                                                                                 ││╻   <%3 = (+)(1, 2)::Any                                                                                               │   
  └──      goto #3 if not %3                                                                                            │   
  2%5 = Base.mul_float(2.0, _2)::Float64                                                                             ││╻   *
  └──      return %53%7 = Base.add_float(3.0, _2)::Float64                                                                             ││╻   +
  └──      return %7      

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 28, 2024

@willtebbutt I made some progress with control flow but I have an issue with DerivedFRule, which receives things that are not Dual and so primal fails. Can you take a look?

@vchuravy
Copy link

vchuravy commented Nov 30, 2024

On a separate note: if it does turn out that (for some reason) that it's really awkward to do this insertion stuff directly on IRCode, I show in this gist how to use BBCode to manage IR transformations involving insertions, because this kind of thing is what BBCode excels at.

IMO one of the reasons why Zygote and it's related tooling failed is that they used their IR format. I do agree that IRCode is a hassle to work with sometimes, but it would be much better for Julia long-term and also the sustainability of MoonCake if it used the IR from Base throughout.

Then the knowledge you acquire here can be translated into fixing things in Base that you notice, and people like me can look at your code and actually help you, without having to learn yet another IR.

We do need better tools for CFG transforms on Base.. I haven't found the time to polish JuliaLang/julia#45305 but I used it in my LoopInfo proto-type https://github.com/vchuravy/Loops.jl/blob/main/src/ir.jl

https://vchuravy.dev/talks/licm/

@willtebbutt
Copy link
Member

willtebbutt commented Nov 30, 2024

I totally understand where you're coming from here. I took quite a lot of time when I first starting trying to write Mooncake trying to avoid creating my own version of IRCode for exactly this reason -- I really didn't want to roll my own thing. Unfortunately, IRCode just really isn't up to the kinds of transformations I've had to do to make Mooncake work. Since @gdalle started working on this, I've been thinking about writing up some specific things that are really not at all easy to do with IRCode, and how the BBCode thing I use makes life easier, to try and start some discussion around this. I'm reasonably confident it'll be clear why working on IRCode directly just wasn't an option for me.

I would add though that there's a bijection between IRCode and BBCode -- they contain exactly the same information, just organised slightly differently.

@yebai
Copy link
Contributor

yebai commented Dec 1, 2024

IIRC, BBCode is only meant to provide some extra utility currently missing from IRCode, which makes life easier (more compact, readable code). Can we try to upstream BBCode functionality to IRCode?

Core.GotoIfNot(CC.SSAValue(i), stmt.dest), #
Any,
CC.NoCallInfo(),
Int32(1), # meaningless
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You likely want nothing since that will try to figure out the right line number instead of just 1.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We tried with nothing but the incremental compact insertion threw an error

@willtebbutt
Copy link
Member

@gdalle tests now run and seem to be sensitive to all the things they ought to be sensitive to.

You'll see I've chucked a loop testset in s2s_forward_mode_ad.jl -- this hooks in to the 100 or so unit-like tests that I have to try out reverse-mode on all of the language features that are supported. I'm currently only running the third and fourth test (I don't think the first two will run correctly at the minute). This is a pretty good battery of tests to work through -- if you can get forwards-mode working on all of these (which will involve writing a bunch of frules, that I'm very happy to help with), then we should be most of the way to having a working forwards-mode AD. At that point, it should just be a case of adding a range of additional hand-written frules.

@gdalle
Copy link
Collaborator Author

gdalle commented Dec 6, 2024

Awesome! I'm gonna try and work my way through them.

@gdalle
Copy link
Collaborator Author

gdalle commented Dec 6, 2024

@willtebbutt I've made some more headway and I need help for:

  • 7: how to define frule!! of getfield
  • 10: how to adapt @zero_adjoint to forward mode

@willtebbutt
Copy link
Member

7: how to define frule!! of getfield

Hmm is this not handled by our existing lgetfield rule?

10: how to adapt @zero_adjoint to forward mode

I think you could probably just add a function like

@inline function zero_derivative(f::Dual, x::Vararg{Dual,N}) where {N}
    return zero_dual(primal(f)(map(primal, x)...))
end

while echos the existing function Mooncake.zero_adjoint, and then shove this in a macro that looks a lot like the @zero_adjoint macro?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants