Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

skip fp8 CI on non-H100 GPUs #465

Merged
merged 31 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
8d00b73
float8 tmp save
weifengpy Jun 12, 2024
68d9f61
Merge branch 'main' into fsdp2
weifengpy Jun 19, 2024
4cd5f74
run 8b eager successfully
weifengpy Jun 19, 2024
05a4a06
enable compile
weifengpy Jun 20, 2024
f48a82e
benchmark
weifengpy Jun 20, 2024
14aabfb
1d setup
weifengpy Jun 21, 2024
b88aee9
2d setup
weifengpy Jun 21, 2024
2b4e0c2
2d setup
weifengpy Jun 21, 2024
ad63aba
Merge branch 'main' into fsdp2
weifengpy Jul 3, 2024
71d4dc6
Merge branch 'main' into fsdp2
weifengpy Jul 11, 2024
23536e9
fp8 all-gather FSDP
weifengpy Jul 12, 2024
bdb0fd0
linter
weifengpy Jul 12, 2024
ef0e843
add unit test and restore original toml
weifengpy Jul 12, 2024
c294f6a
add unit test for float8
weifengpy Jul 12, 2024
b58b07b
better doc with original dtype all-gather and value error on fp8
weifengpy Jul 15, 2024
7df10ae
improve config msg
weifengpy Jul 15, 2024
f674012
Merge branch 'pytorch:main' into fsdp2
weifengpy Jul 15, 2024
7dd788c
rename config to enable_fp8_linear and improve comments
weifengpy Jul 16, 2024
faefe27
rename to enable_fp8_linear
weifengpy Jul 16, 2024
7aad066
add 2D test
weifengpy Jul 16, 2024
5040c31
import Optional and NotImplement for delayed scaling
weifengpy Jul 16, 2024
cee653e
remove TP fp8 all-gather from CI
weifengpy Jul 16, 2024
e164285
fix linter
weifengpy Jul 16, 2024
22c71ea
remove redudant check
weifengpy Jul 16, 2024
595f83d
install float8_experimental in CI
weifengpy Jul 16, 2024
68e9f19
Merge branch 'pytorch:main' into fsdp2
weifengpy Jul 16, 2024
9de67ff
import float8_experimental inside enable_fp8_linear
weifengpy Jul 16, 2024
367507f
import float8_experimental only when needed
weifengpy Jul 17, 2024
de3de0e
Merge branch 'pytorch:main' into fsdp2
weifengpy Jul 17, 2024
1ed8dab
skip CI on non-H100 GPUs
weifengpy Jul 17, 2024
2be380d
warning about sm90
weifengpy Jul 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions torchtitan/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from torchtitan.config_manager import JobConfig
from torchtitan.logging_utils import logger

# Float8 is only supported on H100+ GPUs
SM90OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
Copy link
Contributor Author

@weifengpy weifengpy Jul 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

borrowed from torch/testing/_internal/common_cuda.py https://github.com/pytorch/pytorch/blob/main/torch/testing/_internal/common_cuda.py#L32

I want to avoid dependency on torch/testing, so copied it here



@contextlib.contextmanager
def set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather: bool):
Expand Down
7 changes: 4 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from torchtitan.checkpoint import CheckpointManager
from torchtitan.config_manager import JobConfig
from torchtitan.datasets import build_hf_data_loader, create_tokenizer
from torchtitan.float8_linear import build_fp8_linear
from torchtitan.float8_linear import build_fp8_linear, SM90OrLater
from torchtitan.logging_utils import init_logger, logger
from torchtitan.lr_scheduling import get_lr_schedulers
from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger
Expand Down Expand Up @@ -216,7 +216,7 @@ def loss_fn(pred, labels):
whole_model = model_cls.from_model_args(model_config)

# apply fp8 linear module swap
if job_config.training.enable_fp8_linear:
if SM90OrLater and job_config.training.enable_fp8_linear:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm I think we don't need to expose this to train.py?

We can try to check this directly in build_fp8_linear function, and if its not SM90OrLater, we can throw a warning and don't swap the linear layers? similar to other fp8 APIs in float8_linear.py, only apply when SM90OrLater, otherwise throw warnings

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got you. will move it to fp8 api with warnings

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively you could consider calling it maybe_build_fp8_linear and read enable_fp8_linear from job_config. Within that function, you can do logging to let user know if build is successful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

created maybe_build_fp8_linear and maybe_precompute_fp8_dynamic_scale_for_fsdp. config and SM90 check are moved from train.py into maybe_ functions

build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled)

# log model size
Expand Down Expand Up @@ -399,7 +399,8 @@ def loss_fn(pred, labels):
lr_schedulers.step()

if (
job_config.training.enable_fp8_linear
SM90OrLater
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a quick fix, you can have a return value fp8_enabled of maybe_build_fp8_linear and use that instead of job_config.training.enable_fp8_linear.

and job_config.training.enable_fp8_linear
and job_config.training.enable_fsdp_fp8_all_gather
and job_config.training.precompute_float8_dynamic_scale_for_fsdp
):
Expand Down
Loading