-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[transformer] add multi warmup and learning rate for different modules #2449
Conversation
23ab408
to
81f9ee6
Compare
396b188
to
9254a0f
Compare
9254a0f
to
72afdf3
Compare
b39a8b1 这个commit 重构了cv 和train 以及 lr 的打印逻辑 step 模式单个lr+warmup , cv和trainepoch 模式 单个lr+warmup , cv和trainepoch 模式 多个个lr+warmup , cv和trainstep 模式多个lr+warmup , cv和train |
fsdp (TBD)8卡A100, step 模式 ,3000 save interval, fp32, zero2
8卡A100, step 模式 ,3000 save interval, bf16, zero2
|
强,ctc的收益是因为使用了不同的学习率吗 |
直觉上,整个模型是pretrain的 除了ctc和conv2d, 所以只让这两个较大的学习率 ‘快速的学习’ (上边这个只有ctc的lr和enc+dec的lr),其他的stable training 未来也会支持w2vbert的fintune 所以这里就引入了 |
#2412 修复了A100 和v100精度问题, 这里实验需要重新跑下 (预期还会更好些) |
#2412 使用修复后的的结果: 8卡A100, step 模式 ,3000 save interval, bf16, zero2
|
optim: adam
optim_conf:
lr: [0.001, 0.00005, 0.00001]
modules: ['ctc', 'encoder.embed']
scheduler: warmuplr
scheduler_conf:
warmup_steps: [1500, 10000, 5000] 8卡A100, step 模式 ,3000 save interval, bf16, zero2
|
if isinstance(lr, List): | ||
optim_conf['lr'] = lr[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这一步是为啥,optimizer必须传一个lr参数且不能是list?
add casual model fix typo rm ckpt add topk topp sampler fix positoin [train_engine] support fsdp (wenet-e2e#2412) * [train_engine] support fsdp * [train_engine] support fsdp * unify scaler and amp * fp32&&fp16 works in fsdp env * fix fsdp in cv auto cast * try to fix wenet.join fsdp * implementing zero1 under fsdp is almost equivalent to deepspeed's zero1 * fix clip_and_grad_ * fix train summary * all wenet xxxformer works (-paraformer -transducer) * try to fix nan * add barrier for cv * add destroy group for end of all train * refactor wrap methods and ckpt works * fix ckpt * fix cv in dtype != float32 * fix ckpt in model mode * fix bf16 amp * refactor scaler and autocast, fix fp32 fp16 bf16 for fsdp * fix fp32 nullcontext to nullcontext() * modify after review * fix lint * fix lint LoRA support (wenet-e2e#2049) * support lora for v3.0.1 * format code and update lora attention && encoder * fix bug when lora_list is None --------- Co-authored-by: Xingchen Song(宋星辰) <[email protected]> [env] update python version and deepspeed version (wenet-e2e#2462) * [env] update python version and deepspeed version * [env] fix lint fix rope pos embdining (wenet-e2e#2463) * fix rope pos embdining * fix dropout * fix comment [transformer] add multi warmup and learning rate for different modules (wenet-e2e#2449) * [transformer] add multi warmup and learning rate for different modules * fix typo * it works in warmuplr * fix lr in tensorboard in step mode * fix cv log * cv works * refactor cv log * add helper lrs_to_string * fix lrstr * fix ddp multiple lr * fix initial step * revert to -1 * fix sub params dup * fix step * fix step * fix log * add assert for scheduler * add comment for log --------- Co-authored-by: Xingchen Song(宋星辰) <[email protected]> add generate add toto support sft & pretrain training forward gemm conversion works support init casual model [whisper] limit language to Chinese (wenet-e2e#2470) [train] convert tensor to scalar (wenet-e2e#2471) [workflow] upgrad python version to 3.10 (wenet-e2e#2472) * [workflow] upgrad python version to 3.10 * [workflow] try to pass refactor cache behaviour in training mode (reduce compute cost and memory) (wenet-e2e#2473) all gemma model works fix ut fix ut (wenet-e2e#2477) * fix ut * fix py version [transformer] Make MoE runnable (wenet-e2e#2474) [transformer] fix mqa (wenet-e2e#2478) enable mmap in torch.load (wenet-e2e#2479) [example] Add deespeed configs of different stages for illustration (wenet-e2e#2485) [example] Fix prefetch and step_save (wenet-e2e#2486) [ctl] simplified ctl (wenet-e2e#2483) * [ctl] simplified ctl * [ctl] unify [branchformer] simplified branchformer (wenet-e2e#2482) * [transformer] simplified branchformer * fix yaml * support mqa gradiengt ckpt sdpa * fix gradient checkponit * add deepspeed comment in layer dropout * fix comment [e_branchformer] simplified e_branchformer (wenet-e2e#2484) * [e_branchformer] simplified ctl * try to fix ut * try to fix ut * fix activation * fix att args * e-branformer works [transformer] refactor cache (wenet-e2e#2481) * [transformer] refactor cache * fix ut * unify cache type in branchformer and ebranchformer fix cache fix gradient ckpt in branchformer/ebranformer (wenet-e2e#2488) fix search after refactor cache (wenet-e2e#2490) generate works! unify chat pattern convert llama3 works [transformer] set use_reentrant=False for gradient ckpt (wenet-e2e#2491) [transformer] fix warning: ignore(True) has been deprecated (wenet-e2e#2492) * [transformer] fix warning: ignore(True) has been deprecated * [transformer] fix warning: ignore(True) has been deprecated [log] avoid reduntant logging (wenet-e2e#2493) fix w1 w2 w3 in feedforward add 70b temporarily mv LLM to wenet support llm dataset unify config add dataset yaml in script support llm dataset dynamic static bucket works [transformer] refacgtor mqa repeat (wenet-e2e#2497) [transformer] fix mqa in cross att (wenet-e2e#2498) [deepspeed] update json config (wenet-e2e#2499) training works pretrain works refactor covert fix flash att in generate llama works fix llama3 fix speed try fix ut support stop tokens in gen and support ppl support stop tokens in gen and support ppl
当微调的时候,模型的不同部分需要的学习率和warmup不一样,比如以下 微调whisper的手 因为没有ctc的weight
我门可以设置ctc的warmu up为12000, lr 0.01, 其他部分为1200,0.00001
TODO