diff --git a/README.md b/README.md index b32dbec..ea75fe7 100644 --- a/README.md +++ b/README.md @@ -43,8 +43,19 @@ from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_f # create model m = Model(...) +# optional: filter layers from being eligible for float8 conversion +def layer_filter_fn(fqn: str, mod: torch.nn.Module): + # don't convert the output layer + if fqn == "output": + return False + # don't convert linear layers with weight dimensions not divisible by 16 + if isinstance(mod, torch.nn.Linear): + if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: + return False + return True + # convert all `torch.nn.Linear` modules to `Float8Linear` -swap_linear_with_float8_linear(m) +swap_linear_with_float8_linear(m, layer_filter_fn=layer_filter_fn) # optional: use FSDP model = FSDP(model, use_orig_params=True) diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 818fef0..f9245fb 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -59,26 +59,11 @@ def _update_history_stack( amax_history_stack.copy_(new_amax_history_stack) -def filter_out_small_unaligned_layers(size_limit: int) -> Callable[[nn.Linear], bool]: - """ - Returns a callable that filters out small (dimensions less than the given `size_limit`) - and unaligned (dimenstions not divisible by 16) layers. - It can be passed as the `linear_layer_filter` argument to `swap_linear_with_float8_linear`. - """ - return ( - lambda linear_layer: linear_layer.in_features >= size_limit - and linear_layer.out_features >= size_limit - and linear_layer.in_features % 16 == 0 - and linear_layer.out_features % 16 == 0 - ) - - def swap_linear_layers( module: nn.Module, from_float_func: Callable[[nn.Linear], nn.Linear], *, - skip_fqn_list: Optional[List[str]] = None, - linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None, + layer_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None, ) -> Optional[nn.Module]: """ Generic function to swap linear layers in a module with a new type of linear layer. @@ -90,18 +75,17 @@ def swap_linear_layers( Args: module: Module to modify. from_float_func: Function that accepts a linear layer and returns a new type of linear layer. - skip_fqn_list: If specified, a list of module FQNs to skip. - linear_layer_filter: If specified, only the linear layers - that pass the filter function will be swapped. - from_float_kwargs: Additional keyword arguments for from_float_func. + layer_filter_fn: If specified, only the modules that + that pass the filter function will be swapped. The inputs to the + filter function are the FQN and module instance. Returns: nn.Module: The modified module with swapped linear layers. """ - module_names_to_skip = set(skip_fqn_list or []) - if isinstance(module, nn.Linear) and ( - linear_layer_filter is None or linear_layer_filter(module) + # linear_layer_filter is None or linear_layer_filter(module) + layer_filter_fn is None + or layer_filter_fn("", module) ): if len(list(module.children())) > 0: raise AssertionError( @@ -112,43 +96,44 @@ def swap_linear_layers( ) root_module = module - visited_modules = {root_module} - - for module_name, module in root_module.named_modules(): - if module_name in module_names_to_skip: - visited_modules.add(module) def post_order_traversal( - module: nn.Module, module_name: str, parent_module: Optional[nn.Module] + module: nn.Module, + cur_fqn: Optional[str] = None, + parent_module: Optional[nn.Module] = None, ): - nonlocal visited_modules + if cur_fqn is None: + cur_fqn = "" + for child_module_name, child_module in module.named_children(): - if child_module not in visited_modules: - visited_modules.add(child_module) - post_order_traversal(child_module, child_module_name, module) + if cur_fqn == "": + new_fqn = child_module_name + else: + new_fqn = f"{cur_fqn}.{child_module_name}" + + post_order_traversal(child_module, new_fqn, module) if isinstance(module, nn.Linear) and ( - linear_layer_filter is None or linear_layer_filter(module) + # linear_layer_filter is None or linear_layer_filter(module) + layer_filter_fn is None + or layer_filter_fn(cur_fqn, module) ): assert ( parent_module is not None ), f"Linear root module should return early: {module}" new_linear_module = from_float_func(module) - setattr(parent_module, module_name, new_linear_module) + cur_module_name = cur_fqn.split(".")[-1] + setattr(parent_module, cur_module_name, new_linear_module) - post_order_traversal(root_module, "", None) - # Without this explicit `del`, this set only gets deleted upon an explicit - # garbage collection (not from when its refcount hits zero) - del visited_modules + post_order_traversal(root_module) return root_module def swap_linear_with_float8_linear( module: nn.Module, *, - skip_fqn_list: Optional[List[str]] = None, emulate: bool = False, - linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None, + layer_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None, scaling_type_x: TensorScalingType = TensorScalingType.DYNAMIC, scaling_type_w: TensorScalingType = TensorScalingType.DYNAMIC, scaling_type_dL_dY: TensorScalingType = TensorScalingType.DYNAMIC, @@ -158,10 +143,10 @@ def swap_linear_with_float8_linear( Args: module: Module to modify. - skip_fqn_list: If specified, a list of module FQNs to skip. emulate: If True, emulation is used instead of hardware accelerated gemm - linear_layer_filter: If specified, only the linear layers - that pass the filter function will be swapped. + layer_filter_fn: If specified, only the modules that + that pass the filter function will be swapped. The inputs to the + filter function are the FQN and module instance. scaling_type_x (TensorScalingType): scaling type for `x` scaling_type_w (TensorScalingType): scaling type for `w` scaling_type_dL_dY (TensorScalingType): scaling type for `dL_dY` @@ -179,8 +164,7 @@ def swap_linear_with_float8_linear( return swap_linear_layers( module, from_float, - skip_fqn_list=skip_fqn_list, - linear_layer_filter=linear_layer_filter, + layer_filter_fn=layer_filter_fn, ) diff --git a/float8_experimental/inference.py b/float8_experimental/inference.py index 2bb593b..342b6a1 100644 --- a/float8_experimental/inference.py +++ b/float8_experimental/inference.py @@ -10,7 +10,7 @@ from dataclasses import dataclass from enum import auto, Enum -from typing import List, Optional +from typing import Callable, List, Optional import float8_experimental.config as config @@ -209,7 +209,7 @@ def quantize_to_float8( module: nn.Module, quant_config: QuantConfig, *, - skip_fqn_list: Optional[List[str]] = None, + layer_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None, use_fast_accum: bool = True, ) -> Optional[nn.Module]: """ @@ -222,7 +222,9 @@ def quantize_to_float8( Args: module (nn.Module): The module to modify. quant_config (QuantConfig): Quantization configuration for Float8 conversion. - skip_fqn_list (List[str], optional): List of module FQNs to skip during conversion. + layer_filter_fn: If specified, only the modules that + that pass the filter function will be swapped. The inputs to the + filter function are the FQN and module instance. use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True. Returns: @@ -234,5 +236,5 @@ def quantize_to_float8( return swap_linear_layers( module, lambda m: Float8InferenceLinear.from_float(m, quant_config, use_fast_accum), - skip_fqn_list=skip_fqn_list, + layer_filter_fn=layer_filter_fn, ) diff --git a/test/test_base.py b/test/test_base.py index 381fb4e..5431afe 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -18,7 +18,6 @@ from float8_experimental.float8_linear import Float8Linear, TensorScalingType from float8_experimental.float8_linear_utils import ( - filter_out_small_unaligned_layers, linear_requires_sync, swap_linear_with_float8_linear, sync_float8_amax_and_scale_history, @@ -631,24 +630,34 @@ def __init__(self, dim: int): self.lin1 = nn.Linear(dim, 4 * dim) self.lin2 = nn.Linear(4 * dim, 4 * dim) - for emulate in [True, False]: - model = nn.Sequential(MLP(8), nn.Linear(32, 32), MLP(40)) - # filter out the linear layers whose shape is smaller than 32 or non-divisible by 16. - model = swap_linear_with_float8_linear( - model, - emulate=emulate, - linear_layer_filter=filter_out_small_unaligned_layers(32), + model = nn.Sequential(MLP(8), nn.Linear(32, 32), MLP(40)) + # filter out the linear layers whose shape is smaller than 32 or non-divisible by 16. + + size_limit = 32 + + def layer_filter_fn(fqn, mod): + return ( + mod.in_features >= size_limit + and mod.out_features >= size_limit + and mod.in_features % 16 == 0 + and mod.out_features % 16 == 0 ) - # in_features=8, out_features=32, 8 is less than 32. - self.assertNotIsInstance(model[0].lin1, Float8Linear) - # in_features=32, out_features=32, - self.assertIsInstance(model[0].lin2, Float8Linear) - # in_features=32, out_features=32, - self.assertIsInstance(model[1], Float8Linear) - # in_features=40, out_features=160, 40 is not divisible by 16. - self.assertNotIsInstance(model[2].lin1, Float8Linear) - # in_features=160, out_features=160, - self.assertIsInstance(model[2].lin2, Float8Linear) + + model = swap_linear_with_float8_linear( + model, + emulate=True, + layer_filter_fn=layer_filter_fn, + ) + # in_features=8, out_features=32, 8 is less than 32. + self.assertNotIsInstance(model[0].lin1, Float8Linear) + # in_features=32, out_features=32, + self.assertIsInstance(model[0].lin2, Float8Linear) + # in_features=32, out_features=32, + self.assertIsInstance(model[1], Float8Linear) + # in_features=40, out_features=160, 40 is not divisible by 16. + self.assertNotIsInstance(model[2].lin1, Float8Linear) + # in_features=160, out_features=160, + self.assertIsInstance(model[2].lin2, Float8Linear) def test_swap_submodule_linears_with_skip(self): class MLP(nn.Module): @@ -657,20 +666,21 @@ def __init__(self, dim: int): self.lin1 = nn.Linear(dim, 4 * dim) self.lin2 = nn.Linear(4 * dim, dim) - for emulate in [True, False]: - model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3)) - skip_fqn_list = ["2", "0.lin2"] - model = swap_linear_with_float8_linear( - model, emulate=emulate, skip_fqn_list=skip_fqn_list - ) - self.assertIsInstance(model[0].lin1, Float8Linear) - self.assertNotIsInstance(model[0].lin2, Float8Linear) - self.assertIsInstance(model[0].lin2, nn.Linear) - self.assertIsInstance(model[1], Float8Linear) - self.assertNotIsInstance(model[2].lin2, Float8Linear) - self.assertNotIsInstance(model[2].lin2, Float8Linear) - self.assertIsInstance(model[2].lin1, nn.Linear) - self.assertIsInstance(model[2].lin2, nn.Linear) + model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3)) + layer_filter_fn = lambda fqn, mod: fqn not in [ + "0.lin2", + "2.lin1", + ] + model = swap_linear_with_float8_linear( + model, + emulate=True, + layer_filter_fn=layer_filter_fn, + ) + self.assertTrue(type(model[0].lin1) is Float8Linear) + self.assertTrue(type(model[0].lin2) is nn.Linear) + self.assertTrue(type(model[1]) is Float8Linear) + self.assertTrue(type(model[2].lin1) is nn.Linear) + self.assertTrue(type(model[2].lin2) is Float8Linear) def test_fp8_tensor_statistics(self): hp_dtypes = (torch.float32, torch.float16, torch.bfloat16)