diff --git a/examples/aishell/s0/conf/ds_stage2.json b/examples/aishell/s0/conf/ds_stage2.json new file mode 100644 index 000000000..49884009a --- /dev/null +++ b/examples/aishell/s0/conf/ds_stage2.json @@ -0,0 +1,57 @@ +{ + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "steps_per_print": 100, + "gradient_clipping": 0.0001, + "fp16": { + "enabled": false, + "auto_cast": false, + "loss_scale": 0, + "initial_scale_power": 8, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": false + }, + "zero_force_ds_cpu_optimizer": false, + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "none", + "pin_memory": true + }, + "offload_param": { + "device": "none", + "pin_memory": true + }, + "allgather_partitions": true, + "allgather_bucket_size": 1e7, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 1e7, + "contiguous_gradients" : true + }, + "activation_checkpointing": { + "partition_activations": false, + "cpu_checkpointing": false, + "contiguous_memory_optimization": false, + "number_checkpoints": null, + "synchronize_checkpoint_boundary": false, + "profile": true + }, + "flops_profiler": { + "enabled": false, + "profile_step": 100, + "module_depth": -1, + "top_modules": 1, + "detailed": true, + "output_file": null + }, + "tensorboard": { + "enabled": true, + "output_path": "tensorboard/ds_logs/", + "job_name": "deepspeed" + } +} diff --git a/examples/aishell/s0/conf/train_u2++_conformer_1.8B.yaml b/examples/aishell/s0/conf/train_u2++_conformer_1.8B.yaml new file mode 100644 index 000000000..c13b4b295 --- /dev/null +++ b/examples/aishell/s0/conf/train_u2++_conformer_1.8B.yaml @@ -0,0 +1,90 @@ +# network architecture +# encoder related +encoder: conformer +encoder_conf: + output_size: 2048 # dimension of attention + attention_heads: 16 + linear_units: 8192 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + input_layer: conv2d8 # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: true + cnn_module_kernel: 8 + use_cnn_module: True + activation_type: 'swish' + pos_enc_layer_type: 'rel_pos' + selfattention_layer_type: 'rel_selfattn' + causal: true + use_dynamic_chunk: true + cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster + use_dynamic_left_chunk: false + +# decoder related +decoder: bitransformer +decoder_conf: + attention_heads: 16 + linear_units: 8192 + num_blocks: 3 + r_num_blocks: 3 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.1 + src_attention_dropout_rate: 0.1 + +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + reverse_weight: 0.3 + +dataset_conf: + filter_conf: + max_length: 40960 + min_length: 0 + token_max_length: 200 + token_min_length: 1 + resample_conf: + resample_rate: 16000 + speed_perturb: true + fbank_conf: + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + spec_aug: true + spec_aug_conf: + num_t_mask: 2 + num_f_mask: 2 + max_t: 50 + max_f: 10 + spec_sub: true + spec_sub_conf: + num_t_sub: 3 + max_t: 30 + spec_trim: false + spec_trim_conf: + max_t: 50 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + sort: true + sort_conf: + sort_size: 500 # sort_size should be less than shuffle_size + batch_conf: + batch_type: 'static' # static or dynamic + batch_size: 16 + +grad_clip: 5 +accum_grad: 1 +max_epoch: 100 +log_interval: 100 + +optim: adam +optim_conf: + lr: 0.001 +scheduler: warmuplr # pytorch v1.1.0+ required +scheduler_conf: + warmup_steps: 25000 diff --git a/examples/aishell/s0/run.sh b/examples/aishell/s0/run.sh index fb41f653c..d3ff2ddfa 100644 --- a/examples/aishell/s0/run.sh +++ b/examples/aishell/s0/run.sh @@ -48,6 +48,8 @@ train_config=conf/train_conformer.yaml cmvn=true dir=exp/conformer checkpoint= +num_workers=8 +prefetch=500 # use average_checkpoint will get better result average_checkpoint=true @@ -55,6 +57,10 @@ decode_checkpoint=$dir/final.pt average_num=30 decode_modes="ctc_greedy_search ctc_prefix_beam_search attention attention_rescoring" +deepspeed=false +deepspeed_config=conf/ds_stage2.json +deepspeed_save_states="model_only" + . tools/parse_options.sh || exit 1; if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then @@ -116,11 +122,12 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # You have to rm `INIT_FILE` manually when you resume or restart a # multi-machine training. INIT_FILE=$dir/ddp_init + rm -f ${INIT_FILE} # remove previous INIT_FILE init_method=file://$(readlink -f $INIT_FILE) echo "$0: init method is $init_method" num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') # Use "nccl" if it works, otherwise use "gloo" - dist_backend="gloo" + dist_backend="nccl" world_size=`expr $num_gpus \* $num_nodes` echo "total gpus is: $world_size" cmvn_opts= @@ -130,30 +137,60 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # train.py rewrite $train_config to $dir/train.yaml with model input # and output dimension, and $dir/train.yaml will be used for inference # and export. - for ((i = 0; i < $num_gpus; ++i)); do - { - gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) - # Rank of each gpu/process used for knowing whether it is - # the master of a worker. - rank=`expr $node_rank \* $num_gpus + $i` - python wenet/bin/train.py --gpu $gpu_id \ - --config $train_config \ - --data_type $data_type \ - --symbol_table $dict \ - --train_data data/$train_set/data.list \ - --cv_data data/dev/data.list \ - ${checkpoint:+--checkpoint $checkpoint} \ - --model_dir $dir \ - --ddp.init_method $init_method \ - --ddp.world_size $world_size \ - --ddp.rank $rank \ - --ddp.dist_backend $dist_backend \ - --num_workers 1 \ - $cmvn_opts \ - --pin_memory - } & - done - wait + if [ ${deepspeed} == true ]; then + echo "using deepspeed" + # NOTE(xcsong): deepspeed fails with gloo, see + # https://github.com/microsoft/DeepSpeed/issues/2818 + dist_backend="nccl" + [ ! -f data/$train_set/data.list.filter ] && \ + python tools/filter_uneven_data.py data/$train_set/data.list \ + $data_type $num_gpus $num_utts_per_shard data/$train_set/data.list.filter + deepspeed --include localhost:$CUDA_VISIBLE_DEVICES \ + wenet/bin/train.py \ + --deepspeed \ + --deepspeed_config ${deepspeed_config} \ + --deepspeed.save_states ${deepspeed_save_states} \ + --ddp.dist_backend $dist_backend \ + --ddp.init_method $init_method \ + --data_type $data_type \ + --config $train_config \ + --symbol_table data/dict/lang_char.txt \ + --train_data data/$train_set/data.list.filter \ + --cv_data data/dev/data.list \ + ${checkpoint:+--checkpoint $checkpoint} \ + --model_dir $dir \ + --num_workers ${num_workers} \ + --prefetch ${prefetch} \ + $cmvn_opts \ + --pin_memory + else + echo "using torch ddp" + for ((i = 0; i < $num_gpus; ++i)); do + { + gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) + # Rank of each gpu/process used for knowing whether it is + # the master of a worker. + rank=`expr $node_rank \* $num_gpus + $i` + python wenet/bin/train.py --gpu $gpu_id \ + --config $train_config \ + --data_type $data_type \ + --symbol_table $dict \ + --train_data data/$train_set/data.list \ + --cv_data data/dev/data.list \ + ${checkpoint:+--checkpoint $checkpoint} \ + --model_dir $dir \ + --ddp.init_method $init_method \ + --ddp.world_size $world_size \ + --ddp.rank $rank \ + --ddp.dist_backend $dist_backend \ + --num_workers ${num_workers} \ + --prefetch ${prefetch} \ + $cmvn_opts \ + --pin_memory + } & + done + wait + fi fi if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then @@ -171,8 +208,8 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then # non-streaming model. The default value is -1, which is full chunk # for non-streaming inference. decoding_chunk_size= - ctc_weight=0.5 - reverse_weight=0.0 + ctc_weight=0.3 + reverse_weight=0.5 for mode in ${decode_modes}; do { test_dir=$dir/test_${mode} @@ -298,4 +335,4 @@ if [ ${stage} -le 9 ] && [ ${stop_stage} -ge 9 ]; then # --lfmmi_dir data/local/lfmmi # 9.3 Run HLG decode from stage 8.2 -fi \ No newline at end of file +fi diff --git a/requirements.txt b/requirements.txt index 2886b6c80..f86d6f953 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,3 +15,4 @@ pycodestyle==2.6.0 pyflakes==2.2.0 torch==1.13.0 torchaudio==0.13.0 +deepspeed diff --git a/tools/filter_uneven_data.py b/tools/filter_uneven_data.py new file mode 100755 index 000000000..443f87300 --- /dev/null +++ b/tools/filter_uneven_data.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright [2023-04-27] + +import os +import random +import tarfile + +random.seed(1024) + +# parse arg from command line +datalist = os.sys.argv[1] +datatype = os.sys.argv[2] +num_gpus = int(os.sys.argv[3]) +num_samples_per_tar = int(os.sys.argv[4]) # only used in shard mode +new_datalist = os.sys.argv[5] + +assert datatype in ["shard", "raw"] + + +filtered_list = [] +with open(datalist, "r") as f: + lines = f.readlines() + lines = [l.strip() for l in lines] + if datatype == "raw": + valid_num = len(lines) // num_gpus * num_gpus + random.shuffle(lines) + filtered_list = lines[:valid_num] + else: + for line in lines: + cnt = 0 + with open(line, "rb") as tar: + stream = tarfile.open(fileobj=tar, mode="r|*") + for tarinfo in stream: + name = tarinfo.name + pos = name.rfind('.') + assert pos > 0 + prefix, postfix = name[:pos], name[pos + 1:] + if postfix == 'txt': + cnt += 1 + if cnt == num_samples_per_tar: + filtered_list.append(line) + valid_num = len(filtered_list) // num_gpus * num_gpus + random.shuffle(filtered_list) + filtered_list = filtered_list[:valid_num] + filtered_list.sort() + print("before filter: {} after filter: {}".format(len(lines), len(filtered_list))) + +with open(new_datalist, "w") as f: + for line in filtered_list: + f.writelines("{}\n".format(line)) diff --git a/wenet/bin/train.py b/wenet/bin/train.py index cd6b04b1e..da9a6f6bb 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -16,6 +16,9 @@ import argparse import copy +import datetime +import deepspeed +import json import logging import os @@ -23,6 +26,10 @@ import torch.distributed as dist import torch.optim as optim import yaml + +from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live # noqa +from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_live # noqa +from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict from tensorboardX import SummaryWriter from torch.utils.data import DataLoader @@ -121,6 +128,19 @@ def get_args(): required=False, help='LF-MMI dir') + # Begin deepspeed related config + parser.add_argument('--local_rank', type=int, default=-1, + help='local rank passed from distributed launcher') + parser.add_argument('--deepspeed.save_states', + dest='save_states', + default='model_only', + choices=['model_only', 'model+optimizer'], + help='save model/optimizer states') + # End deepspeed related config + + + # DeepSpeed automaticly add '--deepspeed' and '--deepspeed_config' to parser + parser = deepspeed.add_config_arguments(parser) args = parser.parse_args() return args @@ -129,7 +149,9 @@ def main(): args = get_args() logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s') - os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) + # NOTE(xcsong): deepspeed set CUDA_VISIBLE_DEVICES internally + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) if not args.deepspeed \ + else os.environ['CUDA_VISIBLE_DEVICES'] # Set random seed torch.manual_seed(777) @@ -137,14 +159,37 @@ def main(): configs = yaml.load(fin, Loader=yaml.FullLoader) if len(args.override_config) > 0: configs = override_config(configs, args.override_config) + if args.deepspeed: + with open(args.deepspeed_config, 'r') as fin: + ds_configs = json.load(fin) + if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]: + configs["ds_dtype"] = "fp16" + elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]: + configs["ds_dtype"] = "bf16" + else: + configs["ds_dtype"] = "fp32" + # deepspeed read world_size from env + if args.deepspeed: + assert args.world_size == -1 + # distributed means pytorch native ddp, it parse world_size from args distributed = args.world_size > 1 + local_rank = args.rank + world_size = args.world_size if distributed: logging.info('training on multiple gpus, this gpu {}'.format(args.gpu)) dist.init_process_group(args.dist_backend, init_method=args.init_method, - world_size=args.world_size, - rank=args.rank) + world_size=world_size, + rank=local_rank) + elif args.deepspeed: + # Update local_rank & world_size from enviroment variables + local_rank = int(os.environ['LOCAL_RANK']) + world_size = int(os.environ['WORLD_SIZE']) + deepspeed.init_distributed(dist_backend=args.dist_backend, + init_method=args.init_method, + rank=local_rank, + world_size=world_size) symbol_table = read_symbol_table(args.symbol_table) @@ -157,6 +202,34 @@ def main(): cv_conf['shuffle'] = False non_lang_syms = read_non_lang_symbols(args.non_lang_syms) + # NOTE(xcsong): DeepSpeed does not support uneven data. When using custom + # dataset, we need to manually ensure that the data is evenly distributed + # across all processe. we impl `tools/filter_uneven_data.py` for this func + # ref: https://github.com/microsoft/DeepSpeed/issues/2223 + # + # NOTE(xsong): We also need to keep + # `train_micro_batch_size_per_gpu == 1` + # and + # `accum_grad (in train_confomrer.yaml) + # == gradient_accumulation_steps (in ds_config.json)` + # The reason for such consistence checking lies in that deepspeed's + # dataloader uses PyTorch's torch.utils.data.DistributedSampler which does + # not support IterableDataset, IterableDataset is extremly useful in large + # scale training because it lets you stream the data without having to + # download the complete dataset. + # ref: https://github.com/microsoft/DeepSpeed/issues/1371 + # https://github.com/microsoft/DeepSpeed/issues/285 + # + # To make deepspeed training compatible with IterableDataset, we have to + # use custom dataloader instead of deepspeed's native loader and thus we + # should configure batchsize in train_confomrer.yaml instead of + # ds_config.json. On the contrary, gradient accumulation steps should be + # configured in ds_config.json since it will be handled by deepspeed. + # ref: https://github.com/microsoft/DeepSpeed/issues/62 + if args.deepspeed: + assert train_conf['batch_conf']['batch_type'] == "static" + assert ds_configs["train_micro_batch_size_per_gpu"] == 1 + configs['accum_grad'] = ds_configs["gradient_accumulation_steps"] train_dataset = Dataset(args.data_type, args.train_data, symbol_table, train_conf, args.bpe_model, non_lang_syms, True) cv_dataset = Dataset(args.data_type, @@ -191,7 +264,7 @@ def main(): configs['is_json_cmvn'] = True configs['lfmmi_dir'] = args.lfmmi_dir - if args.rank == 0: + if local_rank == 0: saved_config_path = os.path.join(args.model_dir, 'train.yaml') with open(saved_config_path, 'w') as fout: data = yaml.dump(configs) @@ -199,14 +272,14 @@ def main(): # Init asr model from configs model = init_model(configs) - print(model) + print(model) if local_rank == 0 else None num_params = sum(p.numel() for p in model.parameters()) - print('the number of model params: {:,d}'.format(num_params)) + print('the number of model params: {:,d}'.format(num_params)) if local_rank == 0 else None # noqa # !!!IMPORTANT!!! # Try to export the model by script, if fails, we should refine # the code to satisfy the script export requirements - if args.rank == 0: + if local_rank == 0: script_model = torch.jit.script(model) script_model.save(os.path.join(args.model_dir, 'init.zip')) executor = Executor() @@ -225,12 +298,12 @@ def main(): num_epochs = configs.get('max_epoch', 100) model_dir = args.model_dir writer = None - if args.rank == 0: + if local_rank == 0: os.makedirs(model_dir, exist_ok=True) exp_id = os.path.basename(model_dir) writer = SummaryWriter(os.path.join(args.tensorboard_dir, exp_id)) - if distributed: + if distributed: # native pytorch ddp assert (torch.cuda.is_available()) # cuda model is required for nn.parallel.DistributedDataParallel model.cuda() @@ -244,6 +317,18 @@ def main(): model.register_comm_hook( state=None, hook=comm_hooks.fp16_compress_hook ) + elif args.deepspeed: # deepspeed + # NOTE(xcsong): look in detail how the memory estimator API works: + # https://deepspeed.readthedocs.io/en/latest/memory.html#discussion + if local_rank == 0: + logging.info("Estimating model states memory needs (zero2)...") + estimate_zero2_model_states_mem_needs_all_live( + model, num_gpus_per_node=world_size, num_nodes=1) + logging.info("Estimating model states memory needs (zero3)...") + estimate_zero3_model_states_mem_needs_all_live( + model, num_gpus_per_node=world_size, num_nodes=1) + device = None # Init device later + pass # Init DeepSpeed later else: use_cuda = args.gpu >= 0 and torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') @@ -255,18 +340,51 @@ def main(): optimizer = optim.AdamW(model.parameters(), **configs['optim_conf']) else: raise ValueError("unknown optimizer: " + configs['optim']) + scheduler_type = None if configs['scheduler'] == 'warmuplr': + scheduler_type = WarmupLR scheduler = WarmupLR(optimizer, **configs['scheduler_conf']) elif configs['scheduler'] == 'NoamHoldAnnealing': + scheduler_type = NoamHoldAnnealing scheduler = NoamHoldAnnealing(optimizer, **configs['scheduler_conf']) else: raise ValueError("unknown scheduler: " + configs['scheduler']) + # NOTE(xcsong): Custom optimizer might yield poor performance when + # zero-offload is enabled, if you do want to offload optimizer to CPU, + # please set optimizer in ds_config.json, see: + # (https://www.deepspeed.ai/docs/config-json/#optimizer-parameters) + if args.deepspeed: + if "optimizer" in ds_configs: + # NOTE(xcsong): Disable custom optimizer if it is set in ds_config, + # extremely useful when enable cpu_offload, DeepspeedCpuAdam + # could be 4~5x faster than torch native adam + optimizer = None + if "scheduler" in ds_configs: + scheduler = None + else: + def scheduler(opt): + return scheduler_type(opt, **configs['scheduler_conf']) + model, optimizer, _, scheduler = deepspeed.initialize( + args=args, model=model, optimizer=optimizer, + lr_scheduler=scheduler, model_parameters=model.parameters()) + final_epoch = None - configs['rank'] = args.rank - configs['is_distributed'] = distributed + configs['rank'] = local_rank + configs['is_distributed'] = distributed # pytorch native ddp + configs['is_deepspeed'] = args.deepspeed # deepspeed configs['use_amp'] = args.use_amp - if start_epoch == 0 and args.rank == 0: + if args.deepspeed and start_epoch == 0: + # NOTE(xcsong): All ranks should call this API, but only rank 0 + # save the general model params. see: + # https://github.com/microsoft/DeepSpeed/issues/2993 + with torch.no_grad(): + model.save_checkpoint(save_dir=model_dir, tag='init') + if args.save_states == "model_only" and local_rank == 0: + convert_zero_checkpoint_to_fp32_state_dict( + model_dir, "{}/init.pt".format(model_dir), tag='init') + os.system("rm -rf {}/{}".format(model_dir, "init")) + elif not args.deepspeed and start_epoch == 0 and local_rank == 0: save_model_path = os.path.join(model_dir, 'init.pt') save_checkpoint(model, save_model_path) @@ -283,6 +401,7 @@ def main(): configs['epoch'] = epoch lr = optimizer.param_groups[0]['lr'] logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr)) + device = model.local_rank if args.deepspeed else device executor.train(model, optimizer, scheduler, train_data_loader, device, writer, configs, scaler) total_loss, num_seen_utts = executor.cv(model, cv_data_loader, device, @@ -290,20 +409,35 @@ def main(): cv_loss = total_loss / num_seen_utts logging.info('Epoch {} CV info cv_loss {}'.format(epoch, cv_loss)) - if args.rank == 0: - save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch)) - save_checkpoint( - model, save_model_path, { - 'epoch': epoch, - 'lr': lr, - 'cv_loss': cv_loss, - 'step': executor.step - }) + infos = { + 'epoch': epoch, 'lr': lr, 'cv_loss': cv_loss, 'step': executor.step, + 'save_time': datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S') + } + if local_rank == 0: writer.add_scalar('epoch/cv_loss', cv_loss, epoch) writer.add_scalar('epoch/lr', lr, epoch) + with open("{}/{}.yaml".format(model_dir, epoch), 'w') as fout: + data = yaml.dump(infos) + fout.write(data) + if args.deepspeed: + # NOTE(xcsong): All ranks should call this API, but only rank 0 + # save the general model params. see: + # https://github.com/microsoft/DeepSpeed/issues/2993 + with torch.no_grad(): + model.save_checkpoint(save_dir=model_dir, + tag='{}'.format(epoch), + client_state=infos) + if args.save_states == "model_only" and local_rank == 0: + convert_zero_checkpoint_to_fp32_state_dict( + model_dir, "{}/{}.pt".format(model_dir, epoch), + tag='{}'.format(epoch)) + os.system("rm -rf {}/{}".format(model_dir, epoch)) + elif not args.deepspeed and local_rank == 0: + save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch)) + save_checkpoint(model, save_model_path, infos) final_epoch = epoch - if final_epoch is not None and args.rank == 0: + if final_epoch is not None and local_rank == 0: final_model_path = os.path.join(model_dir, 'final.pt') os.remove(final_model_path) if os.path.exists(final_model_path) else None os.symlink('{}.pt'.format(final_epoch), final_model_path) diff --git a/wenet/utils/executor.py b/wenet/utils/executor.py index dc0b69e6e..a128f6d6c 100644 --- a/wenet/utils/executor.py +++ b/wenet/utils/executor.py @@ -37,7 +37,15 @@ def train(self, model, optimizer, scheduler, data_loader, device, writer, epoch = args.get('epoch', 0) accum_grad = args.get('accum_grad', 1) is_distributed = args.get('is_distributed', True) + is_deepspeed = args.get('is_deepspeed', False) use_amp = args.get('use_amp', False) + ds_dtype = args.get('ds_dtype', "fp32") + if ds_dtype == "fp16": + ds_dtype = torch.float16 + elif ds_dtype == "bf16": + ds_dtype = torch.bfloat16 + else: + ds_dtype = None logging.info('using accumulate grad, new batch size is {} times' ' larger than before'.format(accum_grad)) if use_amp: @@ -71,20 +79,44 @@ def train(self, model, optimizer, scheduler, data_loader, device, writer, else: context = nullcontext with context(): - # autocast context - # The more details about amp can be found in - # https://pytorch.org/docs/stable/notes/amp_examples.html - with torch.cuda.amp.autocast(scaler is not None): - loss_dict = model(feats, feats_lengths, target, - target_lengths) - loss = loss_dict['loss'] / accum_grad - if use_amp: - scaler.scale(loss).backward() - else: - loss.backward() + if is_deepspeed: # deepspeed + with torch.cuda.amp.autocast( + enabled=ds_dtype is not None, + dtype=ds_dtype, cache_enabled=False + ): + loss_dict = model(feats, feats_lengths, target, + target_lengths) + loss = loss_dict['loss'] + # NOTE(xcsong): Zeroing the gradients is handled automatically by DeepSpeed after the weights # noqa + # have been updated using a mini-batch. DeepSpeed also performs gradient averaging automatically # noqa + # at the gradient accumulation boundaries and addresses clip_grad_norm internally. In other words # noqa + # `model.backward(loss)` is equivalent to `loss.backward() + clip_grad_norm_() + optimizer.zero_grad() + accum_grad` # noqa + # ref: https://www.deepspeed.ai/tutorials/megatron/#using-the-training-api # noqa + model.backward(loss) + else: # pytorch native ddp + # autocast context + # The more details about amp can be found in + # https://pytorch.org/docs/stable/notes/amp_examples.html + with torch.cuda.amp.autocast(scaler is not None): + loss_dict = model(feats, feats_lengths, target, + target_lengths) + loss = loss_dict['loss'] / accum_grad + if use_amp: + scaler.scale(loss).backward() + else: + loss.backward() num_seen_utts += num_utts - if batch_idx % accum_grad == 0: + if is_deepspeed: + if rank == 0 and writer is not None \ + and model.is_gradient_accumulation_boundary(): + writer.add_scalar('train_loss', loss.item(), self.step) + # NOTE(xcsong): The step() function in DeepSpeed engine updates the model parameters as well as the learning rate. There is # noqa + # no need to manually perform scheduler.step(). In other words: `ds_model.step() = optimizer.step() + scheduler.step()` # noqa + # ref: https://www.deepspeed.ai/tutorials/megatron/#using-the-training-api # noqa + model.step() + self.step += 1 + elif not is_deepspeed and batch_idx % accum_grad == 0: if rank == 0 and writer is not None: writer.add_scalar('train_loss', loss, self.step) # Use mixed precision training @@ -125,6 +157,14 @@ def cv(self, model, data_loader, device, args): rank = args.get('rank', 0) epoch = args.get('epoch', 0) log_interval = args.get('log_interval', 10) + is_deepspeed = args.get('is_deepspeed', False) + ds_dtype = args.get('ds_dtype', "fp32") + if ds_dtype == "fp16": + ds_dtype = torch.float16 + elif ds_dtype == "bf16": + ds_dtype = torch.bfloat16 + else: # fp32 + ds_dtype = None # in order to avoid division by 0 num_seen_utts = 1 total_loss = 0.0 @@ -138,7 +178,15 @@ def cv(self, model, data_loader, device, args): num_utts = target_lengths.size(0) if num_utts == 0: continue - loss_dict = model(feats, feats_lengths, target, target_lengths) + if is_deepspeed: + with torch.cuda.amp.autocast( + enabled=ds_dtype is not None, + dtype=ds_dtype, cache_enabled=False + ): + loss_dict = model(feats, feats_lengths, + target, target_lengths) + else: + loss_dict = model(feats, feats_lengths, target, target_lengths) loss = loss_dict['loss'] if torch.isfinite(loss): num_seen_utts += num_utts