Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
bc breaking - unify filtering functions
Browse files Browse the repository at this point in the history
Summary:

bc breaking, but we don't have bc yet, so just mentioning this upfront

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: a6664ad758ca9fa4f8a81b4d4c065c61f18cb983
Pull Request resolved: #322
  • Loading branch information
vkuzo committed Jul 19, 2024
1 parent c58fb5d commit 0c01953
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 83 deletions.
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,19 @@ from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_f
# create model
m = Model(...)

# optional: filter layers from being eligible for float8 conversion
def layer_filter_fn(fqn: str, mod: torch.nn.Module):
# don't convert the output layer
if fqn == "output":
return False
# don't convert linear layers with weight dimensions not divisible by 16
if isinstance(mod, torch.nn.Linear):
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
return True

# convert all `torch.nn.Linear` modules to `Float8Linear`
swap_linear_with_float8_linear(m)
swap_linear_with_float8_linear(m, layer_filter_fn=layer_filter_fn)

# optional: use FSDP
model = FSDP(model, use_orig_params=True)
Expand Down
76 changes: 30 additions & 46 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,26 +59,11 @@ def _update_history_stack(
amax_history_stack.copy_(new_amax_history_stack)


def filter_out_small_unaligned_layers(size_limit: int) -> Callable[[nn.Linear], bool]:
"""
Returns a callable that filters out small (dimensions less than the given `size_limit`)
and unaligned (dimenstions not divisible by 16) layers.
It can be passed as the `linear_layer_filter` argument to `swap_linear_with_float8_linear`.
"""
return (
lambda linear_layer: linear_layer.in_features >= size_limit
and linear_layer.out_features >= size_limit
and linear_layer.in_features % 16 == 0
and linear_layer.out_features % 16 == 0
)


def swap_linear_layers(
module: nn.Module,
from_float_func: Callable[[nn.Linear], nn.Linear],
*,
skip_fqn_list: Optional[List[str]] = None,
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
layer_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None,
) -> Optional[nn.Module]:
"""
Generic function to swap linear layers in a module with a new type of linear layer.
Expand All @@ -90,18 +75,17 @@ def swap_linear_layers(
Args:
module: Module to modify.
from_float_func: Function that accepts a linear layer and returns a new type of linear layer.
skip_fqn_list: If specified, a list of module FQNs to skip.
linear_layer_filter: If specified, only the linear layers
that pass the filter function will be swapped.
from_float_kwargs: Additional keyword arguments for from_float_func.
layer_filter_fn: If specified, only the modules that
that pass the filter function will be swapped. The inputs to the
filter function are the FQN and module instance.
Returns:
nn.Module: The modified module with swapped linear layers.
"""
module_names_to_skip = set(skip_fqn_list or [])

if isinstance(module, nn.Linear) and (
linear_layer_filter is None or linear_layer_filter(module)
# linear_layer_filter is None or linear_layer_filter(module)
layer_filter_fn is None
or layer_filter_fn("", module)
):
if len(list(module.children())) > 0:
raise AssertionError(
Expand All @@ -112,43 +96,44 @@ def swap_linear_layers(
)

root_module = module
visited_modules = {root_module}

for module_name, module in root_module.named_modules():
if module_name in module_names_to_skip:
visited_modules.add(module)

def post_order_traversal(
module: nn.Module, module_name: str, parent_module: Optional[nn.Module]
module: nn.Module,
cur_fqn: Optional[str] = None,
parent_module: Optional[nn.Module] = None,
):
nonlocal visited_modules
if cur_fqn is None:
cur_fqn = ""

for child_module_name, child_module in module.named_children():
if child_module not in visited_modules:
visited_modules.add(child_module)
post_order_traversal(child_module, child_module_name, module)
if cur_fqn == "":
new_fqn = child_module_name
else:
new_fqn = f"{cur_fqn}.{child_module_name}"

post_order_traversal(child_module, new_fqn, module)

if isinstance(module, nn.Linear) and (
linear_layer_filter is None or linear_layer_filter(module)
# linear_layer_filter is None or linear_layer_filter(module)
layer_filter_fn is None
or layer_filter_fn(cur_fqn, module)
):
assert (
parent_module is not None
), f"Linear root module should return early: {module}"
new_linear_module = from_float_func(module)
setattr(parent_module, module_name, new_linear_module)
cur_module_name = cur_fqn.split(".")[-1]
setattr(parent_module, cur_module_name, new_linear_module)

post_order_traversal(root_module, "", None)
# Without this explicit `del`, this set only gets deleted upon an explicit
# garbage collection (not from when its refcount hits zero)
del visited_modules
post_order_traversal(root_module)
return root_module


def swap_linear_with_float8_linear(
module: nn.Module,
*,
skip_fqn_list: Optional[List[str]] = None,
emulate: bool = False,
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
layer_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None,
scaling_type_x: TensorScalingType = TensorScalingType.DYNAMIC,
scaling_type_w: TensorScalingType = TensorScalingType.DYNAMIC,
scaling_type_dL_dY: TensorScalingType = TensorScalingType.DYNAMIC,
Expand All @@ -158,10 +143,10 @@ def swap_linear_with_float8_linear(
Args:
module: Module to modify.
skip_fqn_list: If specified, a list of module FQNs to skip.
emulate: If True, emulation is used instead of hardware accelerated gemm
linear_layer_filter: If specified, only the linear layers
that pass the filter function will be swapped.
layer_filter_fn: If specified, only the modules that
that pass the filter function will be swapped. The inputs to the
filter function are the FQN and module instance.
scaling_type_x (TensorScalingType): scaling type for `x`
scaling_type_w (TensorScalingType): scaling type for `w`
scaling_type_dL_dY (TensorScalingType): scaling type for `dL_dY`
Expand All @@ -179,8 +164,7 @@ def swap_linear_with_float8_linear(
return swap_linear_layers(
module,
from_float,
skip_fqn_list=skip_fqn_list,
linear_layer_filter=linear_layer_filter,
layer_filter_fn=layer_filter_fn,
)


Expand Down
10 changes: 6 additions & 4 deletions float8_experimental/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dataclasses import dataclass

from enum import auto, Enum
from typing import List, Optional
from typing import Callable, List, Optional

import float8_experimental.config as config

Expand Down Expand Up @@ -209,7 +209,7 @@ def quantize_to_float8(
module: nn.Module,
quant_config: QuantConfig,
*,
skip_fqn_list: Optional[List[str]] = None,
layer_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None,
use_fast_accum: bool = True,
) -> Optional[nn.Module]:
"""
Expand All @@ -222,7 +222,9 @@ def quantize_to_float8(
Args:
module (nn.Module): The module to modify.
quant_config (QuantConfig): Quantization configuration for Float8 conversion.
skip_fqn_list (List[str], optional): List of module FQNs to skip during conversion.
layer_filter_fn: If specified, only the modules that
that pass the filter function will be swapped. The inputs to the
filter function are the FQN and module instance.
use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True.
Returns:
Expand All @@ -234,5 +236,5 @@ def quantize_to_float8(
return swap_linear_layers(
module,
lambda m: Float8InferenceLinear.from_float(m, quant_config, use_fast_accum),
skip_fqn_list=skip_fqn_list,
layer_filter_fn=layer_filter_fn,
)
74 changes: 42 additions & 32 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear_utils import (
filter_out_small_unaligned_layers,
linear_requires_sync,
swap_linear_with_float8_linear,
sync_float8_amax_and_scale_history,
Expand Down Expand Up @@ -631,24 +630,34 @@ def __init__(self, dim: int):
self.lin1 = nn.Linear(dim, 4 * dim)
self.lin2 = nn.Linear(4 * dim, 4 * dim)

for emulate in [True, False]:
model = nn.Sequential(MLP(8), nn.Linear(32, 32), MLP(40))
# filter out the linear layers whose shape is smaller than 32 or non-divisible by 16.
model = swap_linear_with_float8_linear(
model,
emulate=emulate,
linear_layer_filter=filter_out_small_unaligned_layers(32),
model = nn.Sequential(MLP(8), nn.Linear(32, 32), MLP(40))
# filter out the linear layers whose shape is smaller than 32 or non-divisible by 16.

size_limit = 32

def layer_filter_fn(fqn, mod):
return (
mod.in_features >= size_limit
and mod.out_features >= size_limit
and mod.in_features % 16 == 0
and mod.out_features % 16 == 0
)
# in_features=8, out_features=32, 8 is less than 32.
self.assertNotIsInstance(model[0].lin1, Float8Linear)
# in_features=32, out_features=32,
self.assertIsInstance(model[0].lin2, Float8Linear)
# in_features=32, out_features=32,
self.assertIsInstance(model[1], Float8Linear)
# in_features=40, out_features=160, 40 is not divisible by 16.
self.assertNotIsInstance(model[2].lin1, Float8Linear)
# in_features=160, out_features=160,
self.assertIsInstance(model[2].lin2, Float8Linear)

model = swap_linear_with_float8_linear(
model,
emulate=True,
layer_filter_fn=layer_filter_fn,
)
# in_features=8, out_features=32, 8 is less than 32.
self.assertNotIsInstance(model[0].lin1, Float8Linear)
# in_features=32, out_features=32,
self.assertIsInstance(model[0].lin2, Float8Linear)
# in_features=32, out_features=32,
self.assertIsInstance(model[1], Float8Linear)
# in_features=40, out_features=160, 40 is not divisible by 16.
self.assertNotIsInstance(model[2].lin1, Float8Linear)
# in_features=160, out_features=160,
self.assertIsInstance(model[2].lin2, Float8Linear)

def test_swap_submodule_linears_with_skip(self):
class MLP(nn.Module):
Expand All @@ -657,20 +666,21 @@ def __init__(self, dim: int):
self.lin1 = nn.Linear(dim, 4 * dim)
self.lin2 = nn.Linear(4 * dim, dim)

for emulate in [True, False]:
model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3))
skip_fqn_list = ["2", "0.lin2"]
model = swap_linear_with_float8_linear(
model, emulate=emulate, skip_fqn_list=skip_fqn_list
)
self.assertIsInstance(model[0].lin1, Float8Linear)
self.assertNotIsInstance(model[0].lin2, Float8Linear)
self.assertIsInstance(model[0].lin2, nn.Linear)
self.assertIsInstance(model[1], Float8Linear)
self.assertNotIsInstance(model[2].lin2, Float8Linear)
self.assertNotIsInstance(model[2].lin2, Float8Linear)
self.assertIsInstance(model[2].lin1, nn.Linear)
self.assertIsInstance(model[2].lin2, nn.Linear)
model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3))
layer_filter_fn = lambda fqn, mod: fqn not in [
"0.lin2",
"2.lin1",
]
model = swap_linear_with_float8_linear(
model,
emulate=True,
layer_filter_fn=layer_filter_fn,
)
self.assertTrue(type(model[0].lin1) is Float8Linear)
self.assertTrue(type(model[0].lin2) is nn.Linear)
self.assertTrue(type(model[1]) is Float8Linear)
self.assertTrue(type(model[2].lin1) is nn.Linear)
self.assertTrue(type(model[2].lin2) is Float8Linear)

def test_fp8_tensor_statistics(self):
hp_dtypes = (torch.float32, torch.float16, torch.bfloat16)
Expand Down

0 comments on commit 0c01953

Please sign in to comment.