Skip to content

Commit

Permalink
[train_gpt2.py] synchronize based on device
Browse files Browse the repository at this point in the history
  • Loading branch information
krrishnarraj committed Apr 11, 2024
1 parent a08c11b commit 4542f89
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion train_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,10 @@ def get_batch():
write_model(model, "gpt2_124M.bin")
write_state(model, x, y, logits, loss, "gpt2_124M_debug_state.bin")
optimizer.step()
torch.cuda.synchronize()
if device == "mps":
torch.mps.synchronize()
elif device == "cuda":
torch.cuda.synchronize()
t1 = time.time()
print(f"iteration {i}, loss: {loss.item()}, time: {(t1-t0)*1000:.3f}ms")

Expand Down

0 comments on commit 4542f89

Please sign in to comment.