Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

import float8_experimental only when fp8 is enabled and install it in CI #464

Merged
merged 28 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
8d00b73
float8 tmp save
weifengpy Jun 12, 2024
68d9f61
Merge branch 'main' into fsdp2
weifengpy Jun 19, 2024
4cd5f74
run 8b eager successfully
weifengpy Jun 19, 2024
05a4a06
enable compile
weifengpy Jun 20, 2024
f48a82e
benchmark
weifengpy Jun 20, 2024
14aabfb
1d setup
weifengpy Jun 21, 2024
b88aee9
2d setup
weifengpy Jun 21, 2024
2b4e0c2
2d setup
weifengpy Jun 21, 2024
ad63aba
Merge branch 'main' into fsdp2
weifengpy Jul 3, 2024
71d4dc6
Merge branch 'main' into fsdp2
weifengpy Jul 11, 2024
23536e9
fp8 all-gather FSDP
weifengpy Jul 12, 2024
bdb0fd0
linter
weifengpy Jul 12, 2024
ef0e843
add unit test and restore original toml
weifengpy Jul 12, 2024
c294f6a
add unit test for float8
weifengpy Jul 12, 2024
b58b07b
better doc with original dtype all-gather and value error on fp8
weifengpy Jul 15, 2024
7df10ae
improve config msg
weifengpy Jul 15, 2024
f674012
Merge branch 'pytorch:main' into fsdp2
weifengpy Jul 15, 2024
7dd788c
rename config to enable_fp8_linear and improve comments
weifengpy Jul 16, 2024
faefe27
rename to enable_fp8_linear
weifengpy Jul 16, 2024
7aad066
add 2D test
weifengpy Jul 16, 2024
5040c31
import Optional and NotImplement for delayed scaling
weifengpy Jul 16, 2024
cee653e
remove TP fp8 all-gather from CI
weifengpy Jul 16, 2024
e164285
fix linter
weifengpy Jul 16, 2024
22c71ea
remove redudant check
weifengpy Jul 16, 2024
595f83d
install float8_experimental in CI
weifengpy Jul 16, 2024
68e9f19
Merge branch 'pytorch:main' into fsdp2
weifengpy Jul 16, 2024
9de67ff
import float8_experimental inside enable_fp8_linear
weifengpy Jul 16, 2024
367507f
import float8_experimental only when needed
weifengpy Jul 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/integration_test_4gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,6 @@ jobs:

python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/
python -m pip install git+https://github.com/pytorch-labs/float8_experimental.git
mkdir artifacts-to-be-uploaded
python ./test_runner.py artifacts-to-be-uploaded --ngpu 4
6 changes: 3 additions & 3 deletions torchtitan/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,17 @@
import contextlib
from typing import Optional

import float8_experimental.config as config

import torch
import torch.nn as nn
from float8_experimental.float8_linear import TensorScalingType

from torchtitan.config_manager import JobConfig
from torchtitan.logging_utils import logger


@contextlib.contextmanager
def set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather: bool):
import float8_experimental.config as config

prev = config.enable_fsdp_fp8_all_gather
torch.distributed.barrier()
config.enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather
Expand All @@ -51,6 +50,7 @@ def build_fp8_linear(
job_config.training.enable_fsdp_fp8_all_gather and dp_enabled
)
try:
from float8_experimental.float8_linear import TensorScalingType
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
)
Expand Down
5 changes: 4 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import torch
import torch.nn.functional as F
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
from torch.distributed import destroy_process_group
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.elastic.multiprocessing.errors import record
Expand Down Expand Up @@ -404,6 +403,10 @@ def loss_fn(pred, labels):
and job_config.training.enable_fsdp_fp8_all_gather
and job_config.training.precompute_float8_dynamic_scale_for_fsdp
):
from float8_experimental.fsdp_utils import (
precompute_float8_dynamic_scale_for_fsdp,
)

# calculate float8 dynamic amax/scale for all-parameter for FSDP2
# it issues a single all-reduce for all parameters at once for better performance
precompute_float8_dynamic_scale_for_fsdp(model)
Expand Down
Loading