-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Start forward mode AD #389
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great. I've left a few comments, but if you're planning to do a bunch of additional stuff, then maybe they're redundant. Either way, don't feel the need to respond to them.
Co-authored-by: Will Tebbutt <[email protected]> Signed-off-by: Guillaume Dalle <[email protected]>
@willtebbutt following our discussion yesterday I scratched my head some more, and I decided that it would be infinitely simpler to enforce the invariant that one line of primal IR maps to one line of dual IR. While this may require additional fallbacks in the Julia code itself, I hope it will make our lives much easier on the IR side. What do you think? |
I think this could work. You could just replace the @inline function call_frule!!(rule::R, fargs::Vararg{Any, N}) where {N}
return rule(map(x -> x isa Dual ? x : zero_dual(x), fargs)...)
end The optimisation pass will lower this to the what we were thinking about writing out in the IR anyway. I think the other important kinds of nodes would be largely straightforward to handle. |
I think we might need to be slightly more subtle. If an argument to the |
Yes. I think my propose code handles this though, or am I missing something? |
In the spirit of higher-order AD, we may encounter |
Very good point.
Agreed. Specifically, I think we need to distinguish between literals / |
I still need to dig into the different node types we might encounter (and I still don't understand |
I was reviewing the design docs and realised that, sadly, the "one line of primal IR maps to one line of dual IR" won't work for |
I think that's okay, the main trouble is adding new lines which insert new variables because it requires manual renumbering. A GoTo should be much simpler. |
Were the difficulties around renumbering etc not resolved by not |
No they weren't. I experimented with |
Ah, right, but we do need to insert a new SSAValue. Suppose that the GotoIfNot(%5, #3) i.e. jump to block 3 if not %new_ssa = Expr(:call, primal, %5)
GotoIfNot(%new_ssa, #3) Does this not cause the same kind of problems? |
Oh yes you're probably right. Although it might be slightly less of a hassle because the new SSA is only used in one spot, right after. I'll take a look |
Do you know what I should do about expressions of type |
Yup -- I just strip them out of the IR entirely in reverse-mode. See
The way to remove an instruction from an |
I think this works for
MWE (requires this branch of Mooncake): const CC = Core.Compiler
using Mooncake
using MistyClosures
f(x) = x > 1 ? 2x : 3 + x
ir = Base.code_ircode(f, (Float64,))[1][1]
initial_ir = copy(ir)
get_primal_inst = CC.NewInstruction(Expr(:call, +, 1, 2), Any) # placeholder for get_primal
CC.insert_node!(ir, CC.SSAValue(3), get_primal_inst, false)
ir = CC.compact!(ir)
for k in 1:length(ir.stmts)
inst = ir[CC.SSAValue(k)][:stmt]
if inst isa Core.GotoIfNot
Mooncake.replace_call!(ir,CC.SSAValue(k), Core.GotoIfNot(CC.SSAValue(k-1), inst.dest))
end
end
ir julia> initial_ir
5 1 ─ %1 = Base.lt_float(1.0, _2)::Bool │╻╷╷ >
│ %2 = Base.or_int(%1, false)::Bool ││╻ <
└── goto #3 if not %2 │
2 ─ %4 = Base.mul_float(2.0, _2)::Float64 ││╻ *
└── return %4 │
3 ─ %6 = Base.add_float(3.0, _2)::Float64 ││╻ +
└── return %6 │
julia> ir
5 1 ─ %1 = Base.lt_float(1.0, _2)::Bool │╻╷╷ >
│ Base.or_int(%1, false)::Bool ││╻ <
│ %3 = (+)(1, 2)::Any │
└── goto #3 if not %3 │
2 ─ %5 = Base.mul_float(2.0, _2)::Float64 ││╻ *
└── return %5 │
3 ─ %7 = Base.add_float(3.0, _2)::Float64 ││╻ +
└── return %7 |
@willtebbutt I made some progress with control flow but I have an issue with |
IMO one of the reasons why Zygote and it's related tooling failed is that they used their IR format. I do agree that IRCode is a hassle to work with sometimes, but it would be much better for Julia long-term and also the sustainability of MoonCake if it used the IR from Base throughout. Then the knowledge you acquire here can be translated into fixing things in Base that you notice, and people like me can look at your code and actually help you, without having to learn yet another IR. We do need better tools for CFG transforms on Base.. I haven't found the time to polish JuliaLang/julia#45305 but I used it in my LoopInfo proto-type https://github.com/vchuravy/Loops.jl/blob/main/src/ir.jl |
I totally understand where you're coming from here. I took quite a lot of time when I first starting trying to write Mooncake trying to avoid creating my own version of I would add though that there's a bijection between |
IIRC, BBCode is only meant to provide some extra utility currently missing from IRCode, which makes life easier (more compact, readable code). Can we try to upstream BBCode functionality to IRCode? |
Core.GotoIfNot(CC.SSAValue(i), stmt.dest), # | ||
Any, | ||
CC.NoCallInfo(), | ||
Int32(1), # meaningless |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You likely want nothing
since that will try to figure out the right line number instead of just 1
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We tried with nothing but the incremental compact insertion threw an error
@gdalle tests now run and seem to be sensitive to all the things they ought to be sensitive to. You'll see I've chucked a loop testset in |
Awesome! I'm gonna try and work my way through them. |
@willtebbutt I've made some more headway and I need help for:
|
Hmm is this not handled by our existing
I think you could probably just add a function like @inline function zero_derivative(f::Dual, x::Vararg{Dual,N}) where {N}
return zero_dual(primal(f)(map(primal, x)...))
end while echos the existing function |
This is a very rough backbone of forward mode AD, based on #386 and the existing reverse mode implementation.