Skip to content

Commit

Permalink
Fix bug when running on XPU
Browse files Browse the repository at this point in the history
  • Loading branch information
skywalker2012 committed Nov 23, 2023
1 parent a0901d2 commit c0b6e7a
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions tools/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,12 @@ def train(config,
eta_sec = ((epoch_num + 1 - epoch) * \
len(train_dataloader) - idx - 1) * eta_meter.avg
eta_sec_format = str(datetime.timedelta(seconds=int(eta_sec)))
max_mem_reserved_str = f"max_mem_reserved: {paddle.device.cuda.max_memory_reserved()} B"
max_mem_allocated_str = f"max_mem_allocated: {paddle.device.cuda.max_memory_allocated()} B"
if paddle.device.is_compiled_with_cuda():
max_mem_reserved_str = f"max_mem_reserved: {paddle.device.cuda.max_memory_reserved()} B"
max_mem_allocated_str = f"max_mem_allocated: {paddle.device.cuda.max_memory_allocated()} B"
else:
max_mem_reserved_str = f"max_mem_reserved: not supported on non-CUDA device B"
max_mem_allocated_str = f"max_mem_allocated: not supported on non-CUDA device B"
strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: ' \
'{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \
'ips: {:.5f} samples/s, eta: {}, {}, {}'.format(
Expand All @@ -379,7 +383,7 @@ def train(config,
total_samples / print_batch_step,
total_samples / train_batch_cost, eta_sec_format, max_mem_reserved_str, max_mem_allocated_str)
logger.info(strs)

total_samples = 0
train_reader_cost = 0.0
train_batch_cost = 0.0
Expand Down

0 comments on commit c0b6e7a

Please sign in to comment.