Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance Robustness in Reverse Pass #442

Merged
merged 18 commits into from
Jan 9, 2025
Merged

Conversation

willtebbutt
Copy link
Member

@willtebbutt willtebbutt commented Jan 9, 2025

This PR contains some work to ensure that various Refs inserted into the reverse-pass IR always get SROA'd away reliably. This is a PR in the same vein as #430 and #438 -- change how things are implemented to ensure that we're completely insensitive to what type inference / constant prop happens to let us do. The pay off for this kind of work is not receiving hard-to-debug issues in the future / users seeing poor performance, not telling us about it, and giving up on Mooncake entirely.

In short, a variety of Julia functions are inserted into the reverse pass IR at them minute, whose calls must get inlined away if SROA is to successfully remove the Refs. There are various reasons that Julia might not inline these calls away, for example if the argument types are not known statically (meaning that we get dynamic dispatch). This does not just affect functions whose performance we do not care about -- suppose a particular method instance is mostly type stable, but has one or two lines which are type-unstable and are rarely / if ever hit in practice. At present, the presence of these type instabilities will cause the reverse-pass of AD to have additional allocations.

This PR will remove the possibility of this happening by removing these generic functions altogether and inserting the code to handle the references directly into the IR. i.e. you'll see various getfield(ref, :x) and setfield!(ref, :x, val) calls inserted directly.

todo

  • do the same thing for phi node handling
  • do the same thing for pi node handling
  • remove any code duplication
  • remove all newly-redundant code

Copy link

codecov bot commented Jan 9, 2025

Codecov Report

Attention: Patch coverage is 96.77419% with 2 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/interpreter/zero_like_rdata.jl 50.00% 1 Missing ⚠️
src/utils.jl 90.00% 1 Missing ⚠️
Files with missing lines Coverage Δ
src/fwds_rvs_data.jl 95.93% <100.00%> (-0.02%) ⬇️
src/interpreter/s2s_reverse_mode_ad.jl 96.13% <100.00%> (+1.07%) ⬆️
src/interpreter/zero_like_rdata.jl 88.88% <50.00%> (-11.12%) ⬇️
src/utils.jl 87.77% <90.00%> (+0.68%) ⬆️

... and 5 files with indirect coverage changes

Copy link
Contributor

github-actions bot commented Jan 9, 2025

Performance Ratio:
Ratio of time to compute gradient and time to compute function.
Warning: results are very approximate! See here for more context.

┌────────────────────────────┬──────────┬─────────┬─────────────┬─────────┐
│                      Label │ Mooncake │  Zygote │ ReverseDiff │  Enzyme │
│                     String │   String │  String │      String │  String │
├────────────────────────────┼──────────┼─────────┼─────────────┼─────────┤
│                   sum_1000 │     83.6 │     1.1 │         5.5 │    8.21 │
│                  _sum_1000 │     6.61 │  1390.0 │        33.9 │    1.07 │
│               sum_sin_1000 │     2.28 │    1.69 │        10.9 │    1.98 │
│              _sum_sin_1000 │     2.68 │   255.0 │        13.2 │    2.44 │
│                   kron_sum │     63.9 │    3.32 │       215.0 │    8.68 │
│              kron_view_sum │     54.5 │    8.59 │       219.0 │    90.0 │
│      naive_map_sin_cos_exp │     2.45 │ missing │        7.45 │    2.33 │
│            map_sin_cos_exp │     2.71 │    1.54 │        6.13 │    2.92 │
│      broadcast_sin_cos_exp │     2.65 │    2.42 │        1.45 │    2.26 │
│                 simple_mlp │     4.93 │    3.56 │        6.73 │     3.5 │
│                     gp_lml │     12.8 │    6.49 │     missing │    8.35 │
│ turing_broadcast_benchmark │     3.22 │ missing │        26.9 │ missing │
│         large_single_block │     4.56 │  4280.0 │        32.5 │    2.24 │
└────────────────────────────┴──────────┴─────────┴─────────────┴─────────┘

@willtebbutt willtebbutt merged commit 35b432c into main Jan 9, 2025
71 checks passed
@willtebbutt willtebbutt deleted the wct/small-perf-stuff branch January 9, 2025 17:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant