Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Torch Compile with Float8Linear #106

Closed
drisspg opened this issue Sep 22, 2023 · 7 comments
Closed

Torch Compile with Float8Linear #106

drisspg opened this issue Sep 22, 2023 · 7 comments
Labels
Compile Issues related with subclass compilation

Comments

@drisspg
Copy link
Contributor

drisspg commented Sep 22, 2023

Summary

I will be using this as a top level tracker and link to subissues with smaller repros to tackle this problem

PRs

#56 Brian has done some initial work getting subclasses to compile for fp8

Issues

Problem summaries

All the problems are based off of this implementation of Float8Tensor
#128

Add this repro script to surface compile issues: https://gist.github.com/drisspg/6e76d3d99dc932e2287f19123f6339d1

Backend = "eager"

  1. When attempting to compile FP8Linear w/ "eager" backend. We currently fail during the automatic_dynamic_dims creation problems. For a more detailed analysis and potential fix see: Allow traceable_subclass_tensors to have multiple dynamic tensor attributes pytorch/pytorch#112185
  2. After the above PR there are no hard errors but compiling w/ backend = eager gives the following two warnings
[2023-10-27 09:36:36,705] [2/1] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting the user-defined autograd.Function, we were unable to trace function `trampoline_autograd_fwd` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[2023-10-27 09:36:36,705] [2/1] torch._dynamo.variables.higher_order_ops: [ERROR] Unexpected type in sourceless builder <class 'torch.dtype'>

Adding the following to sourceless builder

    elif isinstance(value, torch.dtype):
        return ConstantVariable.create(value)

PR: pytorch/pytorch#112284
Cleans up both errors.

Graph Breaks
I so using TORCH_LOGS="graph_breaks" python ../scripts/fp8/eager_compile_debug.py we were graphbreaks whenever we tried to construct fp8_tensors with the class method. I found out that moving it to a function fixed the graph breaks and now we we have None for this script, see:
#131

Backend = "aot_eager"

With the fix to no have any graph breaks we now get a more helpful error message:

Traceback (most recent call last):
  File "/home/drisspg/meta/float8_experimental/../scripts/fp8/eager_compile_debug.py", line 41, in <module>
    main()
  File "/home/drisspg/meta/float8_experimental/../scripts/fp8/eager_compile_debug.py", line 31, in main
    y_fp8.sum().backward()
  File "/home/drisspg/meta/pytorch/torch/_tensor.py", line 503, in backward
    torch.autograd.backward(
  File "/home/drisspg/meta/pytorch/torch/autograd/__init__.py", line 254, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/drisspg/meta/pytorch/torch/autograd/function.py", line 289, in apply
    return user_fn(self, *args)
  File "/home/drisspg/meta/pytorch/torch/_functorch/aot_autograd.py", line 4009, in backward
    assert grad_output_types_ == CompiledFunction.metadata.output_types, f"""\
AssertionError: We incorrectly attempted to compile the backward with incorrect subclass metadata.
If you run into this error, please file an issue.
Expected grad_output types: [<class 'float8_experimental.float8_tensor.Float8Tensor'>]
Got grad_output types: [<class 'torch.Tensor'>]

I suspect this error is because for matmul we output a regular tensor and not a TensorSubclass. And then during backward we have the autograd func that converts it to the different fp8 format

Backend = "inductor"

With the tangle of PRs and changes and by not running backwards on the subclass linear I can actually compile with inductor!

However it fails when the "high_precision" dytpe is not float32. I suspect this is because we are storing amax in fp32 (needed for scaled_mm) and inductor scatter produces the following error

  File "/home/drisspg/meta/pytorch/torch/_inductor/lowering.py", line 289, in wrapped
    out = decomp_fn(*args, **kwargs)
  File "/home/drisspg/meta/pytorch/torch/_inductor/lowering.py", line 2219, in select_scatter
    assert x.get_dtype() == src.get_dtype()
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: AssertionError:
  target: aten.select_scatter.default
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='primals_3', layout=FixedLayout('cuda', torch.float32, size=[16], stride=[1]))
  ))
  args[1]: TensorBox(StorageBox(
    ComputedBuffer(name='buf0', layout=FlexibleLayout('cuda', torch.bfloat16, size=[], stride=[]), data=Reduction(
      'cuda',
      torch.bfloat16,
      def inner_fn(index, rindex):
          r0, r1 = rindex
          tmp0 = ops.load(primals_1, r1 + 16 * r0)
          tmp1 = ops.abs(tmp0)
          return tmp1
      ,
      ranges=[],
      reduction_ranges=[16, 16],
      reduction_type=max,
      origin_node=max_1,
      origins={max_1, abs_1}
    ))
  ))
  args[2]: 0
  args[3]: 0
Old error:

When attempting the compile for "aot_eager" with the above two fixes we get

  File "/home/drisspg/meta/pytorch/torch/_functorch/aot_autograd.py", line 4230, in <listcomp>
    return [convert(idx, x) for idx, x in enumerate(flat_args)]
  File "/home/drisspg/meta/pytorch/torch/_functorch/aot_autograd.py", line 4219, in convert
    assert all(getattr(x, attr).fake_mode is fake_mode for attr in attrs)
torch._dynamo.exc.BackendCompilerFailed: backend='aot_eager' raised:
AssertionError:

UPDATE: I was able to trigger a more helpful error message by iterating through the fake modes of the inner tensors:
https://gist.github.com/drisspg/ed916d144e819d7eb0be6728e0e807a7

@drisspg drisspg added the Compile Issues related with subclass compilation label Sep 22, 2023
facebook-github-bot pushed a commit that referenced this issue Oct 27, 2023
Summary:
For more detailed understanding of status see:
#106

But this removes all graph breaks on the main work branch

Pull Request resolved: #131

Reviewed By: albanD

Differential Revision: D50758815

Pulled By: drisspg

fbshipit-source-id: 1502601099988b1eba666306e327eb724eb14989
@drisspg
Copy link
Contributor Author

drisspg commented Oct 30, 2023

3 problems:

  • Dynamic dims issue see PR above, with added tree structure
  • Custom functions and methods for the subclass. E.g. "to_float8". Autograd function takes in a plan tensor and return subclass tensor. We are not doing this. to_float8 and from_float8 should be methods on subclass and we should teach dynamo to create proxies for it on the subgraphs for methods on subclasses. COME BACK TO THIS. If we run into problems follow dtensor
  • According to aot_autograd the output of the forward is a float8Tensor and we don't think it should be

@drisspg
Copy link
Contributor Author

drisspg commented Oct 31, 2023

Latest

Okay so @bdhirsh is back for less then a day and we pretty much have everything (albeit not landed yet) working which is awesome!

Float8Changes:

Core changes:

Numbers:

Better than no tensor subclass! ( likley because I worked on removing graphbreaks which could be done for no tensor subclass)

        name                shape       ref_dtype  ...  te_fp8_time_sec  pt_fp8_speedup  te_fp8_speedup
0  attn.wqkv  (16384, 8192, 1280)  torch.bfloat16  ...         0.002041        0.938752        1.040107
1    attn.w0  (16384, 1024, 8192)  torch.bfloat16  ...         0.001765        1.082501        1.076772
2    ffn.w13  (16384, 8192, 7168)  torch.bfloat16  ...         0.006432        1.476127        1.537317
3     ffn.w2  (16384, 3584, 8192)  torch.bfloat16  ...         0.003637        1.390212        1.432203
4  attn.wqkv  (16384, 8192, 1280)   torch.float16  ...         0.002005        0.990544        1.112229
5    attn.w0  (16384, 1024, 8192)   torch.float16  ...         0.001718        1.124236        1.125813
6    ffn.w13  (16384, 8192, 7168)   torch.float16  ...         0.006646        1.578117        1.569713
7     ffn.w2  (16384, 3584, 8192)   torch.float16  ...         0.003713        1.464948        1.474404

No Tensor Subclass

        name                shape       ref_dtype  ...  te_fp8_time_sec  pt_fp8_speedup  te_fp8_speedup
0  attn.wqkv  (16384, 8192, 1280)  torch.bfloat16  ...         0.001976        0.923693        1.077280
1    attn.w0  (16384, 1024, 8192)  torch.bfloat16  ...         0.001773        1.062716        1.073277
2    ffn.w13  (16384, 8192, 7168)  torch.bfloat16  ...         0.006451        1.437040        1.528630
3     ffn.w2  (16384, 3584, 8192)  torch.bfloat16  ...         0.003633        1.373209        1.441463
4  attn.wqkv  (16384, 8192, 1280)   torch.float16  ...         0.001984        0.945713        1.129551
5    attn.w0  (16384, 1024, 8192)   torch.float16  ...         0.001643        1.116112        1.186204
6    ffn.w13  (16384, 8192, 7168)   torch.float16  ...         0.006483        1.509603        1.604298
7     ffn.w2  (16384, 3584, 8192)   torch.float16  ...         0.003706        1.445889        1.476486

@drisspg
Copy link
Contributor Author

drisspg commented Nov 3, 2023

Silent Correctness Issue

I first discovered this when I tried to re-train llama7b on a single node using torch.compile. Although the performance significantly enhanced, it failed to converge.

I created this sample script to explore the cause:
https://gist.github.com/drisspg/f8a37121a67d9c08500bf678af298554

When attempting to run with backend = aot_eager
The following differences can be observed in the state_dict.
Screenshot 2023-11-02 at 5 13 40 PM

Left is compile, right is eager

The fp8_amax_dL_dY and by extension the fp8_scale_dL_dY buffer were not getting updated when the module was compiled with aot_eager.

I can't be entirely sure that this the same problem that is effecting the the llama train but it is very likely.

Why is this happening

We currently use a "no-op forward" torch.autograd function to handle the conversion of the high_precision backproping grad to the e5m2 format needed for backwards matmul compute. The autograd function can be found here.

The offending line that I do not think is working the same between eager and compile is this fill_.

I think that a majority(all?) backwards are expected to be functional. And this likely breaks some assumption somewhere.

@voznesenskym
Copy link

@bdhirsh for more inplace things.

@bdhirsh
Copy link
Contributor

bdhirsh commented Nov 3, 2023

The problem is that we are mutating some global state (a module buffer) inside of the backward. Today, AOTAutograd can handle input/buffer mutations, but only from the forward.

I think that there are roughly two options (that are both some amount of work):

Option 1: Make AOTAutograd smarter. Detect when the backward graph contains input mutations, and include them in the graph as copy_() nodes for inductor to optimize.

Option 2: Re-write the mutation using backward hooks, and only support this case in compiled autograd. Upside: might just work out-of-the-box. Downside: compiled-autograd is not on by default, so anyone using Float8 would need to turn on compiled autograd or risk wrong results.

I briefly tried Option 2 (compiled autograd), just to see if it would immediately work, by re-writing NoopFwToFloat8E5M2Bw as a backward hook (that mutates the module buffer), and running compiled autograd.

The first issue I ran into is that dynamo only supports a limited set of hooks (code). It seems to support both UserFunctionVariable and FuncToolsPartialVariable, so ideally we could make our hook one of those. From some playing around, I mostly hit issues due to the fact that the hook tries to mutate nn.module state

(a) can't make it UserFunctionVariable, because that would require defining the hook outside of Float8Linear. We could do this, but the buffer want to mutate is a piece of state on the Float8Linear module, and not a global value. We can't pass the buffer as an input to the hook, since hooks are expected to have a particular schema (take in a single argument corresponding to the current gradient). This could be solved by using functools.partial

(b) can't make it a FuncToolsPartialVariable, because dynamo expects all kwargs on the functools.partial call to be python constants (code).

My current thought is that:

(a) compiled autograd doesn't work out-of-box (it seems like it would be a decent amount of work to make this work properly, although I could be wrong)

(b) Ideally, vanilla aot autograd would just work, so we don't risk someone trying to compile Float8 code without compiled autograd and hitting correctness issues.

So next I'm going to try to prototype AOTAutograd handling for buffer mutations in the backward

@bdhirsh
Copy link
Contributor

bdhirsh commented Nov 3, 2023

Tentative min example code, that I'm going to use to prototype the AOTAutograd changes

import torch

class MutatingAutogradFn(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, buf):
        ctx.save_for_backward(buf)
        return x

    @staticmethod
    def backward(ctx, x_grad):
        buf = ctx.saved_tensors[0]
        buf.add_(x_grad)
        return x_grad, None

class Mod(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.buf = torch.ones(2)

    @torch._dynamo.allow_in_graph
    def backward_mutating_fn(self, x, buf):
        return MutatingAutogradFn.apply(x, buf)

    def forward(self, x):
        tmp = self.backward_mutating_fn(x, self.buf)
        return tmp * self.buf

m = Mod()
m_compiled = torch.compile(m, backend="aot_eager")

x = torch.ones(2, requires_grad=True)
out = m_compiled(x)
# at the end of the fw, buf has not been mutated yet
print(m.buf)
out.sum().backward()
# we ran the backward, so buf should be mutated
print(m.buf)

@drisspg
Copy link
Contributor Author

drisspg commented Dec 7, 2023

@bdhirsh PR: pytorch/pytorch#115195 has since landed in nightly. I have been able to verify that the mutation of the buffers in the backwards is now working correctly under torch compile.

@drisspg drisspg closed this as completed Dec 7, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Compile Issues related with subclass compilation
Projects
None yet
Development

No branches or pull requests

3 participants