Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch Size Issue in Maissi Generative Model Configuration #1853

Open
gexinh opened this issue Oct 6, 2024 · 2 comments
Open

Batch Size Issue in Maissi Generative Model Configuration #1853

gexinh opened this issue Oct 6, 2024 · 2 comments

Comments

@gexinh
Copy link

gexinh commented Oct 6, 2024

Dear Dong Yang (@dongyang0122),

I hope this message finds you well. Thank you in advance for your time and support.

I am currently working with the Maissi generative model and planning to accelerate the training process by increasing the batch size. However, I encountered an issue where, despite modifying the batch size in the configuration file, the DataLoader batch size remains set to 1.

Could you kindly advise on how to resolve this issue?

The log file is as bellow:

image
wherein the log is recorded base on the code:

if local_rank == 0:
            logger.info(
                "[{0}] epoch {1}, iter {2}/{3}, loss: {4:.4f}, lr: {5:.12f}.".format(
                    str(datetime.now())[:19], epoch + 1, _iter, len(train_loader), loss.item(), current_lr
                )
            )

Note that the number of itereation is equal to the length of train_loader and the number of training set is 1000. In my understanding, the enlarged batch size should decrease the length of train_loader. However, the length of train_loader is still equal to 1000 (the number of training set), which seems that the batch size is 1.

Additionaly, the corresponding code for data loader is in the scripts.diff_model_train.py:

  def prepare_data(
      train_files: list, device: torch.device, cache_rate: float, num_workers: int = 2, batch_size: int = 1
  ) -> ThreadDataLoader:
      """
      Prepare training data.
  
      Args:
          train_files (list): List of training files.
          device (torch.device): Device to use for training.
          cache_rate (float): Cache rate for dataset.
          num_workers (int): Number of workers for data loading.
          batch_size (int): Mini-batch size.
  
      Returns:
          ThreadDataLoader: Data loader for training.
      """
      train_transforms = Compose(
          [
              monai.transforms.LoadImaged(keys=["image"]),
              monai.transforms.EnsureChannelFirstd(keys=["image"]),
              monai.transforms.Lambdad(
                  keys="top_region_index", func=lambda x: torch.FloatTensor(json.load(open(x))["top_region_index"])
              ),
              monai.transforms.Lambdad(
                  keys="bottom_region_index", func=lambda x: torch.FloatTensor(json.load(open(x))["bottom_region_index"])
              ),
              monai.transforms.Lambdad(keys="spacing", func=lambda x: torch.FloatTensor(json.load(open(x))["spacing"])),
              monai.transforms.Lambdad(keys="top_region_index", func=lambda x: x * 1e2),
              monai.transforms.Lambdad(keys="bottom_region_index", func=lambda x: x * 1e2),
              monai.transforms.Lambdad(keys="spacing", func=lambda x: x * 1e2),
          ]
      )
  
      train_ds = monai.data.CacheDataset(
          data=train_files, transform=train_transforms, cache_rate=cache_rate, num_workers=num_workers
      )
      return ThreadDataLoader(train_ds, num_workers=6, batch_size=batch_size, shuffle=True) 
@gexinh
Copy link
Author

gexinh commented Oct 9, 2024

I found the problem:

In the scripts.diff_model_train.py, Line 51, the definition is as follows:

def prepare_data(
    train_files: list, device: torch.device, cache_rate: float, num_workers: int = 2, batch_size: int = 1
) -> ThreadDataLoader:

However, on line 359, the function call is missing the 'number_workers' argument, causing the 'batch size' parameter to be incorrectly used for 'number_workers'.

    train_loader = prepare_data(
        train_files, device, args.diffusion_unet_train["cache_rate"], args.diffusion_unet_train["batch_size"]
    )

KumoLiu added a commit to KumoLiu/tutorials that referenced this issue Oct 9, 2024
Signed-off-by: YunLiu <[email protected]>
@KumoLiu
Copy link
Contributor

KumoLiu commented Oct 9, 2024

Hi @gexinh, thanks for the reporting, fixed it in this PR: #1857.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants