Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

[bc breaking] change x, w, dL_dY variable names to input, weight, grad_output #323

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ufmt.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,7 @@ jobs:
pip install black==23.3.0 usort==1.0.6 ufmt==2.1.0 libcst==1.0.1
- name: Analyzing the code with ufmt
run: |
ufmt format .
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L26:28 is for easier debugging of differences between local machine and CI ufmt

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the ufmt config different than CI?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've noticed this on a couple of PRs in the past. A better fix would be to align ufmt versions + env, but regardless this is useful for debugging.

git diff
git restore .
ufmt check .
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ pip install -e ".[dev]"

# Single GPU User API

We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. These strategies are configurable separately for activations (`x`), weights (`w`) and gradients (`dL_dY`).
We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. These strategies are configurable separately for activations (`input`), weights (`weight`) and gradients (`grad_output`).

## float8 linear with dynamic scaling for `x`, `w` and `dL_dY`
## float8 linear with dynamic scaling for `input`, `weight` and `grad_output`

This is the most accurate recipe as every tensor is scaled dynamically.

Expand Down Expand Up @@ -95,9 +95,9 @@ m = Model(...)
# type
swap_linear_with_float8_linear(
m,
scaling_type_x=TensorScalingType.DELAYED,
scaling_type_w=TensorScalingType.DELAYED,
scaling_type_dL_dY=TensorScalingType.DELAYED,
scaling_type_input=TensorScalingType.DELAYED,
scaling_type_weight=TensorScalingType.DELAYED,
scaling_type_grad_output=TensorScalingType.DELAYED,
)

# optional: use FSDP. Note that workarounds gated with config.enable_amax_init and
Expand Down
40 changes: 21 additions & 19 deletions benchmarks/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,16 @@ def main(
n_limit: Optional[int] = None,
fast_accum_filter: Optional[bool] = None,
shape_name_filter: Optional[str] = None,
scaling_type_x: str = "dynamic",
scaling_type_w: str = "dynamic",
scaling_type_dL_dY: str = "dynamic",
scaling_type_input: str = "dynamic",
scaling_type_weight: str = "dynamic",
scaling_type_grad_output: str = "dynamic",
):
device = "cuda"
print(f"Compile is set to | {compile}")

scaling_type_x = TensorScalingType(scaling_type_x)
scaling_type_w = TensorScalingType(scaling_type_w)
scaling_type_dL_dY = TensorScalingType(scaling_type_dL_dY)
scaling_type_input = TensorScalingType(scaling_type_input)
scaling_type_weight = TensorScalingType(scaling_type_weight)
scaling_type_grad_output = TensorScalingType(scaling_type_grad_output)

# LLaMa 2 70B single-node weight shapes
# assumes fused attn.wqkv and ffn.w13
Expand Down Expand Up @@ -136,9 +136,9 @@ def main(
linear_float8 = Float8Linear.from_float(
copy.deepcopy(linear_ref),
emulate=False,
scaling_type_x=scaling_type_x,
scaling_type_w=scaling_type_w,
scaling_type_dL_dY=scaling_type_dL_dY,
scaling_type_input=scaling_type_input,
scaling_type_weight=scaling_type_weight,
scaling_type_grad_output=scaling_type_grad_output,
)
scaling_repr = linear_float8.scaling_repr()

Expand All @@ -153,7 +153,9 @@ def main(
ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward()

def float8_forw_backward():
if linear_requires_sync(scaling_type_x, scaling_type_w, scaling_type_dL_dY):
if linear_requires_sync(
scaling_type_input, scaling_type_weight, scaling_type_grad_output
):
sync_float8_amax_and_scale_history(linear_float8)
linear_float8(input_tensor).sum().backward()

Expand Down Expand Up @@ -278,18 +280,18 @@ def invoke_main() -> None:
parser.add_argument("-n", "--n_limit", type=int, required=False)
parser.add_argument("--fast_accum_filter", type=bool, required=False)
parser.add_argument("--shape_name_filter", type=str, required=False)
parser.add_argument("--scaling_type_x", type=str, required=False)
parser.add_argument("--scaling_type_w", type=str, required=False)
parser.add_argument("--scaling_type_dL_dY", type=str, required=False)
parser.add_argument("--scaling_type_input", type=str, required=False)
parser.add_argument("--scaling_type_weight", type=str, required=False)
parser.add_argument("--scaling_type_grad_output", type=str, required=False)
args = parser.parse_args()
output_path = Path(args.output_path) if args.output_path is not None else None
kwargs = {}
if args.scaling_type_x is not None:
kwargs["scaling_type_x"] = args.scaling_type_x
if args.scaling_type_w is not None:
kwargs["scaling_type_w"] = args.scaling_type_w
if args.scaling_type_dL_dY is not None:
kwargs["scaling_type_dL_dY"] = args.scaling_type_dL_dY
if args.scaling_type_input is not None:
kwargs["scaling_type_input"] = args.scaling_type_input
if args.scaling_type_weight is not None:
kwargs["scaling_type_weight"] = args.scaling_type_weight
if args.scaling_type_grad_output is not None:
kwargs["scaling_type_grad_output"] = args.scaling_type_grad_output
main(
output_path,
not args.disable_compile,
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/bench_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def get_model(K, N, is_fp8, base_dtype=torch.float32):
swap_linear_with_float8_linear(
m,
emulate=False,
scaling_type_x=TensorScalingType.DELAYED,
scaling_type_w=TensorScalingType.DELAYED,
scaling_type_dL_dY=TensorScalingType.DELAYED,
scaling_type_input=TensorScalingType.DELAYED,
scaling_type_weight=TensorScalingType.DELAYED,
scaling_type_grad_output=TensorScalingType.DELAYED,
)
return m

Expand Down
27 changes: 16 additions & 11 deletions benchmarks/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,20 +204,23 @@ def profile_function(
def main(
profile_path_prefix: Path,
compile: bool = True,
scaling_type_x: str = "dynamic",
scaling_type_w: str = "dynamic",
scaling_type_dL_dY: str = "dynamic",
scaling_type_input: str = "dynamic",
scaling_type_weight: str = "dynamic",
scaling_type_grad_output: str = "dynamic",
model_type: str = "linear",
dtype_filter: str = "both",
):
assert model_type in ("linear", "ln_linear", "norm_ffn_norm"), "unsupported"
assert dtype_filter in ("both", "float8", "bfloat16")

scaling_type_x = TensorScalingType(scaling_type_x)
scaling_type_w = TensorScalingType(scaling_type_w)
scaling_type_dL_dY = TensorScalingType(scaling_type_dL_dY)
scaling_type_input = TensorScalingType(scaling_type_input)
scaling_type_weight = TensorScalingType(scaling_type_weight)
scaling_type_grad_output = TensorScalingType(scaling_type_grad_output)
scaling_repr = "_".join(
[s.short_str() for s in (scaling_type_x, scaling_type_w, scaling_type_dL_dY)]
[
s.short_str()
for s in (scaling_type_input, scaling_type_weight, scaling_type_grad_output)
]
)

print(f"Compile is set to | {compile}")
Expand Down Expand Up @@ -254,9 +257,9 @@ def main(
m_ref = m_ref.to(device).to(ref_dtype)

extra_kwargs = {
"scaling_type_x": scaling_type_x,
"scaling_type_w": scaling_type_w,
"scaling_type_dL_dY": scaling_type_dL_dY,
"scaling_type_input": scaling_type_input,
"scaling_type_weight": scaling_type_weight,
"scaling_type_grad_output": scaling_type_grad_output,
}

m_float8 = copy.deepcopy(m_ref)
Expand All @@ -278,7 +281,9 @@ def float8_forw_backward_wrapper(x):
# inspection of the fw+bw torch.compile without the scale
# syncing code
# TODO(future): make this better
if linear_requires_sync(scaling_type_x, scaling_type_w, scaling_type_dL_dY):
if linear_requires_sync(
scaling_type_input, scaling_type_weight, scaling_type_grad_output
):
with record_function("scale_amax_and_scales"):
sync_amax_history(m_float8)
out = float8_forw(x)
Expand Down
Loading
Loading