diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index d41e839c1..e939437fb 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -20,6 +20,7 @@ set_model_state_dict, set_optimizer_state_dict, ) +from torch.distributed.checkpoint.stateful import Stateful from torchtitan.config_manager import JobConfig from torchtitan.logging_utils import logger @@ -36,7 +37,7 @@ class IntervalType(enum.Enum): STEPS = enum.auto() -class ModelWrapper: +class ModelWrapper(Stateful): def __init__(self, model: nn.Module) -> None: self.model = model @@ -47,7 +48,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: set_model_state_dict(self.model, state_dict) -class OptimizerWrapper: +class OptimizerWrapper(Stateful): def __init__(self, model: nn.Module, optim: torch.optim.Optimizer) -> None: self.model = model self.optim = optim diff --git a/train.py b/train.py index 5dddd14c4..ea6cdc3fc 100644 --- a/train.py +++ b/train.py @@ -20,6 +20,7 @@ import torch import torch.nn.functional as F from torch.distributed import destroy_process_group +from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.tensor.parallel import loss_parallel @@ -47,7 +48,7 @@ @dataclass -class TrainState: +class TrainState(Stateful): step: int = 0 global_avg_losses: List[float] = field(default_factory=list) global_max_losses: List[float] = field(default_factory=list)