-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathloading_faster.py
119 lines (101 loc) · 5.22 KB
/
loading_faster.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import argparse
import os
import torch
import torch.distributed as dist
from utils import *
from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy
from nvidia.dali.pipeline import pipeline_def
import nvidia.dali.types as types
import nvidia.dali.fn as fn
logger = make_logger('imagenet', 'logs')
def parse_args():
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR', help='path to dataset')
parser.add_argument('--local_rank', metavar='RANK', type=int, default=0)
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N',
help='mini-batch size (default: 256), this is the total '
'batch size of all GPUs on the current node when '
'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--dali_cpu', action='store_true',
help='Runs CPU based version of DALI pipeline.')
return parser.parse_args()
@pipeline_def
def create_dali_pipeline(data_dir, crop, size, shard_id, num_shards, dali_cpu=False, is_training=True):
images, labels = fn.readers.file(file_root=data_dir,
shard_id=shard_id,
num_shards=num_shards,
random_shuffle=is_training,
pad_last_batch=True,
name="Reader")
dali_device = 'cpu' if dali_cpu else 'gpu'
decoder_device = 'cpu' if dali_cpu else 'mixed'
device_memory_padding = 211025920 if decoder_device == 'mixed' else 0
host_memory_padding = 140544512 if decoder_device == 'mixed' else 0
if is_training:
images = fn.decoders.image_random_crop(images,
device=decoder_device, output_type=types.RGB,
device_memory_padding=device_memory_padding,
host_memory_padding=host_memory_padding,
random_aspect_ratio=[0.8, 1.25],
random_area=[0.1, 1.0],
num_attempts=100)
images = fn.resize(images,
device=dali_device,
resize_x=crop,
resize_y=crop,
interp_type=types.INTERP_TRIANGULAR)
mirror = fn.random.coin_flip(probability=0.5)
else:
images = fn.decoders.image(images,
device=decoder_device,
output_type=types.RGB)
images = fn.resize(images,
device=dali_device,
size=size,
mode="not_smaller",
interp_type=types.INTERP_TRIANGULAR)
mirror = False
images = fn.crop_mirror_normalize(images.gpu(),
dtype=types.FLOAT,
output_layout="CHW",
crop=(crop, crop),
mean=[0.485 * 255,0.456 * 255,0.406 * 255],
std=[0.229 * 255,0.224 * 255,0.225 * 255],
mirror=mirror)
labels = labels.gpu()
return images, labels
def train(train_loader, epoch, args):
load_time = Benchmark()
for i, data in enumerate(train_loader):
input = data[0]["data"]
target = data[0]["label"].squeeze(-1).long()
logger.info(f'Epoch #{epoch} [{i}/{len(train_loader)}] {load_time.elapsed():>.3f}')
if __name__ == '__main__':
assert torch.cuda.is_available(), 'CUDA IS NOT AVAILABLE!!'
args = parse_args()
args.batch_size = int(args.batch_size / torch.cuda.device_count())
logger.info(args)
dist.init_process_group('nccl')
torch.cuda.set_device(args.local_rank)
pipe = create_dali_pipeline(batch_size=args.batch_size,
num_threads=args.workers,
device_id=args.local_rank,
seed=12 + args.local_rank,
data_dir=os.path.join(args.data, 'train'),
crop=224,
size=256,
dali_cpu=args.dali_cpu,
shard_id=args.local_rank,
num_shards=dist.get_world_size(),
is_training=True)
pipe.build()
train_loader = DALIClassificationIterator(pipe, reader_name="Reader", last_batch_policy=LastBatchPolicy.PARTIAL)
benchmark = Benchmark()
for epoch in range(0, args.epochs):
train(train_loader, epoch, args)
logger.info(f'{benchmark.elapsed():>.3f}')