Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance readme for ddp cases in ldm tutorials #1857

Merged
merged 8 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions generation/2d_ldm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ torchrun \
--master_addr=localhost --master_port=1234 \
train_autoencoder.py -c ./config/config_train_32g.json -e ./config/environment.json -g ${NUM_GPUS_PER_NODE}
```
Please note that during multi-GPU training, additional GPU memory may be required. Users might need to reduce the `batch_size` accordingly based on their available resources to ensure smooth training.

<p align="center">
<img src="./figs/train_recon.png" alt="autoencoder train curve" width="45%" >
Expand Down Expand Up @@ -88,6 +89,8 @@ torchrun \
--master_addr=localhost --master_port=1234 \
train_diffusion.py -c ./config/config_train_32g.json -e ./config/environment.json -g ${NUM_GPUS_PER_NODE}
```
Please note that during multi-GPU training, additional GPU memory may be required. Users might need to reduce the `batch_size` accordingly based on their available resources to ensure smooth training.

<p align="center">
<img src="./figs/train_diffusion.png" alt="latent diffusion train curve" width="45%" >
&nbsp; &nbsp; &nbsp; &nbsp;
Expand Down
2 changes: 2 additions & 0 deletions generation/3d_ldm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ torchrun \
--master_addr=localhost --master_port=1234 \
train_autoencoder.py -c ./config/config_train_32g.json -e ./config/environment.json -g ${NUM_GPUS_PER_NODE}
```
Please note that during multi-GPU training, additional GPU memory may be required. Users might need to reduce the `batch_size` accordingly based on their available resources to ensure smooth training.

<p align="center">
<img src="./figs/train_recon.png" alt="autoencoder train curve" width="45%" >
Expand All @@ -87,6 +88,7 @@ torchrun \
--master_addr=localhost --master_port=1234 \
train_diffusion.py -c ./config/config_train_32g.json -e ./config/environment.json -g ${NUM_GPUS_PER_NODE}
```
Please note that during multi-GPU training, additional GPU memory may be required. Users might need to reduce the `batch_size` accordingly based on their available resources to ensure smooth training.
<p align="center">
<img src="./figs/train_diffusion.png" alt="latent diffusion train curve" width="45%" >
&nbsp; &nbsp; &nbsp; &nbsp;
Expand Down
4 changes: 3 additions & 1 deletion generation/maisi/maisi_diff_unet_training_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,9 @@
"\n",
"After all latent features have been created, we will initiate the multi-GPU script to train the latent diffusion model.\n",
"\n",
"The image generation process utilizes the [DDPM scheduler](https://arxiv.org/pdf/2006.11239) with 1,000 iterative steps. The diffusion model is optimized using L1 loss and a decayed learning rate scheduler. The batch size for this process is set to 1."
"The image generation process utilizes the [DDPM scheduler](https://arxiv.org/pdf/2006.11239) with 1,000 iterative steps. The diffusion model is optimized using L1 loss and a decayed learning rate scheduler. The batch size for this process is set to 1.\n",
"\n",
"Please be aware that using the H100 GPU may occasionally result in random segmentation faults. To avoid this issue, you can disable AMP by setting the `--no_amp` flag."
]
},
{
Expand Down
44 changes: 27 additions & 17 deletions generation/maisi/scripts/diff_model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torch.nn.parallel import DistributedDataParallel

import monai
from monai.data import ThreadDataLoader, partition_dataset
from monai.data import DataLoader, partition_dataset
from monai.transforms import Compose
from monai.utils import first

Expand All @@ -50,7 +50,7 @@ def load_filenames(data_list_path: str) -> list:

def prepare_data(
train_files: list, device: torch.device, cache_rate: float, num_workers: int = 2, batch_size: int = 1
) -> ThreadDataLoader:
) -> DataLoader:
"""
Prepare training data.

Expand All @@ -62,7 +62,7 @@ def prepare_data(
batch_size (int): Mini-batch size.

Returns:
ThreadDataLoader: Data loader for training.
DataLoader: Data loader for training.
"""

def _load_data_from_file(file_path, key):
Expand Down Expand Up @@ -90,7 +90,7 @@ def _load_data_from_file(file_path, key):
data=train_files, transform=train_transforms, cache_rate=cache_rate, num_workers=num_workers
)

return ThreadDataLoader(train_ds, num_workers=6, batch_size=batch_size, shuffle=True)
return DataLoader(train_ds, num_workers=6, batch_size=batch_size, shuffle=True)


def load_unet(args: argparse.Namespace, device: torch.device, logger: logging.Logger) -> torch.nn.Module:
Expand Down Expand Up @@ -124,14 +124,12 @@ def load_unet(args: argparse.Namespace, device: torch.device, logger: logging.Lo
return unet


def calculate_scale_factor(
train_loader: ThreadDataLoader, device: torch.device, logger: logging.Logger
) -> torch.Tensor:
def calculate_scale_factor(train_loader: DataLoader, device: torch.device, logger: logging.Logger) -> torch.Tensor:
"""
Calculate the scaling factor for the dataset.

Args:
train_loader (ThreadDataLoader): Data loader for training.
train_loader (DataLoader): Data loader for training.
device (torch.device): Device to use for calculation.
logger (logging.Logger): Logger for logging information.

Expand Down Expand Up @@ -181,7 +179,7 @@ def create_lr_scheduler(optimizer: torch.optim.Optimizer, total_steps: int) -> t
def train_one_epoch(
epoch: int,
unet: torch.nn.Module,
train_loader: ThreadDataLoader,
train_loader: DataLoader,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.PolynomialLR,
loss_pt: torch.nn.L1Loss,
Expand All @@ -193,14 +191,15 @@ def train_one_epoch(
device: torch.device,
logger: logging.Logger,
local_rank: int,
amp: bool = True,
) -> torch.Tensor:
"""
Train the model for one epoch.

Args:
epoch (int): Current epoch number.
unet (torch.nn.Module): UNet model.
train_loader (ThreadDataLoader): Data loader for training.
train_loader (DataLoader): Data loader for training.
optimizer (torch.optim.Optimizer): Optimizer.
lr_scheduler (torch.optim.lr_scheduler.PolynomialLR): Learning rate scheduler.
loss_pt (torch.nn.L1Loss): Loss function.
Expand All @@ -212,6 +211,7 @@ def train_one_epoch(
device (torch.device): Device to use for training.
logger (logging.Logger): Logger for logging information.
local_rank (int): Local rank for distributed training.
amp (bool): Use automatic mixed precision training.

Returns:
torch.Tensor: Training loss for the epoch.
Expand All @@ -237,7 +237,7 @@ def train_one_epoch(

optimizer.zero_grad(set_to_none=True)

with autocast("cuda", enabled=True):
with autocast("cuda", enabled=amp):
noise = torch.randn(
(num_images_per_batch, 4, images.size(-3), images.size(-2), images.size(-1)), device=device
)
Expand All @@ -256,9 +256,13 @@ def train_one_epoch(

loss = loss_pt(noise_pred.float(), noise.float())

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
if amp:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()

lr_scheduler.step()

Expand Down Expand Up @@ -312,14 +316,18 @@ def save_checkpoint(
)


def diff_model_train(env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int) -> None:
def diff_model_train(
env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int, amp: bool = True
) -> None:
"""
Main function to train a diffusion model.

Args:
env_config_path (str): Path to the environment configuration file.
model_config_path (str): Path to the model configuration file.
model_def_path (str): Path to the model definition file.
num_gpus (int): Number of GPUs to use for training.
amp (bool): Use automatic mixed precision training.
"""
args = load_config(env_config_path, model_config_path, model_def_path)
local_rank, world_size, device = initialize_distributed(num_gpus)
Expand Down Expand Up @@ -357,7 +365,7 @@ def diff_model_train(env_config_path: str, model_config_path: str, model_def_pat
)[local_rank]

train_loader = prepare_data(
train_files, device, args.diffusion_unet_train["cache_rate"], args.diffusion_unet_train["batch_size"]
train_files, device, args.diffusion_unet_train["cache_rate"], batch_size=args.diffusion_unet_train["batch_size"]
)

unet = load_unet(args, device, logger)
Expand Down Expand Up @@ -392,6 +400,7 @@ def diff_model_train(env_config_path: str, model_config_path: str, model_def_pat
device,
logger,
local_rank,
amp=amp,
)

loss_torch = loss_torch.tolist()
Expand Down Expand Up @@ -431,6 +440,7 @@ def diff_model_train(env_config_path: str, model_config_path: str, model_def_pat
"--model_def", type=str, default="./configs/config_maisi.json", help="Path to model definition file"
)
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use for training")
parser.add_argument("--no_amp", dest="amp", action="store_false", help="Disable automatic mixed precision training")

args = parser.parse_args()
diff_model_train(args.env_config, args.model_config, args.model_def, args.num_gpus)
diff_model_train(args.env_config, args.model_config, args.model_def, args.num_gpus, args.amp)
Loading