-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlogging.py
393 lines (334 loc) · 14.8 KB
/
logging.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
import datetime
import logging
import os
import sys
import time
from typing import Optional, TextIO
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.distributed as dist
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
from .utils import _calc_plot_dim, _get_distributed
def setup_file_logger(save_path: str, logger_name: str, first_line: str = ""):
logger = logging.getLogger(logger_name)
logger.handlers = []
f_handler = logging.FileHandler(save_path)
f_handler.setLevel(logging.DEBUG)
logger.addHandler(f_handler)
# create console handler
ch = logging.StreamHandler()
ch.setLevel(logging.WARNING)
logger.addHandler(ch)
logger.setLevel(logging.INFO)
logger.info(first_line)
return logger
class SyncedLoss(Joinable):
"""Gather loss values to a list that is averaged over parallel ranks.
Arguments:
num_losses: Number of different loss values.
"""
def __init__(self, num_losses: int):
super().__init__()
self.num_losses = num_losses
self.world_size, self.local_rank, self.global_rank, self.group = _get_distributed()
self.reset()
def reset(self):
"""Empty list of losses"""
self.losses = []
self.n_batches = 0
def __len__(self):
return len(self.losses)
def __getitem__(self, index: int):
return self.losses[index]
def mean(self) -> np.ndarray:
"""Get average loss over batches."""
return np.mean(self.losses, axis=0)
@property
def join_process_group(self):
return self.group
@property
def join_device(self):
return self.local_rank
def join_hook(self, **kwargs):
return _SyncedLossJoinHook(self)
def _sync_losses(self, losses, shadow=False):
assert self.world_size > 1, self.world_size
if not shadow:
assert len(losses) == self.num_losses, losses
# We haven't joined yet
Join.notify_join_context(self)
# Count non-joined ranks
world_size_eff = torch.ones(1, device=self.local_rank)
dist.all_reduce(world_size_eff, op=dist.ReduceOp.SUM)
# Sum losses over non-joined ranks
losses = torch.tensor(losses, device=self.local_rank)
dist.all_reduce(losses, op=dist.ReduceOp.SUM)
else: # We joined already, so shadow the reduce operations
# Don't count towards non-joined ranks
world_size_eff = torch.zeros(1, device=self.local_rank)
dist.all_reduce(world_size_eff, op=dist.ReduceOp.SUM)
# Also don't count towards sum of losses
losses = torch.zeros(self.num_losses, device=self.local_rank)
dist.all_reduce(losses, op=dist.ReduceOp.SUM)
# Add averaged losses to list
losses /= world_size_eff
losses = list(losses.cpu().numpy())
self.losses.append(losses)
return losses
def append(self, losses: torch.Tensor | np.ndarray | float | list[float]):
"""
Append a new batch of loss values.
Arguments:
losses: Loss values. Length should match ``self.num_losses``.
"""
if not isinstance(losses, list):
losses = [losses]
losses_ = []
for loss in losses:
if isinstance(loss, torch.Tensor):
if loss.size() == ():
losses_.append(loss.item())
else:
losses_ += list(loss.cpu().detach().numpy())
elif isinstance(loss, np.ndarray):
if loss.size == 1:
losses_.append(loss.item())
else:
losses_ += list(loss)
elif isinstance(loss, (int, float)):
losses_.append(loss)
else:
raise ValueError(f"Loss has unsupported type `{type(loss)}`")
losses = losses_
if np.isnan(losses).any():
raise ValueError(
f"Found a nan in losses ({losses}) at rank {self.global_rank} after {len(self.losses)} batches. "
f"Some of the previous losses were: {self.losses[-5:]}"
)
if self.world_size > 1:
losses = self._sync_losses(losses)
else:
self.losses.append(losses)
self.n_batches += 1
return losses
class _SyncedLossJoinHook(JoinHook):
"""Hook for when the number of batches does not match between processes."""
def __init__(self, synced_loss):
self.synced_loss = synced_loss
def main_hook(self):
self.synced_loss._sync_losses([], shadow=True)
def post_hook(self, is_last_joiner):
pass
class LossLogPlot:
"""
Log and plot model training loss history. Add losses for each batch with :meth:`add_train_loss` and :meth:`add_train_loss`,
and at the end of each epoch call :meth:`next_epoch` to print the status. Works with distributed training.
Arguments:
log_path: Path where loss log is saved.
plot_path: Path where plot of loss history is saved.
loss_labels: Labels for different loss components. If length > 1, an additional component ``'Total'`` is prepended to the list.
loss_weights: Weights for different loss components when there is more than one.
print_interval: Loss values are printed every **print_interval** batches.
init_epoch: Initial epoch. If not None and existing log has more epochs, discard them.
stream: Stream where log is printed to.
"""
def __init__(
self,
log_path: str,
plot_path: str,
loss_labels: list[str],
loss_weights: Optional[list[float]] = None,
print_interval: int = 10,
init_epoch: Optional[int] = None,
stream: TextIO = sys.stdout,
):
self.log_path = log_path
self.plot_path = plot_path
self.print_interval = print_interval
self.stream = stream
if len(loss_labels) > 1:
loss_labels = ["Total"] + loss_labels
self.loss_labels = loss_labels
if loss_weights is None or len(loss_weights) == 0:
loss_weights = [""] * len(self.loss_labels)
else:
if len(loss_labels) == 1:
assert len(loss_weights) == 1
else:
assert len(loss_weights) == (len(loss_labels) - 1)
loss_weights = [""] + loss_weights
self.loss_weights = loss_weights
self.train_losses = np.empty((0, len(loss_labels)))
self.val_losses = np.empty((0, len(loss_labels)))
self.world_size, self.local_rank, self.global_rank, _ = _get_distributed()
self.epoch = 1
self._synced_losses = {"train": SyncedLoss(len(self.loss_labels)), "val": SyncedLoss(len(self.loss_labels))}
self._init_log(init_epoch)
def _init_log(self, init_epoch: Optional[int]):
log_exists = os.path.isfile(self.log_path)
if self.world_size > 1:
dist.barrier()
if not log_exists:
if self.global_rank > 0:
return
self._write_log()
print(f"Created log at {self.log_path}", file=self.stream, flush=True)
else:
with open(self.log_path, "r") as f:
header = f.readline().rstrip("\r\n").split(";")
hl = (len(header) - 1) // 2
if len(self.loss_labels) != hl:
raise ValueError(
f"The length of the given list of loss names and the length of the header of the existing log at {self.log_path} do not match."
)
for line in f:
if init_epoch is not None and self.epoch >= init_epoch:
break
line = line.rstrip("\n").split(";")
if len(line) < 3:
continue
self.train_losses = np.append(self.train_losses, [[float(s) for s in line[1 : hl + 1]]], axis=0)
self.val_losses = np.append(self.val_losses, [[float(s) for s in line[hl + 1 :]]], axis=0)
self.epoch += 1
if self.global_rank == 0:
if init_epoch is not None:
self._write_log() # Make sure there are no additional rows in the log
print(f"Using existing log at {self.log_path}", file=self.stream, flush=True)
def _write_log(self):
with open(self.log_path, "w") as f:
f.write("epoch")
for i, label in enumerate(self.loss_labels):
label = f";train_{label}"
if self.loss_weights[i]:
label += f" (x {self.loss_weights[i]})"
f.write(label)
for i, label in enumerate(self.loss_labels):
label = f";val_{label}"
if self.loss_weights[i]:
label += f" (x {self.loss_weights[i]})"
f.write(label)
f.write("\n")
for epoch, (train_loss, val_loss) in enumerate(zip(self.train_losses, self.val_losses)):
f.write(str(epoch + 1))
for l in train_loss:
f.write(f";{l}")
for l in val_loss:
f.write(f";{l}")
f.write("\n")
def _add_loss(self, losses: torch.Tensor | np.ndarray | float | list[float], mode: str="train"):
synced_loss = self._synced_losses[mode]
losses = synced_loss.append(losses)
if len(losses) != len(self.loss_labels):
raise ValueError(f"Length of losses ({len(losses)}) does not match with number of loss labels ({len(self.loss_labels)}).")
if self.global_rank == 0 and len(synced_loss) % self.print_interval == 0:
self._print_losses(mode)
def _print_losses(self, mode: str = "train"):
if self.global_rank > 0:
return
synced_loss = self._synced_losses[mode]
losses = np.mean(synced_loss[-self.print_interval :], axis=0)
print(f"Epoch {self.epoch}, {mode} batch {len(synced_loss)} - Loss: " + self.loss_str(losses), file=self.stream, flush=True)
def loss_str(self, losses: list[float] | np.ndarray | torch.Tensor) -> str:
"""
Get a pretty string for loss values.
Arguments:
losses: List of losses of the same length as the number of loss labels.
Returns:
String representation of the losses.
"""
if len(losses) != len(self.loss_labels):
raise ValueError(f"Length of losses ({len(losses)}) does not match with number of loss labels ({len(self.loss_labels)}).")
if len(self.loss_labels) == 1:
msg = f"{self.loss_labels[0]}: {losses[0]:.6f}"
else:
msg = f"{losses[0]:.6f}"
msg_loss = [f"{label}: {loss:.6f}" for label, loss in zip(self.loss_labels[1:], losses[1:])]
msg += " (" + ", ".join(msg_loss) + ")"
return msg
def add_train_loss(self, losses: torch.Tensor | np.ndarray | float | list[float]):
"""Add losses for one training batch. Averaged over parallel processes.
Arguments:
losses: Losses to append to the list.
"""
if len(self._synced_losses["train"]) == 0:
self.epoch_start = time.perf_counter()
self._add_loss(losses, mode="train")
def add_val_loss(self, losses: torch.Tensor | np.ndarray | float | list[float]):
"""Add losses for one validation batch. Averaged over parallel processes.
Arguments:
losses: Losses to append to the list.
"""
if len(self._synced_losses["val"]) == 0:
self.val_start = time.perf_counter()
self._add_loss(losses, mode="val")
def next_epoch(self):
"""
Increment epoch by one, write current average batch losses to log, empty batch losses,
report epoch time to terminal, and update loss history plot.
"""
train_loss = self._synced_losses["train"].mean()
val_loss = self._synced_losses["val"].mean()
self.train_losses = np.append(self.train_losses, train_loss[None], axis=0)
self.val_losses = np.append(self.val_losses, val_loss[None], axis=0)
n_train = self._synced_losses["train"].n_batches
n_val = self._synced_losses["val"].n_batches
print(
f"Epoch {self.epoch} at rank {self.global_rank} contained {n_train} training batches " f"and {n_val} validation batches",
file=self.stream,
flush=True,
)
if self.global_rank == 0:
epoch_end = time.perf_counter()
train_step = (self.val_start - self.epoch_start) / n_train
val_step = (epoch_end - self.val_start) / n_val
print(f"Completed epoch {self.epoch} at {datetime.datetime.now()}", file=self.stream, flush=True)
print(f"Train loss: {self.loss_str(train_loss)}", file=self.stream, flush=True)
print(f"Val loss: {self.loss_str(val_loss)}", file=self.stream, flush=True)
print(
f"Epoch time: {epoch_end - self.epoch_start:.2f}s - Train step: {train_step:.5f}s " f"- Val step: {val_step:.5f}s",
file=self.stream,
flush=True,
)
self._write_log()
self.plot_history()
self.epoch += 1
self._synced_losses["train"].reset()
self._synced_losses["val"].reset()
def plot_history(self, show: bool = False):
"""
Plot history of current losses into ``self.plot_path``.
Arguments:
show: Whether to show the plot on screen.
"""
if self.global_rank > 0:
return
x = range(1, len(self.train_losses) + 1)
n_rows, n_cols = _calc_plot_dim(len(self.loss_labels), f=0)
fig, axes = plt.subplots(n_rows, n_cols, figsize=(7 * n_cols, 6 * n_rows))
if n_rows == 1 and n_cols == 1:
axes = np.expand_dims(axes, axis=0)
for i, (label, ax) in enumerate(zip(self.loss_labels, axes.flatten())):
ax.semilogy(x, self.train_losses[:, i], "-bx")
ax.semilogy(x, self.val_losses[:, i], "-gx")
ax.legend(["Training", "Validation"])
ax.set_xlabel("Epochs")
ax.set_ylabel("Loss")
if self.loss_weights[i]:
label = f"{label} (x {self.loss_weights[i]})"
ax.set_title(label)
fig.tight_layout()
plt.savefig(self.plot_path)
print(f"Loss history plot saved to {self.plot_path}", file=self.stream, flush=True)
if show:
plt.show()
else:
plt.close()
def get_joinable(self, mode: str = "train"):
"""Return a joinable for uneven training/validation inputs.
Arguments:
mode: Choose 'train or 'val'.
"""
if mode not in ["train", "val"]:
raise ValueError(f"mode should be 'train' or 'val', but got `{mode}`")
return self._synced_losses[mode]