Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

[bc-breaking] rename config.enable_fsdp_fp8_all_gather to use float8 #332

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,9 @@ class Float8LinearConfig:
# option is useful for safety, but not strictly necessary.
enable_pre_and_post_forward: bool = True

# If True, then uses a tensor subclass for the fp8 linear module's weight that
# implements pre/post-all-gather methods to do fp8 all-gather with FSDP2.
# Only dynamic scaling is supported for now.
enable_fsdp_fp8_all_gather: bool = False
# If True, then uses a tensor subclass for the float8 linear module's weight that
# implements pre/post-all-gather methods to do float8 all-gather with FSDP2.
enable_fsdp_float8_all_gather: bool = False

# If True, then prior to performing the fp8 scaled mamtmul we will pad the
# inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls
Expand Down
2 changes: 1 addition & 1 deletion float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def from_float(
# 1. weight needs to be on the correct device to create the buffers
# 2. buffers need to be already created for the delayed scaling version
# of the weight wrapper to be initialized
if config.enable_fsdp_fp8_all_gather:
if config.enable_fsdp_float8_all_gather:
if config.cast_config_weight.scaling_type is TensorScalingType.DYNAMIC:
new_mod.weight = torch.nn.Parameter(
WeightWithDynamicFloat8CastTensor(
Expand Down
40 changes: 20 additions & 20 deletions test/test_fsdp2/test_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def world_size(self) -> int:
def test_transformer_parity(self):
self.run_subtests(
{
"enable_fsdp_fp8_all_gather": [False, True],
"enable_fsdp_float8_all_gather": [False, True],
"precompute": [False, True],
"scaling_type_weight": [
TensorScalingType.DYNAMIC,
Expand All @@ -96,12 +96,12 @@ def test_transformer_parity(self):

def _test_transformer_parity(
self,
enable_fsdp_fp8_all_gather: bool,
enable_fsdp_float8_all_gather: bool,
precompute: bool,
scaling_type_weight: TensorScalingType,
compile_transformer_block: bool,
):
if not enable_fsdp_fp8_all_gather and precompute:
if not enable_fsdp_float8_all_gather and precompute:
return
elif scaling_type_weight is TensorScalingType.DELAYED and precompute:
return
Expand All @@ -110,7 +110,7 @@ def _test_transformer_parity(
# embedding weight and output linear weight are tied but only the
# latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to
# fp8 for that tied weight, incorrectly using fp8 for the embedding.
weight_tying = not enable_fsdp_fp8_all_gather
weight_tying = not enable_fsdp_float8_all_gather
module = self.init_transformer(weight_tying=weight_tying).cuda()
ref_module = copy.deepcopy(module)
float8_linear_config1 = Float8LinearConfig(
Expand All @@ -125,7 +125,7 @@ def _test_transformer_parity(
transformer_block = torch.compile(transformer_block, dynamic=False)
ref_module.layers.register_module(layer_id, transformer_block)
float8_linear_config2 = Float8LinearConfig(
enable_fsdp_fp8_all_gather=enable_fsdp_fp8_all_gather,
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight),
)
convert_to_float8_training(
Expand Down Expand Up @@ -158,10 +158,10 @@ def _test_transformer_parity(
@skip_if_lt_x_gpu(2)
def test_transformer_memory(self):
"""Tests peak active memory in the forward and backward passes."""
for enable_fsdp_fp8_all_gather in [False, True]:
self._test_transformer_memory(enable_fsdp_fp8_all_gather)
for enable_fsdp_float8_all_gather in [False, True]:
self._test_transformer_memory(enable_fsdp_float8_all_gather)

def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
def _test_transformer_memory(self, enable_fsdp_float8_all_gather: bool):
torch.manual_seed(42)
# Pre-run a linear forward (gemm and bias) and backward (gemm) to
# allocate the cuBLAS workspaces before measuring the memory usage
Expand All @@ -184,7 +184,7 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
# Emulate the fp8 matmul to bypass the scaled matmul op's divisibility
# requirement to use a smaller activation size
float8_linear_config = Float8LinearConfig(
enable_fsdp_fp8_all_gather=enable_fsdp_fp8_all_gather,
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
emulate=True,
)
convert_to_float8_training(model, config=float8_linear_config)
Expand Down Expand Up @@ -231,7 +231,7 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
# number is kept much smaller than the actual memory usage, which is on
# the order of 100-200+ MB)
buffer_mb = 16
if enable_fsdp_fp8_all_gather:
if enable_fsdp_float8_all_gather:
# Non-block parameters (fp32), 3x block non-linear-weight
# parameters (fp32) and block linear-weight parameters (fp8)
# (current all-gather, copy-out, and next all-gather), and other
Expand All @@ -255,7 +255,7 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
# Backward:
loss.sum().backward()
mem_mb = self._get_peak_active_memory_mb()
if enable_fsdp_fp8_all_gather:
if enable_fsdp_float8_all_gather:
# Non-block parameters (fp32), 2x block non-linear weight
# parameters (fp32) and block linear-weight parameters (fp8)
# (current copy-out and next all-gather), 1x block gradients (fp32)
Expand Down Expand Up @@ -294,7 +294,7 @@ def test_weight_subclass_dynamic(self):
# Check for a single FSDP paramter group
module_fp32 = self.init_single_module()
float8_linear_config = Float8LinearConfig(
enable_fsdp_fp8_all_gather=True,
enable_fsdp_float8_all_gather=True,
emulate=True,
)
module = convert_to_float8_training(
Expand Down Expand Up @@ -360,7 +360,7 @@ def get_expected_all_gather_size(module: nn.Module):
module_fp32 = self.init_single_module()
ref_module = copy.deepcopy(module_fp32)
float8_linear_config = Float8LinearConfig(
enable_fsdp_fp8_all_gather=True,
enable_fsdp_float8_all_gather=True,
)
module_fp32 = convert_to_float8_training(
module_fp32, config=float8_linear_config
Expand Down Expand Up @@ -418,15 +418,15 @@ def test_fp32_fp8_single_module_parity(self):
[False, True],
[TensorScalingType.DYNAMIC, TensorScalingType.DELAYED],
)
for enable_fsdp_fp8_all_gather, scaling_type_weight in choices:
for enable_fsdp_float8_all_gather, scaling_type_weight in choices:
float8_linear_config1 = Float8LinearConfig(
enable_fsdp_fp8_all_gather=False,
enable_fsdp_float8_all_gather=False,
cast_config_weight=Float8TensorCastConfig(
scaling_type=scaling_type_weight
),
)
float8_linear_config2 = Float8LinearConfig(
enable_fsdp_fp8_all_gather=enable_fsdp_fp8_all_gather,
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
cast_config_weight=Float8TensorCastConfig(
scaling_type=scaling_type_weight
),
Expand Down Expand Up @@ -466,15 +466,15 @@ def test_fp32_fp8_multi_module_parity(self):
[False, True],
[TensorScalingType.DYNAMIC, TensorScalingType.DELAYED],
)
for enable_fsdp_fp8_all_gather, scaling_type_weight in choices:
for enable_fsdp_float8_all_gather, scaling_type_weight in choices:
float8_linear_config1 = Float8LinearConfig(
enable_fsdp_fp8_all_gather=False,
enable_fsdp_float8_all_gather=False,
cast_config_weight=Float8TensorCastConfig(
scaling_type=scaling_type_weight
),
)
float8_linear_config2 = Float8LinearConfig(
enable_fsdp_fp8_all_gather=enable_fsdp_fp8_all_gather,
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
cast_config_weight=Float8TensorCastConfig(
scaling_type=scaling_type_weight
),
Expand Down Expand Up @@ -545,7 +545,7 @@ def test_delayed_scaling_inplace_update(self):
"""
module = self.init_single_module()
float8_linear_config = Float8LinearConfig(
enable_fsdp_fp8_all_gather=True,
enable_fsdp_float8_all_gather=True,
cast_config_weight=Float8TensorCastConfig(
scaling_type=TensorScalingType.DELAYED
),
Expand Down
Loading