Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for gradient_accumulation_steps training slow #225

Merged
merged 2 commits into from
Apr 18, 2023

Conversation

otaviogood
Copy link
Contributor

@otaviogood otaviogood commented Mar 26, 2023

This is not thoroughly tested. I did not test multiple computers (clusters) and I have no idea if it would work. My last commit fixed OWT training, but made Shakespeare train really slowly for people. Sorry.

Tested, using pytorch 2 (recently released):

gpt2, 1 GPU (4090)
python3 train.py config/train_gpt2.py
step 0: train loss 10.9895, val loss 10.9900
iter 0: loss 10.9774, time 25117.76ms, mfu -100.00%
iter 10: loss 10.4046, time 4666.49ms, mfu 28.86%
iter 20: loss 9.8023, time 4627.31ms, mfu 28.88%
iter 30: loss 9.4981, time 4619.24ms, mfu 28.91%
iter 40: loss 9.2942, time 4626.03ms, mfu 28.93%
iter 50: loss 9.1234, time 4622.44ms, mfu 28.95%
iter 60: loss 8.8425, time 4753.52ms, mfu 28.89%
iter 70: loss 8.6453, time 4621.80ms, mfu 28.91%
iter 80: loss 8.4404, time 4778.06ms, mfu 28.84%
iter 90: loss 8.2229, time 4620.58ms, mfu 28.87%
iter 100: loss 7.8800, time 4627.31ms, mfu 28.89%
iter 110: loss 7.9027, time 4620.63ms, mfu 28.92%
iter 120: loss 7.4240, time 4628.48ms, mfu 28.93%
iter 130: loss 7.4718, time 4655.71ms, mfu 28.93%
iter 140: loss 7.0861, time 4622.26ms, mfu 28.95%
iter 150: loss 7.0825, time 4779.64ms, mfu 28.88%
iter 160: loss 7.0579, time 4621.46ms, mfu 28.90%
iter 170: loss 6.9085, time 4726.33ms, mfu 28.86%
iter 180: loss 6.7840, time 4622.41ms, mfu 28.89%
iter 190: loss 6.7569, time 4628.46ms, mfu 28.91%
iter 200: loss 6.6118, time 4699.47ms, mfu 28.88%
iter 210: loss 6.6385, time 4671.57ms, mfu 28.88%
iter 220: loss 6.5873, time 4619.89ms, mfu 28.90%
iter 230: loss 6.3390, time 4628.21ms, mfu 28.92%
iter 240: loss 6.2774, time 4621.85ms, mfu 28.94%
iter 250: loss 6.4382, time 4629.50ms, mfu 28.96%
iter 260: loss 6.1668, time 4717.24ms, mfu 28.92%
iter 270: loss 6.2586, time 4628.69ms, mfu 28.94%
iter 280: loss 6.3917, time 4781.76ms, mfu 28.86%
iter 290: loss 6.2656, time 4627.78ms, mfu 28.88%
iter 300: loss 6.1675, time 4633.31ms, mfu 28.90%

...
step 1000: train loss 4.5085, val loss 4.5240
step 2000: train loss 3.7875, val loss 3.7994
step 3000: train loss 3.5260, val loss 3.5630
step 4000: train loss 3.4106, val loss 3.4388
step 5000: train loss 3.3494, val loss 3.3631
step 6000: train loss 3.3062, val loss 3.3141
step 7000: train loss 3.2660, val loss 3.2750
step 8000: train loss 3.2323, val loss 3.2474
Took overnight to get to this point. About 9 hours. Looks good.

gpt2, 1 GPU (4090)
python3 train.py
The default setup. Same thing with different logging.

Shakespeare, 1 GPU (4090)
python3 train.py config/train_shakespeare_char.py
step 0: train loss 4.2859, val loss 4.2809
iter 0: loss 4.2708, time 7437.57ms, mfu -100.00%
iter 10: loss 3.2433, time 24.39ms, mfu 15.28%
iter 20: loss 2.7924, time 22.25ms, mfu 15.43%
iter 30: loss 2.6378, time 22.26ms, mfu 15.56%
iter 40: loss 2.5764, time 22.28ms, mfu 15.67%
iter 50: loss 2.5270, time 22.26ms, mfu 15.78%
...
step 250: train loss 1.9969, val loss 2.0965
step 500: train loss 1.5666, val loss 1.7654
step 750: train loss 1.3920, val loss 1.6156
step 1000: train loss 1.3010, val loss 1.5467
step 1250: train loss 1.2331, val loss 1.5115
step 1500: train loss 1.1807, val loss 1.5044
step 1750: train loss 1.1286, val loss 1.4806
1 minute, 7 seconds to get to 1.48 val loss. This is slightly off from Andrej's 1.4697, but I think he had a lucky run or something. IDK. After this it overfits for me.

gpt2, 8xA100 GPUs
torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py
step 0: train loss 10.9895, val loss 10.9900
iter 0: loss 10.9916, time 30893.63ms, mfu -100.00%
iter 10: loss 10.3700, time 435.79ms, mfu 38.63%
iter 20: loss 9.7094, time 435.65ms, mfu 38.63%
iter 30: loss 9.5121, time 436.04ms, mfu 38.62%
iter 40: loss 9.3336, time 436.46ms, mfu 38.62%
iter 50: loss 9.0596, time 436.75ms, mfu 38.61%
iter 60: loss 8.8413, time 437.20ms, mfu 38.60%
iter 70: loss 8.5668, time 439.75ms, mfu 38.57%
iter 80: loss 8.3117, time 438.65ms, mfu 38.55%
iter 90: loss 8.0602, time 439.67ms, mfu 38.52%
...
step 1000: train loss 4.5096, val loss 4.5192
step 2000: train loss 3.7778, val loss 3.8028
step 3000: train loss 3.5356, val loss 3.5424
step 4000: train loss 3.4144, val loss 3.4310
step 5000: train loss 3.3394, val loss 3.3458
This looks to be the same as good training runs I did before, so I'm calling this good.

finetune Shakespeare
step 15: train loss 2.7552, val loss 2.8167
Well I'm not sure exactly how well this is supposed to work. This was my best val loss. Looks like Shakespeare to me but I never really liked that stuff. :P

@dkobak
Copy link

dkobak commented Mar 26, 2023

Cool, thanks. I think it's important that it gets merged soon!

@karpathy
Copy link
Owner

not sure i get this PR... :)

@karpathy karpathy closed this Apr 13, 2023
@dkobak
Copy link

dkobak commented Apr 13, 2023

Oh no, this is a really important fix to something that got broken #145 ! @karpathy please consider merging this. The current master is 40 times slower (!!) than it should be when running tiny-shakespeare on a single GPU.

@karpathy
Copy link
Owner

Oh ok maybe I understand. The training isn't 40 times slower. It's just accumulating gradient 40 times before it updates once. This is done to simulate batches of 0.5M parameters, and hence all the hyperparameters should work ok.

@dkobak
Copy link

dkobak commented Apr 13, 2023

No, I swear, it really takes A LOT longer (maybe not 40 times, but at least 10 times, a huge difference) by the wall clock to reach the same validation loss. I don't fully understand why this happens but it does happen! There are several issues reporting exactly this problem: #178, #179.

If you want I can measure wall clock time to reach the min validation loss with/without this change and report.

@dkobak
Copy link

dkobak commented Apr 13, 2023

Look, I measured it (python train.py config/train_shakespeare_char.py on a single NVIDIA RTX A6000):

With this PR I get to training loss 2.0 in 250 iterations, taking around 40ms per 10 iterations.

iter 220: loss 2.1410, time 48.04ms, mfu 8.06%
iter 230: loss 2.0778, time 67.98ms, mfu 7.80%
iter 240: loss 2.0798, time 37.94ms, mfu 8.00%

With the current master I need around 170 iterations, but now it takes around 1600ms per 10 iterations (40x slowdown).

iter 150: loss 2.1255, time 1627.96ms, mfu 9.35%
iter 160: loss 2.0543, time 1608.02ms, mfu 9.35%
iter 170: loss 2.0049, time 1590.82ms, mfu 9.35%

Overall it's 170/250 * 40 ~ 25x slowdown.

@otaviogood
Copy link
Contributor Author

Andrej, I basically broke Shakespeare with my first PR (but fixed OWT). I think my Shakespeare ended up about 20x slower for training. I didn't expect grad accum to be so slow. So fixing this issue is kinda important to anyone who trains Shakespeare, which I'd guess is most people. Gimme a call if u want me to talk u through any details.

@karpathy karpathy reopened this Apr 18, 2023
@karpathy
Copy link
Owner

ok reopening, let me take a look...

@karpathy
Copy link
Owner

Ok so I think I'm generally ok with this PR, the philosophy is that the gradient_accumulation_steps in the config is set assuming just a single GPU, but then if the world size is larger when the script actually runs, we divide it by the number of GPUs that are sharing the workload. I think I'm ok with that. Ok merging.

@karpathy karpathy merged commit 21f9bff into karpathy:master Apr 18, 2023
gkielian added a commit to gkielian/ReaLLMASIC_nanogpt that referenced this pull request Aug 16, 2024
Add implementation of Rotary Embeddings
gkielian added a commit to gkielian/ReaLLMASIC_nanogpt that referenced this pull request Sep 5, 2024
Add implementation of Rotary Embeddings
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants