Skip to content

Commit

Permalink
Add more Float8 description (pytorch#284)
Browse files Browse the repository at this point in the history
# Summary

Add more the possible options in the configs and add a note on how to
get the dependency at the top of the file.
  • Loading branch information
drisspg authored Apr 29, 2024
1 parent 0f8eb4c commit 0c9d590
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
6 changes: 5 additions & 1 deletion torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,11 @@ def __init__(self):
"dynamic",
"",
], # TODO: add "delayed" option back in when supported
help="Type of fp8 linear quantization to apply to the model ['', 'dynamic']",
help="""
Type of fp8 linear quantization to apply to the model ['', 'dynamic'].
This features requires you to install 'float8_experimental' which can be found
here: https://github.com/pytorch-labs/float8_experimental
""",
)
self.parser.add_argument(
"--training.gc_freq",
Expand Down
11 changes: 10 additions & 1 deletion torchtitan/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# [Note] Getting the 'float8_experimental' package:
# This script requires the 'float8_experimental' package to function correctly.
# Please ensure you have this package installed from the appropriate repository.
# You can obtain it from https://github.com/pytorch-labs/float8_experimental.
# Either clone and run `pip install .` or run `pip install git+https://github.com/pytorch-labs/float8_experimental.git`

# Note: Performance
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance

import torch.nn as nn

from torchtitan.config_manager import JobConfig
Expand All @@ -14,7 +23,7 @@ def build_fp8_linear(model: nn.Module, job_config: JobConfig):
"""
This function converts the linear layers to one of the fp8 types:
- Float8DynamicLinear: Dynamic quantization of the weights and the activations
- Float8Linear: Uses a history of amaxs to quantize the weights and activations
- [Not Yet Supported] Float8Linear: Uses a history of amaxs to quantize the weights and activations
This will mutate the model inplace.
"""
Expand Down

0 comments on commit 0c9d590

Please sign in to comment.