This repository has been archived by the owner on Aug 7, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 20
[poc] scratch work on getting torch_dispatch subclass to play well wi… #56
Closed
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Summary: This is a copy of facebookexperimental/protoquant#23 Many things will change based on recent discussions! Test Plan: ``` python float8_playground/test.py ``` Reviewers: Subscribers: Tasks: Tags:
Python-only float8 data type + bare bones UEX
Summary: skipped it before, going back to it now this will be useful for transformer block Test Plan: ``` python float8_playground/test.py ``` Reviewers: Subscribers: Tasks: Tags:
add bias support to float8linear
Summary: forgot on last PR Test Plan: Reviewers: Subscribers: Tasks: Tags:
fix naming nit
Summary: Now that pytorch/pytorch#104242 landed, we can stop emulation - this simplifies the code quite a bit. Test Plan: ``` python float8_playground/test.py ``` Reviewers: Subscribers: Tasks: Tags:
switch from emulated to real float8 dtypes
Summary: Adds a check that using `Float8Linear` on a `SAM` encoder results in reasonable spot check accuracy. Note that grad accuracy is not tested as it is all over the place, this is probably expected but saving investigation until later. Specifically the layernorm and positional encoding grads have a large error for fp8 version vs reference. Test Plan: ``` python float8_playground/test.py with-proxy python float8_playground/test_sam.py ``` Reviewers: Subscribers: Tasks: Tags:
add numerical test on SAM encoder
Summary: Don't test cpu for now to maximize dev speed, we should add that back later. Requires pytorch/pytorch#105807 Test Plan: ``` python float8_playground/test.py python float8_playground/test_sam.py ``` Reviewers: Subscribers: Tasks: Tags:
moves the repo to cuda
Summary: Just having a copy of this in a script vs notebook is useful Test Plan: ``` python te_examples/quickstart_guide.py ``` Reviewers: Subscribers: Tasks: Tags:
add a repro of TransformerEngine's quickstart guide
Summary: Creates a simple image classification network and finetunes it on MNIST. Baseline is fp32, and fp8 training can be enabled with a flag. Verified that fp8 training converges on this simple example. Note that fp8 compute is emulated for now as we don't have a hookup to the real fp8 matmul kernel yet. Test Plan: ``` with-proxy python finetune/mnist.py --batch-size 4096 // https://gist.github.com/vkuzo/0e8cbb3df1f0610e528ac3ad15da3ace with-proxy python finetune/mnist.py --batch-size 4096 --use-pt-fp // https://gist.github.com/vkuzo/99b0cf2c1492a5f605c9f028f12340c3 ``` Reviewers: Subscribers: Tasks: Tags:
add simple finetuning check with fp8 emulation
Summary: Adds a test for numerical equivalence of single GPU vs FSDP for a toy model. Note: this is not related to fp8 yet, a future PR will add a test that this still holds for fp8. Test Plan: ``` ./float8_playground/test_fsdp.sh ``` Reviewers: Subscribers: Tasks: Tags:
add numerical test for FSDP
Summary: Test Plan: ``` python tests/test.py with-proxy python tests/test_sam.py ./tests/test_fsdp.sh ``` Reviewers: Subscribers: Tasks: Tags:
refactor tests into separate dir
Summary: Note that we can't match the weight gradient and its scale because of gradient reduction. Test Plan: ``` ./tests/test_fsdp.sh ``` Reviewers: Subscribers: Tasks: Tags:
test FSDP with fp8 linear toy model
Summary: Test Plan: ``` with-proxy ./tests/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags:
clean up float8 differentiable constructor
Summary: before: Float8Tensor always converted back to float32 after: Float8Tensor remembers original dtype this will be useful for autocast support Test Plan: ``` with-proxy ./tests/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags:
Make Float8Tensor remember original precision
Summary: Adding grads and casting back to fp8 doesn't have a clear use case yet, we can this back if needed. For now, simplify Test Plan: ``` with-proxy ./tests/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags:
make grad addition happen in original precision
Summary: We just duplicate autocast logic for F.linear to have Float8Linear do the right thing. Test Plan: ``` with-proxy ./tests/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags:
fix tests for real scaled matmul
Summary: 1. get the LLaMa 70B single GPU shapes 2. benchmark fp16 matmul vs float8 matmul for those shapes (for now just the forward) 3. display the speedup and pct peak tops Test Plan: ``` name shape bf16_time_s fp8_fw_time_s fp8_fw_sp --------- ------------------- ------------- --------------- ---------- attn.wqkv (16384, 8192, 1280) 6.04E-04 2.43E-04 2.48 attn.w0 (16384, 1024, 8192) 5.00E-04 2.51E-04 1.99 ffn.w13 (16384, 8192, 7168) 3.57E-03 1.39E-03 2.56 ffn.w2 (16384, 3584, 8192) 1.59E-03 7.45E-04 2.13 // full logs: https://gist.github.com/vkuzo/ce31a3fef86816036441b98be060849a ``` Reviewers: Subscribers: Tasks: Tags:
add simple benchmark for single GPU LLaMa 70B shapes
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
adds requirements.txt for easier checkout on other machines
Update test call
Summary: Adds the right skip logic to skip/emulate when we are not on an H100. Test Plan: ``` with-proxy ./tests/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags:
fix tests on an A100
Summary: Passing a user defined object such as `DelayedScaling` to `torch.autograd.Function` is not traceable by PT2.0 yet. For now, just pass the underlying value instead since that is traceable. Test Plan: ``` TORCH_LOGS="aot,dynamo" pytest tests/test.py -k pt2_nots -s // before this PR, got the error below [2023-08-15 08:58:22,657] torch._dynamo.variables.higher_order_ops: [ERROR] HigherOrderOperator with body that accepts non-Tensors as input // after this PR, the error is no longer shown ``` Reviewers: Subscribers: Tasks: Tags:
pt2 compatibility: stop passing DelayedScaling around
Summary: Apparently having `torch.dype` as an argument to `torch.autograd.Function` does not work yet. Note: was not able to repro this with a standalone repro, although I did not try to hard. Test Plan: ``` TORCH_LOGS="dynamo" pytest tests/test.py -k pt2_nots -s ``` Reviewers: Subscribers: Tasks: Tags:
work around pt2 error of unexpected type in sourceless builder
Summary: We don't need two buffers to track init state, can just use one. This will simplify making this control flow traceable in a future PR. Test Plan: ``` with-proxy ./tests/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags:
simplify init buffers
Summary: This PR moves the init logic from PyTorch tensors to a Python class attribute. After this, we can enable `fullgraph=True` with eager mode PT2.0 tracing without tensor subclass. Test Plan: ``` with-proxy ./tests/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags:
switch init logic to Python, make nots pt2.0 traceable
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
fix nit in nots test
Summary: Ensuring tests pass after #39, we should fix them in a future PR. Test Plan: Reviewers: Subscribers: Tasks: Tags:
fix tests after #39
Add transformers to reqs
Sign up for free
to subscribe to this conversation on GitHub.
Already have an account?
Sign in.
Labels
CLA Signed
This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Not for landing
This is just some scratch work I did today trying to get
Float8Tensor
(as a torch_dispatch subclass) to play well with AOTAutograd.AOTAutograd bits
I mostly got things working, with one annoying issue: The custom fp8 linear as written today will return a
Float8Tensor
output in the forward. However, When we call.sum().backward()
, the sum call returns a plaintorch.Tensor
. This causes the grad_out of the compiled function (in the backward) to be a plain tensor. This is a problem, because:(1) AOTAutograd eagerly traces out a backward graph to do partitioning, during the forward
(2) it needs to make guesses on whether or not each grad_out will be a subclass
(3) Right now, when it sees a fw_out that is a subclass, it assumes that the corresponding grad_out will also be a subclass
(4) If we guessed wrong, then we'll need to re-trace the backward. This requires backward guards, which are not implemented yet.
I had a workaround hack locally, where I just special-case FP8 in aot autograd to always assume that grad_outputs are not subclasses. That seemed to work, and I was able to use
torch.compile
and generate fw and bw graphs.Dynamo bits
However, that isn't enough to test correctness. Why? The custom module in the test has a flag that gets flipped after the first iteration, that dynamo has to specialize the graph on.
I made a mild attempt to get things working in dynamo - seemed to get part of the way through, but eventually I gave up. Some progress here: 814239133