Skip to content

Commit

Permalink
disable metric when model parallel (#701)
Browse files Browse the repository at this point in the history
### Description
Previously metric logging is blocking model parallel. Disable instead of
raising error.

### Type of changes
- [x]  Bug fix (non-breaking change which fixes an issue)

Signed-off-by: sichu <[email protected]>
  • Loading branch information
sichu2023 authored Feb 25, 2025
1 parent fca2cda commit 3519ecf
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,10 @@ def train_model(
)
# Configure the model
train_metric = None
if task_type == "regression":
is_model_parallel = tensor_model_parallel_size * pipeline_model_parallel_size > 1
if is_model_parallel:
valid_metric = None # metric logging under model parallelism is not supported yet
elif task_type == "regression":
valid_metric = TorchmetricsConfig(class_path="MeanSquaredError", task="regression", metric_name="val_mse")
else:
valid_metric = TorchmetricsConfig(
Expand All @@ -296,11 +299,6 @@ def train_model(
metric_name="val_acc",
)

if tensor_model_parallel_size * pipeline_model_parallel_size > 1 and (
train_metric is not None or valid_metric is not None
):
raise NotImplementedError("Metric logging under model parallelism is not supported yet.")

config = config_class(
task_type=task_type,
encoder_frozen=encoder_frozen,
Expand Down
20 changes: 10 additions & 10 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,16 +265,16 @@ def main(
)
# Configure the model
train_metric = None
valid_metric = TorchmetricsConfig(
class_path="text.Perplexity",
task="pretraining",
kwargs={"ignore_index": MLM_LOSS_IGNORE_INDEX},
metric_name="val_ppl",
)
if tensor_model_parallel_size * pipeline_model_parallel_size > 1 and (
train_metric is not None or valid_metric is not None
):
raise NotImplementedError("Metric logging under model parallelism is not supported yet.")
is_model_parallel = tensor_model_parallel_size * pipeline_model_parallel_size > 1
if is_model_parallel:
valid_metric = None # metric logging under model parallelism is not supported yet
else:
valid_metric = TorchmetricsConfig(
class_path="text.Perplexity",
task="pretraining",
kwargs={"ignore_index": MLM_LOSS_IGNORE_INDEX},
metric_name="val_ppl",
)

esm2_config = ESM2Config(
seq_length=max_seq_length,
Expand Down

0 comments on commit 3519ecf

Please sign in to comment.