Skip to content

Commit

Permalink
inherit stateful protocol where appropriate
Browse files Browse the repository at this point in the history
ghstack-source-id: d410f30ec715bfb4206459becb95abeed5a4ae02
Pull Request resolved: #281
  • Loading branch information
tianyu-l committed Apr 26, 2024
1 parent 42549a9 commit 0d09a32
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
5 changes: 3 additions & 2 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
set_model_state_dict,
set_optimizer_state_dict,
)
from torch.distributed.checkpoint.stateful import Stateful
from torchtitan.config_manager import JobConfig
from torchtitan.logging_utils import logger

Expand All @@ -36,7 +37,7 @@ class IntervalType(enum.Enum):
STEPS = enum.auto()


class ModelWrapper:
class ModelWrapper(Stateful):
def __init__(self, model: nn.Module) -> None:
self.model = model

Expand All @@ -47,7 +48,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
set_model_state_dict(self.model, state_dict)


class OptimizerWrapper:
class OptimizerWrapper(Stateful):
def __init__(self, model: nn.Module, optim: torch.optim.Optimizer) -> None:
self.model = model
self.optim = optim
Expand Down
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch
import torch.nn.functional as F
from torch.distributed import destroy_process_group
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.tensor.parallel import loss_parallel

Expand Down Expand Up @@ -47,7 +48,7 @@


@dataclass
class TrainState:
class TrainState(Stateful):
step: int = 0
global_avg_losses: List[float] = field(default_factory=list)
global_max_losses: List[float] = field(default_factory=list)
Expand Down

0 comments on commit 0d09a32

Please sign in to comment.