Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BatchSizeFinder limits number of validation batches for the whole training process #18834

Closed
BoringDonut opened this issue Oct 20, 2023 · 3 comments
Labels
bug Something isn't working duplicate This issue or pull request already exists tuner ver: 1.8.x ver: 2.0.x

Comments

@BoringDonut
Copy link
Contributor

BoringDonut commented Oct 20, 2023

Bug description

Using BatchSizeFinder seems to limit number of validation batches to BatchSizeFinder._steps_per_trial.

This results in val set being equal to few dozens samples and inadequate metrics being produced.

It seems it can be fixed by calling to _reset_dataloaders one additional time

What version are you seeing the problem on?

v1.8, v2.0

How to reproduce the bug

import os

import torch
from lightning.pytorch import LightningModule, Trainer, LightningDataModule
from lightning.pytorch.callbacks import BatchSizeFinder
from lightning.pytorch.tuner.batch_size_scaling import _reset_dataloaders
from torch.utils.data import DataLoader, Dataset


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class Data(LightningDataModule):
    def __init__(self, ds_size: int, batch_size: int):
        super().__init__()
        self.ds_size = ds_size
        self.batch_size = batch_size

    def train_dataloader(self):
        return DataLoader(RandomDataset(27, self.ds_size), batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(RandomDataset(27, self.ds_size), batch_size=self.batch_size)


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(27, 128)
        self.dropout = torch.nn.Dropout(0.5)
        self.val_sample_counter = 0

    def forward(self, x):
        return self.dropout(self.layer(x))

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)
        if not self.trainer.sanity_checking:
            self.val_sample_counter += len(batch)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

    def on_validation_epoch_start(self) -> None:
        self.val_sample_counter = 0

    def on_validation_epoch_end(self) -> None:
        if not self.trainer.sanity_checking:
            print(f"VALIDATED {self.val_sample_counter} SAMPLES ON EPOCH {self.trainer.current_epoch}")


class CustomBatchSizeFinder(BatchSizeFinder):
    def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        super().on_validation_start(trainer, pl_module)
        _reset_dataloaders(trainer)


def run():
    DATASET_SIZE = 123
    BATCH_SIZE = 2
    steps_per_trial = 3

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        max_epochs=2,
        enable_progress_bar=False,
        enable_model_summary=False,
    )
    print("Without BatchSizeFinder:")
    trainer.fit(model, datamodule=Data(DATASET_SIZE, BATCH_SIZE))
    print("-" * 20)

    model = BoringModel()
    callbacks = [BatchSizeFinder(steps_per_trial=steps_per_trial)]
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        max_epochs=2,
        enable_progress_bar=False,
        enable_model_summary=False,
        callbacks=callbacks,
    )
    print("With BatchSizeFinder:")
    trainer.fit(model, datamodule=Data(DATASET_SIZE, BATCH_SIZE))
    print("-"*20)

    model = BoringModel()
    callbacks = [CustomBatchSizeFinder(steps_per_trial=steps_per_trial)]
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        max_epochs=2,
        enable_progress_bar=False,
        enable_model_summary=False,
        callbacks=callbacks,
    )
    print("With BatchSizeFinder that calls to `_reset_dataloaders`:")
    trainer.fit(model, datamodule=Data(DATASET_SIZE, BATCH_SIZE))


if __name__ == "__main__":
    torch.manual_seed(1)
    run()

Error messages and logs

Here is log that shows a number of validated samples for each epoch.
Val ds size: 123, num epochs: 2, batch size: 2 (see code above)

Without BatchSizeFinder:
VALIDATED 123 SAMPLES ON EPOCH 0
VALIDATED 123 SAMPLES ON EPOCH 1
--------------------
With BatchSizeFinder:
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 6 SAMPLES ON EPOCH 1
--------------------
With BatchSizeFinder that calls to `_reset_dataloaders`:
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 123 SAMPLES ON EPOCH 0
VALIDATED 123 SAMPLES ON EPOCH 1

As you can see first and last runs validated all 123 samples twice, while second run (with default BatchSizeFinder) only validated 6 samples on both epochs.
Here 6 = steps_per_trial * BATCH_SIZE = 3 * 2.

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA GeForce RTX 3050 Laptop GPU
    - available: True
    - version: 12.1
  • Lightning:
    - lightning: 2.1.0
    - lightning-utilities: 0.9.0
    - pytorch-lightning: 2.1.0
    - torch: 2.1.0
    - torchmetrics: 1.2.0
  • Packages:
    - aiohttp: 3.8.6
    - aiosignal: 1.3.1
    - async-timeout: 4.0.3
    - attrs: 23.1.0
    - certifi: 2023.7.22
    - charset-normalizer: 3.3.0
    - filelock: 3.12.4
    - frozenlist: 1.4.0
    - fsspec: 2023.9.2
    - idna: 3.4
    - jinja2: 3.1.2
    - lightning: 2.1.0
    - lightning-utilities: 0.9.0
    - markupsafe: 2.1.3
    - mpmath: 1.3.0
    - multidict: 6.0.4
    - networkx: 3.1
    - numpy: 1.24.4
    - nvidia-cublas-cu12: 12.1.3.1
    - nvidia-cuda-cupti-cu12: 12.1.105
    - nvidia-cuda-nvrtc-cu12: 12.1.105
    - nvidia-cuda-runtime-cu12: 12.1.105
    - nvidia-cudnn-cu12: 8.9.2.26
    - nvidia-cufft-cu12: 11.0.2.54
    - nvidia-curand-cu12: 10.3.2.106
    - nvidia-cusolver-cu12: 11.4.5.107
    - nvidia-cusparse-cu12: 12.1.0.106
    - nvidia-nccl-cu12: 2.18.1
    - nvidia-nvjitlink-cu12: 12.2.140
    - nvidia-nvtx-cu12: 12.1.105
    - packaging: 23.2
    - pip: 23.2.1
    - pytorch-lightning: 2.1.0
    - pyyaml: 6.0.1
    - requests: 2.31.0
    - setuptools: 68.1.2
    - sympy: 1.12
    - torch: 2.1.0
    - torchmetrics: 1.2.0
    - tqdm: 4.66.1
    - triton: 2.1.0
    - typing-extensions: 4.8.0
    - urllib3: 2.0.7
    - wheel: 0.41.2
    - yarl: 1.9.2
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.8.18
    - release: 5.15.0-83-generic
    - version: Update trainer.py #92-Ubuntu SMP Mon Aug 14 09:30:42 UTC 2023

More info

@tanaymeh can you maybe add related fix to #18826 ? It seems to be related to the sample parts of code and only require a few additional lines.

@BoringDonut BoringDonut added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Oct 20, 2023
@awaelchli
Copy link
Contributor

I think this is a duplicate of #18394. Can you confirm?

@awaelchli awaelchli added duplicate This issue or pull request already exists tuner and removed needs triage Waiting to be triaged by maintainers labels Oct 21, 2023
@BoringDonut
Copy link
Contributor Author

I think this is a duplicate of #18394. Can you confirm?

Yes, indeed. Sorry for that

@BoringDonut
Copy link
Contributor Author

Dublicate and fixed with #18854

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working duplicate This issue or pull request already exists tuner ver: 1.8.x ver: 2.0.x
Projects
None yet
Development

No branches or pull requests

2 participants