Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 29, 2024
1 parent 46aca6b commit 1d75450
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import logging
import os
import re
from typing import Any, Dict, Optional, List
from typing import Any, Dict, List, Optional

import torch
from fsspec.core import url_to_fs
Expand Down Expand Up @@ -521,7 +521,8 @@ def _get_loops_state_dict(self) -> Dict[str, Any]:

def _get_dataloader_state_dicts(self) -> List[Dict[str, Any]]:
return [
train_dataloader.state_dict() for train_dataloader in (self.trainer.train_dataloader or [])
train_dataloader.state_dict()
for train_dataloader in (self.trainer.train_dataloader or [])
if isinstance(train_dataloader, _Stateful)
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from copy import deepcopy
from unittest import mock
from unittest.mock import Mock

import pytest
import torch
from torch.utils.data import DataLoader

from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.migration.utils import _set_version
from lightning.pytorch.trainer.connectors.checkpoint_connector import _CheckpointConnector
from torch.utils.data import DataLoader


def test_preloaded_checkpoint_lifecycle(tmpdir):
Expand Down Expand Up @@ -235,12 +232,14 @@ def load_state_dict(self, state_dict):
self._counter = state_dict["counter"]


@pytest.mark.parametrize(("train_dataloaders", "expected_states"), [
([], None),
(StatefulDataLoader(RandomDataset(32, 64)), [{"counter": 0}]),
])
@pytest.mark.parametrize(
("train_dataloaders", "expected_states"),
[
([], None),
(StatefulDataLoader(RandomDataset(32, 64)), [{"counter": 0}]),
],
)
def test_train_dataloaders_restore(train_dataloaders, expected_states, tmp_path):

class TestModel(BoringModel):
def train_dataloader(self):
return train_dataloaders
Expand Down

0 comments on commit 1d75450

Please sign in to comment.