Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

skip fp8 CI on non-H100 GPUs #465

Merged
merged 31 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
8d00b73
float8 tmp save
weifengpy Jun 12, 2024
68d9f61
Merge branch 'main' into fsdp2
weifengpy Jun 19, 2024
4cd5f74
run 8b eager successfully
weifengpy Jun 19, 2024
05a4a06
enable compile
weifengpy Jun 20, 2024
f48a82e
benchmark
weifengpy Jun 20, 2024
14aabfb
1d setup
weifengpy Jun 21, 2024
b88aee9
2d setup
weifengpy Jun 21, 2024
2b4e0c2
2d setup
weifengpy Jun 21, 2024
ad63aba
Merge branch 'main' into fsdp2
weifengpy Jul 3, 2024
71d4dc6
Merge branch 'main' into fsdp2
weifengpy Jul 11, 2024
23536e9
fp8 all-gather FSDP
weifengpy Jul 12, 2024
bdb0fd0
linter
weifengpy Jul 12, 2024
ef0e843
add unit test and restore original toml
weifengpy Jul 12, 2024
c294f6a
add unit test for float8
weifengpy Jul 12, 2024
b58b07b
better doc with original dtype all-gather and value error on fp8
weifengpy Jul 15, 2024
7df10ae
improve config msg
weifengpy Jul 15, 2024
f674012
Merge branch 'pytorch:main' into fsdp2
weifengpy Jul 15, 2024
7dd788c
rename config to enable_fp8_linear and improve comments
weifengpy Jul 16, 2024
faefe27
rename to enable_fp8_linear
weifengpy Jul 16, 2024
7aad066
add 2D test
weifengpy Jul 16, 2024
5040c31
import Optional and NotImplement for delayed scaling
weifengpy Jul 16, 2024
cee653e
remove TP fp8 all-gather from CI
weifengpy Jul 16, 2024
e164285
fix linter
weifengpy Jul 16, 2024
22c71ea
remove redudant check
weifengpy Jul 16, 2024
595f83d
install float8_experimental in CI
weifengpy Jul 16, 2024
68e9f19
Merge branch 'pytorch:main' into fsdp2
weifengpy Jul 16, 2024
9de67ff
import float8_experimental inside enable_fp8_linear
weifengpy Jul 16, 2024
367507f
import float8_experimental only when needed
weifengpy Jul 17, 2024
de3de0e
Merge branch 'pytorch:main' into fsdp2
weifengpy Jul 17, 2024
1ed8dab
skip CI on non-H100 GPUs
weifengpy Jul 17, 2024
2be380d
warning about sm90
weifengpy Jul 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions torchtitan/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
# Note: Performance
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
import contextlib
import functools
from typing import Optional

import torch
import torch.nn as nn
from torch._logging import warning_once

from torchtitan.config_manager import JobConfig
from torchtitan.logging_utils import logger
Expand All @@ -36,7 +38,13 @@ def set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather: bool):
config.enable_fsdp_fp8_all_gather = prev


def build_fp8_linear(
@functools.lru_cache(None)
def is_sm90_or_later():
# Float8 is only supported on H100+ GPUs
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)


def maybe_build_fp8_linear(
model: nn.Module, job_config: JobConfig, dp_enabled: Optional[bool] = False
):
"""
Expand All @@ -46,16 +54,24 @@ def build_fp8_linear(
This will mutate the model inplace.
"""
enable_fp8_linear = job_config.training.enable_fp8_linear
enable_fsdp_fp8_all_gather = (
job_config.training.enable_fsdp_fp8_all_gather and dp_enabled
)
if not enable_fp8_linear:
return
if not is_sm90_or_later():
warning_once(
logger,
"Failed to swap to Float8Linear because SM90 or later is not available",
)
return
try:
from float8_experimental.float8_linear import TensorScalingType
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
)

# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
enable_fsdp_fp8_all_gather = (
job_config.training.enable_fsdp_fp8_all_gather and dp_enabled
)
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
swap_linear_with_float8_linear(
model, scaling_type_w=TensorScalingType.DYNAMIC
Expand All @@ -67,3 +83,23 @@ def build_fp8_linear(
raise ImportError(
"float8_experimental is not installed. Please install it to use fp8 linear layers."
) from exc


def maybe_precompute_fp8_dynamic_scale_for_fsdp(
model: nn.Module, job_config: JobConfig
):
if not (
job_config.training.enable_fp8_linear
and job_config.training.enable_fsdp_fp8_all_gather
and job_config.training.precompute_float8_dynamic_scale_for_fsdp
):
return
if not is_sm90_or_later():
warning_once(
logger,
"Skipped precomputing fp8 scales because SM90 or later is not available",
)
return
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp

precompute_float8_dynamic_scale_for_fsdp(model)
26 changes: 10 additions & 16 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
from torchtitan.checkpoint import CheckpointManager
from torchtitan.config_manager import JobConfig
from torchtitan.datasets import build_hf_data_loader, create_tokenizer
from torchtitan.float8_linear import build_fp8_linear
from torchtitan.float8_linear import (
maybe_build_fp8_linear,
maybe_precompute_fp8_dynamic_scale_for_fsdp,
)
from torchtitan.logging_utils import init_logger, logger
from torchtitan.lr_scheduling import get_lr_schedulers
from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger
Expand Down Expand Up @@ -215,9 +218,8 @@ def loss_fn(pred, labels):
with torch.device("meta"):
whole_model = model_cls.from_model_args(model_config)

# apply fp8 linear module swap
if job_config.training.enable_fp8_linear:
build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled)
# swap to Float8Linear base on fp8 config
maybe_build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled)

# log model size
model_param_count = get_num_params(whole_model)
Expand Down Expand Up @@ -398,18 +400,10 @@ def loss_fn(pred, labels):
optimizers.step()
lr_schedulers.step()

if (
job_config.training.enable_fp8_linear
and job_config.training.enable_fsdp_fp8_all_gather
and job_config.training.precompute_float8_dynamic_scale_for_fsdp
):
from float8_experimental.fsdp_utils import (
precompute_float8_dynamic_scale_for_fsdp,
)

# calculate float8 dynamic amax/scale for all-parameter for FSDP2
# it issues a single all-reduce for all parameters at once for better performance
precompute_float8_dynamic_scale_for_fsdp(model)
# when fp8 config is on,
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
# it issues a single all-reduce for all parameters at once for better performance
maybe_precompute_fp8_dynamic_scale_for_fsdp(model, job_config)

losses_since_last_log.append(loss)

Expand Down
Loading