From 4ea01c482cc38cfbc03adf522191960f7324dee7 Mon Sep 17 00:00:00 2001 From: Attila Afra Date: Tue, 11 Jun 2024 12:31:51 +0000 Subject: [PATCH] training: more efficient inference --- training/infer.py | 2 +- training/preprocess.py | 2 +- training/train.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/training/infer.py b/training/infer.py index 5746dbf9..0fefdb11 100755 --- a/training/infer.py +++ b/training/infer.py @@ -157,7 +157,7 @@ def save_images(path, image, image_srgb, num_channels, feature_ext=infer.main_fe else: save_image(filename_prefix + format, image_srgb, num_channels=num_channels[feature_ext]) - with torch.no_grad(): + with torch.inference_mode(): for group, input_names, target_name in image_sample_groups: # Create the output directory if it does not exist output_group_dir = os.path.join(output_dir, os.path.dirname(group)) diff --git a/training/preprocess.py b/training/preprocess.py index 27a8038b..dc8db995 100755 --- a/training/preprocess.py +++ b/training/preprocess.py @@ -139,7 +139,7 @@ def preprocess_dataset(data_name): save_config(output_dir, cfg) # Preprocess all datasets - with torch.no_grad(): + with torch.inference_mode(): for dataset in [cfg.train_data, cfg.valid_data]: if dataset: preprocess_dataset(dataset) diff --git a/training/train.py b/training/train.py index 6db48e48..ab1ba0ad 100755 --- a/training/train.py +++ b/training/train.py @@ -244,7 +244,7 @@ def main_worker(rank, cfg): valid_loss = 0. # Iterate over the batches - with torch.no_grad(): + with torch.inference_mode(): for _, batch in enumerate(valid_loader, 0): if cfg.device == 'cuda' and cfg.compile in {'reduce-overhead', 'max-autotune'}: torch.compiler.cudagraph_mark_step_begin()