Skip to content

Commit

Permalink
Add vortex style fp8 support to predict
Browse files Browse the repository at this point in the history
Signed-off-by: John St John <[email protected]>
  • Loading branch information
jstjohn committed Mar 4, 2025
1 parent 24f1db0 commit e012146
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 29 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/NeMo
44 changes: 16 additions & 28 deletions sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ def parse_args():
default=None,
help="Output dir that will contain the generated text produced by the Evo2 model. If not provided, the output will be logged.",
)
ap.add_argument(
"--full-fp8",
action="store_true",
help="Use full FP8 precision (faster but less accurate) rather than vortex style which "
"only applies FP8 to the projection layer of the hyena mixer, when using FP8.",
)
ap.add_argument("--fp8", action="store_true", help="Use FP8 precision. Defaults to BF16.")
# extra:
ap.add_argument(
Expand Down Expand Up @@ -120,31 +126,6 @@ def _gather_along_cp_dim(input_, seq_dim: int = 1):
return output


def _collect_into_dim(input_: torch.Tensor, dim: int = -1):
"""Gather tensors and concatenate along the last dimension, assuming the input shape is not split.
This is needed when there is no sequence parallelism but tensor parallelism is enabled along the last dimension.
"""
world_size = parallel_state.get_tensor_model_parallel_world_size()
my_rank = parallel_state.get_tensor_model_parallel_rank()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
my_chunk_input = input_.chunk(world_size, dim=dim)[my_rank]
dim_size = list(my_chunk_input.size())
dim_size[0] = dim_size[0] * world_size
output = torch.empty(dim_size, dtype=my_chunk_input.dtype, device=torch.cuda.current_device())
# Gather all chunks into the 0th dimension of the output tensor.
torch.distributed.all_gather_into_tensor(
output, my_chunk_input.contiguous(), group=parallel_state.get_tensor_model_parallel_group()
)
# Split the output tensor back into the original chunks, now synchronized across GPUs that own each chunk.
tensor_list = output.chunk(world_size, dim=0)
output = torch.cat(tensor_list, dim=dim).contiguous()

return output


class HyenaPredictor(LightningPassthroughPredictionMixin, HyenaModel):
"""A predictor for the Hyena model. This adds in the predict step and the passthrough method."""

Expand Down Expand Up @@ -347,6 +328,7 @@ def predict(
model_size: str = "7b",
ckpt_format: CheckpointFormats = "torch_dist",
fp8: bool = False,
full_fp8: bool = False,
work_dir: Path | None = None,
batch_size: int = 1,
output_log_prob_seqs: bool = False,
Expand Down Expand Up @@ -406,15 +388,20 @@ def predict(
plugins=nl.MegatronMixedPrecision(
precision="bf16-mixed",
params_dtype=torch.bfloat16,
fp8="hybrid" if fp8 else None,
fp8_amax_history_len=16 if fp8 else 1,
fp8_amax_compute_algo="max" if fp8 else "most_recent",
# Only use FP8 in this plugin when using full FP8 precision and FP8.
# Otherwise use vortex_style_fp8 in the model config.
fp8="hybrid" if fp8 and full_fp8 else None,
fp8_amax_history_len=16 if fp8 and full_fp8 else 1,
fp8_amax_compute_algo="max" if fp8 and full_fp8 else "most_recent",
),
)
config = HYENA_MODEL_OPTIONS[model_size](
forward_step_fn=hyena_predict_forward_step,
data_step_fn=hyena_predict_data_step, # , attention_backend=AttnBackend.fused,
distribute_saved_activations=False if sequence_parallel and tensor_parallel_size > 1 else True,
# Only use vortex style FP8 in the model config if using FP8 and not full FP8. This will only apply FP8 to
# the projection layer of the hyena mixer.
vortex_style_fp8=fp8 and not full_fp8,
)
trainer.strategy._setup_optimizers = False

Expand Down Expand Up @@ -461,6 +448,7 @@ def main():
model_size=args.model_size,
ckpt_format=args.ckpt_format,
fp8=args.fp8,
full_fp8=args.full_fp8,
batch_size=args.batch_size,
output_log_prob_seqs=args.output_log_prob_seqs,
log_prob_collapse_option=args.log_prob_collapse_option,
Expand Down

0 comments on commit e012146

Please sign in to comment.