Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

[poc] scratch work on getting torch_dispatch subclass to play well wi… #56

Closed
wants to merge 102 commits into from

Conversation

bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented Aug 25, 2023

Not for landing

This is just some scratch work I did today trying to get Float8Tensor (as a torch_dispatch subclass) to play well with AOTAutograd.

AOTAutograd bits

I mostly got things working, with one annoying issue: The custom fp8 linear as written today will return a Float8Tensor output in the forward. However, When we call .sum().backward(), the sum call returns a plain torch.Tensor. This causes the grad_out of the compiled function (in the backward) to be a plain tensor. This is a problem, because:

(1) AOTAutograd eagerly traces out a backward graph to do partitioning, during the forward

(2) it needs to make guesses on whether or not each grad_out will be a subclass

(3) Right now, when it sees a fw_out that is a subclass, it assumes that the corresponding grad_out will also be a subclass

(4) If we guessed wrong, then we'll need to re-trace the backward. This requires backward guards, which are not implemented yet.

I had a workaround hack locally, where I just special-case FP8 in aot autograd to always assume that grad_outputs are not subclasses. That seemed to work, and I was able to use torch.compile and generate fw and bw graphs.

Dynamo bits

However, that isn't enough to test correctness. Why? The custom module in the test has a flag that gets flipped after the first iteration, that dynamo has to specialize the graph on.

I made a mild attempt to get things working in dynamo - seemed to get part of the way through, but eventually I gave up. Some progress here: 814239133

vkuzo and others added 30 commits July 19, 2023 15:32
Summary:

This is a copy of
facebookexperimental/protoquant#23

Many things will change based on recent discussions!

Test Plan:

```
python float8_playground/test.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Python-only float8 data type + bare bones UEX
Summary:

skipped it before, going back to it now

this will be useful for transformer block

Test Plan:

```
python float8_playground/test.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
add bias support to float8linear
Summary:

forgot on last PR

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Now that pytorch/pytorch#104242 landed, we can
stop emulation - this simplifies the code quite a bit.

Test Plan:

```
python float8_playground/test.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
switch from emulated to real float8 dtypes
Summary:

Adds a check that using `Float8Linear` on a `SAM` encoder
results in reasonable spot check accuracy.

Note that grad accuracy is not tested as it is all over the place,
this is probably expected but saving investigation until later.
Specifically the layernorm and positional encoding grads have
a large error for fp8 version vs reference.

Test Plan:

```
python float8_playground/test.py
with-proxy python float8_playground/test_sam.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
add numerical test on SAM encoder
Summary:

Don't test cpu for now to maximize dev speed, we should add that back
later.

Requires pytorch/pytorch#105807

Test Plan:

```
python float8_playground/test.py
python float8_playground/test_sam.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Just having a copy of this in a script vs notebook is useful

Test Plan:

```
python te_examples/quickstart_guide.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
add a repro of TransformerEngine's quickstart guide
Summary:

Creates a simple image classification network and finetunes it
on MNIST.  Baseline is fp32, and fp8 training can be enabled with
a flag.  Verified that fp8 training converges on this simple example.
Note that fp8 compute is emulated for now as we don't have a hookup
to the real fp8 matmul kernel yet.

Test Plan:

```
with-proxy python finetune/mnist.py --batch-size 4096
// https://gist.github.com/vkuzo/0e8cbb3df1f0610e528ac3ad15da3ace
with-proxy python finetune/mnist.py --batch-size 4096 --use-pt-fp
// https://gist.github.com/vkuzo/99b0cf2c1492a5f605c9f028f12340c3
```

Reviewers:

Subscribers:

Tasks:

Tags:
add simple finetuning check with fp8 emulation
Summary:

Adds a test for numerical equivalence of single GPU vs FSDP for a toy
model.

Note: this is not related to fp8 yet, a future PR will add a test that
this still holds for fp8.

Test Plan:

```
./float8_playground/test_fsdp.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

```
python tests/test.py
with-proxy python tests/test_sam.py
./tests/test_fsdp.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
refactor tests into separate dir
Summary:

Note that we can't match the weight gradient and its scale
because of gradient reduction.

Test Plan:

```
./tests/test_fsdp.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

```
with-proxy ./tests/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
clean up float8 differentiable constructor
Summary:

before: Float8Tensor always converted back to float32
after: Float8Tensor remembers original dtype

this will be useful for autocast support

Test Plan:

```
with-proxy ./tests/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
Make Float8Tensor remember original precision
Summary:

Adding grads and casting back to fp8 doesn't have a clear
use case yet, we can this back if needed. For now, simplify

Test Plan:

```
with-proxy ./tests/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
make grad addition happen in original precision
Summary:

We just duplicate autocast logic for F.linear to have Float8Linear
do the right thing.

Test Plan:

```
with-proxy ./tests/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
vkuzo and others added 26 commits August 18, 2023 15:37
fix tests for real scaled matmul
Summary:

1. get the LLaMa 70B single GPU shapes
2. benchmark fp16 matmul vs float8 matmul for those shapes (for now just
the forward)
3. display the speedup and pct peak tops

Test Plan:

```
name       shape                  bf16_time_s    fp8_fw_time_s   fp8_fw_sp
---------  -------------------  -------------  ---------------  ----------
attn.wqkv  (16384, 8192, 1280)       6.04E-04         2.43E-04        2.48
attn.w0    (16384, 1024, 8192)       5.00E-04         2.51E-04        1.99
ffn.w13    (16384, 8192, 7168)       3.57E-03         1.39E-03        2.56
ffn.w2     (16384, 3584, 8192)       1.59E-03         7.45E-04        2.13

// full logs: https://gist.github.com/vkuzo/ce31a3fef86816036441b98be060849a
```

Reviewers:

Subscribers:

Tasks:

Tags:
add simple benchmark for single GPU LLaMa 70B shapes
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
adds requirements.txt for easier checkout on other machines
Summary:

Adds the right skip logic to skip/emulate when we are not on an H100.

Test Plan:

```
with-proxy ./tests/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Passing a user defined object such as `DelayedScaling` to
`torch.autograd.Function` is not traceable by PT2.0 yet.

For now, just pass the underlying value instead since that
is traceable.

Test Plan:

```
TORCH_LOGS="aot,dynamo" pytest tests/test.py -k pt2_nots -s
// before this PR, got the error below
[2023-08-15 08:58:22,657] torch._dynamo.variables.higher_order_ops: [ERROR] HigherOrderOperator with body that accepts non-Tensors as input
// after this PR, the error is no longer shown
```

Reviewers:

Subscribers:

Tasks:

Tags:
pt2 compatibility: stop passing DelayedScaling around
Summary:

Apparently having `torch.dype` as an argument to
`torch.autograd.Function` does not work yet.

Note: was not able to repro this with a standalone repro,
although I did not try to hard.

Test Plan:

```
TORCH_LOGS="dynamo" pytest tests/test.py -k pt2_nots -s
```

Reviewers:

Subscribers:

Tasks:

Tags:
work around pt2 error of unexpected type in sourceless builder
Summary:

We don't need two buffers to track init state, can just use one.

This will simplify making this control flow traceable in a future PR.

Test Plan:

```
with-proxy ./tests/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

This PR moves the init logic from PyTorch tensors to a Python
class attribute.

After this, we can enable `fullgraph=True` with eager mode
PT2.0 tracing without tensor subclass.

Test Plan:

```
with-proxy ./tests/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
switch init logic to Python, make nots pt2.0 traceable
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Ensuring tests pass after
#39, we should
fix them in a future PR.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 25, 2023
@drisspg drisspg closed this Nov 16, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants