Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataSpec.device_transforms do not run on device #3699

Closed
Ghelfi opened this issue Nov 4, 2024 · 7 comments
Closed

DataSpec.device_transforms do not run on device #3699

Ghelfi opened this issue Nov 4, 2024 · 7 comments
Labels
bug Something isn't working

Comments

@Ghelfi
Copy link
Contributor

Ghelfi commented Nov 4, 2024

Description

Since release 0.25.0 DataSpec.device_transforms do not run on device.

2 successive PR changed where batches / micro_batches are moved to device and where device_transforms are applied in an inconsistent way:

  • First PR moved device_transforms and copy at the microbatch level
  • Second PR brought back the device_transforms before moving the batch to device.

Hence, device_transforms are applied by batch but systematically on the cpu.

@Ghelfi Ghelfi added the bug Something isn't working label Nov 4, 2024
@Ghelfi
Copy link
Contributor Author

Ghelfi commented Nov 4, 2024

Here is a snippet to reporduce the error:

from collections.abc import Callable
from typing import cast

import torch.nn.functional as F
from composer import Trainer
from composer.core import DataSpec
from composer.models import ComposerClassifier
from torch import Tensor, nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


class Model(nn.Module):
    def __init__(self, num_classes: int = 10):
        super().__init__()

        self.num_classes = num_classes
        self.conv = nn.Conv2d(1, 16, (3, 3), padding=0)
        self.fc = nn.Linear(16, num_classes)

    def forward(self, x: Tensor) -> Tensor:
        out = self.conv(x)
        out = F.relu(out)
        return cast(Tensor, self.fc(out))


def get_device_transform(device: str) -> Callable[[tuple[Tensor, Tensor]], tuple[Tensor, Tensor]]:
    def _transform(batch: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]:
        """Dummy on device transform to check `batch` is actally on the right device."""
        assert (
            batch[0].device.type == device
        ), f"found data on device {batch[0].device.type} while expecting device_transform to run on {device}"
        assert (
            batch[1].device.type == device
        ), f"found data on device {batch[1].device.type} while expecting device_transform to run on {device}"

        return batch

    return _transform


transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.MNIST("data", train=True, download=True, transform=transform)
train_dataloader = DataSpec(
    dataloader=DataLoader(dataset, batch_size=16),
    device_transforms=get_device_transform(device="cuda"),
    get_num_samples_in_batch=lambda batch: len(batch[0]),
)

trainer = Trainer(
    model=ComposerClassifier(module=Model(), num_classes=10),
    train_dataloader=train_dataloader,
    max_duration="2ep",
    device="gpu",
)
trainer.fit()

This outputs

AssertionError: found data on device cpu while expecting device_transform to run on cuda

--> device transforms run on cpu

This requires a gpu to run but one can switch both "cuda" and "gpu" by "mps" to reproduce the error locally on mac accelerator.

We currently measure a ~10x slow down on heavy transform pipelines...
If reproducible, It would be nice to provide a patch and not waiting for the next release cycle since this appear to be a regression.

@Ghelfi
Copy link
Contributor Author

Ghelfi commented Nov 4, 2024

@mvpatel2000 i saw that you patched some algorithms to move data on the device ahead of time.

Is it something that should be done for every transforms?
This would be unfortunate since it breaks native switch betwen on-cpu and on-device transforms.

@mvpatel2000
Copy link
Contributor

Thanks for flagging this!

The original motivation here is that large batches of images and multimodal data add significant memory pressure if we do not move them incrementally per microbatch.

We decided to avoid transforms at microbatch level because transforms may use batch statistics, and at the time, we didn't think there would be many tradeoffs to leaving it on CPU (our workloads are rarely CPU bottlenecked). Given this issue, I see a few options:

  • move transforms to per microbatch -- we originally thought leaving on CPU would be safer than moving per microbatch, but perhaps this is wrong
  • adding a flag to revert to old behavior (move each batch)
  • add a transform in Composer that user can import to move batch (similar to option 2)

We are still discussing, but if you have any feedback @Ghelfi feel free to chime in. I'm especially curious if option 1 affects you

@Ghelfi
Copy link
Contributor Author

Ghelfi commented Nov 4, 2024

We do have some transforms that are coherent through batch, meaning same parameters are applied through a whole batch. Having the transforms moved to per-microbatch might change this. This falls into what you mention as " batch statistics" somehow.

Option 2 and 3 could work. A trainer flag stating the move happens at the batch level as an opt-in could let people have both way.

Did you kept the name device_tranfrorms for backcompatibility? Wouldn't batched_transforms be a less confusing name?

@mvpatel2000
Copy link
Contributor

We do have some transforms that are coherent through batch, meaning same parameters are applied through a whole batch. Having the transforms moved to per-microbatch might change this. This falls into what you mention as " batch statistics" somehow.

How intensive are these and do you care about ordering? For example, what if we had batch_transforms run on CPU and microbatch_transforms run on GPU

Did you kept the name device_tranfrorms for backcompatibility? Wouldn't batched_transforms be a less confusing name?

Yes, primarily because of this

@Ghelfi
Copy link
Contributor Author

Ghelfi commented Nov 5, 2024

We use intensive transform, well accelerated on GPU.

I'd be more in favor of having flag enabling transfer at the batch level as an opt-in. This allow and easy way to fall back to previous behaviour, is opt-in, and enable "True" device_transforms.

If breaking back-compatibility is ok, we can have batch_transforms on CPU and device_transforms coupled with data transfer either at the batch or microbatch level depending on a flag data_transfer_stage: Literal["batch", "microbatch"] = "microbatch".

@Ghelfi
Copy link
Contributor Author

Ghelfi commented Jan 6, 2025

Closing since handled in #3703

@Ghelfi Ghelfi closed this as completed Jan 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants