-
Notifications
You must be signed in to change notification settings - Fork 20
Torch Compile with Float8Linear #106
Comments
3 problems:
|
LatestOkay so @bdhirsh is back for less then a day and we pretty much have everything (albeit not landed yet) working which is awesome! Float8Changes:Core changes:
Numbers:Better than no tensor subclass! ( likley because I worked on removing graphbreaks which could be done for no tensor subclass) name shape ref_dtype ... te_fp8_time_sec pt_fp8_speedup te_fp8_speedup
0 attn.wqkv (16384, 8192, 1280) torch.bfloat16 ... 0.002041 0.938752 1.040107
1 attn.w0 (16384, 1024, 8192) torch.bfloat16 ... 0.001765 1.082501 1.076772
2 ffn.w13 (16384, 8192, 7168) torch.bfloat16 ... 0.006432 1.476127 1.537317
3 ffn.w2 (16384, 3584, 8192) torch.bfloat16 ... 0.003637 1.390212 1.432203
4 attn.wqkv (16384, 8192, 1280) torch.float16 ... 0.002005 0.990544 1.112229
5 attn.w0 (16384, 1024, 8192) torch.float16 ... 0.001718 1.124236 1.125813
6 ffn.w13 (16384, 8192, 7168) torch.float16 ... 0.006646 1.578117 1.569713
7 ffn.w2 (16384, 3584, 8192) torch.float16 ... 0.003713 1.464948 1.474404 No Tensor Subclass name shape ref_dtype ... te_fp8_time_sec pt_fp8_speedup te_fp8_speedup
0 attn.wqkv (16384, 8192, 1280) torch.bfloat16 ... 0.001976 0.923693 1.077280
1 attn.w0 (16384, 1024, 8192) torch.bfloat16 ... 0.001773 1.062716 1.073277
2 ffn.w13 (16384, 8192, 7168) torch.bfloat16 ... 0.006451 1.437040 1.528630
3 ffn.w2 (16384, 3584, 8192) torch.bfloat16 ... 0.003633 1.373209 1.441463
4 attn.wqkv (16384, 8192, 1280) torch.float16 ... 0.001984 0.945713 1.129551
5 attn.w0 (16384, 1024, 8192) torch.float16 ... 0.001643 1.116112 1.186204
6 ffn.w13 (16384, 8192, 7168) torch.float16 ... 0.006483 1.509603 1.604298
7 ffn.w2 (16384, 3584, 8192) torch.float16 ... 0.003706 1.445889 1.476486
|
Silent Correctness IssueI first discovered this when I tried to re-train llama7b on a single node using torch.compile. Although the performance significantly enhanced, it failed to converge. I created this sample script to explore the cause: When attempting to run with backend = Left is compile, right is eager The I can't be entirely sure that this the same problem that is effecting the the llama train but it is very likely. Why is this happeningWe currently use a The offending line that I do not think is working the same between eager and compile is this fill_. I think that a majority(all?) backwards are expected to be functional. And this likely breaks some assumption somewhere. |
@bdhirsh for more inplace things. |
The problem is that we are mutating some global state (a module buffer) inside of the backward. Today, AOTAutograd can handle input/buffer mutations, but only from the forward. I think that there are roughly two options (that are both some amount of work): Option 1: Make AOTAutograd smarter. Detect when the backward graph contains input mutations, and include them in the graph as Option 2: Re-write the mutation using backward hooks, and only support this case in compiled autograd. Upside: might just work out-of-the-box. Downside: compiled-autograd is not on by default, so anyone using Float8 would need to turn on compiled autograd or risk wrong results. I briefly tried Option 2 (compiled autograd), just to see if it would immediately work, by re-writing The first issue I ran into is that dynamo only supports a limited set of hooks (code). It seems to support both (a) can't make it (b) can't make it a My current thought is that: (a) compiled autograd doesn't work out-of-box (it seems like it would be a decent amount of work to make this work properly, although I could be wrong) (b) Ideally, vanilla aot autograd would just work, so we don't risk someone trying to compile Float8 code without compiled autograd and hitting correctness issues. So next I'm going to try to prototype AOTAutograd handling for buffer mutations in the backward |
Tentative min example code, that I'm going to use to prototype the AOTAutograd changes
|
@bdhirsh PR: pytorch/pytorch#115195 has since landed in nightly. I have been able to verify that the mutation of the buffers in the backwards is now working correctly under torch compile. |
Summary
I will be using this as a top level tracker and link to subissues with smaller repros to tackle this problem
PRs
#56 Brian has done some initial work getting subclasses to compile for fp8
Issues
Problem summaries
All the problems are based off of this implementation of Float8Tensor
#128
Add this repro script to surface compile issues: https://gist.github.com/drisspg/6e76d3d99dc932e2287f19123f6339d1
Backend = "eager"
Adding the following to sourceless builder
PR: pytorch/pytorch#112284
Cleans up both errors.
Graph Breaks
I so using TORCH_LOGS="graph_breaks" python ../scripts/fp8/eager_compile_debug.py we were graphbreaks whenever we tried to construct fp8_tensors with the class method. I found out that moving it to a function fixed the graph breaks and now we we have None for this script, see:
#131
Backend = "aot_eager"
With the fix to no have any graph breaks we now get a more helpful error message:
I suspect this error is because for matmul we output a regular tensor and not a TensorSubclass. And then during backward we have the autograd func that converts it to the different fp8 format
Backend = "inductor"
With the tangle of PRs and changes and by not running backwards on the subclass linear I can actually compile with inductor!
However it fails when the "high_precision" dytpe is not float32. I suspect this is because we are storing amax in fp32 (needed for scaled_mm) and inductor scatter produces the following error
Old error:
When attempting the compile for "aot_eager" with the above two fixes we get
UPDATE: I was able to trigger a more helpful error message by iterating through the fake modes of the inner tensors:
https://gist.github.com/drisspg/ed916d144e819d7eb0be6728e0e807a7
The text was updated successfully, but these errors were encountered: