diff --git a/.github/workflows/integration_test_4gpu.yaml b/.github/workflows/integration_test_4gpu.yaml index 3816f404..7c913b07 100644 --- a/.github/workflows/integration_test_4gpu.yaml +++ b/.github/workflows/integration_test_4gpu.yaml @@ -39,5 +39,6 @@ jobs: python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/ + python -m pip install git+https://github.com/pytorch-labs/float8_experimental.git mkdir artifacts-to-be-uploaded python ./test_runner.py artifacts-to-be-uploaded --ngpu 4 diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index f41a812d..496b590a 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -15,11 +15,8 @@ import contextlib from typing import Optional -import float8_experimental.config as config - import torch import torch.nn as nn -from float8_experimental.float8_linear import TensorScalingType from torchtitan.config_manager import JobConfig from torchtitan.logging_utils import logger @@ -27,6 +24,8 @@ @contextlib.contextmanager def set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather: bool): + import float8_experimental.config as config + prev = config.enable_fsdp_fp8_all_gather torch.distributed.barrier() config.enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather @@ -51,6 +50,7 @@ def build_fp8_linear( job_config.training.enable_fsdp_fp8_all_gather and dp_enabled ) try: + from float8_experimental.float8_linear import TensorScalingType from float8_experimental.float8_linear_utils import ( swap_linear_with_float8_linear, ) diff --git a/train.py b/train.py index 14008525..2c63e299 100644 --- a/train.py +++ b/train.py @@ -19,7 +19,6 @@ import torch import torch.nn.functional as F -from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp from torch.distributed import destroy_process_group from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.elastic.multiprocessing.errors import record @@ -404,6 +403,10 @@ def loss_fn(pred, labels): and job_config.training.enable_fsdp_fp8_all_gather and job_config.training.precompute_float8_dynamic_scale_for_fsdp ): + from float8_experimental.fsdp_utils import ( + precompute_float8_dynamic_scale_for_fsdp, + ) + # calculate float8 dynamic amax/scale for all-parameter for FSDP2 # it issues a single all-reduce for all parameters at once for better performance precompute_float8_dynamic_scale_for_fsdp(model)