Skip to content

Commit

Permalink
training: more efficient inference
Browse files Browse the repository at this point in the history
  • Loading branch information
atafra committed Jun 11, 2024
1 parent 716e687 commit 4ea01c4
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion training/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def save_images(path, image, image_srgb, num_channels, feature_ext=infer.main_fe
else:
save_image(filename_prefix + format, image_srgb, num_channels=num_channels[feature_ext])

with torch.no_grad():
with torch.inference_mode():
for group, input_names, target_name in image_sample_groups:
# Create the output directory if it does not exist
output_group_dir = os.path.join(output_dir, os.path.dirname(group))
Expand Down
2 changes: 1 addition & 1 deletion training/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def preprocess_dataset(data_name):
save_config(output_dir, cfg)

# Preprocess all datasets
with torch.no_grad():
with torch.inference_mode():
for dataset in [cfg.train_data, cfg.valid_data]:
if dataset:
preprocess_dataset(dataset)
Expand Down
2 changes: 1 addition & 1 deletion training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def main_worker(rank, cfg):
valid_loss = 0.

# Iterate over the batches
with torch.no_grad():
with torch.inference_mode():
for _, batch in enumerate(valid_loader, 0):
if cfg.device == 'cuda' and cfg.compile in {'reduce-overhead', 'max-autotune'}:
torch.compiler.cudagraph_mark_step_begin()
Expand Down

0 comments on commit 4ea01c4

Please sign in to comment.