diff --git a/.github/labeler.yml b/.github/labeler.yml index 618fe693c4562..70134b84e5fea 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -34,6 +34,13 @@ TTS: - tests/collections/tts/** - tests/collections/common/tokenizers/text_to_speech/** +Audio: +- nemo/collections/audio/**/* +- examples/audio/**/* +- tutorials/audio/**/* +- docs/source/audio/**/* +- tests/collections/audio/** + core: - nemo/core/**/* - tests/core/** diff --git a/.github/workflows/_test_template.yml b/.github/workflows/_test_template.yml index 5956a23bdd67f..ebdc99cef8471 100644 --- a/.github/workflows/_test_template.yml +++ b/.github/workflows/_test_template.yml @@ -34,9 +34,15 @@ on: description: Last 2000 characters of the test step's log value: ${{ jobs.main.outputs.log }} jobs: + runner-auto-clean: + runs-on: ${{ inputs.RUNNER }} + steps: + - name: Docker system cleanup + run: | + docker system prune -a --filter "until=48h" --force + main: runs-on: ${{ inputs.RUNNER }} - timeout-minutes: ${{ inputs.TIMEOUT }} outputs: conclusion: ${{ steps.main.conclusion }} log: ${{ steps.main.outputs.log }} @@ -54,6 +60,7 @@ jobs: uses: actions/checkout@v4 - id: main name: Run main script + timeout-minutes: ${{ inputs.TIMEOUT }} run: | set +e ( diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 77d97fd6e061e..10cd8d1e6561b 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -95,12 +95,12 @@ jobs: ### \'\' - OPTIONAL_L0_Unit_Tests_GPU: + L0_Unit_Tests_GPU: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml with: RUNNER: self-hosted-azure - TIMEOUT: 30 + TIMEOUT: 60 SCRIPT: | NEMO_NUMBA_MINVER=0.53 pytest -m "not pleasefixme" --with_downloads IS_OPTIONAL: true @@ -213,7 +213,7 @@ jobs: with: RUNNER: self-hosted-azure SCRIPT: | - python examples/nlp/language_modeling/megatron_gpt_quantization.py \ + python examples/nlp/language_modeling/megatron_gpt_ptq.py \ model.restore_from_path=/home/TestData/nlp/megatron_llama/llama_ci.nemo \ quantization.algorithm=null \ export.save_path=/home/TestData/nlp/megatron_llama/ci_baseline @@ -226,7 +226,7 @@ jobs: with: RUNNER: self-hosted-azure SCRIPT: | - python examples/nlp/language_modeling/megatron_gpt_quantization.py \ + python examples/nlp/language_modeling/megatron_gpt_ptq.py \ model.restore_from_path=/home/TestData/nlp/megatron_llama/llama_ci.nemo \ model.tensor_model_parallel_size=2 \ trainer.devices=2 \ @@ -245,7 +245,7 @@ jobs: with: RUNNER: self-hosted-azure SCRIPT: | - python examples/nlp/language_modeling/megatron_gpt_quantization.py \ + python examples/nlp/language_modeling/megatron_gpt_ptq.py \ model.restore_from_path=/home/TestData/nlp/megatron_llama/llama_ci.nemo \ quantization.calib_dataset=/home/TestData/nlp/test_quantization/test.json \ quantization.algorithm=int8_sq \ @@ -274,7 +274,7 @@ jobs: # - name: Checkout repository # uses: actions/checkout@v4 # - run: | - # python examples/nlp/language_modeling/megatron_gpt_quantization.py \ + # python examples/nlp/language_modeling/megatron_gpt_ptq.py \ # model.restore_from_path=/home/TestData/nlp/megatron_llama/llama_ci.nemo \ # model.tensor_model_parallel_size=1 \ # trainer.devices=1 \ @@ -288,6 +288,45 @@ jobs: #- uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" # if: "failure()" + L2_QAT_Llama2_INT4: + needs: [cicd-test-container-setup] + runs-on: self-hosted-azure + timeout-minutes: 10 + container: + image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} + options: + # --user 0:128 + --device=/dev/nvidia0 + --gpus all + --shm-size=8g + --env TRANSFORMERS_OFFLINE=0 + --env HYDRA_FULL_ERROR=1 + --volume /mnt/datadrive/TestData:/home/TestData + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - run: | + python examples/nlp/language_modeling/tuning/megatron_gpt_qat.py \ + quantization.algorithm=int4 \ + quantization.num_calib_size=8 \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + trainer.max_steps=4 \ + trainer.val_check_interval=4 \ + +trainer.limit_val_batches=2 \ + exp_manager.explicit_log_dir=llama2_qat_results \ + model.restore_from_path=/home/TestData/nlp/megatron_llama/llama_ci.nemo \ + model.tensor_model_parallel_size=1 \ + model.pipeline_model_parallel_size=1 \ + model.global_batch_size=2 \ + model.data.train_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl] \ + model.data.train_ds.concat_sampling_probabilities=[1.0] \ + model.data.validation_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl] + + rm -rf llama2_qat_results + - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" + if: "failure()" + # L2: ASR dev run ASR_dev_run_Speech_to_Text: needs: [cicd-test-container-setup] @@ -2352,7 +2391,7 @@ jobs: L2_Megatron_GPT_Pretraining_and_Resume_Training_TP2: needs: [cicd-test-container-setup] - runs-on: self-hosted-azure + runs-on: self-hosted-azure-gpus-2-h100 timeout-minutes: 10 container: image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} @@ -2364,6 +2403,21 @@ jobs: --env TRANSFORMERS_OFFLINE=0 --env HYDRA_FULL_ERROR=1 --volume /mnt/datadrive/TestData:/home/TestData + env: + # This is to improve p2p overlap on H100 + NVTE_FWD_LAYERNORM_SM_MARGIN: 8 + NVTE_BWD_LAYERNORM_SM_MARGIN: 8 + TORCH_NCCL_AVOID_RECORD_STREAMS: 1 + NCCL_MIN_NCHANNELS: 4 + # TP overlap is not supported in docker environment + #NVTE_UB_SPLIT_RS: 0 + #NVTE_UB_ATOMIC_GEMM_RS: 1 + #NVTE_RS_STRIDED_ATOMIC: 1 + #NVTE_UB_FP8_RS: 1 + # Increase p2p chunksize to 2MB + NCCL_P2P_NET_CHUNKSIZE: 2097152 + # Disable gc when switching to/from validation steps + NEMO_MANUAL_GC_IN_VALIDATION: 0 steps: - name: Checkout repository uses: actions/checkout@v4 @@ -2378,8 +2432,17 @@ jobs: trainer.max_steps=3 \ trainer.gradient_clip_val=1.0 \ exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ + ++model.transformer_engine=True \ + ++model.fp8=True \ + ++model.fp8_hybrid=True \ + ++model.fp8_amax_history_len=1024 \ + ++model.fp8_amax_compute_algo=max \ + ++model.reduce_amax=True \ + ++model.use_te_rng_tracker=True \ + ++model.name=megatron_gpt_full_te_layer_autocast \ + model.ub_tp_comm_overlap=False \ model.tensor_model_parallel_size=2 \ - model.optim.name=fused_adam \ + model.optim.name=distributed_fused_adam \ model.optim.lr=2e-4 \ model.optim.sched.warmup_steps=1 \ model.optim.sched.constant_steps=1 \ @@ -2413,8 +2476,17 @@ jobs: trainer.gradient_clip_val=1.0 \ exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ exp_manager.resume_if_exists=True \ + ++model.transformer_engine=True \ + ++model.fp8=True \ + ++model.fp8_hybrid=True \ + ++model.fp8_amax_history_len=1024 \ + ++model.fp8_amax_compute_algo=max \ + ++model.reduce_amax=True \ + ++model.use_te_rng_tracker=True \ + ++model.name=megatron_gpt_full_te_layer_autocast \ + model.ub_tp_comm_overlap=False \ model.tensor_model_parallel_size=2 \ - model.optim.name=fused_adam \ + model.optim.name=distributed_fused_adam \ model.optim.lr=2e-4 \ model.optim.sched.warmup_steps=2 \ model.optim.sched.constant_steps=2 \ @@ -2630,6 +2702,89 @@ jobs: # } # } + L2_Megatron_GPT_with_ResetLR_Pretraining_and_Resume_Training_TP2: + needs: [cicd-test-container-setup] + runs-on: self-hosted-azure + timeout-minutes: 10 + container: + image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} + options: + # --user 0:128 + --device=/dev/nvidia0 + --gpus all + --shm-size=8g + --env TRANSFORMERS_OFFLINE=0 + --env HYDRA_FULL_ERROR=1 + --volume /mnt/datadrive/TestData:/home/TestData + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - run: | + python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ + trainer.devices=2 \ + trainer.accelerator=gpu \ + trainer.log_every_n_steps=1 \ + trainer.val_check_interval=3 \ + trainer.limit_val_batches=2 \ + trainer.accumulate_grad_batches=1 \ + trainer.max_steps=3 \ + trainer.precision=bf16 \ + trainer.gradient_clip_val=1.0 \ + exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ + model.tensor_model_parallel_size=2 \ + model.megatron_amp_O2=True \ + model.optim.name=distributed_fused_adam \ + model.optim.lr=2e-4 \ + model.optim.sched.warmup_steps=2 \ + model.optim.sched.constant_steps=2 \ + model.optim.sched.min_lr=8e-5 \ + model.max_position_embeddings=128 \ + model.encoder_seq_length=128 \ + model.data.seq_length=128 \ + model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ + model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ + model.num_layers=8 \ + model.hidden_size=256 \ + model.num_attention_heads=8 \ + model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ + model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings + + python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ + trainer.devices=2 \ + trainer.accelerator=gpu \ + trainer.log_every_n_steps=1 \ + trainer.val_check_interval=3 \ + trainer.limit_val_batches=2 \ + trainer.accumulate_grad_batches=1 \ + trainer.max_steps=6 \ + trainer.precision=bf16 \ + trainer.gradient_clip_val=1.0 \ + exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ + exp_manager.resume_if_exists=True \ + model.reset_lr=True \ + model.tensor_model_parallel_size=2 \ + model.megatron_amp_O2=True \ + model.optim.name=distributed_fused_adam \ + model.optim.lr=2e-4 \ + model.optim.sched.warmup_steps=2 \ + model.optim.sched.constant_steps=2 \ + model.optim.sched.min_lr=8e-5 \ + model.max_position_embeddings=128 \ + model.encoder_seq_length=128 \ + model.data.seq_length=128 \ + model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ + model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ + model.num_layers=8 \ + model.hidden_size=256 \ + model.num_attention_heads=8 \ + model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ + model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings + + rm -rf examples/nlp/language_modeling/gpt_pretrain_results + rm -rf examples/nlp/language_modeling/gpt_index_mappings + - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" + if: "failure()" + L2_Megatron_GPT_with_ALiBi_Pretraining_and_Resume_Training_TP2: needs: [cicd-test-container-setup] runs-on: self-hosted-azure @@ -2823,10 +2978,11 @@ jobs: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml with: - RUNNER: self-hosted-azure + RUNNER: self-hosted-azure-gpus-2-h100 SCRIPT: | python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ trainer.devices=2 \ + trainer.accelerator=gpu \ trainer.log_every_n_steps=1 \ trainer.val_check_interval=2 \ trainer.limit_val_batches=2 \ @@ -2835,6 +2991,15 @@ jobs: trainer.precision=bf16 \ trainer.gradient_clip_val=1.0 \ exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ + ++model.transformer_engine=True \ + ++model.fp8=True \ + ++model.fp8_hybrid=True \ + ++model.fp8_amax_history_len=1024 \ + ++model.fp8_amax_compute_algo=max \ + ++model.reduce_amax=True \ + ++model.use_te_rng_tracker=True \ + ++model.name=megatron_gpt_full_te_layer_autocast \ + model.ub_tp_comm_overlap=False \ model.pipeline_model_parallel_size=2 \ model.tensor_model_parallel_size=1 \ model.mcore_gpt=True \ @@ -2859,12 +3024,15 @@ jobs: model.hidden_size=256 \ model.num_attention_heads=8 \ model.activations_checkpoint_method=block \ + model.activations_checkpoint_granularity=full \ model.activations_checkpoint_num_layers=1 \ + model.data.validation_drop_last=False \ model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ trainer.devices=2 \ + trainer.accelerator=gpu \ trainer.log_every_n_steps=1 \ trainer.val_check_interval=2 \ trainer.limit_val_batches=2 \ @@ -2876,6 +3044,15 @@ jobs: model.megatron_amp_O2=True \ exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ exp_manager.resume_if_exists=True \ + ++model.transformer_engine=True \ + ++model.fp8=True \ + ++model.fp8_hybrid=True \ + ++model.fp8_amax_history_len=1024 \ + ++model.fp8_amax_compute_algo=max \ + ++model.reduce_amax=True \ + ++model.use_te_rng_tracker=True \ + ++model.name=megatron_gpt_full_te_layer_autocast \ + model.ub_tp_comm_overlap=False \ model.pipeline_model_parallel_size=2 \ model.tensor_model_parallel_size=1 \ model.optim.name=distributed_fused_adam \ @@ -2898,7 +3075,9 @@ jobs: model.hidden_size=256 \ model.num_attention_heads=8 \ model.activations_checkpoint_method=block \ + model.activations_checkpoint_granularity=full \ model.activations_checkpoint_num_layers=1 \ + model.data.validation_drop_last=False \ model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings AFTER_SCRIPT: | @@ -3019,6 +3198,47 @@ jobs: - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" if: "failure()" + L2_Megatron_GPT_Reranker: + needs: [cicd-test-container-setup] + runs-on: self-hosted-azure + timeout-minutes: 10 + container: + image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} + options: + # --user 0:128 + --device=/dev/nvidia0 + --gpus all + --shm-size=8g + --env TRANSFORMERS_OFFLINE=0 + --env HYDRA_FULL_ERROR=1 + --volume /mnt/datadrive/TestData:/home/TestData + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - run: | + rm -rf /home/TestData/nlp/megatron_ir/working_dir + + python examples/nlp/information_retrieval/megatron_gpt_reranker_finetuning.py \ + exp_manager.exp_dir='/home/TestData/nlp/megatron_ir/working_dir' \ + model.global_batch_size=4 \ + model.micro_batch_size=4 \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + trainer.max_epochs=null \ + trainer.max_steps=20 \ + trainer.val_check_interval=10 \ + model.restore_from_path='/home/TestData/nlp/megatron_gpt/mcore_45M/megatron_llama.nemo' \ + model.peft.lora_tuning.adapter_dim=8 \ + model.data.validation_ds.file_names=[/home/TestData/nlp/megatron_ir/train.jsonl] \ + model.data.validation_ds.write_embeddings_to_file=True \ + model.data.validation_ds.output_file_path_prefix='/home/TestData/nlp/megatron_ir/working_dir/val_embs' \ + model.data.train_ds.file_names=[/home/TestData/nlp/megatron_ir/train.jsonl] + + + rm -rf /home/TestData/nlp/megatron_ir/working_dir + - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" + if: "failure()" + L2_Megatron_GPT_Embedding: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml @@ -3366,6 +3586,80 @@ jobs: rm -rf examples/nlp/language_modeling/t5_pretrain_results rm -rf examples/nlp/language_modeling/t5_index_mappings + L2_Megatron_Core_T5_Pretraining_and_Resume_Training_TP2: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + with: + RUNNER: self-hosted-azure + SCRIPT: | + NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python examples/nlp/language_modeling/megatron_t5_pretraining.py \ + trainer.devices=2 \ + trainer.log_every_n_steps=1 \ + trainer.max_epochs=null \ + trainer.max_steps=10 \ + trainer.val_check_interval=10 \ + trainer.accumulate_grad_batches=1 \ + trainer.precision=bf16 \ + model.megatron_amp_O2=True \ + exp_manager.exp_dir=examples/nlp/language_modeling/t5_pretrain_results \ + model.mcore_t5=True \ + model.transformer_engine=True \ + model.tensor_model_parallel_size=2 \ + model.micro_batch_size=4 \ + model.global_batch_size=4 \ + model.seq_length=128 \ + model.encoder.num_layers=4 \ + model.encoder.hidden_size=64 \ + model.encoder.num_attention_heads=8 \ + model.decoder.num_layers=4 \ + model.decoder.hidden_size=64 \ + model.decoder.num_attention_heads=8 \ + model.encoder.transformer_block_type='pre_ln' \ + model.decoder.transformer_block_type='pre_ln' \ + model.data.data_prefix=[.5,/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src,.5,/home/TestData/nlp/nmt/toy_data/wmt14-de-en.ref] \ + model.data.index_mapping_dir=examples/nlp/language_modeling/t5_index_mappings \ + model.data.data_impl=text_mmap \ + +model.data.data_impl_kwargs.newline_int=10 \ + +model.data.data_impl_kwargs.header_lines=0 \ + +model.data.data_impl_kwargs.workers=null \ + +model.data.data_impl_kwargs.sort_dataset_paths=False + + NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python examples/nlp/language_modeling/megatron_t5_pretraining.py \ + trainer.devices=2 \ + trainer.log_every_n_steps=1 \ + trainer.max_epochs=null \ + trainer.max_steps=10 \ + trainer.val_check_interval=10 \ + trainer.accumulate_grad_batches=1 \ + trainer.precision=bf16 \ + model.megatron_amp_O2=True \ + exp_manager.exp_dir=examples/nlp/language_modeling/t5_pretrain_results \ + exp_manager.resume_if_exists=True \ + model.mcore_t5=True \ + model.transformer_engine=True \ + model.tensor_model_parallel_size=2 \ + model.micro_batch_size=4 \ + model.global_batch_size=4 \ + model.seq_length=128 \ + model.encoder.num_layers=4 \ + model.encoder.hidden_size=64 \ + model.encoder.num_attention_heads=8 \ + model.decoder.num_layers=4 \ + model.decoder.hidden_size=64 \ + model.decoder.num_attention_heads=8 \ + model.encoder.transformer_block_type='pre_ln' \ + model.decoder.transformer_block_type='pre_ln' \ + model.data.data_prefix=[.5,/home/TestData/nlp/nmt/toy_data/wmt14-de-en.src,.5,/home/TestData/nlp/nmt/toy_data/wmt14-de-en.ref] \ + model.data.index_mapping_dir=examples/nlp/language_modeling/t5_index_mappings \ + model.data.data_impl=text_mmap \ + +model.data.data_impl_kwargs.newline_int=10 \ + +model.data.data_impl_kwargs.header_lines=0 \ + +model.data.data_impl_kwargs.workers=null \ + +model.data.data_impl_kwargs.sort_dataset_paths=False + AFTER_SCRIPT: | + rm -rf examples/nlp/language_modeling/t5_pretrain_results + rm -rf examples/nlp/language_modeling/t5_index_mappings + L2_Megatron_T5_with_ALiBi_Pretraining_and_Resume_Training_TP2: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml @@ -4236,7 +4530,7 @@ jobs: Nemo_CICD_Test: needs: - #- OPTIONAL_L0_Unit_Tests_GPU + - L0_Unit_Tests_GPU - L0_Unit_Tests_CPU - L2_Community_LLM_Checkpoints_tests_Llama - L2_Community_LLM_Checkpoints_tests_StarCoder @@ -4296,6 +4590,7 @@ jobs: - L2_BioMegatron_Bert_NER_Task - L2_Megatron_GPT_Pretraining_and_Resume_Training_TP2 - L2_Megatron_GPT_with_Rope_Pretraining_and_Resume_Training_TP2 + - L2_Megatron_GPT_with_ResetLR_Pretraining_and_Resume_Training_TP2 - L2_Megatron_GPT_with_ALiBi_Pretraining_and_Resume_Training_TP2 - L2_Megatron_GPT_with_KERPLE_Pretraining_and_Resume_Training_TP2 - L2_Megatron_GPT_Pretraining_and_Resume_Training_PP2 @@ -4310,6 +4605,7 @@ jobs: - L2_Megatron_Change_Partitions_Reduce_TP_Num_Partitions_-2_to_1-_and_PP_Num_Partitions_-1_to_2 - L2_Megatron_Change_Partitions_Increase_TP_Num_Partitions_-2_to_4-_and_PP_Num_Partitions_-1_to_2 - L2_Megatron_T5_Pretraining_and_Resume_Training_TP2 + - L2_Megatron_Core_T5_Pretraining_and_Resume_Training_TP2 - L2_Megatron_T5_with_ALiBi_Pretraining_and_Resume_Training_TP2 - L2_Megatron_T5_with_KERPLE_Pretraining_and_Resume_Training_TP2 - L2_Megatron_T5_Pretraining_and_Resume_Training_PP2 @@ -4351,7 +4647,9 @@ jobs: name: Checkout repository uses: actions/checkout@v4 - - if: ${{ always() && steps.pipeline-conclusion.outputs.FAILED == 'true' }} + - if: ${{ always() && steps.pipeline-conclusion.outputs.FAILED == 'true' && env.SLACK_WEBHOOK != '' }} + env: + SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }} run: | set -x diff --git a/.gitignore b/.gitignore index 1ff2a92cac64c..1aa5ef00de5ee 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ *.pkl #*.ipynb output +output_2048 result *.pt tests/data/asr @@ -179,3 +180,4 @@ examples/neural_graphs/*.yml .hydra/ nemo_experiments/ +slurm*.out diff --git a/Dockerfile b/Dockerfile index b03c3414e5055..a42ae592a9bd5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -167,12 +167,12 @@ COPY tutorials /workspace/nemo/tutorials RUN printf "#!/bin/bash\njupyter lab --no-browser --allow-root --ip=0.0.0.0" >> start-jupyter.sh && \ chmod +x start-jupyter.sh -# If required, install AIS CLI -RUN if [ "${REQUIRE_AIS_CLI}" = true ]; then \ - INSTALL_MSG=$(/bin/bash scripts/installers/install_ais_cli_latest.sh); INSTALL_CODE=$?; \ +# If required, install AIS CLI and Python AIS SDK +RUN INSTALL_MSG=$(/bin/bash /tmp/nemo/scripts/installers/install_ais_cli_latest.sh && pip install aistore); INSTALL_CODE=$?; \ echo ${INSTALL_MSG}; \ if [ ${INSTALL_CODE} -ne 0 ]; then \ echo "AIS CLI installation failed"; \ + if [ "${REQUIRE_AIS_CLI}" = true ]; then \ exit ${INSTALL_CODE}; \ - else echo "AIS CLI installed successfully"; fi \ - else echo "Skipping AIS CLI installation"; fi + else echo "Skipping AIS CLI installation"; fi \ + else echo "AIS CLI installed successfully"; fi diff --git a/Dockerfile.ci b/Dockerfile.ci index 04ba9df13c7a8..55c31e47f6d3c 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -32,9 +32,9 @@ EOF WORKDIR /workspace # Install NeMo requirements -ARG TE_TAG=bfe21c3d68b0a9951e5716fb520045db53419c5e -ARG MODELOPT_VERSION=0.11.0 -ARG MCORE_TAG=02871b4df8c69fac687ab6676c4246e936ce92d0 +ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea +ARG MODELOPT_VERSION=0.13.0 +ARG MCORE_TAG=0bc3547702464501feefeb5523b7a17e591b21fa ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c RUN \ --mount=type=bind,source=requirements,target=requirements \ @@ -47,7 +47,9 @@ pip install --no-cache-dir --no-build-isolation --extra-index-url https://pypi.n "megatron_core @ git+https://github.com/NVIDIA/Megatron-LM.git@${MCORE_TAG}" \ "nvidia-modelopt[torch]~=${MODELOPT_VERSION}" \ "apex @ git+https://github.com/NVIDIA/apex.git@${APEX_TAG}" \ +"unstructured==0.14.9" \ "llama-index==0.10.43" \ +"onnxscript @ git+https://github.com/microsoft/onnxscript" \ -r tools/ctc_segmentation/requirements.txt \ ".[all]" @@ -60,6 +62,22 @@ git checkout ${MCORE_TAG} && \ popd && \ popd export PYTHONPATH="${PYTHONPATH}:/workspace/Megatron-LM" + +# Mamba dependancy installation +git clone https://github.com/state-spaces/mamba.git && \ + cd mamba && \ + git checkout v2.0.3 && \ + python setup.py install && \ + cd .. && \ + rm -rf mamba + +git clone https://github.com/Dao-AILab/causal-conv1d && \ + cd causal-conv1d && \ + git checkout v1.2.2.post1 && \ + python setup.py install && \ + cd .. && \ + rm -rf causal-conv1d + EOF # Copy over NeMo code diff --git a/docs/source/asr/speaker_recognition/api.rst b/docs/source/asr/speaker_recognition/api.rst index 0f95cb281145a..cdadc4dd5f1d0 100644 --- a/docs/source/asr/speaker_recognition/api.rst +++ b/docs/source/asr/speaker_recognition/api.rst @@ -6,6 +6,6 @@ Model Classes ------------- .. autoclass:: nemo.collections.asr.models.label_models.EncDecSpeakerLabelModel :show-inheritance: - :members: setup_finetune_model, get_embedding, verify_speakers + :members: setup_finetune_model, get_embedding, verify_speakers, verify_speakers_batch diff --git a/docs/source/asr/speaker_recognition/results.rst b/docs/source/asr/speaker_recognition/results.rst index a6029595823fd..e607a35a49e68 100644 --- a/docs/source/asr/speaker_recognition/results.rst +++ b/docs/source/asr/speaker_recognition/results.rst @@ -91,7 +91,7 @@ Speaker Verification Inference Speaker Verification is a task of verifying if two utterances are from the same speaker or not. -We provide a helper function to verify the audio files and return True if two provided audio files are from the same speaker, False otherwise. +We provide a helper function to verify the audio files (also in a batch) and return True if provided pair of audio files is from the same speaker, False otherwise. The audio files should be 16KHz mono channel wav files. @@ -99,6 +99,12 @@ The audio files should be 16KHz mono channel wav files. speaker_model = EncDecSpeakerLabelModel.from_pretrained(model_name="titanet_large") decision = speaker_model.verify_speakers('path/to/one/audio_file','path/to/other/audio_file') + decisions = speaker_model.verify_speakers_batch([ + ('/path/to/audio_0_0', '/path/to/audio_0_1'), + ('/path/to/audio_1_0', '/path/to/audio_1_1'), + ('/path/to/audio_2_0', '/path/to/audio_2_1'), + ('/path/to/audio_3_0', '/path/to/audio_3_1') + ], batch_size=4, device='cuda') NGC Pretrained Checkpoints diff --git a/docs/source/core/exp_manager.rst b/docs/source/core/exp_manager.rst index efb55b0feabb9..ce5f7a9cb087c 100644 --- a/docs/source/core/exp_manager.rst +++ b/docs/source/core/exp_manager.rst @@ -203,10 +203,163 @@ file followed by a graceful exit from the run. The checkpoint saved upon preempt This feature is useful to increase utilization on clusters. The ``PreemptionCallback`` is enabled by default. To disable it simply add ``create_preemption_callback: False`` under exp_manager in the config YAML file. +Stragglers Detection +---------------------- -.. _nemo_multirun-label: +.. _exp_manager_straggler_det_support-label: + +.. note:: + Stragglers Detection feature is included in the optional NeMo resiliency package. + +Distributed training can be affected by stragglers, which are slow workers that slow down the overall training process. +NeMo provides a straggler detection feature that can identify slower GPUs. + +This feature is implemented in the ``StragglerDetectionCallback``, which is disabled by default. + +The callback computes normalized GPU performance scores, which are scalar values ranging from 0.0 (worst) to 1.0 (best). +A performance score can be interpreted as the ratio of current performance to reference performance. + +There are two types of performance scores provided by the callback: + - Relative GPU performance score: The best-performing GPU in the workload is used as a reference. + - Individual GPU performance score: The best historical performance of the GPU is used as a reference. + +Examples: + - If the relative performance score is 0.5, it means that a GPU is twice slower than the fastest GPU. + - If the individual performance score is 0.5, it means that a GPU is twice slower than its best observed performance. + +If a GPU performance score drops below the specified threshold, it is identified as a straggler. + +To enable straggler detection, add ``create_straggler_detection_callback: True`` under exp_manager in the config YAML file. +You might also want to adjust the callback parameters: + +.. code-block:: yaml + + exp_manager: + ... + create_straggler_detection_callback: True + straggler_detection_callback_params: + report_time_interval: 300 # Interval [seconds] of the straggler check + calc_relative_gpu_perf: True # Calculate relative GPU performance + calc_individual_gpu_perf: True # Calculate individual GPU performance + num_gpu_perf_scores_to_log: 5 # Log 5 best and 5 worst GPU performance scores, even if no stragglers are detected + gpu_relative_perf_threshold: 0.7 # Threshold for relative GPU performance scores + gpu_individual_perf_threshold: 0.7 # Threshold for individual GPU performance scores + stop_if_detected: True # Terminate the workload if stragglers are detected + +Straggler detection might involve inter-rank synchronization, and should be invoked with reasonable frequency (e.g. every few minutes). + +.. _exp_manager_straggler_det_support-label: + +.. note:: + Stragglers Detection feature is included in the optional NeMo resiliency package. + +Distributed training can be affected by stragglers, which are slow workers that slow down the overall training process. +NeMo provides a straggler detection feature that can identify slower GPUs. + +This feature is implemented in the ``StragglerDetectionCallback``, which is disabled by default. + +The callback computes normalized GPU performance scores, which are scalar values ranging from 0.0 (worst) to 1.0 (best). +A performance score can be interpreted as the ratio of current performance to reference performance. + +There are two types of performance scores provided by the callback: + - Relative GPU performance score: The best-performing GPU in the workload is used as a reference. + - Individual GPU performance score: The best historical performance of the GPU is used as a reference. +Examples: + - If the relative performance score is 0.5, it means that a GPU is twice slower than the fastest GPU. + - If the individual performance score is 0.5, it means that a GPU is twice slower than its best observed performance. +If a GPU performance score drops below the specified threshold, it is identified as a straggler. + +To enable straggler detection, add ``create_straggler_detection_callback: True`` under exp_manager in the config YAML file. +You might also want to adjust the callback parameters: + +.. code-block:: yaml + + exp_manager: + ... + create_straggler_detection_callback: True + straggler_detection_callback_params: + report_time_interval: 300 # Interval [seconds] of the straggler check + calc_relative_gpu_perf: True # Calculate relative GPU performance + calc_individual_gpu_perf: True # Calculate individual GPU performance + num_gpu_perf_scores_to_log: 5 # Log 5 best and 5 worst GPU performance scores, even if no stragglers are detected + gpu_relative_perf_threshold: 0.7 # Threshold for relative GPU performance scores + gpu_individual_perf_threshold: 0.7 # Threshold for individual GPU performance scores + stop_if_detected: True # Terminate the workload if stragglers are detected + +Straggler detection might involve inter-rank synchronization, and should be invoked with reasonable frequency (e.g. every few minutes). + +Fault Tolerance +--------------- + +.. _exp_manager_fault_tolerance_support-label: + +.. note:: + Fault Tolerance feature is included in the optional NeMo resiliency package. + +When training DNN models, faults may occur, hindering the progress of the entire training process. +This is particularly common in distributed, multi-node training scenarios, with many nodes and GPUs involved. + +NeMo incorporates a fault tolerance mechanism to detect training halts. +In response, it can terminate a hung workload and, if requested, restart it from the last checkpoint. + +Fault tolerance ("FT") relies on a special launcher (``ft_launcher``), which is a modified ``torchrun``. +The FT launcher runs background processes called rank monitors. **You need to use ft_launcher to start +your workload if you are using FT**. I.e., `NeMo-Framework-Launcher `_ +can be used to generate SLURM batch scripts with FT support. + +Each training process (rank) sends `heartbeats` to its monitor during training and validation steps. +If a rank monitor stops receiving `heartbeats`, a training failure is detected. + +Fault detection is implemented in the ``FaultToleranceCallback`` and is disabled by default. +To enable it, add a ``create_fault_tolerance_callback: True`` option under ``exp_manager`` in the +config YAML file. Additionally, you can customize FT parameters by adding ``fault_tolerance`` section: + +.. code-block:: yaml + + exp_manager: + ... + create_fault_tolerance_callback: True + fault_tolerance: + initial_rank_heartbeat_timeout: 600 # wait for 10 minutes for the initial heartbeat + rank_heartbeat_timeout: 300 # wait for 5 minutes for subsequent heartbeats + calculate_timeouts: True # estimate more accurate timeouts based on observed intervals + +Timeouts for fault detection need to be adjusted for a given workload: + * ``initial_rank_heartbeat_timeout`` should be long enough to allow for workload initialization. + * ``rank_heartbeat_timeout`` should be at least as long as the longest possible interval between steps. + +**Importantly, `heartbeats` are not sent during checkpoint loading and saving**, so time for +checkpointing related operations should be taken into account. + +If ``calculate_timeouts: True`` timeouts will be automatically estimated based on observed intervals. +Estimated timeouts take precedence over timeouts defined in the config file. **Timeouts are estimated after +checkpoint loading and saving was observed**. For example, in multi-part training started from scratch, +estimated timeouts won't be available during the first run. Estimated timeouts are stored in the checkpoint. + +``max_subsequent_job_failures`` allows for the automatic continuation of training on a SLURM cluster. +This feature requires SLURM job to be scheduled with ``NeMo-Framework-Launcher``. If ``max_subsequent_job_failures`` +value is `>0` continuation job is prescheduled. It will continue the work until ``max_subsequent_job_failures`` +subsequent jobs failed (SLURM job exit code is `!= 0`) or the training is completed successfully +("end of training" marker file is produced by the ``FaultToleranceCallback``, i.e. due to iters or time limit reached). + +All FT configuration items summary: + * ``workload_check_interval`` (float, default=5.0) Periodic workload check interval [seconds] in the workload monitor. + * ``initial_rank_heartbeat_timeout`` (Optional[float], default=60.0 * 60.0) Timeout for the first heartbeat from a rank. + * ``rank_heartbeat_timeout`` (Optional[float], default=45.0 * 60.0) Timeout for subsequent heartbeats from a rank. + * ``calculate_timeouts`` (bool, default=True) Try to calculate ``rank_heartbeat_timeout`` and ``initial_rank_heartbeat_timeout`` + based on the observed heartbeat intervals. + * ``rank_termination_signal`` (signal.Signals, default=signal.SIGKILL) Signal used to terminate the rank when failure is detected. + * ``log_level`` (str, default='INFO') Log level for the FT client and server(rank monitor). + * ``max_rank_restarts`` (int, default=0) Used by FT launcher. Max number of restarts for a rank. + If ``>0`` ranks will be restarted on existing nodes in case of a failure. + * ``max_subsequent_job_failures`` (int, default=0) Used by FT launcher. How many subsequent job failures are allowed until stopping autoresuming. + ``0`` means do not autoresume. + * ``additional_ft_launcher_args`` (str, default='') Additional FT launcher params (for advanced use). + + +.. _nemo_multirun-label: Hydra Multi-Run with NeMo ------------------------- diff --git a/docs/source/features/memory_optimizations.rst b/docs/source/features/memory_optimizations.rst index 4d363670fedf3..1fe8215864a97 100644 --- a/docs/source/features/memory_optimizations.rst +++ b/docs/source/features/memory_optimizations.rst @@ -105,3 +105,24 @@ Implement MQA or GQA NeMo's support for GQA and MQA is enabled through the integration of Megatron Core's Attention mechanism. The underlying implementation details can be explored within the Attention class of Megatron Core, which provides the functional backbone for these advanced attention methods. To understand the specific modifications and implementations of MQA and GQA, refer to the source code in the Attention class: Check implementation details from Attention Class in Megatron Core Repo: https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/attention.py#L49 + + +CPU Offloading +-------------- + +Overview +^^^^^^^^ + +CPU Offloading in NeMo is a feature that reduces the peak memory usage of the GPU by offloading activations and inactive weights to CPU storage. NeMo supports offloading at the transformer layer level, allowing users to specify the number of transformer layers in their language model that require CPU offloading. During the forward pass, NeMo offloads activations at the optimal time and reloads them as needed during the backward pass. + +Features +^^^^^^^^ +> Supports training models with long sequence lengths by managing activation memory efficiently. +> Enables high batch sizes per GPU by offloading activation memory. +> Overlaps computation with data transfers (Host2Device and Device2Host) during offloading and reloading. + +Usage +^^^^^ +> Set cpu_offloading to True to enable CPU offloading. +> Set cpu_offloading_num_layers to a value between 0 and the total number of layers in the model minus one. +> Set cpu_offloading_activations and cpu_offloading_weights based on your needs to offload activations only, weights only, or both. diff --git a/docs/source/index.rst b/docs/source/index.rst index f3d68500f44dd..f10ae126267bc 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -12,7 +12,7 @@ NVIDIA NeMo Framework is an end-to-end, cloud-native framework designed to build - Flash Attention - Activation Recomputation - Positional Embeddings and Positional Interpolation -- Post-Training Quantization (PTQ) with ModelOpt +- Post-Training Quantization (PTQ) and Quantization Aware Training (QAT) with `TensorRT Model Optimizer `_ - Sequence Packing `NVIDIA NeMo Framework `_ has separate collections for: diff --git a/docs/source/nlp/quantization.rst b/docs/source/nlp/quantization.rst index 747938bebedd1..1d016dd0c3a8a 100644 --- a/docs/source/nlp/quantization.rst +++ b/docs/source/nlp/quantization.rst @@ -55,6 +55,10 @@ Table below presents verified model support matrix for popular LLM architectures - ✅ - ✅ - ✅ + * - `Nemotron-4 340b `_ (Base, Instruct, Reward) + - ✅ + - ✅ + - ✅ * - StarCoder 2 - ✅ - ✅ @@ -67,14 +71,14 @@ Table below presents verified model support matrix for popular LLM architectures Example ^^^^^^^ -The example below shows how to quantize the Llama2 70b model into FP8 precision, using tensor parallelism of 8 on a single DGX H100 node. The quantized model is designed for serving using 2 GPUs specified with the ``export.inference_tensor_parallel`` parameter. +The example below shows how to quantize the Llama3 70b model into FP8 precision, using tensor parallelism of 8 on a single DGX H100 node. The quantized model is designed for serving using 2 GPUs specified with the ``export.inference_tensor_parallel`` parameter. The script must be launched correctly with the number of processes equal to tensor parallelism. This is achieved with the ``torchrun`` command below: .. code-block:: bash - torchrun --nproc-per-node 8 examples/nlp/language_modeling/megatron_gpt_quantization.py \ - model.restore_from_path=llama2-70b-base-bf16.nemo \ + torchrun --nproc-per-node 8 examples/nlp/language_modeling/megatron_gpt_ptq.py \ + model.restore_from_path=llama3-70b-base-bf16.nemo \ model.tensor_model_parallel_size=8 \ model.pipeline_model_parallel_size=1 \ trainer.num_nodes=1 \ @@ -83,15 +87,15 @@ The script must be launched correctly with the number of processes equal to tens quantization.algorithm=fp8 \ export.decoder_type=llama \ export.inference_tensor_parallel=2 \ - export.save_path=llama2-70b-base-fp8-qnemo - + export.save_path=llama3-70b-base-fp8-qnemo +For large models, the command can be used in multi-node setting. For example, this can be done with `NeMo Framework Launcher `_ using Slurm. The output directory stores the following files: .. code-block:: bash - llama2-70b-base-fp8-qnemo/ + llama3-70b-base-fp8-qnemo/ ├── config.json ├── rank0.safetensors ├── rank1.safetensors @@ -103,12 +107,12 @@ The TensorRT-LLM engine can be conveniently built and run using ``TensorRTLLM`` .. code-block:: python - from nemo.export import TensorRTLLM + from nemo.export.tensorrt_llm import TensorRTLLM trt_llm_exporter = TensorRTLLM(model_dir="/path/to/trt_llm_engine_folder") trt_llm_exporter.export( - nemo_checkpoint_path="llama2-70b-base-fp8-qnemo", + nemo_checkpoint_path="llama3-70b-base-fp8-qnemo", model_type="llama", ) trt_llm_exporter.forward(["Hi, how are you?", "I am good, thanks, how about you?"]) @@ -119,7 +123,7 @@ Alternatively, it can also be built directly using ``trtllm-build`` command, see .. code-block:: bash trtllm-build \ - --checkpoint_dir llama2-70b-base-fp8-qnemo \ + --checkpoint_dir llama3-70b-base-fp8-qnemo \ --output_dir /path/to/trt_llm_engine_folder \ --max_batch_size 8 \ --max_input_len 2048 \ @@ -129,19 +133,64 @@ Alternatively, it can also be built directly using ``trtllm-build`` command, see Known issues ^^^^^^^^^^^^ -* Currently in NeMo, quantizing and building TensorRT-LLM engines is limited to single-node use cases. -* The supported and tested model family is Llama2. Quantizing other model types is experimental and may not be fully supported. +* Currently with ``nemo.export`` module building TensorRT-LLM engines for quantized "qnemo" models is limited to single-node deployments. -Please refer to the following papers for more details on quantization techniques. +Quantization-Aware Training (QAT) +--------------------------------- -References ----------- +QAT is the technique of fine-tuning a quantized model to recover model quality degradation due to quantization. +During QAT, the quantization scaling factors computed during PTQ are frozen and the model weights are fine-tuned. +While QAT requires much more compute resources than PTQ, it is highly effective in recovering model quality. +To perform QAT on a calibrated model from PTQ, you need to further fine-tune the model on a downstream task using a small dataset before exporting to TensorRT-LLM. +You can reuse your training pipeline for QAT. +As a rule of thumb, we recommend QAT for 1-10% original training duration and a small learning rate, e.g. 1e-5 for Adam optimizer. +If you are doing QAT on an SFT model where learning rates and finetuning dataset size are already small, you can continue using the same SFT learning rate and dataset size as a starting point for QAT. +Since QAT is done after PTQ, the supported model families are the same as for PTQ. + + +Example +^^^^^^^ + +The example below shows how to perform PTQ and QAT on a Supervised Finetuned Llama2 7B model to INT4 precision. +The script is tested using tensor parallelism of 8 on 8x RTX 6000 Ada 48GB GPUs. Alternatively, a single DGX A100 node with 8x 40GB GPUs can be used for the same purpose. +For bigger models like Llama2 70B, you may need to use one or more DGX H100 nodes with 8x 80GB GPUs each. + +The example is a modified version of the `SFT with Llama 2 playbook `_. +Please refer to the playbook for more details on setting up a BF16 NeMo model and the ``databricks-dolly-15k`` instruction dataset. -`Integer Quantization for Deep Learning Inference: Principles and Empirical Evaluation, 2020 `_ +First we will run the SFT example command from the playbook as-is to train a Llama2 7B SFT model for 100 steps. +Make sure to change ``trainer.max_steps=50`` to ``trainer.max_steps=100`` for the ``examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py`` script. +This will take ~2 hours to produce a model checkpoint with validation loss approximately ``1.15`` that we will use for PTQ and QAT next. -`FP8 Formats for Deep Learning, 2022 `_ +For Quantization, we use a modified version of the sft script and config file which includes the quantization and TensorRT-LLM export support. +Along with the new parameters, make sure to pass the same parameters you passed for SFT training except the model restore path will be the SFT output ``.nemo`` file. +The below example command will perform PTQ on the SFT model checkpoint followed by SFT again (QAT) which can then be exported for TensorRT-LLM inference. The script will take ~2-3 hours to complete. + +.. code-block:: bash + + torchrun --nproc-per-node 8 examples/nlp/language_modeling/tuning/megatron_gpt_qat.py \ + trainer.num_nodes=1 \ + trainer.devices=8 \ + trainer.precision=bf16 \ + trainer.max_steps=100 \ + model.restore_from_path= \ + model.global_batch_size=128 \ + quantization.algorithm=int4 \ + # other parameters from sft training + +As you can see from the logs, the INT4 PTQ model has a validation loss of approximately ``1.31`` and the QAT model has a validation loss of approximately ``1.17`` which is very close to the BF16 model loss of ``1.15``. +This script will produce a quantized ``.nemo`` checkpoint at the experiment manager log directory (in the config yaml file) that can be used for further training. +It can also optionally produce an exported TensorRT-LLM engine directory or a ``.qnemo`` file that can be used for inference by setting the ``export`` parameters similar to the PTQ example. +Note that you may tweak the QAT trainer steps and learning rate if needed to achieve better model quality. + + +References +---------- -`SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models, 2022 `_ +Please refer to the following papers for more details on quantization techniques: -`AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration, 2023 `_ +* `Integer Quantization for Deep Learning Inference: Principles and Empirical Evaluation, 2020 `_ +* `FP8 Formats for Deep Learning, 2022 `_ +* `SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models, 2022 `_ +* `AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration, 2023 `_ diff --git a/docs/source/starthere/intro.rst b/docs/source/starthere/intro.rst index ebbe1551c39ee..8edb435bec62f 100644 --- a/docs/source/starthere/intro.rst +++ b/docs/source/starthere/intro.rst @@ -96,13 +96,13 @@ This section details the steps to clone and install the Megatron Core. git checkout a5415fcfacef2a37416259bd38b7c4b673583675 && \ pip install . -Model Optimizer Installation +TensorRT Model Optimizer Installation -This final step involves installing the Model Optimizer package. +This final step involves installing the TensorRT Model Optimizer package. .. code-block:: bash - pip install nvidia-modelopt[torch]~=0.11.0 --extra-index-url https://pypi.nvidia.com + pip install nvidia-modelopt[torch]~=0.13.0 --extra-index-url https://pypi.nvidia.com .. code-block:: bash diff --git a/docs/source/starthere/tutorials.rst b/docs/source/starthere/tutorials.rst index 0298dbdf6d4b5..6f31b9398d47e 100644 --- a/docs/source/starthere/tutorials.rst +++ b/docs/source/starthere/tutorials.rst @@ -65,7 +65,7 @@ Tutorial Overview - `DreamBooth Tutorial `_ * - Multimodal - Preparations and Advanced Applications: Stable Diffusion XL Quantization Tutorial - - `DreamBooth Tutorial `_ + - `SDXL Quantization Tutorial `_ .. list-table:: **Automatic Speech Recognition (ASR) Tutorials** :widths: 15 30 55 diff --git a/examples/audio_tasks/audio_to_audio_eval.py b/examples/audio/audio_to_audio_eval.py similarity index 96% rename from examples/audio_tasks/audio_to_audio_eval.py rename to examples/audio/audio_to_audio_eval.py index ab6623df298d0..4e60b2ec2b528 100644 --- a/examples/audio_tasks/audio_to_audio_eval.py +++ b/examples/audio/audio_to_audio_eval.py @@ -73,9 +73,9 @@ from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility from tqdm import tqdm -from nemo.collections.asr.data import audio_to_audio_dataset -from nemo.collections.asr.data.audio_to_audio_lhotse import LhotseAudioToTargetDataset -from nemo.collections.asr.metrics.audio import AudioMetricWrapper +from nemo.collections.audio.data import audio_to_audio_dataset +from nemo.collections.audio.data.audio_to_audio_lhotse import LhotseAudioToTargetDataset +from nemo.collections.audio.metrics.audio import AudioMetricWrapper from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.collections.common.parts.preprocessing import manifest from nemo.core.config import hydra_runner @@ -107,8 +107,7 @@ class AudioEvaluationConfig(process_audio.ProcessConfig): def get_evaluation_dataloader(config): - """Prepare a dataloader for evaluation. - """ + """Prepare a dataloader for evaluation.""" if config.get("use_lhotse", False): return get_lhotse_dataloader_from_config( config, global_rank=0, world_size=1, dataset=LhotseAudioToTargetDataset() @@ -128,8 +127,7 @@ def get_evaluation_dataloader(config): def get_metrics(cfg: AudioEvaluationConfig): - """Prepare a dictionary with metrics. - """ + """Prepare a dictionary with metrics.""" available_metrics = ['sdr', 'sisdr', 'stoi', 'estoi', 'pesq'] metrics = dict() @@ -203,9 +201,10 @@ def main(cfg: AudioEvaluationConfig): num_files = 0 - with open(process_cfg.output_filename, 'r') as f_processed, open( - temporary_manifest_filepath, 'w', encoding='utf-8' - ) as f_tmp: + with ( + open(process_cfg.output_filename, 'r') as f_processed, + open(temporary_manifest_filepath, 'w', encoding='utf-8') as f_tmp, + ): for line_processed in f_processed: data_processed = json.loads(line_processed) diff --git a/examples/audio_tasks/speech_enhancement.py b/examples/audio/audio_to_audio_train.py similarity index 93% rename from examples/audio_tasks/speech_enhancement.py rename to examples/audio/audio_to_audio_train.py index 33a25c1c107c7..2dc91036234fe 100644 --- a/examples/audio_tasks/speech_enhancement.py +++ b/examples/audio/audio_to_audio_train.py @@ -16,7 +16,7 @@ # Training the model Basic run (on CPU for 50 epochs): - python examples/audio_tasks/speech_enhancement.py \ + python examples/audio/audio_to_audio_train.py \ # (Optional: --config-path= --config-name=) \ model.train_ds.manifest_filepath="" \ model.validation_ds.manifest_filepath="" \ @@ -32,7 +32,7 @@ import torch from omegaconf import OmegaConf -from nemo.collections.asr.models.enhancement_models import ( +from nemo.collections.audio.models.enhancement import ( EncMaskDecAudioToAudioModel, PredictiveAudioToAudioModel, ScoreBasedGenerativeAudioToAudioModel, @@ -43,8 +43,7 @@ class ModelType(str, Enum): - """Enumeration with the available model types. - """ + """Enumeration with the available model types.""" MaskBased = 'mask_based' Predictive = 'predictive' @@ -52,8 +51,7 @@ class ModelType(str, Enum): def get_model_class(model_type: ModelType): - """Get model class for a given model type. - """ + """Get model class for a given model type.""" if model_type == ModelType.MaskBased: return EncMaskDecAudioToAudioModel elif model_type == ModelType.Predictive: diff --git a/examples/audio_tasks/conf/beamforming.yaml b/examples/audio/conf/beamforming.yaml similarity index 91% rename from examples/audio_tasks/conf/beamforming.yaml rename to examples/audio/conf/beamforming.yaml index 3abc4f134e642..9b1b743e60e51 100644 --- a/examples/audio_tasks/conf/beamforming.yaml +++ b/examples/audio/conf/beamforming.yaml @@ -41,17 +41,17 @@ model: pin_memory: true encoder: - _target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram + _target_: nemo.collections.audio.modules.transforms.AudioToSpectrogram fft_length: 512 # Length of the window and FFT for calculating spectrogram hop_length: 256 # Hop length for calculating spectrogram decoder: - _target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio + _target_: nemo.collections.audio.modules.transforms.SpectrogramToAudio fft_length: 512 # Length of the window and FFT for calculating spectrogram hop_length: 256 # Hop length for calculating spectrogram mask_estimator: - _target_: nemo.collections.asr.modules.audio_modules.MaskEstimatorRNN + _target_: nemo.collections.audio.modules.masking.MaskEstimatorRNN num_outputs: ${model.num_outputs} num_subbands: 257 # Number of subbands of the input spectrogram num_features: 256 # Number of features at RNN input @@ -59,11 +59,11 @@ model: bidirectional: true # Use bi-directional RNN mask_processor: - _target_: nemo.collections.asr.modules.audio_modules.MaskBasedBeamformer # Mask-based multi-channel processing + _target_: nemo.collections.audio.modules.masking.MaskBasedBeamformer # Mask-based multi-channel processing ref_channel: 0 # Reference channel for the output loss: - _target_: nemo.collections.asr.losses.SDRLoss + _target_: nemo.collections.audio.losses.SDRLoss scale_invariant: true # Use scale-invariant SDR metrics: diff --git a/examples/audio_tasks/conf/beamforming_flex_channels.yaml b/examples/audio/conf/beamforming_flex_channels.yaml similarity index 93% rename from examples/audio_tasks/conf/beamforming_flex_channels.yaml rename to examples/audio/conf/beamforming_flex_channels.yaml index 29fc87acf93d5..8a22bf4598120 100644 --- a/examples/audio_tasks/conf/beamforming_flex_channels.yaml +++ b/examples/audio/conf/beamforming_flex_channels.yaml @@ -39,17 +39,17 @@ model: permute_channels: true encoder: - _target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram + _target_: nemo.collections.audio.modules.transforms.AudioToSpectrogram fft_length: 512 # Length of the window and FFT for calculating spectrogram hop_length: 256 # Hop length for calculating spectrogram decoder: - _target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio + _target_: nemo.collections.audio.modules.transforms.SpectrogramToAudio fft_length: ${model.encoder.fft_length} hop_length: ${model.encoder.hop_length} mask_estimator: - _target_: nemo.collections.asr.modules.audio_modules.MaskEstimatorFlexChannels + _target_: nemo.collections.audio.modules.masking.MaskEstimatorFlexChannels num_outputs: ${model.num_outputs} # number of output masks num_subbands: 257 # number of subbands for the input spectrogram num_blocks: 5 # number of blocks in the model @@ -67,7 +67,7 @@ model: mask_processor: # Mask-based multi-channel processor - _target_: nemo.collections.asr.modules.audio_modules.MaskBasedBeamformer + _target_: nemo.collections.audio.modules.masking.MaskBasedBeamformer filter_type: pmwf # parametric multichannel wiener filter filter_beta: 0.0 # mvdr filter_rank: one @@ -78,7 +78,7 @@ model: num_subbands: ${model.mask_estimator.num_subbands} loss: - _target_: nemo.collections.asr.losses.SDRLoss + _target_: nemo.collections.audio.losses.SDRLoss convolution_invariant: true # convolution-invariant loss sdr_max: 30 # soft threshold for SDR diff --git a/examples/audio_tasks/conf/masking.yaml b/examples/audio/conf/masking.yaml similarity index 91% rename from examples/audio_tasks/conf/masking.yaml rename to examples/audio/conf/masking.yaml index 68adca116aa50..3f1c7a6a6e3c2 100644 --- a/examples/audio_tasks/conf/masking.yaml +++ b/examples/audio/conf/masking.yaml @@ -39,17 +39,17 @@ model: pin_memory: true encoder: - _target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram + _target_: nemo.collections.audio.modules.transforms.AudioToSpectrogram fft_length: 512 # Length of the window and FFT for calculating spectrogram hop_length: 256 # Hop length for calculating spectrogram decoder: - _target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio + _target_: nemo.collections.audio.modules.transforms.SpectrogramToAudio fft_length: 512 # Length of the window and FFT for calculating spectrogram hop_length: 256 # Hop length for calculating spectrogram mask_estimator: - _target_: nemo.collections.asr.modules.audio_modules.MaskEstimatorRNN + _target_: nemo.collections.audio.modules.masking.MaskEstimatorRNN num_outputs: ${model.num_outputs} num_subbands: 257 # Number of subbands of the input spectrogram num_features: 256 # Number of features at RNN input @@ -57,11 +57,11 @@ model: bidirectional: true # Use bi-directional RNN mask_processor: - _target_: nemo.collections.asr.modules.audio_modules.MaskReferenceChannel # Apply mask on the reference channel + _target_: nemo.collections.audio.modules.masking.MaskReferenceChannel # Apply mask on the reference channel ref_channel: 0 # Reference channel for the output loss: - _target_: nemo.collections.asr.losses.SDRLoss + _target_: nemo.collections.audio.losses.SDRLoss scale_invariant: true # Use scale-invariant SDR metrics: diff --git a/examples/audio_tasks/conf/predictive.yaml b/examples/audio/conf/predictive.yaml similarity index 91% rename from examples/audio_tasks/conf/predictive.yaml rename to examples/audio/conf/predictive.yaml index b141ba6fd1ee0..a4f6bfe904002 100644 --- a/examples/audio_tasks/conf/predictive.yaml +++ b/examples/audio/conf/predictive.yaml @@ -29,21 +29,21 @@ model: pin_memory: true encoder: - _target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram + _target_: nemo.collections.audio.modules.transforms.AudioToSpectrogram fft_length: 510 # Number of subbands in the STFT = fft_length // 2 + 1 = 256 hop_length: 128 magnitude_power: 0.5 scale: 0.33 decoder: - _target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio + _target_: nemo.collections.audio.modules.transforms.SpectrogramToAudio fft_length: ${model.encoder.fft_length} hop_length: ${model.encoder.hop_length} magnitude_power: ${model.encoder.magnitude_power} scale: ${model.encoder.scale} estimator: - _target_: nemo.collections.asr.parts.submodules.diffusion.SpectrogramNoiseConditionalScoreNetworkPlusPlus + _target_: nemo.collections.audio.parts.submodules.ncsnpp.SpectrogramNoiseConditionalScoreNetworkPlusPlus in_channels: 1 # single-channel noisy input out_channels: 1 # single-channel estimate num_res_blocks: 3 # increased number of res blocks @@ -51,7 +51,7 @@ model: pad_dimension_to: 0 # no padding in the frequency dimension loss: - _target_: nemo.collections.asr.losses.MSELoss # computed in the time domain + _target_: nemo.collections.audio.losses.MSELoss # computed in the time domain metrics: val: diff --git a/examples/audio_tasks/conf/score_based_generative.yaml b/examples/audio/conf/score_based_generative.yaml similarity index 90% rename from examples/audio_tasks/conf/score_based_generative.yaml rename to examples/audio/conf/score_based_generative.yaml index c0b36bd750a28..aa55b13d0963b 100644 --- a/examples/audio_tasks/conf/score_based_generative.yaml +++ b/examples/audio/conf/score_based_generative.yaml @@ -31,21 +31,21 @@ model: pin_memory: true encoder: - _target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram + _target_: nemo.collections.audio.modules.transforms.AudioToSpectrogram fft_length: 510 # Number of subbands in the STFT = fft_length // 2 + 1 = 256 hop_length: 128 magnitude_power: 0.5 scale: 0.33 decoder: - _target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio + _target_: nemo.collections.audio.modules.transforms.SpectrogramToAudio fft_length: ${model.encoder.fft_length} hop_length: ${model.encoder.hop_length} magnitude_power: ${model.encoder.magnitude_power} scale: ${model.encoder.scale} estimator: - _target_: nemo.collections.asr.parts.submodules.diffusion.SpectrogramNoiseConditionalScoreNetworkPlusPlus + _target_: nemo.collections.audio.parts.submodules.ncsnpp.SpectrogramNoiseConditionalScoreNetworkPlusPlus in_channels: 2 # concatenation of single-channel perturbed and noisy out_channels: 1 # single-channel score estimate conditioned_on_time: true @@ -54,14 +54,14 @@ model: pad_dimension_to: 0 # no padding in the frequency dimension sde: - _target_: nemo.collections.asr.parts.submodules.diffusion.OrnsteinUhlenbeckVarianceExplodingSDE + _target_: nemo.collections.audio.parts.submodules.diffusion.OrnsteinUhlenbeckVarianceExplodingSDE stiffness: 1.5 std_min: 0.05 std_max: 0.5 num_steps: 1000 sampler: - _target_: nemo.collections.asr.parts.submodules.diffusion.PredictorCorrectorSampler + _target_: nemo.collections.audio.parts.submodules.diffusion.PredictorCorrectorSampler predictor: reverse_diffusion corrector: annealed_langevin_dynamics num_steps: 50 @@ -69,7 +69,7 @@ model: snr: 0.5 loss: - _target_: nemo.collections.asr.losses.MSELoss + _target_: nemo.collections.audio.losses.MSELoss ndim: 4 # loss is calculated on the score in the encoded domain (batch, channel, dimension, time) metrics: diff --git a/examples/audio_tasks/process_audio.py b/examples/audio/process_audio.py similarity index 99% rename from examples/audio_tasks/process_audio.py rename to examples/audio/process_audio.py index e73831fe7a5f5..6cf7a8499122b 100644 --- a/examples/audio_tasks/process_audio.py +++ b/examples/audio/process_audio.py @@ -24,7 +24,7 @@ import torch from omegaconf import OmegaConf -from nemo.collections.asr.models import AudioToAudioModel +from nemo.collections.audio.models import AudioToAudioModel from nemo.core.config import hydra_runner from nemo.utils import logging, model_utils diff --git a/examples/multimodal/convert_ckpt_to_nemo.py b/examples/multimodal/convert_ckpt_to_nemo.py index 2bc0f5d7ab623..573bdc0bc0401 100644 --- a/examples/multimodal/convert_ckpt_to_nemo.py +++ b/examples/multimodal/convert_ckpt_to_nemo.py @@ -165,14 +165,6 @@ def convert(local_rank, rank, world_size, args): model = MegatronControlNet.load_from_checkpoint( checkpoint_path, hparams_file=args.hparams_file, trainer=trainer ) - elif args.model_type == 'kosmos': - model = MegatronKosmosModel.load_from_checkpoint( - checkpoint_path, hparams_file=args.hparams_file, trainer=trainer - ) - elif args.model_type == 'neva': - model = MegatronNevaModel.load_from_checkpoint( - checkpoint_path, hparams_file=args.hparams_file, trainer=trainer - ) else: raise ValueError(f"Unrecognized model_type {args.model_type}.") diff --git a/examples/multimodal/multimodal_llm/neva/conf/neva_export.yaml b/examples/multimodal/multimodal_llm/neva/conf/neva_export.yaml new file mode 100644 index 0000000000000..5a163b2505664 --- /dev/null +++ b/examples/multimodal/multimodal_llm/neva/conf/neva_export.yaml @@ -0,0 +1,15 @@ +name: nemo_neva +infer: + output_dir: ./neva + max_batch_size: 1 + tensor_parallelism: 1 + max_input_len: 4096 + max_output_len: 256 + max_multimodal_len: 3072 + +model: + type: neva + precision: bfloat16 + visual_model_path: /path/to/visual.nemo + llm_model_path: /path/to/llm.nemo + llm_model_type: llama diff --git a/examples/multimodal/multimodal_llm/neva/conf/neva_trt_infer.yaml b/examples/multimodal/multimodal_llm/neva/conf/neva_trt_infer.yaml new file mode 100644 index 0000000000000..14e6f98c06764 --- /dev/null +++ b/examples/multimodal/multimodal_llm/neva/conf/neva_trt_infer.yaml @@ -0,0 +1,12 @@ +name: nemo_neva +engine_dir: ./neva +input_media: ./test.jpg +input_text: "Hi! What is in this image?" +batch_size: 1 +infer: + top_k: 1 # The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p: 0.0 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. + temperature: 1.0 # sampling temperature + repetition_penalty: 1.0 # The parameter for repetition penalty. 1.0 means no penalty. + num_beams: 1 + max_new_tokens: 30 diff --git a/examples/multimodal/multimodal_llm/neva/neva_export.py b/examples/multimodal/multimodal_llm/neva/neva_export.py new file mode 100644 index 0000000000000..2c081d00a003b --- /dev/null +++ b/examples/multimodal/multimodal_llm/neva/neva_export.py @@ -0,0 +1,38 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.core.config import hydra_runner +from nemo.export.tensorrt_mm_exporter import TensorRTMMExporter + + +@hydra_runner(config_path='conf', config_name='neva_export') +def main(cfg): + exporter = TensorRTMMExporter(model_dir=cfg.infer.output_dir, load_model=False) + exporter.export( + visual_checkpoint_path=cfg.model.visual_model_path, + llm_checkpoint_path=cfg.model.llm_model_path, + model_type=cfg.model.type, + llm_model_type=cfg.model.llm_model_type, + tensor_parallel_size=cfg.infer.tensor_parallelism, + max_input_len=cfg.infer.max_input_len, + max_output_len=cfg.infer.max_output_len, + max_batch_size=cfg.infer.max_batch_size, + max_multimodal_len=cfg.infer.max_multimodal_len, + dtype=cfg.model.precision, + load_model=False, + ) + + +if __name__ == '__main__': + main() diff --git a/examples/multimodal/multimodal_llm/neva/neva_finetune.py b/examples/multimodal/multimodal_llm/neva/neva_finetune.py index 8db107134bdf1..e94308ad89f3e 100644 --- a/examples/multimodal/multimodal_llm/neva/neva_finetune.py +++ b/examples/multimodal/multimodal_llm/neva/neva_finetune.py @@ -42,6 +42,7 @@ def main(cfg) -> None: override_config_path=cfg.model, save_restore_connector=NLPSaveRestoreConnector(), strict=False, + validate_access_integrity=False if cfg.model.pipeline_model_parallel_size > 1 else True, ) trainer.fit(model) diff --git a/examples/multimodal/multimodal_llm/neva/neva_trt_run.py b/examples/multimodal/multimodal_llm/neva/neva_trt_run.py new file mode 100644 index 0000000000000..b26d4e83432fc --- /dev/null +++ b/examples/multimodal/multimodal_llm/neva/neva_trt_run.py @@ -0,0 +1,42 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from nemo.core.config import hydra_runner +from nemo.export.tensorrt_mm_exporter import TensorRTMMExporter + + +@hydra_runner(config_path='conf', config_name='neva_trt_infer') +def main(cfg): + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + exporter = TensorRTMMExporter(cfg.engine_dir) + output = exporter.forward( + input_text=cfg.input_text, + input_media=cfg.input_media, + batch_size=cfg.batch_size, + max_output_len=cfg.infer.max_new_tokens, + top_k=cfg.infer.top_k, + top_p=cfg.infer.top_p, + temperature=cfg.infer.temperature, + repetition_penalty=cfg.infer.repetition_penalty, + num_beams=cfg.infer.num_beams, + ) + + print(output) + + +if __name__ == '__main__': + main() diff --git a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml index dff9635908648..da03a1de96cf8 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml +++ b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml @@ -17,7 +17,6 @@ trainer: enable_model_summary: True limit_val_batches: 0 - exp_manager: exp_dir: null name: ${name} diff --git a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_base.yaml b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_base.yaml index c536bae15926f..7e83093eb780c 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_base.yaml +++ b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_base.yaml @@ -58,8 +58,6 @@ model: lossconfig: target: torch.nn.Identity - - conditioner_config: _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.GeneralConditioner emb_models: diff --git a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_base_train.yaml b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_base_train.yaml index 7aa765db2e5f7..aa1d2782d15b4 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_base_train.yaml +++ b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_base_train.yaml @@ -125,7 +125,6 @@ model: target: torch.nn.Identity - conditioner_config: _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.GeneralConditioner emb_models: diff --git a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_infer.yaml b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_infer.yaml index eb1f6d7ccb8e7..632f1634af50a 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_infer.yaml +++ b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_infer.yaml @@ -31,9 +31,9 @@ infer: sampling: base: sampler: EulerEDMSampler - width: 256 - height: 256 - steps: 40 + width: 512 + height: 512 + steps: 50 discretization: "LegacyDDPMDiscretization" guider: "VanillaCFG" thresholder: "None" @@ -48,8 +48,8 @@ sampling: s_noise: 1.0 eta: 1.0 order: 4 - orig_width: 1024 - orig_height: 1024 + orig_width: 512 + orig_height: 512 crop_coords_top: 0 crop_coords_left: 0 aesthetic_score: 5.0 diff --git a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_infer_v2.yaml b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_infer_v2.yaml new file mode 100644 index 0000000000000..9dc838dcc5c59 --- /dev/null +++ b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_infer_v2.yaml @@ -0,0 +1,189 @@ +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + precision: 32 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 # PTL default. In practice, max_steps will be reached first. + max_steps: -1 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: True + limit_val_batches: 0 + + +infer: + num_samples_per_batch: 1 + num_samples: 4 + prompt: + - "A professional photograph of an astronaut riding a pig" + - 'A photo of a Shiba Inu dog with a backpack riding a bike. It is wearing sunglasses and a beach hat.' + - 'A cute corgi lives in a house made out of sushi.' + - 'A high contrast portrait of a very happy fuzzy panda dressed as a chef in a high end kitchen making dough. There is a painting of flowers on the wall behind him.' + - 'A brain riding a rocketship heading towards the moon.' + negative_prompt: "" + seed: 123 + + +sampling: + base: + sampler: EulerEDMSampler + width: 512 + height: 512 + steps: 50 + discretization: "LegacyDDPMDiscretization" + guider: "VanillaCFG" + thresholder: "None" + scale: 5.0 + img2img_strength: 1.0 + sigma_min: 0.0292 + sigma_max: 14.6146 + rho: 3.0 + s_churn: 0.0 + s_tmin: 0.0 + s_tmax: 999.0 + s_noise: 1.0 + eta: 1.0 + order: 4 + orig_width: 512 + orig_height: 512 + crop_coords_top: 0 + crop_coords_left: 0 + aesthetic_score: 5.0 + negative_aesthetic_score: 5.0 + +# model: +# is_legacy: False + +use_refiner: False +use_fp16: False # use fp16 model weights +out_path: ./output + +base_model_config: /opt/NeMo/examples/multimodal/generative/stable_diffusion/conf/sd_xl_base.yaml +refiner_config: /opt/NeMo/examples/multimodal/generative/stable_diffusion/conf/sd_xl_refiner.yaml + +model: + scale_factor: 0.13025 + disable_first_stage_autocast: True + is_legacy: False + restore_from_path: "" + + fsdp: False + fsdp_set_buffer_dtype: null + fsdp_sharding_strategy: 'full' + use_cpu_initialization: True + # hidden_size: 4 + # pipeline_model_parallel_size: 4 + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.0 + betas: + - 0.9 + - 0.999 + sched: + name: WarmupHoldPolicy + warmup_steps: 10 + hold_steps: 10000000000000 # Incredibly large value to hold the lr as constant + + denoiser_config: + _target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.denoiser.DiscreteDenoiser + num_idx: 1000 + + weighting_config: + _target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + _target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.denoiser_scaling.EpsScaling + discretization_config: + _target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.discretizer.LegacyDDPMDiscretization + + unet_config: + _target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.openaimodel.UNetModel + from_pretrained: /opt/nemo-aligner/checkpoints/sdxl/unet_nemo.ckpt + from_NeMo: True + adm_in_channels: 2816 + num_classes: sequential + use_checkpoint: False + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4 ] + num_head_channels: 64 + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: [ 1, 2, 10 ] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16 + context_dim: 2048 + image_size: 64 # unused +# spatial_transformer_attn_type: softmax #note: only default softmax is supported now + legacy: False + use_flash_attention: False + + first_stage_config: + # _target_: nemo.collections.multimodal.models.stable_diffusion.ldm.autoencoder.AutoencoderKLInferenceWrapper + _target_: nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder.AutoencoderKLInferenceWrapper + from_pretrained: /opt/nemo-aligner/checkpoints/sdxl/vae_nemo.ckpt + from_NeMo: True + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 4, 4 ] + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + conditioner_config: + _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.GeneralConditioner + emb_models: + # crossattn cond + - is_trainable: False + input_key: txt + emb_model: + _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenCLIPEmbedder + layer: hidden + layer_idx: 11 + # crossattn and vector cond + - is_trainable: False + input_key: txt + emb_model: + _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenOpenCLIPEmbedder2 + arch: ViT-bigG-14 + version: laion2b_s39b_b160k + freeze: True + layer: penultimate + always_return_pooled: True + legacy: False + # vector cond + - is_trainable: False + input_key: original_size_as_tuple + emb_model: + _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.ConcatTimestepEmbedderND + outdim: 256 # multiplied by two + # vector cond + - is_trainable: False + input_key: crop_coords_top_left + emb_model: + _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.ConcatTimestepEmbedderND + outdim: 256 # multiplied by two + # vector cond + - is_trainable: False + input_key: target_size_as_tuple + emb_model: + _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.ConcatTimestepEmbedderND + outdim: 256 # multiplied by two + diff --git a/examples/multimodal/text_to_image/stable_diffusion/sd_train.py b/examples/multimodal/text_to_image/stable_diffusion/sd_train.py index 968d9bec28842..7e151699b38c4 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/sd_train.py +++ b/examples/multimodal/text_to_image/stable_diffusion/sd_train.py @@ -74,7 +74,11 @@ def main(cfg) -> None: n, c, h = cfg.model.micro_batch_size, cfg.model.channels, cfg.model.image_size x = torch.randn((n, c, h, h), dtype=torch.float32, device="cuda") t = torch.randint(77, (n,), device="cuda") - cc = torch.randn((n, 77, cfg.model.unet_config.context_dim), dtype=torch.float32, device="cuda",) + cc = torch.randn( + (n, 77, cfg.model.unet_config.context_dim), + dtype=torch.float32, + device="cuda", + ) if cfg.model.precision in [16, '16']: x = x.type(torch.float16) cc = cc.type(torch.float16) @@ -93,9 +97,7 @@ def main(cfg) -> None: model.zero_grad() if cfg.model.get('peft', None): - peft_cfg_cls = PEFT_CONFIG_MAP[cfg.model.peft.peft_scheme] - if cfg.model.peft.restore_from_path is not None: # initialize peft weights from a checkpoint instead of randomly # This is not the same as resume training because optimizer states are not restored. diff --git a/examples/multimodal/text_to_image/stable_diffusion/sd_xl_infer.py b/examples/multimodal/text_to_image/stable_diffusion/sd_xl_infer.py index 8d18be517c695..981e83ec95c40 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/sd_xl_infer.py +++ b/examples/multimodal/text_to_image/stable_diffusion/sd_xl_infer.py @@ -26,32 +26,44 @@ def model_cfg_modifier(model_cfg): model_cfg.precision = cfg.trainer.precision model_cfg.ckpt_path = None model_cfg.inductor = False - model_cfg.unet_config.from_pretrained = None - model_cfg.first_stage_config.from_pretrained = None + model_cfg.unet_config.from_pretrained = "/opt/nemo-aligner/checkpoints/sdxl/unet_nemo.ckpt" + model_cfg.unet_config.from_NeMo = True + model_cfg.first_stage_config.from_pretrained = "/opt/nemo-aligner/checkpoints/sdxl/vae_nemo.ckpt" + model_cfg.first_stage_config.from_NeMo = True model_cfg.first_stage_config._target_ = 'nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder.AutoencoderKLInferenceWrapper' - model_cfg.fsdp = False + # model_cfg.fsdp = True torch.backends.cuda.matmul.allow_tf32 = True trainer, megatron_diffusion_model = setup_trainer_and_model_for_inference( model_provider=MegatronDiffusionEngine, cfg=cfg, model_cfg_modifier=model_cfg_modifier ) + ### Manually configure sharded model + # model = megatron_diffusion_model + # model = trainer.strategy._setup_model(model) + # model = model.cuda(torch.cuda.current_device()) + # get the diffusion part only model = megatron_diffusion_model.model model.cuda().eval() - base = SamplingPipeline(model, use_fp16=cfg.use_fp16, is_legacy=cfg.model.is_legacy) - use_refiner = cfg.get('use_refiner', False) - for i, prompt in enumerate(cfg.infer.prompt): - samples = base.text_to_image( - params=cfg.sampling.base, - prompt=[prompt], - negative_prompt=cfg.infer.negative_prompt, - samples=cfg.infer.num_samples, - return_latents=True if use_refiner else False, - seed=int(cfg.infer.seed + i * 100), - ) - - perform_save_locally(cfg.out_path, samples) + with torch.no_grad(): + base = SamplingPipeline(model, use_fp16=cfg.use_fp16, is_legacy=cfg.model.is_legacy) + use_refiner = cfg.get('use_refiner', False) + num_samples_per_batch = cfg.infer.get('num_samples_per_batch', cfg.infer.num_samples) + num_batches = cfg.infer.num_samples // num_samples_per_batch + + for i, prompt in enumerate(cfg.infer.prompt): + for batchid in range(num_batches): + samples = base.text_to_image( + params=cfg.sampling.base, + prompt=[prompt], + negative_prompt=cfg.infer.negative_prompt, + samples=num_samples_per_batch, + return_latents=True if use_refiner else False, + seed=int(cfg.infer.seed + i * 100 + batchid * 200), + ) + # samples=cfg.infer.num_samples, + perform_save_locally(cfg.out_path, samples) if __name__ == "__main__": diff --git a/examples/multimodal/text_to_image/stable_diffusion/sd_xl_train.py b/examples/multimodal/text_to_image/stable_diffusion/sd_xl_train.py index a91beca93761e..44412aee0d146 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/sd_xl_train.py +++ b/examples/multimodal/text_to_image/stable_diffusion/sd_xl_train.py @@ -41,7 +41,10 @@ def _training_strategy(self) -> NLPDDPStrategy: _IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive) if _IS_INTERACTIVE and self.cfg.trainer.devices == 1: logging.info("Detected interactive environment, using NLPDDPStrategyNotebook") - return NLPDDPStrategyNotebook(no_ddp_communication_hook=True, find_unused_parameters=False,) + return NLPDDPStrategyNotebook( + no_ddp_communication_hook=True, + find_unused_parameters=False, + ) if self.cfg.model.get('fsdp', False): assert ( @@ -81,9 +84,7 @@ def main(cfg) -> None: model = MegatronDiffusionEngine(cfg.model, trainer) if cfg.model.get('peft', None): - peft_cfg_cls = PEFT_CONFIG_MAP[cfg.model.peft.peft_scheme] - if cfg.model.peft.restore_from_path is not None: # initialize peft weights from a checkpoint instead of randomly # This is not the same as resume training because optimizer states are not restored. diff --git a/examples/multimodal/vision_language_foundation/clip/conf/megatron_clip_VIT-L-14.yaml b/examples/multimodal/vision_language_foundation/clip/conf/megatron_clip_VIT-L-14.yaml index d8740bb98eb25..bfee36b6c099b 100644 --- a/examples/multimodal/vision_language_foundation/clip/conf/megatron_clip_VIT-L-14.yaml +++ b/examples/multimodal/vision_language_foundation/clip/conf/megatron_clip_VIT-L-14.yaml @@ -1,3 +1,50 @@ +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + precision: 32 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 # PTL default. In practice, max_steps will be reached first. + max_steps: 375000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 + val_check_interval: 100 + check_val_every_n_epoch: null + limit_val_batches: 50 + limit_test_batches: 500 + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: False # default PTL callback for this does not support model parallelism, instead we log manually + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: megatron_clip + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + resume_from_checkpoint: ${model.resume_from_checkpoint} + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 10 + mode: min + always_save_nemo: False # saves nemo file during validation, not implemented for model parallel + save_nemo_on_train_end: False # not recommended when training large models on clusters with short time limits + filename: 'megatron_clip--{val_loss:.2f}-{step}-{consumed_samples}' + model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} + ema: + enable: False + decay: 0.9999 + validate_original_weights: False + every_n_steps: 1 + cpu_offload: False + model: precision: 32 # specify micro_batch_size, global_batch_size, and model parallelism @@ -19,6 +66,9 @@ model: local_loss: False # calculate loss w/ local features @ global (instead of realizing full global @ global matrix) gather_with_grad: True # enable full distributed gradient for feature gather, set this to False may cause convergence issue + mcore_gpt: False + transformer_engine: False + vision: precision: 32 # vision configs @@ -135,7 +185,6 @@ model: bias_activation_fusion: False megatron_legacy: True - transformer_engine: False fp8: False # enables fp8 in TransformerLayer forward fp8_e4m3: False # sets fp8_format = recipe.Format.E4M3 fp8_hybrid: False # sets fp8_format = recipe.Format.HYBRID diff --git a/examples/multimodal/vision_language_foundation/clip/conf/megatron_clip_config.yaml b/examples/multimodal/vision_language_foundation/clip/conf/megatron_clip_config.yaml index a6b1928ef13fe..f75a163a5ed2c 100644 --- a/examples/multimodal/vision_language_foundation/clip/conf/megatron_clip_config.yaml +++ b/examples/multimodal/vision_language_foundation/clip/conf/megatron_clip_config.yaml @@ -68,6 +68,8 @@ model: # numerical results as the naïve method. local_loss: False # calculate loss w/ local features @ global (instead of realizing full global @ global matrix) gather_with_grad: True # enable full distributed gradient for feature gather, set this to False may cause convergence issue + mcore_gpt: True + transformer_engine: True vision: precision: ${trainer.precision} @@ -183,7 +185,6 @@ model: bias_activation_fusion: False megatron_legacy: False - transformer_engine: False fp8: False # enables fp8 in TransformerLayer forward fp8_e4m3: False # sets fp8_format = recipe.Format.E4M3 fp8_hybrid: False # sets fp8_format = recipe.Format.HYBRID diff --git a/examples/multimodal/vision_language_foundation/clip/conf/megatron_clip_infer.yaml b/examples/multimodal/vision_language_foundation/clip/conf/megatron_clip_infer.yaml index 215cd17841aed..3e127aa6d86a2 100755 --- a/examples/multimodal/vision_language_foundation/clip/conf/megatron_clip_infer.yaml +++ b/examples/multimodal/vision_language_foundation/clip/conf/megatron_clip_infer.yaml @@ -6,7 +6,7 @@ trainer: num_nodes: 1 accelerator: gpu logger: False # logger provided by exp_manager - precision: 16 # 16, 32, or bf16 + precision: 32 # 16, 32, or bf16 model: restore_from_path: null # Path to a trained ViT .nemo file diff --git a/examples/multimodal/vision_language_foundation/clip/conf/megatron_siglip_so400m_14_384.yaml b/examples/multimodal/vision_language_foundation/clip/conf/megatron_siglip_so400m_14_384.yaml new file mode 100644 index 0000000000000..6c5be3a2bcd68 --- /dev/null +++ b/examples/multimodal/vision_language_foundation/clip/conf/megatron_siglip_so400m_14_384.yaml @@ -0,0 +1,251 @@ +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + precision: 32 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 # PTL default. In practice, max_steps will be reached first. + max_steps: 375000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 + val_check_interval: 100 + check_val_every_n_epoch: null + limit_val_batches: 50 + limit_test_batches: 500 + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: False # default PTL callback for this does not support model parallelism, instead we log manually + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: megatron_clip + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + resume_from_checkpoint: ${model.resume_from_checkpoint} + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 10 + mode: min + always_save_nemo: False # saves nemo file during validation, not implemented for model parallel + save_nemo_on_train_end: False # not recommended when training large models on clusters with short time limits + filename: 'megatron_clip--{val_loss:.2f}-{step}-{consumed_samples}' + model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} + ema: + enable: False + decay: 0.9999 + validate_original_weights: False + every_n_steps: 1 + cpu_offload: False + +model: + precision: 32 + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + micro_batch_size: 1 # limited by GPU memory + global_batch_size: 1 # will use more micro batches to reach global batch size + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + virtual_pipeline_model_parallel_size: null # interleaved pipeline + + restore_from_pretrained: null # used in fine-tuning + # multimodal configs + output_dim: 1152 + # As the number of devices used to train increases, so does the space complexity of + # the logit matrix. Using a naïve all-gather scheme, space complexity will be + # `O(n^2)`. Instead, complexity may become effectively linear if the flags + # `--gather-with-grad` and `--local-loss` are used. This alteration results in one-to-one + # numerical results as the naïve method. + + use_siglip: True + mcore_gpt: True + transformer_engine: True + + vision: + precision: 32 + # vision configs + patch_dim: 14 + img_h: 378 + img_w: 378 + image_mean: null + image_std: null + num_channels: 3 + drop_patch_rate: 0.0 + drop_path_rate: 0.0 + global_average_pool: False + output_dim: ${model.output_dim} + class_token_length: 0 + preprocess_layernorm: True # apply layer norm to embedded tokens + + # model architecture + encoder_seq_length: 196 + max_position_embeddings: ${.encoder_seq_length} + position_embedding_type: learned_absolute + num_layers: 27 + hidden_size: 1152 + ffn_hidden_size: 4304 # Transformer FFN hidden size. Usually 4 * hidden_size. + num_attention_heads: 16 + init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') + use_scaled_init_method: True # use scaled residuals initialization + hidden_dropout: 0. # Dropout probability for hidden state transformer. + attention_dropout: 0. + kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null + apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. + normalization: layernorm # Type of normalization layers + layernorm_epsilon: 1e-5 + do_layer_norm_weight_decay: False # True means weight decay on all params + pre_process: True # add embedding + post_process: True # add pooler + persist_layer_norm: True # Use of persistent fused layer norm kernel. + + ## Activation Checkpointing + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + activations_checkpoint_num_layers: null # not used with 'selective' + sequence_parallel: False + + # precision + native_amp_init_scale: 4294967296 # 2 ** 32 + native_amp_growth_interval: 1000 + hysteresis: 2 # Gradient scale hysteresis + fp32_residual_connection: False # Move residual connections to fp32 + fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 + + # model fusions + masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. + bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition. + + use_cpu_initialization: False # Init weights on the CPU (slow for large models) + onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. + gradient_accumulation_fusion: False # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism. + openai_gelu: True + bias_activation_fusion: False + megatron_legacy: True + activation: approx-gelu + + + + text: + precision: 32 + # text configs + output_dim: ${model.output_dim} + + # model architecture + encoder_seq_length: 64 + max_position_embeddings: ${.encoder_seq_length} + position_embedding_type: learned_absolute + num_layers: 27 + hidden_size: 1152 + ffn_hidden_size: 4304 # Transformer FFN hidden size. Usually 4 * hidden_size. + num_attention_heads: 16 + init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') + use_scaled_init_method: True # use scaled residuals initialization + hidden_dropout: 0. # Dropout probability for hidden state transformer. + attention_dropout: 0. + kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null + apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. + normalization: layernorm # Type of normalization layers + layernorm_epsilon: 1e-5 + do_layer_norm_weight_decay: False # True means weight decay on all params + pre_process: True # add embedding + post_process: True # add pooler + persist_layer_norm: True # Use of persistent fused layer norm kernel. + + ## Activation Checkpointing + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + activations_checkpoint_num_layers: null # not used with 'selective' + num_micro_batches_with_partial_activation_checkpoints: null + activations_checkpoint_layers_per_pipeline: null + sequence_parallel: False + + # precision + native_amp_init_scale: 4294967296 # 2 ** 32 + native_amp_growth_interval: 1000 + hysteresis: 2 # Gradient scale hysteresis + fp32_residual_connection: False # Move residual connections to fp32 + fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 + + # model fusions + masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. + bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition. + + use_cpu_initialization: False # Init weights on the CPU (slow for large models) + onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. + gradient_accumulation_fusion: False # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism. + openai_gelu: True + bias_activation_fusion: False + megatron_legacy: True + + fp8: False # enables fp8 in TransformerLayer forward + fp8_e4m3: False # sets fp8_format = recipe.Format.E4M3 + fp8_hybrid: False # sets fp8_format = recipe.Format.HYBRID + fp8_margin: 0 # scaling margin + fp8_interval: 1 # scaling update interval + fp8_amax_history_len: 1 # Number of steps for which amax history is recorded per tensor + fp8_amax_compute_algo: most_recent # 'most_recent' or 'max'. Algorithm for computing amax from history + use_emha: False # Use fused multi-head attention for large sequence-length. Note this is not yet supported. Please set to False. + activation: approx-gelu + + # Megatron O2-style half-precision + megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters + grad_allreduce_chunk_size_mb: 125 + grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce + + # miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + + tokenizer: + library: 'huggingface' + type: 'google/siglip-so400m-patch14-384' + model: null + vocab_file: null + merge_file: null + delimiter: null # only used for tabular tokenizer + sentencepiece_legacy: False # Legacy=True allows you to add special tokens to sentencepiece tokenizers. + make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency. + + data: + num_workers: 8 + train: + dataset_path: # List of paths to pkl files or tar files + - /datasets/coyo/test.pkl + validation: # List of paths to pkl files or tar files + dataset_path: + - /datasets/coyo/test.pkl + webdataset: + infinite_sampler: False + local_root_path: /datasets/coyo + + imagenet_val: null # Path to imagenet val set for conducting zero shot evaluation. + + # Nsys profiling options + nsys_profile: + enabled: False + start_step: 10 # Global batch to start profiling + end_step: 10 # Global batch to end profiling + ranks: [ 0 ] # Global rank IDs to profile + gen_shape: False # Generate model and kernel details including input shapes + + optim: + name: fused_adam + lr: 1e-3 + weight_decay: 0.2 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 2000 + constant_steps: 0 + min_lr: 1e-5 \ No newline at end of file diff --git a/examples/multimodal/vision_language_foundation/clip/convert_external_clip_to_nemo.py b/examples/multimodal/vision_language_foundation/clip/convert_external_clip_to_nemo.py index b9b9ab917173f..9af25181d07e2 100644 --- a/examples/multimodal/vision_language_foundation/clip/convert_external_clip_to_nemo.py +++ b/examples/multimodal/vision_language_foundation/clip/convert_external_clip_to_nemo.py @@ -283,6 +283,7 @@ def convert(local_rank, rank, world_size, args): if __name__ == '__main__': + logging.warning("This script is going to be deprecated soon. Please use ") args = get_args() local_rank, rank, world_size = initialize_distributed(args) convert(local_rank, rank, world_size, args) diff --git a/examples/multimodal/vision_language_foundation/clip/megatron_clip_pretrain.py b/examples/multimodal/vision_language_foundation/clip/megatron_clip_pretrain.py index 4462649a5861b..abca470e5843b 100644 --- a/examples/multimodal/vision_language_foundation/clip/megatron_clip_pretrain.py +++ b/examples/multimodal/vision_language_foundation/clip/megatron_clip_pretrain.py @@ -22,8 +22,6 @@ from nemo.utils import logging from nemo.utils.exp_manager import exp_manager -mp.set_start_method("spawn", force=True) - @hydra_runner(config_path="conf", config_name="megatron_clip_config") def main(cfg) -> None: @@ -31,7 +29,10 @@ def main(cfg) -> None: logging.info(f'\n{OmegaConf.to_yaml(cfg)}') assert ( - cfg.trainer.devices * cfg.trainer.num_nodes + cfg.trainer.devices + * cfg.trainer.num_nodes + // cfg.model.tensor_model_parallel_size + // cfg.model.pipeline_model_parallel_size ) * cfg.model.micro_batch_size == cfg.model.global_batch_size, ( "Gradient accumulation is not supported in CLIP yet." ) diff --git a/examples/nlp/information_retrieval/conf/megatron_gpt_embedder_generate_config.yaml b/examples/nlp/information_retrieval/conf/megatron_gpt_embedder_generate_config.yaml index 1a81d21dd9a83..e407aec167e91 100644 --- a/examples/nlp/information_retrieval/conf/megatron_gpt_embedder_generate_config.yaml +++ b/examples/nlp/information_retrieval/conf/megatron_gpt_embedder_generate_config.yaml @@ -120,7 +120,6 @@ model: tunable_base_param_names: ["self_attention", "word_embeddings"] # TODO: regex support @adithyre data: - return_output_tensors: True test_ds: query_file_names: ??? # Path to a list of JSONL files corresponding to the query data. Data format is identical to validation_ds. doc_file_names: ??? # Path to a list of JSONL files corresponding to the doc data. Data format is identical to validation_ds. diff --git a/examples/nlp/information_retrieval/conf/megatron_gpt_embedder_tuning_config.yaml b/examples/nlp/information_retrieval/conf/megatron_gpt_embedder_tuning_config.yaml index 6677dc2ed46ca..1c2db1a862f4c 100644 --- a/examples/nlp/information_retrieval/conf/megatron_gpt_embedder_tuning_config.yaml +++ b/examples/nlp/information_retrieval/conf/megatron_gpt_embedder_tuning_config.yaml @@ -84,6 +84,7 @@ model: use_flash_attention: True precision: bf16 apply_rope_fusion: False + reward_model_loss: False # Set this to true to perform RLHF style reward model loss -log(sigmoid(accept_logit - reject_logit)) peft: peft_scheme: "lora" # can be either adapter,ia3, or ptuning @@ -126,7 +127,6 @@ model: tunable_base_param_names: ["self_attention", "word_embeddings"] # TODO: regex support @adithyre data: - return_output_tensors: True train_ds: # Example of how to specify paths to multiple datasets # file_names: diff --git a/examples/nlp/information_retrieval/conf/megatron_gpt_reranker_tuning_config.yaml b/examples/nlp/information_retrieval/conf/megatron_gpt_reranker_tuning_config.yaml new file mode 100644 index 0000000000000..863b5fb475a07 --- /dev/null +++ b/examples/nlp/information_retrieval/conf/megatron_gpt_reranker_tuning_config.yaml @@ -0,0 +1,222 @@ +name: megatron_gpt_peft_reranker_tuning + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: bf16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: null + max_steps: 20000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 # frequency with which training steps are logged + val_check_interval: ${trainer.max_steps} # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch + gradient_clip_val: null + num_sanity_val_steps: 0 + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: validation_${model.data.validation_ds.metric.name} + save_top_k: 1 + mode: min + save_nemo_on_train_end: True + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{consumed_samples}' + model_parallel_size: ${model.tensor_model_parallel_size} + always_save_nemo: False + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + strict: False # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + +model: + seed: 1234 + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + + global_batch_size: 128 + micro_batch_size: 4 + restore_from_path: ??? # Path to an existing .nemo model you wish to add new tasks to or run inference with + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + save_nemo_on_validation_end: False # Saves an inference ready .nemo file every time a checkpoint is saved during training. + sync_batch_comm: False + megatron_amp_O2: True + + ## Sequence Parallelism + # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + sequence_parallel: False + + ## Activation Checkpoint + activations_checkpoint_granularity: selective # 'selective' or 'full' + activations_checkpoint_method: uniform # 'uniform', 'block', not used with 'selective' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null # not used with 'selective' + activations_checkpoint_layers_per_pipeline: null + gradient_as_bucket_view: False + + hidden_dropout: 0.0 + attention_dropout: 0.0 + ffn_dropout: 0.0 + temperature: 0.02 + num_soft_negatives: 0 # Number of soft negatives to use for contrastive loss,it should be max(batch_size - 1), 0 means use hard negatives only + use_all_possible_negatives: False # If True, use all possible negatives for contrastive loss, otherwise use num_soft_negatives, if num_soft_negatives is 0, use hard negatives only + post_process: False # should be False. + apply_rope_fusion: False + transformer_engine: True # required to be True for newer versions of Megatron-LM based models + mcore_gpt: True # required to be True for newer versions of Megatron-LM based models + use_flash_attention: True + precision: bf16 + + peft: + peft_scheme: "mlp_head,lora" # can be either adapter,ia3, or ptuning + restore_from_path: null + + # Used for adapter peft training + adapter_tuning: + type: 'parallel_adapter' # this should be either 'parallel_adapter' or 'linear_adapter' + adapter_dim: 32 + adapter_dropout: 0.0 + norm_position: 'pre' # This can be set to 'pre', 'post' or null, 'pre' is normally what is used. + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + norm_type: 'mixedfusedlayernorm' # IGNORED if layer_adapter is used, options are ['layernorm', 'mixedfusedlayernorm'] + layer_selection: null # selects in which layers to add adapters, e.g. [1,12] will add adapters to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + lora_tuning: + target_modules: ['attention_qkv', 'attention_dense', 'mlp_fc1', 'mlp_fc2'] # + adapter_dim: 32 + adapter_dropout: 0.0 + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + # Used for p-tuning peft training + p_tuning: + virtual_tokens: 10 # The number of virtual tokens the prompt encoder should add at the start of the sequence + bottleneck_dim: 1024 # the size of the prompt encoder mlp bottleneck + embedding_dim: 1024 # the size of the prompt encoder embeddings + init_std: 0.023 + + # Instead of using the GPT LM Head, we can use a custom head for the reranking task + mlp_head_tuning: + out_features: 1 + + ia3_tuning: + layer_selection: null # selects in which layers to add ia3 adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + + selective_tuning: + tunable_base_param_names: ["self_attention", "word_embeddings"] # TODO: regex support @adithyre + + data: + train_ds: + # Example of how to specify paths to multiple datasets + # file_names: + # - /path/to/squad.jsonl + # - /path/to/mnli.jsonl + # - /path/to/boolq.jsonl + # Example of how each dataset is formatted + # {'input': 'John von Neumann\nVon Neumann made fundamental contributions .... Q: What did the math of artificial viscosity do?', 'output': 'smoothed the shock transition without sacrificing basic physics'} + file_names: ??? # Path to a list of JSONL files corresponding to the source data. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: True + num_workers: 0 + memmap_workers: 2 + pin_memory: True + max_seq_length: 512 # Even if the base model can handle longer sequences, 512 is generally a good choice for training efficiency. + min_seq_length: 1 + drop_last: True + # Example of how to specify concat_sampling_probabilities + # concat_sampling_probabilities: + # - 0.5 + # - 0.25 + # - 0.25 + concat_sampling_probabilities: + - 1.0 + label_key: 'output' + add_eos: True + add_bos: False + index_mapping_dir: null # Path to a directory to write index mapping files. + truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] + validation_ds: + file_names: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + names: ["validation"] # Names of the corresponding datasets used to log metrics. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 0 + memmap_workers: ${model.data.train_ds.memmap_workers} + pin_memory: True + max_seq_length: ${model.data.train_ds.max_seq_length} + min_seq_length: 1 + drop_last: False + label_key: ${model.data.train_ds.label_key} + add_eos: ${model.data.train_ds.add_eos} + add_bos: ${model.data.train_ds.add_bos} + write_embeddings_to_file: False + output_file_path_prefix: "validation_rankings" # Prefix of the file to write predictions to. + index_mapping_dir: null # Path to a directory to write index mapping files. + truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] + metric: + name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + test_ds: + file_names: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + names: null # Names of the corresponding datasets used to log metrics. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 0 + memmap_workers: ${model.data.train_ds.memmap_workers} + pin_memory: True + max_seq_length: ${model.data.train_ds.max_seq_length} + min_seq_length: 1 + drop_last: False + add_eos: ${model.data.train_ds.add_eos} + add_bos: ${model.data.train_ds.add_bos} + write_predictions_to_file: True + output_file_path_prefix: "test_embeddings" # Prefix of the file to write predictions to. + index_mapping_dir: null # Path to a directory to write index mapping files. + truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] + metric: + name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 50 + min_lr: 0.0 # min_lr must be 0.0 for prompt learning when pipeline parallel > 1 + constant_steps: 0 # Constant steps should also be 0 when min_lr=0 + monitor: val_loss + reduce_on_plateau: false \ No newline at end of file diff --git a/examples/nlp/information_retrieval/megatron_gpt_embedding_generate.py b/examples/nlp/information_retrieval/megatron_gpt_embedding_generate.py index 8cddcebbab62b..d66ddb3397735 100644 --- a/examples/nlp/information_retrieval/megatron_gpt_embedding_generate.py +++ b/examples/nlp/information_retrieval/megatron_gpt_embedding_generate.py @@ -68,7 +68,9 @@ def use_inference_server(cfg, model, trainer): web_ui = get_demo loop = asyncio.new_event_loop() thread = threading.Thread( - target=web_ui, daemon=True, args=(cfg.share, cfg.username, cfg.password, cfg.port, cfg.web_port, loop), + target=web_ui, + daemon=True, + args=(cfg.share, cfg.username, cfg.password, cfg.port, cfg.web_port, loop), ) thread.start() server = MegatronServer(model.cuda()) @@ -93,7 +95,6 @@ def main(cfg) -> None: model_cfg = MegatronGPTEmbeddingModel.merge_inference_cfg(cfg.model.restore_from_path, cfg) with open_dict(model_cfg): - model_cfg.data.return_output_tensors = True model_cfg.post_process = False model = MegatronGPTEmbeddingModel.restore_from(cfg.model.restore_from_path, model_cfg, trainer=trainer) diff --git a/examples/nlp/information_retrieval/megatron_gpt_reranker_finetuning.py b/examples/nlp/information_retrieval/megatron_gpt_reranker_finetuning.py new file mode 100644 index 0000000000000..cf65840bb843f --- /dev/null +++ b/examples/nlp/information_retrieval/megatron_gpt_reranker_finetuning.py @@ -0,0 +1,76 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import MutableMapping + +import torch.multiprocessing as mp +from omegaconf.omegaconf import OmegaConf +from pytorch_lightning.loggers import WandbLogger + +from nemo.collections.nlp.models.information_retrieval.megatron_gpt_reranker_model import MegatronGPTRerankerModel +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder +from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + +mp.set_start_method("spawn", force=True) + + +def flatten_dict(d: MutableMapping, parent_key: str = '', sep: str = '.') -> MutableMapping: + items = [] + for k, v in d.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, MutableMapping): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +@hydra_runner(config_path="conf", config_name="megatron_gpt_reranker_tuning_config") +def main(cfg) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f'\n{OmegaConf.to_yaml(cfg)}') + + trainer = MegatronLMPPTrainerBuilder(cfg).create_trainer() + exp_manager(trainer, cfg.exp_manager) + + model_cfg = MegatronGPTRerankerModel.merge_cfg_with(cfg.model.restore_from_path, cfg) + if trainer.global_rank == 0: + for logger in trainer.loggers: + if isinstance(logger, WandbLogger): + fd = flatten_dict(dict(model_cfg), sep="/") + logger.experiment.config.update(fd) + model = MegatronGPTRerankerModel.restore_from(cfg.model.restore_from_path, model_cfg, trainer=trainer) + peft_cfg_cls_lst = [PEFT_CONFIG_MAP[s] for s in cfg.model.peft.peft_scheme.split(",")] + peft_cfg_cls = [_peft_cfg(model_cfg) for _peft_cfg in peft_cfg_cls_lst] + + if cfg.model.peft.restore_from_path is not None: + # initialize peft weights from a checkpoint instead of randomly + # This is not the same as resume training because optimizer states are not restored. + logging.info("PEFT Weights will be loaded from", cfg.model.peft.restore_from_path) + model.load_adapters(cfg.model.peft.restore_from_path, peft_cfg_cls) + elif peft_cfg_cls is not None: + logging.info("Adding adapter weights to the model for PEFT") + # model.add_adapter(peft_cfg_cls(model_cfg)) + model.add_adapter(peft_cfg_cls) + else: + logging.info(f"Running full finetuning since no peft scheme is given.\n{model.summarize()}") + + trainer.fit(model) + + +if __name__ == '__main__': + main() diff --git a/examples/nlp/information_retrieval/megatron_gpt_reranker_generate.py b/examples/nlp/information_retrieval/megatron_gpt_reranker_generate.py new file mode 100644 index 0000000000000..a91449c3deda4 --- /dev/null +++ b/examples/nlp/information_retrieval/megatron_gpt_reranker_generate.py @@ -0,0 +1,138 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import os +import threading +from functools import partial + +import torch +import torch.multiprocessing as mp +from omegaconf.omegaconf import OmegaConf, open_dict + +from nemo.collections.nlp.models.information_retrieval.megatron_gpt_reranker_model import MegatronGPTRerankerModel +from nemo.collections.nlp.modules.common.text_generation_server import MegatronServer +from nemo.collections.nlp.modules.common.text_generation_utils import generate +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder +from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.model_utils import inject_model_parallel_rank + +try: + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + +mp.set_start_method("spawn", force=True) + + +def use_inference_server(cfg, model, trainer): + if not HAVE_MEGATRON_CORE: + raise ValueError('Megatron-core needs to be installed to use this feature!') + + from nemo.collections.nlp.modules.common.megatron_web_server import get_chatbot_demo, get_demo + + trainer.test(model, dataloaders=None) + + if parallel_state.is_pipeline_first_stage() and parallel_state.get_tensor_model_parallel_rank() == 0: + if cfg.web_server: + if cfg.chat: + defaults = { + 'user': cfg.chatbot_config.user, + 'assistant': cfg.chatbot_config.assistant, + 'system': cfg.chatbot_config.system, + } + web_ui = partial( + get_chatbot_demo, + defaults=defaults, + value=cfg.chatbot_config.value, + attributes=cfg.chatbot_config.attributes, + ) + else: + web_ui = get_demo + loop = asyncio.new_event_loop() + thread = threading.Thread( + target=web_ui, + daemon=True, + args=(cfg.share, cfg.username, cfg.password, cfg.port, cfg.web_port, loop), + ) + thread.start() + server = MegatronServer(model.cuda()) + server.run("0.0.0.0", port=cfg.port) + + while True: + choice = torch.cuda.LongTensor(1) + torch.distributed.broadcast(choice, 0) + if choice[0].item() == 0: + generate(model.cuda()) + + +@hydra_runner(config_path="conf", config_name="megatron_gpt_reranker_generate_config") +def main(cfg) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f"\n{OmegaConf.to_yaml(cfg)}") + trainer = MegatronLMPPTrainerBuilder(cfg).create_trainer() + + if cfg.model.peft.restore_from_path: + model_cfg = MegatronGPTRerankerModel.merge_inference_cfg(cfg.model.peft.restore_from_path, cfg) + else: + model_cfg = MegatronGPTRerankerModel.merge_inference_cfg(cfg.model.restore_from_path, cfg) + + with open_dict(model_cfg): + model_cfg.post_process = False + + model = MegatronGPTRerankerModel.restore_from(cfg.model.restore_from_path, model_cfg, trainer=trainer) + + if cfg.model.peft.restore_from_path: + model.load_adapters(cfg.model.peft.restore_from_path) + elif cfg.model.peft.restore_from_ckpt.checkpoint_dir and cfg.model.peft.restore_from_ckpt.checkpoint_name: + peft_cfg_cls_lst = [PEFT_CONFIG_MAP[s] for s in cfg.model.peft.peft_scheme.split(",")] + peft_cfg_cls = [_peft_cfg(model_cfg) for _peft_cfg in peft_cfg_cls_lst] + + checkpoint_path = os.path.join( + cfg.model.peft.restore_from_ckpt.checkpoint_dir, cfg.model.peft.restore_from_ckpt.checkpoint_name + ) + # checkpoint_path is a dir in case of distributed checkpointing + if not os.path.isdir(checkpoint_path): + # legacy checkpoint needs model parallel rank injection + checkpoint_path = inject_model_parallel_rank( + os.path.join( + cfg.model.peft.restore_from_ckpt.checkpoint_dir, cfg.model.peft.restore_from_ckpt.checkpoint_name + ) + ) + model.load_adapters(checkpoint_path, peft_cfgs=peft_cfg_cls) + else: + raise NotImplementedError("distributed checkpointing of PEFT weights is not supported") + + model.freeze() + logging.info(f"Freezing parameters for PEFT eval:\n{model.summarize()}") + + if not cfg.model.get('use_flash_attention', False): + cfg.inference.compute_attention_mask = True + config = OmegaConf.to_container(cfg.inference, resolve=True) + model.set_inference_config(config) + + if not cfg.server: + trainer.test(model) + else: + use_inference_server(cfg, model, trainer) + + +if __name__ == "__main__": + main() diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index ccdddcbc22724..ac1f4a37b2322 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -3,7 +3,6 @@ defaults: - optional tp_overlap@model.ub_tp_comm_overlap_cfg: name: megatron_gpt -restore_from_path: null # used when starting from a .nemo file trainer: devices: 1 @@ -66,6 +65,10 @@ exp_manager: async_save: False # Set to True to enable async checkpoint save. Currently works only with distributed checkpoints model: + # The following two settings are used for continual training: + restore_from_path: null # Set this to a .nemo file path to restore only the model weights + restore_from_ckpt: null # Set this to a training ckpt path to restore both model weights and optimizer states + # use GPTModel from megatron.core mcore_gpt: True @@ -115,6 +118,14 @@ model: seq_len_interpolation_factor: null # RoPE Interpolation factor for sequence length. This is used to build long-context models with RoPE ex: https://arxiv.org/abs/2306.15595. num_query_groups: null # Number of query groups for group query attention. If None, normal attention is used. + ## Reset learning rate schedule. + # 1. reset_lr=True, reset_lr_steps=False. When pre-training an existing checkpoint "from scratch" on a different dataset. + # 2. reset_lr=True, reset_lr_steps=True. When continuing training from an existing checkpoint with the same configuration. + # Learning rate's max_steps and decay_steps will be recalculated as follows: max_steps -= completed_steps, decay_steps -= completed_steps where completed_steps is the number of steps already completed at the checkpoint. + # This will help to reach the min_lr value by the end of training without changing trainer.max_steps. + reset_lr: False # Set to True to reset learning rate to initial learning rate. Only supported with distributed optmizer and megatron_amp_O2. + reset_lr_steps: False # Set to True to adjust learning rate's max_steps and decay_steps by subtracting number of steps already completed at the checkpoint. + tokenizer: library: 'megatron' type: 'GPT2BPETokenizer' @@ -166,6 +177,10 @@ model: dist_ckpt_format: 'zarr' # Set to 'torch_dist' to use PyTorch distributed checkpoint format. dist_ckpt_load_on_device: True # whether to load checkpoint weights directly on GPU or to CPU dist_ckpt_parallel_save: False # if true, each worker will write its own part of the dist checkpoint + dist_ckpt_parallel_load: False # if true, each worker will load part of the dist checkpoint and exchange with NCCL. Might use some extra GPU memory + dist_ckpt_torch_dist_multiproc: 2 # number of extra processes per rank used during ckpt save with PyTorch distributed format + dist_ckpt_assume_constant_structure: False # set to True only if the state dict structure doesn't change within a single job. Allows caching some computation across checkpoint saves. + dist_ckpt_parallel_dist_opt: True # parallel save/load of a DistributedOptimizer. 'True' allows performant save and reshardable checkpoints. Set to 'False' only in order to minimize the number of checkpoint files. ## Activation Checkpointing # NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_inference.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_inference.yaml index 2570251bcdee2..056f9eb9c6ecf 100644 --- a/examples/nlp/language_modeling/conf/megatron_gpt_inference.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_inference.yaml @@ -1,3 +1,4 @@ +# NOTE : This config and megatron_gpt_eval.py will be deprecated soon. Use megatron_gpt_inference_batch_mcore.yaml inference: greedy: False # Whether or not to use sampling ; use greedy decoding otherwise top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering. @@ -31,6 +32,7 @@ hparams_file: null # model configuration file, only used for PTL checkpoint load prompts: # prompts for GPT inference - "Q: How are you?" - "Q: How big is the universe?" +prompts_jsonl: null server: False # whether launch the API server port: 5555 # the port number for the inference server web_server: False # whether launch the web inference server diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_inference_batch_mcore.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_inference_batch_mcore.yaml new file mode 100644 index 0000000000000..1b34a8b5abc3d --- /dev/null +++ b/examples/nlp/language_modeling/conf/megatron_gpt_inference_batch_mcore.yaml @@ -0,0 +1,29 @@ +common_inference_params: + top_k: 1 # The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p: 0.0 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. + temperature: 1.0 # sampling temperature + tokens_to_generate: 30 # The minimum length of the sequence to be generated. + return_log_probs: False # whether return the log prob for the sampled tokens + repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty. + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + logger: False # logger provided by exp_manager + precision: 16 # 16, 32, or bf16 + use_distributed_sampler: False + +tensor_model_parallel_size: -1 +pipeline_model_parallel_size: -1 +inference_batch_times_seq_len_threshold: 1000 # If batch_size * sequence-length is smaller than this threshold we will not use pipelining, otherwise we will. +max_batch_size: 4 # Input prompts are batched using max_batch_size and sent to inference + +megatron_amp_O2: False # Enable O2-level automatic mixed precision to save memory +gpt_model_file: null # GPT nemo file path +checkpoint_dir: null # checkpoint file dir. This is used to load the PTL checkpoint generated during the GPT training +checkpoint_name: null # PTL checkpoint file name, only used for PTL checkpoint loading +hparams_file: null # model configuration file, only used for PTL checkpoint loading +prompts: # prompts for GPT inference + - "Q: How are you?" + - "Q: How big is the universe?" diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_quantization.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_ptq.yaml similarity index 93% rename from examples/nlp/language_modeling/conf/megatron_gpt_quantization.yaml rename to examples/nlp/language_modeling/conf/megatron_gpt_ptq.yaml index d93331439d82f..c70719f51210a 100644 --- a/examples/nlp/language_modeling/conf/megatron_gpt_quantization.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_ptq.yaml @@ -36,6 +36,7 @@ quantization: num_calib_size: 512 # number of samples used for calibration awq_block_size: 128 # block size for scaling factors (only used in AWQ algorithms) sq_alpha: 1.0 # alpha parameter (only used in SmoothQuant algorithms) + enable_kv_cache: null # Enable FP8 KV cache quantization. Set to null for automatic selection. export: decoder_type: llama # gptnext, gpt2, llama @@ -43,3 +44,4 @@ export: inference_pipeline_parallel: 1 # Default using 1 PP for inference dtype: ${trainer.precision} # Default precision data type save_path: llama2-7b-${quantization.algorithm}.qnemo # Path where the quantized model will be saved + compress: false # Wheter save_path should be a tarball or a directory diff --git a/examples/nlp/language_modeling/conf/megatron_mamba_config.yaml b/examples/nlp/language_modeling/conf/megatron_mamba_config.yaml new file mode 100644 index 0000000000000..f4f37d7c4ce08 --- /dev/null +++ b/examples/nlp/language_modeling/conf/megatron_mamba_config.yaml @@ -0,0 +1,191 @@ +name: megatron_mamba +restore_from_path: null # used when starting from a .nemo file + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + precision: bf16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 # PTL default. In practice we don't usually train for more than 1 epoch. + max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 + val_check_interval: 100 + limit_val_batches: 50 + limit_test_batches: 500 + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + benchmark: False + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: megatron_mamba + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 10 + mode: min + always_save_nemo: False # saves nemo file during validation, not implemented for model parallel + filename: 'megatron_mamba--{val_loss:.2f}-{step}-{consumed_samples}' + model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} + + +model: + restore_from_path: null + # model parallelism + mcore_gpt: True + micro_batch_size: 1 + global_batch_size: 8 + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + virtual_pipeline_model_parallel_size: null + expert_model_parallel_size: 1 # expert model parallelism + hybrid_override_pattern: null + vocab_size: 256000 + # model architecture + encoder_seq_length: 4096 + max_position_embeddings: ${.encoder_seq_length} + position_embedding_type: 'none' # Position embedding type. Options ['learned_absolute', 'rope', 'alibi', 'kerple' , 'xpos', 'sandwich'] xpos and sandwich are experimental. + num_layers: 56 + gated_linear_unit: False + add_bias_linear: False + num_query_groups: 8 + mamba_ssm_ngroups: 8 + attention_dropout: 0.0 + hidden_dropout: 0.0 + hidden_size: 4096 + ffn_hidden_size: 14336 # Transformer FFN hidden size. Usually 4 * hidden_size. + num_attention_heads: 32 + transformer_block_type: pre_ln + init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') + kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null + apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. + normalization: RMSNorm + layernorm_epsilon: 1e-5 + num_moe_experts: 16 + moe_router_topk: 2 + moe_aux_loss_coeff: 0.001 + make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency. + pre_process: True # add embedding + post_process: True # add pooler + megatron_legacy: False + persist_layer_norm: True + + tokenizer: + library: 'huggingface' + type: 'EleutherAI/gpt-neox-20b' + model: null + vocab_file: null + merge_file: null + sentencepiece_legacy: False + use_fast: True + + # Distributed checkpoint setup + dist_ckpt_format: 'zarr' # Set to 'torch_dist' to use PyTorch distributed checkpoint format. + dist_ckpt_load_on_device: True # whether to load checkpoint weights directly on GPU or to CPU + dist_ckpt_parallel_save: False # if true, each worker will write its own part of the dist checkpoint + + # precision + native_amp_init_scale: 4294967296 # 2 ** 32 + native_amp_growth_interval: 1000 + fp32_residual_connection: False # Move residual connections to fp32 + fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 + + # Megatron O2-style half-precision + megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters + grad_allreduce_chunk_size_mb: 125 + + + # Fusion + grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce. Only used with O2 and no pipeline parallelism.. + gradient_accumulation_fusion: False # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism and O2. + bias_activation_fusion: False # Use a kernel that fuses the bias addition from weight matrices with the subsequent activation function. + bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition. + masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. + get_attention_mask_from_fusion: True # When using fused softmax it will create the attention mask so we won't copy it to the pipeline stages. + apply_rope_fusion: True # Use a kernel to add rotary positional embeddings. Only used if position_embedding_type=rope + + + # miscellaneous + seed: 1234 + use_cpu_initialization: False # Init weights on the CPU (slow for large models) + onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + + ## Activation Checkpointing + # NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. + # These memory intensive activations are also less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + # 'full' will checkpoint the entire transformer layer. + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_recurrent: False # If set to True, the checkpointing is only done for rglru and conv1d and not for attention and mlp layers + activations_checkpoint_method: null # 'uniform', 'block' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity. When used with 'selective', 'uniform' checkpoints all attention blocks in the model. + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null + # when using 'uniform' this creates groups of transformer layers to checkpoint. Usually set to 1. Increase to save more memory. + # when using 'block' this this will checkpoint the first activations_checkpoint_num_layers per pipeline stage. + num_micro_batches_with_partial_activation_checkpoints: null + # This feature is valid only when used with pipeline-model-parallelism. + # When an integer value is provided, it sets the number of micro-batches where only a partial number of Transformer layers get checkpointed + # and recomputed within a window of micro-batches. The rest of micro-batches in the window checkpoint all Transformer layers. The size of window is + # set by the maximum outstanding micro-batch backpropagations, which varies at different pipeline stages. The number of partial layers to checkpoint + # per micro-batch is set by 'activations_checkpoint_num_layers' with 'activations_checkpoint_method' of 'block'. + # This feature enables using activation checkpoint at a fraction of micro-batches up to the point of full GPU memory usage. + activations_checkpoint_layers_per_pipeline: null + # This feature is valid only when used with pipeline-model-parallelism. + # When an integer value (rounded down when float is given) is provided, it sets the number of Transformer layers to skip checkpointing at later + # pipeline stages. For example, 'activations_checkpoint_layers_per_pipeline' of 3 makes pipeline stage 1 to checkpoint 3 layers less than + # stage 0 and stage 2 to checkpoint 6 layers less stage 0, and so on. This is possible because later pipeline stage + # uses less GPU memory with fewer outstanding micro-batch backpropagations. Used with 'num_micro_batches_with_partial_activation_checkpoints', + # this feature removes most of activation checkpoints at the last pipeline stage, which is the critical execution path. + sequence_parallel: False + + data: + # Path to data must be specified by the user. + # can override from the CLI: "model.data.data_prefix=[.5,/raid/data/pile/my-gpt3_00_text_document,.5,/raid/data/pile/my-gpt3_01_text_document]", + # Or see example below: + # data_prefix: + # - .5 + # - /raid/data/pile/my-gpt3_00_text_document + # - .5 + # - /raid/data/pile/my-gpt3_01_text_document + data_prefix: [1.0, /path/to/data] + index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix + data_impl: mmap + splits_string: 900,50,50 + seq_length: ${model.encoder_seq_length} + skip_warmup: True + num_workers: 0 + dataloader_type: single # cyclic, LDDL + reset_position_ids: False # Reset position ids after end-of-document token + reset_attention_mask: False # Reset attention mask after end-of-document token + eod_mask_loss: False # Mask loss for the end of document tokens + masked_lm_prob: 0.15 # Probability of replacing a token with mask. + short_seq_prob: 0.1 # Probability of producing a short sequence. + ceil_to_power_2: True + get_attention_mask_from_fusion: True + pad_to_max_length: True + + optim: + name: distributed_fused_adam + lr: 2e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 500 + constant_steps: 50000 + min_lr: 2e-5 diff --git a/examples/nlp/language_modeling/conf/megatron_t5_config.yaml b/examples/nlp/language_modeling/conf/megatron_t5_config.yaml index e51cfff420a37..439a0f1533bd0 100644 --- a/examples/nlp/language_modeling/conf/megatron_t5_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_t5_config.yaml @@ -43,6 +43,10 @@ exp_manager: model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} model: + # use T5 model from megatron.core + mcore_t5: False + transformer_engine: False + # model parallelism micro_batch_size: 4 global_batch_size: 8 # will use more micro batches to reach global batch size diff --git a/examples/nlp/language_modeling/mamba_change_num_partition.py b/examples/nlp/language_modeling/mamba_change_num_partition.py new file mode 100644 index 0000000000000..bc76b3215a741 --- /dev/null +++ b/examples/nlp/language_modeling/mamba_change_num_partition.py @@ -0,0 +1,696 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import tarfile +import tempfile +from argparse import ArgumentParser + +import torch +from omegaconf import open_dict +from pytorch_lightning import Trainer + +from nemo.collections.nlp.models.language_modeling.megatron_mamba_model import MegatronMambaModel +from nemo.collections.nlp.parts.nlp_overrides import ( + NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE, + GradScaler, + MegatronHalfPrecisionPlugin, + NLPDDPStrategy, + NLPSaveRestoreConnector, + PipelineMixedPrecisionPlugin, +) +from nemo.utils import logging +from nemo.utils.app_state import AppState + +""" +Usage: + +### Tensor Parallelism conversion ### + +# Megatron Mamba +python /opt/NeMo/examples/nlp/language_modeling/mamba_change_num_partition.py \ + --model_file= \ + --target_file= \ + --tensor_model_parallel_size=1 \ + --target_tensor_model_parallel_size=4 \ + --precision=bf16 \ + --d-model=4096 \ + --mamba-version=2 \ + --mamba2-n-groups=8 \ + --mamba2-head-dim=64 +""" + +tp_split_dim = { + 'word_embeddings.weight': 0, + 'norm.weight': -1, + 'final_norm.weight': -1, + 'output_layer.weight': 0, + # mamba1/2 + 'A_log': 0, + 'D': 0, + 'dt_bias': 0, + 'in_proj.weight': 0, + 'conv1d.weight': 0, + 'conv1d.bias': 0, + 'x_proj.weight': 1, + 'dt_proj.weight': 0, + 'dt_proj.bias': 0, + 'out_proj.weight': 1, + 'mixer.norm.weight': 0, + # mlp + 'linear_fc1.layer_norm_weight': -1, + 'linear_fc1.weight': 0, + 'linear_fc2.weight': 1, + # attention + 'self_attention.linear_proj.weight': 1, + 'self_attention.linear_qkv.layer_norm_weight': -1, + 'self_attention.linear_qkv.weight': 0, +} + + +def get_split_dim(tensor_name): + # norm.weight will match tensor_name of mixer.norm.weight and norm.weight, need to distinguish + if 'norm.weight' in tensor_name: + if 'mixer.norm.weight' in tensor_name: + return tp_split_dim['mixer.norm.weight'] + else: + return tp_split_dim['norm.weight'] + + for key in tp_split_dim.keys(): + if key in tensor_name: + return tp_split_dim[key] + raise Exception("Unknown tensor name {}".format(tensor_name)) + + +def split_tensor_for_tp(params, key, dim, tensor): + + tp_size = params.target_tensor_model_parallel_size + tensor_sliced = [] + if dim == -1: + tensor_sliced = [tensor for i in range(tp_size)] + else: + if 'mixer.in_proj.weight' in key and params.mamba_version == 1: + x, z = torch.split(tensor, [params.mamba_d_inner, params.mamba_d_inner], dim=dim) + x_sliced = torch.chunk(x, tp_size, dim=dim) + z_sliced = torch.chunk(z, tp_size, dim=dim) + for x, z in zip(x_sliced, z_sliced): + tensor_sliced.append(torch.cat((x, z), dim=dim)) + + elif 'mixer.in_proj.weight' in key and params.mamba_version == 2: + x, z, B, C, dt = torch.split( + tensor, + [ + params.mamba_d_inner, + params.mamba_d_inner, + params.mamba2_n_groups * params.mamba_d_state, + params.mamba2_n_groups * params.mamba_d_state, + params.mamba2_n_heads, + ], + dim=dim, + ) + B = torch.reshape(B, (-1, params.mamba_d_state, B.shape[-1])) + C = torch.reshape(C, (-1, params.mamba_d_state, C.shape[-1])) + + B_sliced = torch.chunk(B, tp_size, dim=dim) + C_sliced = torch.chunk(C, tp_size, dim=dim) + x_sliced = torch.chunk(x, tp_size, dim=dim) + z_sliced = torch.chunk(z, tp_size, dim=dim) + dt_sliced = torch.chunk(dt, tp_size, dim=dim) + + tensor_sliced = [] + for x, z, B, C, dt in zip(x_sliced, z_sliced, B_sliced, C_sliced, dt_sliced): + tensor_sliced.append(torch.cat((x, z, B.flatten(0, 1), C.flatten(0, 1), dt), dim=dim)) + + elif 'mixer.conv1d' in key and params.mamba_version == 2: + x, B, C = torch.split( + tensor, + [ + params.mamba_d_inner, + params.mamba2_n_groups * params.mamba_d_state, + params.mamba2_n_groups * params.mamba_d_state, + ], + dim=dim, + ) + if 'weight' in key: + B = torch.reshape(B, (-1, params.mamba_d_state, B.shape[-2], B.shape[-1])) + C = torch.reshape(C, (-1, params.mamba_d_state, C.shape[-2], C.shape[-1])) + elif 'bias' in key: + B = torch.reshape(B, (-1, params.mamba_d_state)) + C = torch.reshape(C, (-1, params.mamba_d_state)) + else: + raise Exception("Unknown key") + + B_sliced = torch.chunk(B, tp_size, dim=dim) + C_sliced = torch.chunk(C, tp_size, dim=dim) + x_sliced = torch.chunk(x, tp_size, dim=dim) + + tensor_sliced = [] + for x, B, C in zip(x_sliced, B_sliced, C_sliced): + tensor_sliced.append(torch.cat((x, B.flatten(0, 1), C.flatten(0, 1)), dim=dim)) + elif '_extra_state' in key: + pass + else: + tensor_sliced = torch.chunk(tensor, tp_size, dim=dim) + + return tensor_sliced + + +################# +### Utilities ### +################# + + +def force_cpu_model(cfg): + with open_dict(cfg): + # temporarily set to cpu + original_cpu_init = cfg.get('use_cpu_initialization', False) + if 'megatron_amp_O2' in cfg: + amp_o2_key = 'megatron_amp_O2' + original_amp_o2 = cfg.megatron_amp_O2 + elif 'megatron_amp_02' in cfg: + amp_o2_key = 'megatron_amp_02' + original_amp_o2 = cfg.megatron_amp_02 + else: + amp_o2_key, original_amp_o2 = None, None + + # Set new values + cfg.use_cpu_initialization = True + if amp_o2_key is not None: + cfg[amp_o2_key] = False + + # Disable sequence parallelism - Not disabling this gives error when converting the the model to TP=1 + original_sequence_parallel = cfg.get('sequence_parallel', None) + cfg.sequence_parallel = False + + # Setup restore dict + restore_dict = {'use_cpu_initialization': original_cpu_init} # 'megatron_amp_O2': original_amp_o2 + if amp_o2_key is not None: + restore_dict[amp_o2_key] = original_amp_o2 + if original_sequence_parallel is not None: + restore_dict['sequence_parallel'] = original_sequence_parallel + + return cfg, restore_dict + + +def restore_model_config(cfg, original_dict): + with open_dict(cfg): + for key, val in original_dict.items(): + logging.info(f"Restoring model config key ({key}) from {cfg[key]} to original value of {val}") + cfg[key] = val + return cfg + + +def write_tp_pp_split(model, splits, app_state, tp_size, pp_rank, write_path): + """ + Function to write the given TP PP split to NeMo File. + + Save each of the TP ranks in reverse order + This is done so that the last PP rank will save the last TP rank only after all other PP TP ranks are saved + The final rank will then save a new NeMo file with all other ranks inside. + + Args: + model: The model corresponding to the current TP PP split. Contains partial parameters. + splits: Nested List of tensors containing the TP splits of the current model given current PP rank. + Indexed as splits[idx][tp_rank]. + app_state: AppState object. + tp_size: The global tensor-parallel size of the final model. + pp_rank: The local pipeline parallel rank of the final model. + write_path: The path to save the NeMo file. + """ + for tp_rank in range(tp_size - 1, -1, -1): + app_state.pipeline_model_parallel_rank = pp_rank + app_state.tensor_model_parallel_rank = tp_rank + + idx = 0 + for name, param in model.named_parameters(): + split_val = splits[idx][tp_rank].clone() + + if param.shape != split_val.shape: + raise RuntimeError( + f"Can not handle parameter {name}, required shape: {param.shape}, split shape: {split_val.shape}." + ) + + param.data = split_val + idx += 1 + + if write_path is not None: + logging.info(f"Writing pp rank {pp_rank} tp rank {tp_rank} to file {write_path}") + model.save_to(write_path) + + +################## +### Converters ### +################## + + +def split_tp_partition_only(args, model, original_model, tp_size, write_path=None, megatron_legacy=False): + + if tp_size < 1: + raise ValueError("TP size must to be >= 1.") + + app_state = AppState() + app_state.data_parallel_rank = 0 + app_state.pipeline_model_parallel_size = 1 + app_state.tensor_model_parallel_size = tp_size + app_state.model_parallel_size = app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size + + app_state.pipeline_model_parallel_rank = 0 + app_state.tensor_model_parallel_rank = tp_size - 1 + + idx = 0 + splits = [] + + for ii, (key, original_tensor) in enumerate(original_model.model.state_dict().items()): + try: + layer_num = int(re.findall(r'\d+', key)[0]) + new_key = key.replace(str(layer_num), str(layer_num), 1) + except: + new_key = key + + if '_extra_state' not in new_key: + split_dim = get_split_dim(new_key) + split = split_tensor_for_tp(args, new_key, split_dim, original_tensor) + + splits.append(split) + idx += 1 + + # Save each of the TP ranks in reverse order + # This is done so that the last PP rank will save the last TP rank only after all other PP TP ranks are saved + # The final rank will then save a new NeMo file with all other ranks inside. + write_tp_pp_split(model, splits, app_state, tp_size, pp_rank=0, write_path=write_path) + + with tarfile.open(write_path, 'r') as tar: + # Extract all contents to the specified path + tar.extractall(path=os.path.dirname(write_path)) + + +def main(): + parser = ArgumentParser() + parser.add_argument("--model_file", type=str, default=None, required=False, help="Path to source .nemo file") + parser.add_argument("--target_file", type=str, required=True, help="Path to write target .nemo file") + parser.add_argument( + "--tensor_model_parallel_size", type=int, default=-1, required=False, help="TP size of source model" + ) + parser.add_argument("--target_tensor_model_parallel_size", type=int, required=True, help="TP size of target model") + parser.add_argument( + '--pipeline_model_parallel_size', type=int, default=1, required=False, help='PP size of source model' + ) + parser.add_argument( + '--target_pipeline_model_parallel_size', type=int, required=False, default=1, help='PP size of target model' + ) + parser.add_argument( + '--target_pipeline_model_parallel_split_rank', type=int, default=0, help='PP rank to split for Enc-Dec models' + ) + parser.add_argument( + '--virtual_pipeline_model_parallel_size', type=int, default=None, help='Virtual Pipeline parallelism size' + ) + parser.add_argument( + '--ckpt_name', type=str, default=None, help='Checkpoint name to load from for Virtual Parallel' + ) + parser.add_argument( + "--model_class", + type=str, + default="nemo.collections.nlp.models.language_modeling.megatron_mamba_model.MegatronMambaModel", + help="NeMo model class. This script should support all NeMo megatron models that use Tensor Parallel", + ) + parser.add_argument("--precision", default=16, help="PyTorch Lightning Trainer precision flag") + parser.add_argument('--num_gpu_per_node', default=8, type=int, help='Number of GPUs per node') + parser.add_argument( + "--megatron_legacy", + action="store_true", + help="Converter for legacy megatron modles that have different q,k,v weight splits", + ) + parser.add_argument( + "--tokenizer_model_path", + type=str, + required=False, + default=None, + help="Path to the tokenizer model path if your model uses a tokenizer model as an artifact. This is needed if your model uses a sentencepiece tokenizer.", + ) + parser.add_argument( + "--tokenizer_vocab_file", + type=str, + required=False, + default=None, + help="Path to the tokenizer model path if your model uses a tokenizer model as an artifact. This is needed if your model uses a sentencepiece tokenizer.", + ) + parser.add_argument('--hparams_file', type=str, default=None, help='Path to hparams file from PTL training') + parser.add_argument( + '--tp_conversion_only', default=True, action='store_true', help='Only convert TP model to TP model' + ) + parser.add_argument('--model_extracted_dir', type=str, default=None, help='Path to pre-extracted model directory') + + parser.add_argument('--d-model', type=int, default=4096) + parser.add_argument('--mamba-version', type=int, default=2) + parser.add_argument('--mamba-d-state', type=int, default=128) + parser.add_argument('--mamba2-n-groups', type=int, default=8) + parser.add_argument('--mamba2-head-dim', type=int, default=64) + + args = parser.parse_args() + + args.mamba_d_inner = args.d_model * 2 + args.mamba2_n_heads = args.mamba_d_inner // args.mamba2_head_dim + + precision = args.precision + num_gpu_per_node = int(args.num_gpu_per_node) + if args.precision in ["32", "16"]: + precision = int(float(args.precision)) + + if precision in ["bf16", "bf16-mixed"]: + if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): + pass + else: + logging.warning("BF16 is not supported on this device. Using FP16 instead.") + precision = precision[2:] + + if precision == 32: + dtype = torch.float32 + elif precision in [16, "16", "16-mixed"]: + dtype = torch.float16 + elif precision in ["bf16", "bf16-mixed"]: + dtype = torch.bfloat16 + else: + dtype = torch.float32 # fallback + + # Built target directory if it does not exist + target_dir = os.path.split(args.target_file)[0] + if not os.path.exists(target_dir): + os.makedirs(target_dir, exist_ok=True) + + tp_size = args.tensor_model_parallel_size + tgt_tp_size = args.target_tensor_model_parallel_size + pp_size = args.pipeline_model_parallel_size + tgt_pp_size = args.target_pipeline_model_parallel_size + pipeline_model_parallel_split_rank = args.target_pipeline_model_parallel_split_rank + vp_size = args.virtual_pipeline_model_parallel_size + if vp_size is None: + vp_size = 1 + + convert_vp = vp_size > 1 + if convert_vp: + from megatron.core import parallel_state + + parallel_state.set_virtual_pipeline_model_parallel_world_size(vp_size) + + hparams_filepath = args.hparams_file + if hparams_filepath is None: + logging.warning( + '\n\n\n!!!!!!!!!\n' + 'You are converting a model with virtual pipeline parallelism enabled, \n' + 'but have not passed `hparams_file` argument. \n' + 'This will cause each ckpt file to be temporarily laoded onto GPU memory!\n\n' + 'It is highly recommended to pass `hparams_file` argument to avoid this.\n' + ) + + # Import the class of the model + + if args.model_file is None and args.model_extracted_dir is None: + raise ValueError("Cannot pass model_file and model_extracted_dir as None at the same time.") + + tmp_cfg = MegatronMambaModel.restore_from( + restore_path=args.model_file, + trainer=Trainer(devices=1, strategy=NLPDDPStrategy(), accelerator="cpu", precision=precision), + map_location=torch.device("cpu"), + return_config=True, + ) + plugins = [] + if precision in [16, '16', 'bf16', '16-mixed', 'bf16-mixed']: + scaler = None + if precision in [16, '16', '16-mixed']: + scaler = GradScaler( + init_scale=tmp_cfg.get('native_amp_init_scale', 2**32), + growth_interval=tmp_cfg.get('native_amp_growth_interval', 1000), + hysteresis=tmp_cfg.get('hysteresis', 2), + ) + # MixedPrecisionPlugin in PTL >= 2.0 requires precision to be 16-mixed or bf16-mixed + plugin_precision = '16-mixed' + else: + plugin_precision = 'bf16-mixed' + + if tmp_cfg.get('megatron_amp_O2', False): + plugins.append(MegatronHalfPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)) + else: + plugins.append(PipelineMixedPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)) + # Set precision None after precision plugins are created as PTL >= 2.1 does not allow both + # precision plugins and precision to exist + trainer = Trainer(plugins=plugins, devices=1, strategy=NLPDDPStrategy(), accelerator="cpu") + + if tp_size < 0 or pp_size < 0: + logging.info(f"Loading model config from {args.model_file} to get TP and PP size") + model_config_internal = MegatronMambaModel.restore_from( + restore_path=args.model_file, + trainer=trainer, + map_location=torch.device("cpu"), + return_config=True, + ) + + tp_size = model_config_internal.get('tensor_model_parallel_size', 1) + pp_size = model_config_internal.get('pipeline_model_parallel_size', 1) + + # Check if TP conversion only + tp_conversion_only = args.tp_conversion_only + if tp_conversion_only: + logging.info("Converting TP model to TP model only") + + if pp_size > 1: + raise ValueError("Provided `--tp_conversion_only` but `--pipeline_model_parallel_size` > 1") + + if tgt_pp_size > 1: + raise ValueError("Provided `--tp_conversion_only` but `--target_pipeline_model_parallel_size` > 1") + + if pipeline_model_parallel_split_rank > 0: + raise ValueError("Provided `--tp_conversion_only` but `--target_pipeline_model_parallel_split_rank` > 0") + + # Force PP size to 1 + pp_size = 1 + tgt_pp_size = 1 + pipeline_model_parallel_split_rank = 0 + + if vp_size is None or vp_size < 0: + vp_size = 1 + + app_state = AppState() + app_state.data_parallel_rank = 0 + app_state.pipeline_model_parallel_size = pp_size + app_state.tensor_model_parallel_size = tp_size + + if vp_size > 1: + app_state.virtual_pipeline_model_parallel_size = vp_size + app_state.model_parallel_size = app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size + + world_size = pp_size * tp_size # pseudo world size for simulating load of a specific rank on a single gpu + + app_state.tensor_model_parallel_rank = 0 + app_state.pipeline_model_parallel_rank = 0 + + # Extract tokenizer artifact from the model to temp directory + logging.info("Extracting tokenizer artifact from NeMo file...") + temp_dir = tempfile.mkdtemp() + tokenizer_model_path = None + with tarfile.open(args.model_file, "r") as tar: + for member in tar.getmembers(): + if '.model' in member.name: + extracted_file = tar.extractfile(member) + extracted_file_path = os.path.join(temp_dir, member.name) + + if tokenizer_model_path is None: + logging.info(f"Found tokenizer. Extracting {member.name} to {extracted_file_path}") + + tokenizer_model_path = extracted_file_path + with open(extracted_file_path, "wb") as f: + f.write(extracted_file.read()) + else: + if args.tokenizer_model_path is None: + logging.warning( + f"\n\nFound multiple tokenizer artifacts in the model file.\n" + f"Using only {tokenizer_model_path}.\n" + f"If this is incorrect, manually pass the correct tokenizer using " + f"`--tokenizer_model_path`.\n\n" + ) + + # If input model has TP > 1 or PP > 1 + # Reconstruct the model to have TP = 1 and PP = 1 + # Note that this is a forward loop that will process PP [0..N] TP [0..M] in sequential order. + + # If input model has TP = 1 and PP = 1 + app_state.model_parallel_size = 1 + + save_restore_connector = NLPSaveRestoreConnector() + + if args.model_extracted_dir is not None: + logging.info(f"Using extracted model directory: {args.model_extracted_dir}") + save_restore_connector.model_extracted_dir = args.model_extracted_dir + + if args.model_file is not None: + model_filepath = args.model_file + else: + model_filepath = args.model_extracted_dir + + tmp_cfg = MegatronMambaModel.restore_from( + restore_path=model_filepath, + trainer=trainer, + map_location=torch.device("cpu"), + save_restore_connector=save_restore_connector, + return_config=True, + ) + + tmp_cfg, restore_dict = force_cpu_model(tmp_cfg) + + model = MegatronMambaModel.restore_from( + restore_path=model_filepath, + trainer=trainer, + map_location=torch.device("cpu"), + save_restore_connector=save_restore_connector, + override_config_path=tmp_cfg, + ) + + original_model = MegatronMambaModel.restore_from( + restore_path=model_filepath, + trainer=trainer, + map_location=torch.device("cpu"), + save_restore_connector=save_restore_connector, + override_config_path=tmp_cfg, + ) + original_model = original_model.to('cpu') + original_model._save_restore_connector = NLPSaveRestoreConnector() + original_model.freeze() + original_model.to(dtype=dtype) + + model.to(dtype=dtype) + + restore_model_config(model.cfg, restore_dict) + + # If target model has TP > 1 or PP > 1 + if tgt_pp_size > 1 or tgt_tp_size > 1: + + # Preserve the TP 1 PP 1 model parameters and names + global_params = [] + global_params.append([p for n, p in model.named_parameters()]) # params + global_params.append([n for n, p in model.named_parameters()]) # names + + logging.debug("Global parameters:") + for idx, (name, p) in enumerate(zip(global_params[1], global_params[0])): + logging.debug(f"{name} - {p.shape}") + + logging.info(f"TP 1 PP 1 Number of Parameters : {len(global_params[0])}") + + world_size = ( + tgt_pp_size * tgt_tp_size + ) # pseudo world size for simulating load of a specific rank on a single gpu + new_global_batch_size = model.cfg.micro_batch_size * world_size + old_global_batch_size = model.cfg.get('global_batch_size', model.cfg.micro_batch_size) + + global_offset = len(global_params[0]) - 1 # -1 cause this indexes the array, range [0, L-1] + logging.info(f"Final layer offset for parameters: {global_offset}") + + for pp_rank in range(tgt_pp_size - 1, -1, -1): # reverse order + + with open_dict(model.cfg): + model.cfg.pipeline_model_parallel_size = tgt_pp_size + model.cfg.tensor_model_parallel_size = tgt_tp_size + + if 'pipeline_model_parallel_split_rank' in model.cfg: + if pipeline_model_parallel_split_rank > 0: + model.cfg.pipeline_model_parallel_split_rank = pipeline_model_parallel_split_rank + elif pp_size > 1: + logging.warning( + f"Model config has `pipeline_model_parallel_split_rank` set to " + f"{model.cfg.pipeline_model_parallel_split_rank} and target PP " + f"size is {tgt_pp_size}. " + f"Provided `pipeline_model_parallel_split_rank` is " + f"{pipeline_model_parallel_split_rank}. " + f"Be careful that the model config is correct " + f"if encoder-decoder models are being converted." + ) + + model.cfg.global_batch_size = old_global_batch_size # Used for restoration + + # Override flag that forces Model to use AppState instead of Trainer + # to determine the world size, global and local rank + # Used for simulating load of a specific rank on a single gpu + os.environ[NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE] = "true" + + # Compute the global rank + global_rank = ( + pp_rank * tgt_tp_size + 0 + ) # tp_rank = 0 needed just for modules, all TP will be merged to this PP rank + + # Update AppState + app_state.world_size = world_size + app_state.global_rank = global_rank + app_state.local_rank = global_rank % num_gpu_per_node + app_state.pipeline_model_parallel_size = tgt_pp_size + app_state.tensor_model_parallel_size = tgt_tp_size + app_state.model_parallel_size = ( + app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size + ) + + trainer = Trainer(plugins=plugins, devices=1, strategy=NLPDDPStrategy(), accelerator="cpu") + if args.tokenizer_model_path is not None: + with open_dict(model.cfg): + model.cfg.tokenizer.model = args.tokenizer_model_path + + else: + if tokenizer_model_path is None: + logging.warning("Could not extract tokenizer model file from checkpoint.") + + else: + # Extract tokenizer info + with open_dict(model.cfg): + model.cfg.tokenizer.model = tokenizer_model_path + + model.cfg, restore_dict = force_cpu_model(model.cfg) + + from apex.transformer.pipeline_parallel.utils import _GLOBAL_NUM_MICROBATCHES_CALCULATOR + + _GLOBAL_NUM_MICROBATCHES_CALCULATOR.current_global_batch_size = 1 + _GLOBAL_NUM_MICROBATCHES_CALCULATOR.current_micro_batch_size = 1 + model.cfg.global_batch_size = 1 + model.cfg.micro_batch_size = 1 + + model = MegatronMambaModel(model.cfg, trainer) + model = model.to('cpu') + model._save_restore_connector = NLPSaveRestoreConnector() + model.freeze() + model.to(dtype=dtype) + + restore_model_config(model.cfg, restore_dict) + + # Update global batch size + if old_global_batch_size % new_global_batch_size != 0 or old_global_batch_size < new_global_batch_size: + logging.info( + f"Global batch size {old_global_batch_size} is not divisible by new global batch size {new_global_batch_size}." + f" The model config will be updated with new global batch size {new_global_batch_size}." + ) + with open_dict(model.cfg): + model.cfg.global_batch_size = new_global_batch_size + + logging.info(f"Global rank: {global_rank} Local rank: {app_state.local_rank} World size: {world_size}") + logging.info(f"PP rank: {pp_rank} TP rank: {0}") + logging.info(f"TP 1 PP 1 Number of Layers : {len(global_params[0])}") + logging.info(f"Remaining layer offset for parameters: {global_offset}") + logging.info("\n") + + # Special case for TP conversion only mode + if tp_conversion_only: + logging.info(f"Skipping PP split due to flag `--tp_conversion_only`") + split_tp_partition_only( + args, model, original_model, tgt_tp_size, args.target_file, args.megatron_legacy + ) + break + + +if __name__ == '__main__': + main() diff --git a/examples/nlp/language_modeling/megatron_gpt_continue_training.py b/examples/nlp/language_modeling/megatron_gpt_continue_training.py deleted file mode 100755 index 73cbb2abcce8d..0000000000000 --- a/examples/nlp/language_modeling/megatron_gpt_continue_training.py +++ /dev/null @@ -1,198 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import tempfile - -from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector - -from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel -from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel -from nemo.collections.nlp.parts.nlp_overrides import ( - CustomProgressBar, - GradScaler, - MegatronHalfPrecisionPlugin, - NLPDDPStrategy, - NLPSaveRestoreConnector, - PipelineMixedPrecisionPlugin, -) -from nemo.core.config import hydra_runner -from nemo.utils import AppState, logging -from nemo.utils.exp_manager import exp_manager -from nemo.utils.model_utils import inject_model_parallel_rank - - -def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False): - """ - This function modifies the original gpt pre-training config (t5_cfg) with attributes from the finetuning config (cfg). - The `add_cfg_to_tree` arg adds `cfg` to the top of the yaml tree which is needed for all `hparams.yaml` files when passed as an arg to `load_from_checkpoint()`. - """ - OmegaConf.set_struct(gpt_cfg, True) - OmegaConf.resolve(cfg) - with open_dict(gpt_cfg): - gpt_cfg.megatron_amp_O2 = cfg.model.get('megatron_amp_O2', False) - gpt_cfg.micro_batch_size = cfg.model.micro_batch_size - gpt_cfg.global_batch_size = cfg.model.global_batch_size - gpt_cfg.sequence_parallel = cfg.model.get("sequence_parallel", False) - gpt_cfg.activations_checkpoint_granularity = cfg.model.get("activations_checkpoint_granularity", None) - gpt_cfg.activations_checkpoint_num_layers = cfg.model.get("activations_checkpoint_num_layers", None) - gpt_cfg.activations_checkpoint_method = cfg.model.get("activations_checkpoint_method", None) - gpt_cfg.data = cfg.model.data - gpt_cfg.optim = cfg.model.optim - gpt_cfg.precision = cfg.trainer.precision - gpt_cfg.restore_from_path = cfg.restore_from_path - gpt_cfg.resume_from_checkpoint = cfg.model.resume_from_checkpoint - gpt_cfg.gradient_as_bucket_view = cfg.model.gradient_as_bucket_view - gpt_cfg.encoder_seq_length = cfg.model.encoder_seq_length - gpt_cfg.max_position_embeddings = cfg.model.max_position_embeddings - gpt_cfg.seq_len_interpolation_factor = cfg.model.seq_len_interpolation_factor - gpt_cfg.use_flash_attention = cfg.model.use_flash_attention - gpt_cfg.tensor_model_parallel_size = cfg.model.get('tensor_model_parallel_size', 1) - gpt_cfg.pipeline_model_parallel_size = cfg.model.get('pipeline_model_parallel_size', 1) - gpt_cfg.pipeline_model_parallel_split_rank = cfg.model.get('pipeline_model_parallel_split_rank', 0) - - # This is needed when modifying a hparam file directly to load `.ckpt` files. - # This is not needed to modify the cfg in `.nemo` files. - if add_cfg_to_tree: - OmegaConf.resolve(gpt_cfg) - gpt_cfg.cfg = gpt_cfg - - return gpt_cfg - - -def load_from_nemo(cls, cfg, trainer, gpt_cfg, modify_confg_fn): - gpt_cfg = modify_confg_fn(gpt_cfg, cfg, add_cfg_to_tree=False) - save_restore_connector = NLPSaveRestoreConnector() - if os.path.isdir(cfg.restore_from_path): - save_restore_connector.model_extracted_dir = cfg.restore_from_path - model = cls.restore_from( - restore_path=cfg.restore_from_path, - trainer=trainer, - override_config_path=gpt_cfg, - save_restore_connector=save_restore_connector, - ) - return model - - -def load_from_checkpoint_dir(cls, cfg, trainer, modify_confg_fn): - app_state = AppState() - if cfg.model.tensor_model_parallel_size > 1 or cfg.model.pipeline_model_parallel_size > 1: - app_state.model_parallel_size = cfg.model.tensor_model_parallel_size * cfg.model.pipeline_model_parallel_size - app_state.tensor_model_parallel_size = cfg.model.tensor_model_parallel_size - app_state.pipeline_model_parallel_size = cfg.model.pipeline_model_parallel_size - ( - app_state.tensor_model_parallel_rank, - app_state.pipeline_model_parallel_rank, - app_state.model_parallel_size, - app_state.data_parallel_size, - app_state.pipeline_model_parallel_split_rank, - app_state.virtual_pipeline_model_parallel_rank, - ) = fake_initialize_model_parallel( - world_size=app_state.model_parallel_size, - rank=trainer.global_rank, - tensor_model_parallel_size_=cfg.model.tensor_model_parallel_size, - pipeline_model_parallel_size_=cfg.model.pipeline_model_parallel_size, - pipeline_model_parallel_split_rank_=cfg.model.pipeline_model_parallel_split_rank, - ) - checkpoint_path = inject_model_parallel_rank( - os.path.join(cfg.model.pretrained_checkpoint.checkpoint_dir, cfg.model.pretrained_checkpoint.checkpoint_name) - ) - hparams_file = OmegaConf.load(cfg.model.pretrained_checkpoint.hparams_file) - gpt_cfg = modify_confg_fn(hparams_file.cfg, cfg, add_cfg_to_tree=True) - with tempfile.NamedTemporaryFile(suffix='.yaml') as f: - OmegaConf.save(config=gpt_cfg, f=f.name) - model = cls.load_from_checkpoint(checkpoint_path=checkpoint_path, trainer=trainer, hparams_file=f.name,) - return model - - -def validate_checkpoint_loading_args(cfg): - if cfg.checkpoint_dir is None or not os.path.isdir(cfg.checkpoint_dir): - raise ValueError(f'Checkpoint directory {cfg.checkpoint_dir} does not exist or is not a directory.') - if cfg.checkpoint_name is None: - raise ValueError(f'Checkpoint name {cfg.checkpoint_name} is not valid.') - if cfg.hparams_file is None or not os.path.isfile(cfg.hparams_file): - raise ValueError(f'Hparams file {cfg.hparams_file} does not exist or is not a file.') - - -@hydra_runner(config_path="conf", config_name="megatron_gpt_config") -def main(cfg) -> None: - logging.info("\n\n************** Experiment configuration ***********") - logging.info(f'\n{OmegaConf.to_yaml(cfg)}') - - megatron_amp_O2 = cfg.model.get('megatron_amp_O2', False) - with_distributed_adam = cfg.model.optim.get('name', 'fused_adam') == 'distributed_fused_adam' - plugins = [] - strategy = NLPDDPStrategy( - no_ddp_communication_hook=True, - gradient_as_bucket_view=cfg.model.gradient_as_bucket_view, - find_unused_parameters=False, - ) - if cfg.trainer.precision in [16, '16', 'bf16', '16-mixed', 'bf16-mixed']: - scaler = None - if cfg.trainer.precision in [16, '16', '16-mixed']: - scaler = GradScaler( - init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), - growth_interval=cfg.model.get('native_amp_growth_interval', 1000), - hysteresis=cfg.model.get('hysteresis', 2), - ) - plugin_precision = '16-mixed' - else: - plugin_precision = 'bf16-mixed' - if megatron_amp_O2 and not with_distributed_adam: - plugins.append(MegatronHalfPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)) - else: - plugins.append(PipelineMixedPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)) - - if cfg.get('cluster_type', None) == 'BCP': - plugins.append(TorchElasticEnvironment()) - - callbacks = [] - # enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False, CustomProgressBar is not appended to callbacks - if 'enable_progress_bar' not in cfg.trainer or cfg.trainer.enable_progress_bar: - callbacks.append(CustomProgressBar()) - trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer, callbacks=callbacks) - - exp_manager(trainer, cfg.exp_manager) - - # update resume from checkpoint found by exp_manager - if cfg.model.resume_from_checkpoint is not None: - trainer.ckpt_path = cfg.model.resume_from_checkpoint - - logging.info(f'Resuming training from checkpoint: {trainer.ckpt_path}') - - if cfg.restore_from_path: - save_restore_connector = NLPSaveRestoreConnector() - if os.path.isdir(cfg.restore_from_path): - save_restore_connector.model_extracted_dir = cfg.restore_from_path - gpt_cfg = MegatronGPTModel.restore_from( - restore_path=cfg.restore_from_path, - trainer=trainer, - return_config=True, - save_restore_connector=save_restore_connector, - ) - model = load_from_nemo(MegatronGPTModel, cfg, trainer, gpt_cfg, modify_confg_fn=_modify_config) - elif cfg.model.get("pretrained_checkpoint", None) is not None: - validate_checkpoint_loading_args(cfg.model.pretrained_checkpoint) - model = load_from_checkpoint_dir(MegatronGPTModel, cfg, trainer, modify_confg_fn=_modify_config) - else: - print(' > WARNING: No checkpoint provided. Starting from scratch.') - model = MegatronGPTModel(cfg.model, trainer) - trainer.fit(model) - - -if __name__ == '__main__': - main() diff --git a/examples/nlp/language_modeling/megatron_gpt_eval.py b/examples/nlp/language_modeling/megatron_gpt_eval.py index f3413a5fa92ef..b9b0d2973094a 100644 --- a/examples/nlp/language_modeling/megatron_gpt_eval.py +++ b/examples/nlp/language_modeling/megatron_gpt_eval.py @@ -14,6 +14,7 @@ import asyncio import datetime +import json import os import threading from functools import partial @@ -30,6 +31,7 @@ from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam from nemo.collections.nlp.parts.nlp_overrides import CustomProgressBar, NLPDDPStrategy, NLPSaveRestoreConnector from nemo.core.config import hydra_runner +from nemo.utils import logging from nemo.utils.app_state import AppState from nemo.utils.model_utils import inject_model_parallel_rank @@ -166,19 +168,7 @@ def remove_padded_prompts(response, nb_paddings): return result -@hydra_runner(config_path="conf", config_name="megatron_gpt_inference") -def main(cfg) -> None: - - callbacks = [] - # enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False, CustomProgressBar is not appended to callbacks - if 'enable_progress_bar' not in cfg.trainer or cfg.trainer.enable_progress_bar: - callbacks.append(CustomProgressBar()) - # trainer required for restoring model parallel models - trainer = Trainer( - strategy=NLPDDPStrategy(timeout=datetime.timedelta(seconds=18000)), - **cfg.trainer, - callbacks=callbacks, - ) +def load_model_from_config(trainer, cfg): if cfg.gpt_model_file is not None: if ( @@ -285,7 +275,51 @@ def main(cfg) -> None: model = MegatronGPTModel.load_from_checkpoint(checkpoint_path, hparams_file=cfg.hparams_file, trainer=trainer) else: raise ValueError("need at least a nemo file or checkpoint dir") + return model + + +def load_prompts(cfg): + prompts = [] + if (cfg_prompts := getattr(cfg, 'prompts', None)) is not None: + prompts = OmegaConf.to_container(cfg_prompts) + if (prompts_jsonl := getattr(cfg, 'prompts_jsonl', None)) is not None: + with open(prompts_jsonl, 'rt') as fp: + try: + prompts += list(map(json.loads, map(str.rstrip, fp))) + except: + prompts += list(map(str.rstrip, fp)) + # Make sure non-empty input + assert len(prompts) > 0, "Expected at least one prompt" + # Make sure all have the same type + assert all( + map(lambda x: isinstance(x, type(prompts[0])), prompts) + ), "Expected all prompts to have the same datatype" + return prompts + + +def round_to_mult(n, mult=8): + """ + Rounds number n to be a multiple of mult + """ + return ((n + mult - 1) // mult) * mult + +@hydra_runner(config_path="conf", config_name="megatron_gpt_inference") +def main(cfg) -> None: + + callbacks = [] + logging.warning("This file will be depreacted soon. Use the megatron_gpt_mcore_batch_eval.py file instead.") + # enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False, CustomProgressBar is not appended to callbacks + if 'enable_progress_bar' not in cfg.trainer or cfg.trainer.enable_progress_bar: + callbacks.append(CustomProgressBar()) + # trainer required for restoring model parallel models + trainer = Trainer( + strategy=NLPDDPStrategy(timeout=datetime.timedelta(seconds=18000)), + **cfg.trainer, + callbacks=callbacks, + ) + + model = load_model_from_config(trainer, cfg) model.freeze() # Have to turn off activations_checkpoint_method for inference @@ -311,17 +345,17 @@ def main(cfg) -> None: "end_strings": cfg.inference.end_strings, } + prompts = load_prompts(cfg) + fp8_enabled = hasattr(model.cfg, "fp8") and (model.cfg.fp8 == True) - if fp8_enabled: - nb_paddings = 0 - while len(cfg.prompts) % 8 != 0: - cfg.prompts.append("") - nb_paddings += 1 + if fp8_enabled and len(prompts) > 0: + padded_len = round_to_mult(len(prompts), 8) + nb_paddings = padded_len - len(prompts) + if nb_paddings > 0: + nb_paddings += [''] * nb_paddings # First method of running text generation, call model.generate method - response = model.generate( - inputs=OmegaConf.to_container(cfg.prompts), length_params=length_params, sampling_params=sampling_params - ) + response = model.generate(inputs=prompts, length_params=length_params, sampling_params=sampling_params) if fp8_enabled: response = remove_padded_prompts(response, nb_paddings) @@ -331,7 +365,7 @@ def main(cfg) -> None: # Second method of running text generation, call trainer.predict [recommended] bs = 8 if fp8_enabled else 2 - ds = RequestDataSet(OmegaConf.to_container(cfg.prompts)) + ds = RequestDataSet(prompts) request_dl = DataLoader(dataset=ds, batch_size=bs) config = OmegaConf.to_container(cfg.inference) model.set_inference_config(config) diff --git a/examples/nlp/language_modeling/megatron_gpt_mcore_batch_eval.py b/examples/nlp/language_modeling/megatron_gpt_mcore_batch_eval.py new file mode 100644 index 0000000000000..988a5f8588ff9 --- /dev/null +++ b/examples/nlp/language_modeling/megatron_gpt_mcore_batch_eval.py @@ -0,0 +1,222 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import os +from argparse import Namespace + +from megatron.core.inference.common_inference_params import CommonInferenceParams +from megatron.core.inference.engines.mcore_engine import MCoreEngine +from megatron.core.inference.inference_model_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper +from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import ( + SimpleTextGenerationController, +) +from omegaconf import OmegaConf, open_dict +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel +from nemo.collections.nlp.parts.nlp_overrides import CustomProgressBar, NLPDDPStrategy, NLPSaveRestoreConnector +from nemo.core.config import hydra_runner +from nemo.utils.app_state import AppState +from nemo.utils.model_utils import inject_model_parallel_rank + +""" +This is the script to run GPT text generation in batch mode using Megatron Core Generate function. +""" + + +@hydra_runner(config_path="conf", config_name="megatron_gpt_inference_batch_mcore") +def main(cfg) -> None: + callbacks = [] + # enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False, CustomProgressBar is not appended to callbacks + if 'enable_progress_bar' not in cfg.trainer or cfg.trainer.enable_progress_bar: + callbacks.append(CustomProgressBar()) + # trainer required for restoring model parallel models + trainer = Trainer( + strategy=NLPDDPStrategy(timeout=datetime.timedelta(seconds=18000)), + **cfg.trainer, + callbacks=callbacks, + ) + + if cfg.gpt_model_file is not None: + if ( + cfg.tensor_model_parallel_size < 0 + or cfg.pipeline_model_parallel_size < 0 + or cfg.get('pipeline_model_parallel_split_rank', -1) < 0 + ): + save_restore_connector = NLPSaveRestoreConnector() + if os.path.isdir(cfg.gpt_model_file): + save_restore_connector.model_extracted_dir = cfg.gpt_model_file + model_config = MegatronGPTModel.restore_from( + restore_path=cfg.gpt_model_file, + trainer=trainer, + return_config=True, + save_restore_connector=save_restore_connector, + ) + + # with dist checkpointing we don't need to set this + if not model_config.get('mcore_gpt', False): + with open_dict(cfg): + cfg.tensor_model_parallel_size = model_config.get('tensor_model_parallel_size', 1) + cfg.pipeline_model_parallel_size = model_config.get('pipeline_model_parallel_size', 1) + cfg.pipeline_model_parallel_split_rank = model_config.get('pipeline_model_parallel_split_rank', 0) + + assert ( + cfg.trainer.devices * cfg.trainer.num_nodes + == cfg.tensor_model_parallel_size + * cfg.pipeline_model_parallel_size + * max(1, cfg.get('expert_model_parallel_size', 1)) + ), "devices * num_nodes should equal tensor_model_parallel_size * pipeline_model_parallel_size" + + if cfg.gpt_model_file: + save_restore_connector = NLPSaveRestoreConnector() + if os.path.isdir(cfg.gpt_model_file): + save_restore_connector.model_extracted_dir = cfg.gpt_model_file + + pretrained_cfg = MegatronGPTModel.restore_from( + restore_path=cfg.gpt_model_file, + trainer=trainer, + return_config=True, + save_restore_connector=save_restore_connector, + ) + OmegaConf.set_struct(pretrained_cfg, True) + with open_dict(pretrained_cfg): + pretrained_cfg.sequence_parallel = False + pretrained_cfg.activations_checkpoint_granularity = None + pretrained_cfg.activations_checkpoint_method = None + pretrained_cfg.precision = trainer.precision + pretrained_cfg["use_flash_attention"] = cfg.get("use_flash_attention", False) + pretrained_cfg["apply_rope_fusion"] = False + if pretrained_cfg.get('mcore_gpt', False): + # with dist checkpointing we can use the model parallel config specified by the user + pretrained_cfg.tensor_model_parallel_size = cfg.tensor_model_parallel_size + pretrained_cfg.pipeline_model_parallel_size = cfg.pipeline_model_parallel_size + pretrained_cfg.expert_model_parallel_size = cfg.get('expert_model_parallel_size', 1) + pretrained_cfg.micro_batch_size = 1 + if trainer.precision == "16": + pretrained_cfg.megatron_amp_O2 = False + elif trainer.precision in ['bf16', 'bf16-mixed'] and cfg.get('megatron_amp_O2', False): + pretrained_cfg.megatron_amp_O2 = True + model = MegatronGPTModel.restore_from( + restore_path=cfg.gpt_model_file, + trainer=trainer, + override_config_path=pretrained_cfg, + save_restore_connector=save_restore_connector, + map_location=f'cuda:{trainer.local_rank}', # map_location is needed for converted models + ) + elif cfg.checkpoint_dir: + app_state = AppState() + if ( + cfg.tensor_model_parallel_size > 1 + or cfg.pipeline_model_parallel_size > 1 + or cfg.get('expert_model_parallel_size', 1) > 1 + ): + app_state.model_parallel_size = ( + cfg.tensor_model_parallel_size + * cfg.pipeline_model_parallel_size + * cfg.get('expert_model_parallel_size', 1) + ) + app_state.tensor_model_parallel_size = cfg.tensor_model_parallel_size + app_state.pipeline_model_parallel_size = cfg.pipeline_model_parallel_size + app_state.expert_model_parallel_size = cfg.get('expert_model_parallel_size', 1) + ( + app_state.tensor_model_parallel_rank, + app_state.pipeline_model_parallel_rank, + app_state.expert_model_parallel_rank, + app_state.model_parallel_size, + app_state.data_parallel_size, + app_state.pipeline_model_parallel_split_rank, + app_state.virtual_pipeline_model_parallel_rank, + ) = fake_initialize_model_parallel( + world_size=app_state.model_parallel_size, + rank=trainer.global_rank, + tensor_model_parallel_size_=cfg.tensor_model_parallel_size, + pipeline_model_parallel_size_=cfg.pipeline_model_parallel_size, + pipeline_model_parallel_split_rank_=cfg.pipeline_model_parallel_split_rank, + expert_model_parallel_size_=cfg.get('expert_model_parallel_size', 1), + ) + checkpoint_path = os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name) + # checkpoint_path is a dir in case of distributed checkpointing + if not os.path.isdir(checkpoint_path): + # legacy checkpoint needs model parallel rank injection + checkpoint_path = inject_model_parallel_rank(os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name)) + model = MegatronGPTModel.load_from_checkpoint(checkpoint_path, hparams_file=cfg.hparams_file, trainer=trainer) + else: + raise ValueError("need at least a nemo file or checkpoint dir") + + model.freeze() + + # Have to turn off activations_checkpoint_method for inference + try: + model.model.language_model.encoder.activations_checkpoint_method = None + except AttributeError: + pass + + args = Namespace + args.inference_batch_times_seq_len_threshold = cfg.inference_batch_times_seq_len_threshold + args.padded_vocab_size = model.padded_vocab_size + args.fp32_residual_connection = model.cfg.fp32_residual_connection + args.hidden_size = model.cfg.hidden_size + args.params_dtype = model.cfg.precision + args.max_batch_size = cfg.max_batch_size + + # We need this wrapper since mcore generate uses tokenizer.detokenize, tokenizer.tokenize to encode and decode prompts + class MCoreTokenizerWrappper: + def __init__(self, tokenizer): + self.tokenizer = tokenizer + self.eod = tokenizer.eod + self.vocab_size = tokenizer.vocab_size + + def detokenize(self, tokens): + return self.tokenizer.ids_to_text(tokens) + + def tokenize(self, prompt): + return self.tokenizer.text_to_ids(prompt) + + tokenizer = MCoreTokenizerWrappper(model.tokenizer) + + inference_wrapped_model = GPTInferenceWrapper(model.model, args) + text_generation_controller = SimpleTextGenerationController( + inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer + ) + mcore_engine = MCoreEngine( + text_generation_controller=text_generation_controller, max_batch_size=args.max_batch_size + ) + + common_inference_params = CommonInferenceParams( + temperature=cfg.common_inference_params.temperature, + top_k=cfg.common_inference_params.top_k, + top_p=cfg.common_inference_params.top_p, + return_log_probs=cfg.common_inference_params.return_log_probs, + num_tokens_to_generate=cfg.common_inference_params.tokens_to_generate, + ) + + results = mcore_engine.generate( + prompts=OmegaConf.to_container(cfg.prompts), common_inference_params=common_inference_params + ) + + for idx, result in enumerate(results): + print(f' \n------------- RESULT FOR PROMPT {idx} --------------- ') + result = { + 'id': result.request_id, + 'input_prompt': result.prompt, + 'generated_text': result.generated_text, + 'generated_tokens': result.generated_tokens, + } + print(result) + + +if __name__ == '__main__': + main() # noqa pylint: disable=no-value-for-parameter diff --git a/examples/nlp/language_modeling/megatron_gpt_pretraining.py b/examples/nlp/language_modeling/megatron_gpt_pretraining.py index 80158446d95a9..422319a382c83 100644 --- a/examples/nlp/language_modeling/megatron_gpt_pretraining.py +++ b/examples/nlp/language_modeling/megatron_gpt_pretraining.py @@ -13,6 +13,8 @@ # limitations under the License. +from pathlib import Path + # To suppress BF16 compile related issue in the CI runs with turing/V100 import torch._dynamo import torch.multiprocessing as mp @@ -20,6 +22,7 @@ from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector from nemo.core.config import hydra_runner from nemo.utils import logging from nemo.utils.exp_manager import exp_manager @@ -37,7 +40,25 @@ def main(cfg) -> None: trainer = MegatronTrainerBuilder(cfg).create_trainer() exp_manager(trainer, cfg.exp_manager) - model = MegatronGPTModel(cfg.model, trainer) + # Continual training + if cfg.model.get("restore_from_path") is not None: + # Option 1: Restore only the model weights from a .nemo file + logging.info(f"Continual training: loading weights from {cfg.model.restore_from_path}") + model = MegatronGPTModel.restore_from( + restore_path=cfg.model.restore_from_path, + override_config_path=cfg.model, + trainer=trainer, + save_restore_connector=NLPSaveRestoreConnector(), + ) + elif cfg.model.get("restore_from_ckpt") is not None: + # Option 2: Restore both model weights and optimizer states from a PTL checkpoint + logging.info(f"Continual training: loading weights and optimizer states from {cfg.model.restore_from_ckpt}") + trainer.ckpt_path = Path(cfg.model.restore_from_ckpt) + model = MegatronGPTModel(cfg.model, trainer) + + # Start new pretraining or resume from a checkpoint if it exists + else: + model = MegatronGPTModel(cfg.model, trainer) trainer.fit(model) diff --git a/examples/nlp/language_modeling/megatron_gpt_quantization.py b/examples/nlp/language_modeling/megatron_gpt_ptq.py similarity index 94% rename from examples/nlp/language_modeling/megatron_gpt_quantization.py rename to examples/nlp/language_modeling/megatron_gpt_ptq.py index faf442ecd22c1..e41becc2d8e00 100644 --- a/examples/nlp/language_modeling/megatron_gpt_quantization.py +++ b/examples/nlp/language_modeling/megatron_gpt_ptq.py @@ -31,12 +31,12 @@ Nemo quantization example script. Please consult nemo.export.quantize.Quantizer class -and examples/nlp/language_modeling/conf/megatron_gpt_quantization.yaml config on available quantization methods, +and examples/nlp/language_modeling/conf/megatron_gpt_ptq.yaml config on available quantization methods, models supported as well as how to set up data and inference for calibration (with defaults recommended). Example usage: ``` -python examples/nlp/language_modeling/megatron_gpt_quantization.py \ +python examples/nlp/language_modeling/megatron_gpt_ptq.py \ model.restore_from_path=llama2-7b-fp16.nemo \ quantization.algorithm=fp8 \ export.decoder_type=llama \ @@ -65,7 +65,7 @@ def get_calib_data_iter(data="cnn_dailymail", batch_size=64, calib_size=512, max yield batch -@hydra_runner(config_path="conf", config_name="megatron_gpt_quantization") +@hydra_runner(config_path="conf", config_name="megatron_gpt_ptq") def main(cfg) -> None: if not torch.cuda.is_available(): raise EnvironmentError("GPU is required for the quantization.") diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_qat_config.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_qat_config.yaml new file mode 100644 index 0000000000000..09e00f8be1108 --- /dev/null +++ b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_qat_config.yaml @@ -0,0 +1,206 @@ +name: llama2-7b + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: bf16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: 9999 + max_steps: 100 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 # frequency with which training steps are logged + val_check_interval: 0.25 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch + gradient_clip_val: 1.0 + +exp_manager: + explicit_log_dir: ${name}-${trainer.precision}-sft-${quantization.algorithm} # Path to the directory where logs and checkpoints will be saved + exp_dir: null + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: validation_${model.data.validation_ds.metric.name} + save_top_k: 1 + mode: min + save_nemo_on_train_end: True + filename: "${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{consumed_samples}" + model_parallel_size: ${model.tensor_model_parallel_size} + always_save_nemo: False + save_best_model: False + create_early_stopping_callback: True + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + strict: False # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + +model: + seed: 1234 + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + + global_batch_size: 128 + micro_batch_size: 1 + restore_from_path: ??? # Path to an existing .nemo model you wish to quantize + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + save_nemo_on_validation_end: False # Saves an inference ready .nemo file every time a checkpoint is saved during training. + sync_batch_comm: False + megatron_amp_O2: True + + ## Sequence Parallelism + # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + sequence_parallel: False + + ## Activation Checkpoint + activations_checkpoint_granularity: selective # 'selective' or 'full' + activations_checkpoint_method: uniform # 'uniform', 'block', not used with 'selective' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null # not used with 'selective' + activations_checkpoint_layers_per_pipeline: null + answer_only_loss: True + gradient_as_bucket_view: False + + hidden_dropout: 0.0 + attention_dropout: 0.0 + ffn_dropout: 0.0 + + # FSDP + fsdp: False # Enable training with torch FSDP. + fsdp_sharding_strategy: "full" # Method to shard model states. Available options are 'full', 'hybrid', and 'grad'. + fsdp_grad_reduce_dtype: "fp32" # Gradient reduction data type. + fsdp_sharded_checkpoint: False # Store and load FSDP shared checkpoint. + fsdp_use_orig_params: False # Set to True to use FSDP for specific peft scheme. + + peft: + peft_scheme: "none" # Should be none for QAT as we are doing SFT on all parameters + + data: + train_ds: + # Example of how to specify paths to multiple datasets + # file_names: + # - /path/to/squad.jsonl + # - /path/to/mnli.jsonl + # - /path/to/boolq.jsonl + # Example of how each dataset is formatted + # {'input': 'John von Neumann\nVon Neumann made fundamental contributions .... Q: What did the math of artificial viscosity do?', 'output': 'smoothed the shock transition without sacrificing basic physics'} + file_names: ??? # Path to a list of JSONL files corresponding to the source data. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: True + num_workers: 0 + memmap_workers: 2 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: True + # Example of how to specify concat_sampling_probabilities + # concat_sampling_probabilities: + # - 0.5 + # - 0.25 + # - 0.25 + concat_sampling_probabilities: null # When providing a list of datasets, this arg defines the sampling probabilities from each dataset when strategy='random' + label_key: "output" + add_eos: True + add_sep: False + add_bos: False + truncation_field: "input" # # Can be multiple keys separated with ',' Options: keys in prompt_template + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: "{input} {output}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + truncation_method: "right" # Truncation from which position, Options: ['left', 'right'] + validation_ds: + file_names: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + names: null # Names of the corresponding datasets used to log metrics. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 0 + memmap_workers: ${model.data.train_ds.memmap_workers} + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: False + label_key: ${model.data.train_ds.label_key} + add_eos: ${model.data.train_ds.add_eos} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: ${model.data.train_ds.truncation_field} # Options: keys in prompt_template + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics + truncation_method: "right" # Truncation from which position, Options: ['left', 'right'] + metric: + name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + test_ds: + file_names: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + names: null # Names of the corresponding datasets used to log metrics. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 0 + memmap_workers: ${model.data.train_ds.memmap_workers} + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: False + label_key: ${model.data.train_ds.label_key} + add_eos: ${model.data.train_ds.add_eos} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: ${model.data.train_ds.truncation_field} # Options: keys in prompt_template + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: ${model.data.train_ds.prompt_template} + tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics + truncation_method: "right" # Truncation from which position, Options: ['left', 'right'] + metric: + name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + + optim: + name: distributed_fused_adam + lr: 5e-6 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 50 + min_lr: 0.0 # min_lr must be 0.0 for prompt learning when pipeline parallel > 1 + constant_steps: 0 # Constant steps should also be 0 when min_lr=0 + monitor: val_loss + reduce_on_plateau: false + +quantization: + decoder_type: ${export.decoder_type} # gptnext, gpt2, llama + algorithm: int4 # null, int8_sq, fp8, int4_awq, int4 + num_calib_size: 512 # number of samples used for calibration + awq_block_size: 128 # block size for scaling factors (only used in AWQ algorithms) + sq_alpha: 1.0 # alpha parameter (only used in SmoothQuant algorithms) + enable_kv_cache: false # Enable FP8 KV cache quantization. Set to null for automatic selection. + +export: + decoder_type: llama # gptnext, gpt2, llama + inference_tensor_parallel: 1 # Default using 1 TP for inference + inference_pipeline_parallel: 1 # Default using 1 PP for inference + dtype: ${trainer.precision} # Default precision data type + save_path: ${exp_manager.explicit_log_dir}/${name}-sft-${quantization.algorithm}.qnemo # Path where the quantized model will be saved + compress: false # Wheter save_path should be a tarball or a directory \ No newline at end of file diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_mamba_finetuning_config.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_mamba_finetuning_config.yaml new file mode 100644 index 0000000000000..33498540a3d53 --- /dev/null +++ b/examples/nlp/language_modeling/tuning/conf/megatron_mamba_finetuning_config.yaml @@ -0,0 +1,234 @@ +name: megatron_mamba +restore_from_path: ${model.restore_from_path} # used when starting from a .nemo file + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: bf16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: 9999 + max_steps: 10000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 1 # frequency with which training steps are logged + val_check_interval: 200 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch + gradient_clip_val: 1.0 + limit_val_batches: 1024 + limit_test_batches: 500 + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: True + wandb_logger_kwargs: + project: griffin + name: sft-test + resume_if_exists: False + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: validation_${model.data.validation_ds.metric.name} + save_top_k: 1 + mode: min + save_nemo_on_train_end: True + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{consumed_samples}' + model_parallel_size: ${model.tensor_model_parallel_size} + always_save_nemo: False + save_best_model: True + create_early_stopping_callback: True + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + strict: False # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + + +model: + seed: 1234 + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + + encoder_seq_length: 1024 + global_batch_size: 8 + micro_batch_size: 1 + restore_from_path: ??? # Path to an existing .nemo model you wish to add new tasks to or run inference with + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + save_nemo_on_validation_end: False # Saves an inference ready .nemo file every time a checkpoint is saved during training. + sync_batch_comm: False + megatron_amp_O2: False + + ## Sequence Parallelism + # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + sequence_parallel: False + + ## Activation Checkpoint + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null # not used with 'selective' + activations_checkpoint_layers_per_pipeline: null + answer_only_loss: True + gradient_as_bucket_view: False + + hidden_dropout: 0.0 + attention_dropout: 0.0 + ffn_dropout: 0.0 + + peft: + peft_scheme: "lora" # can be either adapter,ia3, lora, or ptuning + restore_from_path: null + + # Used for adapter peft training + adapter_tuning: + type: 'parallel_adapter' # this should be either 'parallel_adapter' or 'linear_adapter' + adapter_dim: 32 + adapter_dropout: 0.0 + norm_position: 'pre' # This can be set to 'pre', 'post' or null, 'pre' is normally what is used. + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + norm_type: 'mixedfusedlayernorm' # IGNORED if layer_adapter is used, options are ['layernorm', 'mixedfusedlayernorm'] + layer_selection: null # selects in which layers to add adapters, e.g. [1,12] will add adapters to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + lora_tuning: + target_modules: ['all'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', attention (qkv & dense), mlp (fc1 & fc2) + adapter_dim: 32 + alpha: 32 + adapter_dropout: 0.0 + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + # Used for p-tuning peft training + p_tuning: + virtual_tokens: 10 # The number of virtual tokens the prompt encoder should add at the start of the sequence + bottleneck_dim: 1024 # the size of the prompt encoder mlp bottleneck + embedding_dim: 1024 # the size of the prompt encoder embeddings + init_std: 0.023 + + ia3_tuning: + layer_selection: null # selects in which layers to add ia3 adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + + selective_tuning: + tunable_base_param_names: ["self_attention", "word_embeddings"] # TODO: regex support @adithyre + + + data: + train_ds: + # Example of how to specify paths to multiple datasets + # file_names: + # - /path/to/squad.jsonl + # - /path/to/mnli.jsonl + # - /path/to/boolq.jsonl + # Example of how each dataset is formatted + # {'input': 'John von Neumann\nVon Neumann made fundamental contributions .... Q: What did the math of artificial viscosity do?', 'output': 'smoothed the shock transition without sacrificing basic physics'} + file_names: null # Path to a list of JSONL files corresponding to the source data. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: True + num_workers: 0 + memmap_workers: 2 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: True + # Example of how to specify concat_sampling_probabilities + # concat_sampling_probabilities: + # - 0.5 + # - 0.25 + # - 0.25 + concat_sampling_probabilities: [1.0] # When providing a list of datasets, this arg defines the sampling probabilities from each dataset when strategy='random' + label_key: 'output' + add_eos: True + add_sep: False + add_bos: True + truncation_field: "input" # # Can be multiple keys separated with ',' Options: keys in prompt_template + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: "{input} {output}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] + ceil_to_power_2: True + get_attention_mask_from_fusion: True + pad_to_max_length: True + validation_ds: + file_names: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + names: null # Names of the corresponding datasets used to log metrics. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 0 + memmap_workers: ${model.data.train_ds.memmap_workers} + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: False + label_key: ${model.data.train_ds.label_key} + add_eos: ${model.data.train_ds.add_eos} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: ${model.data.train_ds.truncation_field} # Options: keys in prompt_template + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics + truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] + ceil_to_power_2: True + get_attention_mask_from_fusion: True + pad_to_max_length: True + metric: + name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + test_ds: + file_names: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + names: null # Names of the corresponding datasets used to log metrics. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 0 + memmap_workers: ${model.data.train_ds.memmap_workers} + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: False + label_key: ${model.data.train_ds.label_key} + add_eos: ${model.data.train_ds.add_eos} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: ${model.data.train_ds.truncation_field} # Options: keys in prompt_template + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: ${model.data.train_ds.prompt_template} + tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics + truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] + ceil_to_power_2: True + get_attention_mask_from_fusion: True + pad_to_max_length: True + metric: + name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + + optim: + name: distributed_fused_adam + lr: 2e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 500 + constant_steps: 50000 + min_lr: 2e-5 diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_mamba_generate_config.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_mamba_generate_config.yaml new file mode 100644 index 0000000000000..fddfa16c8c09f --- /dev/null +++ b/examples/nlp/language_modeling/tuning/conf/megatron_mamba_generate_config.yaml @@ -0,0 +1,224 @@ +name: megatron_mamba +restore_from_path: ${model.restore_from_path} # used when starting from a .nemo file + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + precision: bf16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 # PTL default. In practice we don't usually train for more than 1 epoch. + max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 + val_check_interval: 100 + limit_val_batches: 50 + limit_test_batches: 500 + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + benchmark: False + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: megatron_mamba + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 10 + mode: min + always_save_nemo: False # saves nemo file during validation, not implemented for model parallel + filename: 'megatron_mamba--{val_loss:.2f}-{step}-{consumed_samples}' + model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} + +model: + seed: 1234 + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + + encoder_seq_length: 1024 + global_batch_size: 8 + micro_batch_size: 1 + restore_from_path: ??? # Path to an existing .nemo model you wish to add new tasks to or run inference with + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + save_nemo_on_validation_end: False # Saves an inference ready .nemo file every time a checkpoint is saved during training. + sync_batch_comm: False + megatron_amp_O2: False + + ## Sequence Parallelism + # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + sequence_parallel: False + + ## Activation Checkpoint + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null # not used with 'selective' + activations_checkpoint_layers_per_pipeline: null + answer_only_loss: True + gradient_as_bucket_view: False + + hidden_dropout: 0.0 + attention_dropout: 0.0 + ffn_dropout: 0.0 + + + peft: + peft_scheme: null # can be either adapter,ia3, lora, or ptuning + restore_from_path: null + + # Used for adapter peft training + adapter_tuning: + type: 'parallel_adapter' # this should be either 'parallel_adapter' or 'linear_adapter' + adapter_dim: 32 + adapter_dropout: 0.0 + norm_position: 'pre' # This can be set to 'pre', 'post' or null, 'pre' is normally what is used. + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + norm_type: 'mixedfusedlayernorm' # IGNORED if layer_adapter is used, options are ['layernorm', 'mixedfusedlayernorm'] + layer_selection: null # selects in which layers to add adapters, e.g. [1,12] will add adapters to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + lora_tuning: + target_modules: ['all'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', attention (qkv & dense), mlp (fc1 & fc2) + adapter_dim: 32 + alpha: 32 + adapter_dropout: 0.0 + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + # Used for p-tuning peft training + p_tuning: + virtual_tokens: 10 # The number of virtual tokens the prompt encoder should add at the start of the sequence + bottleneck_dim: 1024 # the size of the prompt encoder mlp bottleneck + embedding_dim: 1024 # the size of the prompt encoder embeddings + init_std: 0.023 + + ia3_tuning: + layer_selection: null # selects in which layers to add ia3 adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + + selective_tuning: + tunable_base_param_names: ["self_attention", "word_embeddings"] # TODO: regex support @adithyre + + data: + test_ds: + file_names: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + names: ??? # Names of the corresponding datasets used to log metrics. + global_batch_size: 1 + micro_batch_size: 1 + shuffle: False + num_workers: 0 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: False + context_key: 'input' + label_key: 'output' + add_eos: True + add_sep: False + add_bos: True + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: "input" # Options: keys in prompt_template + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: "{input} {output}" + tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics + truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] + ceil_to_power_2: True + get_attention_mask_from_fusion: True + pad_to_max_length: True + + metric: + name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + +inference: + greedy: True # Whether or not to use sampling ; use greedy decoding otherwise + top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. + temperature: 1.0 # sampling temperature + all_probs: False # whether return the log prob for all the tokens in vocab + repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty. + min_tokens_to_generate: 0 # The minimum length of the sequence to be generated. + compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False + outfile_path: output.txt + compute_attention_mask: True + +# server-related configs +server: False # whether launch the API server +port: 5555 # the port number for the inference server +web_server: False # whether launch the web inference server +share: True # whether create a public URL +username: test # user name for web client +password: test2 # password for web client +web_port: 9889 # the port number of the web server 1058 +chat: False # use the chat interface +chatbot_config: + value: False # whether to inject the value attributes + attributes: + - name: Quality + min: 0 + max: 4 + key: quality + type: int + default: 4 + - name: Toxicity + min: 0 + max: 4 + key: toxcity + type: int + default: 0 + - name: Humor + min: 0 + max: 4 + key: humor + type: int + default: 0 + - name: Creativity + min: 0 + max: 4 + key: creativity + type: int + default: 0 + - name: Violence + min: 0 + max: 4 + key: violence + type: int + default: 0 + - name: Helpfulness + min: 0 + max: 4 + key: helpfulness + type: int + default: 4 + - name: Not_Appropriate + min: 0 + max: 4 + key: not_appropriate + type: int + default: 0 + - name: Language + choices: ['ar', 'bg', 'bn', 'ca', 'cs', 'da', 'de', 'el', 'en', 'eo', 'es', 'eu', 'fa', 'fi', 'fr', 'gl', 'he', 'hu', 'id', 'it', 'ja', 'ko', 'nb', 'nl', 'pl', 'pt', 'ro', 'ru', 'sk', 'sv', 'th', 'tr', 'uk', 'vi', 'zh'] + key: lang + type: list + default: en + + user: User + assistant: Assistant + system: "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n" \ No newline at end of file diff --git a/examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py b/examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py index aaa087a46623b..bfe8ea35960ed 100644 --- a/examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py +++ b/examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/nlp/language_modeling/tuning/megatron_gpt_qat.py b/examples/nlp/language_modeling/tuning/megatron_gpt_qat.py new file mode 100644 index 0000000000000..23e1b358d06e1 --- /dev/null +++ b/examples/nlp/language_modeling/tuning/megatron_gpt_qat.py @@ -0,0 +1,93 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from itertools import islice + +import torch.multiprocessing as mp +from omegaconf.omegaconf import OmegaConf +from tqdm import tqdm + +from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder +from nemo.core.config import hydra_runner +from nemo.export.quantize import Quantizer +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + +mp.set_start_method("spawn", force=True) + +""" +This is a modified version of `megatron_gpt_finetuning.py` to perform PTQ and QAT on a SFT Model like Llama2-7b. +Please see docs/source/nlp/quantization.rst for more details on the usage. +""" + + +def get_forward_loop(fwd_bwd_step, dataloader, num_batches): + if len(dataloader) < num_batches: + logging.warning( + f"Dataloader has fewer batches ({len(dataloader)}) than required ({num_batches}) for calibration." + ) + num_batches = len(dataloader) + + def forward_loop(model): + data_iter = islice(iter(dataloader), num_batches) + for _ in tqdm(range(num_batches), desc="Calibrating"): + fwd_bwd_step(data_iter, forward_only=True) + + return forward_loop + + +@hydra_runner(config_path="conf", config_name="megatron_gpt_qat_config") +def main(cfg) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f'\n{OmegaConf.to_yaml(cfg)}') + + trainer = MegatronLMPPTrainerBuilder(cfg).create_trainer() + exp_manager(trainer, cfg.exp_manager) + + quantizer = Quantizer(cfg.quantization, cfg.export) + + model_cfg = MegatronGPTSFTModel.merge_cfg_with(cfg.model.restore_from_path, cfg) + model_cfg = quantizer.modify_model_config(model_cfg) + + model = MegatronGPTSFTModel.restore_from(cfg.model.restore_from_path, model_cfg, trainer=trainer) + assert model.mcore_gpt, "Only MCoreGPTModel is supported with nvidia-modelopt for QAT." + + # Setup dataloaders + model.setup() + + # Perform PTQ on the SFT Model + if cfg.quantization.algorithm is not None: + model_module_list = model.get_model_module_list() + assert len(model_module_list) == 1 + unwrapped_model = model_module_list[0] + + num_batches = cfg.quantization.num_calib_size // cfg.model.global_batch_size + forward_loop = get_forward_loop(model.fwd_bwd_step, model.train_dataloader(), num_batches) + quantizer.quantize(unwrapped_model, forward_loop) + + logging.info("Validating model after PTQ...") + trainer.validate(model) + + # Perform QAT on the PTQ Model + trainer.fit(model) + + # Export the quantized model for TensorRT-LLM inference + # INT4 export is not supported yet + if cfg.quantization.algorithm != "int4": + quantizer.export(model) + + +if __name__ == '__main__': + main() diff --git a/examples/nlp/language_modeling/tuning/megatron_mamba_finetuning.py b/examples/nlp/language_modeling/tuning/megatron_mamba_finetuning.py new file mode 100644 index 0000000000000..0613ef486ec3d --- /dev/null +++ b/examples/nlp/language_modeling/tuning/megatron_mamba_finetuning.py @@ -0,0 +1,60 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.multiprocessing as mp +from omegaconf.omegaconf import OmegaConf + +from nemo.collections.nlp.models.language_modeling.megatron_mamba_sft_model import MegatronMambaSFTModel +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder +from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + +mp.set_start_method("spawn", force=True) + + +@hydra_runner(config_path="conf", config_name="megatron_mamba_finetuning_config") +def main(cfg) -> None: + + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f'\n{OmegaConf.to_yaml(cfg)}') + + precision = cfg.trainer.precision + trainer = MegatronLMPPTrainerBuilder(cfg).create_trainer() + # Restore the precision value after Trainer is built. + cfg.trainer.precision = precision + exp_manager(trainer, cfg.exp_manager) + + model_cfg = MegatronMambaSFTModel.merge_cfg_with(cfg.model.restore_from_path, cfg) + model = MegatronMambaSFTModel.restore_from(cfg.model.restore_from_path, model_cfg, trainer=trainer) + + peft_cfg_cls = PEFT_CONFIG_MAP[cfg.model.peft.peft_scheme] + + if cfg.model.peft.restore_from_path is not None: + # initialize peft weights from a check`point instead of randomly + # This is not the same as resume training because optimizer states are not restored. + logging.info("PEFT Weights will be loaded from", cfg.model.peft.restore_from_path) + model.load_adapters(cfg.model.peft.restore_from_path, peft_cfg_cls(model_cfg)) + elif peft_cfg_cls is not None: + logging.info("Adding adapter weights to the model for PEFT") + model.add_adapter(peft_cfg_cls(model_cfg)) + else: + logging.info(f"Running full finetuning since no peft scheme is given.\n{model.summarize()}") + + trainer.fit(model) + + +if __name__ == '__main__': + main() diff --git a/examples/nlp/language_modeling/tuning/megatron_mamba_generate.py b/examples/nlp/language_modeling/tuning/megatron_mamba_generate.py new file mode 100644 index 0000000000000..6f660d552fc6b --- /dev/null +++ b/examples/nlp/language_modeling/tuning/megatron_mamba_generate.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import torch.multiprocessing as mp +from omegaconf.omegaconf import OmegaConf +from nemo.collections.nlp.models.language_modeling.megatron_mamba_sft_model import MegatronMambaSFTModel +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder +from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.model_utils import inject_model_parallel_rank + + +mp.set_start_method("spawn", force=True) + + +@hydra_runner(config_path="conf", config_name="megatron_mamba_generate_config") +def main(cfg) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f"\n{OmegaConf.to_yaml(cfg)}") + trainer = MegatronLMPPTrainerBuilder(cfg).create_trainer() + + if cfg.model.peft.restore_from_path: + model_cfg = MegatronMambaSFTModel.merge_inference_cfg(cfg.model.peft.restore_from_path, cfg) + else: + model_cfg = MegatronMambaSFTModel.merge_inference_cfg(cfg.model.restore_from_path, cfg) + + model = MegatronMambaSFTModel.restore_from(cfg.model.restore_from_path, model_cfg, trainer=trainer) + + if cfg.model.peft.restore_from_path: + model.load_adapters(cfg.model.peft.restore_from_path) + elif cfg.model.peft.restore_from_ckpt.checkpoint_dir and cfg.model.peft.restore_from_ckpt.checkpoint_name: + peft_cfg_cls = PEFT_CONFIG_MAP[cfg.model.peft.peft_scheme] + checkpoint_path = os.path.join( + cfg.model.peft.restore_from_ckpt.checkpoint_dir, cfg.model.peft.restore_from_ckpt.checkpoint_name + ) + # checkpoint_path is a dir in case of distributed checkpointing + if not os.path.isdir(checkpoint_path): + # legacy checkpoint needs model parallel rank injection + checkpoint_path = inject_model_parallel_rank( + os.path.join( + cfg.model.peft.restore_from_ckpt.checkpoint_dir, cfg.model.peft.restore_from_ckpt.checkpoint_name + ) + ) + model.load_adapters(checkpoint_path, peft_cfgs=peft_cfg_cls(model_cfg)) + else: + raise NotImplementedError("distributed checkpointing of PEFT weights is not supported") + + model.freeze() + logging.info(f"Freezing parameters for PEFT eval:\n{model.summarize()}") + + trainer.test(model) + + +if __name__ == "__main__": + main() diff --git a/nemo/README.md b/nemo/README.md index 91b734b643611..869ce2f50031a 100644 --- a/nemo/README.md +++ b/nemo/README.md @@ -9,3 +9,4 @@ NeMo (**Ne**ural **Mo**dules) is a toolkit for creating AI applications built ar * NLP - collection of modules and models for building NLP networks * Vision - collection of modules and models for building computer vision networks * Multimodal - collection of modules and models for building multimodal networks +* Audio - collection of modules and models for building audio processing networks diff --git a/nemo/collections/asr/data/audio_to_label.py b/nemo/collections/asr/data/audio_to_label.py index 4ff27f91ed0f9..decd6beaa961a 100644 --- a/nemo/collections/asr/data/audio_to_label.py +++ b/nemo/collections/asr/data/audio_to_label.py @@ -118,12 +118,12 @@ def _speech_collate_fn(batch, pad_id): def _fixed_seq_collate_fn(self, batch): """collate batch of audio sig, audio len, tokens, tokens len - Args: - batch (Optional[FloatTensor], Optional[LongTensor], LongTensor, - LongTensor): A tuple of tuples of signal, signal lengths, - encoded tokens, and encoded tokens length. This collate func - assumes the signals are 1d torch tensors (i.e. mono audio). - """ + Args: + batch (Optional[FloatTensor], Optional[LongTensor], LongTensor, + LongTensor): A tuple of tuples of signal, signal lengths, + encoded tokens, and encoded tokens length. This collate func + assumes the signals are 1d torch tensors (i.e. mono audio). + """ _, audio_lengths, _, tokens_lengths = zip(*batch) has_audio = audio_lengths[0] is not None @@ -232,19 +232,23 @@ class _AudioLabelDataset(Dataset): Defaults to None. trim (bool): Whether to use trim silence from beginning and end of audio signal using librosa.effects.trim(). Defaults to False. + channel selector (Union[str, int, List[int]]): string denoting the downmix mode, an integer denoting the channel to be selected, or an iterable + of integers denoting a subset of channels. Channel selector is using zero-based indexing. + If set to `None`, the original signal will be used. """ @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" output_types = { 'audio_signal': NeuralType( ('B', 'T'), - AudioSignal(freq=self._sample_rate) - if self is not None and hasattr(self, '_sample_rate') - else AudioSignal(), + ( + AudioSignal(freq=self._sample_rate) + if self is not None and hasattr(self, '_sample_rate') + else AudioSignal() + ), ), 'a_sig_length': NeuralType(tuple('B'), LengthsType()), } @@ -259,7 +263,10 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: else: output_types.update( - {'label': NeuralType(tuple('B'), LabelsType()), 'label_length': NeuralType(tuple('B'), LengthsType()),} + { + 'label': NeuralType(tuple('B'), LabelsType()), + 'label_length': NeuralType(tuple('B'), LengthsType()), + } ) return output_types @@ -273,6 +280,7 @@ def __init__( min_duration: Optional[float] = 0.1, max_duration: Optional[float] = None, trim: bool = False, + channel_selector: Union[str, int, List[int]] = None, is_regression_task: bool = False, cal_labels_occurrence: Optional[bool] = False, ): @@ -290,6 +298,7 @@ def __init__( self.featurizer = featurizer self.trim = trim + self.channel_selector = channel_selector self.is_regression_task = is_regression_task if not is_regression_task: @@ -325,7 +334,13 @@ def __getitem__(self, index): if offset is None: offset = 0 - features = self.featurizer.process(sample.audio_file, offset=offset, duration=sample.duration, trim=self.trim) + features = self.featurizer.process( + sample.audio_file, + offset=offset, + duration=sample.duration, + trim=self.trim, + channel_selector=self.channel_selector, + ) f, fl = features, torch.tensor(features.shape[0]).long() if not self.is_regression_task: @@ -392,6 +407,9 @@ class AudioToSpeechLabelDataset(_AudioLabelDataset): trim (bool): Whether to use trim silence from beginning and end of audio signal using librosa.effects.trim(). Defaults to False. + channel selector (Union[str, int, List[int]]): string denoting the downmix mode, an integer denoting the channel to be selected, or an iterable + of integers denoting a subset of channels. Channel selector is using zero-based indexing. + If set to `None`, the original signal will be used. window_length_in_sec (float): length of window/slice (in seconds) Use this for speaker recognition and VAD tasks. shift_length_in_sec (float): amount of shift of window for generating the frame for VAD task in a batch @@ -413,6 +431,7 @@ def __init__( min_duration: Optional[float] = 0.1, max_duration: Optional[float] = None, trim: bool = False, + channel_selector: Optional[Union[str, int, List[int]]] = None, window_length_in_sec: Optional[float] = 8, shift_length_in_sec: Optional[float] = 1, normalize_audio: bool = False, @@ -433,6 +452,7 @@ def __init__( min_duration=min_duration, max_duration=max_duration, trim=trim, + channel_selector=channel_selector, is_regression_task=is_regression_task, cal_labels_occurrence=cal_labels_occurrence, ) @@ -631,8 +651,7 @@ def _internal_generator(self): return TarredAudioFilter(self.collection, self.file_occurence) def _build_sample(self, tup): - """Builds the training sample by combining the data from the WebDataset with the manifest info. - """ + """Builds the training sample by combining the data from the WebDataset with the manifest info.""" audio_bytes, audio_filename = tup # Grab manifest entry from self.collection file_id, _ = os.path.splitext(os.path.basename(audio_filename)) @@ -647,7 +666,10 @@ def _build_sample(self, tup): # Convert audio bytes to IO stream for processing (for SoundFile to read) audio_filestream = io.BytesIO(audio_bytes) features = self.featurizer.process( - audio_filestream, offset=offset, duration=manifest_entry.duration, trim=self.trim, + audio_filestream, + offset=offset, + duration=manifest_entry.duration, + trim=self.trim, ) audio_filestream.close() @@ -879,9 +901,12 @@ class AudioToMultiLabelDataset(Dataset): All training files which have a duration more than max_duration are dropped. Note: Duration is read from the manifest JSON. Defaults to None. - trim (bool): Whether to use trim silence from beginning and end + trim_silence (bool): Whether to use trim silence from beginning and end of audio signal using librosa.effects.trim(). Defaults to False. + channel selector (Union[str, int, List[int]]): string denoting the downmix mode, an integer denoting the channel to be selected, or an iterable + of integers denoting a subset of channels. Channel selector is using zero-based indexing. + If set to `None`, the original signal will be used. window_length_in_sec (float): length of window/slice (in seconds) Use this for speaker recognition and VAD tasks. shift_length_in_sec (float): amount of shift of window for generating the frame for VAD task in a batch @@ -898,15 +923,16 @@ class AudioToMultiLabelDataset(Dataset): @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" output_types = { 'audio_signal': NeuralType( ('B', 'T'), - AudioSignal(freq=self._sample_rate) - if self is not None and hasattr(self, '_sample_rate') - else AudioSignal(), + ( + AudioSignal(freq=self._sample_rate) + if self is not None and hasattr(self, '_sample_rate') + else AudioSignal() + ), ), 'a_sig_length': NeuralType(tuple('B'), LengthsType()), } @@ -920,7 +946,10 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: ) else: output_types.update( - {'label': NeuralType(('B', 'T'), LabelsType()), 'label_length': NeuralType(tuple('B'), LengthsType()),} + { + 'label': NeuralType(('B', 'T'), LabelsType()), + 'label_length': NeuralType(tuple('B'), LengthsType()), + } ) return output_types @@ -936,6 +965,7 @@ def __init__( min_duration: Optional[float] = 0.1, max_duration: Optional[float] = None, trim_silence: bool = False, + channel_selector: Optional[Union[str, int, List[int]]] = None, is_regression_task: bool = False, cal_labels_occurrence: Optional[bool] = False, delimiter: Optional[str] = None, @@ -959,6 +989,7 @@ def __init__( self.featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=augmentor) self.trim = trim_silence + self.channel_selector = channel_selector self.is_regression_task = is_regression_task self.id2occurrence = {} self.labels_occurrence = None @@ -1016,6 +1047,7 @@ def __getitem__(self, index): offset=offset, duration=sample.duration, trim=self.trim, + channel_selector=self.channel_selector, normalize_db=self.normalize_audio_db, ) @@ -1245,8 +1277,7 @@ def _internal_generator(self): return TarredAudioFilter(self.collection, self.file_occurence) def _build_sample(self, tup): - """Builds the training sample by combining the data from the WebDataset with the manifest info. - """ + """Builds the training sample by combining the data from the WebDataset with the manifest info.""" audio_bytes, audio_filename = tup # Grab manifest entry from self.collection file_id, _ = os.path.splitext(os.path.basename(audio_filename)) diff --git a/nemo/collections/asr/data/audio_to_text.py b/nemo/collections/asr/data/audio_to_text.py index e0bb63ad18cd4..28dc168481ed9 100644 --- a/nemo/collections/asr/data/audio_to_text.py +++ b/nemo/collections/asr/data/audio_to_text.py @@ -27,8 +27,8 @@ from tqdm import tqdm from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer +from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType from nemo.collections.asr.parts.preprocessing.segment import available_formats as valid_sf_formats -from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType from nemo.collections.common import tokenizers from nemo.collections.common.parts.preprocessing import collections, parsers from nemo.core.classes import Dataset, IterableDataset diff --git a/nemo/collections/asr/data/data_simulation.py b/nemo/collections/asr/data/data_simulation.py index 5bbdcdfb56057..5ee2ad19b9514 100644 --- a/nemo/collections/asr/data/data_simulation.py +++ b/nemo/collections/asr/data/data_simulation.py @@ -13,29 +13,19 @@ # limitations under the License. import concurrent -import itertools -import multiprocessing import os -import random import warnings -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Dict, List, Tuple -import h5py -import librosa -import matplotlib.pyplot as plt import numpy as np import soundfile as sf import torch -from numpy.random import default_rng -from omegaconf import DictConfig, OmegaConf +from omegaconf import OmegaConf from scipy.signal import convolve from scipy.signal.windows import cosine, hamming, hann -from scipy.spatial.transform import Rotation from tqdm import tqdm from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations -from nemo.collections.asr.parts.preprocessing.segment import AudioSegment -from nemo.collections.asr.parts.utils.audio_utils import db2mag, generate_approximate_noise_field, mag2db, pow2db, rms from nemo.collections.asr.parts.utils.data_simulation_utils import ( DataAnnotator, SpeechSampler, @@ -53,7 +43,7 @@ read_audio_from_buffer, read_noise_manifest, ) -from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest +from nemo.collections.asr.parts.utils.manifest_utils import read_manifest from nemo.collections.asr.parts.utils.speaker_utils import get_overlap_range, is_overlap, merge_float_intervals from nemo.utils import logging @@ -74,16 +64,16 @@ class MultiSpeakerSimulator(object): """ - Multispeaker Audio Session Simulator - Simulates multispeaker audio sessions using single-speaker audio files and + Multispeaker Audio Session Simulator - Simulates multispeaker audio sessions using single-speaker audio files and corresponding word alignments. Change Log: v1.0: Dec 2022 - First working verison, supports multispeaker simulation with overlaps, silence and RIR v1.0.1: Feb 2023 - - Multi-GPU support for speed up - - Faster random sampling routine - - Fixed sentence duration bug + - Multi-GPU support for speed up + - Faster random sampling routine + - Fixed sentence duration bug - Silence and overlap length sampling algorithms are updated to guarantee `mean_silence` approximation v1.0.2: March 2023 - Added support for segment-level gain perturbation and session-level white-noise perturbation @@ -108,65 +98,65 @@ class MultiSpeakerSimulator(object): session_config: num_speakers (int): Number of unique speakers per multispeaker audio session num_sessions (int): Number of sessions to simulate - session_length (int): Length of each simulated multispeaker audio session (seconds). Short sessions + session_length (int): Length of each simulated multispeaker audio session (seconds). Short sessions (e.g. ~240 seconds) tend to fall short of the expected overlap-ratio and silence-ratio. - + session_params: - max_audio_read_sec (int): The maximum audio length in second when loading an audio file. + max_audio_read_sec (int): The maximum audio length in second when loading an audio file. The bigger the number, the slower the reading speed. Should be greater than 2.5 second. - sentence_length_params (list): k,p values for a negative_binomial distribution which is sampled to get the + sentence_length_params (list): k,p values for a negative_binomial distribution which is sampled to get the sentence length (in number of words) - dominance_var (float): Variance in speaker dominance (where each speaker's dominance is sampled from a normal - distribution centered on 1/`num_speakers`, and then the dominance values are together + dominance_var (float): Variance in speaker dominance (where each speaker's dominance is sampled from a normal + distribution centered on 1/`num_speakers`, and then the dominance values are together normalized to 1) - min_dominance (float): Minimum percentage of speaking time per speaker (note that this can cause the dominance of + min_dominance (float): Minimum percentage of speaking time per speaker (note that this can cause the dominance of the other speakers to be slightly reduced) turn_prob (float): Probability of switching speakers after each utterance mean_silence (float): Mean proportion of silence to speaking time in the audio session. Should be in range [0, 1). - mean_silence_var (float): Variance for mean silence in all audio sessions. + mean_silence_var (float): Variance for mean silence in all audio sessions. This value should be 0 <= mean_silence_var < mean_silence * (1 - mean_silence). per_silence_var (float): Variance for each silence in an audio session, set large values (e.g., 20) for de-correlation. per_silence_min (float): Minimum duration for each silence, default to 0. per_silence_max (float): Maximum duration for each silence, default to -1 for no maximum. - mean_overlap (float): Mean proportion of overlap in the overall non-silence duration. Should be in range [0, 1) and + mean_overlap (float): Mean proportion of overlap in the overall non-silence duration. Should be in range [0, 1) and recommend [0, 0.15] range for accurate results. - mean_overlap_var (float): Variance for mean overlap in all audio sessions. + mean_overlap_var (float): Variance for mean overlap in all audio sessions. This value should be 0 <= mean_overlap_var < mean_overlap * (1 - mean_overlap). - per_overlap_var (float): Variance for per overlap in each session, set large values to de-correlate silence lengths + per_overlap_var (float): Variance for per overlap in each session, set large values to de-correlate silence lengths with the latest speech segment lengths per_overlap_min (float): Minimum per overlap duration in seconds per_overlap_max (float): Maximum per overlap duration in seconds, set -1 for no maximum - start_window (bool): Whether to window the start of sentences to smooth the audio signal (and remove silence at + start_window (bool): Whether to window the start of sentences to smooth the audio signal (and remove silence at the start of the clip) window_type (str): Type of windowing used when segmenting utterances ("hamming", "hann", "cosine") window_size (float): Length of window at the start or the end of segmented utterance (seconds) - start_buffer (float): Buffer of silence before the start of the sentence (to avoid cutting off speech or starting + start_buffer (float): Buffer of silence before the start of the sentence (to avoid cutting off speech or starting abruptly) - split_buffer (float): Split RTTM labels if greater than twice this amount of silence (to avoid long gaps between + split_buffer (float): Split RTTM labels if greater than twice this amount of silence (to avoid long gaps between utterances as being labelled as speech) release_buffer (float): Buffer before window at end of sentence (to avoid cutting off speech or ending abruptly) normalize (bool): Normalize speaker volumes - normalization_type (str): Normalizing speakers ("equal" - same volume per speaker, "var" - variable volume per + normalization_type (str): Normalizing speakers ("equal" - same volume per speaker, "var" - variable volume per speaker) normalization_var (str): Variance in speaker volume (sample from standard deviation centered at 1) min_volume (float): Minimum speaker volume (only used when variable normalization is used) max_volume (float): Maximum speaker volume (only used when variable normalization is used) end_buffer (float): Buffer at the end of the session to leave blank - + outputs: output_dir (str): Output directory for audio sessions and corresponding label files output_filename (str): Output filename for the wav and RTTM files overwrite_output (bool): If true, delete the output directory if it exists output_precision (int): Number of decimal places in output files - - background_noise: + + background_noise: add_bg (bool): Add ambient background noise if true background_manifest (str): Path to background noise manifest file snr (int): SNR for background noise (using average speaker power), set `snr_min` and `snr_max` values to enable random SNR snr_min (int): Min random SNR for background noise (using average speaker power), set `null` to use fixed SNR snr_max (int): Max random SNR for background noise (using average speaker power), set `null` to use fixed SNR - + segment_augmentor: add_seg_aug (bool): Set True to enable augmentation on each speech segment (Default: False) segmentor: @@ -185,12 +175,12 @@ class MultiSpeakerSimulator(object): speaker_enforcement: enforce_num_speakers (bool): Enforce that all requested speakers are present in the output wav file - enforce_time (list): Percentage of the way through the audio session that enforcement mode is triggered (sampled + enforce_time (list): Percentage of the way through the audio session that enforcement mode is triggered (sampled between time 1 and 2) - + segment_manifest: (parameters for regenerating the segment manifest file) window (float): Window length for segmentation - shift (float): Shift length for segmentation + shift (float): Shift length for segmentation step_count (int): Number of the unit segments you want to create per utterance deci (int): Rounding decimals for segment manifest file """ @@ -266,8 +256,8 @@ def _init_speaker_permutations(self, num_sess: int, num_speakers: int, all_speak """ Initialize the speaker permutations for the number of speakers in the session. When generating the simulated sessions, we want to include as many speakers as possible. - This function generates a set of permutations that can be used to sweep all speakers in - the source dataset to make sure we maximize the total number of speakers included in + This function generates a set of permutations that can be used to sweep all speakers in + the source dataset to make sure we maximize the total number of speakers included in the simulated sessions. Args: @@ -276,7 +266,7 @@ def _init_speaker_permutations(self, num_sess: int, num_speakers: int, all_speak all_speaker_ids (list): List of all speaker IDs Returns: - permuted_inds (np.array): + permuted_inds (np.array): Array of permuted speaker indices to use for each session Dimensions: (num_sess, num_speakers) """ @@ -308,8 +298,8 @@ def _init_speaker_permutations(self, num_sess: int, num_speakers: int, all_speak def _init_chunk_count(self): """ Initialize the chunk count for multi-processing to prevent over-flow of job counts. - The multi-processing pipeline can freeze if there are more than approximately 10,000 jobs - in the pipeline at the same time. + The multi-processing pipeline can freeze if there are more than approximately 10,000 jobs + in the pipeline at the same time. """ return int(np.ceil(self._params.data_simulator.session_config.num_sessions / self.multiprocessing_chunksize)) @@ -653,7 +643,7 @@ def _add_file( random_offset: bool = False, ) -> Tuple[int, torch.Tensor]: """ - Add audio file to current sentence (up to the desired number of words). + Add audio file to current sentence (up to the desired number of words). Uses the alignments to segment the audio file. NOTE: 0 index is always silence in `audio_manifest['words']`, so we choose `offset_idx=1` as the first word @@ -663,7 +653,7 @@ def _add_file( sentence_word_count (int): Running count for number of words in sentence max_word_count_in_sentence (int): Maximum count for number of words in sentence max_samples_in_sentence (int): Maximum length for sentence in terms of samples - + Returns: sentence_word_count+current_word_count (int): Running word count len(self._sentence) (tensor): Current length of the audio file @@ -739,7 +729,11 @@ def _add_file( 0, ) self._sentence = torch.cat( - (self._sentence, audio_file[start_cutoff + start_window_amount : start_cutoff + prev_dur_samples],), 0, + ( + self._sentence, + audio_file[start_cutoff + start_window_amount : start_cutoff + prev_dur_samples], + ), + 0, ).to(self._device) else: @@ -752,7 +746,9 @@ def _add_file( word_idx < len(audio_manifest['words']) ) and self._params.data_simulator.session_params.window_type is not None: release_buffer, end_window_amount = self._get_end_buffer_and_window( - prev_dur_samples, remaining_dur_samples, len(audio_file[start_cutoff + prev_dur_samples :]), + prev_dur_samples, + remaining_dur_samples, + len(audio_file[start_cutoff + prev_dur_samples :]), ) self._sentence = torch.cat( ( @@ -780,7 +776,7 @@ def _build_sentence( max_samples_in_sentence: int, ): """ - Build a new sentence by attaching utterance samples together until the sentence has reached a desired length. + Build a new sentence by attaching utterance samples together until the sentence has reached a desired length. While generating the sentence, alignment information is used to segment the audio. Args: @@ -936,7 +932,7 @@ def _get_session_meta_data(self, array: np.ndarray, snr: float) -> dict: snr (float): signal-to-noise ratio Returns: - dict: meta data + dict: meta data """ meta_data = { "duration": array.shape[0] / self._params.data_simulator.sr, @@ -1093,7 +1089,10 @@ def _generate_session( ) # step 5: add sentence to array array, is_speech, end = self._add_sentence_to_array( - start=start, length=length, array=array, is_speech=is_speech, + start=start, + length=length, + array=array, + is_speech=is_speech, ) # Step 6: Build entries for output files @@ -1174,7 +1173,9 @@ def _generate_session( sf.write(os.path.join(basepath, filename + '.wav'), array, self._params.data_simulator.sr) self.annotator.write_annotation_files( - basepath=basepath, filename=filename, meta_data=self._get_session_meta_data(array=array, snr=snr), + basepath=basepath, + filename=filename, + meta_data=self._get_session_meta_data(array=array, snr=snr), ) # Step 8: Clean up memory @@ -1262,7 +1263,9 @@ def generate_sessions(self, random_seed: int = None): if self.num_workers > 1: basepath, filename = future.result() else: - self._noise_samples = self.sampler.sample_noise_manifest(noise_manifest=source_noise_manifest,) + self._noise_samples = self.sampler.sample_noise_manifest( + noise_manifest=source_noise_manifest, + ) basepath, filename = self._generate_session(*future) self.annotator.add_to_filename_lists(basepath=basepath, filename=filename) @@ -1277,7 +1280,7 @@ def generate_sessions(self, random_seed: int = None): class RIRMultiSpeakerSimulator(MultiSpeakerSimulator): """ - RIR Augmented Multispeaker Audio Session Simulator - simulates multispeaker audio sessions using single-speaker + RIR Augmented Multispeaker Audio Session Simulator - simulates multispeaker audio sessions using single-speaker audio files and corresponding word alignments, as well as simulated RIRs for augmentation. Args: @@ -1288,17 +1291,17 @@ class RIRMultiSpeakerSimulator(MultiSpeakerSimulator): use_rir (bool): Whether to generate synthetic RIR toolkit (str): Which toolkit to use ("pyroomacoustics", "gpuRIR") room_config: - room_sz (list): Size of the shoebox room environment (1d array for specific, 2d array for random range to be + room_sz (list): Size of the shoebox room environment (1d array for specific, 2d array for random range to be sampled from) - pos_src (list): Positions of the speakers in the simulated room environment (2d array for specific, 3d array + pos_src (list): Positions of the speakers in the simulated room environment (2d array for specific, 3d array for random ranges to be sampled from) noise_src_pos (list): Position in room for the ambient background noise source mic_config: num_channels (int): Number of output audio channels - pos_rcv (list): Microphone positions in the simulated room environment (1d/2d array for specific, 2d/3d array + pos_rcv (list): Microphone positions in the simulated room environment (1d/2d array for specific, 2d/3d array for range assuming num_channels is 1/2+) orV_rcv (list or null): Microphone orientations (needed for non-omnidirectional microphones) - mic_pattern (str): Microphone type ("omni" - omnidirectional) - currently only omnidirectional microphones are + mic_pattern (str): Microphone type ("omni" - omnidirectional) - currently only omnidirectional microphones are supported for pyroomacoustics absorbtion_params: (Note that only `T60` is used for pyroomacoustics simulations) abs_weights (list): Absorption coefficient ratios for each surface @@ -1463,7 +1466,10 @@ def _generate_rir_pyroomacoustics(self) -> Tuple[torch.Tensor, int]: if self._params.data_simulator.rir_generation.mic_config.mic_pattern == 'omni': mic_pattern = DirectivityPattern.OMNI dir_vec = DirectionVector(azimuth=0, colatitude=90, degrees=True) - dir_obj = CardioidFamily(orientation=dir_vec, pattern_enum=mic_pattern,) + dir_obj = CardioidFamily( + orientation=dir_vec, + pattern_enum=mic_pattern, + ) mic_pos_tmp = np.array(self._params.data_simulator.rir_generation.mic_config.pos_rcv) if mic_pos_tmp.ndim == 3: # randomize @@ -1684,2354 +1690,11 @@ def _generate_session( sf.write(os.path.join(basepath, filename + '.wav'), array, self._params.data_simulator.sr) self.annotator.write_annotation_files( - basepath=basepath, filename=filename, meta_data=self._get_session_meta_data(array=array, snr=snr), + basepath=basepath, + filename=filename, + meta_data=self._get_session_meta_data(array=array, snr=snr), ) del array self.clean_up() return basepath, filename - - -def check_angle(key: str, val: Union[float, Iterable[float]]) -> bool: - """Check if the angle value is within the expected range. Input - values are in degrees. - - Note: - azimuth: angle between a projection on the horizontal (xy) plane and - positive x axis. Increases counter-clockwise. Range: [-180, 180]. - elevation: angle between a vector an its projection on the horizontal (xy) plane. - Positive above, negative below, i.e., north=+90, south=-90. Range: [-90, 90] - yaw: rotation around the z axis. Defined accoding to right-hand rule. - Range: [-180, 180] - pitch: rotation around the yʹ axis. Defined accoding to right-hand rule. - Range: [-90, 90] - roll: rotation around the xʺ axis. Defined accoding to right-hand rule. - Range: [-180, 180] - - Args: - key: angle type - val: values in degrees - - Returns: - True if all values are within the expected range. - """ - if np.isscalar(val): - min_val = max_val = val - else: - min_val = min(val) - max_val = max(val) - - if key == 'azimuth' and -180 <= min_val <= max_val <= 180: - return True - if key == 'elevation' and -90 <= min_val <= max_val <= 90: - return True - if key == 'yaw' and -180 <= min_val <= max_val <= 180: - return True - if key == 'pitch' and -90 <= min_val <= max_val <= 90: - return True - if key == 'roll' and -180 <= min_val <= max_val <= 180: - return True - - raise ValueError(f'Invalid value for angle {key} = {val}') - - -def wrap_to_180(angle: float) -> float: - """Wrap an angle to range ±180 degrees. - - Args: - angle: angle in degrees - - Returns: - Angle in degrees wrapped to ±180 degrees. - """ - return angle - np.floor(angle / 360 + 1 / 2) * 360 - - -class ArrayGeometry(object): - """A class to simplify handling of array geometry. - - Supports translation and rotation of the array and calculation of - spherical coordinates of a given point relative to the internal - coordinate system of the array. - - Args: - mic_positions: 3D coordinates, with shape (num_mics, 3) - center: optional position of the center of the array. Defaults to the average of the coordinates. - internal_cs: internal coordinate system for the array relative to the global coordinate system. - Defaults to (x, y, z), and is rotated with the array. - """ - - def __init__( - self, - mic_positions: Union[np.ndarray, List], - center: Optional[np.ndarray] = None, - internal_cs: Optional[np.ndarray] = None, - ): - if isinstance(mic_positions, Iterable): - mic_positions = np.array(mic_positions) - - if not mic_positions.ndim == 2: - raise ValueError( - f'Expecting a 2D array specifying mic positions, but received {mic_positions.ndim}-dim array' - ) - - if not mic_positions.shape[1] == 3: - raise ValueError(f'Expecting 3D positions, but received {mic_positions.shape[1]}-dim positions') - - mic_positions_center = np.mean(mic_positions, axis=0) - self.centered_positions = mic_positions - mic_positions_center - self.center = mic_positions_center if center is None else center - - # Internal coordinate system - if internal_cs is None: - # Initially aligned with the global - self.internal_cs = np.eye(3) - else: - self.internal_cs = internal_cs - - @property - def num_mics(self): - """Return the number of microphones for the current array. - """ - return self.centered_positions.shape[0] - - @property - def positions(self): - """Absolute positions of the microphones. - """ - return self.centered_positions + self.center - - @property - def internal_positions(self): - """Positions in the internal coordinate system. - """ - return np.matmul(self.centered_positions, self.internal_cs.T) - - @property - def radius(self): - """Radius of the array, relative to the center. - """ - return max(np.linalg.norm(self.centered_positions, axis=1)) - - @staticmethod - def get_rotation(yaw: float = 0, pitch: float = 0, roll: float = 0) -> Rotation: - """Get a Rotation object for given angles. - - All angles are defined according to the right-hand rule. - - Args: - yaw: rotation around the z axis - pitch: rotation around the yʹ axis - roll: rotation around the xʺ axis - - Returns: - A rotation object constructed using the provided angles. - """ - check_angle('yaw', yaw) - check_angle('pitch', pitch) - check_angle('roll', roll) - - return Rotation.from_euler('ZYX', [yaw, pitch, roll], degrees=True) - - def translate(self, to: np.ndarray): - """Translate the array center to a new point. - - Translation does not change the centered positions or the internal coordinate system. - - Args: - to: 3D point, shape (3,) - """ - self.center = to - - def rotate(self, yaw: float = 0, pitch: float = 0, roll: float = 0): - """Apply rotation on the mic array. - - This rotates the centered microphone positions and the internal - coordinate system, it doesn't change the center of the array. - - All angles are defined according to the right-hand rule. - For example, this means that a positive pitch will result in a rotation from z - to x axis, which will result in a reduced elevation with respect to the global - horizontal plane. - - Args: - yaw: rotation around the z axis - pitch: rotation around the yʹ axis - roll: rotation around the xʺ axis - """ - # construct rotation using TB angles - rotation = self.get_rotation(yaw=yaw, pitch=pitch, roll=roll) - - # rotate centered positions - self.centered_positions = rotation.apply(self.centered_positions) - - # apply the same transformation on the internal coordinate system - self.internal_cs = rotation.apply(self.internal_cs) - - def new_rotated_array(self, yaw: float = 0, pitch: float = 0, roll: float = 0): - """Create a new array by rotating this array. - - Args: - yaw: rotation around the z axis - pitch: rotation around the yʹ axis - roll: rotation around the xʺ axis - - Returns: - A new ArrayGeometry object constructed using the provided angles. - """ - new_array = ArrayGeometry(mic_positions=self.positions, center=self.center, internal_cs=self.internal_cs) - new_array.rotate(yaw=yaw, pitch=pitch, roll=roll) - return new_array - - def spherical_relative_to_array( - self, point: np.ndarray, use_internal_cs: bool = True - ) -> Tuple[float, float, float]: - """Return spherical coordinates of a point relative to the internal coordinate system. - - Args: - point: 3D coordinate, shape (3,) - use_internal_cs: Calculate position relative to the internal coordinate system. - If `False`, the positions will be calculated relative to the - external coordinate system centered at `self.center`. - - Returns: - A tuple (distance, azimuth, elevation) relative to the mic array. - """ - rel_position = point - self.center - distance = np.linalg.norm(rel_position) - - if use_internal_cs: - # transform from the absolute coordinate system to the internal coordinate system - rel_position = np.matmul(self.internal_cs, rel_position) - - # get azimuth - azimuth = np.arctan2(rel_position[1], rel_position[0]) / np.pi * 180 - # get elevation - elevation = np.arcsin(rel_position[2] / distance) / np.pi * 180 - - return distance, azimuth, elevation - - def __str__(self): - with np.printoptions(precision=3, suppress=True): - desc = f"{type(self)}:\ncenter =\n{self.center}\ncentered positions =\n{self.centered_positions}\nradius = \n{self.radius:.3}\nabsolute positions =\n{self.positions}\ninternal coordinate system =\n{self.internal_cs}\n\n" - return desc - - def plot(self, elev=30, azim=-55, mic_size=25): - """Plot microphone positions. - - Args: - elev: elevation for the view of the plot - azim: azimuth for the view of the plot - mic_size: size of the microphone marker in the plot - """ - fig = plt.figure() - ax = fig.add_subplot(projection='3d') - - # show mic positions - for m in range(self.num_mics): - # show mic - ax.scatter( - self.positions[m, 0], - self.positions[m, 1], - self.positions[m, 2], - marker='o', - c='black', - s=mic_size, - depthshade=False, - ) - # add label - ax.text(self.positions[m, 0], self.positions[m, 1], self.positions[m, 2], str(m), c='red', zorder=10) - - # show the internal coordinate system - ax.quiver( - self.center[0], - self.center[1], - self.center[2], - self.internal_cs[:, 0], - self.internal_cs[:, 1], - self.internal_cs[:, 2], - length=self.radius, - label='internal cs', - normalize=False, - linestyle=':', - linewidth=1.0, - ) - for dim, label in enumerate(['x′', 'y′', 'z′']): - label_pos = self.center + self.radius * self.internal_cs[dim] - ax.text(label_pos[0], label_pos[1], label_pos[2], label, tuple(self.internal_cs[dim]), c='blue') - try: - # Unfortunately, equal aspect ratio has been added very recently to Axes3D - ax.set_aspect('equal') - except NotImplementedError: - logging.warning('Equal aspect ratio not supported by Axes3D') - # Set view - ax.view_init(elev=elev, azim=azim) - # Set reasonable limits for all axes, even for the case of an unequal aspect ratio - ax.set_xlim([self.center[0] - self.radius, self.center[0] + self.radius]) - ax.set_ylim([self.center[1] - self.radius, self.center[1] + self.radius]) - ax.set_zlim([self.center[2] - self.radius, self.center[2] + self.radius]) - - ax.set_xlabel('x/m') - ax.set_ylabel('y/m') - ax.set_zlabel('z/m') - ax.set_title('Microphone positions') - ax.legend() - plt.show() - - -def convert_placement_to_range( - placement: dict, room_dim: Iterable[float], object_radius: float = 0 -) -> List[List[float]]: - """Given a placement dictionary, return ranges for each dimension. - - Args: - placement: dictionary containing x, y, height, and min_to_wall - room_dim: dimensions of the room, shape (3,) - object_radius: radius of the object to be placed - - Returns - List with a range of values for each dimensions. - """ - if not np.all(np.array(room_dim) > 0): - raise ValueError(f'Room dimensions must be positive: {room_dim}') - - if object_radius < 0: - raise ValueError(f'Object radius must be non-negative: {object_radius}') - - placement_range = [None] * 3 - min_to_wall = placement.get('min_to_wall', 0) - - if min_to_wall < 0: - raise ValueError(f'Min distance to wall must be positive: {min_to_wall}') - - for idx, key in enumerate(['x', 'y', 'height']): - # Room dimension - dim = room_dim[idx] - # Construct the range - val = placement.get(key) - if val is None: - # No constrained specified on the coordinate of the mic center - min_val, max_val = 0, dim - elif np.isscalar(val): - min_val = max_val = val - else: - if len(val) != 2: - raise ValueError(f'Invalid value for placement for dim {idx}/{key}: {str(placement)}') - min_val, max_val = val - - # Make sure the array is not too close to a wall - min_val = max(min_val, min_to_wall + object_radius) - max_val = min(max_val, dim - min_to_wall - object_radius) - - if min_val > max_val or min(min_val, max_val) < 0: - raise ValueError(f'Invalid range dim {idx}/{key}: min={min_val}, max={max_val}') - - placement_range[idx] = [min_val, max_val] - - return placement_range - - -class RIRCorpusGenerator(object): - """Creates a corpus of RIRs based on a defined configuration of rooms and microphone array. - - RIRs are generated using `generate` method. - """ - - def __init__(self, cfg: DictConfig): - """ - Args: - cfg: dictionary with parameters of the simulation - """ - logging.info("Initialize RIRCorpusGenerator") - self._cfg = cfg - self.check_cfg() - - @property - def cfg(self): - """Property holding the internal config of the object. - - Note: - Changes to this config are not reflected in the state of the object. - Please create a new model with the updated config. - """ - return self._cfg - - @property - def sample_rate(self): - return self._cfg.sample_rate - - @cfg.setter - def cfg(self, cfg): - """Property holding the internal config of the object. - - Note: - Changes to this config are not reflected in the state of the object. - Please create a new model with the updated config. - """ - self._cfg = cfg - - def check_cfg(self): - """ - Checks provided configuration to ensure it has the minimal required - configuration the values are in a reasonable range. - """ - # sample rate - sample_rate = self.cfg.get('sample_rate') - if sample_rate is None: - raise ValueError('Sample rate not provided.') - elif sample_rate < 0: - raise ValueError(f'Sample rate must to be positive: {sample_rate}') - - # room configuration - room_cfg = self.cfg.get('room') - if room_cfg is None: - raise ValueError('Room configuration not provided') - - if room_cfg.get('num') is None: - raise ValueError('Number of rooms per subset not provided') - - if room_cfg.get('dim') is None: - raise ValueError('Room dimensions not provided') - - for idx, key in enumerate(['width', 'length', 'height']): - dim = room_cfg.dim.get(key) - - if dim is None: - # not provided - raise ValueError(f'Room {key} needs to be a scalar or a range, currently it is None') - elif np.isscalar(dim) and dim <= 0: - # fixed dimension - raise ValueError(f'A fixed dimension must be positive for {key}: {dim}') - elif len(dim) != 2 or not 0 < dim[0] < dim[1]: - # not a valid range - raise ValueError(f'Range must be specified with two positive increasing elements for {key}: {dim}') - - rt60 = room_cfg.get('rt60') - if rt60 is None: - # not provided - raise ValueError(f'RT60 needs to be a scalar or a range, currently it is None') - elif np.isscalar(rt60) and rt60 <= 0: - # fixed dimension - raise ValueError(f'RT60 must be positive: {rt60}') - elif len(rt60) != 2 or not 0 < rt60[0] < rt60[1]: - # not a valid range - raise ValueError(f'RT60 range must be specified with two positive increasing elements: {rt60}') - - # mic array - mic_cfg = self.cfg.get('mic_array') - if mic_cfg is None: - raise ValueError('Mic configuration not provided') - - if mic_cfg.get('positions') == 'random': - # Only num_mics and placement are required - mic_cfg_keys = ['num_mics', 'placement'] - else: - mic_cfg_keys = ['positions', 'placement', 'orientation'] - - for key in mic_cfg_keys: - if key not in mic_cfg: - raise ValueError(f'Mic array {key} not provided') - - # source - source_cfg = self.cfg.get('source') - if source_cfg is None: - raise ValueError('Source configuration not provided') - - if source_cfg.get('num') is None: - raise ValueError('Number of sources per room not provided') - elif source_cfg.num <= 0: - raise ValueError(f'Number of sources must be positive: {source_cfg.num}') - - if 'placement' not in source_cfg: - raise ValueError('Source placement dictionary not provided') - - # anechoic - if self.cfg.get('anechoic') is None: - raise ValueError(f'Anechoic configuratio not provided.') - - def generate_room_params(self) -> dict: - """Generate randomized room parameters based on the provided - configuration. - """ - # Prepare room sim parameters - if not PRA: - raise ImportError('pyroomacoustics is required for room simulation') - - room_cfg = self.cfg.room - - # Prepare rt60 - if room_cfg.rt60 is None: - raise ValueError(f'Room RT60 needs to be a scalar or a range, currently it is None') - - if np.isscalar(room_cfg.rt60): - assert room_cfg.rt60 > 0, f'RT60 should be positive: {room_cfg.rt60}' - rt60 = room_cfg.rt60 - elif len(room_cfg.rt60) == 2: - assert ( - 0 < room_cfg.rt60[0] <= room_cfg.rt60[1] - ), f'Expecting two non-decreasing values for RT60, received {room_cfg.rt60}' - rt60 = self.random.uniform(low=room_cfg.rt60[0], high=room_cfg.rt60[1]) - else: - raise ValueError(f'Unexpected value for RT60: {room_cfg.rt60}') - - # Generate a room with random dimensions - num_retries = self.cfg.get('num_retries', 20) - - for n in range(num_retries): - - # width, length, height - room_dim = np.zeros(3) - - # prepare dimensions - for idx, key in enumerate(['width', 'length', 'height']): - # get configured dimension - dim = room_cfg.dim[key] - - # set a value - if dim is None: - raise ValueError(f'Room {key} needs to be a scalar or a range, currently it is None') - elif np.isscalar(dim): - assert dim > 0, f'Dimension should be positive for {key}: {dim}' - room_dim[idx] = dim - elif len(dim) == 2: - assert 0 < dim[0] <= dim[1], f'Expecting two non-decreasing values for {key}, received {dim}' - # Reduce dimension if the previous attempt failed - room_dim[idx] = self.random.uniform(low=dim[0], high=dim[1] - n * (dim[1] - dim[0]) / num_retries) - else: - raise ValueError(f'Unexpected value for {key}: {dim}') - - try: - # Get parameters from size and RT60 - room_absorption, room_max_order = pra.inverse_sabine(rt60, room_dim) - break - except Exception as e: - logging.debug('Inverse sabine failed: %s', str(e)) - # Inverse sabine may fail if the room is too large for the selected RT60. - # Try again by generate a smaller room. - room_absorption = room_max_order = None - continue - - if room_absorption is None or room_max_order is None: - raise RuntimeError(f'Evaluation of parameters failed for RT60 {rt60}s and room size {room_dim}.') - - # Return the required values - room_params = { - 'dim': room_dim, - 'absorption': room_absorption, - 'max_order': room_max_order, - 'rt60_theoretical': rt60, - 'anechoic_absorption': self.cfg.anechoic.absorption, - 'anechoic_max_order': self.cfg.anechoic.max_order, - 'sample_rate': self.cfg.sample_rate, - } - return room_params - - def generate_array(self, room_dim: Iterable[float]) -> ArrayGeometry: - """Generate array placement for the current room and config. - - Args: - room_dim: dimensions of the room, [width, length, height] - - Returns: - Randomly placed microphone array. - """ - mic_cfg = self.cfg.mic_array - - if mic_cfg.positions == 'random': - # Create a radom set of microphones - num_mics = mic_cfg.num_mics - mic_positions = [] - - # Each microphone is placed individually - placement_range = convert_placement_to_range( - placement=mic_cfg.placement, room_dim=room_dim, object_radius=0 - ) - - # Randomize mic placement - for m in range(num_mics): - position_m = [None] * 3 - for idx in range(3): - position_m[idx] = self.random.uniform(low=placement_range[idx][0], high=placement_range[idx][1]) - mic_positions.append(position_m) - - mic_array = ArrayGeometry(mic_positions) - - else: - mic_array = ArrayGeometry(mic_cfg.positions) - - # Randomize center placement - center = np.zeros(3) - placement_range = convert_placement_to_range( - placement=mic_cfg.placement, room_dim=room_dim, object_radius=mic_array.radius - ) - - for idx in range(len(center)): - center[idx] = self.random.uniform(low=placement_range[idx][0], high=placement_range[idx][1]) - - # Place the array at the configured center point - mic_array.translate(to=center) - - # Randomize orientation - orientation = dict() - for key in ['yaw', 'roll', 'pitch']: - # angle for current orientation - angle = mic_cfg.orientation[key] - - if angle is None: - raise ValueError(f'Mic array {key} should be a scalar or a range, currently it is set to None.') - - # check it's within the expected range - check_angle(key, angle) - - if np.isscalar(angle): - orientation[key] = angle - elif len(angle) == 2: - assert angle[0] <= angle[1], f"Expecting two non-decreasing values for {key}, received {angle}" - # generate integer values, for easier bucketing, if necessary - orientation[key] = self.random.uniform(low=angle[0], high=angle[1]) - else: - raise ValueError(f'Unexpected value for orientation {key}: {angle}') - - # Rotate the array to match the selected orientation - mic_array.rotate(**orientation) - - return mic_array - - def generate_source_position(self, room_dim: Iterable[float]) -> List[List[float]]: - """Generate position for all sources in a room. - - Args: - room_dim: dimensions of a 3D shoebox room - - Returns: - List of source positions, with each position characterized with a 3D coordinate - """ - source_cfg = self.cfg.source - placement_range = convert_placement_to_range(placement=source_cfg.placement, room_dim=room_dim) - source_position = [] - - for n in range(source_cfg.num): - # generate a random point withing the range - s_pos = [None] * 3 - for idx in range(len(s_pos)): - s_pos[idx] = self.random.uniform(low=placement_range[idx][0], high=placement_range[idx][1]) - source_position.append(s_pos) - - return source_position - - def generate(self): - """Generate RIR corpus. - - This method will prepare randomized examples based on the current configuration, - run room simulations and save results to output_dir. - """ - logging.info("Generate RIR corpus") - - # Initialize - self.random = default_rng(seed=self.cfg.random_seed) - - # Prepare output dir - output_dir = self.cfg.output_dir - if output_dir.endswith('.yaml'): - output_dir = output_dir[:-5] - - # Create absolute path - logging.info('Output dir set to: %s', output_dir) - - # Generate all cases - for subset, num_rooms in self.cfg.room.num.items(): - - output_dir_subset = os.path.join(output_dir, subset) - examples = [] - - if not os.path.exists(output_dir_subset): - logging.info('Creating output directory: %s', output_dir_subset) - os.makedirs(output_dir_subset) - elif os.path.isdir(output_dir_subset) and len(os.listdir(output_dir_subset)) > 0: - raise RuntimeError(f'Output directory {output_dir_subset} is not empty.') - - # Generate examples - for n_room in range(num_rooms): - - # room info - room_params = self.generate_room_params() - - # array placement - mic_array = self.generate_array(room_params['dim']) - - # source placement - source_position = self.generate_source_position(room_params['dim']) - - # file name for the file - room_filepath = os.path.join(output_dir_subset, f'{subset}_room_{n_room:06d}.h5') - - # prepare example - example = { - 'room_params': room_params, - 'mic_array': mic_array, - 'source_position': source_position, - 'room_filepath': room_filepath, - } - examples.append(example) - - # Simulation - if (num_workers := self.cfg.get('num_workers')) is None: - num_workers = os.cpu_count() - 1 - - if num_workers > 1: - logging.info(f'Simulate using {num_workers} workers') - with multiprocessing.Pool(processes=num_workers) as pool: - metadata = list(tqdm(pool.imap(simulate_room_kwargs, examples), total=len(examples))) - - else: - logging.info('Simulate using a single worker') - metadata = [] - for example in tqdm(examples, total=len(examples)): - metadata.append(simulate_room(**example)) - - # Save manifest - manifest_filepath = os.path.join(output_dir, f'{subset}_manifest.json') - - if os.path.exists(manifest_filepath) and os.path.isfile(manifest_filepath): - raise RuntimeError(f'Manifest config file exists: {manifest_filepath}') - - # Make all paths in the manifest relative to the output dir - for data in metadata: - data['room_filepath'] = os.path.relpath(data['room_filepath'], start=output_dir) - - write_manifest(manifest_filepath, metadata) - - # Generate plots with information about generated data - plot_filepath = os.path.join(output_dir, f'{subset}_info.png') - - if os.path.exists(plot_filepath) and os.path.isfile(plot_filepath): - raise RuntimeError(f'Plot file exists: {plot_filepath}') - - plot_rir_manifest_info(manifest_filepath, plot_filepath=plot_filepath) - - # Save used configuration for reference - config_filepath = os.path.join(output_dir, 'config.yaml') - if os.path.exists(config_filepath) and os.path.isfile(config_filepath): - raise RuntimeError(f'Output config file exists: {config_filepath}') - - OmegaConf.save(self.cfg, config_filepath, resolve=True) - - -def simulate_room_kwargs(kwargs: dict) -> dict: - """Wrapper around `simulate_room` to handle kwargs. - - `pool.map(simulate_room_kwargs, examples)` would be - equivalent to `pool.starstarmap(simulate_room, examples)` - if `starstarmap` would exist. - - Args: - kwargs: kwargs that are forwarded to `simulate_room` - - Returns: - Dictionary with metadata, see `simulate_room` - """ - return simulate_room(**kwargs) - - -def simulate_room( - room_params: dict, mic_array: ArrayGeometry, source_position: Iterable[Iterable[float]], room_filepath: str, -) -> dict: - """Simulate room - - Args: - room_params: parameters of the room to be simulated - mic_array: defines positions of the microphones - source_positions: positions for all sources to be simulated - room_filepath: results are saved to this path - - Returns: - Dictionary with metadata based on simulation setup - and simulation results. Used to create the corresponding - manifest file. - """ - # room with the selected parameters - room_sim = pra.ShoeBox( - room_params['dim'], - fs=room_params['sample_rate'], - materials=pra.Material(room_params['absorption']), - max_order=room_params['max_order'], - ) - - # same geometry for generating anechoic responses - room_anechoic = pra.ShoeBox( - room_params['dim'], - fs=room_params['sample_rate'], - materials=pra.Material(room_params['anechoic_absorption']), - max_order=room_params['anechoic_max_order'], - ) - - # Compute RIRs - for room in [room_sim, room_anechoic]: - # place the array - room.add_microphone_array(mic_array.positions.T) - - # place the sources - for s_pos in source_position: - room.add_source(s_pos) - - # generate RIRs - room.compute_rir() - - # Get metadata for sources - source_distance = [] - source_azimuth = [] - source_elevation = [] - for s_pos in source_position: - distance, azimuth, elevation = mic_array.spherical_relative_to_array(s_pos) - source_distance.append(distance) - source_azimuth.append(azimuth) - source_elevation.append(elevation) - - # RIRs - rir_dataset = { - 'rir': convert_rir_to_multichannel(room_sim.rir), - 'anechoic': convert_rir_to_multichannel(room_anechoic.rir), - } - - # Prepare metadata dict and return - metadata = { - 'room_filepath': room_filepath, - 'sample_rate': room_params['sample_rate'], - 'dim': room_params['dim'], - 'rir_absorption': room_params['absorption'], - 'rir_max_order': room_params['max_order'], - 'rir_rt60_theory': room_sim.rt60_theory(), - 'rir_rt60_measured': room_sim.measure_rt60().mean(axis=0), # average across mics for each source - 'anechoic_rt60_theory': room_anechoic.rt60_theory(), - 'anechoic_rt60_measured': room_anechoic.measure_rt60().mean(axis=0), # average across mics for each source - 'anechoic_absorption': room_params['anechoic_absorption'], - 'anechoic_max_order': room_params['anechoic_max_order'], - 'mic_positions': mic_array.positions, - 'mic_center': mic_array.center, - 'source_position': source_position, - 'source_distance': source_distance, - 'source_azimuth': source_azimuth, - 'source_elevation': source_elevation, - 'num_sources': len(source_position), - } - - # Save simulated RIR - save_rir_simulation(room_filepath, rir_dataset, metadata) - - return convert_numpy_to_serializable(metadata) - - -def save_rir_simulation(filepath: str, rir_dataset: Dict[str, List[np.array]], metadata: dict): - """Save simulated RIRs and metadata. - - Args: - filepath: Path to the file where the data will be saved. - rir_dataset: Dictionary with RIR data. Each item is a set of multi-channel RIRs. - metadata: Dictionary with related metadata. - """ - if os.path.exists(filepath): - raise RuntimeError(f'Output file exists: {room_filepath}') - - num_sources = metadata['num_sources'] - - with h5py.File(filepath, 'w') as h5f: - # Save RIRs, each RIR set in a separate group - for rir_key, rir_value in rir_dataset.items(): - if len(rir_value) != num_sources: - raise ValueError( - f'Each RIR dataset should have exactly {num_sources} elements. Current RIR {key} has {len(rir_value)} elements' - ) - - rir_group = h5f.create_group(rir_key) - - # RIRs for different sources are saved under [group]['idx'] - for idx, rir in enumerate(rir_value): - rir_group.create_dataset(f'{idx}', data=rir_value[idx]) - - # Save metadata - metadata_group = h5f.create_group('metadata') - for key, value in metadata.items(): - metadata_group.create_dataset(key, data=value) - - -def load_rir_simulation(filepath: str, source: int = 0, rir_key: str = 'rir') -> Tuple[np.ndarray, float]: - """Load simulated RIRs and metadata. - - Args: - filepath: Path to simulated RIR data - source: Index of a source. - rir_key: String to denote which RIR to load, if there are multiple available. - - Returns: - Multichannel RIR as ndarray with shape (num_samples, num_channels) and scalar sample rate. - """ - with h5py.File(filepath, 'r') as h5f: - # Load RIR - rir = h5f[rir_key][f'{source}'][:] - - # Load metadata - sample_rate = h5f['metadata']['sample_rate'][()] - - return rir, sample_rate - - -def convert_numpy_to_serializable(data: Union[dict, float, np.ndarray]) -> Union[dict, float, np.ndarray]: - """Convert all numpy estries to list. - Can be used to preprocess data before writing to a JSON file. - - Args: - data: Dictionary, array or scalar. - - Returns: - The same structure, but converted to list if - the input is np.ndarray, so `data` can be seralized. - """ - if isinstance(data, dict): - for key, val in data.items(): - data[key] = convert_numpy_to_serializable(val) - elif isinstance(data, list): - data = [convert_numpy_to_serializable(d) for d in data] - elif isinstance(data, np.ndarray): - data = data.tolist() - elif isinstance(data, np.integer): - data = int(data) - elif isinstance(data, np.floating): - data = float(data) - elif isinstance(data, np.generic): - data = data.item() - - return data - - -def convert_rir_to_multichannel(rir: List[List[np.ndarray]]) -> List[np.ndarray]: - """Convert RIR to a list of arrays. - - Args: - rir: list of lists, each element is a single-channel RIR - - Returns: - List of multichannel RIRs - """ - num_mics = len(rir) - num_sources = len(rir[0]) - - mc_rir = [None] * num_sources - - for n_source in range(num_sources): - rir_len = [len(rir[m][n_source]) for m in range(num_mics)] - max_len = max(rir_len) - mc_rir[n_source] = np.zeros((max_len, num_mics)) - for n_mic, len_mic in enumerate(rir_len): - mc_rir[n_source][:len_mic, n_mic] = rir[n_mic][n_source] - - return mc_rir - - -def plot_rir_manifest_info(filepath: str, plot_filepath: str = None): - """Plot distribution of parameters from manifest file. - - Args: - filepath: path to a RIR corpus manifest file - plot_filepath: path to save the plot at - """ - metadata = read_manifest(filepath) - - # source placement - source_distance = [] - source_azimuth = [] - source_elevation = [] - source_height = [] - - # room config - rir_rt60_theory = [] - rir_rt60_measured = [] - anechoic_rt60_theory = [] - anechoic_rt60_measured = [] - - # get the required data - for data in metadata: - # source config - source_distance += data['source_distance'] - source_azimuth += data['source_azimuth'] - source_elevation += data['source_elevation'] - source_height += [s_pos[2] for s_pos in data['source_position']] - - # room config - rir_rt60_theory.append(data['rir_rt60_theory']) - rir_rt60_measured += data['rir_rt60_measured'] - anechoic_rt60_theory.append(data['anechoic_rt60_theory']) - anechoic_rt60_measured += data['anechoic_rt60_measured'] - - # plot - plt.figure(figsize=(12, 6)) - - plt.subplot(2, 4, 1) - plt.hist(source_distance, label='distance') - plt.xlabel('distance / m') - plt.ylabel('# examples') - plt.title('Source-to-array center distance') - - plt.subplot(2, 4, 2) - plt.hist(source_azimuth, label='azimuth') - plt.xlabel('azimuth / deg') - plt.ylabel('# examples') - plt.title('Source-to-array center azimuth') - - plt.subplot(2, 4, 3) - plt.hist(source_elevation, label='elevation') - plt.xlabel('elevation / deg') - plt.ylabel('# examples') - plt.title('Source-to-array center elevation') - - plt.subplot(2, 4, 4) - plt.hist(source_height, label='source height') - plt.xlabel('height / m') - plt.ylabel('# examples') - plt.title('Source height') - - plt.subplot(2, 4, 5) - plt.hist(rir_rt60_theory, label='theory') - plt.xlabel('RT60 / s') - plt.ylabel('# examples') - plt.title('RT60 theory') - - plt.subplot(2, 4, 6) - plt.hist(rir_rt60_measured, label='measured') - plt.xlabel('RT60 / s') - plt.ylabel('# examples') - plt.title('RT60 measured') - - plt.subplot(2, 4, 7) - plt.hist(anechoic_rt60_theory, label='theory') - plt.xlabel('RT60 / s') - plt.ylabel('# examples') - plt.title('RT60 theory (anechoic)') - - plt.subplot(2, 4, 8) - plt.hist(anechoic_rt60_measured, label='measured') - plt.xlabel('RT60 / s') - plt.ylabel('# examples') - plt.title('RT60 measured (anechoic)') - - for n in range(8): - plt.subplot(2, 4, n + 1) - plt.grid() - plt.legend(loc='lower left') - - plt.tight_layout() - - if plot_filepath is not None: - plt.savefig(plot_filepath) - plt.close() - logging.info('Plot saved at %s', plot_filepath) - - -class RIRMixGenerator(object): - """Creates a dataset of mixed signals at the microphone - by combining target speech, background noise and interference. - - Correspnding signals are are generated and saved - using the `generate` method. - - Input configuration is expexted to have the following structure - ``` - sample_rate: sample rate used for simulation - room: - subset: manifest for RIR data - target: - subset: manifest for target source data - noise: - subset: manifest for noise data - interference: - subset: manifest for interference data - interference_probability: probability that interference is present - max_num_interferers: max number of interferers, randomly selected between 0 and max - mix: - subset: - num: number of examples to generate - rsnr: range of RSNR - rsir: range of RSIR - ref_mic: reference microphone - ref_mic_rms: desired RMS at ref_mic - ``` - """ - - def __init__(self, cfg: DictConfig): - """ - Instantiate a RIRMixGenerator object. - - Args: - cfg: generator configuration defining data for room, - target signal, noise, interference and mixture - """ - logging.info("Initialize RIRMixGenerator") - self._cfg = cfg - self.check_cfg() - - self.subsets = self.cfg.room.keys() - logging.info('Initialized with %d subsets: %s', len(self.subsets), str(self.subsets)) - - # load manifests - self.metadata = dict() - for subset in self.subsets: - subset_data = dict() - - logging.info('Loading data for %s', subset) - for key in ['room', 'target', 'noise', 'interference']: - try: - subset_data[key] = read_manifest(self.cfg[key][subset]) - logging.info('\t%-*s: \t%d files', 15, key, len(subset_data[key])) - except Exception as e: - subset_data[key] = None - logging.info('\t%-*s: \t0 files', 15, key) - logging.warning('\t\tManifest data not loaded. Exception: %s', str(e)) - - self.metadata[subset] = subset_data - - logging.info('Loaded all manifests') - - self.num_retries = self.cfg.get('num_retries', 5) - - @property - def cfg(self): - """Property holding the internal config of the object. - - Note: - Changes to this config are not reflected in the state of the object. - Please create a new model with the updated config. - """ - return self._cfg - - @property - def sample_rate(self): - return self._cfg.sample_rate - - @cfg.setter - def cfg(self, cfg): - """Property holding the internal config of the object. - - Note: - Changes to this config are not reflected in the state of the object. - Please create a new model with the updated config. - """ - self._cfg = cfg - - def check_cfg(self): - """ - Checks provided configuration to ensure it has the minimal required - configuration the values are in a reasonable range. - """ - # sample rate - sample_rate = self.cfg.get('sample_rate') - if sample_rate is None: - raise ValueError('Sample rate not provided.') - elif sample_rate < 0: - raise ValueError(f'Sample rate must be positive: {sample_rate}') - - # room configuration - room_cfg = self.cfg.get('room') - if not room_cfg: - raise ValueError( - 'Room configuration not provided. Expecting RIR manifests in format {subset: path_to_manifest}' - ) - - # target configuration - target_cfg = self.cfg.get('target') - if not target_cfg: - raise ValueError( - 'Target configuration not provided. Expecting audio manifests in format {subset: path_to_manifest}' - ) - - for key in ['azimuth', 'elevation', 'distance']: - value = target_cfg.get(key) - - if value is None or np.isscalar(value): - # no constraint or a fixed dimension is ok - pass - elif len(value) != 2 or not value[0] < value[1]: - # not a valid range - raise ValueError(f'Range must be specified with two positive increasing elements for {key}: {value}') - - # noise configuration - noise_cfg = self.cfg.get('noise') - if not noise_cfg: - raise ValueError( - 'Noise configuration not provided. Expecting audio manifests in format {subset: path_to_manifest}' - ) - - # interference configuration - interference_cfg = self.cfg.get('interference') - if not interference_cfg: - logging.info('Interference configuration not provided.') - else: - interference_probability = interference_cfg.get('interference_probability', 0) - max_num_interferers = interference_cfg.get('max_num_interferers', 0) - min_azimuth_to_target = interference_cfg.get('min_azimuth_to_target', 0) - if interference_probability is not None: - if interference_probability < 0: - raise ValueError( - f'Interference probability must be non-negative. Current value: {interference_prob}' - ) - elif interference_probability > 0: - assert ( - max_num_interferers is not None and max_num_interferers > 0 - ), f'Max number of interferers must be positive. Current value: {max_num_interferers}' - assert ( - min_azimuth_to_target is not None and min_azimuth_to_target >= 0 - ), f'Min azimuth to target must be non-negative' - - # mix configuration - mix_cfg = self.cfg.get('mix') - if not mix_cfg: - raise ValueError('Mix configuration not provided. Expecting configuration for each subset.') - if 'ref_mic' not in mix_cfg: - raise ValueError('Reference microphone not defined.') - if 'ref_mic_rms' not in mix_cfg: - raise ValueError('Reference microphone RMS not defined.') - - def generate_target(self, subset: str) -> dict: - """ - Prepare a dictionary with target configuration. - - The output dictionary contains the following information - ``` - room_index: index of the selected room from the RIR corpus - room_filepath: path to the room simulation file - source: index of the selected source for the target - rt60: reverberation time of the selected room - num_mics: number of microphones - azimuth: azimuth of the target source, relative to the microphone array - elevation: elevation of the target source, relative to the microphone array - distance: distance of the target source, relative to the microphone array - audio_filepath: path to the audio file for the target source - text: text for the target source audio signal, if available - duration: duration of the target source audio signal - ``` - - Args: - subset: string denoting a subset which will be used to selected target - audio and room parameters. - - Returns: - Dictionary with target configuration, including room, source index, and audio information. - """ - # Utility function - def select_target_source(room_metadata, room_indices): - """Find a room and a source that satisfies the constraints. - """ - for room_index in room_indices: - # Select room - room_data = room_metadata[room_index] - - # Candidate sources - sources = self.random.choice(room_data['num_sources'], size=self.num_retries, replace=False) - - # Select target source in this room - for source in sources: - # Check constraints - constraints_met = [] - for constraint in ['azimuth', 'elevation', 'distance']: - if self.cfg.target.get(constraint) is not None: - # Check that the selected source is in the range - source_value = room_data[f'source_{constraint}'][source] - if self.cfg.target[constraint][0] <= source_value <= self.cfg.target[constraint][1]: - constraints_met.append(True) - else: - constraints_met.append(False) - # No need to check the remaining constraints - break - - # Check if a feasible source is found - if all(constraints_met): - # A feasible source has been found - return source, room_index - - return None, None - - # Prepare room & source position - room_metadata = self.metadata[subset]['room'] - room_indices = self.random.choice(len(room_metadata), size=self.num_retries, replace=False) - source, room_index = select_target_source(room_metadata, room_indices) - - if source is None: - raise RuntimeError(f'Could not find a feasible source given target constraints {self.cfg.target}') - - room_data = room_metadata[room_index] - - # Optional: select subset of channels - num_available_mics = len(room_data['mic_positions']) - if 'mic_array' in self.cfg: - num_mics = self.cfg.mic_array['num_mics'] - mic_selection = self.cfg.mic_array['selection'] - - if mic_selection == 'random': - logging.debug('Randomly selecting %d mics', num_mics) - selected_mics = self.random.choice(num_available_mics, size=num_mics, replace=False) - elif isinstance(mic_selection, Iterable): - logging.debug('Using explicitly selected mics: %s', str(mic_selection)) - assert ( - 0 <= min(mic_selection) < num_available_mics - ), f'Expecting mic_selection in range [0,{num_available_mics}), current value: {mic_selection}' - selected_mics = np.array(mic_selection) - else: - raise ValueError(f'Unexpected value for mic_selection: {mic_selection}') - else: - logging.debug('Using all %d available mics', num_available_mics) - num_mics = num_available_mics - selected_mics = np.arange(num_mics) - - # Double-check the number of mics is as expected - assert ( - len(selected_mics) == num_mics - ), f'Expecting {num_mics} mics, but received {len(selected_mics)} mics: {selected_mics}' - logging.debug('Selected mics: %s', str(selected_mics)) - - # Calculate distance from the source to each microphone - mic_positions = np.array(room_data['mic_positions'])[selected_mics] - source_position = np.array(room_data['source_position'][source]) - distance_source_to_mic = np.linalg.norm(mic_positions - source_position, axis=1) - - # Handle relative paths - room_filepath = room_data['room_filepath'] - if not os.path.isabs(room_filepath): - manifest_dir = os.path.dirname(self.cfg.room[subset]) - room_filepath = os.path.join(manifest_dir, room_filepath) - - target_cfg = { - 'room_index': int(room_index), - 'room_filepath': room_filepath, - 'source': source, - 'rt60': room_data['rir_rt60_measured'][source], - 'selected_mics': selected_mics.tolist(), - # Positions - 'source_position': source_position.tolist(), - 'mic_positions': mic_positions.tolist(), - # Relative to center of the array - 'azimuth': room_data['source_azimuth'][source], - 'elevation': room_data['source_elevation'][source], - 'distance': room_data['source_distance'][source], - # Relative to mics - 'distance_source_to_mic': distance_source_to_mic, - } - - return target_cfg - - def generate_interference(self, subset: str, target_cfg: dict) -> List[dict]: - """ - Prepare a list of dictionaries with interference configuration. - - Args: - subset: string denoting a subset which will be used to select interference audio. - target_cfg: dictionary with target configuration. This is used to determine - the minimal required duration for the noise signal. - - Returns: - List of dictionary with interference configuration, including source index and audio information - for one or more interference sources. - """ - if (interference_metadata := self.metadata[subset]['interference']) is None: - # No interference to be configured - return None - - # Configure interfering sources - max_num_sources = self.cfg.interference.get('max_num_interferers', 0) - interference_probability = self.cfg.interference.get('interference_probability', 0) - - if ( - max_num_sources >= 1 - and interference_probability > 0 - and self.random.uniform(low=0.0, high=1.0) < interference_probability - ): - # interference present - num_interferers = self.random.integers(low=1, high=max_num_sources + 1) - else: - # interference not present - return None - - # Room setup: same room as target - room_index = target_cfg['room_index'] - room_data = self.metadata[subset]['room'][room_index] - feasible_sources = list(range(room_data['num_sources'])) - # target source is not eligible - feasible_sources.remove(target_cfg['source']) - - # Constraints for interfering sources - min_azimuth_to_target = self.cfg.interference.get('min_azimuth_to_target', 0) - - # Prepare interference configuration - interference_cfg = [] - for n in range(num_interferers): - - # Select a source - source = None - while len(feasible_sources) > 0 and source is None: - - # Select a potential source for the target - source = self.random.choice(feasible_sources) - feasible_sources.remove(source) - - # Check azimuth separation - if min_azimuth_to_target > 0: - source_azimuth = room_data['source_azimuth'][source] - azimuth_diff = wrap_to_180(source_azimuth - target_cfg['azimuth']) - if abs(azimuth_diff) < min_azimuth_to_target: - # Try again - source = None - continue - - if source is None: - logging.warning('Could not select a feasible interference source %d of %s', n, num_interferers) - - # Return what we have for now or None - return interference_cfg if interference_cfg else None - - # Current source setup - interfering_source = { - 'source': source, - 'selected_mics': target_cfg['selected_mics'], - 'position': room_data['source_position'][source], - 'azimuth': room_data['source_azimuth'][source], - 'elevation': room_data['source_elevation'][source], - 'distance': room_data['source_distance'][source], - } - - # Done with interference for this source - interference_cfg.append(interfering_source) - - return interference_cfg - - def generate_mix(self, subset: str, target_cfg: dict) -> dict: - """Generate scaling parameters for mixing - the target speech at the microphone, background noise - and interference signal at the microphone. - - The output dictionary contains the following information - ``` - rsnr: reverberant signal-to-noise ratio - rsir: reverberant signal-to-interference ratio - ref_mic: reference microphone for calculating the metrics - ref_mic_rms: RMS of the signal at the reference microphone - ``` - - Args: - subset: string denoting the subset of configuration - target_cfg: dictionary with target configuration - - Returns: - Dictionary containing configured RSNR, RSIR, ref_mic - and RMS on ref_mic. - """ - mix_cfg = dict() - - for key in ['rsnr', 'rsir', 'ref_mic', 'ref_mic_rms', 'min_duration']: - if key in self.cfg.mix[subset]: - # Take the value from subset config - value = self.cfg.mix[subset].get(key) - else: - # Take the global value - value = self.cfg.mix.get(key) - - if value is None: - mix_cfg[key] = None - elif np.isscalar(value): - mix_cfg[key] = value - elif len(value) == 2: - # Select from the given range, including the upper bound - mix_cfg[key] = self.random.integers(low=value[0], high=value[1] + 1) - else: - # Select one of the multiple values - mix_cfg[key] = self.random.choice(value) - - if mix_cfg['ref_mic'] == 'closest': - # Select the closest mic as the reference - mix_cfg['ref_mic'] = np.argmin(target_cfg['distance_source_to_mic']) - - # Configuration for saving individual components - mix_cfg['save'] = OmegaConf.to_object(self.cfg.mix['save']) if 'save' in self.cfg.mix else {} - - return mix_cfg - - def generate(self): - """Generate a corpus of microphone signals by mixing target, background noise - and interference signals. - - This method will prepare randomized examples based on the current configuration, - run simulations and save results to output_dir. - """ - logging.info('Generate mixed signals') - - # Initialize - self.random = default_rng(seed=self.cfg.random_seed) - - # Prepare output dir - output_dir = self.cfg.output_dir - if output_dir.endswith('.yaml'): - output_dir = output_dir[:-5] - - # Create absolute path - logging.info('Output dir set to: %s', output_dir) - - # Generate all cases - for subset in self.subsets: - - output_dir_subset = os.path.join(output_dir, subset) - examples = [] - - if not os.path.exists(output_dir_subset): - logging.info('Creating output directory: %s', output_dir_subset) - os.makedirs(output_dir_subset) - elif os.path.isdir(output_dir_subset) and len(os.listdir(output_dir_subset)) > 0: - raise RuntimeError(f'Output directory {output_dir_subset} is not empty.') - - num_examples = self.cfg.mix[subset].num - logging.info('Preparing %d examples for subset %s', num_examples, subset) - - # Generate examples - for n_example in tqdm(range(num_examples), total=num_examples, desc=f'Preparing {subset}'): - # prepare configuration - target_cfg = self.generate_target(subset) - interference_cfg = self.generate_interference(subset, target_cfg) - mix_cfg = self.generate_mix(subset, target_cfg) - - # base file name - base_output_filepath = os.path.join(output_dir_subset, f'{subset}_example_{n_example:09d}') - - # prepare example - example = { - 'sample_rate': self.sample_rate, - 'target_cfg': target_cfg, - 'interference_cfg': interference_cfg, - 'mix_cfg': mix_cfg, - 'base_output_filepath': base_output_filepath, - } - - examples.append(example) - - # Audio data - audio_metadata = { - 'target': self.metadata[subset]['target'], - 'target_dir': os.path.dirname(self.cfg.target[subset]), # manifest_dir - 'noise': self.metadata[subset]['noise'], - 'noise_dir': os.path.dirname(self.cfg.noise[subset]), # manifest_dir - } - - if interference_cfg is not None: - audio_metadata.update( - { - 'interference': self.metadata[subset]['interference'], - 'interference_dir': os.path.dirname(self.cfg.interference[subset]), # manifest_dir - } - ) - - # Simulation - if (num_workers := self.cfg.get('num_workers')) is None: - num_workers = os.cpu_count() - 1 - - if num_workers is not None and num_workers > 1: - logging.info(f'Simulate using {num_workers} workers') - examples_and_audio_metadata = zip(examples, itertools.repeat(audio_metadata, len(examples))) - with multiprocessing.Pool(processes=num_workers) as pool: - metadata = list( - tqdm( - pool.imap(simulate_room_mix_helper, examples_and_audio_metadata), - total=len(examples), - desc=f'Simulating {subset}', - ) - ) - else: - logging.info('Simulate using a single worker') - metadata = [] - for example in tqdm(examples, total=len(examples), desc=f'Simulating {subset}'): - metadata.append(simulate_room_mix(**example, audio_metadata=audio_metadata)) - - # Save manifest - manifest_filepath = os.path.join(output_dir, f'{os.path.basename(output_dir)}_{subset}.json') - - if os.path.exists(manifest_filepath) and os.path.isfile(manifest_filepath): - raise RuntimeError(f'Manifest config file exists: {manifest_filepath}') - - # Make all paths in the manifest relative to the output dir - for data in tqdm(metadata, total=len(metadata), desc=f'Making filepaths relative {subset}'): - for key, val in data.items(): - if key.endswith('_filepath') and val is not None: - data[key] = os.path.relpath(val, start=output_dir) - - write_manifest(manifest_filepath, metadata) - - # Generate plots with information about generated data - plot_filepath = os.path.join(output_dir, f'{os.path.basename(output_dir)}_{subset}_info.png') - - if os.path.exists(plot_filepath) and os.path.isfile(plot_filepath): - raise RuntimeError(f'Plot file exists: {plot_filepath}') - - plot_mix_manifest_info(manifest_filepath, plot_filepath=plot_filepath) - - # Save used configuration for reference - config_filepath = os.path.join(output_dir, 'config.yaml') - if os.path.exists(config_filepath) and os.path.isfile(config_filepath): - raise RuntimeError(f'Output config file exists: {config_filepath}') - - OmegaConf.save(self.cfg, config_filepath, resolve=True) - - -def convolve_rir(signal: np.ndarray, rir: np.ndarray) -> np.ndarray: - """Convolve signal with a possibly multichannel IR in rir, i.e., - calculate the following for each channel m: - - signal_m = rir_m \ast signal - - Args: - signal: single-channel signal (samples,) - rir: single- or multi-channel IR, (samples,) or (samples, channels) - - Returns: - out: same length as signal, same number of channels as rir, shape (samples, channels) - """ - num_samples = len(signal) - if rir.ndim == 1: - # convolve and trim to length - out = convolve(signal, rir)[:num_samples] - elif rir.ndim == 2: - num_channels = rir.shape[1] - out = np.zeros((num_samples, num_channels)) - for m in range(num_channels): - out[:, m] = convolve(signal, rir[:, m])[:num_samples] - - else: - raise RuntimeError(f'RIR with {rir.ndim} not supported') - - return out - - -def calculate_drr(rir: np.ndarray, sample_rate: float, n_direct: List[int], n_0_ms=2.5) -> List[float]: - """Calculate direct-to-reverberant ratio (DRR) from the measured RIR. - - Calculation is done as in eq. (3) from [1]. - - Args: - rir: room impulse response, shape (num_samples, num_channels) - sample_rate: sample rate for the impulse response - n_direct: direct path delay - n_0_ms: window around n_direct for calculating the direct path energy - - Returns: - Calculated DRR for each channel of the input RIR. - - References: - [1] Eaton et al, The ACE challenge: Corpus description and performance evaluation, WASPAA 2015 - """ - # Define a window around the direct path delay - n_0 = int(n_0_ms * sample_rate / 1000) - - len_rir, num_channels = rir.shape - drr = [None] * num_channels - for m in range(num_channels): - - # Window around the direct path - dir_start = max(n_direct[m] - n_0, 0) - dir_end = n_direct[m] + n_0 - - # Power of the direct component - pow_dir = np.sum(np.abs(rir[dir_start:dir_end, m]) ** 2) / len_rir - - # Power of the reverberant component - pow_reverberant = (np.sum(np.abs(rir[0:dir_start, m]) ** 2) + np.sum(np.abs(rir[dir_end:, m]) ** 2)) / len_rir - - # DRR in dB - drr[m] = pow2db(pow_dir / pow_reverberant) - - return drr - - -def normalize_max(x: np.ndarray, max_db: float = 0, eps: float = 1e-16) -> np.ndarray: - """Normalize max input value to max_db full scale (±1). - - Args: - x: input signal - max_db: desired max magnitude compared to full scale - eps: small regularization constant - - Returns: - Normalized signal with max absolute value max_db. - """ - max_val = db2mag(max_db) - return max_val * x / (np.max(np.abs(x)) + eps) - - -def simultaneously_active_rms( - x: np.ndarray, - y: np.ndarray, - sample_rate: float, - rms_threshold_db: float = -60, - window_len_ms: float = 200, - min_active_duration: float = 0.5, -) -> Tuple[float, float]: - """Calculate RMS over segments where both input signals are active. - - Args: - x: first input signal - y: second input signal - sample_rate: sample rate for input signals in Hz - rms_threshold_db: threshold for determining activity of the signal, relative - to max absolute value - window_len_ms: window length in milliseconds, used for calculating segmental RMS - min_active_duration: minimal duration of the active segments - - Returns: - RMS value over active segments for x and y. - """ - if len(x) != len(y): - raise RuntimeError(f'Expecting signals of same length: len(x)={len(x)}, len(y)={len(y)}') - window_len = int(window_len_ms * sample_rate / 1000) - rms_threshold = db2mag(rms_threshold_db) # linear scale - - x_normalized = normalize_max(x) - y_normalized = normalize_max(y) - - x_active_power = y_active_power = active_len = 0 - for start in range(0, len(x) - window_len, window_len): - window = slice(start, start + window_len) - - # check activity on the scaled signal - x_window_rms = rms(x_normalized[window]) - y_window_rms = rms(y_normalized[window]) - - if x_window_rms > rms_threshold and y_window_rms > rms_threshold: - # sum the power of the original non-scaled signal - x_active_power += np.sum(np.abs(x[window]) ** 2) - y_active_power += np.sum(np.abs(y[window]) ** 2) - active_len += window_len - - if active_len < int(min_active_duration * sample_rate): - raise RuntimeError( - f'Signals are simultaneously active less than {min_active_duration} s: only {active_len/sample_rate} s' - ) - - # normalize - x_active_power /= active_len - y_active_power /= active_len - - return np.sqrt(x_active_power), np.sqrt(y_active_power) - - -def scaled_disturbance( - signal: np.ndarray, - disturbance: np.ndarray, - sdr: float, - sample_rate: float = None, - ref_channel: int = 0, - eps: float = 1e-16, -) -> np.ndarray: - """ - Args: - signal: numpy array, shape (num_samples, num_channels) - disturbance: numpy array, same shape as signal - sdr: desired signal-to-disturbance ration - sample_rate: sample rate of the input signals - ref_channel: ref mic used to calculate RMS - eps: regularization constant - - Returns: - Scaled disturbance, so that signal-to-disturbance ratio at ref_channel - is approximately equal to input SDR during simultaneously active - segment of signal and disturbance. - """ - if signal.shape != disturbance.shape: - raise ValueError(f'Signal and disturbance shapes do not match: {signal.shape} != {disturbance.shape}') - - # set scaling based on RMS at ref_mic - signal_rms, disturbance_rms = simultaneously_active_rms( - signal[:, ref_channel], disturbance[:, ref_channel], sample_rate=sample_rate - ) - disturbance_gain = db2mag(-sdr) * signal_rms / (disturbance_rms + eps) - # scale disturbance - scaled_disturbance = disturbance_gain * disturbance - return scaled_disturbance - - -def prepare_source_signal( - signal_type: str, - sample_rate: int, - audio_data: List[dict], - audio_dir: Optional[str] = None, - min_duration: Optional[int] = None, - ref_signal: Optional[np.ndarray] = None, - mic_positions: Optional[np.ndarray] = None, - num_retries: int = 10, -) -> tuple: - """Prepare an audio signal for a source. - - Args: - signal_type: 'point' or 'diffuse' - sample_rate: Sampling rate for the signal - audio_data: List of audio items, each is a dictionary with audio_filepath, duration, offset and optionally text - audio_dir: Base directory for resolving paths, e.g., manifest basedir - min_duration: Minimal duration to be loaded if ref_signal is not provided, in seconds - ref_signal: Optional, used to determine the length of the signal - mic_positions: Optional, used to prepare approximately diffuse signal - num_retries: Number of retries when selecting the source files - - Returns: - (audio_signal, metadata), where audio_signal is an ndarray and metadata is a dictionary - with audio filepaths, durations and offsets - """ - if not signal_type in ['point', 'diffuse']: - raise ValueError(f'Unexpected signal type {signal_type}.') - - if audio_data is None: - # No data to load - return None - - metadata = {} - - if ref_signal is None: - audio_signal = None - # load at least one sample if min_duration is not provided - samples_to_load = int(min_duration * sample_rate) if min_duration is not None else 1 - source_signals_metadata = {'audio_filepath': [], 'duration': [], 'offset': [], 'text': []} - - while samples_to_load > 0: - # Select a random item and load the audio - item = random.choice(audio_data) - - audio_filepath = item['audio_filepath'] - if not os.path.isabs(audio_filepath) and audio_dir is not None: - audio_filepath = os.path.join(audio_dir, audio_filepath) - - # Load audio - check_min_sample_rate(audio_filepath, sample_rate) - audio_segment = AudioSegment.from_file( - audio_file=audio_filepath, - target_sr=sample_rate, - duration=item['duration'], - offset=item.get('offset', 0), - ) - - if signal_type == 'point': - if audio_segment.num_channels > 1: - raise RuntimeError( - f'Expecting single-channel source signal, but received {audio_segment.num_channels}. File: {audio_filepath}' - ) - else: - raise ValueError(f'Unexpected signal type {signal_type}.') - - source_signals_metadata['audio_filepath'].append(audio_filepath) - source_signals_metadata['duration'].append(item['duration']) - source_signals_metadata['duration'].append(item.get('offset', 0)) - source_signals_metadata['text'].append(item.get('text')) - - # not perfect, since different files may have different distributions - segment_samples = normalize_max(audio_segment.samples) - # concatenate - audio_signal = ( - np.concatenate((audio_signal, segment_samples)) if audio_signal is not None else segment_samples - ) - # remaining samples - samples_to_load -= len(segment_samples) - - # Finally, we need only the metadata for the complete signal - metadata = { - 'duration': sum(source_signals_metadata['duration']), - 'offset': 0, - } - - # Add text only if all source signals have text - if all([isinstance(tt, str) for tt in source_signals_metadata['text']]): - metadata['text'] = ' '.join(source_signals_metadata['text']) - else: - # Load a signal with total_len samples and ensure it has enough simultaneous activity/overlap with ref_signal - # Concatenate multiple files if necessary - total_len = len(ref_signal) - - for n in range(num_retries): - - audio_signal = None - source_signals_metadata = {'audio_filepath': [], 'duration': [], 'offset': []} - - if signal_type == 'point': - samples_to_load = total_len - elif signal_type == 'diffuse': - # Load longer signal so it can be reshaped into (samples, mics) and - # used to generate approximately diffuse noise field - num_mics = len(mic_positions) - samples_to_load = num_mics * total_len - - while samples_to_load > 0: - # Select an audio file - item = random.choice(audio_data) - - audio_filepath = item['audio_filepath'] - if not os.path.isabs(audio_filepath) and audio_dir is not None: - audio_filepath = os.path.join(audio_dir, audio_filepath) - - # Load audio signal - check_min_sample_rate(audio_filepath, sample_rate) - - if (max_offset := item['duration'] - np.ceil(samples_to_load / sample_rate)) > 0: - # Load with a random offset if the example is longer than samples_to_load - offset = random.uniform(0, max_offset) - duration = -1 - else: - # Load the whole file - offset, duration = 0, item['duration'] - audio_segment = AudioSegment.from_file( - audio_file=audio_filepath, target_sr=sample_rate, duration=duration, offset=offset - ) - - # Prepare a single-channel signal - if audio_segment.num_channels == 1: - # Take all samples - segment_samples = audio_segment.samples - else: - # Take a random channel - selected_channel = random.choice(range(audio_segment.num_channels)) - segment_samples = audio_segment.samples[:, selected_channel] - - source_signals_metadata['audio_filepath'].append(audio_filepath) - source_signals_metadata['duration'].append(len(segment_samples) / sample_rate) - source_signals_metadata['offset'].append(offset) - - # not perfect, since different files may have different distributions - segment_samples = normalize_max(segment_samples) - # concatenate - audio_signal = ( - np.concatenate((audio_signal, segment_samples)) if audio_signal is not None else segment_samples - ) - # remaining samples - samples_to_load -= len(segment_samples) - - if signal_type == 'diffuse' and num_mics > 1: - try: - # Trim and reshape to num_mics to prepare num_mics source signals - audio_signal = audio_signal[: num_mics * total_len].reshape(num_mics, -1).T - - # Make spherically diffuse noise - audio_signal = generate_approximate_noise_field( - mic_positions=np.array(mic_positions), noise_signal=audio_signal, sample_rate=sample_rate - ) - except Exception as e: - logging.info('Failed to generate approximate noise field: %s', str(e)) - logging.info('Try again.') - # Try again - audio_signal, source_signals_metadata = None, {} - continue - - # Trim to length - audio_signal = audio_signal[:total_len, ...] - - # Include the channel dimension if the reference includes it - if ref_signal.ndim == 2 and audio_signal.ndim == 1: - audio_signal = audio_signal[:, None] - - try: - # Signal and ref_signal should be simultaneously active - simultaneously_active_rms(ref_signal, audio_signal, sample_rate=sample_rate) - # We have enough overlap - break - except Exception as e: - # Signal and ref_signal are not overlapping, try again - logging.info('Exception: %s', str(e)) - logging.info('Signals are not overlapping, try again.') - audio_signal, source_signals_metadata = None, {} - continue - - if audio_signal is None: - logging.warning('Audio signal not set: %s.', signal_type) - - metadata['source_signals'] = source_signals_metadata - - return audio_signal, metadata - - -def check_min_sample_rate(filepath: str, sample_rate: float): - """Make sure the file's sample rate is at least sample_rate. - This will make sure that we have only downsampling if loading - this file, while upsampling is not permitted. - - Args: - filepath: path to a file - sample_rate: desired sample rate - """ - file_sample_rate = librosa.get_samplerate(path=filepath) - if file_sample_rate < sample_rate: - raise RuntimeError( - f'Sample rate ({file_sample_rate}) is lower than the desired sample rate ({sample_rate}). File: {filepath}.' - ) - - -def simulate_room_mix( - sample_rate: int, - target_cfg: dict, - interference_cfg: dict, - mix_cfg: dict, - audio_metadata: dict, - base_output_filepath: str, - max_amplitude: float = 0.999, - eps: float = 1e-16, -) -> dict: - """Simulate mixture signal at the microphone, including target, noise and - interference signals and mixed at specific RSNR and RSIR. - - Args: - sample_rate: Sample rate for all signals - target_cfg: Dictionary with configuration of the target. Includes - room_filepath, source index, audio_filepath, duration - noise_cfg: List of dictionaries, where each item includes audio_filepath, - offset and duration. - interference_cfg: List of dictionaries, where each item contains source - index - mix_cfg: Dictionary with the mixture configuration. Includes RSNR, RSIR, - ref_mic and ref_mic_rms. - audio_metadata: Dictionary with a list of files for target, noise and interference - base_output_filepath: All output audio files will be saved with this prefix by - adding a diffierent suffix for each component, e.g., _mic.wav. - max_amplitude: Maximum amplitude of the mic signal, used to prevent clipping. - eps: Small regularization constant. - - Returns: - Dictionary with metadata based on the mixture setup and - simulation results. This corresponds to a line of the - output manifest file. - """ - # Local utilities - def load_rir( - room_filepath: str, source: int, selected_mics: list, sample_rate: float, rir_key: str = 'rir' - ) -> np.ndarray: - """Load a RIR and check that the sample rate is matching the desired sample rate - - Args: - room_filepath: Path to a room simulation in an h5 file - source: Index of the desired source - sample_rate: Sample rate of the simulation - rir_key: Key of the RIR to load from the simulation. - - Returns: - Numpy array with shape (num_samples, num_channels) - """ - rir, rir_sample_rate = load_rir_simulation(room_filepath, source=source, rir_key=rir_key) - if rir_sample_rate != sample_rate: - raise RuntimeError( - f'RIR sample rate ({sample_rate}) is not matching the expected sample rate ({sample_rate}). File: {room_filepath}' - ) - return rir[:, selected_mics] - - def get_early_rir( - rir: np.ndarray, rir_anechoic: np.ndarray, sample_rate: int, early_duration: float = 0.050 - ) -> np.ndarray: - """Return only the early part of the RIR. - """ - early_len = int(early_duration * sample_rate) - direct_path_delay = np.min(np.argmax(rir_anechoic, axis=0)) - rir_early = rir.copy() - rir_early[direct_path_delay + early_len :, :] = 0 - return rir_early - - def save_audio( - base_path: str, - tag: str, - audio_signal: Optional[np.ndarray], - sample_rate: int, - save: str = 'all', - ref_mic: Optional[int] = None, - format: str = 'wav', - subtype: str = 'float', - ): - """Save audio signal and return filepath. - """ - if (audio_signal is None) or (not save): - return None - - if save == 'ref_mic': - # save only ref_mic - audio_signal = audio_signal[:, ref_mic] - - audio_filepath = base_path + f'_{tag}.{format}' - sf.write(audio_filepath, audio_signal, sample_rate, subtype) - - return audio_filepath - - # Target RIRs - target_rir = load_rir( - target_cfg['room_filepath'], - source=target_cfg['source'], - selected_mics=target_cfg['selected_mics'], - sample_rate=sample_rate, - ) - target_rir_anechoic = load_rir( - target_cfg['room_filepath'], - source=target_cfg['source'], - sample_rate=sample_rate, - selected_mics=target_cfg['selected_mics'], - rir_key='anechoic', - ) - target_rir_early = get_early_rir(rir=target_rir, rir_anechoic=target_rir_anechoic, sample_rate=sample_rate) - - # Target signals - target_signal, target_metadata = prepare_source_signal( - signal_type='point', - sample_rate=sample_rate, - audio_data=audio_metadata['target'], - audio_dir=audio_metadata['target_dir'], - min_duration=mix_cfg['min_duration'], - ) - source_signals_metadata = {'target': target_metadata['source_signals']} - - # Convolve target - target_reverberant = convolve_rir(target_signal, target_rir) - target_anechoic = convolve_rir(target_signal, target_rir_anechoic) - target_early = convolve_rir(target_signal, target_rir_early) - - # Prepare noise signal - noise, noise_metadata = prepare_source_signal( - signal_type='diffuse', - sample_rate=sample_rate, - mic_positions=target_cfg['mic_positions'], - audio_data=audio_metadata['noise'], - audio_dir=audio_metadata['noise_dir'], - ref_signal=target_reverberant, - ) - source_signals_metadata['noise'] = noise_metadata['source_signals'] - - # Prepare interference signal - if interference_cfg is None: - interference = None - else: - # Load interference signals - interference = 0 - source_signals_metadata['interference'] = [] - for i_cfg in interference_cfg: - # Load single-channel signal for directional interference - i_signal, i_metadata = prepare_source_signal( - signal_type='point', - sample_rate=sample_rate, - audio_data=audio_metadata['interference'], - audio_dir=audio_metadata['interference_dir'], - ref_signal=target_signal, - ) - source_signals_metadata['interference'].append(i_metadata['source_signals']) - # Load RIR from the same room as the target, but a difference source - i_rir = load_rir( - target_cfg['room_filepath'], - source=i_cfg['source'], - selected_mics=i_cfg['selected_mics'], - sample_rate=sample_rate, - ) - # Convolve interference - i_reverberant = convolve_rir(i_signal, i_rir) - # Sum - interference += i_reverberant - - # Scale and add components of the signal - mic = target_reverberant.copy() - - if noise is not None: - noise = scaled_disturbance( - signal=target_reverberant, - disturbance=noise, - sdr=mix_cfg['rsnr'], - sample_rate=sample_rate, - ref_channel=mix_cfg['ref_mic'], - ) - # Update mic signal - mic += noise - - if interference is not None: - interference = scaled_disturbance( - signal=target_reverberant, - disturbance=interference, - sdr=mix_cfg['rsir'], - sample_rate=sample_rate, - ref_channel=mix_cfg['ref_mic'], - ) - # Update mic signal - mic += interference - - # Set the final mic signal level - mic_rms = rms(mic[:, mix_cfg['ref_mic']]) - global_gain = db2mag(mix_cfg['ref_mic_rms']) / (mic_rms + eps) - mic_max = np.max(np.abs(mic)) - if (clipped_max := mic_max * global_gain) > max_amplitude: - # Downscale the global gain to prevent clipping + adjust ref_mic_rms accordingly - clipping_prevention_gain = max_amplitude / clipped_max - global_gain *= clipping_prevention_gain - mix_cfg['ref_mic_rms'] += mag2db(clipping_prevention_gain) - - logging.debug( - 'Clipping prevented for example %s (protection gain: %.2f dB)', - base_output_filepath, - mag2db(clipping_prevention_gain), - ) - - # save signals - signals = { - 'mic': mic, - 'target_reverberant': target_reverberant, - 'target_anechoic': target_anechoic, - 'target_early': target_early, - 'noise': noise, - 'interference': interference, - } - - metadata = {} - - for tag, signal in signals.items(): - - if signal is not None: - # scale all signal components with the global gain - signal = global_gain * signal - - audio_filepath = save_audio( - base_path=base_output_filepath, - tag=tag, - audio_signal=signal, - sample_rate=sample_rate, - save=mix_cfg['save'].get(tag, 'all'), - ref_mic=mix_cfg['ref_mic'], - format=mix_cfg['save'].get('format', 'wav'), - subtype=mix_cfg['save'].get('subtype', 'float'), - ) - - if tag == 'mic': - metadata['audio_filepath'] = audio_filepath - else: - metadata[tag + '_filepath'] = audio_filepath - - # Add metadata - metadata.update( - { - 'text': target_metadata.get('text'), - 'duration': target_metadata['duration'], - 'target_cfg': target_cfg, - 'interference_cfg': interference_cfg, - 'mix_cfg': mix_cfg, - 'ref_channel': mix_cfg.get('ref_mic'), - 'rt60': target_cfg.get('rt60'), - 'drr': calculate_drr(target_rir, sample_rate, n_direct=np.argmax(target_rir_anechoic, axis=0)), - 'rsnr': None if noise is None else mix_cfg['rsnr'], - 'rsir': None if interference is None else mix_cfg['rsir'], - 'source_signals': source_signals_metadata, - } - ) - - return convert_numpy_to_serializable(metadata) - - -def simulate_room_mix_helper(example_and_audio_metadata: tuple) -> dict: - """Wrapper around `simulate_room_mix` for pool.imap. - - Args: - args: example and audio_metadata that are forwarded to `simulate_room_mix` - - Returns: - Dictionary with metadata, see `simulate_room_mix` - """ - example, audio_metadata = example_and_audio_metadata - return simulate_room_mix(**example, audio_metadata=audio_metadata) - - -def plot_mix_manifest_info(filepath: str, plot_filepath: str = None): - """Plot distribution of parameters from the manifest file. - - Args: - filepath: path to a RIR corpus manifest file - plot_filepath: path to save the plot at - """ - metadata = read_manifest(filepath) - - # target info - target_distance = [] - target_azimuth = [] - target_elevation = [] - target_duration = [] - - # room config - rt60 = [] - drr = [] - - # noise - rsnr = [] - rsir = [] - - # get the required data - for data in metadata: - # target info - target_distance.append(data['target_cfg']['distance']) - target_azimuth.append(data['target_cfg']['azimuth']) - target_elevation.append(data['target_cfg']['elevation']) - target_duration.append(data['duration']) - - # room config - rt60.append(data['rt60']) - drr += data['drr'] # average DRR across all mics - - # noise - if data['rsnr'] is not None: - rsnr.append(data['rsnr']) - - if data['rsir'] is not None: - rsir.append(data['rsir']) - - # plot - plt.figure(figsize=(12, 6)) - - plt.subplot(2, 4, 1) - plt.hist(target_distance, label='distance') - plt.xlabel('distance / m') - plt.ylabel('# examples') - plt.title('Target-to-array distance') - - plt.subplot(2, 4, 2) - plt.hist(target_azimuth, label='azimuth') - plt.xlabel('azimuth / deg') - plt.ylabel('# examples') - plt.title('Target-to-array azimuth') - - plt.subplot(2, 4, 3) - plt.hist(target_elevation, label='elevation') - plt.xlabel('elevation / deg') - plt.ylabel('# examples') - plt.title('Target-to-array elevation') - - plt.subplot(2, 4, 4) - plt.hist(target_duration, label='duration') - plt.xlabel('time / s') - plt.ylabel('# examples') - plt.title('Target duration') - - plt.subplot(2, 4, 5) - plt.hist(rt60, label='RT60') - plt.xlabel('RT60 / s') - plt.ylabel('# examples') - plt.title('RT60') - - plt.subplot(2, 4, 6) - plt.hist(drr, label='DRR') - plt.xlabel('DRR / dB') - plt.ylabel('# examples') - plt.title('DRR [avg over mics]') - - if len(rsnr) > 0: - plt.subplot(2, 4, 7) - plt.hist(rsnr, label='RSNR') - plt.xlabel('RSNR / dB') - plt.ylabel('# examples') - plt.title(f'RSNR [{100 * len(rsnr) / len(rt60):.0f}% ex]') - - if len(rsir): - plt.subplot(2, 4, 8) - plt.hist(rsir, label='RSIR') - plt.xlabel('RSIR / dB') - plt.ylabel('# examples') - plt.title(f'RSIR [{100 * len(rsir) / len(rt60):.0f}% ex]') - - for n in range(8): - plt.subplot(2, 4, n + 1) - plt.grid() - plt.legend(loc='lower left') - - plt.tight_layout() - - if plot_filepath is not None: - plt.savefig(plot_filepath) - plt.close() - logging.info('Plot saved at %s', plot_filepath) diff --git a/nemo/collections/asr/data/feature_to_text.py b/nemo/collections/asr/data/feature_to_text.py index a7e295051ae80..b0b524d374f17 100644 --- a/nemo/collections/asr/data/feature_to_text.py +++ b/nemo/collections/asr/data/feature_to_text.py @@ -19,7 +19,7 @@ from nemo.collections.asr.data.feature_to_label import _audio_feature_collate_fn from nemo.collections.asr.parts.preprocessing.feature_loader import ExternalFeatureLoader from nemo.collections.asr.parts.preprocessing.features import normalize_batch -from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType +from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType from nemo.collections.asr.parts.utils.vad_utils import load_speech_segments_from_rttm from nemo.collections.common import tokenizers from nemo.collections.common.parts.preprocessing import collections, parsers @@ -80,7 +80,7 @@ class _FeatureTextDataset(Dataset): """ Dataset that loads tensors via a json file containing paths to audio feature files, transcripts, durations (in seconds) and optional RTTM files. Each new line is a different sample. Example below: - {"feature_filepath": "/path/to/audio_feature.pt", "text_filepath": "/path/to/audio.txt", + {"feature_filepath": "/path/to/audio_feature.pt", "text_filepath": "/path/to/audio.txt", "rttm_filepath": "/path/to/audio_rttm.rttm", "duration": 23.147} ... {"feature_filepath": "/path/to/audio_feature.pt", "text": "the transcription", "offset": 301.75, "duration": 0.82, "utt": @@ -115,8 +115,7 @@ class _FeatureTextDataset(Dataset): @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { 'features': NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), 'feature_length': NeuralType(tuple('B'), LengthsType()), @@ -264,7 +263,7 @@ def _collate_fn(self, batch): def normalize_feature(self, feat): """ Args: - feat: feature tensor of shape [M, T] + feat: feature tensor of shape [M, T] """ feat = feat.unsqueeze(0) # add batch dim feat, _, _ = normalize_batch(feat, torch.tensor([feat.size(-1)]), self.normalize_type) @@ -369,7 +368,7 @@ def __init__( class FeatureToBPEDataset(_FeatureTextDataset): """ Dataset that loads tensors via a json file containing paths to audio feature - files, transcripts, durations (in seconds) and optional RTTM files. Each new line is a different sample. + files, transcripts, durations (in seconds) and optional RTTM files. Each new line is a different sample. Example below: {"audio_filepath": "/path/to/audio.wav", "text_filepath": "/path/to/audio.txt", "duration": 23.147, "rttm_filepath": "/path/to/audio_rttm.rttm",} diff --git a/nemo/collections/asr/data/huggingface/hf_audio_to_text.py b/nemo/collections/asr/data/huggingface/hf_audio_to_text.py index f0a3f8376049b..da4aeb3f888c0 100644 --- a/nemo/collections/asr/data/huggingface/hf_audio_to_text.py +++ b/nemo/collections/asr/data/huggingface/hf_audio_to_text.py @@ -22,8 +22,7 @@ from nemo.collections.asr.data.audio_to_text import _speech_collate_fn from nemo.collections.asr.parts.preprocessing.perturb import AudioAugmentor -from nemo.collections.asr.parts.preprocessing.segment import AudioSegment -from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType +from nemo.collections.asr.parts.preprocessing.segment import AudioSegment, ChannelSelectorType from nemo.collections.common import tokenizers from nemo.collections.common.parts.preprocessing import parsers from nemo.core.classes import Dataset, IterableDataset @@ -33,8 +32,8 @@ class HFTextProcessor: """ - Text processor for huggingface datasets, mimicing the behavior of - `nemo.collections.asr.data.audio_to_text.ASRManifestProcessor`. + Text processor for huggingface datasets, mimicing the behavior of + `nemo.collections.asr.data.audio_to_text.ASRManifestProcessor`. Basic text cleaning is also supported. Args: parser: Str for a language specific preprocessor or a callable. @@ -124,7 +123,7 @@ class _HFAudioTextDataset(Dataset): ref_channel: Reference channel for normalization. id_key: key to access sample id from the dataset normalize_text: If true, normalizes text in HFTextProcessor - symbols_to_keep: If not None, only keeps symbols in this list when normalizing text + symbols_to_keep: If not None, only keeps symbols in this list when normalizing text """ def __init__( @@ -222,8 +221,7 @@ class HFAudioToCharDataset(_HFAudioTextDataset): @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { 'audio_signal': NeuralType(('B', 'T'), AudioSignal()), 'a_sig_length': NeuralType(tuple('B'), LengthsType()), @@ -292,8 +290,7 @@ class HFAudioToBPEDataset(_HFAudioTextDataset): @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { 'audio_signal': NeuralType(('B', 'T'), AudioSignal()), 'a_sig_length': NeuralType(tuple('B'), LengthsType()), @@ -378,7 +375,7 @@ def __call__(self, *args): class _HFIterableAudioTextDataset(IterableDataset): """ - Wrapper class for loading HuggingFace IterableDataset and converts to NeMo compatible format. + Wrapper class for loading HuggingFace IterableDataset and converts to NeMo compatible format. Args: audio_key: key to access audio data from the dataset text_key: key to access text data from the dataset @@ -528,8 +525,7 @@ class HFIterableAudioToCharDataset(_HFIterableAudioTextDataset): @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { 'audio_signal': NeuralType(('B', 'T'), AudioSignal()), 'a_sig_length': NeuralType(tuple('B'), LengthsType()), @@ -606,8 +602,7 @@ class HFIterableAudioToBPEDataset(_HFIterableAudioTextDataset): @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { 'audio_signal': NeuralType(('B', 'T'), AudioSignal()), 'a_sig_length': NeuralType(tuple('B'), LengthsType()), diff --git a/nemo/collections/asr/losses/__init__.py b/nemo/collections/asr/losses/__init__.py index c03f7a48ffe32..0747e9a37bea0 100644 --- a/nemo/collections/asr/losses/__init__.py +++ b/nemo/collections/asr/losses/__init__.py @@ -13,7 +13,6 @@ # limitations under the License. from nemo.collections.asr.losses.angularloss import AngularSoftmaxLoss -from nemo.collections.asr.losses.audio_losses import MSELoss, SDRLoss from nemo.collections.asr.losses.ctc import CTCLoss from nemo.collections.asr.losses.lattice_losses import LatticeLoss from nemo.collections.asr.losses.ssl_losses.contrastive import ContrastiveLoss diff --git a/nemo/collections/asr/models/__init__.py b/nemo/collections/asr/models/__init__.py index 23c759afc80de..9b339df44f18e 100644 --- a/nemo/collections/asr/models/__init__.py +++ b/nemo/collections/asr/models/__init__.py @@ -14,7 +14,6 @@ from nemo.collections.asr.models.aed_multitask_models import EncDecMultiTaskModel from nemo.collections.asr.models.asr_model import ASRModel -from nemo.collections.asr.models.audio_to_audio_model import AudioToAudioModel from nemo.collections.asr.models.classification_models import ( ClassificationInferConfig, EncDecClassificationModel, @@ -23,11 +22,6 @@ from nemo.collections.asr.models.clustering_diarizer import ClusteringDiarizer from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE from nemo.collections.asr.models.ctc_models import EncDecCTCModel -from nemo.collections.asr.models.enhancement_models import ( - EncMaskDecAudioToAudioModel, - PredictiveAudioToAudioModel, - ScoreBasedGenerativeAudioToAudioModel, -) from nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models import EncDecHybridRNNTCTCBPEModel from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel from nemo.collections.asr.models.k2_sequence_models import ( diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index edb5919217821..5ec7a8298beef 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -14,13 +14,14 @@ import os import warnings +from collections.abc import Mapping, Sequence from dataclasses import dataclass, field from math import ceil from typing import Any, Dict, List, Optional, Union import numpy as np import torch -from omegaconf import DictConfig, OmegaConf, open_dict +from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict from pytorch_lightning import Trainer from torch.utils.data import DataLoader @@ -30,16 +31,16 @@ ) from nemo.collections.asr.metrics import BLEU, WER from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel -from nemo.collections.asr.parts.mixins import ASRBPEMixin, ASRTranscriptionMixin +from nemo.collections.asr.parts.mixins import ASRBPEMixin, ASRModuleMixin, ASRTranscriptionMixin from nemo.collections.asr.parts.mixins.transcription import ( GenericTranscriptionType, InternalTranscribeConfig, TranscribeConfig, ) +from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType from nemo.collections.asr.parts.submodules.multitask_decoding import MultiTaskDecoding, MultiTaskDecodingConfig from nemo.collections.asr.parts.submodules.token_classifier import TokenClassifier from nemo.collections.asr.parts.utils import manifest_utils -from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis from nemo.collections.common import tokenizers from nemo.collections.common.data.lhotse.dataloader import get_lhotse_dataloader_from_config @@ -114,7 +115,7 @@ def __post_init__(self): self.prompt = parse_multitask_prompt(self.prompt) -class EncDecMultiTaskModel(ASRModel, ExportableEncDecModel, ASRBPEMixin, ASRTranscriptionMixin): +class EncDecMultiTaskModel(ASRModel, ExportableEncDecModel, ASRBPEMixin, ASRModuleMixin, ASRTranscriptionMixin): """Base class for AED multi-task models""" def __init__(self, cfg: DictConfig, trainer: Trainer = None): @@ -224,6 +225,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.decoding, tokenize=self.cfg.get('bleu_tokenizer', "13a"), log_prediction=False ) # Wer is handling logging + # Setup encoder adapters (from ASRAdapterModelMixin) + self.setup_adapters() + def change_decoding_strategy(self, decoding_cfg: DictConfig): """ Changes decoding strategy used during Multi Task decoding process. @@ -387,6 +391,59 @@ def change_vocabulary( logging.info(f"Changed decoder to output to {vocabulary} vocabulary.") + def change_prompt( + self, prompt_format: Optional[str] = None, prompt_defaults: Optional[List[Dict[str, Any]]] = None + ): + """ + Changes the prompt format used during Multi Task decoding process. + + Args: + prompt_format: A string alias of the object that represents the prompt structure. + If not None, it will be used to update the prompt format. + prompt_defaults: A dictionary of default values for the prompt format. + """ + if prompt_format is not None: + self.prompt_format = prompt_format + + if prompt_defaults is not None: + # Perform some assertions on the prompt defaults contents + # Must be a list-like object + if not isinstance(prompt_defaults, Sequence): + raise ValueError("`prompt_defaults` must be a list of dictionaries") + + # Must contain dict-like objects + for item in prompt_defaults: + if not isinstance(item, Mapping): + raise ValueError("`prompt_defaults` must be a list of dictionaries") + + # Each dict item must have a `role` key + if 'role' not in item: + raise ValueError( + "`prompt_defaults` must have a `role` key for each item in the list of dictionaries" + ) + + if 'slots' not in item: + raise ValueError( + "`prompt_defaults` must have a `slots` key for each item in the list of dictionaries" + ) + + # Cast to OmegaConf if not already + if not isinstance(prompt_defaults, ListConfig): + prompt_defaults = OmegaConf.create(prompt_defaults) + + prompt_cls = PromptFormatter.resolve(self.prompt_format) + self.prompt = prompt_cls( + tokenizer=self.tokenizer, + defaults=OmegaConf.to_container(pd) if (pd := self.cfg.prompt_defaults) is not None else None, + ) + + # Update config + with open_dict(self.cfg): + self.cfg.prompt_format = self.prompt_format + self.cfg.prompt_defaults = prompt_defaults + + logging.info(f"Changed prompt format to `{self.prompt_format}`") + @torch.no_grad() def transcribe( self, @@ -1003,6 +1060,10 @@ def predict_step(self, batch, batch_idx=0, dataloader_idx=0, has_processed_signa text = [self.decoding.strip_special_tokens(t) for t in text] return text + @property + def adapter_module_names(self) -> List[str]: + return ['', 'encoder', 'transf_encoder', 'transf_decoder'] + def parse_multitask_prompt(prompt: dict | None) -> list[dict]: if prompt is None or not prompt: diff --git a/nemo/collections/asr/models/asr_model.py b/nemo/collections/asr/models/asr_model.py index 0539f961a1cae..24e300aff1120 100644 --- a/nemo/collections/asr/models/asr_model.py +++ b/nemo/collections/asr/models/asr_model.py @@ -240,12 +240,12 @@ def output_names(self): if getattr(self.input_module, 'export_cache_support', False): in_types = self.input_module.output_types otypes = {n: t for (n, t) in list(otypes.items())[:1]} - for (n, t) in list(in_types.items())[1:]: + for n, t in list(in_types.items())[1:]: otypes[n] = t return get_io_names(otypes, self.disabled_deployment_output_names) def forward_for_export( - self, input, length=None, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None + self, audio_signal, length=None, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None ): """ This forward is used when we need to export the model to ONNX format. @@ -264,12 +264,12 @@ def forward_for_export( """ enc_fun = getattr(self.input_module, 'forward_for_export', self.input_module.forward) if cache_last_channel is None: - encoder_output = enc_fun(audio_signal=input, length=length) + encoder_output = enc_fun(audio_signal=audio_signal, length=length) if isinstance(encoder_output, tuple): encoder_output = encoder_output[0] else: encoder_output, length, cache_last_channel, cache_last_time, cache_last_channel_len = enc_fun( - audio_signal=input, + audio_signal=audio_signal, length=length, cache_last_channel=cache_last_channel, cache_last_time=cache_last_time, diff --git a/nemo/collections/asr/models/clustering_diarizer.py b/nemo/collections/asr/models/clustering_diarizer.py index 93913a43c1b56..98e56a7be48dc 100644 --- a/nemo/collections/asr/models/clustering_diarizer.py +++ b/nemo/collections/asr/models/clustering_diarizer.py @@ -392,13 +392,6 @@ def _extract_embeddings(self, manifest_file: str, scale_idx: int, num_scales: in pkl.dump(self.embeddings, open(self._embeddings_file, 'wb')) logging.info("Saved embedding files to {}".format(embedding_dir)) - def path2audio_files_to_manifest(self, paths2audio_files, manifest_filepath): - with open(manifest_filepath, 'w', encoding='utf-8') as fp: - for audio_file in paths2audio_files: - audio_file = audio_file.strip() - entry = {'audio_filepath': audio_file, 'offset': 0.0, 'duration': None, 'text': '-', 'label': 'infer'} - fp.write(json.dumps(entry) + '\n') - def diarize(self, paths2audio_files: List[str] = None, batch_size: int = 0): """ Diarize files provided through paths2audio_files or manifest file diff --git a/nemo/collections/asr/models/confidence_ensemble.py b/nemo/collections/asr/models/confidence_ensemble.py index dcbb0a05976cc..9ae3bc3fbb5d4 100644 --- a/nemo/collections/asr/models/confidence_ensemble.py +++ b/nemo/collections/asr/models/confidence_ensemble.py @@ -23,13 +23,13 @@ from nemo.collections.asr.models.asr_model import ASRModel from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel +from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType from nemo.collections.asr.parts.utils.asr_confidence_utils import ( ConfidenceConfig, ConfidenceMethodConfig, get_confidence_aggregation_bank, get_confidence_measure_bank, ) -from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis from nemo.core.classes import ModelPT from nemo.utils import model_utils @@ -62,7 +62,10 @@ def to_confidence_config(self) -> ConfidenceConfig: exclude_blank=self.exclude_blank, aggregation=self.aggregation, method_cfg=ConfidenceMethodConfig( - name=name, entropy_type=entropy_type, alpha=self.alpha, entropy_norm=entropy_norm, + name=name, + entropy_type=entropy_type, + alpha=self.alpha, + entropy_norm=entropy_norm, ), ) @@ -159,7 +162,9 @@ class ConfidenceEnsembleModel(ModelPT): """ def __init__( - self, cfg: DictConfig, trainer: 'Trainer' = None, + self, + cfg: DictConfig, + trainer: 'Trainer' = None, ): super().__init__(cfg=cfg, trainer=trainer) @@ -180,7 +185,9 @@ def __init__( model_cfg = self.cfg[cfg_field] model_class = model_utils.import_class_by_path(model_cfg['target']) self.register_nemo_submodule( - name=cfg_field, config_field=cfg_field, model=model_class(model_cfg, trainer=trainer), + name=cfg_field, + config_field=cfg_field, + model=model_class(model_cfg, trainer=trainer), ) else: self.num_models = len(cfg.load_models) @@ -196,7 +203,9 @@ def __init__( ) else: self.register_nemo_submodule( - cfg_field, config_field=cfg_field, model=ASRModel.from_pretrained(model, map_location="cpu"), + cfg_field, + config_field=cfg_field, + model=ASRModel.from_pretrained(model, map_location="cpu"), ) # registering model selection block - this is expected to be a joblib-saved diff --git a/nemo/collections/asr/models/configs/classification_models_config.py b/nemo/collections/asr/models/configs/classification_models_config.py index 33408f591c8e7..76c6022e22e2d 100644 --- a/nemo/collections/asr/models/configs/classification_models_config.py +++ b/nemo/collections/asr/models/configs/classification_models_config.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from omegaconf import MISSING @@ -46,6 +46,7 @@ class EncDecClassificationDatasetConfig(nemo.core.classes.dataset.DatasetConfig) max_duration: Optional[float] = None min_duration: Optional[float] = None cal_labels_occurrence: Optional[bool] = False + channel_selector: Optional[Union[str, int, List[int]]] = None # VAD Optional vad_stream: Optional[bool] = None diff --git a/nemo/collections/asr/models/ctc_models.py b/nemo/collections/asr/models/ctc_models.py index 093419c3ca0ca..b6d8945b6c6b0 100644 --- a/nemo/collections/asr/models/ctc_models.py +++ b/nemo/collections/asr/models/ctc_models.py @@ -34,9 +34,9 @@ from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel from nemo.collections.asr.parts.mixins import ASRModuleMixin, ASRTranscriptionMixin, InterCTCMixin, TranscribeConfig from nemo.collections.asr.parts.mixins.transcription import GenericTranscriptionType, TranscriptionReturnType +from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecoding, CTCDecodingConfig from nemo.collections.asr.parts.utils.asr_batching import get_semi_sorted_batch_sampler -from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.collections.common.parts.preprocessing.parsers import make_parser from nemo.core.classes.common import PretrainedModelInfo, typecheck @@ -879,6 +879,10 @@ def list_available_models(cls) -> List[PretrainedModelInfo]: return results + @property + def adapter_module_names(self) -> List[str]: + return ['', 'encoder', 'decoder'] + @property def wer(self): return self._wer diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py index 9a5c4188aebd2..c7c09739be647 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py @@ -29,8 +29,8 @@ from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel from nemo.collections.asr.parts.mixins import ASRBPEMixin, InterCTCMixin, TranscribeConfig from nemo.collections.asr.parts.mixins.transcription import TranscriptionReturnType +from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecoding, CTCDecodingConfig -from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType from nemo.core.classes.common import PretrainedModelInfo from nemo.core.classes.mixins import AccessMixin from nemo.utils import logging, model_utils diff --git a/nemo/collections/asr/models/label_models.py b/nemo/collections/asr/models/label_models.py index 071c53417ae2f..62cf2e4608d0a 100644 --- a/nemo/collections/asr/models/label_models.py +++ b/nemo/collections/asr/models/label_models.py @@ -13,6 +13,8 @@ # limitations under the License. import copy import itertools +import os +import tempfile from collections import Counter from math import ceil from typing import Dict, List, Optional, Union @@ -34,6 +36,7 @@ ) from nemo.collections.asr.data.audio_to_text_dataset import convert_to_config_list from nemo.collections.asr.models.asr_model import ExportableEncDecModel +from nemo.collections.asr.parts.mixins.mixins import VerificationMixin from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations from nemo.collections.common.metrics import TopKClassificationAccuracy @@ -46,7 +49,7 @@ __all__ = ['EncDecSpeakerLabelModel'] -class EncDecSpeakerLabelModel(ModelPT, ExportableEncDecModel): +class EncDecSpeakerLabelModel(ModelPT, ExportableEncDecModel, VerificationMixin): """ Encoder decoder class for speaker label models. Model class creates training, validation methods for setting up data @@ -242,6 +245,7 @@ def __setup_dataloader_from_config(self, config: Optional[Dict]): max_duration=config.get('max_duration', None), min_duration=config.get('min_duration', None), trim=config.get('trim_silence', False), + channel_selector=config.get('channel_selector', None), normalize_audio=config.get('normalize_audio', False), cal_labels_occurrence=config.get('cal_labels_occurrence', False), ) @@ -333,8 +337,8 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: "embs": NeuralType(('B', 'D'), AcousticEncodedRepresentation()), } - def forward_for_export(self, processed_signal, processed_signal_len): - encoded, length = self.encoder(audio_signal=processed_signal, length=processed_signal_len) + def forward_for_export(self, audio_signal, length): + encoded, length = self.encoder(audio_signal=audio_signal, length=length) logits, embs = self.decoder(encoder_output=encoded, length=length) return logits, embs @@ -583,6 +587,7 @@ def verify_speakers(self, path2audio_file1, path2audio_file2, threshold=0.7): # Score similarity_score = torch.dot(X, Y) / ((torch.dot(X, X) * torch.dot(Y, Y)) ** 0.5) similarity_score = (similarity_score + 1) / 2 + # Decision if similarity_score >= threshold: logging.info(" two audio files are from same speaker") @@ -591,6 +596,58 @@ def verify_speakers(self, path2audio_file1, path2audio_file2, threshold=0.7): logging.info(" two audio files are from different speakers") return False + @torch.no_grad() + def verify_speakers_batch(self, audio_files_pairs, threshold=0.7, batch_size=32, sample_rate=16000, device='cuda'): + """ + Verify if audio files from the first and second manifests are from the same speaker or not. + + Args: + audio_files_pairs: list of tuples with audio_files pairs to be verified + threshold: cosine similarity score used as a threshold to distinguish two embeddings (default = 0.7) + batch_size: batch size to perform batch inference + sample_rate: sample rate of audio files in manifest file + device: compute device to perform operations. + + Returns: + True if both audio pair is from same speaker, False otherwise + """ + + if type(audio_files_pairs) is list: + tmp_dir = tempfile.TemporaryDirectory() + manifest_filepath1 = os.path.join(tmp_dir.name, 'tmp_manifest1.json') + manifest_filepath2 = os.path.join(tmp_dir.name, 'tmp_manifest2.json') + self.path2audio_files_to_manifest([p[0] for p in audio_files_pairs], manifest_filepath1) + self.path2audio_files_to_manifest([p[1] for p in audio_files_pairs], manifest_filepath2) + else: + raise ValueError("audio_files_pairs must be of type list of tuples containing a pair of audio files") + + embs1, _, _, _ = self.batch_inference( + manifest_filepath1, batch_size=batch_size, sample_rate=sample_rate, device=device + ) + embs2, _, _, _ = self.batch_inference( + manifest_filepath2, batch_size=batch_size, sample_rate=sample_rate, device=device + ) + + embs1 = torch.Tensor(embs1).to(device) + embs2 = torch.Tensor(embs2).to(device) + # Length Normalize + embs1 = torch.div(embs1, torch.linalg.norm(embs1, dim=1).unsqueeze(dim=1)) + embs2 = torch.div(embs2, torch.linalg.norm(embs2, dim=1).unsqueeze(dim=1)) + + X = embs1.unsqueeze(dim=1) + Y = embs2.unsqueeze(dim=2) + # Score + similarity_scores = torch.matmul(X, Y).squeeze() / ( + (torch.matmul(X, X.permute(0, 2, 1)).squeeze() * torch.matmul(Y.permute(0, 2, 1), Y).squeeze()) ** 0.5 + ) + similarity_scores = (similarity_scores + 1) / 2 + + # Decision + decision = similarity_scores >= threshold + + tmp_dir.cleanup() + return decision.cpu().numpy() + @torch.no_grad() def batch_inference(self, manifest_filepath, batch_size=32, sample_rate=16000, device='cuda'): """ @@ -623,15 +680,15 @@ def batch_inference(self, manifest_filepath, batch_size=32, sample_rate=16000, d if trained_labels is not None: trained_labels = list(trained_labels) - featurizer = WaveformFeaturizer(sample_rate=sample_rate) - - dataset = AudioToSpeechLabelDataset(manifest_filepath=manifest_filepath, labels=None, featurizer=featurizer) - - dataloader = torch.utils.data.DataLoader( - dataset=dataset, - batch_size=batch_size, - collate_fn=dataset.fixed_seq_collate_fn, - ) + dl_config = { + 'manifest_filepath': manifest_filepath, + 'sample_rate': sample_rate, + 'channel_selector': 0, + 'batch_size': batch_size, + } + self.labels = self.extract_labels(dl_config) + dl_config['labels'] = self.labels + dataloader = self.__setup_dataloader_from_config(config=dl_config) logits = [] embs = [] @@ -647,7 +704,7 @@ def batch_inference(self, manifest_filepath, batch_size=32, sample_rate=16000, d gt_labels.extend(labels.cpu().numpy()) embs.extend(emb.cpu().numpy()) - gt_labels = list(map(lambda t: dataset.id2label[t], gt_labels)) + gt_labels = list(map(lambda t: dataloader.dataset.id2label[t], gt_labels)) self.train(mode=mode) if mode is True: diff --git a/nemo/collections/asr/models/msdd_models.py b/nemo/collections/asr/models/msdd_models.py index 01926eb4ae792..c88275dcacd34 100644 --- a/nemo/collections/asr/models/msdd_models.py +++ b/nemo/collections/asr/models/msdd_models.py @@ -163,8 +163,7 @@ def add_speaker_model_config(self, cfg): del cfg.speaker_model_cfg.validation_ds def _init_segmentation_info(self): - """Initialize segmentation settings: window, shift and multiscale weights. - """ + """Initialize segmentation settings: window, shift and multiscale weights.""" self._diarizer_params = self.cfg_msdd_model.diarizer self.multiscale_args_dict = parse_scale_configs( self._diarizer_params.speaker_embeddings.parameters.window_length_in_sec, @@ -275,10 +274,14 @@ def __setup_dataloader_from_config_infer( ) def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): - self._train_dl = self.__setup_dataloader_from_config(config=train_data_config,) + self._train_dl = self.__setup_dataloader_from_config( + config=train_data_config, + ) def setup_validation_data(self, val_data_layer_config: Optional[Union[DictConfig, Dict]]): - self._validation_dl = self.__setup_dataloader_from_config(config=val_data_layer_config,) + self._validation_dl = self.__setup_dataloader_from_config( + config=val_data_layer_config, + ) def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): if self.pairwise_infer: @@ -338,32 +341,32 @@ def get_ms_emb_seq( Merged embeddings without zero-padding in the batch. See `ms_seg_counts` for details. Shape: (Total number of segments in the batch, emb_dim) scale_mapping (Tensor): - The element at the m-th row and the n-th column of the scale mapping matrix indicates the (m+1)-th scale - segment index which has the closest center distance with (n+1)-th segment in the base scale. - Example: - scale_mapping_argmat[2][101] = 85 - In the above example, it means that 86-th segment in the 3rd scale (python index is 2) is mapped with - 102-th segment in the base scale. Thus, the longer segments bound to have more repeating numbers since - multiple base scale segments (since the base scale has the shortest length) fall into the range of the - longer segments. At the same time, each row contains N numbers of indices where N is number of - segments in the base-scale (i.e., the finest scale). + The element at the m-th row and the n-th column of the scale mapping matrix indicates the (m+1)-th scale + segment index which has the closest center distance with (n+1)-th segment in the base scale. + Example: + scale_mapping_argmat[2][101] = 85 + In the above example, it means that 86-th segment in the 3rd scale (python index is 2) is mapped with + 102-th segment in the base scale. Thus, the longer segments bound to have more repeating numbers since + multiple base scale segments (since the base scale has the shortest length) fall into the range of the + longer segments. At the same time, each row contains N numbers of indices where N is number of + segments in the base-scale (i.e., the finest scale). Shape: (batch_size, scale_n, self.diar_window_length) ms_seg_counts (Tensor): Cumulative sum of the number of segments in each scale. This information is needed to reconstruct the multi-scale input matrix during forward propagating. - Example: `batch_size=3, scale_n=6, emb_dim=192` - ms_seg_counts = - [[8, 9, 12, 16, 25, 51], - [11, 13, 14, 17, 25, 51], - [ 9, 9, 11, 16, 23, 50]] + Example: `batch_size=3, scale_n=6, emb_dim=192` + ms_seg_counts = + [[8, 9, 12, 16, 25, 51], + [11, 13, 14, 17, 25, 51], + [ 9, 9, 11, 16, 23, 50]] - In this function, `ms_seg_counts` is used to get the actual length of each embedding sequence without - zero-padding. + In this function, `ms_seg_counts` is used to get the actual length of each embedding sequence without + zero-padding. Returns: ms_emb_seq (Tensor): - Multi-scale embedding sequence that is mapped, matched and repeated. The longer scales are less repeated, + Multi-scale embedding sequence that is mapped, matched and repeated. The longer scales are less repeated, while shorter scales are more frequently repeated following the scale mapping tensor. """ scale_n, batch_size = scale_mapping[0].shape[0], scale_mapping.shape[0] @@ -409,9 +412,9 @@ def get_cluster_avg_embs_model( [ 9, 9, 11, 16, 23, 50] ] - Counts of merged segments: (121, 131, 118) - embs has shape of (370, 192) - clus_label_index has shape of (3, 131) + Counts of merged segments: (121, 131, 118) + embs has shape of (370, 192) + clus_label_index has shape of (3, 131) Shape: (batch_size, scale_n) @@ -553,7 +556,7 @@ def forward( with torch.no_grad(): self.msdd._speaker_model.eval() logits, embs_d = self.msdd._speaker_model.forward_for_export( - processed_signal=audio_signal[detach_ids[1]], processed_signal_len=audio_signal_len[detach_ids[1]] + audio_signal=audio_signal[detach_ids[1]], length=audio_signal_len[detach_ids[1]] ) embs = torch.zeros(audio_signal.shape[0], embs_d.shape[1]).to(embs_d.device) embs[detach_ids[1], :] = embs_d.detach() @@ -562,7 +565,7 @@ def forward( self.msdd._speaker_model.train() if len(detach_ids[0]) > 1: logits, embs_a = self.msdd._speaker_model.forward_for_export( - processed_signal=audio_signal[detach_ids[0]], processed_signal_len=audio_signal_len[detach_ids[0]] + audio_signal=audio_signal[detach_ids[0]], length=audio_signal_len[detach_ids[0]] ) embs[detach_ids[0], :] = embs_a @@ -854,9 +857,9 @@ def run_clustering_diarizer(self, manifest_filepath: str, emb_dir: str): os.makedirs(self.out_rttm_dir, exist_ok=True) self.clus_diar_model._cluster_params = self.cfg_diar_infer.diarizer.clustering.parameters - self.clus_diar_model.multiscale_args_dict[ - "multiscale_weights" - ] = self.cfg_diar_infer.diarizer.speaker_embeddings.parameters.multiscale_weights + self.clus_diar_model.multiscale_args_dict["multiscale_weights"] = ( + self.cfg_diar_infer.diarizer.speaker_embeddings.parameters.multiscale_weights + ) self.clus_diar_model._diarizer_params.speaker_embeddings.parameters = ( self.cfg_diar_infer.diarizer.speaker_embeddings.parameters ) @@ -1076,7 +1079,6 @@ def extract_standalone_speaker_model(self, prefix: str = 'msdd._speaker_model.') return _speaker_model def _init_msdd_model(self, cfg: Union[DictConfig, NeuralDiarizerInferenceConfig]): - """ Initialized MSDD model with the provided config. Load either from `.nemo` file or `.ckpt` checkpoint files. """ @@ -1128,7 +1130,7 @@ def get_pred_mat(self, data_list: List[Union[Tuple[int], List[torch.Tensor]]]) - digit_map = dict(zip(sorted(set(all_tups)), range(n_est_spks))) total_len = max([sess[1].shape[1] for sess in data_list]) sum_pred = torch.zeros(total_len, n_est_spks) - for (_dim_tup, pred_mat) in data_list: + for _dim_tup, pred_mat in data_list: dim_tup = [digit_map[x] for x in _dim_tup] if len(pred_mat.shape) == 3: pred_mat = pred_mat.squeeze(0) @@ -1167,8 +1169,7 @@ def get_integrated_preds_list( return output_list def get_emb_clus_infer(self, cluster_embeddings): - """Assign dictionaries containing the clustering results from the class instance `cluster_embeddings`. - """ + """Assign dictionaries containing the clustering results from the class instance `cluster_embeddings`.""" self.msdd_model.emb_sess_test_dict = cluster_embeddings.emb_sess_test_dict self.msdd_model.clus_test_label_dict = cluster_embeddings.clus_test_label_dict self.msdd_model.emb_seq_test = cluster_embeddings.emb_seq_test @@ -1456,7 +1457,10 @@ def from_pretrained( """ logging.setLevel(logging.INFO if verbose else logging.WARNING) cfg = NeuralDiarizerInferenceConfig.init_config( - diar_model_path=model_name, vad_model_path=vad_model_name, map_location=map_location, verbose=verbose, + diar_model_path=model_name, + vad_model_path=vad_model_name, + map_location=map_location, + verbose=verbose, ) return cls(cfg) diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index cb2505fbadbff..d58e4f7db8f2e 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -37,9 +37,9 @@ TranscribeConfig, TranscriptionReturnType, ) +from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecoding, RNNTDecodingConfig from nemo.collections.asr.parts.utils.asr_batching import get_semi_sorted_batch_sampler -from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.collections.common.parts.preprocessing.parsers import make_parser from nemo.core.classes.common import PretrainedModelInfo, typecheck diff --git a/nemo/collections/asr/models/transformer_bpe_models.py b/nemo/collections/asr/models/transformer_bpe_models.py index e7e67f8fbb2f4..79de83f1d4a19 100644 --- a/nemo/collections/asr/models/transformer_bpe_models.py +++ b/nemo/collections/asr/models/transformer_bpe_models.py @@ -38,8 +38,8 @@ get_nemo_transformer, ) from nemo.collections.asr.parts.mixins import ASRBPEMixin, ASRTranscriptionMixin, TranscribeConfig +from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType from nemo.collections.asr.parts.submodules.token_classifier import TokenClassifier -from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.collections.common.losses import SmoothedCrossEntropyLoss diff --git a/nemo/collections/asr/modules/__init__.py b/nemo/collections/asr/modules/__init__.py index 0265d9e306878..a412040a3b67c 100644 --- a/nemo/collections/asr/modules/__init__.py +++ b/nemo/collections/asr/modules/__init__.py @@ -12,20 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemo.collections.asr.modules.audio_modules import ( - MaskBasedBeamformer, - MaskEstimatorFlexChannels, - MaskEstimatorRNN, - MaskReferenceChannel, -) from nemo.collections.asr.modules.audio_preprocessing import ( AudioToMelSpectrogramPreprocessor, AudioToMFCCPreprocessor, - AudioToSpectrogram, CropOrPadSpectrogramAugmentation, MaskedPatchAugmentation, SpectrogramAugmentation, - SpectrogramToAudio, ) from nemo.collections.asr.modules.beam_search_decoder import BeamSearchDecoderWithLM from nemo.collections.asr.modules.conformer_encoder import ConformerEncoder, ConformerEncoderAdapter diff --git a/nemo/collections/asr/modules/audio_preprocessing.py b/nemo/collections/asr/modules/audio_preprocessing.py index 33143364ede1e..f567e3f5c8ffa 100644 --- a/nemo/collections/asr/modules/audio_preprocessing.py +++ b/nemo/collections/asr/modules/audio_preprocessing.py @@ -16,17 +16,13 @@ import random from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple +from typing import Any, Optional import torch from packaging import version from nemo.collections.asr.parts.numba.spec_augment import SpecAugmentNumba, spec_augment_launch_heuristics -from nemo.collections.asr.parts.preprocessing.features import ( - FilterbankFeatures, - FilterbankFeaturesTA, - make_seq_mask_like, -) +from nemo.collections.asr.parts.preprocessing.features import FilterbankFeatures, FilterbankFeaturesTA from nemo.collections.asr.parts.submodules.spectr_augment import SpecAugment, SpecCutout from nemo.core.classes import Exportable, NeuralModule, typecheck from nemo.core.neural_types import ( @@ -55,8 +51,6 @@ __all__ = [ 'AudioToMelSpectrogramPreprocessor', - 'AudioToSpectrogram', - 'SpectrogramToAudio', 'AudioToMFCCPreprocessor', 'SpectrogramAugmentation', 'MaskedPatchAugmentation', @@ -726,253 +720,6 @@ def restore_from(cls, restore_path: str): pass -class AudioToSpectrogram(NeuralModule): - """Transform a batch of input multi-channel signals into a batch of - STFT-based spectrograms. - - Args: - fft_length: length of FFT - hop_length: length of hops/shifts of the sliding window - power: exponent for magnitude spectrogram. Default `None` will - return a complex-valued spectrogram - magnitude_power: Transform magnitude of the spectrogram as x^magnitude_power. - scale: Positive scaling of the spectrogram. - """ - - def __init__(self, fft_length: int, hop_length: int, magnitude_power: float = 1.0, scale: float = 1.0): - if not HAVE_TORCHAUDIO: - logging.error('Could not import torchaudio. Some features might not work.') - - raise ModuleNotFoundError( - f"torchaudio is not installed but is necessary to instantiate a {self.__class__.__name__}" - ) - - super().__init__() - - # For now, assume FFT length is divisible by two - if fft_length % 2 != 0: - raise ValueError(f'fft_length = {fft_length} must be divisible by 2') - - self.stft = torchaudio.transforms.Spectrogram( - n_fft=fft_length, hop_length=hop_length, power=None, pad_mode='constant' - ) - - # number of subbands - self.F = fft_length // 2 + 1 - - if magnitude_power <= 0: - raise ValueError(f'Magnitude power needs to be positive: current value {magnitude_power}') - self.magnitude_power = magnitude_power - - if scale <= 0: - raise ValueError(f'Scale needs to be positive: current value {scale}') - self.scale = scale - - logging.debug('Initialized %s with:', self.__class__.__name__) - logging.debug('\tfft_length: %s', fft_length) - logging.debug('\thop_length: %s', hop_length) - logging.debug('\tmagnitude_power: %s', magnitude_power) - logging.debug('\tscale: %s', scale) - - @property - def num_subbands(self) -> int: - return self.F - - @property - def input_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports.""" - return { - "input": NeuralType(('B', 'C', 'T'), AudioSignal()), - "input_length": NeuralType(('B',), LengthsType(), optional=True), - } - - @property - def output_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports.""" - return { - "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), - "output_length": NeuralType(('B',), LengthsType()), - } - - @typecheck() - def forward( - self, input: torch.Tensor, input_length: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Convert a batch of C-channel input signals - into a batch of complex-valued spectrograms. - - Args: - input: Time-domain input signal with C channels, shape (B, C, T) - input_length: Length of valid entries along the time dimension, shape (B,) - - Returns: - Output spectrogram with F subbands and N time frames, shape (B, C, F, N) - and output length with shape (B,). - """ - B, T = input.size(0), input.size(-1) - input = input.view(B, -1, T) - - # STFT output (B, C, F, N) - with torch.cuda.amp.autocast(enabled=False): - output = self.stft(input.float()) - - if self.magnitude_power != 1: - # apply power on the magnitude - output = torch.pow(output.abs(), self.magnitude_power) * torch.exp(1j * output.angle()) - - if self.scale != 1: - # apply scaling of the coefficients - output = self.scale * output - - if input_length is not None: - # Mask padded frames - output_length = self.get_output_length(input_length=input_length) - - length_mask: torch.Tensor = make_seq_mask_like( - lengths=output_length, like=output, time_dim=-1, valid_ones=False - ) - output = output.masked_fill(length_mask, 0.0) - else: - # Assume all frames are valid for all examples in the batch - output_length = output.size(-1) * torch.ones(B, device=output.device).long() - - return output, output_length - - def get_output_length(self, input_length: torch.Tensor) -> torch.Tensor: - """Get length of valid frames for the output. - - Args: - input_length: number of valid samples, shape (B,) - - Returns: - Number of valid frames, shape (B,) - """ - output_length = input_length.div(self.stft.hop_length, rounding_mode='floor').add(1).long() - return output_length - - -class SpectrogramToAudio(NeuralModule): - """Transform a batch of input multi-channel spectrograms into a batch of - time-domain multi-channel signals. - - Args: - fft_length: length of FFT - hop_length: length of hops/shifts of the sliding window - magnitude_power: Transform magnitude of the spectrogram as x^(1/magnitude_power). - scale: Spectrogram will be scaled with 1/scale before the inverse transform. - """ - - def __init__(self, fft_length: int, hop_length: int, magnitude_power: float = 1.0, scale: float = 1.0): - if not HAVE_TORCHAUDIO: - logging.error('Could not import torchaudio. Some features might not work.') - - raise ModuleNotFoundError( - f"torchaudio is not installed but is necessary to instantiate a {self.__class__.__name__}" - ) - - super().__init__() - - # For now, assume FFT length is divisible by two - if fft_length % 2 != 0: - raise ValueError(f'fft_length = {fft_length} must be divisible by 2') - - self.istft = torchaudio.transforms.InverseSpectrogram( - n_fft=fft_length, hop_length=hop_length, pad_mode='constant' - ) - - self.F = fft_length // 2 + 1 - - if magnitude_power <= 0: - raise ValueError(f'Magnitude power needs to be positive: current value {magnitude_power}') - self.magnitude_power = magnitude_power - - if scale <= 0: - raise ValueError(f'Scale needs to be positive: current value {scale}') - self.scale = scale - - logging.debug('Initialized %s with:', self.__class__.__name__) - logging.debug('\tfft_length: %s', fft_length) - logging.debug('\thop_length: %s', hop_length) - logging.debug('\tmagnitude_power: %s', magnitude_power) - logging.debug('\tscale: %s', scale) - - @property - def num_subbands(self) -> int: - return self.F - - @property - def input_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports.""" - return { - "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), - "input_length": NeuralType(('B',), LengthsType(), optional=True), - } - - @property - def output_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports.""" - return { - "output": NeuralType(('B', 'C', 'T'), AudioSignal()), - "output_length": NeuralType(('B',), LengthsType()), - } - - @typecheck() - def forward(self, input: torch.Tensor, input_length: Optional[torch.Tensor] = None) -> torch.Tensor: - """Convert input complex-valued spectrogram to a time-domain - signal. Multi-channel IO is supported. - - Args: - input: Input spectrogram for C channels, shape (B, C, F, N) - input_length: Length of valid entries along the time dimension, shape (B,) - - Returns: - Time-domain signal with T time-domain samples and C channels, (B, C, T) - and output length with shape (B,). - """ - B, F, N = input.size(0), input.size(-2), input.size(-1) - assert F == self.F, f'Number of subbands F={F} not matching self.F={self.F}' - input = input.view(B, -1, F, N) - - # iSTFT output (B, C, T) - with torch.cuda.amp.autocast(enabled=False): - output = input.cfloat() - - if self.scale != 1: - # apply 1/scale on the coefficients - output = output / self.scale - - if self.magnitude_power != 1: - # apply 1/power on the magnitude - output = torch.pow(output.abs(), 1 / self.magnitude_power) * torch.exp(1j * output.angle()) - output = self.istft(output) - - if input_length is not None: - # Mask padded samples - output_length = self.get_output_length(input_length=input_length) - - length_mask: torch.Tensor = make_seq_mask_like( - lengths=output_length, like=output, time_dim=-1, valid_ones=False - ) - output = output.masked_fill(length_mask, 0.0) - else: - # Assume all frames are valid for all examples in the batch - output_length = output.size(-1) * torch.ones(B, device=output.device).long() - - return output, output_length - - def get_output_length(self, input_length: torch.Tensor) -> torch.Tensor: - """Get length of valid samples for the output. - - Args: - input_length: number of valid frames, shape (B,) - - Returns: - Number of valid samples, shape (B,) - """ - output_length = input_length.sub(1).mul(self.istft.hop_length).long() - return output_length - - @dataclass class AudioToMelSpectrogramPreprocessorConfig: _target_: str = "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor" diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index d723ce85d2ce7..245404a7601cd 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -501,6 +501,7 @@ def streaming_post_process(self, rets, keep_all_outputs=True): def forward( self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None ): + self.update_max_seq_length(seq_length=audio_signal.size(2), device=audio_signal.device) return self.forward_internal( audio_signal, length, @@ -512,8 +513,6 @@ def forward( def forward_internal( self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None ): - self.update_max_seq_length(seq_length=audio_signal.size(2), device=audio_signal.device) - if length is None: length = audio_signal.new_full( (audio_signal.size(0),), audio_signal.size(-1), dtype=torch.int64, device=audio_signal.device diff --git a/nemo/collections/asr/modules/transformer/transformer.py b/nemo/collections/asr/modules/transformer/transformer.py index 718448aa1c7c6..0ea376340d186 100644 --- a/nemo/collections/asr/modules/transformer/transformer.py +++ b/nemo/collections/asr/modules/transformer/transformer.py @@ -13,18 +13,21 @@ # limitations under the License. from dataclasses import dataclass -from typing import Dict, Optional +from typing import Dict, List, Optional import torch -from omegaconf.omegaconf import MISSING +from omegaconf.omegaconf import MISSING, DictConfig from nemo.collections.asr.modules.transformer.decoder_module import DecoderModule from nemo.collections.asr.modules.transformer.encoder_module import EncoderModule -from nemo.collections.asr.modules.transformer.transformer_decoders import TransformerDecoder +from nemo.collections.asr.modules.transformer.transformer_decoders import TransformerDecoder, TransformerDecoderAdapter from nemo.collections.asr.modules.transformer.transformer_encoders import TransformerEncoder from nemo.collections.asr.modules.transformer.transformer_modules import TransformerEmbedding +from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin +from nemo.collections.asr.parts.utils import adapter_utils from nemo.core.classes.common import typecheck from nemo.core.classes.exportable import Exportable +from nemo.core.classes.mixins import adapter_mixins from nemo.core.neural_types import ChannelType, NeuralType @@ -155,6 +158,8 @@ def input_example(self, max_batch=1, max_dim=256): class TransformerDecoderNM(DecoderModule, Exportable): + DECODER_TYPE: type = TransformerDecoder + def __init__( self, vocab_size: int, @@ -192,7 +197,7 @@ def __init__( learn_positional_encodings=learn_positional_encodings, ) - self._decoder = TransformerDecoder( + self._decoder = self.DECODER_TYPE( hidden_size=self.hidden_size, num_layers=num_layers, inner_size=inner_size, @@ -207,7 +212,12 @@ def __init__( @typecheck() def forward( - self, input_ids, decoder_mask, encoder_embeddings, encoder_mask, decoder_mems=None, + self, + input_ids, + decoder_mask, + encoder_embeddings, + encoder_mask, + decoder_mems=None, ): start_pos = 0 if decoder_mems is not None: @@ -274,3 +284,36 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: return {"last_hidden_states": NeuralType(('B', 'D', 'T', 'D'), ChannelType())} else: return {"last_hidden_states": NeuralType(('B', 'T', 'D'), ChannelType())} + + +class TransformerDecoderNMAdapter(TransformerDecoderNM, adapter_mixins.AdapterModuleMixin): + DECODER_TYPE: type = TransformerDecoderAdapter + + # Higher level forwarding + def add_adapter(self, name: str, cfg: dict): + cfg = self._update_adapter_cfg_input_dim(cfg) + self._decoder.add_adapter(name, cfg) # type: adapter_mixins.AdapterModuleMixin + + def is_adapter_available(self) -> bool: + return self._decoder.is_adapter_available() # type: adapter_mixins.AdapterModuleMixin + + def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True): + self._decoder.set_enabled_adapters(name=name, enabled=enabled) # # type: adapter_mixins.AdapterModuleMixin + + def get_enabled_adapters(self) -> List[str]: + names = set([]) + names.update(self._decoder.get_enabled_adapters()) # type: adapter_mixins.AdapterModuleMixin + + names = sorted(list(names)) + return names + + def _update_adapter_cfg_input_dim(self, cfg: DictConfig): + cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=self._hidden_size) + return cfg + + +""" +Register any additional information +""" +if adapter_mixins.get_registered_adapter(TransformerDecoderNM) is None: + adapter_mixins.register_adapter(base_class=TransformerDecoderNM, adapter_class=TransformerDecoderNMAdapter) diff --git a/nemo/collections/asr/modules/transformer/transformer_decoders.py b/nemo/collections/asr/modules/transformer/transformer_decoders.py index a5b2c299393cf..30c6179b85a6a 100644 --- a/nemo/collections/asr/modules/transformer/transformer_decoders.py +++ b/nemo/collections/asr/modules/transformer/transformer_decoders.py @@ -13,17 +13,22 @@ # limitations under the License. import copy +from typing import List, Optional, Set import torch import torch.nn as nn +from omegaconf import DictConfig from nemo.collections.asr.modules.transformer.transformer_modules import MultiHeadAttention, PositionWiseFF +from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin +from nemo.collections.asr.parts.utils import adapter_utils from nemo.collections.common.parts import form_attention_mask +from nemo.core.classes.mixins import adapter_mixins __all__ = ["TransformerDecoder"] -class TransformerDecoderBlock(nn.Module): +class TransformerDecoderBlock(nn.Module, AttentionAdapterModuleMixin): """ Building block of Transformer decoder. @@ -63,6 +68,9 @@ def __init__( self.layer_norm_3 = nn.LayerNorm(hidden_size, eps=1e-5) self.third_sub_layer = PositionWiseFF(hidden_size, inner_size, ffn_dropout, hidden_act) + # Information for the adapter module mixin + self.self_attention_model = "transf_abs" + def forward_preln(self, decoder_query, decoder_mask, decoder_keys, encoder_states, encoder_mask): """ Pre-LayerNorm block @@ -74,6 +82,17 @@ def forward_preln(self, decoder_query, decoder_mask, decoder_keys, encoder_state self_attn_output = self.first_sub_layer(decoder_query, decoder_keys, decoder_keys, decoder_mask) self_attn_output += residual + if self.is_adapter_available(): + # Call the MHA adapters + pack_input = { + 'x': self_attn_output, + 'loc': 'mha', + 'att_mask': decoder_mask, + 'pos_emb': None, + } + pack_input = self.forward_enabled_adapters(pack_input) + self_attn_output = pack_input['x'] + residual = self_attn_output self_attn_output = self.layer_norm_2(self_attn_output) enc_dec_attn_output = self.second_sub_layer(self_attn_output, encoder_states, encoder_states, encoder_mask) @@ -84,6 +103,15 @@ def forward_preln(self, decoder_query, decoder_mask, decoder_keys, encoder_state output_states = self.third_sub_layer(enc_dec_attn_output) output_states += residual + if self.is_adapter_available(): + # Call the Linear adapters + pack_input = { + 'x': output_states, + 'loc': 'post', + } + pack_input = self.forward_enabled_adapters(pack_input) + output_states = pack_input['x'] + return output_states def forward_postln(self, decoder_query, decoder_mask, decoder_keys, encoder_states, encoder_mask): @@ -93,6 +121,18 @@ def forward_postln(self, decoder_query, decoder_mask, decoder_keys, encoder_stat """ self_attn_output = self.first_sub_layer(decoder_query, decoder_keys, decoder_keys, decoder_mask) self_attn_output += decoder_query + + if self.is_adapter_available(): + # Call the MHA adapters + pack_ip = { + 'x': self_attn_output, + 'loc': 'mha', + 'att_mask': decoder_mask, + 'pos_emb': None, + } + pack_ip = self.forward_enabled_adapters(pack_ip) + self_attn_output = pack_ip['x'] + self_attn_output = self.layer_norm_1(self_attn_output) enc_dec_attn_output = self.second_sub_layer(self_attn_output, encoder_states, encoder_states, encoder_mask) @@ -101,6 +141,16 @@ def forward_postln(self, decoder_query, decoder_mask, decoder_keys, encoder_stat output_states = self.third_sub_layer(enc_dec_attn_output) output_states += enc_dec_attn_output + + if self.is_adapter_available(): + # Call the linear adapters + pack_ip = { + 'x': output_states, + 'loc': 'post', + } + pack_ip = self.forward_enabled_adapters(pack_ip) + output_states = pack_ip['x'] + return self.layer_norm_3(output_states) def forward(self, decoder_query, decoder_mask, decoder_keys, encoder_states, encoder_mask): @@ -109,6 +159,19 @@ def forward(self, decoder_query, decoder_mask, decoder_keys, encoder_states, enc else: return self.forward_postln(decoder_query, decoder_mask, decoder_keys, encoder_states, encoder_mask) + def get_accepted_adapter_types(self) -> Set[type]: + types = super().get_accepted_adapter_types() + + if len(types) == 0: + self.set_accepted_adapter_types( + [ + adapter_utils.LINEAR_ADAPTER_CLASSPATH, + adapter_utils.TRANSFORMER_MHA_ADAPTER_CLASSPATH, + ] + ) + types = self.get_accepted_adapter_types() + return types + class TransformerDecoder(nn.Module): def __init__( @@ -131,6 +194,8 @@ def __init__( else: self.final_layer_norm = None + self.d_model = hidden_size + layer = TransformerDecoderBlock( hidden_size, inner_size, @@ -219,3 +284,38 @@ def input_example(self, max_batch=1, max_dim=256): input_ids = torch.randint(low=0, high=2048, size=(max_batch, max_dim, 1024), device=sample.device) encoder_mask = torch.randint(low=0, high=1, size=(max_batch, max_dim), device=sample.device) return tuple([input_ids, encoder_mask, input_ids, encoder_mask]) + + +class TransformerDecoderAdapter(TransformerDecoder, adapter_mixins.AdapterModuleMixin): + + # Higher level forwarding + def add_adapter(self, name: str, cfg: dict): + cfg = self._update_adapter_cfg_input_dim(cfg) + for transformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin + transformer_layer.add_adapter(name, cfg) + + def is_adapter_available(self) -> bool: + return any([transformer_layer.is_adapter_available() for transformer_layer in self.layers]) + + def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True): + for transformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin + transformer_layer.set_enabled_adapters(name=name, enabled=enabled) + + def get_enabled_adapters(self) -> List[str]: + names = set([]) + for transformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin + names.update(transformer_layer.get_enabled_adapters()) + + names = sorted(list(names)) + return names + + def _update_adapter_cfg_input_dim(self, cfg: DictConfig): + cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=self.d_model) + return cfg + + +""" +Register any additional information +""" +if adapter_mixins.get_registered_adapter(TransformerDecoder) is None: + adapter_mixins.register_adapter(base_class=TransformerDecoder, adapter_class=TransformerDecoderAdapter) diff --git a/nemo/collections/asr/modules/transformer/transformer_encoders.py b/nemo/collections/asr/modules/transformer/transformer_encoders.py index 544d561267cff..d3116db82482e 100644 --- a/nemo/collections/asr/modules/transformer/transformer_encoders.py +++ b/nemo/collections/asr/modules/transformer/transformer_encoders.py @@ -13,17 +13,22 @@ # limitations under the License. import copy +from typing import List, Optional, Set import torch import torch.nn as nn +from omegaconf import DictConfig from nemo.collections.asr.modules.transformer.transformer_modules import MultiHeadAttention, PositionWiseFF +from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin +from nemo.collections.asr.parts.utils import adapter_utils from nemo.collections.common.parts import form_attention_mask +from nemo.core.classes.mixins import adapter_mixins __all__ = ["TransformerEncoder"] -class TransformerEncoderBlock(nn.Module): +class TransformerEncoderBlock(nn.Module, AttentionAdapterModuleMixin): """ Building block of Transformer encoder. @@ -59,6 +64,9 @@ def __init__( self.layer_norm_2 = nn.LayerNorm(hidden_size, eps=1e-5) self.second_sub_layer = PositionWiseFF(hidden_size, inner_size, ffn_dropout, hidden_act) + # Information for the adapter module mixin + self.self_attention_model = "transf_abs" + def forward_preln(self, encoder_query, encoder_mask, encoder_keys): """ Pre-LayerNorm block @@ -70,11 +78,31 @@ def forward_preln(self, encoder_query, encoder_mask, encoder_keys): self_attn_output = self.first_sub_layer(encoder_query, encoder_keys, encoder_keys, encoder_mask) self_attn_output += residual + if self.is_adapter_available(): + # Call the MHA adapters + pack_input = { + 'x': self_attn_output, + 'loc': 'mha', + 'att_mask': encoder_mask, + 'pos_emb': None, + } + pack_input = self.forward_enabled_adapters(pack_input) + self_attn_output = pack_input['x'] + residual = self_attn_output self_attn_output = self.layer_norm_2(self_attn_output) output_states = self.second_sub_layer(self_attn_output) output_states += residual + if self.is_adapter_available(): + # Call the Linear adapters + pack_input = { + 'x': output_states, + 'loc': 'post', + } + pack_input = self.forward_enabled_adapters(pack_input) + output_states = pack_input['x'] + return output_states def forward_postln(self, encoder_query, encoder_mask, encoder_keys): @@ -84,10 +112,32 @@ def forward_postln(self, encoder_query, encoder_mask, encoder_keys): """ self_attn_output = self.first_sub_layer(encoder_query, encoder_keys, encoder_keys, encoder_mask) self_attn_output += encoder_query + + if self.is_adapter_available(): + # Call the MHA adapters + pack_ip = { + 'x': self_attn_output, + 'loc': 'mha', + 'att_mask': encoder_mask, + 'pos_emb': None, + } + pack_ip = self.forward_enabled_adapters(pack_ip) + self_attn_output = pack_ip['x'] + self_attn_output = self.layer_norm_1(self_attn_output) output_states = self.second_sub_layer(self_attn_output) output_states += self_attn_output + + if self.is_adapter_available(): + # Call the linear adapters + pack_ip = { + 'x': output_states, + 'loc': 'post', + } + pack_ip = self.forward_enabled_adapters(pack_ip) + output_states = pack_ip['x'] + output_states = self.layer_norm_2(output_states) return output_states @@ -98,6 +148,19 @@ def forward(self, encoder_query, encoder_mask, encoder_keys): else: return self.forward_postln(encoder_query, encoder_mask, encoder_keys) + def get_accepted_adapter_types(self) -> Set[type]: + types = super().get_accepted_adapter_types() + + if len(types) == 0: + self.set_accepted_adapter_types( + [ + adapter_utils.LINEAR_ADAPTER_CLASSPATH, + adapter_utils.TRANSFORMER_MHA_ADAPTER_CLASSPATH, + ] + ) + types = self.get_accepted_adapter_types() + return types + class TransformerEncoder(nn.Module): def __init__( @@ -121,6 +184,8 @@ def __init__( else: self.final_layer_norm = None + self.d_model = hidden_size + layer = TransformerEncoderBlock( hidden_size, inner_size, @@ -172,3 +237,38 @@ def forward(self, encoder_states, encoder_mask, encoder_mems_list=None, return_m return cached_mems_list else: return cached_mems_list[-1] + + +class TransformerEncoderAdapter(TransformerEncoder, adapter_mixins.AdapterModuleMixin): + + # Higher level forwarding + def add_adapter(self, name: str, cfg: dict): + cfg = self._update_adapter_cfg_input_dim(cfg) + for transformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin + transformer_layer.add_adapter(name, cfg) + + def is_adapter_available(self) -> bool: + return any([transformer_layer.is_adapter_available() for transformer_layer in self.layers]) + + def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True): + for transformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin + transformer_layer.set_enabled_adapters(name=name, enabled=enabled) + + def get_enabled_adapters(self) -> List[str]: + names = set([]) + for transformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin + names.update(transformer_layer.get_enabled_adapters()) + + names = sorted(list(names)) + return names + + def _update_adapter_cfg_input_dim(self, cfg: DictConfig): + cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=self.d_model) + return cfg + + +""" +Register any additional information +""" +if adapter_mixins.get_registered_adapter(TransformerEncoder) is None: + adapter_mixins.register_adapter(base_class=TransformerEncoder, adapter_class=TransformerEncoderAdapter) diff --git a/nemo/collections/asr/modules/transformer/transformer_generators.py b/nemo/collections/asr/modules/transformer/transformer_generators.py index 4061f54a907a4..1a38e7fa4b6c2 100644 --- a/nemo/collections/asr/modules/transformer/transformer_generators.py +++ b/nemo/collections/asr/modules/transformer/transformer_generators.py @@ -173,7 +173,7 @@ def _forward( def __call__( self, decoder_input_ids=None, encoder_hidden_states=None, encoder_input_mask=None, return_beam_scores=False ): - with self.as_frozen(): + with torch.inference_mode(): results = self._forward( decoder_input_ids, encoder_hidden_states, encoder_input_mask, return_beam_scores=return_beam_scores ) @@ -188,8 +188,7 @@ def __call__( return prefixes, scores, tgt def freeze(self) -> None: - """Freeze weights of embedding, decoder, and classification layers to prevent memory leak. - """ + """Freeze weights of embedding, decoder, and classification layers to prevent memory leak.""" for param in self.embedding.parameters(): param.requires_grad = False self.embedding.eval() @@ -201,8 +200,7 @@ def freeze(self) -> None: self.log_softmax.eval() def unfreeze(self) -> None: - """Unfreeze weights of embedding, decoder, and classification layers. - """ + """Unfreeze weights of embedding, decoder, and classification layers.""" for param in self.embedding.parameters(): param.requires_grad = True self.embedding.train() @@ -357,13 +355,13 @@ def _forward( # choose top-k hypotheses with length penalty applied len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen) scores = scores / len_penalties - scores, indices_i = torch.topk(scores.view(-1, self.beam_size ** 2), self.beam_size, dim=1) + scores, indices_i = torch.topk(scores.view(-1, self.beam_size**2), self.beam_size, dim=1) scores = scores.view(-1, 1) * len_penalties # select prefixes which correspond to the chosen hypotheses prefixes = prefixes.unsqueeze(1).repeat(1, self.beam_size, 1) prefixes = torch.cat((prefixes, prefixes_i.unsqueeze(2)), dim=2) - prefixes = prefixes.view(batch_size, self.beam_size ** 2, -1) + prefixes = prefixes.view(batch_size, self.beam_size**2, -1) p_len = prefixes.size(2) prefixes_ids = indices_i.unsqueeze(2).repeat(1, 1, p_len) prefixes = prefixes.gather(1, prefixes_ids).view(-1, p_len) @@ -463,7 +461,10 @@ def _one_step_forward_lm(self, decoder_input_ids=None, lm_mems_list=None, pos=0) input_mask = mask_padded_tokens(decoder_input_ids, self.pad).float() lm_hidden_states = self.language_model.encoder.embedding.forward(decoder_input_ids, start_pos=pos) lm_mems_list = self.language_model.encoder.encoder.forward( - lm_hidden_states, input_mask, lm_mems_list, return_mems=True, + lm_hidden_states, + input_mask, + lm_mems_list, + return_mems=True, ) lm_log_probs = self.language_model.log_softmax.forward(hidden_states=lm_mems_list[-1][:, -1:]) return lm_log_probs, lm_mems_list @@ -639,13 +640,13 @@ def _forward(self, src_ids, encoder_input_mask, decoder_input_ids=None, return_b # choose top-k hypotheses with length penalty applied len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen) scores = scores / len_penalties - scores, indices_i = torch.topk(scores.view(-1, self.beam_size ** 2), self.beam_size, dim=1) + scores, indices_i = torch.topk(scores.view(-1, self.beam_size**2), self.beam_size, dim=1) scores = scores.view(-1, 1) * len_penalties # select prefixes which correspond to the chosen hypotheses prefixes = prefixes.unsqueeze(1).repeat(1, self.beam_size, 1) prefixes = torch.cat((prefixes, prefixes_i.unsqueeze(2)), dim=2) - prefixes = prefixes.view(batch_size, self.beam_size ** 2, -1) + prefixes = prefixes.view(batch_size, self.beam_size**2, -1) p_len = prefixes.size(2) prefixes_ids = indices_i.unsqueeze(2).repeat(1, 1, p_len) prefixes = prefixes.gather(1, prefixes_ids).view(-1, p_len) @@ -697,12 +698,11 @@ def _forward(self, src_ids, encoder_input_mask, decoder_input_ids=None, return_b return tgt def __call__(self, src_ids, encoder_input_mask, decoder_input_ids=None, return_beam_scores=False): - with self.as_frozen(): + with torch.inference_mode(): return self._forward(src_ids, encoder_input_mask, decoder_input_ids, return_beam_scores) def freeze(self) -> None: - """Freeze weights of embedding, decoder, and classification layers to prevent memory leak. - """ + """Freeze weights of embedding, decoder, and classification layers to prevent memory leak.""" for model_num in range(self.num_models): for param in self.embeddings[model_num].parameters(): param.requires_grad = False @@ -718,8 +718,7 @@ def freeze(self) -> None: self.encoders[model_num].eval() def unfreeze(self) -> None: - """Unfreeze weights of embedding, decoder, and classification layers. - """ + """Unfreeze weights of embedding, decoder, and classification layers.""" for model_num in range(self.num_models): for param in self.embeddings[model_num].parameters(): param.requires_grad = True @@ -781,13 +780,20 @@ def _one_step_forward( ): nmt_log_probs, decoder_mems_list = super()._one_step_forward( - decoder_input_ids, encoder_hidden_states, encoder_input_mask, decoder_mems_list, pos, + decoder_input_ids, + encoder_hidden_states, + encoder_input_mask, + decoder_mems_list, + pos, ) input_mask = mask_padded_tokens(decoder_input_ids, self.pad).float() lm_hidden_states = self.language_model.encoder.embedding.forward(decoder_input_ids, start_pos=pos) lm_mems_list = self.language_model.encoder.encoder.forward( - lm_hidden_states, input_mask, lm_mems_list, return_mems=True, + lm_hidden_states, + input_mask, + lm_mems_list, + return_mems=True, ) lm_log_probs = self.language_model.log_softmax.forward(hidden_states=lm_mems_list[-1][:, -1:]) @@ -863,13 +869,13 @@ def _forward( # choose top-k hypotheses with length penalty applied len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen) scores = scores / len_penalties - scores, indices_i = torch.topk(scores.view(-1, self.beam_size ** 2), self.beam_size, dim=1) + scores, indices_i = torch.topk(scores.view(-1, self.beam_size**2), self.beam_size, dim=1) scores = scores.view(-1, 1) * len_penalties # select prefixes which correspond to the chosen hypotheses prefixes = prefixes.unsqueeze(1).repeat(1, self.beam_size, 1) prefixes = torch.cat((prefixes, prefixes_i.unsqueeze(2)), dim=2) - prefixes = prefixes.view(batch_size, self.beam_size ** 2, -1) + prefixes = prefixes.view(batch_size, self.beam_size**2, -1) p_len = prefixes.size(2) prefixes_ids = indices_i.unsqueeze(2).repeat(1, 1, p_len) prefixes = prefixes.gather(1, prefixes_ids).view(-1, p_len) diff --git a/nemo/collections/asr/modules/transformer/transformer_modules.py b/nemo/collections/asr/modules/transformer/transformer_modules.py index 25fb781f0cd44..d090604287cb5 100644 --- a/nemo/collections/asr/modules/transformer/transformer_modules.py +++ b/nemo/collections/asr/modules/transformer/transformer_modules.py @@ -65,7 +65,9 @@ def forward(self, position_ids): f'Max position id {max_pos_id} is greater than max sequence length {self._max_sequence_length}. Expanding position embeddings just for this batch. This is not expected to work very well. Consider chunking your input into smaller sequences.' ) self._build_pos_enc( - hidden_size=self._hidden_size, max_sequence_length=max_pos_id + 1, device=position_ids.device, + hidden_size=self._hidden_size, + max_sequence_length=max_pos_id + 1, + device=position_ids.device, ) embeddings = torch.embedding(self.pos_enc, position_ids) @@ -203,8 +205,9 @@ def forward(self, queries, keys, values, attention_mask): attention_probs = self.attn_dropout(attention_probs) context = torch.matmul(attention_probs, value) + context_hidden_size = context.size()[-1] * self.num_attention_heads context = context.permute(0, 2, 1, 3).contiguous() - new_context_shape = context.size()[:-2] + (self.hidden_size,) + new_context_shape = context.size()[:-2] + (context_hidden_size,) context = context.view(*new_context_shape) # output projection diff --git a/nemo/collections/asr/modules/transformer/transformer_utils.py b/nemo/collections/asr/modules/transformer/transformer_utils.py index da9ffb8fbd002..5de1652ee1b0f 100644 --- a/nemo/collections/asr/modules/transformer/transformer_utils.py +++ b/nemo/collections/asr/modules/transformer/transformer_utils.py @@ -113,6 +113,7 @@ def get_nemo_transformer( else: raise ValueError(f"Unknown arch = {arch}") else: + model = TransformerDecoderNM( vocab_size=cfg.get('vocab_size'), hidden_size=cfg.get('hidden_size'), diff --git a/nemo/collections/asr/parts/mixins/asr_adapter_mixins.py b/nemo/collections/asr/parts/mixins/asr_adapter_mixins.py index f452acd19847e..bd0607f2c4f35 100644 --- a/nemo/collections/asr/parts/mixins/asr_adapter_mixins.py +++ b/nemo/collections/asr/parts/mixins/asr_adapter_mixins.py @@ -21,7 +21,7 @@ class ASRAdapterModelMixin(AdapterModelPTMixin): - """ ASR Adapter Mixin that can augment any Encoder module with Adapter module support. + """ASR Adapter Mixin that can augment any Encoder module with Adapter module support. This mixin class should be used only with a top level ModelPT subclass, that includes an `encoder` submodule. This mixin class adds several utility methods which are propagated to the `encoder`. @@ -54,14 +54,10 @@ def setup_adapters(self): supports_adapters = False # At least the encoder must extend AdapterModuleMixin - if hasattr(self, 'encoder') and isinstance(self.encoder, AdapterModuleMixin): - supports_adapters |= True - - if hasattr(self, 'decoder') and isinstance(self.decoder, AdapterModuleMixin): - supports_adapters |= True - - if hasattr(self, 'joint') and isinstance(self.joint, AdapterModuleMixin): - supports_adapters |= True + valid_adapter_names = [x for x in self.adapter_module_names if x != ''] + for module_name in valid_adapter_names: + if hasattr(self, module_name) and isinstance(getattr(self, module_name), AdapterModuleMixin): + supports_adapters |= True # If adapters are supported, setup the adapter config + any modules (pre-existing adapter modules) if supports_adapters: @@ -87,24 +83,30 @@ def add_adapter(self, name: str, cfg: DictConfig): else: module_names = [module_name] + valid_module_names = [x for x in self.adapter_module_names if x != ''] + default_module_name = self.default_adapter_module_name + + # Check if default module name is None or not + if default_module_name is None: + raise ValueError( + f"Default module name is None. Class {self.__class__.__name__} must implement " + f"`default_adapter_module_name`" + ) + # Update the model.cfg with information about the new adapter from cfg with open_dict(self.cfg): for module_name in module_names: # Check if encoder adapters should be added - if module_name in ('', 'encoder'): - # Dispatch the call to the encoder. - self.encoder.add_adapter(name=name, cfg=cfg) - - # Check if decoder adapters should be added - if module_name == 'decoder': - # Dispatch call to the decoder. - self.decoder.add_adapter(name=name, cfg=cfg) + if module_name == '': + if hasattr(self, default_module_name): + # Dispatch the call to the default model. + getattr(self, default_module_name).add_adapter(name=name, cfg=cfg) - # Check if joint adapters should be added; - # Note: We need additional check if joint even exists in model (for CTC models) - if hasattr(self, 'joint') and module_name == 'joint': - # Dispatch call to the joint. - self.joint.add_adapter(name=name, cfg=cfg) + elif module_name in valid_module_names: + # Check if module exists + if hasattr(self, module_name): + # Dispatch the call to the module. + getattr(self, module_name).add_adapter(name=name, cfg=cfg) def is_adapter_available(self) -> bool: """ @@ -116,15 +118,12 @@ def is_adapter_available(self) -> bool: """ config_contains_adapter = super().is_adapter_available() - # Forward the method call to the individual modules - if hasattr(self, 'encoder') and isinstance(self.encoder, AdapterModuleMixin): - config_contains_adapter |= self.encoder.is_adapter_available() - - if hasattr(self, 'decoder') and isinstance(self.decoder, AdapterModuleMixin): - config_contains_adapter |= self.decoder.is_adapter_available() + valid_module_names = [x for x in self.adapter_module_names if x != ''] - if hasattr(self, 'joint') and isinstance(self.joint, AdapterModuleMixin): - config_contains_adapter |= self.joint.is_adapter_available() + # Forward the method call to the individual modules + for module_name in valid_module_names: + if hasattr(self, module_name) and isinstance(getattr(self, module_name), AdapterModuleMixin): + config_contains_adapter |= getattr(self, module_name).is_adapter_available() return config_contains_adapter @@ -160,23 +159,29 @@ def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True) else: module_names = [module_name] + valid_module_names = [x for x in self.adapter_module_names if x != ''] + default_module_name = self.default_adapter_module_name + + # Check if default module name is None or not + if default_module_name is None: + raise ValueError( + f"Default module name is None. Class {self.__class__.__name__} must implement " + f"`default_adapter_module_name`" + ) + + # Forward the method call to the individual modules if they exist for module_name in module_names: # Check if encoder adapters should be used - # Dispatch the call to the encoder. - if name is None or module_name in ('', 'encoder'): - if self.encoder.is_adapter_available(): - self.encoder.set_enabled_adapters(name=name, enabled=enabled) - - # Dispatch the call to the decoder. - if name is None or module_name == 'decoder': - if self.decoder.is_adapter_available(): - self.decoder.set_enabled_adapters(name=name, enabled=enabled) - - # Dispatch the call to the joint. - # Note: We need additional check for joint, since it may not exist (CTC models). - if name is None or module_name == 'joint': - if hasattr(self, 'joint') and self.joint.is_adapter_available(): - self.joint.set_enabled_adapters(name=name, enabled=enabled) + + if module_name == '': + if hasattr(self, default_module_name): + # Dispatch the call to the default model. + getattr(self, default_module_name).set_enabled_adapters(name=name, enabled=enabled) + + elif module_name in valid_module_names: + if hasattr(self, module_name): + # Dispatch the call to the module. + getattr(self, module_name).set_enabled_adapters(name=name, enabled=enabled) def get_enabled_adapters(self) -> List[str]: """ @@ -187,15 +192,12 @@ def get_enabled_adapters(self) -> List[str]: """ enabled_adapters = super().get_enabled_adapters() - # Check if encoder adapters should be used or are enabled - if hasattr(self, 'encoder') and isinstance(self.encoder, AdapterModuleMixin): - enabled_adapters.extend(self.encoder.get_enabled_adapters()) + valid_module_names = [x for x in self.adapter_module_names if x != ''] - if hasattr(self, 'decoder') and isinstance(self.decoder, AdapterModuleMixin): - enabled_adapters.extend(self.decoder.get_enabled_adapters()) - - if hasattr(self, 'joint') and isinstance(self.joint, AdapterModuleMixin): - enabled_adapters.extend(self.joint.get_enabled_adapters()) + # Check if encoder adapters should be used or are enabled + for module_name in valid_module_names: + if hasattr(self, module_name) and isinstance(getattr(self, module_name), AdapterModuleMixin): + enabled_adapters.extend(getattr(self, module_name).get_enabled_adapters()) enabled_adapters = list(sorted(list(set(enabled_adapters)))) @@ -208,44 +210,19 @@ def check_valid_model_with_adapter_support_(self): # Obtain the global adapter config if possible, otherwise use sensible defaults. global_cfg = self._get_global_cfg() - # Test whether the encoder supports adapters - use_encoder_adapter = global_cfg.get('check_encoder_adapter', True) - if use_encoder_adapter: - if not hasattr(self, 'encoder'): - logging.warning( - "Cannot add adapter to this object as it does not have an `encoder` sub-module!", - mode=logging_mode.ONCE, - ) - - if hasattr(self, 'encoder') and not isinstance(self.encoder, AdapterModuleMixin): - logging.warning( - f'{self.encoder.__class__.__name__} does not implement `AdapterModuleMixin`', - mode=logging_mode.ONCE, - ) - - # Test whether the decoder supports adapters - use_decoder_adapter = global_cfg.get('check_decoder_adapter', True) - if use_decoder_adapter: - if not hasattr(self, 'decoder'): - logging.warning( - "Cannot add adapter to this object as it does not have an `decoder` sub-module!", - mode=logging_mode.ONCE, - ) - - if hasattr(self, 'decoder') and not isinstance(self.decoder, AdapterModuleMixin): - logging.warning( - f'{self.decoder.__class__.__name__} does not implement `AdapterModuleMixin`', - mode=logging_mode.ONCE, - ) - - # Test whether the joint supports adapters - use_joint_adapter = global_cfg.get('check_joint_adapter', True) - if use_joint_adapter: - # Joint is only for RNNT models, skip assertion that it must always exist. - if hasattr(self, 'joint') and not isinstance(self.joint, AdapterModuleMixin): - logging.warning( - f'{self.joint.__class__.__name__} does not implement `AdapterModuleMixin`', mode=logging_mode.ONCE - ) + valid_module_names = [x for x in self.adapter_module_names if x != ''] + + for module_name in valid_module_names: + check_adapter_support = global_cfg.get(f'check_{module_name}_adapter', True) + + if check_adapter_support: + # Test whether the module supports adapters + if hasattr(self, module_name) and not isinstance(getattr(self, module_name), AdapterModuleMixin): + logging.warning( + f'Module `{module_name}` exists, but {getattr(self, module_name).__class__.__name__} ' + f'does not implement `AdapterModuleMixin`', + mode=logging_mode.ONCE, + ) def resolve_adapter_module_name_(self, name: str) -> Tuple[str, str]: """ @@ -293,3 +270,7 @@ def _get_global_cfg(self): def adapter_module_names(self) -> List[str]: valid_module_names = ['', 'encoder', 'decoder', 'joint'] return valid_module_names + + @property + def default_adapter_module_name(self) -> str: + return 'encoder' diff --git a/nemo/collections/asr/parts/mixins/mixins.py b/nemo/collections/asr/parts/mixins/mixins.py index 1ec4066220361..f5b4381f7fb73 100644 --- a/nemo/collections/asr/parts/mixins/mixins.py +++ b/nemo/collections/asr/parts/mixins/mixins.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os import shutil import tarfile @@ -31,7 +32,7 @@ class ASRBPEMixin(ABC): - """ ASR BPE Mixin class that sets up a Tokenizer via a config + """ASR BPE Mixin class that sets up a Tokenizer via a config This mixin class adds the method `_setup_tokenizer(...)`, which can be used by ASR models which depend on subword tokenization. @@ -204,7 +205,12 @@ def _setup_aggregate_tokenizer(self, tokenizer_cfg: DictConfig): tokenizers_dict = {} # init each of the monolingual tokenizers found in the config and assemble into AggregateTokenizer for lang, tokenizer_config in self.tokenizer_cfg[self.AGGREGATE_TOKENIZERS_DICT_PREFIX].items(): - (tokenizer, model_path, vocab_path, spe_vocab_path,) = self._make_tokenizer(tokenizer_config, lang) + ( + tokenizer, + model_path, + vocab_path, + spe_vocab_path, + ) = self._make_tokenizer(tokenizer_config, lang) tokenizers_dict[lang] = tokenizer if hasattr(self, 'cfg'): @@ -845,7 +851,23 @@ def _setup_streaming_transcribe_dataloader( streaming_buffer.reset_buffer() -class DiarizationMixin(ABC): +class VerificationMixin(ABC): + @staticmethod + def path2audio_files_to_manifest(paths2audio_files, manifest_filepath): + """ + Takes paths to audio files and manifest filepath and creates manifest file with the audios + Args: + paths2audio_files: paths to audio fragment to be verified + manifest_filepath: path to manifest file to bre created + """ + with open(manifest_filepath, 'w', encoding='utf-8') as fp: + for audio_file in paths2audio_files: + audio_file = audio_file.strip() + entry = {'audio_filepath': audio_file, 'offset': 0.0, 'duration': None, 'text': '-', 'label': 'infer'} + fp.write(json.dumps(entry) + '\n') + + +class DiarizationMixin(VerificationMixin): @abstractmethod def diarize(self, paths2audio_files: List[str], batch_size: int = 1) -> List[str]: """ diff --git a/nemo/collections/asr/parts/mixins/transcription.py b/nemo/collections/asr/parts/mixins/transcription.py index 5b9461d0a3896..b6238cad4534a 100644 --- a/nemo/collections/asr/parts/mixins/transcription.py +++ b/nemo/collections/asr/parts/mixins/transcription.py @@ -28,8 +28,7 @@ from tqdm import tqdm from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations -from nemo.collections.asr.parts.preprocessing.segment import AudioSegment -from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType +from nemo.collections.asr.parts.preprocessing.segment import AudioSegment, ChannelSelectorType from nemo.utils import logging, logging_mode TranscriptionReturnType = Union[List[str], List['Hypothesis'], Tuple[List[str]], Tuple[List['Hypothesis']]] diff --git a/nemo/collections/asr/parts/preprocessing/features.py b/nemo/collections/asr/parts/preprocessing/features.py index dccc81b1816ca..d70737b5135b5 100644 --- a/nemo/collections/asr/parts/preprocessing/features.py +++ b/nemo/collections/asr/parts/preprocessing/features.py @@ -131,7 +131,7 @@ def clean_spectrogram_batch(spectrogram: torch.Tensor, spectrogram_len: torch.Te def splice_frames(x, frame_splicing): - """ Stacks frames together across feature dim + """Stacks frames together across feature dim input is batch_size, feature_dim, num_frames output is batch_size, feature_dim*frame_splicing, num_frames @@ -261,7 +261,7 @@ def __init__( highfreq=None, log=True, log_zero_guard_type="add", - log_zero_guard_value=2 ** -24, + log_zero_guard_value=2**-24, dither=CONSTANT, pad_to=16, max_duration=16.7, @@ -308,6 +308,7 @@ def __init__( self.hop_length = n_window_stride self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length)) self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 if exact_pad else None + self.exact_pad = exact_pad if exact_pad: logging.info("STFT using exact pad") @@ -321,15 +322,6 @@ def __init__( window_fn = torch_windows.get(window, None) window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None self.register_buffer("window", window_tensor) - self.stft = lambda x: torch.stft( - x, - n_fft=self.n_fft, - hop_length=self.hop_length, - win_length=self.win_length, - center=False if exact_pad else True, - window=self.window.to(dtype=torch.float), - return_complex=True, - ) self.normalize = normalize self.log = log @@ -388,6 +380,17 @@ def __init__( logging.debug(f"using grads: {use_grads}") logging.debug(f"nb_augmentation_prob: {nb_augmentation_prob}") + def stft(self, x): + return torch.stft( + x, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + center=False if self.exact_pad else True, + window=self.window.to(dtype=torch.float), + return_complex=True, + ) + def log_zero_guard_value_fn(self, x): if isinstance(self.log_zero_guard_value, str): if self.log_zero_guard_value == "tiny": @@ -508,7 +511,7 @@ def __init__( highfreq: Optional[float] = None, log: bool = True, log_zero_guard_type: str = "add", - log_zero_guard_value: Union[float, str] = 2 ** -24, + log_zero_guard_value: Union[float, str] = 2**-24, dither: float = 1e-5, window: str = "hann", pad_to: int = 0, @@ -579,7 +582,7 @@ def __init__( @property def filter_banks(self): - """ Matches the analogous class """ + """Matches the analogous class""" return self._mel_spec_extractor.mel_scale.fb def _resolve_log_zero_guard_value(self, dtype: torch.dtype) -> float: diff --git a/nemo/collections/asr/parts/preprocessing/segment.py b/nemo/collections/asr/parts/preprocessing/segment.py index be78ac74b71d6..310e76cfd0b0f 100644 --- a/nemo/collections/asr/parts/preprocessing/segment.py +++ b/nemo/collections/asr/parts/preprocessing/segment.py @@ -36,13 +36,13 @@ import math import os import random -from typing import Optional +from typing import Iterable, Optional, Union import librosa import numpy as np +import numpy.typing as npt import soundfile as sf -from nemo.collections.asr.parts.utils.audio_utils import select_channels from nemo.utils import logging # TODO @blisc: Perhaps refactor instead of import guarding @@ -50,6 +50,10 @@ try: from pydub import AudioSegment as Audio from pydub.exceptions import CouldntDecodeError + + # FFMPEG for some formats needs explicitly defined coding-decoding strategy + ffmpeg_codecs = {'opus': 'opus'} + except ModuleNotFoundError: HAVE_PYDUB = False @@ -58,6 +62,92 @@ sf_supported_formats = ["." + i.lower() for i in available_formats.keys()] +ChannelSelectorType = Union[int, Iterable[int], str] + + +def select_channels(signal: npt.NDArray, channel_selector: Optional[ChannelSelectorType] = None) -> npt.NDArray: + """ + Convert a multi-channel signal to a single-channel signal by averaging over channels or selecting a single channel, + or pass-through multi-channel signal when channel_selector is `None`. + + Args: + signal: numpy array with shape (..., num_channels) + channel selector: string denoting the downmix mode, an integer denoting the channel to be selected, or an iterable + of integers denoting a subset of channels. Channel selector is using zero-based indexing. + If set to `None`, the original signal will be returned. Uses zero-based indexing. + + Returns: + numpy array + """ + if signal.ndim == 1: + # For one-dimensional input, return the input signal. + if channel_selector not in [None, 0, 'average']: + raise ValueError( + 'Input signal is one-dimensional, channel selector (%s) cannot not be used.', str(channel_selector) + ) + return signal + + num_channels = signal.shape[-1] + num_samples = signal.size // num_channels # handle multi-dimensional signals + + if num_channels >= num_samples: + logging.warning( + 'Number of channels (%d) is greater or equal than number of samples (%d). Check for possible transposition.', + num_channels, + num_samples, + ) + + # Samples are arranged as (num_channels, ...) + if channel_selector is None: + # keep the original multi-channel signal + pass + elif channel_selector == 'average': + # default behavior: downmix by averaging across channels + signal = np.mean(signal, axis=-1) + elif isinstance(channel_selector, int): + # select a single channel + if channel_selector >= num_channels: + raise ValueError(f'Cannot select channel {channel_selector} from a signal with {num_channels} channels.') + signal = signal[..., channel_selector] + elif isinstance(channel_selector, Iterable): + # select multiple channels + if max(channel_selector) >= num_channels: + raise ValueError( + f'Cannot select channel subset {channel_selector} from a signal with {num_channels} channels.' + ) + signal = signal[..., channel_selector] + # squeeze the channel dimension if a single-channel is selected + # this is done to have the same shape as when using integer indexing + if len(channel_selector) == 1: + signal = np.squeeze(signal, axis=-1) + else: + raise ValueError(f'Unexpected value for channel_selector ({channel_selector})') + + return signal + + +def get_samples(audio_file: str, target_sr: int = 16000, dtype: str = 'float32'): + """ + Read the samples from the given audio_file path. If not specified, the input audio file is automatically + resampled to 16kHz. + + Args: + audio_file (str): + Path to the input audio file + target_sr (int): + Targeted sampling rate + Returns: + samples (numpy.ndarray): + Time-series sample data from the given audio file + """ + with sf.SoundFile(audio_file, 'r') as f: + samples = f.read(dtype=dtype) + if f.samplerate != target_sr: + samples = librosa.core.resample(samples, orig_sr=f.samplerate, target_sr=target_sr) + samples = samples.transpose() + return samples + + class AudioSegment(object): """Audio segment abstraction. :param samples: Audio samples [num_samples x num_channels]. @@ -256,14 +346,14 @@ def from_file( if HAVE_PYDUB and samples is None: try: - samples = Audio.from_file(audio_file) + samples = Audio.from_file(audio_file, codec=ffmpeg_codecs.get(os.path.splitext(audio_file)[-1])) sample_rate = samples.frame_rate num_channels = samples.channels if offset > 0: # pydub does things in milliseconds seconds = offset * 1000 samples = samples[int(seconds) :] - if duration > 0: + if duration is not None and duration > 0: seconds = duration * 1000 samples = samples[: int(seconds)] samples = np.array(samples.get_array_of_samples()) @@ -370,7 +460,13 @@ def from_file_list( sample_rate = target_sr return cls( - samples, sample_rate, target_sr=target_sr, trim=trim, channel_selector=channel_selector, *args, **kwargs, + samples, + sample_rate, + target_sr=target_sr, + trim=trim, + channel_selector=channel_selector, + *args, + **kwargs, ) @classmethod @@ -468,9 +564,8 @@ def duration(self): @property def rms_db(self): - """Return per-channel RMS value. - """ - mean_square = np.mean(self._samples ** 2, axis=0) + """Return per-channel RMS value.""" + mean_square = np.mean(self._samples**2, axis=0) return 10 * np.log10(mean_square) @property @@ -481,7 +576,7 @@ def gain_db(self, gain): self._samples *= 10.0 ** (gain / 20.0) def normalize_db(self, target_db=-20, ref_channel=None): - """Normalize the signal to a target RMS value in decibels. + """Normalize the signal to a target RMS value in decibels. For multi-channel audio, the RMS value is determined by the reference channel (if not None), otherwise it will be the maximum RMS across all channels. """ @@ -509,7 +604,11 @@ def pad(self, pad_size, symmetric=False): f"Padding not implemented for signals with more that 2 dimensions. Current samples dimension: {samples_ndim}." ) # apply padding - self._samples = np.pad(self._samples, pad_width, mode='constant',) + self._samples = np.pad( + self._samples, + pad_width, + mode='constant', + ) def subsegment(self, start_time=None, end_time=None): """Cut the AudioSegment between given boundaries. diff --git a/nemo/collections/asr/parts/submodules/adapters/__init__.py b/nemo/collections/asr/parts/submodules/adapters/__init__.py index 6aa05d07dea1d..c51d935bddd47 100644 --- a/nemo/collections/asr/parts/submodules/adapters/__init__.py +++ b/nemo/collections/asr/parts/submodules/adapters/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# fmt: off +from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin from nemo.collections.asr.parts.submodules.adapters.multi_head_attention_adapter_module import ( MHAResidualAddAdapterStrategy, MHAResidualAddAdapterStrategyConfig, @@ -24,3 +26,9 @@ RelPositionMultiHeadAttentionAdapter, RelPositionMultiHeadAttentionAdapterConfig, ) +from nemo.collections.asr.parts.submodules.adapters.transformer_multi_head_attention_adapter_module import ( + TransformerMultiHeadAttentionAdapter, + TransformerMultiHeadAttentionAdapterConfig, +) + +# fmt: on diff --git a/nemo/collections/asr/parts/submodules/adapters/attention_adapter_mixin.py b/nemo/collections/asr/parts/submodules/adapters/attention_adapter_mixin.py new file mode 100644 index 0000000000000..0c1852773072b --- /dev/null +++ b/nemo/collections/asr/parts/submodules/adapters/attention_adapter_mixin.py @@ -0,0 +1,119 @@ +import torch + +from nemo.core.classes.mixins import adapter_mixins +from nemo.utils import logging, logging_mode + + +class AttentionAdapterModuleMixin(adapter_mixins.AdapterModuleMixin): + """ + Utility class that implements a custom forward method for Modules that are attention based. + Attention based adapters can support either linear adapters, and Multi-Head Attention adapters. + + However, Multi Head Attention adapters require additional arguments, such as `att_mask` and `pos_emb`. + This utility class unifies the adapter forward pass for both types of adapters. + + .. Usage: + + To use this class, inherit from this class, and when calling self.foward_enabled_adapters() pass the following: + + .. code-block:: python + + if self.is_adapter_available(): + # Call the MHA adapters + pack_ip = { + 'x': residual, + 'loc': 'mha', + 'att_mask': att_mask, + 'pos_emb': pos_emb, + } + pack_ip = self.forward_enabled_adapters(pack_ip) + residual = pack_ip['x'] + + if self.is_adapter_available(): + # Call the Linear adapters + pack_ip = { + 'x': x, + 'loc': 'post', + } + pack_ip = self.forward_enabled_adapters(pack_ip) + x = pack_ip['x'] + """ + + def forward_single_enabled_adapter_( + self, + input: dict, + adapter_module: torch.nn.Module, + *, + adapter_name: str, + adapter_strategy: 'nemo.core.classes.mixins.adapter_mixin_strategies.AbstractAdapterStrategy', + ): + """ + Perform the forward step of a single adapter module on some input data. + + **Note**: Subclasses can override this method to accommodate more complicate adapter forward steps. + + Args: + input: Dictionary of packed tensors. The dict should contain at least + `x`: output tensor + `loc`: Semantic location in module where this adapter was called. Can be 'mha' or 'post'. + `att_mask`: Optional, Attention mask + `pos_emb`: Optional, Positional Embedding for Relative Positional Encoding. + The output tensor of the calling module is the input to the first adapter, whose output + is then chained to the next adapter until all adapters are consumed. + adapter_module: The adapter module that is currently required to perform the forward pass. + adapter_name: The resolved name of the adapter that is undergoing the current forward pass. + adapter_strategy: A subclass of `AbstractAdapterStrategy`, that determines how the + output of the adapter should be merged with the input, or if it should be merged at all. + + Returns: + The result tensor, after the current active adapter has finished its forward pass. + """ + if not hasattr(self, 'self_attention_model'): + raise RuntimeError( + "self_attention_model attribute not found in the module! Please set in the module " + "a string attribute 'self_attention_model' with value 'abs_pos', 'rel_pos' or " + "other supported self-attention model types." + ) + + # Collect imports to prevent circular imports + from nemo.collections.asr.modules.transformer import transformer_modules as transformer_mha + from nemo.collections.asr.parts.submodules import multi_head_attention as conformer_mha + + # (input: torch.Tensor, adapter: torch.nn.Module, *, module: 'AdapterModuleMixin') + x = input['x'] + loc = input['loc'] + att_mask = input.get('att_mask', None) + pos_emb = input.get('pos_emb', None) + + from nemo.collections.common.parts import adapter_modules + + if isinstance(adapter_module, adapter_modules.LinearAdapter) and loc == 'post': + output = adapter_strategy(x, adapter_module, module=self) + + elif isinstance(adapter_module, conformer_mha.MultiHeadAttention) and loc == 'mha': + if self.self_attention_model == 'rel_pos': + x = dict(query=x, key=x, value=x, mask=att_mask, pos_emb=pos_emb) + output = adapter_strategy(x, adapter_module, module=self) + + elif self.self_attention_model == 'abs_pos': + x = dict(query=x, key=x, value=x, mask=att_mask) + output = adapter_strategy(x, adapter_module, module=self) + + else: + raise ValueError(f"Unsupported value of self_attention_model , provided {self.self_attention_model}!") + + elif isinstance(adapter_module, transformer_mha.MultiHeadAttention) and loc == 'mha': + x = dict(queries=x, keys=x, values=x, attention_mask=att_mask) + output = adapter_strategy(x, adapter_module, module=self) + + else: + # No adapter compatible, skip + logging.warning( + "No adapter compatible with the current module. Skipping adapter forward pass.", mode=logging_mode.ONCE + ) + + output = x + + input['x'] = output + + return input diff --git a/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py b/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py index 3df51092ac4b7..2617ed6f575b9 100644 --- a/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py +++ b/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py @@ -29,7 +29,7 @@ class MHAResidualAddAdapterStrategy(adapter_mixin_strategies.ResidualAddAdapterS An implementation of residual addition of an adapter module with its input for the MHA Adapters. """ - def forward(self, input: torch.Tensor, adapter: torch.nn.Module, *, module: 'AdapterModuleMixin'): + def forward(self, input: dict, adapter: torch.nn.Module, *, module: 'AdapterModuleMixin'): """ A basic strategy, comprising of a residual connection over the input, after forward pass by the underlying adapter. Additional work is done to pack and unpack the dictionary of inputs and outputs. @@ -55,18 +55,29 @@ def forward(self, input: torch.Tensor, adapter: torch.nn.Module, *, module: 'Ada """ out = self.compute_output(input, adapter, module=module) + value_name = None + if 'value' in input: + value_name = 'value' + elif 'values' in input: + value_name = 'values' + else: + raise ValueError( + "Input dictionary must contain 'value' or 'values' key for residual connection. Input " + f"dictionary keys: {input.keys()}" + ) + # If not in training mode, or probability of stochastic depth is 0, skip step. p = self.stochastic_depth if not module.training or p == 0.0: pass else: - out = self.apply_stochastic_depth(out, input['value'], adapter, module=module) + out = self.apply_stochastic_depth(out, input[value_name], adapter, module=module) # Return the residual connection output = input + adapter(input) - result = input['value'] + out + result = input[value_name] + out # If l2_lambda is activated, register the loss value - self.compute_auxiliary_losses(result, input['value'], adapter, module=module) + self.compute_auxiliary_losses(result, input[value_name], adapter, module=module) return result @@ -105,16 +116,16 @@ class MHAResidualAddAdapterStrategyConfig(adapter_mixin_strategies.ResidualAddAd class MultiHeadAttentionAdapter(mha.MultiHeadAttention, adapter_modules.AdapterModuleUtil): """Multi-Head Attention layer of Transformer. - Args: - n_head (int): number of heads - n_feat (int): size of the features - dropout_rate (float): dropout rate - proj_dim (int, optional): Optional integer value for projection before computing attention. - If None, then there is no projection (equivalent to proj_dim = n_feat). - If > 0, then will project the n_feat to proj_dim before calculating attention. - If <0, then will equal n_head, so that each head has a projected dimension of 1. - adapter_strategy: By default, MHAResidualAddAdapterStrategyConfig. An adapter composition function object. - """ + Args: + n_head (int): number of heads + n_feat (int): size of the features + dropout_rate (float): dropout rate + proj_dim (int, optional): Optional integer value for projection before computing attention. + If None, then there is no projection (equivalent to proj_dim = n_feat). + If > 0, then will project the n_feat to proj_dim before calculating attention. + If <0, then will equal n_head, so that each head has a projected dimension of 1. + adapter_strategy: By default, MHAResidualAddAdapterStrategyConfig. An adapter composition function object. + """ def __init__( self, @@ -300,7 +311,6 @@ class RelPositionMultiHeadAttentionAdapterConfig: class PositionalEncodingAdapter(mha.PositionalEncoding, adapter_modules.AdapterModuleUtil): - """ Absolute positional embedding adapter. @@ -327,7 +337,11 @@ def __init__( ): super().__init__( - d_model=d_model, dropout_rate=0.0, max_len=max_len, xscale=xscale, dropout_rate_emb=0.0, + d_model=d_model, + dropout_rate=0.0, + max_len=max_len, + xscale=xscale, + dropout_rate_emb=0.0, ) # Setup adapter strategy diff --git a/nemo/collections/asr/parts/submodules/adapters/transformer_multi_head_attention_adapter_module.py b/nemo/collections/asr/parts/submodules/adapters/transformer_multi_head_attention_adapter_module.py new file mode 100644 index 0000000000000..4319a6962f4fc --- /dev/null +++ b/nemo/collections/asr/parts/submodules/adapters/transformer_multi_head_attention_adapter_module.py @@ -0,0 +1,128 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass, field +from typing import Any, Optional + +import torch +from torch import nn as nn + +from nemo.collections.asr.modules.transformer import transformer_modules +from nemo.collections.asr.parts.submodules.adapters.multi_head_attention_adapter_module import ( + MHAResidualAddAdapterStrategy, + MHAResidualAddAdapterStrategyConfig, +) +from nemo.collections.common.parts import adapter_modules +from nemo.core.classes.mixins import adapter_mixin_strategies, adapter_mixins + + +class TransformerMultiHeadAttentionAdapter(transformer_modules.MultiHeadAttention, adapter_modules.AdapterModuleUtil): + """Multi-Head Attention layer of Transformer Encoder. + + Args: + hidden_size (int): number of heads + num_attention_heads (int): size of the features + attn_score_dropout (float): dropout rate for the attention scores + attn_layer_dropout (float): dropout rate for the layer + proj_dim (int, optional): Optional integer value for projection before computing attention. + If None, then there is no projection (equivalent to proj_dim = n_feat). + If > 0, then will project the n_feat to proj_dim before calculating attention. + If <0, then will equal n_head, so that each head has a projected dimension of 1. + adapter_strategy: By default, MHAResidualAddAdapterStrategyConfig. An adapter composition function object. + """ + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + attn_score_dropout: float = 0.0, + attn_layer_dropout: float = 0.0, + proj_dim: Optional[int] = None, + adapter_strategy: MHAResidualAddAdapterStrategy = None, + ): + super().__init__( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attn_score_dropout=attn_score_dropout, + attn_layer_dropout=attn_layer_dropout, + ) + + self.pre_norm = nn.LayerNorm(hidden_size) + + # Set the projection dim to number of heads automatically + if proj_dim is not None and proj_dim < 1: + proj_dim = num_attention_heads + + self.proj_dim = proj_dim + + # Recompute weights for projection dim + if self.proj_dim is not None: + if self.proj_dim % num_attention_heads != 0: + raise ValueError(f"proj_dim ({proj_dim}) is not divisible by n_head ({num_attention_heads})") + + self.attn_head_size = self.proj_dim // num_attention_heads + self.attn_scale = math.sqrt(math.sqrt(self.attn_head_size)) + self.query_net = nn.Linear(hidden_size, self.proj_dim) + self.key_net = nn.Linear(hidden_size, self.proj_dim) + self.value_net = nn.Linear(hidden_size, self.proj_dim) + self.out_projection = nn.Linear(self.proj_dim, hidden_size) + + # Setup adapter strategy + self.setup_adapter_strategy(adapter_strategy) + + # reset parameters for Q to be identity operation + self.reset_parameters() + + def forward(self, queries, keys, values, attention_mask): + """Compute 'Scaled Dot Product Attention'. + Args: + query (torch.Tensor): (batch, time1, size) + key (torch.Tensor): (batch, time2, size) + value(torch.Tensor): (batch, time2, size) + mask (torch.Tensor): (batch, time1, time2) + cache (torch.Tensor) : (batch, time_cache, size) + + returns: + output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention + cache (torch.Tensor) : (batch, time_cache_next, size) + """ + # Need to perform duplicate computations as at this point the tensors have been + # separated by the adapter forward + query = self.pre_norm(queries) + key = self.pre_norm(keys) + value = self.pre_norm(values) + + return super().forward(query, key, value, attention_mask) + + def reset_parameters(self): + with torch.no_grad(): + nn.init.zeros_(self.out_projection.weight) + nn.init.zeros_(self.out_projection.bias) + + def get_default_strategy_config(self) -> 'dataclass': + return MHAResidualAddAdapterStrategyConfig() + + +@dataclass +class TransformerMultiHeadAttentionAdapterConfig: + hidden_size: int + num_attention_heads: int + attn_score_dropout: float = 0.0 + attn_layer_dropout: float = 0.0 + proj_dim: Optional[int] = None + adapter_strategy: Optional[Any] = field(default_factory=lambda: MHAResidualAddAdapterStrategyConfig()) + _target_: str = "{0}.{1}".format( + TransformerMultiHeadAttentionAdapter.__module__, TransformerMultiHeadAttentionAdapter.__name__ + ) diff --git a/nemo/collections/asr/parts/submodules/conformer_modules.py b/nemo/collections/asr/parts/submodules/conformer_modules.py index 093cde63c4393..c2d897d632255 100644 --- a/nemo/collections/asr/parts/submodules/conformer_modules.py +++ b/nemo/collections/asr/parts/submodules/conformer_modules.py @@ -17,6 +17,7 @@ from torch import nn as nn from torch.nn import LayerNorm +from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin from nemo.collections.asr.parts.submodules.batchnorm import FusedBatchNorm1d from nemo.collections.asr.parts.submodules.causal_convs import CausalConv1D from nemo.collections.asr.parts.submodules.multi_head_attention import ( @@ -25,15 +26,13 @@ RelPositionMultiHeadAttentionLongformer, ) from nemo.collections.asr.parts.utils.activations import Swish -from nemo.collections.common.parts import adapter_modules from nemo.collections.common.parts.utils import activation_registry from nemo.core.classes.mixins import AccessMixin -from nemo.core.classes.mixins.adapter_mixins import AdapterModuleMixin __all__ = ['ConformerConvolution', 'ConformerFeedForward', 'ConformerLayer'] -class ConformerLayer(torch.nn.Module, AdapterModuleMixin, AccessMixin): +class ConformerLayer(torch.nn.Module, AttentionAdapterModuleMixin, AccessMixin): """A single block of the Conformer encoder. Args: @@ -184,14 +183,14 @@ def forward(self, x, att_mask=None, pos_emb=None, pad_mask=None, cache_last_chan if self.is_adapter_available(): # Call the MHA adapters - pack_ip = { + pack_input = { 'x': residual, 'loc': 'mha', 'att_mask': att_mask, 'pos_emb': pos_emb, } - pack_ip = self.forward_enabled_adapters(pack_ip) - residual = pack_ip['x'] + pack_input = self.forward_enabled_adapters(pack_input) + residual = pack_input['x'] x = self.norm_conv(residual) x = self.conv(x, pad_mask=pad_mask, cache=cache_last_time) @@ -207,12 +206,12 @@ def forward(self, x, att_mask=None, pos_emb=None, pad_mask=None, cache_last_chan if self.is_adapter_available(): # Call the adapters - pack_ip = { + pack_input = { 'x': x, 'loc': 'post', } - pack_ip = self.forward_enabled_adapters(pack_ip) - x = pack_ip['x'] + pack_input = self.forward_enabled_adapters(pack_input) + x = pack_input['x'] if self.is_access_enabled(getattr(self, "model_guid", None)) and self.access_cfg.get( 'save_encoder_tensors', False @@ -223,64 +222,6 @@ def forward(self, x, att_mask=None, pos_emb=None, pad_mask=None, cache_last_chan else: return x, cache_last_channel, cache_last_time - def forward_single_enabled_adapter_( - self, - input: dict, - adapter_module: torch.nn.Module, - *, - adapter_name: str, - adapter_strategy: 'nemo.core.classes.mixins.adapter_mixin_strategies.AbstractAdapterStrategy', - ): - """ - Perform the forward step of a single adapter module on some input data. - - **Note**: Subclasses can override this method to accommodate more complicate adapter forward steps. - - Args: - input: Dictionary of packed tensors. The dict should contain at least - `x`: output tensor - `loc`: Semantic location in module where this adapter was called - `att_mask`: Optional, Attention mask - `pos_emb`: Optional, Positional Embedding for Relative Positional Encoding. - The output tensor of the calling module is the input to the first adapter, whose output - is then chained to the next adapter until all adapters are consumed. - adapter_module: The adapter module that is currently required to perform the forward pass. - adapter_name: The resolved name of the adapter that is undergoing the current forward pass. - adapter_strategy: A subclass of `AbstractAdapterStrategy`, that determines how the - output of the adapter should be merged with the input, or if it should be merged at all. - - Returns: - The result tensor, after the current active adapter has finished its forward pass. - """ - # (input: torch.Tensor, adapter: torch.nn.Module, *, module: 'AdapterModuleMixin') - x = input['x'] - loc = input['loc'] - att_mask = input.get('att_mask', None) - pos_emb = input.get('pos_emb', None) - - if isinstance(adapter_module, adapter_modules.LinearAdapter) and loc == 'post': - output = adapter_strategy(x, adapter_module, module=self) - - elif isinstance(adapter_module, MultiHeadAttention) and loc == 'mha': - if self.self_attention_model == 'rel_pos': - x = dict(query=x, key=x, value=x, mask=att_mask, pos_emb=pos_emb) - output = adapter_strategy(x, adapter_module, module=self) - - elif self.self_attention_model == 'abs_pos': - x = dict(query=x, key=x, value=x, mask=att_mask) - output = adapter_strategy(x, adapter_module, module=self) - - else: - raise ValueError(f"Unsupported value of self_attention_model , provided {self.self_attention_model}!") - - else: - # No adapter compatible, skip - output = x - - input['x'] = output - - return input - class ConformerConvolution(nn.Module): """The convolution module for the Conformer model. diff --git a/nemo/collections/asr/parts/submodules/jasper.py b/nemo/collections/asr/parts/submodules/jasper.py index e53f6299b08aa..78f81ee555bc3 100644 --- a/nemo/collections/asr/parts/submodules/jasper.py +++ b/nemo/collections/asr/parts/submodules/jasper.py @@ -478,7 +478,7 @@ def forward_for_export(self, x, lengths): mask = self.make_pad_mask(lengths, max_audio_length=max_len, device=x.device) mask = ~mask # 0 represents value, 1 represents pad x = x.float() # For stable AMP, SE must be computed at fp32. - x.masked_fill_(mask, 0.0) # mask padded values explicitly to 0 + x = x.masked_fill(mask, 0.0) # mask padded values explicitly to 0 y = self._se_pool_step(x, mask) # [B, C, 1] y = y.transpose(1, -1) # [B, 1, C] y = self.fc(y) # [B, 1, C] @@ -510,8 +510,8 @@ def _se_pool_step(self, x, mask): return y def set_max_len(self, max_len, seq_range=None): - """ Sets maximum input length. - Pre-calculates internal seq_range mask. + """Sets maximum input length. + Pre-calculates internal seq_range mask. """ self.max_len = max_len if seq_range is None: diff --git a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py index ef3a0cddb286c..25becda6fa751 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py @@ -201,8 +201,7 @@ class BeamRNNTInfer(Typing): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" return { "encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), "encoded_lengths": NeuralType(tuple('B'), LengthsType()), @@ -211,8 +210,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return {"predictions": [NeuralType(elements_type=HypothesisType())]} def __init__( @@ -369,7 +367,7 @@ def __call__( return_hat_ilm_default = self.joint.return_hat_ilm self.joint.return_hat_ilm = self.hat_subtract_ilm - with torch.no_grad(): + with torch.inference_mode(): # Apply optional preprocessing encoder_output = encoder_output.transpose(1, 2) # (B, T, D) @@ -384,38 +382,34 @@ def __call__( unit='sample', ) as idx_gen: - # Freeze the decoder and joint to prevent recording of gradients - # during the beam loop. - with self.decoder.as_frozen(), self.joint.as_frozen(): - - _p = next(self.joint.parameters()) - dtype = _p.dtype + _p = next(self.joint.parameters()) + dtype = _p.dtype - # Decode every sample in the batch independently. - for batch_idx in idx_gen: - inseq = encoder_output[batch_idx : batch_idx + 1, : encoded_lengths[batch_idx], :] # [1, T, D] - logitlen = encoded_lengths[batch_idx] + # Decode every sample in the batch independently. + for batch_idx in idx_gen: + inseq = encoder_output[batch_idx : batch_idx + 1, : encoded_lengths[batch_idx], :] # [1, T, D] + logitlen = encoded_lengths[batch_idx] - if inseq.dtype != dtype: - inseq = inseq.to(dtype=dtype) + if inseq.dtype != dtype: + inseq = inseq.to(dtype=dtype) - # Extract partial hypothesis if exists - partial_hypothesis = partial_hypotheses[batch_idx] if partial_hypotheses is not None else None + # Extract partial hypothesis if exists + partial_hypothesis = partial_hypotheses[batch_idx] if partial_hypotheses is not None else None - # Execute the specific search strategy - nbest_hyps = self.search_algorithm( - inseq, logitlen, partial_hypotheses=partial_hypothesis - ) # sorted list of hypothesis + # Execute the specific search strategy + nbest_hyps = self.search_algorithm( + inseq, logitlen, partial_hypotheses=partial_hypothesis + ) # sorted list of hypothesis - # Prepare the list of hypotheses - nbest_hyps = pack_hypotheses(nbest_hyps) + # Prepare the list of hypotheses + nbest_hyps = pack_hypotheses(nbest_hyps) - # Pack the result - if self.return_best_hypothesis: - best_hypothesis = nbest_hyps[0] # type: Hypothesis - else: - best_hypothesis = NBestHypotheses(nbest_hyps) # type: NBestHypotheses - hypotheses.append(best_hypothesis) + # Pack the result + if self.return_best_hypothesis: + best_hypothesis = nbest_hyps[0] # type: Hypothesis + else: + best_hypothesis = NBestHypotheses(nbest_hyps) # type: NBestHypotheses + hypotheses.append(best_hypothesis) self.decoder.train(decoder_training_state) self.joint.train(joint_training_state) @@ -639,7 +633,10 @@ def default_beam_search( # keep those hypothesis that have scores greater than next search generation hyps_max = float(max(hyps, key=lambda x: x.score).score) - kept_most_prob = sorted([hyp for hyp in kept_hyps if hyp.score > hyps_max], key=lambda x: x.score,) + kept_most_prob = sorted( + [hyp for hyp in kept_hyps if hyp.score > hyps_max], + key=lambda x: x.score, + ) # If enough hypothesis have scores greater than next search generation, # stop beam search. diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index 420e49c961420..70ab74e7b0148 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -383,14 +383,13 @@ def forward( hypotheses = [] # Process each sequence independently - with self.decoder.as_frozen(), self.joint.as_frozen(): - for batch_idx in range(encoder_output.size(0)): - inseq = encoder_output[batch_idx, :, :].unsqueeze(1) # [T, 1, D] - logitlen = encoded_lengths[batch_idx] + for batch_idx in range(encoder_output.size(0)): + inseq = encoder_output[batch_idx, :, :].unsqueeze(1) # [T, 1, D] + logitlen = encoded_lengths[batch_idx] - partial_hypothesis = partial_hypotheses[batch_idx] if partial_hypotheses is not None else None - hypothesis = self._greedy_decode(inseq, logitlen, partial_hypotheses=partial_hypothesis) - hypotheses.append(hypothesis) + partial_hypothesis = partial_hypotheses[batch_idx] if partial_hypotheses is not None else None + hypothesis = self._greedy_decode(inseq, logitlen, partial_hypotheses=partial_hypothesis) + hypotheses.append(hypothesis) # Pack results into Hypotheses packed_result = pack_hypotheses(hypotheses, encoded_lengths) @@ -720,12 +719,11 @@ def forward( self.decoder.eval() self.joint.eval() - with self.decoder.as_frozen(), self.joint.as_frozen(): - inseq = encoder_output # [B, T, D] + inseq = encoder_output # [B, T, D] - hypotheses = self._greedy_decode( - inseq, logitlen, device=inseq.device, partial_hypotheses=partial_hypotheses - ) + hypotheses = self._greedy_decode( + inseq, logitlen, device=inseq.device, partial_hypotheses=partial_hypotheses + ) # Pack the hypotheses results packed_result = pack_hypotheses(hypotheses, logitlen) @@ -2487,14 +2485,13 @@ def forward( hypotheses = [] # Process each sequence independently - with self.decoder.as_frozen(), self.joint.as_frozen(): - for batch_idx in range(encoder_output.size(0)): - inseq = encoder_output[batch_idx, :, :].unsqueeze(1) # [T, 1, D] - logitlen = encoded_lengths[batch_idx] + for batch_idx in range(encoder_output.size(0)): + inseq = encoder_output[batch_idx, :, :].unsqueeze(1) # [T, 1, D] + logitlen = encoded_lengths[batch_idx] - partial_hypothesis = partial_hypotheses[batch_idx] if partial_hypotheses is not None else None - hypothesis = self._greedy_decode(inseq, logitlen, partial_hypotheses=partial_hypothesis) - hypotheses.append(hypothesis) + partial_hypothesis = partial_hypotheses[batch_idx] if partial_hypotheses is not None else None + hypothesis = self._greedy_decode(inseq, logitlen, partial_hypotheses=partial_hypothesis) + hypotheses.append(hypothesis) # Pack results into Hypotheses packed_result = pack_hypotheses(hypotheses, encoded_lengths) @@ -2775,11 +2772,10 @@ def forward( self.decoder.eval() self.joint.eval() - with self.decoder.as_frozen(), self.joint.as_frozen(): - inseq = encoder_output # [B, T, D] - hypotheses = self._greedy_decode( - inseq, logitlen, device=inseq.device, partial_hypotheses=partial_hypotheses - ) + inseq = encoder_output # [B, T, D] + hypotheses = self._greedy_decode( + inseq, logitlen, device=inseq.device, partial_hypotheses=partial_hypotheses + ) # Pack the hypotheses results packed_result = pack_hypotheses(hypotheses, logitlen) diff --git a/nemo/collections/asr/parts/submodules/squeezeformer_modules.py b/nemo/collections/asr/parts/submodules/squeezeformer_modules.py index ff2cf7c5b3cc4..212320e1f76fc 100644 --- a/nemo/collections/asr/parts/submodules/squeezeformer_modules.py +++ b/nemo/collections/asr/parts/submodules/squeezeformer_modules.py @@ -16,14 +16,13 @@ from torch import nn as nn from torch.nn import LayerNorm +from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin from nemo.collections.asr.parts.submodules.conformer_modules import ConformerConvolution, ConformerFeedForward from nemo.collections.asr.parts.submodules.multi_head_attention import ( MultiHeadAttention, RelPositionMultiHeadAttention, ) -from nemo.collections.common.parts import adapter_modules from nemo.core.classes.mixins import AccessMixin -from nemo.core.classes.mixins.adapter_mixins import AdapterModuleMixin __all__ = ['SqueezeformerLayer', 'ConformerFeedForward', 'SqueezeformerLayer'] @@ -57,7 +56,7 @@ def forward(self, x): return x * scale + bias -class SqueezeformerLayer(torch.nn.Module, AdapterModuleMixin, AccessMixin): +class SqueezeformerLayer(torch.nn.Module, AttentionAdapterModuleMixin, AccessMixin): """A single block of the Squeezeformer encoder. Args: @@ -197,64 +196,6 @@ def forward(self, x, att_mask=None, pos_emb=None, pad_mask=None): return x - def forward_single_enabled_adapter_( - self, - input: dict, - adapter_module: torch.nn.Module, - *, - adapter_name: str, - adapter_strategy: 'nemo.core.classes.mixins.adapter_mixin_strategies.AbstractAdapterStrategy', - ): - """ - Perform the forward step of a single adapter module on some input data. - - **Note**: Subclasses can override this method to accommodate more complicate adapter forward steps. - - Args: - input: Dictionary of packed tensors. The dict should contain at least - `x`: output tensor - `loc`: Semantic location in module where this adapter was called - `att_mask`: Optional, Attention mask - `pos_emb`: Optional, Positional Embedding for Relative Positional Encoding. - The output tensor of the calling module is the input to the first adapter, whose output - is then chained to the next adapter until all adapters are consumed. - adapter_module: The adapter module that is currently required to perform the forward pass. - adapter_name: The resolved name of the adapter that is undergoing the current forward pass. - adapter_strategy: A subclass of `AbstractAdapterStrategy`, that determines how the - output of the adapter should be merged with the input, or if it should be merged at all. - - Returns: - The result tensor, after the current active adapter has finished its forward pass. - """ - # (input: torch.Tensor, adapter: torch.nn.Module, *, module: 'AdapterModuleMixin') - x = input['x'] - loc = input['loc'] - att_mask = input.get('att_mask', None) - pos_emb = input.get('pos_emb', None) - - if isinstance(adapter_module, adapter_modules.LinearAdapter) and loc == 'post': - output = adapter_strategy(x, adapter_module, module=self) - - elif isinstance(adapter_module, MultiHeadAttention) and loc == 'mha': - if self.self_attention_model == 'rel_pos': - x = dict(query=x, key=x, value=x, mask=att_mask, pos_emb=pos_emb) - output = adapter_strategy(x, adapter_module, module=self) - - elif self.self_attention_model == 'abs_pos': - x = dict(query=x, key=x, value=x, mask=att_mask) - output = adapter_strategy(x, adapter_module, module=self) - - else: - raise ValueError(f"Unsupported value of self_attention_model , provided {self.self_attention_model}!") - - else: - # No adapter compatible, skip - output = x - - input['x'] = output - - return input - def reset_parameters(self): # Used for Squeezeformer initialization only self.feed_forward1.reset_parameters_ff() diff --git a/nemo/collections/asr/parts/utils/adapter_utils.py b/nemo/collections/asr/parts/utils/adapter_utils.py index 5b74a296419a9..b85bdee7051ad 100644 --- a/nemo/collections/asr/parts/utils/adapter_utils.py +++ b/nemo/collections/asr/parts/utils/adapter_utils.py @@ -21,6 +21,8 @@ # Constants LINEAR_ADAPTER_CLASSPATH = "nemo.collections.common.parts.adapter_modules.LinearAdapter" + +# Conformer Adapters MHA_ADAPTER_CLASSPATH = ( "nemo.collections.asr.parts.submodules.adapters.multi_head_attention_adapter_module.MultiHeadAttentionAdapter" ) @@ -32,6 +34,9 @@ "nemo.collections.asr.parts.submodules.adapters.multi_head_attention_adapter_module.RelPositionalEncodingAdapter" ) +# Transformer Adapters +TRANSFORMER_MHA_ADAPTER_CLASSPATH = "nemo.collections.asr.parts.submodules.adapters.transformer_multi_head_attention_adapter_module.TransformerMultiHeadAttentionAdapter" + def convert_adapter_cfg_to_dict_config(cfg: DictConfig): # Convert to DictConfig from dict or Dataclass @@ -58,7 +63,7 @@ def update_adapter_cfg_input_dim(module: torch.nn.Module, cfg: DictConfig, *, mo """ cfg = convert_adapter_cfg_to_dict_config(cfg) - input_dim_valid_keys = ['in_features', 'n_feat'] + input_dim_valid_keys = ['in_features', 'n_feat', 'hidden_size'] input_key = None for key in input_dim_valid_keys: diff --git a/nemo/collections/asr/parts/utils/decoder_timestamps_utils.py b/nemo/collections/asr/parts/utils/decoder_timestamps_utils.py index 8ed143d3c2216..a740f899ca67d 100644 --- a/nemo/collections/asr/parts/utils/decoder_timestamps_utils.py +++ b/nemo/collections/asr/parts/utils/decoder_timestamps_utils.py @@ -23,13 +23,13 @@ import nemo.collections.asr as nemo_asr from nemo.collections.asr.metrics.wer import WER from nemo.collections.asr.models import EncDecCTCModel, EncDecCTCModelBPE +from nemo.collections.asr.parts.preprocessing.segment import get_samples from nemo.collections.asr.parts.submodules.ctc_decoding import ( CTCBPEDecoding, CTCBPEDecodingConfig, CTCDecoding, CTCDecodingConfig, ) -from nemo.collections.asr.parts.utils.audio_utils import get_samples from nemo.collections.asr.parts.utils.speaker_utils import audio_rttm_map, get_uniqname_from_filepath from nemo.collections.asr.parts.utils.streaming_utils import AudioFeatureIterator, FrameBatchASR from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec @@ -197,7 +197,9 @@ def decode_ids_to_tokens_with_ts(self, tokens: List[int], timestamps: List[int]) return token_list, timestamp_list def ctc_decoder_predictions_tensor_with_ts( - self, predictions: torch.Tensor, predictions_len: torch.Tensor = None, + self, + predictions: torch.Tensor, + predictions_len: torch.Tensor = None, ) -> List[str]: """ A shortened version of the original function ctc_decoder_predictions_tensor(). @@ -286,7 +288,9 @@ def _get_batch_preds(self, keep_logits): del predictions def transcribe_with_ts( - self, tokens_per_chunk: int, delay: int, + self, + tokens_per_chunk: int, + delay: int, ): self.infer_logits() self.unmerged = [] @@ -720,7 +724,10 @@ def get_word_ts_from_spaces(self, char_ts: List[float], spaces_in_sec: List[floa elif len(spaces_in_sec) > 0: # word_timetamps_middle should be an empty list if len(spaces_in_sec) == 1. word_timetamps_middle = [ - [round(spaces_in_sec[k][1], 2), round(spaces_in_sec[k + 1][0], 2),] + [ + round(spaces_in_sec[k][1], 2), + round(spaces_in_sec[k + 1][0], 2), + ] for k in range(len(spaces_in_sec) - 1) ] word_timestamps = ( diff --git a/nemo/collections/asr/parts/utils/streaming_utils.py b/nemo/collections/asr/parts/utils/streaming_utils.py index 51a46184e66fc..bae2c9ffdc67e 100644 --- a/nemo/collections/asr/parts/utils/streaming_utils.py +++ b/nemo/collections/asr/parts/utils/streaming_utils.py @@ -24,7 +24,7 @@ from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE from nemo.collections.asr.parts.mixins.streaming import StreamingEncoder from nemo.collections.asr.parts.preprocessing.features import normalize_batch -from nemo.collections.asr.parts.utils.audio_utils import get_samples +from nemo.collections.asr.parts.preprocessing.segment import get_samples from nemo.core.classes import IterableDataset from nemo.core.neural_types import LengthsType, MelSpectrogramType, NeuralType diff --git a/nemo/collections/audio/README.md b/nemo/collections/audio/README.md new file mode 100644 index 0000000000000..45a0adc931dfe --- /dev/null +++ b/nemo/collections/audio/README.md @@ -0,0 +1,10 @@ +# Audio processing collection + +The NeMo Audio Collection supports a range of models tailored for audio processing tasks, including single- and multi-channel speech enhancement and restoration. + +* Mask-based speech processing: single-channel masking and guided source separation (GSS) +* Predictive speech processing: NCSN++ +* Score-based generative models: SGMSE+ +* Multi-channel audio processing: mask-based beamforming (MVDR) and dereverberation (WPE) + +More details can be found in [NeMo documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/index.html). diff --git a/nemo/collections/audio/__init__.py b/nemo/collections/audio/__init__.py new file mode 100644 index 0000000000000..f3d1566094871 --- /dev/null +++ b/nemo/collections/audio/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.audio import data, losses, metrics, models, modules +from nemo.package_info import __version__ + +# Set collection version equal to NeMo version. +__version = __version__ + +# Authorship. +__author__ = "NVIDIA Corporation" + +# Set collection name. +__description__ = "Audio Processing collection" diff --git a/nemo/collections/audio/data/__init__.py b/nemo/collections/audio/data/__init__.py new file mode 100644 index 0000000000000..d9155f923f186 --- /dev/null +++ b/nemo/collections/audio/data/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/asr/data/audio_to_audio.py b/nemo/collections/audio/data/audio_to_audio.py similarity index 97% rename from nemo/collections/asr/data/audio_to_audio.py rename to nemo/collections/audio/data/audio_to_audio.py index 4f4727239a4b3..78d863e312d17 100644 --- a/nemo/collections/asr/data/audio_to_audio.py +++ b/nemo/collections/audio/data/audio_to_audio.py @@ -23,8 +23,7 @@ import numpy as np import torch -from nemo.collections.asr.parts.preprocessing.segment import AudioSegment -from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType +from nemo.collections.asr.parts.preprocessing.segment import AudioSegment, ChannelSelectorType from nemo.collections.common.parts.preprocessing import collections from nemo.collections.common.parts.utils import flatten from nemo.core.classes import Dataset @@ -137,7 +136,11 @@ class ASRAudioProcessor: """ def __init__( - self, sample_rate: float, random_offset: bool, normalization_signal: Optional[str] = None, eps: float = 1e-8, + self, + sample_rate: float, + random_offset: bool, + normalization_signal: Optional[str] = None, + eps: float = 1e-8, ): self.sample_rate = sample_rate self.random_offset = random_offset @@ -226,8 +229,7 @@ def async_setup(self, value: Optional[SignalSetup]): @property def embedding_setup(self) -> SignalSetup: - """Setup signals corresponding to an embedding vector. - """ + """Setup signals corresponding to an embedding vector.""" return self._embedding_setup @embedding_setup.setter @@ -477,7 +479,7 @@ def get_samples_synchronized( available_duration = min_audio_duration - fixed_offset if available_duration <= 0: - raise ValueError(f'Fixed offset {fixed_offset}s is larger than shortest file {min_duration}s.') + raise ValueError(f'Fixed offset {fixed_offset}s is larger than shortest file {min_audio_duration}s.') if duration + fixed_offset > min_audio_duration: # The shortest file is shorter than the requested duration @@ -584,11 +586,14 @@ def get_segment_from_file( channel_selector: Select a subset of available channels. Returns: - An array with shape (samples,) or (channels, samples) + An array with shape (samples,) or (channels, samples) """ if num_samples is None: segment = AudioSegment.from_file( - audio_file=audio_file, target_sr=sample_rate, offset=offset, channel_selector=channel_selector, + audio_file=audio_file, + target_sr=sample_rate, + offset=offset, + channel_selector=channel_selector, ) else: @@ -682,7 +687,7 @@ def load_embedding_vector(filepath: str) -> np.ndarray: Args: filepath: path to a file storing a vector. Currently, it is assumed the file is a npy file. - + Returns: Array loaded from filepath. """ @@ -709,12 +714,10 @@ class BaseAudioDataset(Dataset): @property @abc.abstractmethod def output_types(self) -> Optional[Dict[str, NeuralType]]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" def __init__(self, collection: collections.Audio, audio_processor: Callable, output_type: Type[namedtuple]): - """Instantiates an audio dataset. - """ + """Instantiates an audio dataset.""" super().__init__() self.collection = collection @@ -732,7 +735,7 @@ def num_channels(self, signal_key) -> int: NOTE: This assumes that all examples have the same number of channels. - + Args: signal_key: string, used to select a signal from the dictionary output by __getitem__ @@ -774,13 +777,11 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: return output def __len__(self) -> int: - """Return the number of examples in the dataset. - """ + """Return the number of examples in the dataset.""" return len(self.collection) def _collate_fn(self, batch) -> Tuple[torch.Tensor]: - """Collate items in a batch. - """ + """Collate items in a batch.""" return self.output_type(*_audio_collate_fn(batch)) @@ -865,7 +866,9 @@ def __init__( ) audio_processor = ASRAudioProcessor( - sample_rate=sample_rate, random_offset=random_offset, normalization_signal=normalization_signal, + sample_rate=sample_rate, + random_offset=random_offset, + normalization_signal=normalization_signal, ) audio_processor.sync_setup = SignalSetup( signals=['input_signal', 'target_signal'], @@ -886,7 +889,7 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: 'input_signal': batched single- or multi-channel format, 'input_length': batched original length of each input signal 'target_signal': batched single- or multi-channel format, - 'target_length': batched original length of each target signal + 'target_length': batched original length of each target signal } ``` """ @@ -996,7 +999,9 @@ def __init__( ) audio_processor = ASRAudioProcessor( - sample_rate=sample_rate, random_offset=random_offset, normalization_signal=normalization_signal, + sample_rate=sample_rate, + random_offset=random_offset, + normalization_signal=normalization_signal, ) if reference_is_synchronized: @@ -1130,7 +1135,9 @@ def __init__( ) audio_processor = ASRAudioProcessor( - sample_rate=sample_rate, random_offset=random_offset, normalization_signal=normalization_signal, + sample_rate=sample_rate, + random_offset=random_offset, + normalization_signal=normalization_signal, ) audio_processor.sync_setup = SignalSetup( signals=['input_signal', 'target_signal'], diff --git a/nemo/collections/asr/data/audio_to_audio_dataset.py b/nemo/collections/audio/data/audio_to_audio_dataset.py similarity index 98% rename from nemo/collections/asr/data/audio_to_audio_dataset.py rename to nemo/collections/audio/data/audio_to_audio_dataset.py index 46e47020fda0b..38ea5ef9cd39e 100644 --- a/nemo/collections/asr/data/audio_to_audio_dataset.py +++ b/nemo/collections/audio/data/audio_to_audio_dataset.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemo.collections.asr.data import audio_to_audio +from nemo.collections.audio.data import audio_to_audio def get_audio_to_target_dataset(config: dict) -> audio_to_audio.AudioToTargetDataset: diff --git a/nemo/collections/asr/data/audio_to_audio_lhotse.py b/nemo/collections/audio/data/audio_to_audio_lhotse.py similarity index 98% rename from nemo/collections/asr/data/audio_to_audio_lhotse.py rename to nemo/collections/audio/data/audio_to_audio_lhotse.py index 6317d8a929c20..27d8a0ed28d74 100644 --- a/nemo/collections/asr/data/audio_to_audio_lhotse.py +++ b/nemo/collections/audio/data/audio_to_audio_lhotse.py @@ -104,7 +104,12 @@ def create_array(path: str) -> Array: assert path.endswith(".npy"), f"Currently only conversion of numpy files is supported (got: {path})" arr = np.load(path) parent, path = os.path.split(path) - return Array(storage_type="numpy_files", storage_path=parent, storage_key=path, shape=list(arr.shape),) + return Array( + storage_type="numpy_files", + storage_path=parent, + storage_key=path, + shape=list(arr.shape), + ) def convert_manifest_nemo_to_lhotse( @@ -118,7 +123,7 @@ def convert_manifest_nemo_to_lhotse( ): """ Convert an audio-to-audio manifest from NeMo format to Lhotse format. - + Args: input_manifest: Path to the input NeMo manifest. output_manifest: Path where we'll write the output Lhotse manifest (supported extensions: .jsonl.gz and .jsonl). diff --git a/nemo/collections/audio/data/data_simulation.py b/nemo/collections/audio/data/data_simulation.py new file mode 100644 index 0000000000000..d03c5c64d307d --- /dev/null +++ b/nemo/collections/audio/data/data_simulation.py @@ -0,0 +1,2385 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import multiprocessing +import os +import random +from typing import Dict, Iterable, List, Optional, Tuple, Union + +import h5py +import librosa +import matplotlib.pyplot as plt +import numpy as np +import soundfile as sf +from numpy.random import default_rng +from omegaconf import DictConfig, OmegaConf +from scipy.signal import convolve +from scipy.spatial.transform import Rotation +from tqdm import tqdm + +from nemo.collections.asr.parts.preprocessing.segment import AudioSegment +from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest +from nemo.collections.audio.parts.utils.audio import db2mag, generate_approximate_noise_field, mag2db, pow2db, rms +from nemo.utils import logging + +try: + import pyroomacoustics as pra + + PRA = True +except ImportError: + PRA = False + + +def check_angle(key: str, val: Union[float, Iterable[float]]) -> bool: + """Check if the angle value is within the expected range. Input + values are in degrees. + + Note: + azimuth: angle between a projection on the horizontal (xy) plane and + positive x axis. Increases counter-clockwise. Range: [-180, 180]. + elevation: angle between a vector an its projection on the horizontal (xy) plane. + Positive above, negative below, i.e., north=+90, south=-90. Range: [-90, 90] + yaw: rotation around the z axis. Defined accoding to right-hand rule. + Range: [-180, 180] + pitch: rotation around the yʹ axis. Defined accoding to right-hand rule. + Range: [-90, 90] + roll: rotation around the xʺ axis. Defined accoding to right-hand rule. + Range: [-180, 180] + + Args: + key: angle type + val: values in degrees + + Returns: + True if all values are within the expected range. + """ + if np.isscalar(val): + min_val = max_val = val + else: + min_val = min(val) + max_val = max(val) + + if key == 'azimuth' and -180 <= min_val <= max_val <= 180: + return True + if key == 'elevation' and -90 <= min_val <= max_val <= 90: + return True + if key == 'yaw' and -180 <= min_val <= max_val <= 180: + return True + if key == 'pitch' and -90 <= min_val <= max_val <= 90: + return True + if key == 'roll' and -180 <= min_val <= max_val <= 180: + return True + + raise ValueError(f'Invalid value for angle {key} = {val}') + + +def wrap_to_180(angle: float) -> float: + """Wrap an angle to range ±180 degrees. + + Args: + angle: angle in degrees + + Returns: + Angle in degrees wrapped to ±180 degrees. + """ + return angle - np.floor(angle / 360 + 1 / 2) * 360 + + +class ArrayGeometry(object): + """A class to simplify handling of array geometry. + + Supports translation and rotation of the array and calculation of + spherical coordinates of a given point relative to the internal + coordinate system of the array. + + Args: + mic_positions: 3D coordinates, with shape (num_mics, 3) + center: optional position of the center of the array. Defaults to the average of the coordinates. + internal_cs: internal coordinate system for the array relative to the global coordinate system. + Defaults to (x, y, z), and is rotated with the array. + """ + + def __init__( + self, + mic_positions: Union[np.ndarray, List], + center: Optional[np.ndarray] = None, + internal_cs: Optional[np.ndarray] = None, + ): + if isinstance(mic_positions, Iterable): + mic_positions = np.array(mic_positions) + + if not mic_positions.ndim == 2: + raise ValueError( + f'Expecting a 2D array specifying mic positions, but received {mic_positions.ndim}-dim array' + ) + + if not mic_positions.shape[1] == 3: + raise ValueError(f'Expecting 3D positions, but received {mic_positions.shape[1]}-dim positions') + + mic_positions_center = np.mean(mic_positions, axis=0) + self.centered_positions = mic_positions - mic_positions_center + self.center = mic_positions_center if center is None else center + + # Internal coordinate system + if internal_cs is None: + # Initially aligned with the global + self.internal_cs = np.eye(3) + else: + self.internal_cs = internal_cs + + @property + def num_mics(self): + """Return the number of microphones for the current array.""" + return self.centered_positions.shape[0] + + @property + def positions(self): + """Absolute positions of the microphones.""" + return self.centered_positions + self.center + + @property + def internal_positions(self): + """Positions in the internal coordinate system.""" + return np.matmul(self.centered_positions, self.internal_cs.T) + + @property + def radius(self): + """Radius of the array, relative to the center.""" + return max(np.linalg.norm(self.centered_positions, axis=1)) + + @staticmethod + def get_rotation(yaw: float = 0, pitch: float = 0, roll: float = 0) -> Rotation: + """Get a Rotation object for given angles. + + All angles are defined according to the right-hand rule. + + Args: + yaw: rotation around the z axis + pitch: rotation around the yʹ axis + roll: rotation around the xʺ axis + + Returns: + A rotation object constructed using the provided angles. + """ + check_angle('yaw', yaw) + check_angle('pitch', pitch) + check_angle('roll', roll) + + return Rotation.from_euler('ZYX', [yaw, pitch, roll], degrees=True) + + def translate(self, to: np.ndarray): + """Translate the array center to a new point. + + Translation does not change the centered positions or the internal coordinate system. + + Args: + to: 3D point, shape (3,) + """ + self.center = to + + def rotate(self, yaw: float = 0, pitch: float = 0, roll: float = 0): + """Apply rotation on the mic array. + + This rotates the centered microphone positions and the internal + coordinate system, it doesn't change the center of the array. + + All angles are defined according to the right-hand rule. + For example, this means that a positive pitch will result in a rotation from z + to x axis, which will result in a reduced elevation with respect to the global + horizontal plane. + + Args: + yaw: rotation around the z axis + pitch: rotation around the yʹ axis + roll: rotation around the xʺ axis + """ + # construct rotation using TB angles + rotation = self.get_rotation(yaw=yaw, pitch=pitch, roll=roll) + + # rotate centered positions + self.centered_positions = rotation.apply(self.centered_positions) + + # apply the same transformation on the internal coordinate system + self.internal_cs = rotation.apply(self.internal_cs) + + def new_rotated_array(self, yaw: float = 0, pitch: float = 0, roll: float = 0): + """Create a new array by rotating this array. + + Args: + yaw: rotation around the z axis + pitch: rotation around the yʹ axis + roll: rotation around the xʺ axis + + Returns: + A new ArrayGeometry object constructed using the provided angles. + """ + new_array = ArrayGeometry(mic_positions=self.positions, center=self.center, internal_cs=self.internal_cs) + new_array.rotate(yaw=yaw, pitch=pitch, roll=roll) + return new_array + + def spherical_relative_to_array( + self, point: np.ndarray, use_internal_cs: bool = True + ) -> Tuple[float, float, float]: + """Return spherical coordinates of a point relative to the internal coordinate system. + + Args: + point: 3D coordinate, shape (3,) + use_internal_cs: Calculate position relative to the internal coordinate system. + If `False`, the positions will be calculated relative to the + external coordinate system centered at `self.center`. + + Returns: + A tuple (distance, azimuth, elevation) relative to the mic array. + """ + rel_position = point - self.center + distance = np.linalg.norm(rel_position) + + if use_internal_cs: + # transform from the absolute coordinate system to the internal coordinate system + rel_position = np.matmul(self.internal_cs, rel_position) + + # get azimuth + azimuth = np.arctan2(rel_position[1], rel_position[0]) / np.pi * 180 + # get elevation + elevation = np.arcsin(rel_position[2] / distance) / np.pi * 180 + + return distance, azimuth, elevation + + def __str__(self): + with np.printoptions(precision=3, suppress=True): + desc = f"{type(self)}:\ncenter =\n{self.center}\ncentered positions =\n{self.centered_positions}\nradius = \n{self.radius:.3}\nabsolute positions =\n{self.positions}\ninternal coordinate system =\n{self.internal_cs}\n\n" + return desc + + def plot(self, elev=30, azim=-55, mic_size=25): + """Plot microphone positions. + + Args: + elev: elevation for the view of the plot + azim: azimuth for the view of the plot + mic_size: size of the microphone marker in the plot + """ + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + + # show mic positions + for m in range(self.num_mics): + # show mic + ax.scatter( + self.positions[m, 0], + self.positions[m, 1], + self.positions[m, 2], + marker='o', + c='black', + s=mic_size, + depthshade=False, + ) + # add label + ax.text(self.positions[m, 0], self.positions[m, 1], self.positions[m, 2], str(m), c='red', zorder=10) + + # show the internal coordinate system + ax.quiver( + self.center[0], + self.center[1], + self.center[2], + self.internal_cs[:, 0], + self.internal_cs[:, 1], + self.internal_cs[:, 2], + length=self.radius, + label='internal cs', + normalize=False, + linestyle=':', + linewidth=1.0, + ) + for dim, label in enumerate(['x′', 'y′', 'z′']): + label_pos = self.center + self.radius * self.internal_cs[dim] + ax.text(label_pos[0], label_pos[1], label_pos[2], label, tuple(self.internal_cs[dim]), c='blue') + try: + # Unfortunately, equal aspect ratio has been added very recently to Axes3D + ax.set_aspect('equal') + except NotImplementedError: + logging.warning('Equal aspect ratio not supported by Axes3D') + # Set view + ax.view_init(elev=elev, azim=azim) + # Set reasonable limits for all axes, even for the case of an unequal aspect ratio + ax.set_xlim([self.center[0] - self.radius, self.center[0] + self.radius]) + ax.set_ylim([self.center[1] - self.radius, self.center[1] + self.radius]) + ax.set_zlim([self.center[2] - self.radius, self.center[2] + self.radius]) + + ax.set_xlabel('x/m') + ax.set_ylabel('y/m') + ax.set_zlabel('z/m') + ax.set_title('Microphone positions') + ax.legend() + plt.show() + + +def convert_placement_to_range( + placement: dict, room_dim: Iterable[float], object_radius: float = 0 +) -> List[List[float]]: + """Given a placement dictionary, return ranges for each dimension. + + Args: + placement: dictionary containing x, y, height, and min_to_wall + room_dim: dimensions of the room, shape (3,) + object_radius: radius of the object to be placed + + Returns + List with a range of values for each dimensions. + """ + if not np.all(np.array(room_dim) > 0): + raise ValueError(f'Room dimensions must be positive: {room_dim}') + + if object_radius < 0: + raise ValueError(f'Object radius must be non-negative: {object_radius}') + + placement_range = [None] * 3 + min_to_wall = placement.get('min_to_wall', 0) + + if min_to_wall < 0: + raise ValueError(f'Min distance to wall must be positive: {min_to_wall}') + + for idx, key in enumerate(['x', 'y', 'height']): + # Room dimension + dim = room_dim[idx] + # Construct the range + val = placement.get(key) + if val is None: + # No constrained specified on the coordinate of the mic center + min_val, max_val = 0, dim + elif np.isscalar(val): + min_val = max_val = val + else: + if len(val) != 2: + raise ValueError(f'Invalid value for placement for dim {idx}/{key}: {str(placement)}') + min_val, max_val = val + + # Make sure the array is not too close to a wall + min_val = max(min_val, min_to_wall + object_radius) + max_val = min(max_val, dim - min_to_wall - object_radius) + + if min_val > max_val or min(min_val, max_val) < 0: + raise ValueError(f'Invalid range dim {idx}/{key}: min={min_val}, max={max_val}') + + placement_range[idx] = [min_val, max_val] + + return placement_range + + +class RIRCorpusGenerator(object): + """Creates a corpus of RIRs based on a defined configuration of rooms and microphone array. + + RIRs are generated using `generate` method. + """ + + def __init__(self, cfg: DictConfig): + """ + Args: + cfg: dictionary with parameters of the simulation + """ + logging.info("Initialize RIRCorpusGenerator") + self._cfg = cfg + self.check_cfg() + + @property + def cfg(self): + """Property holding the internal config of the object. + + Note: + Changes to this config are not reflected in the state of the object. + Please create a new model with the updated config. + """ + return self._cfg + + @property + def sample_rate(self): + return self._cfg.sample_rate + + @cfg.setter + def cfg(self, cfg): + """Property holding the internal config of the object. + + Note: + Changes to this config are not reflected in the state of the object. + Please create a new model with the updated config. + """ + self._cfg = cfg + + def check_cfg(self): + """ + Checks provided configuration to ensure it has the minimal required + configuration the values are in a reasonable range. + """ + # sample rate + sample_rate = self.cfg.get('sample_rate') + if sample_rate is None: + raise ValueError('Sample rate not provided.') + elif sample_rate < 0: + raise ValueError(f'Sample rate must to be positive: {sample_rate}') + + # room configuration + room_cfg = self.cfg.get('room') + if room_cfg is None: + raise ValueError('Room configuration not provided') + + if room_cfg.get('num') is None: + raise ValueError('Number of rooms per subset not provided') + + if room_cfg.get('dim') is None: + raise ValueError('Room dimensions not provided') + + for idx, key in enumerate(['width', 'length', 'height']): + dim = room_cfg.dim.get(key) + + if dim is None: + # not provided + raise ValueError(f'Room {key} needs to be a scalar or a range, currently it is None') + elif np.isscalar(dim) and dim <= 0: + # fixed dimension + raise ValueError(f'A fixed dimension must be positive for {key}: {dim}') + elif len(dim) != 2 or not 0 < dim[0] < dim[1]: + # not a valid range + raise ValueError(f'Range must be specified with two positive increasing elements for {key}: {dim}') + + rt60 = room_cfg.get('rt60') + if rt60 is None: + # not provided + raise ValueError('RT60 needs to be a scalar or a range, currently it is None') + elif np.isscalar(rt60) and rt60 <= 0: + # fixed dimension + raise ValueError(f'RT60 must be positive: {rt60}') + elif len(rt60) != 2 or not 0 < rt60[0] < rt60[1]: + # not a valid range + raise ValueError(f'RT60 range must be specified with two positive increasing elements: {rt60}') + + # mic array + mic_cfg = self.cfg.get('mic_array') + if mic_cfg is None: + raise ValueError('Mic configuration not provided') + + if mic_cfg.get('positions') == 'random': + # Only num_mics and placement are required + mic_cfg_keys = ['num_mics', 'placement'] + else: + mic_cfg_keys = ['positions', 'placement', 'orientation'] + + for key in mic_cfg_keys: + if key not in mic_cfg: + raise ValueError(f'Mic array {key} not provided') + + # source + source_cfg = self.cfg.get('source') + if source_cfg is None: + raise ValueError('Source configuration not provided') + + if source_cfg.get('num') is None: + raise ValueError('Number of sources per room not provided') + elif source_cfg.num <= 0: + raise ValueError(f'Number of sources must be positive: {source_cfg.num}') + + if 'placement' not in source_cfg: + raise ValueError('Source placement dictionary not provided') + + # anechoic + if self.cfg.get('anechoic') is None: + raise ValueError('Anechoic configuratio not provided.') + + def generate_room_params(self) -> dict: + """Generate randomized room parameters based on the provided + configuration. + """ + # Prepare room sim parameters + if not PRA: + raise ImportError('pyroomacoustics is required for room simulation') + + room_cfg = self.cfg.room + + # Prepare rt60 + if room_cfg.rt60 is None: + raise ValueError('Room RT60 needs to be a scalar or a range, currently it is None') + + if np.isscalar(room_cfg.rt60): + assert room_cfg.rt60 > 0, f'RT60 should be positive: {room_cfg.rt60}' + rt60 = room_cfg.rt60 + elif len(room_cfg.rt60) == 2: + assert ( + 0 < room_cfg.rt60[0] <= room_cfg.rt60[1] + ), f'Expecting two non-decreasing values for RT60, received {room_cfg.rt60}' + rt60 = self.random.uniform(low=room_cfg.rt60[0], high=room_cfg.rt60[1]) + else: + raise ValueError(f'Unexpected value for RT60: {room_cfg.rt60}') + + # Generate a room with random dimensions + num_retries = self.cfg.get('num_retries', 20) + + for n in range(num_retries): + + # width, length, height + room_dim = np.zeros(3) + + # prepare dimensions + for idx, key in enumerate(['width', 'length', 'height']): + # get configured dimension + dim = room_cfg.dim[key] + + # set a value + if dim is None: + raise ValueError(f'Room {key} needs to be a scalar or a range, currently it is None') + elif np.isscalar(dim): + assert dim > 0, f'Dimension should be positive for {key}: {dim}' + room_dim[idx] = dim + elif len(dim) == 2: + assert 0 < dim[0] <= dim[1], f'Expecting two non-decreasing values for {key}, received {dim}' + # Reduce dimension if the previous attempt failed + room_dim[idx] = self.random.uniform(low=dim[0], high=dim[1] - n * (dim[1] - dim[0]) / num_retries) + else: + raise ValueError(f'Unexpected value for {key}: {dim}') + + try: + # Get parameters from size and RT60 + room_absorption, room_max_order = pra.inverse_sabine(rt60, room_dim) + break + except Exception as e: + logging.debug('Inverse sabine failed: %s', str(e)) + # Inverse sabine may fail if the room is too large for the selected RT60. + # Try again by generate a smaller room. + room_absorption = room_max_order = None + continue + + if room_absorption is None or room_max_order is None: + raise RuntimeError(f'Evaluation of parameters failed for RT60 {rt60}s and room size {room_dim}.') + + # Return the required values + room_params = { + 'dim': room_dim, + 'absorption': room_absorption, + 'max_order': room_max_order, + 'rt60_theoretical': rt60, + 'anechoic_absorption': self.cfg.anechoic.absorption, + 'anechoic_max_order': self.cfg.anechoic.max_order, + 'sample_rate': self.cfg.sample_rate, + } + return room_params + + def generate_array(self, room_dim: Iterable[float]) -> ArrayGeometry: + """Generate array placement for the current room and config. + + Args: + room_dim: dimensions of the room, [width, length, height] + + Returns: + Randomly placed microphone array. + """ + mic_cfg = self.cfg.mic_array + + if mic_cfg.positions == 'random': + # Create a radom set of microphones + num_mics = mic_cfg.num_mics + mic_positions = [] + + # Each microphone is placed individually + placement_range = convert_placement_to_range( + placement=mic_cfg.placement, room_dim=room_dim, object_radius=0 + ) + + # Randomize mic placement + for m in range(num_mics): + position_m = [None] * 3 + for idx in range(3): + position_m[idx] = self.random.uniform(low=placement_range[idx][0], high=placement_range[idx][1]) + mic_positions.append(position_m) + + mic_array = ArrayGeometry(mic_positions) + + else: + mic_array = ArrayGeometry(mic_cfg.positions) + + # Randomize center placement + center = np.zeros(3) + placement_range = convert_placement_to_range( + placement=mic_cfg.placement, room_dim=room_dim, object_radius=mic_array.radius + ) + + for idx in range(len(center)): + center[idx] = self.random.uniform(low=placement_range[idx][0], high=placement_range[idx][1]) + + # Place the array at the configured center point + mic_array.translate(to=center) + + # Randomize orientation + orientation = dict() + for key in ['yaw', 'roll', 'pitch']: + # angle for current orientation + angle = mic_cfg.orientation[key] + + if angle is None: + raise ValueError(f'Mic array {key} should be a scalar or a range, currently it is set to None.') + + # check it's within the expected range + check_angle(key, angle) + + if np.isscalar(angle): + orientation[key] = angle + elif len(angle) == 2: + assert angle[0] <= angle[1], f"Expecting two non-decreasing values for {key}, received {angle}" + # generate integer values, for easier bucketing, if necessary + orientation[key] = self.random.uniform(low=angle[0], high=angle[1]) + else: + raise ValueError(f'Unexpected value for orientation {key}: {angle}') + + # Rotate the array to match the selected orientation + mic_array.rotate(**orientation) + + return mic_array + + def generate_source_position(self, room_dim: Iterable[float]) -> List[List[float]]: + """Generate position for all sources in a room. + + Args: + room_dim: dimensions of a 3D shoebox room + + Returns: + List of source positions, with each position characterized with a 3D coordinate + """ + source_cfg = self.cfg.source + placement_range = convert_placement_to_range(placement=source_cfg.placement, room_dim=room_dim) + source_position = [] + + for n in range(source_cfg.num): + # generate a random point withing the range + s_pos = [None] * 3 + for idx in range(len(s_pos)): + s_pos[idx] = self.random.uniform(low=placement_range[idx][0], high=placement_range[idx][1]) + source_position.append(s_pos) + + return source_position + + def generate(self): + """Generate RIR corpus. + + This method will prepare randomized examples based on the current configuration, + run room simulations and save results to output_dir. + """ + logging.info("Generate RIR corpus") + + # Initialize + self.random = default_rng(seed=self.cfg.random_seed) + + # Prepare output dir + output_dir = self.cfg.output_dir + if output_dir.endswith('.yaml'): + output_dir = output_dir[:-5] + + # Create absolute path + logging.info('Output dir set to: %s', output_dir) + + # Generate all cases + for subset, num_rooms in self.cfg.room.num.items(): + + output_dir_subset = os.path.join(output_dir, subset) + examples = [] + + if not os.path.exists(output_dir_subset): + logging.info('Creating output directory: %s', output_dir_subset) + os.makedirs(output_dir_subset) + elif os.path.isdir(output_dir_subset) and len(os.listdir(output_dir_subset)) > 0: + raise RuntimeError(f'Output directory {output_dir_subset} is not empty.') + + # Generate examples + for n_room in range(num_rooms): + + # room info + room_params = self.generate_room_params() + + # array placement + mic_array = self.generate_array(room_params['dim']) + + # source placement + source_position = self.generate_source_position(room_params['dim']) + + # file name for the file + room_filepath = os.path.join(output_dir_subset, f'{subset}_room_{n_room:06d}.h5') + + # prepare example + example = { + 'room_params': room_params, + 'mic_array': mic_array, + 'source_position': source_position, + 'room_filepath': room_filepath, + } + examples.append(example) + + # Simulation + if (num_workers := self.cfg.get('num_workers')) is None: + num_workers = os.cpu_count() - 1 + + if num_workers > 1: + logging.info(f'Simulate using {num_workers} workers') + with multiprocessing.Pool(processes=num_workers) as pool: + metadata = list(tqdm(pool.imap(simulate_room_kwargs, examples), total=len(examples))) + + else: + logging.info('Simulate using a single worker') + metadata = [] + for example in tqdm(examples, total=len(examples)): + metadata.append(simulate_room(**example)) + + # Save manifest + manifest_filepath = os.path.join(output_dir, f'{subset}_manifest.json') + + if os.path.exists(manifest_filepath) and os.path.isfile(manifest_filepath): + raise RuntimeError(f'Manifest config file exists: {manifest_filepath}') + + # Make all paths in the manifest relative to the output dir + for data in metadata: + data['room_filepath'] = os.path.relpath(data['room_filepath'], start=output_dir) + + write_manifest(manifest_filepath, metadata) + + # Generate plots with information about generated data + plot_filepath = os.path.join(output_dir, f'{subset}_info.png') + + if os.path.exists(plot_filepath) and os.path.isfile(plot_filepath): + raise RuntimeError(f'Plot file exists: {plot_filepath}') + + plot_rir_manifest_info(manifest_filepath, plot_filepath=plot_filepath) + + # Save used configuration for reference + config_filepath = os.path.join(output_dir, 'config.yaml') + if os.path.exists(config_filepath) and os.path.isfile(config_filepath): + raise RuntimeError(f'Output config file exists: {config_filepath}') + + OmegaConf.save(self.cfg, config_filepath, resolve=True) + + +def simulate_room_kwargs(kwargs: dict) -> dict: + """Wrapper around `simulate_room` to handle kwargs. + + `pool.map(simulate_room_kwargs, examples)` would be + equivalent to `pool.starstarmap(simulate_room, examples)` + if `starstarmap` would exist. + + Args: + kwargs: kwargs that are forwarded to `simulate_room` + + Returns: + Dictionary with metadata, see `simulate_room` + """ + return simulate_room(**kwargs) + + +def simulate_room( + room_params: dict, + mic_array: ArrayGeometry, + source_position: Iterable[Iterable[float]], + room_filepath: str, +) -> dict: + """Simulate room + + Args: + room_params: parameters of the room to be simulated + mic_array: defines positions of the microphones + source_positions: positions for all sources to be simulated + room_filepath: results are saved to this path + + Returns: + Dictionary with metadata based on simulation setup + and simulation results. Used to create the corresponding + manifest file. + """ + # room with the selected parameters + room_sim = pra.ShoeBox( + room_params['dim'], + fs=room_params['sample_rate'], + materials=pra.Material(room_params['absorption']), + max_order=room_params['max_order'], + ) + + # same geometry for generating anechoic responses + room_anechoic = pra.ShoeBox( + room_params['dim'], + fs=room_params['sample_rate'], + materials=pra.Material(room_params['anechoic_absorption']), + max_order=room_params['anechoic_max_order'], + ) + + # Compute RIRs + for room in [room_sim, room_anechoic]: + # place the array + room.add_microphone_array(mic_array.positions.T) + + # place the sources + for s_pos in source_position: + room.add_source(s_pos) + + # generate RIRs + room.compute_rir() + + # Get metadata for sources + source_distance = [] + source_azimuth = [] + source_elevation = [] + for s_pos in source_position: + distance, azimuth, elevation = mic_array.spherical_relative_to_array(s_pos) + source_distance.append(distance) + source_azimuth.append(azimuth) + source_elevation.append(elevation) + + # RIRs + rir_dataset = { + 'rir': convert_rir_to_multichannel(room_sim.rir), + 'anechoic': convert_rir_to_multichannel(room_anechoic.rir), + } + + # Prepare metadata dict and return + metadata = { + 'room_filepath': room_filepath, + 'sample_rate': room_params['sample_rate'], + 'dim': room_params['dim'], + 'rir_absorption': room_params['absorption'], + 'rir_max_order': room_params['max_order'], + 'rir_rt60_theory': room_sim.rt60_theory(), + 'rir_rt60_measured': room_sim.measure_rt60().mean(axis=0), # average across mics for each source + 'anechoic_rt60_theory': room_anechoic.rt60_theory(), + 'anechoic_rt60_measured': room_anechoic.measure_rt60().mean(axis=0), # average across mics for each source + 'anechoic_absorption': room_params['anechoic_absorption'], + 'anechoic_max_order': room_params['anechoic_max_order'], + 'mic_positions': mic_array.positions, + 'mic_center': mic_array.center, + 'source_position': source_position, + 'source_distance': source_distance, + 'source_azimuth': source_azimuth, + 'source_elevation': source_elevation, + 'num_sources': len(source_position), + } + + # Save simulated RIR + save_rir_simulation(room_filepath, rir_dataset, metadata) + + return convert_numpy_to_serializable(metadata) + + +def save_rir_simulation(filepath: str, rir_dataset: Dict[str, List[np.array]], metadata: dict): + """Save simulated RIRs and metadata. + + Args: + filepath: Path to the file where the data will be saved. + rir_dataset: Dictionary with RIR data. Each item is a set of multi-channel RIRs. + metadata: Dictionary with related metadata. + """ + if os.path.exists(filepath): + raise RuntimeError(f'Output file exists: {filepath}') + + num_sources = metadata['num_sources'] + + with h5py.File(filepath, 'w') as h5f: + # Save RIRs, each RIR set in a separate group + for rir_key, rir_value in rir_dataset.items(): + if len(rir_value) != num_sources: + raise ValueError( + f'Each RIR dataset should have exactly {num_sources} elements. Current RIR {rir_key} has {len(rir_value)} elements' + ) + + rir_group = h5f.create_group(rir_key) + + # RIRs for different sources are saved under [group]['idx'] + for idx, rir in enumerate(rir_value): + rir_group.create_dataset(f'{idx}', data=rir_value[idx]) + + # Save metadata + metadata_group = h5f.create_group('metadata') + for key, value in metadata.items(): + metadata_group.create_dataset(key, data=value) + + +def load_rir_simulation(filepath: str, source: int = 0, rir_key: str = 'rir') -> Tuple[np.ndarray, float]: + """Load simulated RIRs and metadata. + + Args: + filepath: Path to simulated RIR data + source: Index of a source. + rir_key: String to denote which RIR to load, if there are multiple available. + + Returns: + Multichannel RIR as ndarray with shape (num_samples, num_channels) and scalar sample rate. + """ + with h5py.File(filepath, 'r') as h5f: + # Load RIR + rir = h5f[rir_key][f'{source}'][:] + + # Load metadata + sample_rate = h5f['metadata']['sample_rate'][()] + + return rir, sample_rate + + +def convert_numpy_to_serializable(data: Union[dict, float, np.ndarray]) -> Union[dict, float, np.ndarray]: + """Convert all numpy estries to list. + Can be used to preprocess data before writing to a JSON file. + + Args: + data: Dictionary, array or scalar. + + Returns: + The same structure, but converted to list if + the input is np.ndarray, so `data` can be seralized. + """ + if isinstance(data, dict): + for key, val in data.items(): + data[key] = convert_numpy_to_serializable(val) + elif isinstance(data, list): + data = [convert_numpy_to_serializable(d) for d in data] + elif isinstance(data, np.ndarray): + data = data.tolist() + elif isinstance(data, np.integer): + data = int(data) + elif isinstance(data, np.floating): + data = float(data) + elif isinstance(data, np.generic): + data = data.item() + + return data + + +def convert_rir_to_multichannel(rir: List[List[np.ndarray]]) -> List[np.ndarray]: + """Convert RIR to a list of arrays. + + Args: + rir: list of lists, each element is a single-channel RIR + + Returns: + List of multichannel RIRs + """ + num_mics = len(rir) + num_sources = len(rir[0]) + + mc_rir = [None] * num_sources + + for n_source in range(num_sources): + rir_len = [len(rir[m][n_source]) for m in range(num_mics)] + max_len = max(rir_len) + mc_rir[n_source] = np.zeros((max_len, num_mics)) + for n_mic, len_mic in enumerate(rir_len): + mc_rir[n_source][:len_mic, n_mic] = rir[n_mic][n_source] + + return mc_rir + + +def plot_rir_manifest_info(filepath: str, plot_filepath: str = None): + """Plot distribution of parameters from manifest file. + + Args: + filepath: path to a RIR corpus manifest file + plot_filepath: path to save the plot at + """ + metadata = read_manifest(filepath) + + # source placement + source_distance = [] + source_azimuth = [] + source_elevation = [] + source_height = [] + + # room config + rir_rt60_theory = [] + rir_rt60_measured = [] + anechoic_rt60_theory = [] + anechoic_rt60_measured = [] + + # get the required data + for data in metadata: + # source config + source_distance += data['source_distance'] + source_azimuth += data['source_azimuth'] + source_elevation += data['source_elevation'] + source_height += [s_pos[2] for s_pos in data['source_position']] + + # room config + rir_rt60_theory.append(data['rir_rt60_theory']) + rir_rt60_measured += data['rir_rt60_measured'] + anechoic_rt60_theory.append(data['anechoic_rt60_theory']) + anechoic_rt60_measured += data['anechoic_rt60_measured'] + + # plot + plt.figure(figsize=(12, 6)) + + plt.subplot(2, 4, 1) + plt.hist(source_distance, label='distance') + plt.xlabel('distance / m') + plt.ylabel('# examples') + plt.title('Source-to-array center distance') + + plt.subplot(2, 4, 2) + plt.hist(source_azimuth, label='azimuth') + plt.xlabel('azimuth / deg') + plt.ylabel('# examples') + plt.title('Source-to-array center azimuth') + + plt.subplot(2, 4, 3) + plt.hist(source_elevation, label='elevation') + plt.xlabel('elevation / deg') + plt.ylabel('# examples') + plt.title('Source-to-array center elevation') + + plt.subplot(2, 4, 4) + plt.hist(source_height, label='source height') + plt.xlabel('height / m') + plt.ylabel('# examples') + plt.title('Source height') + + plt.subplot(2, 4, 5) + plt.hist(rir_rt60_theory, label='theory') + plt.xlabel('RT60 / s') + plt.ylabel('# examples') + plt.title('RT60 theory') + + plt.subplot(2, 4, 6) + plt.hist(rir_rt60_measured, label='measured') + plt.xlabel('RT60 / s') + plt.ylabel('# examples') + plt.title('RT60 measured') + + plt.subplot(2, 4, 7) + plt.hist(anechoic_rt60_theory, label='theory') + plt.xlabel('RT60 / s') + plt.ylabel('# examples') + plt.title('RT60 theory (anechoic)') + + plt.subplot(2, 4, 8) + plt.hist(anechoic_rt60_measured, label='measured') + plt.xlabel('RT60 / s') + plt.ylabel('# examples') + plt.title('RT60 measured (anechoic)') + + for n in range(8): + plt.subplot(2, 4, n + 1) + plt.grid() + plt.legend(loc='lower left') + + plt.tight_layout() + + if plot_filepath is not None: + plt.savefig(plot_filepath) + plt.close() + logging.info('Plot saved at %s', plot_filepath) + + +class RIRMixGenerator(object): + """Creates a dataset of mixed signals at the microphone + by combining target speech, background noise and interference. + + Correspnding signals are are generated and saved + using the `generate` method. + + Input configuration is expexted to have the following structure + ``` + sample_rate: sample rate used for simulation + room: + subset: manifest for RIR data + target: + subset: manifest for target source data + noise: + subset: manifest for noise data + interference: + subset: manifest for interference data + interference_probability: probability that interference is present + max_num_interferers: max number of interferers, randomly selected between 0 and max + mix: + subset: + num: number of examples to generate + rsnr: range of RSNR + rsir: range of RSIR + ref_mic: reference microphone + ref_mic_rms: desired RMS at ref_mic + ``` + """ + + def __init__(self, cfg: DictConfig): + """ + Instantiate a RIRMixGenerator object. + + Args: + cfg: generator configuration defining data for room, + target signal, noise, interference and mixture + """ + logging.info("Initialize RIRMixGenerator") + self._cfg = cfg + self.check_cfg() + + self.subsets = self.cfg.room.keys() + logging.info('Initialized with %d subsets: %s', len(self.subsets), str(self.subsets)) + + # load manifests + self.metadata = dict() + for subset in self.subsets: + subset_data = dict() + + logging.info('Loading data for %s', subset) + for key in ['room', 'target', 'noise', 'interference']: + try: + subset_data[key] = read_manifest(self.cfg[key][subset]) + logging.info('\t%-*s: \t%d files', 15, key, len(subset_data[key])) + except Exception as e: + subset_data[key] = None + logging.info('\t%-*s: \t0 files', 15, key) + logging.warning('\t\tManifest data not loaded. Exception: %s', str(e)) + + self.metadata[subset] = subset_data + + logging.info('Loaded all manifests') + + self.num_retries = self.cfg.get('num_retries', 5) + + @property + def cfg(self): + """Property holding the internal config of the object. + + Note: + Changes to this config are not reflected in the state of the object. + Please create a new model with the updated config. + """ + return self._cfg + + @property + def sample_rate(self): + return self._cfg.sample_rate + + @cfg.setter + def cfg(self, cfg): + """Property holding the internal config of the object. + + Note: + Changes to this config are not reflected in the state of the object. + Please create a new model with the updated config. + """ + self._cfg = cfg + + def check_cfg(self): + """ + Checks provided configuration to ensure it has the minimal required + configuration the values are in a reasonable range. + """ + # sample rate + sample_rate = self.cfg.get('sample_rate') + if sample_rate is None: + raise ValueError('Sample rate not provided.') + elif sample_rate < 0: + raise ValueError(f'Sample rate must be positive: {sample_rate}') + + # room configuration + room_cfg = self.cfg.get('room') + if not room_cfg: + raise ValueError( + 'Room configuration not provided. Expecting RIR manifests in format {subset: path_to_manifest}' + ) + + # target configuration + target_cfg = self.cfg.get('target') + if not target_cfg: + raise ValueError( + 'Target configuration not provided. Expecting audio manifests in format {subset: path_to_manifest}' + ) + + for key in ['azimuth', 'elevation', 'distance']: + value = target_cfg.get(key) + + if value is None or np.isscalar(value): + # no constraint or a fixed dimension is ok + pass + elif len(value) != 2 or not value[0] < value[1]: + # not a valid range + raise ValueError(f'Range must be specified with two positive increasing elements for {key}: {value}') + + # noise configuration + noise_cfg = self.cfg.get('noise') + if not noise_cfg: + raise ValueError( + 'Noise configuration not provided. Expecting audio manifests in format {subset: path_to_manifest}' + ) + + # interference configuration + interference_cfg = self.cfg.get('interference') + if not interference_cfg: + logging.info('Interference configuration not provided.') + else: + interference_probability = interference_cfg.get('interference_probability', 0) + max_num_interferers = interference_cfg.get('max_num_interferers', 0) + min_azimuth_to_target = interference_cfg.get('min_azimuth_to_target', 0) + if interference_probability is not None: + if interference_probability < 0: + raise ValueError( + f'Interference probability must be non-negative. Current value: {interference_probability}' + ) + elif interference_probability > 0: + assert ( + max_num_interferers is not None and max_num_interferers > 0 + ), f'Max number of interferers must be positive. Current value: {max_num_interferers}' + assert ( + min_azimuth_to_target is not None and min_azimuth_to_target >= 0 + ), 'Min azimuth to target must be non-negative' + + # mix configuration + mix_cfg = self.cfg.get('mix') + if not mix_cfg: + raise ValueError('Mix configuration not provided. Expecting configuration for each subset.') + if 'ref_mic' not in mix_cfg: + raise ValueError('Reference microphone not defined.') + if 'ref_mic_rms' not in mix_cfg: + raise ValueError('Reference microphone RMS not defined.') + + def generate_target(self, subset: str) -> dict: + """ + Prepare a dictionary with target configuration. + + The output dictionary contains the following information + ``` + room_index: index of the selected room from the RIR corpus + room_filepath: path to the room simulation file + source: index of the selected source for the target + rt60: reverberation time of the selected room + num_mics: number of microphones + azimuth: azimuth of the target source, relative to the microphone array + elevation: elevation of the target source, relative to the microphone array + distance: distance of the target source, relative to the microphone array + audio_filepath: path to the audio file for the target source + text: text for the target source audio signal, if available + duration: duration of the target source audio signal + ``` + + Args: + subset: string denoting a subset which will be used to selected target + audio and room parameters. + + Returns: + Dictionary with target configuration, including room, source index, and audio information. + """ + + # Utility function + def select_target_source(room_metadata, room_indices): + """Find a room and a source that satisfies the constraints.""" + for room_index in room_indices: + # Select room + room_data = room_metadata[room_index] + + # Candidate sources + sources = self.random.choice(room_data['num_sources'], size=self.num_retries, replace=False) + + # Select target source in this room + for source in sources: + # Check constraints + constraints_met = [] + for constraint in ['azimuth', 'elevation', 'distance']: + if self.cfg.target.get(constraint) is not None: + # Check that the selected source is in the range + source_value = room_data[f'source_{constraint}'][source] + if self.cfg.target[constraint][0] <= source_value <= self.cfg.target[constraint][1]: + constraints_met.append(True) + else: + constraints_met.append(False) + # No need to check the remaining constraints + break + + # Check if a feasible source is found + if all(constraints_met): + # A feasible source has been found + return source, room_index + + return None, None + + # Prepare room & source position + room_metadata = self.metadata[subset]['room'] + room_indices = self.random.choice(len(room_metadata), size=self.num_retries, replace=False) + source, room_index = select_target_source(room_metadata, room_indices) + + if source is None: + raise RuntimeError(f'Could not find a feasible source given target constraints {self.cfg.target}') + + room_data = room_metadata[room_index] + + # Optional: select subset of channels + num_available_mics = len(room_data['mic_positions']) + if 'mic_array' in self.cfg: + num_mics = self.cfg.mic_array['num_mics'] + mic_selection = self.cfg.mic_array['selection'] + + if mic_selection == 'random': + logging.debug('Randomly selecting %d mics', num_mics) + selected_mics = self.random.choice(num_available_mics, size=num_mics, replace=False) + elif isinstance(mic_selection, Iterable): + logging.debug('Using explicitly selected mics: %s', str(mic_selection)) + assert ( + 0 <= min(mic_selection) < num_available_mics + ), f'Expecting mic_selection in range [0,{num_available_mics}), current value: {mic_selection}' + selected_mics = np.array(mic_selection) + else: + raise ValueError(f'Unexpected value for mic_selection: {mic_selection}') + else: + logging.debug('Using all %d available mics', num_available_mics) + num_mics = num_available_mics + selected_mics = np.arange(num_mics) + + # Double-check the number of mics is as expected + assert ( + len(selected_mics) == num_mics + ), f'Expecting {num_mics} mics, but received {len(selected_mics)} mics: {selected_mics}' + logging.debug('Selected mics: %s', str(selected_mics)) + + # Calculate distance from the source to each microphone + mic_positions = np.array(room_data['mic_positions'])[selected_mics] + source_position = np.array(room_data['source_position'][source]) + distance_source_to_mic = np.linalg.norm(mic_positions - source_position, axis=1) + + # Handle relative paths + room_filepath = room_data['room_filepath'] + if not os.path.isabs(room_filepath): + manifest_dir = os.path.dirname(self.cfg.room[subset]) + room_filepath = os.path.join(manifest_dir, room_filepath) + + target_cfg = { + 'room_index': int(room_index), + 'room_filepath': room_filepath, + 'source': source, + 'rt60': room_data['rir_rt60_measured'][source], + 'selected_mics': selected_mics.tolist(), + # Positions + 'source_position': source_position.tolist(), + 'mic_positions': mic_positions.tolist(), + # Relative to center of the array + 'azimuth': room_data['source_azimuth'][source], + 'elevation': room_data['source_elevation'][source], + 'distance': room_data['source_distance'][source], + # Relative to mics + 'distance_source_to_mic': distance_source_to_mic, + } + + return target_cfg + + def generate_interference(self, subset: str, target_cfg: dict) -> List[dict]: + """ + Prepare a list of dictionaries with interference configuration. + + Args: + subset: string denoting a subset which will be used to select interference audio. + target_cfg: dictionary with target configuration. This is used to determine + the minimal required duration for the noise signal. + + Returns: + List of dictionary with interference configuration, including source index and audio information + for one or more interference sources. + """ + if self.metadata[subset]['interference'] is None: + # No interference to be configured + return None + + # Configure interfering sources + max_num_sources = self.cfg.interference.get('max_num_interferers', 0) + interference_probability = self.cfg.interference.get('interference_probability', 0) + + if ( + max_num_sources >= 1 + and interference_probability > 0 + and self.random.uniform(low=0.0, high=1.0) < interference_probability + ): + # interference present + num_interferers = self.random.integers(low=1, high=max_num_sources + 1) + else: + # interference not present + return None + + # Room setup: same room as target + room_index = target_cfg['room_index'] + room_data = self.metadata[subset]['room'][room_index] + feasible_sources = list(range(room_data['num_sources'])) + # target source is not eligible + feasible_sources.remove(target_cfg['source']) + + # Constraints for interfering sources + min_azimuth_to_target = self.cfg.interference.get('min_azimuth_to_target', 0) + + # Prepare interference configuration + interference_cfg = [] + for n in range(num_interferers): + + # Select a source + source = None + while len(feasible_sources) > 0 and source is None: + + # Select a potential source for the target + source = self.random.choice(feasible_sources) + feasible_sources.remove(source) + + # Check azimuth separation + if min_azimuth_to_target > 0: + source_azimuth = room_data['source_azimuth'][source] + azimuth_diff = wrap_to_180(source_azimuth - target_cfg['azimuth']) + if abs(azimuth_diff) < min_azimuth_to_target: + # Try again + source = None + continue + + if source is None: + logging.warning('Could not select a feasible interference source %d of %s', n, num_interferers) + + # Return what we have for now or None + return interference_cfg if interference_cfg else None + + # Current source setup + interfering_source = { + 'source': source, + 'selected_mics': target_cfg['selected_mics'], + 'position': room_data['source_position'][source], + 'azimuth': room_data['source_azimuth'][source], + 'elevation': room_data['source_elevation'][source], + 'distance': room_data['source_distance'][source], + } + + # Done with interference for this source + interference_cfg.append(interfering_source) + + return interference_cfg + + def generate_mix(self, subset: str, target_cfg: dict) -> dict: + """Generate scaling parameters for mixing + the target speech at the microphone, background noise + and interference signal at the microphone. + + The output dictionary contains the following information + ``` + rsnr: reverberant signal-to-noise ratio + rsir: reverberant signal-to-interference ratio + ref_mic: reference microphone for calculating the metrics + ref_mic_rms: RMS of the signal at the reference microphone + ``` + + Args: + subset: string denoting the subset of configuration + target_cfg: dictionary with target configuration + + Returns: + Dictionary containing configured RSNR, RSIR, ref_mic + and RMS on ref_mic. + """ + mix_cfg = dict() + + for key in ['rsnr', 'rsir', 'ref_mic', 'ref_mic_rms', 'min_duration']: + if key in self.cfg.mix[subset]: + # Take the value from subset config + value = self.cfg.mix[subset].get(key) + else: + # Take the global value + value = self.cfg.mix.get(key) + + if value is None: + mix_cfg[key] = None + elif np.isscalar(value): + mix_cfg[key] = value + elif len(value) == 2: + # Select from the given range, including the upper bound + mix_cfg[key] = self.random.integers(low=value[0], high=value[1] + 1) + else: + # Select one of the multiple values + mix_cfg[key] = self.random.choice(value) + + if mix_cfg['ref_mic'] == 'closest': + # Select the closest mic as the reference + mix_cfg['ref_mic'] = np.argmin(target_cfg['distance_source_to_mic']) + + # Configuration for saving individual components + mix_cfg['save'] = OmegaConf.to_object(self.cfg.mix['save']) if 'save' in self.cfg.mix else {} + + return mix_cfg + + def generate(self): + """Generate a corpus of microphone signals by mixing target, background noise + and interference signals. + + This method will prepare randomized examples based on the current configuration, + run simulations and save results to output_dir. + """ + logging.info('Generate mixed signals') + + # Initialize + self.random = default_rng(seed=self.cfg.random_seed) + + # Prepare output dir + output_dir = self.cfg.output_dir + if output_dir.endswith('.yaml'): + output_dir = output_dir[:-5] + + # Create absolute path + logging.info('Output dir set to: %s', output_dir) + + # Generate all cases + for subset in self.subsets: + + output_dir_subset = os.path.join(output_dir, subset) + examples = [] + + if not os.path.exists(output_dir_subset): + logging.info('Creating output directory: %s', output_dir_subset) + os.makedirs(output_dir_subset) + elif os.path.isdir(output_dir_subset) and len(os.listdir(output_dir_subset)) > 0: + raise RuntimeError(f'Output directory {output_dir_subset} is not empty.') + + num_examples = self.cfg.mix[subset].num + logging.info('Preparing %d examples for subset %s', num_examples, subset) + + # Generate examples + for n_example in tqdm(range(num_examples), total=num_examples, desc=f'Preparing {subset}'): + # prepare configuration + target_cfg = self.generate_target(subset) + interference_cfg = self.generate_interference(subset, target_cfg) + mix_cfg = self.generate_mix(subset, target_cfg) + + # base file name + base_output_filepath = os.path.join(output_dir_subset, f'{subset}_example_{n_example:09d}') + + # prepare example + example = { + 'sample_rate': self.sample_rate, + 'target_cfg': target_cfg, + 'interference_cfg': interference_cfg, + 'mix_cfg': mix_cfg, + 'base_output_filepath': base_output_filepath, + } + + examples.append(example) + + # Audio data + audio_metadata = { + 'target': self.metadata[subset]['target'], + 'target_dir': os.path.dirname(self.cfg.target[subset]), # manifest_dir + 'noise': self.metadata[subset]['noise'], + 'noise_dir': os.path.dirname(self.cfg.noise[subset]), # manifest_dir + } + + if interference_cfg is not None: + audio_metadata.update( + { + 'interference': self.metadata[subset]['interference'], + 'interference_dir': os.path.dirname(self.cfg.interference[subset]), # manifest_dir + } + ) + + # Simulation + if (num_workers := self.cfg.get('num_workers')) is None: + num_workers = os.cpu_count() - 1 + + if num_workers is not None and num_workers > 1: + logging.info(f'Simulate using {num_workers} workers') + examples_and_audio_metadata = zip(examples, itertools.repeat(audio_metadata, len(examples))) + with multiprocessing.Pool(processes=num_workers) as pool: + metadata = list( + tqdm( + pool.imap(simulate_room_mix_helper, examples_and_audio_metadata), + total=len(examples), + desc=f'Simulating {subset}', + ) + ) + else: + logging.info('Simulate using a single worker') + metadata = [] + for example in tqdm(examples, total=len(examples), desc=f'Simulating {subset}'): + metadata.append(simulate_room_mix(**example, audio_metadata=audio_metadata)) + + # Save manifest + manifest_filepath = os.path.join(output_dir, f'{os.path.basename(output_dir)}_{subset}.json') + + if os.path.exists(manifest_filepath) and os.path.isfile(manifest_filepath): + raise RuntimeError(f'Manifest config file exists: {manifest_filepath}') + + # Make all paths in the manifest relative to the output dir + for data in tqdm(metadata, total=len(metadata), desc=f'Making filepaths relative {subset}'): + for key, val in data.items(): + if key.endswith('_filepath') and val is not None: + data[key] = os.path.relpath(val, start=output_dir) + + write_manifest(manifest_filepath, metadata) + + # Generate plots with information about generated data + plot_filepath = os.path.join(output_dir, f'{os.path.basename(output_dir)}_{subset}_info.png') + + if os.path.exists(plot_filepath) and os.path.isfile(plot_filepath): + raise RuntimeError(f'Plot file exists: {plot_filepath}') + + plot_mix_manifest_info(manifest_filepath, plot_filepath=plot_filepath) + + # Save used configuration for reference + config_filepath = os.path.join(output_dir, 'config.yaml') + if os.path.exists(config_filepath) and os.path.isfile(config_filepath): + raise RuntimeError(f'Output config file exists: {config_filepath}') + + OmegaConf.save(self.cfg, config_filepath, resolve=True) + + +def convolve_rir(signal: np.ndarray, rir: np.ndarray) -> np.ndarray: + """Convolve signal with a possibly multichannel IR in rir, i.e., + calculate the following for each channel m: + + signal_m = rir_m \ast signal + + Args: + signal: single-channel signal (samples,) + rir: single- or multi-channel IR, (samples,) or (samples, channels) + + Returns: + out: same length as signal, same number of channels as rir, shape (samples, channels) + """ + num_samples = len(signal) + if rir.ndim == 1: + # convolve and trim to length + out = convolve(signal, rir)[:num_samples] + elif rir.ndim == 2: + num_channels = rir.shape[1] + out = np.zeros((num_samples, num_channels)) + for m in range(num_channels): + out[:, m] = convolve(signal, rir[:, m])[:num_samples] + + else: + raise RuntimeError(f'RIR with {rir.ndim} not supported') + + return out + + +def calculate_drr(rir: np.ndarray, sample_rate: float, n_direct: List[int], n_0_ms=2.5) -> List[float]: + """Calculate direct-to-reverberant ratio (DRR) from the measured RIR. + + Calculation is done as in eq. (3) from [1]. + + Args: + rir: room impulse response, shape (num_samples, num_channels) + sample_rate: sample rate for the impulse response + n_direct: direct path delay + n_0_ms: window around n_direct for calculating the direct path energy + + Returns: + Calculated DRR for each channel of the input RIR. + + References: + [1] Eaton et al, The ACE challenge: Corpus description and performance evaluation, WASPAA 2015 + """ + # Define a window around the direct path delay + n_0 = int(n_0_ms * sample_rate / 1000) + + len_rir, num_channels = rir.shape + drr = [None] * num_channels + for m in range(num_channels): + + # Window around the direct path + dir_start = max(n_direct[m] - n_0, 0) + dir_end = n_direct[m] + n_0 + + # Power of the direct component + pow_dir = np.sum(np.abs(rir[dir_start:dir_end, m]) ** 2) / len_rir + + # Power of the reverberant component + pow_reverberant = (np.sum(np.abs(rir[0:dir_start, m]) ** 2) + np.sum(np.abs(rir[dir_end:, m]) ** 2)) / len_rir + + # DRR in dB + drr[m] = pow2db(pow_dir / pow_reverberant) + + return drr + + +def normalize_max(x: np.ndarray, max_db: float = 0, eps: float = 1e-16) -> np.ndarray: + """Normalize max input value to max_db full scale (±1). + + Args: + x: input signal + max_db: desired max magnitude compared to full scale + eps: small regularization constant + + Returns: + Normalized signal with max absolute value max_db. + """ + max_val = db2mag(max_db) + return max_val * x / (np.max(np.abs(x)) + eps) + + +def simultaneously_active_rms( + x: np.ndarray, + y: np.ndarray, + sample_rate: float, + rms_threshold_db: float = -60, + window_len_ms: float = 200, + min_active_duration: float = 0.5, +) -> Tuple[float, float]: + """Calculate RMS over segments where both input signals are active. + + Args: + x: first input signal + y: second input signal + sample_rate: sample rate for input signals in Hz + rms_threshold_db: threshold for determining activity of the signal, relative + to max absolute value + window_len_ms: window length in milliseconds, used for calculating segmental RMS + min_active_duration: minimal duration of the active segments + + Returns: + RMS value over active segments for x and y. + """ + if len(x) != len(y): + raise RuntimeError(f'Expecting signals of same length: len(x)={len(x)}, len(y)={len(y)}') + window_len = int(window_len_ms * sample_rate / 1000) + rms_threshold = db2mag(rms_threshold_db) # linear scale + + x_normalized = normalize_max(x) + y_normalized = normalize_max(y) + + x_active_power = y_active_power = active_len = 0 + for start in range(0, len(x) - window_len, window_len): + window = slice(start, start + window_len) + + # check activity on the scaled signal + x_window_rms = rms(x_normalized[window]) + y_window_rms = rms(y_normalized[window]) + + if x_window_rms > rms_threshold and y_window_rms > rms_threshold: + # sum the power of the original non-scaled signal + x_active_power += np.sum(np.abs(x[window]) ** 2) + y_active_power += np.sum(np.abs(y[window]) ** 2) + active_len += window_len + + if active_len < int(min_active_duration * sample_rate): + raise RuntimeError( + f'Signals are simultaneously active less than {min_active_duration} s: only {active_len/sample_rate} s' + ) + + # normalize + x_active_power /= active_len + y_active_power /= active_len + + return np.sqrt(x_active_power), np.sqrt(y_active_power) + + +def scaled_disturbance( + signal: np.ndarray, + disturbance: np.ndarray, + sdr: float, + sample_rate: float = None, + ref_channel: int = 0, + eps: float = 1e-16, +) -> np.ndarray: + """ + Args: + signal: numpy array, shape (num_samples, num_channels) + disturbance: numpy array, same shape as signal + sdr: desired signal-to-disturbance ration + sample_rate: sample rate of the input signals + ref_channel: ref mic used to calculate RMS + eps: regularization constant + + Returns: + Scaled disturbance, so that signal-to-disturbance ratio at ref_channel + is approximately equal to input SDR during simultaneously active + segment of signal and disturbance. + """ + if signal.shape != disturbance.shape: + raise ValueError(f'Signal and disturbance shapes do not match: {signal.shape} != {disturbance.shape}') + + # set scaling based on RMS at ref_mic + signal_rms, disturbance_rms = simultaneously_active_rms( + signal[:, ref_channel], disturbance[:, ref_channel], sample_rate=sample_rate + ) + disturbance_gain = db2mag(-sdr) * signal_rms / (disturbance_rms + eps) + # scale disturbance + scaled_disturbance = disturbance_gain * disturbance + return scaled_disturbance + + +def prepare_source_signal( + signal_type: str, + sample_rate: int, + audio_data: List[dict], + audio_dir: Optional[str] = None, + min_duration: Optional[int] = None, + ref_signal: Optional[np.ndarray] = None, + mic_positions: Optional[np.ndarray] = None, + num_retries: int = 10, +) -> tuple: + """Prepare an audio signal for a source. + + Args: + signal_type: 'point' or 'diffuse' + sample_rate: Sampling rate for the signal + audio_data: List of audio items, each is a dictionary with audio_filepath, duration, offset and optionally text + audio_dir: Base directory for resolving paths, e.g., manifest basedir + min_duration: Minimal duration to be loaded if ref_signal is not provided, in seconds + ref_signal: Optional, used to determine the length of the signal + mic_positions: Optional, used to prepare approximately diffuse signal + num_retries: Number of retries when selecting the source files + + Returns: + (audio_signal, metadata), where audio_signal is an ndarray and metadata is a dictionary + with audio filepaths, durations and offsets + """ + if signal_type not in ['point', 'diffuse']: + raise ValueError(f'Unexpected signal type {signal_type}.') + + if audio_data is None: + # No data to load + return None + + metadata = {} + + if ref_signal is None: + audio_signal = None + # load at least one sample if min_duration is not provided + samples_to_load = int(min_duration * sample_rate) if min_duration is not None else 1 + source_signals_metadata = {'audio_filepath': [], 'duration': [], 'offset': [], 'text': []} + + while samples_to_load > 0: + # Select a random item and load the audio + item = random.choice(audio_data) + + audio_filepath = item['audio_filepath'] + if not os.path.isabs(audio_filepath) and audio_dir is not None: + audio_filepath = os.path.join(audio_dir, audio_filepath) + + # Load audio + check_min_sample_rate(audio_filepath, sample_rate) + audio_segment = AudioSegment.from_file( + audio_file=audio_filepath, + target_sr=sample_rate, + duration=item['duration'], + offset=item.get('offset', 0), + ) + + if signal_type == 'point': + if audio_segment.num_channels > 1: + raise RuntimeError( + f'Expecting single-channel source signal, but received {audio_segment.num_channels}. File: {audio_filepath}' + ) + else: + raise ValueError(f'Unexpected signal type {signal_type}.') + + source_signals_metadata['audio_filepath'].append(audio_filepath) + source_signals_metadata['duration'].append(item['duration']) + source_signals_metadata['duration'].append(item.get('offset', 0)) + source_signals_metadata['text'].append(item.get('text')) + + # not perfect, since different files may have different distributions + segment_samples = normalize_max(audio_segment.samples) + # concatenate + audio_signal = ( + np.concatenate((audio_signal, segment_samples)) if audio_signal is not None else segment_samples + ) + # remaining samples + samples_to_load -= len(segment_samples) + + # Finally, we need only the metadata for the complete signal + metadata = { + 'duration': sum(source_signals_metadata['duration']), + 'offset': 0, + } + + # Add text only if all source signals have text + if all([isinstance(tt, str) for tt in source_signals_metadata['text']]): + metadata['text'] = ' '.join(source_signals_metadata['text']) + else: + # Load a signal with total_len samples and ensure it has enough simultaneous activity/overlap with ref_signal + # Concatenate multiple files if necessary + total_len = len(ref_signal) + + for n in range(num_retries): + + audio_signal = None + source_signals_metadata = {'audio_filepath': [], 'duration': [], 'offset': []} + + if signal_type == 'point': + samples_to_load = total_len + elif signal_type == 'diffuse': + # Load longer signal so it can be reshaped into (samples, mics) and + # used to generate approximately diffuse noise field + num_mics = len(mic_positions) + samples_to_load = num_mics * total_len + + while samples_to_load > 0: + # Select an audio file + item = random.choice(audio_data) + + audio_filepath = item['audio_filepath'] + if not os.path.isabs(audio_filepath) and audio_dir is not None: + audio_filepath = os.path.join(audio_dir, audio_filepath) + + # Load audio signal + check_min_sample_rate(audio_filepath, sample_rate) + + if (max_offset := item['duration'] - np.ceil(samples_to_load / sample_rate)) > 0: + # Load with a random offset if the example is longer than samples_to_load + offset = random.uniform(0, max_offset) + duration = -1 + else: + # Load the whole file + offset, duration = 0, item['duration'] + audio_segment = AudioSegment.from_file( + audio_file=audio_filepath, target_sr=sample_rate, duration=duration, offset=offset + ) + + # Prepare a single-channel signal + if audio_segment.num_channels == 1: + # Take all samples + segment_samples = audio_segment.samples + else: + # Take a random channel + selected_channel = random.choice(range(audio_segment.num_channels)) + segment_samples = audio_segment.samples[:, selected_channel] + + source_signals_metadata['audio_filepath'].append(audio_filepath) + source_signals_metadata['duration'].append(len(segment_samples) / sample_rate) + source_signals_metadata['offset'].append(offset) + + # not perfect, since different files may have different distributions + segment_samples = normalize_max(segment_samples) + # concatenate + audio_signal = ( + np.concatenate((audio_signal, segment_samples)) if audio_signal is not None else segment_samples + ) + # remaining samples + samples_to_load -= len(segment_samples) + + if signal_type == 'diffuse' and num_mics > 1: + try: + # Trim and reshape to num_mics to prepare num_mics source signals + audio_signal = audio_signal[: num_mics * total_len].reshape(num_mics, -1).T + + # Make spherically diffuse noise + audio_signal = generate_approximate_noise_field( + mic_positions=np.array(mic_positions), noise_signal=audio_signal, sample_rate=sample_rate + ) + except Exception as e: + logging.info('Failed to generate approximate noise field: %s', str(e)) + logging.info('Try again.') + # Try again + audio_signal, source_signals_metadata = None, {} + continue + + # Trim to length + audio_signal = audio_signal[:total_len, ...] + + # Include the channel dimension if the reference includes it + if ref_signal.ndim == 2 and audio_signal.ndim == 1: + audio_signal = audio_signal[:, None] + + try: + # Signal and ref_signal should be simultaneously active + simultaneously_active_rms(ref_signal, audio_signal, sample_rate=sample_rate) + # We have enough overlap + break + except Exception as e: + # Signal and ref_signal are not overlapping, try again + logging.info('Exception: %s', str(e)) + logging.info('Signals are not overlapping, try again.') + audio_signal, source_signals_metadata = None, {} + continue + + if audio_signal is None: + logging.warning('Audio signal not set: %s.', signal_type) + + metadata['source_signals'] = source_signals_metadata + + return audio_signal, metadata + + +def check_min_sample_rate(filepath: str, sample_rate: float): + """Make sure the file's sample rate is at least sample_rate. + This will make sure that we have only downsampling if loading + this file, while upsampling is not permitted. + + Args: + filepath: path to a file + sample_rate: desired sample rate + """ + file_sample_rate = librosa.get_samplerate(path=filepath) + if file_sample_rate < sample_rate: + raise RuntimeError( + f'Sample rate ({file_sample_rate}) is lower than the desired sample rate ({sample_rate}). File: {filepath}.' + ) + + +def simulate_room_mix( + sample_rate: int, + target_cfg: dict, + interference_cfg: dict, + mix_cfg: dict, + audio_metadata: dict, + base_output_filepath: str, + max_amplitude: float = 0.999, + eps: float = 1e-16, +) -> dict: + """Simulate mixture signal at the microphone, including target, noise and + interference signals and mixed at specific RSNR and RSIR. + + Args: + sample_rate: Sample rate for all signals + target_cfg: Dictionary with configuration of the target. Includes + room_filepath, source index, audio_filepath, duration + noise_cfg: List of dictionaries, where each item includes audio_filepath, + offset and duration. + interference_cfg: List of dictionaries, where each item contains source + index + mix_cfg: Dictionary with the mixture configuration. Includes RSNR, RSIR, + ref_mic and ref_mic_rms. + audio_metadata: Dictionary with a list of files for target, noise and interference + base_output_filepath: All output audio files will be saved with this prefix by + adding a diffierent suffix for each component, e.g., _mic.wav. + max_amplitude: Maximum amplitude of the mic signal, used to prevent clipping. + eps: Small regularization constant. + + Returns: + Dictionary with metadata based on the mixture setup and + simulation results. This corresponds to a line of the + output manifest file. + """ + + # Local utilities + def load_rir( + room_filepath: str, source: int, selected_mics: list, sample_rate: float, rir_key: str = 'rir' + ) -> np.ndarray: + """Load a RIR and check that the sample rate is matching the desired sample rate + + Args: + room_filepath: Path to a room simulation in an h5 file + source: Index of the desired source + sample_rate: Sample rate of the simulation + rir_key: Key of the RIR to load from the simulation. + + Returns: + Numpy array with shape (num_samples, num_channels) + """ + rir, rir_sample_rate = load_rir_simulation(room_filepath, source=source, rir_key=rir_key) + if rir_sample_rate != sample_rate: + raise RuntimeError( + f'RIR sample rate ({sample_rate}) is not matching the expected sample rate ({sample_rate}). File: {room_filepath}' + ) + return rir[:, selected_mics] + + def get_early_rir( + rir: np.ndarray, rir_anechoic: np.ndarray, sample_rate: int, early_duration: float = 0.050 + ) -> np.ndarray: + """Return only the early part of the RIR.""" + early_len = int(early_duration * sample_rate) + direct_path_delay = np.min(np.argmax(rir_anechoic, axis=0)) + rir_early = rir.copy() + rir_early[direct_path_delay + early_len :, :] = 0 + return rir_early + + def save_audio( + base_path: str, + tag: str, + audio_signal: Optional[np.ndarray], + sample_rate: int, + save: str = 'all', + ref_mic: Optional[int] = None, + format: str = 'wav', + subtype: str = 'float', + ): + """Save audio signal and return filepath.""" + if (audio_signal is None) or (not save): + return None + + if save == 'ref_mic': + # save only ref_mic + audio_signal = audio_signal[:, ref_mic] + + audio_filepath = base_path + f'_{tag}.{format}' + sf.write(audio_filepath, audio_signal, sample_rate, subtype) + + return audio_filepath + + # Target RIRs + target_rir = load_rir( + target_cfg['room_filepath'], + source=target_cfg['source'], + selected_mics=target_cfg['selected_mics'], + sample_rate=sample_rate, + ) + target_rir_anechoic = load_rir( + target_cfg['room_filepath'], + source=target_cfg['source'], + sample_rate=sample_rate, + selected_mics=target_cfg['selected_mics'], + rir_key='anechoic', + ) + target_rir_early = get_early_rir(rir=target_rir, rir_anechoic=target_rir_anechoic, sample_rate=sample_rate) + + # Target signals + target_signal, target_metadata = prepare_source_signal( + signal_type='point', + sample_rate=sample_rate, + audio_data=audio_metadata['target'], + audio_dir=audio_metadata['target_dir'], + min_duration=mix_cfg['min_duration'], + ) + source_signals_metadata = {'target': target_metadata['source_signals']} + + # Convolve target + target_reverberant = convolve_rir(target_signal, target_rir) + target_anechoic = convolve_rir(target_signal, target_rir_anechoic) + target_early = convolve_rir(target_signal, target_rir_early) + + # Prepare noise signal + noise, noise_metadata = prepare_source_signal( + signal_type='diffuse', + sample_rate=sample_rate, + mic_positions=target_cfg['mic_positions'], + audio_data=audio_metadata['noise'], + audio_dir=audio_metadata['noise_dir'], + ref_signal=target_reverberant, + ) + source_signals_metadata['noise'] = noise_metadata['source_signals'] + + # Prepare interference signal + if interference_cfg is None: + interference = None + else: + # Load interference signals + interference = 0 + source_signals_metadata['interference'] = [] + for i_cfg in interference_cfg: + # Load single-channel signal for directional interference + i_signal, i_metadata = prepare_source_signal( + signal_type='point', + sample_rate=sample_rate, + audio_data=audio_metadata['interference'], + audio_dir=audio_metadata['interference_dir'], + ref_signal=target_signal, + ) + source_signals_metadata['interference'].append(i_metadata['source_signals']) + # Load RIR from the same room as the target, but a difference source + i_rir = load_rir( + target_cfg['room_filepath'], + source=i_cfg['source'], + selected_mics=i_cfg['selected_mics'], + sample_rate=sample_rate, + ) + # Convolve interference + i_reverberant = convolve_rir(i_signal, i_rir) + # Sum + interference += i_reverberant + + # Scale and add components of the signal + mic = target_reverberant.copy() + + if noise is not None: + noise = scaled_disturbance( + signal=target_reverberant, + disturbance=noise, + sdr=mix_cfg['rsnr'], + sample_rate=sample_rate, + ref_channel=mix_cfg['ref_mic'], + ) + # Update mic signal + mic += noise + + if interference is not None: + interference = scaled_disturbance( + signal=target_reverberant, + disturbance=interference, + sdr=mix_cfg['rsir'], + sample_rate=sample_rate, + ref_channel=mix_cfg['ref_mic'], + ) + # Update mic signal + mic += interference + + # Set the final mic signal level + mic_rms = rms(mic[:, mix_cfg['ref_mic']]) + global_gain = db2mag(mix_cfg['ref_mic_rms']) / (mic_rms + eps) + mic_max = np.max(np.abs(mic)) + if (clipped_max := mic_max * global_gain) > max_amplitude: + # Downscale the global gain to prevent clipping + adjust ref_mic_rms accordingly + clipping_prevention_gain = max_amplitude / clipped_max + global_gain *= clipping_prevention_gain + mix_cfg['ref_mic_rms'] += mag2db(clipping_prevention_gain) + + logging.debug( + 'Clipping prevented for example %s (protection gain: %.2f dB)', + base_output_filepath, + mag2db(clipping_prevention_gain), + ) + + # save signals + signals = { + 'mic': mic, + 'target_reverberant': target_reverberant, + 'target_anechoic': target_anechoic, + 'target_early': target_early, + 'noise': noise, + 'interference': interference, + } + + metadata = {} + + for tag, signal in signals.items(): + + if signal is not None: + # scale all signal components with the global gain + signal = global_gain * signal + + audio_filepath = save_audio( + base_path=base_output_filepath, + tag=tag, + audio_signal=signal, + sample_rate=sample_rate, + save=mix_cfg['save'].get(tag, 'all'), + ref_mic=mix_cfg['ref_mic'], + format=mix_cfg['save'].get('format', 'wav'), + subtype=mix_cfg['save'].get('subtype', 'float'), + ) + + if tag == 'mic': + metadata['audio_filepath'] = audio_filepath + else: + metadata[tag + '_filepath'] = audio_filepath + + # Add metadata + metadata.update( + { + 'text': target_metadata.get('text'), + 'duration': target_metadata['duration'], + 'target_cfg': target_cfg, + 'interference_cfg': interference_cfg, + 'mix_cfg': mix_cfg, + 'ref_channel': mix_cfg.get('ref_mic'), + 'rt60': target_cfg.get('rt60'), + 'drr': calculate_drr(target_rir, sample_rate, n_direct=np.argmax(target_rir_anechoic, axis=0)), + 'rsnr': None if noise is None else mix_cfg['rsnr'], + 'rsir': None if interference is None else mix_cfg['rsir'], + 'source_signals': source_signals_metadata, + } + ) + + return convert_numpy_to_serializable(metadata) + + +def simulate_room_mix_helper(example_and_audio_metadata: tuple) -> dict: + """Wrapper around `simulate_room_mix` for pool.imap. + + Args: + args: example and audio_metadata that are forwarded to `simulate_room_mix` + + Returns: + Dictionary with metadata, see `simulate_room_mix` + """ + example, audio_metadata = example_and_audio_metadata + return simulate_room_mix(**example, audio_metadata=audio_metadata) + + +def plot_mix_manifest_info(filepath: str, plot_filepath: str = None): + """Plot distribution of parameters from the manifest file. + + Args: + filepath: path to a RIR corpus manifest file + plot_filepath: path to save the plot at + """ + metadata = read_manifest(filepath) + + # target info + target_distance = [] + target_azimuth = [] + target_elevation = [] + target_duration = [] + + # room config + rt60 = [] + drr = [] + + # noise + rsnr = [] + rsir = [] + + # get the required data + for data in metadata: + # target info + target_distance.append(data['target_cfg']['distance']) + target_azimuth.append(data['target_cfg']['azimuth']) + target_elevation.append(data['target_cfg']['elevation']) + target_duration.append(data['duration']) + + # room config + rt60.append(data['rt60']) + drr += data['drr'] # average DRR across all mics + + # noise + if data['rsnr'] is not None: + rsnr.append(data['rsnr']) + + if data['rsir'] is not None: + rsir.append(data['rsir']) + + # plot + plt.figure(figsize=(12, 6)) + + plt.subplot(2, 4, 1) + plt.hist(target_distance, label='distance') + plt.xlabel('distance / m') + plt.ylabel('# examples') + plt.title('Target-to-array distance') + + plt.subplot(2, 4, 2) + plt.hist(target_azimuth, label='azimuth') + plt.xlabel('azimuth / deg') + plt.ylabel('# examples') + plt.title('Target-to-array azimuth') + + plt.subplot(2, 4, 3) + plt.hist(target_elevation, label='elevation') + plt.xlabel('elevation / deg') + plt.ylabel('# examples') + plt.title('Target-to-array elevation') + + plt.subplot(2, 4, 4) + plt.hist(target_duration, label='duration') + plt.xlabel('time / s') + plt.ylabel('# examples') + plt.title('Target duration') + + plt.subplot(2, 4, 5) + plt.hist(rt60, label='RT60') + plt.xlabel('RT60 / s') + plt.ylabel('# examples') + plt.title('RT60') + + plt.subplot(2, 4, 6) + plt.hist(drr, label='DRR') + plt.xlabel('DRR / dB') + plt.ylabel('# examples') + plt.title('DRR [avg over mics]') + + if len(rsnr) > 0: + plt.subplot(2, 4, 7) + plt.hist(rsnr, label='RSNR') + plt.xlabel('RSNR / dB') + plt.ylabel('# examples') + plt.title(f'RSNR [{100 * len(rsnr) / len(rt60):.0f}% ex]') + + if len(rsir): + plt.subplot(2, 4, 8) + plt.hist(rsir, label='RSIR') + plt.xlabel('RSIR / dB') + plt.ylabel('# examples') + plt.title(f'RSIR [{100 * len(rsir) / len(rt60):.0f}% ex]') + + for n in range(8): + plt.subplot(2, 4, n + 1) + plt.grid() + plt.legend(loc='lower left') + + plt.tight_layout() + + if plot_filepath is not None: + plt.savefig(plot_filepath) + plt.close() + logging.info('Plot saved at %s', plot_filepath) diff --git a/nemo/collections/audio/losses/__init__.py b/nemo/collections/audio/losses/__init__.py new file mode 100644 index 0000000000000..b2968b7b1ad0a --- /dev/null +++ b/nemo/collections/audio/losses/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.audio.losses.audio import MSELoss, SDRLoss diff --git a/nemo/collections/asr/losses/audio_losses.py b/nemo/collections/audio/losses/audio.py similarity index 95% rename from nemo/collections/asr/losses/audio_losses.py rename to nemo/collections/audio/losses/audio.py index b0214375a7136..635b02c5d1fe4 100644 --- a/nemo/collections/asr/losses/audio_losses.py +++ b/nemo/collections/audio/losses/audio.py @@ -19,7 +19,7 @@ import torch from nemo.collections.asr.parts.preprocessing.features import make_seq_mask_like -from nemo.collections.asr.parts.utils.audio_utils import toeplitz +from nemo.collections.audio.parts.utils.audio import toeplitz from nemo.core.classes import Loss, Typing, typecheck from nemo.core.neural_types import AudioSignal, LengthsType, LossType, MaskType, NeuralType, VoidType from nemo.utils import logging @@ -253,7 +253,7 @@ def calculate_sdr_batch( SDR in dB for each channel, shape (B, C) """ if scale_invariant and convolution_invariant: - raise ValueError(f'Arguments scale_invariant and convolution_invariant cannot be used simultaneously.') + raise ValueError('Arguments scale_invariant and convolution_invariant cannot be used simultaneously.') assert ( estimate.shape == target.shape @@ -277,7 +277,11 @@ def calculate_sdr_batch( target = scale_invariant_target(estimate=estimate, target=target, mask=mask, eps=eps) elif convolution_invariant: target = convolution_invariant_target( - estimate=estimate, target=target, mask=mask, filter_length=convolution_filter_length, eps=eps, + estimate=estimate, + target=target, + mask=mask, + filter_length=convolution_filter_length, + eps=eps, ) distortion = estimate - target @@ -327,9 +331,9 @@ def __init__( elif not np.isclose(sum(weight), 1, atol=1e-6): raise ValueError(f'Weight should add to one, current weight: {weight}') weight = torch.tensor(weight).reshape(1, -1) - logging.info(f'Channel weight set to %s', weight) + logging.info('Channel weight set to %s', weight) self.register_buffer('weight', weight) - self.weight: Optional[Tensor] + self.weight: Optional[torch.Tensor] # Batch reduction self.reduction = reduction @@ -352,8 +356,7 @@ def __init__( @property def input_types(self): - """Input types definitions for SDRLoss. - """ + """Input types definitions for SDRLoss.""" signal_shape = ('B', 'C', 'T') return { "estimate": NeuralType(signal_shape, AudioSignal()), @@ -481,7 +484,10 @@ class MSELoss(Loss, Typing): """ def __init__( - self, weight: Optional[List[float]] = None, reduction: str = 'mean', ndim: int = 3, + self, + weight: Optional[List[float]] = None, + reduction: str = 'mean', + ndim: int = 3, ): super().__init__() @@ -492,9 +498,9 @@ def __init__( elif not np.isclose(sum(weight), 1, atol=1e-6): raise ValueError(f'Weight should add to one, current weight: {weight}') weight = torch.tensor(weight).reshape(1, -1) - logging.info(f'Channel weight set to %s', weight) + logging.info('Channel weight set to %s', weight) self.register_buffer('weight', weight) - self.weight: Optional[Tensor] + self.weight: Optional[torch.Tensor] # Batch reduction self.reduction = reduction @@ -523,8 +529,7 @@ def __init__( @property def input_types(self): - """Input types definitions for SDRLoss. - """ + """Input types definitions for SDRLoss.""" return { "estimate": NeuralType(self.signal_shape, VoidType()), "target": NeuralType(self.signal_shape, VoidType()), @@ -560,7 +565,12 @@ def forward( Returns: Scalar loss. """ - mse = calculate_mse_batch(estimate=estimate, target=target, input_length=input_length, mask=mask,) + mse = calculate_mse_batch( + estimate=estimate, + target=target, + input_length=input_length, + mask=mask, + ) # channel averaging if self.weight is None: diff --git a/nemo/collections/audio/metrics/__init__.py b/nemo/collections/audio/metrics/__init__.py new file mode 100644 index 0000000000000..d9155f923f186 --- /dev/null +++ b/nemo/collections/audio/metrics/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/asr/metrics/audio.py b/nemo/collections/audio/metrics/audio.py similarity index 97% rename from nemo/collections/asr/metrics/audio.py rename to nemo/collections/audio/metrics/audio.py index db63ac19c098c..096700eff24a0 100644 --- a/nemo/collections/asr/metrics/audio.py +++ b/nemo/collections/audio/metrics/audio.py @@ -149,8 +149,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor, input_length: Option self.num_examples += preds.size(0) def compute(self) -> torch.Tensor: - """Compute the underlying metric. - """ + """Compute the underlying metric.""" return self._metric.compute() def forward( @@ -181,22 +180,19 @@ def forward( return self._batch_reduction(batch_values) def reset(self) -> None: - """Reset the underlying metric. - """ + """Reset the underlying metric.""" # reset the internal states super().reset() # reset the underlying metric self._metric.reset() def __repr__(self) -> str: - """Return string representation of the object. - """ + """Return string representation of the object.""" _op_metric = f"(metric: {repr(self._metric)}, channel: {self._channel})" repr_str = self.__class__.__name__ + _op_metric return repr_str def _wrap_compute(self, compute: Callable) -> Callable: - """Overwrite to do nothing, as in CompositionalMetric. - """ + """Overwrite to do nothing, as in CompositionalMetric.""" return compute diff --git a/nemo/collections/audio/models/__init__.py b/nemo/collections/audio/models/__init__.py new file mode 100644 index 0000000000000..a8d801fdd0e04 --- /dev/null +++ b/nemo/collections/audio/models/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.audio.models.audio_to_audio import AudioToAudioModel +from nemo.collections.audio.models.enhancement import ( + EncMaskDecAudioToAudioModel, + PredictiveAudioToAudioModel, + ScoreBasedGenerativeAudioToAudioModel, +) diff --git a/nemo/collections/asr/models/audio_to_audio_model.py b/nemo/collections/audio/models/audio_to_audio.py similarity index 78% rename from nemo/collections/asr/models/audio_to_audio_model.py rename to nemo/collections/audio/models/audio_to_audio.py index 094dbc38b72a1..b12f9ce73cbe7 100644 --- a/nemo/collections/asr/models/audio_to_audio_model.py +++ b/nemo/collections/audio/models/audio_to_audio.py @@ -26,11 +26,11 @@ from pytorch_lightning import Trainer from tqdm import tqdm -from nemo.collections.asr.data import audio_to_audio_dataset -from nemo.collections.asr.data.audio_to_audio_lhotse import LhotseAudioToTargetDataset from nemo.collections.asr.data.audio_to_text_dataset import inject_dataloader_value_from_model_config -from nemo.collections.asr.metrics.audio import AudioMetricWrapper -from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType +from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType +from nemo.collections.audio.data import audio_to_audio_dataset +from nemo.collections.audio.data.audio_to_audio_lhotse import LhotseAudioToTargetDataset +from nemo.collections.audio.metrics.audio import AudioMetricWrapper from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.core.classes import ModelPT from nemo.utils import logging, model_utils @@ -45,8 +45,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self._setup_loss() def _setup_loss(self): - """Setup loss for this model. - """ + """Setup loss for this model.""" self.loss = AudioToAudioModel.from_config_dict(self._cfg.loss) def _get_num_dataloaders(self, tag: str = 'val'): @@ -169,120 +168,6 @@ def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0): def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): return self.multi_evaluation_epoch_end(outputs, dataloader_idx, 'test') - @torch.no_grad() - def process( - self, - paths2audio_files: List[str], - output_dir: str, - batch_size: int = 1, - num_workers: Optional[int] = None, - input_channel_selector: Optional[ChannelSelectorType] = None, - ) -> List[str]: - """ - Process audio files provided in paths2audio_files. - Processed signals will be saved in output_dir. - - Args: - paths2audio_files: (a list) of paths to audio files. \ - Recommended length per file is between 5 and 25 seconds. \ - But it is possible to pass a few hours long file if enough GPU memory is available. - output_dir: - batch_size: (int) batch size to use during inference. - Bigger will result in better throughput performance but would use more memory. - num_workers: Number of workers for the dataloader - input_channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. - - Returns: - """ - if paths2audio_files is None or len(paths2audio_files) == 0: - return {} - - if num_workers is None: - num_workers = min(batch_size, os.cpu_count() - 1) - - # Output - paths2processed_files = [] - - # Model's mode and device - mode = self.training - device = next(self.parameters()).device - - try: - # Switch model to evaluation mode - self.eval() - # Freeze weights - self.freeze() - - logging_level = logging.get_verbosity() - logging.set_verbosity(logging.WARNING) - - # Processing - with tempfile.TemporaryDirectory() as tmpdir: - # Save temporary manifest - temporary_manifest_filepath = os.path.join(tmpdir, 'manifest.json') - with open(temporary_manifest_filepath, 'w', encoding='utf-8') as fp: - for audio_file in paths2audio_files: - entry = {'input_filepath': audio_file, 'duration': librosa.get_duration(path=audio_file)} - fp.write(json.dumps(entry) + '\n') - - config = { - 'manifest_filepath': temporary_manifest_filepath, - 'input_key': 'input_filepath', - 'input_channel_selector': input_channel_selector, - 'batch_size': min(batch_size, len(paths2audio_files)), - 'num_workers': num_workers, - } - - # Create output dir if necessary - if not os.path.isdir(output_dir): - os.makedirs(output_dir) - - # DataLoader for the input files - temporary_dataloader = self._setup_process_dataloader(config) - - # Indexing of the original files, used to form the output file name - file_idx = 0 - - # Process batches - for test_batch in tqdm(temporary_dataloader, desc="Processing"): - input_signal = test_batch[0] - input_length = test_batch[1] - - # Expand channel dimension, if necessary - # For consistency, the model uses multi-channel format, even if the channel dimension is 1 - if input_signal.ndim == 2: - input_signal = input_signal.unsqueeze(1) - - processed_batch, _ = self.forward( - input_signal=input_signal.to(device), input_length=input_length.to(device) - ) - - for example_idx in range(processed_batch.size(0)): - # This assumes the data loader is not shuffling files - file_name = os.path.basename(paths2audio_files[file_idx]) - # Prepare output file - output_file = os.path.join(output_dir, f'processed_{file_name}') - # Crop the output signal to the actual length - output_signal = processed_batch[example_idx, :, : input_length[example_idx]].cpu().numpy() - # Write audio - sf.write(output_file, output_signal.T, self.sample_rate, 'float') - # Update the file counter - file_idx += 1 - # Save processed file - paths2processed_files.append(output_file) - - del test_batch - del processed_batch - - finally: - # set mode back to its original value - self.train(mode=mode) - if mode is True: - self.unfreeze() - logging.set_verbosity(logging_level) - - return paths2processed_files - def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse", False): @@ -593,5 +478,5 @@ def on_after_backward(self): torch.distributed.all_reduce(valid_gradients, op=torch.distributed.ReduceOp.MIN) if valid_gradients < 1: - logging.warning(f'detected inf or nan values in gradients! Setting gradients to zero.') + logging.warning('detected inf or nan values in gradients! Setting gradients to zero.') self.zero_grad() diff --git a/nemo/collections/asr/models/enhancement_models.py b/nemo/collections/audio/models/enhancement.py similarity index 98% rename from nemo/collections/asr/models/enhancement_models.py rename to nemo/collections/audio/models/enhancement.py index b765ae0fddad4..f605537041834 100644 --- a/nemo/collections/asr/models/enhancement_models.py +++ b/nemo/collections/audio/models/enhancement.py @@ -11,22 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json -import os -import tempfile -from typing import Dict, List, Optional, Union + +from typing import Dict, Optional import einops import hydra -import librosa -import soundfile as sf import torch from omegaconf import DictConfig from pytorch_lightning import Trainer -from tqdm import tqdm - -from nemo.collections.asr.models.audio_to_audio_model import AudioToAudioModel +from nemo.collections.audio.models.audio_to_audio import AudioToAudioModel from nemo.core.classes.common import PretrainedModelInfo, typecheck from nemo.core.neural_types import AudioSignal, LengthsType, LossType, NeuralType from nemo.utils import logging @@ -261,11 +255,11 @@ def output_types(self) -> Dict[str, NeuralType]: @typecheck() def forward(self, input_signal, input_length=None): """Forward pass of the model. - + Args: input_signal: time-domain signal input_length: valid length of each example in the batch - + Returns: Output signal `output` in the time domain and the length of the output signal `output_length`. """ @@ -361,7 +355,7 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = class ScoreBasedGenerativeAudioToAudioModel(AudioToAudioModel): """This models is using a score-based diffusion process to generate an encoded representation of the enhanced signal. - + The model consists of the following blocks: - encoder: transforms input multi-channel audio signal into an encoded representation (analysis transform) - estimator: neural model, estimates a score for the diffusion process @@ -481,7 +475,9 @@ def forward(self, input_signal, input_length=None): "input_signal": NeuralType(('B', 'C', 'T'), AudioSignal()), "input_length": NeuralType(tuple('B'), LengthsType()), }, - output_types={"loss": NeuralType(None, LossType()),}, + output_types={ + "loss": NeuralType(None, LossType()), + }, ) def _step(self, target_signal, input_signal, input_length=None): """Randomly generate a time step for each example in the batch, estimate diff --git a/nemo/collections/audio/modules/__init__.py b/nemo/collections/audio/modules/__init__.py new file mode 100644 index 0000000000000..d9155f923f186 --- /dev/null +++ b/nemo/collections/audio/modules/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/audio/modules/features.py b/nemo/collections/audio/modules/features.py new file mode 100644 index 0000000000000..ce6cedf0c533a --- /dev/null +++ b/nemo/collections/audio/modules/features.py @@ -0,0 +1,279 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional + +import torch + +from nemo.collections.audio.losses.audio import calculate_mean +from nemo.collections.audio.parts.utils.audio import wrap_to_pi +from nemo.core.classes import NeuralModule, typecheck +from nemo.core.neural_types import LengthsType, NeuralType, SpectrogramType +from nemo.utils import logging + + +class SpectrogramToMultichannelFeatures(NeuralModule): + """Convert a complex-valued multi-channel spectrogram to + multichannel features. + + Args: + num_subbands: Expected number of subbands in the input signal + num_input_channels: Optional, provides the number of channels + of the input signal. Used to infer the number + of output channels. + mag_reduction: Reduction across channels. Default `None`, will calculate + magnitude of each channel. + mag_power: Optional, apply power on the magnitude. + use_ipd: Use inter-channel phase difference (IPD). + mag_normalization: Normalization for magnitude features + ipd_normalization: Normalization for IPD features + eps: Small regularization constant. + """ + + def __init__( + self, + num_subbands: int, + num_input_channels: Optional[int] = None, + mag_reduction: Optional[str] = None, + mag_power: Optional[float] = None, + use_ipd: bool = False, + mag_normalization: Optional[str] = None, + ipd_normalization: Optional[str] = None, + eps: float = 1e-8, + ): + super().__init__() + self.mag_reduction = mag_reduction + self.mag_power = mag_power + self.use_ipd = use_ipd + + if mag_normalization not in [None, 'mean', 'mean_var']: + raise NotImplementedError(f'Unknown magnitude normalization {mag_normalization}') + self.mag_normalization = mag_normalization + + if ipd_normalization not in [None, 'mean', 'mean_var']: + raise NotImplementedError(f'Unknown ipd normalization {ipd_normalization}') + self.ipd_normalization = ipd_normalization + + if self.use_ipd: + self._num_features = 2 * num_subbands + self._num_channels = num_input_channels + else: + self._num_features = num_subbands + self._num_channels = num_input_channels if self.mag_reduction is None else 1 + + self.eps = eps + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tnum_subbands: %d', num_subbands) + logging.debug('\tmag_reduction: %s', self.mag_reduction) + logging.debug('\tmag_power: %s', self.mag_power) + logging.debug('\tuse_ipd: %s', self.use_ipd) + logging.debug('\tmag_normalization: %s', self.mag_normalization) + logging.debug('\tipd_normalization: %s', self.ipd_normalization) + logging.debug('\teps: %f', self.eps) + logging.debug('\t_num_features: %s', self._num_features) + logging.debug('\t_num_channels: %s', self._num_channels) + + @property + def input_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports.""" + return { + "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "input_length": NeuralType(('B',), LengthsType()), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports.""" + return { + "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "output_length": NeuralType(('B',), LengthsType()), + } + + @property + def num_features(self) -> int: + """Configured number of features""" + return self._num_features + + @property + def num_channels(self) -> int: + """Configured number of channels""" + if self._num_channels is not None: + return self._num_channels + else: + raise ValueError( + 'Num channels is not configured. To configure this, `num_input_channels` ' + 'must be provided when constructing the object.' + ) + + @staticmethod + def get_mean_time_channel(input: torch.Tensor, input_length: Optional[torch.Tensor] = None) -> torch.Tensor: + """Calculate mean across time and channel dimensions. + + Args: + input: tensor with shape (B, C, F, T) + input_length: tensor with shape (B,) + + Returns: + Mean of `input` calculated across time and channel dimension + with shape (B, 1, F, 1) + """ + assert input.ndim == 4, f'Expected input to have 4 dimensions, got {input.ndim}' + + if input_length is None: + mean = torch.mean(input, dim=(-1, -3), keepdim=True) + else: + # temporal mean + mean = calculate_mean(input, input_length, dim=-1, keepdim=True) + # channel mean + mean = torch.mean(mean, dim=-3, keepdim=True) + + return mean + + @classmethod + def get_mean_std_time_channel( + cls, input: torch.Tensor, input_length: Optional[torch.Tensor] = None, eps: float = 1e-10 + ) -> torch.Tensor: + """Calculate mean and standard deviation across time and channel dimensions. + + Args: + input: tensor with shape (B, C, F, T) + input_length: tensor with shape (B,) + + Returns: + Mean and standard deviation of the `input` calculated across time and + channel dimension, each with shape (B, 1, F, 1). + """ + assert input.ndim == 4, f'Expected input to have 4 dimensions, got {input.ndim}' + + if input_length is None: + std, mean = torch.std_mean(input, dim=(-1, -3), unbiased=False, keepdim=True) + else: + mean = cls.get_mean_time_channel(input, input_length) + std = (input - mean).pow(2) + # temporal mean + std = calculate_mean(std, input_length, dim=-1, keepdim=True) + # channel mean + std = torch.mean(std, dim=-3, keepdim=True) + # final value + std = torch.sqrt(std.clamp(eps)) + + return mean, std + + @typecheck( + input_types={ + 'input': NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + 'input_length': NeuralType(tuple('B'), LengthsType()), + }, + output_types={ + 'output': NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + }, + ) + def normalize_mean(self, input: torch.Tensor, input_length: torch.Tensor) -> torch.Tensor: + """Mean normalization for the input tensor. + + Args: + input: input tensor + input_length: valid length for each example + + Returns: + Mean normalized input. + """ + mean = self.get_mean_time_channel(input=input, input_length=input_length) + output = input - mean + return output + + @typecheck( + input_types={ + 'input': NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + 'input_length': NeuralType(tuple('B'), LengthsType()), + }, + output_types={ + 'output': NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + }, + ) + def normalize_mean_var(self, input: torch.Tensor, input_length: torch.Tensor) -> torch.Tensor: + """Mean and variance normalization for the input tensor. + + Args: + input: input tensor + input_length: valid length for each example + + Returns: + Mean and variance normalized input. + """ + mean, std = self.get_mean_std_time_channel(input=input, input_length=input_length, eps=self.eps) + output = (input - mean) / std + return output + + @typecheck() + def forward(self, input: torch.Tensor, input_length: torch.Tensor) -> torch.Tensor: + """Convert input batch of C-channel spectrograms into + a batch of time-frequency features with dimension num_feat. + The output number of channels may be the same as input, or + reduced to 1, e.g., if averaging over magnitude and not appending individual IPDs. + + Args: + input: Spectrogram for C channels with F subbands and N time frames, (B, C, F, N) + input_length: Length of valid entries along the time dimension, shape (B,) + + Returns: + num_feat_channels channels with num_feat features, shape (B, num_feat_channels, num_feat, N) + """ + # Magnitude spectrum + if self.mag_reduction is None: + mag = torch.abs(input) + elif self.mag_reduction == 'abs_mean': + mag = torch.abs(torch.mean(input, axis=1, keepdim=True)) + elif self.mag_reduction == 'mean_abs': + mag = torch.mean(torch.abs(input), axis=1, keepdim=True) + elif self.mag_reduction == 'rms': + mag = torch.sqrt(torch.mean(torch.abs(input) ** 2, axis=1, keepdim=True)) + else: + raise ValueError(f'Unexpected magnitude reduction {self.mag_reduction}') + + if self.mag_power is not None: + mag = torch.pow(mag, self.mag_power) + + if self.mag_normalization == 'mean': + # normalize mean across channels and time steps + mag = self.normalize_mean(input=mag, input_length=input_length) + elif self.mag_normalization == 'mean_var': + mag = self.normalize_mean_var(input=mag, input_length=input_length) + + features = mag + + if self.use_ipd: + # Calculate IPD relative to the average spec + spec_mean = torch.mean(input, axis=1, keepdim=True) # channel average + ipd = torch.angle(input) - torch.angle(spec_mean) + # Modulo to [-pi, pi] + ipd = wrap_to_pi(ipd) + + if self.ipd_normalization == 'mean': + # normalize mean across channels and time steps + # mean across time + ipd = self.normalize_mean(input=ipd, input_length=input_length) + elif self.ipd_normalization == 'mean_var': + ipd = self.normalize_mean_var(input=ipd, input_length=input_length) + + # Concatenate to existing features + features = torch.cat([features.expand(ipd.shape), ipd], axis=2) + + if self._num_channels is not None and features.size(1) != self._num_channels: + raise RuntimeError( + f'Number of channels in features {features.size(1)} is different than the configured number of channels {self._num_channels}' + ) + + return features, input_length diff --git a/nemo/collections/asr/modules/audio_modules.py b/nemo/collections/audio/modules/masking.py similarity index 61% rename from nemo/collections/asr/modules/audio_modules.py rename to nemo/collections/audio/modules/masking.py index 67a923099cdee..cfb575eea8797 100644 --- a/nemo/collections/asr/modules/audio_modules.py +++ b/nemo/collections/audio/modules/masking.py @@ -14,289 +14,23 @@ from typing import Dict, List, Optional, Tuple -import numpy as np import torch -from nemo.collections.asr.losses.audio_losses import calculate_mean from nemo.collections.asr.modules.conformer_encoder import ConformerEncoder from nemo.collections.asr.parts.preprocessing.features import make_seq_mask_like -from nemo.collections.asr.parts.submodules.multichannel_modules import ( +from nemo.collections.audio.modules.features import SpectrogramToMultichannelFeatures +from nemo.collections.audio.parts.submodules.multichannel import ( ChannelAttentionPool, ChannelAveragePool, ParametricMultichannelWienerFilter, TransformAttendConcatenate, TransformAverageConcatenate, + WPEFilter, ) -from nemo.collections.asr.parts.utils.audio_utils import db2mag, wrap_to_pi +from nemo.collections.audio.parts.utils.audio import db2mag from nemo.core.classes import NeuralModule, typecheck from nemo.core.neural_types import FloatType, LengthsType, NeuralType, SpectrogramType from nemo.utils import logging -from nemo.utils.decorators import experimental - -__all__ = [ - 'MaskEstimatorRNN', - 'MaskEstimatorFlexChannels', - 'MaskReferenceChannel', - 'MaskBasedBeamformer', - 'MaskBasedDereverbWPE', - 'MixtureConsistencyProjection', -] - - -class SpectrogramToMultichannelFeatures(NeuralModule): - """Convert a complex-valued multi-channel spectrogram to - multichannel features. - - Args: - num_subbands: Expected number of subbands in the input signal - num_input_channels: Optional, provides the number of channels - of the input signal. Used to infer the number - of output channels. - mag_reduction: Reduction across channels. Default `None`, will calculate - magnitude of each channel. - mag_power: Optional, apply power on the magnitude. - use_ipd: Use inter-channel phase difference (IPD). - mag_normalization: Normalization for magnitude features - ipd_normalization: Normalization for IPD features - eps: Small regularization constant. - """ - - def __init__( - self, - num_subbands: int, - num_input_channels: Optional[int] = None, - mag_reduction: Optional[str] = None, - mag_power: Optional[float] = None, - use_ipd: bool = False, - mag_normalization: Optional[str] = None, - ipd_normalization: Optional[str] = None, - eps: float = 1e-8, - ): - super().__init__() - self.mag_reduction = mag_reduction - self.mag_power = mag_power - self.use_ipd = use_ipd - - if mag_normalization not in [None, 'mean', 'mean_var']: - raise NotImplementedError(f'Unknown magnitude normalization {mag_normalization}') - self.mag_normalization = mag_normalization - - if ipd_normalization not in [None, 'mean', 'mean_var']: - raise NotImplementedError(f'Unknown ipd normalization {ipd_normalization}') - self.ipd_normalization = ipd_normalization - - if self.use_ipd: - self._num_features = 2 * num_subbands - self._num_channels = num_input_channels - else: - self._num_features = num_subbands - self._num_channels = num_input_channels if self.mag_reduction is None else 1 - - self.eps = eps - - logging.debug('Initialized %s with', self.__class__.__name__) - logging.debug('\tnum_subbands: %d', num_subbands) - logging.debug('\tmag_reduction: %s', self.mag_reduction) - logging.debug('\tmag_power: %s', self.mag_power) - logging.debug('\tuse_ipd: %s', self.use_ipd) - logging.debug('\tmag_normalization: %s', self.mag_normalization) - logging.debug('\tipd_normalization: %s', self.ipd_normalization) - logging.debug('\teps: %f', self.eps) - logging.debug('\t_num_features: %s', self._num_features) - logging.debug('\t_num_channels: %s', self._num_channels) - - @property - def input_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports. - """ - return { - "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), - "input_length": NeuralType(('B',), LengthsType()), - } - - @property - def output_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports. - """ - return { - "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), - "output_length": NeuralType(('B',), LengthsType()), - } - - @property - def num_features(self) -> int: - """Configured number of features - """ - return self._num_features - - @property - def num_channels(self) -> int: - """Configured number of channels - """ - if self._num_channels is not None: - return self._num_channels - else: - raise ValueError( - 'Num channels is not configured. To configure this, `num_input_channels` ' - 'must be provided when constructing the object.' - ) - - @staticmethod - def get_mean_time_channel(input: torch.Tensor, input_length: Optional[torch.Tensor] = None) -> torch.Tensor: - """Calculate mean across time and channel dimensions. - - Args: - input: tensor with shape (B, C, F, T) - input_length: tensor with shape (B,) - - Returns: - Mean of `input` calculated across time and channel dimension - with shape (B, 1, F, 1) - """ - assert input.ndim == 4, f'Expected input to have 4 dimensions, got {input.ndim}' - - if input_length is None: - mean = torch.mean(input, dim=(-1, -3), keepdim=True) - else: - # temporal mean - mean = calculate_mean(input, input_length, dim=-1, keepdim=True) - # channel mean - mean = torch.mean(mean, dim=-3, keepdim=True) - - return mean - - @classmethod - def get_mean_std_time_channel( - cls, input: torch.Tensor, input_length: Optional[torch.Tensor] = None, eps: float = 1e-10 - ) -> torch.Tensor: - """Calculate mean and standard deviation across time and channel dimensions. - - Args: - input: tensor with shape (B, C, F, T) - input_length: tensor with shape (B,) - - Returns: - Mean and standard deviation of the `input` calculated across time and - channel dimension, each with shape (B, 1, F, 1). - """ - assert input.ndim == 4, f'Expected input to have 4 dimensions, got {input.ndim}' - - if input_length is None: - std, mean = torch.std_mean(input, dim=(-1, -3), unbiased=False, keepdim=True) - else: - mean = cls.get_mean_time_channel(input, input_length) - std = (input - mean).pow(2) - # temporal mean - std = calculate_mean(std, input_length, dim=-1, keepdim=True) - # channel mean - std = torch.mean(std, dim=-3, keepdim=True) - # final value - std = torch.sqrt(std.clamp(eps)) - - return mean, std - - @typecheck( - input_types={ - 'input': NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), - 'input_length': NeuralType(tuple('B'), LengthsType()), - }, - output_types={'output': NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),}, - ) - def normalize_mean(self, input: torch.Tensor, input_length: torch.Tensor) -> torch.Tensor: - """Mean normalization for the input tensor. - - Args: - input: input tensor - input_length: valid length for each example - - Returns: - Mean normalized input. - """ - mean = self.get_mean_time_channel(input=input, input_length=input_length) - output = input - mean - return output - - @typecheck( - input_types={ - 'input': NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), - 'input_length': NeuralType(tuple('B'), LengthsType()), - }, - output_types={'output': NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),}, - ) - def normalize_mean_var(self, input: torch.Tensor, input_length: torch.Tensor) -> torch.Tensor: - """Mean and variance normalization for the input tensor. - - Args: - input: input tensor - input_length: valid length for each example - - Returns: - Mean and variance normalized input. - """ - mean, std = self.get_mean_std_time_channel(input=input, input_length=input_length, eps=self.eps) - output = (input - mean) / std - return output - - @typecheck() - def forward(self, input: torch.Tensor, input_length: torch.Tensor) -> torch.Tensor: - """Convert input batch of C-channel spectrograms into - a batch of time-frequency features with dimension num_feat. - The output number of channels may be the same as input, or - reduced to 1, e.g., if averaging over magnitude and not appending individual IPDs. - - Args: - input: Spectrogram for C channels with F subbands and N time frames, (B, C, F, N) - input_length: Length of valid entries along the time dimension, shape (B,) - - Returns: - num_feat_channels channels with num_feat features, shape (B, num_feat_channels, num_feat, N) - """ - # Magnitude spectrum - if self.mag_reduction is None: - mag = torch.abs(input) - elif self.mag_reduction == 'abs_mean': - mag = torch.abs(torch.mean(input, axis=1, keepdim=True)) - elif self.mag_reduction == 'mean_abs': - mag = torch.mean(torch.abs(input), axis=1, keepdim=True) - elif self.mag_reduction == 'rms': - mag = torch.sqrt(torch.mean(torch.abs(input) ** 2, axis=1, keepdim=True)) - else: - raise ValueError(f'Unexpected magnitude reduction {self.mag_reduction}') - - if self.mag_power is not None: - mag = torch.pow(mag, self.mag_power) - - if self.mag_normalization == 'mean': - # normalize mean across channels and time steps - mag = self.normalize_mean(input=mag, input_length=input_length) - elif self.mag_normalization == 'mean_var': - mag = self.normalize_mean_var(input=mag, input_length=input_length) - - features = mag - - if self.use_ipd: - # Calculate IPD relative to the average spec - spec_mean = torch.mean(input, axis=1, keepdim=True) # channel average - ipd = torch.angle(input) - torch.angle(spec_mean) - # Modulo to [-pi, pi] - ipd = wrap_to_pi(ipd) - - if self.ipd_normalization == 'mean': - # normalize mean across channels and time steps - # mean across time - ipd = self.normalize_mean(input=ipd, input_length=input_length) - elif self.ipd_normalization == 'mean_var': - ipd = self.normalize_mean_var(input=ipd, input_length=input_length) - - # Concatenate to existing features - features = torch.cat([features.expand(ipd.shape), ipd], axis=2) - - if self._num_channels is not None and features.size(1) != self._num_channels: - raise RuntimeError( - f'Number of channels in features {features.size(1)} is different than the configured number of channels {self._num_channels}' - ) - - return features, input_length class MaskEstimatorRNN(NeuralModule): @@ -389,8 +123,7 @@ def __init__( @property def input_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), "input_length": NeuralType(('B',), LengthsType()), @@ -398,8 +131,7 @@ def input_types(self) -> Dict[str, NeuralType]: @property def output_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { "output": NeuralType(('B', 'C', 'D', 'T'), FloatType()), "output_length": NeuralType(('B',), LengthsType()), @@ -638,8 +370,7 @@ def __init__( @property def input_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), "input_length": NeuralType(('B',), LengthsType()), @@ -647,8 +378,7 @@ def input_types(self) -> Dict[str, NeuralType]: @property def output_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { "output": NeuralType(('B', 'C', 'D', 'T'), FloatType()), "output_length": NeuralType(('B',), LengthsType()), @@ -656,8 +386,7 @@ def output_types(self) -> Dict[str, NeuralType]: @typecheck() def forward(self, input: torch.Tensor, input_length: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """Estimate `num_outputs` masks from the input spectrogram. - """ + """Estimate `num_outputs` masks from the input spectrogram.""" # get input features from a complex-valued spectrogram, (B, C, F, T) output, output_length = self.features(input=input, input_length=input_length) @@ -786,7 +515,9 @@ def normalize(self, x: torch.Tensor, dim: int = 1) -> torch.Tensor: 'activity': NeuralType(('B', 'C', 'T')), 'log_pdf': NeuralType(('B', 'C', 'D', 'T')), }, - output_types={'gamma': NeuralType(('B', 'C', 'D', 'T')),}, + output_types={ + 'gamma': NeuralType(('B', 'C', 'D', 'T')), + }, ) def update_masks(self, alpha: torch.Tensor, activity: torch.Tensor, log_pdf: torch.Tensor) -> torch.Tensor: """Update masks for the cACGMM. @@ -814,7 +545,12 @@ def update_masks(self, alpha: torch.Tensor, activity: torch.Tensor, log_pdf: tor return gamma @typecheck( - input_types={'gamma': NeuralType(('B', 'C', 'D', 'T')),}, output_types={'alpha': NeuralType(('B', 'C', 'D')),}, + input_types={ + 'gamma': NeuralType(('B', 'C', 'D', 'T')), + }, + output_types={ + 'alpha': NeuralType(('B', 'C', 'D')), + }, ) def update_weights(self, gamma: torch.Tensor) -> torch.Tensor: """Update weights for the individual components @@ -835,7 +571,10 @@ def update_weights(self, gamma: torch.Tensor) -> torch.Tensor: 'gamma': NeuralType(('B', 'C', 'D', 'T')), 'zH_invBM_z': NeuralType(('B', 'C', 'D', 'T')), }, - output_types={'log_pdf': NeuralType(('B', 'C', 'D', 'T')), 'zH_invBM_z': NeuralType(('B', 'C', 'D', 'T')),}, + output_types={ + 'log_pdf': NeuralType(('B', 'C', 'D', 'T')), + 'zH_invBM_z': NeuralType(('B', 'C', 'D', 'T')), + }, ) def update_pdf( self, z: torch.Tensor, gamma: torch.Tensor, zH_invBM_z: torch.Tensor @@ -903,8 +642,7 @@ def update_pdf( @property def input_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), "activity": NeuralType(('B', 'C', 'T')), @@ -912,8 +650,7 @@ def input_types(self) -> Dict[str, NeuralType]: @property def output_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { "gamma": NeuralType(('B', 'C', 'D', 'T')), } @@ -995,8 +732,7 @@ def __init__(self, ref_channel: int = 0, mask_min_db: float = -200, mask_max_db: @property def input_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), "input_length": NeuralType(('B',), LengthsType()), @@ -1005,8 +741,7 @@ def input_types(self) -> Dict[str, NeuralType]: @property def output_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), "output_length": NeuralType(('B',), LengthsType()), @@ -1014,7 +749,10 @@ def output_types(self) -> Dict[str, NeuralType]: @typecheck() def forward( - self, input: torch.Tensor, input_length: torch.Tensor, mask: torch.Tensor, + self, + input: torch.Tensor, + input_length: torch.Tensor, + mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Apply mask on `ref_channel` of the input signal. This can be used to generate multi-channel output. @@ -1124,8 +862,7 @@ def __init__( @property def input_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), "mask": NeuralType(('B', 'C', 'D', 'T'), FloatType()), @@ -1135,8 +872,7 @@ def input_types(self) -> Dict[str, NeuralType]: @property def output_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), "output_length": NeuralType(('B',), LengthsType(), optional=True), @@ -1161,7 +897,7 @@ def forward( input: Input signal complex-valued spectrogram, shape (B, C, F, N) mask: Mask for M output signals, shape (B, num_masks, F, N) input_length: Length of valid entries along the time dimension, shape (B,) - + Returns: Multichannel output signal complex-valued spectrogram, shape (B, num_masks * M, F, N) """ @@ -1216,296 +952,6 @@ def forward( return output, input_length -class WPEFilter(NeuralModule): - """A weighted prediction error filter. - Given input signal, and expected power of the desired signal, this - class estimates a multiple-input multiple-output prediction filter - and returns the filtered signal. Currently, estimation of statistics - and processing is performed in batch mode. - - Args: - filter_length: Length of the prediction filter in frames, per channel - prediction_delay: Prediction delay in frames - diag_reg: Diagonal regularization for the correlation matrix Q, applied as diag_reg * trace(Q) + eps - eps: Small positive constant for regularization - - References: - - Yoshioka and Nakatani, Generalization of Multi-Channel Linear Prediction - Methods for Blind MIMO Impulse Response Shortening, 2012 - - Jukić et al, Group sparsity for MIMO speech dereverberation, 2015 - """ - - def __init__(self, filter_length: int, prediction_delay: int, diag_reg: Optional[float] = 1e-6, eps: float = 1e-8): - super().__init__() - self.filter_length = filter_length - self.prediction_delay = prediction_delay - self.diag_reg = diag_reg - self.eps = eps - - logging.debug('Initialized %s', self.__class__.__name__) - logging.debug('\tfilter_length: %d', self.filter_length) - logging.debug('\tprediction_delay: %d', self.prediction_delay) - logging.debug('\tdiag_reg: %g', self.diag_reg) - logging.debug('\teps: %g', self.eps) - - @property - def input_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports. - """ - return { - "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), - "power": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), - "input_length": NeuralType(('B',), LengthsType(), optional=True), - } - - @property - def output_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports. - """ - return { - "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), - "output_length": NeuralType(('B',), LengthsType(), optional=True), - } - - @typecheck() - def forward( - self, input: torch.Tensor, power: torch.Tensor, input_length: Optional[torch.Tensor] = None - ) -> torch.Tensor: - """Given input and the predicted power for the desired signal, estimate - the WPE filter and return the processed signal. - - Args: - input: Input signal, shape (B, C, F, N) - power: Predicted power of the desired signal, shape (B, C, F, N) - input_length: Optional, length of valid frames in `input`. Defaults to `None` - - Returns: - Tuple of (processed_signal, output_length). Processed signal has the same - shape as the input signal (B, C, F, N), and the output length is the same - as the input length. - """ - # Temporal weighting: average power over channels, output shape (B, F, N) - weight = torch.mean(power, dim=1) - # Use inverse power as the weight - weight = 1 / (weight + self.eps) - - # Multi-channel convolution matrix for each subband - tilde_input = self.convtensor(input, filter_length=self.filter_length, delay=self.prediction_delay) - - # Estimate correlation matrices - Q, R = self.estimate_correlations( - input=input, weight=weight, tilde_input=tilde_input, input_length=input_length - ) - - # Estimate prediction filter - G = self.estimate_filter(Q=Q, R=R) - - # Apply prediction filter - undesired_signal = self.apply_filter(filter=G, tilde_input=tilde_input) - - # Dereverberation - desired_signal = input - undesired_signal - - if input_length is not None: - # Mask padded frames - length_mask: torch.Tensor = make_seq_mask_like( - lengths=input_length, like=desired_signal, time_dim=-1, valid_ones=False - ) - desired_signal = desired_signal.masked_fill(length_mask, 0.0) - - return desired_signal, input_length - - @classmethod - def convtensor( - cls, x: torch.Tensor, filter_length: int, delay: int = 0, n_steps: Optional[int] = None - ) -> torch.Tensor: - """Create a tensor equivalent of convmtx_mc for each example in the batch. - The input signal tensor `x` has shape (B, C, F, N). - Convtensor returns a view of the input signal `x`. - - Note: We avoid reshaping the output to collapse channels and filter taps into - a single dimension, e.g., (B, F, N, -1). In this way, the output is a view of the input, - while an additional reshape would result in a contiguous array and more memory use. - - Args: - x: input tensor, shape (B, C, F, N) - filter_length: length of the filter, determines the shape of the convolution tensor - delay: delay to add to the input signal `x` before constructing the convolution tensor - n_steps: Optional, number of time steps to keep in the out. Defaults to the number of - time steps in the input tensor. - - Returns: - Return a convolutional tensor with shape (B, C, F, n_steps, filter_length) - """ - if x.ndim != 4: - raise RuntimeError(f'Expecting a 4-D input. Received input with shape {x.shape}') - - B, C, F, N = x.shape - - if n_steps is None: - # Keep the same length as the input signal - n_steps = N - - # Pad temporal dimension - x = torch.nn.functional.pad(x, (filter_length - 1 + delay, 0)) - - # Build Toeplitz-like matrix view by unfolding across time - tilde_X = x.unfold(-1, filter_length, 1) - - # Trim to the set number of time steps - tilde_X = tilde_X[:, :, :, :n_steps, :] - - return tilde_X - - @classmethod - def permute_convtensor(cls, x: torch.Tensor) -> torch.Tensor: - """Reshape and permute columns to convert the result of - convtensor to be equal to convmtx_mc. This is used for verification - purposes and it is not required to use the filter. - - Args: - x: output of self.convtensor, shape (B, C, F, N, filter_length) - - Returns: - Output has shape (B, F, N, C*filter_length) that corresponds to - the layout of convmtx_mc. - """ - B, C, F, N, filter_length = x.shape - - # .view will not work, so a copy will have to be created with .reshape - # That will result in more memory use, since we don't use a view of the original - # multi-channel signal - x = x.permute(0, 2, 3, 1, 4) - x = x.reshape(B, F, N, C * filter_length) - - permute = [] - for m in range(C): - permute[m * filter_length : (m + 1) * filter_length] = m * filter_length + np.flip( - np.arange(filter_length) - ) - return x[..., permute] - - def estimate_correlations( - self, - input: torch.Tensor, - weight: torch.Tensor, - tilde_input: torch.Tensor, - input_length: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor]: - """ - Args: - input: Input signal, shape (B, C, F, N) - weight: Time-frequency weight, shape (B, F, N) - tilde_input: Multi-channel convolution tensor, shape (B, C, F, N, filter_length) - input_length: Length of each input example, shape (B) - - Returns: - Returns a tuple of correlation matrices for each batch. - - Let `X` denote the input signal in a single subband, - `tilde{X}` the corresponding multi-channel correlation matrix, - and `w` the vector of weights. - - The first output is - Q = tilde{X}^H * diag(w) * tilde{X} (1) - for each (b, f). - The matrix calculated in (1) has shape (C * filter_length, C * filter_length) - The output is returned in a tensor with shape (B, F, C, filter_length, C, filter_length). - - The second output is - R = tilde{X}^H * diag(w) * X (2) - for each (b, f). - The matrix calculated in (2) has shape (C * filter_length, C) - The output is returned in a tensor with shape (B, F, C, filter_length, C). The last - dimension corresponds to output channels. - """ - if input_length is not None: - # Take only valid samples into account - length_mask: torch.Tensor = make_seq_mask_like( - lengths=input_length, like=weight, time_dim=-1, valid_ones=False - ) - weight = weight.masked_fill(length_mask, 0.0) - - # Calculate (1) - # result: (B, F, C, filter_length, C, filter_length) - Q = torch.einsum('bjfik,bmfin->bfjkmn', tilde_input.conj(), weight[:, None, :, :, None] * tilde_input) - - # Calculate (2) - # result: (B, F, C, filter_length, C) - R = torch.einsum('bjfik,bmfi->bfjkm', tilde_input.conj(), weight[:, None, :, :] * input) - - return Q, R - - def estimate_filter(self, Q: torch.Tensor, R: torch.Tensor) -> torch.Tensor: - """Estimate the MIMO prediction filter as - G(b,f) = Q(b,f) \ R(b,f) - for each subband in each example in the batch (b, f). - - Args: - Q: shape (B, F, C, filter_length, C, filter_length) - R: shape (B, F, C, filter_length, C) - - Returns: - Complex-valued prediction filter, shape (B, C, F, C, filter_length) - """ - B, F, C, filter_length, _, _ = Q.shape - assert ( - filter_length == self.filter_length - ), f'Shape of Q {Q.shape} is not matching filter length {self.filter_length}' - - # Reshape to analytical dimensions for each (b, f) - Q = Q.reshape(B, F, C * self.filter_length, C * filter_length) - R = R.reshape(B, F, C * self.filter_length, C) - - # Diagonal regularization - if self.diag_reg: - # Regularization: diag_reg * trace(Q) + eps - diag_reg = self.diag_reg * torch.diagonal(Q, dim1=-2, dim2=-1).sum(-1).real + self.eps - # Apply regularization on Q - Q = Q + torch.diag_embed(diag_reg.unsqueeze(-1) * torch.ones(Q.shape[-1], device=Q.device)) - - # Solve for the filter - G = torch.linalg.solve(Q, R) - - # Reshape to desired representation: (B, F, input channels, filter_length, output channels) - G = G.reshape(B, F, C, filter_length, C) - # Move output channels to front: (B, output channels, F, input channels, filter_length) - G = G.permute(0, 4, 1, 2, 3) - - return G - - def apply_filter( - self, filter: torch.Tensor, input: Optional[torch.Tensor] = None, tilde_input: Optional[torch.Tensor] = None - ) -> torch.Tensor: - """Apply a prediction filter `filter` on the input `input` as - - output(b,f) = tilde{input(b,f)} * filter(b,f) - - If available, directly use the convolution matrix `tilde_input`. - - Args: - input: Input signal, shape (B, C, F, N) - tilde_input: Convolution matrix for the input signal, shape (B, C, F, N, filter_length) - filter: Prediction filter, shape (B, C, F, C, filter_length) - - Returns: - Multi-channel signal obtained by applying the prediction filter on - the input signal, same shape as input (B, C, F, N) - """ - if input is None and tilde_input is None: - raise RuntimeError(f'Both inputs cannot be None simultaneously.') - if input is not None and tilde_input is not None: - raise RuntimeError(f'Both inputs cannot be provided simultaneously.') - - if tilde_input is None: - tilde_input = self.convtensor(input, filter_length=self.filter_length, delay=self.prediction_delay) - - # For each (batch, output channel, f, time step), sum across (input channel, filter tap) - output = torch.einsum('bjfik,bmfjk->bmfi', tilde_input, filter) - - return output - - class MaskBasedDereverbWPE(NeuralModule): """Multi-channel linear prediction-based dereverberation using weighted prediction error for filter estimation. @@ -1562,8 +1008,7 @@ def __init__( @property def input_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), "input_length": NeuralType(('B',), LengthsType(), optional=True), @@ -1572,8 +1017,7 @@ def input_types(self) -> Dict[str, NeuralType]: @property def output_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), "output_length": NeuralType(('B',), LengthsType(), optional=True), @@ -1610,77 +1054,8 @@ def forward( # Mask magnitude magnitude = mask * magnitude # Calculate power - power = magnitude ** 2 + power = magnitude**2 # Apply filter output, output_length = self.filter(input=output, input_length=input_length, power=power) return output.to(io_dtype), output_length - - -class MixtureConsistencyProjection(NeuralModule): - """Ensure estimated sources are consistent with the input mixture. - Note that the input mixture is assume to be a single-channel signal. - - Args: - weighting: Optional weighting mode for the consistency constraint. - If `None`, use uniform weighting. If `power`, use the power of the - estimated source as the weight. - eps: Small positive value for regularization - - Reference: - Wisdom et al, Differentiable consistency constraints for improved deep speech enhancement, 2018 - """ - - def __init__(self, weighting: Optional[str] = None, eps: float = 1e-8): - super().__init__() - self.weighting = weighting - self.eps = eps - - if self.weighting not in [None, 'power']: - raise NotImplementedError(f'Weighting mode {self.weighting} not implemented') - - @property - def input_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports. - """ - return { - "mixture": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), - "estimate": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), - } - - @property - def output_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports. - """ - return { - "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), - } - - @typecheck() - def forward(self, mixture: torch.Tensor, estimate: torch.Tensor) -> torch.Tensor: - """Enforce mixture consistency on the estimated sources. - Args: - mixture: Single-channel mixture, shape (B, 1, F, N) - estimate: M estimated sources, shape (B, M, F, N) - - Returns: - Source estimates consistent with the mixture, shape (B, M, F, N) - """ - # number of sources - M = estimate.size(-3) - # estimated mixture based on the estimated sources - estimated_mixture = torch.sum(estimate, dim=-3, keepdim=True) - - # weighting - if self.weighting is None: - weight = 1 / M - elif self.weighting == 'power': - weight = estimate.abs().pow(2) - weight = weight / (weight.sum(dim=-3, keepdim=True) + self.eps) - else: - raise NotImplementedError(f'Weighting mode {self.weighting} not implemented') - - # consistent estimate - consistent_estimate = estimate + weight * (mixture - estimated_mixture) - - return consistent_estimate diff --git a/nemo/collections/audio/modules/projections.py b/nemo/collections/audio/modules/projections.py new file mode 100644 index 0000000000000..9012432287dbd --- /dev/null +++ b/nemo/collections/audio/modules/projections.py @@ -0,0 +1,87 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional + +import torch + +from nemo.core.classes import NeuralModule, typecheck +from nemo.core.neural_types import NeuralType, SpectrogramType + + +class MixtureConsistencyProjection(NeuralModule): + """Ensure estimated sources are consistent with the input mixture. + Note that the input mixture is assume to be a single-channel signal. + + Args: + weighting: Optional weighting mode for the consistency constraint. + If `None`, use uniform weighting. If `power`, use the power of the + estimated source as the weight. + eps: Small positive value for regularization + + Reference: + Wisdom et al, Differentiable consistency constraints for improved deep speech enhancement, 2018 + """ + + def __init__(self, weighting: Optional[str] = None, eps: float = 1e-8): + super().__init__() + self.weighting = weighting + self.eps = eps + + if self.weighting not in [None, 'power']: + raise NotImplementedError(f'Weighting mode {self.weighting} not implemented') + + @property + def input_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports.""" + return { + "mixture": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "estimate": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports.""" + return { + "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + } + + @typecheck() + def forward(self, mixture: torch.Tensor, estimate: torch.Tensor) -> torch.Tensor: + """Enforce mixture consistency on the estimated sources. + Args: + mixture: Single-channel mixture, shape (B, 1, F, N) + estimate: M estimated sources, shape (B, M, F, N) + + Returns: + Source estimates consistent with the mixture, shape (B, M, F, N) + """ + # number of sources + M = estimate.size(-3) + # estimated mixture based on the estimated sources + estimated_mixture = torch.sum(estimate, dim=-3, keepdim=True) + + # weighting + if self.weighting is None: + weight = 1 / M + elif self.weighting == 'power': + weight = estimate.abs().pow(2) + weight = weight / (weight.sum(dim=-3, keepdim=True) + self.eps) + else: + raise NotImplementedError(f'Weighting mode {self.weighting} not implemented') + + # consistent estimate + consistent_estimate = estimate + weight * (mixture - estimated_mixture) + + return consistent_estimate diff --git a/nemo/collections/audio/modules/transforms.py b/nemo/collections/audio/modules/transforms.py new file mode 100644 index 0000000000000..ecbdca88e22b5 --- /dev/null +++ b/nemo/collections/audio/modules/transforms.py @@ -0,0 +1,277 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Optional, Tuple + +import torch + +from nemo.collections.asr.parts.preprocessing.features import make_seq_mask_like +from nemo.core.classes import NeuralModule, typecheck +from nemo.core.neural_types import AudioSignal, LengthsType, NeuralType, SpectrogramType +from nemo.utils import logging + +try: + import torchaudio + import torchaudio.functional + import torchaudio.transforms + + HAVE_TORCHAUDIO = True +except ModuleNotFoundError: + HAVE_TORCHAUDIO = False + + +class AudioToSpectrogram(NeuralModule): + """Transform a batch of input multi-channel signals into a batch of + STFT-based spectrograms. + + Args: + fft_length: length of FFT + hop_length: length of hops/shifts of the sliding window + power: exponent for magnitude spectrogram. Default `None` will + return a complex-valued spectrogram + magnitude_power: Transform magnitude of the spectrogram as x^magnitude_power. + scale: Positive scaling of the spectrogram. + """ + + def __init__(self, fft_length: int, hop_length: int, magnitude_power: float = 1.0, scale: float = 1.0): + if not HAVE_TORCHAUDIO: + logging.error('Could not import torchaudio. Some features might not work.') + + raise ModuleNotFoundError( + f"torchaudio is not installed but is necessary to instantiate a {self.__class__.__name__}" + ) + + super().__init__() + + # For now, assume FFT length is divisible by two + if fft_length % 2 != 0: + raise ValueError(f'fft_length = {fft_length} must be divisible by 2') + + self.stft = torchaudio.transforms.Spectrogram( + n_fft=fft_length, hop_length=hop_length, power=None, pad_mode='constant' + ) + + # number of subbands + self.F = fft_length // 2 + 1 + + if magnitude_power <= 0: + raise ValueError(f'Magnitude power needs to be positive: current value {magnitude_power}') + self.magnitude_power = magnitude_power + + if scale <= 0: + raise ValueError(f'Scale needs to be positive: current value {scale}') + self.scale = scale + + logging.debug('Initialized %s with:', self.__class__.__name__) + logging.debug('\tfft_length: %s', fft_length) + logging.debug('\thop_length: %s', hop_length) + logging.debug('\tmagnitude_power: %s', magnitude_power) + logging.debug('\tscale: %s', scale) + + @property + def num_subbands(self) -> int: + return self.F + + @property + def input_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports.""" + return { + "input": NeuralType(('B', 'C', 'T'), AudioSignal()), + "input_length": NeuralType(('B',), LengthsType(), optional=True), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports.""" + return { + "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "output_length": NeuralType(('B',), LengthsType()), + } + + @typecheck() + def forward( + self, input: torch.Tensor, input_length: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Convert a batch of C-channel input signals + into a batch of complex-valued spectrograms. + + Args: + input: Time-domain input signal with C channels, shape (B, C, T) + input_length: Length of valid entries along the time dimension, shape (B,) + + Returns: + Output spectrogram with F subbands and N time frames, shape (B, C, F, N) + and output length with shape (B,). + """ + B, T = input.size(0), input.size(-1) + input = input.view(B, -1, T) + + # STFT output (B, C, F, N) + with torch.cuda.amp.autocast(enabled=False): + output = self.stft(input.float()) + + if self.magnitude_power != 1: + # apply power on the magnitude + output = torch.pow(output.abs(), self.magnitude_power) * torch.exp(1j * output.angle()) + + if self.scale != 1: + # apply scaling of the coefficients + output = self.scale * output + + if input_length is not None: + # Mask padded frames + output_length = self.get_output_length(input_length=input_length) + + length_mask: torch.Tensor = make_seq_mask_like( + lengths=output_length, like=output, time_dim=-1, valid_ones=False + ) + output = output.masked_fill(length_mask, 0.0) + else: + # Assume all frames are valid for all examples in the batch + output_length = output.size(-1) * torch.ones(B, device=output.device).long() + + return output, output_length + + def get_output_length(self, input_length: torch.Tensor) -> torch.Tensor: + """Get length of valid frames for the output. + + Args: + input_length: number of valid samples, shape (B,) + + Returns: + Number of valid frames, shape (B,) + """ + output_length = input_length.div(self.stft.hop_length, rounding_mode='floor').add(1).long() + return output_length + + +class SpectrogramToAudio(NeuralModule): + """Transform a batch of input multi-channel spectrograms into a batch of + time-domain multi-channel signals. + + Args: + fft_length: length of FFT + hop_length: length of hops/shifts of the sliding window + magnitude_power: Transform magnitude of the spectrogram as x^(1/magnitude_power). + scale: Spectrogram will be scaled with 1/scale before the inverse transform. + """ + + def __init__(self, fft_length: int, hop_length: int, magnitude_power: float = 1.0, scale: float = 1.0): + if not HAVE_TORCHAUDIO: + logging.error('Could not import torchaudio. Some features might not work.') + + raise ModuleNotFoundError( + f"torchaudio is not installed but is necessary to instantiate a {self.__class__.__name__}" + ) + + super().__init__() + + # For now, assume FFT length is divisible by two + if fft_length % 2 != 0: + raise ValueError(f'fft_length = {fft_length} must be divisible by 2') + + self.istft = torchaudio.transforms.InverseSpectrogram( + n_fft=fft_length, hop_length=hop_length, pad_mode='constant' + ) + + self.F = fft_length // 2 + 1 + + if magnitude_power <= 0: + raise ValueError(f'Magnitude power needs to be positive: current value {magnitude_power}') + self.magnitude_power = magnitude_power + + if scale <= 0: + raise ValueError(f'Scale needs to be positive: current value {scale}') + self.scale = scale + + logging.debug('Initialized %s with:', self.__class__.__name__) + logging.debug('\tfft_length: %s', fft_length) + logging.debug('\thop_length: %s', hop_length) + logging.debug('\tmagnitude_power: %s', magnitude_power) + logging.debug('\tscale: %s', scale) + + @property + def num_subbands(self) -> int: + return self.F + + @property + def input_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports.""" + return { + "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "input_length": NeuralType(('B',), LengthsType(), optional=True), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports.""" + return { + "output": NeuralType(('B', 'C', 'T'), AudioSignal()), + "output_length": NeuralType(('B',), LengthsType()), + } + + @typecheck() + def forward(self, input: torch.Tensor, input_length: Optional[torch.Tensor] = None) -> torch.Tensor: + """Convert input complex-valued spectrogram to a time-domain + signal. Multi-channel IO is supported. + + Args: + input: Input spectrogram for C channels, shape (B, C, F, N) + input_length: Length of valid entries along the time dimension, shape (B,) + + Returns: + Time-domain signal with T time-domain samples and C channels, (B, C, T) + and output length with shape (B,). + """ + B, F, N = input.size(0), input.size(-2), input.size(-1) + assert F == self.F, f'Number of subbands F={F} not matching self.F={self.F}' + input = input.view(B, -1, F, N) + + # iSTFT output (B, C, T) + with torch.cuda.amp.autocast(enabled=False): + output = input.cfloat() + + if self.scale != 1: + # apply 1/scale on the coefficients + output = output / self.scale + + if self.magnitude_power != 1: + # apply 1/power on the magnitude + output = torch.pow(output.abs(), 1 / self.magnitude_power) * torch.exp(1j * output.angle()) + output = self.istft(output) + + if input_length is not None: + # Mask padded samples + output_length = self.get_output_length(input_length=input_length) + + length_mask: torch.Tensor = make_seq_mask_like( + lengths=output_length, like=output, time_dim=-1, valid_ones=False + ) + output = output.masked_fill(length_mask, 0.0) + else: + # Assume all frames are valid for all examples in the batch + output_length = output.size(-1) * torch.ones(B, device=output.device).long() + + return output, output_length + + def get_output_length(self, input_length: torch.Tensor) -> torch.Tensor: + """Get length of valid samples for the output. + + Args: + input_length: number of valid frames, shape (B,) + + Returns: + Number of valid samples, shape (B,) + """ + output_length = input_length.sub(1).mul(self.istft.hop_length).long() + return output_length diff --git a/nemo/collections/audio/parts/__init__.py b/nemo/collections/audio/parts/__init__.py new file mode 100644 index 0000000000000..d9155f923f186 --- /dev/null +++ b/nemo/collections/audio/parts/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/audio/parts/submodules/__init__.py b/nemo/collections/audio/parts/submodules/__init__.py new file mode 100644 index 0000000000000..d9155f923f186 --- /dev/null +++ b/nemo/collections/audio/parts/submodules/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/asr/parts/submodules/diffusion.py b/nemo/collections/audio/parts/submodules/diffusion.py similarity index 57% rename from nemo/collections/asr/parts/submodules/diffusion.py rename to nemo/collections/audio/parts/submodules/diffusion.py index db3d30f497019..c8b3e803e373a 100644 --- a/nemo/collections/asr/parts/submodules/diffusion.py +++ b/nemo/collections/audio/parts/submodules/diffusion.py @@ -12,33 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math from abc import ABC, abstractmethod -from typing import Dict, Optional, Sequence, Tuple, Type +from typing import Optional, Tuple, Type -import einops -import einops.layers.torch import numpy as np import torch -import torch.nn.functional as F -from nemo.collections.common.parts.utils import activation_registry from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor from nemo.core.classes import NeuralModule, typecheck from nemo.core.neural_types import FloatType, LengthsType, NeuralType, SpectrogramType, VoidType from nemo.utils import logging -__all__ = [ - 'OrnsteinUhlenbeckVarianceExplodingSDE', - 'SpectrogramNoiseConditionalScoreNetworkPlusPlus', - 'NoiseConditionalScoreNetworkPlusPlus', - 'PredictorCorrectorSampler', -] - class StochasticDifferentialEquation(NeuralModule, ABC): - """Base class for stochastic differential equations. - """ + """Base class for stochastic differential equations.""" def __init__(self, time_min: float, time_max: float, num_steps: int): super().__init__() @@ -68,8 +55,7 @@ def dt(self) -> float: @property def time_delta(self) -> float: - """Time range for this SDE. - """ + """Time range for this SDE.""" return self.time_max - self.time_min def generate_time(self, size: int, device: torch.device) -> torch.Tensor: @@ -100,8 +86,12 @@ def coefficients(self, state: torch.Tensor, time: torch.Tensor, **kwargs) -> Tup pass @typecheck( - input_types={"prior_mean": NeuralType(('B', 'C', 'D', 'T'), VoidType()),}, - output_types={"sample": NeuralType(('B', 'C', 'D', 'T'), VoidType()),}, + input_types={ + "prior_mean": NeuralType(('B', 'C', 'D', 'T'), VoidType()), + }, + output_types={ + "sample": NeuralType(('B', 'C', 'D', 'T'), VoidType()), + }, ) @abstractmethod def prior_sampling(self, prior_mean: torch.Tensor) -> torch.Tensor: @@ -156,8 +146,7 @@ def discretize( @abstractmethod def copy(self): - """Create a copy of this SDE. - """ + """Create a copy of this SDE.""" pass def __repr__(self): @@ -235,7 +224,9 @@ def log_std_ratio(self) -> float: "prior_mean": NeuralType(('B', 'C', 'D', 'T'), VoidType()), "time": NeuralType(tuple('B'), FloatType()), }, - output_types={"mean": NeuralType(('B', 'C', 'D', 'T'), FloatType()),}, + output_types={ + "mean": NeuralType(('B', 'C', 'D', 'T'), FloatType()), + }, ) def perturb_kernel_mean(self, state: torch.Tensor, prior_mean: torch.Tensor, time: torch.Tensor) -> torch.Tensor: """Return the mean of the perturbation kernel for this SDE. @@ -260,8 +251,12 @@ def perturb_kernel_mean(self, state: torch.Tensor, prior_mean: torch.Tensor, tim return mean @typecheck( - input_types={"time": NeuralType(tuple('B'), FloatType()),}, - output_types={"std": NeuralType(tuple('B'), FloatType()),}, + input_types={ + "time": NeuralType(tuple('B'), FloatType()), + }, + output_types={ + "std": NeuralType(tuple('B'), FloatType()), + }, ) def perturb_kernel_std(self, time: torch.Tensor) -> torch.Tensor: """Return the standard deviation of the perturbation kernel for this SDE. @@ -275,7 +270,7 @@ def perturb_kernel_std(self, time: torch.Tensor) -> torch.Tensor: Returns: A tensor of shape (B,) """ - var = (self.std_min ** 2) * self.log_std_ratio + var = (self.std_min**2) * self.log_std_ratio var *= torch.pow(self.std_ratio, 2 * time) - torch.exp(-2 * self.stiffness * time) var /= self.stiffness + self.log_std_ratio std = torch.sqrt(var) @@ -429,8 +424,7 @@ def coefficients( raise NotImplementedError('Coefficients not necessary for the reverse SDE.') def prior_sampling(self, shape: torch.Size, device: torch.device) -> torch.Tensor: - """Prior sampling is not necessary for the reverse SDE. - """ + """Prior sampling is not necessary for the reverse SDE.""" raise NotImplementedError('Prior sampling not necessary for the reverse SDE.') def discretize( @@ -482,493 +476,6 @@ def __repr__(self): return desc -class SpectrogramNoiseConditionalScoreNetworkPlusPlus(NeuralModule): - """This model handles complex-valued inputs by stacking real and imaginary components. - Stacked tensor is processed using NCSN++ and the output is projected to generate real - and imaginary components of the output channels. - - Args: - in_channels: number of input complex-valued channels - out_channels: number of output complex-valued channels - """ - - def __init__(self, *, in_channels: int = 1, out_channels: int = 1, **kwargs): - super().__init__() - - # Number of input signals for this estimator - if in_channels < 1: - raise ValueError( - f'Number of input channels needs to be larger or equal to one, current value {in_channels}' - ) - - self.in_channels = in_channels - - # Number of output signals for this estimator - if out_channels < 1: - raise ValueError( - f'Number of output channels needs to be larger or equal to one, current value {out_channels}' - ) - - self.out_channels = out_channels - - # Instantiate noise conditional score network NCSN++ - ncsnpp_params = kwargs.copy() - ncsnpp_params['in_channels'] = ncsnpp_params['out_channels'] = 2 * self.in_channels # stack real and imag - self.ncsnpp = NoiseConditionalScoreNetworkPlusPlus(**ncsnpp_params) - - # Output projection to generate real and imaginary components of the output channels - self.output_projection = torch.nn.Conv2d( - in_channels=2 * self.in_channels, out_channels=2 * self.out_channels, kernel_size=1 - ) - - logging.debug('Initialized %s with', self.__class__.__name__) - logging.debug('\tin_channels: %s', self.in_channels) - logging.debug('\tout_channels: %s', self.out_channels) - - @property - def input_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports. - """ - return { - "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), - "input_length": NeuralType(('B',), LengthsType(), optional=True), - "condition": NeuralType(('B',), FloatType(), optional=True), - } - - @property - def output_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports. - """ - return { - "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), - "output_length": NeuralType(('B',), LengthsType(), optional=True), - } - - @typecheck() - def forward(self, input, input_length=None, condition=None): - # Stack real and imaginary components - B, C_in, D, T = input.shape - - if C_in != self.in_channels: - raise RuntimeError(f'Unexpected input channel size {C_in}, expected {self.in_channels}') - - # Stack real and imaginary parts - input_real_imag = torch.stack([input.real, input.imag], dim=2) - input = einops.rearrange(input_real_imag, 'B C RI F T -> B (C RI) F T') - - # Process using NCSN++ - output, output_length = self.ncsnpp(input=input, input_length=input_length, condition=condition) - - # Output projection - output = self.output_projection(output) - - # Convert to complex-valued signal - output = output.reshape(B, 2, self.out_channels, D, T) - # Move real/imag dimension to the end - output = output.permute(0, 2, 3, 4, 1) - output = torch.view_as_complex(output.contiguous()) - - return output, output_length - - -class NoiseConditionalScoreNetworkPlusPlus(NeuralModule): - """Implementation of Noise Conditional Score Network (NCSN++) architecture. - - References: - - Song et al., Score-Based Generative Modeling through Stochastic Differential Equations, NeurIPS 2021 - - Brock et al., Large scale GAN training for high fidelity natural image synthesis, ICLR 2018 - """ - - def __init__( - self, - nonlinearity: str = "swish", - in_channels: int = 2, # number of channels in the input image - out_channels: int = 2, # number of channels in the output image - channels: Sequence[int] = (128, 128, 256, 256, 256), # number of channels at start + at every resolution - num_res_blocks: int = 2, - num_resolutions: int = 4, - init_scale: float = 1e-5, - conditioned_on_time: bool = False, - fourier_embedding_scale: float = 16.0, - dropout_rate: float = 0.0, - pad_time_to: Optional[int] = None, - pad_dimension_to: Optional[int] = None, - **_, - ): - # Network topology is a flavor of UNet, example chart for num_resolutions=4 - # - # 1: Image → Image/2 → Image/4 → Image/8 - # ↓ ↓ ↓ ↓ - # 2: Hidden → Hidden/2 → Hidden/4 → Hidden/8 - # ↓ ↓ ↓ ↓ - # 3: Hidden ← Hidden/2 ← Hidden/4 ← Hidden/8 - # ↓ ↓ ↓ ↓ - # 4: Image ← Image/2 ← Image/4 ← Image/8 - - # Horizontal arrows in (1) are downsampling - # Vertical arrows from (1) to (2) are channel upconversions - # - # Horizontal arrows in (2) are blocks with downsampling where necessary - # Horizontal arrows in (3) are blocks with upsampling where necessary - # - # Vertical arrows from (1) to (2) are downsampling and channel upconversioins - # Vertical arrows from (2) to (3) are sums connections (also with / sqrt(2)) - # Vertical arrows from (3) to (4) are channel downconversions - # Horizontal arrows in (4) are upsampling and addition - super().__init__() - - # same nonlinearity is used throughout the whole network - self.activation: torch.nn.Module = activation_registry[nonlinearity]() - self.init_scale: float = init_scale - - self.downsample = torch.nn.Upsample(scale_factor=0.5, mode="bilinear") - self.upsample = torch.nn.Upsample(scale_factor=2, mode="bilinear") - - self.in_channels = in_channels - self.out_channels = out_channels - self.channels = channels - self.num_res_blocks = num_res_blocks - self.num_resolutions = num_resolutions - self.conditioned_on_time = conditioned_on_time - - # padding setup - self.pad_time_to = pad_time_to or 2 ** self.num_resolutions - self.pad_dimension_to = pad_dimension_to or 2 ** self.num_resolutions - - if self.conditioned_on_time: - self.time_embedding = torch.nn.Sequential( - GaussianFourierProjection(embedding_size=self.channels[0], scale=fourier_embedding_scale), - torch.nn.Linear(self.channels[0] * 2, self.channels[0] * 4), - self.activation, - torch.nn.Linear(self.channels[0] * 4, self.channels[0] * 4), - ) - - self.input_pyramid = torch.nn.ModuleList() - for ch in self.channels[:-1]: - self.input_pyramid.append(torch.nn.Conv2d(in_channels=self.in_channels, out_channels=ch, kernel_size=1)) - - # each block takes an image and outputs an image - # possibly changes number of channels - # output blocks ("reverse" path of the unet) reuse outputs of input blocks ("forward" path) - # so great care must be taken to in/out channels of each block - # resolutions are handled in `forward` - block_params = { - "activation": self.activation, - "dropout_rate": dropout_rate, - "init_scale": self.init_scale, - "diffusion_step_embedding_dim": channels[0] * 4 if self.conditioned_on_time else None, - } - self.input_blocks = torch.nn.ModuleList() - for in_ch, out_ch in zip(self.channels[:-1], self.channels[1:]): - for n in range(num_res_blocks): - block = ResnetBlockBigGANPlusPlus(in_ch=in_ch if n == 0 else out_ch, out_ch=out_ch, **block_params) - self.input_blocks.append(block) - - self.output_blocks = torch.nn.ModuleList() - for in_ch, out_ch in zip(reversed(self.channels[1:]), reversed(self.channels[:-1])): - for n in reversed(range(num_res_blocks)): - block = ResnetBlockBigGANPlusPlus(in_ch=in_ch, out_ch=out_ch if n == 0 else in_ch, **block_params) - self.output_blocks.append(block) - - self.projection_blocks = torch.nn.ModuleList() - for ch in self.channels[:-1]: - self.projection_blocks.append(torch.nn.Conv2d(ch, out_channels, kernel_size=1)) - - assert len(self.input_pyramid) == self.num_resolutions - assert len(self.input_blocks) == self.num_resolutions * self.num_res_blocks - assert len(self.output_blocks) == self.num_resolutions * self.num_res_blocks - assert len(self.projection_blocks) == self.num_resolutions - - self.init_weights_() - - logging.debug('Initialized %s with', self.__class__.__name__) - logging.debug('\tin_channels: %s', self.in_channels) - logging.debug('\tout_channels: %s', self.out_channels) - logging.debug('\tchannels: %s', self.channels) - logging.debug('\tnum_res_blocks: %s', self.num_res_blocks) - logging.debug('\tnum_resolutions: %s', self.num_resolutions) - logging.debug('\tconditioned_on_time: %s', self.conditioned_on_time) - logging.debug('\tpad_time_to: %s', self.pad_time_to) - logging.debug('\tpad_dimension_to: %s', self.pad_dimension_to) - - def init_weights_(self): - for module in self.modules(): - if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - torch.nn.init.zeros_(module.bias) - - # torch.nn submodules with scaled init - for module in self.projection_blocks: - torch.nn.init.xavier_uniform_(module.weight, gain=self.init_scale) - - # non-torch.nn submodules can have their own init schemes - for module in self.modules(): - if module is self: - continue - - if hasattr(module, "init_weights_"): - module.init_weights_() - - @typecheck( - input_types={"input": NeuralType(('B', 'C', 'D', 'T')),}, - output_types={"output": NeuralType(('B', 'C', 'D', 'T')),}, - ) - def pad_input(self, input: torch.Tensor) -> torch.Tensor: - """Pad input tensor to match the required dimensions across `T` and `D`. - """ - *_, D, T = input.shape - output = input - - # padding across time - if T % self.pad_time_to != 0: - output = F.pad(output, (0, self.pad_time_to - T % self.pad_time_to)) - - # padding across dimension - if D % self.pad_dimension_to != 0: - output = F.pad(output, (0, 0, 0, self.pad_dimension_to - D % self.pad_dimension_to)) - - return output - - @property - def input_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports. - """ - return { - "input": NeuralType(('B', 'C', 'D', 'T'), VoidType()), - "input_length": NeuralType(('B',), LengthsType(), optional=True), - "condition": NeuralType(('B',), FloatType(), optional=True), - } - - @property - def output_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports. - """ - return { - "output": NeuralType(('B', 'C', 'D', 'T'), VoidType()), - "output_length": NeuralType(('B',), LengthsType(), optional=True), - } - - @typecheck() - def forward( - self, *, input: torch.Tensor, input_length: Optional[torch.Tensor], condition: Optional[torch.Tensor] = None - ): - """Forward pass of the model. - - Args: - input: input tensor, shjae (B, C, D, T) - input_length: length of the valid time steps for each example in the batch, shape (B,) - condition: scalar condition (time) for the model, will be embedded using `self.time_embedding` - """ - assert input.shape[1] == self.in_channels - - # apply padding at the input - *_, D, T = input.shape - input = self.pad_input(input=input) - - if input_length is None: - # assume all time frames are valid - input_length = torch.LongTensor([input.shape[-1]] * input.shape[0]).to(input.device) - - lengths = input_length - - if condition is not None: - if len(condition.shape) != 1: - raise ValueError( - f"Expected conditon to be a 1-dim tensor, got a {len(condition.shape)}-dim tensor of shape {tuple(condition.shape)}" - ) - if condition.shape[0] != input.shape[0]: - raise ValueError( - f"Condition {tuple(condition.shape)} and input {tuple(input.shape)} should match along the batch dimension" - ) - - condition = self.time_embedding(torch.log(condition)) - - # downsample and project input image to add later in the downsampling path - pyramid = [input] - for resolution_num in range(self.num_resolutions - 1): - pyramid.append(self.downsample(pyramid[-1])) - pyramid = [block(image) for image, block in zip(pyramid, self.input_pyramid)] - - # downsampling path - history = [] - hidden = torch.zeros_like(pyramid[0]) - input_blocks = iter(self.input_blocks) - for resolution_num, image in enumerate(pyramid): - hidden = (hidden + image) / math.sqrt(2.0) - hidden = mask_sequence_tensor(hidden, lengths) - - for _ in range(self.num_res_blocks): - hidden = next(input_blocks)(hidden, condition) - hidden = mask_sequence_tensor(hidden, lengths) - history.append(hidden) - - final_resolution = resolution_num == self.num_resolutions - 1 - if not final_resolution: - hidden = self.downsample(hidden) - lengths = (lengths / 2).ceil().long() - - # upsampling path - to_project = [] - for residual, block in zip(reversed(history), self.output_blocks): - if hidden.shape != residual.shape: - to_project.append(hidden) - hidden = self.upsample(hidden) - lengths = (lengths * 2).long() - - hidden = (hidden + residual) / math.sqrt(2.0) - hidden = block(hidden, condition) - hidden = mask_sequence_tensor(hidden, lengths) - - to_project.append(hidden) - - # projecting to images - images = [] - for tensor, projection in zip(to_project, reversed(self.projection_blocks)): - image = projection(tensor) - images.append(F.interpolate(image, size=input.shape[-2:])) # TODO write this loop using self.upsample - - result = sum(images) - - assert result.shape[-2:] == input.shape[-2:] - - # remove padding - result = result[:, :, :D, :T] - return result, input_length - - -class GaussianFourierProjection(NeuralModule): - """Gaussian Fourier embeddings for input scalars. - - The input scalars are typically time or noise levels. - """ - - def __init__(self, embedding_size: int = 256, scale: float = 1.0): - super().__init__() - self.W = torch.nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) - - @property - def input_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports. - """ - return { - "input": NeuralType(('B',), FloatType()), - } - - @property - def output_types(self) -> Dict[str, NeuralType]: - """Returns definitions of module output ports. - """ - return { - "output": NeuralType(('B', 'D'), VoidType()), - } - - def forward(self, input): - x_proj = input[:, None] * self.W[None, :] * 2 * math.pi - return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) - - -class ResnetBlockBigGANPlusPlus(torch.nn.Module): - """Implementation of a ResNet block for the BigGAN model. - - References: - - Song et al., Score-Based Generative Modeling through Stochastic Differential Equations, NeurIPS 2021 - - Brock et al., Large scale GAN training for high fidelity natural image synthesis, ICLR 2018 - """ - - def __init__( - self, - activation: torch.nn.Module, - in_ch: int, - out_ch: int, - diffusion_step_embedding_dim: Optional[int] = None, - init_scale: float = 1e-5, - dropout_rate: float = 0.1, - in_num_groups: Optional[int] = None, - out_num_groups: Optional[int] = None, - eps: float = 1e-6, - ): - """ - Args: - activation (torch.nn.Module): activation layer (ReLU, SiLU, etc) - in_ch (int): number of channels in the input image - out_ch (int, optional): number of channels in the output image - diffusion_step_embedding_dim (int, optional): dimension of diffusion timestep embedding. Defaults to None (no embedding). - dropout_rate (float, optional): dropout rate. Defaults to 0.1. - init_scale (float, optional): scaling for weight initialization. Defaults to 0.0. - in_num_groups (int, optional): num_groups in the first GroupNorm. Defaults to min(in_ch // 4, 32) - out_num_groups (int, optional): num_groups in the second GroupNorm. Defaults to min(out_ch // 4, 32) - eps (float, optional): eps parameter of GroupNorms. Defaults to 1e-6. - """ - super().__init__() - in_num_groups = in_num_groups or min(in_ch // 4, 32) - out_num_groups = out_num_groups or min(out_ch // 4, 32) - - self.init_scale = init_scale - - self.input_block = torch.nn.Sequential( - torch.nn.GroupNorm(num_groups=in_num_groups, num_channels=in_ch, eps=eps), activation, - ) - - self.middle_conv = torch.nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, padding=1) - if diffusion_step_embedding_dim is not None: - self.diffusion_step_projection = torch.nn.Sequential( - activation, - torch.nn.Linear(diffusion_step_embedding_dim, out_ch), - einops.layers.torch.Rearrange("batch dim -> batch dim 1 1"), - ) - - self.output_block = torch.nn.Sequential( - torch.nn.GroupNorm(num_groups=out_num_groups, num_channels=out_ch, eps=eps), - activation, - torch.nn.Dropout(dropout_rate), - torch.nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, padding=1), - ) - - if in_ch != out_ch: - self.residual_projection = torch.nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=1) - - self.act = activation - self.in_ch = in_ch - self.out_ch = out_ch - - self.init_weights_() - - def init_weights_(self): - """Weight initialization - """ - for module in self.modules(): - if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - torch.nn.init.zeros_(module.bias) - - # a single Conv2d is initialized with gain - torch.nn.init.xavier_uniform_(self.output_block[-1].weight, gain=self.init_scale) - - def forward(self, x: torch.Tensor, diffusion_time_embedding: Optional[torch.Tensor] = None): - """Forward pass of the model. - - Args: - x: input tensor - diffusion_time_embedding: embedding of the diffusion time step - - Returns: - Output tensor - """ - h = self.input_block(x) - h = self.middle_conv(h) - - if diffusion_time_embedding is not None: - h = h + self.diffusion_step_projection(diffusion_time_embedding) - - h = self.output_block(h) - - if x.shape != h.shape: # matching number of channels - x = self.residual_projection(x) - return (x + h) / math.sqrt(2.0) - - class PredictorCorrectorSampler(NeuralModule): """Predictor-Corrector sampler for the reverse SDE. @@ -1233,7 +740,9 @@ def __init__( "score_condition": NeuralType(('B', 'C', 'D', 'T'), VoidType(), optional=True), "state_length": NeuralType(tuple('B'), LengthsType(), optional=True), }, - output_types={"state": NeuralType(('B', 'C', 'D', 'T'), VoidType()),}, + output_types={ + "state": NeuralType(('B', 'C', 'D', 'T'), VoidType()), + }, ) @torch.inference_mode() def forward(self, state, time, score_condition=None, state_length=None): diff --git a/nemo/collections/asr/parts/submodules/multichannel_modules.py b/nemo/collections/audio/parts/submodules/multichannel.py similarity index 67% rename from nemo/collections/asr/parts/submodules/multichannel_modules.py rename to nemo/collections/audio/parts/submodules/multichannel.py index 04ab9985d6415..aff0f28cfc3a0 100644 --- a/nemo/collections/asr/parts/submodules/multichannel_modules.py +++ b/nemo/collections/audio/parts/submodules/multichannel.py @@ -13,13 +13,15 @@ # limitations under the License. import random -from typing import Callable, Optional +from typing import Callable, Dict, Optional, Tuple +import numpy as np import torch +from nemo.collections.asr.parts.preprocessing.features import make_seq_mask_like from nemo.collections.asr.parts.submodules.multi_head_attention import MultiHeadAttention from nemo.core.classes import NeuralModule, typecheck -from nemo.core.neural_types import AudioSignal, FloatType, NeuralType, SpectrogramType +from nemo.core.neural_types import AudioSignal, FloatType, LengthsType, NeuralType, SpectrogramType from nemo.utils import logging try: @@ -68,16 +70,14 @@ def __init__( @property def input_types(self): - """Returns definitions of module input types - """ + """Returns definitions of module input types""" return { 'input': NeuralType(('B', 'C', 'T'), AudioSignal()), } @property def output_types(self): - """Returns definitions of module output types - """ + """Returns definitions of module output types""" return { 'output': NeuralType(('B', 'C', 'T'), AudioSignal()), } @@ -86,7 +86,7 @@ def output_types(self): @torch.no_grad() def forward(self, input: torch.Tensor) -> torch.Tensor: # Expecting (B, C, T) - assert input.ndim == 3, f'Expecting input with shape (B, C, T)' + assert input.ndim == 3, 'Expecting input with shape (B, C, T)' num_channels_in = input.size(1) if num_channels_in < self.num_channels_min: @@ -143,16 +143,14 @@ def __init__(self, in_features: int, out_features: Optional[int] = None): @property def input_types(self): - """Returns definitions of module input types - """ + """Returns definitions of module input types""" return { 'input': NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), } @property def output_types(self): - """Returns definitions of module output types - """ + """Returns definitions of module output types""" return { 'output': NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), } @@ -231,16 +229,14 @@ def __init__(self, in_features: int, out_features: Optional[int] = None, n_head: @property def input_types(self): - """Returns definitions of module input types - """ + """Returns definitions of module input types""" return { 'input': NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), } @property def output_types(self): - """Returns definitions of module output types - """ + """Returns definitions of module output types""" return { 'output': NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), } @@ -281,8 +277,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class ChannelAveragePool(NeuralModule): - """Apply average pooling across channels. - """ + """Apply average pooling across channels.""" def __init__(self): super().__init__() @@ -290,16 +285,14 @@ def __init__(self): @property def input_types(self): - """Returns definitions of module input types - """ + """Returns definitions of module input types""" return { 'input': NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), } @property def output_types(self): - """Returns definitions of module output types - """ + """Returns definitions of module output types""" return { 'output': NeuralType(('B', 'D', 'T'), SpectrogramType()), } @@ -343,16 +336,14 @@ def __init__(self, in_features: int, n_head: int = 1, dropout_rate: float = 0): @property def input_types(self): - """Returns definitions of module input types - """ + """Returns definitions of module input types""" return { 'input': NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), } @property def output_types(self): - """Returns definitions of module output types - """ + """Returns definitions of module output types""" return { 'output': NeuralType(('B', 'D', 'T'), SpectrogramType()), } @@ -523,7 +514,7 @@ def apply_filter(self, input: torch.Tensor, filter: torch.Tensor) -> torch.Tenso Args: input: batch with C input channels, shape (B, C, F, T) filter: batch of C-input, M-output filters, shape (B, F, C, M) - + Returns: M-channel filter output, shape (B, M, F, T) """ @@ -551,7 +542,7 @@ def apply_ban(self, input: torch.Tensor, filter: torch.Tensor, psd_n: torch.Tens input: batch with M output channels (B, M, F, T) filter: batch of C-input, M-output filters, shape (B, F, C, M) psd_n: batch of noise PSDs, shape (B, F, C, C) - + Returns: Filtere input, shape (B, M, F, T) @@ -576,8 +567,7 @@ def apply_ban(self, input: torch.Tensor, filter: torch.Tensor, psd_n: torch.Tens @property def input_types(self): - """Returns definitions of module input types - """ + """Returns definitions of module input types""" return { 'input': NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), 'mask_s': NeuralType(('B', 'D', 'T'), FloatType()), @@ -586,8 +576,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output types - """ + """Returns definitions of module output types""" return { 'output': NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), } @@ -714,8 +703,7 @@ def __init__( @property def input_types(self): - """Returns definitions of module input types - """ + """Returns definitions of module input types""" return { 'W': NeuralType(('B', 'D', 'C', 'C'), SpectrogramType()), 'psd_s': NeuralType(('B', 'D', 'C', 'C'), SpectrogramType()), @@ -724,8 +712,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output types - """ + """Returns definitions of module output types""" return { 'output': NeuralType(('B', 'C'), FloatType()), } @@ -778,3 +765,291 @@ def forward(self, W: torch.Tensor, psd_s: torch.Tensor, psd_n: torch.Tensor) -> ref = ref_soft return ref + + +class WPEFilter(NeuralModule): + """A weighted prediction error filter. + Given input signal, and expected power of the desired signal, this + class estimates a multiple-input multiple-output prediction filter + and returns the filtered signal. Currently, estimation of statistics + and processing is performed in batch mode. + + Args: + filter_length: Length of the prediction filter in frames, per channel + prediction_delay: Prediction delay in frames + diag_reg: Diagonal regularization for the correlation matrix Q, applied as diag_reg * trace(Q) + eps + eps: Small positive constant for regularization + + References: + - Yoshioka and Nakatani, Generalization of Multi-Channel Linear Prediction + Methods for Blind MIMO Impulse Response Shortening, 2012 + - Jukić et al, Group sparsity for MIMO speech dereverberation, 2015 + """ + + def __init__(self, filter_length: int, prediction_delay: int, diag_reg: Optional[float] = 1e-6, eps: float = 1e-8): + super().__init__() + self.filter_length = filter_length + self.prediction_delay = prediction_delay + self.diag_reg = diag_reg + self.eps = eps + + logging.debug('Initialized %s', self.__class__.__name__) + logging.debug('\tfilter_length: %d', self.filter_length) + logging.debug('\tprediction_delay: %d', self.prediction_delay) + logging.debug('\tdiag_reg: %g', self.diag_reg) + logging.debug('\teps: %g', self.eps) + + @property + def input_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports.""" + return { + "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "power": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "input_length": NeuralType(('B',), LengthsType(), optional=True), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports.""" + return { + "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "output_length": NeuralType(('B',), LengthsType(), optional=True), + } + + @typecheck() + def forward( + self, input: torch.Tensor, power: torch.Tensor, input_length: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Given input and the predicted power for the desired signal, estimate + the WPE filter and return the processed signal. + + Args: + input: Input signal, shape (B, C, F, N) + power: Predicted power of the desired signal, shape (B, C, F, N) + input_length: Optional, length of valid frames in `input`. Defaults to `None` + + Returns: + Tuple of (processed_signal, output_length). Processed signal has the same + shape as the input signal (B, C, F, N), and the output length is the same + as the input length. + """ + # Temporal weighting: average power over channels, output shape (B, F, N) + weight = torch.mean(power, dim=1) + # Use inverse power as the weight + weight = 1 / (weight + self.eps) + + # Multi-channel convolution matrix for each subband + tilde_input = self.convtensor(input, filter_length=self.filter_length, delay=self.prediction_delay) + + # Estimate correlation matrices + Q, R = self.estimate_correlations( + input=input, weight=weight, tilde_input=tilde_input, input_length=input_length + ) + + # Estimate prediction filter + G = self.estimate_filter(Q=Q, R=R) + + # Apply prediction filter + undesired_signal = self.apply_filter(filter=G, tilde_input=tilde_input) + + # Dereverberation + desired_signal = input - undesired_signal + + if input_length is not None: + # Mask padded frames + length_mask: torch.Tensor = make_seq_mask_like( + lengths=input_length, like=desired_signal, time_dim=-1, valid_ones=False + ) + desired_signal = desired_signal.masked_fill(length_mask, 0.0) + + return desired_signal, input_length + + @classmethod + def convtensor( + cls, x: torch.Tensor, filter_length: int, delay: int = 0, n_steps: Optional[int] = None + ) -> torch.Tensor: + """Create a tensor equivalent of convmtx_mc for each example in the batch. + The input signal tensor `x` has shape (B, C, F, N). + Convtensor returns a view of the input signal `x`. + + Note: We avoid reshaping the output to collapse channels and filter taps into + a single dimension, e.g., (B, F, N, -1). In this way, the output is a view of the input, + while an additional reshape would result in a contiguous array and more memory use. + + Args: + x: input tensor, shape (B, C, F, N) + filter_length: length of the filter, determines the shape of the convolution tensor + delay: delay to add to the input signal `x` before constructing the convolution tensor + n_steps: Optional, number of time steps to keep in the out. Defaults to the number of + time steps in the input tensor. + + Returns: + Return a convolutional tensor with shape (B, C, F, n_steps, filter_length) + """ + if x.ndim != 4: + raise RuntimeError(f'Expecting a 4-D input. Received input with shape {x.shape}') + + B, C, F, N = x.shape + + if n_steps is None: + # Keep the same length as the input signal + n_steps = N + + # Pad temporal dimension + x = torch.nn.functional.pad(x, (filter_length - 1 + delay, 0)) + + # Build Toeplitz-like matrix view by unfolding across time + tilde_X = x.unfold(-1, filter_length, 1) + + # Trim to the set number of time steps + tilde_X = tilde_X[:, :, :, :n_steps, :] + + return tilde_X + + @classmethod + def permute_convtensor(cls, x: torch.Tensor) -> torch.Tensor: + """Reshape and permute columns to convert the result of + convtensor to be equal to convmtx_mc. This is used for verification + purposes and it is not required to use the filter. + + Args: + x: output of self.convtensor, shape (B, C, F, N, filter_length) + + Returns: + Output has shape (B, F, N, C*filter_length) that corresponds to + the layout of convmtx_mc. + """ + B, C, F, N, filter_length = x.shape + + # .view will not work, so a copy will have to be created with .reshape + # That will result in more memory use, since we don't use a view of the original + # multi-channel signal + x = x.permute(0, 2, 3, 1, 4) + x = x.reshape(B, F, N, C * filter_length) + + permute = [] + for m in range(C): + permute[m * filter_length : (m + 1) * filter_length] = m * filter_length + np.flip( + np.arange(filter_length) + ) + return x[..., permute] + + def estimate_correlations( + self, + input: torch.Tensor, + weight: torch.Tensor, + tilde_input: torch.Tensor, + input_length: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor]: + """ + Args: + input: Input signal, shape (B, C, F, N) + weight: Time-frequency weight, shape (B, F, N) + tilde_input: Multi-channel convolution tensor, shape (B, C, F, N, filter_length) + input_length: Length of each input example, shape (B) + + Returns: + Returns a tuple of correlation matrices for each batch. + + Let `X` denote the input signal in a single subband, + `tilde{X}` the corresponding multi-channel correlation matrix, + and `w` the vector of weights. + + The first output is + Q = tilde{X}^H * diag(w) * tilde{X} (1) + for each (b, f). + The matrix calculated in (1) has shape (C * filter_length, C * filter_length) + The output is returned in a tensor with shape (B, F, C, filter_length, C, filter_length). + + The second output is + R = tilde{X}^H * diag(w) * X (2) + for each (b, f). + The matrix calculated in (2) has shape (C * filter_length, C) + The output is returned in a tensor with shape (B, F, C, filter_length, C). The last + dimension corresponds to output channels. + """ + if input_length is not None: + # Take only valid samples into account + length_mask: torch.Tensor = make_seq_mask_like( + lengths=input_length, like=weight, time_dim=-1, valid_ones=False + ) + weight = weight.masked_fill(length_mask, 0.0) + + # Calculate (1) + # result: (B, F, C, filter_length, C, filter_length) + Q = torch.einsum('bjfik,bmfin->bfjkmn', tilde_input.conj(), weight[:, None, :, :, None] * tilde_input) + + # Calculate (2) + # result: (B, F, C, filter_length, C) + R = torch.einsum('bjfik,bmfi->bfjkm', tilde_input.conj(), weight[:, None, :, :] * input) + + return Q, R + + def estimate_filter(self, Q: torch.Tensor, R: torch.Tensor) -> torch.Tensor: + """Estimate the MIMO prediction filter as + G(b,f) = Q(b,f) \ R(b,f) + for each subband in each example in the batch (b, f). + + Args: + Q: shape (B, F, C, filter_length, C, filter_length) + R: shape (B, F, C, filter_length, C) + + Returns: + Complex-valued prediction filter, shape (B, C, F, C, filter_length) + """ + B, F, C, filter_length, _, _ = Q.shape + assert ( + filter_length == self.filter_length + ), f'Shape of Q {Q.shape} is not matching filter length {self.filter_length}' + + # Reshape to analytical dimensions for each (b, f) + Q = Q.reshape(B, F, C * self.filter_length, C * filter_length) + R = R.reshape(B, F, C * self.filter_length, C) + + # Diagonal regularization + if self.diag_reg: + # Regularization: diag_reg * trace(Q) + eps + diag_reg = self.diag_reg * torch.diagonal(Q, dim1=-2, dim2=-1).sum(-1).real + self.eps + # Apply regularization on Q + Q = Q + torch.diag_embed(diag_reg.unsqueeze(-1) * torch.ones(Q.shape[-1], device=Q.device)) + + # Solve for the filter + G = torch.linalg.solve(Q, R) + + # Reshape to desired representation: (B, F, input channels, filter_length, output channels) + G = G.reshape(B, F, C, filter_length, C) + # Move output channels to front: (B, output channels, F, input channels, filter_length) + G = G.permute(0, 4, 1, 2, 3) + + return G + + def apply_filter( + self, filter: torch.Tensor, input: Optional[torch.Tensor] = None, tilde_input: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Apply a prediction filter `filter` on the input `input` as + + output(b,f) = tilde{input(b,f)} * filter(b,f) + + If available, directly use the convolution matrix `tilde_input`. + + Args: + input: Input signal, shape (B, C, F, N) + tilde_input: Convolution matrix for the input signal, shape (B, C, F, N, filter_length) + filter: Prediction filter, shape (B, C, F, C, filter_length) + + Returns: + Multi-channel signal obtained by applying the prediction filter on + the input signal, same shape as input (B, C, F, N) + """ + if input is None and tilde_input is None: + raise RuntimeError('Both inputs cannot be None simultaneously.') + if input is not None and tilde_input is not None: + raise RuntimeError('Both inputs cannot be provided simultaneously.') + + if tilde_input is None: + tilde_input = self.convtensor(input, filter_length=self.filter_length, delay=self.prediction_delay) + + # For each (batch, output channel, f, time step), sum across (input channel, filter tap) + output = torch.einsum('bjfik,bmfjk->bmfi', tilde_input, filter) + + return output diff --git a/nemo/collections/audio/parts/submodules/ncsnpp.py b/nemo/collections/audio/parts/submodules/ncsnpp.py new file mode 100644 index 0000000000000..adbeccc0dc023 --- /dev/null +++ b/nemo/collections/audio/parts/submodules/ncsnpp.py @@ -0,0 +1,511 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Dict, Optional, Sequence + +import einops +import einops.layers.torch +import torch +import torch.nn.functional as F + +from nemo.collections.common.parts.utils import activation_registry +from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor +from nemo.core.classes import NeuralModule, typecheck +from nemo.core.neural_types import FloatType, LengthsType, NeuralType, SpectrogramType, VoidType +from nemo.utils import logging + + +class SpectrogramNoiseConditionalScoreNetworkPlusPlus(NeuralModule): + """This model handles complex-valued inputs by stacking real and imaginary components. + Stacked tensor is processed using NCSN++ and the output is projected to generate real + and imaginary components of the output channels. + + Args: + in_channels: number of input complex-valued channels + out_channels: number of output complex-valued channels + """ + + def __init__(self, *, in_channels: int = 1, out_channels: int = 1, **kwargs): + super().__init__() + + # Number of input signals for this estimator + if in_channels < 1: + raise ValueError( + f'Number of input channels needs to be larger or equal to one, current value {in_channels}' + ) + + self.in_channels = in_channels + + # Number of output signals for this estimator + if out_channels < 1: + raise ValueError( + f'Number of output channels needs to be larger or equal to one, current value {out_channels}' + ) + + self.out_channels = out_channels + + # Instantiate noise conditional score network NCSN++ + ncsnpp_params = kwargs.copy() + ncsnpp_params['in_channels'] = ncsnpp_params['out_channels'] = 2 * self.in_channels # stack real and imag + self.ncsnpp = NoiseConditionalScoreNetworkPlusPlus(**ncsnpp_params) + + # Output projection to generate real and imaginary components of the output channels + self.output_projection = torch.nn.Conv2d( + in_channels=2 * self.in_channels, out_channels=2 * self.out_channels, kernel_size=1 + ) + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tin_channels: %s', self.in_channels) + logging.debug('\tout_channels: %s', self.out_channels) + + @property + def input_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports.""" + return { + "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "input_length": NeuralType(('B',), LengthsType(), optional=True), + "condition": NeuralType(('B',), FloatType(), optional=True), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports.""" + return { + "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "output_length": NeuralType(('B',), LengthsType(), optional=True), + } + + @typecheck() + def forward(self, input, input_length=None, condition=None): + # Stack real and imaginary components + B, C_in, D, T = input.shape + + if C_in != self.in_channels: + raise RuntimeError(f'Unexpected input channel size {C_in}, expected {self.in_channels}') + + # Stack real and imaginary parts + input_real_imag = torch.stack([input.real, input.imag], dim=2) + input = einops.rearrange(input_real_imag, 'B C RI F T -> B (C RI) F T') + + # Process using NCSN++ + output, output_length = self.ncsnpp(input=input, input_length=input_length, condition=condition) + + # Output projection + output = self.output_projection(output) + + # Convert to complex-valued signal + output = output.reshape(B, 2, self.out_channels, D, T) + # Move real/imag dimension to the end + output = output.permute(0, 2, 3, 4, 1) + output = torch.view_as_complex(output.contiguous()) + + return output, output_length + + +class NoiseConditionalScoreNetworkPlusPlus(NeuralModule): + """Implementation of Noise Conditional Score Network (NCSN++) architecture. + + References: + - Song et al., Score-Based Generative Modeling through Stochastic Differential Equations, NeurIPS 2021 + - Brock et al., Large scale GAN training for high fidelity natural image synthesis, ICLR 2018 + """ + + def __init__( + self, + nonlinearity: str = "swish", + in_channels: int = 2, # number of channels in the input image + out_channels: int = 2, # number of channels in the output image + channels: Sequence[int] = (128, 128, 256, 256, 256), # number of channels at start + at every resolution + num_res_blocks: int = 2, + num_resolutions: int = 4, + init_scale: float = 1e-5, + conditioned_on_time: bool = False, + fourier_embedding_scale: float = 16.0, + dropout_rate: float = 0.0, + pad_time_to: Optional[int] = None, + pad_dimension_to: Optional[int] = None, + **_, + ): + # Network topology is a flavor of UNet, example chart for num_resolutions=4 + # + # 1: Image → Image/2 → Image/4 → Image/8 + # ↓ ↓ ↓ ↓ + # 2: Hidden → Hidden/2 → Hidden/4 → Hidden/8 + # ↓ ↓ ↓ ↓ + # 3: Hidden ← Hidden/2 ← Hidden/4 ← Hidden/8 + # ↓ ↓ ↓ ↓ + # 4: Image ← Image/2 ← Image/4 ← Image/8 + + # Horizontal arrows in (1) are downsampling + # Vertical arrows from (1) to (2) are channel upconversions + # + # Horizontal arrows in (2) are blocks with downsampling where necessary + # Horizontal arrows in (3) are blocks with upsampling where necessary + # + # Vertical arrows from (1) to (2) are downsampling and channel upconversioins + # Vertical arrows from (2) to (3) are sums connections (also with / sqrt(2)) + # Vertical arrows from (3) to (4) are channel downconversions + # Horizontal arrows in (4) are upsampling and addition + super().__init__() + + # same nonlinearity is used throughout the whole network + self.activation: torch.nn.Module = activation_registry[nonlinearity]() + self.init_scale: float = init_scale + + self.downsample = torch.nn.Upsample(scale_factor=0.5, mode="bilinear") + self.upsample = torch.nn.Upsample(scale_factor=2, mode="bilinear") + + self.in_channels = in_channels + self.out_channels = out_channels + self.channels = channels + self.num_res_blocks = num_res_blocks + self.num_resolutions = num_resolutions + self.conditioned_on_time = conditioned_on_time + + # padding setup + self.pad_time_to = pad_time_to or 2**self.num_resolutions + self.pad_dimension_to = pad_dimension_to or 2**self.num_resolutions + + if self.conditioned_on_time: + self.time_embedding = torch.nn.Sequential( + GaussianFourierProjection(embedding_size=self.channels[0], scale=fourier_embedding_scale), + torch.nn.Linear(self.channels[0] * 2, self.channels[0] * 4), + self.activation, + torch.nn.Linear(self.channels[0] * 4, self.channels[0] * 4), + ) + + self.input_pyramid = torch.nn.ModuleList() + for ch in self.channels[:-1]: + self.input_pyramid.append(torch.nn.Conv2d(in_channels=self.in_channels, out_channels=ch, kernel_size=1)) + + # each block takes an image and outputs an image + # possibly changes number of channels + # output blocks ("reverse" path of the unet) reuse outputs of input blocks ("forward" path) + # so great care must be taken to in/out channels of each block + # resolutions are handled in `forward` + block_params = { + "activation": self.activation, + "dropout_rate": dropout_rate, + "init_scale": self.init_scale, + "diffusion_step_embedding_dim": channels[0] * 4 if self.conditioned_on_time else None, + } + self.input_blocks = torch.nn.ModuleList() + for in_ch, out_ch in zip(self.channels[:-1], self.channels[1:]): + for n in range(num_res_blocks): + block = ResnetBlockBigGANPlusPlus(in_ch=in_ch if n == 0 else out_ch, out_ch=out_ch, **block_params) + self.input_blocks.append(block) + + self.output_blocks = torch.nn.ModuleList() + for in_ch, out_ch in zip(reversed(self.channels[1:]), reversed(self.channels[:-1])): + for n in reversed(range(num_res_blocks)): + block = ResnetBlockBigGANPlusPlus(in_ch=in_ch, out_ch=out_ch if n == 0 else in_ch, **block_params) + self.output_blocks.append(block) + + self.projection_blocks = torch.nn.ModuleList() + for ch in self.channels[:-1]: + self.projection_blocks.append(torch.nn.Conv2d(ch, out_channels, kernel_size=1)) + + assert len(self.input_pyramid) == self.num_resolutions + assert len(self.input_blocks) == self.num_resolutions * self.num_res_blocks + assert len(self.output_blocks) == self.num_resolutions * self.num_res_blocks + assert len(self.projection_blocks) == self.num_resolutions + + self.init_weights_() + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tin_channels: %s', self.in_channels) + logging.debug('\tout_channels: %s', self.out_channels) + logging.debug('\tchannels: %s', self.channels) + logging.debug('\tnum_res_blocks: %s', self.num_res_blocks) + logging.debug('\tnum_resolutions: %s', self.num_resolutions) + logging.debug('\tconditioned_on_time: %s', self.conditioned_on_time) + logging.debug('\tpad_time_to: %s', self.pad_time_to) + logging.debug('\tpad_dimension_to: %s', self.pad_dimension_to) + + def init_weights_(self): + for module in self.modules(): + if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + + # torch.nn submodules with scaled init + for module in self.projection_blocks: + torch.nn.init.xavier_uniform_(module.weight, gain=self.init_scale) + + # non-torch.nn submodules can have their own init schemes + for module in self.modules(): + if module is self: + continue + + if hasattr(module, "init_weights_"): + module.init_weights_() + + @typecheck( + input_types={ + "input": NeuralType(('B', 'C', 'D', 'T')), + }, + output_types={ + "output": NeuralType(('B', 'C', 'D', 'T')), + }, + ) + def pad_input(self, input: torch.Tensor) -> torch.Tensor: + """Pad input tensor to match the required dimensions across `T` and `D`.""" + *_, D, T = input.shape + output = input + + # padding across time + if T % self.pad_time_to != 0: + output = F.pad(output, (0, self.pad_time_to - T % self.pad_time_to)) + + # padding across dimension + if D % self.pad_dimension_to != 0: + output = F.pad(output, (0, 0, 0, self.pad_dimension_to - D % self.pad_dimension_to)) + + return output + + @property + def input_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports.""" + return { + "input": NeuralType(('B', 'C', 'D', 'T'), VoidType()), + "input_length": NeuralType(('B',), LengthsType(), optional=True), + "condition": NeuralType(('B',), FloatType(), optional=True), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports.""" + return { + "output": NeuralType(('B', 'C', 'D', 'T'), VoidType()), + "output_length": NeuralType(('B',), LengthsType(), optional=True), + } + + @typecheck() + def forward( + self, *, input: torch.Tensor, input_length: Optional[torch.Tensor], condition: Optional[torch.Tensor] = None + ): + """Forward pass of the model. + + Args: + input: input tensor, shjae (B, C, D, T) + input_length: length of the valid time steps for each example in the batch, shape (B,) + condition: scalar condition (time) for the model, will be embedded using `self.time_embedding` + """ + assert input.shape[1] == self.in_channels + + # apply padding at the input + *_, D, T = input.shape + input = self.pad_input(input=input) + + if input_length is None: + # assume all time frames are valid + input_length = torch.LongTensor([input.shape[-1]] * input.shape[0]).to(input.device) + + lengths = input_length + + if condition is not None: + if len(condition.shape) != 1: + raise ValueError( + f"Expected conditon to be a 1-dim tensor, got a {len(condition.shape)}-dim tensor of shape {tuple(condition.shape)}" + ) + if condition.shape[0] != input.shape[0]: + raise ValueError( + f"Condition {tuple(condition.shape)} and input {tuple(input.shape)} should match along the batch dimension" + ) + + condition = self.time_embedding(torch.log(condition)) + + # downsample and project input image to add later in the downsampling path + pyramid = [input] + for resolution_num in range(self.num_resolutions - 1): + pyramid.append(self.downsample(pyramid[-1])) + pyramid = [block(image) for image, block in zip(pyramid, self.input_pyramid)] + + # downsampling path + history = [] + hidden = torch.zeros_like(pyramid[0]) + input_blocks = iter(self.input_blocks) + for resolution_num, image in enumerate(pyramid): + hidden = (hidden + image) / math.sqrt(2.0) + hidden = mask_sequence_tensor(hidden, lengths) + + for _ in range(self.num_res_blocks): + hidden = next(input_blocks)(hidden, condition) + hidden = mask_sequence_tensor(hidden, lengths) + history.append(hidden) + + final_resolution = resolution_num == self.num_resolutions - 1 + if not final_resolution: + hidden = self.downsample(hidden) + lengths = (lengths / 2).ceil().long() + + # upsampling path + to_project = [] + for residual, block in zip(reversed(history), self.output_blocks): + if hidden.shape != residual.shape: + to_project.append(hidden) + hidden = self.upsample(hidden) + lengths = (lengths * 2).long() + + hidden = (hidden + residual) / math.sqrt(2.0) + hidden = block(hidden, condition) + hidden = mask_sequence_tensor(hidden, lengths) + + to_project.append(hidden) + + # projecting to images + images = [] + for tensor, projection in zip(to_project, reversed(self.projection_blocks)): + image = projection(tensor) + images.append(F.interpolate(image, size=input.shape[-2:])) # TODO write this loop using self.upsample + + result = sum(images) + + assert result.shape[-2:] == input.shape[-2:] + + # remove padding + result = result[:, :, :D, :T] + return result, input_length + + +class GaussianFourierProjection(NeuralModule): + """Gaussian Fourier embeddings for input scalars. + + The input scalars are typically time or noise levels. + """ + + def __init__(self, embedding_size: int = 256, scale: float = 1.0): + super().__init__() + self.W = torch.nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + + @property + def input_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports.""" + return { + "input": NeuralType(('B',), FloatType()), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports.""" + return { + "output": NeuralType(('B', 'D'), VoidType()), + } + + def forward(self, input): + x_proj = input[:, None] * self.W[None, :] * 2 * math.pi + return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + + +class ResnetBlockBigGANPlusPlus(torch.nn.Module): + """Implementation of a ResNet block for the BigGAN model. + + References: + - Song et al., Score-Based Generative Modeling through Stochastic Differential Equations, NeurIPS 2021 + - Brock et al., Large scale GAN training for high fidelity natural image synthesis, ICLR 2018 + """ + + def __init__( + self, + activation: torch.nn.Module, + in_ch: int, + out_ch: int, + diffusion_step_embedding_dim: Optional[int] = None, + init_scale: float = 1e-5, + dropout_rate: float = 0.1, + in_num_groups: Optional[int] = None, + out_num_groups: Optional[int] = None, + eps: float = 1e-6, + ): + """ + Args: + activation (torch.nn.Module): activation layer (ReLU, SiLU, etc) + in_ch (int): number of channels in the input image + out_ch (int, optional): number of channels in the output image + diffusion_step_embedding_dim (int, optional): dimension of diffusion timestep embedding. Defaults to None (no embedding). + dropout_rate (float, optional): dropout rate. Defaults to 0.1. + init_scale (float, optional): scaling for weight initialization. Defaults to 0.0. + in_num_groups (int, optional): num_groups in the first GroupNorm. Defaults to min(in_ch // 4, 32) + out_num_groups (int, optional): num_groups in the second GroupNorm. Defaults to min(out_ch // 4, 32) + eps (float, optional): eps parameter of GroupNorms. Defaults to 1e-6. + """ + super().__init__() + in_num_groups = in_num_groups or min(in_ch // 4, 32) + out_num_groups = out_num_groups or min(out_ch // 4, 32) + + self.init_scale = init_scale + + self.input_block = torch.nn.Sequential( + torch.nn.GroupNorm(num_groups=in_num_groups, num_channels=in_ch, eps=eps), + activation, + ) + + self.middle_conv = torch.nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, padding=1) + if diffusion_step_embedding_dim is not None: + self.diffusion_step_projection = torch.nn.Sequential( + activation, + torch.nn.Linear(diffusion_step_embedding_dim, out_ch), + einops.layers.torch.Rearrange("batch dim -> batch dim 1 1"), + ) + + self.output_block = torch.nn.Sequential( + torch.nn.GroupNorm(num_groups=out_num_groups, num_channels=out_ch, eps=eps), + activation, + torch.nn.Dropout(dropout_rate), + torch.nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, padding=1), + ) + + if in_ch != out_ch: + self.residual_projection = torch.nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=1) + + self.act = activation + self.in_ch = in_ch + self.out_ch = out_ch + + self.init_weights_() + + def init_weights_(self): + """Weight initialization""" + for module in self.modules(): + if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + + # a single Conv2d is initialized with gain + torch.nn.init.xavier_uniform_(self.output_block[-1].weight, gain=self.init_scale) + + def forward(self, x: torch.Tensor, diffusion_time_embedding: Optional[torch.Tensor] = None): + """Forward pass of the model. + + Args: + x: input tensor + diffusion_time_embedding: embedding of the diffusion time step + + Returns: + Output tensor + """ + h = self.input_block(x) + h = self.middle_conv(h) + + if diffusion_time_embedding is not None: + h = h + self.diffusion_step_projection(diffusion_time_embedding) + + h = self.output_block(h) + + if x.shape != h.shape: # matching number of channels + x = self.residual_projection(x) + return (x + h) / math.sqrt(2.0) diff --git a/nemo/collections/audio/parts/utils/__init__.py b/nemo/collections/audio/parts/utils/__init__.py new file mode 100644 index 0000000000000..d9155f923f186 --- /dev/null +++ b/nemo/collections/audio/parts/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/asr/parts/utils/audio_utils.py b/nemo/collections/audio/parts/utils/audio.py similarity index 81% rename from nemo/collections/asr/parts/utils/audio_utils.py rename to nemo/collections/audio/parts/utils/audio.py index 8188dbed003b5..25ab66468c825 100644 --- a/nemo/collections/asr/parts/utils/audio_utils.py +++ b/nemo/collections/audio/parts/utils/audio.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Iterable, Optional, Union +from typing import Optional import librosa import numpy as np @@ -23,103 +23,18 @@ import torch from scipy.spatial.distance import pdist, squareform -from nemo.utils import logging SOUND_VELOCITY = 343.0 # m/s -ChannelSelectorType = Union[int, Iterable[int], str] - - -def get_samples(audio_file: str, target_sr: int = 16000, dtype: str = 'float32'): - """ - Read the samples from the given audio_file path. If not specified, the input audio file is automatically - resampled to 16kHz. - - Args: - audio_file (str): - Path to the input audio file - target_sr (int): - Targeted sampling rate - Returns: - samples (numpy.ndarray): - Time-series sample data from the given audio file - """ - with sf.SoundFile(audio_file, 'r') as f: - samples = f.read(dtype=dtype) - if f.samplerate != target_sr: - samples = librosa.core.resample(samples, orig_sr=f.samplerate, target_sr=target_sr) - samples = samples.transpose() - return samples - - -def select_channels(signal: npt.NDArray, channel_selector: Optional[ChannelSelectorType] = None) -> npt.NDArray: - """ - Convert a multi-channel signal to a single-channel signal by averaging over channels or selecting a single channel, - or pass-through multi-channel signal when channel_selector is `None`. - - Args: - signal: numpy array with shape (..., num_channels) - channel selector: string denoting the downmix mode, an integer denoting the channel to be selected, or an iterable - of integers denoting a subset of channels. Channel selector is using zero-based indexing. - If set to `None`, the original signal will be returned. Uses zero-based indexing. - - Returns: - numpy array - """ - if signal.ndim == 1: - # For one-dimensional input, return the input signal. - if channel_selector not in [None, 0, 'average']: - raise ValueError( - 'Input signal is one-dimensional, channel selector (%s) cannot not be used.', str(channel_selector) - ) - return signal - - num_channels = signal.shape[-1] - num_samples = signal.size // num_channels # handle multi-dimensional signals - - if num_channels >= num_samples: - logging.warning( - 'Number of channels (%d) is greater or equal than number of samples (%d). Check for possible transposition.', - num_channels, - num_samples, - ) - - # Samples are arranged as (num_channels, ...) - if channel_selector is None: - # keep the original multi-channel signal - pass - elif channel_selector == 'average': - # default behavior: downmix by averaging across channels - signal = np.mean(signal, axis=-1) - elif isinstance(channel_selector, int): - # select a single channel - if channel_selector >= num_channels: - raise ValueError(f'Cannot select channel {channel_selector} from a signal with {num_channels} channels.') - signal = signal[..., channel_selector] - elif isinstance(channel_selector, Iterable): - # select multiple channels - if max(channel_selector) >= num_channels: - raise ValueError( - f'Cannot select channel subset {channel_selector} from a signal with {num_channels} channels.' - ) - signal = signal[..., channel_selector] - # squeeze the channel dimension if a single-channel is selected - # this is done to have the same shape as when using integer indexing - if len(channel_selector) == 1: - signal = np.squeeze(signal, axis=-1) - else: - raise ValueError(f'Unexpected value for channel_selector ({channel_selector})') - - return signal def sinc_unnormalized(x: float) -> float: """Unnormalized sinc. - + Args: x: input value - + Returns: - Calculates sin(x)/x + Calculates sin(x)/x """ return np.sinc(x / np.pi) @@ -132,14 +47,14 @@ def theoretical_coherence( sound_velocity: float = SOUND_VELOCITY, ) -> npt.NDArray: """Calculate a theoretical coherence matrix for given mic positions and field type. - + Args: mic_positions: 3D Cartesian coordinates of microphone positions, shape (num_mics, 3) field: string denoting the type of the soundfield sample_rate: sampling rate of the input signal in Hz fft_length: length of the fft in samples sound_velocity: speed of sound in m/s - + Returns: Calculated coherence with shape (num_subbands, num_mics, num_mics) """ @@ -171,11 +86,11 @@ def theoretical_coherence( def estimated_coherence(S: npt.NDArray, eps: float = 1e-16) -> npt.NDArray: """Estimate complex-valued coherence for the input STFT-domain signal. - + Args: S: STFT of the signal with shape (num_subbands, num_frames, num_channels) eps: small regularization constant - + Returns: Estimated coherence with shape (num_subbands, num_channels, num_channels) """ @@ -220,10 +135,10 @@ def generate_approximate_noise_field( fft_length: length of the fft in samples method: coherence decomposition method sound_velocity: speed of sound in m/s - + Returns: Signal with coherence approximately matching the desired coherence, shape (num_samples, num_channels) - + References: E.A.P. Habets, I. Cohen and S. Gannot, 'Generating nonstationary multisensor signals under a spatial coherence constraint', Journal of the Acoustical Society @@ -254,16 +169,16 @@ def transform_to_match_coherence( corrcoef_threshold: float = 0.2, ) -> npt.NDArray: """Transform the input multichannel signal to match the desired coherence. - + Note: It's assumed that channels are independent. - + Args: signal: independent noise signals with shape (num_samples, num_channels) desired_coherence: desired coherence with shape (num_subbands, num_channels, num_channels) method: decomposition method used to construct the transformation matrix ref_channel: reference channel for power normalization of the input signal corrcoef_threshold: used to detect input signals with high correlation between channels - + Returns: Signal with coherence approximately matching the desired coherence, shape (num_samples, num_channels) @@ -358,7 +273,7 @@ def mag2db(mag: float, eps: Optional[float] = 1e-16) -> float: def db2mag(db: float) -> float: """Convert value in dB to linear magnitude ratio. - + Args: db: magnitude ratio in dB @@ -374,7 +289,7 @@ def pow2db(power: float, eps: Optional[float] = 1e-16) -> float: Args: power: power ratio in linear scale eps: small regularization constant - + Returns: Power in dB. """ @@ -521,7 +436,7 @@ def convmtx_mc_numpy(x: np.ndarray, filter_length: int, delay: int = 0, n_steps: def scale_invariant_target_numpy(estimate: np.ndarray, target: np.ndarray, eps: float = 1e-8) -> np.ndarray: """Calculate convolution-invariant target for a given estimated signal. - + Calculate scaled target obtained by solving min_scale || scale * target - estimate ||^2 @@ -534,7 +449,7 @@ def scale_invariant_target_numpy(estimate: np.ndarray, target: np.ndarray, eps: Returns: Scaled target signal, shape (T,) """ - assert target.ndim == estimate.ndim == 1, f'Only one-dimensional inputs supported' + assert target.ndim == estimate.ndim == 1, 'Only one-dimensional inputs supported' estimate_dot_target = np.mean(estimate * target) target_pow = np.mean(np.abs(target) ** 2) @@ -546,7 +461,7 @@ def convolution_invariant_target_numpy( estimate: np.ndarray, target: np.ndarray, filter_length, diag_reg: float = 1e-6, eps: float = 1e-8 ) -> np.ndarray: """Calculate convolution-invariant target for a given estimated signal. - + Calculate target filtered with a linear f obtained by solving min_filter || conv(filter, target) - estimate ||^2 @@ -558,7 +473,7 @@ def convolution_invariant_target_numpy( diag_reg: multiplicative factor for relative diagonal loading eps: absolute diagonal loading """ - assert target.ndim == estimate.ndim == 1, f'Only one-dimensional inputs supported' + assert target.ndim == estimate.ndim == 1, 'Only one-dimensional inputs supported' n_fft = 2 ** math.ceil(math.log2(len(target) + len(estimate) - 1)) diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index 01bf51b0e2c63..5533b50922f8f 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import random import warnings from dataclasses import dataclass from functools import partial @@ -319,6 +320,7 @@ def get_lhotse_dataloader_from_config( ReverbWithImpulseResponse( rir_recordings=RecordingSet.from_file(config.rir_path) if config.rir_path is not None else None, p=config.rir_prob, + randgen=random.Random(seed), ) ) diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index 24ca6cffe4589..0cb81c115d059 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -702,18 +702,23 @@ def __init__( output_type = self.OUTPUT_TYPE data, duration_filtered = [], 0.0 total_duration = 0.0 + duration_undefined = True + for audio_file, duration, command, offset in zip(audio_files, durations, labels, offsets): # Duration filters. - if min_duration is not None and duration < min_duration: + if duration is not None and min_duration is not None and duration < min_duration: duration_filtered += duration continue - if max_duration is not None and duration > max_duration: + if duration is not None and max_duration is not None and duration > max_duration: duration_filtered += duration continue data.append(output_type(audio_file, duration, command, offset)) - total_duration += duration + + if duration is not None: + total_duration += duration + duration_undefined = False if index_by_file_id: file_id, _ = os.path.splitext(os.path.basename(audio_file)) @@ -729,8 +734,14 @@ def __init__( else: data.sort(key=lambda entity: entity.duration) - logging.info(f"Filtered duration for loading collection is {duration_filtered / 3600: .2f} hours.") - logging.info(f"Dataset loaded with {len(data)} items, total duration of {total_duration / 3600: .2f} hours.") + if duration_undefined: + logging.info(f"Dataset loaded with {len(data)} items. The durations were not provided.") + else: + logging.info(f"Filtered duration for loading collection is {duration_filtered / 3600: .2f} hours.") + logging.info( + f"Dataset successfully loaded with {len(data)} items and total duration provided from manifest is {total_duration / 3600: .2f} hours." + ) + self.uniq_labels = sorted(set(map(lambda x: x.label, data))) logging.info("# {} files loaded accounting to # {} labels".format(len(data), len(self.uniq_labels))) diff --git a/nemo/collections/common/prompts/canary.py b/nemo/collections/common/prompts/canary.py index aadc976ba4742..e511368a1edfb 100644 --- a/nemo/collections/common/prompts/canary.py +++ b/nemo/collections/common/prompts/canary.py @@ -16,9 +16,9 @@ class CanaryPromptFormatter(PromptFormatter): "template": f"{CANARY_BOS}|source_lang||task||target_lang||pnc|", "slots": { "source_lang": Modality.Text, - "task": Modality.Text, + "task": Modality.TextLiteral("asr", "ast", "s2t_translation", "<|transcribe|>", "<|translate|>"), "target_lang": Modality.Text, - "pnc": Modality.Text, + "pnc": Modality.TextLiteral("yes", "no", "<|pnc|>", "<|nopnc|>"), }, }, OUTPUT_ROLE: { diff --git a/nemo/collections/common/prompts/formatter.py b/nemo/collections/common/prompts/formatter.py index 524b2e62c5a37..8a82563ebbaa6 100644 --- a/nemo/collections/common/prompts/formatter.py +++ b/nemo/collections/common/prompts/formatter.py @@ -20,22 +20,38 @@ EOS_SLOT = "|eos|" -class Modality(Enum): +class BaseModalityType: + @staticmethod + def matches(value: Any) -> bool: + raise NotImplementedError + + +class Text(BaseModalityType): + """Modality for text values.""" + + @staticmethod + def matches(value: str) -> bool: + return isinstance(value, str) + + +class TextLiteral(BaseModalityType): + def __init__(self, *items): + self.allowed_values = items + + def matches(self, value: str) -> bool: + return isinstance(value, str) and value in self.allowed_values + + def __repr__(self): + return f"{self.__class__.__name__}({self.allowed_values})" + + +class Modality: """ Modalities supported as PromptFormatter slot values. """ - Text = "text" - - def matches(self, value: Any) -> bool: - """ - Checks if the provided value is compatible with an instance of Modality. - """ - match self: - case Modality.Text: - return isinstance(value, str) - case _: - return False + Text = Text + TextLiteral = TextLiteral class PromptFormatter(ABC): diff --git a/nemo/collections/common/tokenizers/__init__.py b/nemo/collections/common/tokenizers/__init__.py index 750398670d0c2..6a71920bf6d41 100644 --- a/nemo/collections/common/tokenizers/__init__.py +++ b/nemo/collections/common/tokenizers/__init__.py @@ -21,3 +21,16 @@ from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.collections.common.tokenizers.word_tokenizer import WordTokenizer + + +__all__ = [ + "AggregateTokenizer", + "ByteLevelTokenizer", + "CanaryTokenizer", + "CharTokenizer", + "AutoTokenizer", + "RegExTokenizer", + "SentencePieceTokenizer", + "TokenizerSpec", + "WordTokenizer", +] diff --git a/nemo/collections/common/tokenizers/chat_template_mixin.py b/nemo/collections/common/tokenizers/chat_template_mixin.py new file mode 100644 index 0000000000000..83a5e537519cd --- /dev/null +++ b/nemo/collections/common/tokenizers/chat_template_mixin.py @@ -0,0 +1,179 @@ +import re +from functools import cache + +TEMPLATE_VAR_VALIDATION_PAT = re.compile(r'^\{_[A-Za-z][A-Za-z0-9_]*_\}$') +TEMPLATE_VAR_SEARCH_PAT = re.compile('({_[^}]+_})') + + +class ChatTemplateMixin: + def apply_chat_template(self, messages): + assert self.chat_template is not None + return tokenize_with_chat_template(self, messages, self.chat_template) + + @property + def has_chat_template(self): + return self.chat_template is not None + + +@cache +def is_template_var(s): + # It should start with {_ and end with _}, be non-empty and not contain { or } within. + return re.match(TEMPLATE_VAR_VALIDATION_PAT, s) + + +def extract_template_parts(template, skip_empty=True): + for part in re.split(TEMPLATE_VAR_SEARCH_PAT, template): + # skip empty parts + if skip_empty and part == '': + continue + yield part + + +def strip_template_wrap(s): + if not is_template_var(s): + return s + # Strip the "{_" prefix and the "_}" suffix + return s[2:-2] + + +def render_chat_turn(message, template): + """Renders a chat turn based on template + + Args: + message (Dict) + e.g. {'role': ['user'], 'content': ['What is your favourite fruit?']}, + template (Str): + "[INST] {_content_} [/INST]", + + Returns: + (str, token_id/None): the template formatted message + e.g. + "[INST] What is your favourite fruit? [/INST]", None + """ + ans = [] + for i, template_part in enumerate(extract_template_parts(template)): + if is_template_var(template_part): + template_part = strip_template_wrap(template_part) + if template_part == 'content': + ans.append(message['content']) + else: + # assert i == len(template_parts) - 1, "unsupported" + yield ''.join(ans), template_part + ans = [] + else: + # Otherwise it is literal string + ans.append(template_part) + yield ''.join(ans), None + + +def encode_string_with_special_token(tokenizer, inputs, special_token): + """ + Tokenizes a string or a list of string into their corresponding token_ids + and appends (at the end) a special_token if present. + + Args: + tokenizer: (SPM) + inputs: (Str, List[Str]) + e.g. "Alex" or ["Alex", "nvidia"] + special_token: (Str): + e.g. "eos" + + Returns: + (list[int]): list of token_ids + e.g. + input="Alex", special_token="eos" + Alex->[3413] + eos->[2] + + Will return the following: + [3413, 2] + """ + ans = [] + if isinstance(inputs, str) and inputs != '': + ans += tokenizer.text_to_ids(inputs) + elif isinstance(inputs, list) and len(inputs) > 0: + ans += tokenizer.text_to_ids(''.join(inputs)) + if special_token is not None: + # TODO(@akoumparouli): limit which attributes user-defined string can query. + assert hasattr(tokenizer, special_token), f"Special_token {special_token} is not part of tokenizer" + ans += [getattr(tokenizer, special_token)] + return ans + + +def tokenize_with_chat_template(tokenizer, messages, template): + assert is_chat_input(messages), "Expected input to be chat-template" + assert len(messages) > 0, "Expected non-empty messages" + assert 'roles' in template, "Expected template to have key `roles`." + ans = [] + encode = lambda x, y: encode_string_with_special_token(tokenizer, x, y) + if 'prefix' in template: + for part, special_token in render_chat_turn('', template['prefix']): + ans += encode(part, special_token) + buffer = [] + for message in messages: + assert message['role'] in template['roles'], (message['role'], template['roles']) + msg_template = template['roles'][message['role']] + for templated_messages, special_token in render_chat_turn(message, msg_template): + buffer += [templated_messages] + if special_token is not None: + ans += encode(buffer, special_token) + buffer = [] + # handle tail + ans += encode(buffer, None) + assert len(ans) > 0, 'Expected non-empty output' + return ans + + +def extract_turns(messages, axis): + """ + a collated messages can have multiple chat messages in each dict, + this extracts (vertically) one of them, for example: + + messages = [ + {'role': ['user', 'user'], 'content': ['What is your favourite condiment?', 'What is your favourite fruit?']}, + {'role': ['assistant', 'assistant'], 'content': ["Well, I'm quite partial to a ", "good squeeze of fresh lemon"]}, + {'role': ['user', 'user'], 'content': ['Do you have mayonnaise recipes?', 'Do you have tomato salad recipes?']} + ] + ans = extract_turns(messages, axis=1) + + ans = [ + {'role': ['user'], 'content': ['What is your favourite fruit?']}, + {'role': ['assistant'], 'content': ["good squeeze of fresh lemon"]}, + {'role': ['user'], 'content': ['Do you have tomato salad recipes?']} + ] + """ + ans = [] + for turn in messages: + ans.append({k: v[axis] for k, v in turn.items()}) + return ans + + +def explode_chat_template_input(messages): + """ + Example input + [ + {'role': ['user', 'user'], 'content': ['What is your favourite condiment?', 'What is your favourite fruit?']}, + {'role': ['assistant', 'assistant'], 'content': ["Well, I'm quite partial to a ", "good squeeze of fresh lemon"]}, + {'role': ['user', 'user'], 'content': ['Do you have mayonnaise recipes?', 'Do you have tomato salad recipes?']} + ] + + Notice the 2D axis system of the messages variable, one for the list and one for each item in the list (i.e. + the 'content' contains multiple messages). + """ + assert isinstance(messages, list), "Expected messages to be a list" + assert len(messages) > 0, "Expected non empty messages" + assert all(map(lambda x: isinstance(x, dict), messages)), "Expected messages to contain dicts" + assert all( + map(lambda x: 'role' in x and 'content' in x, messages) + ), "Expected messages each dict to contain 'role' and 'content' fields" + n = len(messages[0]['role']) + assert all( + map(lambda x: len(x['role']) == n, messages) + ), "Expected all batch messages to contain equal number of roles in all turns" + for i in range(n): + yield extract_turns(messages, axis=i) + + +def is_chat_input(messages): + # TOOD(@akoumparouli): improve validation. + return isinstance(messages, list) and len(messages) > 0 and isinstance(messages[0], dict) diff --git a/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py b/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py index 4a47f0e49b1e7..00893b6f379f6 100644 --- a/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py +++ b/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py @@ -20,13 +20,14 @@ import torch from nemo.collections.common.parts.utils import if_exist +from nemo.collections.common.tokenizers.chat_template_mixin import ChatTemplateMixin from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.utils import logging __all__ = ['SentencePieceTokenizer', 'create_spt_model'] -class SentencePieceTokenizer(TokenizerSpec): +class SentencePieceTokenizer(TokenizerSpec, ChatTemplateMixin): """ Sentencepiecetokenizer https://github.com/google/sentencepiece. @@ -38,8 +39,13 @@ class SentencePieceTokenizer(TokenizerSpec): """ def __init__( - self, model_path: str, special_tokens: Optional[Union[Dict[str, str], List[str]]] = None, legacy: bool = False + self, + model_path: str, + special_tokens: Optional[Union[Dict[str, str], List[str]]] = None, + legacy: bool = False, + chat_template: Optional[Dict] = None, ): + self.chat_template = chat_template if not model_path or not os.path.exists(model_path): raise ValueError(f"model_path: {model_path} is invalid") self.tokenizer = sentencepiece.SentencePieceProcessor() @@ -89,6 +95,14 @@ def text_to_tokens(self, text): return self.tokenizer.encode_as_pieces(text) def text_to_ids(self, text, sample_alpha=None): + if isinstance(text, str): + return self._text_to_ids(text, sample_alpha) + elif isinstance(text, list): + return self.apply_chat_template(text) + else: + raise ValueError(f"Expected either str or list input, but got {type(text)}") + + def _text_to_ids(self, text, sample_alpha=None): if self.legacy: ids = [] idx = 0 diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index 19911b544f437..83c0a3af48c07 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -4,7 +4,8 @@ except ImportError: pass -from nemo.collections.llm.api import export_ckpt, import_ckpt, pretrain, train, validate +from nemo.collections.llm import peft, tokenizer +from nemo.collections.llm.api import export_ckpt, finetune, import_ckpt, pretrain, train, validate from nemo.collections.llm.gpt.data import ( DollyDataModule, FineTuningDataModule, @@ -12,6 +13,7 @@ PreTrainingDataModule, SquadDataModule, ) +from nemo.collections.llm.gpt.data.api import dolly, mock, squad from nemo.collections.llm.gpt.model import ( CodeGemmaConfig2B, CodeGemmaConfig7B, @@ -33,13 +35,31 @@ LlamaConfig, LlamaModel, MaskedTokenLossReduction, - Mistral7BConfig, - Mistral7BModel, - MixtralConfig, + MistralConfig7B, + MistralModel, + MixtralConfig8x7B, MixtralModel, gpt_data_step, gpt_forward_step, ) +from nemo.collections.llm.gpt.model.api import ( + code_gemma_2b, + code_gemma_7b, + code_llama_7b, + code_llama_13b, + code_llama_34b, + code_llama_70b, + gemma, + gemma_2b, + gemma_7b, + llama2_7b, + llama2_13b, + llama2_70b, + llama3_8b, + llama3_70b, + mistral, + mixtral, +) __all__ = [ "MockDataModule", @@ -48,9 +68,9 @@ "gpt_data_step", "gpt_forward_step", "MaskedTokenLossReduction", - "Mistral7BConfig", - "Mistral7BModel", - "MixtralConfig", + "MistralConfig7B", + "MistralModel", + "MixtralConfig8x7B", "MixtralModel", "LlamaConfig", "Llama2Config7B", @@ -78,4 +98,26 @@ "export_ckpt", "pretrain", "validate", + "finetune", + "tokenizer", + "mock", + "squad", + "dolly", + "mistral", + "mixtral", + "llama2_7b", + "llama3_8b", + "llama2_13b", + "llama2_70b", + "llama3_70b", + "code_llama_7b", + "code_llama_13b", + "code_llama_34b", + "code_llama_70b", + "gemma", + "gemma_2b", + "gemma_7b", + "code_gemma_2b", + "code_gemma_7b", + "peft", ] diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index 90166d895a1e4..5c9703497597e 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -1,11 +1,17 @@ +from copy import deepcopy from pathlib import Path -from typing import Callable, Optional +from typing import Any, Callable, Optional, Union import pytorch_lightning as pl from typing_extensions import Annotated from nemo.collections.llm.utils import Config, task -from nemo.lightning import AutoResume, MegatronStrategy, NeMoLogger, OptimizerModule, Trainer, io, teardown +from nemo.lightning import AutoResume, NeMoLogger, OptimizerModule, Trainer, io +from nemo.lightning.pytorch.callbacks import PEFT, ModelTransform +from nemo.utils import logging + + +TokenizerType = Any @task(namespace="llm") @@ -15,8 +21,9 @@ def train( trainer: Trainer, log: Annotated[Optional[NeMoLogger], Config[NeMoLogger]] = None, resume: Annotated[Optional[AutoResume], Config[AutoResume]] = None, - opt: Optional[OptimizerModule] = None, - tokenizer: Optional[str] = None, + optim: Optional[OptimizerModule] = None, + tokenizer: Optional[TokenizerType] = None, + model_transform: Optional[Union[PEFT, ModelTransform, Callable]] = None, # TODO: Fix export export: Optional[str] = None, ) -> Path: """ @@ -28,46 +35,40 @@ def train( trainer (Trainer): The trainer instance configured with a MegatronStrategy. log (NeMoLogger): A nemologger instance. resume (Optional[Union[AutoResume, Resume]]): Resume training from a checkpoint. - opt (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default optimizer + optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default optimizer from the model will be used. - tokenizer (Optional[str]): Tokenizer setting to be applied. Can be 'data' or 'model'. + tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' or an instance of TokenizerSpec. export (Optional[str]): Filename to save the exported checkpoint after training. + model_transform (Optional[Union[Callable[[nn.Module], nn.Module], PEFT]]): A model transform to be applied. Returns ------- Path: The directory path where training artifacts are saved. - Raises - ------ - ValueError: If the trainer's strategy is not MegatronStrategy. - Examples -------- - >>> model = MyModel() - >>> data = MyDataModule() - >>> trainer = Trainer(strategy=MegatronStrategy()) - >>> train(model, data, trainer, tokenizer='data', source='path/to/ckpt.ckpt', export='final.ckpt') + >>> from nemo.collections import llm + >>> from nemo import lightning as nl + >>> model = llm.MistralModel() + >>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2) + >>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed") + >>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision) + >>> train(model, data, trainer, tokenizer="data") PosixPath('/path/to/log_dir') """ - _log = log or NeMoLogger() - app_state = _log.setup( - trainer, - resume_if_exists=getattr(resume, "resume_if_exists", False), + app_state = _setup( + model=model, + data=data, + trainer=trainer, + log=log, + resume=resume, + optim=optim, + tokenizer=tokenizer, + model_transform=model_transform, ) - if resume is not None: - resume.setup(model, trainer) - if opt: - opt.connect(model) - if tokenizer: # TODO: Improve this - _use_tokenizer(model, data, tokenizer) - - if hasattr(train, "__io__"): - _save_config_img(app_state.exp_dir, train.__io__) trainer.fit(model, data) - _log.teardown() - return app_state.exp_dir @@ -76,41 +77,152 @@ def pretrain( model: pl.LightningModule, data: pl.LightningDataModule, trainer: Trainer, - source: Optional[str] = None, - # export: Optional[str] = None + log: Annotated[Optional[NeMoLogger], Config[NeMoLogger]] = None, + resume: Annotated[Optional[AutoResume], Config[AutoResume]] = None, + optim: Optional[OptimizerModule] = None, ) -> Path: - return train(model=model, data=data, trainer=trainer, tokenizer="data", source=source) + """ + Pretrains a model using the specified data and trainer, with optional logging, resuming, and optimization. + + This function is a wrapper around the `train` function, specifically configured for pretraining tasks. + Note, by default it will use the tokenizer from the model. + + Args: + model (pl.LightningModule): The model to be pretrained. + data (pl.LightningDataModule): The data module containing pretraining data. + trainer (Trainer): The trainer instance configured with a MegatronStrategy. + log (NeMoLogger): A nemologger instance. + resume (Optional[AutoResume]): Resume training from a checkpoint. + optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default + optimizer from the model will be used. + + Returns: + Path: The directory path where pretraining artifacts are saved. + + Examples: + >>> from nemo.collections import llm + >>> from nemo import lightning as nl + >>> model = llm.MistralModel() + >>> data = llm.PretrainingDataModule(paths=[...], seq_length=4096, global_batch_size=16, micro_batch_size=2) + >>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed") + >>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision) + >>> llm.pretrain(model, data, trainer) + PosixPath('/path/to/log_dir') + """ + return train( + model=model, + data=data, + trainer=trainer, + log=log, + resume=resume, + optim=optim, + tokenizer="data", + ) @task(namespace="llm") -def validate( +def finetune( model: pl.LightningModule, data: pl.LightningDataModule, trainer: Trainer, - tokenizer: Optional[str] = None, - source: Optional[str] = None, - export: Optional[str] = None, + log: Annotated[Optional[NeMoLogger], Config[NeMoLogger]] = None, + resume: Annotated[Optional[AutoResume], Config[AutoResume]] = None, + optim: Optional[OptimizerModule] = None, + peft: Optional[Union[PEFT, ModelTransform, Callable]] = None, ) -> Path: - if not isinstance(trainer.strategy, MegatronStrategy): - raise ValueError("Only MegatronStrategy is supported") + """ + Finetunes a model using the specified data and trainer, with optional logging, resuming, and PEFT. - validate_kwargs = {} - run_dir = Path(trainer.logger.log_dir) - export_dir = run_dir / "export" + Note, by default it will use the tokenizer from the model. + + Args: + model (pl.LightningModule): The model to be finetuned. + data (pl.LightningDataModule): The data module containing finetuning data. + trainer (Trainer): The trainer instance configured with a MegatronStrategy. + log (NeMoLogger): A nemologger instance. + resume (Optional[AutoResume]): Resume training from a checkpoint. + optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default + optimizer from the model will be used. + peft (Optional[PEFT]): A PEFT (Parameter-Efficient Fine-Tuning) configuration to be applied. + + Returns: + Path: The directory path where finetuning artifacts are saved. + + Examples: + >>> from nemo.collections import llm + >>> from nemo import lightning as nl + >>> model = llm.MistralModel() + >>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2) + >>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed") + >>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision) + >>> finetune(model, data, trainer, peft=llm.peft.LoRA()]) + PosixPath('/path/to/log_dir') + """ + + return train( + model=model, + data=data, + trainer=trainer, + log=log, + resume=resume, + optim=optim, + tokenizer="model", + model_transform=peft, + ) - if tokenizer: # TODO: Improve this - _use_tokenizer(model, data, tokenizer) - if source: - _add_ckpt_path(source, model, validate_kwargs) - trainer.validate(model, data, **validate_kwargs) - trainer.save_checkpoint(export_dir) - if export: - teardown(trainer) - del trainer, model, data - export_ckpt(export_dir, export) +@task(namespace="llm") +def validate( + model: pl.LightningModule, + data: pl.LightningDataModule, + trainer: Trainer, + log: Annotated[Optional[NeMoLogger], Config[NeMoLogger]] = None, + resume: Annotated[Optional[AutoResume], Config[AutoResume]] = None, + optim: Optional[OptimizerModule] = None, + tokenizer: Optional[TokenizerType] = None, + model_transform: Optional[Union[PEFT, ModelTransform, Callable]] = None, +) -> Path: + """ + Validates a model using the specified data and trainer, with optional logging, resuming, and model transformations. - return run_dir + Args: + model (pl.LightningModule): The model to be validated. + data (pl.LightningDataModule): The data module containing validation data. + trainer (Trainer): The trainer instance configured with a MegatronStrategy. + log (NeMoLogger): A nemologger instance. + resume (Optional[AutoResume]): Resume from a checkpoint for validation. + optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default optimizer + from the model will be used. + tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' or an instance of TokenizerSpec. + model_transform (Optional[Union[Callable[[nn.Module], nn.Module], PEFT]]): A model transform to be applied. + + Returns: + Path: The directory path where validation artifacts are saved. + + Examples: + >>> from nemo.collections import llm + >>> from nemo import lightning as nl + >>> model = llm.MistralModel() + >>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2) + >>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed") + >>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision) + >>> validate(model, data, trainer, tokenizer="data") + PosixPath('/path/to/log_dir') + """ + app_state = _setup( + model=model, + data=data, + trainer=trainer, + log=log, + resume=resume, + optim=optim, + tokenizer=tokenizer, + model_transform=model_transform, + ) + + trainer.validate(model, data) + + return app_state.exp_dir @task(name="import", namespace="llm") @@ -124,7 +236,7 @@ def import_ckpt( def load_connector_from_trainer_ckpt(path: Path, target: str) -> io.ModelConnector: - return io.load_ckpt(path).model.exporter(target, path) + return io.load_context(path).model.exporter(target, path) @task(name="export", namespace="llm") @@ -138,24 +250,67 @@ def export_ckpt( return io.export_ckpt(path, target, output_path, overwrite, load_connector) -def _use_tokenizer(model: pl.LightningModule, data: pl.LightningDataModule, tokenizer: str) -> None: +def _use_tokenizer(model: pl.LightningModule, data: pl.LightningDataModule, tokenizer: TokenizerType) -> None: if tokenizer == "data": - model.tokenizer = data.tokenizer + _set_with_io(model, "tokenizer", data.tokenizer) elif tokenizer == "model": - data.tokenizer = model.tokenizer + _set_with_io(data, "tokenizer", model.tokenizer) + else: + try: + from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + if isinstance(tokenizer, TokenizerSpec): + _set_with_io(model, "tokenizer", tokenizer) + _set_with_io(data, "tokenizer", tokenizer) + else: + raise ValueError(f"Expected TokenizerSpec or 'data' or 'model', got: {tokenizer}") + except ImportError: + raise ValueError("TokenizerSpec is not available") -def _add_ckpt_path(source, model, kwargs) -> None: - if io.is_distributed_ckpt(source): - kwargs["ckpt_path"] = source - else: - kwargs["ckpt_path"] = model.import_ckpt(source) +def _setup( + model: pl.LightningModule, + data: pl.LightningDataModule, + trainer: Trainer, + log: Optional[NeMoLogger], + resume: Optional[AutoResume], + optim: Optional[OptimizerModule], + tokenizer: Optional[TokenizerType], + model_transform: Optional[Union[PEFT, ModelTransform, Callable]], +) -> Any: # Return type is Any because app_state's type is not specified + _log = log or NeMoLogger() + if resume and resume.adapter_path and _log.ckpt: + logging.info("Disabling try_restore_best_ckpt restoration for adapters") + _log.ckpt.try_restore_best_ckpt = False + + app_state = _log.setup( + trainer, + resume_if_exists=getattr(resume, "resume_if_exists", False), + task_config=getattr(train, "__io__", None), + ) + if resume is not None: + resume.setup(model, trainer) + + if optim: + optim.connect(model) + if tokenizer: # TODO: Improve this + _use_tokenizer(model, data, tokenizer) + + if model_transform: + _set_with_io(model, "model_transform", model_transform) + + # Add ModelTransform callback to Trainer if needed + if getattr(model, "model_transform", None): + if not any(isinstance(cb, ModelTransform) for cb in trainer.callbacks): + if isinstance(model_transform, ModelTransform): + trainer.callbacks.append(model_transform) + else: + trainer.callbacks.append(ModelTransform()) + + return app_state -def _save_config_img(*args, **kwargs): - try: - from nemo_sdk.utils import save_config_img - save_config_img(*args, **kwargs) - except ImportError: - pass +def _set_with_io(obj, attr, value): + setattr(obj, attr, value) + if hasattr(obj, "__io__") and hasattr(value, "__io__"): + setattr(obj.__io__, attr, deepcopy(value.__io__)) diff --git a/nemo/collections/llm/fn/activation.py b/nemo/collections/llm/fn/activation.py new file mode 100644 index 0000000000000..89b5ba93f0f64 --- /dev/null +++ b/nemo/collections/llm/fn/activation.py @@ -0,0 +1,11 @@ +import torch + + +@torch.jit.script +def gelu_impl(x): + """OpenAI's gelu implementation.""" + return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) + + +def openai_gelu(x): + return gelu_impl(x) diff --git a/nemo/collections/llm/gpt/data/api.py b/nemo/collections/llm/gpt/data/api.py new file mode 100644 index 0000000000000..e674fea91b793 --- /dev/null +++ b/nemo/collections/llm/gpt/data/api.py @@ -0,0 +1,24 @@ +import pytorch_lightning as pl + +from nemo.collections.llm.gpt.data.dolly import DollyDataModule +from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.data.squad import SquadDataModule +from nemo.collections.llm.utils import factory + + +@factory +def mock() -> pl.LightningDataModule: + return MockDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2) + + +@factory +def squad() -> pl.LightningDataModule: + return SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2) + + +@factory +def dolly() -> pl.LightningDataModule: + return DollyDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2) + + +__all__ = ["mock", "squad", "dolly"] diff --git a/nemo/collections/llm/gpt/data/mock.py b/nemo/collections/llm/gpt/data/mock.py index ccc1acfd6a2a4..37e255bf5aec0 100644 --- a/nemo/collections/llm/gpt/data/mock.py +++ b/nemo/collections/llm/gpt/data/mock.py @@ -53,12 +53,18 @@ def setup(self, stage: str = "") -> None: self._test_ds = _MockGPTDataset(self.tokenizer, "test", self.num_test_samples, self.seq_length) def train_dataloader(self) -> TRAIN_DATALOADERS: + if not hasattr(self, "_train_ds"): + self.setup() return self._create_dataloader(self._train_ds) def val_dataloader(self) -> EVAL_DATALOADERS: + if not hasattr(self, "_validation_ds"): + self.setup() return self._create_dataloader(self._validation_ds) def test_dataloader(self) -> EVAL_DATALOADERS: + if not hasattr(self, "_test_ds"): + self.setup() return self._create_dataloader(self._test_ds) def _create_dataloader(self, dataset, **kwargs) -> DataLoader: diff --git a/nemo/collections/llm/gpt/data/pre_training.py b/nemo/collections/llm/gpt/data/pre_training.py index a659823b085e2..46b407410d316 100644 --- a/nemo/collections/llm/gpt/data/pre_training.py +++ b/nemo/collections/llm/gpt/data/pre_training.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional import pytorch_lightning as pl from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS @@ -17,7 +17,8 @@ class PreTrainingDataModule(pl.LightningDataModule): def __init__( self, - path: Path, + paths: Path | List[Path], + weights: Optional[List[float]] = None, seq_length: int = 2048, tokenizer: Optional["TokenizerSpec"] = None, micro_batch_size: int = 4, @@ -34,9 +35,19 @@ def __init__( eod_mask_loss: bool = False, seed: int = 1234, split: str = "900,50,50", + index_mapping_dir: Optional[str] = None, ) -> None: super().__init__() - self.path = path + if not isinstance(paths, (list, tuple)): + paths = [paths] + if weights is not None: + assert len(weights) == len(paths) + if len(weights) == 1: + # weights must be None if there is only one dataset + weights = None + + self.paths = paths + self.weights = weights self.seq_length = seq_length self.tokenizer = tokenizer self.num_train_samples = num_train_samples @@ -50,6 +61,8 @@ def __init__( self.eod_mask_loss = eod_mask_loss self.seed = seed self.split = split + self.index_mapping_dir = index_mapping_dir + self.init_global_step = 0 from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer @@ -74,13 +87,13 @@ def setup(self, stage: str = "") -> None: assert max_train_steps > 0, "Please specify trainer.max_steps" eval_iters = (max_train_steps // self.trainer.val_check_interval + 1) * self.trainer.limit_val_batches test_iters = self.trainer.limit_test_batches - num_train_samples = max_train_steps * self.data_sampler.global_batch_size - num_val_samples = eval_iters * self.data_sampler.global_batch_size - num_test_samples = test_iters * self.data_sampler.global_batch_size + num_train_samples = int(max_train_steps * self.data_sampler.global_batch_size) + num_val_samples = int(eval_iters * self.data_sampler.global_batch_size) + num_test_samples = int(test_iters * self.data_sampler.global_batch_size) if self.trainer.limit_val_batches <= 1.0 and isinstance(self.trainer.limit_val_batches, float): # This is to make sure we only have one epoch on every validation iteration - num_val_samples = 1 + num_val_samples = None if self.weights is None else 1 train_valid_test_num_samples = [num_train_samples, num_val_samples, num_test_samples] self._train_ds, self._validation_ds, self._test_ds = BlendedMegatronDatasetBuilder( @@ -117,6 +130,7 @@ def test_dataloader(self) -> EVAL_DATALOADERS: return self._create_dataloader(self._test_ds) def _create_dataloader(self, dataset, **kwargs) -> DataLoader: + self.init_global_step = self.trainer.global_step return DataLoader( dataset, num_workers=self.num_workers, @@ -131,13 +145,53 @@ def gpt_dataset_config(self) -> "GPTDatasetConfig": from megatron.core.datasets.gpt_dataset import GPTDatasetConfig return GPTDatasetConfig( - blend=[[str(self.path)], [1.0]], + blend=[[str(path) for path in self.paths], self.weights], random_seed=self.seed, sequence_length=self.seq_length, tokenizer=self.tokenizer, split=self.split, - path_to_cache=None, + path_to_cache=self.index_mapping_dir, reset_position_ids=self.reset_position_ids, reset_attention_mask=self.reset_attention_mask, eod_mask_loss=self.eod_mask_loss, ) + + def state_dict(self) -> Dict[str, Any]: + """Called when saving a checkpoint, implement to generate and save datamodule state. + + Returns: + A dictionary containing datamodule state. + + """ + consumed_samples = self.data_sampler.compute_consumed_samples(self.trainer.global_step - self.init_global_step) + return {'consumed_samples': consumed_samples} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Called when loading a checkpoint, implement to reload datamodule state given datamodule stat + + Args: + state_dict: the datamodule state returned by ``state_dict``. + + """ + try: + from apex.transformer.pipeline_parallel.utils import _GLOBAL_NUM_MICROBATCHES_CALCULATOR + except ModuleNotFoundError: + from nemo.lightning.apex_utils import _GLOBAL_NUM_MICROBATCHES_CALCULATOR + consumed_samples = state_dict['consumed_samples'] + self.data_sampler.init_consumed_samples = consumed_samples + self.data_sampler.prev_consumed_samples = consumed_samples + num_microbatch_calculator = _GLOBAL_NUM_MICROBATCHES_CALCULATOR # noqa: SLF001 + + num_microbatch_calculator.update( + consumed_samples=consumed_samples, + consistency_check=False, + ) + current_global_batch_size = num_microbatch_calculator.current_global_batch_size + '''pl_module.log( + "global_batch_size", + current_global_batch_size, + prog_bar=True, + rank_zero_only=True, + batch_size=1, + )''' + self.if_first_step = 1 diff --git a/nemo/collections/llm/gpt/model/__init__.py b/nemo/collections/llm/gpt/model/__init__.py index 2da72539fd158..4391a41293eed 100644 --- a/nemo/collections/llm/gpt/model/__init__.py +++ b/nemo/collections/llm/gpt/model/__init__.py @@ -4,18 +4,39 @@ MaskedTokenLossReduction, gpt_data_step, gpt_forward_step, + local_layer_spec, + transformer_engine_layer_spec, ) -from nemo.collections.llm.gpt.model.gemma import * -from nemo.collections.llm.gpt.model.llama import * -from nemo.collections.llm.gpt.model.mistral_7b import Mistral7BConfig, Mistral7BModel -from nemo.collections.llm.gpt.model.mixtral import MixtralConfig, MixtralModel +from nemo.collections.llm.gpt.model.gemma import ( + CodeGemmaConfig2B, + CodeGemmaConfig7B, + GemmaConfig, + GemmaConfig2B, + GemmaConfig7B, + GemmaModel, +) +from nemo.collections.llm.gpt.model.llama import ( + CodeLlamaConfig7B, + CodeLlamaConfig13B, + CodeLlamaConfig34B, + CodeLlamaConfig70B, + Llama2Config7B, + Llama2Config13B, + Llama2Config70B, + Llama3Config8B, + Llama3Config70B, + LlamaConfig, + LlamaModel, +) +from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel +from nemo.collections.llm.gpt.model.mixtral import MixtralConfig8x7B, MixtralModel __all__ = [ "GPTConfig", "GPTModel", - "Mistral7BConfig", - "Mistral7BModel", - "MixtralConfig", + "MistralConfig7B", + "MistralModel", + "MixtralConfig8x7B", "MixtralModel", "LlamaConfig", "Llama2Config7B", @@ -37,4 +58,6 @@ "MaskedTokenLossReduction", "gpt_data_step", "gpt_forward_step", + "transformer_engine_layer_spec", + "local_layer_spec", ] diff --git a/nemo/collections/llm/gpt/model/api.py b/nemo/collections/llm/gpt/model/api.py new file mode 100644 index 0000000000000..7c8cbf4d02e69 --- /dev/null +++ b/nemo/collections/llm/gpt/model/api.py @@ -0,0 +1,125 @@ +import pytorch_lightning as pl + +from nemo.collections.llm.gpt.model.gemma import ( + CodeGemmaConfig2B, + CodeGemmaConfig7B, + GemmaConfig, + GemmaConfig2B, + GemmaConfig7B, + GemmaModel, +) +from nemo.collections.llm.gpt.model.llama import ( + CodeLlamaConfig7B, + CodeLlamaConfig13B, + CodeLlamaConfig34B, + CodeLlamaConfig70B, + Llama2Config7B, + Llama2Config13B, + Llama2Config70B, + Llama3Config8B, + Llama3Config70B, + LlamaModel, +) +from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel +from nemo.collections.llm.gpt.model.mixtral import MixtralConfig8x7B, MixtralModel +from nemo.collections.llm.utils import factory + + +@factory +def mistral() -> pl.LightningModule: + return MistralModel(MistralConfig7B()) + + +@factory +def mixtral() -> pl.LightningModule: + return MixtralModel(MixtralConfig8x7B()) + + +@factory +def llama2_7b() -> pl.LightningModule: + return LlamaModel(Llama2Config7B()) + + +@factory +def llama3_8b() -> pl.LightningModule: + return LlamaModel(Llama3Config8B()) + + +@factory +def llama2_13b() -> pl.LightningModule: + return LlamaModel(Llama2Config13B()) + + +@factory +def llama2_70b() -> pl.LightningModule: + return LlamaModel(Llama2Config70B()) + + +@factory +def llama3_70b() -> pl.LightningModule: + return LlamaModel(Llama3Config70B()) + + +@factory +def code_llama_7b() -> pl.LightningModule: + return LlamaModel(CodeLlamaConfig7B()) + + +@factory +def code_llama_13b() -> pl.LightningModule: + return LlamaModel(CodeLlamaConfig13B()) + + +@factory +def code_llama_34b() -> pl.LightningModule: + return LlamaModel(CodeLlamaConfig34B()) + + +@factory +def code_llama_70b() -> pl.LightningModule: + return LlamaModel(CodeLlamaConfig70B()) + + +@factory +def gemma() -> pl.LightningModule: + return GemmaModel(GemmaConfig()) + + +@factory +def gemma_2b() -> pl.LightningModule: + return GemmaModel(GemmaConfig2B()) + + +@factory +def gemma_7b() -> pl.LightningModule: + return GemmaModel(GemmaConfig7B()) + + +@factory +def code_gemma_2b() -> pl.LightningModule: + return GemmaModel(CodeGemmaConfig2B()) + + +@factory +def code_gemma_7b() -> pl.LightningModule: + return GemmaModel(CodeGemmaConfig7B()) + + +__all__ = [ + "mistral", + "mixtral", + "llama2_7b", + "llama3_8b", + "llama2_13b", + "llama2_70b", + "llama3_70b", + "code_llama_7b", + "code_llama_13b", + "code_llama_34b", + "code_llama_70b", + "gemma", + "gemma_2b", + "gemma_7b", + "code_gemma_2b", + "code_gemma_7b", +] diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index 1a3b5c754a396..4c1f425d7f993 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -1,16 +1,19 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, Literal, Optional +from typing import TYPE_CHECKING, Callable, Dict, Literal, Optional, Union import pytorch_lightning as L import torch import torch.distributed +from megatron.core.models.gpt import gpt_layer_specs from megatron.core.optimizer import OptimizerConfig +from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_config import TransformerConfig +from torch import nn from nemo.collections.llm import fn from nemo.lightning import get_vocab_size, io from nemo.lightning.megatron_parallel import MaskedTokenLossReduction -from nemo.lightning.pytorch.opt import MegatronOptimizerModule, OptimizerModule +from nemo.lightning.pytorch.optim import MegatronOptimizerModule, OptimizerModule if TYPE_CHECKING: from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel @@ -18,8 +21,64 @@ from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +def gpt_data_step(dataloader_iter) -> Dict[str, torch.Tensor]: + from megatron.core import parallel_state + + # Based on: https://github.com/NVIDIA/Megatron-LM/blob/main/pretrain_gpt.py#L87 + # https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L828-L842 + + batch = next(dataloader_iter) + + _batch: dict + if isinstance(batch, tuple) and len(batch) == 3: + _batch = batch[0] + else: + _batch = batch + + required_keys = set() + required_keys.add("attention_mask") + if parallel_state.is_pipeline_first_stage(): + required_keys.update(("tokens", "position_ids")) + if parallel_state.is_pipeline_last_stage(): + required_keys.update(("labels", "loss_mask")) + # if self.get_attention_mask_from_fusion: + # required_keys.remove('attention_mask') + + _batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in _batch.items()} + # slice batch along sequence dimension for context parallelism + output = get_batch_on_this_context_parallel_rank(_batch) + + return output + + +def gpt_forward_step(model, batch) -> torch.Tensor: + forward_args = { + "input_ids": batch["tokens"], + "position_ids": batch["position_ids"], + "attention_mask": batch["attention_mask"], + "labels": batch["labels"], + } + + if 'cu_seqlens' in batch: + forward_args['packed_seq_params'] = get_packed_seq_params(batch) + + return model(**forward_args) + + +def transformer_engine_layer_spec(config: "GPTConfig") -> ModuleSpec: + return gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec( + num_experts=config.num_moe_experts, moe_grouped_gemm=config.moe_grouped_gemm, qk_layernorm=config.qk_layernorm + ) + + +def local_layer_spec(config: "GPTConfig") -> ModuleSpec: + return gpt_layer_specs.get_gpt_layer_local_spec( + num_experts=config.num_moe_experts, moe_grouped_gemm=config.moe_grouped_gemm, qk_layernorm=config.qk_layernorm + ) + + @dataclass -class GPTConfig(TransformerConfig): +class GPTConfig(TransformerConfig, io.IOMixin): # From megatron.core.models.gpt.gpt_model.GPTModel fp16_lm_cross_entropy: bool = False parallel_output: bool = True @@ -34,6 +93,10 @@ class GPTConfig(TransformerConfig): # TODO: Move this to better places? get_attention_mask_from_fusion: bool = False + transformer_layer_spec: Union[ModuleSpec, Callable[["GPTConfig"], ModuleSpec]] = transformer_engine_layer_spec + forward_step_fn: Callable = gpt_forward_step + data_step_fn: Callable = gpt_data_step + def configure_model(self, tokenizer) -> "MCoreGPTModel": vp_size = self.virtual_pipeline_model_parallel_size if vp_size: @@ -43,12 +106,15 @@ def configure_model(self, tokenizer) -> "MCoreGPTModel": ) % vp_size == 0, "Make sure the number of model chunks is the same across all pipeline stages." from megatron.core import parallel_state - from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel + transformer_layer_spec = self.transformer_layer_spec + if not isinstance(transformer_layer_spec, ModuleSpec): + transformer_layer_spec = transformer_layer_spec(self) + return MCoreGPTModel( self, - transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec(self.num_moe_experts), + transformer_layer_spec=transformer_layer_spec, vocab_size=get_vocab_size(self, tokenizer.vocab_size, self.make_vocab_size_divisible_by), max_sequence_length=self.seq_length, fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, @@ -70,15 +136,18 @@ def __init__( # TODO: Add transformer_layer_spec when we update mcore optim: Optional[OptimizerModule] = None, tokenizer: Optional["TokenizerSpec"] = None, + model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, ): super().__init__() self.config = config self.tokenizer = tokenizer self.optim = optim or MegatronOptimizerModule(config=OptimizerConfig(lr=1e-4, use_distributed_optimizer=True)) self.optim.connect(self) # This will bind the `configure_optimizers` method + self.model_transform = model_transform def configure_model(self) -> None: - self.module = self.config.configure_model(self.tokenizer) + if not hasattr(self, "module"): + self.module = self.config.configure_model(self.tokenizer) def forward( self, @@ -101,14 +170,13 @@ def forward( return output_tensor def data_step(self, dataloader_iter) -> Dict[str, torch.Tensor]: - return gpt_data_step(dataloader_iter) + return self.config.data_step_fn(dataloader_iter) def forward_step(self, batch) -> torch.Tensor: - return gpt_forward_step(self, batch) + return self.config.forward_step_fn(self, batch) def training_step(self, batch, batch_idx=None) -> torch.Tensor: # In mcore the loss-function is part of the forward-pass (when labels are provided) - return self.forward_step(batch) def validation_step(self, batch, batch_idx=None) -> torch.Tensor: @@ -123,50 +191,6 @@ def validation_loss_reduction(self) -> MaskedTokenLossReduction: return MaskedTokenLossReduction(validation_step=True) -def gpt_data_step(dataloader_iter) -> Dict[str, torch.Tensor]: - from megatron.core import parallel_state - - # Based on: https://github.com/NVIDIA/Megatron-LM/blob/main/pretrain_gpt.py#L87 - # https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L828-L842 - - batch = next(dataloader_iter) - - _batch: dict - if isinstance(batch, tuple) and len(batch) == 3: - _batch = batch[0] - else: - _batch = batch - - required_keys = set() - required_keys.add("attention_mask") - if parallel_state.is_pipeline_first_stage(): - required_keys.update(("tokens", "position_ids")) - if parallel_state.is_pipeline_last_stage(): - required_keys.update(("labels", "loss_mask")) - # if self.get_attention_mask_from_fusion: - # required_keys.remove('attention_mask') - - _batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in _batch.items()} - # slice batch along sequence dimension for context parallelism - output = get_batch_on_this_context_parallel_rank(_batch) - - return output - - -def gpt_forward_step(model, batch) -> torch.Tensor: - forward_args = { - "input_ids": batch["tokens"], - "position_ids": batch["position_ids"], - "attention_mask": batch["attention_mask"], - "labels": batch["labels"], - } - - if 'cu_seqlens' in batch: - forward_args['packed_seq_params'] = get_packed_seq_params(batch) - - return model(**forward_args) - - def get_batch_on_this_context_parallel_rank(batch): from megatron.core import parallel_state @@ -219,4 +243,11 @@ def get_packed_seq_params(batch): ) -__all__ = ["GPTModel", "GPTConfig", "gpt_data_step", "gpt_forward_step"] +__all__ = [ + "GPTModel", + "GPTConfig", + "gpt_data_step", + "gpt_forward_step", + "transformer_engine_layer_spec", + "local_layer_spec", +] diff --git a/nemo/collections/llm/gpt/model/gemma.py b/nemo/collections/llm/gpt/model/gemma.py index ff9772b1b74c9..6493bb0dfad7a 100644 --- a/nemo/collections/llm/gpt/model/gemma.py +++ b/nemo/collections/llm/gpt/model/gemma.py @@ -3,10 +3,11 @@ from typing import TYPE_CHECKING, Annotated, Callable, Optional import torch +from torch import nn +from nemo.collections.llm.fn.activation import openai_gelu from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel from nemo.collections.llm.utils import Config -from nemo.collections.nlp.modules.common.megatron.utils import openai_gelu from nemo.lightning import OptimizerModule, io, teardown if TYPE_CHECKING: @@ -68,8 +69,9 @@ def __init__( config: Annotated[Optional[GemmaConfig], Config[GemmaConfig]] = None, optim: Optional[OptimizerModule] = None, tokenizer: Optional["TokenizerSpec"] = None, + model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, ): - super().__init__(config or GemmaConfig(), optim=optim, tokenizer=tokenizer) + super().__init__(config or GemmaConfig(), optim=optim, tokenizer=tokenizer, model_transform=model_transform) @io.model_importer(GemmaModel, "hf") @@ -172,11 +174,11 @@ def convert_state(self, source, target): @property def tokenizer(self): - return io.load_ckpt(str(self)).model.tokenizer.tokenizer + return io.load_context(str(self)).model.tokenizer.tokenizer @property def config(self) -> "GemmaConfig": - source: GemmaConfig = io.load_ckpt(str(self)).model.config + source: GemmaConfig = io.load_context(str(self)).model.config from transformers import GemmaConfig as HFGemmaConfig diff --git a/nemo/collections/llm/gpt/model/llama.py b/nemo/collections/llm/gpt/model/llama.py index aa089b0770416..c7add828b7f42 100644 --- a/nemo/collections/llm/gpt/model/llama.py +++ b/nemo/collections/llm/gpt/model/llama.py @@ -4,6 +4,7 @@ import torch import torch.nn.functional as F +from torch import nn from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel from nemo.collections.llm.utils import Config @@ -103,8 +104,9 @@ def __init__( config: Annotated[Optional[LlamaConfig], Config[LlamaConfig]] = None, optim: Optional[OptimizerModule] = None, tokenizer: Optional["TokenizerSpec"] = None, + model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, ): - super().__init__(config or LlamaConfig(), optim=optim, tokenizer=tokenizer) + super().__init__(config or LlamaConfig(), optim=optim, tokenizer=tokenizer, model_transform=model_transform) @io.model_importer(LlamaModel, "hf") @@ -209,11 +211,11 @@ def convert_state(self, source, target): @property def tokenizer(self): - return io.load_ckpt(str(self)).model.tokenizer.tokenizer + return io.load_context(str(self)).model.tokenizer.tokenizer @property def config(self) -> "HFLlamaConfig": - source: LlamaConfig = io.load_ckpt(str(self)).model.config + source: LlamaConfig = io.load_context(str(self)).model.config from transformers import LlamaConfig as HFLlamaConfig diff --git a/nemo/collections/llm/gpt/model/mistral_7b.py b/nemo/collections/llm/gpt/model/mistral.py similarity index 89% rename from nemo/collections/llm/gpt/model/mistral_7b.py rename to nemo/collections/llm/gpt/model/mistral.py index ff9591581f86f..d1049cfe77cec 100644 --- a/nemo/collections/llm/gpt/model/mistral_7b.py +++ b/nemo/collections/llm/gpt/model/mistral.py @@ -5,12 +5,13 @@ import pytorch_lightning as pl import torch import torch.nn.functional as F +from torch import nn from typing_extensions import Annotated from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel from nemo.collections.llm.utils import Config from nemo.lightning import io, teardown -from nemo.lightning.pytorch.opt import OptimizerModule +from nemo.lightning.pytorch.optim import OptimizerModule if TYPE_CHECKING: from transformers import MistralConfig, MistralForCausalLM @@ -20,7 +21,7 @@ @dataclass -class Mistral7BConfig(GPTConfig): +class MistralConfig7B(GPTConfig): normalization: str = "RMSNorm" activation_func: Callable = F.silu position_embedding_type: str = "rope" @@ -40,20 +41,23 @@ class Mistral7BConfig(GPTConfig): window_size: List[int] = field(default_factory=lambda: [4096, 0]) -class Mistral7BModel(GPTModel): +class MistralModel(GPTModel): def __init__( self, - config: Annotated[Optional[Mistral7BConfig], Config[Mistral7BConfig]] = None, + config: Annotated[Optional[MistralConfig7B], Config[MistralConfig7B]] = None, optim: Optional[OptimizerModule] = None, tokenizer: Optional["TokenizerSpec"] = None, + model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, ): - super().__init__(config or Mistral7BConfig(), optim=optim, tokenizer=tokenizer) + super().__init__( + config or MistralConfig7B(), optim=optim, tokenizer=tokenizer, model_transform=model_transform + ) -@io.model_importer(Mistral7BModel, "hf") -class HFMistral7BImporter(io.ModelConnector["MistralForCausalLM", Mistral7BModel]): - def init(self) -> Mistral7BModel: - return Mistral7BModel(self.config, tokenizer=self.tokenizer) +@io.model_importer(MistralModel, "hf") +class HFMistralImporter(io.ModelConnector["MistralForCausalLM", MistralModel]): + def init(self) -> MistralModel: + return MistralModel(self.config, tokenizer=self.tokenizer) def apply(self, output_path: Path) -> Path: from transformers import MistralForCausalLM @@ -91,7 +95,7 @@ def tokenizer(self) -> "AutoTokenizer": return AutoTokenizer(str(self)) @property - def config(self) -> Mistral7BConfig: + def config(self) -> MistralConfig7B: from transformers import MistralConfig source = MistralConfig.from_pretrained(str(self)) @@ -102,7 +106,7 @@ def make_vocab_size_divisible_by(mistral_vocab_size): base //= 2 return base - output = Mistral7BConfig( + output = MistralConfig7B( seq_length=source.sliding_window, num_layers=source.num_hidden_layers, hidden_size=source.hidden_size, @@ -122,8 +126,8 @@ def make_vocab_size_divisible_by(mistral_vocab_size): return output -@io.model_exporter(Mistral7BModel, "hf") -class HFMistral7BExporter(io.ModelConnector[Mistral7BModel, "MistralForCausalLM"]): +@io.model_exporter(MistralModel, "hf") +class HFMistralExporter(io.ModelConnector[MistralModel, "MistralForCausalLM"]): def init(self) -> "MistralForCausalLM": from transformers import AutoModelForCausalLM @@ -159,15 +163,15 @@ def convert_state(self, source, target): @property def tokenizer(self): - return io.load_ckpt(str(self)).model.tokenizer.tokenizer + return io.load_context(str(self)).model.tokenizer.tokenizer @property def config(self) -> "MistralConfig": - source: Mistral7BConfig = io.load_ckpt(str(self)).model.config + source: MistralConfig7B = io.load_context(str(self)).model.config - from transformers import MistralConfig + from transformers import MistralConfig as HfMistralConfig - return MistralConfig( + return HfMistralConfig( sliding_window=source.window_size[0], num_hidden_layers=source.num_layers, hidden_size=source.hidden_size, diff --git a/nemo/collections/llm/gpt/model/mixtral.py b/nemo/collections/llm/gpt/model/mixtral.py index 424fab8c37982..6256b67515ee2 100644 --- a/nemo/collections/llm/gpt/model/mixtral.py +++ b/nemo/collections/llm/gpt/model/mixtral.py @@ -4,19 +4,21 @@ import torch import torch.nn.functional as F +from torch import nn from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel from nemo.lightning import io, teardown -from nemo.lightning.pytorch.opt import OptimizerModule +from nemo.lightning.pytorch.optim import OptimizerModule if TYPE_CHECKING: - from transformers import MistralConfig, MistralForCausalLM + from transformers import MixtralForCausalLM from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec @dataclass -class MixtralConfig(GPTConfig): +class MixtralConfig8x7B(GPTConfig): """ Config for Mixtral-8x7B model Official announcement: https://mistral.ai/news/mixtral-of-experts/ @@ -50,11 +52,14 @@ class MixtralConfig(GPTConfig): class MixtralModel(GPTModel): def __init__( self, - config: Optional[MixtralConfig] = None, + config: Optional[MixtralConfig8x7B] = None, optim: Optional[OptimizerModule] = None, tokenizer: Optional["TokenizerSpec"] = None, + model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, ): - super().__init__(config or MixtralConfig(), optim=optim, tokenizer=tokenizer) + super().__init__( + config or MixtralConfig8x7B(), optim=optim, tokenizer=tokenizer, model_transform=model_transform + ) @io.model_importer(MixtralModel, ext="hf") @@ -99,11 +104,11 @@ def tokenizer(self) -> "AutoTokenizer": return AutoTokenizer(str(self)) @property - def config(self) -> MixtralConfig: + def config(self) -> MixtralConfig8x7B: from transformers import MixtralConfig as HfMixtralConfig config = HfMixtralConfig.from_pretrained(str(self)) - return MixtralConfig( + return MixtralConfig8x7B( activation_func=F.silu, # network num_layers=config.num_hidden_layers, @@ -181,3 +186,122 @@ def _import_qkv(ctx: io.TransformCTX, q, k, v): ) def _import_moe_w1_w3(gate_proj, up_proj): return torch.cat((gate_proj, up_proj), axis=0) + + +@io.model_exporter(MixtralModel, "hf") +class HFMixtralExporter(io.ModelConnector[MixtralModel, "MixtralForCausalLM"]): + def init(self) -> "MixtralForCausalLM": + from transformers import AutoModelForCausalLM + + return AutoModelForCausalLM.from_config(self.config) + + def apply(self, output_path: Path) -> Path: + # TODO: Make it work with lazy init + # with torch.device("meta"): + # target = self.init() + target = self.init() + source, _ = self.nemo_load(str(self)) + target = self.convert_state(source, target) + + # TODO: Make sure we don't need to do this + target = target.cpu() + target.save_pretrained(output_path) + self.tokenizer.save_pretrained(output_path) + + return output_path + + def convert_state(self, source, target): + mapping = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", + "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight", + "decoder.layers.*.pre_mlp_layernorm.weight": "model.layers.*.post_attention_layernorm.weight", + # MoE + "decoder.layers.*.mlp.experts.local_experts.*.linear_fc2.weight": "model.layers.*.block_sparse_moe.experts.*.w2.weight", + "decoder.layers.*.mlp.router.weight": "model.layers.*.block_sparse_moe.gate.weight", + # lm-head + "decoder.final_layernorm.weight": "model.norm.weight", + "output_layer.weight": "lm_head.weight", + } + + return io.apply_transforms(source, target, mapping=mapping, transforms=[_export_qkv, _export_moe_w1_w3]) + + @property + def tokenizer(self): + return io.load_ckpt(str(self)).model.tokenizer.tokenizer + + @property + def config(self) -> "MixtralConfig": + source: MixtralConfig7B = io.load_ckpt(str(self)).model.config + + from transformers import MixtralConfig as HfMixtralConfig + + return HfMixtralConfig( + num_hidden_layers=source.num_layers, + hidden_size=source.hidden_size, + intermediate_size=source.ffn_hidden_size, + max_position_embeddings=source.max_position_embeddings, + seq_length=source.max_position_embeddings, + # RoPe + rope_theta=source.rotary_base, + # transformer config + num_attention_heads=source.num_attention_heads, + num_key_value_heads=source.num_query_groups, + num_local_experts=config.num_moe_experts, + num_experts_per_tok=config.moe_router_topk, + # norm + rms_norm_eps=source.layernorm_epsilon, + # init + initializer_range=source.init_method_std, + # vocab + vocab_size=self.tokenizer.vocab_size, + ) + + +@io.state_transform( + source_key="decoder.layers.*.self_attention.linear_qkv.weight", + target_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), +) +def _export_qkv(ctx: io.TransformCTX, linear_qkv): + megatron_config = ctx.source.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_num = megatron_config.num_attention_heads + head_size = hidden_size // head_num + qkv_total_dim = head_num + 2 * num_query_groups + + linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, hidden_size]) + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu() + k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu() + v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu() + + return q_proj, k_proj, v_proj + + +@io.state_transform( + source_key="decoder.layers.*.mlp.experts.local_experts.*.linear_fc1.weight", + target_key=( + "model.layers.*.block_sparse_moe.experts.*.w1.weight", + "model.layers.*.block_sparse_moe.experts.*.w3.weight", + ), +) +def _export_moe_w1_w3(linear_fc1): + gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0) + + return gate_proj, up_proj diff --git a/nemo/collections/llm/peft/__init__.py b/nemo/collections/llm/peft/__init__.py new file mode 100644 index 0000000000000..69855f6f9c532 --- /dev/null +++ b/nemo/collections/llm/peft/__init__.py @@ -0,0 +1,4 @@ +from nemo.collections.llm.peft.api import gpt_lora +from nemo.collections.llm.peft.lora import LoRA + +__all__ = ["LoRA", "gpt_lora"] diff --git a/nemo/collections/llm/peft/api.py b/nemo/collections/llm/peft/api.py new file mode 100644 index 0000000000000..dc8fc76c752e0 --- /dev/null +++ b/nemo/collections/llm/peft/api.py @@ -0,0 +1,11 @@ +from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.utils import factory +from nemo.lightning.pytorch.callbacks.peft import PEFT + + +@factory +def gpt_lora() -> PEFT: + return LoRA() + + +__all__ = ["gpt_lora"] diff --git a/nemo/collections/llm/peft/lora.py b/nemo/collections/llm/peft/lora.py new file mode 100644 index 0000000000000..913144d1bf5fa --- /dev/null +++ b/nemo/collections/llm/peft/lora.py @@ -0,0 +1,123 @@ +from dataclasses import dataclass, field +from typing import List, Literal + +from megatron.core import parallel_state +from torch import nn + +from nemo.lightning.pytorch.callbacks.peft import PEFT, AdapterWrapper +from nemo.utils import logging + + +class AdapterParallelAdd(AdapterWrapper): + """An adapter wrapper that adds the output of the adapter to the output of the wrapped module. + + This class is designed to be used with LoRA (Low-Rank Adaptation) and similar techniques + where the adapter's output is added to the main module's output. It extends the AdapterWrapper + class to provide a specific implementation of the forward method. + """ + + def forward(self, x): + linear_output, bias = self.to_wrap(x) + if isinstance(linear_output, tuple) and len(linear_output) == 2: + linear_output, layernorm_output = linear_output + adapter_output = self.adapter(layernorm_output) + else: + adapter_output = self.adapter(x) + return linear_output + adapter_output, bias + + +@dataclass +class LoRA(PEFT): + """ + Implements the LoRA (Low-Rank Adaptation) module for parameter-efficient fine-tuning. + + LoRA uses a low-rank projection to adapt the weights of a pre-trained model to a new downstream task. + This class facilitates the application of LoRA to specific modules within the model architecture. + + Args: + target_modules (List[str], optional): A list of module names to apply LoRA to. + Defaults to all linear layers ['linear_qkv', 'linear_proj', 'linear_fc1', 'linear_fc2']. + - 'linear_qkv': Apply LoRA to the fused linear layer used for query, key, and value projections + in self-attention modules. + - 'linear_proj': Apply LoRA to the linear layer used for projecting the output of self-attention modules. + - 'linear_fc1': Apply LoRA to the first fully-connected layer in MLP. + - 'linear_fc2': Apply LoRA to the second fully-connected layer in MLP. + dim (int): Dimension of the low-rank projection space. Defaults to 32. + alpha (int): Weighting factor for the low-rank projection. Defaults to 32. + dropout (float): Dropout rate for the low-rank projection. Defaults to 0.0. + dropout_position (Literal['pre', 'post'], optional): Position for applying dropout. + Can be 'pre' (before the low-rank projection) or 'post' (after). Defaults to 'post'. + + Example: + -------- + >>> from nemo.collections import llm + >>> lora = llm.peft.LoRA(target_modules=['linear_qkv', 'linear_proj'], dim=32) + >>> model = llm.Mistral7BModel(model_transform=lora) + >>> # (set up trainer and data) + >>> trainer.fit(model, data) + + References: + ----------- + Hu, E. J., Shen, Y., Wallis, P., Allen-Zhu, Z., Li, Y., Wang, S., Wang, L., & Chen, W. (2021). + LoRA: Low-Rank Adaptation of Large Language Models. arXiv preprint arXiv:2106.09685. + https://arxiv.org/abs/2106.09685 + + ) + """ + + target_modules: List[str] = field( + default_factory=lambda: ['linear_qkv', 'linear_proj', 'linear_fc1', 'linear_fc2'] + ) + dim: int = 32 + alpha: int = 32 + dropout: float = 0.0 + dropout_position: Literal['pre', 'post'] = 'post' + + def transform(self, m: nn.Module, name=None, prefix=None): + """ + Applies LoRA to a specific module within the model architecture. + + Args: + m (nn.Module): The module to apply LoRA to. + name (str, optional): Name of the module (if applicable). Defaults to None. + prefix (str, optional): Prefix for the module name (if applicable). Defaults to None. + + Returns: + nn.Module: The modified module with LoRA applied, or the original module if not a target. + """ + from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ParallelLinearAdapter + + tp_size = parallel_state.get_tensor_model_parallel_world_size() + if name in self.target_modules: + # m.in_features and m.out_features are divided by tp_size already, + # but in_features and out_features passed to ParallelLinearAdapter are not. + if name in ['linear_qkv', 'linear_fc1']: + # Column Parallel Linear + input_is_parallel = False + in_features = m.in_features + out_features = m.out_features * tp_size + else: # name in ['linear_proj', 'linear_fc2'] + # Row Parallel Linear + input_is_parallel = True + in_features = m.in_features * tp_size + out_features = m.out_features + + logging.info(f"Adding lora to: {prefix}.{name}") + adapter = ParallelLinearAdapter( + in_features, + out_features, + self.dim, + activation='identity', + norm_position=None, + norm_type=None, + column_init_method="normal", + row_init_method="zero", + gather_output=False, + input_is_parallel=input_is_parallel, + dropout=self.dropout, + dropout_position=self.dropout_position, + model_parallel_config=getattr(m, "config", None), + alpha=self.alpha, + ) + return AdapterParallelAdd(m, adapter) + return m diff --git a/nemo/collections/llm/tokenizer.py b/nemo/collections/llm/tokenizer.py new file mode 100644 index 0000000000000..3943e24ba7991 --- /dev/null +++ b/nemo/collections/llm/tokenizer.py @@ -0,0 +1,27 @@ +from nemo.lightning.io.artifact import FileArtifact +from nemo.lightning.io.mixin import track_io + +__all__ = [] + +try: + from nemo.collections.common.tokenizers import AutoTokenizer + + track_io( + AutoTokenizer, + artifacts=[ + FileArtifact("vocab_file"), + FileArtifact("merges_file"), + ], + ) + __all__.append("AutoTokenizer") +except ImportError: + pass + + +try: + from nemo.collections.common.tokenizers import SentencePieceTokenizer + + track_io(SentencePieceTokenizer, artifacts=[FileArtifact("model_path")]) + __all__.append("SentencePieceTokenizer") +except ImportError: + pass diff --git a/nemo/collections/llm/utils.py b/nemo/collections/llm/utils.py index c108d86c2e1b8..b4382d0afd5f6 100644 --- a/nemo/collections/llm/utils.py +++ b/nemo/collections/llm/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Generic, TypeVar +from typing import Any, Callable, Generic, TypeVar, Union, overload T = TypeVar('T', bound=Callable[..., Any]) @@ -28,3 +28,32 @@ def noop_decorator(func: T) -> T: return func return noop_decorator + + +@overload +def factory() -> Callable[[T], T]: ... + + +@overload +def factory(*args: Any, **kwargs: Any) -> Callable[[T], T]: ... + + +def factory(*args: Any, **kwargs: Any) -> Union[Callable[[T], T], T]: + try: + import nemo_sdk as sdk + + if not args and not kwargs: + # Used as @factory without arguments + return sdk.factory() + else: + # Used as @factory(*args, **kwargs) + return sdk.factory(*args, **kwargs) + except ImportError: + # Return a no-op function + def noop_decorator(func: T) -> T: + return func + + if not args and not kwargs: + return noop_decorator + else: + return noop_decorator diff --git a/nemo/collections/multimodal/data/clip/clip_dataset.py b/nemo/collections/multimodal/data/clip/clip_dataset.py index 7e263e19dcc95..6b63d546194a5 100644 --- a/nemo/collections/multimodal/data/clip/clip_dataset.py +++ b/nemo/collections/multimodal/data/clip/clip_dataset.py @@ -76,11 +76,18 @@ def get_preprocess_fns(model_cfg, tokenizer=None, is_train=True): img_size = (model_cfg.vision.get("img_h"), model_cfg.vision.get("img_w")) img_mean = model_cfg.vision.get("img_mean") img_std = model_cfg.vision.get("img_std") - img_transform = image_transform(img_size, is_train=is_train, mean=img_mean, std=img_std,) + img_transform = image_transform( + img_size, + is_train=is_train, + mean=img_mean, + std=img_std, + ) text_transform = lambda x: x if tokenizer is not None: text_transform = partial( - tokenize, tokenizer=tokenizer, context_length=model_cfg.text.get("max_position_embeddings"), + tokenize, + tokenizer=tokenizer, + context_length=model_cfg.text.get("max_position_embeddings"), ) return img_transform, text_transform @@ -100,7 +107,9 @@ def transform_fn(sample, img_transform, text_transform): def build_train_valid_datasets( - model_cfg, consumed_samples, tokenizer=None, + model_cfg, + consumed_samples, + tokenizer=None, ): data_cfg = model_cfg.data @@ -127,6 +136,13 @@ def build_train_valid_datasets( return train_data, val_data +def custom_collate(batch): + if len(batch) == 0: + return None, None + else: + return default_collate(batch) + + # For zero-shot imagenet validation def build_imagenet_validation_dataloader(model_cfg, tokenizer=None): val_image_transform, text_transform = get_preprocess_fns(model_cfg, tokenizer, is_train=False) @@ -138,7 +154,10 @@ def build_imagenet_validation_dataloader(model_cfg, tokenizer=None): if imagenet_path is None: return None - image_dataset = ImageFolder(root=imagenet_path, transform=val_image_transform,) + image_dataset = ImageFolder( + root=imagenet_path, + transform=val_image_transform, + ) image_batch_sampler = MegatronPretrainingSampler( total_samples=len(image_dataset), @@ -150,12 +169,6 @@ def build_imagenet_validation_dataloader(model_cfg, tokenizer=None): drop_last=False, ) - def custom_collate(batch): - if len(batch) == 0: - return None, None - else: - return default_collate(batch) - imagenet_val["images"] = torch.utils.data.DataLoader( image_dataset, batch_sampler=image_batch_sampler, diff --git a/nemo/collections/multimodal/data/neva/conversation.py b/nemo/collections/multimodal/data/neva/conversation.py index 43b1977aa993a..10a6c9e7283dc 100644 --- a/nemo/collections/multimodal/data/neva/conversation.py +++ b/nemo/collections/multimodal/data/neva/conversation.py @@ -43,6 +43,7 @@ class SeparatorStyle(Enum): PLAIN = auto() LLAMA_2 = auto() LLAMA_3 = auto() + MISTRAL = auto() NVGPT = auto() @@ -94,11 +95,15 @@ def get_prompt(self): ret += " " else: ret += role + ":" - elif self.sep_style == SeparatorStyle.LLAMA_2: - wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" + elif self.sep_style == SeparatorStyle.LLAMA_2 or self.sep_style == SeparatorStyle.MISTRAL: + if self.sep_style == SeparatorStyle.LLAMA_2: + wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" + else: + wrap_sys = lambda msg: f"{msg}" + ("\n" if msg else "") wrap_inst = lambda msg: f"[INST] {msg} [/INST]" ret = "" - + if self.sep_style == SeparatorStyle.MISTRAL: + ret += DEFAULT_BOS_TOKEN for i, (role, message) in enumerate(messages): if i == 0: assert message, "first message should not be none" @@ -112,7 +117,10 @@ def get_prompt(self): message = wrap_inst(message) ret += self.sep + " " + message else: - ret += " " + message + " " + self.sep2 + if self.sep_style == SeparatorStyle.LLAMA_2: + ret += " " + message + " " + self.sep2 + else: + ret += message + self.sep2 else: ret += "" ret = ret.lstrip(self.sep) @@ -449,6 +457,17 @@ def dict(self): version="v1_mmtag", ) +conv_mistral = Conversation( + system="", + roles=("USER", "ASSISTANT"), + version="mistral", + messages=(), + offset=0, + sep_style=SeparatorStyle.MISTRAL, + sep="", + sep2=DEFAULT_EOS_TOKEN, +) + default_conversation = conv_vicuna_v1 conv_templates = { "default": conv_vicuna_v0, @@ -466,6 +485,7 @@ def dict(self): "nvgpt": conv_nvgpt, "nv_steerlm": conv_nvgpt, "nv_dpo": conv_nv_dpo, + "mistral": conv_mistral, } if __name__ == "__main__": diff --git a/nemo/collections/multimodal/data/neva/neva_dataset.py b/nemo/collections/multimodal/data/neva/neva_dataset.py index 86d45ded54cfd..7eef677e13a8b 100644 --- a/nemo/collections/multimodal/data/neva/neva_dataset.py +++ b/nemo/collections/multimodal/data/neva/neva_dataset.py @@ -426,6 +426,7 @@ def preprocess_llama_2( sources: dict, tokenizer, cfg, + is_mistral: bool = False, ) -> Dict: """ Preprocesses sources for the LLaMA 2 model configuration. @@ -442,7 +443,10 @@ def preprocess_llama_2( - Dict: A dictionary containing tokenized and labeled data suitable for the LLaMA 2 model. This includes tokens, labels, and any special processing as defined in the configuration. """ - conv = conversation_lib.conv_llava_llama_2.copy() + if is_mistral: + conv = conversation_lib.conv_mistral.copy() + else: + conv = conversation_lib.conv_llava_llama_2.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates @@ -477,7 +481,10 @@ def preprocess_llama_2( labels = tokens.clone().detach() # Mask labels - sep = "[/INST] " + if is_mistral: + sep = "[/INST]" + else: + sep = "[/INST] " for conversation, target in zip(conversations, labels): rounds = conversation.split(conv.sep2) cur_len = 0 @@ -492,18 +499,23 @@ def preprocess_llama_2( parts[0] += sep round_len = len(tokenizer.text_to_ids(rou + conv.sep2)) - instruction_len = len(tokenizer.text_to_ids(parts[0])) - 2 + + if is_mistral: + instruction_len = len(tokenizer.text_to_ids(parts[0])) - 1 + else: + instruction_len = len(tokenizer.text_to_ids(parts[0])) - 2 + if i > 0: round_len -= 1 # Remove extra token added by sp tokenizer else: instruction_len += 1 target[cur_len : cur_len + instruction_len] = IGNORE_INDEX - cur_len += round_len target[cur_len:] = IGNORE_INDEX # Check if masking working correctly - # print([x for x in zip(tokens[0].numpy().tolist(), labels[0].numpy().tolist())]) + # masking_test =[x for x in zip(tokens[0].numpy().tolist(), labels[0].numpy().tolist())] + # print(masking_test) if add_extra_token: tokens = tokens[:, :-1].contiguous() @@ -990,7 +1002,10 @@ def expand2square(pil_img, background_color): result.paste(pil_img, ((height - width) // 2, 0)) return result - frames = expand2square(frames, tuple(int(x * 255) for x in self.processor.image_mean)) + frames = [ + expand2square(frame, tuple(int(x * 255) for x in self.processor.image_mean)) + for frame in frames + ] frames = self.processor.preprocess(frames, return_tensors='pt')['pixel_values'] else: frames = self.processor.preprocess(frames, return_tensors='pt')['pixel_values'] @@ -1057,6 +1072,13 @@ def expand2square(pil_img, background_color): self.tokenizer, self.multimodal_cfg, ) + elif self.conv_template == "mistral": + data_dict = preprocess_llama_2( + sources, + self.tokenizer, + self.multimodal_cfg, + is_mistral=True, + ) elif self.conv_template == "plain": data_dict = preprocess_plain( sources, diff --git a/nemo/collections/multimodal/losses/siglip_loss.py b/nemo/collections/multimodal/losses/siglip_loss.py new file mode 100644 index 0000000000000..a7d2ec9b46ce9 --- /dev/null +++ b/nemo/collections/multimodal/losses/siglip_loss.py @@ -0,0 +1,220 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file contains code artifacts adapted from the original implementation: +# https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/loss.py + +import torch +import torch.nn.functional as F + + +def neighbour_exchange(from_rank, to_rank, tensor, group=None): + tensor_recv = torch.zeros_like(tensor) + send_op = torch.distributed.P2POp( + torch.distributed.isend, + tensor, + to_rank, + group=group, + ) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, + tensor_recv, + from_rank, + group=group, + ) + reqs = torch.distributed.batch_isend_irecv([send_op, recv_op]) + for req in reqs: + req.wait() + return tensor_recv + + +def neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None): + tensor_from_left = torch.zeros_like(tensor_to_right) + tensor_from_right = torch.zeros_like(tensor_to_left) + send_op_left = torch.distributed.P2POp( + torch.distributed.isend, + tensor_to_left, + left_rank, + group=group, + ) + send_op_right = torch.distributed.P2POp( + torch.distributed.isend, + tensor_to_right, + right_rank, + group=group, + ) + recv_op_left = torch.distributed.P2POp( + torch.distributed.irecv, + tensor_from_left, + left_rank, + group=group, + ) + recv_op_right = torch.distributed.P2POp( + torch.distributed.irecv, + tensor_from_right, + right_rank, + group=group, + ) + reqs = torch.distributed.batch_isend_irecv([send_op_right, send_op_left, recv_op_right, recv_op_left]) + for req in reqs: + req.wait() + return tensor_from_right, tensor_from_left + + +class NeighbourExchange(torch.autograd.Function): + @staticmethod + def forward(ctx, from_rank, to_rank, group, tensor): + ctx.group = group + ctx.from_rank = from_rank + ctx.to_rank = to_rank + return neighbour_exchange(from_rank, to_rank, tensor, group=group) + + @staticmethod + def backward(ctx, grad_output): + return (None, None, None) + (NeighbourExchange.apply(ctx.to_rank, ctx.from_rank, ctx.group, grad_output),) + + +def neighbour_exchange_with_grad(from_rank, to_rank, tensor, group=None): + return NeighbourExchange.apply(from_rank, to_rank, group, tensor) + + +class NeighbourExchangeBidir(torch.autograd.Function): + @staticmethod + def forward(ctx, left_rank, right_rank, group, tensor_to_left, tensor_to_right): + ctx.group = group + ctx.left_rank = left_rank + ctx.right_rank = right_rank + return neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=group) + + @staticmethod + def backward(ctx, *grad_outputs): + return (None, None, None) + NeighbourExchangeBidir.apply( + ctx.right_rank, ctx.left_rank, ctx.group, *grad_outputs + ) + + +def neighbour_exchange_bidir_with_grad(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None): + return NeighbourExchangeBidir.apply(left_rank, right_rank, group, tensor_to_left, tensor_to_right) + + +class SigLipLoss(torch.nn.Module): + """Sigmoid Loss for Language Image Pre-Training (SigLIP) - https://arxiv.org/abs/2303.15343 + + @article{zhai2023sigmoid, + title={Sigmoid loss for language image pre-training}, + author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas}, + journal={arXiv preprint arXiv:2303.15343}, + year={2023} + } + """ + + def __init__( + self, + cache_labels=False, + rank=0, + world_size=1, + group=None, + bidir=True, + ): + super().__init__() + self.cache_labels = cache_labels + self.rank = rank + self.world_size = world_size + self.group = group + self.bidir = bidir + + def get_ground_truth(self, device, dtype, num_logits, negative_only=False) -> torch.Tensor: + labels = -torch.ones((num_logits, num_logits), device=device, dtype=dtype) + if not negative_only: + labels = 2 * torch.eye(num_logits, device=device, dtype=dtype) + labels + return labels + + def get_logits(self, image_features, text_features, logit_scale, logit_bias=None): + logits = logit_scale * image_features @ text_features.T + if logit_bias is not None: + logits += logit_bias + return logits + + def _loss(self, image_features, text_features, logit_scale, logit_bias=None, negative_only=False): + logits = self.get_logits(image_features, text_features, logit_scale, logit_bias) + labels = self.get_ground_truth( + image_features.device, + image_features.dtype, + image_features.shape[0], + negative_only=negative_only, + ) + loss = -F.logsigmoid(labels * logits).sum() / image_features.shape[0] + return loss + + def forward( + self, + output_tensor, + ): + image_features, text_features, logit_scale, logit_bias = output_tensor + loss = self._loss(image_features, text_features, logit_scale, logit_bias) + + if self.world_size > 1: + # exchange text features w/ neighbour world_size - 1 times + right_rank = (self.rank + 1) % self.world_size + left_rank = (self.rank - 1 + self.world_size) % self.world_size + if self.bidir: + text_features_to_right = text_features_to_left = text_features + num_bidir, remainder = divmod(self.world_size - 1, 2) + for i in range(num_bidir): + text_features_recv = neighbour_exchange_bidir_with_grad( + left_rank, + right_rank, + text_features_to_left, + text_features_to_right, + group=self.group, + ) + + for f in text_features_recv: + loss += self._loss( + image_features, + f, + logit_scale, + logit_bias, + negative_only=True, + ) + text_features_to_left, text_features_to_right = text_features_recv + + if remainder: + text_features_recv = neighbour_exchange_with_grad( + left_rank, right_rank, text_features_to_right, group=self.group + ) + + loss += self._loss( + image_features, + text_features_recv, + logit_scale, + logit_bias, + negative_only=True, + ) + else: + text_features_to_right = text_features + for i in range(self.world_size - 1): + text_features_from_left = neighbour_exchange_with_grad( + left_rank, right_rank, text_features_to_right, group=self.group + ) + + loss += self._loss( + image_features, + text_features_from_left, + logit_scale, + logit_bias, + negative_only=True, + ) + text_features_to_right = text_features_from_left + return loss, {"loss": loss} diff --git a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py index cce40da457253..376237e89ecc6 100644 --- a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py +++ b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py @@ -75,7 +75,7 @@ HAVE_APEX = False try: - from megatron.core import InferenceParams, dist_checkpointing, parallel_state + from megatron.core import InferenceParams, dist_checkpointing, parallel_state, tensor_parallel from megatron.core.models.gpt import GPTModel as MCoreGPTModel from megatron.core.pipeline_parallel.schedules import get_forward_backward_func from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint @@ -154,10 +154,34 @@ def set_media(self, media): self.media = media def forward(self, input_ids, **kwargs): - media = self.media # avoid change the signature of embedding forward function + media = self.media # avoid changing the signature of embedding forward function + + # TODO: Refactor replace_media_embedding to account for MCore's embedding communication optimization + # https://github.com/NVIDIA/Megatron-LM/commit/ee423e7 changes the way we handle embeddings with sequence parallelism + # When using reduce_scatter_embeddings, word_embedding_tensor is now in the following shape: [sequence/tp, batch_size, hidden_size] + # replace_media_embedding currently expects [batch_size, sequence, hidden_size] + + # Check if reduce_scatter_embeddings is enabled in the embedding forward function + apply_reduce_scatter = getattr(self, 'reduce_scatter_embeddings', False) + + # Set reduce_scatter_embeddings to false to keep words_embedding's + # tensor dimesion the same for replace_media_embedding + if apply_reduce_scatter: + self.reduce_scatter_embeddings = False + words_embeddings = super().forward(input_ids, **kwargs) + words_embeddings = self.replace_media_embeddings(input_ids, words_embeddings, media) - return self.replace_media_embeddings(input_ids, words_embeddings, media) + # Scatter embeddings back to each TP rank if reduce_scatter_embeddings is enabled + if apply_reduce_scatter: + words_embeddings = self._apply_reduce_scatter(words_embeddings) + self.reduce_scatter_embeddings = True + + return words_embeddings + + def _apply_reduce_scatter(self, embeddings): + embeddings = embeddings.transpose(0, 1).contiguous() + return tensor_parallel.mappings.scatter_to_sequence_parallel_region(embeddings) def encode_vision_x(self, vision_x: torch.Tensor): """ @@ -193,7 +217,6 @@ def encode_vision_x(self, vision_x: torch.Tensor): def replace_media_embeddings(self, input_ids, inputs_embeds, media): if media is None: return inputs_embeds - batch_size, sequence_length, hidden_size = inputs_embeds.shape # calculate media features without gradients @@ -550,7 +573,12 @@ def dummy(): media_end_id=media_end_id, mcore_gpt=self.mcore_gpt, config=self.transformer_config, - transformer_layer_spec=get_specs(self.spec_name), + transformer_layer_spec=get_specs( + self.spec_name, + self.transformer_config.num_moe_experts, + self.transformer_config.moe_grouped_gemm, + self.transformer_engine, + ), vocab_size=self.cfg.get('override_vocab_size', self.padded_vocab_size), max_sequence_length=self.cfg.get('encoder_seq_length', 512), pre_process=pre_process, diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/diffusion_engine.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/diffusion_engine.py index efc1550113a01..755588202ef0d 100644 --- a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/diffusion_engine.py +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/diffusion_engine.py @@ -119,7 +119,9 @@ def __init__(self, cfg, model_parallel_config): self._init_first_stage(first_stage_config) self.model_type = None - self.rng = torch.Generator(device=torch.cuda.current_device(),) + self.rng = torch.Generator( + device=torch.cuda.current_device(), + ) self.use_ema = False # TODO use_ema need to switch to NeMo style if self.use_ema: @@ -158,6 +160,13 @@ def decode_first_stage(self, z): out = self.first_stage_model.decode(z) return out + # same as above but differentiable + def differentiable_decode_first_stage(self, z): + z = 1.0 / self.scale_factor * z + with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): + out = self.first_stage_model.decode(z) + return out + @torch.no_grad() def encode_first_stage(self, x): with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): @@ -185,7 +194,12 @@ def training_step(self, batch, batch_idx): self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False) self.log( - "global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False, + "global_step", + self.global_step, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=False, ) if self.scheduler_config is not None: @@ -231,7 +245,11 @@ def configure_optimizers(self): scheduler = DiffusionEngine.from_config_dict(self.scheduler_config) print("Setting up LambdaLR scheduler...") scheduler = [ - {"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), "interval": "step", "frequency": 1,} + { + "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), + "interval": "step", + "frequency": 1, + } ] return [opt], scheduler return opt @@ -291,7 +309,14 @@ def set_input_tensor(self, input_tensor): pass @torch.no_grad() - def log_images(self, batch: Dict, N: int = 8, sample: bool = True, ucg_keys: List[str] = None, **kwargs,) -> Dict: + def log_images( + self, + batch: Dict, + N: int = 8, + sample: bool = True, + ucg_keys: List[str] = None, + **kwargs, + ) -> Dict: conditioner_input_keys = [e.input_key for e in self.conditioner.embedders] if ucg_keys: assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), ( @@ -305,7 +330,8 @@ def log_images(self, batch: Dict, N: int = 8, sample: bool = True, ucg_keys: Lis x = self.get_input(batch) c, uc = self.conditioner.get_unconditional_conditioning( - batch, force_uc_zero_embeddings=ucg_keys if len(self.conditioner.embedders) > 0 else [], + batch, + force_uc_zero_embeddings=ucg_keys if len(self.conditioner.embedders) > 0 else [], ) sampling_kwargs = {} @@ -400,7 +426,10 @@ def fwd_bwd_step(self, dataloader_iter, forward_only): # handle asynchronous grad reduction no_sync_func = None if not forward_only and self.with_distributed_adam: - no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,) + no_sync_func = partial( + self._optimizer.no_sync, + greedy_grad_copy=self.megatron_amp_O2, + ) # pipeline schedules will get these from self.model.config for module in self.get_module_list(): @@ -438,12 +467,12 @@ def fwd_bwd_step(self, dataloader_iter, forward_only): def training_step(self, dataloader_iter): """ - Our dataloaders produce a micro-batch and then we fetch - a number of microbatches depending on the global batch size and model parallel size - from the dataloader to produce a list of microbatches. - Batch should be a list of microbatches and those microbatches should on CPU. - Microbatches are then moved to GPU during the pipeline. - The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + Batch should be a list of microbatches and those microbatches should on CPU. + Microbatches are then moved to GPU during the pipeline. + The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. """ self._optimizer.zero_grad() @@ -491,20 +520,20 @@ def training_step(self, dataloader_iter): return loss_mean def backward(self, *args, **kwargs): - """ LightningModule hook to do backward. - We want this to do nothing since we run backward in the fwd/bwd functions from apex. - No need to call it here. + """LightningModule hook to do backward. + We want this to do nothing since we run backward in the fwd/bwd functions from apex. + No need to call it here. """ pass def optimizer_zero_grad(self, *args, **kwargs): - """ LightningModule hook to zero grad. - We want this to do nothing as we are zeroing grads during the training_step. + """LightningModule hook to zero grad. + We want this to do nothing as we are zeroing grads during the training_step. """ pass def _append_sequence_parallel_module_grads(self, module, grads): - """ Helper method for allreduce_sequence_parallel_gradients""" + """Helper method for allreduce_sequence_parallel_gradients""" for param in module.parameters(): sequence_parallel_param = getattr(param, 'sequence_parallel', False) @@ -517,12 +546,13 @@ def _append_sequence_parallel_module_grads(self, module, grads): def get_forward_output_and_loss_func(self): def process_batch(batch): - """ Prepares the global batch for apex fwd/bwd functions. - Global batch is a list of micro batches. + """Prepares the global batch for apex fwd/bwd functions. + Global batch is a list of micro batches. """ # SD has more dedicated structure for encoding, so we enable autocasting here as well with torch.cuda.amp.autocast( - self.autocast_dtype in (torch.half, torch.bfloat16), dtype=self.autocast_dtype, + self.autocast_dtype in (torch.half, torch.bfloat16), + dtype=self.autocast_dtype, ): if self.model.precache_mode == 'both': x = batch[self.model.input_key].to(torch.cuda.current_device()) @@ -565,7 +595,7 @@ def validation_step(self, dataloader_iter, batch_idx): return loss def setup(self, stage=None): - """ PTL hook that is executed after DDP spawns. + """PTL hook that is executed after DDP spawns. We setup datasets here as megatron datasets require DDP to instantiate. See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information. Args: @@ -678,20 +708,23 @@ def setup_test_data(self, cfg): f'Setting up test dataloader with len(len(self._test_ds)): {len(self._test_ds)} and consumed samples: {consumed_samples}' ) self._test_dl = torch.utils.data.DataLoader( - self._test_ds, batch_size=self._micro_batch_size, num_workers=cfg.num_workers, pin_memory=True, + self._test_ds, + batch_size=self._micro_batch_size, + num_workers=cfg.num_workers, + pin_memory=True, ) def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any: - """ PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device - When using pipeline parallelism, we need the global batch to remain on the CPU, - since the memory overhead will be too high when using a large number of microbatches. - Microbatches are transferred from CPU to GPU inside the pipeline. + """PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device + When using pipeline parallelism, we need the global batch to remain on the CPU, + since the memory overhead will be too high when using a large number of microbatches. + Microbatches are transferred from CPU to GPU inside the pipeline. """ return batch def _validate_trainer(self): - """ Certain trainer configurations can break training. - Here we try to catch them and raise an error. + """Certain trainer configurations can break training. + Here we try to catch them and raise an error. """ if self.trainer.accumulate_grad_batches > 1: raise ValueError( diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/autoencoder.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/autoencoder.py index 6bd47a78fbcf4..d79d85c2e026d 100644 --- a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/autoencoder.py +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/autoencoder.py @@ -16,6 +16,7 @@ import pytorch_lightning as pl import torch import torch.nn.functional as F +from nemo.utils import logging try: from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer @@ -316,6 +317,7 @@ def __init__( ignore_keys=[], image_key="image", colorize_nlabels=None, + from_NeMo=False, monitor=None, from_pretrained: str = None, ): @@ -337,6 +339,7 @@ def __init__( self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) if from_pretrained is not None: + logging.info(f"Attempting to load vae weights from {from_pretrained}") if from_pretrained.endswith('safetensors'): from safetensors.torch import load_file as load_safetensors @@ -345,7 +348,7 @@ def __init__( state_dict = torch.load(from_pretrained) if 'state_dict' in state_dict: state_dict = state_dict['state_dict'] - missing_key, unexpected_key, _, _ = self._load_pretrained_model(state_dict) + missing_key, unexpected_key, _, _ = self._load_pretrained_model(state_dict, from_NeMo=from_NeMo) if len(missing_key) > 0: print( f'{self.__class__.__name__}: Following keys are missing during loading VAE weights, which may lead to compromised image quality for a resumed training. Please check the checkpoint you provided.' @@ -395,8 +398,9 @@ def _state_key_mapping(self, state_dict: dict): res_dict[key_] = val_ return res_dict - def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False): - state_dict = self._state_key_mapping(state_dict) + def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False, from_NeMo=False): + if not from_NeMo: + state_dict = self._state_key_mapping(state_dict) model_state_dict = self.state_dict() loaded_keys = [k for k in state_dict.keys()] expected_keys = list(model_state_dict.keys()) @@ -405,7 +409,10 @@ def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False): unexpected_keys = list(set(loaded_keys) - set(expected_keys)) def _find_mismatched_keys( - state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes, + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, ): mismatched_keys = [] if ignore_mismatched_sizes: @@ -440,7 +447,10 @@ def _find_mismatched_keys( if state_dict is not None: # Whole checkpoint mismatched_keys = _find_mismatched_keys( - state_dict, model_state_dict, original_loaded_keys, ignore_mismatched_sizes, + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, ) error_msgs = self._load_state_dict_into_model(state_dict) return missing_keys, unexpected_keys, mismatched_keys, error_msgs diff --git a/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py b/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py index 7be7407b98ae0..a83960307672a 100644 --- a/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py +++ b/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py @@ -13,12 +13,17 @@ # limitations under the License. import itertools -from functools import partial +import os +import warnings +from contextlib import nullcontext +from dataclasses import fields +from functools import cache, partial from typing import Any, Optional import numpy as np import torch import torch.nn.functional as F +from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.trainer.trainer import Trainer @@ -29,7 +34,9 @@ build_train_valid_datasets, ) from nemo.collections.multimodal.losses.clip_loss import ClipLoss +from nemo.collections.multimodal.losses.siglip_loss import SigLipLoss from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import get_specs, mcore_supports_moe from nemo.collections.nlp.modules.common.megatron.build_model import build_model from nemo.collections.nlp.modules.common.megatron.language_model import get_language_model from nemo.collections.nlp.modules.common.megatron.module import Float16Module, MegatronModule @@ -40,7 +47,7 @@ init_method_normal, scaled_init_method_normal, ) -from nemo.collections.nlp.parts.utils_funcs import get_last_rank, torch_dtype_from_precision +from nemo.collections.nlp.parts.utils_funcs import activation_to_func, get_last_rank from nemo.collections.vision.modules.vit.vit_backbone import VitBackbone from nemo.core.classes.common import PretrainedModelInfo from nemo.utils import logging @@ -55,7 +62,33 @@ try: from megatron.core import parallel_state + from megatron.core.distributed import DistributedDataParallel as McoreDDP + from megatron.core.distributed import DistributedDataParallelConfig + from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add + from megatron.core.models.gpt import GPTModel as MCoreGPTModel + from megatron.core.models.vision.clip_vit_model import CLIPViTModel from megatron.core.pipeline_parallel.schedules import get_forward_backward_func + from megatron.core.transformer.attention import CrossAttention, CrossAttentionSubmodules + from megatron.core.transformer.custom_layers.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TELayerNormColumnParallelLinear, + TENorm, + TERowParallelLinear, + ) + from megatron.core.transformer.enums import AttnMaskType as MCoreAttnMaskType + from megatron.core.transformer.identity_op import IdentityOp + from megatron.core.transformer.mlp import MLP, MLPSubmodules + from megatron.core.transformer.module import Float16Module as MCoreFloat16Module + from megatron.core.transformer.spec_utils import ModuleSpec + from megatron.core.transformer.transformer_config import TransformerConfig + from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules + from megatron.core.utils import ( + drain_embedding_wgrad_compute, + get_model_config, + init_method_normal, + scaled_init_method_normal, + ) HAVE_MEGATRON_CORE = True @@ -63,6 +96,28 @@ HAVE_MEGATRON_CORE = False +try: + import transformer_engine + from transformer_engine.pytorch import module as te_module + + HAVE_TE = True + +except (ImportError, ModuleNotFoundError): + HAVE_TE = False + + +@cache +def mcore_supports_moe() -> bool: + global HAVE_MEGATRON_CORE + if not HAVE_MEGATRON_CORE: + return False + try: + from megatron.core.transformer.moe.router import TopKRouter + + return True + except ImportError: + return False + class CLIPVisionTransformer(MegatronModule): """Vision Transformer Model.""" @@ -100,7 +155,11 @@ def __init__(self, model_cfg, model_parallel_config, pre_process=True, post_proc if self.post_process and not skip_head: self.output_dim = model_cfg.output_dim - self.head = torch.nn.Linear(self.hidden_size, self.output_dim, bias=False,) + self.head = torch.nn.Linear( + self.hidden_size, + self.output_dim, + bias=False, + ) def set_input_tensor(self, input_tensor): """See megatron.model.transformer.set_input_tensor()""" @@ -129,7 +188,6 @@ def __init__(self, model_cfg, model_parallel_config, padded_vocab_size, pre_proc self.pre_process = pre_process self.post_process = post_process self.fp16_lm_cross_entropy = model_cfg.fp16_lm_cross_entropy - self.sequence_parallel = model_cfg.sequence_parallel self.gradient_accumulation_fusion = model_cfg.gradient_accumulation_fusion scaled_init_method = ( @@ -173,7 +231,7 @@ def __init__(self, model_cfg, model_parallel_config, padded_vocab_size, pre_proc openai_gelu=model_cfg.openai_gelu, onnx_safe=model_cfg.onnx_safe, megatron_legacy=model_cfg.megatron_legacy, - transformer_engine=model_cfg.transformer_engine, + transformer_engine=False, fp8=model_cfg.fp8, fp8_e4m3=model_cfg.fp8_e4m3, fp8_hybrid=model_cfg.fp8_hybrid, @@ -193,14 +251,17 @@ def __init__(self, model_cfg, model_parallel_config, padded_vocab_size, pre_proc hidden_size=model_cfg.hidden_size, ) - # TODO (yuya): check this position id self.position_ids = None if self.pre_process: self.position_ids = torch.arange(model_cfg.max_position_embeddings).expand(1, -1).cuda() if self.post_process: self.output_dim = model_cfg.output_dim - self.head = torch.nn.Linear(model_cfg.hidden_size, self.output_dim, bias=False,) + self.head = torch.nn.Linear( + model_cfg.hidden_size, + self.output_dim, + bias=False, + ) self.attn_mask = self.build_attention_mask(model_cfg.max_position_embeddings) @@ -217,7 +278,8 @@ def build_attention_mask(self, max_position_embeddings): return mask def forward( - self, input_ids, + self, + input_ids, ): # input_ids: [b, s] # position_ids: [b, s] @@ -245,27 +307,263 @@ def forward( return hidden_states +class SiglipMHAPoolingHead(TransformerLayer): + """Multihead Attention Pooling.""" + + def __init__( + self, + config: TransformerConfig, + submodules: TransformerLayerSubmodules, + ): + super().__init__(config, submodules) + + self.probe = torch.nn.Parameter(torch.randn(1, 1, config.hidden_size)) + + def forward(self, hidden_state): + batch_size = hidden_state.shape[0] + # [s, b, h] + probe = self.probe.repeat(1, batch_size, 1) + hidden_state = hidden_state.transpose(0, 1) + hidden_state, context = super().forward( + probe, + attention_mask=None, + context=hidden_state, + ) + + return hidden_state[0] + + +class MCoreSiglipViTModel(CLIPViTModel): + def __init__(self, *args, **kwargs): + # TODO (yuya): need to handle post_process correctly in order to enable PP + self.output_dim = kwargs.pop('output_dim') + kwargs['ln_pre_impl'] = IdentityOp + super().__init__(*args, **kwargs) + assert self.output_dim == self.config.hidden_size, "Siglip output_dim needs to be the same as hidden_size." + self.conv1 = torch.nn.Conv2d( + in_channels=3, + out_channels=self.visual_hidden_size, + kernel_size=self.patch_dim, + stride=self.patch_dim, + bias=True, + ) + self.final_layernorm = TENorm( + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + + self.head = SiglipMHAPoolingHead( + self.config, + submodules=TransformerLayerSubmodules( + cross_attention=ModuleSpec( + module=CrossAttention, + params={"attn_mask_type": MCoreAttnMaskType.no_mask}, + submodules=CrossAttentionSubmodules( + linear_q=TEColumnParallelLinear, + linear_kv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + ), + ), + cross_attn_bda=get_bias_dropout_add, + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TELayerNormColumnParallelLinear, + linear_fc2=TERowParallelLinear, + ), + ), + mlp_bda=get_bias_dropout_add, + ), + ) + + def forward(self, x): + x = super().forward( + x, + ) + x = self.final_layernorm(x) + x = self.head(x) + return x + + +class MCoreSiglipTextModel(MCoreGPTModel): + def __init__(self, *args, **kwargs): + # TODO (yuya): need to handle post_process correctly in order to enable PP + self.output_dim = kwargs.pop('output_dim') + kwargs['transformer_layer_spec'].submodules.self_attention.params['attn_mask_type'] = MCoreAttnMaskType.no_mask + + super().__init__(*args, **kwargs) + self.final_layernorm = TENorm( + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + self.head = torch.nn.Linear( + self.config.hidden_size, + self.output_dim, + bias=True, + ) + + self.position_ids = None + if self.pre_process: + self.position_ids = torch.arange(kwargs['max_sequence_length']).expand(1, -1).cuda() + + def forward(self, input_ids): + + x = super().forward(input_ids, position_ids=self.position_ids, attention_mask=None) + x = self.final_layernorm(x) + x = x[-1] + x = self.head(x) + return x + + +class MCoreCLIPViTModel(CLIPViTModel): + def __init__(self, *args, **kwargs): + # TODO (yuya): need to handle post_process correctly in order to enable PP + self.output_dim = kwargs.pop('output_dim') + super().__init__(*args, **kwargs) + self.final_layernorm = TENorm( + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + self.head = torch.nn.Linear( + self.config.hidden_size, + self.output_dim, + bias=False, + ) + + def forward(self, x): + x = super().forward( + x, + ) + x = self.final_layernorm(x) + x = x[:, 0] + x = self.head(x) + return x + + +class MCoreCLIPTextModel(MCoreGPTModel): + def __init__(self, *args, **kwargs): + # TODO (yuya): need to handle post_process correctly in order to enable PP + self.output_dim = kwargs.pop('output_dim') + + super().__init__(*args, **kwargs) + self.final_layernorm = TENorm( + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + self.head = torch.nn.Linear( + self.config.hidden_size, + self.output_dim, + bias=False, + ) + self.position_ids = None + if self.pre_process: + self.position_ids = torch.arange(kwargs['max_sequence_length']).expand(1, -1).cuda() + + def forward(self, input_ids): + x = super().forward(input_ids, position_ids=self.position_ids, attention_mask=None) + x = self.final_layernorm(x) + x = x[input_ids.argmax(dim=-1), torch.arange(x.shape[1])] + x = self.head(x) + return x + + class CLIPModel(MegatronModule): """CLIP Model""" - def __init__(self, model_cfg, model_parallel_config, padded_vocab_size, pre_process=True, post_process=True): + def __init__( + self, + model_cfg, + model_parallel_config, + vision_transformer_config, + text_transformer_config, + padded_vocab_size, + pre_process=True, + post_process=True, + ): super(CLIPModel, self).__init__() self.config = model_parallel_config + self.use_siglip = model_cfg.get("use_siglip", False) self.pre_process = pre_process self.post_process = post_process - self.vision_encoder = CLIPVisionTransformer( - model_cfg.vision, model_parallel_config, pre_process=self.pre_process, post_process=self.post_process, - ) - self.text_encoder = CLIPTextTransformer( - model_cfg.text, - model_parallel_config, - padded_vocab_size, - pre_process=self.pre_process, - post_process=self.post_process, - ) + self.output_dim = model_cfg.output_dim + self.get_attention_mask_from_fusion = model_cfg.get('get_attention_mask_from_fusion', True) - self.logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + if model_cfg.get("mcore_gpt", False): + if model_cfg.vision.get("class_token_length") is None or model_cfg.vision.get("class_token_length") <= 0: + add_class_token = False + else: + add_class_token = True + vision_layer_spec = get_specs( + model_cfg.text.get('name', ''), + vision_transformer_config.num_moe_experts, + vision_transformer_config.moe_grouped_gemm, + model_cfg.get('transformer_engine', True), + ) + vision_layer_spec.submodules.self_attention.params['attn_mask_type'] = MCoreAttnMaskType.no_mask + + if model_cfg.get("use_siglip", False): + vision_module = MCoreSiglipViTModel + text_module = MCoreSiglipTextModel + else: + vision_module = MCoreCLIPViTModel + text_module = MCoreCLIPTextModel + self.vision_encoder = vision_module( + transformer_config=vision_transformer_config, + transformer_layer_spec=vision_layer_spec, + patch_dim=model_cfg.vision.get('patch_dim', 16), + img_h=model_cfg.vision.get('img_h', 224), + img_w=model_cfg.vision.get('img_w', 224), + add_class_token=add_class_token, + class_token_len=model_cfg.vision.get('class_token_length'), + output_dim=model_cfg.output_dim, + ) + self.text_encoder = text_module( + config=text_transformer_config, + transformer_layer_spec=get_specs( + model_cfg.text.get('name', ''), + text_transformer_config.num_moe_experts, + text_transformer_config.moe_grouped_gemm, + model_cfg.get('transformer_engine', True), + ), + vocab_size=model_cfg.text.get('override_vocab_size', padded_vocab_size), + max_sequence_length=model_cfg.text.get('encoder_seq_length', 512), + pre_process=pre_process, + post_process=False, + parallel_output=True, + share_embeddings_and_output_weights=False, + position_embedding_type=model_cfg.text.get('position_embedding_type', 'learned_absolute'), + rotary_percent=model_cfg.text.get('rotary_percentage', 1.0), + seq_len_interpolation_factor=model_cfg.text.get('seq_len_interpolation_factor', None), + rotary_base=model_cfg.text.get('rotary_base', 10000), + output_dim=model_cfg.output_dim, + ) + + else: + self.vision_encoder = CLIPVisionTransformer( + model_cfg.vision, + model_parallel_config, + pre_process=self.pre_process, + post_process=self.post_process, + ) + self.text_encoder = CLIPTextTransformer( + model_cfg.text, + model_parallel_config, + padded_vocab_size, + pre_process=self.pre_process, + post_process=self.post_process, + ) + + if self.use_siglip: + self.logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(10)) + self.logit_bias = torch.nn.Parameter(torch.ones([]) * (-10)) + else: + self.logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) def set_input_tensor(self, input_tensor): """See megatron.model.transformer.set_input_tensor()""" @@ -277,10 +575,89 @@ def forward(self, images, captions): text_features = self.text_encoder(captions) if self.post_process: + if self.use_siglip: + return ( + F.normalize(image_features, dim=-1), + F.normalize(text_features, dim=-1), + self.logit_scale.exp(), + self.logit_bias, + ) return F.normalize(image_features, dim=-1), F.normalize(text_features, dim=-1), self.logit_scale.exp() return image_features, text_features + def build_transformer_config(self) -> TransformerConfig: + """Builds the megatron core gpt transformer config for the model. + For attributes in the nemo model config that are the same + as the megatron core TransformerConfig, we will use the value from the nemo model config. + For attributes in TransformerConfig that are not in the nemo model config, we add custom logic. + """ + + normalization = self.cfg.get('normalization', 'layernorm').lower() + layernorm_zero_centered_gamma = self.cfg.get('normalization', 'layernorm') == 'layernorm1p' + if normalization == 'layernorm': + normalization = 'LayerNorm' + elif normalization == 'rmsnorm': + normalization = 'RMSNorm' + elif normalization == 'layernorm1p': + normalization = 'LayerNorm' + layernorm_zero_centered_gamma = True + else: + logging.warning( + f"The normalization type: {normalization} might not be supported in megatron core." + f"Supported types are LayerNorm and RMSNorm." + ) + + ub_tp_comm_overlap = self.cfg.get('ub_tp_comm_overlap', False) + + if not self.cfg.get('fp8', False): + fp8 = None + elif self.cfg.get('fp8_e4m3', False): + fp8 = 'e4m3' + elif self.cfg.get('fp8_hybrid', False): + fp8 = 'hybrid' + else: + raise ValueError(f"fp8 enabled but fp8_format (fp8_e4m3 | fp8_hybrid) is not set.") + + # any configs that are not in the nemo model config will be added here + model_specific_configs = { + 'layernorm_zero_centered_gamma': layernorm_zero_centered_gamma, + 'normalization': normalization, + 'fp8': fp8, + 'tp_comm_overlap': ub_tp_comm_overlap, + # MoE related + 'num_moe_experts': self.cfg.get('num_moe_experts', None), + 'moe_router_load_balancing_type': self.cfg.get('moe_router_load_balancing_type', 'aux_loss'), + 'moe_router_topk': self.cfg.get('moe_router_topk', 2), + 'moe_grouped_gemm': self.cfg.get('moe_grouped_gemm', False), + 'moe_aux_loss_coeff': self.cfg.get( + 'moe_aux_loss_coeff', 0 + ), # 1e-2 would be a good start value for load balance loss. + 'moe_z_loss_coeff': self.cfg.get('moe_z_loss_coeff', None), # 1e-3 would be a good start value for z-loss + 'moe_input_jitter_eps': self.cfg.get('moe_input_jitter_eps', None), + 'moe_token_dropping': self.cfg.get('moe_token_dropping', False), # TODO: Support token dropping. + } + if model_specific_configs['num_moe_experts'] is not None: + assert mcore_supports_moe(), 'Megatron-core >= v0.5.0 is required for MoE' + elif not mcore_supports_moe(): + if 'num_moe_experts' in model_specific_configs: + del model_specific_configs['num_moe_experts'] + moe_keys = list(filter(lambda x: x.startswith('moe_'), model_specific_configs.keys())) + for k in moe_keys: + del model_specific_configs[k] + + transformer_config = super().build_transformer_config() + + for key, value in model_specific_configs.items(): + setattr(transformer_config, key, value) + + # pass mcore customization configs directly to mcore + mcore_customization_config_dict = self.cfg.get('mcore_customization_config', {}) + for key, value in mcore_customization_config_dict.items(): + setattr(transformer_config, key, value) + + return transformer_config + class MegatronCLIPModel(MegatronBaseModel): """Megatron CLIP Model.""" @@ -302,11 +679,21 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self._validate_trainer() + # placeholder for O2 wrapper + self.transformer_config = self.build_transformer_config(self.cfg.text) + self.megatron_amp_O2 = cfg.get('megatron_amp_O2', False) + self.mcore_gpt = cfg.get('mcore_gpt', False) + if cfg.get('fp8', False): + self.prev_step_training = True if not self.megatron_amp_O2 and self.cfg.get('virtual_pipeline_model_parallel_size', None): raise ValueError('Virtual pipeline model parallel is only supported when using megatron_amp_O2') + self.transformer_engine = cfg.get('transformer_engine', False) + if self.megatron_amp_O2 and not self.transformer_engine: + logging.warning('megatron_amp_O2 is enabled but transformer-engine is not.') + # build_model returns a list of modules which are used for interleaved pipeline parallelism if isinstance(self.trainer.accelerator, CPUAccelerator): self.model = build_model( @@ -316,19 +703,24 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None), ) else: - self.model = build_model( - model_provider_func=self.model_provider_func, - wrap_with_ddp=False, - virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None), - ) + build_model_context = nullcontext + if HAVE_TE and self.cfg.get('fp8', False) and self.cfg.get('fp8_params', False): + build_model_context = transformer_engine.pytorch.fp8_model_init + with build_model_context(): + self.model = build_model( + model_provider_func=self.model_provider_func, + wrap_with_ddp=False, + virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None), + on_cpu=cfg.get('fsdp', False) and cfg.get('use_cpu_initialization', False), + ) # if we're not using interleaved, then self.model is a module. - if self.cfg.get('virtual_pipeline_model_parallel_size', None) is None: + if self.cfg.get('virtual_pipeline_model_parallel_size', None) is None and (not self.use_mcore_dist_optim): self.model = self.model[0] if self.megatron_amp_O2: - if not self.with_distributed_adam: + if not self.with_distributed_adam and not self.cfg.get("use_cpu_initialization", False): # Pre-allocate the model on GPU to have master parameters allocated on the same device with matching data type if isinstance(self.model, list): for module in self.model: @@ -336,31 +728,17 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): else: self.model.cuda(torch.cuda.current_device()) - # Model wrapper to convert both model and inputs to half precision - # TODO (yuya): check this; FP16 Module might not work; when self.model is a list? - if isinstance(self.model, list): - converted_model = [] - for module in self.model: - converted_model.append( - Float16Module(config=self.model_parallel_config, module=module, precision=cfg.precision) - ) - self.model = converted_model - else: - self.model = Float16Module( - config=self.model_parallel_config, module=self.model, precision=cfg.precision - ) + self._wrap_model_for_O2() - self.autocast_dtype = torch_dtype_from_precision(self.trainer.precision) self.enable_autocast = ( True if (not self.megatron_amp_O2) and (self.autocast_dtype in [torch.float16, torch.bfloat16]) else False ) - self.transformer_engine = cfg.get('transformer_engine', False) - # Convert the global-batch-based profile index to micro-batch index if hasattr(self, '_nsys_profile_enabled') or hasattr(self, '_memory_profile_enabled'): mp_size = cfg.get('tensor_model_parallel_size', 1) * cfg.get('pipeline_model_parallel_size', 1) - data_parallel_world_size = trainer.world_size // mp_size + cp_size = cfg.get('context_parallel_size', 1) + data_parallel_world_size = trainer.world_size // (mp_size * cp_size) grad_accum_steps = cfg.get('global_batch_size') // (cfg.get('micro_batch_size') * data_parallel_world_size) if hasattr(self, '_nsys_profile_enabled'): self._nsys_profile_start_step *= grad_accum_steps @@ -368,22 +746,36 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): if hasattr(self, '_memory_profile_enabled'): self._memory_profile_start_step *= grad_accum_steps self._memory_profile_end_step *= grad_accum_steps - self.get_attention_mask_from_fusion = self.cfg.get('get_attention_mask_from_fusion', True) - self.initialize_ub = self.cfg.get('ub_tp_comm_overlap', False) - def get_module_list(self): - if isinstance(self.model, list): - return [model.module if isinstance(model, Float16Module) else model for model in self.model] - elif isinstance(self.model, Float16Module): - return [self.model.module] - else: - return [self.model] + self.initialize_ub = self.cfg.get('ub_tp_comm_overlap', False) + self.log_train_loss = bool(int(os.getenv("NEMO_LOG_TRAIN_LOSS", 1))) + self.log_memory_usage = bool(int(os.getenv("NEMO_LOG_MEMORY_USAGE", 0))) + self.loss_broadcast_src_rank = None + data_cfg = cfg.get('data', {}) + self.return_output_tensors = data_cfg.get('return_output_tensors', False) + self.validation_drop_last = data_cfg.get('validation_drop_last', True) + self.sample_weight = data_cfg.get('sample_weight', 'token') + self.validation_param_sync_overlap = self.cfg.get('validation_param_sync_overlap', False) def model_provider_func(self, pre_process, post_process): """Model depends on pipeline paralellism.""" + vision_transformer_config = self.build_transformer_config(self.cfg.vision) if self.mcore_gpt else None + text_transformer_config = self.build_transformer_config(self.cfg.text) if self.mcore_gpt else None + + if self.mcore_gpt and not parallel_state.is_initialized(): + + def dummy(): + return + + if self.trainer.strategy.launcher is not None: + self.trainer.strategy.launcher.launch(dummy, trainer=self.trainer) + self.trainer.strategy.setup_environment() + model = CLIPModel( model_cfg=self.cfg, model_parallel_config=self.model_parallel_config, + vision_transformer_config=vision_transformer_config, + text_transformer_config=text_transformer_config, padded_vocab_size=self.padded_vocab_size, pre_process=pre_process, post_process=post_process, @@ -401,9 +793,40 @@ def setup_optimizer_param_groups(self): else: self._optimizer_param_groups = get_params_for_weight_decay_optimization(self.model) + def setup_mcore_distributed_parallel(self): + """Set up mcore distributed data parallel""" + if self.with_distributed_adam and self.use_mcore_dist_optim: + config = get_model_config(self.model[0]) + ddp_config = DistributedDataParallelConfig( + grad_reduce_in_fp32=(self.cfg.optim.get('grad_sync_dtype', 'fp32') == 'fp32'), + overlap_grad_reduce=self.cfg.optim.get('overlap_grad_sync', False), + use_distributed_optimizer=True, + check_for_nan_in_grad=self.cfg.optim.get('check_for_nan_in_grad', False), + # mcore bucket_size is based on num of parameters, therefore not + # using bucket_cap_mb to configure bucket_size here + bucket_size=self.cfg.optim.get('ddp_bucket_size', None), + ) + + self.model = [ + McoreDDP( + config, + ddp_config, + model_chunk, + data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True), + expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(), + # Turn off bucketing for model_chunk 2 onwards, since communication for these + # model chunks is overlapped with compute anyway. + disable_bucketing=(model_chunk_idx > 0), + ) + for (model_chunk_idx, model_chunk) in enumerate(self.model) + ] + + # (TODO) Broadcast params from data parallel src rank to other data parallel ranks. + # by calling model_module.broadcast_params() if the model is randomly initialized. + def configure_optimizers(self): - if self.with_distributed_adam: + if self.with_distributed_adam and not self.use_mcore_dist_optim: # Disable overlapped grad sync for layer norm grads when # sequence parallelism is enabled @@ -462,13 +885,16 @@ def fwd_bwd_step(self, dataloader_iter, forward_only): no_sync_func = None grad_sync_func = None param_sync_func = None - if not forward_only and self.with_distributed_adam: - no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,) + if not forward_only and self.with_distributed_adam and not self.use_mcore_dist_optim: + no_sync_func = partial( + self._optimizer.no_sync, + greedy_grad_copy=self.megatron_amp_O2, + ) grad_sync_func = self.reduce_overlap_gradients param_sync_func = self.sync_overlap_parameters # pipeline schedules will get these from self.model.config - for module in self.get_module_list(): + for module in self.get_model_module_list(): module.config.no_sync_func = no_sync_func module.config.grad_sync_func = grad_sync_func module.config.param_sync_func = param_sync_func @@ -515,7 +941,9 @@ def initialize_ub_func(self): ) input_shape = [ - self.cfg.get('encoder_seq_length') * self.cfg.get('micro_batch_size'), + self.cfg.get('encoder_seq_length') + * self.cfg.get('micro_batch_size') + // self.cfg.get('context_parallel_size', 1), self.cfg.get('hidden_size'), ] @@ -529,12 +957,12 @@ def initialize_ub_func(self): def training_step(self, dataloader_iter): """ - Our dataloaders produce a micro-batch and then we fetch - a number of microbatches depending on the global batch size and model parallel size - from the dataloader to produce a list of microbatches. - Batch should be a list of microbatches and those microbatches should on CPU. - Microbatches are then moved to GPU during the pipeline. - The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + Batch should be a list of microbatches and those microbatches should on CPU. + Microbatches are then moved to GPU during the pipeline. + The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. """ # Initialize userbuffer communicators. if self.initialize_ub: @@ -543,7 +971,7 @@ def training_step(self, dataloader_iter): # we zero grads here because we also call backward in the megatron-core fwd/bwd functions self._optimizer.zero_grad() - if self.with_distributed_adam: + if self.with_distributed_adam and not self.use_mcore_dist_optim: # hack to enable overlapping param sync and forward compute # note: the distributed optimizer monkey-patches each # parameter's __getattribute__ function so that it can @@ -554,9 +982,10 @@ def training_step(self, dataloader_iter): # manually interact with the parameter. modules = self.model if isinstance(self.model, list) else [self.model] for module in modules: - if isinstance(module, Float16Module): + if isinstance(module, (Float16Module, MCoreFloat16Module)): module = module.module - module = module.text_encoder.language_model + if not self.mcore_gpt: + module = module.language_model if hasattr(module, 'embedding'): for param in module.embedding.parameters(): param.data_ptr() @@ -567,38 +996,115 @@ def training_step(self, dataloader_iter): if self.cfg.get('tensor_model_parallel_size', 1) > 1 and self.cfg.get('sequence_parallel', False): self.allreduce_sequence_parallel_gradients() - if self.with_distributed_adam: - # synchronize asynchronous grad reductions - # note: not necessary, but reduces performance degradation - # from multiple simultaneous NCCL calls - self._optimizer._finish_bucket_grad_sync() + if self.cfg.get('fp8', False): + self.prev_step_training = self.training + + # Optimization: Defer the embedding GEMM Wgrads of the last PP stage to pipeline flush waiting time + if self.cfg.get('pipeline_model_parallel_size', 1) > 1 and parallel_state.is_pipeline_last_stage( + ignore_virtual=True + ): + if ( + self.cfg.get('defer_embedding_wgrad_compute', False) and self.mcore_gpt + ): # Silently ignore the optimization if MCORE is not used + module_list = self.get_model_module_list() + if len(module_list) > 1: + embedding_module = module_list[-1] + else: + embedding_module = module_list[0] + + embedding_activation_buffer = embedding_module.embedding_activation_buffer + grad_output_buffer = embedding_module.grad_output_buffer + weight = embedding_module.output_layer.weight + + drain_embedding_wgrad_compute( + embedding_module.config, embedding_activation_buffer, grad_output_buffer, weight + ) + + # when using sequence parallelism, the sequence parallel layernorm grads must be all-reduced + if self.cfg.get('tensor_model_parallel_size', 1) > 1 and self.cfg.get('sequence_parallel', False): + self.megatron_timer_start('allreduce_sequence_parallel_gradients', log_level=1) + self.allreduce_sequence_parallel_gradients() + self.megatron_timer_stop('allreduce_sequence_parallel_gradients') + + self.megatron_timer_start('gradient_allreduce', log_level=1) + if self.use_fsdp: + # Reduce the gradients omitted from FSDP-sharding + self.allreduce_fsdp_sharding_omitted_gradients() + elif self.with_distributed_adam: + if not self.use_mcore_dist_optim: + # synchronize asynchronous grad reductions + # note: not necessary, but reduces performance degradation + # from multiple simultaneous NCCL calls + self._optimizer._finish_bucket_grad_sync() + # else: Mcore distributed optim calls finalize_model_grads to finish grad sync elif self.megatron_amp_O2: # when using pipeline parallelism grads must be all-reduced after the pipeline (not asynchronously) - # if self.cfg.get('pipeline_model_parallel_size', 1) > 1 or self.cfg.get('sequence_parallel', False): - # # main grads are stored in the MainParamsOptimizer wrapper - self._optimizer.allreduce_main_grads() + if ( + self.cfg.get('pipeline_model_parallel_size', 1) > 1 + or self.cfg.get('sequence_parallel', False) + or not self.cfg.get('async_grad_allreduce', True) + ): + # main grads are stored in the MainParamsOptimizer wrapper + self._optimizer.allreduce_main_grads() else: # async grad allreduce is not currently implemented for O1/autocasting mixed precision training # so we all-reduce gradients after the pipeline self.allreduce_gradients() # @sangkug we think this is causing memory to blow up (hurts perf) + self.megatron_timer_stop('gradient_allreduce') + + if ( + not self.use_mcore_dist_optim + and self.cfg.get('pipeline_model_parallel_size', 1) > 1 + and self.cfg.get('share_embeddings_and_output_weights', True) + ): + self.megatron_timer_start('allreduce_first_last_embeddings', log_level=1) + # when using pipeline parallelism the first and last stage must keep embeddings in sync + self.allreduce_first_last_embeddings() + self.megatron_timer_stop('allreduce_first_last_embeddings') + + if self.log_memory_usage: + mem_reserved = torch.cuda.max_memory_reserved() + self.log( + 'peak_memory_usage', + mem_reserved, + prog_bar=True, + rank_zero_only=True, + batch_size=1, + ) ## logging - # we can only log on one rank if it is rank zero so we broadcast from last rank - # we can avoid this broadcast by updating the PTL log function to accept specific ranks - torch.distributed.broadcast(loss_mean, get_last_rank()) - - if self.cfg.precision in [16, '16', '16-mixed']: - loss_scale = self.trainer.precision_plugin.scaler._scale - if loss_scale is not None: - self.log('loss_scale', loss_scale, batch_size=1) + if self.log_train_loss: + # When using pipeline parallelism, loss is calculated only in the last pipeline stage and + # it should be casted to other pipeline stages for logging. + # we can avoid this broadcast by updating the PTL log function to accept specific ranks + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + if torch.distributed.get_rank() == get_last_rank(): + torch.distributed.send(loss_mean, 0) + elif torch.distributed.get_rank() == 0: + torch.distributed.recv(loss_mean, get_last_rank()) + self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1) + + # (@adithyare) we need to check for the _scaler attribute to enable pp>1 for adapter training + if self.cfg.precision == 16 and hasattr(self.trainer.precision_plugin.scaler, "_scale"): + loss_scale = self.trainer.precision_plugin.scaler._scale + if loss_scale is not None: + self.log('loss_scale', loss_scale, batch_size=1) - self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1) lr = self._optimizer.param_groups[0]['lr'] self.log('lr', lr, rank_zero_only=True, batch_size=1) - self.log('global_step', self.trainer.global_step + 1, prog_bar=True, rank_zero_only=True, batch_size=1) + self.log( + 'global_step', + self.trainer.global_step + 1, + prog_bar=True, + rank_zero_only=True, + batch_size=1, + ) + + consumed_samples = self._compute_consumed_samples_after_training_step() + # TODO: make sure compute_consumed_samples works for pipeline parallelism self.log( 'consumed_samples', - self.compute_consumed_samples(self.trainer.global_step + 1 - self.init_global_step), + consumed_samples, prog_bar=True, rank_zero_only=True, batch_size=1, @@ -607,20 +1113,20 @@ def training_step(self, dataloader_iter): return loss_mean def backward(self, *args, **kwargs): - """ LightningModule hook to do backward. - We want this to do nothing since we run backward in the fwd/bwd functions from apex. - No need to call it here. + """LightningModule hook to do backward. + We want this to do nothing since we run backward in the fwd/bwd functions from apex. + No need to call it here. """ pass def optimizer_zero_grad(self, *args, **kwargs): - """ LightningModule hook to zero grad. - We want this to do nothing as we are zeroing grads during the training_step. + """LightningModule hook to zero grad. + We want this to do nothing as we are zeroing grads during the training_step. """ pass def _append_sequence_parallel_module_grads(self, module, grads): - """ Helper method for allreduce_sequence_parallel_gradients""" + """Helper method for allreduce_sequence_parallel_gradients""" for param in module.parameters(): sequence_parallel_param = getattr(param, 'sequence_parallel', False) @@ -632,9 +1138,9 @@ def _append_sequence_parallel_module_grads(self, module, grads): grads.append(grad.data) def allreduce_sequence_parallel_gradients(self): - """ All-reduce layernorm parameters across model parallel nodes when sequence parallelism is used. - Modified from megatron-lm: - https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/3f91f09bb2ab32f9904b47f46f19d2fc3f518ed8/megatron/training.py#L425 + """All-reduce layernorm parameters across model parallel nodes when sequence parallelism is used. + Modified from megatron-lm: + https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/3f91f09bb2ab32f9904b47f46f19d2fc3f518ed8/megatron/training.py#L425 """ grads = [] @@ -650,7 +1156,18 @@ def allreduce_sequence_parallel_gradients(self): buf.copy_(synced) def get_forward_output_and_loss_func(self): - loss_func = ClipLoss(local_loss=self.cfg.local_loss, gather_with_grad=self.cfg.gather_with_grad,) + if self.cfg.get("use_siglip", False): + # TODO(yuya): fix rank + loss_func = SigLipLoss( + rank=parallel_state.get_data_parallel_rank(), + world_size=parallel_state.get_data_parallel_world_size(), + group=parallel_state.get_data_parallel_group(), + ) + else: + loss_func = ClipLoss( + local_loss=self.cfg.local_loss, + gather_with_grad=self.cfg.gather_with_grad, + ) def fwd_output_and_loss_func(dataloader_iter, model): batch, _, _ = next(dataloader_iter) @@ -690,7 +1207,8 @@ def zero_shot_classifier(self): texts = texts.cuda(non_blocking=True) # TODO (yuya): distributed not working with torch.cuda.amp.autocast( - enabled=self.autocast_dtype in (torch.half, torch.bfloat16), dtype=self.autocast_dtype, + enabled=self.autocast_dtype in (torch.half, torch.bfloat16), + dtype=self.autocast_dtype, ): class_embeddings = text_encoder(texts) class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) @@ -726,7 +1244,8 @@ def accuracy(output, target, topk=(1,)): target = target.cuda(non_blocking=True) # predict with torch.cuda.amp.autocast( - enabled=self.autocast_dtype in (torch.half, torch.bfloat16), dtype=self.autocast_dtype, + enabled=self.autocast_dtype in (torch.half, torch.bfloat16), + dtype=self.autocast_dtype, ): image_features = vision_encoder(images) image_features = F.normalize(image_features, dim=-1) @@ -745,10 +1264,10 @@ def accuracy(output, target, topk=(1,)): def validation_step(self, dataloader_iter): """ - Our dataloaders produce a micro-batch and then we fetch - a number of microbatches depending on the global batch size and model parallel size - from the dataloader to produce a list of microbatches. - The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. """ + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions.""" # Initialize userbuffer communicators. if self.initialize_ub: self.initialize_ub_func() @@ -801,7 +1320,9 @@ def build_train_valid_test_datasets(self): raise ValueError("limit_val_batches must be an integer or float less than or equal to 1.0.") self._train_ds, self._validation_ds = build_train_valid_datasets( - model_cfg=self.cfg, consumed_samples=self.compute_consumed_samples(0), tokenizer=self.tokenizer, + model_cfg=self.cfg, + consumed_samples=self.compute_consumed_samples(0), + tokenizer=self.tokenizer, ) self._test_ds = None @@ -816,7 +1337,7 @@ def build_train_valid_test_datasets(self): return self._train_ds, self._validation_ds, self._test_ds def setup(self, stage=None): - """ PTL hook that is executed after DDP spawns. + """PTL hook that is executed after DDP spawns. We setup datasets here as megatron datasets require DDP to instantiate. See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information. Args: @@ -909,23 +1430,18 @@ def setup_test_data(self, cfg): f'Setting up test dataloader with len(len(self._test_ds)): {len(self._test_ds)} and consumed samples: {consumed_samples}' ) self._test_dl = torch.utils.data.DataLoader( - self._test_ds, batch_size=self._micro_batch_size, num_workers=cfg.num_workers, pin_memory=True, + self._test_ds, + batch_size=self._micro_batch_size, + num_workers=cfg.num_workers, + pin_memory=True, ) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any: raise NotImplementedError - def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any: - """ PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device - When using pipeline parallelism, we need the global batch to remain on the CPU, - since the memory overhead will be too high when using a large number of microbatches. - Microbatches are transferred from CPU to GPU inside the pipeline. - """ - return batch - def _validate_trainer(self): - """ Certain trainer configurations can break training. - Here we try to catch them and raise an error. + """Certain trainer configurations can break training. + Here we try to catch them and raise an error. """ if self.trainer.accumulate_grad_batches > 1: raise ValueError( @@ -961,3 +1477,178 @@ def parameters(self): return itertools.chain.from_iterable(module.parameters() for module in self.model) else: return self.model.parameters() + + def build_transformer_config(self, model_cfg=None) -> TransformerConfig: + """Builds the megatron core gpt transformer config for the model. + For attributes in the nemo model config that are the same + as the megatron core TransformerConfig, we will use the value from the nemo model config. + For attributes in TransformerConfig that are not in the nemo model config, we add custom logic. + """ + if model_cfg is None: + model_cfg = self.cfg + normalization = model_cfg.get('normalization', 'layernorm').lower() + layernorm_zero_centered_gamma = model_cfg.get('normalization', 'layernorm') == 'layernorm1p' + if normalization == 'layernorm': + normalization = 'LayerNorm' + elif normalization == 'rmsnorm': + normalization = 'RMSNorm' + elif normalization == 'layernorm1p': + normalization = 'LayerNorm' + layernorm_zero_centered_gamma = True + else: + logging.warning( + f"The normalization type: {normalization} might not be supported in megatron core." + f"Supported types are LayerNorm and RMSNorm." + ) + + ub_tp_comm_overlap = model_cfg.get('ub_tp_comm_overlap', False) + + if not model_cfg.get('fp8', False): + fp8 = None + elif model_cfg.get('fp8_e4m3', False): + fp8 = 'e4m3' + elif model_cfg.get('fp8_hybrid', False): + fp8 = 'hybrid' + else: + raise ValueError(f"fp8 enabled but fp8_format (fp8_e4m3 | fp8_hybrid) is not set.") + + # any configs that are not in the nemo model config will be added here + model_specific_configs = { + 'layernorm_zero_centered_gamma': layernorm_zero_centered_gamma, + 'normalization': normalization, + 'fp8': fp8, + 'tp_comm_overlap': ub_tp_comm_overlap, + # MoE related + 'num_moe_experts': model_cfg.get('num_moe_experts', None), + 'moe_router_load_balancing_type': model_cfg.get('moe_router_load_balancing_type', 'aux_loss'), + 'moe_router_topk': model_cfg.get('moe_router_topk', 2), + 'moe_grouped_gemm': model_cfg.get('moe_grouped_gemm', False), + 'moe_aux_loss_coeff': model_cfg.get( + 'moe_aux_loss_coeff', 0 + ), # 1e-2 would be a good start value for load balance loss. + 'moe_z_loss_coeff': model_cfg.get('moe_z_loss_coeff', None), # 1e-3 would be a good start value for z-loss + 'moe_input_jitter_eps': model_cfg.get('moe_input_jitter_eps', None), + 'moe_token_dropping': model_cfg.get('moe_token_dropping', False), # TODO: Support token dropping. + } + if model_specific_configs['num_moe_experts'] is not None: + assert mcore_supports_moe(), 'Megatron-core >= v0.5.0 is required for MoE' + elif not mcore_supports_moe(): + if 'num_moe_experts' in model_specific_configs: + del model_specific_configs['num_moe_experts'] + moe_keys = list(filter(lambda x: x.startswith('moe_'), model_specific_configs.keys())) + for k in moe_keys: + del model_specific_configs[k] + + # create a dictionary copy of the model config + cfg = OmegaConf.to_container(model_cfg, resolve=True) + + # create a dict to store the transformer config arguments + transformer_config_dict = {} + + # get model parallel configs from the base class + model_parallel_config = self.build_model_parallel_config() + + add_bias_linear = model_cfg.get('bias', True) + add_qkv_bias = model_cfg.get('qkv_bias', False) + + activation = model_cfg.get('activation', 'gelu') + gated_linear_unit = activation.endswith('glu') + # TODO: need to check which activation functions are supported in mcore + activation_func = activation_to_func(activation, openai_gelu=model_cfg.get("openai_gelu", False)) + + normalization = model_cfg.get('normalization', 'LayerNorm') + + init_method_std = model_cfg.get('init_method_std', 0.02) + # default used in mcore + init_method = init_method_normal(init_method_std) + + output_layer_init_method = init_method + num_layers = model_cfg.get('num_layers', 1) + use_scaled_init_method = model_cfg.get('use_scaled_init_method', True) + if use_scaled_init_method: + output_layer_init_method = scaled_init_method_normal(init_method_std, num_layers=num_layers) + + attention_softmax_in_fp32 = False # not currently used in NeMo unless apply_query_key_layer_scaling is True + apply_query_key_layer_scaling = model_cfg.get('apply_query_key_layer_scaling', False) + + rotary_interleaved = model_cfg.get('rotary_interleaved', False) + + fp16_enabled = self.trainer.precision in [16, '16', '16-mixed'] + if apply_query_key_layer_scaling: + if fp16_enabled: + os.environ["NVTE_APPLY_QK_LAYER_SCALING"] = "1" + else: + logging.warning( + "apply_query_key_layer_scaling is only enabled when using FP16, setting it to False " + "and setting NVTE_APPLY_QK_LAYER_SCALING=0" + ) + os.environ["NVTE_APPLY_QK_LAYER_SCALING"] = "0" + apply_query_key_layer_scaling = False + + if apply_query_key_layer_scaling: + attention_softmax_in_fp32 = True + + bias_activation_fusion = model_cfg.get('bias_activation_fusion', True) + + bias_dropout_fusion = model_cfg.get('bias_dropout_add_fusion', True) + + apply_rope_fusion = model_cfg.get('apply_rope_fusion', False) + + # TODO: need to check if recompute APIs are matching up properly + recompute_granularity = model_cfg.get('activations_checkpoint_granularity', None) + recompute_method = model_cfg.get('activations_checkpoint_method', None) + recompute_num_layers = model_cfg.get('activations_checkpoint_num_layers', None) + + # any configs that are not in the nemo model config will be added here + config_mapping = { + 'apply_query_key_layer_scaling': apply_query_key_layer_scaling, + 'apply_residual_connection_post_layernorm': False, # we don't use this in NeMo + 'layernorm_zero_centered_gamma': False, + 'add_bias_linear': add_bias_linear, + 'add_qkv_bias': add_qkv_bias, + 'gated_linear_unit': gated_linear_unit, + 'activation_func': activation_func, + 'normalization': normalization, + 'init_method': init_method, + 'output_layer_init_method': output_layer_init_method, + 'attention_softmax_in_fp32': attention_softmax_in_fp32, + 'bias_activation_fusion': bias_activation_fusion, + 'bias_dropout_fusion': bias_dropout_fusion, + 'apply_rope_fusion': apply_rope_fusion, + 'recompute_granularity': recompute_granularity, + 'recompute_method': recompute_method, + 'recompute_num_layers': recompute_num_layers, + 'distribute_saved_activations': False, # not currently used in NeMo + 'fp8': None, + 'rotary_interleaved': rotary_interleaved, + 'deallocate_pipeline_outputs': True, + } + + # populate the transformer config dict + for field in fields(TransformerConfig): + # config mapping has second highest priority + if field.name in config_mapping: + transformer_config_dict[field.name] = config_mapping[field.name] + # then config + elif field.name in cfg: + transformer_config_dict[field.name] = cfg[field.name] + # then model parallel config + elif field in fields(model_parallel_config): + transformer_config_dict[field.name] = getattr(model_parallel_config, field.name) + else: + logging.warning( + f"The model: {self} does not have field.name: {field.name} in its cfg. " + f"Add this key to cfg or config_mapping to make to make it configurable." + ) + + transformer_config = TransformerConfig(**transformer_config_dict) + + for key, value in model_specific_configs.items(): + setattr(transformer_config, key, value) + + # pass mcore customization configs directly to mcore + mcore_customization_config_dict = model_cfg.get('mcore_customization_config', {}) + for key, value in mcore_customization_config_dict.items(): + setattr(transformer_config, key, value) + + return transformer_config diff --git a/nemo/collections/multimodal/modules/stable_diffusion/attention.py b/nemo/collections/multimodal/modules/stable_diffusion/attention.py index 2eeed97db7810..e748bcbf93a08 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/attention.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/attention.py @@ -227,6 +227,10 @@ def __init__(self, in_features, out_features, bias=True, lora_network_alpha=None def forward(self, x): mixed_x = super().forward(x) if self.is_adapter_available(): + # return this output if lora is not enabled + cfg = self.get_adapter_cfg(AdapterName.PARALLEL_LINEAR_ADAPTER) + if not cfg['enabled']: + return mixed_x lora_linear_adapter = self.get_adapter_module(AdapterName.PARALLEL_LINEAR_ADAPTER) lora_mixed_x = lora_linear_adapter(x) # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/denoiser.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/denoiser.py index df1f27449bd19..a358bb08f92d4 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/denoiser.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/denoiser.py @@ -33,13 +33,18 @@ def possibly_quantize_c_noise(self, c_noise): def w(self, sigma): return self.weighting(sigma) - def __call__(self, network, input, sigma, cond): + def __call__(self, network, input, sigma, cond, return_noise=False): sigma = self.possibly_quantize_sigma(sigma) sigma_shape = sigma.shape sigma = append_dims(sigma, input.ndim) c_skip, c_out, c_in, c_noise = self.scaling(sigma) c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) - return network(input * c_in, c_noise, cond) * c_out + input * c_skip + # predict noise from network + noise_pred = network(input * c_in, c_noise, cond) + denoised = noise_pred * c_out + input * c_skip + if return_noise: + return denoised, noise_pred + return denoised class DiscreteDenoiser(Denoiser): diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py index 804eb9a2753ac..3fb3a8cee3a2f 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py @@ -789,6 +789,7 @@ def __init__( self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch input_block_chans.append(ch) + if level != len(channel_mult) - 1: out_ch = ch self.input_blocks.append( @@ -954,6 +955,7 @@ def __init__( ) if from_pretrained is not None: + logging.info(f"Attempting to load pretrained unet from {from_pretrained}") if from_pretrained.endswith('safetensors'): from safetensors.torch import load_file as load_safetensors @@ -1023,6 +1025,16 @@ def _input_blocks_mapping(self, input_dict): .replace('conv2', 'out_layers.3') .replace('conv_shortcut', 'skip_connection') ) + ## Rohit: I've changed this to make sure it is compatible + # post_fix = ( + # key_[25:] + # .replace('time_emb_proj', 'emb_layers.1') + # .replace('norm1', 'in_layers.0') + # .replace('norm2', 'out_layers.0') + # .replace('conv1', 'in_layers.1') + # .replace('conv2', 'out_layers.2') + # .replace('conv_shortcut', 'skip_connection') + # ) res_dict["input_blocks." + str(target_id) + '.0.' + post_fix] = value_ elif "attentions" in key_: id_1 = int(key_[26]) @@ -1170,7 +1182,7 @@ def te_fp8_key_mapping(self, unet_dict): return new_state_dict def _state_key_mapping(self, state_dict: dict): - + # state_dict is a HF model res_dict = {} input_dict = {} mid_dict = {} diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/sampling.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/sampling.py index c636ffec345d1..bfae8790eeb26 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/sampling.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/sampling.py @@ -47,7 +47,12 @@ def __init__( ): self.num_steps = num_steps self.discretization = instantiate_from_config(discretization_config) - self.guider = instantiate_from_config(default(guider_config, DEFAULT_GUIDER,)) + self.guider = instantiate_from_config( + default( + guider_config, + DEFAULT_GUIDER, + ) + ) self.verbose = verbose self.device = device @@ -93,35 +98,50 @@ def euler_step(self, x, d, dt): class EDMSampler(SingleStepDiffusionSampler): def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs): super().__init__(*args, **kwargs) - self.s_churn = s_churn self.s_tmin = s_tmin self.s_tmax = s_tmax self.s_noise = s_noise - def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0): + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0, return_noise=False): + # x is actually \bar{x} as in the DDIM paper sigma_hat = sigma * (gamma + 1.0) if gamma > 0: eps = torch.randn_like(x) * self.s_noise - x = x + eps * append_dims(sigma_hat ** 2 - sigma ** 2, x.ndim) ** 0.5 + x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 denoised = self.denoise(x, denoiser, sigma_hat, cond, uc) + # this is the noise (e_t) d = to_d(x, sigma_hat, denoised) dt = append_dims(next_sigma - sigma_hat, x.ndim) - euler_step = self.euler_step(x, d, dt) + euler_step = self.euler_step(x, d, dt) # this is x_{t-\delta{t}} x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc) + if return_noise: + return x, d return x + def get_gamma(self, sigmas, num_sigmas, index): + gamma = ( + min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[index] <= self.s_tmax else 0.0 + ) + return gamma + def __call__(self, denoiser, x, cond, uc=None, num_steps=None): + # prepare_sampling_loop converts x into \bar{x} = x / \sqrt{\tilde{\alpha_t}} x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) for i in self.get_sigma_gen(num_sigmas): - gamma = ( - min(self.s_churn / (num_sigmas - 1), 2 ** 0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 + gamma = self.get_gamma(sigmas, num_sigmas, i) + x = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc, + gamma, ) - x = self.sampler_step(s_in * sigmas[i], s_in * sigmas[i + 1], denoiser, x, cond, uc, gamma,) - return x @@ -151,14 +171,24 @@ def __call__(self, denoiser, x, cond, uc=None, num_steps=None): x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) for i in self.get_sigma_gen(num_sigmas): - x = self.sampler_step(s_in * sigmas[i], s_in * sigmas[i + 1], denoiser, x, cond, uc,) + x = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc, + ) return x class LinearMultistepSampler(BaseDiffusionSampler): def __init__( - self, order=4, *args, **kwargs, + self, + order=4, + *args, + **kwargs, ): super().__init__(*args, **kwargs) @@ -276,7 +306,15 @@ def get_mult(self, h, r, t, t_next, previous_sigma): return mult1, mult2 def sampler_step( - self, old_denoised, previous_sigma, sigma, next_sigma, denoiser, x, cond, uc=None, + self, + old_denoised, + previous_sigma, + sigma, + next_sigma, + denoiser, + x, + cond, + uc=None, ): denoised = self.denoise(x, denoiser, sigma, cond, uc) diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/wrappers.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/wrappers.py index 0d465c1275c63..24e2124e6f83a 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/wrappers.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/wrappers.py @@ -37,6 +37,11 @@ class OpenAIWrapper(IdentityWrapper): def forward(self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs) -> torch.Tensor: if c.get("concat", None): x = torch.cat((x, c.get("concat")), dim=1) + return self.diffusion_model( - x, timesteps=t, context=c.get("crossattn", None), y=c.get("vector", None), **kwargs, + x, + timesteps=t, + context=c.get("crossattn", None), + y=c.get("vector", None), + **kwargs, ) diff --git a/nemo/collections/multimodal/modules/stable_diffusion/encoders/modules.py b/nemo/collections/multimodal/modules/stable_diffusion/encoders/modules.py index bff579bbca4fb..ab33532c3c1fa 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/encoders/modules.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/encoders/modules.py @@ -298,7 +298,7 @@ def encode(self, x): class BERTTokenizer(AbstractEncoder): - """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" + """Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" def __init__(self, device="cuda", vq_interface=True, max_length=77): super().__init__() @@ -530,7 +530,10 @@ def __init__( print(f"Downloading clip with", arch, version, cache_dir) self.device = device model, _, _ = open_clip.create_model_and_transforms( - arch, device=torch.device("cpu"), pretrained=version, cache_dir=cache_dir, + arch, + device=torch.device("cpu"), + pretrained=version, + cache_dir=cache_dir, ) del model.visual self.model = model @@ -669,7 +672,11 @@ def build_tokenizer(self, cfg): legacy=legacy, ) - _, self.text_transform = get_preprocess_fns(cfg, self.tokenizer, is_train=False,) + _, self.text_transform = get_preprocess_fns( + cfg, + self.tokenizer, + is_train=False, + ) self.max_length = cfg.text.get("max_position_embeddings") def load_model(self, cfg, state_dict): @@ -699,8 +706,7 @@ def load_model(self, cfg, state_dict): def _vocab_size_with_padding(self, orig_vocab_size, make_vocab_size_divisible_by, tensor_model_parallel_size): after = orig_vocab_size multiple = make_vocab_size_divisible_by * tensor_model_parallel_size - while (after % multiple) != 0: - after += 1 + after = ((after + multiple - 1) // multiple) * multiple return after def forward(self, text): @@ -765,7 +771,11 @@ def __init__( super().__init__() assert layer in self.LAYERS self.projection_dim = 1280 - model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device("cpu"), pretrained=version,) + model, _, _ = open_clip.create_model_and_transforms( + arch, + device=torch.device("cpu"), + pretrained=version, + ) del model.visual self.model = model diff --git a/nemo/collections/multimodal/parts/utils.py b/nemo/collections/multimodal/parts/utils.py index 9ad8856daa63f..5a01e8702a9ee 100644 --- a/nemo/collections/multimodal/parts/utils.py +++ b/nemo/collections/multimodal/parts/utils.py @@ -23,11 +23,11 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins.environments import TorchElasticEnvironment from transformers import CLIPImageProcessor, SiglipImageProcessor -from nemo.collections.multimodal.data.clip.augmentations.augmentations import image_transform +from nemo.collections.multimodal.data.clip.augmentations.augmentations import image_transform from nemo.collections.multimodal.data.neva.neva_dataset import process_image from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel -from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPFSDPStrategy, NLPSaveRestoreConnector from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision from nemo.utils import AppState, logging @@ -135,8 +135,10 @@ def load_nemo_model_weights(nemo_path, sharded_state_dict=None): # distributed checkpointing if state_dict is None and sharded_state_dict is not None: + is_dist_ckpt = True checkpoint = dict(state_dict=sharded_state_dict) + tmp_model_weights_ckpt = os.path.join(tmpdir, save_restore_connector.model_weights_ckpt) tmp_model_weights_dir = os.path.splitext(tmp_model_weights_ckpt)[0] assert os.path.isdir(tmp_model_weights_dir), f'Expected {tmp_model_weights_dir} to be a directory.' @@ -274,10 +276,23 @@ def setup_trainer_and_model_for_inference( # Use the NLPDDPStrategy for the distributed data parallel strategy. # We don't use DDP for async grad allreduce and don't find unused parameters. - strategy = NLPDDPStrategy( - no_ddp_communication_hook=True, - find_unused_parameters=False, - ) + if not cfg.model.get('fsdp', False): + logging.info("FSDP is False, using DDP strategy.") + strategy = NLPDDPStrategy( + no_ddp_communication_hook=True, + find_unused_parameters=False, + ) + else: + logging.info("Using FSDP strategy.") + strategy = NLPFSDPStrategy( + limit_all_gathers=cfg.model.get('fsdp_limit_all_gathers', True), + sharding_strategy=cfg.model.get('fsdp_sharding_strategy', 'full'), + cpu_offload=cfg.model.get('fsdp_cpu_offload', True), + grad_reduce_dtype=cfg.model.get('fsdp_grad_reduce_dtype', 32), + precision=cfg.trainer.precision, + # use_orig_params=cfg.model.inductor, + set_buffer_dtype=cfg.get('fsdp_set_buffer_dtype', None), + ) # Set up the trainer with the specified plugins and strategy. trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer) @@ -321,7 +336,9 @@ def setup_trainer_and_model_for_inference( ) else: - raise ValueError(f"Unrecognized checkpoint type: {cfg.model.restore_from_path}") + # load a model from scratch + logging.warning("Loading a model from scratch for inference. Tread carefully.") + model = model_provider(cfg=cfg.model, trainer=trainer) # initialize apex DDP strategy def dummy(): @@ -501,7 +518,7 @@ def expand2square(pil_img, background_color): result.paste(pil_img, ((height - width) // 2, 0)) return result - frames = expand2square(frames, tuple(int(x * 255) for x in processor.image_mean)) + frames = [expand2square(frame, tuple(int(x * 255) for x in self.processor.image_mean)) for frame in frames] frames = processor.preprocess(frames, return_tensors='pt')['pixel_values'] else: frames = processor.preprocess(frames, return_tensors='pt')['pixel_values'] @@ -525,8 +542,8 @@ def create_image_processor(mm_cfg): else: raise (ValueError("Currently only support CLIPImageProcessor and SiglipImageProcessor from Huggingface")) - crop_size = mm_cfg.vision_encoder.get("crop_size", (224, 224)) - if hasattr(image_processor, 'crop_size'): + crop_size = mm_cfg.vision_encoder.get("crop_size") + if hasattr(image_processor, 'crop_size') and crop_size is not None: assert crop_size == ( image_processor.crop_size['height'], image_processor.crop_size['width'], diff --git a/nemo/collections/multimodal/speech_cv/data/video_to_text.py b/nemo/collections/multimodal/speech_cv/data/video_to_text.py index a20d6e5bb9a8c..2034e554d7a1e 100644 --- a/nemo/collections/multimodal/speech_cv/data/video_to_text.py +++ b/nemo/collections/multimodal/speech_cv/data/video_to_text.py @@ -19,7 +19,7 @@ import webdataset as wds from nemo.collections.asr.data.audio_to_text import cache_datastore_manifests, expand_sharded_filepaths -from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType +from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType from nemo.collections.common import tokenizers from nemo.collections.common.parts.preprocessing import collections, parsers from nemo.collections.multimodal.speech_cv.parts.preprocessing.features import VideoFeaturizer @@ -123,8 +123,7 @@ class _VideoTextDataset(Dataset): @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { 'video_signal': NeuralType(('B', 'C', 'T', 'H', 'W'), VideoSignal()), 'video_sig_length': NeuralType(tuple('B'), LengthsType()), @@ -307,8 +306,7 @@ class VideoToBPEDataset(_VideoTextDataset): @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { 'video_signal': NeuralType(('B', 'C', 'T', 'H', 'W'), VideoSignal()), 'video_sig_length': NeuralType(tuple('B'), LengthsType()), @@ -411,8 +409,7 @@ class VideoToCharDataset(_VideoTextDataset): @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { 'video_signal': NeuralType(('B', 'C', 'T', 'H', 'W'), VideoSignal()), 'video_sig_length': NeuralType(tuple('B'), LengthsType()), @@ -641,8 +638,7 @@ def __next__(self): return TarredAudioFilter(self.manifest_processor.collection) def _loop_offsets(self, iterator): - """This function is used to iterate through utterances with different offsets for each file. - """ + """This function is used to iterate through utterances with different offsets for each file.""" class TarredAudioLoopOffsets: def __init__(self, collection): @@ -675,8 +671,7 @@ def _collate_fn(self, batch): return _video_speech_collate_fn(batch, self.pad_id) def _build_sample(self, tup): - """Builds the training sample by combining the data from the WebDataset with the manifest info. - """ + """Builds the training sample by combining the data from the WebDataset with the manifest info.""" video_tuple, audio_filename, offset_id = tup # Grab manifest entry from self.manifest_preprocessor.collection diff --git a/nemo/collections/multimodal/speech_cv/models/visual_ctc_models.py b/nemo/collections/multimodal/speech_cv/models/visual_ctc_models.py index a8226c3fc403e..13f92f1acb14a 100644 --- a/nemo/collections/multimodal/speech_cv/models/visual_ctc_models.py +++ b/nemo/collections/multimodal/speech_cv/models/visual_ctc_models.py @@ -29,8 +29,8 @@ from nemo.collections.asr.metrics.wer import WER from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel from nemo.collections.asr.parts.mixins import ASRModuleMixin, InterCTCMixin +from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecoding, CTCDecodingConfig -from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType from nemo.collections.multimodal.speech_cv.data import video_to_text_dataset from nemo.core.classes.common import PretrainedModelInfo, typecheck from nemo.core.classes.mixins import AccessMixin @@ -210,7 +210,9 @@ def transcribe( hypotheses.append(lg.cpu().numpy()) else: current_hypotheses, all_hyp = self.decoding.ctc_decoder_predictions_tensor( - logits, decoder_lengths=logits_len, return_hypotheses=return_hypotheses, + logits, + decoder_lengths=logits_len, + return_hypotheses=return_hypotheses, ) if return_hypotheses: @@ -579,7 +581,9 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): ) transcribed_texts, _ = self.wer.decoding.ctc_decoder_predictions_tensor( - decoder_outputs=log_probs, decoder_lengths=encoded_len, return_hypotheses=False, + decoder_outputs=log_probs, + decoder_lengths=encoded_len, + return_hypotheses=False, ) sample_id = sample_id.cpu().detach().numpy() @@ -598,7 +602,12 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0): log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len ) loss_value, metrics = self.add_interctc_losses( - loss_value, transcript, transcript_len, compute_wer=True, log_wer_num_denom=True, log_prefix="val_", + loss_value, + transcript, + transcript_len, + compute_wer=True, + log_wer_num_denom=True, + log_prefix="val_", ) self.wer.update( diff --git a/nemo/collections/multimodal/speech_cv/models/visual_hybrid_rnnt_ctc_models.py b/nemo/collections/multimodal/speech_cv/models/visual_hybrid_rnnt_ctc_models.py index 07dc46d3e061c..1b30263985daf 100644 --- a/nemo/collections/multimodal/speech_cv/models/visual_hybrid_rnnt_ctc_models.py +++ b/nemo/collections/multimodal/speech_cv/models/visual_hybrid_rnnt_ctc_models.py @@ -26,8 +26,8 @@ from nemo.collections.asr.losses.ctc import CTCLoss from nemo.collections.asr.metrics.wer import WER from nemo.collections.asr.parts.mixins import ASRBPEMixin, InterCTCMixin +from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecoding, CTCDecodingConfig -from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType from nemo.collections.multimodal.speech_cv.models.visual_rnnt_models import VisualEncDecRNNTModel from nemo.core.classes.common import PretrainedModelInfo from nemo.core.classes.mixins import AccessMixin @@ -178,7 +178,9 @@ def transcribe( logits = self.ctc_decoder(encoder_output=encoded) best_hyp, all_hyp = self.ctc_decoding.ctc_decoder_predictions_tensor( - logits, encoded_len, return_hypotheses=return_hypotheses, + logits, + encoded_len, + return_hypotheses=return_hypotheses, ) if return_hypotheses: # dump log probs per file @@ -550,7 +552,12 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0): # Add interCTC losses ctc_loss, interctc_tensorboard_logs = self.add_interctc_losses( - ctc_loss, transcript, transcript_len, compute_wer=True, log_wer_num_denom=True, log_prefix="val_", + ctc_loss, + transcript, + transcript_len, + compute_wer=True, + log_wer_num_denom=True, + log_prefix="val_", ) tensorboard_logs.update(interctc_tensorboard_logs) @@ -559,7 +566,10 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0): loss_value = (1 - self.ctc_loss_weight) * loss_value + self.ctc_loss_weight * ctc_loss tensorboard_logs['val_loss'] = loss_value self.ctc_wer.update( - predictions=log_probs, targets=transcript, target_lengths=transcript_len, predictions_lengths=encoded_len, + predictions=log_probs, + targets=transcript, + target_lengths=transcript_len, + predictions_lengths=encoded_len, ) ctc_wer, ctc_wer_num, ctc_wer_denom = self.ctc_wer.compute() self.ctc_wer.reset() diff --git a/nemo/collections/multimodal/speech_cv/models/visual_rnnt_models.py b/nemo/collections/multimodal/speech_cv/models/visual_rnnt_models.py index f5519b4808281..5a86eed93019c 100644 --- a/nemo/collections/multimodal/speech_cv/models/visual_rnnt_models.py +++ b/nemo/collections/multimodal/speech_cv/models/visual_rnnt_models.py @@ -30,8 +30,8 @@ from nemo.collections.asr.models.asr_model import ASRModel from nemo.collections.asr.modules.rnnt import RNNTDecoderJoint from nemo.collections.asr.parts.mixins import ASRModuleMixin +from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecoding, RNNTDecodingConfig -from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType from nemo.collections.multimodal.speech_cv.data import video_to_text_dataset from nemo.core.classes import Exportable from nemo.core.classes.common import PretrainedModelInfo, typecheck @@ -89,7 +89,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Setup decoding objects self.decoding = RNNTDecoding( - decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, vocabulary=self.joint.vocabulary, + decoding_cfg=self.cfg.decoding, + decoder=self.decoder, + joint=self.joint, + vocabulary=self.joint.vocabulary, ) # Setup WER calculation self.wer = WER( @@ -364,7 +367,10 @@ def change_vocabulary(self, new_vocabulary: List[str], decoding_cfg: Optional[Di decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) self.decoding = RNNTDecoding( - decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, vocabulary=self.joint.vocabulary, + decoding_cfg=decoding_cfg, + decoder=self.decoder, + joint=self.joint, + vocabulary=self.joint.vocabulary, ) self.wer = WER( @@ -419,7 +425,10 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig): decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) self.decoding = RNNTDecoding( - decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, vocabulary=self.joint.vocabulary, + decoding_cfg=decoding_cfg, + decoder=self.decoder, + joint=self.joint, + vocabulary=self.joint.vocabulary, ) self.wer = WER( diff --git a/nemo/collections/multimodal/speech_llm/data/audio_text_dataset.py b/nemo/collections/multimodal/speech_llm/data/audio_text_dataset.py index 94d2cd50a240b..a433a5a6badfe 100644 --- a/nemo/collections/multimodal/speech_llm/data/audio_text_dataset.py +++ b/nemo/collections/multimodal/speech_llm/data/audio_text_dataset.py @@ -29,7 +29,7 @@ ) from nemo.collections.asr.data.audio_to_text_dataset import ConcatDataset, convert_to_config_list, get_chain_dataset from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer -from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType +from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType from nemo.collections.common.parts.preprocessing import collections from nemo.collections.multimodal.speech_llm.parts.utils.data_utils import ( TextProcessing, diff --git a/nemo/collections/nlp/data/information_retrieval/gpt_embedding_dataset.py b/nemo/collections/nlp/data/information_retrieval/gpt_embedding_dataset.py index e697d5ec3bf64..3a2a8152313e1 100644 --- a/nemo/collections/nlp/data/information_retrieval/gpt_embedding_dataset.py +++ b/nemo/collections/nlp/data/information_retrieval/gpt_embedding_dataset.py @@ -27,7 +27,7 @@ from nemo.core.classes import Dataset from nemo.utils import logging -__all__ = ['GPTEmbeddingDataset'] +__all__ = ['GPTEmbeddingDataset', 'GPTRerankerDataset'] class GPTEmbeddingDataset(Dataset): @@ -49,7 +49,7 @@ def __init__( data_type: str = 'train', # train, query or doc ): """ - file_path: Path to a JSONL dataset with (query,pos_doc,neg_doc) triplets in jsonl format. + file_path: Path to a JSONL dataset with (query,pos_doc,neg_doc) triplets in jsonl format. tokenizer: Tokenizer for the dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece). max_seq_length (int): maximum sequence length for each dataset examples. Examples will either be truncated to fit this length or dropped if they cannot be truncated. min_seq_length (int): min length of each data example in the dataset. Data examples will be dropped if they do not meet the min length requirements. @@ -279,3 +279,138 @@ def collate_fn(self, batch): } return processed_batch + + +class GPTRerankerDataset(GPTEmbeddingDataset): + def __init__( + self, + file_path: str, + tokenizer: TokenizerSpec, + max_seq_length: int = 1024, + min_seq_length: int = 1, + add_bos: bool = False, + add_eos: bool = True, + max_num_samples: int = None, + seed: int = 1234, + index_mapping_dir: str = None, + virtual_tokens: int = 0, + memmap_workers: Optional[int] = None, + truncation_method: str = 'right', + special_tokens: Optional[Mapping[str, str]] = None, # special tokens, a dictory of {token_type: token} + data_type: str = 'train', # train, query or doc + ): + """ + file_path: Path to a JSONL dataset with (query,pos_doc,neg_doc) triplets in jsonl format. + tokenizer: Tokenizer for the dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece). + max_seq_length (int): maximum sequence length for each dataset examples. Examples will either be truncated to fit this length or dropped if they cannot be truncated. + min_seq_length (int): min length of each data example in the dataset. Data examples will be dropped if they do not meet the min length requirements. + add_bos (bool): Whether to add a beginning of sentence token to each data example + add_eos (bool): Whether to add an end of sentence token to each data example + seed: Random seed for data shuffling. + max_num_samples: Maximum number of samples to load. This can be > dataset length if you want to oversample data. If None, all samples will be loaded. + index_mapping_dir: Directory to save the index mapping to. If None, will write to the same folder as the dataset. + truncation_method: Truncation from which position. Options: ['left', 'right'] + special_tokens: special tokens for the chat prompts, a dictionary of {token_type: token}. Default: {'system_turn_start': '', 'turn_start': '', 'label_start': '', 'end_of_turn': '\n', "end_of_name": "\n"} + """ + super().__init__( + file_path=file_path, + tokenizer=tokenizer, + max_seq_length=max_seq_length, + min_seq_length=min_seq_length, + add_bos=add_bos, + add_eos=add_eos, + max_num_samples=max_num_samples, + seed=seed, + index_mapping_dir=index_mapping_dir, + virtual_tokens=virtual_tokens, + memmap_workers=memmap_workers, + truncation_method=truncation_method, + special_tokens=special_tokens, + data_type=data_type, + ) + + def _process_example(self, example): + """ + Create an example by concatenating text and answer. + Truncation is carried out when needed, but it is performed only on the prompt side. + BOS, EOS, and SEP, are added if specified. + """ + metadata = {k: v for k, v in example.items()} + if self.data_type == 'train': + qd = self.tokenizer.text_to_ids( + "query: " + example['query'].strip() + " passage: " + example['pos_doc'].strip() + ) + qnd = self.tokenizer.text_to_ids( + "query: " + example['query'].strip() + " passage: " + example['neg_doc'].strip() + ) + else: + qd = self.tokenizer.text_to_ids( + "query: " + example['query'].strip() + " passage: " + example['pos_doc'].strip() + ) + qnd = [] + + if self.virtual_tokens: + # (@adithyare) we are going to insert "pad/eos" tokens in the beginning of the text and context + # these pad/eos tokens are placeholders for virtual tokens for ptuning (if used) + qd = [self.tokenizer.eos_id] * self.virtual_tokens + qd # type: ignore + qnd = [self.tokenizer.eos_id] * self.virtual_tokens + qnd # type: ignore + + if self.add_bos: + qd = [self.tokenizer.bos_id] + qd # type: ignore + qnd = [self.tokenizer.bos_id] + qnd # type: ignore + + # TODO: (@adithyare) should probably add a warning before truncation + qd = qd[: self.max_seq_length - 1] + qnd = qnd[: self.max_seq_length - 1] + + if self.add_eos: + qd = qd + [self.tokenizer.eos_id] # type: ignore + qnd = qnd + [self.tokenizer.eos_id] # type: ignore + + processed_example = { + 'query_pos_doc': qd, + 'query_neg_doc': qnd, + 'metadata': metadata, + } + + return processed_example + + def collate_fn(self, batch): + input_ids = [] + metadata = [] + lengths = [] + max_length = -1 + for item in batch: + metadata.append(item['metadata']) + if self.data_type == 'train': + input_ids.append(item['query_pos_doc']) + lengths.append(len(item['query_pos_doc'])) + input_ids.append(item['query_neg_doc']) + lengths.append(len(item['query_neg_doc'])) + max_length = max(max_length, len(item['query_pos_doc']), len(item['query_neg_doc'])) + else: + input_ids.append(item['query_pos_doc']) + lengths.append(len(item['query_pos_doc'])) + max_length = max(max_length, len(item['query_pos_doc'])) + + max_length = min(self.max_seq_length, self._ceil_to_nearest(max_length, 16)) + assert max_length <= self.max_seq_length + + attention_mask = [self._create_attention_mask(max_length) for _ in input_ids] + attention_mask = torch.stack(attention_mask) + position_ids = [list(range(max_length)) for _ in input_ids] + position_ids = torch.LongTensor(position_ids) + input_ids = torch.LongTensor( + self._collate_item(input_ids, max_length=max_length, pad_id=self.tokenizer.eos_id) + ) + lengths = torch.LongTensor(lengths) - 1 # subtract 1 to account for the eos token + + processed_batch = { + 'tokens': input_ids, + 'attention_mask': attention_mask, + 'loss_mask': lengths, + 'position_ids': position_ids, + 'metadata': metadata, + } + + return processed_batch diff --git a/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py index 0f8d3410398d5..7d604c0b51bcc 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py @@ -122,7 +122,11 @@ def __getitem__(self, idx): def build_train_valid_test_datasets( - cfg, retro_config: RetroConfig, train_valid_test_num_samples, seq_length, tokenizer, + cfg, + retro_config: RetroConfig, + train_valid_test_num_samples, + seq_length, + tokenizer, ): # gpt dataset @@ -135,7 +139,10 @@ def build_train_valid_test_datasets( } retro_train_ds, retro_valid_ds, retro_test_ds = get_retro_datasets( - config=retro_config, gpt_datasets=gpt_datasets, sample_length=seq_length, eod_token_id=tokenizer.eos_id, + config=retro_config, + gpt_datasets=gpt_datasets, + sample_length=seq_length, + eod_token_id=tokenizer.eos_id, ) train_ds = ( diff --git a/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py b/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py index 67fd2b1b6c62e..c7565f45358eb 100644 --- a/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py +++ b/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py @@ -36,11 +36,6 @@ except (ImportError, ModuleNotFoundError): HAVE_MEGATRON_CORE = False -try: - - HAVE_APEX = True -except (ImportError, ModuleNotFoundError): - HAVE_APEX = False def listify(tensor): @@ -52,6 +47,17 @@ def listify(tensor): return l_tensor +def _gather_global_inbatch_representations(local_eos_tensor): + local_eos_tensor = local_eos_tensor.contiguous() + global_eos_tensors = [ + torch.zeros_like(local_eos_tensor) for _ in range(parallel_state.get_data_parallel_world_size()) + ] + torch.distributed.all_gather(global_eos_tensors, local_eos_tensor, group=parallel_state.get_data_parallel_group()) + global_eos_tensors[parallel_state.get_data_parallel_rank()] = local_eos_tensor + global_eos_tensors = torch.cat(global_eos_tensors, dim=0) + return global_eos_tensors + + class MegatronGPTEmbeddingModel(MegatronGPTSFTModel): def __init__(self, cfg: DictConfig, trainer: Trainer): super().__init__(cfg, trainer=trainer) @@ -412,25 +418,20 @@ def inference_loss_func(self, loss_mask, num_valid_tokens_in_ub, eos_tensors): hs = eos_tensors hs = torch.nn.functional.normalize(hs, dim=1) _blank = torch.zeros(1, device=hs.device, dtype=hs.dtype)[0] - return _blank, hs, hs, _blank, _blank, _blank - - def _gather_global_inbatch_representations(self, local_eos_tensor): - local_eos_tensor = local_eos_tensor.contiguous() - global_eos_tensors = [ - torch.zeros_like(local_eos_tensor) for _ in range(parallel_state.get_data_parallel_world_size()) - ] - torch.distributed.all_gather( - global_eos_tensors, local_eos_tensor, group=parallel_state.get_data_parallel_group() - ) - global_eos_tensors[parallel_state.get_data_parallel_rank()] = local_eos_tensor - global_eos_tensors = torch.cat(global_eos_tensors, dim=0) - return global_eos_tensors + return { + "loss": _blank, + "query_hs": hs, + "pos_doc_hs": hs, + "pos_cs": _blank, + "neg_cs": _blank, + "diff_cs": _blank, + } def loss_func(self, loss_mask, num_valid_tokens_in_ub, output_tensor): idx = torch.arange(output_tensor.shape[1], device=output_tensor.device) eos_tensors = output_tensor[loss_mask, idx, :] if self.global_inbatch_negatives and self.trainer.training: - eos_tensors = self._gather_global_inbatch_representations(eos_tensors) + eos_tensors = _gather_global_inbatch_representations(eos_tensors) if not self.trainer.training: return self.inference_loss_func(loss_mask, num_valid_tokens_in_ub, eos_tensors) bs = eos_tensors.shape[0] // 3 @@ -464,4 +465,11 @@ def loss_func(self, loss_mask, num_valid_tokens_in_ub, output_tensor): query_hs = query_hs.clone().detach() pos_doc_hs = pos_doc_hs.clone().detach() diff_cs = pos_cs - neg_cs - return loss, query_hs, pos_doc_hs, pos_cs, neg_cs, diff_cs + return { + "loss": loss, + "query_hs": query_hs, + "pos_doc_hs": pos_doc_hs, + "pos_cs": pos_cs, + "neg_cs": neg_cs, + "diff_cs": diff_cs, + } diff --git a/nemo/collections/nlp/models/information_retrieval/megatron_gpt_reranker_model.py b/nemo/collections/nlp/models/information_retrieval/megatron_gpt_reranker_model.py new file mode 100644 index 0000000000000..e316871fe6070 --- /dev/null +++ b/nemo/collections/nlp/models/information_retrieval/megatron_gpt_reranker_model.py @@ -0,0 +1,301 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import os + +import numpy as np +import torch +from omegaconf import DictConfig, ListConfig +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.collections.nlp.data.information_retrieval.gpt_embedding_dataset import GPTRerankerDataset +from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import ( + get_datasets_weights_and_num_samples, +) +from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset +from nemo.collections.nlp.models.information_retrieval.megatron_gpt_embedding_model import ( + MegatronGPTEmbeddingModel, + _gather_global_inbatch_representations, +) +from nemo.utils import logging + +try: + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + + +def listify(tensor): + l_tensor = [] + for t in tensor: + for rid in range(t.shape[0]): + r = t[rid, :].unsqueeze(0).cpu() + l_tensor.append(r) + return l_tensor + + +class MegatronGPTRerankerModel(MegatronGPTEmbeddingModel): + def __init__(self, cfg: DictConfig, trainer: Trainer): + self.reward_model_loss = cfg.get("reward_model_loss", False) + super().__init__(cfg, trainer=trainer) + + def model_provider_func(self, pre_process, post_process): + # (@adithyare) We need post_process to be False to get hidden states in the loss_func + return super().model_provider_func(pre_process, post_process=False) + + def maybe_setup_test(self): + if hasattr(self.cfg.data, 'test_ds') and self.cfg.data.test_ds.get('file_names', None) is not None: + self._test_dl = self.setup_eval_dataloader(self._test_ds, self.cfg.data.test_ds) + return + + def maybe_build_test(self): + if hasattr(self.cfg.data, 'test_ds') and self.cfg.data.test_ds.get('file_names', None) is not None: + logging.info('Building GPT Reranker test datasets.') + # Wrap this in a list since the general finetuning parent class supports multi-validation. + self._test_ds = self._build_dataset(self.cfg.data.test_ds, is_train=False) + + def _build_dataset(self, data_cfg, is_train=True): + packed_sequence = data_cfg.get("packed_sequence", False) + + # Determine if we are using a single dataset or a list of datasets. + if is_train: + # Construct the data prefix list for `get_datasets_weights_and_num_samples()` + # that is of the format [weight1,file_name1,weight2,file_name2,...] + if data_cfg.concat_sampling_probabilities is None or not isinstance( + data_cfg.concat_sampling_probabilities, ListConfig + ): + raise ValueError( + ( + f"concat_sampling_probabilities must be a ListConfig with the same number of files in file_names." + f"Found: {data_cfg.concat_sampling_probabilities}" + ) + ) + + if len(data_cfg.get('concat_sampling_probabilities', None)) != len(data_cfg.file_names): + raise ValueError( + ( + f"concat_sampling_probabilities must be of the same size as file_names.", + f"Provided size {len(data_cfg.concat_sampling_probabilities)}, number of datasets {len(data_cfg.file_names)}", + ) + ) + + data_prefix = [] + for weight, prefix in zip(data_cfg.concat_sampling_probabilities, data_cfg.file_names): + data_prefix.append(weight) + data_prefix.append(prefix) + + if self.trainer.max_steps is None or self.trainer.max_steps <= 0: + raise ValueError( + f'Trainer max_steps must be set to a positive integer. Found {self.trainer.max_steps}' + ) + num_train_samples = [self.trainer.max_steps * data_cfg.global_batch_size] + _, _, num_train_samples_per_dataset = get_datasets_weights_and_num_samples(data_prefix, num_train_samples) + num_train_samples_after_blend = sum([x[0] for x in num_train_samples_per_dataset]) + else: + num_train_samples_per_dataset = [[None]] * len(data_cfg.file_names) + + # Check dataset max_seq_legnth and max_position_embeddings size + if ( + self.cfg.get('position_embedding_type', None) in [None, 'learned_absolute'] + and data_cfg.max_seq_length > self.cfg.max_position_embeddings + ): + logging.warning( + f"Set dataset max_seq_length to max_position_embeddings {self.cfg.max_position_embeddings} if using learned_absolute position embedding" + ) + data_cfg.max_seq_length = self.cfg.max_position_embeddings + + # TE requires that the first input dim is divisible by 8 and the second by 16 for fp8 + # When using sequence parallel, sequence will further be split by TP size + pad_seq_length_to_mult = ( + 8 * self.cfg.get('tensor_model_parallel_size', 1) if self.cfg.get('sequence_parallel', False) else 16 + ) + pad_seq_length_to_mult *= self.cfg.get('context_parallel_size', 1) + + datasets = [] + for file_path, num_samples in zip(data_cfg.file_names, num_train_samples_per_dataset): + dataset = GPTRerankerDataset( + file_path=file_path, + tokenizer=self.tokenizer, + max_seq_length=data_cfg.max_seq_length, + min_seq_length=data_cfg.min_seq_length, + add_bos=data_cfg.get('add_bos', False), + add_eos=data_cfg.get('add_eos', True), + max_num_samples=num_samples[0], + seed=data_cfg.get('seed', 1234), + index_mapping_dir=data_cfg.get('index_mapping_dir', None), + virtual_tokens=self.virtual_tokens, + memmap_workers=data_cfg.get( + 'memmap_workers', None + ), # used to set num. of workers to create the memmap index files + truncation_method=data_cfg.get( + 'truncation_method', 'right' + ), # used to choose truncation method. Options: ['random', 'left', 'right'] + special_tokens=self.cfg.data.get( + 'chat_prompt_tokens', None + ), # special tokens for the chat prompts, a dictionary of {token_type: token}. Default: {'system_turn_start': '', 'turn_start': '', 'label_start': '', 'end_of_turn': '\n', "end_of_name": "\n"} + data_type="train" if is_train else "validation", + ) + datasets.append(dataset) + if is_train: + if packed_sequence: + num_train_samples_after_blend = sum(len(dataset) for dataset in datasets) + dataset = BlendableDataset( + datasets=datasets, weights=data_cfg.concat_sampling_probabilities, size=num_train_samples_after_blend + ) + return dataset + else: + return datasets + + def training_step_fwd_bwd_step_call(self, dataloader_iter, forward_only): + loss_mean, non_loss_tensors = self.fwd_bwd_step(dataloader_iter, forward_only) + logit_diff = non_loss_tensors['logit_diff'][0].item() + self.log("logit_diff", logit_diff, prog_bar=True, rank_zero_only=True, batch_size=1) + return loss_mean + + def inference_step_validation_call(self, batch, batch_idx, data_cfg, dataloader_idx=0): + metadata = batch.get('metadata', [{}] * len(batch['tokens'])) + loss, non_loss_tensors = self.local_validation_step(itertools.chain([dataloader_idx], [batch])) + outputs = { + 'loss': loss, + 'metadata': metadata, # [dict] + 'query_pos_doc_logit': non_loss_tensors['query_pos_doc_logit'], # [batch_size, hidden_size] + } + return outputs + + def inference_loss_func(self, loss_mask, num_valid_tokens_in_ub, eos_tensors): + query_pos_doc_hs = eos_tensors + _blank = torch.zeros(1, device=query_pos_doc_hs.device, dtype=query_pos_doc_hs.dtype)[0] + return { + "loss": _blank, + "query_pos_doc_logit": query_pos_doc_hs, + "query_neg_doc_logit": _blank, + "logit_diff": _blank, + } + + def loss_func(self, loss_mask, num_valid_tokens_in_ub, output_tensor): + idx = torch.arange(output_tensor.shape[1], device=output_tensor.device) + eos_tensors = output_tensor[loss_mask, idx, :] # (bs x 1) + if self.global_inbatch_negatives and self.trainer.training: + eos_tensors = _gather_global_inbatch_representations(eos_tensors) + if not self.trainer.training: + return self.inference_loss_func(loss_mask, num_valid_tokens_in_ub, eos_tensors) + bs = eos_tensors.shape[0] // 2 + query_pos_doc_hs = eos_tensors[::2, :] # every second tensor from idx 0 is a query w pos_doc (bs x 1) + query_neg_doc_hs = eos_tensors[1::2, :] # every second tensor from idx 1 is a query w negative doc (bs x 1) + + if self.reward_model_loss: + loss = -torch.nn.functional.logsigmoid(query_pos_doc_hs - query_neg_doc_hs).mean() + else: + cs = torch.cat([query_pos_doc_hs, query_neg_doc_hs], dim=1) # (bs x 2) + cs = cs / self.temperature + labels = torch.zeros(bs, device=cs.device).long() + loss = torch.nn.functional.cross_entropy(cs, labels) + + cp_size = self.cfg.get('context_parallel_size', 1) + if cp_size > 1: + torch.distributed.all_reduce(loss, group=parallel_state.get_context_parallel_group()) + query_pos_doc_hs = query_pos_doc_hs.clone().detach() + query_neg_doc_hs = query_neg_doc_hs.clone().detach() + logit_diffs = torch.mean(query_pos_doc_hs - query_neg_doc_hs) + return { + "loss": loss, + "query_pos_doc_logit": query_pos_doc_hs, + "query_neg_doc_logit": query_neg_doc_hs, + "logit_diff": logit_diffs, + } + + def gather_and_maybe_write_predictions(self, output, data_cfg, mode, averaged_metric, dataloader_idx=0): + if not data_cfg.get("write_embeddings_to_file", False): + return True + gathered_output_batches = [None for _ in range(parallel_state.get_data_parallel_world_size())] + torch.distributed.all_gather_object( + gathered_output_batches, + [ + { + 'query_pos_doc_logit': batch['query_pos_doc_logit'], + 'metadata': batch['metadata'], + } + for batch in output + ], + group=parallel_state.get_data_parallel_group(), + ) + + # Remove duplicate examples due to distributed sampler. + deduplicated_outputs = { + 'query_pos_doc_logit': [], + 'metadata': [], + } + total_size, skipped = 0, 0 + for rank in range(0, parallel_state.get_data_parallel_world_size()): + for batch in gathered_output_batches[rank]: + l_q_hs = listify(batch['query_pos_doc_logit']) + l_m = batch['metadata'] + assert len(l_m) == len(l_q_hs) + for q_hs, metadata in zip( + l_q_hs, + l_m, + ): + total_size += 1 + if not metadata.get("__AUTOGENERATED__", False): + deduplicated_outputs['query_pos_doc_logit'].append(q_hs) + deduplicated_outputs['metadata'].append(metadata) + else: + skipped += 1 + + logging.info( + f"{total_size-skipped} deduplicated outputs in dataloader:{dataloader_idx}, (skipped {skipped} autogenerated examples)." + ) + # Compute metric score + metric_name = self.val_metric_name if mode == 'validation' else self.test_metric_name + assert metric_name == "loss", "Only loss is supported for now." + # avg_pos_cs = torch.tensor(deduplicated_outputs['avg_pos_cs']).mean().item() + # avg_neg_cs = torch.tensor(deduplicated_outputs['avg_neg_cs']).mean().item() + # diff_cs = torch.tensor(deduplicated_outputs['diff_cs']).mean().item() + # self.log('val_avg_pos_cs', avg_pos_cs, prog_bar=True, rank_zero_only=True, batch_size=1) + # self.log('val_avg_neg_cs', avg_neg_cs, prog_bar=True, rank_zero_only=True, batch_size=1) + # self.log('val_diff_cs', diff_cs, prog_bar=True, rank_zero_only=True, batch_size=1) + + # Write predictions to file + if self.global_rank == 0 and data_cfg.get("write_embeddings_to_file", False): + logging.info( + f"Total deduplicated inference data size: {total_size} to {len(deduplicated_outputs['metadata'])}" + ) + + # Check if the user provided a prefix path to the file(s) they want to write. + if not hasattr(data_cfg, "output_file_path_prefix") or data_cfg.output_file_path_prefix is None: + raise ValueError( + f"Cannot write predictions to file when output_file_path_prefix is not set or present in the yaml config file." + ) + # (@adithyare) We are not using the log key to write the embeddings to file + filename_log_key = self._determine_log_key(data_cfg, dataloader_idx, None, mode) + consumed_samples = self._compute_consumed_samples_after_training_step() + fldr_path = f"{data_cfg.output_file_path_prefix}/consumed_samples{consumed_samples}/{filename_log_key}" + self.write_embeddings_to_file(deduplicated_outputs, fldr_path, dataloader_idx) + return deduplicated_outputs, total_size + + def write_embeddings_to_file(self, outputs, output_file_path, d_idx): + hs = torch.cat(outputs['query_pos_doc_logit'], dim=0) + hs_npy = hs.float().numpy() + emb_fldr = f"{output_file_path}" + os.makedirs(emb_fldr, exist_ok=True) + with open(f"{output_file_path}/logits.ids", "w") as f: + for m in outputs['metadata']: + f.write(f"{m['query_id'].strip()} {m['doc_id']}\n") + np.save(f"{emb_fldr}/logits.npy", hs_npy) + return True diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_modelopt_spec.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_modelopt_spec.py index f9ba58736cbd3..f001e8f58d259 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_modelopt_spec.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_modelopt_spec.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults + try: from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear @@ -21,6 +23,7 @@ from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP, MLPSubmodules + from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules @@ -38,7 +41,7 @@ # Use this spec for Model Optimizer PTQ and TensorRT-LLM export -def get_gpt_layer_modelopt_spec() -> ModuleSpec: +def get_gpt_layer_modelopt_spec(num_experts: int = None) -> ModuleSpec: """Mix the native spec with TENorm. This is essentially the native local spec except for the layernorm implementation @@ -65,18 +68,38 @@ def get_gpt_layer_modelopt_spec() -> ModuleSpec: ), self_attn_bda=get_bias_dropout_add, pre_mlp_layernorm=TENorm, - mlp=ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=ColumnParallelLinear, - linear_fc2=RowParallelLinear, - ), - ), + mlp=_get_mlp_module_spec(num_experts=num_experts), mlp_bda=get_bias_dropout_add, # Map TE-layernorm-fusion keys back sharded_state_dict_keys_map={ 'input_layernorm.': 'self_attention.linear_qkv.layer_norm_', - 'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_', + **({'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_'} if num_experts is None else {}), }, ), ) + + +# Helper function to get module spec for MLP/MoE +def _get_mlp_module_spec(num_experts: int = None, moe_grouped_gemm: bool = False) -> ModuleSpec: + if num_experts is None: + # Dense MLP w/ or w/o TE modules. + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=ColumnParallelLinear, + linear_fc2=RowParallelLinear, + ), + ) + else: + # Mixture of experts with modules in megatron core. + return ModuleSpec( + module=MoELayer, + submodules=( + MLPSubmodules( + linear_fc1=ColumnParallelLinear, + linear_fc2=RowParallelLinear, + ) + if not moe_grouped_gemm + else None + ), + ) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 0828d88a81333..e1641a81c0dca 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -290,7 +290,11 @@ def _wrap_model_for_O2(self): Returns: The wrapped model. Returns a list of wrapped modules or a single wrapped module. """ - is_mcore_model = self.__dict__.get('mcore_gpt', False) or self.__dict__.get('mcore_bert', False) + is_mcore_model = ( + self.__dict__.get('mcore_gpt', False) + or self.__dict__.get('mcore_bert', False) + or self.__dict__.get('mcore_t5', False) + ) Float16Wrapper = MCoreFloat16Module if is_mcore_model else Float16Module @@ -305,15 +309,21 @@ def _wrap_model_for_O2(self): args = mcore_args if is_mcore_model else nemo_args # Model wrapper to convert both model and inputs to half precision - if isinstance(self.model, list): + if isinstance((self.enc_dec_model if hasattr(self, "enc_dec_model") else self.model), list): converted_model = [] - for module in self.model: + for module in self.enc_dec_model if hasattr(self, "enc_dec_model") else self.model: args['module'] = module converted_model.append(Float16Wrapper(**args)) - self.model = converted_model + if hasattr(self, "enc_dec_model"): + self.enc_dec_model = converted_model + else: + self.model = converted_model else: - args['module'] = self.model - self.model = Float16Wrapper(**args) + args['module'] = self.enc_dec_model if hasattr(self, "enc_dec_model") else self.model + if hasattr(self, "enc_dec_model"): + self.enc_dec_model = Float16Wrapper(**args) + else: + self.model = Float16Wrapper(**args) args.pop('module') def get_model_module_list(self): @@ -323,10 +333,10 @@ def extract_module(model): else: return model - if isinstance(self.model, list): - return list(map(extract_module, self.model)) + if isinstance((self.enc_dec_model if hasattr(self, "enc_dec_model") else self.model), list): + return list(map(extract_module, (self.enc_dec_model if hasattr(self, "enc_dec_model") else self.model))) else: - return [extract_module(self.model)] + return [extract_module(self.enc_dec_model if hasattr(self, "enc_dec_model") else self.model)] def _reconfigure_limit_batches(self, limit_batches, dataloader, mode): """ @@ -431,6 +441,7 @@ def _build_tokenizer(self): special_tokens=self.cfg.tokenizer.get('special_tokens', None), trust_remote_code=self.cfg.tokenizer.get('trust_remote_code', False), legacy=legacy, + chat_template=getattr(self._cfg.tokenizer, "chat_template", None), ) if self._cfg.tokenizer.get('additional_special_tokens', None) is not None: @@ -473,7 +484,7 @@ def build_transformer_config(self) -> TransformerConfig: activation = self.cfg.get('activation', 'gelu') gated_linear_unit = activation.endswith('glu') # TODO: need to check which activation functions are supported in mcore - activation_func = activation_to_func(activation) + activation_func = activation_to_func(activation, openai_gelu=self.cfg.get("openai_gelu", False)) normalization = self.cfg.get('normalization', 'LayerNorm') @@ -581,8 +592,7 @@ def _vocab_size_with_padding(self, orig_vocab_size, make_vocab_size_divisible_by after = orig_vocab_size multiple = make_vocab_size_divisible_by * tensor_model_parallel_size - while (after % multiple) != 0: - after += 1 + after = ((after + multiple - 1) // multiple) * multiple logging.info( f'Padded vocab_size: {after}, original vocab_size: {orig_vocab_size}, dummy tokens: {after - orig_vocab_size}.' ) @@ -846,7 +856,9 @@ def configure_optimizers(self): if hasattr(self._cfg.optim, 'sched'): sched_config = self._cfg.optim.sched self._scheduler = prepare_lr_scheduler( - optimizer=self._optimizer, scheduler_config=sched_config, train_dataloader=self._train_dl + optimizer=self._optimizer, + scheduler_config=sched_config, + train_dataloader=self._train_dl, ) if getattr(self._cfg.optim, 'sched', None) is not None and self._scheduler is None: @@ -1020,7 +1032,11 @@ def is_data_parallel_rank_zero(self): def _get_total_params_across_model_parallel_groups_gpt_bert(self): """Returns the total number of parameters across all model parallel groups.""" - is_mcore_model = self.__dict__.get('mcore_gpt', False) or self.__dict__.get('mcore_bert', False) + is_mcore_model = ( + self.__dict__.get('mcore_gpt', False) + or self.__dict__.get('mcore_bert', False) + or self.__dict__.get('mcore_t5', False) + ) # log number of parameters model = self.get_model_module_list() if isinstance(model, list): @@ -1255,6 +1271,8 @@ def find_frozen_submodules(model): # TODO: Currently the main parameter data type is kept in fp32 (when O2=False). This needs to be # extended to support lower precision main parameters. frozen_submodule_names, frozen_submodules = find_frozen_submodules(self.model) + for submodule in frozen_submodule_names: + logging.debug(f"Ignoring state {submodule} in FSDP.") self.trainer.strategy.kwargs['ignored_states'] = frozen_submodules # FSDP requires uniform status of require_grads # Diffusion models like SD has frozen parts and needs to be added to 'ignored_states' from sharding for FSDP to work diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index f603e853cb103..69cd06021f500 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -155,7 +155,7 @@ def get_specs(spec_name, num_experts=None, moe_grouped_gemm=False, use_te=True, "te_gpt": get_gpt_layer_with_transformer_engine_spec(num_experts, moe_grouped_gemm), "megatron_falcon_gpt": get_falcon_layer_spec(), "megatron_gpt_full_te_layer_autocast": get_gpt_full_te_layer_autocast_spec(), - "modelopt": get_gpt_layer_modelopt_spec(), + "modelopt": get_gpt_layer_modelopt_spec(num_experts), "te_gpt_hyena": get_gpt_layer_with_te_and_hyena_spec(hyena_cfg), } if spec_name not in name_spec_dict: @@ -300,6 +300,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.spec_name = cfg.get('name', '') if cfg.get('fp8', False): self.prev_step_training = True + self.continue_training = True if cfg.get("restore_from_ckpt") else False self.rampup_batch_size = self.cfg.get('rampup_batch_size', None) if self.rampup_batch_size: @@ -390,13 +391,21 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.log_memory_usage = bool(int(os.getenv("NEMO_LOG_MEMORY_USAGE", 0))) self.loss_broadcast_src_rank = None data_cfg = cfg.get('data', {}) - self.return_output_tensors = data_cfg.get('return_output_tensors', False) self.validation_drop_last = data_cfg.get('validation_drop_last', True) self.sample_weight = data_cfg.get('sample_weight', 'token') self.validation_param_sync_overlap = self.cfg.get('validation_param_sync_overlap', False) self.inference_params = None + # Reset learning rate params + self.if_init_step = True + self.reset_lr = self.cfg.get('reset_lr', False) + self.reset_lr_steps = self.cfg.get('reset_lr_steps', False) + if self.reset_lr and (not self.with_distributed_adam or not self.megatron_amp_O2): + raise ValueError( + 'Learning rate reset feature is only supported with the distributed optmizer and megatron_amp_O2 for now.' + ) + # default to false since this doesn't work with sequence parallelism currently self.use_loss_mask = self.cfg.get('use_loss_mask', False) @@ -763,6 +772,20 @@ def training_step(self, dataloader_iter): if self.initialize_ub: self.initialize_ub_func() + # Reset learning rate + if self.if_init_step and self.reset_lr: + num_groups = len(self._optimizer.param_groups) + for group in range(num_groups): + self._optimizer.param_groups[group]['lr'] = ( + 0.0 if self.cfg.optim.sched.warmup_steps > 0 else self.cfg.optim.lr + ) + self._optimizer.param_groups[0]['reset_lr'] = { + 'num_steps': self.trainer.global_step, + 'reset_lr_steps': True if self.reset_lr_steps else False, + 'if_init_step': self.if_init_step, + } + self.if_init_step = False + if self.rampup_batch_size: num_microbatch_calculator = apex.transformer.pipeline_parallel.utils._GLOBAL_NUM_MICROBATCHES_CALCULATOR current_global_batch_size = num_microbatch_calculator.current_global_batch_size @@ -1251,24 +1274,47 @@ def loss_func(output_tensor): # Loss for a micro-batch (ub) loss_for_ub = self.loss_func(batch['loss_mask'], batch['num_valid_tokens_in_ub'], output_tensor) cp_size = parallel_state.get_context_parallel_world_size() - if self.return_output_tensors: + if isinstance(loss_for_ub, dict): # TODO: need a better way to check if loss_func is returning more stuff than just loss... (@adithyare) - loss_for_ub, q_hs, d_hs, pos_cs, neg_cs, diff_cs = loss_for_ub - reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) - pos_cs = average_losses_across_data_parallel_group([pos_cs]) - neg_cs = average_losses_across_data_parallel_group([neg_cs]) - diff_cs = average_losses_across_data_parallel_group([diff_cs]) - return ( - loss_for_ub * cp_size, - { - 'avg': reduced_loss, - 'query_hs': q_hs, - 'doc_hs': d_hs, - 'avg_pos_cs': pos_cs, - 'avg_neg_cs': neg_cs, - 'diff_cs': diff_cs, - }, - ) + + if set(loss_for_ub.keys()) == set( + ["loss", "query_hs", "pos_doc_hs", "pos_cs", "neg_cs", "diff_cs"] + ): # (adithyare) this check will be True for GPT Embedding models + loss = loss_for_ub['loss'] + reduced_loss = average_losses_across_data_parallel_group([loss]) + pos_cs = average_losses_across_data_parallel_group([loss_for_ub['pos_cs']]) + neg_cs = average_losses_across_data_parallel_group([loss_for_ub['neg_cs']]) + diff_cs = average_losses_across_data_parallel_group([loss_for_ub['diff_cs']]) + return ( + loss * cp_size, + { + 'avg': reduced_loss, + 'query_hs': loss_for_ub['query_hs'], + 'doc_hs': loss_for_ub['pos_doc_hs'], + 'avg_pos_cs': pos_cs, + 'avg_neg_cs': neg_cs, + 'diff_cs': diff_cs, + }, + ) + elif set(loss_for_ub.keys()) == set( + ["loss", "query_pos_doc_logit", "query_neg_doc_logit", "logit_diff"] + ): # (adithyare) this check will be True for GPT Reranker models + + loss = loss_for_ub['loss'] + reduced_loss = average_losses_across_data_parallel_group([loss]) + logit_diff = average_losses_across_data_parallel_group([loss_for_ub['logit_diff']]) + return ( + loss * cp_size, + { + 'avg': reduced_loss, + 'query_pos_doc_logit': loss_for_ub['query_pos_doc_logit'], + 'query_neg_doc_logit': loss_for_ub['query_neg_doc_logit'], + 'logit_diff': logit_diff, + }, + ) + else: + raise RuntimeError(f"Dict loss_for_ub has unknown key set {loss_for_ub.keys()}") + elif validation_step and not self.validation_drop_last: num_valid_tokens_in_ub = batch['num_valid_tokens_in_ub'] if loss_for_ub.isnan(): @@ -1472,15 +1518,16 @@ def build_train_valid_test_datasets(self): # E = argmin_e e * N_d >= N, or equivalently E = ceildiv(N, N_d) # Where N_d is the total number of samples in a dataset (files), and N is the requested number of samples (provided for every split in the list below). # Setting N = 1 we force E to be 1 as well + legacy_dataset = self.cfg.data.get("legacy_dataset", False) if self.trainer.limit_val_batches <= 1.0 and isinstance(self.trainer.limit_val_batches, float): - train_valid_test_num_samples[1] = None + train_valid_test_num_samples[1] = 1 if legacy_dataset else None # Add extra FIM tokens to tokenizer if self.cfg.data.get('add_fim', False) and self.cfg.tokenizer.library == 'megatron': fim_tokens = self.cfg.data.fim.extra_tokens fim_tokens = [fim_tokens.prefix, fim_tokens.middle, fim_tokens.suffix, fim_tokens.pad, fim_tokens.eod] self.tokenizer.add_special_tokens({'additional_special_tokens': fim_tokens}) - if self.cfg.data.get("legacy_dataset", False): + if legacy_dataset: self._train_ds, self._validation_ds, self._test_ds = build_train_valid_test_datasets( cfg=self.cfg, trainer=self.trainer, @@ -1611,7 +1658,7 @@ def setup(self, stage=None): ) resume_checkpoint_path = self.trainer.ckpt_path - if resume_checkpoint_path: + if resume_checkpoint_path and not self.continue_training: init_consumed_samples = self._extract_consumed_samples_from_ckpt(resume_checkpoint_path) else: init_consumed_samples = 0 diff --git a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py index 8fe215bcc9af6..6609b1aff3037 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py @@ -32,11 +32,13 @@ from nemo.collections.nlp.modules.common.megatron.build_model import build_model from nemo.collections.nlp.modules.common.megatron.module import Float16Module from nemo.collections.nlp.modules.common.megatron.token_level_encoder_decoder import ( + AttnMaskType, MegatronTokenLevelEncoderDecoderModule, ) from nemo.collections.nlp.modules.common.megatron.utils import ( ApexGuardDefaults, average_losses_across_data_parallel_group, + build_attention_mask_3d, get_params_for_weight_decay_optimization, ) from nemo.collections.nlp.modules.common.text_generation_utils import ( @@ -62,7 +64,16 @@ try: from megatron.core import parallel_state, tensor_parallel from megatron.core.enums import ModelType + from megatron.core.models.T5 import T5Model as MCoreT5Model + from megatron.core.models.T5.t5_spec import ( + get_t5_decoder_with_local_block_spec, + get_t5_decoder_with_transformer_engine_block_spec, + get_t5_encoder_with_local_block_spec, + get_t5_encoder_with_transformer_engine_block_spec, + ) from megatron.core.pipeline_parallel.schedules import get_forward_backward_func + from megatron.core.transformer.module import Float16Module as MCoreFloat16Module + from megatron.core.transformer.transformer_config import TransformerConfig HAVE_MEGATRON_CORE = True @@ -96,6 +107,13 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): # Make sure trainer.accumulate_grad_batches is 1. self._validate_trainer() + self.mcore_t5 = cfg.get('mcore_t5', False) + + if self.mcore_t5: + self.transformer_config = self.build_transformer_config() + + self.megatron_amp_O2 = cfg.get('megatron_amp_O2', False) + # TODO: Currently does not support interleaved pipeline parallelism. # This means we can only use pipeline parallelism without the interleaved schedule. if isinstance(self.trainer.accelerator, CPUAccelerator): @@ -116,18 +134,18 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): # We don't need to call it explicitly? Since it is a pytorch lightning hook function # self.setup_optimizer_param_groups() - self.megatron_amp_O2 = cfg.get('megatron_amp_O2', False) - if self.megatron_amp_O2: if not self.with_distributed_adam: # Pre-allocate the model on GPU to have master parameters allocated on the same device with matching data type - self.enc_dec_model.cuda(torch.cuda.current_device()) + if isinstance(self.enc_dec_model, list): + for module in self.enc_dec_model: + module.cuda(torch.cuda.current_device()) + else: + self.enc_dec_model.cuda(torch.cuda.current_device()) # Model wrapper to convert both model and inputs to half precision - self.enc_dec_model = Float16Module( - config=self.model_parallel_config, module=self.enc_dec_model, precision=self.cfg.precision - ) + self._wrap_model_for_O2() self.enable_autocast = ( True if (not self.megatron_amp_O2) and (self.autocast_dtype in [torch.float16, torch.bfloat16]) else False @@ -250,38 +268,74 @@ def model_provider_func(self, pre_process, post_process, add_encoder, add_decode if parallel_state.get_pipeline_model_parallel_world_size() > 1 and self.cfg.encoder.arch == 'perceiver': raise ValueError(f"Perceivers with pipeline parallel > 1 is not supported yet.") - if not hasattr(self.cfg, 'embedding_init_method_std'): - embedding_init_method_std = self.cfg.encoder.init_method_std - else: - embedding_init_method_std = self.cfg.embedding_init_method_std + if hasattr(self, 'mcore_t5') and self.mcore_t5: + assert HAVE_MEGATRON_CORE, "Cannot use MCore T5 since Megatron Core is not found" + assert self.cfg.get( + 'share_token_embeddings', True + ), "share_token_embeddings must be True if using MCore T5 model" + if self.cfg.get('transformer_engine', False): + enc_dec_spec_fns = ( + get_t5_encoder_with_transformer_engine_block_spec, + get_t5_decoder_with_transformer_engine_block_spec, + ) + else: + enc_dec_spec_fns = ( + get_t5_encoder_with_local_block_spec, + get_t5_decoder_with_local_block_spec, + ) + + en_block_spec = enc_dec_spec_fns[0](self.cfg.encoder.num_layers) + de_block_spec = enc_dec_spec_fns[1](self.cfg.decoder.num_layers) + model = MCoreT5Model( + config=self.transformer_config, + transformer_encoder_layer_spec=en_block_spec, + transformer_decoder_layer_spec=de_block_spec, + vocab_size=self.padded_vocab_size, + max_sequence_length=self.cfg.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=self.cfg.get('fp16_lm_cross_entropy', False), + parallel_output=True, + share_embeddings_and_output_weights=self.cfg.get('share_decoder_tokens_head_embeddings', True), + position_embedding_type=self.cfg.get('position_embedding_type', 'learned_absolute'), + rotary_percent=self.cfg.get('rotary_percentage', 1.0), + seq_len_interpolation_factor=self.cfg.get('seq_len_interpolation_factor', None), + ) - if not hasattr(self.cfg, 'embedding_dropout'): - embedding_dropout = self.cfg.encoder.hidden_dropout else: - embedding_dropout = self.cfg.embedding_dropout - - model = MegatronTokenLevelEncoderDecoderModule( - config=self.model_parallel_config, - encoder_cfg=self.cfg.encoder, - decoder_cfg=self.cfg.decoder, - vocab_size=self.padded_vocab_size, - max_position_embeddings=self.cfg.max_position_embeddings, - num_tokentypes=0, - parallel_output=True, - pre_process=pre_process, - post_process=post_process, - fp16_cross_entropy=self.cfg.get('fp16_lm_cross_entropy', False), - precision=self.cfg.get('precision', 16), - embedding_init_method_std=embedding_init_method_std, - embedding_dropout=embedding_dropout, - label_smoothing=self.cfg.get('label_smoothing', 0.0), - add_encoder=add_encoder, - add_decoder=add_decoder, - share_token_embeddings=self.cfg.get('share_token_embeddings', True), - share_decoder_tokens_head_embeddings=self.cfg.get('share_decoder_tokens_head_embeddings', True), - tokens_head_bias=self.cfg.get('tokens_head_bias', True), - hiddens_cfg=self.cfg.get('hiddens', None), - ) + if not hasattr(self.cfg, 'embedding_init_method_std'): + embedding_init_method_std = self.cfg.encoder.init_method_std + else: + embedding_init_method_std = self.cfg.embedding_init_method_std + + if not hasattr(self.cfg, 'embedding_dropout'): + embedding_dropout = self.cfg.encoder.hidden_dropout + else: + embedding_dropout = self.cfg.embedding_dropout + + model = MegatronTokenLevelEncoderDecoderModule( + config=self.model_parallel_config, + encoder_cfg=self.cfg.encoder, + decoder_cfg=self.cfg.decoder, + vocab_size=self.padded_vocab_size, + max_position_embeddings=self.cfg.max_position_embeddings, + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + fp16_cross_entropy=self.cfg.get('fp16_lm_cross_entropy', False), + precision=self.cfg.get('precision', 16), + embedding_init_method_std=embedding_init_method_std, + embedding_dropout=embedding_dropout, + label_smoothing=self.cfg.get('label_smoothing', 0.0), + add_encoder=add_encoder, + add_decoder=add_decoder, + share_token_embeddings=self.cfg.get('share_token_embeddings', True), + share_decoder_tokens_head_embeddings=self.cfg.get('share_decoder_tokens_head_embeddings', True), + tokens_head_bias=self.cfg.get('tokens_head_bias', True), + hiddens_cfg=self.cfg.get('hiddens', None), + ) + return model def forward( @@ -372,6 +426,25 @@ def training_step(self, dataloader_iter): # we zero grads here because we also call backward in the megatron fwd/bwd functions self._optimizer.zero_grad() + if self.with_distributed_adam: + # hack to enable overlapping param sync and forward compute + # note: the distributed optimizer monkey-patches each + # parameter's __getattribute__ function so that it can + # launch parameter all-gathers the first time the + # parameter is accessed after the optimizer step. However, + # PyTorch directly passes embedding parameters into a C++, + # bypassing this process. A quick-and-dirty hack is to + # manually interact with the parameter. + modules = self.enc_dec_model if isinstance(self.enc_dec_model, list) else [self.enc_dec_model] + for module in modules: + if isinstance(module, (Float16Module, MCoreFloat16Module)): + module = module.module + if not self.mcore_t5: + module = module.language_model + if hasattr(module, 'embedding'): + for param in module.embedding.parameters(): + param.data_ptr() + loss_dict = self.fwd_bwd_step(dataloader_iter, False) if self.with_distributed_adam: @@ -380,8 +453,12 @@ def training_step(self, dataloader_iter): # from multiple simultaneous NCCL calls self._optimizer._finish_bucket_grad_sync() elif self.megatron_amp_O2: - # when using pipeline parallelism grads must be reduced after the pipeline (not asynchronously) - if self.cfg.get('pipeline_model_parallel_size', 1) > 1: + # when using pipeline parallelism grads must be all-reduced after the pipeline (not asynchronously) + if ( + self.cfg.get('pipeline_model_parallel_size', 1) > 1 + or self.cfg.get('sequence_parallel', False) + or not self.cfg.get('async_grad_allreduce', True) + ): # main grads are stored in the MainParamsOptimizer wrapper self._optimizer.allreduce_main_grads() else: @@ -596,15 +673,37 @@ def fwd_output_and_loss_func(dataloader_iter, model): batch_data, ) = batch - output = model( - encoder_input_ids, # enc_input_ids - encoder_attn_mask, # enc_attn_mask - decoder_input_ids, # dec_input_ids - decoder_attn_mask, # dec_attn_mask - None, # token_type_ids - lm_labels, # labels - batch_data, # batch_data - ) + if self.mcore_t5: + # attn mask logic follows megatron.data.t5_dataset.py in Megatron-LM + encoder_attn_mask_3d = build_attention_mask_3d( + encoder_attn_mask, encoder_attn_mask, AttnMaskType.padding + ) + decoder_attn_mask_3d = build_attention_mask_3d( + decoder_attn_mask, decoder_attn_mask, AttnMaskType.causal + ) + enc_dec_attn_mask_3d = build_attention_mask_3d( + decoder_attn_mask, encoder_attn_mask, AttnMaskType.padding + ) + + output = model( # model is MCoreT5Model + encoder_input_ids, # encoder_input_ids + decoder_input_ids, # decoder_input_ids + encoder_attn_mask_3d, # encoder_attn_mask + decoder_attn_mask_3d, # decoder_attn_mask + enc_dec_attn_mask_3d, # encoder_decoder_attn_mask + lm_labels, # lm_labels + ) + + else: + output = model( + encoder_input_ids, # enc_input_ids + encoder_attn_mask, # enc_attn_mask + decoder_input_ids, # dec_input_ids + decoder_attn_mask, # dec_attn_mask + None, # token_type_ids + lm_labels, # labels + batch_data, # batch_data + ) def loss_func(output_tensor): if isinstance(output_tensor, dict): @@ -983,6 +1082,36 @@ def setup(self, stage=None): ) == 'relative' and not self.cfg.decoder.get('relative_position_bias_self_attention_only', True): self.enc_dec_model.sync_initial_decoder_cross_attention_relative_position_embeddings() + if self.cfg.get('transformer_engine', False) or self.cfg.get('mcore_t5', False): + self.setup_transformer_engine_tp_groups() + + def setup_transformer_engine_tp_groups(self): + """This should be called after model parallel groups have been initialized + and only needs to be called when using Transformer Engine. + """ + for module in self.get_t5_module_list(): + """Set TP group + Copied from: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/transformer.py#L398 + """ + # Deep iterate but skip self to avoid infinite recursion. + for index, child in enumerate(module.modules()): + if index == 0: + continue + if hasattr(child, "set_tensor_parallel_group"): + tp_group = parallel_state.get_tensor_model_parallel_group() + child.set_tensor_parallel_group(tp_group) + + def get_t5_module_list(self): + if isinstance(self.enc_dec_model, list): + return [ + model.module if isinstance(model, (Float16Module, MCoreFloat16Module)) else model + for model in self.enc_dec_model + ] + elif isinstance(self.enc_dec_model, (Float16Module, MCoreFloat16Module)): + return [self.enc_dec_model.module] + else: + return [self.enc_dec_model] + def setup_training_data(self, cfg): if hasattr(self, '_train_ds'): consumed_samples = self.compute_consumed_samples(0) @@ -1536,3 +1665,149 @@ def build_model_parallel_config(self): f'encoder.hidden_size not found in {self.cfg}. Set this in model_parallel_config if using pipeline parallelism.' ) return model_parallel_config + + def sharded_state_dict(self, prefix: str = '') -> Dict[str, Any]: + """ + Creates the sharded state dict which is used by dist_checkpoint to save the sharded tensors to disk. + When given the sharded_stated_dict, dist_checkpoint.load will load the tensors corresponding to + self.state_dict(). + The sharded tensor mapping is defined in the GPTModel class from mcore. + """ + if self.mcore_t5: + module_prefix = f'{prefix}model.' + sharded_state_dict = {} + for index, module in enumerate(self.get_model_module_list()): + if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: + # virtual pipline rank must be set so that GPTModel returns the correct sharded state dict + parallel_state.set_virtual_pipeline_model_parallel_rank(index) + module_sharded_state_dict = module.sharded_state_dict(prefix=module_prefix) + sharded_state_dict[f'model_{index}'] = module_sharded_state_dict + else: + module_sharded_state_dict = module.sharded_state_dict(prefix=module_prefix) + sharded_state_dict.update(module_sharded_state_dict) + + # reset vp rank + if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: + parallel_state.set_virtual_pipeline_model_parallel_rank(0) + + return sharded_state_dict + + def on_save_checkpoint(self, checkpoint) -> None: + """LightningModule hook: + https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-save-checkpoint + """ + if self.mcore_t5: + checkpoint['sharded_state_dict'] = self.sharded_state_dict() + else: + if isinstance(self.enc_dec_model, list): + for i in range(len(self.enc_dec_model)): + parallel_state.set_virtual_pipeline_model_parallel_rank(i) + checkpoint[f'model{i}'] = self.enc_dec_model[i].module.state_dict_for_save_checkpoint() + parallel_state.set_virtual_pipeline_model_parallel_rank(0) + + def on_load_checkpoint(self, checkpoint) -> None: + """LightningModule hook: + https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-load-checkpoint + """ + if self.mcore_t5: + if 'state_dict' in checkpoint and checkpoint['state_dict']: + for index, module in enumerate(self.get_model_module_list()): + if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: + checkpoint_state_dict = checkpoint['state_dict'][f'model_{index}'] + else: + checkpoint_state_dict = checkpoint['state_dict'] + # checkpoint_state_dict has "model." but module does not so we need to remove it when loading + checkpoint_state_dict = { + key.replace('model.', ''): checkpoint_state_dict.pop(key) + for key in list(checkpoint_state_dict.keys()) + } + + # addressing the current T5 mcore version's implementation of sharded_state_dict + checkpoint_state_dict['lm_head.output_layer.bias'] = checkpoint_state_dict['output_layer.bias'] + + module.load_state_dict(checkpoint_state_dict, strict=True) + else: + checkpoint['state_dict'] = {} + else: + if isinstance(self.enc_dec_model, list): + for i in range(len(self.enc_dec_model)): + parallel_state.set_virtual_pipeline_model_parallel_rank(i) + self.enc_dec_model[i].module.load_state_dict(checkpoint[f'model{i}'], strict=True) + parallel_state.set_virtual_pipeline_model_parallel_rank(0) + + def build_transformer_config(self) -> TransformerConfig: + """Builds the megatron core gpt transformer config for the model. + For attributes in the nemo model config that are the same + as the megatron core TransformerConfig, we will use the value from the nemo model config. + For attributes in TransformerConfig that are not in the nemo model config, we add custom logic. + """ + + # for T5 model, transformers hyperparameters are stored in self.cfg.encoder/self.cfg.decoder + with open_dict(self.cfg): + for key in self.cfg.encoder: + print("{}: {}".format(key, self.cfg.encoder.get(key))) + OmegaConf.update(self.cfg, key, self.cfg.encoder.get(key)) + + normalization = self.cfg.get('normalization', 'layernorm') + + layernorm_zero_centered_gamma = self.cfg.get('normalization', 'layernorm') == 'layernorm1p' + if normalization == 'layernorm': + normalization = 'LayerNorm' + elif normalization == 'rmsnorm': + normalization = 'RMSNorm' + elif normalization == 'layernorm1p': + normalization = 'LayerNorm' + layernorm_zero_centered_gamma = True + else: + logging.warning( + f"The normalization type: {normalization} might not be supported in megatron core." + f"Supported types are LayerNorm and RMSNorm." + ) + + # any configs that are not in the nemo model config will be added here + model_specific_configs = { + 'layernorm_zero_centered_gamma': layernorm_zero_centered_gamma, + 'normalization': normalization, + } + + transformer_config = super().build_transformer_config() + + for key, value in model_specific_configs.items(): + setattr(transformer_config, key, value) + + # pass mcore customization configs directly to mcore + mcore_customization_config_dict = self.cfg.get('mcore_customization_config', {}) + for key, value in mcore_customization_config_dict.items(): + setattr(transformer_config, key, value) + + return transformer_config + + def setup_mcore_distributed_parallel(self): + """Set up mcore distributed data parallel""" + if self.with_distributed_adam and self.use_mcore_dist_optim: + config = get_model_config(self.enc_dec_model[0]) + ddp_config = DistributedDataParallelConfig( + grad_reduce_in_fp32=(self.cfg.optim.get('grad_sync_dtype', 'fp32') == 'fp32'), + overlap_grad_reduce=self.cfg.optim.get('overlap_grad_sync', False), + use_distributed_optimizer=True, + check_for_nan_in_grad=self.cfg.optim.get('check_for_nan_in_grad', False), + # mcore bucket_size is based on num of parameters, therefore not + # using bucket_cap_mb to configure bucket_size here + bucket_size=self.cfg.optim.get('ddp_bucket_size', None), + ) + self.enc_dec_model = [ + McoreDDP( + config, + ddp_config, + model_chunk, + data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True), + expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(), + # Turn off bucketing for model_chunk 2 onwards, since communication for these + # model chunks is overlapped with compute anyway. + disable_bucketing=(model_chunk_idx > 0), + ) + for (model_chunk_idx, model_chunk) in enumerate(self.enc_dec_model) + ] + + # (TODO) Broadcast params from data parallel src rank to other data parallel ranks. + # by calling model_module.broadcast_params() if the model is randomly initialized. diff --git a/nemo/collections/nlp/models/language_modeling/megatron_mamba_model.py b/nemo/collections/nlp/models/language_modeling/megatron_mamba_model.py new file mode 100644 index 0000000000000..5180bd12b35e4 --- /dev/null +++ b/nemo/collections/nlp/models/language_modeling/megatron_mamba_model.py @@ -0,0 +1,89 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from megatron.core.models.mamba import MambaModel +from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec +from omegaconf.dictconfig import DictConfig +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.utils import logging + + +class MegatronMambaModel(MegatronGPTModel): + """ + Megatron Mamba pretraining. + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer): + + self.vocab_size = cfg.get('vocab_size', 65536) + self.cfg = cfg + super().__init__(cfg=cfg, trainer=trainer) + logging.warning("Overriding mcore_gpt=True") + self.mcore_gpt = True + + def model_provider_func(self, pre_process, post_process): + + self.hybrid_override_pattern = self.cfg.get( + 'hybrid_override_pattern', "M" * self.transformer_config.num_layers + ) + self.transformer_config.add_bias_linear = self.cfg.get('add_bias_linear', False) + self.transformer_config.gated_linear_unit = self.cfg.get('gated_linear_unit', False) + self.transformer_config.layernorm_epsilon = self.cfg.get('layernorm_epsilon', 1e-5) + + # TODO @ataghibakhsh: add mamba_ssm_ngroups=self.cfg.get('mamba_ssm_ngroups', 8) once MLM MR merged + + model = MambaModel( + config=self.transformer_config, + max_sequence_length=self.cfg.get('encoder_seq_length', 4096), + vocab_size=self.cfg.get('vocab_size', 65536), + mamba_ssm_ngroups=self.cfg.get('mamba_ssm_ngroups', 8), + mamba_stack_spec=mamba_stack_spec, + hybrid_override_pattern=self.hybrid_override_pattern, + ) + + return model + + def forward(self, input_ids, position_ids=None, attention_mask=None, labels=None): + + output_tensor = self.model( + input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask, labels=labels + ) + return output_tensor + + def build_transformer_config(self): + transformer_config = super().build_transformer_config() + return transformer_config + + def on_validation_epoch_end(self): + + averaged_loss = torch.tensor(0.0, dtype=torch.float32).cuda() + return averaged_loss + + def sharded_state_dict(self, prefix: str = ''): + return None + + def _reset_activation_checkpointing_args(self): + return + + def _restore_activation_checkpointing_args(self): + return + + def _reset_sequence_parallelism_args(self): + return + + def _restore_sequence_parallelism_args(self): + return diff --git a/nemo/collections/nlp/models/language_modeling/megatron_mamba_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_mamba_sft_model.py new file mode 100644 index 0000000000000..ebcc470047115 --- /dev/null +++ b/nemo/collections/nlp/models/language_modeling/megatron_mamba_sft_model.py @@ -0,0 +1,47 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from omegaconf import DictConfig +from omegaconf.dictconfig import DictConfig +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel +from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel +from nemo.collections.nlp.models.language_modeling.megatron_mamba_model import MegatronMambaModel + + +__all__ = ['MegatronMambaSFTModel'] + + +class MegatronMambaSFTModel(MegatronGPTSFTModel, MegatronMambaModel): + """ + Megatron Jamba Supervised Fine-Tuning + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer): + + super().__init__(cfg, trainer=trainer) + self.mcore_gpt = True + self.validation_param_sync_overlap = self.cfg.get('validation_param_sync_overlap', False) + + def _reset_activation_checkpointing_args(self): + pass + + def on_validation_model_zero_grad(self) -> None: + """ + Skip gradient zeroing at the beginning of validation routine. + This is needed when overlapping the AllGather of the updated parameters with the following valdation step. + """ + if not self.validation_param_sync_overlap: + MegatronBaseModel.on_validation_model_zero_grad(self) diff --git a/nemo/collections/nlp/models/nlp_model.py b/nemo/collections/nlp/models/nlp_model.py index 2380ed15cc45c..b27c00c5d7c35 100644 --- a/nemo/collections/nlp/models/nlp_model.py +++ b/nemo/collections/nlp/models/nlp_model.py @@ -462,6 +462,7 @@ def restore_from( return_config: bool = False, save_restore_connector: SaveRestoreConnector = None, trainer: Optional[Trainer] = None, + validate_access_integrity: bool = True, ): if save_restore_connector is None: save_restore_connector = NLPSaveRestoreConnector() @@ -475,5 +476,12 @@ def restore_from( logging.info('use_cpu_initialization is True, loading checkpoint on CPU') map_location = 'cpu' return super().restore_from( - restore_path, override_config_path, map_location, strict, return_config, save_restore_connector, trainer + restore_path, + override_config_path, + map_location, + strict, + return_config, + save_restore_connector, + trainer, + validate_access_integrity, ) diff --git a/nemo/collections/nlp/modules/common/huggingface/huggingface_utils.py b/nemo/collections/nlp/modules/common/huggingface/huggingface_utils.py index cf692e07749d0..d8f6936f71261 100644 --- a/nemo/collections/nlp/modules/common/huggingface/huggingface_utils.py +++ b/nemo/collections/nlp/modules/common/huggingface/huggingface_utils.py @@ -16,12 +16,6 @@ from typing import List, Optional from transformers import ( - ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST, - BERT_PRETRAINED_MODEL_ARCHIVE_LIST, - CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST, - DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST, - GPT2_PRETRAINED_MODEL_ARCHIVE_LIST, - ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, AlbertConfig, AutoModel, BertConfig, @@ -41,6 +35,74 @@ __all__ = ["get_huggingface_lm_model", "get_huggingface_pretrained_lm_models_list", "VOCAB_FILE_NAME"] +# Manually specify the model archive lists since these are now removed in HF +# https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/deprecated/_archive_maps.py +ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "albert/albert-base-v1", + "albert/albert-large-v1", + "albert/albert-xlarge-v1", + "albert/albert-xxlarge-v1", + "albert/albert-base-v2", + "albert/albert-large-v2", + "albert/albert-xlarge-v2", + "albert/albert-xxlarge-v2", +] + +BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "google-bert/bert-base-uncased", + "google-bert/bert-large-uncased", + "google-bert/bert-base-cased", + "google-bert/bert-large-cased", + "google-bert/bert-base-multilingual-uncased", + "google-bert/bert-base-multilingual-cased", + "google-bert/bert-base-chinese", + "google-bert/bert-base-german-cased", + "google-bert/bert-large-uncased-whole-word-masking", + "google-bert/bert-large-cased-whole-word-masking", + "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad", + "google-bert/bert-large-cased-whole-word-masking-finetuned-squad", + "google-bert/bert-base-cased-finetuned-mrpc", + "google-bert/bert-base-german-dbmdz-cased", + "google-bert/bert-base-german-dbmdz-uncased", + "cl-tohoku/bert-base-japanese", + "cl-tohoku/bert-base-japanese-whole-word-masking", + "cl-tohoku/bert-base-japanese-char", + "cl-tohoku/bert-base-japanese-char-whole-word-masking", + "TurkuNLP/bert-base-finnish-cased-v1", + "TurkuNLP/bert-base-finnish-uncased-v1", + "wietsedv/bert-base-dutch-cased", +] +CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "almanach/camembert-base", + "Musixmatch/umberto-commoncrawl-cased-v1", + "Musixmatch/umberto-wikipedia-uncased-v1", +] + +DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "distilbert-base-uncased", + "distilbert-base-uncased-distilled-squad", + "distilbert-base-cased", + "distilbert-base-cased-distilled-squad", + "distilbert-base-german-cased", + "distilbert-base-multilingual-cased", + "distilbert-base-uncased-finetuned-sst-2-english", +] +GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "openai-community/gpt2", + "openai-community/gpt2-medium", + "openai-community/gpt2-large", + "openai-community/gpt2-xl", + "distilbert/distilgpt2", +] +ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "FacebookAI/roberta-base", + "FacebookAI/roberta-large", + "FacebookAI/roberta-large-mnli", + "distilbert/distilroberta-base", + "openai-community/roberta-base-openai-detector", + "openai-community/roberta-large-openai-detector", +] + HUGGINGFACE_MODELS = { "BertModel": { @@ -94,7 +156,9 @@ def get_huggingface_lm_model( - pretrained_model_name: str, config_dict: Optional[dict] = None, config_file: Optional[str] = None, + pretrained_model_name: str, + config_dict: Optional[dict] = None, + config_file: Optional[str] = None, ): """ Returns lm model instantiated with Huggingface @@ -135,7 +199,9 @@ def get_huggingface_lm_model( raise ValueError(f"Use HuggingFace API directly in NeMo for {pretrained_model_name}") -def get_huggingface_pretrained_lm_models_list(include_external: bool = False,) -> List[str]: +def get_huggingface_pretrained_lm_models_list( + include_external: bool = False, +) -> List[str]: """ Returns the list of pretrained HuggingFace language models diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py index bcfe07f702a0d..48b6afa788aee 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py @@ -14,18 +14,21 @@ import torch import torch.nn.functional as F +from megatron.core import InferenceParams from megatron.core.fusions.fused_bias_geglu import bias_geglu_impl from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb -from megatron.core.tensor_parallel import ColumnParallelLinear +from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.transformer.attention import SelfAttention from megatron.core.transformer.custom_layers.transformer_engine import SplitAlongDim from megatron.core.transformer.mlp import MLP from megatron.core.transformer.moe.experts import SequentialMLP +from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_layer import TransformerLayer from megatron.core.utils import make_viewless_tensor +from torch import Tensor from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ( AdapterName, @@ -38,6 +41,7 @@ LoraMoeHto4HAdapterConfig, LoraUnfusedHto4HAdapterConfig, LoraUnfusedKQVAdapterConfig, + MLPHeadAdapterConfig, MLPInfusedAdapterConfig, ParallelLinearAdapterConfig, PromptEncoderAdapterConfig, @@ -62,6 +66,34 @@ def mcore_register_adapters(self): raise NotImplementedError("Mcore mixins should implement setup_adapters on a subclass of MyBase") +class MCoreTransformerBlockMixin(TransformerBlock, MCoreAdapterModuleMixin): + def mcore_register_adapters(self): + """ + Setup NeMo (canonical) Adapter to this MCore layer. + """ + self.set_accepted_adapter_types([MLPHeadAdapterConfig._target_]) + + def forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + context: Tensor = None, + context_mask: Tensor = None, + rotary_pos_emb: Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + ): + hidden_states = super().forward( + hidden_states, attention_mask, context, context_mask, rotary_pos_emb, inference_params, packed_seq_params + ) + + mlp_head_adapter = self.get_adapter_module(AdapterName.MLP_HEAD_ADAPTER) + if mlp_head_adapter and self.adapter_cfg[AdapterName.MLP_HEAD_ADAPTER]['enabled']: + hidden_states = mlp_head_adapter(hidden_states) + + return hidden_states + + class MCoreSelfAttentionMixin(SelfAttention, MCoreAdapterModuleMixin): def mcore_register_adapters(self): """ @@ -305,14 +337,16 @@ def mcore_register_adapters(self): def forward(self, hidden_states, expert_idx=None): # [s, b, 4 * h/p] - if isinstance(self.linear_fc1, ColumnParallelLinear): - layernorm_output = hidden_states - intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states) - elif self.linear_fc1.te_return_bias: - intermediate_parallel, bias_parallel, layernorm_output = self.linear_fc1(hidden_states) + output = self.linear_fc1(hidden_states) + if isinstance(output, tuple) and len(output) == 2: + intermediate_parallel, bias_parallel = output + if isinstance(intermediate_parallel, tuple) and len(intermediate_parallel) == 2: + intermediate_parallel, layernorm_output = intermediate_parallel + else: + layernorm_output = hidden_states else: - # bias_parallel is None - (intermediate_parallel, layernorm_output), bias_parallel = self.linear_fc1(hidden_states) + # self.linear_fc1.te_return_bias == True + intermediate_parallel, bias_parallel, layernorm_output = output # LoRA logic if self.is_adapter_available(): diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py index 21dace0088776..8d2d77c55cf23 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py @@ -24,6 +24,7 @@ import torch.nn as nn import torch.nn.init as init +from megatron.core.dist_checkpointing.mapping import ShardedStateDict from nemo.collections.common.parts.adapter_modules import AdapterModuleUtil from nemo.collections.common.parts.utils import activation_registry from nemo.collections.nlp.modules.common.megatron.fused_bias_gelu import fused_bias_gelu @@ -76,6 +77,7 @@ class AdapterName(str, enum.Enum): PTUNING_ADAPTER = "ptuning_adapter" LORA_KQV_ADAPTER = "lora_kqv_adapter" LORA_UNFUSED_KQV_ADAPTER = "lora_unfused_kqv_adapter" + MLP_HEAD_ADAPTER = "mlp_head_adapter" LORA_KV_ADAPTER = "lora_kv_adapter" LORA_Q_ADAPTER = "lora_q_adapter" MM_LINEAR_ADAPTER = "mm_linear_adapter" @@ -322,6 +324,16 @@ def forward(self, x): return x + def sharded_state_dict( + self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None + ) -> ShardedStateDict: + sharded_state_dict = {} + sharded_state_dict.update(self.linear_in.sharded_state_dict(f"{prefix}linear_in.", sharded_offsets, metadata)) + sharded_state_dict.update( + self.linear_out.sharded_state_dict(f"{prefix}linear_out.", sharded_offsets, metadata) + ) + return sharded_state_dict + class _All2AllHp2Sp(torch.autograd.Function): """ @@ -377,6 +389,57 @@ class ParallelLinearAdapterConfig(AdapterConfig): _target_: str = "{0}.{1}".format(ParallelLinearAdapter.__module__, ParallelLinearAdapter.__name__) +class MLPHeadAdapter(nn.Module, AdapterModuleUtil): + def __init__( + self, + in_features: int, + out_features: int, + input_is_parallel: bool = False, + model_parallel_config: Optional[ModelParallelConfig] = None, + **kwargs, + ): + super().__init__() + if model_parallel_config is None: + model_parallel_config = ModelParallelConfig() + self._sequence_parallel = model_parallel_config.sequence_parallel + model_parallel_config.sequence_parallel = False # SP is irrelevant for the lora linear layer + + if input_is_parallel: + self.linear = RowParallelLinear( + in_features, + out_features, + config=model_parallel_config, + input_is_parallel=True, + skip_bias_add=True, + bias=False, + init_method=init.xavier_normal_, + ) + else: + self.linear = ColumnParallelLinear( + in_features, + out_features, + config=model_parallel_config, + bias=False, + gather_output=True, + init_method=init.xavier_normal_, + disable_grad_reduce=self._sequence_parallel, + ) + + # Setup adapter strategy + self.setup_adapter_strategy(adapter_mixin_strategies.ReturnResultAdapterStrategy()) + + def forward(self, x): + x, _ = self.linear(x) + return x + + +@dataclass +class MLPHeadAdapterConfig(AdapterConfig): + in_features: int + out_features: int + _target_: str = "{0}.{1}".format(MLPHeadAdapter.__module__, MLPHeadAdapter.__name__) + + class LoraKQVAdapter(ParallelLinearAdapter): """ Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes @@ -766,14 +829,21 @@ def set_inference_table(self, prompt_representation: torch.Tensor): self.is_inference_ready = True return True - def clear_inference_table(self): + def clear_inference_table( + self, + ): self.inference_table.fill_(0.0) self.is_inference_ready = False - def get_inference_table(self): + def get_inference_table( + self, + ): return self.inference_table.data - def inner_forward(self): + def inner_forward( + self, + ): + input_embeds = self.embedding(self.indices).unsqueeze(0) intermediate_parallel, bias_parallel = self.first(input_embeds) intermediate_parallel = fused_bias_gelu(intermediate_parallel, bias_parallel) diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/qlora.py b/nemo/collections/nlp/modules/common/megatron/adapters/qlora.py index e29744ce4d4da..a834b9a3fb49a 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/qlora.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/qlora.py @@ -103,6 +103,10 @@ def backward(ctx, grad_output): return grad_output @ weight.dequantize().to(grad_output.device), None +def nf4_quantize(x: torch.Tensor): + return NF4Weight(x).cuda() + + class NF4LinearWrapper(nn.Module): """ NF4 Linear Layer for QLoRA as introduced in `QLORA: Efficient Finetuning of Quantized LLMs `_. @@ -117,7 +121,7 @@ def __init__(self, bf16_linear_weight: torch.Tensor): super().__init__() # quantize the weight upon initialization - self.weight = NF4Weight(bf16_linear_weight).cuda() + self.weight = nf4_quantize(bf16_linear_weight) def forward(self, x: torch.Tensor): """ @@ -224,12 +228,12 @@ def qlora_load_model(model: 'MCoreGPTModel', model_cfg: 'DictConfig', checkpoint def replace_linear(module: nn.Module, prefix=""): for name, child in module.named_children(): if name in qlora_targets: - bf16_weight = checkpoint[f"{prefix}.{name}.weight"] + bf16_weight = checkpoint[f"{prefix}.{name}.weight"].to(torch.bfloat16) logging.info(f'QLoRA: Quantizing linear layer: {prefix}.{name}') - if name in ['linear_proj', 'linear_fc2']: + layer_norm_weight = checkpoint.get(f"{prefix}.{name}.layer_norm_weight", None) + if layer_norm_weight is None: setattr(module, name, NF4LinearWrapper(bf16_weight)) - else: # name in ['linear_qkv', 'linear_fc1'] - layer_norm_weight = checkpoint[f"{prefix}.{name}.layer_norm_weight"] + else: layer_norm_bias = checkpoint.get(f"{prefix}.{name}.layer_norm_bias", None) normalization = module.config.normalization zero_centered_gamma = module.config.layernorm_zero_centered_gamma diff --git a/nemo/collections/nlp/modules/common/text_generation_strategy.py b/nemo/collections/nlp/modules/common/text_generation_strategy.py index e8e2859e439fd..8f8fe313a5e3d 100644 --- a/nemo/collections/nlp/modules/common/text_generation_strategy.py +++ b/nemo/collections/nlp/modules/common/text_generation_strategy.py @@ -21,6 +21,8 @@ import torch from transformers import CLIPImageProcessor + +from nemo.collections.common.tokenizers.chat_template_mixin import explode_chat_template_input, is_chat_input from nemo.collections.nlp.modules.common.lm_utils import pad_batch from nemo.collections.nlp.modules.common.megatron.module import Float16Module from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids @@ -94,7 +96,12 @@ def tokenize_batch(self, sentences, max_len, add_BOS): Tuple[torch.Tensor], the tokenized and padded torch tensor and the token context length tensor. """ tokenizer = self.model.tokenizer - if add_BOS: + if is_chat_input(sentences): + assert getattr( + tokenizer, 'has_chat_template', False + ), "Got chat-template input but tokenizer does not support chat template formating." + context_tokens = list(map(tokenizer.text_to_ids, explode_chat_template_input(sentences))) + elif add_BOS: context_tokens = [[tokenizer.bos_id] + tokenizer.text_to_ids(s) for s in sentences] elif hasattr(tokenizer.tokenizer, "get_prefix_tokens"): # chatglm: add tokenizer.gmask_id, tokenizer.sop_id @@ -501,6 +508,27 @@ def neva_process_prompts(prompt, tokenizer, multimodal_cfg, num_media_latents, c copy.deepcopy(list_data_dict), multimodal_cfg, num_media_latents ) # HARDCODED FOR NOW data_dict = preprocess_llama_3(sources, tokenizer, multimodal_cfg) + elif multimodal_cfg["conv_template"] == "mistral": + record = { + 'conversations': [ + { + 'from': 'human', + 'value': prompt, + }, + { + 'from': 'gpt', + 'value': '', + }, + ], + } + for turn in record['conversations']: + if turn.get('value') is not None: + turn['value'] = re.sub('', f'{DEFAULT_IMAGE_TOKEN}\n', turn['value']) + list_data_dict.append(record) + sources = preprocess_multimodal( + copy.deepcopy(list_data_dict), multimodal_cfg, num_media_latents + ) # HARDCODED FOR NOW + data_dict = preprocess_llama_2(sources, tokenizer, multimodal_cfg, is_mistral=True) elif multimodal_cfg["conv_template"] == "v1": record = { 'conversations': [ @@ -981,6 +1009,7 @@ def model_inference_strategy_dispatcher(model, **args): MegatronGPTPromptLearningModel, ) from nemo.collections.nlp.models.language_modeling.megatron_griffin_model import MegatronGriffinModel + from nemo.collections.nlp.models.language_modeling.megatron_mamba_model import MegatronMambaModel from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel from nemo.collections.nlp.models.language_modeling.megatron_retro_model import MegatronRetroModel from nemo.collections.nlp.modules.common.retro_inference_strategies import ( @@ -991,6 +1020,8 @@ def model_inference_strategy_dispatcher(model, **args): if isinstance(model, MegatronGriffinModel): return GriffinModelTextGenerationStrategy(model) + if isinstance(model, MegatronMambaModel): + return GPTModelTextGenerationStrategy(model) if isinstance(model, MegatronNevaModel): return NevaModelTextGenerationStrategy(model) if isinstance(model, MegatronGPTPromptLearningModel): diff --git a/nemo/collections/nlp/modules/common/text_generation_utils.py b/nemo/collections/nlp/modules/common/text_generation_utils.py index 498d9e9a09dad..cd02f54096793 100644 --- a/nemo/collections/nlp/modules/common/text_generation_utils.py +++ b/nemo/collections/nlp/modules/common/text_generation_utils.py @@ -122,31 +122,26 @@ def megatron_gpt_generate(model, inputs, tokenizer, length_params, sampling_para compute_prob_response = get_computeprob_response(tokenizer, response, inputs) return compute_prob_response - if isinstance(inputs, (list, tuple)): - if isinstance(inputs[0], (str, torch.Tensor)): - output = generate( - model, - inputs=inputs, - tokens_to_generate=length_params['max_length'], - all_probs=sampling_params['all_probs'], - compute_logprob=sampling_params['compute_logprob'], - temperature=sampling_params['temperature'], - add_BOS=sampling_params['add_BOS'], - top_k=sampling_params['top_k'], - top_p=sampling_params['top_p'], - greedy=sampling_params['use_greedy'], - repetition_penalty=sampling_params['repetition_penalty'], - end_strings=sampling_params['end_strings'], - min_tokens_to_generate=length_params['min_length'], - **strategy_args, - ) - return output - elif isinstance(inputs[0], dict): - raise NotImplementedError("json object not implemented") - else: - raise NotImplementedError("unknown type is not implemented") - else: - raise NotImplementedError("unknown type is not implemented") + if not isinstance(inputs, (list, tuple)): + raise NotImplementedError(f"unknown type {type(inputs)} is not implemented") + + output = generate( + model, + inputs=inputs, + tokens_to_generate=length_params['max_length'], + all_probs=sampling_params['all_probs'], + compute_logprob=sampling_params['compute_logprob'], + temperature=sampling_params['temperature'], + add_BOS=sampling_params['add_BOS'], + top_k=sampling_params['top_k'], + top_p=sampling_params['top_p'], + greedy=sampling_params['use_greedy'], + repetition_penalty=sampling_params['repetition_penalty'], + end_strings=sampling_params['end_strings'], + min_tokens_to_generate=length_params['min_length'], + **strategy_args, + ) + return output def megatron_neva_generate(model, prompt_dict_list, length_params, sampling_params, inference_config, **strategy_args): diff --git a/nemo/collections/nlp/modules/common/tokenizer_utils.py b/nemo/collections/nlp/modules/common/tokenizer_utils.py index 67c94ae5d608f..d3ee69f75b25c 100644 --- a/nemo/collections/nlp/modules/common/tokenizer_utils.py +++ b/nemo/collections/nlp/modules/common/tokenizer_utils.py @@ -78,6 +78,7 @@ def get_tokenizer( special_tokens: Optional[Dict[str, str]] = None, use_fast: Optional[bool] = False, bpe_dropout: Optional[float] = 0.0, + chat_template: Optional[Dict] = None, ): """ Args: @@ -91,7 +92,7 @@ def get_tokenizer( use_fast: (only for HuggingFace AutoTokenizer) set to True to use fast HuggingFace tokenizer bpe_dropout: (experimental) BPE dropout tries to corrupt the standard segmentation procedure of BPE to help - model better learn word compositionality and become robust to segmentation errors. + model better learn word compositionality and become robust to segmentation errors. It has emperically been shown to improve inference time BLEU scores. """ if special_tokens is None: @@ -116,7 +117,10 @@ def get_tokenizer( if tokenizer_name == 'sentencepiece': logging.info("tokenizer_model: " + str(tokenizer_model)) return nemo.collections.common.tokenizers.sentencepiece_tokenizer.SentencePieceTokenizer( - model_path=tokenizer_model, special_tokens=special_tokens, legacy=True + model_path=tokenizer_model, + special_tokens=special_tokens, + legacy=True, + chat_template=chat_template, ) elif tokenizer_name == 'word': return WordTokenizer(vocab_file=vocab_file, **special_tokens_dict) @@ -151,6 +155,7 @@ def get_nmt_tokenizer( legacy: Optional[bool] = False, delimiter: Optional[str] = None, trust_remote_code: Optional[bool] = False, + chat_template: Optional[Dict] = None, ): """ Args: @@ -187,7 +192,9 @@ def get_nmt_tokenizer( elif library == 'sentencepiece': logging.info(f'Getting SentencePiece with model: {tokenizer_model}') return nemo.collections.common.tokenizers.sentencepiece_tokenizer.SentencePieceTokenizer( - model_path=tokenizer_model, legacy=legacy + model_path=tokenizer_model, + legacy=legacy, + chat_template=chat_template, ) elif library == 'byte-level': logging.info(f'Using byte-level tokenization') @@ -209,7 +216,9 @@ def get_nmt_tokenizer( logging.info( f'Getting Megatron tokenizer for pretrained model name: {model_name}, custom vocab file: {vocab_file}, and merges file: {merges_file}' ) - return get_tokenizer(tokenizer_name=model_name, vocab_file=vocab_file, merges_file=merges_file) + return get_tokenizer( + tokenizer_name=model_name, vocab_file=vocab_file, merges_file=merges_file, chat_template=chat_template + ) elif library == 'tabular': return TabularTokenizer(vocab_file, delimiter=delimiter) else: diff --git a/nemo/collections/nlp/modules/common/transformer/transformer_generators.py b/nemo/collections/nlp/modules/common/transformer/transformer_generators.py index 6e17151dcd1b1..9bac89f611350 100644 --- a/nemo/collections/nlp/modules/common/transformer/transformer_generators.py +++ b/nemo/collections/nlp/modules/common/transformer/transformer_generators.py @@ -179,8 +179,7 @@ def __call__( ) def freeze(self) -> None: - """Freeze weights of embedding, decoder, and classification layers to prevent memory leak. - """ + """Freeze weights of embedding, decoder, and classification layers to prevent memory leak.""" for param in self.embedding.parameters(): param.requires_grad = False self.embedding.eval() @@ -192,8 +191,7 @@ def freeze(self) -> None: self.log_softmax.eval() def unfreeze(self) -> None: - """Unfreeze weights of embedding, decoder, and classification layers. - """ + """Unfreeze weights of embedding, decoder, and classification layers.""" for param in self.embedding.parameters(): param.requires_grad = True self.embedding.train() @@ -347,13 +345,13 @@ def _forward( # choose top-k hypotheses with length penalty applied len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen) scores = scores / len_penalties - scores, indices_i = torch.topk(scores.view(-1, self.beam_size ** 2), self.beam_size, dim=1) + scores, indices_i = torch.topk(scores.view(-1, self.beam_size**2), self.beam_size, dim=1) scores = scores.view(-1, 1) * len_penalties # select prefixes which correspond to the chosen hypotheses prefixes = prefixes.unsqueeze(1).repeat(1, self.beam_size, 1) prefixes = torch.cat((prefixes, prefixes_i.unsqueeze(2)), dim=2) - prefixes = prefixes.view(batch_size, self.beam_size ** 2, -1) + prefixes = prefixes.view(batch_size, self.beam_size**2, -1) p_len = prefixes.size(2) prefixes_ids = indices_i.unsqueeze(2).repeat(1, 1, p_len) prefixes = prefixes.gather(1, prefixes_ids).view(-1, p_len) @@ -453,7 +451,10 @@ def _one_step_forward_lm(self, decoder_input_ids=None, lm_mems_list=None, pos=0) input_mask = mask_padded_tokens(decoder_input_ids, self.pad).float() lm_hidden_states = self.language_model.encoder.embedding.forward(decoder_input_ids, start_pos=pos) lm_mems_list = self.language_model.encoder.encoder.forward( - lm_hidden_states, input_mask, lm_mems_list, return_mems=True, + lm_hidden_states, + input_mask, + lm_mems_list, + return_mems=True, ) lm_log_probs = self.language_model.log_softmax.forward(hidden_states=lm_mems_list[-1][:, -1:]) return lm_log_probs, lm_mems_list @@ -629,13 +630,13 @@ def _forward(self, src_ids, encoder_input_mask, decoder_input_ids=None, return_b # choose top-k hypotheses with length penalty applied len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen) scores = scores / len_penalties - scores, indices_i = torch.topk(scores.view(-1, self.beam_size ** 2), self.beam_size, dim=1) + scores, indices_i = torch.topk(scores.view(-1, self.beam_size**2), self.beam_size, dim=1) scores = scores.view(-1, 1) * len_penalties # select prefixes which correspond to the chosen hypotheses prefixes = prefixes.unsqueeze(1).repeat(1, self.beam_size, 1) prefixes = torch.cat((prefixes, prefixes_i.unsqueeze(2)), dim=2) - prefixes = prefixes.view(batch_size, self.beam_size ** 2, -1) + prefixes = prefixes.view(batch_size, self.beam_size**2, -1) p_len = prefixes.size(2) prefixes_ids = indices_i.unsqueeze(2).repeat(1, 1, p_len) prefixes = prefixes.gather(1, prefixes_ids).view(-1, p_len) @@ -691,8 +692,7 @@ def __call__(self, src_ids, encoder_input_mask, decoder_input_ids=None, return_b return self._forward(src_ids, encoder_input_mask, decoder_input_ids, return_beam_scores) def freeze(self) -> None: - """Freeze weights of embedding, decoder, and classification layers to prevent memory leak. - """ + """Freeze weights of embedding, decoder, and classification layers to prevent memory leak.""" for model_num in range(self.num_models): for param in self.embeddings[model_num].parameters(): param.requires_grad = False @@ -708,8 +708,7 @@ def freeze(self) -> None: self.encoders[model_num].eval() def unfreeze(self) -> None: - """Unfreeze weights of embedding, decoder, and classification layers. - """ + """Unfreeze weights of embedding, decoder, and classification layers.""" for model_num in range(self.num_models): for param in self.embeddings[model_num].parameters(): param.requires_grad = True @@ -730,6 +729,40 @@ def as_frozen(self): Context manager which temporarily freezes embedding, decoder, and log_softmax modules, yields control and finally unfreezes the modules. """ + grad_module_list = {'embeddings': {}, 'decoders': {}, 'log_softmaxes': {}, 'encoders': {}} + training_mode_module_list = {'embeddings': {}, 'decoders': {}, 'log_softmaxes': {}, 'encoders': {}} + + def gather_grad_values(module_name): + map_values = [{} for _ in range(self.num_models)] + for model_num in range(self.num_models): + for name, param in getattr(self, module_name)[model_num].named_parameters(): + map_values[model_num][name].append(param.requires_grad) + return map_values + + def reset_grad_values(module_name, map_values, require_grad_default: bool): + for model_num in range(self.num_models): + for name, param in getattr(self, module_name)[model_num].named_parameters(): + if name in map_values[model_num]: + param.requires_grad = map_values[model_num].pop() + else: + param.requires_grad = require_grad_default + + def gather_reset_training_mode_values(module_name, map_values: dict = None): + map_values = [{} for _ in range(self.num_models)] if not map_values else map_values + get_values = len(map_values) == 0 + + for model_num in range(self.num_models): + if get_values: + map_values[model_num] = getattr(self, module_name)[model_num].training + else: + getattr(self, module_name)[model_num].train(map_values[model_num]) + return map_values + + # Cache the param.require_grad state of each module + for module_name in grad_module_list.keys(): + grad_module_list[module_name] = gather_grad_values(module_name) + training_mode_module_list[module_name] = gather_reset_training_mode_values(module_name) + self.freeze() try: @@ -737,6 +770,11 @@ def as_frozen(self): finally: self.unfreeze() + # Reset the param.require_grad state of each module + for module_name in grad_module_list.keys(): + reset_grad_values(module_name, grad_module_list[module_name], require_grad_default=True) + gather_reset_training_mode_values(module_name, map_values=training_mode_module_list[module_name]) + class BeamSearchSequenceGeneratorWithLanguageModel(GreedySequenceGenerator): def __init__( @@ -771,13 +809,20 @@ def _one_step_forward( ): nmt_log_probs, decoder_mems_list = super()._one_step_forward( - decoder_input_ids, encoder_hidden_states, encoder_input_mask, decoder_mems_list, pos, + decoder_input_ids, + encoder_hidden_states, + encoder_input_mask, + decoder_mems_list, + pos, ) input_mask = mask_padded_tokens(decoder_input_ids, self.pad).float() lm_hidden_states = self.language_model.encoder.embedding.forward(decoder_input_ids, start_pos=pos) lm_mems_list = self.language_model.encoder.encoder.forward( - lm_hidden_states, input_mask, lm_mems_list, return_mems=True, + lm_hidden_states, + input_mask, + lm_mems_list, + return_mems=True, ) lm_log_probs = self.language_model.log_softmax.forward(hidden_states=lm_mems_list[-1][:, -1:]) @@ -853,13 +898,13 @@ def _forward( # choose top-k hypotheses with length penalty applied len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen) scores = scores / len_penalties - scores, indices_i = torch.topk(scores.view(-1, self.beam_size ** 2), self.beam_size, dim=1) + scores, indices_i = torch.topk(scores.view(-1, self.beam_size**2), self.beam_size, dim=1) scores = scores.view(-1, 1) * len_penalties # select prefixes which correspond to the chosen hypotheses prefixes = prefixes.unsqueeze(1).repeat(1, self.beam_size, 1) prefixes = torch.cat((prefixes, prefixes_i.unsqueeze(2)), dim=2) - prefixes = prefixes.view(batch_size, self.beam_size ** 2, -1) + prefixes = prefixes.view(batch_size, self.beam_size**2, -1) p_len = prefixes.size(2) prefixes_ids = indices_i.unsqueeze(2).repeat(1, 1, p_len) prefixes = prefixes.gather(1, prefixes_ids).view(-1, p_len) diff --git a/nemo/collections/nlp/parts/megatron_trainer_builder.py b/nemo/collections/nlp/parts/megatron_trainer_builder.py index 194168008dc4f..f4276fd1b8f9a 100644 --- a/nemo/collections/nlp/parts/megatron_trainer_builder.py +++ b/nemo/collections/nlp/parts/megatron_trainer_builder.py @@ -90,7 +90,7 @@ def _training_strategy(self) -> Union[NLPDDPStrategy, NLPFSDPStrategy]: find_unused_parameters=False, nccl_communicator_config_path=self.cfg.model.get('nccl_communicator_config_path', None), sharp=self.cfg.model.get('sharp', False), - dist_ckpt_parallel_save=self.cfg.model.get('dist_ckpt_parallel_save', False), + dist_ckpt_parallel_save=self.cfg.model.get('dist_ckpt_parallel_dist_opt', True), ) def _grad_scaler(self) -> GradScaler: diff --git a/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py b/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py index 7d294f6085bbb..90b3912784c81 100644 --- a/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py +++ b/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py @@ -17,6 +17,7 @@ from typing import List, Optional, Union import torch +from megatron.core.transformer.identity_op import IdentityOp from omegaconf import DictConfig, OmegaConf, open_dict from nemo.utils.model_utils import inject_model_parallel_rank @@ -29,8 +30,13 @@ HAVE_MEGATRON_CORE = False -from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import PromptEncoderAdapterConfig +from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ( + MLPHeadAdapterConfig, + PromptEncoderAdapterConfig, +) + from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector + from nemo.collections.nlp.parts.peft_config import ( PEFT_CONFIG_MAP, CanonicalAdaptersPEFTConfig, @@ -126,14 +132,15 @@ def _check_and_add_adapter(self, name, module, peft_name, peft_cfg, name_key_to_ f'model.{mcore_target}', f'model.module.{mcore_target}', ]: # simple string match for now - swap_mcore_mixin(module, mcore_mixin) - if model_utils.import_class_by_path(peft_cfg._target_) in module.get_accepted_adapter_types(): - module.add_adapter( - name=peft_name, - cfg=peft_cfg, - base_model_cfg=self.cfg, - model_parallel_config=self.model_parallel_config, - ) + if not isinstance(module, IdentityOp): + swap_mcore_mixin(module, mcore_mixin) + if model_utils.import_class_by_path(peft_cfg._target_) in module.get_accepted_adapter_types(): + module.add_adapter( + name=peft_name, + cfg=peft_cfg, + base_model_cfg=self.cfg, + model_parallel_config=self.model_parallel_config, + ) elif isinstance(module, AdapterModuleMixin): if model_utils.import_class_by_path(peft_cfg._target_) in module.get_accepted_adapter_types(): module.add_adapter( @@ -159,7 +166,6 @@ def _get_layers_from_model(self, model): def _check_and_add_peft_cfg(self, peft_cfg): layer_selection = peft_cfg.layer_selection - assert not self.use_mcore_gpt or hasattr( peft_cfg, 'name_key_to_mcore_mixins' ), f"{peft_cfg.__class__.__name__} is not supported in megatron core mode yet." @@ -167,7 +173,11 @@ def _check_and_add_peft_cfg(self, peft_cfg): for adapter_name, adapter_cfg in peft_cfg.get_config_dict().items(): # self.mcore_gpt means is GPT and not T5 - if hasattr(self, 'mcore_gpt') and not isinstance(adapter_cfg, PromptEncoderAdapterConfig): + if ( + hasattr(self, 'mcore_gpt') + and not isinstance(adapter_cfg, PromptEncoderAdapterConfig) + and not isinstance(adapter_cfg, MLPHeadAdapterConfig) + ): if layer_selection is not None: logging.info( f"Layer selection {layer_selection} is enabled for the current model (" @@ -178,9 +188,10 @@ def _check_and_add_peft_cfg(self, peft_cfg): for layer in layers: if layer.layer_number in (layer_selection or list(range(1, self.cfg.num_layers + 1))): for name, module in layer.named_modules(): - self._check_and_add_adapter( - name, module, adapter_name, adapter_cfg, name_key_to_mcore_mixins - ) + if not isinstance(module, IdentityOp): + self._check_and_add_adapter( + name, module, adapter_name, adapter_cfg, name_key_to_mcore_mixins + ) else: # Non GPT models, as well as GPT+PTuning do not support layer selection if layer_selection is not None: @@ -349,8 +360,10 @@ def load_adapters( assert filepath.endswith( '.nemo' ), "Inferring peft scheme is only supported for .nemo checkpoints. Please supply the `peft_cfgs` argument." - peft_cfgs = [PEFT_CONFIG_MAP[conf.peft.peft_scheme](conf)] + peft_cfg_cls_lst = [PEFT_CONFIG_MAP[s] for s in conf.peft.peft_scheme.split(",")] + peft_cfgs = [_peft_cfg(conf) for _peft_cfg in peft_cfg_cls_lst] if getattr(self, 'megatron_amp_O2', False): + state_dict = {replace_prefix(k, 'model.', 'model.module.'): v for k, v in state_dict.items()} self.add_adapter(peft_cfgs) if not self.ptuning_only_and_non_first_stage: diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 2fdb1906c31fd..b003e310baeba 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -116,6 +116,15 @@ HAVE_MEGATRON_CORE = False + +try: + from modelopt.torch.opt.plugins import restore_sharded_modelopt_state, save_sharded_modelopt_state + + HAVE_MODELOPT = True + +except Exception: + HAVE_MODELOPT = False + NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE = "NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE" @@ -381,6 +390,14 @@ def save_checkpoint( checkpoint['state_dict'] = OrderedDict([]) self.checkpoint_io.save_checkpoint(checkpoint, ckpt_to_dir(filepath), storage_options=storage_options) + + if HAVE_MODELOPT and hasattr(self.lightning_module, "get_model_module_list"): + save_sharded_modelopt_state( + self.lightning_module.get_model_module_list(), + ckpt_to_dir(filepath), + self.unwrapped_checkpoint_io.save_sharded_strategy, + prefix="model.", + ) else: # PTL override to accomodate model parallel checkpoints filepath = inject_model_parallel_rank(filepath) @@ -511,6 +528,11 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: if not fs.isdir(checkpoint_path): raise ValueError(f'Distributed checkpoints should be a directory. Found: {checkpoint_path}.') + if HAVE_MODELOPT and hasattr(self.lightning_module, "get_model_module_list"): + restore_sharded_modelopt_state( + self.lightning_module.get_model_module_list(), checkpoint_path, prefix="model." + ) + sharded_state_dict = self.lightning_module.sharded_state_dict() checkpoint = {} @@ -518,10 +540,14 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: # after dist_checkpointing.load, sharded tensors will be replaced with tensors checkpoint['state_dict'] = sharded_state_dict checkpoint['optimizer_states'] = [self.optimizer_sharded_state_dict(is_loading=True)] - if self._check_param_groups_mismatch(checkpoint_path, checkpoint): - return self._fix_param_groups(checkpoint_path, checkpoint) - return self.checkpoint_io.load_checkpoint(checkpoint_path, sharded_state_dict=checkpoint) + checkpoint = self._fix_param_groups(checkpoint_path, checkpoint) + else: + checkpoint = self.checkpoint_io.load_checkpoint(checkpoint_path, sharded_state_dict=checkpoint) + + if getattr(self.lightning_module, 'continue_training', False): + checkpoint = self._integrate_original_checkpoint_data(checkpoint) + return checkpoint # Legacy model parallel checkpointing logic, does not use megatron core else: @@ -532,6 +558,26 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: torch.cuda.empty_cache() return self.checkpoint_io.load_checkpoint(checkpoint_path) + def _integrate_original_checkpoint_data(self, checkpoint: Dict[str, Any]) -> Dict[str, Any]: + """ + Ensures that model and optimizer weights are loaded from the checkpoint. + All other metadata are reinitialized. + """ + original_checkpoint = self.lightning_module.trainer._checkpoint_connector.dump_checkpoint() + for key in checkpoint: + if key not in ['state_dict', 'optimizer_states']: + checkpoint[key] = original_checkpoint[key] + if 'optimizer' in checkpoint['optimizer_states'][0]: + checkpoint['optimizer_states'][0]['optimizer']['param_groups'] = original_checkpoint['optimizer_states'][ + 0 + ]['optimizer']['param_groups'] + else: + checkpoint['optimizer_states'][0]['param_groups'] = original_checkpoint['optimizer_states'][0][ + 'param_groups' + ] + + return checkpoint + def remove_checkpoint(self, filepath: Union[str, Path]) -> None: # check if filepath is a distributed checkpoint if self.use_distributed_checkpointing: @@ -549,10 +595,7 @@ def remove_checkpoint(self, filepath: Union[str, Path]) -> None: @property def use_distributed_checkpointing(self): - checkpoint_io = self.checkpoint_io - while isinstance(checkpoint_io, _WrappingCheckpointIO): - checkpoint_io = checkpoint_io.checkpoint_io - has_dist_ckpt_io = HAVE_MEGATRON_CORE and isinstance(checkpoint_io, DistributedCheckpointIO) + has_dist_ckpt_io = HAVE_MEGATRON_CORE and isinstance(self.unwrapped_checkpoint_io, DistributedCheckpointIO) has_sharded_state_dict = ( hasattr(self.lightning_module, 'sharded_state_dict') and self.lightning_module.sharded_state_dict() is not None @@ -592,6 +635,14 @@ def restore_checkpoint_after_setup(self) -> bool: """ return True + @property + def unwrapped_checkpoint_io(self) -> CheckpointIO: + """Returns CheckpointIO unwrapped from any _WrappedCheckpointIO wrappers.""" + checkpoint_io = self.checkpoint_io + while isinstance(checkpoint_io, _WrappingCheckpointIO): + checkpoint_io = checkpoint_io.checkpoint_io + return checkpoint_io + class NLPDDPStrategyNotebook(NLPDDPStrategy): """Version of NLPDDPStrategy to be used in a Jupyter Notebook @@ -650,6 +701,7 @@ def __init__( nccl_communicator_config_path: Optional[str] = None, sharp: bool = False, set_buffer_dtype: Optional[str] = None, + extra_fsdp_wrap_module: Optional[set] = None, **kwargs: Union[Any, Dict[str, Any]], ) -> None: if not HAVE_APEX: @@ -679,6 +731,11 @@ def __init__( ParallelTransformerLayer, BasicTransformerBlock, } + + # if extra wrap modules are provided, use them + if extra_fsdp_wrap_module is not None: + self.fsdp_wrap_module.update(extra_fsdp_wrap_module) + kwargs['auto_wrap_policy'] = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=self.fsdp_wrap_module ) @@ -961,9 +1018,19 @@ def dummy(): model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer) model.trainer.strategy.setup_environment() sharded_state_dict = model.sharded_state_dict() - checkpoint_io = DistributedCheckpointIO(model.cfg.get('dist_ckpt_format', 'zarr')) + checkpoint_io = DistributedCheckpointIO.from_config(model.cfg, async_save=False) checkpoint_io.save_checkpoint(sharded_state_dict, dist_ckpt_dir) + if HAVE_MODELOPT and hasattr(model, "get_model_module_list"): + while isinstance(checkpoint_io, _WrappingCheckpointIO): + checkpoint_io = checkpoint_io.checkpoint_io + save_sharded_modelopt_state( + model.get_model_module_list(), + dist_ckpt_dir, + checkpoint_io.save_sharded_strategy, + prefix="model.", + ) + else: # first we save the weights for each model parallel rank @@ -1179,6 +1246,7 @@ def restore_from( strict: bool = True, return_config: bool = False, trainer: Trainer = None, + validate_access_integrity: bool = True, ): """ Restores model instance (weights and configuration) into .nemo file @@ -1213,6 +1281,7 @@ def restore_from( strict, return_config, trainer, + validate_access_integrity, ) if not isinstance(loaded_params, tuple) or return_config is True: return loaded_params @@ -1246,16 +1315,26 @@ def dummy(): self._unpack_nemo_file( path2file=restore_path, out_folder=tmpdir, extract_config_only=return_config is True ) - checkpoint = {} - sharded_state_dict = instance.sharded_state_dict() - checkpoint['state_dict'] = sharded_state_dict # remove model weights extension tmp_model_weights_ckpt = os.path.join(tmpdir, self.model_weights_ckpt) tmp_model_weights_dir = os.path.splitext(tmp_model_weights_ckpt)[0] assert os.path.isdir(tmp_model_weights_dir), f'Expected {tmp_model_weights_dir} to be a directory.' + + if HAVE_MODELOPT and hasattr(instance, "get_model_module_list"): + restore_sharded_modelopt_state( + instance.get_model_module_list(), tmp_model_weights_dir, prefix="model." + ) + + checkpoint = {} + sharded_state_dict = instance.sharded_state_dict() + checkpoint['state_dict'] = sharded_state_dict + checkpoint_io = DistributedCheckpointIO.from_config(conf) checkpoint = checkpoint_io.load_checkpoint( - tmp_model_weights_dir, sharded_state_dict=checkpoint, strict=strict + tmp_model_weights_dir, + sharded_state_dict=checkpoint, + strict=strict, + validate_access_integrity=validate_access_integrity, ) instance.on_load_checkpoint(checkpoint) if hasattr(instance, 'setup_transformer_engine_tp_groups'): diff --git a/nemo/collections/nlp/parts/peft_config.py b/nemo/collections/nlp/parts/peft_config.py index 50c97e3498855..25f303fc22fb2 100644 --- a/nemo/collections/nlp/parts/peft_config.py +++ b/nemo/collections/nlp/parts/peft_config.py @@ -24,6 +24,7 @@ MCoreMLPMixin, MCoreSelfAttentionMixin, MCoreSequentialMLPMixin, + MCoreTransformerBlockMixin, MCoreTransformerLayerMixin, ) except (ImportError, ModuleNotFoundError): @@ -41,6 +42,7 @@ LoraMoeHto4HAdapterConfig, LoraUnfusedHto4HAdapterConfig, LoraUnfusedKQVAdapterConfig, + MLPHeadAdapterConfig, MLPInfusedAdapterConfig, ParallelLinearAdapterConfig, ParallelLinearAdapterWeightTyingConfig, @@ -127,6 +129,21 @@ def __init__(self, cfg): self.tunable_base_param_names = selective_cfg.get("tunable_base_param_names", []) +class MLPHeadPEFTConfig(PEFTConfig): + def __init__(self, cfg): + config_args = {"in_features": cfg.hidden_size, "out_features": cfg.peft.mlp_head_tuning.out_features} + mlp_head_cfg = MLPHeadAdapterConfig(**config_args) + + name_key_to_cfg = { + AdapterName.MLP_HEAD_ADAPTER: mlp_head_cfg, + } + self.name_key_to_mcore_mixins = { + AdapterName.MLP_HEAD_ADAPTER: [("decoder", MCoreTransformerBlockMixin)], + } + + super().__init__(cfg.peft.mlp_head_tuning, name_key_to_cfg) + + class LoraPEFTConfig(PEFTConfig): def __init__(self, cfg): lora_cfg = cfg.peft.lora_tuning @@ -170,7 +187,7 @@ def __init__(self, cfg): elif module == PEFT_MODULE_MAP["dense_module"]: adapter_cfg = self._create_lora_config( - cfg, lora_cfg, cfg.hidden_size, cfg.hidden_size, LoraDenseAttentionAdapterConfig + cfg, lora_cfg, projection_size, cfg.hidden_size, LoraDenseAttentionAdapterConfig ) name_key_to_cfg[AdapterName.LORA_DENSE_ATTENTION_ADAPTER] = adapter_cfg name_key_to_mcore_mixins[AdapterName.LORA_DENSE_ATTENTION_ADAPTER] = [ @@ -401,6 +418,7 @@ def __init__(self, cfg): "ia3": IA3PEFTConfig, "ptuning": PtuningPEFTConfig, "lora": LoraPEFTConfig, + "mlp_head": MLPHeadPEFTConfig, "qlora": QLoraPEFTConfig, "selective": SelectivePEFTConfig, 'none': None, diff --git a/nemo/collections/nlp/parts/utils_funcs.py b/nemo/collections/nlp/parts/utils_funcs.py index c00df5de1a988..a989ff3f606ca 100644 --- a/nemo/collections/nlp/parts/utils_funcs.py +++ b/nemo/collections/nlp/parts/utils_funcs.py @@ -34,14 +34,14 @@ from sklearn.metrics import classification_report, confusion_matrix from torch import Tensor -from nemo.collections.nlp.modules.common.megatron.utils import erf_gelu +from nemo.collections.nlp.modules.common.megatron.utils import ApproxGELUActivation, erf_gelu from nemo.collections.nlp.modules.common.megatron.utils import openai_gelu as openai_gelu_func from nemo.collections.nlp.modules.common.megatron.utils import squared_relu from nemo.utils import logging def torch_dtype_from_precision(precision: Union[int, str], megatron_amp_O2: Optional[bool] = None) -> torch.dtype: - """ Mapping from PTL precision types to corresponding PyTorch parameter datatype.""" + """Mapping from PTL precision types to corresponding PyTorch parameter datatype.""" if megatron_amp_O2 is not None and megatron_amp_O2 is False: return torch.float32 @@ -56,12 +56,12 @@ def torch_dtype_from_precision(precision: Union[int, str], megatron_amp_O2: Opti def list2str(l: List[int]) -> str: - """ Converts list to a string""" + """Converts list to a string""" return ' '.join([str(x) for x in l]) def tensor2list(tensor: Tensor) -> List[Union[int, float]]: - """ Converts tensor to a list """ + """Converts tensor to a list""" return tensor.detach().cpu().tolist() @@ -168,13 +168,13 @@ def get_last_rank(): def activation_to_func(activation: str, openai_gelu: bool = False, onnx_safe: bool = False) -> Callable: - """ Converts an activation function represented as a string to a function. + """Converts an activation function represented as a string to a function. Args: activation (str): string representation of an activation function, typically gotten from the model config. openai_gelu (bool): whether to use the OpenAI GELU implementation. Used with HF compatibility. onnx_safe (bool): whether to use the ONNX-compatible implementation of GELU. - + Returns: Callable: the activation function. """ @@ -188,6 +188,7 @@ def activation_to_func(activation: str, openai_gelu: bool = False, onnx_safe: bo 'fast-geglu', 'fast-swiglu', 'fast-reglu', + 'approx-gelu', ] if activation not in supported_activations: @@ -208,6 +209,8 @@ def activation_to_func(activation: str, openai_gelu: bool = False, onnx_safe: bo activation_func = F.silu elif activation == 'squared-relu': activation_func = squared_relu + elif activation == 'approx-gelu': + activation_func = ApproxGELUActivation return activation_func diff --git a/nemo/collections/tts/modules/transformer.py b/nemo/collections/tts/modules/transformer.py index 728b583919ff6..25c177d221cc1 100644 --- a/nemo/collections/tts/modules/transformer.py +++ b/nemo/collections/tts/modules/transformer.py @@ -102,7 +102,7 @@ def __init__(self, n_head, d_model, d_head, dropout, dropatt=0.1, pre_lnorm=Fals self.n_head = n_head self.d_model = d_model self.d_head = d_head - self.scale = 1 / (d_head ** 0.5) + self.scale = 1 / (d_head**0.5) self.pre_lnorm = pre_lnorm self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head) @@ -125,13 +125,17 @@ def _forward(self, inp, attn_mask=None, conditioning=None): head_q, head_k, head_v = torch.chunk(self.qkv_net(inp), 3, dim=2) - head_q = head_q.view(inp.size(0), inp.size(1), n_head, d_head) - head_k = head_k.view(inp.size(0), inp.size(1), n_head, d_head) - head_v = head_v.view(inp.size(0), inp.size(1), n_head, d_head) + s0 = inp.size(0) + s1 = inp.size(1) + s2 = s0 * n_head - q = head_q.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head) - k = head_k.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head) - v = head_v.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head) + head_q = head_q.view(s0, s1, n_head, d_head) + head_k = head_k.view(s0, s1, n_head, d_head) + head_v = head_v.view(s0, s1, n_head, d_head) + + q = head_q.permute(2, 0, 1, 3).reshape(s2, s1, d_head) + k = head_k.permute(2, 0, 1, 3).reshape(s2, s1, d_head) + v = head_v.permute(2, 0, 1, 3).reshape(s2, s1, d_head) attn_score = torch.bmm(q, k.transpose(1, 2)) attn_score.mul_(self.scale) @@ -145,8 +149,8 @@ def _forward(self, inp, attn_mask=None, conditioning=None): attn_prob = self.dropatt(attn_prob) attn_vec = torch.bmm(attn_prob, v) - attn_vec = attn_vec.view(n_head, inp.size(0), inp.size(1), d_head) - attn_vec = attn_vec.permute(1, 2, 0, 3).contiguous().view(inp.size(0), inp.size(1), n_head * d_head) + attn_vec = attn_vec.view(n_head, s0, s1, d_head) + attn_vec = attn_vec.permute(1, 2, 0, 3).contiguous().view(s0, s1, n_head * d_head) # linear projection attn_out = self.o_net(attn_vec) diff --git a/nemo/core/classes/common.py b/nemo/core/classes/common.py index 97757b2e3826e..60f842dbfb68c 100644 --- a/nemo/core/classes/common.py +++ b/nemo/core/classes/common.py @@ -1015,8 +1015,14 @@ def __init__( self.ignore_collections = ignore_collections + def __call__(self, wrapped): + return self.wrapped_call(wrapped) + + def unwrapped_call(self, wrapped): + return wrapped + @wrapt.decorator(enabled=is_typecheck_enabled) - def __call__(self, wrapped, instance: Typing, args, kwargs): + def wrapped_call(self, wrapped, instance: Typing, args, kwargs): """ Wrapper method that can be used on any function of a class that implements :class:`~nemo.core.Typing`. By default, it will utilize the `input_types` and `output_types` properties of the class inheriting Typing. @@ -1125,3 +1131,11 @@ def disable_semantic_checks(): yield finally: typecheck.set_semantic_check_enabled(enabled=True) + + @staticmethod + def enable_wrapping(enabled: bool = True): + typecheck.set_typecheck_enabled(enabled) + if enabled: + typecheck.__call__ = nemo.core.classes.common.typecheck.wrapped_call + else: + typecheck.__call__ = nemo.core.classes.common.typecheck.unwrapped_call diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index 5bd1bb813ba3b..aab09d42d9078 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -20,12 +20,13 @@ from nemo.core.classes import typecheck from nemo.core.neural_types import NeuralType from nemo.core.utils.neural_type_utils import get_dynamic_axes, get_io_names -from nemo.utils import logging +from nemo.utils import logging, monkeypatched from nemo.utils.export_utils import ( ExportFormat, augment_filename, get_export_format, parse_input_example, + rename_onnx_io, replace_for_export, verify_runtime, verify_torchscript, @@ -68,6 +69,7 @@ def export( check_tolerance=0.01, export_modules_as_functions=False, keep_initializers_as_inputs=None, + use_dynamo=False, ): """ Exports the model to the specified format. The format is inferred from the file extension of the output file. @@ -99,6 +101,7 @@ def export( ONNX specific. keep_initializers_as_inputs (bool): If True, will keep the model's initializers as inputs in the onnx graph. This is ONNX specific. + use_dynamo (bool): If True, use onnx.dynamo_export() instead of onnx.export(). This is ONNX specific. Returns: A tuple of two outputs. @@ -122,6 +125,7 @@ def export( check_tolerance=check_tolerance, export_modules_as_functions=export_modules_as_functions, keep_initializers_as_inputs=keep_initializers_as_inputs, + use_dynamo=use_dynamo, ) # Propagate input example (default scenario, may need to be overriden) if input_example is not None: @@ -143,6 +147,7 @@ def _export( check_tolerance=0.01, export_modules_as_functions=False, keep_initializers_as_inputs=None, + use_dynamo=False, ): my_args = locals().copy() my_args.pop('self') @@ -162,7 +167,7 @@ def _export( # Pytorch's default opset version is too low, using reasonable latest one if onnx_opset_version is None: - onnx_opset_version = 16 + onnx_opset_version = 17 try: # Disable typechecks @@ -189,14 +194,16 @@ def _export( input_list, input_dict = parse_input_example(input_example) input_names = self.input_names output_names = self.output_names - output_example = tuple(self.forward(*input_list, **input_dict)) + output_example = self.forward(*input_list, **input_dict) + if not isinstance(output_example, tuple): + output_example = (output_example,) if check_trace: if isinstance(check_trace, bool): check_trace_input = [input_example] else: check_trace_input = check_trace - jitted_model = self + if format == ExportFormat.TORCHSCRIPT: jitted_model = torch.jit.trace_module( self, @@ -216,27 +223,64 @@ def _export( elif format == ExportFormat.ONNX: # dynamic axis is a mapping from input/output_name => list of "dynamic" indices if dynamic_axes is None: - dynamic_axes = get_dynamic_axes(self.input_module.input_types_for_export, input_names) - dynamic_axes.update(get_dynamic_axes(self.output_module.output_types_for_export, output_names)) - torch.onnx.export( - jitted_model, - input_example, - output, - input_names=input_names, - output_names=output_names, - verbose=verbose, - do_constant_folding=do_constant_folding, - dynamic_axes=dynamic_axes, - opset_version=onnx_opset_version, - keep_initializers_as_inputs=keep_initializers_as_inputs, - export_modules_as_functions=export_modules_as_functions, - ) + dynamic_axes = self.dynamic_shapes_for_export(use_dynamo) + if use_dynamo: + typecheck.enable_wrapping(enabled=False) + # https://github.com/pytorch/pytorch/issues/126339 + with monkeypatched(torch.nn.RNNBase, "flatten_parameters", lambda *args: None): + logging.info(f"Running export.export, dynamic shapes:{dynamic_axes}\n") + + # We have to use different types of arguments for dynamo_export to achieve + # same external weights behaviour as onnx.export : + # https://github.com/pytorch/pytorch/issues/126479 + # https://github.com/pytorch/pytorch/issues/126269 + mem_params = sum([param.nelement() * param.element_size() for param in self.parameters()]) + mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()]) + mem = mem_params + mem_bufs + + if mem > 2 * 1000 * 1000 * 1000: + ex_model = torch.export.export( + self, + tuple(input_list), + kwargs=input_dict, + dynamic_shapes=dynamic_axes, + strict=False, + ) + ex_model = ex_model.run_decompositions() + model_state = ex_model.state_dict + else: + model_state = None + ex_model = self + + options = torch.onnx.ExportOptions(dynamic_shapes=True, op_level_debug=True) + ex = torch.onnx.dynamo_export(ex_model, *input_list, **input_dict, export_options=options) + ex.save(output, model_state=model_state) + + del ex + del ex_model + # Rename I/O after save - don't want to risk modifying ex._model_proto + rename_onnx_io(output, input_names, output_names) + else: + torch.onnx.export( + self, + input_example, + output, + input_names=input_names, + output_names=output_names, + verbose=verbose, + do_constant_folding=do_constant_folding, + dynamic_axes=dynamic_axes, + opset_version=onnx_opset_version, + keep_initializers_as_inputs=keep_initializers_as_inputs, + export_modules_as_functions=export_modules_as_functions, + ) if check_trace: verify_runtime(self, output, check_trace_input, input_names, check_tolerance=check_tolerance) else: raise ValueError(f'Encountered unknown export format {format}.') finally: + typecheck.enable_wrapping(enabled=True) typecheck.set_typecheck_enabled(enabled=True) if forward_method: type(self).forward = old_forward_method @@ -288,9 +332,12 @@ def input_types_for_export(self) -> Optional[Dict[str, NeuralType]]: def output_types_for_export(self): return self.output_types + def dynamic_shapes_for_export(self, use_dynamo=False): + return get_dynamic_axes(self.input_module.input_types_for_export, self.input_names, use_dynamo) + def get_export_subnet(self, subnet=None): """ - Returns Exportable subnet model/module to export + Returns Exportable subnet model/module to export """ if subnet is None or subnet == 'self': return self diff --git a/nemo/core/classes/mixins/adapter_mixins.py b/nemo/core/classes/mixins/adapter_mixins.py index 2a05f374d4641..7b5d02c86bf7b 100644 --- a/nemo/core/classes/mixins/adapter_mixins.py +++ b/nemo/core/classes/mixins/adapter_mixins.py @@ -15,7 +15,7 @@ import inspect from abc import ABC from dataclasses import dataclass, is_dataclass -from typing import List, Optional, Set, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch import torch.nn as nn @@ -123,8 +123,72 @@ def _prepare_default_adapter_config(*, global_key: str, meta_key: str, cfg: Dict return cfg +def update_module_class_with_adapter_class( + module: nn.Module, cfg: DictConfig, update_config: bool = True, verbose: bool = True +): + """ + Recursively walks through the module and its children, checking if the class is registered in the adapter registry. + If it is, the module's class is swapped with the registered adapter class. + Also updates the config with the adapter classpath, if required. + + Args: + module: torch.nn.Module to recurse through. + cfg: DictConfig object or dict that contains the config of the module. + update_config: Bool, whether to update the config with the adapter classpath. + verbose: Bool, whether to log the changes made to the module and config. + """ + + def inplace_recursive_walk_dict(d: Union[dict, DictConfig], base_class_path: str, adapter_class_path: str): + """ + Utility function to recursively walk through a dictionary and update the classpath if required. + Update is done inplace + + Args: + d: Dict to recurse through. + base_class_path: The str classpath of the base class. + adapter_class_path: The str classpath of the adapter class. + """ + for k, v in d.items(): # Loop through all k, v pairs + if isinstance(v, (dict, DictConfig)): # If value is a dict, recurse through it + inplace_recursive_walk_dict(v, base_class_path, adapter_class_path) + + # If key is target and value is base class, update the value to adapter class + elif k in ('target', '_target_') and isinstance(v, str) and v == base_class_path: + if verbose: + logging.info( + f"Updating config from {v} (base class) to {adapter_class_path} (adapter compatible " f"class)" + ) + + # Update the value inplace + d[k] = adapter_class_path + + if not isinstance(module, AdapterModuleMixin): + info = get_registered_adapter(module.__class__) + if info is not None: + if verbose: + logging.info( + f"Swapping class {info.base_class_path} with adapter compatible class: " + f"{info.adapter_class_path}" + ) + + # Swap the registered class with its registered adapter class. + # Due to direct inheritance of the Adapter subclass from the original class, + # the module's class container will be replaced with the adapter class. + + adapter_cls = info.adapter_class + module.__class__ = adapter_cls + + if update_config: + # Update the adapter config with the registered adapter config + # Find the location where the original module was registered in config + # and replace it with the adapter classpath. + original_classpath = info.base_class_path + adapter_classpath = info.adapter_class_path + inplace_recursive_walk_dict(cfg, original_classpath, adapter_classpath) + + class AdapterModuleMixin(ABC): - """ Generic Adapter Mixin that can augment any torch.nn.Module with Adapter module support. + """Generic Adapter Mixin that can augment any torch.nn.Module with Adapter module support. This mixin class adds a hierarchical way to add any type of Adapter modules to a pre-existing module. Since Models are inherently also nn.Module, this mixin can be attached to any Model or Module. @@ -171,21 +235,7 @@ def add_adapter(self, name: str, cfg: Union[DictConfig, AdapterConfig], **kwargs cfg = DictConfig(cfg) adapter_types = self.get_accepted_adapter_types() - _pass_types = False - if len(adapter_types) > 0: - test = model_utils.import_class_by_path(cfg._target_) - for _type in adapter_types: - # TODO: (@adithyare) should revisit if subclass is the best check... - if issubclass(test, _type): - _pass_types = True - break - if not _pass_types: - raise ValueError( - f"Config: \n{OmegaConf.to_yaml(cfg)}\n" - f"It creates adapter class {test} \n" - f"that is not in the list of accepted adapter types.\n" - f"Accepted adapters: {[t for t in adapter_types]}" - ) + self.check_supported_adapter_type_(cfg, adapter_types) # Convert to DictConfig from dict or Dataclass if is_dataclass(cfg): @@ -341,6 +391,14 @@ def get_adapter_module(self, name: str): return self.adapter_layer[name] if name in self.adapter_layer else None return None + def get_adapter_cfg(self, name: str): + """Same logic as `get_adapter_module` but to get the config""" + _, name = self.resolve_adapter_module_name_(name) + + if hasattr(self, "adapter_cfg"): + return self.adapter_cfg[name] if name in self.adapter_cfg else None + return None + def set_accepted_adapter_types(self, adapter_types: List[Union[type, str]]) -> None: """ The module with this mixin can define a list of adapter names that it will accept. @@ -363,7 +421,9 @@ def set_accepted_adapter_types(self, adapter_types: List[Union[type, str]]) -> N self._accepted_adapter_types = set(types) - def get_accepted_adapter_types(self,) -> Set[type]: + def get_accepted_adapter_types( + self, + ) -> Set[type]: """ Utility function to get the set of all classes that are accepted by the module. @@ -543,9 +603,38 @@ def forward_single_enabled_adapter_( output = adapter_strategy(input, adapter_module, module=self) return output + def check_supported_adapter_type_( + self, adapter_cfg: DictConfig, supported_adapter_types: Optional[Iterable[type]] = None + ): + """ + Utility method to check if the adapter module is a supported type by the module. + + This method should be called by the subclass to ensure that the adapter module is a supported type. + """ + _pass_types = False + + if supported_adapter_types is None: + supported_adapter_types = self.get_accepted_adapter_types() + + if len(supported_adapter_types) > 0: + test = model_utils.import_class_by_path(adapter_cfg['_target_']) + for _type in supported_adapter_types: + # TODO: (@adithyare) should revisit if subclass is the best check... + if issubclass(test, _type): + _pass_types = True + break + + if not _pass_types: + raise ValueError( + f"Config: \n{OmegaConf.to_yaml(adapter_cfg)}\n" + f"It creates adapter class {test} \n" + f"that is not in the list of accepted adapter types.\n" + f"Accepted adapters: {[t for t in supported_adapter_types]}" + ) + class AdapterModelPTMixin(AdapterModuleMixin): - """ Adapter Mixin that can augment a ModelPT subclass with Adapter support. + """Adapter Mixin that can augment a ModelPT subclass with Adapter support. This mixin class should be used only with a top level ModelPT subclass. This mixin class adds several utility methods which should be subclassed and overriden to @@ -641,7 +730,9 @@ def add_adapter(self, name: str, cfg: Union[DictConfig, AdapterConfig]): self.cfg.adapters = OmegaConf.create({}) self.cfg.adapters = _prepare_default_adapter_config( - global_key=self.adapter_global_cfg_key, meta_key=self.adapter_metadata_cfg_key, cfg=self.cfg.adapters, + global_key=self.adapter_global_cfg_key, + meta_key=self.adapter_metadata_cfg_key, + cfg=self.cfg.adapters, ) # If the adapter is not being restored, force unique name to be provided for all adapters. @@ -970,6 +1061,19 @@ def update_adapter_cfg(self, cfg: DictConfig): if isinstance(module, AdapterModuleMixin): module.adapter_cfg = cfg + def replace_adapter_compatible_modules(self, update_config: bool = True, verbose: bool = True): + """ + Utility method to replace all child modules with Adapter variants, if they exist. + Does NOT recurse through children of children modules (only immediate children). + + Args: + update_config: A flag that determines if the config should be updated or not. + verbose: A flag that determines if the method should log the changes made or not. + """ + # Update the given module itself, and then all its children modules + for name, mod in self.named_modules(): + update_module_class_with_adapter_class(mod, cfg=self.cfg, update_config=update_config, verbose=verbose) + @property def adapter_module_names(self) -> List[str]: """ @@ -982,6 +1086,22 @@ def adapter_module_names(self) -> List[str]: Returns: A list of str, one for each of the adapter modules that are supported. By default, the subclass - should support the "global adapter" (''). + should support the "default adapter" (''). """ return [''] + + @property + def default_adapter_module_name(self) -> Optional[str]: + """ + Name of the adapter module that is used as "default" if a name of '' is provided. + + .. note:: + + Subclasses should override this property and return a str name of the module + that they wish to denote as the default. + + Returns: + A str name of a module, which is denoted as 'default' adapter or None. If None, then no default + adapter is supported. + """ + return None diff --git a/nemo/core/classes/modelPT.py b/nemo/core/classes/modelPT.py index f5d61a8edb157..2bfd4e5cd695b 100644 --- a/nemo/core/classes/modelPT.py +++ b/nemo/core/classes/modelPT.py @@ -422,6 +422,7 @@ def restore_from( return_config: bool = False, save_restore_connector: SaveRestoreConnector = None, trainer: Optional[Trainer] = None, + validate_access_integrity: bool = True, ): """ Restores model instance (weights and configuration) from .nemo file. @@ -465,7 +466,14 @@ def restore_from( cls.update_save_restore_connector(save_restore_connector) instance = cls._save_restore_connector.restore_from( - cls, restore_path, override_config_path, map_location, strict, return_config, trainer + cls, + restore_path, + override_config_path, + map_location, + strict, + return_config, + trainer, + validate_access_integrity, ) if isinstance(instance, ModelPT): instance._save_restore_connector = save_restore_connector diff --git a/nemo/core/connectors/save_restore_connector.py b/nemo/core/connectors/save_restore_connector.py index 70d91066b7f05..23b38510bb001 100644 --- a/nemo/core/connectors/save_restore_connector.py +++ b/nemo/core/connectors/save_restore_connector.py @@ -92,6 +92,7 @@ def load_config_and_state_dict( strict: bool = True, return_config: bool = False, trainer: Trainer = None, + validate_access_integrity: bool = True, ): """ Restores model instance (weights and configuration) into .nemo file @@ -226,6 +227,7 @@ def restore_from( strict: bool = True, return_config: bool = False, trainer: Trainer = None, + validate_access_integrity: bool = True, ): """ Restores model instance (weights and configuration) into .nemo file @@ -253,7 +255,14 @@ def restore_from( # Get path where the command is executed - the artifacts will be "retrieved" there # (original .nemo behavior) loaded_params = self.load_config_and_state_dict( - calling_cls, restore_path, override_config_path, map_location, strict, return_config, trainer, + calling_cls, + restore_path, + override_config_path, + map_location, + strict, + return_config, + trainer, + validate_access_integrity, ) if not isinstance(loaded_params, tuple) or return_config is True: return loaded_params diff --git a/nemo/core/optim/lr_scheduler.py b/nemo/core/optim/lr_scheduler.py index 473ca0f5c416c..cfb3068b1cc8a 100644 --- a/nemo/core/optim/lr_scheduler.py +++ b/nemo/core/optim/lr_scheduler.py @@ -97,7 +97,14 @@ class SquareRootConstantPolicy(_LRScheduler): """ def __init__( - self, optimizer, *, constant_steps=None, constant_ratio=None, max_steps=None, min_lr=0.0, last_epoch=-1 + self, + optimizer, + *, + constant_steps=None, + constant_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1, ): assert not ( constant_steps is not None and constant_ratio is not None @@ -114,7 +121,7 @@ def __init__( else: self.constant_steps = 0 - self.constant_lr = 1 / (constant_steps ** 0.5) + self.constant_lr = 1 / (constant_steps**0.5) self.min_lr = min_lr super().__init__(optimizer, last_epoch) @@ -280,6 +287,16 @@ def get_lr(self): step = self.last_epoch + # Reset learning rate + if 'reset_lr' in self.optimizer.param_groups[0].keys(): + reset_lr = self.optimizer.param_groups[0]['reset_lr'] + num_steps = reset_lr['num_steps'] + step -= num_steps + if reset_lr['if_init_step'] and reset_lr['reset_lr_steps']: + self.decay_steps -= num_steps + self.max_steps -= num_steps + self.optimizer.param_groups[0]['reset_lr']['if_init_step'] = False + # Warmup steps if self.warmup_steps > 0 and step <= self.warmup_steps: return self._get_warmup_lr(step) @@ -364,7 +381,7 @@ def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle): def _noam_hold_annealing(initial_lr, step, warmup_steps, hold_steps, decay_rate, min_lr): # hold_steps = total number of steps to hold the LR, not the warmup + hold steps. - T_warmup_decay = max(1, warmup_steps ** decay_rate) + T_warmup_decay = max(1, warmup_steps**decay_rate) T_hold_decay = max(1, (step - hold_steps) ** decay_rate) lr = (initial_lr * T_warmup_decay) / T_hold_decay lr = max(lr, min_lr) @@ -453,7 +470,15 @@ def _get_linear_warmup_with_cosine_annealing_lr(self, step): class NoamAnnealing(_LRScheduler): def __init__( - self, optimizer, *, d_model, warmup_steps=None, warmup_ratio=None, max_steps=None, min_lr=0.0, last_epoch=-1 + self, + optimizer, + *, + d_model, + warmup_steps=None, + warmup_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1, ): self._normalize = d_model ** (-0.5) assert not ( @@ -593,7 +618,7 @@ def __init__(self, optimizer, *, max_steps, last_epoch=-1, min_lr=0.0, **kwargs) super().__init__(optimizer=optimizer, max_steps=max_steps, **kwargs, last_epoch=last_epoch, min_lr=min_lr) def _get_lr(self, step): - return [1 / (step ** 0.5) for _ in self.base_lrs] + return [1 / (step**0.5) for _ in self.base_lrs] class PolynomialDecayAnnealing(WarmupPolicy): diff --git a/nemo/core/optim/mcore_optim.py b/nemo/core/optim/mcore_optim.py index 234680f492497..9feb70cc90a18 100644 --- a/nemo/core/optim/mcore_optim.py +++ b/nemo/core/optim/mcore_optim.py @@ -58,9 +58,7 @@ def load_state_dict(self, state_dict): def sharded_state_dict( self, model_sharded_state_dict, optimizer_state_dict=None, is_loading=False, dist_ckpt_parallel_save=False ): - # TODO(@akoumparouli, @mikolajblaz): switch to sharding_type once support for fully_sharded_model_space merged in mcore. - # sharding_type = 'fully_sharded_model_space' if dist_ckpt_parallel_save else 'dp_zero_gather_scatter' - sharding_type = 'dp_zero_gather_scatter' + sharding_type = 'fully_sharded_model_space' if dist_ckpt_parallel_save else 'dp_zero_gather_scatter' return self.mcore_optimizer.sharded_state_dict( model_sharded_state_dict, is_loading=is_loading, sharding_type=sharding_type ) diff --git a/nemo/core/optim/optimizer_with_main_params.py b/nemo/core/optim/optimizer_with_main_params.py index 7d47b7e895f7b..412332adef907 100755 --- a/nemo/core/optim/optimizer_with_main_params.py +++ b/nemo/core/optim/optimizer_with_main_params.py @@ -119,7 +119,7 @@ def zero(self): self.data.zero_() def allreduce_buffer(self): - """Synchronous buffer data allreduce """ + """Synchronous buffer data allreduce""" self.data.div_(get_data_parallel_world_size()) torch.distributed.all_reduce(self.data, group=self._data_group) @@ -175,7 +175,7 @@ class MainParamsOptimizerWrapper(torch.optim.Optimizer): Arguments: optimizer: base optimizer such as Adam or SGD. fp32_grad_accum: to enable the use of fp32 in gradient accumulation and allreduce. - contiguous_grad_bucket: to enable allocating the master gradients in the + contiguous_grad_bucket: to enable allocating the master gradients in the contiguous memory space to reduce memory fragmentation. async_grad_allreduce: enable asynchronous gradient allreduce that is executed along with the training step backprop. @@ -339,6 +339,7 @@ def __init__( def _make_param_hook(self, param, main_param, i, grad_chunk_info, is_expert_group): """Create the grad accumulation and all-reduce hook for backprop.""" + # Hook used for back-prop. def param_hook(*unused): # Accumulates gradients on main gradients @@ -361,7 +362,9 @@ def allreduce_grads(use_fused_div, tensor, data_group, grad_mult): else: tensor.div_(grad_mult) torch.distributed.all_reduce( - tensor, group=data_group, async_op=True, + tensor, + group=data_group, + async_op=True, ) # Asynchronous gradients allreduce accross data_parallel ranks @@ -473,12 +476,16 @@ def load_state_dict(self, state_dict): if optimizer_key not in state_dict: optimizer_key = 'optimizer_state_dict' logging.info('***WARNING*** loading optimizer from ' 'an old checkpoint ...') + if 'state' not in state_dict[optimizer_key]: + state_dict[optimizer_key]['state'] = {} self.optimizer.load_state_dict(state_dict[optimizer_key]) # Copy data for the main params. fp32_from_float16_params_key = 'fp32_from_fp16_params' if fp32_from_float16_params_key not in state_dict: fp32_from_float16_params_key = 'fp32_from_fp16' + if fp32_from_float16_params_key not in state_dict: + state_dict[fp32_from_float16_params_key] = [] for current_group, saved_group in zip(self.fp32_from_float16_groups, state_dict[fp32_from_float16_params_key]): for current_param, saved_param in zip(current_group, saved_group): current_param.data.copy_(saved_param.data) @@ -489,7 +496,7 @@ def allreduce_main_grads(self): @contextmanager def no_sync(self): - """ A context manager to disable gradient synchronizations across + """A context manager to disable gradient synchronizations across data-parallel ranks.""" old_require_backward_grad_sync = self._require_backward_grad_sync self._require_backward_grad_sync = False diff --git a/nemo/core/utils/neural_type_utils.py b/nemo/core/utils/neural_type_utils.py index 98ae442b9aa78..5a634dad3d57c 100644 --- a/nemo/core/utils/neural_type_utils.py +++ b/nemo/core/utils/neural_type_utils.py @@ -14,7 +14,7 @@ from collections import defaultdict from typing import Dict, List, Optional - +import torch from nemo.core.neural_types import AxisKind, NeuralType @@ -30,19 +30,19 @@ def get_io_names(types: Optional[Dict[str, NeuralType]], disabled_names: List[st def extract_dynamic_axes(name: str, ntype: NeuralType): """ - This method will extract BATCH and TIME dimension ids from each provided input/output name argument. - - For example, if module/model accepts argument named "input_signal" with type corresponding to [Batch, Time, Dim] - shape, then the returned result should contain "input_signal" -> [0, 1] because Batch and Time are dynamic axes - as they can change from call to call during inference. - - Args: - name: Name of input or output parameter - ntype: Corresponding Neural Type - - Returns: + This method will extract BATCH and TIME dimension ids from each provided input/output name argument. + + For example, if module/model accepts argument named "input_signal" with type corresponding to [Batch, Time, Dim] + shape, then the returned result should contain "input_signal" -> [0, 1] because Batch and Time are dynamic axes + as they can change from call to call during inference. + + Args: + name: Name of input or output parameter + ntype: Corresponding Neural Type - """ + Returns: + + """ def unpack_nested_neural_type(neural_type): if type(neural_type) in (list, tuple): @@ -60,10 +60,23 @@ def unpack_nested_neural_type(neural_type): return dynamic_axes -def get_dynamic_axes(types, names): +def get_dynamic_axes(types, names, use_dynamo=False): dynamic_axes = defaultdict(list) if names is not None: for name in names: if name in types: dynamic_axes.update(extract_dynamic_axes(name, types[name])) + if use_dynamo: + dynamic_shapes = {} + batch = torch.export.Dim("batch") + for name, dims in dynamic_axes.items(): + ds = {} + for d in dims: + if d == 0: + ds[d] = batch + # this currently has issues: https://github.com/pytorch/pytorch/issues/126127 + else: + ds[d] = torch.export.Dim(name + '__' + str(d)) + dynamic_shapes[name] = ds + dynamic_axes = dynamic_shapes return dynamic_axes diff --git a/nemo/deploy/deploy_pytriton.py b/nemo/deploy/deploy_pytriton.py index 25e09cf3eacca..1e1333f03b553 100644 --- a/nemo/deploy/deploy_pytriton.py +++ b/nemo/deploy/deploy_pytriton.py @@ -29,7 +29,7 @@ class DeployPyTriton(DeployBase): Example: from nemo.deploy import DeployPyTriton, NemoQueryLLM - from nemo.export import TensorRTLLM + from nemo.export.tensorrt_llm import TensorRTLLM trt_llm_exporter = TensorRTLLM(model_dir="/path/for/model/files") trt_llm_exporter.export( diff --git a/nemo/deploy/multimodal/__init__.py b/nemo/deploy/multimodal/__init__.py new file mode 100644 index 0000000000000..b75e37007ab97 --- /dev/null +++ b/nemo/deploy/multimodal/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from nemo.deploy.multimodal.query_multimodal import NemoQueryMultimodal diff --git a/nemo/deploy/multimodal/query_multimodal.py b/nemo/deploy/multimodal/query_multimodal.py new file mode 100644 index 0000000000000..9f747ff6d3061 --- /dev/null +++ b/nemo/deploy/multimodal/query_multimodal.py @@ -0,0 +1,115 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from decord import VideoReader +from PIL import Image + +from nemo.deploy.utils import str_list2numpy + +use_pytriton = True +try: + from pytriton.client import ModelClient +except Exception: + use_pytriton = False + + +class NemoQueryMultimodal: + """ + Sends a query to Triton for Multimodal inference + + Example: + from nemo.deploy.multimodal import NemoQueryMultimodal + + nq = NemoQueryMultimodal(url="localhost", model_name="neva", model_type="neva") + + input_text = "Hi! What is in this image?" + output = nq.query( + input_text=input_text, + input_media="/path/to/image.jpg", + max_output_len=30, + top_k=1, + top_p=0.0, + temperature=1.0, + ) + print("prompts: ", prompts) + """ + + def __init__(self, url, model_name, model_type): + self.url = url + self.model_name = model_name + self.model_type = model_type + + def setup_media(self, input_media): + if self.model_type == "video-neva": + vr = VideoReader(input_media) + frames = [f.asnumpy() for f in vr] + return np.array(frames) + elif self.model_type == "neva": + media = Image.open(input_media).convert('RGB') + return np.expand_dims(np.array(media), axis=0) + else: + raise RuntimeError(f"Invalid model type {self.model_type}") + + def query( + self, + input_text, + input_media, + batch_size=1, + max_output_len=30, + top_k=1, + top_p=0.0, + temperature=1.0, + repetition_penalty=1.0, + num_beams=1, + init_timeout=60.0, + ): + + prompts = str_list2numpy([input_text]) + inputs = {"input_text": prompts} + + media = self.setup_media(input_media) + + inputs["input_media"] = np.repeat(media[np.newaxis, :, :, :, :], prompts.shape[0], axis=0) + + if batch_size is not None: + inputs["batch_size"] = np.full(prompts.shape, batch_size, dtype=np.int_) + + if max_output_len is not None: + inputs["max_output_len"] = np.full(prompts.shape, max_output_len, dtype=np.int_) + + if top_k is not None: + inputs["top_k"] = np.full(prompts.shape, top_k, dtype=np.int_) + + if top_p is not None: + inputs["top_p"] = np.full(prompts.shape, top_p, dtype=np.single) + + if temperature is not None: + inputs["temperature"] = np.full(prompts.shape, temperature, dtype=np.single) + + if repetition_penalty is not None: + inputs["repetition_penalty"] = np.full(prompts.shape, repetition_penalty, dtype=np.single) + + if num_beams is not None: + inputs["num_beams"] = np.full(prompts.shape, num_beams, dtype=np.int_) + + with ModelClient(self.url, self.model_name, init_timeout_s=init_timeout) as client: + result_dict = client.infer_batch(**inputs) + output_type = client.model_config.outputs[0].dtype + + if output_type == np.bytes_: + sentences = np.char.decode(result_dict["outputs"].astype("bytes"), "utf-8") + return sentences + else: + return result_dict["outputs"] diff --git a/nemo/deploy/nlp/__init__.py b/nemo/deploy/nlp/__init__.py index ae4db1ce6f2a2..5ebbe68166649 100644 --- a/nemo/deploy/nlp/__init__.py +++ b/nemo/deploy/nlp/__init__.py @@ -15,8 +15,12 @@ use_query_llm = True try: - from nemo.deploy.nlp.query_llm import NemoQueryLLM + from nemo.deploy.nlp.query_llm import NemoQueryLLM, NemoQueryLLMPyTorch except Exception: use_query_llm = False -from nemo.deploy.nlp.megatronllm_deployable import MegatronLLMDeployable +use_megatron_llm = True +try: + from nemo.deploy.nlp.megatronllm_deployable import MegatronLLMDeployable +except Exception: + use_megatron_llm = False diff --git a/nemo/deploy/nlp/megatronllm_deployable.py b/nemo/deploy/nlp/megatronllm_deployable.py index c27bbbd0102b0..1fe029f9faded 100644 --- a/nemo/deploy/nlp/megatronllm_deployable.py +++ b/nemo/deploy/nlp/megatronllm_deployable.py @@ -15,6 +15,7 @@ import logging from enum import IntEnum, auto from pathlib import Path +from typing import List import numpy as np import torch @@ -129,6 +130,12 @@ def _load_from_nemo_checkpoint(self, nemo_checkpoint_filepath: str, num_devices: nemo_checkpoint_filepath, trainer=trainer, return_config=True ) # transformer_engine should always be true according to EricH, but GPT-2B model will fail if it is enabled + if not custom_config.transformer_engine: + LOGGER.warning( + "MegatronLLMDeployable expects model config transformer_engine=True, but this model has it =False. " + "Overriding it to =True, but this may break certain checkpoints converted on older Nemo versions. " + "If your model breaks, please try re-converting the checkpoint on the current Nemo version." + ) custom_config.transformer_engine = True # using multi-gpu for tensor parallelism directly for now, could do pipeline parallel instead or a combination custom_config.tensor_model_parallel_size = num_devices @@ -233,9 +240,7 @@ def _length_params_from_triton_inputs(**inputs: np.ndarray): length_params[length_param_field] = inputs.pop(length_param_field)[0][0] return length_params - @batch - def triton_infer_fn(self, **inputs: np.ndarray): - """Triton server inference function that actually runs the model""" + def generate(self, inputs: List[str], length_params: LengthParam, sampling_params: SamplingParam): if torch.distributed.is_initialized(): distributed_rank = torch.distributed.get_rank() if distributed_rank != 0: @@ -245,13 +250,16 @@ def triton_infer_fn(self, **inputs: np.ndarray): signal_value = ServerSync.SIGNAL.to_long_tensor() torch.distributed.broadcast(signal_value, 0) + return self.model.generate(inputs=inputs, length_params=length_params, sampling_params=sampling_params) + + @batch + def triton_infer_fn(self, **inputs: np.ndarray): + """Triton server inference function that actually runs the model""" input_strings = str_ndarray2list(inputs.pop("prompts")) sampling_params = self._sampling_params_from_triton_inputs(**inputs) length_params = self._length_params_from_triton_inputs(**inputs) - model_output = self.model.generate( - inputs=input_strings, length_params=length_params, sampling_params=sampling_params - ) + model_output = self.generate(input_strings, length_params, sampling_params) ''' model_output['sentences'] will be a list of strings (one per prompt) other fields will either be a list of lists (tokens, for example) diff --git a/nemo/deploy/nlp/query_llm.py b/nemo/deploy/nlp/query_llm.py index 940a927c7a540..71492520bf0ab 100644 --- a/nemo/deploy/nlp/query_llm.py +++ b/nemo/deploy/nlp/query_llm.py @@ -30,23 +30,99 @@ def __init__(self, url, model_name): self.url = url self.model_name = model_name - @abstractmethod + +class NemoQueryLLMPyTorch(NemoQueryLLMBase): + """ + Sends a query to Triton for LLM inference + + Example: + from nemo.deploy import NemoTritonQueryLLMPyTorch + + nq = NemoTritonQueryLLMPyTorch(url="localhost", model_name="GPT-2B") + + prompts = ["hello, testing GPT inference", "another GPT inference test?"] + output = nq.query_llm( + prompts=prompts, + max_length=100, + top_k=1, + top_p=0.0, + temperature=0.0, + ) + print("prompts: ", prompts) + """ + + def __init__(self, url, model_name): + super().__init__( + url=url, + model_name=model_name, + ) + + # these arguments are explicitly defined in order to make it clear to user what they can pass + # names and optionality should exactly match the get_triton_input() results for MegatronGPTDeployable def query_llm( self, prompts, - stop_words_list=None, - bad_words_list=None, - no_repeat_ngram_size=None, - max_output_len=512, - top_k=1, - top_p=0.0, - temperature=1.0, - random_seed=None, - task_id=None, - lora_uids=None, + use_greedy: bool = None, + temperature: float = None, + top_k: int = None, + top_p: float = None, + repetition_penalty: float = None, + add_BOS: bool = None, + all_probs: bool = None, + compute_logprob: bool = None, + end_strings=None, + min_length: int = None, + max_length: int = None, init_timeout=60.0, ): - pass + """ + Query the Triton server synchronously and return a list of responses. + + Args: + prompts (List(str)): list of sentences. + use_greedy (bool): use greedy sampling, effectively the same as top_k=1 + temperature (float): A parameter of the softmax function, which is the last layer in the network. + top_k (int): limits us to a certain number (K) of the top tokens to consider. + top_p (float): limits us to the top tokens within a certain probability mass (p). + repetition_penalty (float): penalty applied to repeated sequences, 1.0 means no penalty. + add_BOS (bool): whether or not to add a BOS (beginning of sentence) token. + all_probs (bool): when using compute_logprob, returns probabilities for all tokens in vocabulary. + compute_logprob (bool): get back probabilities of all tokens in the sequence. + end_strings (List(str)): list of strings which will terminate generation when they appear in the output. + min_length (int): min generated tokens. + max_length (int): max generated tokens. + init_timeout (flat): timeout for the connection. + """ + prompts = str_list2numpy(prompts) + inputs = { + "prompts": prompts, + } + if use_greedy is not None: + inputs["use_greedy"] = np.full(prompts.shape, use_greedy, dtype=np.bool_) + if temperature is not None: + inputs["temperature"] = np.full(prompts.shape, temperature, dtype=np.single) + if top_k is not None: + inputs["top_k"] = np.full(prompts.shape, top_k, dtype=np.int_) + if top_p is not None: + inputs["top_p"] = np.full(prompts.shape, top_p, dtype=np.single) + if repetition_penalty is not None: + inputs["repetition_penalty"] = np.full(prompts.shape, repetition_penalty, dtype=np.single) + if add_BOS is not None: + inputs["add_BOS"] = np.full(prompts.shape, add_BOS, dtype=np.bool_) + if all_probs is not None: + inputs["all_probs"] = np.full(prompts.shape, all_probs, dtype=np.bool_) + if compute_logprob is not None: + inputs["compute_logprob"] = np.full(prompts.shape, compute_logprob, dtype=np.bool_) + if end_strings is not None: + inputs["end_strings"] = str_list2numpy(end_strings) + if min_length is not None: + inputs["min_length"] = np.full(prompts.shape, min_length, dtype=np.int_) + if max_length is not None: + inputs["max_length"] = np.full(prompts.shape, max_length, dtype=np.int_) + + with ModelClient(self.url, self.model_name, init_timeout_s=init_timeout) as client: + result_dict = client.infer_batch(**inputs) + return result_dict class NemoQueryLLM(NemoQueryLLMBase): diff --git a/nemo/deploy/service/__init__.py b/nemo/deploy/service/__init__.py new file mode 100644 index 0000000000000..0349454da9e13 --- /dev/null +++ b/nemo/deploy/service/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .rest_model_api import app diff --git a/nemo/deploy/service/config.json b/nemo/deploy/service/config.json new file mode 100644 index 0000000000000..d3b3440dd97b5 --- /dev/null +++ b/nemo/deploy/service/config.json @@ -0,0 +1,5 @@ +{ + "triton_service_port": 8000, + "triton_service_ip": "0.0.0.0", + "triton_request_timeout": 60 + } \ No newline at end of file diff --git a/nemo/deploy/service/rest_model_api.py b/nemo/deploy/service/rest_model_api.py new file mode 100644 index 0000000000000..5c49370fd45f8 --- /dev/null +++ b/nemo/deploy/service/rest_model_api.py @@ -0,0 +1,87 @@ +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import os +from pathlib import Path + +from fastapi import FastAPI +from pydantic import BaseModel +from pydantic_settings import BaseSettings + +from nemo.deploy.nlp import NemoQueryLLM + + +class TritonSettings(BaseSettings): + _triton_service_port: int + _triton_service_ip: str + _triton_request_timeout: str + + def __init__(self): + super(TritonSettings, self).__init__() + try: + with open(os.path.join(Path.cwd(), 'nemo/deploy/service/config.json')) as config: + config_json = json.load(config) + self._triton_service_port = config_json["triton_service_port"] + self._triton_service_ip = config_json["triton_service_ip"] + self._triton_request_timeout = config_json["triton_request_timeout"] + except Exception as error: + print("An exception occurred:", error) + return + + @property + def triton_service_port(self): + return self._triton_service_port + + @property + def triton_service_ip(self): + return self._triton_service_ip + + @property + def triton_request_timeout(self): + return self._triton_request_timeout + + +app = FastAPI() +triton_settings = TritonSettings() + + +class CompletionRequest(BaseModel): + model: str + prompt: str + max_tokens: int = 512 + temperature: float = 1.0 + top_p: float = 0.0 + n: int = 1 + stream: bool = False + stop: str | None = None + frequency_penalty: float = 1.0 + + +@app.post("/v1/completions/") +def completions_v1(request: CompletionRequest): + try: + url = triton_settings.triton_service_ip + ":" + str(triton_settings.triton_service_port) + nq = NemoQueryLLM(url=url, model_name=request.model) + output = nq.query_llm( + prompts=[request.prompt], + max_output_len=request.max_tokens, + top_k=request.n, + top_p=request.top_p, + temperature=request.temperature, + init_timeout=triton_settings.triton_request_timeout, + ) + return { + "output": output[0][0], + } + except Exception as error: + print("An exception occurred:", error) + return {"error": "An exception occurred"} diff --git a/nemo/deploy/utils.py b/nemo/deploy/utils.py index fe770debe7392..650770e771525 100644 --- a/nemo/deploy/utils.py +++ b/nemo/deploy/utils.py @@ -16,6 +16,7 @@ import numpy as np import torch +from PIL import Image from pytriton.model_config import Tensor @@ -64,6 +65,11 @@ def str_ndarray2list(str_ndarray: np.ndarray) -> typing.List[str]: return str_ndarray.tolist() +def ndarray2img(img_ndarray: np.ndarray) -> typing.List[Image.Image]: + img_list = [Image.fromarray(i) for i in img_ndarray] + return img_list + + def cast_output(data, required_dtype): if isinstance(data, torch.Tensor): data = data.cpu().numpy() diff --git a/nemo/export/__init__.py b/nemo/export/__init__.py index 55712d98852cb..d9155f923f186 100644 --- a/nemo/export/__init__.py +++ b/nemo/export/__init__.py @@ -11,15 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - -import logging - -LOGGER = logging.getLogger("NeMo") - - -use_TensorRTLLM = True -try: - from nemo.export.tensorrt_llm import TensorRTLLM -except Exception as e: - LOGGER.warning("TensorRTLLM could not be imported.") diff --git a/nemo/export/multimodal/__init__.py b/nemo/export/multimodal/__init__.py new file mode 100644 index 0000000000000..d9155f923f186 --- /dev/null +++ b/nemo/export/multimodal/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/export/multimodal/build.py b/nemo/export/multimodal/build.py new file mode 100644 index 0000000000000..b21e5383b57f7 --- /dev/null +++ b/nemo/export/multimodal/build.py @@ -0,0 +1,300 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import shutil +import tarfile +import tempfile +from time import time + +import tensorrt as trt +import torch +import yaml +from tensorrt_llm.builder import Builder +from transformers import AutoModel + +from nemo.export.tensorrt_llm import TensorRTLLM +from nemo.export.trt_llm.nemo_ckpt_loader.nemo_file import load_nemo_model + +logger = trt.Logger(trt.Logger.INFO) + + +def build_trtllm_engine( + model_dir: str, + visual_checkpoint_path: str, + llm_checkpoint_path: str = None, + model_type: str = "neva", + llm_model_type: str = "llama", + tensor_parallel_size: int = 1, + max_input_len: int = 256, + max_output_len: int = 256, + max_batch_size: int = 1, + max_multimodal_len: int = 1024, + dtype: str = "bfloat16", +): + trt_llm_exporter = TensorRTLLM(model_dir=model_dir, load_model=False) + trt_llm_exporter.export( + nemo_checkpoint_path=visual_checkpoint_path if model_type == "neva" else llm_checkpoint_path, + model_type=llm_model_type, + tensor_parallel_size=tensor_parallel_size, + max_input_len=max_input_len, + max_output_len=max_output_len, + max_batch_size=max_batch_size, + max_prompt_embedding_table_size=max_multimodal_len, + dtype=dtype, + load_model=False, + ) + + +def export_visual_wrapper_onnx( + visual_wrapper, input, output_dir, input_names=['input'], dynamic_axes={'input': {0: 'batch'}} +): + logger.log(trt.Logger.INFO, "Exporting onnx") + os.makedirs(f'{output_dir}/onnx', exist_ok=True) + torch.onnx.export( + visual_wrapper, + input, + f'{output_dir}/onnx/visual_encoder.onnx', + opset_version=17, + input_names=input_names, + output_names=['output'], + dynamic_axes=dynamic_axes, + ) + + +def build_trt_engine( + model_type, input_sizes, output_dir, max_batch_size, dtype=torch.bfloat16, image_size=None, num_frames=None +): + part_name = 'visual_encoder' + onnx_file = '%s/onnx/%s.onnx' % (output_dir, part_name) + engine_file = '%s/%s.engine' % (output_dir, part_name) + config_file = '%s/%s' % (output_dir, "config.json") + logger.log(trt.Logger.INFO, "Building TRT engine for %s" % part_name) + + builder = trt.Builder(logger) + network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) + profile = builder.create_optimization_profile() + + config_args = {"precision": str(dtype).split('.')[-1], "model_type": model_type} + if image_size is not None: + config_args["image_size"] = image_size + if num_frames is not None: + config_args["num_frames"] = num_frames + + config_wrapper = Builder().create_builder_config(**config_args) + config = config_wrapper.trt_builder_config + + parser = trt.OnnxParser(network, logger) + + with open(onnx_file, 'rb') as model: + if not parser.parse(model.read(), os.path.abspath(onnx_file)): + logger.log(trt.Logger.ERROR, "Failed parsing %s" % onnx_file) + for error in range(parser.num_errors): + logger.log(trt.Logger.ERROR, parser.get_error(error)) + logger.log(trt.Logger.INFO, "Succeeded parsing %s" % onnx_file) + + # Delete onnx files since we don't need them now + shutil.rmtree(f'{output_dir}/onnx') + + nBS = -1 + nMinBS = 1 + nOptBS = max(nMinBS, int(max_batch_size / 2)) + nMaxBS = max_batch_size + + inputT = network.get_input(0) + + # input sizes can be a list of ints (e.g., [3, H, W]) when inputs are images, + # or a list of three int lists (e.g., [[1, 1, 2700], [1, 500, 2700], [1, 4096, 2700]]). + assert isinstance(input_sizes, list), "input_sizes must be a list" + if isinstance(input_sizes[0], int): + logger.log(trt.Logger.INFO, f"Processed input sizes {input_sizes}") + inputT.shape = [nBS, *input_sizes] + min_size = opt_size = max_size = input_sizes + elif len(input_sizes) == 3 and isinstance(input_sizes[0], list): + min_size, opt_size, max_size = input_sizes + logger.log(trt.Logger.INFO, f"Processed min/opt/max input sizes {min_size}/{opt_size}/{max_size}") + else: + raise ValueError(f"invalid input sizes: {input_sizes}") + + profile.set_shape(inputT.name, [nMinBS, *min_size], [nOptBS, *opt_size], [nMaxBS, *max_size]) + config.add_optimization_profile(profile) + + t0 = time() + engine_string = builder.build_serialized_network(network, config) + t1 = time() + if engine_string is None: + raise RuntimeError("Failed building %s" % (engine_file)) + else: + logger.log(trt.Logger.INFO, "Succeeded building %s in %d s" % (engine_file, t1 - t0)) + with open(engine_file, 'wb') as f: + f.write(engine_string) + + Builder.save_config(config_wrapper, config_file) + + +def build_neva_engine( + model_dir: str, + visual_checkpoint_path: str, + max_batch_size: int = 1, +): + device = torch.device("cuda") if torch.cuda.is_available() else "cpu" + # extract NeMo checkpoint + with tempfile.TemporaryDirectory() as temp: + mp0_weights, nemo_config, _ = load_nemo_model(visual_checkpoint_path, temp) + + vision_config = nemo_config["mm_cfg"]["vision_encoder"] + + class VisionEncoderWrapper(torch.nn.Module): + + def __init__(self, encoder, connector): + super().__init__() + self.encoder = encoder + self.connector = connector + + def forward(self, images): + vision_x = self.encoder(pixel_values=images, output_hidden_states=True) + vision_x = vision_x.hidden_states[-2] + vision_x = vision_x[:, 1:] + vision_x = self.connector(vision_x) + return vision_x + + encoder = AutoModel.from_pretrained( + vision_config["from_pretrained"], torch_dtype=torch.bfloat16, trust_remote_code=True + ) + vision_encoder = encoder.vision_model + hf_config = encoder.config + dtype = hf_config.torch_dtype + + # connector + assert nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "mlp2x_gelu" + vision_connector = torch.nn.Sequential( + torch.nn.Linear(vision_config["hidden_size"], nemo_config["hidden_size"], bias=True), + torch.nn.GELU(), + torch.nn.Linear(nemo_config["hidden_size"], nemo_config["hidden_size"], bias=True), + ).to(dtype=dtype) + + key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector" + for layer in range(0, 3, 2): + vision_connector[layer].load_state_dict( + { + 'weight': mp0_weights[f"{key_prefix}.{layer}.weight"].to(dtype), + 'bias': mp0_weights[f"{key_prefix}.{layer}.bias"].to(dtype), + } + ) + + # export the whole wrapper + wrapper = VisionEncoderWrapper(vision_encoder, vision_connector).to(device, dtype) + image_size = hf_config.vision_config.image_size + dummy_image = torch.empty( + 1, 3, image_size, image_size, dtype=dtype, device=device + ) # dummy image shape [B, C, H, W] + + export_visual_wrapper_onnx(wrapper, dummy_image, model_dir) + build_trt_engine( + "neva", + [3, image_size, image_size], + model_dir, + max_batch_size, + dtype, + image_size=image_size, + ) + + +def build_video_neva_engine( + model_dir: str, + visual_checkpoint_path: str, + max_batch_size: int = 1, +): + device = torch.device("cuda") if torch.cuda.is_available() else "cpu" + # extract NeMo checkpoint + with tarfile.open(visual_checkpoint_path) as tar: + nemo_config = yaml.safe_load(tar.extractfile("./model_config.yaml")) + try: + # trained without TP + mp0_weights = torch.load(tar.extractfile("./model_weights.ckpt"), map_location=device) + except KeyError: + # trained with TP + mp0_weights = torch.load(tar.extractfile("./mp_rank_00/model_weights.ckpt"), map_location=device) + + vision_config = nemo_config["mm_cfg"]["vision_encoder"] + + class VisionEncoderWrapper(torch.nn.Module): + + def __init__(self, encoder, connector): + super().__init__() + self.encoder = encoder + self.connector = connector + + def forward(self, images): + b, num_frames, c, h, w = images.shape + images = images.view(b * num_frames, c, h, w) + vision_x = self.encoder(pixel_values=images, output_hidden_states=True) # [(B num_frames), C, H, W] + vision_x = vision_x.hidden_states[-2] + vision_x = vision_x[:, 1:] + + # reshape back to [B, num_frames, img_size, hidden_size] + vision_x = vision_x.view(b, num_frames, -1, vision_x.shape[-1]) + + vision_x = self.connector(vision_x) + return vision_x + + encoder = AutoModel.from_pretrained( + vision_config["from_pretrained"], torch_dtype=torch.bfloat16, trust_remote_code=True + ) + vision_encoder = encoder.vision_model + hf_config = encoder.config + dtype = hf_config.torch_dtype + + # connector + assert nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "linear" + vision_connector = torch.nn.Linear(vision_config["hidden_size"], nemo_config["hidden_size"], bias=True) + + key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector" + vision_connector.load_state_dict( + { + 'weight': mp0_weights[f"{key_prefix}.weight"].to(dtype), + 'bias': mp0_weights[f"{key_prefix}.bias"].to(dtype), + } + ) + + # export the whole wrapper + wrapper = VisionEncoderWrapper(vision_encoder, vision_connector).to(device, dtype) + image_size = hf_config.vision_config.image_size + num_frames = nemo_config['data']['num_frames'] + dummy_video = torch.empty(1, num_frames, 3, image_size, image_size, dtype=dtype, device=device) # dummy image + export_visual_wrapper_onnx(wrapper, dummy_video, model_dir) + build_trt_engine( + "video-neva", + [num_frames, 3, image_size, image_size], # [num_frames, 3, H, W] + model_dir, + max_batch_size, + dtype, + image_size=image_size, + num_frames=num_frames, + ) + + +def build_visual_engine( + model_dir: str, + visual_checkpoint_path: str, + model_type: str = "neva", + max_batch_size: int = 1, +): + if model_type == "neva": + build_neva_engine(model_dir, visual_checkpoint_path, max_batch_size) + elif model_type == "video-neva": + build_video_neva_engine(model_dir, visual_checkpoint_path, max_batch_size) + else: + raise RuntimeError(f"Invalid model type {model_type}") diff --git a/nemo/export/multimodal/run.py b/nemo/export/multimodal/run.py new file mode 100644 index 0000000000000..f94c2e3f39447 --- /dev/null +++ b/nemo/export/multimodal/run.py @@ -0,0 +1,483 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import os + +import numpy as np +import tensorrt as trt +import tensorrt_llm +import tensorrt_llm.profiler as profiler +import torch +from PIL import Image +from tensorrt_llm import logger +from tensorrt_llm._utils import str_dtype_to_trt +from tensorrt_llm.runtime import ModelRunner, Session, TensorInfo +from torchvision import transforms +from transformers import CLIPImageProcessor + + +def trt_dtype_to_torch(dtype): + if dtype == trt.float16: + return torch.float16 + elif dtype == trt.float32: + return torch.float32 + elif dtype == trt.int32: + return torch.int32 + elif dtype == trt.bfloat16: + return torch.bfloat16 + else: + raise TypeError("%s is not supported" % dtype) + + +class MultimodalModelRunner: + + def __init__(self, visual_engine_dir, llm_engine_dir): + self.runtime_rank = tensorrt_llm.mpi_rank() + device_id = self.runtime_rank % torch.cuda.device_count() + torch.cuda.set_device(device_id) + self.device = "cuda:%d" % (device_id) + + self.stream = torch.cuda.Stream(torch.cuda.current_device()) + torch.cuda.set_stream(self.stream) + + # parse model type from visual engine config + with open(os.path.join(visual_engine_dir, "config.json"), "r") as f: + config = json.load(f) + self.model_type = config['builder_config']['model_type'] + self.vision_precision = config['builder_config']['precision'] + + self.num_frames = config['builder_config'].get('num_frames', None) + self.image_size = config['builder_config'].get('image_size', None) + + self.profiling_iterations = 20 + + self.init_image_encoder(visual_engine_dir) + self.init_tokenizer(llm_engine_dir) + self.init_llm(llm_engine_dir) + + def init_tokenizer(self, llm_engine_dir): + if os.path.exists(os.path.join(llm_engine_dir, 'huggingface_tokenizer')): + from transformers import AutoTokenizer + + self.tokenizer = AutoTokenizer.from_pretrained(os.path.join(llm_engine_dir, 'huggingface_tokenizer')) + self.tokenizer.pad_token = self.tokenizer.eos_token + else: + from sentencepiece import SentencePieceProcessor + + sp = SentencePieceProcessor(os.path.join(llm_engine_dir, 'tokenizer.model')) + + class return_obj: + + def __init__(self, input_ids): + self.input_ids = input_ids + + def __getitem__(self, name): + if name in "input_ids": + return self.input_ids + else: + raise AttributeError(f"'return_obj' has no item '{name}'") + + # sentencepiece does not follow the same interface as HF + class HFTokenizerInterface: + + def encode(self, x, return_tensors=None, **kwargs): + out = sp.encode(x) + if return_tensors == "pt": + out = torch.tensor(out) + return return_obj(out) + + def __call__(self, x, return_tensors=None, **kwargs): + return self.encode(x, return_tensors, **kwargs) + + def decode(self, x, **kwargs): + return sp.decode(x.tolist()) + + def batch_decode(self, x, **kwargs): + return self.decode(x, **kwargs) + + self.tokenizer = HFTokenizerInterface() + self.tokenizer.eos_token_id = sp.eos_id() + self.tokenizer.bos_token_id = sp.bos_id() + self.tokenizer.pad_token_id = sp.pad_id() + + self.tokenizer.padding_side = "right" + + def init_image_encoder(self, visual_engine_dir): + vision_encoder_path = os.path.join(visual_engine_dir, 'visual_encoder.engine') + logger.info(f'Loading engine from {vision_encoder_path}') + with open(vision_encoder_path, 'rb') as f: + engine_buffer = f.read() + logger.info(f'Creating session from engine {vision_encoder_path}') + self.visual_encoder_session = Session.from_serialized_engine(engine_buffer) + + def init_llm(self, llm_engine_dir): + self.model = ModelRunner.from_dir( + llm_engine_dir, rank=tensorrt_llm.mpi_rank(), debug_mode=False, stream=self.stream + ) + self.model_config = self.model.session._model_config + self.runtime_mapping = self.model.session.mapping + + def video_preprocess(self, video_path): + from decord import VideoReader + + if isinstance(video_path, str): + vr = VideoReader(video_path) + num_frames = self.num_frames + if num_frames == -1: + frames = [Image.fromarray(frame.asnumpy()[:, :, ::-1]).convert('RGB') for frame in vr] + else: + # equally sliced frames into self.num_frames frames + # if self.num_frames is greater than the number of frames in the video, we will repeat the last frame + num_frames = min(num_frames, len(vr)) + indices = np.linspace(0, len(vr) - 1, num=num_frames, dtype=int) + frames = [Image.fromarray(vr[idx].asnumpy()[:, :, ::-1]).convert('RGB') for idx in indices] + if len(frames) < num_frames: + frames += [frames[-1]] * (num_frames - len(frames)) + elif isinstance(video_path, np.ndarray): + num_frames = self.num_frames + if num_frames == -1: + frames = [Image.fromarray(frame[:, :, ::-1]).convert('RGB') for frame in video_path] + else: + # equally sliced frames into self.num_frames frames + # if self.num_frames is greater than the number of frames in the video, we will repeat the last frame + num_frames = min(num_frames, video_path.shape[0]) + indices = np.linspace(0, video_path.shape[0] - 1, num=num_frames, dtype=int) + frames = [Image.fromarray(video_path[idx][:, :, ::-1]).convert('RGB') for idx in indices] + if len(frames) < num_frames: + frames += [frames[-1]] * (num_frames - len(frames)) + else: + frames = self.video_path + + processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.bfloat16) + frames = processor.preprocess(frames, return_tensors="pt")['pixel_values'] + # make dtype consistent with vision encoder + media_tensors = frames.to( + tensorrt_llm._utils.str_dtype_to_torch(self.vision_precision) + ) # [num_frames, 3, H, W] + return media_tensors.unsqueeze(0) # [1, num_frames, 3, H, W] + + def preprocess(self, warmup, pre_prompt, post_prompt, image, attention_mask, batch_size): + if not warmup: + profiler.start("Vision") + + visual_features, visual_atts = self.get_visual_features(image, attention_mask) + + if not warmup: + profiler.stop("Vision") + + pre_input_ids = self.tokenizer(pre_prompt, return_tensors="pt", padding=True).input_ids + if post_prompt[0] is not None: + post_input_ids = self.tokenizer(post_prompt, return_tensors="pt", padding=True).input_ids + if self.model_type == 'video-neva': + length = pre_input_ids.shape[1] + post_input_ids.shape[1] + visual_atts.shape[2] * visual_atts.shape[1] + else: + length = pre_input_ids.shape[1] + post_input_ids.shape[1] + visual_atts.shape[1] + else: + post_input_ids = None + length = pre_input_ids.shape[1] + visual_atts.shape[1] + + input_lengths = torch.IntTensor([length] * batch_size).to(torch.int32) + + input_ids, ptuning_args = self.setup_fake_prompts( + visual_features, pre_input_ids, post_input_ids, input_lengths + ) + + return input_ids, input_lengths, ptuning_args, visual_features + + def generate( + self, + pre_prompt, + post_prompt, + image, + decoder_input_ids, + max_new_tokens, + attention_mask, + warmup, + batch_size, + top_k, + top_p, + temperature, + repetition_penalty, + num_beams, + ): + if not warmup: + profiler.start("Generate") + + input_ids, input_lengths, ptuning_args, visual_features = self.preprocess( + warmup, pre_prompt, post_prompt, image, attention_mask, batch_size + ) + + if warmup: + return None + + profiler.start("LLM") + end_id = self.tokenizer.eos_token_id + + ptuning_args[0] = torch.stack([ptuning_args[0]]) + output_ids = self.model.generate( + input_ids, + sampling_config=None, + prompt_table=ptuning_args[0], + max_new_tokens=max_new_tokens, + end_id=end_id, + pad_id=( + self.tokenizer.pad_token_id + if self.tokenizer.pad_token_id is not None + else self.tokenizer.all_special_ids[0] + ), + top_k=top_k, + top_p=top_p, + temperature=temperature, + repetition_penalty=repetition_penalty, + num_beams=num_beams, + output_sequence_lengths=False, + return_dict=False, + ) + + profiler.stop("LLM") + + if tensorrt_llm.mpi_rank() == 0: + # Extract a list of tensors of shape beam_width x output_ids. + output_beams_list = [ + self.tokenizer.batch_decode( + output_ids[batch_idx, :, input_lengths[batch_idx] :], skip_special_tokens=True + ) + for batch_idx in range(batch_size) + ] + + stripped_text = [ + [output_beams_list[batch_idx][beam_idx].strip() for beam_idx in range(num_beams)] + for batch_idx in range(batch_size) + ] + profiler.stop("Generate") + return stripped_text + else: + profiler.stop("Generate") + return None + + def get_visual_features(self, image, attention_mask): + visual_features = {'input': image.to(tensorrt_llm._utils.str_dtype_to_torch(self.vision_precision))} + if attention_mask is not None: + visual_features['attention_mask'] = attention_mask + tensor_info = [TensorInfo('input', str_dtype_to_trt(self.vision_precision), image.shape)] + if attention_mask is not None: + tensor_info.append(TensorInfo('attention_mask', trt.DataType.INT32, attention_mask.shape)) + + visual_output_info = self.visual_encoder_session.infer_shapes(tensor_info) + + visual_outputs = { + t.name: torch.empty(tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device=image.device) + for t in visual_output_info + } + + ok = self.visual_encoder_session.run(visual_features, visual_outputs, self.stream.cuda_stream) + assert ok, "Runtime execution failed for vision encoder session" + self.stream.synchronize() + + image_embeds = visual_outputs['output'] + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) + + return image_embeds, image_atts + + def setup_fake_prompts(self, visual_features, pre_input_ids, post_input_ids, input_lengths): + # Assemble fake prompts which points to image embedding actually + if hasattr(self, 'num_frames') and (visual_features.shape[1] == self.num_frames): + visual_features = visual_features.view(visual_features.shape[0], -1, visual_features.shape[-1]) + + fake_prompt_id = torch.arange( + self.model_config.vocab_size, + self.model_config.vocab_size + visual_features.shape[0] * visual_features.shape[1], + ) + fake_prompt_id = fake_prompt_id.reshape(visual_features.shape[0], visual_features.shape[1]) + + if post_input_ids is not None: + input_ids = [pre_input_ids, fake_prompt_id, post_input_ids] + else: + input_ids = [fake_prompt_id, pre_input_ids] + input_ids = torch.cat(input_ids, dim=1).contiguous().to(torch.int32) + + ptuning_args = self.ptuning_setup(visual_features, input_ids, input_lengths) + + return input_ids, ptuning_args + + def ptuning_setup(self, prompt_table, input_ids, input_lengths): + hidden_size = self.model_config.hidden_size * self.runtime_mapping.tp_size + if prompt_table is not None: + task_vocab_size = torch.tensor( + [prompt_table.shape[1]], + dtype=torch.int32, + ).cuda() + prompt_table = prompt_table.view((prompt_table.shape[0] * prompt_table.shape[1], prompt_table.shape[2])) + + assert prompt_table.shape[1] == hidden_size, "Prompt table dimensions do not match hidden size" + + prompt_table = prompt_table.cuda().to( + dtype=tensorrt_llm._utils.str_dtype_to_torch(self.model_config.dtype) + ) + else: + prompt_table = torch.empty([1, hidden_size]).cuda() + task_vocab_size = torch.zeros([1]).cuda() + + if self.model_config.remove_input_padding: + tasks = torch.zeros([torch.sum(input_lengths)], dtype=torch.int32).cuda() + else: + tasks = torch.zeros(input_ids.shape, dtype=torch.int32).cuda() + + return [prompt_table, tasks, task_vocab_size] + + def setup_inputs(self, input_text, raw_image, batch_size): + attention_mask = None + + if self.model_type == "neva": + image_size = self.image_size + dtype = torch.float32 + transform = transforms.Compose( + [ + transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ) + image = transform(raw_image).to(dtype).unsqueeze(0) + + if input_text is None: + input_text = "Hi! What is in this image?" + + pre_prompt = "System\n\nUser\n" + post_prompt = f"\n{input_text}\nAssistant\n" + elif self.model_type == "video-neva": + image = self.video_preprocess(raw_image) # shape (1, num_frames, 3, H, W) + + if input_text is None: + input_text = "Hi! What is in this video?" + + # SteerLM prompt template + pre_prompt = """System\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\nUser""" + post_prompt = ( + f"\n{input_text}\nAssistant\nquality:4,toxicity:0,humor:0,creativity:0,helpfulness:4,correctness:4,coherence:4,complexity:4,verbosity:4\n" + "" + ) + else: + raise RuntimeError(f"Invalid model type {self.model_type}") + + # Repeat inputs to match batch size + pre_prompt = [pre_prompt] * batch_size + post_prompt = [post_prompt] * batch_size + if image.dim() == 5: + image = image.expand(batch_size, -1, -1, -1, -1).contiguous() + else: + image = image.expand(batch_size, -1, -1, -1).contiguous() + image = image.to(self.device) + + # Generate decoder_input_ids for enc-dec models + # Custom prompts can be added as: + # decoder_input_ids = model.tokenizer(decoder_prompt).input_ids + decoder_input_ids = None + + return input_text, pre_prompt, post_prompt, image, decoder_input_ids, attention_mask + + def run( + self, + input_text, + input_image, + max_new_tokens, + batch_size, + top_k, + top_p, + temperature, + repetition_penalty, + num_beams, + run_profiling=False, + check_accuracy=False, + ): + input_text, pre_prompt, post_prompt, processed_image, decoder_input_ids, attention_mask = self.setup_inputs( + input_text, input_image, batch_size + ) + + self.generate( + pre_prompt, + post_prompt, + processed_image, + decoder_input_ids, + max_new_tokens, + attention_mask=attention_mask, + warmup=True, + batch_size=batch_size, + top_k=top_k, + top_p=top_p, + temperature=temperature, + repetition_penalty=repetition_penalty, + num_beams=num_beams, + ) + num_iters = self.profiling_iterations if run_profiling else 1 + for _ in range(num_iters): + output_text = self.generate( + pre_prompt, + post_prompt, + processed_image, + decoder_input_ids, + max_new_tokens, + attention_mask=attention_mask, + warmup=False, + batch_size=batch_size, + top_k=top_k, + top_p=top_p, + temperature=temperature, + repetition_penalty=repetition_penalty, + num_beams=num_beams, + ) + if self.runtime_rank == 0: + self.print_result(input_text, output_text, batch_size, num_beams, run_profiling, check_accuracy) + return output_text + + def print_result(self, input_text, output_text, batch_size, num_beams, run_profiling, check_accuracy): + if not run_profiling and not check_accuracy: + return + logger.info("---------------------------------------------------------") + if self.model_type != 'nougat': + logger.info(f"\n[Q] {input_text}") + logger.info(f"\n[A] {output_text[0]}") + + if num_beams == 1: + output_ids = self.tokenizer(output_text[0][0], add_special_tokens=False)['input_ids'] + logger.info(f"Generated {len(output_ids)} tokens") + + if check_accuracy: + for i in range(batch_size - 1): + if not (output_text[i] == output_text[i + 1]): + logger.info(f"Output {i} and {i + 1} do not match") + assert False + + assert 'robot' in output_text[0][0].lower() + + if run_profiling: + msec_per_batch = lambda name: 1000 * profiler.elapsed_time_in_sec(name) / self.profiling_iterations + logger.info('Latencies per batch (msec)') + logger.info('TRT vision encoder: %.1f' % (msec_per_batch('Vision'))) + logger.info('TRTLLM LLM generate: %.1f' % (msec_per_batch('LLM'))) + logger.info('Multimodal generate: %.1f' % (msec_per_batch('Generate'))) + + logger.info("---------------------------------------------------------") + + def load_test_media(self, input_media): + if self.model_type == "video-neva": + media = input_media + elif self.model_type == "neva": + media = Image.open(input_media).convert('RGB') + else: + raise RuntimeError(f"Invalid model type {self.model_type}") + + return media diff --git a/nemo/export/quantize/quantizer.py b/nemo/export/quantize/quantizer.py index dee1e85345e49..e645ed8971c36 100644 --- a/nemo/export/quantize/quantizer.py +++ b/nemo/export/quantize/quantizer.py @@ -71,7 +71,7 @@ class Quantizer: Available quantization methods are listed in `QUANT_CFG_CHOICES` dictionary above. Please consult Model Optimizer documentation https://nvidia.github.io/TensorRT-Model-Optimizer/ for details. - You can also inspect different choices in examples/nlp/language_modeling/conf/megatron_gpt_quantization.yaml + You can also inspect different choices in examples/nlp/language_modeling/conf/megatron_gpt_ptq.yaml for quantization algorithms and calibration data as well as recommended settings. Quantization algorithm can also be conveniently set to 'null' to perform only weights export step @@ -86,6 +86,7 @@ def __init__(self, quantization_config: Optional[DictConfig], export_config: Opt - decoder_type: str - awq_block_size: int (only for awq algorithms) - sq_alpha: float (only for smooth quant algorithms) + - enable_kv_cache: bool (default: None i.e. auto-detect based on algorithm and decoder_type) Expected keys in `export_config`: - dtype: str/int @@ -116,9 +117,11 @@ def __init__(self, quantization_config: Optional[DictConfig], export_config: Opt # Always turn on FP8 kv cache to save memory footprint. # For int8_sq, we use int8 kv cache. # TODO: Investigate why enabling FP8 kv cache will cause accuracy regressions for Nemotron. - enable_quant_kv_cache = ( - "int8" not in quantization_config.algorithm and quantization_config.decoder_type != "gptnext" - ) + enable_quant_kv_cache = quantization_config.get("enable_kv_cache", None) + if enable_quant_kv_cache is None: + enable_quant_kv_cache = ( + "int8" not in quantization_config.algorithm and quantization_config.decoder_type != "gptnext" + ) logging.info(f'{"Enabled" if enable_quant_kv_cache else "Disabled"} KV cache quantization') quant_cfg["quant_cfg"]["*output_quantizer"] = { "num_bits": 8 if quantization_config.algorithm == "int8_sq" else (4, 3), @@ -229,9 +232,8 @@ def export(self, model: MegatronGPTModel): # Setup model export handling: temporary directory for # '.qnemo' tarball or directly write to export_config.save_path - # TODO [later]: consider a flag like `export_config.compress` - save_qnemo = self.export_config.save_path.endswith(".qnemo") - if save_qnemo: + compress = self.export_config.get("compress", False) + if compress: export_handler = temporary_directory() else: export_handler = nullcontext(enter_result=self.export_config.save_path) @@ -252,6 +254,6 @@ def export(self, model: MegatronGPTModel): ) if dist.get_rank() == 0: save_artifacts(model, export_dir) - if save_qnemo: + if compress: with tarfile.open(self.export_config.save_path, "w:gz") as tar: tar.add(export_dir, arcname="./") diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/sentencepiece_tokenizer.py b/nemo/export/sentencepiece_tokenizer.py similarity index 93% rename from nemo/export/trt_llm/nemo_ckpt_loader/sentencepiece_tokenizer.py rename to nemo/export/sentencepiece_tokenizer.py index 1f86c5887a5e8..e47b1c665af51 100644 --- a/nemo/export/trt_llm/nemo_ckpt_loader/sentencepiece_tokenizer.py +++ b/nemo/export/sentencepiece_tokenizer.py @@ -22,7 +22,7 @@ class SentencePieceTokenizer: """ - Sentencepiecetokenizer https://github.com/google/sentencepiece + SentencePieceTokenizer https://github.com/google/sentencepiece Args: model_path: path to sentence piece tokenizer model. @@ -247,3 +247,21 @@ def vocab(self): for i in range(self.vocab_size - self.original_vocab_size) ] return main_vocab + special_tokens + + ### Below are a few methods that mimic transformers.PreTrainedTokenizer for vLLM + + def convert_ids_to_tokens(self, ids, skip_special_tokens: bool = False): + return self.ids_to_tokens(ids) # TODO: support skip_special_tokens + + def convert_tokens_to_string(self, tokens: List[str]): + return self.tokens_to_text(tokens) + + def __len__(self): + return self.vocab_size + + @property + def is_fast(self): + return True + + def get_added_vocab(self): + return None diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index 7cc92f0ca588e..b4299dfd8945a 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -30,12 +30,19 @@ from nemo.deploy import ITritonDeployable from nemo.export.tarutils import TarPath, unpack_tarball from nemo.export.trt_llm.converter.model_converter import model_to_trtllm_ckpt -from nemo.export.trt_llm.nemo_ckpt_loader.nemo_file import get_tokenzier, is_nemo_file, load_nemo_model +from nemo.export.trt_llm.converter.model_to_trt_llm_ckpt import dist_model_to_trt_llm_ckpt +from nemo.export.trt_llm.converter.utils import init_model_parallel_from_nemo +from nemo.export.trt_llm.nemo_ckpt_loader.nemo_file import ( + build_tokenizer, + get_tokenzier, + is_nemo_file, + load_nemo_model, +) from nemo.export.trt_llm.qnemo import qnemo_to_tensorrt_llm from nemo.export.trt_llm.qnemo.tokenizer_utils import get_nmt_tokenizer from nemo.export.trt_llm.qnemo.utils import is_qnemo_checkpoint from nemo.export.trt_llm.tensorrt_llm_build import build_and_save_engine -from nemo.export.trt_llm.tensorrt_llm_run import generate, generate_streaming, load +from nemo.export.trt_llm.tensorrt_llm_run import generate, generate_streaming, load, load_distributed, refit use_deploy = True try: @@ -68,7 +75,7 @@ class TensorRTLLM(ITritonDeployable): Exports nemo checkpoints to TensorRT-LLM and run fast inference. Example: - from nemo.export import TensorRTLLM + from nemo.export.tensorrt_llm import TensorRTLLM trt_llm_exporter = TensorRTLLM(model_dir="/path/for/model/files") trt_llm_exporter.export( @@ -116,11 +123,11 @@ def __init__( def export( self, nemo_checkpoint_path: str, - model_type: str, + model_type: Optional[str] = None, delete_existing_files: bool = True, - n_gpus: int = 1, - tensor_parallel_size: int = None, - pipeline_parallel_size: int = None, + n_gpus: int = None, + tensor_parallelism_size: int = 1, + pipeline_parallelism_size: int = 1, gpus_per_node: int = None, max_input_len: int = 256, max_output_len: int = 256, @@ -132,6 +139,7 @@ def export( use_embedding_sharing: bool = False, paged_kv_cache: bool = True, remove_input_padding: bool = True, + paged_context_fmha: bool = False, dtype: str = "bfloat16", load_model: bool = True, enable_multi_block_mode: bool = False, @@ -140,18 +148,17 @@ def export( max_lora_rank: int = 64, max_num_tokens: int = None, opt_num_tokens: int = None, - save_nemo_model_config: bool = False, ): """ Exports nemo checkpoints to TensorRT-LLM. Args: nemo_checkpoint_path (str): path for the nemo checkpoint. - model_type (str): type of the model. Currently, "llama", "gptnext", "falcon", and "starcoder" are supported. - delete_existing_files (bool): if Truen, deletes all the files in model_dir. + model_type (str): type of the model (optional for quantized checkpoints). + delete_existing_files (bool): if True, deletes all the files in model_dir. n_gpus (int): number of GPUs to use for inference. - tensor_parallel_size (int): tensor parallelism. - pipeline_parallel_size (int): pipeline parallelism. + tensor_parallelism_size (int): tensor parallelism. + pipeline_parallelism_size (int): pipeline parallelism. gpus_per_node (int): number of gpus per node. max_input_len (int): max input length. max_output_len (int): max output length. @@ -162,6 +169,7 @@ def export( use_parallel_embedding (bool): whether to use parallel embedding feature of TRT-LLM or not use_embedding_sharing (bool): paged_kv_cache (bool): if True, uses kv cache feature of the TensorRT-LLM. + paged_context_fmha (bool): whether to use paged context fmha feature of TRT-LLM or not remove_input_padding (bool): enables removing input padding or not. dtype (str): Floating point type for model weights (Supports BFloat16/Float16). load_model (bool): load TensorRT-LLM model after the export. @@ -171,29 +179,18 @@ def export( max_lora_rank (int): maximum lora rank. max_num_tokens (int): opt_num_tokens (int): - save_nemo_model_config (bool): """ - if model_type not in self.get_supported_models_list: - raise Exception( - "Model {0} is not currently a supported model type. " - "Supported model types are llama, gptnext, falcon, and starcoder.".format(model_type) + if n_gpus is not None: + warnings.warn( + "Parameter n_gpus is deprecated and will be removed in the next release. " + "Please use tensor_parallelism_size and pipeline_parallelism_size parameters instead.", + DeprecationWarning, + stacklevel=2, ) + tensor_parallelism_size = n_gpus - if model_type == "gpt" or model_type == "starcoder": - model_type = "gptnext" - - if model_type == "mixtral": - model_type = "llama" - - if pipeline_parallel_size is None: - tensor_parallel_size = n_gpus - pipeline_parallel_size = 1 - elif tensor_parallel_size is None: - tensor_parallel_size = 1 - pipeline_parallel_size = n_gpus - - gpus_per_node = tensor_parallel_size if gpus_per_node is None else gpus_per_node + gpus_per_node = tensor_parallelism_size if gpus_per_node is None else gpus_per_node if Path(self.model_dir).exists(): if delete_existing_files and len(os.listdir(self.model_dir)) > 0: @@ -251,8 +248,8 @@ def export( max_output_len=max_output_len, max_batch_size=max_batch_size, max_prompt_embedding_table_size=max_prompt_embedding_table_size, - tensor_parallel_size=tensor_parallel_size, - pipeline_parallel_size=pipeline_parallel_size, + tensor_parallel_size=tensor_parallelism_size, + pipeline_parallel_size=pipeline_parallelism_size, use_parallel_embedding=use_parallel_embedding, paged_kv_cache=paged_kv_cache, remove_input_padding=remove_input_padding, @@ -264,6 +261,21 @@ def export( opt_num_tokens=opt_num_tokens, ) else: + if model_type is None: + raise Exception("model_type needs to be specified, got None.") + + if model_type not in self.get_supported_models_list: + raise Exception( + "Model {0} is not currently a supported model type. " + "Supported model types are: {1}.".format(model_type, self.get_supported_models_list) + ) + + if model_type == "gpt" or model_type == "starcoder": + model_type = "gptnext" + + if model_type == "mixtral": + model_type = "llama" + model, model_configs, self.tokenizer = load_nemo_model(nemo_checkpoint_path, nemo_export_dir) weights_dicts, model_configs = model_to_trtllm_ckpt( model=model, @@ -271,8 +283,8 @@ def export( nemo_export_dir=nemo_export_dir, decoder_type=model_type, dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - pipeline_parallel_size=pipeline_parallel_size, + tensor_parallel_size=tensor_parallelism_size, + pipeline_parallel_size=pipeline_parallelism_size, gpus_per_node=gpus_per_node, use_parallel_embedding=use_parallel_embedding, use_embedding_sharing=use_embedding_sharing, @@ -295,6 +307,7 @@ def export( enable_multi_block_mode=enable_multi_block_mode, paged_kv_cache=paged_kv_cache, remove_input_padding=remove_input_padding, + paged_context_fmha=paged_context_fmha, max_num_tokens=max_num_tokens, opt_num_tokens=opt_num_tokens, ) @@ -317,6 +330,80 @@ def export( if load_model: self._load() + def build( + self, + model, + model_config, + model_type, + gpus_per_node, + tokenizer, + max_input_len: int = 1024, + max_output_len: int = 1024, + max_batch_size: int = 4, + use_refit: bool = True, + reshard_model: bool = False, + ): + """ + Convert a model parallel nemo model to TensorRT-LLM. + """ + assert tensorrt_llm.mpi_rank() == torch.distributed.get_rank() + self.use_refit, self.model_type, self.gpus_per_node = use_refit, model_type, gpus_per_node + self.mp_rank, self.dp_rank, self.tp_size, self.pp_size, self.dp_size = init_model_parallel_from_nemo( + reshard_model + ) + self.tokenizer = build_tokenizer(tokenizer) + + if self.dp_size > 1: + self.model_dir = os.path.join(self.model_dir, f"dp_rank{self.dp_rank}") + + weights, model_config = model_to_trtllm_ckpt( + model=model, + nemo_model_config=model_config, + nemo_export_dir=self.model_dir, + decoder_type=model_type, + tensor_parallel_size=self.tp_size, + pipeline_parallel_size=self.pp_size, + gpus_per_node=gpus_per_node, + use_parallel_embedding=True, + use_distributed_convert=True, + model_parallel_rank=self.mp_rank, + vocab_size=self.tokenizer.vocab_size, + ) + + engine = build_and_save_engine( + max_input_len=max_input_len, + max_output_len=max_output_len, + max_batch_size=max_batch_size, + model_config=model_config[0], + model_weights=weights[0], + model_dir=self.model_dir, + model_type=model_type, + custom_all_reduce=False, + use_refit=use_refit, + ) + torch.distributed.barrier() + + cfg_path = Path(os.path.join(self.model_dir, f'config_{torch.distributed.get_rank()}.json')) + with open(cfg_path, "w", encoding="utf-8") as f: + json.dump(engine.config.to_dict(), f, indent=4) + + load_distributed(self.model_dir, self.mp_rank, gpus_per_node) + + def refit(self, model, model_config): + """ + Refits an TensorRT engine using an instantiated nemo model. + This function should only be used after calling build() + """ + weights_dict = dist_model_to_trt_llm_ckpt( + model=model, + nemo_model_config=model_config, + inference_tp_size=self.tp_size, + inference_pp_size=self.pp_size, + tokenizer_vocab_size=self.tokenizer.vocab_size, + ) + load_distributed(self.model_dir, self.mp_rank, self.gpus_per_node) + refit(weights_dict) + def forward( self, input_texts: List[str], diff --git a/nemo/export/tensorrt_mm_exporter.py b/nemo/export/tensorrt_mm_exporter.py new file mode 100644 index 0000000000000..13bc82b393343 --- /dev/null +++ b/nemo/export/tensorrt_mm_exporter.py @@ -0,0 +1,225 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import shutil +from pathlib import Path + +import numpy as np +import wrapt + +from nemo.deploy import ITritonDeployable +from nemo.export.multimodal.build import build_trtllm_engine, build_visual_engine +from nemo.export.multimodal.run import MultimodalModelRunner + +use_deploy = True +try: + from nemo.deploy.utils import cast_output, ndarray2img, str_ndarray2list +except Exception: + use_deploy = False + + +@wrapt.decorator +def noop_decorator(func): + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + +use_pytriton = True +batch = noop_decorator +try: + from pytriton.decorators import batch + from pytriton.model_config import Tensor +except Exception: + use_pytriton = False + + +LOGGER = logging.getLogger("NeMo") + + +class TensorRTMMExporter(ITritonDeployable): + """ + Exports nemo checkpoints to TensorRT and run fast inference. + + Example: + from nemo.export import TensorRTMMExporter + + exporter = TensorRTMMExporter(model_dir="/path/for/model/files") + exporter.export( + visual_checkpoint_path="/path/for/nemo/checkpoint", + model_type="neva", + tensor_parallel_size=1, + ) + + output = exporter.forward("Hi! What is in this image?", "/path/for/input_media") + print("output: ", output) + + """ + + def __init__( + self, + model_dir: str, + load_model: bool = True, + ): + self.model_dir = model_dir + self.runner = None + + if load_model: + self._load() + + def export( + self, + visual_checkpoint_path: str, + llm_checkpoint_path: str = None, + model_type: str = "neva", + llm_model_type: str = "llama", + tensor_parallel_size: int = 1, + max_input_len: int = 4096, + max_output_len: int = 256, + max_batch_size: int = 1, + max_multimodal_len: int = 3072, + dtype: str = "bfloat16", + delete_existing_files: bool = True, + load_model: bool = True, + ): + if Path(self.model_dir).exists(): + if delete_existing_files and len(os.listdir(self.model_dir)) > 0: + for files in os.listdir(self.model_dir): + path = os.path.join(self.model_dir, files) + try: + shutil.rmtree(path) + except OSError: + os.remove(path) + + if len(os.listdir(self.model_dir)) > 0: + raise Exception("Couldn't delete all files.") + elif len(os.listdir(self.model_dir)) > 0: + raise Exception("There are files in this folder. Try setting delete_existing_files=True.") + else: + Path(self.model_dir).mkdir(parents=True, exist_ok=True) + + llm_dir = os.path.join(self.model_dir, "llm_engine") + build_trtllm_engine( + model_dir=llm_dir, + visual_checkpoint_path=visual_checkpoint_path, + llm_checkpoint_path=llm_checkpoint_path, + model_type=model_type, + llm_model_type=llm_model_type, + tensor_parallel_size=tensor_parallel_size, + max_input_len=max_input_len, + max_output_len=max_output_len, + max_batch_size=max_batch_size, + max_multimodal_len=max_multimodal_len, + dtype=dtype, + ) + + visual_dir = os.path.join(self.model_dir, "visual_engine") + build_visual_engine(visual_dir, visual_checkpoint_path, model_type, max_batch_size) + + if load_model: + self._load() + + def forward( + self, + input_text: str, + input_media: str, + batch_size: int = 1, + max_output_len: int = 30, + top_k: int = 1, + top_p: float = 0.0, + temperature: float = 1.0, + repetition_penalty: float = 1.0, + num_beams: int = 1, + ): + if self.runner is None: + raise Exception( + "A nemo checkpoint should be exported and " "then it should be loaded first to run inference." + ) + + input_media = self.runner.load_test_media(input_media) + return self.runner.run( + input_text, + input_media, + max_output_len, + batch_size, + top_k, + top_p, + temperature, + repetition_penalty, + num_beams, + ) + + @property + def get_triton_input(self): + inputs = ( + Tensor(name="input_text", shape=(-1,), dtype=bytes), + Tensor(name="input_media", shape=(-1, -1, -1, 3), dtype=np.uint8), + Tensor(name="batch_size", shape=(-1,), dtype=np.int_, optional=True), + Tensor(name="max_output_len", shape=(-1,), dtype=np.int_, optional=True), + Tensor(name="top_k", shape=(-1,), dtype=np.int_, optional=True), + Tensor(name="top_p", shape=(-1,), dtype=np.single, optional=True), + Tensor(name="temperature", shape=(-1,), dtype=np.single, optional=True), + Tensor(name="repetition_penalty", shape=(-1,), dtype=np.single, optional=True), + Tensor(name="num_beams", shape=(-1,), dtype=np.int_, optional=True), + ) + return inputs + + @property + def get_triton_output(self): + outputs = (Tensor(name="outputs", shape=(-1,), dtype=bytes),) + return outputs + + @batch + def triton_infer_fn(self, **inputs: np.ndarray): + try: + if self.runner is None: + raise Exception( + "A nemo checkpoint should be exported and " "then it should be loaded first to run inference." + ) + + infer_input = {"input_text": str_ndarray2list(inputs.pop("input_text")[0])} + if self.runner.model_type == "neva": + infer_input["input_image"] = ndarray2img(inputs.pop("input_media")[0])[0] + elif self.runner.model_type == "video-neva": + infer_input["input_image"] = inputs.pop("input_media")[0] + if "batch_size" in inputs: + infer_input["batch_size"] = inputs.pop("batch_size")[0][0] + if "max_output_len" in inputs: + infer_input["max_new_tokens"] = inputs.pop("max_output_len")[0][0] + if "top_k" in inputs: + infer_input["top_k"] = inputs.pop("top_k")[0][0] + if "top_p" in inputs: + infer_input["top_p"] = inputs.pop("top_p")[0][0] + if "temperature" in inputs: + infer_input["temperature"] = inputs.pop("temperature")[0][0] + if "repetition_penalty" in inputs: + infer_input["repetition_penalty"] = inputs.pop("repetition_penalty")[0][0] + if "num_beams" in inputs: + infer_input["num_beams"] = inputs.pop("num_beams")[0][0] + + output_texts = self.runner.run(**infer_input) + output = cast_output(output_texts, np.bytes_) + except Exception as error: + err_msg = "An error occurred: {0}".format(str(error)) + output = cast_output([err_msg], np.bytes_) + + return {"outputs": output} + + def _load(self): + llm_dir = os.path.join(self.model_dir, "llm_engine") + visual_dir = os.path.join(self.model_dir, "visual_engine") + self.runner = MultimodalModelRunner(visual_dir, llm_dir) diff --git a/nemo/export/trt_llm/converter/model_converter.py b/nemo/export/trt_llm/converter/model_converter.py index da13449160f95..2a78f68337821 100644 --- a/nemo/export/trt_llm/converter/model_converter.py +++ b/nemo/export/trt_llm/converter/model_converter.py @@ -24,7 +24,10 @@ from tensorrt_llm.layers import MoeConfig from tensorrt_llm.models.modeling_utils import PretrainedConfig -from nemo.export.trt_llm.converter.model_to_trt_llm_ckpt import convert_model_to_trt_llm_ckpt +from nemo.export.trt_llm.converter.model_to_trt_llm_ckpt import ( + convert_model_to_trt_llm_ckpt, + dist_model_to_trt_llm_ckpt, +) from nemo.export.trt_llm.converter.utils import DECODER_MODEL_TYPE, split LOGGER = logging.getLogger("NeMo") @@ -75,6 +78,9 @@ def model_to_trtllm_ckpt( gpus_per_node: int = None, use_parallel_embedding: bool = False, use_embedding_sharing: bool = False, + use_distributed_convert: bool = False, + model_parallel_rank: int = None, + vocab_size: int = None, ) -> Tuple[List[Dict], List[PretrainedConfig]]: if nemo_model_config.get("share_embeddings_and_output_weights", False) and not use_embedding_sharing: @@ -83,30 +89,40 @@ def model_to_trtllm_ckpt( ) use_embedding_sharing = True - weights_dict = convert_model_to_trt_llm_ckpt( - model=model, - nemo_model_config=nemo_model_config, - nemo_export_dir=nemo_export_dir, - inference_tp_size=tensor_parallel_size, - processes=1, - storage_type=dtype, - use_parallel_embedding=use_parallel_embedding, - decoder_type=decoder_type, - ) - - world_size = tensor_parallel_size * pipeline_parallel_size - - has_lm_head = "lm_head.weight" in weights_dict - if has_lm_head: - lm_head_weight = weights_dict["lm_head.weight"] + # If the model has been sharded with model parallelism, convert the model in a gpu-distributed manner + if use_distributed_convert: + weights_dict = dist_model_to_trt_llm_ckpt( + model=model, + nemo_model_config=nemo_model_config, + inference_tp_size=tensor_parallel_size, + inference_pp_size=pipeline_parallel_size, + tokenizer_vocab_size=vocab_size, + ) + vocab_size_padded = vocab_size + else: + weights_dict = convert_model_to_trt_llm_ckpt( + model=model, + nemo_model_config=nemo_model_config, + nemo_export_dir=nemo_export_dir, + inference_tp_size=tensor_parallel_size, + processes=1, + storage_type=dtype, + use_parallel_embedding=use_parallel_embedding, + decoder_type=decoder_type, + ) - vocab_size = weights_dict["transformer.vocab_embedding.weight"].shape[0] - vocab_size_padded = pad_vocab_size(vocab_size, tensor_parallel_size) if has_lm_head else vocab_size + has_lm_head = "lm_head.weight" in weights_dict + if has_lm_head: + lm_head_weight = weights_dict["lm_head.weight"] + if vocab_size is None: + vocab_size = weights_dict["transformer.vocab_embedding.weight"].shape[0] + vocab_size_padded = pad_vocab_size(vocab_size, tensor_parallel_size) if has_lm_head else vocab_size - if has_lm_head and vocab_size_padded != vocab_size: - pad_width = vocab_size_padded - vocab_size - lm_head_weight = np.pad(lm_head_weight, ((0, pad_width), (0, 0)), "constant", constant_values=0) + if has_lm_head and vocab_size_padded != vocab_size: + pad_width = vocab_size_padded - vocab_size + lm_head_weight = np.pad(lm_head_weight, ((0, pad_width), (0, 0)), "constant", constant_values=0) + world_size = tensor_parallel_size * pipeline_parallel_size hidden_act = nemo_model_config.get('activation') hidden_act = ( hidden_act.split("-")[-1] if nemo_model_config.get('num_moe_experts', 0) else non_gated_version(hidden_act) @@ -150,7 +166,6 @@ def model_to_trtllm_ckpt( 'tp_size': tensor_parallel_size, 'pp_size': pipeline_parallel_size, } - model_configs = [] weights_dicts = [] num_layers = nemo_model_config.get('num_layers') @@ -162,6 +177,18 @@ def model_to_trtllm_ckpt( if rotary_scaling is not None: config["rotary_scaling"] = {"type": "linear", "factor": float(rotary_scaling)} + if use_distributed_convert: + config["gpus_per_node"] = gpus_per_node + model_configs.append(PretrainedConfig(**config)) + model_configs[0].mapping = tensorrt_llm.Mapping( + world_size=world_size, + rank=model_parallel_rank, + tp_size=tensor_parallel_size, + pp_size=pipeline_parallel_size, + ) + weights_dicts.append(weights_dict) + return weights_dicts, model_configs + pp_key = { "transformer.vocab_embedding.weight", "transformer.position_embedding.weight", diff --git a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py index c29edc87353e7..0345f979b8c27 100644 --- a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py +++ b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py @@ -24,7 +24,8 @@ from tensorrt_llm._utils import pad_vocab_size, str_dtype_to_torch, torch_to_numpy from tqdm import tqdm -from nemo.export.trt_llm.converter.utils import split_and_save_weight +from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision +from nemo.export.trt_llm.converter.utils import save_val, split_and_save_weight, weights_dict LOGGER = logging.getLogger("NeMo") @@ -68,26 +69,29 @@ def get_layer_prefix(layer_names, is_mcore): return model_prefix, transformer_layer_prefix +def rename_key(new_key: str): + if "self_attention" in new_key: + new_key = new_key.replace("self_attention", "attention") + if "attention.linear_qkv.layer_norm_weight" in new_key: + new_key = new_key.replace("attention.linear_qkv.layer_norm_weight", "input_layernorm.weight") + if "attention.linear_qkv.layer_norm_bias" in new_key: + new_key = new_key.replace("attention.linear_qkv.layer_norm_bias", "input_layernorm.bias") + if "mlp.linear_fc1.layer_norm_weight" in new_key: + new_key = new_key.replace("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight") + if "mlp.linear_fc1.layer_norm_bias" in new_key: + new_key = new_key.replace("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias") + + return new_key + + def rename_key_dist_ckpt(old_key: str, layer: int): new_key = old_key - if "layers." in old_key: split_key = old_key.split(".") split_key.insert(1, str(layer)) new_key = ".".join(split_key) - if "self_attention" in new_key: - new_key = new_key.replace("self_attention", "attention") - if "attention.linear_qkv.layer_norm_weight" in new_key: - new_key = new_key.replace("attention.linear_qkv.layer_norm_weight", "input_layernorm.weight") - if "attention.linear_qkv.layer_norm_bias" in new_key: - new_key = new_key.replace("attention.linear_qkv.layer_norm_bias", "input_layernorm.bias") - if "mlp.linear_fc1.layer_norm_weight" in new_key: - new_key = new_key.replace("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight") - if "mlp.linear_fc1.layer_norm_bias" in new_key: - new_key = new_key.replace("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias") - - return new_key + return rename_key(new_key) @torch.no_grad() @@ -238,6 +242,223 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int): return weights_dict +def _get_layer_index(split_key): + for index, key in enumerate(split_key): + if key == "layers": + return index + 1 + raise ValueError(f"Unknown layer name format: {split_key}") + + +def rename_layer_num(param_name, layer_num): + split_key = param_name.split(".") + layer_index = int(_get_layer_index(split_key)) + split_key[layer_index] = str(layer_num) + return ".".join(split_key) + + +def get_layer_num(param_name): + split_key = param_name.split(".") + layer_index = int(_get_layer_index(split_key)) + return int(split_key[layer_index]) + + +@torch.no_grad() +def dist_model_to_trt_llm_ckpt( + model, + nemo_model_config, + inference_tp_size, + inference_pp_size, + tokenizer_vocab_size, +): + from megatron.core import parallel_state + from megatron.core.tensor_parallel.utils import VocabUtility + + tp_rank = parallel_state.get_tensor_model_parallel_rank() + tp_size = parallel_state.get_tensor_model_parallel_world_size() + tp_group = parallel_state.get_tensor_model_parallel_group() + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + pp_first_rank = parallel_state.get_pipeline_model_parallel_first_rank() + pp_last_rank = parallel_state.get_pipeline_model_parallel_last_rank() + pp_size = parallel_state.get_pipeline_model_parallel_world_size() + pp_group = parallel_state.get_pipeline_model_parallel_group() + pp_is_last = parallel_state.is_pipeline_last_stage(ignore_virtual=True) + pp_is_first = parallel_state.is_pipeline_first_stage(ignore_virtual=True) + vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() + if not vp_size: + vp_size = 1 + + reshard_model = False + if inference_tp_size != tp_size or inference_pp_size != pp_size: + LOGGER.info("Training/Generation model parallelism resharding enabled") + if inference_pp_size == 1 and pp_size > 1 and inference_tp_size == tp_size: + reshard_model = True + else: + raise NotImplementedError( + f"NeMo currently only supports PP>1 -> PP=1 resharding, other types of resharding will come in future releases." + ) + + num_layers = nemo_model_config["num_layers"] + is_mcore = nemo_model_config.get("mcore_gpt", False) + storage_type = torch_dtype_from_precision(nemo_model_config.precision) + sample_state_dict = model[0].state_dict() if vp_size > 1 else model.state_dict() + prefix, transformer_layer_prefix = get_layer_prefix(sample_state_dict, is_mcore) + assert is_mcore, "Only megatron-core inflight model conversion is supported" + + export_config = { + "apply_layernorm_1p": nemo_model_config.get("normalization", "") == "layernorm1p", + "tp_size": tp_size, + "split_gated_activation": nemo_model_config.get("activation", "gelu") + in ["swiglu", "geglu", "fast-swiglu", "fast-geglu"], + "num_attention_heads": nemo_model_config["num_attention_heads"], + "num_kv_heads": nemo_model_config.get('num_query_groups', nemo_model_config['num_attention_heads']), + "convert_on_device": True, + "use_attention_nemo_shape": True, + "transpose_weights": True, + } + + starmap_config = { + "tp_rank": None, + "saved_dir": None, # unused + "split_factor": 0, + "storage_type": storage_type, + "act_range": None, + "config": export_config, + } + + tl_params = {} + model_level_params = {} + starmap_args = [] + layers_per_pp = num_layers // pp_size + layers_per_chunk = layers_per_pp // vp_size + + if vp_size > 1: # consolidate params across model chunks + for idx, model_chunk in enumerate(model): + for key, val in model_chunk.state_dict().items(): + if torch.is_tensor(val): + if 'layers' in key: + key2 = rename_layer_num(key, get_layer_num(key) + idx * pp_size * layers_per_chunk) + tl_params[key2] = val + else: + model_level_params[key] = val + else: + for key, val in model.state_dict().items(): + if torch.is_tensor(val): + if 'decoder.layers' in key: + tl_params[key] = val + else: + model_level_params[key] = val + + if vp_size > 1 or reshard_model: + # gather layers across pp ranks + gathered_params = {} + for key, val in tl_params.items(): + weight_list = [torch.zeros_like(val) for _ in range(pp_size)] + torch.distributed.all_gather(weight_list, val, group=pp_group) + for idx in range(pp_size): + layer_num = get_layer_num(key) + idx * layers_per_chunk + key2 = rename_layer_num(key, layer_num) + if not reshard_model: # Save only layers of 1 single PP stage + layers_start = layers_per_pp * pp_rank + layers_end = layers_per_pp * (pp_rank + 1) - 1 + if layer_num >= layers_start and layer_num <= layers_end: + key2 = rename_layer_num(key, layer_num % layers_per_pp) + gathered_params[key2] = weight_list[idx] + else: + gathered_params[key2] = weight_list[idx] + tl_params = gathered_params + + # ----------------Convert layer level weights---------------- + layer_params = extract_layers_with_prefix(tl_params, transformer_layer_prefix) + layer_params = {k: v for k, v in layer_params.items() if k.startswith("layers.")} + for key, val in layer_params.items(): + starmap_args.append(starmap_config | {'key': rename_key(key), 'vals': val}) + + def broadcast_item(item, group, src_rank): + item = [item] + torch.distributed.broadcast_object_list(item, src_rank, group=group) + return item[0] + + def try_get_model_level_weight(src_key_or_tensor, pp_src_idx): + have_tensor = False + if torch.distributed.get_rank() == pp_src_idx: + if isinstance(src_key_or_tensor, str): + tensor = model_level_params.get(src_key_or_tensor, None) + have_tensor = torch.is_tensor(tensor) + else: + assert torch.is_tensor(src_key_or_tensor) + tensor = src_key_or_tensor + have_tensor = True + if reshard_model: + have_tensor = broadcast_item(have_tensor, pp_group, pp_src_idx) + if not have_tensor: + return None + + if reshard_model: # Broadcast tensor to all PP groups + if torch.distributed.get_rank() == pp_src_idx: + shape = tensor.shape + else: + shape = [None] + shape = broadcast_item(shape, pp_group, pp_src_idx) + if torch.distributed.get_rank() != pp_src_idx: + tensor = torch.zeros(shape, dtype=storage_type).cuda() + torch.distributed.broadcast(tensor.contiguous(), pp_src_idx, group=pp_group) + return tensor + + # ----------------Convert Final Layernorm---------------- + if pp_is_last or reshard_model: + ln_f = try_get_model_level_weight( + get_layer_name("final_layernorm.weight", transformer_layer_prefix), pp_last_rank + ) + if ln_f is not None: + starmap_args.append(starmap_config | {'key': "final_layernorm.weight", 'vals': ln_f}) + + ln_f_bias = try_get_model_level_weight( + get_layer_name("final_layernorm.bias", transformer_layer_prefix), pp_last_rank + ) + if ln_f_bias is not None: + starmap_args.append(starmap_config | {'key': "final_layernorm.bias", 'vals': ln_f_bias}) + + # ----------------Convert Embeddings---------------- + def get_remove_vocab_padding(tensor_name): + tensor = model_level_params.get(tensor_name, None) + if tensor is None: + return None + + if tp_size > 1: # Gather padded tensor chunks + vocab_size_padded = tensor.shape[0] * tp_size + vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size( + vocab_size_padded, tp_rank, tp_size + ) + dim_size = list(tensor.size()) + dim_size[0] = vocab_size_padded + gathered_tensor = torch.zeros(dim_size, dtype=tensor.dtype, device=torch.cuda.current_device()) + gathered_tensor[vocab_start_index:vocab_end_index] = tensor + torch.distributed.all_reduce(gathered_tensor, group=tp_group) + tensor = gathered_tensor + unpadded = tensor[:tokenizer_vocab_size] + if tp_size > 1: # Split gathered tensor for tensor parallel embedding + vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size( + tokenizer_vocab_size, tp_rank, tp_size + ) + unpadded = unpadded[vocab_start_index:vocab_end_index] + return unpadded.T # TRTLLM expects (vocab_size, hidden_size) so need extra transpose + + if pp_is_first or reshard_model: + vocab_embed = get_remove_vocab_padding(get_layer_name("word_embedding", prefix)) + vocab_embed = try_get_model_level_weight(vocab_embed, pp_first_rank) + save_val(vocab_embed, dir=None, key='transformer.vocab_embedding.weight', tp_num=None) + + if pp_is_last or reshard_model: + lm_head = get_remove_vocab_padding(get_layer_name("output_layer", prefix)) + lm_head = try_get_model_level_weight(lm_head, pp_last_rank) + save_val(lm_head, dir=None, key='lm_head.weight', tp_num=None) + + for starmap_arg in tqdm(starmap_args, desc="saving weights"): + split_and_save_weight(**starmap_arg) + + return weights_dict + + def create_export_dir(nemo_export_dir): out_dir = Path(nemo_export_dir) if not out_dir.exists(): diff --git a/nemo/export/trt_llm/converter/utils.py b/nemo/export/trt_llm/converter/utils.py index 469d624bdb18b..3768ff4b28448 100644 --- a/nemo/export/trt_llm/converter/utils.py +++ b/nemo/export/trt_llm/converter/utils.py @@ -14,6 +14,7 @@ import numpy as np +import tensorrt_llm import torch from tensorrt_llm._utils import torch_to_numpy @@ -33,11 +34,23 @@ def save_val(val, dir, key, tp_num=None): suffix = "" if tp_num is None else f".{tp_num}.bin" - # Transpose linear layer weights to the correct shape. - if len(val.shape) >= 2: - val = np.ascontiguousarray(np.transpose(val.reshape(val.shape[0], -1), [1, 0])) global weights_dict - weights_dict[f"{key}{suffix}"] = val + + # Transpose linear layer weights to the correct shape. + if torch.is_tensor(val): + val = val.detach().contiguous() + if len(val.shape) >= 2: + val = val.reshape(val.shape[0], -1) + val = torch.transpose(val, 0, 1) + if key not in weights_dict: + weights_dict[f"{key}{suffix}"] = torch.empty( + val.size(), dtype=val.dtype, layout=val.layout, device="cpu", pin_memory=True + ) + weights_dict[f"{key}{suffix}"].copy_(val, non_blocking=True) + else: + if len(val.shape) >= 2: + val = np.ascontiguousarray(np.transpose(val.reshape(val.shape[0], -1), [1, 0])) + weights_dict[f"{key}{suffix}"] = val def save_split(split_vals, dir, key, i, split_factor): @@ -173,6 +186,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t multi_query_mode = config.get("multi_query_mode", False) num_kv_heads = config.get("num_kv_heads", num_attention_heads) size_per_head = config.get("kv_channels", None) + convert_on_device = config.get("convert_on_device", False) save_int8 = int8_outputs == "all" or int8_outputs == "kv_cache_only" @@ -185,10 +199,14 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t if config.get("transpose_weights", False) and vals[0].ndim == 2: vals = [val.T for val in vals] if "layernorm.weight" in key and config.get("apply_layernorm_1p", False): - vals = [val + 1.0 for val in vals] + vals = [val.float() + 1.0 for val in vals] - if torch.is_tensor(vals[0]): - vals = [torch_to_numpy(val.cpu().to(storage_type)) for val in vals] + vals = [val.to(storage_type) for val in vals] + if convert_on_device: + assert len(vals) == 1 # Should only convert a single device param per call + assert torch.is_tensor(vals[0]) + elif torch.is_tensor(vals[0]): + vals = [torch_to_numpy(val.cpu()) for val in vals] if ( "input_layernorm.weight" in key @@ -227,7 +245,7 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t key = f'{layer_prefix}.post_layernorm.weight' else: key = f'{layer_prefix}.post_layernorm.bias' - if tp_rank == 0: + if tp_rank == 0 or convert_on_device: save_val(vals[0], saved_dir, key) elif ( @@ -236,14 +254,19 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t or "attention.linear_proj.weight" in key or "mlp.linear_fc2.weight" in key ): - cat_dim = 0 - val = np.concatenate(vals, axis=cat_dim) - split_vals = np.split(val, split_factor, axis=cat_dim) if "attention.linear_proj.weight" in key or "attention.dense.weight" in key: key = f'{layer_prefix}.attention.dense.weight' elif "mlp.linear_fc2.weight" in key or "mlp.dense_4h_to_h.weight" in key: key = f'{layer_prefix}.mlp.proj.weight' - save_split(split_vals, saved_dir, key, tp_rank, split_factor) + + if convert_on_device: + save_val(vals[0], saved_dir, key) + else: + cat_dim = 0 + val = np.concatenate(vals, axis=cat_dim) + split_vals = np.split(val, split_factor, axis=cat_dim) + save_split(split_vals, saved_dir, key, tp_rank, split_factor) + if act_range is not None and int8_outputs == "all": base_key = key.replace(".weight", "") vals_i8 = generate_int8(val, act_range, multi_query_mode=multi_query_mode) @@ -255,18 +278,26 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t or "mlp.linear_fc1.weight" in key or "mlp.linear_fc1.bias" in key ): - if split_gated_activation: - splits = [np.split(val, 2, axis=-1) for val in vals] - vals, gates = list(zip(*splits)) - cat_dim = -1 - val = np.concatenate(vals, axis=cat_dim) - split_vals = np.split(val, split_factor, axis=cat_dim) - if key.endswith("weight"): key = f'{layer_prefix}.mlp.fc.weight' else: key = f'{layer_prefix}.mlp.fc.bias' - save_split(split_vals, saved_dir, key, tp_rank, split_factor) + + if split_gated_activation: + if convert_on_device: + vals, gates = [[n] for n in torch.chunk(vals[0], 2, axis=-1)] + else: + splits = [np.split(val, 2, axis=-1) for val in vals] + vals, gates = list(zip(*splits)) + + if convert_on_device: + save_val(vals[0], saved_dir, key) + else: + cat_dim = -1 + val = np.concatenate(vals, axis=cat_dim) + split_vals = np.split(val, split_factor, axis=cat_dim) + save_split(split_vals, saved_dir, key, tp_rank, split_factor) + if act_range is not None and int8_outputs == "all": base_key = key.replace(".weight", "") vals_i8 = generate_int8(val, act_range, multi_query_mode=multi_query_mode) @@ -279,47 +310,61 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t else: key = f'{layer_prefix}.mlp.gate.bias' - gate = np.concatenate(gates, axis=cat_dim) - split_vals = np.split(gate, split_factor, axis=cat_dim) - save_split(split_vals, saved_dir, key, tp_rank, split_factor) + if convert_on_device: + save_val(gates[0], saved_dir, key) + else: + gate = np.concatenate(gates, axis=cat_dim) + split_vals = np.split(gate, split_factor, axis=cat_dim) + save_split(split_vals, saved_dir, key, tp_rank, split_factor) elif "mlp.dense_h_to_4h_2.weight" in key or "mlp.dense_h_to_4h_2.bias" in key: - cat_dim = -1 - val = np.concatenate(vals, axis=cat_dim) - split_vals = np.split(val, split_factor, axis=cat_dim) - save_split(split_vals, saved_dir, key, tp_rank, split_factor) + if convert_on_device: + save_val(vals[0], saved_dir, key) + else: + cat_dim = -1 + val = np.concatenate(vals, axis=cat_dim) + split_vals = np.split(val, split_factor, axis=cat_dim) + save_split(split_vals, saved_dir, key, tp_rank, split_factor) + if act_range is not None and int8_outputs == "all": base_key = key.replace(".weight", "") vals_i8 = generate_int8(val, act_range, multi_query_mode=multi_query_mode) write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank, split_factor) elif "attention.query_key_value.bias" in key or "attention.linear_qkv.bias" in key: + key = f'{layer_prefix}.attention.qkv.bias' qkv_hidden_dim = vals[0].shape[0] size_per_head = qkv_hidden_dim // (num_attention_heads + 2 * num_kv_heads) q_num = num_attention_heads // num_kv_heads # We first concat all sub weights per tp rank together. - len_vals = len(vals) - val = np.concatenate(vals, axis=0) + if convert_on_device: + val = vals[0] + else: + val = np.concatenate(vals, axis=0) val = val.reshape(num_kv_heads * len_vals // tp_size, q_num + 2, size_per_head) # Split the QKV to separate variables. - - qkv = np.split(val, [q_num, q_num + 1], axis=1) - q_split = np.split(qkv[0], split_factor, axis=0) - k_split = np.split(qkv[1], split_factor, axis=0) - v_split = np.split(qkv[2], split_factor, axis=0) - - # Concatenate Q, K, and V together - split_vals = [ - np.concatenate([q_split[i].reshape(-1), k_split[i].reshape(-1), v_split[i].reshape(-1)], axis=0) - for i in range(split_factor) - ] - key = f'{layer_prefix}.attention.qkv.bias' - save_split(split_vals, saved_dir, key, tp_rank, split_factor) + if convert_on_device: + qkv = torch.split(val, [q_num, 1, 1], dim=1) + split_vals = torch.concatenate([qkv[0].reshape(-1), qkv[1].reshape(-1), qkv[2].reshape(-1)], dim=1) + save_val(split_vals, saved_dir, key) + else: + qkv = np.split(val, [q_num, q_num + 1], axis=1) + q_split = np.split(qkv[0], split_factor, axis=0) + k_split = np.split(qkv[1], split_factor, axis=0) + v_split = np.split(qkv[2], split_factor, axis=0) + + # Concatenate Q, K, and V together + split_vals = [ + np.concatenate([q_split[i].reshape(-1), k_split[i].reshape(-1), v_split[i].reshape(-1)], axis=0) + for i in range(split_factor) + ] + save_split(split_vals, saved_dir, key, tp_rank, split_factor) elif "attention.query_key_value.weight" in key or "attention.linear_qkv.weight" in key: + key = f'{layer_prefix}.attention.qkv.weight' assert use_attention_nemo_shape, "Only support NEMO shape for QKV weights" hidden_dim = vals[0].shape[0] if size_per_head is None: @@ -328,35 +373,49 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t # When the merge factor exceeds 1, the 'vals' list will have multiple entries. # Depending on the format, 'vals' can look like either [QQQQ..KV, QQQQ..KV, ...](for GQA) or [QKV, QKV, ...](for MHA). - # We first concat all sub weights per tp rank together. - len_vals = len(vals) - val = np.concatenate(vals, axis=1) - - val = val.reshape(hidden_dim, num_kv_heads * len_vals // tp_size, q_num + 2, size_per_head) - - # Split the QKV to separate variables. - qkv = np.split(val, [q_num, q_num + 1], axis=2) - - q_split = np.split(qkv[0], split_factor, axis=1) - k_split = np.split(qkv[1], split_factor, axis=1) - v_split = np.split(qkv[2], split_factor, axis=1) - - # Concatenate Q, K, and V together - split_vals = [ - np.concatenate( - [ - q_split[i].reshape(hidden_dim, -1), - k_split[i].reshape(hidden_dim, -1), - v_split[i].reshape(hidden_dim, -1), - ], - axis=1, + if convert_on_device: + val = vals[0].reshape(hidden_dim, num_kv_heads // tp_size, q_num + 2, size_per_head) + qkv = torch.split(val, [q_num, 1, 1], dim=2) + split_vals = torch.concatenate( + [qkv[0].reshape(hidden_dim, -1), qkv[1].reshape(hidden_dim, -1), qkv[2].reshape(hidden_dim, -1)], dim=1 ) - for i in range(split_factor) - ] + save_val(split_vals, saved_dir, key) + else: + len_vals = len(vals) + val = np.concatenate(vals, axis=1) + val = val.reshape(hidden_dim, num_kv_heads * len_vals // tp_size, q_num + 2, size_per_head) + + # Split the QKV to separate variables. + qkv = np.split(val, [q_num, q_num + 1], axis=2) + + query_groups_shape = qkv[0].shape + if len(query_groups_shape) > 1: + if (query_groups_shape[1] % split_factor) != 0: + raise Exception( + "Number of query groups of the models is {0}. Please select tensor parallelism size " + "that can split the number of query groups to equal number of query matrices in the " + "each GPU.".format(query_groups_shape[1]) + ) + + q_split = np.split(qkv[0], split_factor, axis=1) + k_split = np.split(qkv[1], split_factor, axis=1) + v_split = np.split(qkv[2], split_factor, axis=1) + + # Concatenate Q, K, and V together + split_vals = [ + np.concatenate( + [ + q_split[i].reshape(hidden_dim, -1), + k_split[i].reshape(hidden_dim, -1), + v_split[i].reshape(hidden_dim, -1), + ], + axis=1, + ) + for i in range(split_factor) + ] + save_split(split_vals, saved_dir, key, tp_rank, split_factor) - key = f'{layer_prefix}.attention.qkv.weight' - save_split(split_vals, saved_dir, key, tp_rank, split_factor) if save_int8: base_key = key.replace(".weight", "") vals_i8 = generate_int8(val, act_range, is_qkv=True, multi_query_mode=multi_query_mode) @@ -390,14 +449,14 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t split_w3s = np.split(w3, split_factor, axis=1) split_vals = [np.concatenate(item, axis=1) for item in zip(split_w3s, split_w1s)] - key = f'{layer_prefix}.mlp.experts_weight_1' + key = f'{layer_prefix}.mlp.fc.weight' save_expert_split(split_vals, saved_dir, key, tp_rank, split_factor) elif "experts.linear_fc2.weight" in key: cat_dim = -1 val = np.concatenate(vals, axis=cat_dim) split_vals = np.split(val, split_factor, axis=cat_dim) - key = f'{layer_prefix}.mlp.experts_weight_2' + key = f'{layer_prefix}.mlp.proj.weight' save_expert_split(split_vals, saved_dir, key, tp_rank, split_factor) else: print(f"[WARNING] {key} not handled by converter") @@ -414,3 +473,25 @@ def split(v, tp_size, idx, dim=0): return np.ascontiguousarray(np.split(v, tp_size)[idx]) else: return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx]) + + +def init_model_parallel_from_nemo(reshard_model): + from megatron.core import parallel_state + + pp_size = parallel_state.get_pipeline_model_parallel_world_size() + tp_size = parallel_state.get_tensor_model_parallel_world_size() + dp_size = parallel_state.get_data_parallel_world_size() + tp_rank = parallel_state.get_tensor_model_parallel_rank() + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + dp_rank = parallel_state.get_data_parallel_rank() + + if reshard_model and pp_size > 1: + dp_size = dp_size * pp_size + dp_rank = torch.distributed.get_rank() // tp_size + pp_rank = 0 + pp_size = 1 + + mp_rank = tp_size * pp_rank + tp_rank + tensorrt_llm.bindings.MpiComm.split(dp_rank, mp_rank) + + return mp_rank, dp_rank, tp_size, pp_size, dp_size diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/__init__.py b/nemo/export/trt_llm/nemo_ckpt_loader/__init__.py index c9c6f65d27e0d..d9155f923f186 100644 --- a/nemo/export/trt_llm/nemo_ckpt_loader/__init__.py +++ b/nemo/export/trt_llm/nemo_ckpt_loader/__init__.py @@ -11,6 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - -from nemo.export.trt_llm.nemo_ckpt_loader.sentencepiece_tokenizer import SentencePieceTokenizer diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py index 09eae628999a8..1d473f497f519 100644 --- a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py +++ b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py @@ -28,8 +28,8 @@ from torch.distributed.checkpoint import FileSystemReader from transformers import AutoTokenizer, PreTrainedTokenizer +from nemo.export.sentencepiece_tokenizer import SentencePieceTokenizer from nemo.export.tarutils import TarPath, ZarrPathStore -from nemo.export.trt_llm.nemo_ckpt_loader.sentencepiece_tokenizer import SentencePieceTokenizer LOGGER = logging.getLogger("NeMo") diff --git a/nemo/export/trt_llm/qnemo/tokenizer_utils.py b/nemo/export/trt_llm/qnemo/tokenizer_utils.py index 4b0775a0aa2a6..c3dd5c2befc9e 100644 --- a/nemo/export/trt_llm/qnemo/tokenizer_utils.py +++ b/nemo/export/trt_llm/qnemo/tokenizer_utils.py @@ -17,7 +17,7 @@ from omegaconf import OmegaConf from transformers import AutoTokenizer -from nemo.export.trt_llm.nemo_ckpt_loader.sentencepiece_tokenizer import SentencePieceTokenizer +from nemo.export.sentencepiece_tokenizer import SentencePieceTokenizer # TODO: use get_nmt_tokenizer helper below to instantiate tokenizer once environment / dependencies get stable # from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer diff --git a/nemo/export/trt_llm/tensorrt_llm_build.py b/nemo/export/trt_llm/tensorrt_llm_build.py index ef9a14c1d582a..b329de2a3b183 100644 --- a/nemo/export/trt_llm/tensorrt_llm_build.py +++ b/nemo/export/trt_llm/tensorrt_llm_build.py @@ -44,6 +44,9 @@ def build_and_save_engine( enable_multi_block_mode: bool = False, paged_kv_cache: bool = True, remove_input_padding: bool = True, + paged_context_fmha: bool = False, + custom_all_reduce: bool = True, + use_refit: bool = False, max_num_tokens: int = None, opt_num_tokens: int = None, max_beam_width: int = 1, @@ -59,12 +62,14 @@ def build_and_save_engine( plugin_config = PluginConfig() plugin_config.set_gpt_attention_plugin(dtype=str_dtype) plugin_config.set_gemm_plugin(dtype=str_dtype) + plugin_config.use_custom_all_reduce = custom_all_reduce plugin_config.set_plugin("multi_block_mode", enable_multi_block_mode) if paged_kv_cache: plugin_config.enable_paged_kv_cache(tokens_per_block=tokens_per_block) else: plugin_config.paged_kv_cache = False plugin_config.remove_input_padding = remove_input_padding + plugin_config.use_paged_context_fmha = paged_context_fmha max_num_tokens, opt_num_tokens = check_max_num_tokens( max_num_tokens=max_num_tokens, @@ -89,6 +94,7 @@ def build_and_save_engine( 'gather_generation_logits': False, 'strongly_typed': False, 'builder_opt': None, + 'use_refit': use_refit, } build_config = BuildConfig.from_dict(build_dict, plugin_config=plugin_config) diff --git a/nemo/export/trt_llm/tensorrt_llm_run.py b/nemo/export/trt_llm/tensorrt_llm_run.py index 8fdd747dcb90f..dbbf40cc3cf13 100644 --- a/nemo/export/trt_llm/tensorrt_llm_run.py +++ b/nemo/export/trt_llm/tensorrt_llm_run.py @@ -26,12 +26,13 @@ import tensorrt_llm import torch from mpi4py.futures import MPIPoolExecutor +from tensorrt_llm.bindings import GptJsonConfig, GptSession, GptSessionConfig, KvCacheConfig, WorldConfig from tensorrt_llm.lora_manager import LoraManager from tensorrt_llm.quantization import QuantMode from tensorrt_llm.runtime import ModelConfig, ModelRunner, ModelRunnerCpp, SamplingConfig +from tensorrt_llm.runtime.model_runner_cpp import ModelRunnerCppGptSession from transformers import PreTrainedTokenizer - LOGGER = logging.getLogger("NeMo") @@ -399,6 +400,77 @@ def forward( raise RuntimeError("Internal error") +def load_distributed(engine_dir, model_parallel_rank, gpus_per_node): + """Loads TRTLLM engines in a distributed gpu environment, in particular + this function creates a custom mapping of device_id to WorldConfig + """ + global tensorrt_llm_worker_context + if isinstance(tensorrt_llm_worker_context.decoder, ModelRunnerCppGptSession): + return + + config_path = Path(engine_dir) / f"config_{torch.distributed.get_rank()}.json" + json_config = GptJsonConfig.parse_file(config_path) + model_config = json_config.model_config + + max_beam_width = model_config.max_beam_width + max_batch_size = model_config.max_batch_size + max_input_len = model_config.max_input_len + max_seq_len = model_config.max_seq_len + + tp_size = json_config.tensor_parallelism + pp_size = json_config.pipeline_parallelism + assert tp_size <= gpus_per_node, "Multinode TP is not unsupported" + + # TRTLLM asserts that rank equals the device num however this + # is not true for the megatron mapping of TP->DP->PP. + # So we manipulate TRTLLM to emulate a TP->PP single node setup + # TRTLLM is expected to fix this in future releases + offset = (torch.cuda.current_device() - model_parallel_rank % gpus_per_node + gpus_per_node) % gpus_per_node + device_ids = [i for i in range(gpus_per_node)] + for _ in range(offset): + device_ids.append(device_ids.pop(0)) + world_config = WorldConfig.mpi( + gpus_per_node=gpus_per_node, tensor_parallelism=tp_size, pipeline_parallelism=pp_size, device_ids=device_ids + ) + engine_filename = json_config.engine_filename(world_config) + serialize_path = Path(engine_dir) / engine_filename + assert torch.cuda.current_device() == world_config.device + + session_config = GptSessionConfig( + max_batch_size=max_batch_size, max_beam_width=max_beam_width, max_sequence_length=max_seq_len + ) + session_config.gen_micro_batch_size = max_batch_size + session_config.ctx_micro_batch_size = max_batch_size + session_config.kv_cache_config = KvCacheConfig( + max_tokens=max_seq_len * max_batch_size, max_attention_window=max_seq_len + ) + + with open(serialize_path, "rb") as f: + engine_data = bytearray(f.read()) + + session = GptSession(session_config, model_config, world_config, engine_data) + decoder = ModelRunnerCppGptSession( + session, + lora_manager=None, + max_batch_size=max_batch_size, + max_input_len=max_input_len, + max_seq_len=max_seq_len, + max_beam_width=max_beam_width, + ) + + tensorrt_llm_worker_context.decoder = decoder + tensorrt_llm_worker_context.max_batch_size = max_batch_size + tensorrt_llm_worker_context.max_input_len = max_input_len + # Save the model config in case for refit + tensorrt_llm_worker_context.model_config = model_config + + +def refit(weights_dict): + global tensorrt_llm_worker_context + dtype = tensorrt_llm_worker_context.model_config.data_type + tensorrt_llm_worker_context.decoder.session.refit_engine(weights_dict, dtype) + + def prepare_input_tensors( input_texts: List[str], host_context: TensorrtLLMHostContext, diff --git a/nemo/export/vllm/__init__.py b/nemo/export/vllm/__init__.py new file mode 100644 index 0000000000000..d9155f923f186 --- /dev/null +++ b/nemo/export/vllm/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/export/vllm/engine.py b/nemo/export/vllm/engine.py new file mode 100644 index 0000000000000..0ce0e5083916f --- /dev/null +++ b/nemo/export/vllm/engine.py @@ -0,0 +1,103 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path + +from vllm import LLMEngine +from vllm.transformers_utils.tokenizer_group.tokenizer_group import TokenizerGroup + +from nemo.export.sentencepiece_tokenizer import SentencePieceTokenizer +from nemo.export.tarutils import TarPath +from nemo.export.vllm.tokenizer_group import NemoTokenizerGroup + +LOGGER = logging.getLogger("NeMo") + + +class NemoLLMEngine(LLMEngine): + """ + Overrides some functionality from vllm.LLMEngine to use our custom tokenizer + instead of one from Transformers. + """ + + def _init_tokenizer(self, **tokenizer_init_kwargs): + # Find the tokenizer file name in the Nemo checkpoint config + tokenizer_config = self.model_config.nemo_model_config.get('tokenizer', {}) + tokenizer_model = tokenizer_config.get('model', tokenizer_config.get('tokenizer_model', None)) + + # If there is no tokenizer file specified but there's a reference to an HF tokenizer, use that + if tokenizer_model is None and tokenizer_config.get('library') == 'huggingface': + tokenizer_type = tokenizer_config.get('type') + if tokenizer_type is not None: + tokenizer_group = TokenizerGroup( + tokenizer_id=tokenizer_type, + enable_lora=bool(self.lora_config), + max_num_seqs=self.scheduler_config.max_num_seqs, + max_input_length=None, + ) + + # Update the HF config fields that come from the tokenizer in NeMo + self.model_config.hf_config.vocab_size = len( + tokenizer_group.tokenizer.vocab + ) # this may be greater than vocab_size + self.model_config.hf_config.bos_token_id = tokenizer_group.tokenizer.bos_token_id + self.model_config.hf_config.eos_token_id = tokenizer_group.tokenizer.eos_token_id + self.model_config.hf_config.pad_token_id = tokenizer_group.tokenizer.pad_token_id + + return tokenizer_group + + # Open the checkpoint archive + with TarPath(self.model_config.nemo_checkpoint) as archive: + tokenizer_model_file = None + if isinstance(tokenizer_model, str) and tokenizer_model.startswith('nemo:'): + tokenizer_model = tokenizer_model[len('nemo:') :] + tokenizer_model_file = archive / tokenizer_model + if not tokenizer_model_file.exists(): + LOGGER.warn( + f'Tokenizer model file {tokenizer_model} specified in the model_config does not ' + + 'exist in the checkpoint.' + ) + tokenizer_model_file = None + + if tokenizer_model_file is None: + for path in archive.glob('*tokenizer*.model'): + LOGGER.info(f'Found tokenizer model file {path}.') + tokenizer_model_file = path + break + + if tokenizer_model_file is None: + raise RuntimeError('No tokenizer model file found, aborting.') + + # Extract the tokenizer model file into the model directory, + # because sentencepiece cannot load it directly from TarPath. + extracted_tokenizer_model = Path(self.model_config.model) / 'tokenizer.model' + with tokenizer_model_file.open('rb') as infile: + with extracted_tokenizer_model.open('wb') as outfile: + outfile.write(infile.read()) + + # Construct the tokenizer object and wrapper + tokenizer = SentencePieceTokenizer(str(extracted_tokenizer_model)) + + # Determine if the model needs a bos token (which is not stored in Nemo checkpoints) + add_bos_token = self.model_config.model_converter.requires_bos_token() + + tokenizer_group = NemoTokenizerGroup(tokenizer, add_bos_token=add_bos_token) + + # Update the HF config fields that come from the tokenizer in NeMo + self.model_config.hf_config.vocab_size = tokenizer.vocab_size + self.model_config.hf_config.bos_token_id = tokenizer.bos_token_id + self.model_config.hf_config.eos_token_id = tokenizer.eos_token_id + self.model_config.hf_config.pad_token_id = tokenizer.pad_id + + return tokenizer_group diff --git a/nemo/export/vllm/model_config.py b/nemo/export/vllm/model_config.py new file mode 100644 index 0000000000000..0a98a9180c1dd --- /dev/null +++ b/nemo/export/vllm/model_config.py @@ -0,0 +1,135 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import torch +import yaml +from transformers import AutoConfig +from vllm.config import ModelConfig, _get_and_verify_dtype, _get_and_verify_max_len +from vllm.transformers_utils.config import get_hf_text_config + +from nemo.export.tarutils import TarPath +from nemo.export.vllm.model_converters import get_model_converter + + +class NemoModelConfig(ModelConfig): + """ + This class pretents to be a vllm.config.ModelConfig (with extra fields) but skips + some of its initialization code, and initializes the configuration from a Nemo checkpoint instead. + """ + + def __init__( + self, + nemo_checkpoint: str, + model_dir: str, + model_type: str, + tokenizer_mode: str, + dtype: Union[str, torch.dtype], + seed: int, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + rope_scaling: Optional[dict] = None, + rope_theta: Optional[float] = None, + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + quantization_param_path: Optional[str] = None, + enforce_eager: bool = False, + max_seq_len_to_capture: Optional[int] = None, + max_logprobs: int = 5, + disable_sliding_window: bool = False, + ) -> None: + # Don't call ModelConfig.__init__ because we don't want it to call + # transformers.AutoConfig.from_pretrained(...) + + # TODO: Do something about vLLM's call to _load_generation_config_dict in LLMEngine.__init__ + # because it calls transformers.GenerationConfig.from_pretrained(...), which tries to download things + + self.nemo_checkpoint = nemo_checkpoint + self.model = model_dir + self.model_type = model_type + self.tokenizer = None + self.tokenizer_mode = tokenizer_mode + self.skip_tokenizer_init = False + self.trust_remote_code = False + self.seed = seed + self.revision = revision + self.code_revision = code_revision + self.rope_scaling = rope_scaling + self.rope_theta = rope_theta + self.tokenizer_revision = tokenizer_revision + self.quantization = quantization + self.quantization_param_path = quantization_param_path + self.enforce_eager = enforce_eager + self.max_seq_len_to_capture = max_seq_len_to_capture + self.max_logprobs = max_logprobs + self.disable_sliding_window = disable_sliding_window + self.served_model_name = nemo_checkpoint + + self.model_converter = get_model_converter(model_type) + if self.model_converter is None: + raise RuntimeError(f'Unknown model type "{model_type}"') + + hf_to_nemo_dict = { + 'hidden_size': 'hidden_size', + 'intermediate_size': 'ffn_hidden_size', + 'num_hidden_layers': 'num_layers', + 'num_attention_heads': 'num_attention_heads', + 'num_key_value_heads': 'num_query_groups', + # 'hidden_act': 'activation', ## <- vLLM has good defaults for the models, nemo values are wrong + 'max_position_embeddings': ['max_position_embeddings', 'encoder_seq_length'], + 'rms_norm_eps': 'layernorm_epsilon', + 'attention_dropout': 'attention_dropout', + 'initializer_range': 'init_method_std', + 'norm_epsilon': 'layernorm_epsilon', + 'rope_theta': 'rotary_base', + 'use_bias': 'bias', + } + + with TarPath(nemo_checkpoint) as archive: + with (archive / "model_config.yaml").open("r") as model_config_file: + self.nemo_model_config = yaml.load(model_config_file, Loader=yaml.SafeLoader) + + hf_args = {} + for hf_arg, nemo_arg in hf_to_nemo_dict.items(): + if not isinstance(nemo_arg, list): + nemo_arg = [nemo_arg] + + for nemo_arg_option in nemo_arg: + value = self.nemo_model_config.get(nemo_arg_option) + if value is not None: + hf_args[hf_arg] = value + break + + self.model_converter.convert_config(self.nemo_model_config, hf_args) + + self.hf_config = AutoConfig.for_model(model_type, **hf_args) + + self.hf_config.architectures = [self.model_converter.get_architecture()] + if self.rope_scaling is not None: + self.hf_config['rope_scaling'] = rope_scaling + + self.hf_text_config = get_hf_text_config(self.hf_config) + self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) + self.max_model_len = _get_and_verify_max_len( + hf_config=self.hf_text_config, + max_model_len=max_model_len, + disable_sliding_window=self.disable_sliding_window, + sliding_window_len=self.get_hf_config_sliding_window(), + ) + self._verify_tokenizer_mode() + self._verify_embedding_mode() + self._verify_quantization() + self._verify_cuda_graph() diff --git a/nemo/export/vllm/model_converters.py b/nemo/export/vllm/model_converters.py new file mode 100644 index 0000000000000..595ceecf0b186 --- /dev/null +++ b/nemo/export/vllm/model_converters.py @@ -0,0 +1,410 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Optional, Sequence, Tuple + +import torch + + +class ModelConverter(ABC): + """ + Abstract class that defines the interface for a converter that implements model-specific conversion functions + for deploying NeMo checkpoints on vLLM. + """ + + def __init__(self, model_type: str): + self.model_type = model_type + + @abstractmethod + def get_architecture(self) -> Optional[str]: + """ + Returns the HF architecture name for the current model, such as 'LlamaForCausalLM'. + """ + pass + + def convert_config(self, nemo_model_config: dict, hf_config: dict) -> None: + """ + Implements any custom HF configuration adjustments in the 'hf_config' dict that are necessary + for this model after the common translation takes place in NemoModelConfig's constructor. + """ + pass + + @abstractmethod + def convert_weights(self, nemo_model_config: dict, state_dict: dict) -> Sequence[Tuple[str, torch.tensor]]: + """ + Returns or yields a sequence of (name, tensor) tuples that contain model weights in the HF format. + """ + pass + + def requires_bos_token(self) -> bool: + """ + Returns True if the model requires a 'bos' token to be used at the beginning of the input sequence. + NeMo checkpoints do not store this information. + """ + return False + + +class LlamaConverter(ModelConverter): + + def get_architecture(self): + if self.model_type == 'llama': + return 'LlamaForCausalLM' + if self.model_type == 'mistral': + return 'MistralForCausalLM' + return None + + def convert_weights(self, nemo_model_config, state_dict): + hidden_size = nemo_model_config["hidden_size"] + head_num = nemo_model_config["num_attention_heads"] + num_query_groups = nemo_model_config["num_query_groups"] + num_layers = nemo_model_config["num_layers"] + head_size = hidden_size // head_num + heads_per_group = head_num // num_query_groups + qkv_total_dim = head_num + 2 * num_query_groups + + yield ('model.embed_tokens.weight', state_dict['model.embedding.word_embeddings.weight']) + yield ('model.norm.weight', state_dict['model.decoder.final_layernorm.weight']) + yield ('lm_head.weight', state_dict['model.output_layer.weight']) + + for layer in range(int(num_layers)): + qkv_weights = state_dict['model.decoder.layers.self_attention.linear_qkv.weight'][layer] + qkv_weights = qkv_weights.reshape([qkv_total_dim, head_size, hidden_size]) + + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + for name, slice in [('q_proj', q_slice), ('k_proj', k_slice), ('v_proj', v_slice)]: + weight_name = f'model.layers.{layer}.self_attn.{name}.weight' + yield (weight_name, qkv_weights[slice].reshape(-1, hidden_size)) + + linear_proj_weight = state_dict['model.decoder.layers.self_attention.linear_proj.weight'][layer] + yield (f'model.layers.{layer}.self_attn.o_proj.weight', linear_proj_weight) + + gate_proj_weight, up_proj_weight = torch.chunk( + state_dict['model.decoder.layers.mlp.linear_fc1.weight'][layer], 2, dim=0 + ) + yield (f'model.layers.{layer}.mlp.gate_proj.weight', gate_proj_weight) + yield (f'model.layers.{layer}.mlp.up_proj.weight', up_proj_weight) + + mlp_up_weight = state_dict['model.decoder.layers.mlp.linear_fc2.weight'][layer] + yield (f'model.layers.{layer}.mlp.down_proj.weight', mlp_up_weight) + + input_layernorm_weight = state_dict['model.decoder.layers.self_attention.linear_qkv.layer_norm_weight'][ + layer + ] + yield (f'model.layers.{layer}.input_layernorm.weight', input_layernorm_weight) + + post_attn_layernorm_weight = state_dict['model.decoder.layers.mlp.linear_fc1.layer_norm_weight'][layer] + yield (f'model.layers.{layer}.post_attention_layernorm.weight', post_attn_layernorm_weight) + + def requires_bos_token(self): + return True + + +class MixtralConverter(ModelConverter): + + def get_architecture(self): + if self.model_type == 'mixtral': + return 'MixtralForCausalLM' + return None + + def convert_weights(self, nemo_model_config, state_dict): + hidden_size = nemo_model_config["hidden_size"] + head_num = nemo_model_config["num_attention_heads"] + num_query_groups = nemo_model_config["num_query_groups"] + num_layers = nemo_model_config["num_layers"] + num_moe_experts = nemo_model_config["num_moe_experts"] + head_size = hidden_size // head_num + heads_per_group = head_num // num_query_groups + qkv_total_dim = head_num + 2 * num_query_groups + + yield ('model.embed_tokens.weight', state_dict['model.embedding.word_embeddings.weight']) + yield ('model.norm.weight', state_dict['model.decoder.final_layernorm.weight']) + yield ('lm_head.weight', state_dict['model.output_layer.weight']) + + for layer in range(int(num_layers)): + qkv_weights = state_dict['model.decoder.layers.self_attention.linear_qkv.weight'][layer] + qkv_weights = qkv_weights.reshape([qkv_total_dim, head_size, hidden_size]) + + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + for name, slice in [('q_proj', q_slice), ('k_proj', k_slice), ('v_proj', v_slice)]: + weight_name = f'model.layers.{layer}.self_attn.{name}.weight' + yield (weight_name, qkv_weights[slice].reshape(-1, hidden_size)) + + linear_proj_weight = state_dict['model.decoder.layers.self_attention.linear_proj.weight'][layer] + yield (f'model.layers.{layer}.self_attn.o_proj.weight', linear_proj_weight) + + mlp_router_weight = state_dict['model.decoder.layers.mlp.router.weight'][layer] + yield (f'model.layers.{layer}.block_sparse_moe.gate.weight', mlp_router_weight) + + for expert in range(num_moe_experts): + linear_fc1_weight = state_dict['model.decoder.layers.mlp.experts.experts.linear_fc1.weight'][layer][ + expert + ] + gate_proj_weight, up_proj_weight = torch.chunk(linear_fc1_weight, 2, dim=0) + yield (f'model.layers.{layer}.block_sparse_moe.experts.{expert}.w1.weight', gate_proj_weight) + yield (f'model.layers.{layer}.block_sparse_moe.experts.{expert}.w3.weight', up_proj_weight) + + linear_fc2_weight = state_dict['model.decoder.layers.mlp.experts.experts.linear_fc2.weight'][layer][ + expert + ] + yield (f'model.layers.{layer}.block_sparse_moe.experts.{expert}.w2.weight', linear_fc2_weight) + + input_layernorm_weight = state_dict['model.decoder.layers.self_attention.linear_qkv.layer_norm_weight'][ + layer + ] + yield (f'model.layers.{layer}.input_layernorm.weight', input_layernorm_weight) + + post_attn_layernorm_weight = state_dict['model.decoder.layers.pre_mlp_layernorm.weight'][layer] + yield (f'model.layers.{layer}.post_attention_layernorm.weight', post_attn_layernorm_weight) + + def requires_bos_token(self): + return True + + +class GemmaConverter(ModelConverter): + + def get_architecture(self): + if self.model_type == 'gemma': + return 'GemmaForCausalLM' + return None + + def convert_weights(self, nemo_model_config, state_dict): + num_layers = nemo_model_config["num_layers"] + num_query_groups = nemo_model_config["num_query_groups"] + head_num = nemo_model_config["num_attention_heads"] + head_size = nemo_model_config["kv_channels"] + hidden_size = nemo_model_config["hidden_size"] + heads_per_group = head_num // num_query_groups + + yield ('model.embed_tokens.weight', state_dict['model.embedding.word_embeddings.weight']) + + final_layernorm_weight = state_dict['model.decoder.final_layernorm.weight'] + final_layernorm_weight -= 1.0 + yield ('model.norm.weight', final_layernorm_weight) + + for layer in range(int(num_layers)): + input_layernorm_weight = state_dict['model.decoder.layers.self_attention.linear_qkv.layer_norm_weight'][ + layer + ] + input_layernorm_weight -= 1.0 + yield (f'model.layers.{layer}.input_layernorm.weight', input_layernorm_weight) + + post_attention_layernorm_weight = state_dict['model.decoder.layers.mlp.linear_fc1.layer_norm_weight'][ + layer + ] + post_attention_layernorm_weight -= 1.0 + yield (f'model.layers.{layer}.post_attention_layernorm.weight', post_attention_layernorm_weight) + + gate_up_combined_weight = state_dict['model.decoder.layers.mlp.linear_fc1.weight'][layer] + gate_size = gate_up_combined_weight.shape[0] // 2 + yield (f'model.layers.{layer}.mlp.gate_proj.weight', gate_up_combined_weight[:gate_size, :]) + yield (f'model.layers.{layer}.mlp.up_proj.weight', gate_up_combined_weight[gate_size:, :]) + + down_proj_weight = state_dict['model.decoder.layers.mlp.linear_fc2.weight'][layer] + yield (f'model.layers.{layer}.mlp.down_proj.weight', down_proj_weight) + + self_attn_o_proj_weight = state_dict['model.decoder.layers.self_attention.linear_proj.weight'][layer] + yield (f'model.layers.{layer}.self_attn.o_proj.weight', self_attn_o_proj_weight) + + qkv_weight = state_dict['model.decoder.layers.self_attention.linear_qkv.weight'][layer] + qkv_intermediate_size = head_num + 2 * num_query_groups + qkv_weight = qkv_weight.reshape(qkv_intermediate_size, head_size, hidden_size) + + q_weight = torch.empty((head_num, head_size, hidden_size), dtype=qkv_weight.dtype) + k_weight = torch.empty((num_query_groups, head_size, hidden_size), dtype=qkv_weight.dtype) + v_weight = torch.empty((num_query_groups, head_size, hidden_size), dtype=qkv_weight.dtype) + + ptr = 0 + for i in range(num_query_groups): + q_weight[i * heads_per_group : (i + 1) * heads_per_group, :, :] = qkv_weight[ + ptr : ptr + heads_per_group, :: + ] + ptr += heads_per_group + k_weight[i : i + 1, :, :] = qkv_weight[ptr : ptr + 1, :, :] + ptr += 1 + v_weight[i : i + 1, :, :] = qkv_weight[ptr : ptr + 1, :, :] + ptr += 1 + assert ptr == qkv_intermediate_size + + q_weight = q_weight.reshape(head_num * head_size, hidden_size) + k_weight = k_weight.reshape(num_query_groups * head_size, hidden_size) + v_weight = v_weight.reshape(num_query_groups * head_size, hidden_size) + + yield (f'model.layers.{layer}.self_attn.q_proj.weight', q_weight) + yield (f'model.layers.{layer}.self_attn.k_proj.weight', k_weight) + yield (f'model.layers.{layer}.self_attn.v_proj.weight', v_weight) + + def requires_bos_token(self): + return True + + +class Starcoder2Converter(ModelConverter): + + def get_architecture(self): + if self.model_type == 'starcoder2': + return 'Starcoder2ForCausalLM' + return None + + def convert_config(self, nemo_model_config, hf_config): + window_sizes = nemo_model_config.get('window_size') + if window_sizes is not None: + hf_config['sliding_window'] = window_sizes[0] + + # 'tie_word_embeddings = False' means that there is a 'lm_head.weight' tensor. + # This converter assumes that it's always there. + # If there is a version of starcoder2 where it's not there, we'll need to copy + # 'model.embed_tokens.weight' into 'lm_head.weight' and still set 'tie_word_embeddings = False' + # because at this point we don't know if the weight is there or not, and this configuration + # is not stored in NeMo checkpoints. + hf_config['tie_word_embeddings'] = False + + def convert_weights(self, nemo_model_config, state_dict): + num_layers = nemo_model_config["num_layers"] + num_query_groups = nemo_model_config["num_query_groups"] + head_num = nemo_model_config["num_attention_heads"] + hidden_size = nemo_model_config["hidden_size"] + head_size = hidden_size // head_num + heads_per_group = head_num // num_query_groups + qkv_total_dim = head_num + 2 * num_query_groups + has_bias = nemo_model_config["bias"] + + yield ('model.embed_tokens.weight', state_dict['model.embedding.word_embeddings.weight']) + + yield ('model.norm.weight', state_dict['model.decoder.final_layernorm.weight']) + if has_bias: + yield ('model.norm.bias', state_dict['model.decoder.final_layernorm.bias']) + + yield ('lm_head.weight', state_dict['model.output_layer.weight']) + + for layer in range(int(num_layers)): + # q,k,v + qkv_weights = state_dict['model.decoder.layers.self_attention.linear_qkv.weight'][layer] + qkv_weights = qkv_weights.reshape([qkv_total_dim, head_size, hidden_size]) + if has_bias: + qkv_bias = state_dict['model.decoder.layers.self_attention.linear_qkv.bias'][layer] + qkv_bias = qkv_bias.reshape([qkv_total_dim, head_size]) + + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + for name, slice in [('q_proj', q_slice), ('k_proj', k_slice), ('v_proj', v_slice)]: + qkv_weights_slice = qkv_weights[slice].reshape(-1, hidden_size) + yield (f'model.layers.{layer}.self_attn.{name}.weight', qkv_weights_slice) + if has_bias: + qkv_bias_slice = qkv_bias[slice].reshape(-1) + yield (f'model.layers.{layer}.self_attn.{name}.bias', qkv_bias_slice) + + # Attention dense + yield ( + f'model.layers.{layer}.self_attn.o_proj.weight', + state_dict[f'model.decoder.layers.self_attention.linear_proj.weight'][layer], + ) + if has_bias: + yield ( + f'model.layers.{layer}.self_attn.o_proj.bias', + state_dict['model.decoder.layers.self_attention.linear_proj.bias'][layer], + ) + + # MLP FC1 + yield ( + f'model.layers.{layer}.mlp.c_fc.weight', + state_dict['model.decoder.layers.mlp.linear_fc1.weight'][layer], + ) + if has_bias: + yield ( + f'model.layers.{layer}.mlp.c_fc.bias', + state_dict['model.decoder.layers.mlp.linear_fc1.bias'][layer], + ) + + # MLP FC2 + yield ( + f'model.layers.{layer}.mlp.c_proj.weight', + state_dict['model.decoder.layers.mlp.linear_fc2.weight'][layer], + ) + if has_bias: + yield ( + f'model.layers.{layer}.mlp.c_proj.bias', + state_dict['model.decoder.layers.mlp.linear_fc2.bias'][layer], + ) + + # Input LayerNorm + yield ( + f'model.layers.{layer}.input_layernorm.weight', + state_dict['model.decoder.layers.self_attention.linear_qkv.layer_norm_weight'][layer], + ) + if has_bias: + yield ( + f'model.layers.{layer}.input_layernorm.bias', + state_dict['model.decoder.layers.self_attention.linear_qkv.layer_norm_bias'][layer], + ) + + # Post-attention LayerNorm + yield ( + f'model.layers.{layer}.post_attention_layernorm.weight', + state_dict['model.decoder.layers.mlp.linear_fc1.layer_norm_weight'][layer], + ) + if has_bias: + yield ( + f'model.layers.{layer}.post_attention_layernorm.bias', + state_dict['model.decoder.layers.mlp.linear_fc1.layer_norm_bias'][layer], + ) + + +_MODEL_CONVERTERS = { + 'llama': LlamaConverter, + 'mistral': LlamaConverter, + 'mixtral': MixtralConverter, + 'gemma': GemmaConverter, + 'starcoder2': Starcoder2Converter, +} + + +def register_model_converter(model_type, cls): + """ + Establishes a mapping from short model type to a class that converts the model from Nemo format + to a vLLM compatible format. + """ + _MODEL_CONVERTERS[model_type] = cls + + +def get_model_converter(model_type) -> ModelConverter: + """ + Returns an instance of the the model conversion class for the given model type, or None. + """ + cls = _MODEL_CONVERTERS.get(model_type, None) + if cls is None: + return None + return cls(model_type) diff --git a/nemo/export/vllm/model_loader.py b/nemo/export/vllm/model_loader.py new file mode 100644 index 0000000000000..e7f3f1d1569f0 --- /dev/null +++ b/nemo/export/vllm/model_loader.py @@ -0,0 +1,120 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import logging +import os.path +from typing import Optional + +import numpy +import safetensors.torch +import tensorstore # needed to register 'bfloat16' dtype with numpy for zarr compatibility +import torch +import zarr +from vllm.config import CacheConfig, DeviceConfig, LoRAConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig +from vllm.model_executor.model_loader.loader import BaseModelLoader, _initialize_model +from vllm.model_executor.model_loader.utils import set_default_torch_dtype + +from nemo.export.tarutils import TarPath, ZarrPathStore +from nemo.export.vllm.model_config import NemoModelConfig + +LOGGER = logging.getLogger("NeMo") + + +class NemoModelLoader(BaseModelLoader): + """ + Implements a custom ModelLoader for vLLM that reads the weights from a Nemo checkpoint + and converts them to a vLLM compatible format at load time. + + Also supports an ahead-of-time conversion that stores new weights in a Safetensors file, + see convert_and_store_nemo_weights(...) + """ + + @staticmethod + def _load_nemo_checkpoint_state(nemo_file: str): + sharded_state_dict = {} + + LOGGER.info(f'Loading weights from {nemo_file}...') + + with TarPath(nemo_file) as archive: + for subdir in archive.iterdir(): + if not subdir.is_dir() or not (subdir / '.zarray').exists(): + continue + key = subdir.name + + zstore = ZarrPathStore(subdir) + arr = zarr.open(zstore, 'r') + + if arr.dtype.name == "bfloat16": + sharded_state_dict[key] = torch.from_numpy(arr[:].view(numpy.int16)).view(torch.bfloat16) + else: + sharded_state_dict[key] = torch.from_numpy(arr[:]) + + arr = None + gc.collect() + + LOGGER.debug(f'Loaded tensor "{key}": {sharded_state_dict[key].shape}') + + return sharded_state_dict + + def load_model( + self, + *, + model_config: NemoModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> torch.nn.Module: + """ + Overrides the load_model function from BaseModelLoader to convert Nemo weights at load time. + """ + + assert isinstance(model_config, NemoModelConfig) + state_dict = NemoModelLoader._load_nemo_checkpoint_state(model_config.nemo_checkpoint) + + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model( + model_config, self.load_config, lora_config, vision_language_config, cache_config + ) + + weights_iterator = model_config.model_converter.convert_weights(model_config.nemo_model_config, state_dict) + + model.load_weights(weights_iterator) + + return model.eval() + + @staticmethod + def convert_and_store_nemo_weights(model_config: NemoModelConfig, safetensors_file: str): + """ + Converts Nemo weights and stores the converted weights in a Safetensors file. + """ + + assert isinstance(model_config, NemoModelConfig) + assert os.path.exists(model_config.model) + + state_dict = NemoModelLoader._load_nemo_checkpoint_state(model_config.nemo_checkpoint) + + tensors = { + name: tensor + for name, tensor in model_config.model_converter.convert_weights( + model_config.nemo_model_config, state_dict + ) + } + + LOGGER.info(f'Saving weights to {safetensors_file}...') + safetensors.torch.save_file(tensors, safetensors_file) diff --git a/nemo/export/vllm/tokenizer_group.py b/nemo/export/vllm/tokenizer_group.py new file mode 100644 index 0000000000000..6e4aedc14acbe --- /dev/null +++ b/nemo/export/vllm/tokenizer_group.py @@ -0,0 +1,55 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +from vllm.lora.request import LoRARequest +from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import BaseTokenizerGroup + +from nemo.export.sentencepiece_tokenizer import SentencePieceTokenizer + + +class NemoTokenizerGroup(BaseTokenizerGroup): + """ + Implements a custom tokenizer for vLLM, based on SentencePieceTokenizer. + """ + + def __init__(self, tokenizer: SentencePieceTokenizer, add_bos_token: bool = False): + self.tokenizer = tokenizer + self.add_bos_token = add_bos_token + + def ping(self) -> bool: + return True + + def get_max_input_len(self, lora_request: Optional[LoRARequest] = None) -> Optional[int]: + return None + + def encode( + self, prompt: str, request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None + ) -> List[int]: + ids = self.tokenizer.encode(prompt) + if self.add_bos_token: + ids = [self.tokenizer.bos_token_id] + ids + return ids + + async def encode_async( + self, prompt: str, request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None + ) -> List[int]: + return self.tokenizer.encode(prompt) # TODO: not sure how this is supposed to work + + def get_lora_tokenizer(self, lora_request: Optional[LoRARequest] = None) -> SentencePieceTokenizer: + return self.tokenizer + + async def get_lora_tokenizer_async(self, lora_request: Optional[LoRARequest] = None) -> SentencePieceTokenizer: + return self.tokenizer diff --git a/nemo/export/vllm_exporter.py b/nemo/export/vllm_exporter.py new file mode 100644 index 0000000000000..f3dd6c8a248b5 --- /dev/null +++ b/nemo/export/vllm_exporter.py @@ -0,0 +1,417 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os.path +from typing import Iterable, List, Optional, Union + +import numpy +import wrapt +from vllm import RequestOutput, SamplingParams +from vllm.config import CacheConfig, DeviceConfig, LoadConfig, LoadFormat, ParallelConfig, SchedulerConfig +from vllm.executor.ray_utils import initialize_ray_cluster + +from nemo.deploy import ITritonDeployable +from nemo.deploy.utils import cast_output +from nemo.export.vllm.engine import NemoLLMEngine +from nemo.export.vllm.model_config import NemoModelConfig +from nemo.export.vllm.model_loader import NemoModelLoader + +LOGGER = logging.getLogger("NeMo") + + +@wrapt.decorator +def noop_decorator(func): + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + +use_pytriton = True +try: + from pytriton.decorators import batch + from pytriton.model_config import Tensor +except Exception: + use_pytriton = False + + +class vLLMExporter(ITritonDeployable): + """ + The Exporter class implements conversion from a Nemo checkpoint format to something compatible with vLLM, + loading the model in vLLM, and binding that model to a Triton server. + + Example: + from nemo.export.vllm import Exporter + from nemo.deploy import DeployPyTriton + + exporter = Exporter() + exporter.export( + nemo_checkpoint='/path/to/checkpoint.nemo', + model_dir='/path/to/temp_dir', + model_type='llama') + + server = DeployPyTriton( + model=exporter, + triton_model_name='LLAMA') + + server.deploy() + server.serve() + server.stop() + """ + + def __init__(self): + self.request_id = 0 + + def export( + self, + nemo_checkpoint: str, + model_dir: str, + model_type: str, + device: str = 'auto', + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, + max_model_len: int = None, + dtype: str = 'auto', + seed: int = 0, + log_stats: bool = True, + weight_storage: str = 'auto', + gpu_memory_utilization: float = 0.9, + ): + """ + Exports the Nemo checkpoint to vLLM and initializes the engine. + + Args: + nemo_checkpoint (str): path to the nemo checkpoint. + model_dir (str): path to a temporary directory to store weights and the tokenizer model. + The temp dir may persist between subsequent export operations, in which case + converted weights may be reused to speed up the export. + model_type (str): type of the model, such as "llama", "mistral", "mixtral". + Needs to be compatible with transformers.AutoConfig. + device (str): type of the device to use by the vLLM engine. + Supported values are "auto", "cuda", "cpu", "neuron". + tensor_parallel_size (int): tensor parallelism. + pipeline_parallel_size (int): pipeline parallelism. + Values over 1 are not currently supported by vLLM. + max_model_len (int): model context length. + dtype (str): data type for model weights and activations. + Possible choices: auto, half, float16, bfloat16, float, float32 + "auto" will use FP16 precision for FP32 and FP16 models, + and BF16 precision for BF16 models. + seed (int): random seed value. + log_stats (bool): enables logging inference performance statistics by vLLM. + weight_storage (str): controls how converted weights are stored: + "file" - always write weights into a file inside 'model_dir', + "memory" - always do an in-memory conversion, + "cache" - reuse existing files if they are newer than the nemo checkpoint, + "auto" - use "cache" for multi-GPU runs and "memory" for single-GPU runs. + gpu_memory_utilization (float): The fraction of GPU memory to be used for the model + executor, which can range from 0 to 1. + """ + + # Pouplate the basic configuration structures + device_config = DeviceConfig(device) + + model_config = NemoModelConfig( + nemo_checkpoint, + model_dir, + model_type, + tokenizer_mode='auto', + dtype=dtype, + seed=seed, + revision=None, + code_revision=None, + tokenizer_revision=None, + max_model_len=max_model_len, + quantization=None, # TODO ??? + quantization_param_path=None, + enforce_eager=False, + max_seq_len_to_capture=None, + ) + + parallel_config = ParallelConfig( + pipeline_parallel_size=pipeline_parallel_size, tensor_parallel_size=tensor_parallel_size + ) + + # See if we have an up-to-date safetensors file + safetensors_file = os.path.join(model_config.model, 'model.safetensors') + safetensors_file_valid = os.path.exists(safetensors_file) and os.path.getmtime( + safetensors_file + ) > os.path.getmtime(nemo_checkpoint) + + # Decide how we're going to convert the weights + if weight_storage == 'auto': + if parallel_config.distributed_executor_backend is not None: + save_weights = not safetensors_file_valid + inmemory_weight_conversion = False + else: + save_weights = False + inmemory_weight_conversion = True + + elif weight_storage == 'cache': + save_weights = not safetensors_file_valid + inmemory_weight_conversion = False + + elif weight_storage == 'file': + save_weights = True + inmemory_weight_conversion = False + + elif weight_storage == 'memory': + save_weights = False + inmemory_weight_conversion = True + + else: + raise ValueError(f'Unsupported value for weight_storage: "{weight_storage}"') + + # Convert the weights ahead-of-time, if needed + if save_weights: + NemoModelLoader.convert_and_store_nemo_weights(model_config, safetensors_file) + elif not inmemory_weight_conversion: + LOGGER.info(f'Using cached weights in {safetensors_file}') + + # TODO: these values are the defaults from vllm.EngineArgs. + cache_config = CacheConfig( + block_size=16, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=4, + cache_dtype='auto', + sliding_window=model_config.get_sliding_window(), + ) + + # TODO: these values are the defaults from vllm.EngineArgs. + scheduler_config = SchedulerConfig( + max_num_batched_tokens=None, + max_num_seqs=256, + # Note: max_model_len can be derived by model_config if the input value is None + max_model_len=model_config.max_model_len, + use_v2_block_manager=False, + num_lookahead_slots=0, + delay_factor=0.0, + enable_chunked_prefill=False, + ) + + load_config = LoadConfig( + load_format=NemoModelLoader if inmemory_weight_conversion else LoadFormat.SAFETENSORS, + download_dir=None, + model_loader_extra_config=None, + ) + + # Initialize the cluster and specify the executor class. + if device_config.device_type == "neuron": + from vllm.executor.neuron_executor import NeuronExecutor + + executor_class = NeuronExecutor + elif device_config.device_type == "cpu": + from vllm.executor.cpu_executor import CPUExecutor + + executor_class = CPUExecutor + elif parallel_config.distributed_executor_backend == "ray": + initialize_ray_cluster(parallel_config) + from vllm.executor.ray_gpu_executor import RayGPUExecutor + + executor_class = RayGPUExecutor + elif parallel_config.distributed_executor_backend == "mp": + from vllm.executor.multiproc_gpu_executor import MultiprocessingGPUExecutor + + executor_class = MultiprocessingGPUExecutor + else: + assert parallel_config.world_size == 1, "Ray is required if parallel_config.world_size > 1." + from vllm.executor.gpu_executor import GPUExecutor + + executor_class = GPUExecutor + + # Initialize the engine + self.engine = NemoLLMEngine( + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + load_config=load_config, + lora_config=None, + vision_language_config=None, + speculative_config=None, + decoding_config=None, + executor_class=executor_class, + log_stats=log_stats, + ) + + def _add_request_to_engine( + self, prompt: str, max_output_len: int, temperature: float = 1.0, top_k: int = 1, top_p: float = 0.0 + ) -> str: + if top_p <= 0.0: + top_p = 1.0 + + sampling_params = SamplingParams(max_tokens=max_output_len, temperature=temperature, top_k=top_k, top_p=top_p) + + request_id = str(self.request_id) + self.request_id += 1 + + self.engine.add_request(request_id, prompt, sampling_params) + + return request_id + + def _forward_regular(self, request_ids: List[str]): + responses = [None] * len(request_ids) + finished = [False] * len(request_ids) + + while not all(finished): + request_outputs: List[RequestOutput] = self.engine.step() + + for request_output in request_outputs: + if not request_output.finished: + continue + + try: + request_index = request_ids.index(request_output.request_id) + except ValueError: + continue + + finished[request_index] = request_output.finished + output_text = request_output.outputs[-1].text + responses[request_index] = output_text + + return [[response] for response in responses] + + def _forward_streaming(self, request_ids: List[str]): + responses = [None] * len(request_ids) + finished = [False] * len(request_ids) + + while not all(finished): + request_outputs: List[RequestOutput] = self.engine.step() + + for request_output in request_outputs: + try: + request_index = request_ids.index(request_output.request_id) + except ValueError: + continue + + finished[request_index] = request_output.finished + output_text = request_output.outputs[-1].text + responses[request_index] = output_text + + yield [[response] for response in responses] + + def _add_triton_request_to_engine(self, inputs: numpy.ndarray, index: int) -> str: + return self._add_request_to_engine( + prompt=inputs['prompts'][index][0].decode('UTF-8'), + max_output_len=inputs['max_output_len'][index][0], + temperature=inputs['temperature'][index][0], + top_k=inputs['top_k'][index][0], + top_p=inputs['top_p'][index][0], + ) + + @property + def get_triton_input(self): + inputs = ( + Tensor(name="prompts", shape=(-1,), dtype=bytes), + Tensor(name="max_output_len", shape=(-1,), dtype=numpy.int_, optional=True), + Tensor(name="top_k", shape=(-1,), dtype=numpy.int_, optional=True), + Tensor(name="top_p", shape=(-1,), dtype=numpy.single, optional=True), + Tensor(name="temperature", shape=(-1,), dtype=numpy.single, optional=True), + ) + return inputs + + @property + def get_triton_output(self): + outputs = (Tensor(name="outputs", shape=(-1,), dtype=bytes),) + return outputs + + @batch + def triton_infer_fn(self, **inputs: numpy.ndarray): + request_ids = [] + num_requests = len(inputs["prompts"]) + for index in range(num_requests): + request_id = self._add_triton_request_to_engine(inputs, index) + request_ids.append(request_id) + + responses = self._forward_regular(request_ids) + responses = [r[0] for r in responses] + + output_tensor = cast_output(responses, numpy.bytes_) + return {'outputs': output_tensor} + + @batch + def triton_infer_fn_streaming(self, **inputs: numpy.ndarray): + request_ids = [] + num_requests = len(inputs["prompts"]) + for index in range(num_requests): + request_id = self._add_triton_request_to_engine(inputs, index) + request_ids.append(request_id) + + for responses in self._forward_streaming(request_ids): + responses = [r[0] for r in responses] + output_tensor = cast_output(responses, numpy.bytes_) + yield {'outputs': output_tensor} + + # Mimic the TensorRTLLM exporter's forward function, even though we don't support many of its features. + def forward( + self, + input_texts: List[str], + max_output_len: int = 64, + top_k: int = 1, + top_p: float = 0.0, + temperature: float = 1.0, + stop_words_list: Optional[List[str]] = None, + bad_words_list: Optional[List[str]] = None, + no_repeat_ngram_size: Optional[int] = None, + task_ids: Optional[List[str]] = None, + lora_uids: Optional[List[str]] = None, + prompt_embeddings_table=None, + prompt_embeddings_checkpoint_path: Optional[str] = None, + streaming: bool = False, + output_log_probs: bool = False, + ) -> Union[List[List[str]], Iterable[List[List[str]]]]: + """ + The forward function performs LLM evaluation on the provided array of prompts with other parameters shared, + and returns the generated texts. If 'streaming' is True, the output texts are returned incrementally + with a generator: one token appended to each output at a time. If 'streaming' is false, the final output texts + are returned as a single list of responses. + """ + + if stop_words_list is not None and stop_words_list != []: + raise NotImplementedError("stop_words_list is not supported") + + if bad_words_list is not None and bad_words_list != []: + raise NotImplementedError("bad_words_list is not supported") + + if no_repeat_ngram_size is not None: + raise NotImplementedError("no_repeat_ngram_size is not supported") + + if task_ids is not None and task_ids != []: + raise NotImplementedError("task_ids is not supported") + + if lora_uids is not None and lora_uids != []: + raise NotImplementedError("lora_uids is not supported") + + if prompt_embeddings_table is not None: + raise NotImplementedError("prompt_embeddings_table is not supported") + + if prompt_embeddings_checkpoint_path is not None: + raise NotImplementedError("prompt_embeddings_checkpoint_path is not supported") + + if output_log_probs: + raise NotImplementedError("output_log_probs is not supported") + + request_ids = [] + for prompt in input_texts: + request_id = self._add_request_to_engine( + prompt=prompt, max_output_len=max_output_len, temperature=temperature, top_k=top_k, top_p=top_p + ) + request_ids.append(request_id) + + if streaming: + return self._forward_streaming(request_ids) + else: + return self._forward_regular(request_ids) diff --git a/nemo/lightning/__init__.py b/nemo/lightning/__init__.py index 0c5379fb6e82a..e9674ed1e2128 100644 --- a/nemo/lightning/__init__.py +++ b/nemo/lightning/__init__.py @@ -10,9 +10,12 @@ pass from nemo.lightning.base import get_vocab_size, teardown +from nemo.lightning.fabric.fabric import Fabric +from nemo.lightning.fabric.plugins import FabricMegatronMixedPrecision +from nemo.lightning.fabric.strategies import FabricMegatronStrategy from nemo.lightning.nemo_logger import NeMoLogger -from nemo.lightning.pytorch.callbacks.megatron_model_checkpoint import ModelCheckpoint -from nemo.lightning.pytorch.opt import LRSchedulerModule, MegatronOptimizerModule, OptimizerModule +from nemo.lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint +from nemo.lightning.pytorch.optim import LRSchedulerModule, MegatronOptimizerModule, OptimizerModule, lr_scheduler from nemo.lightning.pytorch.plugins import MegatronDataSampler, MegatronMixedPrecision from nemo.lightning.pytorch.plugins import data_sampler as _data_sampler from nemo.lightning.pytorch.strategies import MegatronStrategy @@ -34,11 +37,15 @@ def _is_slurm_interactive_mode(): __all__ = [ "AutoResume", + "Fabric", + "FabricMegatronMixedPrecision", + "FabricMegatronStrategy", "LRSchedulerModule", "MegatronStrategy", "MegatronDataSampler", "MegatronMixedPrecision", "MegatronOptimizerModule", + "lr_scheduler", "NeMoLogger", "ModelCheckpoint", "OptimizerModule", diff --git a/nemo/lightning/_strategy_lib.py b/nemo/lightning/_strategy_lib.py index 9dd36ba54dbee..e6452de165128 100644 --- a/nemo/lightning/_strategy_lib.py +++ b/nemo/lightning/_strategy_lib.py @@ -2,7 +2,7 @@ import os from collections import defaultdict from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Protocol, TypeVar +from typing import TYPE_CHECKING, Any, Dict, Generator, Mapping, Optional, Protocol, TypeVar import torch from torch import nn @@ -119,6 +119,29 @@ def init_model_parallel(model: Optional[nn.Module] = None) -> None: child.set_tensor_parallel_group(tp_group) +def set_model_parallel_attributes(model, parallelism): + # Right now mcore sub-classes ModelParellelConfig, we should remove that + # Given Lightning's structure it would be better if parallelism is a different object + # Since then it can be passed to the Strategy + + from megatron.core.transformer.transformer_config import TransformerConfig + + has_mcore_config = isinstance(getattr(model, "config", None), TransformerConfig) + if has_mcore_config and hasattr(model, "configure_model"): + config: TransformerConfig = model.config + config.tensor_model_parallel_size = parallelism.tensor_model_parallel_size + config.pipeline_model_parallel_size = parallelism.pipeline_model_parallel_size + config.virtual_pipeline_model_parallel_size = parallelism.virtual_pipeline_model_parallel_size + config.context_parallel_size = parallelism.context_parallel_size + config.expert_model_parallel_size = parallelism.expert_model_parallel_size + config.moe_extended_tp = parallelism.moe_extended_tp + config.sequence_parallel = parallelism.sequence_parallel + + return config + + return None + + @contextmanager def megatron_lazy_init_context(config) -> Generator[None, None, None]: def monkey_patched(c): @@ -375,7 +398,12 @@ def enable_nvidia_optimizations() -> None: pass -def optimizer_sharded_state_dict(model: SharedStateDictProtocol, optimizer: "Optimizable") -> Dict[str, torch.Tensor]: +def optimizer_sharded_state_dict( + model: SharedStateDictProtocol, + optimizer: "Optimizable", + is_loading=False, + sharding_type='fully_sharded_model_space', +) -> Dict[str, torch.Tensor]: """ Sharded state dictionary for an MainParamsOptimizerWrapper. Used to save and load the optimizer state when training with distributed_checkpoint. @@ -403,7 +431,9 @@ def optimizer_sharded_state_dict(model: SharedStateDictProtocol, optimizer: "Opt } if hasattr(optimizer, "sharded_state_dict"): - return optimizer.sharded_state_dict(model_sharded_state_dict) + return optimizer.sharded_state_dict( + model_sharded_state_dict, is_loading=is_loading, sharding_type=sharding_type + ) if not isinstance(optimizer, MainParamsOptimizerWrapper): # Regular optimizer, e.g. Adam or FusedAdam @@ -447,3 +477,42 @@ def get_safe(param_id): optim_state_to_sharding_state(optimizer_state_dict["optimizer"], id_to_sharded_param_map) return optimizer_state_dict + + +def load_model_state_dict(megatron_parallel, checkpoint: Mapping[str, Any], strict: bool = True) -> None: + from megatron.core import parallel_state + + for index, module in enumerate(megatron_parallel): + if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: + if "state_dict" in checkpoint: + checkpoint_state_dict = checkpoint["state_dict"][f"model_{index}"] + else: + checkpoint_state_dict = checkpoint[f"model_{index}"] + else: + if "state_dict" in checkpoint: + checkpoint_state_dict = checkpoint["state_dict"] + else: + checkpoint_state_dict = checkpoint + + n_nesting = 0 + mcore_model = megatron_parallel.module + while hasattr(mcore_model, "module"): + mcore_model = mcore_model.module + n_nesting += 1 + + _state_dict = {} + for key, value in checkpoint_state_dict.items(): + # Count the number of "module." at the start of the key + count, _key = 0, key + while _key.startswith("module."): + _key = _key[len("module.") :] + count += 1 + + # Adjust the number of "module." prefixes + if count < n_nesting: + to_add = "module." * (n_nesting - count) + _state_dict[f"{to_add}{key}"] = value + elif count > n_nesting: + to_remove = "module." * (count - n_nesting) + _state_dict[key[len(to_remove) :]] = value + module.load_state_dict(_state_dict, strict=strict) diff --git a/nemo/lightning/base.py b/nemo/lightning/base.py index ba5daf12f95fc..128ecb661efd8 100644 --- a/nemo/lightning/base.py +++ b/nemo/lightning/base.py @@ -26,8 +26,7 @@ def get_vocab_size( after = vocab_size multiple = make_vocab_size_divisible_by * config.tensor_model_parallel_size - while (after % multiple) != 0: - after += 1 + after = ((after + multiple - 1) // multiple) * multiple logging.info( f"Padded vocab_size: {after}, original vocab_size: {vocab_size}, dummy tokens:" f" {after - vocab_size}." ) diff --git a/nemo/lightning/data.py b/nemo/lightning/data.py index adfc0aa14d295..d83f5ba3b7282 100644 --- a/nemo/lightning/data.py +++ b/nemo/lightning/data.py @@ -183,9 +183,12 @@ def __len__(self): num_available_samples: int = self.total_samples - self.consumed_samples if self.global_batch_size is not None: if self.drop_last: - return num_available_samples // self.global_batch_size + num_global_batches = num_available_samples // self.global_batch_size else: - return (num_available_samples + self.global_batch_size - 1) // self.global_batch_size + num_global_batches = (num_available_samples + self.global_batch_size - 1) // self.global_batch_size + # return len of dataloader in terms of micro batches to avoid discrepancy between len of dataloader and + # num of batches fetched (as training step fetches in terms of micro batches) + return num_global_batches * (self.global_batch_size // self.micro_batch_times_data_parallel_size) else: return (num_available_samples - 1) // self.micro_batch_times_data_parallel_size + 1 diff --git a/nemo/lightning/fabric/__init__.py b/nemo/lightning/fabric/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/nemo/lightning/fabric/conversion.py b/nemo/lightning/fabric/conversion.py new file mode 100644 index 0000000000000..cc2b074940dd8 --- /dev/null +++ b/nemo/lightning/fabric/conversion.py @@ -0,0 +1,110 @@ +from functools import singledispatch +from typing import Any, TypeVar + +from lightning_fabric import plugins as fl_plugins +from lightning_fabric import strategies as fl_strategies +from pytorch_lightning import plugins as pl_plugins +from pytorch_lightning import strategies as pl_strategies + +T = TypeVar('T') +FabricT = TypeVar('FabricT') + + +@singledispatch +def to_fabric(obj: Any) -> Any: + """ + Convert a PyTorch Lightning object to its Fabric equivalent. + + Args: + obj: The object to convert. + + Returns: + The Fabric equivalent of the input object. + + Raises: + NotImplementedError: If no converter is registered for the object's type. + + Example: + >>> from pytorch_lightning.strategies import Strategy as PLStrategy + >>> from lightning_fabric.strategies import Strategy as FabricStrategy + >>> from nemo.lightning.fabric.conversion import to_fabric + >>> + >>> # Define a custom PyTorch Lightning strategy + >>> class CustomPLStrategy(PLStrategy): + ... def __init__(self, custom_param: str): + ... super().__init__() + ... self.custom_param = custom_param + >>> + >>> # Define a custom Fabric strategy + >>> class CustomFabricStrategy(FabricStrategy): + ... def __init__(self, custom_param: str): + ... super().__init__() + ... self.custom_param = custom_param + >>> + >>> # Register a custom conversion + >>> @to_fabric.register(CustomPLStrategy) + ... def _custom_converter(strategy: CustomPLStrategy) -> CustomFabricStrategy: + ... return CustomFabricStrategy(custom_param=strategy.custom_param) + >>> + >>> # Use the custom conversion + >>> pl_strategy = CustomPLStrategy(custom_param="test") + >>> fabric_strategy = to_fabric(pl_strategy) + >>> assert isinstance(fabric_strategy, CustomFabricStrategy) + >>> assert fabric_strategy.custom_param == "test" + """ + raise NotImplementedError( + f"No Fabric converter registered for {type(obj).__name__}. " + f"To register a new conversion, use the @to_fabric.register decorator:\n\n" + f"from nemo.lightning.fabric.conversion import to_fabric\n" + f"from lightning_fabric import strategies as fl_strategies\n\n" + f"@to_fabric.register({type(obj).__name__})\n" + f"def _{type(obj).__name__.lower()}_converter(obj: {type(obj).__name__}) -> fl_strategies.Strategy:\n" + f" return fl_strategies.SomeStrategy(\n" + f" # Map relevant attributes from 'obj' to Fabric equivalent\n" + f" param1=obj.param1,\n" + f" param2=obj.param2,\n" + f" # ... other parameters ...\n" + f" )\n\n" + f"Add this code to the appropriate module (e.g., nemo/lightning/fabric/conversion.py)." + ) + + +@to_fabric.register(pl_strategies.DDPStrategy) +def _ddp_converter(strategy: pl_strategies.DDPStrategy) -> fl_strategies.DDPStrategy: + return fl_strategies.DDPStrategy( + accelerator=strategy.accelerator, + parallel_devices=strategy.parallel_devices, + cluster_environment=strategy.cluster_environment, + process_group_backend=strategy.process_group_backend, + timeout=strategy._timeout, + start_method=strategy._start_method, + **strategy._ddp_kwargs, + ) + + +@to_fabric.register(pl_strategies.FSDPStrategy) +def _fsdp_converter(strategy: pl_strategies.FSDPStrategy) -> fl_strategies.FSDPStrategy: + return fl_strategies.FSDPStrategy( + cpu_offload=strategy.cpu_offload, + parallel_devices=strategy.parallel_devices, + cluster_environment=strategy.cluster_environment, + process_group_backend=strategy.process_group_backend, + timeout=strategy._timeout, + **strategy.kwargs, + ) + + +@to_fabric.register(pl_plugins.MixedPrecision) +def _mixed_precision_converter(plugin: pl_plugins.MixedPrecision) -> fl_plugins.MixedPrecision: + return fl_plugins.MixedPrecision( + precision=plugin.precision, + device=plugin.device, + scaler=plugin.scaler, + ) + + +@to_fabric.register(pl_plugins.FSDPPrecision) +def _fsdp_precision_converter(plugin: pl_plugins.FSDPPrecision) -> fl_plugins.FSDPPrecision: + return fl_plugins.FSDPPrecision( + precision=plugin.precision, + ) diff --git a/nemo/lightning/fabric/fabric.py b/nemo/lightning/fabric/fabric.py new file mode 100644 index 0000000000000..ced57af5adef5 --- /dev/null +++ b/nemo/lightning/fabric/fabric.py @@ -0,0 +1,132 @@ +from copy import deepcopy +from pathlib import Path +from typing import Optional, Protocol, Type, TypeVar, Union, runtime_checkable + +import fiddle as fdl +import lightning_fabric as lb +from torch import nn +from typing_extensions import Self, override + +from nemo.lightning.io.mixin import IOMixin, serialization, track_io + +ModelT = TypeVar("ModelT", bound=nn.Module) + + +class Fabric(lb.Fabric, IOMixin): + def io_init(self, **kwargs) -> fdl.Config[Self]: + # Each argument of the trainer can be stateful so we copy them + cfg_kwargs = {k: deepcopy(v) for k, v in kwargs.items()} + + for val in cfg_kwargs.values(): + if not serialization.find_node_traverser(type(val)): + track_io(type(val)) + + return fdl.Config(type(self), **cfg_kwargs) + + def load_model( + self, + path: Union[str, Path], + model: Optional[ModelT] = None, + ) -> "DistributedModel[ModelT]": + """Load and set up a model for distributed training. + + This method loads a model from the given path, sets it up for distributed training + using the current Fabric instance, and returns a DistributedModel. + + Args: + path (Union[str, Path]): The path to the saved model checkpoint. + model (Optional[ModelT], optional): An optional pre-instantiated model. If not + provided, the model will be loaded from the checkpoint. Defaults to None. + + Returns: + DistributedModel[ModelT]: The loaded and distributed model. + + Example: + >>> from nemo import lightning as nl + >>> + >>> trainer = nl.Trainer( + ... devices=2, + ... strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), + ... plugins=nl.MegatronMixedPrecision(precision='16-mixed') + ... ) + >>> fabric = trainer.to_fabric() + >>> distributed_model = fabric.load_model("path/to/checkpoint/dir") + >>> + >>> # You can now interact with the parallel model + """ + self.launch() + + from nemo.lightning.io import load_context + + if model is None: + context = load_context(path) + model = context.model + + dist_model = self.setup_module(model) + self.load(path, {"state_dict": dist_model}) + + return dist_model + + def import_model( + self, + path: Union[str, Path], + model_type: Type[ModelT], + ) -> "DistributedModel[ModelT]": + """ + Import a model from a given path and set it up for distributed training. + + This method imports a model of the specified type from the given path, loads it, + and sets it up for distributed training using the current Fabric instance. + + Args: + path (Union[str, Path]): The path to the model. Can be a local path or a + Hugging Face model identifier. + model_type (Type[ModelT]): The type of the model to import. Must be a subclass + of ConnectorMixin. + + Returns: + DistributedModel[ModelT]: The imported and distributed model. + + Raises: + TypeError: If the provided model_type is not a subclass of ConnectorMixin. + + Example: + >>> from nemo import lightning as nl + >>> from nemo.collections.llm import MistralModel + >>> + >>> trainer = nl.Trainer( + ... devices=2, + ... strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), + ... plugins=nl.MegatronMixedPrecision(precision='16-mixed') + ... ) + >>> fabric = trainer.to_fabric() + >>> model = fabric.import_model("hf://mistralai/Mistral-7B-v0.1", MistralModel) + >>> + >>> # You can now interact with the parallel model + """ + from nemo.lightning.io import ConnectorMixin + + if not issubclass(model_type, ConnectorMixin): + raise TypeError("The provided model class must be a subclass of ConnectorMixin") + + model: ModelT = model_type.import_from(path) + + return self.load_model(model.ckpt_path, model) + + @override + def setup_module(self, module: nn.Module, move_to_device: bool = True, _reapply_compile: bool = True): + from nemo.lightning.fabric.strategies import FabricMegatronStrategy + + out = super().setup_module(module, move_to_device=move_to_device, _reapply_compile=_reapply_compile) + + # We don't want to return a _FabricModule for megatron since we only want to precision convert + # at the beginning and end of the pipeline + if isinstance(self.strategy, FabricMegatronStrategy): + return out._forward_module + + return out + + +@runtime_checkable +class DistributedModel(Protocol[ModelT]): + module: ModelT diff --git a/nemo/lightning/fabric/plugins.py b/nemo/lightning/fabric/plugins.py new file mode 100644 index 0000000000000..79e1455cb33f3 --- /dev/null +++ b/nemo/lightning/fabric/plugins.py @@ -0,0 +1,129 @@ +from contextlib import contextmanager +from typing import Any, Generator, Literal, Optional, TypeVar, Union + +import torch +from lightning_fabric.plugins.precision import MixedPrecision +from lightning_fabric.utilities.types import Optimizable +from torch import nn +from torch.optim import Optimizer + +from nemo.lightning._strategy_lib import GradScaler +from nemo.lightning.fabric.conversion import to_fabric +from nemo.lightning.pytorch.plugins.mixed_precision import MegatronMixedPrecision + +AnyT = TypeVar("AnyT") + + +class FabricMegatronMixedPrecision(MixedPrecision): + def __init__( + self, + precision: Literal["16-mixed", "bf16-mixed"] = "16-mixed", + amp_02: bool = True, + device="cuda", + scaler: Optional[Union[torch.cuda.amp.GradScaler, str]] = None, + ) -> None: + if precision == "bf16-mixed": + scaler = None + else: + scaler = GradScaler( + init_scale=2**32, + growth_interval=1000, + hysteresis=2, + ) + + super().__init__(precision, device, scaler) + self.amp_02 = amp_02 + + def convert_input(self, data: AnyT) -> AnyT: + """Convert model inputs (forward) to the floating point precision type of this plugin. + + Note: MegatronStrategy will take care of only doing this when: + mpu.is_pipeline_first_stage() + + """ + return data + + def convert_output(self, data: AnyT) -> AnyT: + """Convert outputs to the floating point precision type expected after model's forward. + + Note: MegatronStrategy will take care of only doing this when: + mpu.is_pipeline_first_stage() + + """ + return data + + def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: + from nemo.core.optim import MainParamsOptimizerWrapper + + return MainParamsOptimizerWrapper( + optimizer, + # https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_base_model.py#L496 + fp32_grad_accum=True, + contiguous_grad_bucket=True, + ) + + def convert_module(self, module: nn.Module) -> nn.Module: + """Convert the module parameters to the precision type this plugin handles. + + This is optional and depends on the precision limitations during optimization. + + """ + if not hasattr(module, "module"): + return module + + from megatron.core.transformer.module import Float16Module + from megatron.core.utils import get_model_config + + if self.precision in ["16-mixed", "bf16-mixed"]: + config = get_model_config(module.module) + config.fp16 = self.precision == "16-mixed" + config.bf16 = self.precision == "bf16-mixed" + if not isinstance(module.module, Float16Module): + module.module = Float16Module(config, module.module) + + return module + + def optimizer_step( + self, + optimizer: Optimizable, + **kwargs: Any, + ) -> None: + from nemo.core.optim import MainParamsOptimizerWrapper + + assert isinstance( + optimizer, MainParamsOptimizerWrapper + ), "MegatronHalfPrecisionPlugin supports only the optimizer with master parameters" + + if self.scaler is None: + assert optimizer.fp32_grad_accumulation, "BF16 uses FP32 grad accumulation" + + # skip scaler logic, as bfloat16 does not require scaler + return super().optimizer_step(optimizer, **kwargs) + + assert not optimizer.fp32_grad_accumulation, "FP16 uses FP16 grad accumulation" + + # cast fp16 grads to fp32 and copy to main grads, which are used for unscale and param update + optimizer.copy_model_grads_to_main_grads() + + # note: the scaler will skip the `optimizer.step` if nonfinite gradients are found + step_output = self.scaler.step(optimizer, **kwargs) + self.scaler.update() + + return step_output + + @contextmanager + def forward_context(self) -> Generator[None, None, None]: + """No explicit precision casting. Inputs are supposed to be manually casted.""" + try: + yield + finally: + pass + + +@to_fabric.register(MegatronMixedPrecision) +def _convert_megatron_mixed_precision(plugin: MegatronMixedPrecision) -> FabricMegatronMixedPrecision: + return FabricMegatronMixedPrecision( + precision=plugin.precision, + device=plugin.device, + scaler=plugin.scaler, + ) diff --git a/nemo/lightning/fabric/strategies.py b/nemo/lightning/fabric/strategies.py new file mode 100644 index 0000000000000..a662386a91196 --- /dev/null +++ b/nemo/lightning/fabric/strategies.py @@ -0,0 +1,427 @@ +from contextlib import ExitStack, contextmanager +from datetime import timedelta +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ContextManager, + Dict, + Generator, + Iterator, + List, + Literal, + Optional, + Union, +) + +import torch +from lightning_fabric.accelerators import CPUAccelerator +from lightning_fabric.accelerators.accelerator import Accelerator +from lightning_fabric.plugins.collectives.torch_collective import default_pg_timeout +from lightning_fabric.plugins.environments.cluster_environment import ClusterEnvironment +from lightning_fabric.plugins.io.checkpoint_io import CheckpointIO +from lightning_fabric.plugins.precision import Precision +from lightning_fabric.strategies import DDPStrategy +from lightning_fabric.strategies.strategy import _validate_keys_for_strict_loading +from lightning_fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1 +from lightning_fabric.utilities.types import _PATH, _Stateful +from megatron.core.distributed import DistributedDataParallelConfig +from pytorch_lightning.loops.fetchers import _DataFetcher +from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO +from pytorch_lightning.utilities.combined_loader import CombinedLoader +from torch import Tensor, nn +from torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks import noop_hook +from torch.nn import Module +from torch.optim import Optimizer +from torch.utils.data import DataLoader +from typing_extensions import override + +from nemo.lightning import _strategy_lib +from nemo.lightning.fabric.conversion import to_fabric +from nemo.lightning.io.pl import MegatronCheckpointIO +from nemo.lightning.megatron_parallel import CallbackConnector, MegatronParallel +from nemo.lightning.pytorch.strategies import MegatronStrategy + +if TYPE_CHECKING: + from megatron.core.model_parallel_config import ModelParallelConfig + + from nemo.lightning.pytorch.plugins.data_sampler import DataSampler + + +DDPLiteral = Literal["megatron", "pytorch"] + + +class FabricMegatronStrategy(DDPStrategy): + def __init__( + self, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + virtual_pipeline_model_parallel_size: Optional[int] = None, + context_parallel_size: int = 1, + sequence_parallel: bool = False, + expert_model_parallel_size: int = 1, + moe_extended_tp: bool = False, + data_sampler: Optional["DataSampler"] = None, + accelerator: Optional[Accelerator] = None, + parallel_devices: Optional[List[torch.device]] = None, + cluster_environment: Optional[ClusterEnvironment] = None, + checkpoint_io: Optional[CheckpointIO] = None, + precision: Optional[Precision] = None, + megatron_callbacks: Optional[CallbackConnector] = None, + ddp: Union[DDPLiteral, DistributedDataParallelConfig] = "megatron", + process_group_backend: Optional[str] = None, + timeout: Optional[timedelta] = default_pg_timeout, + start_method: Literal["popen", "spawn", "fork", "forkserver"] = "popen", + no_ddp_communication_hook: bool = True, + output_data_idx: bool = False, + pipeline_dtype: Optional[torch.dtype] = None, + **kwargs: Any, + ) -> None: + super().__init__( + accelerator=accelerator, + parallel_devices=parallel_devices, + cluster_environment=cluster_environment, + checkpoint_io=checkpoint_io, + precision=precision, + process_group_backend=process_group_backend, + timeout=timeout, + start_method=start_method, + **kwargs, + ) + self.megatron_callbacks = CallbackConnector() + self.data_sampler: Optional['DataSampler'] = data_sampler + self.tensor_model_parallel_size = tensor_model_parallel_size + self.pipeline_model_parallel_size = pipeline_model_parallel_size + self.context_parallel_size = context_parallel_size + self.expert_model_parallel_size = expert_model_parallel_size + self.moe_extended_tp = moe_extended_tp + self.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size + self.sequence_parallel = sequence_parallel + self.pipeline_dtype = pipeline_dtype + + self.no_ddp_communication_hook = no_ddp_communication_hook + self.megatron_callbacks = CallbackConnector() + if megatron_callbacks: + self.megatron_callbacks.add(megatron_callbacks) + self.output_data_idx = output_data_idx + + # used in NVIDIA NGC PyTorch containers + _strategy_lib.enable_nvidia_optimizations() + + self._ddp = ddp + if ddp == "megatron": + self.ddp_config = DistributedDataParallelConfig() + elif isinstance(ddp, DistributedDataParallelConfig): + self.ddp_config = ddp + elif ddp == "pytorch": + self.ddp_config = None + self.no_ddp_communication_hook = False + else: + raise ValueError(f"Invalid DDP type: {ddp}") + + @override + def _setup_distributed(self) -> None: + self._set_world_ranks() + + assert self.cluster_environment is not None + _strategy_lib.init_parallel_ranks( + world_size=self.cluster_environment.world_size(), + global_rank=self.cluster_environment.global_rank(), + local_rank=self.cluster_environment.local_rank(), + parallel_config=self.parallelism, + ) + + super()._setup_distributed() + torch.cuda.set_device(self.cluster_environment.local_rank()) + + # TODO: Fix this: + # if self.data_config is not None: + # _strategy_lib.initialize_data(self.cluster_environment.global_rank(), self.data_config) + _strategy_lib.init_model_parallel() + + @override + def process_dataloader(self, dataloader: DataLoader) -> Iterator: + loader = _strategy_lib.process_dataloader(dataloader, self.data_config) + + # Code taken from: https://github.com/Lightning-AI/pytorch-lightning/blob/6cbe9ceb560d798892bdae9186291acf9bf5d2e3/src/lightning/pytorch/loops/fit_loop.py#L258-L260 + output = _MegatronDataLoaderIterDataFetcher(self.data_config, output_data_idx=self.output_data_idx) + output.setup(CombinedLoader(loader, "max_size_cycle")) + iter(output) + + return output + + @override + def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: + """Pass the optimizer to the precision-plugin if needed & add it as callback.""" + if hasattr(self._precision, "setup_optimizer"): + optimizer = self._precision.setup_optimizer(optimizer) + + self.megatron_callbacks.add(optimizer) + + return optimizer + + @override + def setup_module(self, module: Module) -> MegatronParallel: + _strategy_lib.set_model_parallel_attributes(module, self.parallelism) + + # Call configure_model if it's overridden (relevant for LightningModules with lazy initialization) + if hasattr(module, "configure_model"): + module.configure_model() + + convert_module_fn = None + if hasattr(self.precision, "convert_module"): + convert_module_fn = self.precision.convert_module + + megatron_parallel = MegatronParallel( + module, + precision_plugin=self.precision, + vp_size=self.virtual_pipeline_model_parallel_size, + cpu=isinstance(self.accelerator, CPUAccelerator), + ddp_config=self.ddp_config, + convert_module_fn=convert_module_fn, + ) + + if not self.ddp_config: + from megatron.core import mpu + + from nemo.utils import AppState + + app_state = AppState() + + if app_state.model_parallel_size is not None: + self._ddp_kwargs["process_group"] = mpu.get_data_parallel_group() + + dist_data_parallel = super().setup_module(megatron_parallel) + if self.no_ddp_communication_hook: + # When using custom gradient accumulation and allreduce, disable + # DDP communication hook that works on the gradient bucket. + # Instead, use the custom gradient function and communication hook, + # which is defined in the master optimizer wrapper. + dist_data_parallel.require_backward_grad_sync = False + dist_data_parallel.register_comm_hook(None, noop_hook) + + return dist_data_parallel + + return megatron_parallel + + def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: + precision_init_ctx = self.precision.module_init_context() + module_sharded_ctx = self.megatron_context() + stack = ExitStack() + if _TORCH_GREATER_EQUAL_2_1 and empty_init: + # Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is: + # 1) materialize module 2) call `reset_parameters()` 3) shard the module. + # These operations are applied to each submodule 'bottom up' in the module hierarchy. + stack.enter_context(torch.device("meta")) + stack.enter_context(precision_init_ctx) + stack.enter_context(module_sharded_ctx) + + return stack + + def module_to_device(self, module: nn.Module) -> None: + pass + + @override + def save_checkpoint( + self, + path: _PATH, + state: Dict[str, Union[Module, Optimizer, Any]], + storage_options: Optional[Any] = None, + filter_dict: Optional[Dict[str, Callable[[str, Any], bool]]] = None, + ) -> None: + """Save model, optimizer, and other state as a checkpoint file. + + Args: + path: A path to where the file(s) should be saved + state: A dictionary with contents to be saved. If the dict contains modules or optimizers, their + state-dict will be retrieved and converted automatically. + storage_options: Additional options for the ``CheckpointIO`` plugin + filter: An optional dictionary containing filter callables that return a boolean indicating whether the + given item should be saved (``True``) or filtered out (``False``). Each filter key should match a + state key, where its filter will be applied to the ``state_dict`` generated. + + """ + state = self._convert_stateful_objects_in_state(state, filter=(filter_dict or {})) + self.checkpoint_io.save_checkpoint(checkpoint=state, path=path, storage_options=storage_options) + + def load_checkpoint( + self, + path: _PATH, + state: Optional[Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]]] = None, + strict: bool = True, + ) -> Dict[str, Any]: + if isinstance(state, Optimizer): + raise NotImplementedError("Optimizer loading is not supported, pass it as a dict including the model") + + torch.cuda.empty_cache() + + # After dist_checkpointing.load, sharded tensors will be replaced with tensors + sharded_state_dict = {} + if isinstance(state, Module): + sharded_state_dict["state_dict"] = state.sharded_state_dict() + elif strict: + sharded_state_dict["state_dict"] = state["state_dict"].sharded_state_dict() + if "optimizer" in state: + sharded_state_dict["optimizer"] = _strategy_lib.optimizer_sharded_state_dict( + state["state_dict"], state["optimizer"], is_loading=True + ) + else: + for obj in state.items(): + if isinstance(obj, Module): + sharded_state_dict["state_dict"] = obj.sharded_state_dict() + elif isinstance(obj, Optimizer): + sharded_state_dict["optimizer"] = _strategy_lib.optimizer_sharded_state_dict(obj, is_loading=True) + + checkpoint = self.checkpoint_io.load_checkpoint(path, sharded_state_dict=sharded_state_dict) + + if isinstance(state, Module): + self.load_module_state_dict(module=state, state_dict=checkpoint, strict=strict) + return {} + + _validate_keys_for_strict_loading(state.keys(), checkpoint.keys(), strict=strict) + for name, obj in state.copy().items(): + if name not in checkpoint: + continue + if isinstance(obj, _Stateful): + if isinstance(obj, Module): + self.load_module_state_dict(module=obj, state_dict=checkpoint.pop(name), strict=strict) + else: + obj.load_state_dict(checkpoint.pop(name)) + else: + state[name] = checkpoint.pop(name) + + return checkpoint + + @override + def load_module_state_dict( + self, module: Module, state_dict: Dict[str, Union[Any, Tensor]], strict: bool = True + ) -> None: + _strategy_lib.load_model_state_dict(module, state_dict, strict=strict) + + @contextmanager + def megatron_context(self) -> Generator[None, None, None]: + def monkey_patched(config): + return {"device": "meta"} + + from megatron.core.transformer.custom_layers import transformer_engine as _te + + original = _te._get_extra_te_kwargs # noqa: SLF001 + _te._get_extra_te_kwargs = monkey_patched # noqa: SLF001 + + self.parallelism.perform_initialization = False + self.parallelism.use_cpu_initialization = True + + yield + + _te._get_extra_te_kwargs = original # noqa: SLF001 + + @property + @override + def checkpoint_io(self) -> CheckpointIO: + if self._checkpoint_io is None: + self._checkpoint_io = MegatronCheckpointIO() + elif isinstance(self._checkpoint_io, _WrappingCheckpointIO): + self._checkpoint_io.checkpoint_io = MegatronCheckpointIO() + + return self._checkpoint_io + + @property + def parallelism(self): + from megatron.core.model_parallel_config import ModelParallelConfig + + return ModelParallelConfig( + tensor_model_parallel_size=self.tensor_model_parallel_size, + pipeline_model_parallel_size=self.pipeline_model_parallel_size, + virtual_pipeline_model_parallel_size=self.virtual_pipeline_model_parallel_size, + context_parallel_size=self.context_parallel_size, + sequence_parallel=self.sequence_parallel, + expert_model_parallel_size=self.expert_model_parallel_size, + moe_extended_tp=self.moe_extended_tp, + pipeline_dtype=self.pipeline_dtype, + ) + + +# TODO: Fix this +class _MegatronDataLoaderIterDataFetcher(_DataFetcher): + def __init__(self, data_config, *args: Any, output_data_idx: bool = False, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.data_config = data_config + self.output_data_idx = output_data_idx + self._batch: Any = None + self._batch_idx: int = 0 + self._dataloader_idx: int = 0 + + def __iter__(self) -> "_MegatronDataLoaderIterDataFetcher": + super().__iter__() + self.iterator_wrapper = iter(_DataFetcherWrapper(self, output_data_idx=self.output_data_idx)) + return self + + def __next__(self) -> Iterator["_DataFetcherWrapper"]: # type: ignore[override] + if self.done: + raise StopIteration + return self.iterator_wrapper + + def reset(self) -> None: + super().reset() + self._batch = None + self._batch_idx = 0 + self._dataloader_idx = 0 + + +class _DataFetcherWrapper(Iterator): + def __init__( + self, + data_fetcher: _MegatronDataLoaderIterDataFetcher, + output_data_idx: bool = False, + ) -> None: + self.data_fetcher = data_fetcher + self.output_data_idx = output_data_idx + + @property + def done(self) -> bool: + return self.data_fetcher.done + + @property + def fetched(self) -> int: + return self.data_fetcher.fetched + + @property + def length(self) -> Optional[int]: + return self.data_fetcher.length + + @property + def data_config(self): + return self.data_fetcher.data_config + + def __next__(self): + fetcher = self.data_fetcher + if fetcher.done: + raise StopIteration + batch, batch_idx, dataloader_idx = super(_MegatronDataLoaderIterDataFetcher, fetcher).__next__() + # save the state so the loops can access it + fetcher._batch = batch # noqa: SLF001 + fetcher._batch_idx = batch_idx # noqa: SLF001 + fetcher._dataloader_idx = dataloader_idx # noqa: SLF001 + + if not self.output_data_idx: + return batch + + return batch, batch_idx, dataloader_idx + + +@to_fabric.register(MegatronStrategy) +def convert_megatron_strategy(strategy: MegatronStrategy) -> FabricMegatronStrategy: + return FabricMegatronStrategy( + tensor_model_parallel_size=strategy.tensor_model_parallel_size, + pipeline_model_parallel_size=strategy.pipeline_model_parallel_size, + virtual_pipeline_model_parallel_size=strategy.virtual_pipeline_model_parallel_size, + context_parallel_size=strategy.context_parallel_size, + sequence_parallel=strategy.sequence_parallel, + expert_model_parallel_size=strategy.expert_model_parallel_size, + moe_extended_tp=strategy.moe_extended_tp, + pipeline_dtype=strategy.pipeline_dtype, + ddp=strategy._ddp, + process_group_backend=strategy.process_group_backend, + timeout=strategy._timeout, + start_method=strategy._start_method, + ) diff --git a/nemo/lightning/io/__init__.py b/nemo/lightning/io/__init__.py index d1a193c5e7283..2dcc53945fff3 100644 --- a/nemo/lightning/io/__init__.py +++ b/nemo/lightning/io/__init__.py @@ -1,25 +1,27 @@ -from nemo.lightning.io.api import export_ckpt, import_ckpt, load, load_ckpt, model_exporter, model_importer +from nemo.lightning.io.api import export_ckpt, import_ckpt, load, load_context, model_exporter, model_importer from nemo.lightning.io.capture import reinit from nemo.lightning.io.connector import Connector, ModelConnector -from nemo.lightning.io.mixin import ConnectorMixin, IOMixin -from nemo.lightning.io.pl import TrainerCheckpoint, is_distributed_ckpt +from nemo.lightning.io.mixin import ConnectorMixin, IOMixin, track_io +from nemo.lightning.io.pl import TrainerContext, is_distributed_ckpt from nemo.lightning.io.state import TransformCTX, apply_transforms, state_transform + __all__ = [ "apply_transforms", "Connector", "ConnectorMixin", "IOMixin", + "track_io", "import_ckpt", "is_distributed_ckpt", "export_ckpt", "load", - "load_ckpt", + "load_context", "ModelConnector", "model_importer", "model_exporter", 'reinit', "state_transform", - "TrainerCheckpoint", + "TrainerContext", "TransformCTX", ] diff --git a/nemo/lightning/io/api.py b/nemo/lightning/io/api.py index fbe764d67e3de..cc594b562cffa 100644 --- a/nemo/lightning/io/api.py +++ b/nemo/lightning/io/api.py @@ -1,12 +1,12 @@ -import pickle from pathlib import Path from typing import Any, Callable, Optional, Type, TypeVar import fiddle as fdl import pytorch_lightning as pl +from fiddle._src.experimental import serialization from nemo.lightning.io.mixin import ConnectorMixin, ConnT, ModelConnector -from nemo.lightning.io.pl import TrainerCheckpoint +from nemo.lightning.io.pl import TrainerContext CkptType = TypeVar("CkptType") @@ -34,34 +34,34 @@ def load(path: Path, output_type: Type[CkptType] = Any) -> CkptType: _path = Path(path) if hasattr(_path, 'is_dir') and _path.is_dir(): - _path = Path(_path) / "io.pkl" + _path = Path(_path) / "io.json" elif hasattr(_path, 'isdir') and _path.isdir: - _path = Path(_path) / "io.pkl" + _path = Path(_path) / "io.json" if not _path.is_file(): raise FileNotFoundError(f"No such file: '{_path}'") with open(_path, "rb") as f: - config = pickle.load(f) + config = serialization.load_json(f.read()) return fdl.build(config) -def load_ckpt(path: Path) -> TrainerCheckpoint: +def load_context(path: Path) -> TrainerContext: """ - Loads a TrainerCheckpoint from a pickle file or directory. + Loads a TrainerContext from a json-file or directory. Args: - path (Path): The path to the pickle file or directory containing 'io.pkl'. + path (Path): The path to the json-file or directory containing 'io.json'. Returns ------- - TrainerCheckpoint: The loaded TrainerCheckpoint instance. + TrainerContext: The loaded TrainerContext instance. Example: - checkpoint: TrainerCheckpoint = load_ckpt("/path/to/checkpoint") + checkpoint: TrainerContext = load_ckpt("/path/to/checkpoint") """ - return load(path, output_type=TrainerCheckpoint) + return load(path, output_type=TrainerContext) def model_importer(target: Type[ConnectorMixin], ext: str) -> Callable[[Type[ConnT]], Type[ConnT]]: @@ -167,7 +167,7 @@ def import_ckpt( def load_connector_from_trainer_ckpt(path: Path, target: str) -> ModelConnector: - model: pl.LightningModule = load_ckpt(path).model + model: pl.LightningModule = load_context(path).model if not isinstance(model, ConnectorMixin): raise ValueError("Model must be an instance of ConnectorMixin") diff --git a/nemo/lightning/io/artifact/__init__.py b/nemo/lightning/io/artifact/__init__.py new file mode 100644 index 0000000000000..572bd37c0be81 --- /dev/null +++ b/nemo/lightning/io/artifact/__init__.py @@ -0,0 +1,4 @@ +from nemo.lightning.io.artifact.base import Artifact +from nemo.lightning.io.artifact.file import FileArtifact, PathArtifact + +__all__ = ["Artifact", "FileArtifact", "PathArtifact"] diff --git a/nemo/lightning/io/artifact/base.py b/nemo/lightning/io/artifact/base.py new file mode 100644 index 0000000000000..4025634ebe28a --- /dev/null +++ b/nemo/lightning/io/artifact/base.py @@ -0,0 +1,18 @@ +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Generic, TypeVar + +ValueT = TypeVar("ValueT") + + +class Artifact(ABC, Generic[ValueT]): + def __init__(self, attr: str): + self.attr = attr + + @abstractmethod + def dump(self, value: ValueT, path: Path) -> ValueT: + pass + + @abstractmethod + def load(self, path: Path) -> ValueT: + pass diff --git a/nemo/lightning/io/artifact/file.py b/nemo/lightning/io/artifact/file.py new file mode 100644 index 0000000000000..0bd4f48dc17f0 --- /dev/null +++ b/nemo/lightning/io/artifact/file.py @@ -0,0 +1,29 @@ +import shutil +from pathlib import Path +from typing import Union + +from nemo.lightning.io.artifact.base import Artifact + + +class PathArtifact(Artifact[Path]): + def dump(self, value: Path, path: Path) -> Path: + new_value = copy_file(value, path) + return new_value + + def load(self, path: Path) -> Path: + return path + + +class FileArtifact(Artifact[str]): + def dump(self, value: str, path: Path) -> str: + new_value = copy_file(value, path) + return str(new_value) + + def load(self, path: str) -> str: + return path + + +def copy_file(src: Union[Path, str], dst: Union[Path, str]): + output = Path(dst) / Path(src).name + shutil.copy2(src, output) + return output diff --git a/nemo/lightning/io/artifact/pickle.py b/nemo/lightning/io/artifact/pickle.py new file mode 100644 index 0000000000000..31ed7e36ac93c --- /dev/null +++ b/nemo/lightning/io/artifact/pickle.py @@ -0,0 +1,22 @@ +from pathlib import Path +from typing import Any + +from cloudpickle import dump, load + +from nemo.lightning.io.artifact.base import Artifact + + +class PickleArtifact(Artifact[Any]): + def dump(self, value: Any, path: Path) -> Path: + file = self.file_path(path) + with open(file, "wb") as f: + dump(value, f) + + return file + + def load(self, path: Path) -> Any: + with open(self.file_path(path), "rb") as f: + return load(f) + + def file_path(self, path: Path) -> Path: + return path / self.attr + ".pkl" diff --git a/nemo/lightning/io/connector.py b/nemo/lightning/io/connector.py index 41c81582bb631..500d0203cfd47 100644 --- a/nemo/lightning/io/connector.py +++ b/nemo/lightning/io/connector.py @@ -184,9 +184,9 @@ def nemo_load( Tuple[pl.LightningModule, pl.Trainer]: The loaded model and the trainer configured with the model. """ from nemo.lightning import MegatronStrategy, Trainer, _strategy_lib - from nemo.lightning.io.api import load_ckpt + from nemo.lightning.io.api import load_context - model = load_ckpt(path).model + model = load_context(path).model _trainer = trainer or Trainer( devices=1, accelerator="cpu" if cpu else "gpu", strategy=MegatronStrategy(ddp="pytorch") ) @@ -218,4 +218,7 @@ def local_path(self, base_path: Optional[Path] = None) -> Path: return _base / str(self).replace("://", "/") def on_import_ckpt(self, model: pl.LightningModule): - model.tokenizer = self.tokenizer + if hasattr(self, "tokenizer"): + model.tokenizer = self.tokenizer + if hasattr(model, "__io__"): + model.__io__.tokenizer = self.tokenizer diff --git a/nemo/lightning/io/fdl_torch.py b/nemo/lightning/io/fdl_torch.py new file mode 100644 index 0000000000000..c74e48e1c411c --- /dev/null +++ b/nemo/lightning/io/fdl_torch.py @@ -0,0 +1,116 @@ +"""Fiddle extensions to handle PyTorch code more elegantly. + +This module provides extensions for better handling of PyTorch types and functions +in codegen, graphviz, and other debugging functions. +""" + +import types + +import libcst as cst +import torch +import torch.nn as nn +from fiddle._src import daglish_extensions +from fiddle._src.codegen import import_manager, py_val_to_cst_converter, special_value_codegen +from fiddle._src.experimental import serialization + + +def _make_torch_importable(name: str) -> special_value_codegen.Importable: + return special_value_codegen.SingleImportable("torch", lambda torch_name: f"{torch_name}.{name}") + + +_torch_type_importables = ( + (torch.bool, _make_torch_importable("bool")), + (torch.uint8, _make_torch_importable("uint8")), + (torch.int8, _make_torch_importable("int8")), + (torch.int16, _make_torch_importable("int16")), + (torch.int32, _make_torch_importable("int32")), + (torch.int64, _make_torch_importable("int64")), + (torch.float16, _make_torch_importable("float16")), + (torch.bfloat16, _make_torch_importable("bfloat16")), + (torch.float32, _make_torch_importable("float32")), + (torch.float64, _make_torch_importable("float64")), + (torch.complex64, _make_torch_importable("complex64")), + (torch.complex128, _make_torch_importable("complex128")), +) + +_torch_initializers = ( + nn.init.constant_, + nn.init.dirac_, + nn.init.xavier_normal_, + nn.init.xavier_uniform_, + nn.init.kaiming_normal_, + nn.init.kaiming_uniform_, + nn.init.normal_, + nn.init.ones_, + nn.init.orthogonal_, + nn.init.uniform_, + nn.init.zeros_, +) + +_import_aliases = (("torch.nn.init", "from torch.nn import init"),) + + +def _make_torch_nn_importable(name: str) -> special_value_codegen.Importable: + return special_value_codegen.SingleImportable("torch", lambda torch_mod_name: f"{torch_mod_name}.nn.{name}") + + +_nn_type_importables = ( + (nn.ReLU, _make_torch_nn_importable("ReLU")), + (nn.GELU, _make_torch_nn_importable("GELU")), + (nn.ReLU6, _make_torch_nn_importable("ReLU6")), + (nn.SiLU, _make_torch_nn_importable("SiLU")), + (nn.Sigmoid, _make_torch_nn_importable("Sigmoid")), + (nn.SELU, _make_torch_nn_importable("SELU")), + (nn.Hardtanh, _make_torch_nn_importable("Hardtanh")), + (nn.Tanh, _make_torch_nn_importable("Tanh")), +) + + +def is_torch_tensor(value): + """Returns true if `value` is a PyTorch Tensor.""" + return isinstance(value, torch.Tensor) + + +def convert_torch_tensor_to_cst(value, convert_child): + return cst.Call( + func=cst.Attribute(value=convert_child(torch), attr=cst.Name("tensor")), + args=[ + cst.Arg(convert_child(value.tolist())), + py_val_to_cst_converter.kwarg_to_cst("dtype", convert_child(value.dtype)), + ], + ) + + +def enable(): + """Registers PyTorch fiddle extensions. + + This allows for things like nicer handling of torch dtypes. + """ + for value, importable in _torch_type_importables: + special_value_codegen.register_exact_value(value, importable) + + for value, importable in _nn_type_importables: + special_value_codegen.register_exact_value(value, importable) + + for module_str, import_stmt in _import_aliases: + import_manager.register_import_alias(module_str, import_stmt) + + py_val_to_cst_converter.register_py_val_to_cst_converter(is_torch_tensor)(convert_torch_tensor_to_cst) + + for dtype, _ in _torch_type_importables: + daglish_extensions.register_immutable(dtype) + lib, symbol = str(dtype).split(".") + serialization.register_constant(lib, symbol, compare_by_identity=True) + + for init in _torch_initializers: + daglish_extensions.register_immutable(init) + daglish_extensions.register_function_with_immutable_return_value(init) + + # Monkey-patch the Serialization class to handle things like activation-functions + def _modified_serialize(self, value, current_path, all_paths=None): + if isinstance(value, types.BuiltinFunctionType): + return self._pyref(value, current_path) + return self._original_serialize(value, current_path, all_paths) + + serialization.Serialization._original_serialize = serialization.Serialization._serialize + serialization.Serialization._serialize = _modified_serialize diff --git a/nemo/lightning/io/mixin.py b/nemo/lightning/io/mixin.py index 54b6e7195bc9b..dfc78c30a929c 100644 --- a/nemo/lightning/io/mixin.py +++ b/nemo/lightning/io/mixin.py @@ -1,17 +1,31 @@ import functools import inspect +import shutil +import threading +import types +import uuid +from copy import deepcopy from dataclasses import is_dataclass from pathlib import Path -from typing import Any, Callable, Dict, Optional, Type, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union import fiddle as fdl -from cloudpickle import dump +import fiddle._src.experimental.dataclasses as fdl_dc +from cloudpickle import dump, load +from fiddle._src.experimental import serialization from typing_extensions import Self +from nemo.lightning.io.artifact.base import Artifact from nemo.lightning.io.capture import IOProtocol from nemo.lightning.io.connector import ModelConnector +from nemo.lightning.io.fdl_torch import enable as _enable_ext ConnT = TypeVar('ConnT', bound=ModelConnector) +_enable_ext() + + +# Thread-local storage for artifacts directory +_thread_local = threading.local() class IOMixin: @@ -54,7 +68,7 @@ def __init__(self, param1, param2): """ - __io__ = fdl.Config[Self] + __io__: fdl.Config[Self] def __new__(cls, *args, **kwargs): """ @@ -69,19 +83,14 @@ def __new__(cls, *args, **kwargs): ------- The newly created object instance. """ - original_init = cls.__init__ - - @functools.wraps(original_init) - def wrapped_init(self, *args, **kwargs): - cfg_kwargs = self.io_transform_args(original_init, *args, **kwargs) - self.__io__ = self.io_init(**cfg_kwargs) - original_init(self, *args, **kwargs) - - cls.__init__ = wrapped_init + cls = _io_wrap_init(cls) output = object().__new__(cls) return output + def __init_subclass__(cls): + _io_register_serialization(cls) + def io_transform_args(self, init_fn, *args, **kwargs) -> Dict[str, Any]: """ Transforms and captures the arguments passed to the `__init__` method, filtering out @@ -97,24 +106,7 @@ def io_transform_args(self, init_fn, *args, **kwargs) -> Dict[str, Any]: ------- Dict[str, Any]: A dictionary of the captured and transformed arguments. """ - sig = inspect.signature(init_fn) - bound_args = sig.bind_partial(self, *args, **kwargs) - bound_args.apply_defaults() - config_kwargs = {k: v for k, v in bound_args.arguments.items() if k != "self"} - - to_del = [] - for key in config_kwargs: - if isinstance(config_kwargs[key], IOProtocol): - config_kwargs[key] = config_kwargs[key].__io__ - if is_dataclass(self): - # Check if the arg is a factory (dataclasses.field) - if config_kwargs[key].__class__.__name__ == "_HAS_DEFAULT_FACTORY_CLASS": - to_del.append(key) - - for key in to_del: - del config_kwargs[key] - - return config_kwargs + return _io_transform_args(self, init_fn, *args, **kwargs) def io_init(self, **kwargs) -> fdl.Config[Self]: """ @@ -127,19 +119,42 @@ def io_init(self, **kwargs) -> fdl.Config[Self]: ------- fdl.Config[Self]: The initialized configuration object. """ - return fdl.Config(type(self), **kwargs) + return _io_init(self, **kwargs) + + @classmethod + def io_artifacts(cls) -> List[Artifact]: + return [] def io_dump(self, output: Path): """ Serializes the configuration object (`__io__`) to a file, allowing the object state to be - saved and later restored. + saved and later restored. Also creates an artifacts directory and stores it in a thread-local + global variable. If the artifacts directory is empty at the end, it is deleted. Args: - output (Path): The path to the file where the configuration object will be serialized. + output (Path): The path to the directory where the configuration object and artifacts + will be stored. """ - config_path = Path(output) / "io.pkl" - with open(config_path, "wb") as f: - dump(self.__io__, f) + output_path = Path(output) + artifacts_dir = output_path / "artifacts" + artifacts_dir.mkdir(parents=True, exist_ok=True) + + # Store artifacts directory in thread-local storage + _thread_local.artifacts_dir = artifacts_dir + + config_path = output_path / "io.json" + with open(config_path, "w") as f: + io = deepcopy(self.__io__) + _artifact_transform(io, artifacts_dir) + json = serialization.dump_json(io) + f.write(json) + + # Clear thread-local storage after io_dump is complete + del _thread_local.artifacts_dir + + # Check if artifacts directory is empty and delete if so + if not any(artifacts_dir.iterdir()): + shutil.rmtree(artifacts_dir) class ConnectorMixin: @@ -178,7 +193,7 @@ def import_from(cls, path: str) -> Self: Self: An instance of the model initialized from the imported data. """ output = cls._get_connector(path).init() - output.ckpt_path = output.import_ckpt_path(path) + output.ckpt_path = output.import_ckpt(path) return output @@ -321,3 +336,175 @@ def _get_connector(cls, ext, path=None, importer=True) -> ModelConnector: return connector() return connector(_path) + + +def track_io(target, artifacts: Optional[List[Artifact]] = None): + """ + Adds IO functionality to the target object or eligible classes in the target module + by wrapping __init__ and registering serialization methods. + + Args: + target (object or types.ModuleType): The target object or module to modify. + + Returns: + object or types.ModuleType: The modified target with IO functionality added to eligible classes. + + Examples: + >>> from nemo.collections.common import tokenizers + >>> modified_tokenizers = track_io(tokenizers) + >>> ModifiedWordTokenizer = track_io(tokenizers.WordTokenizer) + """ + + def _add_io_to_class(cls): + if inspect.isclass(cls) and hasattr(cls, '__init__') and not hasattr(cls, '__io__'): + if cls in [str, int, float, tuple, list, dict, bool, type(None)]: + return cls + + cls = _io_wrap_init(cls) + _io_register_serialization(cls) + cls.__io_artifacts__ = artifacts or [] + return cls + + def _process_module(module): + for name, obj in inspect.getmembers(module): + if inspect.isclass(obj) and _is_defined_in_module_or_submodules(obj, module): + setattr(module, name, _add_io_to_class(obj)) + return module + + def _is_defined_in_module_or_submodules(obj, module): + return obj.__module__ == module.__name__ or obj.__module__.startswith(f"{module.__name__}.") + + if isinstance(target, types.ModuleType): + return _process_module(target) + elif inspect.isclass(target): + return _add_io_to_class(target) + else: + raise TypeError("Target must be a module or a class") + + +def _io_transform_args(self, init_fn, *args, **kwargs) -> Dict[str, Any]: + """ + Transforms and captures the arguments passed to the `__init__` method, filtering out + any arguments that are instances of `IOProtocol` or are dataclass fields with default + factories. + + Args: + init_fn (Callable): The original `__init__` method of the class. + *args: Variable length argument list for the `__init__` method. + **kwargs: Arbitrary keyword arguments for the `__init__` method. + + Returns + ------- + Dict[str, Any]: A dictionary of the captured and transformed arguments. + """ + sig = inspect.signature(init_fn) + bound_args = sig.bind_partial(self, *args, **kwargs) + bound_args.apply_defaults() + config_kwargs = {k: v for k, v in bound_args.arguments.items() if k != "self"} + + to_del = [] + for key in config_kwargs: + if isinstance(config_kwargs[key], IOProtocol): + config_kwargs[key] = config_kwargs[key].__io__ + if is_dataclass(config_kwargs[key]): + config_kwargs[key] = fdl_dc.convert_dataclasses_to_configs(config_kwargs[key], allow_post_init=True) + # Check if the arg is a factory (dataclasses.field) + if config_kwargs[key].__class__.__name__ == "_HAS_DEFAULT_FACTORY_CLASS": + to_del.append(key) + + for key in to_del: + del config_kwargs[key] + + return config_kwargs + + +def _io_init(self, **kwargs) -> fdl.Config[Self]: + """ + Initializes the configuration object (`__io__`) with the captured arguments. + + Args: + **kwargs: A dictionary of arguments that were captured during object initialization. + + Returns + ------- + fdl.Config[Self]: The initialized configuration object. + """ + return fdl.Config(type(self), **kwargs) + + +def _io_wrap_init(cls): + """Wraps the __init__ method of a class to add IO functionality.""" + original_init = cls.__init__ + + @functools.wraps(original_init) + def wrapped_init(self, *args, **kwargs): + if hasattr(self, "io_transform_args"): + cfg_kwargs = self.io_transform_args(original_init, *args, **kwargs) + else: + cfg_kwargs = _io_transform_args(self, original_init, *args, **kwargs) + if hasattr(self, "io_init"): + self.__io__ = self.io_init(**cfg_kwargs) + else: + self.__io__ = _io_init(self, **cfg_kwargs) + + original_init(self, *args, **kwargs) + + cls.__init__ = wrapped_init + return cls + + +def _io_register_serialization(cls): + serialization.register_node_traverser( + cls, + flatten_fn=_io_flatten_object, + unflatten_fn=_io_unflatten_object, + path_elements_fn=_io_path_elements_fn, + ) + + +def _io_flatten_object(instance): + try: + serialization.dump_json(instance.__io__) + except (serialization.UnserializableValueError, AttributeError) as e: + if not hasattr(_thread_local, "artifacts_dir"): + raise e + + artifact_dir = _thread_local.artifacts_dir + artifact_path = artifact_dir / f"{uuid.uuid4()}" + with open(artifact_path, "wb") as f: + dump(getattr(instance, "__io__", instance), f) + return (str(artifact_path),), None + + return instance.__io__.__flatten__() + + +def _io_unflatten_object(values, metadata): + if len(values) == 1: + pickle_path = values[0] + with open(pickle_path, "rb") as f: + return load(f) + + return fdl.Config.__unflatten__(values, metadata) + + +def _io_path_elements_fn(x): + try: + serialization.dump_json(x.__io__) + except (serialization.UnserializableValueError, AttributeError) as e: + return (serialization.IdentityElement(),) + + return x.__io__.__path_elements__() + + +def _artifact_transform(cfg: fdl.Config, output_path: Path): + for artifact in getattr(cfg.__fn_or_cls__, "__io_artifacts__", []): + current_val = getattr(cfg, artifact.attr) + new_val = artifact.dump(current_val, output_path) + setattr(cfg, artifact.attr, new_val) + + for attr in dir(cfg): + try: + if isinstance(getattr(cfg, attr), fdl.Config): + _artifact_transform(getattr(cfg, attr), output_path=output_path) + except ValueError: + pass diff --git a/nemo/lightning/io/pl.py b/nemo/lightning/io/pl.py index 72490c5d17a55..2cadc56e59b42 100644 --- a/nemo/lightning/io/pl.py +++ b/nemo/lightning/io/pl.py @@ -1,21 +1,34 @@ import logging from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, Protocol, TypeVar, Union +from typing import Any, Callable, Dict, Generic, Optional, TypeVar, Union import pytorch_lightning as pl import torch from lightning_fabric.plugins.io.checkpoint_io import CheckpointIO from lightning_fabric.utilities.cloud_io import get_filesystem from lightning_fabric.utilities.types import _PATH +from megatron.core.dist_checkpointing.serialization import ( + get_default_load_sharded_strategy, + get_default_save_sharded_strategy, +) + +# from nemo.utils.callbacks.torch_dist_async import TorchDistAsyncSaveShardedStrategy +from megatron.core.dist_checkpointing.strategies import tensorstore +from megatron.core.dist_checkpointing.strategies.async_utils import AsyncCallsQueue, AsyncRequest +from megatron.core.dist_checkpointing.strategies.base import SaveShardedStrategy +from megatron.core.dist_checkpointing.strategies.fully_parallel import ( + FullyParallelLoadStrategyWrapper, + FullyParallelSaveStrategyWrapper, +) +from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy +from megatron.core.parallel_state import get_data_parallel_group from torch import nn from typing_extensions import Self, override from nemo.lightning.io.capture import IOProtocol from nemo.lightning.io.mixin import IOMixin - -if TYPE_CHECKING: - from nemo.lightning.pytorch.strategies import MegatronStrategy +from nemo.utils.callbacks.dist_ckpt_io import AsyncCompatibleCheckpointIO log = logging.getLogger(__name__) @@ -25,40 +38,30 @@ @dataclass -class TrainerCheckpoint(IOMixin, Generic[LightningModuleT]): +class TrainerContext(IOMixin, Generic[LightningModuleT]): model: LightningModuleT trainer: pl.Trainer extra: Dict[str, Any] = field(default_factory=dict) @classmethod - def from_strategy(cls, strategy: "MegatronStrategy") -> Self: - if not isinstance(strategy.trainer, IOProtocol): + def from_trainer(cls, trainer: pl.Trainer) -> Self: + if not hasattr(trainer, "__io__"): raise ValueError(f"Trainer must be an instance of {IOProtocol}. Please use the Trainer from nemo.") - - if not isinstance(strategy.lightning_module, IOProtocol): + if not hasattr(trainer.lightning_module, "__io__"): raise ValueError("LightningModule must extend IOMixin.") - return cls(trainer=strategy.trainer, model=strategy.lightning_module, extra=cls.construct_extra(strategy)) + return cls(trainer=trainer, model=trainer.lightning_module, extra=cls.construct_extra(trainer)) @classmethod - def construct_extra(cls, strategy: "MegatronStrategy") -> Dict[str, Any]: + def construct_extra(cls, trainer: pl.Trainer) -> Dict[str, Any]: extra = {} - if hasattr(strategy.trainer, "datamodule") and isinstance(strategy.trainer.datamodule, IOProtocol): - extra["datamodule"] = strategy.trainer.datamodule.__io__ - - # TODO: Add optimizer to extra + if hasattr(trainer, "datamodule") and hasattr(trainer.datamodule, "__io__"): + extra["datamodule"] = trainer.datamodule.__io__ return extra -class TrainerCkptProtocol(Protocol): - @classmethod - def from_strategy(cls, strategy: "MegatronStrategy") -> Self: ... - - def io_dump(self, output: Path): ... - - -class MegatronCheckpointIO(CheckpointIO): +class MegatronCheckpointIO(AsyncCompatibleCheckpointIO, IOMixin): """CheckpointIO that utilizes :func:`torch.save` and :func:`torch.load` to save and load checkpoints respectively, common for most use cases. @@ -68,10 +71,24 @@ class MegatronCheckpointIO(CheckpointIO): def __init__( self, - save_ckpt_format: str = 'zarr', + save_ckpt_format: str = 'torch_dist', + load_directly_on_device: bool = True, + async_save: bool = False, + torch_dist_multiproc: Optional[int] = None, + assume_constant_structure: bool = False, + parallel_save: bool = True, + parallel_load: bool = False, ): self.save_ckpt_format = save_ckpt_format - self.save_sharded_strategy = self._determine_dist_ckpt_save_strategy() + self.load_directly_on_device = load_directly_on_device + self.async_save = async_save + self.torch_dist_multiproc = torch_dist_multiproc + self.assume_constant_structure = assume_constant_structure + self.parallel_save = parallel_save + self.parallel_load = parallel_load + + self._save_sharded_strategy = None + self.validated_consistency = False @override def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: @@ -90,11 +107,11 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio """ from megatron.core import dist_checkpointing - if storage_options is not None: - raise TypeError( - "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg" - f" is not supported for `{self.__class__.__name__}`. Please implement your custom `CheckpointIO`" - " to define how you'd like to use `storage_options`." + if storage_options is not None and len(storage_options) > 0: + logging.warning( + f"{self.__class__.__name__} does not support" + f" storage_options, but {storage_options=} was provided." + f" Ignoring given storage_options" ) checkpoint_dir = ckpt_to_dir(path) fs = get_filesystem(checkpoint_dir) @@ -103,10 +120,14 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio return fs.makedirs(checkpoint_dir, exist_ok=True) - dist_checkpointing.save( - checkpoint, - checkpoint_dir=str(checkpoint_dir), + validate_sharding_integrity = not (self.validated_consistency and self.assume_constant_structure) + self.validated_consistency = True + return dist_checkpointing.save( + sharded_state_dict=checkpoint, + checkpoint_dir=checkpoint_dir, sharded_strategy=self.save_sharded_strategy, + validate_access_integrity=validate_sharding_integrity, + async_sharded_save=self.async_save, ) @override @@ -139,7 +160,24 @@ def load_checkpoint( if not fs.isdir(path): raise ValueError(f"Distributed checkpoints should be a directory. Found: {path}.") - checkpoint = dist_checkpointing.load(sharded_state_dict=sharded_state_dict, checkpoint_dir=str(path)) + if self.save_ckpt_format == 'zarr' and self.load_directly_on_device: + sharded_strategy = tensorstore.TensorStoreLoadShardedStrategy(load_directly_on_device=True) + else: + sharded_strategy = None + + if self.parallel_load: + if sharded_strategy is None: + sharded_strategy = get_default_load_sharded_strategy(path) + sharded_strategy = FullyParallelLoadStrategyWrapper( + sharded_strategy, get_data_parallel_group(with_context_parallel=True) + ) + + if sharded_strategy is not None: + logging.info(f'Using {sharded_strategy} dist-ckpt load strategy.') + + checkpoint = dist_checkpointing.load( + sharded_state_dict=sharded_state_dict, checkpoint_dir=str(path), sharded_strategy=sharded_strategy + ) checkpoint = _fix_tensors_device(checkpoint) return checkpoint @@ -159,14 +197,38 @@ def remove_checkpoint(self, path: _PATH) -> None: def _determine_dist_ckpt_save_strategy(self): """Determine the saving strategy based on constructor args. - If self.async_save is True instantiates an async PyT Dist strategy, - otherwise relies on MCore to create a proper strategy based on ckpt format. + + Relies on the default MCore strategy unless extra PyT Distributed format arguments + are passed in config or in case of a fully parallel save in which case + a parallelization wrapper is applied. """ - save_strategy = (self.save_ckpt_format, 1) + if self.async_save and self.save_ckpt_format != 'torch_dist': + raise ValueError('Async dist-ckpt save supported only for torch_dist format') + + torch_dist_kwargs = {} if self.torch_dist_multiproc is None else dict(thread_count=self.torch_dist_multiproc) + if self.save_ckpt_format == 'torch_dist' and torch_dist_kwargs: + save_strategy = TorchDistSaveShardedStrategy(self.save_ckpt_format, 1, **torch_dist_kwargs) + else: + save_strategy = get_default_save_sharded_strategy(self.save_ckpt_format, 1) + + # MCore v0.8 introduces `use_cached_ckpt_structure` attribute + if hasattr(save_strategy, 'use_cached_ckpt_structure'): + save_strategy.use_cached_ckpt_structure = self.assume_constant_structure + + if self.parallel_save: + save_strategy = FullyParallelSaveStrategyWrapper( + save_strategy, get_data_parallel_group(with_context_parallel=True), self.assume_constant_structure + ) logging.info(f'Using {save_strategy} dist-ckpt save strategy.') return save_strategy + @property + def save_sharded_strategy(self) -> 'SaveShardedStrategy': + if self._save_sharded_strategy is None: + self._save_sharded_strategy = self._determine_dist_ckpt_save_strategy() + return self._save_sharded_strategy + def _fix_tensors_device(ckpt: Dict) -> Dict: """Ensure checkpoint tensors are on the correct device.""" diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 4eab2fc4ea387..2f23087170040 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -12,6 +12,7 @@ Iterable, Iterator, List, + Mapping, Optional, Protocol, Sequence, @@ -28,8 +29,10 @@ from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.transformer.transformer_config import TransformerConfig from torch import Tensor, nn +from typing_extensions import override DataT = TypeVar("DataT", Tensor, Dict[str, Tensor], Sequence[Tensor]) +ModelT = TypeVar("ModelT", bound=nn.Module) @runtime_checkable @@ -46,7 +49,7 @@ def default_data_step(dataloader_iter: Iterator[DataT]) -> DataT: batch = batch[0] if isinstance(batch, dict): - batch = {k: v.cuda() for k, v in batch.items()} + batch = {k: v.cuda(non_blocking=True) for k, v in batch.items()} return batch @@ -55,7 +58,21 @@ def default_forward_step(model: nn.Module, batch, *args, **kwargs) -> torch.Tens return model(batch, *args, **kwargs) -class MegatronParallel(nn.ModuleList): +def extract_ddp_funcs(ddp_config, pipeline): + no_sync_func, grad_sync_func = None, None + + if getattr(ddp_config, "overlap_grad_reduce", False): + no_sync_func = [model_chunk.no_sync for model_chunk in pipeline] + no_sync_func = no_sync_func[0] if len(pipeline) == 1 else no_sync_func + # TODO(@akoumparouli): why is True default here? + if getattr(ddp_config, "delay_grad_reduce", True): + grad_sync_func = [model_chunk.start_grad_sync for model_chunk in pipeline] + grad_sync_func = grad_sync_func[0] if len(pipeline) == 1 else grad_sync_func + + return no_sync_func, grad_sync_func + + +class MegatronParallel(nn.ModuleList, Generic[ModelT]): """Implements distributed model parallelism that is based on Megatron-LM. This supports various forms of parallelism: @@ -101,16 +118,16 @@ class MegatronParallel(nn.ModuleList): def __init__( self, - pipeline: Union[nn.Module, Iterable[nn.Module]], + pipeline: Union[ModelT, Iterable[ModelT]], precision_plugin: Optional[PrecisionPluginProtocol] = None, callbacks: Optional["CallbackConnector"] = None, data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None, - forward_step: Optional[Callable[[nn.Module, DataT], Tensor]] = None, - loss_reduction: Optional[Callable[[nn.Module], "MegatronLossReduction"]] = None, + forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None, + loss_reduction: Optional[Callable[[ModelT], "MegatronLossReduction"]] = None, vp_size: Optional[int] = None, ddp_config: Optional[DistributedDataParallelConfig] = None, cpu: bool = False, - convert_module_fn: Optional[Callable[[nn.Module], nn.Module]] = None, + convert_module_fn: Optional[Callable[[ModelT], nn.Module]] = None, ) -> None: from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes from megatron.core import parallel_state @@ -157,6 +174,12 @@ def __init__( model_chunk.buffers = ddp.buffers # We need to do this explicitly since this is a attr pytorch uses model_chunk.__class__.__getattr__ = getattr_proxy # type: ignore + # param_sync_func is set in nemo.lightning.pytorch.optim.megatron + no_sync_func, grad_sync_func = extract_ddp_funcs(ddp_config, _pipeline) + for module in _pipeline: + module.config.no_sync_func = no_sync_func + module.config.grad_sync_func = grad_sync_func + for i, model_module in enumerate(_pipeline): if not cpu: model_module.cuda(torch.cuda.current_device()) @@ -277,7 +300,7 @@ def forward( if forward_only: loss_mean = cast(torch.Tensor, []) else: - loss_mean = torch.tensor(0.0).cuda() + loss_mean = torch.tensor(0.0, device=torch.cuda.current_device()) self.callbacks.event("on_megatron_log_step_end", **context) self.callbacks.event("on_megatron_step_end", **context) @@ -503,7 +526,7 @@ def sharded_state_dict(self, prefix: str = "") -> Dict[str, Any]: # virtual pipline rank must be set so that GPTModel returns the correct sharded state dict parallel_state.set_virtual_pipeline_model_parallel_rank(index) module_sharded_state_dict = self._module_sharded_state_dict(module) - sharded_state_dict[f"megatron_module_{index}"] = module_sharded_state_dict + sharded_state_dict[f"model_{index}"] = module_sharded_state_dict else: module_sharded_state_dict = self._module_sharded_state_dict(module) sharded_state_dict.update(module_sharded_state_dict) @@ -524,18 +547,37 @@ def _module_sharded_state_dict(self, module, *args, **kwargs) -> Dict[str, Any]: raise ValueError("Could not find sharded state dict") @property - def pipeline(self) -> Union[nn.Module, List[nn.Module]]: + def pipeline(self) -> Union[ModelT, List[ModelT]]: if len(self) == 1: return self[0] else: return list(self) + @property + def module(self) -> ModelT: + return self[0] + @property def forward_backward_func(self) -> "MegatronStepProtocol": from megatron.core.pipeline_parallel.schedules import get_forward_backward_func return get_forward_backward_func() + @override + def __getattr__(self, item: Any) -> Any: + if len(self) == 0: + return super().__getattr__(item) + + try: + # __getattr__ gets called as a last resort if the attribute does not exist + # call nn.Module's implementation first + return super().__getattr__(item) + except AttributeError: + # If the attribute is not available on the _FabricModule wrapper, redirect to the wrapped nn.Module + attr = getattr(self._modules[self._get_abs_string_index(0)], item) + + return attr + class _ModuleStepFunction: def __init__(self, name: str, is_property: bool = False, includes_self: bool = False): @@ -976,7 +1018,7 @@ def forward( loss_sum_and_ub_size_all_gpu = torch.cat( [ loss_sum_for_ub.clone().detach().view(1), - torch.tensor([num_valid_tokens_in_ub]).cuda().clone().detach(), + torch.tensor([num_valid_tokens_in_ub], device=torch.cuda.current_device()).clone().detach(), ] ) torch.distributed.all_reduce(loss_sum_and_ub_size_all_gpu, group=parallel_state.get_data_parallel_group()) @@ -1003,11 +1045,11 @@ def reduce(self, losses_reduced_per_micro_batch) -> torch.Tensor: loss_sum = ( torch.vstack(loss_sum_tensors_list).sum(dim=0) if len(loss_sum_tensors_list) > 0 - else torch.tensor([0.0, 0.0]).cuda() + else torch.tensor([0.0, 0.0], device=torch.cuda.current_device()) ) return loss_sum - return torch.tensor(0.0).cuda() + return torch.tensor(0.0, device=torch.cuda.current_device()) def masked_token_loss(tensor: Tensor, mask: Tensor): diff --git a/nemo/lightning/nemo_logger.py b/nemo/lightning/nemo_logger.py index fbf9298dfec48..5ed783fdbefec 100644 --- a/nemo/lightning/nemo_logger.py +++ b/nemo/lightning/nemo_logger.py @@ -1,21 +1,24 @@ import os import sys import time -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path from typing import List, Optional, Union import lightning_fabric as fl import pytorch_lightning as pl +from fiddle._src.experimental import serialization from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint as PTLModelCheckpoint +from pytorch_lightning.loggers import Logger, TensorBoardLogger, WandbLogger +from nemo.lightning.io.mixin import IOMixin from nemo.lightning.pytorch.callbacks import ModelCheckpoint from nemo.utils import logging from nemo.utils.app_state import AppState @dataclass -class NeMoLogger: +class NeMoLogger(IOMixin): """Logger for NeMo runs. Args: @@ -41,6 +44,9 @@ class NeMoLogger: files_to_copy: Optional[List[str]] = None update_logger_directory: bool = True ckpt: Optional[ModelCheckpoint] = None + tensorboard: Optional[TensorBoardLogger] = None + wandb: Optional[WandbLogger] = None + extra_loggers: List[Logger] = field(default_factory=list) def __post_init__(self): if self.log_local_rank_0_only is True and self.log_global_rank_0_only is True: @@ -48,11 +54,7 @@ def __post_init__(self): f"Cannot set both log_local_rank_0_only and log_global_rank_0_only to True. Please set either one or neither." ) - def setup( - self, - trainer: Union[pl.Trainer, fl.Fabric], - resume_if_exists: bool = False, - ): + def setup(self, trainer: Union[pl.Trainer, fl.Fabric], resume_if_exists: bool = False, task_config=None): """Setup the logger for the experiment. Args: @@ -62,15 +64,13 @@ def setup( Returns: AppState: The application state with updated log directory and other settings. """ - from nemo.constants import NEMO_ENV_VARNAME_TESTING, NEMO_ENV_VARNAME_VERSION - from nemo.utils.env_var_parsing import get_envbool + from nemo.constants import NEMO_ENV_VARNAME_VERSION from nemo.utils.exp_manager import check_explicit_log_dir from nemo.utils.get_rank import is_global_rank_zero - from nemo.utils.mcore_logger import add_handlers_to_mcore_logger - local_rank = int(os.environ.get("LOCAL_RANK", 0)) - global_rank = trainer.node_rank * trainer.world_size + local_rank - logging.rank = global_rank + self.local_rank = int(os.environ.get("LOCAL_RANK", 0)) + self.global_rank = trainer.node_rank * trainer.world_size + self.local_rank + logging.rank = self.global_rank if self.explicit_log_dir and isinstance(trainer, pl.Trainer): # If explicit log_dir was passed, short circuit return check_explicit_log_dir(trainer, self.explicit_log_dir, self.dir, self.name, self.version) @@ -83,14 +83,6 @@ def setup( if not self.name: self.name = "default" - if isinstance(trainer, pl.Trainer) and trainer.logger is not None: - if self.update_logger_directory: - logging.warning( - f'"update_logger_directory" is True. Overwriting logger "save_dir" to {_dir} and "name" to {self.name}' - ) - trainer.logger._root_dir = _dir - trainer.logger._name = self.name - version = self.version or os.environ.get(NEMO_ENV_VARNAME_VERSION, None) if is_global_rank_zero(): if self.use_datetime_version: @@ -100,7 +92,6 @@ def setup( "No version folders would be created under the log folder as 'resume_if_exists' is enabled." ) version = None - trainer.logger._version = version or "" if version: if is_global_rank_zero(): os.environ[NEMO_ENV_VARNAME_VERSION] = version @@ -112,80 +103,120 @@ def setup( app_state.exp_dir = _dir app_state.name = self.name app_state.version = version + app_state.cmd_args = sys.argv os.makedirs(log_dir, exist_ok=True) # Cannot limit creation to global zero as all ranks write to own log file logging.info(f'Experiments will be logged at {log_dir}') + if task_config and is_global_rank_zero(): + self._handle_task_config(task_config, log_dir) + if isinstance(trainer, pl.Trainer): - if self.ckpt: - _overwrite_i = None - for i, callback in enumerate(trainer.callbacks): - if isinstance(callback, PTLModelCheckpoint): - logging.warning( - "The Trainer already contains a ModelCheckpoint callback. " "This will be overwritten." - ) - _overwrite_i = i - break - if _overwrite_i is not None: - trainer.callbacks[_overwrite_i] = self.ckpt - else: - trainer.callbacks.append(self.ckpt) - - if self.ckpt.monitor and "val" in self.ckpt.monitor: - if ( - trainer.max_epochs is not None - and trainer.max_epochs != -1 - and trainer.max_epochs < trainer.check_val_every_n_epoch - ): - logging.error( - "The checkpoint callback was told to monitor a validation value but trainer.max_epochs(" - f"{trainer.max_epochs}) was less than trainer.check_val_every_n_epoch({trainer.check_val_every_n_epoch}" - f"). It is very likely this run will fail with ModelCheckpoint(monitor='{self.ckpt.monitor}') not found " - "in the returned metrics. Please ensure that validation is run within trainer.max_epochs." - ) - elif trainer.max_steps is not None and trainer.max_steps != -1: - logging.warning( - "The checkpoint callback was told to monitor a validation value and trainer's max_steps was set to " - f"{trainer.max_steps}. Please ensure that max_steps will run for at least " - f"{trainer.check_val_every_n_epoch} epochs to ensure that checkpointing will not error out." - ) - - for callback in trainer.callbacks: + self._setup_trainer_loggers(trainer, _dir, version) + self._setup_trainer_model_checkpoint(trainer, log_dir=log_dir, ckpt=self.ckpt) + + self._setup_files_to_move(log_dir, app_state) + self._setup_file_logging(log_dir) + + return app_state + + def _setup_trainer_loggers(self, trainer, dir, version): + loggers = [self.tensorboard, self.wandb, *self.extra_loggers] + loggers = [logger for logger in loggers if logger is not None] + + if self.update_logger_directory and self.wandb: + self.wandb._save_dir = dir + self.wandb._wandb_init["dir"] = dir + self.wandb._wandb_init["name"] = self.name + self.wandb._name = self.name + + if loggers: + if trainer.logger is not None and not self.tensorboard: + loggers = [trainer.logger] + loggers + trainer._logger_connector.configure_logger(loggers) + + if trainer.logger is not None: + trainer.logger._version = version or "" + if self.update_logger_directory: + logging.warning( + f'"update_logger_directory" is True. Overwriting logger "save_dir" to {dir} and "name" to {self.name}' + ) + trainer.logger._root_dir = dir + trainer.logger._name = self.name + + def _setup_trainer_model_checkpoint(self, trainer, log_dir, ckpt=None): + if ckpt: + _overwrite_i = None + for i, callback in enumerate(trainer.callbacks): if isinstance(callback, PTLModelCheckpoint): - if callback.dirpath is None: - callback.dirpath = Path(log_dir / "checkpoints") - if callback.filename is None: - callback.filename = f'{self.name}--{{{callback.monitor}:.4f}}-{{epoch}}' - ModelCheckpoint.CHECKPOINT_NAME_LAST = callback.filename + '-last' + logging.warning( + "The Trainer already contains a ModelCheckpoint callback. " "This will be overwritten." + ) + _overwrite_i = i + break + if _overwrite_i is not None: + trainer.callbacks[_overwrite_i] = ckpt + else: + trainer.callbacks.append(ckpt) + + if ckpt.monitor and "val" in ckpt.monitor: + if ( + trainer.max_epochs is not None + and trainer.max_epochs != -1 + and trainer.max_epochs < trainer.check_val_every_n_epoch + ): + logging.error( + "The checkpoint callback was told to monitor a validation value but trainer.max_epochs(" + f"{trainer.max_epochs}) was less than trainer.check_val_every_n_epoch({trainer.check_val_every_n_epoch}" + f"). It is very likely this run will fail with ModelCheckpoint(monitor='{ckpt.monitor}') not found " + "in the returned metrics. Please ensure that validation is run within trainer.max_epochs." + ) + elif trainer.max_steps is not None and trainer.max_steps != -1: + logging.warning( + "The checkpoint callback was told to monitor a validation value and trainer's max_steps was set to " + f"{trainer.max_steps}. Please ensure that max_steps will run for at least " + f"{trainer.check_val_every_n_epoch} epochs to ensure that checkpointing will not error out." + ) + + for callback in trainer.callbacks: + if isinstance(callback, PTLModelCheckpoint): + if callback.dirpath is None: + callback.dirpath = Path(log_dir / "checkpoints") + if callback.filename is None: + callback.filename = f'{self.name}--{{{callback.monitor}:.4f}}-{{epoch}}' + ModelCheckpoint.CHECKPOINT_NAME_LAST = callback.filename + '-last' + + def _handle_task_config(self, task_config, log_dir): + task_config.save_config_img(log_dir / "task.png") + task_json = serialization.dump_json(task_config) + with open(log_dir / "task.json", "w") as f: + f.write(task_json) + + def _setup_file_logging(self, log_dir): + """Set up file logging based on rank settings.""" + from nemo.constants import NEMO_ENV_VARNAME_TESTING + from nemo.utils.env_var_parsing import get_envbool + from nemo.utils.mcore_logger import add_handlers_to_mcore_logger # This is set if the env var NEMO_TESTING is set to True. nemo_testing = get_envbool(NEMO_ENV_VARNAME_TESTING, False) + log_file = log_dir / f'nemo_log_globalrank-{self.global_rank}_localrank-{self.local_rank}.txt' + + if self.log_local_rank_0_only and not nemo_testing and self.local_rank == 0: + logging.add_file_handler(log_file) + elif self.log_global_rank_0_only and not nemo_testing and self.global_rank == 0: + logging.add_file_handler(log_file) + elif not (self.log_local_rank_0_only or self.log_global_rank_0_only): + logging.add_file_handler(log_file) + + add_handlers_to_mcore_logger() + def _setup_files_to_move(self, log_dir, app_state): files_to_move = [] if Path(log_dir).exists(): for child in Path(log_dir).iterdir(): if child.is_file(): files_to_move.append(child) - # Handle logging to file - log_file = log_dir / f'nemo_log_globalrank-{global_rank}_localrank-{local_rank}.txt' - if self.log_local_rank_0_only is True and not nemo_testing: - if local_rank == 0: - logging.add_file_handler(log_file) - elif self.log_global_rank_0_only is True and not nemo_testing: - if global_rank == 0: - logging.add_file_handler(log_file) - else: - # Logs on all ranks. - logging.add_file_handler(log_file) - - add_handlers_to_mcore_logger() - app_state.files_to_move = files_to_move app_state.files_to_copy = self.files_to_copy - app_state.cmd_args = sys.argv - - return app_state - - def teardown(self): - pass diff --git a/nemo/lightning/pytorch/callbacks/__init__.py b/nemo/lightning/pytorch/callbacks/__init__.py index 1525ab21b8357..ee0e777d739e7 100644 --- a/nemo/lightning/pytorch/callbacks/__init__.py +++ b/nemo/lightning/pytorch/callbacks/__init__.py @@ -1,7 +1,9 @@ -from nemo.lightning.pytorch.callbacks.megatron_model_checkpoint import ModelCheckpoint +from nemo.lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint +from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform +from nemo.lightning.pytorch.callbacks.nsys import NsysCallback +from nemo.lightning.pytorch.callbacks.peft import PEFT +from nemo.lightning.pytorch.callbacks.preemption import PreemptionCallback from nemo.lightning.pytorch.callbacks.progress import MegatronProgressBar -__all__ = [ - "MegatronProgressBar", - "ModelCheckpoint", -] + +__all__ = ["ModelCheckpoint", "ModelTransform", "PEFT", "NsysCallback", "MegatronProgressBar", "PreemptionCallback"] diff --git a/nemo/lightning/pytorch/callbacks/megatron_model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py similarity index 83% rename from nemo/lightning/pytorch/callbacks/megatron_model_checkpoint.py rename to nemo/lightning/pytorch/callbacks/model_checkpoint.py index 44b1ab238198a..83e750ff281ea 100644 --- a/nemo/lightning/pytorch/callbacks/megatron_model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -26,12 +26,15 @@ from pytorch_lightning.callbacks.model_checkpoint import _is_local_file_protocol from pytorch_lightning.utilities import rank_zero_info +from nemo.lightning.io.mixin import IOMixin +from nemo.lightning.io.pl import TrainerContext from nemo.utils import logging from nemo.utils.app_state import AppState +from nemo.utils.callbacks.dist_ckpt_io import AsyncFinalizableCheckpointIO from nemo.utils.model_utils import ckpt_to_dir -class ModelCheckpoint(PTLModelCheckpoint): +class ModelCheckpoint(PTLModelCheckpoint, IOMixin): UNFINISHED_CHECKPOINT_SUFFIX = "-unfinished" @@ -48,10 +51,21 @@ def __init__( train_time_interval: Optional[timedelta] = None, save_best_model: bool = False, save_on_train_epoch_end: Optional[bool] = False, # Save after training, not after validation + enable_nemo_ckpt_io: bool = True, + async_save: bool = False, + try_restore_best_ckpt: bool = True, **kwargs, ): self.save_best_model = save_best_model self.previous_best_path = "" + self.enable_nemo_ckpt_io = enable_nemo_ckpt_io + self.async_save = async_save + # Checkpoints which removal is deferred until async save is done. + # Each element of `deferred_ckpts_to_remove` is a growing list + # that `self._remove_checkpoint` adds to. Once `self._save_checkpoint` + # is called, the last element is frozen and a new element is added. + self.deferred_ckpts_to_remove: List[List[str]] = [] + self.try_restore_best_ckpt = try_restore_best_ckpt # Call the parent class constructor with the remaining kwargs. super().__init__( @@ -92,26 +106,34 @@ def on_train_start(self, trainer, pl_module): if fold.is_dir(): run_count += 1 new_run_dir = Path(Path(log_dir) / f"run_{run_count}") - new_run_dir.mkdir() - for _file in files_to_move: - shutil.move(str(_file), str(new_run_dir)) + if not new_run_dir.exists(): + new_run_dir.mkdir() + for _file in files_to_move: + shutil.move(str(_file), str(new_run_dir)) # Move files_to_copy to folder and add git information if present if app_state.files_to_copy: for _file in app_state.files_to_copy: - shutil.copy(Path(_file), log_dir) + src_path = Path(_file) + dst_path = Path(log_dir) / src_path.name + if not dst_path.exists(): + shutil.copy(src_path, dst_path) # Create files for cmd args and git info if app_state.cmd_args: - with open(log_dir / 'cmd-args.log', 'w', encoding='utf-8') as _file: - _file.write(" ".join(app_state.cmd_args)) + cmd_args_file = log_dir / 'cmd-args.log' + if not cmd_args_file.exists(): + with open(cmd_args_file, 'w', encoding='utf-8') as _file: + _file.write(" ".join(app_state.cmd_args)) # Try to get git hash git_repo, git_hash = get_git_hash() if git_repo: - with open(log_dir / 'git-info.log', 'w', encoding='utf-8') as _file: - _file.write(f'commit hash: {git_hash}') - _file.write(get_git_diff()) + git_info_file = log_dir / 'git-info.log' + if not git_info_file.exists(): + with open(git_info_file, 'w', encoding='utf-8') as _file: + _file.write(f'commit hash: {git_hash}\n') + _file.write(get_git_diff()) # Add err_file logging to global_rank zero logging.add_err_file_handler(log_dir / 'nemo_error_log.txt') @@ -220,13 +242,7 @@ def on_train_end(self, trainer, pl_module): return None # check if we need to save a last checkpoint manually as validation isn't always run based on the interval - ## TODO: there is some sort of bug in this code. - ## this is what is causing the failure with async checkpointing when "epoch" is part of the ckpt name - ## I think this is unnecessary because we will automatically save a final checkpoint - ## during on_train_batch_end - ## see https://github.com/Lightning-AI/pytorch-lightning/blob/f6fd046552a1504023cb3386a8a0df418a810e4f/src/lightning/pytorch/callbacks/model_checkpoint.py#L315 - ## we should change the logic to only save a final checkpoint if it wasn't just saveds - '''if self.save_last and trainer.val_check_interval != 0: + if self.save_last and trainer.val_check_interval != 0: should_save_last_checkpoint = False if isinstance(trainer.val_check_interval, float) and trainer.val_check_interval % trainer.global_step != 0: should_save_last_checkpoint = True @@ -237,7 +253,7 @@ def on_train_end(self, trainer, pl_module): if self.last_model_path == self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST): logging.debug(f'Last checkpoint {self.last_model_path} already saved') else: - super()._save_last_checkpoint(trainer, monitor_candidates)''' + super()._save_last_checkpoint(trainer, monitor_candidates) # Call parent on_train_end() to save the -last checkpoint super().on_train_end(trainer, pl_module) @@ -254,8 +270,9 @@ def on_train_end(self, trainer, pl_module): else: if os.path.isdir(self.best_model_path.split('.ckpt')[0]): self.best_model_path = self.best_model_path.split('.ckpt')[0] - self.best_model_path = trainer.strategy.broadcast(self.best_model_path) - trainer._checkpoint_connector.restore(self.best_model_path) + if self.try_restore_best_ckpt: + self.best_model_path = trainer.strategy.broadcast(self.best_model_path) + trainer._checkpoint_connector.restore(self.best_model_path) def _del_model_without_trainer(self, filepath: str) -> None: from nemo.utils.get_rank import is_global_rank_zero @@ -363,7 +380,10 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) # if anything goes wrong during checkpointing, we should be able to detect that data is incomplete. self.set_checkpoint_unfinished_marker(filepath, barrier_after=True) ema_callback = self._ema_callback(trainer) + if ema_callback is not None: + if self.async_save: + raise ValueError('async_save with EMA not supported') with ema_callback.save_original_optimizer_state(trainer): super()._save_checkpoint(trainer, filepath) @@ -376,10 +396,23 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) super()._save_checkpoint(trainer, filepath) self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) else: + # Async save passes the finalization function to checkpoint_io, + # sync save calls the finalization function immediately after save. finalize_fn = self._get_finalize_save_checkpoint_callback(trainer, filepath, trainer.global_step) - storage_options = None + if self.async_save: + checkpoint_io = trainer.strategy.checkpoint_io + if not isinstance(checkpoint_io, AsyncFinalizableCheckpointIO): + raise ValueError('Async save requires async compatible CheckpointIO') + storage_options = dict(finalize_fn=finalize_fn) + # Each upcoming ckpt removal request will be executed as part of this save finalization + self.deferred_ckpts_to_remove.append([]) + else: + storage_options = None trainer.save_checkpoint(filepath, self.save_weights_only, storage_options=storage_options) - finalize_fn() + if self.async_save: + logging.info(f'Scheduled async checkpoint save for {filepath}') + else: + finalize_fn() def _get_finalize_save_checkpoint_callback( self, trainer: 'pytorch_lightning.Trainer', filepath: str, global_step: int @@ -391,6 +424,11 @@ def _cb(): self._last_global_step_saved = global_step self._last_checkpoint_saved = filepath + from nemo.utils.get_rank import is_global_rank_zero + + if self.enable_nemo_ckpt_io and is_global_rank_zero(): + TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath)) + # notify loggers if trainer.is_global_zero: for logger in trainer.loggers: @@ -400,10 +438,32 @@ def _cb(): # we don't want to remove the marker until all checkpointing is done. self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) + if not self.async_save: + return + + logging.info(f'Async checkpoint save for step {global_step} ({filepath}) finalized successfully.') + + # Remove checkpoints marked for removal by `self._remove_checkpoint` + # For each finalization there is exactly one entry in self.deferred_ckpts_to_remove + assert self.deferred_ckpts_to_remove + ckpts_to_remove = self.deferred_ckpts_to_remove.pop(0) + logging.debug(f'Checkpoints to remove: {ckpts_to_remove}') + for ckpt_to_remove in ckpts_to_remove: + self._remove_checkpoint(trainer, ckpt_to_remove, override_async=True) + return _cb def _remove_checkpoint(self, trainer: "pytorch_lightning.Trainer", filepath: str, override_async=False) -> None: - """Performs checkpoint removal.""" + """Performs checkpoint removal. + + With async save, `self._remove_checkpoint` is called before the checkpoint + is actually finished so we can't remove it. Instead we add it to + `self.deferred_ckpts_to_remove` for future removal. + """ + if self.async_save and not override_async: + # Register checkpoint removal in the last (active) checkpoint removal list + self.deferred_ckpts_to_remove[-1].append(filepath) + return # barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed. # if anything goes wrong during removal, we should be able to detect that data is incomplete. self.set_checkpoint_unfinished_marker(filepath, barrier_after=True) @@ -411,6 +471,7 @@ def _remove_checkpoint(self, trainer: "pytorch_lightning.Trainer", filepath: str ema_callback = self._ema_callback(trainer) if ema_callback is not None: # remove EMA copy of the state dict as well. + filepath = self._ema_format_filepath(filepath) super()._remove_checkpoint(trainer, filepath) # barrier_before=True, so all ranks synchronize before removing the unfinished checkpoint marker diff --git a/nemo/lightning/pytorch/callbacks/model_transform.py b/nemo/lightning/pytorch/callbacks/model_transform.py new file mode 100644 index 0000000000000..68b3db16f4734 --- /dev/null +++ b/nemo/lightning/pytorch/callbacks/model_transform.py @@ -0,0 +1,98 @@ +from functools import wraps +from typing import Any, Callable, Optional, TypeVar + +import pytorch_lightning as pl +from torch import nn + +from nemo.lightning.io.mixin import IOMixin +from nemo.utils import logging + + +class ModelTransform(pl.Callback, IOMixin): + """ + A PyTorch Lightning callback that applies a model transformation function at the start of fitting or validation. + + This callback is designed to apply a transformation to the model when fitting or validation begins. + This design allows for loading the original checkpoint first and then applying the transformation, + which is particularly useful for techniques like Parameter-Efficient Fine-Tuning (PEFT). + + The transformation function is expected to be defined on the LightningModule + as an attribute called 'model_transform'. + + Key Features: + - Applies transformation at the start of fit or validation, not during initialization. + - Allows loading of original checkpoints before transformation. + - Supports PEFT and similar techniques that modify model structure. + + Example: + >>> class MyLightningModule(pl.LightningModule): + ... def __init__(self): + ... super().__init__() + ... self.model = SomeModel() + ... self.model_transform = lambda m: SomePEFTMethod()(m) + ... + >>> model = MyLightningModule() + >>> # Load original checkpoint here if needed + >>> model.load_state_dict(torch.load('original_checkpoint.pth')) + >>> trainer = pl.Trainer(callbacks=[ModelTransform()]) + >>> # The model will be transformed when trainer.fit() or trainer.validate() is called + >>> trainer.fit(model) + + Note: + The transformation is applied only once, at the start of fitting or validation, + whichever comes first. This ensures that the model structure is modified before + any forward passes or parameter updates occur, but after the original weights + have been loaded. + """ + + def __init__(self): + super().__init__() + self.model_transform: Optional[Callable[[nn.Module], nn.Module]] = None + + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: + logging.info(f"Setting up ModelTransform for stage: {stage}") + + if hasattr(pl_module, 'model_transform'): + logging.info("Found model_transform attribute on pl_module") + self.model_transform = _call_counter(pl_module.model_transform) + pl_module.model_transform = self.model_transform + logging.info(f"Set model_transform to: {self.model_transform}") + else: + logging.info("No model_transform attribute found on pl_module") + + def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self._maybe_apply_transform(trainer) + + def _maybe_apply_transform(self, trainer): + if self._needs_to_call: + self.model_transform(trainer.model) + + @property + def _needs_to_call(self) -> bool: + return self.model_transform and self.model_transform.__num_calls__ == 0 + + +T = TypeVar('T', bound=Callable[..., Any]) + + +def _call_counter(func: T) -> T: + """ + A decorator that counts the number of times a function is called. + + This decorator wraps a function and adds a '__num_calls__' attribute to it, + which is incremented each time the function is called. + + Args: + func (Callable): The function to be wrapped. + + Returns: + Callable: The wrapped function with a call counter. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + wrapper.__num_calls__ += 1 + return func(*args, **kwargs) + + wrapper.__num_calls__ = 0 + return wrapper # type: ignore diff --git a/nemo/lightning/pytorch/callbacks/nsys.py b/nemo/lightning/pytorch/callbacks/nsys.py index f50fe0481e9d4..d24d7fd974be8 100644 --- a/nemo/lightning/pytorch/callbacks/nsys.py +++ b/nemo/lightning/pytorch/callbacks/nsys.py @@ -1,13 +1,33 @@ -from typing import Any, List, Optional +from typing import List, Optional import torch from pytorch_lightning.callbacks.callback import Callback +from nemo.lightning.io.mixin import IOMixin from nemo.utils import logging from nemo.utils.get_rank import get_rank -class NsysCallback(Callback): +class NsysCallback(Callback, IOMixin): + """ + A PyTorch Lightning callback for NVIDIA Nsight Systems (Nsys) profiling. + + This callback enables profiling of specific steps during training using NVIDIA Nsys. + It allows for precise control over when profiling starts and ends, which ranks are profiled, + and whether to generate detailed shape information. + + More info about nsys can be found [here](https://developer.nvidia.com/nsight-systems). + + Args: + start_step (int): Global batch to start profiling + end_step (int): Global batch to end profiling + ranks (List[int]): Global rank IDs to profile + gen_shape (bool): Generate model and kernel details including input shapes + + Example: + >>> callback = NsysCallback(start_step=100, end_step=200, ranks=[0, 1], gen_shape=True) + >>> trainer = Trainer(callbacks=[callback]) + """ def __init__( self, @@ -16,13 +36,6 @@ def __init__( ranks: List[int] = [0], gen_shape: bool = False, ): - """ - Args: - start_step (int): Global batch to start profiling - end_step (int): Global batch to end profiling - ranks (List[int]): Global rank IDs to profile - gen_shape (bool): Generate model and kernel details including input shapes - """ assert type(start_step) == int, f'Nsys start_step must be of type int. Found: {type(start_step)}' self._nsys_profile_start_step = start_step @@ -54,6 +67,8 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx: int) -> Opt torch.cuda.cudart().cudaProfilerStart() if self._nsys_profile_gen_shape: torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__() + else: + torch.autograd.profiler.emit_nvtx().__enter__() def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int) -> None: """PyTorch Lightning hook: @@ -63,7 +78,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int) device = trainer.strategy.root_device if device.type == 'cuda': - print(f'batch idx: {batch_idx}') if batch_idx == self._nsys_profile_end_step and get_rank() in self._nsys_profile_ranks: logging.info("====== End nsys profiling ======") torch.cuda.cudart().cudaProfilerStop() + torch.autograd.profiler.emit_nvtx().__exit__(None, None, None) diff --git a/nemo/lightning/pytorch/callbacks/peft.py b/nemo/lightning/pytorch/callbacks/peft.py new file mode 100644 index 0000000000000..26325bf549d03 --- /dev/null +++ b/nemo/lightning/pytorch/callbacks/peft.py @@ -0,0 +1,261 @@ +import json +from abc import ABC, abstractmethod +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple + +import pytorch_lightning as pl +import torch.nn as nn +from lightning_fabric.utilities.types import _PATH +from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO +from typing_extensions import override + +from nemo.lightning.io.pl import ckpt_to_dir +from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform +from nemo.utils import logging + +if TYPE_CHECKING: + from megatron.core.dist_checkpointing.mapping import ShardedStateDict + + +_ADAPTER_META_FILENAME = "adapter_metadata.json" + + +class PEFT(ABC, ModelTransform): + """Abstract base class for Parameter-Efficient Fine-Tuning (PEFT) methods. + + This class defines the interface for PEFT methods, which are used to fine-tune + large language models efficiently by modifying only a small subset of the model's + parameters. + + Example: + class MyPEFT(PEFT): + def transform(self, module, name=None, prefix=None): + # Implement the transform logic + pass + + + peft = MyPEFT() + peft_model = LargeLanguageModel(model_transform=peft) + """ + + @abstractmethod + def transform(self, module, name=None, prefix=None): + """Transform a single module according to the PEFT method. + + This method is called for each module in the model during the PEFT application process. + It should be implemented by subclasses to define how individual modules are transformed + for the specific PEFT technique. + + Args: + module (nn.Module): The individual module to be transformed. + name (Optional[str]): The name of the module within the model structure. Defaults to None. + prefix (Optional[str]): A prefix to be added to the module name, typically used for + nested modules. Defaults to None. + + Returns: + nn.Module: The transformed module. This can be the original module with modifications, + a new module replacing the original, or the original module if no + transformation is needed for this specific module. + + Note: + This method is automatically called for each module in the model when the PEFT + instance is applied to the model using the __call__ method. + """ + raise NotImplementedError("The transform method should be implemented by subclasses.") + + def __call__(self, model: nn.Module) -> nn.Module: + """Apply the PEFT method to the entire model. + + This method freezes the model parameters and walks through the model + structure, applying the transform method to each module. + + Args: + model (nn.Module): The model to be fine-tuned. + + Returns: + nn.Module: The transformed model with PEFT applied. + """ + + model.freeze() + model.walk(self.transform) + + return model + + def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: + super().setup(trainer, pl_module, stage=stage) + + self.wrapped_io = WrappedAdapterIO(trainer.strategy.checkpoint_io) + trainer.strategy._checkpoint_io = self.wrapped_io + + def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + needs_to_call = self._needs_to_call + self._maybe_apply_transform(trainer) + + # Check if we need to load the adapters + if needs_to_call and self.wrapped_io.adapter_ckpt_path is not None: + logging.info(f"Loading adapters from {self.wrapped_io.adapter_ckpt_path}") + adapter_state = self.wrapped_io.load_checkpoint(self.wrapped_io.adapter_ckpt_path) + trainer.strategy.load_model_state_dict(adapter_state, strict=False) + + def on_load_checkpoint( + self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: Dict[str, Any] + ) -> None: + pl_module.strict_loading = False + + +class AdapterWrapper(nn.Module): + """Abstract base class for wrapping modules with adapters in Parameter-Efficient Fine-Tuning (PEFT). + + This class wraps a module and its associated adapter, providing methods for + managing the state dictionaries of both the main module and the adapter. It does not + implement the forward method, which must be implemented by concrete subclasses. + + Attributes: + to_wrap (nn.Module): The main module to be wrapped. + adapter (nn.Module): The adapter module to be applied. + + Note: + This class is abstract and cannot be instantiated directly. Subclasses must + implement the forward method. + + Example: + class AdapterParallelAdd(AdapterWrapper): + def __init__(self, to_wrap, adapter): + super().__init__(to_wrap, adapter) + + def forward(self, x): + return self.to_wrap(x) + self.adapter(x) + + main_module = nn.Linear(100, 100) + adapter = nn.Linear(100, 100) + parallel_adapter = AdapterParallelAdd(main_module, adapter) + """ + + def __init__(self, to_wrap: nn.Module, adapter: nn.Module): + super(AdapterWrapper, self).__init__() + self.to_wrap = to_wrap + self.adapter = adapter + + def state_dict(self, destination=None, prefix='', keep_vars=False): + """Retrieve the state dictionary of the wrapped module and adapter. + + This method overrides the default state_dict behavior to include both + the main module's state and the adapter's state under a special 'adapters' key. + + Args: + destination (Optional[dict]): A dictionary to store the state. If None, a new + dictionary is created. Defaults to None. + prefix (str): A prefix added to parameter and buffer names. Defaults to ''. + keep_vars (bool): If True, returns variables instead of tensor values. + Defaults to False. + + Returns: + dict: The state dictionary containing both the main module and adapter states. + """ + + if destination is None: + destination = {} + + # Get state dict of the main module + main_state_dict = self.to_wrap.state_dict(destination, prefix, keep_vars) + + # Store adapter state dict under the special "adapters" key in the destination dict + adapter_state_dict = self.adapter.state_dict(None, prefix, keep_vars) + destination[f'{prefix}adapters'] = adapter_state_dict + return main_state_dict + + def sharded_state_dict( + self, + prefix: str = '', + sharded_offsets: Tuple[Tuple[int, int, int]] = (), + metadata: Optional[dict] = None, + ) -> "ShardedStateDict": + """Retrieve the sharded state dictionary of the wrapped module and adapter. + + This method is used for distributed checkpointing, combining the sharded states + of both the main module and the adapter. + + Args: + prefix (str): A prefix added to parameter and buffer names. Defaults to ''. + sharded_offsets (Tuple[Tuple[int, int, int]]): Offsets for sharded parameters. + Defaults to an empty tuple. + metadata (Optional[dict]): Additional metadata for the sharded state. + Defaults to None. + + Returns: + ShardedStateDict: The combined sharded state dictionary. + """ + sharded_state_dict = {} + sharded_state_dict.update(self.to_wrap.sharded_state_dict(prefix, sharded_offsets, metadata)) + sharded_state_dict.update(self.adapter.sharded_state_dict(f"{prefix}adapter.", sharded_offsets, metadata)) + return sharded_state_dict + + def load_state_dict(self, state_dict, strict=True): + """Load a state dictionary into the wrapped module and adapter. + + This method overrides the default load_state_dict behavior to handle + loading states for both the main module and the adapter. + + Args: + state_dict (dict): The state dictionary to load. + strict (bool): Whether to strictly enforce that the keys in state_dict + match the keys returned by this module's state_dict() + function. Defaults to True. + """ + # Check if the 'adapters' key is present in the state_dict + if 'adapters' in state_dict: + adapter_state_dict = state_dict.pop('adapters') + else: + adapter_state_dict = {} + + # Load the main module state dict + self.to_wrap.load_state_dict(state_dict, strict) + + # Load the adapter module state dict if present + if adapter_state_dict: + self.adapter.load_state_dict(adapter_state_dict, strict) + + +class WrappedAdapterIO(_WrappingCheckpointIO): + model_ckpt_path: Optional[Path] = None + adapter_ckpt_path: Optional[Path] = None + + @override + def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + assert self.checkpoint_io is not None + + key = "sharded_state_dict" if "sharded_state_dict" in checkpoint else "state_dict" + checkpoint[key] = dict(filter(lambda x: ".adapter." in x[0], checkpoint[key].items())) + + self.checkpoint_io.save_checkpoint(checkpoint, path, storage_options=storage_options) + + from nemo.utils.get_rank import is_global_rank_zero + + if is_global_rank_zero(): + metadata = {"model_ckpt_path": str(self.model_ckpt_path)} + adapter_meta_path = ckpt_to_dir(path) / _ADAPTER_META_FILENAME + with open(adapter_meta_path, "w") as f: + json.dump(metadata, f) + + @override + def load_checkpoint( + self, path: _PATH, sharded_state_dict=None, map_location: Optional[Callable] = None + ) -> Dict[str, Any]: + assert self.checkpoint_io is not None + + adapter_meta_path = ckpt_to_dir(path) / _ADAPTER_META_FILENAME + if getattr(path, "adapter_path", None): + self.model_ckpt_path = path + self.adapter_ckpt_path = path.adapter_path + elif adapter_meta_path.exists(): + with open(adapter_meta_path, "r") as f: + metadata = json.load(f) + self.model_ckpt_path = Path(metadata['model_ckpt_path']) + self.adapter_ckpt_path = path + else: + self.model_ckpt_path = path + + # Note: this will include the Trainer-state of the model-checkpoint + model_ckpt = self.checkpoint_io.load_checkpoint(path, sharded_state_dict, map_location) + + return model_ckpt diff --git a/nemo/lightning/pytorch/callbacks/preemption.py b/nemo/lightning/pytorch/callbacks/preemption.py new file mode 100644 index 0000000000000..7f1dd94256d29 --- /dev/null +++ b/nemo/lightning/pytorch/callbacks/preemption.py @@ -0,0 +1,115 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import signal +from typing import Optional + +import torch +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.utils import logging + + +class PreemptionCallback(Callback): + """ + PreemptionCallback checks for preemption during training at the end of every step. + Upon preemption, it signals the trainer to stop gracefully. + + Args: + sig (int, optional): The signal to listen for. Defaults to signal.SIGTERM. + + Example: + >>> from nemo.lightning.pytorch.callbacks import PreemptionCallback + >>> callback = PreemptionCallback() + >>> trainer = Trainer(callbacks=[callback]) + """ + + def __init__(self, sig: Optional[int] = None): + self.sig = sig if sig is not None else signal.SIGTERM + self._interrupted = False + self._handler_context = None + self._preemption_supported = None + + def on_train_start(self, trainer: Trainer, pl_module) -> None: + if self.preemption_supported: + self._handler_context = self._preemption_handler() + self._handler_context.__enter__() + + def on_train_batch_start(self, trainer: Trainer, pl_module, batch, batch_idx: int) -> None: + if not self.preemption_supported: + self._preemption_supported = self._check_preemption_support() + if self.preemption_supported: + self._handler_context = self._preemption_handler() + self._handler_context.__enter__() + + def on_train_end(self, trainer: Trainer, pl_module) -> None: + if self._handler_context: + self._handler_context.__exit__(None, None, None) + + def on_train_batch_end(self, trainer: Trainer, pl_module, outputs, batch, batch_idx: int) -> None: + if self.interrupted: + logging.info("Preemption detected, signaling trainer to stop") + trainer.should_stop = True + + def on_exception(self, trainer: Trainer, pl_module, exception: BaseException) -> None: + if isinstance(exception, PreemptionException): + logging.info("Handling PreemptionException") + trainer.should_stop = True + + @contextlib.contextmanager + def _preemption_handler(self): + if not self.preemption_supported: + logging.warning("Preemption requires torch distributed to be initialized, preemption may be disabled") + yield + return + + original_handler = signal.getsignal(self.sig) + + def master_handler(signum, frame): + logging.info(f"Received signal {signum}, initiating graceful stop") + self._interrupted = True + raise PreemptionException("Preemption signal received") + + def ignoring_handler(signum, frame): + logging.debug(f"Received signal {signum} on non-master rank, ignoring") + + try: + private_rank = torch.distributed.get_rank() + signal.signal(self.sig, master_handler if private_rank == 0 else ignoring_handler) + yield + finally: + signal.signal(self.sig, original_handler) + + @property + def preemption_supported(self) -> bool: + if self._preemption_supported is None: + self._preemption_supported = self._check_preemption_support() + return self._preemption_supported + + def _check_preemption_support(self) -> bool: + return torch.distributed.is_available() and torch.distributed.is_initialized() + + @property + def interrupted(self) -> bool: + if not self.preemption_supported: + return False + interrupted = torch.tensor(self._interrupted, device=torch.cuda.current_device(), dtype=torch.int32) + torch.distributed.broadcast(interrupted, 0) + return bool(interrupted.item()) + + +class PreemptionException(Exception): + """Custom exception for preemption events.""" diff --git a/nemo/lightning/pytorch/callbacks/progress.py b/nemo/lightning/pytorch/callbacks/progress.py index 9d4d9b385da8d..17178618852f6 100644 --- a/nemo/lightning/pytorch/callbacks/progress.py +++ b/nemo/lightning/pytorch/callbacks/progress.py @@ -26,19 +26,13 @@ def init_train_tqdm(self): return self.bar def on_train_epoch_start(self, trainer, *_): - if trainer.max_steps > 0 and (trainer.ckpt_path is not None): + if trainer.max_steps > 0: # and (trainer.ckpt_path is not None): # while resuming from a ckpt use trainer.max_steps as the total for progress bar as trainer.num_training_batches # is truncated to max_steps - step being resumed at num_training_batches = trainer.max_steps else: num_training_batches = trainer.num_training_batches - # from nemo.utils import AppState - # app_state = AppState() - # app_state. - - num_training_batches = num_training_batches // calculate_data_parallel_groups() - self.train_progress_bar.reset(num_training_batches) self.train_progress_bar.initial = 0 self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}") diff --git a/nemo/lightning/pytorch/opt/__init__.py b/nemo/lightning/pytorch/optim/__init__.py similarity index 81% rename from nemo/lightning/pytorch/opt/__init__.py rename to nemo/lightning/pytorch/optim/__init__.py index ded886bf1e6c1..d23494a96a5fa 100644 --- a/nemo/lightning/pytorch/opt/__init__.py +++ b/nemo/lightning/pytorch/optim/__init__.py @@ -1,5 +1,5 @@ -from nemo.lightning.pytorch.opt.base import LRSchedulerModule, OptimizerModule -from nemo.lightning.pytorch.opt.lr_scheduler import ( +from nemo.lightning.pytorch.optim.base import LRSchedulerModule, OptimizerModule +from nemo.lightning.pytorch.optim.lr_scheduler import ( CosineAnnealingScheduler, InverseSquareRootAnnealingScheduler, NoamAnnealingScheduler, @@ -13,7 +13,7 @@ WarmupHoldPolicyScheduler, WarmupPolicyScheduler, ) -from nemo.lightning.pytorch.opt.megatron import MegatronOptimizerModule +from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule __all__ = [ "OptimizerModule", diff --git a/nemo/lightning/pytorch/opt/base.py b/nemo/lightning/pytorch/optim/base.py similarity index 94% rename from nemo/lightning/pytorch/opt/base.py rename to nemo/lightning/pytorch/optim/base.py index 5f5704beaf6ed..8e857a1566492 100644 --- a/nemo/lightning/pytorch/opt/base.py +++ b/nemo/lightning/pytorch/optim/base.py @@ -1,15 +1,17 @@ import types from abc import ABC, abstractmethod +from copy import deepcopy from typing import List, Optional import pytorch_lightning as L from pytorch_lightning.utilities.types import OptimizerLRScheduler from torch.optim import Optimizer +from nemo.lightning.io.mixin import IOMixin from nemo.lightning.megatron_parallel import CallbackMethods -class LRSchedulerModule(L.Callback, CallbackMethods, ABC): +class LRSchedulerModule(L.Callback, CallbackMethods, IOMixin, ABC): """A module to standardize the learning rate scheduler setup and configuration. This class decouples the learning rate scheduler from the model, similar to how the LightningDataModule @@ -77,7 +79,7 @@ def __call__(self, model, optimizers): return self._scheduler -class OptimizerModule(L.Callback, CallbackMethods, ABC): +class OptimizerModule(L.Callback, CallbackMethods, IOMixin, ABC): """A module to standardize the optimizer setup and configuration. This class decouples the optimizer from the model, similar to how the LightningDataModule @@ -131,6 +133,10 @@ def custom_configure_optimizers(lightning_module_self, megatron_parallel=None): model.configure_optimizers = types.MethodType(custom_configure_optimizers, model) model.optim = self + if hasattr(self, "__io__") and hasattr(model, "__io__"): + if hasattr(model.__io__, "optim"): + model.__io__.optim = deepcopy(self.__io__) + @abstractmethod def optimizers(self, model) -> List[Optimizer]: """Abstract method to define the optimizers. diff --git a/nemo/lightning/pytorch/opt/lr_scheduler.py b/nemo/lightning/pytorch/optim/lr_scheduler.py similarity index 81% rename from nemo/lightning/pytorch/opt/lr_scheduler.py rename to nemo/lightning/pytorch/optim/lr_scheduler.py index 689eb2faa839a..298a6e7a7f452 100644 --- a/nemo/lightning/pytorch/opt/lr_scheduler.py +++ b/nemo/lightning/pytorch/optim/lr_scheduler.py @@ -13,7 +13,7 @@ WarmupHoldPolicy, WarmupPolicy, ) -from nemo.lightning.pytorch.opt.base import LRSchedulerModule +from nemo.lightning.pytorch.optim.base import LRSchedulerModule class WarmupPolicyScheduler(LRSchedulerModule): @@ -48,9 +48,11 @@ def scheduler(self, model, optimizer): ) return { "optimizer": optimizer, - "scheduler": lr_scheduler, - "interval": self.interval, - "frequency": self.frequency, + "lr_scheduler": { + "scheduler": lr_scheduler, + "interval": self.interval, + "frequency": self.frequency, + }, "monitor": self.monitor, } @@ -93,9 +95,11 @@ def scheduler(self, model, optimizer): ) return { "optimizer": optimizer, - "scheduler": lr_scheduler, - "interval": self.interval, - "frequency": self.frequency, + "lr_scheduler": { + "scheduler": lr_scheduler, + "interval": self.interval, + "frequency": self.frequency, + }, "monitor": self.monitor, } @@ -122,9 +126,11 @@ def scheduler(self, model, optimizer): lr_scheduler = SquareAnnealing(optimizer, max_steps=self.max_steps, min_lr=self.min_lr) return { "optimizer": optimizer, - "scheduler": lr_scheduler, - "interval": self.interval, - "frequency": self.frequency, + "lr_scheduler": { + "scheduler": lr_scheduler, + "interval": self.interval, + "frequency": self.frequency, + }, "monitor": self.monitor, } @@ -151,9 +157,11 @@ def scheduler(self, model, optimizer): lr_scheduler = SquareRootAnnealing(optimizer, max_steps=self.max_steps, min_lr=self.min_lr) return { "optimizer": optimizer, - "scheduler": lr_scheduler, - "interval": self.interval, - "frequency": self.frequency, + "lr_scheduler": { + "scheduler": lr_scheduler, + "interval": self.interval, + "frequency": self.frequency, + }, "monitor": self.monitor, } @@ -193,9 +201,11 @@ def scheduler(self, model, optimizer): ) return { "optimizer": optimizer, - "scheduler": lr_scheduler, - "interval": self.interval, - "frequency": self.frequency, + "lr_scheduler": { + "scheduler": lr_scheduler, + "interval": self.interval, + "frequency": self.frequency, + }, "monitor": self.monitor, } @@ -226,9 +236,11 @@ def scheduler(self, model, optimizer): ) return { "optimizer": optimizer, - "scheduler": lr_scheduler, - "interval": self.interval, - "frequency": self.frequency, + "lr_scheduler": { + "scheduler": lr_scheduler, + "interval": self.interval, + "frequency": self.frequency, + }, "monitor": self.monitor, } @@ -255,9 +267,11 @@ def scheduler(self, model, optimizer): lr_scheduler = WarmupAnnealing(optimizer, max_steps=self.max_steps, min_lr=self.min_lr) return { "optimizer": optimizer, - "scheduler": lr_scheduler, - "interval": self.interval, - "frequency": self.frequency, + "lr_scheduler": { + "scheduler": lr_scheduler, + "interval": self.interval, + "frequency": self.frequency, + }, "monitor": self.monitor, } @@ -284,9 +298,11 @@ def scheduler(self, model, optimizer): lr_scheduler = InverseSquareRootAnnealing(optimizer, max_steps=self.max_steps, min_lr=self.min_lr) return { "optimizer": optimizer, - "scheduler": lr_scheduler, - "interval": self.interval, - "frequency": self.frequency, + "lr_scheduler": { + "scheduler": lr_scheduler, + "interval": self.interval, + "frequency": self.frequency, + }, "monitor": self.monitor, } @@ -313,9 +329,11 @@ def scheduler(self, model, optimizer): lr_scheduler = T5InverseSquareRootAnnealing(optimizer, max_steps=self.max_steps, min_lr=self.min_lr) return { "optimizer": optimizer, - "scheduler": lr_scheduler, - "interval": self.interval, - "frequency": self.frequency, + "lr_scheduler": { + "scheduler": lr_scheduler, + "interval": self.interval, + "frequency": self.frequency, + }, "monitor": self.monitor, } @@ -348,9 +366,11 @@ def scheduler(self, model, optimizer): ) return { "optimizer": optimizer, - "scheduler": lr_scheduler, - "interval": self.interval, - "frequency": self.frequency, + "lr_scheduler": { + "scheduler": lr_scheduler, + "interval": self.interval, + "frequency": self.frequency, + }, "monitor": self.monitor, } @@ -383,9 +403,11 @@ def scheduler(self, model, optimizer): ) return { "optimizer": optimizer, - "scheduler": lr_scheduler, - "interval": self.interval, - "frequency": self.frequency, + "lr_scheduler": { + "scheduler": lr_scheduler, + "interval": self.interval, + "frequency": self.frequency, + }, "monitor": self.monitor, } @@ -423,16 +445,19 @@ def scheduler(self, model, optimizer): return { "optimizer": optimizer, - # REQUIRED: The scheduler instance "scheduler": lr_scheduler, - # The unit of the scheduler's step size, could also be 'step'. - # 'epoch' updates the scheduler on epoch end whereas 'step' - # updates it after a optimizer update. - "interval": self.interval, - # How many epochs/steps should pass between calls to - # `scheduler.step()`. 1 corresponds to updating the learning - # rate after every epoch/step. - "frequency": self.frequency, + "lr_scheduler": { + # REQUIRED: The scheduler instance + "scheduler": lr_scheduler, + # The unit of the scheduler's step size, could also be 'step'. + # 'epoch' updates the scheduler on epoch end whereas 'step' + # updates it after a optimizer update. + "interval": self.interval, + # How many epochs/steps should pass between calls to + # `scheduler.step()`. 1 corresponds to updating the learning + # rate after every epoch/step. + "frequency": self.frequency, + }, # Metric to to monitor for schedulers like `ReduceLROnPlateau` "monitor": self.monitor, } diff --git a/nemo/lightning/pytorch/opt/megatron.py b/nemo/lightning/pytorch/optim/megatron.py similarity index 78% rename from nemo/lightning/pytorch/opt/megatron.py rename to nemo/lightning/pytorch/optim/megatron.py index a841148b1a3ba..7faa53f32b651 100644 --- a/nemo/lightning/pytorch/opt/megatron.py +++ b/nemo/lightning/pytorch/optim/megatron.py @@ -7,7 +7,7 @@ from torch.optim import Optimizer from nemo.lightning.megatron_parallel import MegatronParallel -from nemo.lightning.pytorch.opt.base import LRSchedulerModule, OptimizerModule +from nemo.lightning.pytorch.optim.base import LRSchedulerModule, OptimizerModule class MegatronOptimizerModule(OptimizerModule): @@ -54,7 +54,7 @@ def __init__( self.scale_lr_cond = scale_lr_cond self.lr_mult = lr_mult - def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str): + def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): """We will add the finalize_model_grads function to the model config. Args: @@ -90,9 +90,12 @@ def sharded_state_dict( model_sharded_state_dict, optimizer_state_dict=None, is_loading=False, - dist_ckpt_parallel_save=False, + sharding_type='fully_sharded_model_space', ): - return self.mcore_optimizer.sharded_state_dict(model_sharded_state_dict, is_loading=is_loading) + state_dict = self.mcore_optimizer.sharded_state_dict( + model_sharded_state_dict, is_loading=is_loading, sharding_type=sharding_type + ) + return state_dict mcore_opt = get_megatron_optimizer( self.config, @@ -102,6 +105,17 @@ def sharded_state_dict( lr_mult=self.lr_mult, ) + if getattr(model.ddp_config, "overlap_param_sync", False) and getattr( + model.ddp_config, "delay_param_gather", False + ): + param_sync_func = [ + lambda x, model_index=model_index: mcore_opt.finish_param_sync(model_index, x) + for model_index in range(len(pipeline)) + ] + param_sync_func = param_sync_func[0] if len(pipeline) == 1 else param_sync_func + for module in model: + module.config.param_sync_func = param_sync_func + return [McoreOpt(mcore_opt)] def finalize_model_grads(self, *args, **kwargs): diff --git a/nemo/lightning/pytorch/plugins/data_sampler.py b/nemo/lightning/pytorch/plugins/data_sampler.py index c6ff3b7ccaaa3..378375e3bc0ca 100644 --- a/nemo/lightning/pytorch/plugins/data_sampler.py +++ b/nemo/lightning/pytorch/plugins/data_sampler.py @@ -23,14 +23,15 @@ def __init__( global_batch_size: int = 8, rampup_batch_size: Optional[List[int]] = None, dataloader_type: Literal["single", "cyclic"] = "single", + init_consumed_samples: int = 0, ): self.seq_len = seq_len self.micro_batch_size = micro_batch_size self.global_batch_size = global_batch_size self.rampup_batch_size = rampup_batch_size self.dataloader_type = dataloader_type - self.init_consumed_samples: int = 0 - self.prev_consumed_samples = 0 + self.init_consumed_samples = init_consumed_samples + self.prev_consumed_samples = self.init_consumed_samples self.if_first_step = 0 self.prev_global_batch_size = None @@ -47,7 +48,7 @@ def transform_dataloader(self, dataloader: DataLoader, consumed_samples: int = 0 micro_batch_size=self.micro_batch_size, global_batch_size=self.global_batch_size, rampup_batch_size=self.rampup_batch_size, - consumed_samples=consumed_samples, + consumed_samples=self.init_consumed_samples, dataloader_type=self.dataloader_type, ) diff --git a/nemo/lightning/pytorch/plugins/mixed_precision.py b/nemo/lightning/pytorch/plugins/mixed_precision.py index 923bd625da62e..751141d8111b1 100644 --- a/nemo/lightning/pytorch/plugins/mixed_precision.py +++ b/nemo/lightning/pytorch/plugins/mixed_precision.py @@ -13,7 +13,6 @@ # limitations under the License. from contextlib import contextmanager -from types import SimpleNamespace from typing import Any, Callable, Generator, List, Literal, Tuple, TypeVar, Union import pytorch_lightning as pl @@ -40,26 +39,6 @@ def __init__( scaler = GradScaler(init_scale=2**32, growth_interval=1000, hysteresis=2) super().__init__(precision, device, scaler) - - # MixedPrecisionPlugin class in PTL >= 2.0 takes only "16-mixed" or "bf16-mixed" for precision arg - if precision == "16-mixed": - dtype = torch.float16 - - def float16_convertor(val): - return val.half() - - elif precision == "bf16-mixed": - dtype = torch.bfloat16 - - def float16_convertor(val): - return val.bfloat16() - - else: - raise ValueError("precision must be '16-mixed' or 'bf16-mixed'") - - self.dtype = dtype - # torch.set_autocast_gpu_dtype(dtype) - self.float16_convertor = float16_convertor self.amp_O2 = amp_O2 def connect( @@ -90,7 +69,8 @@ def convert_module(self, module: Module) -> Module: config = get_model_config(module.module) config.fp16 = self.precision == "16-mixed" config.bf16 = self.precision == "bf16-mixed" - module.module = Float16Module(config, module.module) + if not isinstance(module.module, Float16Module): + module.module = Float16Module(config, module.module) return module @@ -120,10 +100,6 @@ def convert_input(self, data: AnyT) -> AnyT: """ return data - from megatron.core.transformer.module import fp32_to_float16 - - return fp32_to_float16(data, self.float16_convertor) - def convert_output(self, data: AnyT) -> AnyT: """Convert outputs to the floating point precision type expected after model's forward. @@ -133,10 +109,6 @@ def convert_output(self, data: AnyT) -> AnyT: """ return data - from megatron.core.transformer.module import float16_to_fp32 - - return float16_to_fp32(data) - def optimizer_step( self, optimizer: torch.optim.Optimizer, diff --git a/nemo/lightning/pytorch/strategies.py b/nemo/lightning/pytorch/strategies.py index f62de77f62889..6a84319b4fa29 100644 --- a/nemo/lightning/pytorch/strategies.py +++ b/nemo/lightning/pytorch/strategies.py @@ -12,8 +12,9 @@ import torch import torch.distributed from lightning_fabric.plugins import CheckpointIO, ClusterEnvironment -from lightning_fabric.utilities.optimizer import _optimizers_to_device +from lightning_fabric.utilities.optimizer import _optimizer_to_device, _optimizers_to_device from megatron.core.distributed import DistributedDataParallelConfig +from megatron.core.optimizer import OptimizerConfig from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.callbacks.progress import TQDMProgressBar from pytorch_lightning.loops import _AutomaticOptimization, evaluation_loop, fit_loop, prediction_loop @@ -22,7 +23,6 @@ from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.trainer.states import RunningStage, TrainerFn -from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import STEP_OUTPUT from torch import nn from torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks import noop_hook @@ -31,9 +31,10 @@ from typing_extensions import override from nemo.lightning import _strategy_lib, io -from nemo.lightning.io.pl import MegatronCheckpointIO, TrainerCheckpoint, TrainerCkptProtocol +from nemo.lightning.io.pl import MegatronCheckpointIO from nemo.lightning.megatron_parallel import CallbackConnector, MegatronParallel, _ModuleStepFunction -from nemo.lightning.pytorch.callbacks import MegatronProgressBar +from nemo.lightning.pytorch.callbacks import MegatronProgressBar, ModelTransform +from nemo.utils.callbacks.dist_ckpt_io import AsyncFinalizableCheckpointIO, AsyncFinalizerCallback if TYPE_CHECKING: from nemo.lightning.pytorch.plugins.data_sampler import DataSampler @@ -99,18 +100,22 @@ def __init__( cluster_environment=None, # TODO: Add type-hint checkpoint_io=None, # TODO: Add type-hint find_unused_parameters: bool = False, - enable_nemo_ckpt_io: bool = True, - ckpt_type: TrainerCkptProtocol = TrainerCheckpoint, ckpt_include_optimizer: bool = False, ddp: Union[DDPLiteral, DistributedDataParallelConfig] = "megatron", lazy_init: bool = False, pipeline_dtype: Optional[torch.dtype] = None, + save_ckpt_format='torch_dist', + ckpt_torch_dist_multiproc=None, ## TODO(ashors): put elsewhere? + ckpt_assume_constant_structure=False, + ckpt_parallel_save=True, + ckpt_parallel_load=False, + ckpt_parallel_save_optim=True, **kwargs, ) -> None: super().__init__( - parallel_devices, - cluster_environment, - checkpoint_io, + parallel_devices=parallel_devices, + cluster_environment=cluster_environment, + checkpoint_io=checkpoint_io, find_unused_parameters=find_unused_parameters, **kwargs, ) @@ -124,16 +129,22 @@ def __init__( self.moe_extended_tp = moe_extended_tp self.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size self.sequence_parallel = sequence_parallel - self.enable_nemo_ckpt_io = enable_nemo_ckpt_io - self.ckpt_type = ckpt_type self.lazy_init = lazy_init self.ckpt_include_optimizer = ckpt_include_optimizer self.pipeline_dtype = pipeline_dtype self.log_train_loss = bool(int(os.getenv("NEMO_LOG_TRAIN_LOSS", 1))) self.log_memory_usage = bool(int(os.getenv("NEMO_LOG_MEMORY_USAGE", 0))) + self.save_ckpt_format = save_ckpt_format + self.torch_dist_multiproc = ckpt_torch_dist_multiproc + self.assume_constant_structure = ckpt_assume_constant_structure + self.parallel_save = ckpt_parallel_save + self.parallel_load = ckpt_parallel_load + self.parallel_save_optim = ckpt_parallel_save_optim + + self._ddp = ddp if ddp == "megatron": - self.ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=True) + self.ddp_config = DistributedDataParallelConfig() elif isinstance(ddp, DistributedDataParallelConfig): self.ddp_config = ddp elif ddp == "pytorch": @@ -149,23 +160,24 @@ def __init__( def connect(self, model: pl.LightningModule) -> None: super().connect(model) - # Right now mcore sub-classes ModelParellelConfig, we should remove that - # Given Lightning's structure it would be better if parallelism is a different object - # Since then it can be passed to the Strategy + _maybe_mcore_config = _strategy_lib.set_model_parallel_attributes(model, self.parallelism) + if _maybe_mcore_config: + self._mcore_config = _maybe_mcore_config + + has_optim = getattr(model, "optim", None) + if has_optim: + opt_config = getattr(model.optim, "config", None) + if isinstance(opt_config, OptimizerConfig): + mcore_opt_config: OptimizerConfig = cast(OptimizerConfig, opt_config) + if not self.ddp_config: + raise ValueError("PyTorch DDP is not enabled for mcore optimizer") + ddp_config = cast(DistributedDataParallelConfig, self.ddp_config) - from megatron.core.transformer.transformer_config import TransformerConfig + if mcore_opt_config.use_distributed_optimizer != ddp_config.use_distributed_optimizer: + from nemo.utils import logging - has_mcore_config = isinstance(getattr(model, "config", None), TransformerConfig) - if has_mcore_config and is_overridden("configure_model", model): - config: TransformerConfig = model.config - config.tensor_model_parallel_size = self.tensor_model_parallel_size - config.pipeline_model_parallel_size = self.pipeline_model_parallel_size - config.virtual_pipeline_model_parallel_size = self.virtual_pipeline_model_parallel_size - config.context_parallel_size = self.context_parallel_size - config.expert_model_parallel_size = self.expert_model_parallel_size - config.moe_extended_tp = self.moe_extended_tp - config.sequence_parallel = self.sequence_parallel - self._mcore_config = config + logging.info("Fixing mis-match between ddp-config & mcore-optimizer config") + ddp_config.use_distributed_optimizer = mcore_opt_config.use_distributed_optimizer @override def setup(self, trainer: pl.Trainer, setup_optimizers: bool = True) -> None: @@ -195,6 +207,18 @@ def setup(self, trainer: pl.Trainer, setup_optimizers: bool = True) -> None: self.setup_megatron_parallel(trainer, setup_optimizers=setup_optimizers) self.setup_precision_plugin() + if getattr(self.lightning_module, "model_transform", None): + # Ensure the ModelTransform callback is pass to the trainer. + # Callback.setup() is called before the current Strategy.setup(), so we can + # only perform a check here; adding the callback here would not be sufficient + if not any(isinstance(cb, ModelTransform) for cb in trainer.callbacks): + raise ValueError( + "You specified a model_transform function in the model, but no" + "ModelTransform callback was found in the trainer. " + "Please initialize the trainer with " + "`trainer = Trainer(..., callbacks=[ModelTransform()])`" + ) + if trainer.num_sanity_val_steps > 1 and self.pipeline_model_parallel_size > 1: # TODO: log here trainer.num_sanity_val_steps = 0 @@ -354,6 +378,11 @@ def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTP batch_size=1, ) + self.lightning_module.log( + 'step', + self.trainer.global_step, + ) + if self.log_memory_usage: max_memory_reserved = torch.cuda.max_memory_reserved() memory_allocated = torch.cuda.memory_allocated() @@ -454,7 +483,7 @@ def _fix_progress_bar(self, trainer: pl.Trainer) -> None: callback.__class__ = MegatronProgressBar break - def optimizer_sharded_state_dict(self): + def optimizer_sharded_state_dict(self, is_loading=False): """ Sharded state dictionary for an MainParamsOptimizerWrapper. Used to save and load the optimizer state when training with distributed_checkpoint. @@ -468,8 +497,11 @@ def optimizer_sharded_state_dict(self): # TODO: Fix when MainParamsOptimizerWrapper is not used optimizer = self.lightning_module.optimizers(use_pl_optimizer=False) + sharding_type = 'fully_sharded_model_space' if self.parallel_save_optim else 'dp_zero_gather_scatter' - return _strategy_lib.optimizer_sharded_state_dict(self.megatron_parallel, optimizer) + return _strategy_lib.optimizer_sharded_state_dict( + self.megatron_parallel, optimizer, is_loading=is_loading, sharding_type=sharding_type + ) @override def save_checkpoint( @@ -477,12 +509,10 @@ def save_checkpoint( ) -> None: checkpoint["state_dict"] = OrderedDict([]) # remove device state_dict checkpoint["sharded_state_dict"] = self.megatron_parallel.sharded_state_dict() - if self.trainer.state.fn == TrainerFn.FITTING: + if self.trainer.state.fn == TrainerFn.FITTING and self.ckpt_include_optimizer: checkpoint["optimizer"] = [self.optimizer_sharded_state_dict()] self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options) - if self.enable_nemo_ckpt_io and self.is_global_zero and self.ckpt_type: - self.ckpt_type.from_strategy(self).io_dump(ckpt_to_dir(filepath)) @override def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: @@ -499,62 +529,63 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: if self.ckpt_include_optimizer and self.trainer.state.fn == TrainerFn.FITTING: if self.lightning_module.optimizers(use_pl_optimizer=False): - sharded_state_dict["optimizer"] = [self.optimizer_sharded_state_dict()] + sharded_state_dict["optimizer"] = [self.optimizer_sharded_state_dict(is_loading=True)] checkpoint = self.checkpoint_io.load_checkpoint(checkpoint_path, sharded_state_dict=sharded_state_dict) return checkpoint + @override + def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + if not self.ckpt_include_optimizer: + return + + optimizer_states = checkpoint["optimizer"] + for optimizer, opt_state in zip(self.optimizers, optimizer_states): + optimizer.load_state_dict(opt_state) + _optimizer_to_device(optimizer, self.root_device) + def remove_checkpoint(self, filepath: Union[str, Path]) -> None: if self.is_global_zero: shutil.rmtree(ckpt_to_dir(filepath)) def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None: assert self.megatron_parallel is not None - from megatron.core import parallel_state - for index, module in enumerate(self.megatron_parallel): - if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: - checkpoint_state_dict = checkpoint['state_dict'][f'model_{index}'] - else: - checkpoint_state_dict = checkpoint['state_dict'] - - mcore_model = self.lightning_module.module - current = self.model[0] - n_nesting = 2 - while current != mcore_model: - current = current.module - n_nesting += 1 - - _state_dict = {} - for key, value in checkpoint_state_dict.items(): - # Count the number of "module." at the start of the key - count, _key = 0, key - while _key.startswith("module."): - _key = _key[len("module.") :] - count += 1 - - # Adjust the number of "module." prefixes - if count < n_nesting: - to_add = "module." * (n_nesting - count) - _state_dict[f"{to_add}{key}"] = value - elif count > n_nesting: - to_remove = "module." * (count - n_nesting) - _state_dict[key[len(to_remove) :]] = value - checkpoint_state_dict = _state_dict - - module.load_state_dict(checkpoint_state_dict, strict=strict) + _strategy_lib.load_model_state_dict(self.megatron_parallel, checkpoint, strict=strict) @property @override def checkpoint_io(self) -> CheckpointIO: if self._checkpoint_io is None: - self._checkpoint_io = MegatronCheckpointIO() + checkpoint_callback = self.trainer.checkpoint_callback + async_save = getattr(checkpoint_callback, "async_save", False) + self._checkpoint_io = MegatronCheckpointIO( + save_ckpt_format=self.save_ckpt_format, + async_save=async_save, + torch_dist_multiproc=self.torch_dist_multiproc, + assume_constant_structure=self.assume_constant_structure, + parallel_save=self.parallel_save, + parallel_load=self.parallel_load, + ) + if async_save: + self._checkpoint_io = AsyncFinalizableCheckpointIO(self._checkpoint_io) + have_async_callback = False + for callback in self.trainer.callbacks: + if isinstance(callback, AsyncFinalizerCallback): + have_async_callback = True + break + if not have_async_callback: + self.trainer.callbacks.append(AsyncFinalizerCallback()) elif isinstance(self._checkpoint_io, _WrappingCheckpointIO): self._checkpoint_io.checkpoint_io = MegatronCheckpointIO() return self._checkpoint_io + @checkpoint_io.setter + def checkpoint_io(self, io: CheckpointIO) -> None: + self._checkpoint_io = io + def _get_data_step(self, step_type: str) -> Optional[_ModuleStepFunction]: for fn_name in [f"{step_type}_data_step", "data_step"]: if hasattr(self.lightning_module, fn_name): @@ -624,6 +655,10 @@ def parallelism(self): tensor_model_parallel_size=self.tensor_model_parallel_size, pipeline_model_parallel_size=self.pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=self.virtual_pipeline_model_parallel_size, + context_parallel_size=self.context_parallel_size, + sequence_parallel=self.sequence_parallel, + expert_model_parallel_size=self.expert_model_parallel_size, + moe_extended_tp=self.moe_extended_tp, pipeline_dtype=self.pipeline_dtype, ) diff --git a/nemo/lightning/pytorch/trainer.py b/nemo/lightning/pytorch/trainer.py index b4483d4af4b9f..8b453832d56e5 100644 --- a/nemo/lightning/pytorch/trainer.py +++ b/nemo/lightning/pytorch/trainer.py @@ -4,7 +4,9 @@ import pytorch_lightning as pl from typing_extensions import Self -from nemo.lightning.io.mixin import IOMixin +from nemo.lightning.fabric.conversion import to_fabric +from nemo.lightning.fabric.fabric import Fabric +from nemo.lightning.io.mixin import IOMixin, serialization, track_io class Trainer(pl.Trainer, IOMixin): @@ -12,4 +14,37 @@ def io_init(self, **kwargs) -> fdl.Config[Self]: # Each argument of the trainer can be stateful so we copy them cfg_kwargs = {k: deepcopy(v) for k, v in kwargs.items()} + for val in cfg_kwargs.values(): + if not serialization.find_node_traverser(type(val)): + track_io(type(val)) + return fdl.Config(type(self), **cfg_kwargs) + + def to_fabric(self, callbacks=None, loggers=None) -> Fabric: + accelerator, devices, strategy, plugins = None, None, None, None + if hasattr(self.__io__, "devices"): + devices = self.__io__.devices + if hasattr(self.__io__, "accelerator"): + accelerator = self.__io__.accelerator + if hasattr(self.__io__, "strategy"): + strategy = self.__io__.strategy + if isinstance(strategy, fdl.Config): + strategy = fdl.build(strategy) + + strategy = to_fabric(strategy) + if hasattr(self.__io__, "plugins"): + plugins = self.__io__.plugins + if isinstance(plugins, fdl.Config): + plugins = fdl.build(plugins) + plugins = to_fabric(plugins) + + out = Fabric( + devices=devices, + accelerator=accelerator, + strategy=strategy, + plugins=plugins, + callbacks=callbacks, + loggers=loggers, + ) + + return out diff --git a/nemo/lightning/resume.py b/nemo/lightning/resume.py index fc4f7ec9fab8f..fc2e21eb37fd7 100644 --- a/nemo/lightning/resume.py +++ b/nemo/lightning/resume.py @@ -1,14 +1,24 @@ -from pathlib import Path +import os +from pathlib import Path, PosixPath, WindowsPath from typing import Optional, Union import lightning_fabric as fl import pytorch_lightning as pl +from nemo.lightning import io +from nemo.lightning.io.mixin import IOMixin from nemo.utils import logging from nemo.utils.app_state import AppState +from nemo.utils.model_utils import uninject_model_parallel_rank +# Dynamically inherit from the correct Path subclass based on the operating system. +if os.name == 'nt': + BasePath = WindowsPath +else: + BasePath = PosixPath -class Resume: + +class Resume(IOMixin): def nemo_path(self, model) -> Optional[Path]: raise NotImplementedError @@ -22,7 +32,7 @@ def setup(self, model, trainer: Union[pl.Trainer, fl.Fabric]): trainer.checkpoint_callback.last_model_path = ckpt_path -class AutoResume(Resume): +class AutoResume(Resume, io.IOMixin): """Class that handles the logic for setting checkpoint paths and restoring from checkpoints in NeMo. """ @@ -32,6 +42,7 @@ def __init__( path: Optional[str] = None, ## old resume_from_checkpoint dirpath: Optional[str] = None, ## optional path to checkpoint directory import_path: Optional[str] = None, ## for importing from hf or other checkpoint formats + adapter_path: Optional[str] = None, resume_if_exists: bool = False, resume_past_end: bool = False, resume_ignore_no_checkpoint: bool = False, @@ -64,6 +75,7 @@ def __init__( self.path = path self.dirpath = dirpath self.import_path = import_path + self.adapter_path = adapter_path self.resume_if_exists = resume_if_exists self.resume_past_end = resume_past_end self.resume_ignore_no_checkpoint = resume_ignore_no_checkpoint @@ -74,7 +86,10 @@ def nemo_path(self, model=None) -> Optional[Path]: if self.import_path: if model is None: raise ValueError("Model is needed to import checkpoint from HF or other non-NeMo checkpoint format.") - return model.import_ckpt(self.import_path) + output = model.import_ckpt(self.import_path) + if self.adapter_path: + return AdapterPath(output, adapter_path=Path(self.adapter_path)) + return output ### refactored from exp_manager checkpoint = None @@ -101,15 +116,15 @@ def nemo_path(self, model=None) -> Optional[Path]: warn = f"There were no checkpoints found in checkpoint_dir or no checkpoint folder at checkpoint_dir :{checkpoint_dir}. " if checkpoint is None: warn += "Training from scratch." - elif checkpoint == resume_from_checkpoint: - warn += f"Training from {resume_from_checkpoint}." + elif checkpoint == self.path: + warn += f"Training from {self.path}." logging.warning(warn) else: raise NotFoundError( f"There were no checkpoints found in checkpoint_dir or no checkpoint folder at checkpoint_dir :{checkpoint_dir}. Cannot resume." ) elif len(end_checkpoints) > 0: - if resume_past_end: + if self.resume_past_end: if len(end_checkpoints) > 1: if 'mp_rank' in str(end_checkpoints[0]): checkpoint = end_checkpoints[0] @@ -129,6 +144,17 @@ def nemo_path(self, model=None) -> Optional[Path]: checkpoint = last_checkpoints[0] if checkpoint: + if self.adapter_path: + return AdapterPath(checkpoint, adapter_path=Path(self.adapter_path)) return Path(checkpoint) return None + + +class AdapterPath(BasePath): + adapter_path: Optional[Path] + + def __new__(cls, *args, adapter_path: Optional[Path] = None, **kwargs): + output = super().__new__(cls, *args, **kwargs) + output.adapter_path = adapter_path + return output diff --git a/nemo/utils/__init__.py b/nemo/utils/__init__.py index ebf892927723d..a1e59646ae131 100644 --- a/nemo/utils/__init__.py +++ b/nemo/utils/__init__.py @@ -21,6 +21,7 @@ avoid_float16_autocast_context, cast_all, cast_tensor, + monkeypatched, ) from nemo.utils.dtype import str_to_dtype from nemo.utils.nemo_logging import Logger as _Logger diff --git a/nemo/utils/callbacks/dist_ckpt_io.py b/nemo/utils/callbacks/dist_ckpt_io.py index b95be90274e3e..144c07addaa8a 100644 --- a/nemo/utils/callbacks/dist_ckpt_io.py +++ b/nemo/utils/callbacks/dist_ckpt_io.py @@ -32,16 +32,29 @@ from megatron.core import dist_checkpointing from megatron.core.dist_checkpointing.dict_utils import extract_matching_values from megatron.core.dist_checkpointing.mapping import ShardedBase + from megatron.core.dist_checkpointing.serialization import ( + get_default_load_sharded_strategy, + get_default_save_sharded_strategy, + ) from megatron.core.dist_checkpointing.strategies import tensorstore - - from nemo.utils.callbacks.torch_dist_async import AsyncCallsQueue, AsyncRequest, TorchDistAsyncSaveShardedStrategy + from megatron.core.dist_checkpointing.strategies.async_utils import AsyncCallsQueue, AsyncRequest + from megatron.core.dist_checkpointing.strategies.base import SaveShardedStrategy + from megatron.core.dist_checkpointing.strategies.fully_parallel import ( + FullyParallelLoadStrategyWrapper, + FullyParallelSaveStrategyWrapper, + ) + from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy + from megatron.core.parallel_state import get_data_parallel_group HAVE_MEGATRON_CORE = True -except (ImportError, ModuleNotFoundError) as IMPORT_ERROR_EXC: +except (ImportError, ModuleNotFoundError) as e: HAVE_MEGATRON_CORE = False - IMPORT_ERROR = "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + IMPORT_ERROR = ( + "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + f" Exact error: {e}" + ) @contextmanager @@ -87,7 +100,7 @@ class AsyncFinalizableCheckpointIO(_WrappingCheckpointIO): def __init__(self, checkpoint_io: AsyncCompatibleCheckpointIO) -> None: if not HAVE_MEGATRON_CORE: - raise ImportError(IMPORT_ERROR) from IMPORT_ERROR_EXC + raise ImportError(IMPORT_ERROR) if not isinstance(checkpoint_io, AsyncCompatibleCheckpointIO): raise ValueError(f'Incompatible wrapped checkpoint_io type: {type(checkpoint_io)}') @@ -177,6 +190,12 @@ class DistributedCheckpointIO(AsyncCompatibleCheckpointIO): always loads on device). Defaults to True. async_save (bool): whether to save asynchronously. Should be set to True if this class will be wrapped with AsyncFinalizableCheckpointIO. + torch_dist_multiproc (int, optional): number of extra processes per rank + used during ckpt save with PyTorch distributed format. Defaults, to None + which means using an MCore default (2). + parallel_save (bool): parallelizes the save across ranks. Defaults to True + parallel_load (bool): parallelizes the load across ranks (followed by params all gather). + Defaults to False due to some extra memory usage requirement. """ def __init__( @@ -184,15 +203,25 @@ def __init__( save_ckpt_format: str, load_directly_on_device: bool = True, async_save: bool = False, + torch_dist_multiproc: Optional[int] = None, + assume_constant_structure: bool = False, + parallel_save: bool = False, + parallel_load: bool = False, ): super().__init__() if not HAVE_MEGATRON_CORE: - raise ImportError(IMPORT_ERROR) from IMPORT_ERROR_EXC + raise ImportError(IMPORT_ERROR) self.save_ckpt_format = save_ckpt_format self.load_directly_on_device = load_directly_on_device self.async_save = async_save - self.save_sharded_strategy = self._determine_dist_ckpt_save_strategy() + self.torch_dist_multiproc = torch_dist_multiproc + self.assume_constant_structure = assume_constant_structure + self.parallel_save = parallel_save + self.parallel_load = parallel_load + + self._save_sharded_strategy = None + self.validated_consistency = False @classmethod def from_config(cls, model_cfg: dict, async_save: bool = False): @@ -208,6 +237,9 @@ def from_config(cls, model_cfg: dict, async_save: bool = False): save_ckpt_format=model_cfg.get('dist_ckpt_format', 'zarr'), load_directly_on_device=model_cfg.get('dist_ckpt_load_on_device', True), async_save=async_save, + torch_dist_multiproc=model_cfg.get('dist_ckpt_torch_dist_multiproc', None), + parallel_save=model_cfg.get('dist_ckpt_parallel_save', False), + parallel_load=model_cfg.get('dist_ckpt_parallel_load', False), ) @_debug_time('DistributedCheckpointIO.save_checkpoint') @@ -224,16 +256,15 @@ def save_checkpoint( fs = get_filesystem(path) fs.makedirs(path, exist_ok=True) - dist_checkpointing.save( - sharded_state_dict=checkpoint, checkpoint_dir=path, sharded_strategy=self.save_sharded_strategy + validate_sharding_integrity = not (self.validated_consistency and self.assume_constant_structure) + self.validated_consistency = True + return dist_checkpointing.save( + sharded_state_dict=checkpoint, + checkpoint_dir=path, + sharded_strategy=self.save_sharded_strategy, + validate_access_integrity=validate_sharding_integrity, + async_sharded_save=self.async_save, ) - if not self.async_save: - return None - # NOTE: this logic will be simplified in MCore v0.7 - assert self.save_sharded_strategy.async_request is not None - async_request = self.save_sharded_strategy.async_request - self.save_sharded_strategy.async_request = None - return async_request @_debug_time('DistributedCheckpointIO.load_checkpoint') def load_checkpoint( @@ -242,6 +273,7 @@ def load_checkpoint( map_location: Optional[Any] = None, sharded_state_dict: Dict[str, Any] = None, strict: Optional[bool] = True, + validate_access_integrity: Optional[bool] = True, ) -> Dict[str, Any]: """Loads a distributed checkpoint. @@ -266,11 +298,24 @@ def load_checkpoint( else: sharded_strategy = None + if self.parallel_load: + if sharded_strategy is None: + sharded_strategy = get_default_load_sharded_strategy(path) + sharded_strategy = FullyParallelLoadStrategyWrapper( + sharded_strategy, get_data_parallel_group(with_context_parallel=True) + ) + + if sharded_strategy is not None: + logging.info(f'Using {sharded_strategy} dist-ckpt load strategy.') + if not strict: sharded_state_dict = self.adjust_non_strict_load(path, sharded_state_dict) return dist_checkpointing.load( - sharded_state_dict=sharded_state_dict, checkpoint_dir=path, sharded_strategy=sharded_strategy + sharded_state_dict=sharded_state_dict, + checkpoint_dir=path, + sharded_strategy=sharded_strategy, + validate_access_integrity=validate_access_integrity, ) def adjust_non_strict_load(self, path: _PATH, sharded_state_dict: Dict[str, Any]): @@ -305,17 +350,36 @@ def remove_checkpoint(self, path: _PATH) -> None: """ shutil.rmtree(path, ignore_errors=True) + @property + def save_sharded_strategy(self) -> 'SaveShardedStrategy': + if self._save_sharded_strategy is None: + self._save_sharded_strategy = self._determine_dist_ckpt_save_strategy() + return self._save_sharded_strategy + def _determine_dist_ckpt_save_strategy(self): """Determine the saving strategy based on constructor args. - If self.async_save is True instantiates an async PyT Dist strategy, - otherwise relies on MCore to create a proper strategy based on ckpt format. + Relies on the default MCore strategy unless extra PyT Distributed format arguments + are passed in config or in case of a fully parallel save in which case + a parallelization wrapper is applied. """ - save_strategy = (self.save_ckpt_format, 1) - if self.async_save: - if save_strategy[0] != 'torch_dist': - raise ValueError('Async dist-ckpt save supported only for torch_dist format') - save_strategy = TorchDistAsyncSaveShardedStrategy('torch_dist', 1) + if self.async_save and self.save_ckpt_format != 'torch_dist': + raise ValueError('Async dist-ckpt save supported only for torch_dist format') + + torch_dist_kwargs = {} if self.torch_dist_multiproc is None else dict(thread_count=self.torch_dist_multiproc) + if self.save_ckpt_format == 'torch_dist' and torch_dist_kwargs: + save_strategy = TorchDistSaveShardedStrategy(self.save_ckpt_format, 1, **torch_dist_kwargs) + else: + save_strategy = get_default_save_sharded_strategy(self.save_ckpt_format, 1) + + # MCore v0.8 introduces `use_cached_ckpt_structure` attribute + if hasattr(save_strategy, 'use_cached_ckpt_structure'): + save_strategy.use_cached_ckpt_structure = self.assume_constant_structure + + if self.parallel_save: + save_strategy = FullyParallelSaveStrategyWrapper( + save_strategy, get_data_parallel_group(with_context_parallel=True), self.assume_constant_structure + ) logging.info(f'Using {save_strategy} dist-ckpt save strategy.') return save_strategy diff --git a/nemo/utils/callbacks/torch_dist_async.py b/nemo/utils/callbacks/torch_dist_async.py deleted file mode 100644 index 1cd226af9cdbe..0000000000000 --- a/nemo/utils/callbacks/torch_dist_async.py +++ /dev/null @@ -1,298 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import deque -from pathlib import Path -from time import time -from typing import Callable, List, NamedTuple, Optional, Tuple - -import torch -from megatron.core.dist_checkpointing.mapping import ShardedStateDict -from megatron.core.dist_checkpointing.strategies.filesystem_async import FileSystemWriterAsync -from megatron.core.dist_checkpointing.strategies.state_dict_saver import ( - save_state_dict_async_finalize, - save_state_dict_async_plan, -) -from megatron.core.dist_checkpointing.strategies.torch import ( - MCoreSavePlanner, - TorchDistSaveShardedStrategy, - _replace_state_dict_keys_with_sharded_keys, - mcore_to_pyt_state_dict, -) -from torch import multiprocessing as mp - -from nemo.utils import logging - - -class TorchDistAsyncSaveShardedStrategy(TorchDistSaveShardedStrategy): - """Async save strategy for the PyT Distributed format. - - NOTE: this class will be removed and replaced with an MCore version - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.async_request = None - - def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): - """Translates MCore ShardedTensors to PyT ShardedTensors and saves in PyT Distributed format. - - Args: - sharded_state_dict (ShardedStateDict): sharded state dict to save - checkpoint_dir (Path): checkpoint directory - - Returns: None - """ - # Translate the state dict - ( - sharded_state_dict, - flat_mapping, - rename_mapping, - ) = _replace_state_dict_keys_with_sharded_keys(sharded_state_dict, self.keep_only_main_replica) - pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, False) - # Use PyT saving mechanism - writer = FileSystemWriterAsync(checkpoint_dir, thread_count=self.thread_count) - - save_state_dict_ret = save_state_dict_async_plan( - pyt_state_dict, - writer, - None, - planner=MCoreSavePlanner(), - ) - self.async_request = self._get_save_and_finalize_callbacks(writer, save_state_dict_ret) - return self.async_request - - def _get_save_and_finalize_callbacks(self, writer, save_state_dict_ret): - save_fn_args = writer.get_save_function_and_args() - if save_fn_args is None: # this check can be removed with MCore v0.7 - save_fn_args = None, () - save_fn, save_args = save_fn_args - - def finalize_fn(): - save_state_dict_async_finalize(*save_state_dict_ret) - torch.distributed.barrier() - - return AsyncRequest(save_fn, save_args, [finalize_fn]) - - -class AsyncRequest(NamedTuple): - """Represents an async request that needs to be scheduled for execution. - - NOTE: this class will be removed and replaced with an MCore version - - Args: - async_fn (Callable, optional): async function to call. None represents noop. - async_fn_args (Tuple): args to pass to `async_fn`. - finalize_fns (List[Callable]): list of functions to call to finalize the request. - These functions will be called synchronously after `async_fn` is done - *on all ranks*. - """ - - async_fn: Optional[Callable] - async_fn_args: Tuple - finalize_fns: List[Callable] - is_frozen: bool = False - - def add_finalize_fn(self, fn: Callable) -> None: - """Adds a new finalize function to the request. - - Args: - fn (Callable): function to add to the async request. This function - will be called *after* existing finalization functions. - - Returns: - None - """ - if self.is_frozen: - raise RuntimeError('Cannot add finalization functions to a frozen AsyncRequest') - self.finalize_fns.append(fn) - - def execute_sync(self) -> None: - """Helper to synchronously execute the request. - - This logic is equivalent to what should happen in case of the async call. - """ - if self.async_fn is not None: - self.async_fn(*self.async_fn_args) - torch.distributed.barrier() - for finalize_fn in self.finalize_fns: - finalize_fn() - - def freeze(self) -> 'AsyncRequest': - """Freezes the async request, disallowing adding new finalization functions. - - Returns: - AsyncRequest: new async request with all same fields except for the - `is_frozen` flag. - """ - return self._replace(is_frozen=True) - - -class DistributedAsyncCaller: - """Wrapper around mp.Process that ensures correct semantic of distributed finalization. - - NOTE: this class will be removed and replaced with an MCore version - - Starts process asynchronously and allows checking if all processes on all ranks are done. - """ - - def __init__(self): - self.process: Optional[mp.Process] = None - self.start_time: Optional[float] = None - - def schedule_async_call( - self, - async_fn: Optional[Callable], - save_args: Tuple, - ) -> None: - """Spawn a process with `async_fn` as the target. - - This method must be called on all ranks. - - Args: - async_fn (Callable, optional): async function to call. If None, - no process will be started. - save_args (Tuple): async function args. - """ - if async_fn is None: - return # nothing to do - torch.cuda.synchronize() - ctx = mp.get_context('fork') - self.start_time = time() - self.process = ctx.Process( - target=async_fn, - args=save_args, - ) - self.process.start() - - def is_current_async_call_done(self, blocking=False) -> bool: - """Check if async save is finished on all ranks. - - For semantic correctness, requires rank synchronization in each check. - This method must be called on all ranks. - - Args: - blocking (bool, optional): if True, will wait until the call is done - on all ranks. Otherwise, returns immediately if at least one rank - is still active. Defaults to False. - - Returns: - bool: True if all ranks are done (immediately of after active wait - if `blocking` is True), False if at least one rank is still active. - """ - # The following takes the same overhead as torch.distributed.barrier (single integer all-reduce) - is_alive = int(self.process.is_alive()) if self.process is not None else 0 - ten = torch.tensor([is_alive], dtype=torch.int, device=torch.cuda.current_device()) - logging.debug(f"[rank {torch.distributed.get_rank()}] DistributedAsyncCaller is_alive:{is_alive}") - torch.distributed.all_reduce(ten) - if ten[0] > 0 and not blocking: - return False - else: - if self.process is not None: - logging.debug(f"rank: {torch.distributed.get_rank()}, joining self.process") - self.process.join() - self.process = None - - logging.debug( - f"DistributedAsyncCaller: Async process join finished after {time() - self.start_time:.2f}s from forking" - ) - self.start_time = None - return True - - -class _ActiveAsyncRequest(NamedTuple): - """Helper to represent an active async call. - - NOTE: this class will be removed and replaced with an MCore version - - Args: - idx (int): index of the call (starting from 0) - async_caller (DistributedAsyncCaller): async caller instance that represents - the async process handling the async request - async_request (AsyncRequest): async request that is being called - """ - - idx: int - async_caller: DistributedAsyncCaller - async_request: AsyncRequest - - -class AsyncCallsQueue: - """Manages a queue of async calls. - - NOTE: this class will be removed and replaced with an MCore version - - Allows adding a new async call with `schedule_async_request` and finalizing - active calls with `maybe_finalize_async_calls`. - """ - - def __init__(self): - self.async_calls: deque[_ActiveAsyncRequest] = deque([]) - self.call_idx: int = -1 - - def schedule_async_request(self, async_request: AsyncRequest) -> int: - """Start a new async call and add it to a queue of active async calls. - - This method must be called on all ranks. - - Args: - async_request (AsyncRequest): async request to start. - - Returns: - int: index of the async call that was started. - This can help the user keep track of the async calls. - """ - self.call_idx += 1 - async_caller = DistributedAsyncCaller() - async_request = async_request.freeze() - async_caller.schedule_async_call(async_request.async_fn, async_request.async_fn_args) - self.async_calls.append(_ActiveAsyncRequest(self.call_idx, async_caller, async_request)) - return self.call_idx - - def maybe_finalize_async_calls(self, blocking=False) -> List[int]: - """Finalizes all available calls. - - This method must be called on all ranks. - - Args: - blocking (bool, optional): if True, will wait until all active requests - are done. Otherwise, finalizes only the async request that already - finished. Defaults to False. - Returns: - List[int]: list of indices (as returned by `schedule_async_request`) - of async calls that have been successfully finalized. - """ - call_idx_finalized = [] - while self.async_calls: - next_async_done = self.async_calls[0].async_caller.is_current_async_call_done(blocking) - if not next_async_done: - break - call_idx, _, async_request = self.async_calls.popleft() - for finalize_fn in async_request.finalize_fns: - finalize_fn() - ten = torch.tensor([call_idx], dtype=torch.int, device=torch.cuda.current_device()) - torch.distributed.all_reduce(ten, op=torch.distributed.ReduceOp.MAX) - assert ( - ten.item() == call_idx - ), 'Unmatched async calls. That probably means not all ranks are participating in async finalization' - call_idx_finalized.append(call_idx) - return call_idx_finalized - - def get_num_unfinalized_calls(self): - """Get the number of active async calls.""" - return len(self.async_calls) - - def close(self): - """Finalize all calls upon closing.""" - self.maybe_finalize_async_calls(blocking=True) diff --git a/nemo/utils/cast_utils.py b/nemo/utils/cast_utils.py index 21e977ec494d8..a7960be4cc4d9 100644 --- a/nemo/utils/cast_utils.py +++ b/nemo/utils/cast_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import nullcontext +from contextlib import contextmanager, nullcontext import torch @@ -91,3 +91,12 @@ def forward(self, *args): return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype) else: return self.mod.forward(*args) + + +@contextmanager +def monkeypatched(object, name, patch): + """Temporarily monkeypatches an object.""" + pre_patched_value = getattr(object, name) + setattr(object, name, patch) + yield object + setattr(object, name, pre_patched_value) diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index 13cf62d699a47..f4bfb8ec95c47 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -14,6 +14,7 @@ import glob import os +import signal import subprocess import sys import time @@ -51,6 +52,21 @@ from nemo.utils.mcore_logger import add_handlers_to_mcore_logger from nemo.utils.model_utils import uninject_model_parallel_rank +try: + # `ptl_resiliency` is included in `gwe_resiliency_pkg` package + from ptl_resiliency import StragglerDetectionCallback + + HAVE_STRAGGLER_DET = True +except (ImportError, ModuleNotFoundError): + HAVE_STRAGGLER_DET = False + +try: + from ptl_resiliency import FaultToleranceCallback + + HAVE_FT = True +except (ImportError, ModuleNotFoundError): + HAVE_FT = False + class NotFoundError(NeMoBaseException): """Raised when a file or folder is not found""" @@ -129,6 +145,34 @@ class EMAParams: every_n_steps: int = 1 +@dataclass +class StragglerDetectionParams: + report_time_interval: float = 300 + calc_relative_gpu_perf: bool = True + calc_individual_gpu_perf: bool = True + num_gpu_perf_scores_to_log: int = 5 + gpu_relative_perf_threshold: float = 0.7 + gpu_individual_perf_threshold: float = 0.7 + stop_if_detected: bool = False + + +@dataclass +class FaultToleranceParams: + # NOTE: This config section is also read by the launcher. + # NOTE: Default values should match fault_tolerance.FaultToleranceConfig. + + workload_check_interval: float = 5.0 + initial_rank_heartbeat_timeout: Optional[float] = 60.0 * 60.0 + rank_heartbeat_timeout: Optional[float] = 45.0 * 60.0 + calculate_timeouts: bool = True + rank_termination_signal: signal.Signals = signal.SIGKILL + log_level: str = 'INFO' + max_rank_restarts: int = 0 + max_subsequent_job_failures: int = 0 + additional_ft_launcher_args: str = '' + simulated_fault: Optional[Any] = None + + @dataclass class ExpManagerConfig: """Experiment Manager config for validation of passed arguments.""" @@ -179,6 +223,12 @@ class ExpManagerConfig: max_time_per_run: Optional[str] = None # time to sleep non 0 ranks during initialization seconds_to_sleep: float = 5 + # Straggler detection + create_straggler_detection_callback: Optional[bool] = False + straggler_detection_params: Optional[StragglerDetectionParams] = field(default_factory=StragglerDetectionParams) + # Fault tolrance + create_fault_tolerance_callback: Optional[bool] = False + fault_tolerance: Optional[FaultToleranceParams] = field(default_factory=FaultToleranceParams) class TimingCallback(Callback): @@ -309,6 +359,8 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo See EarlyStoppingParams dataclass above. - create_preemption_callback (bool): Flag to decide whether to enable preemption callback to save checkpoints and exit training immediately upon preemption. Default is True. + - create_straggler_detection_callback (bool): Use straggler detection callback. Default is False. + - create_fault_tolerance_callback (bool): Use fault tolerance callback. Default is False. - files_to_copy (list): A list of files to copy to the experiment logging directory. Defaults to None which copies no files. - log_local_rank_0_only (bool): Whether to only create log files for local rank 0. Defaults to False. @@ -502,6 +554,35 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo trainer.max_time = cfg.max_time_per_run trainer.callbacks.append(StatelessTimer(cfg.max_time_per_run)) + if cfg.create_straggler_detection_callback: + if HAVE_STRAGGLER_DET: + logging.info("Enabling straggler detection...") + straggler_det_args_dict = dict(cfg.straggler_detection_params) + straggler_det_callback = StragglerDetectionCallback(**straggler_det_args_dict, logger=logging) + trainer.callbacks.append(straggler_det_callback) + else: + raise ValueError( + "`create_straggler_detection_callback` is True, but there is no Straggler Det. package installed." + ) + + if cfg.create_fault_tolerance_callback: + if HAVE_FT: + logging.info("Enabling fault tolerance...") + ft_params = cfg.fault_tolerance + # job failures are handled by the ft_launcher, + # here we only need to know if the autoresume is enabled. + ft_use_autoresume = ft_params.max_subsequent_job_failures > 0 + fault_tol_callback = FaultToleranceCallback( + autoresume=ft_use_autoresume, + calculate_timeouts=ft_params.calculate_timeouts, + simulated_fault_params=ft_params.simulated_fault, + ) + trainer.callbacks.append(fault_tol_callback) + else: + raise ValueError( + 'FaultToleranceCallback was enabled with create_fault_tolerance_callback, but fault_tolerance package is not installed.' + ) + if is_global_rank_zero(): # Move files_to_copy to folder and add git information if present if cfg.files_to_copy: diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index 4c7a166437cc8..534598097bf45 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -72,10 +72,12 @@ def __init__(self, weight, bias, skip_bias_add): self.weight = weight self.skip_bias_add = skip_bias_add - def forward(self, x): + def forward(self, x, weight=None): + if weight is None: + weight = self.weight if self.skip_bias_add: - return F.linear(x, self.weight), self.bias - return F.linear(x, self.weight, self.bias), None + return F.linear(x, weight), self.bias + return F.linear(x, weight, self.bias), None def get_export_format(filename: str): @@ -126,6 +128,11 @@ def parse_input_example(input_example): def to_onnxrt_input(ort_input_names, input_names, input_dict, input_list): odict = {} + if not input_names: + input_list.extend(input_dict.values()) + for k, v in zip(ort_input_names, input_list): + odict[k] = v.cpu().numpy() + return odict for k in reversed(input_names): val = None if k in input_dict: @@ -172,6 +179,8 @@ def verify_runtime(model, output, input_examples, input_names, check_tolerance=0 for input_example in input_examples: input_list, input_dict = parse_input_example(input_example) output_example = model.forward(*input_list, **input_dict) + if not isinstance(output_example, tuple): + output_example = (output_example,) ort_input = to_onnxrt_input(ort_input_names, input_names, input_dict, input_list) all_good = all_good and run_ort_and_compare(sess, ort_input, output_example, check_tolerance) status = "SUCCESS" if all_good else "FAIL" @@ -216,10 +225,12 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01): try: if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance): this_good = False - except Exception: # there may ne size mismatch and it may be OK + except Exception: # there may be size mismatch and it may be OK this_good = False if not this_good: - logging.info(f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}") + logging.info( + f"onnxruntime results mismatch! PyTorch(expected, {expected.shape}):\n{expected}\nONNXruntime, {tout.shape}:\n{tout}" + ) all_good = False return all_good @@ -230,7 +241,8 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01): from apex.contrib.layer_norm.layer_norm import FastLayerNorm from apex.normalization import MixedFusedRMSNorm from apex.normalization.fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm - from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax + from megatron.core.fusions.fused_layer_norm import FusedLayerNorm as MCoreFusedLayerNorm + from megatron.core.fusions.fused_softmax import FusedScaleMaskSoftmax from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]: @@ -246,21 +258,17 @@ def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]: if isinstance(n, FusedLayerNorm) or isinstance(n, MixedFusedLayerNorm): shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine - n_state = n.state_dict() + elif isinstance(n, MCoreFusedLayerNorm): + shape, eps, affine = n.weight.shape, n.eps, True elif isinstance(n, FastLayerNorm): shape, eps, affine = n.weight.shape, n.epsilon, True - n_state = n.state_dict() - elif isinstance(n, MixedFusedRMSNorm): - shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine - tmp_n_state = n.state_dict() - n_state = {'weight': tmp_n_state['weight'], 'bias': torch.zeros_like(tmp_n_state['weight'])} else: return None n_state = n.state_dict() mod = nn.LayerNorm(shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype) - mod.load_state_dict(n_state) + mod.load_state_dict(n_state, strict=True) return mod @@ -297,7 +305,7 @@ def replace_ParallelLinear(n: nn.Module) -> Optional[nn.Linear]: mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(dev) n_state = n.state_dict() - mod.load_state_dict(n_state) + mod.load_state_dict(n_state, strict=False) return mod def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: @@ -309,7 +317,7 @@ def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: Equivalent LayerNorm module """ if not isinstance(n, FusedScaleMaskSoftmax): - logging.warning("This function can only change the FusedScaleMaskSoftmax module.") + logging.warning(f"This function can only change the FusedScaleMaskSoftmax module, got: {n.__class__}") return n # disable the fusion only @@ -322,6 +330,7 @@ def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: default_Apex_replacements = { "FusedLayerNorm": replace_FusedLayerNorm, "MixedFusedLayerNorm": replace_FusedLayerNorm, + "MCoreFusedLayerNorm": replace_FusedLayerNorm, "FastLayerNorm": replace_FusedLayerNorm, "RowParallelLinear": replace_ParallelLinear, "ColumnParallelLinear": replace_ParallelLinear, @@ -374,7 +383,7 @@ def replace_MatchedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: def wrap_module(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]: """ - Generic function generator to replace BaseT module with DestT wrapper. + Generic function generator to replace BaseT module with DestT wrapper. Args: BaseT : module type to replace DestT : destination module type @@ -441,7 +450,7 @@ def script_module(m: nn.Module): def replace_for_export(model: nn.Module) -> nn.Module: """ - Top-level function to replace 'default set' of modules in model, called from _prepare_for_export. + Top-level function to replace 'default set' of modules in model, called from _prepare_for_export. NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. Args: model : top level module @@ -474,3 +483,25 @@ def add_casts_around_norms(model: nn.Module): "MaskedInstanceNorm1d": wrap_module(MaskedInstanceNorm1d, CastToFloatAll), } replace_modules(model, default_cast_replacements) + + +def rename_onnx_io(output, input_names, output_names): + onnx_model = onnx.load(output) + rename_map = {} + for inp, name in zip(onnx_model.graph.input, input_names): + rename_map[inp.name] = name + for out, name in zip(onnx_model.graph.output, output_names): + rename_map[out.name] = name + for n in onnx_model.graph.node: + for inp in range(len(n.input)): + if n.input[inp] in rename_map: + n.input[inp] = rename_map[n.input[inp]] + for out in range(len(n.output)): + if n.output[out] in rename_map: + n.output[out] = rename_map[n.output[out]] + + for i in range(len(input_names)): + onnx_model.graph.input[i].name = input_names[i] + for i in range(len(output_names)): + onnx_model.graph.output[i].name = output_names[i] + onnx.save(onnx_model, output) diff --git a/requirements/requirements_asr.txt b/requirements/requirements_asr.txt index 30e839fd2ca8a..7745f5326047a 100644 --- a/requirements/requirements_asr.txt +++ b/requirements/requirements_asr.txt @@ -2,14 +2,12 @@ braceexpand editdistance einops g2p_en -ipywidgets jiwer kaldi-python-io kaldiio -lhotse>=1.22.0 +lhotse>=1.24.2 librosa>=0.10.0 marshmallow -matplotlib packaging pyannote.core pyannote.metrics diff --git a/requirements/requirements_audio.txt b/requirements/requirements_audio.txt new file mode 100644 index 0000000000000..9e6f07624c9ac --- /dev/null +++ b/requirements/requirements_audio.txt @@ -0,0 +1,9 @@ +einops +lhotse>=1.22.0 +librosa>=0.10.0 +matplotlib +pesq +pystoi +scipy>=0.14 +soundfile +sox diff --git a/requirements/requirements_infer.txt b/requirements/requirements_infer.txt index c18f4e81ade39..5380398c278b0 100644 --- a/requirements/requirements_infer.txt +++ b/requirements/requirements_infer.txt @@ -1,4 +1,6 @@ +fastapi nvidia-pytriton +pydantic-settings tensorstore==0.1.45 +uvicorn zarr - diff --git a/requirements/requirements_lightning.txt b/requirements/requirements_lightning.txt index cf996584da23d..1b3397f690339 100644 --- a/requirements/requirements_lightning.txt +++ b/requirements/requirements_lightning.txt @@ -2,8 +2,8 @@ cloudpickle fiddle hydra-core>1.3,<=1.3.2 omegaconf<=2.3 -pytorch-lightning>=2.2.1 +pytorch-lightning>2.2.1 torchmetrics>=0.11.0 -transformers>=4.36.0,<=4.40.2 +transformers wandb webdataset>=0.2.86 diff --git a/requirements/requirements_nlp.txt b/requirements/requirements_nlp.txt index 494a9ab6d6720..a1dad5b64a8af 100644 --- a/requirements/requirements_nlp.txt +++ b/requirements/requirements_nlp.txt @@ -1,6 +1,5 @@ accelerated-scan boto3 -causal-conv1d==1.2.0.post2 einops faiss-cpu fasttext diff --git a/requirements/requirements_test.txt b/requirements/requirements_test.txt index f0a35f5b087e1..8c356cf3e4614 100644 --- a/requirements/requirements_test.txt +++ b/requirements/requirements_test.txt @@ -1,5 +1,5 @@ black~=24.3 -click==8.0.2 +click>=8.1 isort>5.1.0,<6.0.0 parameterized pytest diff --git a/requirements/requirements_vllm.txt b/requirements/requirements_vllm.txt new file mode 100644 index 0000000000000..a603b3c4ec53f --- /dev/null +++ b/requirements/requirements_vllm.txt @@ -0,0 +1 @@ +vllm==0.5.0 diff --git a/scripts/audio_to_audio/convert_nemo_to_lhotse.py b/scripts/audio_to_audio/convert_nemo_to_lhotse.py index e498a3b2d4609..a9923451286cf 100644 --- a/scripts/audio_to_audio/convert_nemo_to_lhotse.py +++ b/scripts/audio_to_audio/convert_nemo_to_lhotse.py @@ -14,7 +14,7 @@ import argparse -from nemo.collections.asr.data.audio_to_audio_lhotse import convert_manifest_nemo_to_lhotse +from nemo.collections.audio.data.audio_to_audio_lhotse import convert_manifest_nemo_to_lhotse def parse_args(): diff --git a/scripts/checkpoint_converters/convert_clip_hf_to_nemo.py b/scripts/checkpoint_converters/convert_clip_hf_to_nemo.py new file mode 100644 index 0000000000000..690fa74abccd2 --- /dev/null +++ b/scripts/checkpoint_converters/convert_clip_hf_to_nemo.py @@ -0,0 +1,248 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage example: + torchrun --nproc-per-node=1 /opt/NeMo/scripts/checkpoint_converters/convert_clip_hf_to_nemo.py \ + --input_name_or_path=openai/clip-vit-large-patch14 \ + --output_path=openai_clip.nemo \ + --hparams_file=/opt/NeMo/examples/multimodal/vision_language_foundation/clip/conf/megatron_clip_VIT-L-14.yaml + +Additionally, provide a NeMo hparams file with the correct model architecture arguments. Refer to examples/multimodal/foundation/clip/conf/megatron_clip_config.yaml. + +After conversion, you can verify with the following command: + + wget https://upload.wikimedia.org/wikipedia/commons/0/0f/1665_Girl_with_a_Pearl_Earring.jpg + torchrun --nproc-per-node=1 /opt/NeMo/examples/multimodal/vision_language_foundation/clip/megatron_clip_infer.py \ + model.restore_from_path=./openai_clip.nemo \ + image_path=./1665_Girl_with_a_Pearl_Earring.jpg \ + texts='["a dog", "a boy", "a girl"]' + +It should generate a high probability for "a girl" tag, e.g. +Given image's CLIP text probability: [('a dog', 0.0049710185), ('a boy', 0.002258187), ('a girl', 0.99277073)] + +""" + +import os +from argparse import ArgumentParser + +import torch +from omegaconf import OmegaConf +from pytorch_lightning.plugins.environments import TorchElasticEnvironment +from pytorch_lightning.trainer.trainer import Trainer +from transformers import CLIPModel + +from nemo.collections.multimodal.models.vision_language_foundation.clip.megatron_clip_models import MegatronCLIPModel +from nemo.utils import AppState, logging +from nemo.utils.distributed import initialize_distributed + +try: + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + + +def get_args(): + parser = ArgumentParser() + parser.add_argument("--input_name_or_path", type=str, default="openai/clip-vit-base-patch32") + + parser.add_argument( + "--hparams_file", + type=str, + default=None, + required=True, + help="Path config for restoring. It's created during training and may need to be modified during restore if restore environment is different than training. Ex: /opt/NeMo/examples/multimodal/vision_language_foundation/clip/conf/megatron_clip_VIT-L-14.yaml", + ) + parser.add_argument("--output_path", type=str, default=None, required=True, help="Path to output .nemo file.") + + parser.add_argument("--gpus_per_node", type=int, required=False, default=1) + parser.add_argument("--tensor_model_parallel_size", type=int, required=False, default=1) + parser.add_argument("--pipeline_model_parallel_size", type=int, required=False, default=1) + parser.add_argument( + "--pipeline_model_parallel_split_rank", + type=int, + required=False, + default=None, + help="If pipeline parallel size > 1, this is the rank at which the encoder ends and the decoder begins.", + ) + parser.add_argument("--local_rank", type=int, required=False, default=os.getenv('LOCAL_RANK', -1)) + + args = parser.parse_args() + return args + + +def mapping_hf_state_dict(hf_model): + hf_state_dict = hf_model.state_dict() + hf_config = hf_model.config + key_mapping = { + "text_projection.weight": "text_encoder.head.weight", + "visual_projection.weight": "vision_encoder.head.weight", + } + + layer_mapping = { + ".layer_norm1.weight": ".self_attention.linear_qkv.layer_norm_weight", + ".layer_norm1.bias": ".self_attention.linear_qkv.layer_norm_bias", + ".layer_norm2.weight": ".mlp.linear_fc1.layer_norm_weight", + ".layer_norm2.bias": ".mlp.linear_fc1.layer_norm_bias", + ".self_attn.out_proj.weight": ".self_attention.linear_proj.weight", + ".self_attn.out_proj.bias": ".self_attention.linear_proj.bias", + ".mlp.fc1.weight": ".mlp.linear_fc1.weight", + ".mlp.fc1.bias": ".mlp.linear_fc1.bias", + ".mlp.fc2.weight": ".mlp.linear_fc2.weight", + ".mlp.fc2.bias": ".mlp.linear_fc2.bias", + ".pre_layrnorm.weight": ".ln_pre.weight", + ".pre_layrnorm.bias": ".ln_pre.bias", + ".post_layernorm.weight": ".final_layernorm.weight", + ".post_layernorm.bias": ".final_layernorm.bias", + ".embeddings.patch_embedding.weight": ".conv1.weight", + ".embeddings.class_embedding": ".class_token", + ".final_layer_norm.weight": ".final_layernorm.weight", + ".final_layer_norm.bias": ".final_layernorm.bias", + ".embeddings.token_embedding.weight": ".embedding.word_embeddings.weight", + "vision_encoder.embeddings.position_embedding.weight": "vision_encoder.position_embeddings.weight", + "text_encoder.embeddings.position_embedding.weight": "text_encoder.embedding.position_embeddings.weight", + } + + nemo_state_dict = {} + for key in hf_state_dict.keys(): + if key.startswith("text_model.encoder.layers"): + key_ = key.replace("text_model.encoder.layers", "text_encoder.decoder.layers") + elif key.startswith("vision_model.encoder.layers"): + key_ = key.replace("vision_model.encoder.layers", "vision_encoder.decoder.layers") + elif key.startswith('vision_model.'): + key_ = key.replace("vision_model.", "vision_encoder.") + elif key.startswith('text_model.'): + key_ = key.replace('text_model.', 'text_encoder.') + else: + key_ = key + for pat in key_mapping: + if key_ == pat: + key_ = key_.replace(pat, key_mapping[pat]) + for pat in layer_mapping: + if key_.endswith(pat): + key_ = key_[: -len(pat)] + layer_mapping[pat] + break + if "vision" in key_: + config = hf_config.vision_config + else: + config = hf_config.text_config + head_num = num_query_groups = config.num_attention_heads + hidden_size = config.hidden_size + head_size = hidden_size // head_num + heads_per_group = head_num // num_query_groups + + if 'q_proj.weight' in key_: + key_k = key.replace('q_proj', 'k_proj') + key_v = key.replace('q_proj', 'v_proj') + key_new = key_.replace('self_attn.q_proj', 'self_attention.linear_qkv') + q_weight, k_weight, v_weight = hf_state_dict[key], hf_state_dict[key_k], hf_state_dict[key_v] + + q_weight = q_weight.reshape(head_num, head_size, hidden_size) + k_weight = k_weight.reshape(num_query_groups, head_size, hidden_size) + v_weight = v_weight.reshape(num_query_groups, head_size, hidden_size) + qkv_weight = torch.empty((0, head_size, hidden_size), device=q_weight.device) + for i in range(num_query_groups): + qkv_weight = torch.cat((qkv_weight, q_weight[i * heads_per_group : (i + 1) * heads_per_group, :, :])) + qkv_weight = torch.cat((qkv_weight, k_weight[i : i + 1, :, :])) + qkv_weight = torch.cat((qkv_weight, v_weight[i : i + 1, :, :])) + qkv_weight = qkv_weight.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + nemo_state_dict[key_new] = qkv_weight + + elif 'q_proj.bias' in key_: + key_k = key.replace('q_proj', 'k_proj') + key_v = key.replace('q_proj', 'v_proj') + key_new = key_.replace('self_attn.q_proj', 'self_attention.linear_qkv') + q_bias, k_bias, v_bias = hf_state_dict[key], hf_state_dict[key_k], hf_state_dict[key_v] + + q_bias = q_bias.reshape(head_num, head_size) + k_bias = k_bias.reshape(num_query_groups, head_size) + v_bias = v_bias.reshape(num_query_groups, head_size) + qkv_bias = torch.empty((0, head_size), device=q_bias.device) + for i in range(num_query_groups): + qkv_bias = torch.cat((qkv_bias, q_bias[i * heads_per_group : (i + 1) * heads_per_group, :])) + qkv_bias = torch.cat((qkv_bias, k_bias[i : i + 1, :])) + qkv_bias = torch.cat((qkv_bias, v_bias[i : i + 1, :])) + qkv_bias = qkv_bias.reshape([head_size * (head_num + 2 * num_query_groups)]) + nemo_state_dict[key_new] = qkv_bias + elif not ('k_proj' in key_ or 'v_proj' in key_ or 'position_ids' in key_): + nemo_state_dict[key_] = hf_state_dict[key] + + nemo_state_dict["vision_encoder.class_token"] = nemo_state_dict["vision_encoder.class_token"].reshape(1, 1, -1) + + return nemo_state_dict + + +def convert(local_rank, rank, world_size, args): + app_state = AppState() + app_state.data_parallel_rank = 0 + num_nodes = world_size // args.gpus_per_node + trainer = Trainer( + devices=args.gpus_per_node, num_nodes=num_nodes, accelerator='gpu', plugins=[TorchElasticEnvironment()] + ) + + app_state.pipeline_model_parallel_size = args.pipeline_model_parallel_size + app_state.tensor_model_parallel_size = args.tensor_model_parallel_size + + # no use atm, use to split ranks in encoder/decoder models. + if args.pipeline_model_parallel_size > 1 and args.model_type in []: + if args.pipeline_model_parallel_split_rank is not None: + app_state.pipeline_model_parallel_split_rank = args.pipeline_model_parallel_split_rank + else: + if args.pipeline_model_parallel_size % 2 != 0: + raise ValueError( + f"Pipeline model parallel size {args.pipeline_model_parallel_size} must be even if split rank is not specified." + ) + else: + # If split rank is not set, then we set it to be pipeline_model_parallel_size // 2 - this is because in most cases we have the same number of enc/dec layers. + app_state.pipeline_model_parallel_split_rank = args.pipeline_model_parallel_size // 2 + else: + app_state.pipeline_model_parallel_split_rank = None + + app_state.model_parallel_size = app_state.tensor_model_parallel_size * app_state.pipeline_model_parallel_size + + parallel_state.initialize_model_parallel( + tensor_model_parallel_size=app_state.tensor_model_parallel_size, + pipeline_model_parallel_size=app_state.pipeline_model_parallel_size, + pipeline_model_parallel_split_rank=app_state.pipeline_model_parallel_split_rank, + ) + + app_state.pipeline_model_parallel_rank = parallel_state.get_pipeline_model_parallel_rank() + app_state.tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank() + + cfg = OmegaConf.load(args.hparams_file) + cfg.model.mcore_gpt = True + cfg.model.transformer_engine = True + cfg.model.text.position_embedding_type = "learned_absolute" + cfg.model.vision.position_embedding_type = "learned_absolute" + + model = MegatronCLIPModel(cfg.model, trainer) + + hf_model = CLIPModel.from_pretrained(args.input_name_or_path) + state_dict = mapping_hf_state_dict(hf_model) + + model.model.load_state_dict(state_dict, strict=False) + + model.save_to(args.output_path) + + logging.info(f'NeMo model saved to: {args.output_path}') + + +if __name__ == '__main__': + args = get_args() + local_rank, rank, world_size = initialize_distributed(args) + convert(local_rank, rank, world_size, args) diff --git a/scripts/checkpoint_converters/convert_gpt_nemo_to_mcore.py b/scripts/checkpoint_converters/convert_gpt_nemo_to_mcore.py index 70c323553eb79..1f8c69b5b2403 100644 --- a/scripts/checkpoint_converters/convert_gpt_nemo_to_mcore.py +++ b/scripts/checkpoint_converters/convert_gpt_nemo_to_mcore.py @@ -88,6 +88,9 @@ def get_mcore_model_from_nemo_file(nemo_restore_from_path, cpu_only=False): model_cfg.mcore_gpt = True model_cfg.use_cpu_initialization = cpu_only + # The key mappings use TE spec, hence set the TE flag to True + model_cfg.transformer_engine = True + logging.info("*** initializing mcore model with the following config") logging.info(OmegaConf.to_yaml(model_cfg)) trainer = Trainer(devices=1, accelerator='cpu', strategy=NLPDDPStrategy()) @@ -125,9 +128,9 @@ def build_key_mapping(nemo_cfg): f"{model_str}.decoder.final_layernorm.weight": "model.language_model.encoder.final_layernorm.weight", } if has_layernorm_bias: - mcore_to_nemo_mapping[ - f"{model_str}.decoder.final_layernorm.bias" - ] = "model.language_model.encoder.final_layernorm.bias" + mcore_to_nemo_mapping[f"{model_str}.decoder.final_layernorm.bias"] = ( + "model.language_model.encoder.final_layernorm.bias" + ) if not nemo_cfg.get("share_embeddings_and_output_weights", True): mcore_to_nemo_mapping[f"{model_str}.output_layer.weight"] = "model.language_model.output_layer.weight" @@ -135,9 +138,9 @@ def build_key_mapping(nemo_cfg): if nemo_cfg.get("position_embedding_type", 'learned_absolute') == 'rope': mcore_to_nemo_mapping[f"{model_str}.rotary_pos_emb.inv_freq"] = "model.language_model.rotary_pos_emb.inv_freq" else: - mcore_to_nemo_mapping[ - f"{model_str}.embedding.position_embeddings.weight" - ] = "model.language_model.embedding.position_embeddings.weight" + mcore_to_nemo_mapping[f"{model_str}.embedding.position_embeddings.weight"] = ( + "model.language_model.embedding.position_embeddings.weight" + ) nemo_prefix = "model.language_model.encoder.layers" mcore_prefix = f"{model_str}.decoder.layers" @@ -335,5 +338,7 @@ def run_sanity_checks(nemo_file, mcore_file, cpu_only=False, ignore_if_missing=t try: run_sanity_checks(input_nemo_file, output_nemo_file, cpu_only=cpu_only, ignore_if_missing=ignore_if_missing) except torch.cuda.OutOfMemoryError: - logging.info("✅ Conversion was successful, but could not run sanity check due to torch.cuda.OutOfMemoryError.") + logging.info( + "✅ Conversion was successful, but could not run sanity check due to torch.cuda.OutOfMemoryError." + ) logging.info("Please run the script with the same command again to run sanity check.") diff --git a/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py b/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py new file mode 100644 index 0000000000000..9dfd9565179d9 --- /dev/null +++ b/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py @@ -0,0 +1,176 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +from argparse import ArgumentParser +from collections import defaultdict +import torch +from omegaconf.omegaconf import OmegaConf +from nemo.collections.nlp.models.language_modeling.megatron_mamba_model import MegatronMambaModel +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder +from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision +from nemo.utils import logging + +''' +Example + +CUDA_VISIBLE_DEVICES="0" python /NeMo/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py \ + --input_name_or_path \ + --output_path \ + --mamba_ssm_ngroups 8 \ + --precision bf16 \ + --tokenizer_model_dir +''' + + +def get_args(): + parser = ArgumentParser() + parser.add_argument( + "--hparams_file", + type=str, + default=f"{os.path.dirname(__file__)}/../../examples/nlp/language_modeling/conf/megatron_mamba_config.yaml", + required=False, + help="Path config for restoring. It's created during training and may need to be modified during restore if restore environment is different than training. Ex: /raid/nemo_experiments/megatron_gpt/hparams.yaml", + ) + parser.add_argument("--output_path", type=str, default=None, required=True, help="Path to output .nemo file.") + parser.add_argument( + "--input_name_or_path", + type=str, + required=True, + ) + parser.add_argument("--mamba_ssm_ngroups", type=int, default=8, help="ngroups for Mamba model") + parser.add_argument( + "--precision", type=str, default="bf16", choices=["bf16", "32"], help="Precision for checkpoint weights saved" + ) + parser.add_argument( + "--tokenizer_model_dir", type=str, default=None, help="Path to the tokenizer.model, required for 8b models" + ) + args = parser.parse_args() + return args + + +def convert(args): + + checkpoint_weights = torch.load(args.input_name_or_path, map_location='cpu') + new_state_dict = {} + + if 'backbone' in list(checkpoint_weights.keys())[0]: + + layer_keys = [key for key in checkpoint_weights.keys() if re.match(r'backbone\.layers\.\d+\.', key)] + layer_numbers = set(int(re.search(r'backbone\.layers\.(\d+)\.', key).group(1)) for key in layer_keys) + num_layers = max(layer_numbers) + 1 + + direct_mappings = { + 'model.embedding.word_embeddings.weight': 'backbone.embedding.weight', + 'model.decoder.final_norm.weight': 'backbone.norm_f.weight', + 'model.output_layer.weight': 'lm_head.weight', + } + + for new_key, old_key in direct_mappings.items(): + new_state_dict[new_key] = checkpoint_weights[old_key] + + layer_attributes = [ + 'mixer.A_log', + 'mixer.D', + 'mixer.conv1d.weight', + 'mixer.conv1d.bias', + 'mixer.in_proj.weight', + 'mixer.dt_bias', + 'mixer.out_proj.weight', + 'mixer.norm.weight', + 'norm.weight', + ] + + for i in range(num_layers): + for attr in layer_attributes: + new_key = f'model.decoder.layers.{i}.{attr}' + old_key = f'backbone.layers.{i}.{attr}' + new_state_dict[new_key] = checkpoint_weights[old_key] + + # Tokenizer settings + tokenizer_library = 'huggingface' + tokenizer_type = 'EleutherAI/gpt-neox-20b' + tokenizer_model = None + + else: + + layer_keys = [key for key in checkpoint_weights.keys() if re.match(r'decoder\.layers\.\d+\.', key)] + layer_numbers = set(int(re.search(r'decoder\.layers\.(\d+)\.', key).group(1)) for key in layer_keys) + num_layers = max(layer_numbers) + 1 + + new_state_dict = {"model." + key: value for key, value in checkpoint_weights.items()} + + # Tokenizer settings + tokenizer_library = 'megatron' + tokenizer_type = 'GPTSentencePieceTokenizer' + tokenizer_model = args.tokenizer_model_dir + + layers = defaultdict(list) + + for key in new_state_dict.keys(): + match = re.match(r'model\.decoder\.layers\.(\d+)\.(\w+)', key) + if match: + index, layer_type = match.groups() + layers[index].append(layer_type) + + layer_pattern = '' + for i in range(max(map(int, layers.keys())) + 1): + index_str = str(i) + layer_types = layers.get(index_str, []) + if 'mixer' in layer_types: + layer_pattern += 'M' + elif 'self_attention' in layer_types: + layer_pattern += '*' + elif 'mlp' in layer_types: + layer_pattern += '-' + else: + raise AssertionError("Layer not found. Each layer must be eiher MLP, Mamba, or Attention") + + nemo_config = OmegaConf.load(args.hparams_file) + nemo_config.trainer["precision"] = args.precision + nemo_config.model.vocab_size, nemo_config.model.hidden_size = new_state_dict[ + 'model.embedding.word_embeddings.weight' + ].shape + nemo_config.model.num_layers = num_layers + nemo_config.model.hybrid_override_pattern = layer_pattern + nemo_config.model.mamba_ssm_ngroups = args.mamba_ssm_ngroups + nemo_config.model.tokenizer.library = tokenizer_library + nemo_config.model.tokenizer.type = tokenizer_type + nemo_config.model.tokenizer.model = tokenizer_model + + if "-" in layer_pattern: + nemo_config.model.ffn_hidden_size = new_state_dict[ + f'model.decoder.layers.{layer_pattern.index("-")}.mlp.linear_fc1.weight' + ].shape[0] + else: + nemo_config.model.ffn_hidden_size = nemo_config.model.hidden_size + + nemo_config.model.use_cpu_initialization = True + + logging.info(f"Loading Mamba2 Pytorch checkpoint : `{args.input_name_or_path}`") + + trainer = MegatronLMPPTrainerBuilder(nemo_config).create_trainer() + nemo_model_from_pyt = MegatronMambaModel(nemo_config.model, trainer) + + nemo_model_from_pyt.load_state_dict(new_state_dict, strict=True) + dtype = torch_dtype_from_precision(args.precision) + nemo_model_from_pyt = nemo_model_from_pyt.to(dtype=dtype) + nemo_model_from_pyt.save_to(args.output_path) + logging.info(f'Mamba2 NeMo model saved to: {args.output_path}') + + +if __name__ == '__main__': + args = get_args() + convert(args) diff --git a/scripts/checkpoint_converters/convert_mistral_7b_hf_to_nemo.py b/scripts/checkpoint_converters/convert_mistral_7b_hf_to_nemo.py index cb11bb5da5641..3a72661499bf1 100644 --- a/scripts/checkpoint_converters/convert_mistral_7b_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_mistral_7b_hf_to_nemo.py @@ -54,7 +54,7 @@ def get_args(): help="Path to Huggingface Mistral-7b checkpoints", ) parser.add_argument("--output_path", type=str, default=None, required=True, help="Path to output .nemo file.") - parser.add_argument("--precision", type=str, default="32", help="Model precision") + parser.add_argument("--precision", type=str, default="bf16", help="Model precision") args = parser.parse_args() return args @@ -167,7 +167,7 @@ def convert(args): scaler = None if precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=nemo_config.get('native_amp_init_scale', 2 ** 32), + init_scale=nemo_config.get('native_amp_init_scale', 2**32), growth_interval=nemo_config.get('native_amp_growth_interval', 1000), hysteresis=nemo_config.get('hysteresis', 2), ) @@ -329,6 +329,22 @@ def convert(args): model = model.to(dtype=dtype) model.cfg.use_cpu_initialization = False + if getattr(tokenizer, 'chat_template', None) is not None: + import hashlib + + assert ( + hashlib.md5(tokenizer.chat_template.encode('utf-8')).hexdigest() == "0b629f783db54e02509999196956ff40" + ), "Got unkown chat template" + from omegaconf import OmegaConf, open_dict + + with open_dict(model.cfg): + model.cfg.tokenizer.chat_template = OmegaConf.create( + { + 'prefix': "{_bos_}", + 'roles': {'User': "[INST] {_content_} [/INST]", 'Assistant': "{_content_}{_eos_}"}, + } + ) + model.save_to(args.output_path) logging.info(f'NeMo model saved to: {args.output_path}') diff --git a/scripts/checkpoint_converters/convert_mixtral_hf_to_nemo.py b/scripts/checkpoint_converters/convert_mixtral_hf_to_nemo.py index 8183b0d142c1d..1bf23224357fa 100644 --- a/scripts/checkpoint_converters/convert_mixtral_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_mixtral_hf_to_nemo.py @@ -50,11 +50,17 @@ def get_args(): parser = ArgumentParser() parser.add_argument( - "--input_name_or_path", type=str, default=None, required=True, help="Path to Huggingface Mixtral checkpoints", + "--input_name_or_path", + type=str, + default=None, + required=True, + help="Path to Huggingface Mixtral checkpoints", ) parser.add_argument("--output_path", type=str, default=None, required=True, help="Path to output .nemo file.") - valid_precision_values = [16, '16', 'bf16', '16-mixed', 'bf16-mixed', 32, '32'] - parser.add_argument("--precision", type=str, default="32", choices=valid_precision_values, help="Model precision") + valid_precision_values = [16, '16', 'bf16', '16-mixed', 'bf16-mixed'] + parser.add_argument( + "--precision", type=str, default="bf16", choices=valid_precision_values, help="Model precision" + ) parser.add_argument('--low-ram', action='store_true') parser.add_argument('--tmp-dir', default='/tmp/mixtral_ckpt_parts/') args = parser.parse_args() @@ -185,7 +191,7 @@ def make_trainer(args, nemo_config): scaler = None if precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=nemo_config.get('native_amp_init_scale', 2 ** 32), + init_scale=nemo_config.get('native_amp_init_scale', 2**32), growth_interval=nemo_config.get('native_amp_growth_interval', 1000), hysteresis=nemo_config.get('hysteresis', 2), ) diff --git a/scripts/checkpoint_converters/convert_siglip_hf_to_nemo.py b/scripts/checkpoint_converters/convert_siglip_hf_to_nemo.py new file mode 100644 index 0000000000000..97a9d557f78b6 --- /dev/null +++ b/scripts/checkpoint_converters/convert_siglip_hf_to_nemo.py @@ -0,0 +1,380 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Requires HF transformers updated to support Gemma Models + python3 /opt/NeMo/scripts/nlp_language_modeling/convert_gemma_hf_to_nemo.py \ + --input_name_or_path /path/to/gemma/checkpoints/hf/7b \ + --output_path /path/to/gemma-7b.nemo \ + --tokenizer_path /path/to/tokenizer.model +""" + +import os +from argparse import ArgumentParser + +import torch +from omegaconf import OmegaConf +from transformers import AutoModel, AutoProcessor + +from nemo.collections.multimodal.models.vision_language_foundation.clip.megatron_clip_models import MegatronCLIPModel +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision +from nemo.utils import logging + + +def create_rename_keys(num_hidden_layers): + rename_keys = [] + for i in range(num_hidden_layers): + rename_keys.extend( + [ + ( + f"text_model.encoder.layers.{i}.self_attn.k_proj.weight", + f"model.text_encoder.decoder.layers.{i}.self_attention.linear_k.weight", + ), + ( + f"text_model.encoder.layers.{i}.self_attn.k_proj.bias", + f"model.text_encoder.decoder.layers.{i}.self_attention.linear_k.bias", + ), + ( + f"text_model.encoder.layers.{i}.self_attn.q_proj.weight", + f"model.text_encoder.decoder.layers.{i}.self_attention.linear_q.weight", + ), + ( + f"text_model.encoder.layers.{i}.self_attn.q_proj.bias", + f"model.text_encoder.decoder.layers.{i}.self_attention.linear_q.bias", + ), + ( + f"text_model.encoder.layers.{i}.self_attn.v_proj.weight", + f"model.text_encoder.decoder.layers.{i}.self_attention.linear_v.weight", + ), + ( + f"text_model.encoder.layers.{i}.self_attn.v_proj.bias", + f"model.text_encoder.decoder.layers.{i}.self_attention.linear_v.bias", + ), + ( + f"text_model.encoder.layers.{i}.self_attn.out_proj.weight", + f"model.text_encoder.decoder.layers.{i}.self_attention.linear_proj.weight", + ), + ( + f"text_model.encoder.layers.{i}.self_attn.out_proj.bias", + f"model.text_encoder.decoder.layers.{i}.self_attention.linear_proj.bias", + ), + ( + f"text_model.encoder.layers.{i}.layer_norm1.weight", + f"model.text_encoder.decoder.layers.{i}.self_attention.linear_qkv.layer_norm_weight", + ), + ( + f"text_model.encoder.layers.{i}.layer_norm1.bias", + f"model.text_encoder.decoder.layers.{i}.self_attention.linear_qkv.layer_norm_bias", + ), + ( + f"text_model.encoder.layers.{i}.mlp.fc1.weight", + f"model.text_encoder.decoder.layers.{i}.mlp.linear_fc1.weight", + ), + ( + f"text_model.encoder.layers.{i}.mlp.fc1.bias", + f"model.text_encoder.decoder.layers.{i}.mlp.linear_fc1.bias", + ), + ( + f"text_model.encoder.layers.{i}.mlp.fc2.weight", + f"model.text_encoder.decoder.layers.{i}.mlp.linear_fc2.weight", + ), + ( + f"text_model.encoder.layers.{i}.mlp.fc2.bias", + f"model.text_encoder.decoder.layers.{i}.mlp.linear_fc2.bias", + ), + ( + f"text_model.encoder.layers.{i}.layer_norm2.weight", + f"model.text_encoder.decoder.layers.{i}.mlp.linear_fc1.layer_norm_weight", + ), + ( + f"text_model.encoder.layers.{i}.layer_norm2.bias", + f"model.text_encoder.decoder.layers.{i}.mlp.linear_fc1.layer_norm_bias", + ), + ( + f"vision_model.encoder.layers.{i}.self_attn.k_proj.weight", + f"model.vision_encoder.decoder.layers.{i}.self_attention.linear_k.weight", + ), + ( + f"vision_model.encoder.layers.{i}.self_attn.k_proj.bias", + f"model.vision_encoder.decoder.layers.{i}.self_attention.linear_k.bias", + ), + ( + f"vision_model.encoder.layers.{i}.self_attn.v_proj.weight", + f"model.vision_encoder.decoder.layers.{i}.self_attention.linear_v.weight", + ), + ( + f"vision_model.encoder.layers.{i}.self_attn.v_proj.bias", + f"model.vision_encoder.decoder.layers.{i}.self_attention.linear_v.bias", + ), + ( + f"vision_model.encoder.layers.{i}.self_attn.q_proj.weight", + f"model.vision_encoder.decoder.layers.{i}.self_attention.linear_q.weight", + ), + ( + f"vision_model.encoder.layers.{i}.self_attn.q_proj.bias", + f"model.vision_encoder.decoder.layers.{i}.self_attention.linear_q.bias", + ), + ( + f"vision_model.encoder.layers.{i}.self_attn.out_proj.weight", + f"model.vision_encoder.decoder.layers.{i}.self_attention.linear_proj.weight", + ), + ( + f"vision_model.encoder.layers.{i}.self_attn.out_proj.bias", + f"model.vision_encoder.decoder.layers.{i}.self_attention.linear_proj.bias", + ), + ( + f"vision_model.encoder.layers.{i}.layer_norm1.weight", + f"model.vision_encoder.decoder.layers.{i}.self_attention.linear_qkv.layer_norm_weight", + ), + ( + f"vision_model.encoder.layers.{i}.layer_norm1.bias", + f"model.vision_encoder.decoder.layers.{i}.self_attention.linear_qkv.layer_norm_bias", + ), + ( + f"vision_model.encoder.layers.{i}.mlp.fc1.weight", + f"model.vision_encoder.decoder.layers.{i}.mlp.linear_fc1.weight", + ), + ( + f"vision_model.encoder.layers.{i}.mlp.fc1.bias", + f"model.vision_encoder.decoder.layers.{i}.mlp.linear_fc1.bias", + ), + ( + f"vision_model.encoder.layers.{i}.mlp.fc2.weight", + f"model.vision_encoder.decoder.layers.{i}.mlp.linear_fc2.weight", + ), + ( + f"vision_model.encoder.layers.{i}.mlp.fc2.bias", + f"model.vision_encoder.decoder.layers.{i}.mlp.linear_fc2.bias", + ), + ( + f"vision_model.encoder.layers.{i}.layer_norm2.weight", + f"model.vision_encoder.decoder.layers.{i}.mlp.linear_fc1.layer_norm_weight", + ), + ( + f"vision_model.encoder.layers.{i}.layer_norm2.bias", + f"model.vision_encoder.decoder.layers.{i}.mlp.linear_fc1.layer_norm_bias", + ), + ] + ) + + rename_keys.extend( + [ + ("logit_scale", "model.logit_scale"), + ("logit_bias", "model.logit_bias"), + ("vision_model.embeddings.patch_embedding.weight", "model.vision_encoder.conv1.weight"), + ("vision_model.embeddings.patch_embedding.bias", "model.vision_encoder.conv1.bias"), + ("vision_model.embeddings.position_embedding.weight", "model.vision_encoder.position_embeddings.weight"), + ("vision_model.post_layernorm.weight", "model.vision_encoder.final_layernorm.weight"), + ("vision_model.post_layernorm.bias", "model.vision_encoder.final_layernorm.bias"), + ("vision_model.head.probe", "model.vision_encoder.head.probe"), + ( + "vision_model.head.attention.in_proj_weight", + "model.vision_encoder.head.cross_attention.linear_qkv.weight", + ), + ("vision_model.head.attention.in_proj_bias", "model.vision_encoder.head.cross_attention.linear_qkv.bias"), + ( + "vision_model.head.attention.out_proj.weight", + "model.vision_encoder.head.cross_attention.linear_proj.weight", + ), + ( + "vision_model.head.attention.out_proj.bias", + "model.vision_encoder.head.cross_attention.linear_proj.bias", + ), + ("vision_model.head.layernorm.weight", "model.vision_encoder.head.mlp.linear_fc1.layer_norm_weight"), + ("vision_model.head.layernorm.bias", "model.vision_encoder.head.mlp.linear_fc1.layer_norm_bias"), + ("vision_model.head.mlp.fc1.weight", "model.vision_encoder.head.mlp.linear_fc1.weight"), + ("vision_model.head.mlp.fc1.bias", "model.vision_encoder.head.mlp.linear_fc1.bias"), + ("vision_model.head.mlp.fc2.weight", "model.vision_encoder.head.mlp.linear_fc2.weight"), + ("vision_model.head.mlp.fc2.bias", "model.vision_encoder.head.mlp.linear_fc2.bias"), + ("text_model.embeddings.token_embedding.weight", "model.text_encoder.embedding.word_embeddings.weight"), + ( + "text_model.embeddings.position_embedding.weight", + "model.text_encoder.embedding.position_embeddings.weight", + ), + ("text_model.final_layer_norm.weight", "model.text_encoder.final_layernorm.weight"), + ("text_model.final_layer_norm.bias", "model.text_encoder.final_layernorm.bias"), + ("text_model.head.weight", "model.text_encoder.head.weight"), + ("text_model.head.bias", "model.text_encoder.head.bias"), + ] + ) + + return rename_keys + + +def rename_model_keys(model_state_dict, rename_keys): + """ + Rename keys in the model's state dictionary based on the provided mappings. + + Parameters: + model_state_dict (dict): The state dictionary of the model. + rename_keys (list): A list of tuples with the mapping (old_key, new_key). + + Returns: + dict: A new state dictionary with updated key names. + """ + + # Create a new state dictionary with updated key names + new_state_dict = {} + + # Track keys from the original state dict to ensure all are processed + remaining_keys = set(model_state_dict.keys()) + + # Iterate over the rename mappings + for old_key, new_key in rename_keys: + if old_key in model_state_dict: + # Rename the key and remove it from the tracking set + new_state_dict[new_key] = model_state_dict[old_key] + remaining_keys.remove(old_key) + + # Check if any keys were not converted from old to new + for old_key in remaining_keys: + print(f"Warning: Key '{old_key}' was not converted.") + + return new_state_dict + + +def adjust_tensor_shapes(model, nemo_state_dict): + """ + Adapt tensor shapes in the state dictionary to ensure compatibility with a different model structure. + + Parameters: + nemo_state_dict (dict): The state dictionary of the model. + + Returns: + dict: The updated state dictionary with modified tensor shapes for compatibility. + """ + model_config = model.cfg + + # Note: For 'key' and 'value' weight and biases, NeMo uses a consolidated tensor 'query_key_value'. + for key_ in list(nemo_state_dict.keys()): + if "vision" in key_: + config = model_config["vision"] + else: + config = model_config["text"] + num_query_groups = head_num = config["num_attention_heads"] + hidden_size = config["hidden_size"] + head_size = hidden_size // head_num + heads_per_group = head_num // num_query_groups + if "bias" in key_: + hidden_size = 1 + + if 'head.cross_attention.linear_qkv.' in key_: + key_q = key_.replace('linear_qkv', 'linear_q') + key_kv = key_.replace('linear_qkv', 'linear_kv') + q_weight, k_weight, v_weight = nemo_state_dict[key_].chunk(3) + k_weight = k_weight.reshape(num_query_groups, head_size, hidden_size) + v_weight = v_weight.reshape(num_query_groups, head_size, hidden_size) + kv_weight = torch.empty((0, head_size, hidden_size), device=q_weight.device) + for i in range(num_query_groups): + kv_weight = torch.cat((kv_weight, k_weight[i : i + 1, :, :])) + kv_weight = torch.cat((kv_weight, v_weight[i : i + 1, :, :])) + kv_weight = kv_weight.reshape([head_size * 2 * num_query_groups, hidden_size]) + if "bias" in key_: + kv_weight = kv_weight.squeeze(-1) + nemo_state_dict[key_q] = q_weight + nemo_state_dict[key_kv] = kv_weight + del nemo_state_dict[key_] + + if 'self_attention.linear_q.' in key_: + key_q = key_ + key_k = key_.replace('linear_q', 'linear_k') + key_v = key_.replace('linear_q', 'linear_v') + key_qkv = key_.replace('linear_q', 'linear_qkv') + + # [(head_num + 2 * num_query_groups) * head_size, hidden_size] + # -> [head_num, head_size, hidden_size], 2 * [num_query_groups, head_size, hidden_size] + q_weight, k_weight, v_weight = nemo_state_dict[key_q], nemo_state_dict[key_k], nemo_state_dict[key_v] + q_weight = q_weight.reshape(head_num, head_size, hidden_size) + k_weight = k_weight.reshape(num_query_groups, head_size, hidden_size) + v_weight = v_weight.reshape(num_query_groups, head_size, hidden_size) + + qkv_weight = torch.empty((0, head_size, hidden_size), device=q_weight.device) + for i in range(num_query_groups): + qkv_weight = torch.cat((qkv_weight, q_weight[i * heads_per_group : (i + 1) * heads_per_group, :, :])) + qkv_weight = torch.cat((qkv_weight, k_weight[i : i + 1, :, :])) + qkv_weight = torch.cat((qkv_weight, v_weight[i : i + 1, :, :])) + qkv_weight = qkv_weight.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + if "bias" in key_: + qkv_weight = qkv_weight.squeeze(-1) + nemo_state_dict[key_qkv] = qkv_weight + del nemo_state_dict[key_q], nemo_state_dict[key_k], nemo_state_dict[key_v] + + return nemo_state_dict + + +def adjust_nemo_config(model_config, ref_config): + model_config["encoder_seq_length"] = ref_config["max_position_embeddings"] + model_config["num_layers"] = ref_config["num_hidden_layers"] + model_config["ffn_hidden_size"] = ref_config["intermediate_size"] + model_config["hidden_size"] = ref_config["hidden_size"] + model_config["num_attention_heads"] = ref_config["num_attention_heads"] + model_config["num_query_groups"] = ref_config["num_key_value_heads"] + model_config["kv_channels"] = ref_config["head_dim"] + model_config["layernorm_epsilon"] = ref_config["rms_norm_eps"] + return model_config + + +def get_args(): + parser = ArgumentParser() + parser.add_argument("--input_name_or_path", type=str) + parser.add_argument("--tokenizer_path", type=str) + parser.add_argument( + "--hparams_file", + type=str, + default=os.path.join( + os.path.dirname(__file__), + '../../examples/multimodal/vision_language_foundation/clip/conf/megatron_siglip_so400m_14_384.yaml', + ), + required=False, + help="Path config for restoring. It's created during training and may need to be modified during restore if restore environment is different than training. Ex: /raid/nemo_experiments/megatron_gpt/hparams.yaml", + ) + parser.add_argument("--output_path", type=str, default=None, help="Path to output .nemo file.") + parser.add_argument( + "--precision", type=str, default="bf16", choices=["bf16", "32"], help="Precision for checkpoint weight saved" + ) + + args = parser.parse_args() + return args + + +def convert(args): + logging.info(f"Loading checkpoint from HF: `{args.input_name_or_path}`") + hf_model = AutoModel.from_pretrained(args.input_name_or_path) + # hf_processor = AutoProcessor.from_pretrained(args.input_name_or_path) + logging.info("HF Model loading done.") + + nemo_config = OmegaConf.load(args.hparams_file) + + nemo_config.trainer["precision"] = args.precision + trainer = MegatronTrainerBuilder(nemo_config).create_trainer() + model = MegatronCLIPModel(nemo_config.model, trainer) + + assert nemo_config.model.text.num_layers == nemo_config.model.vision.num_layers + rename_keys = create_rename_keys(nemo_config.model.text.num_layers) + old_state_dict = hf_model.state_dict() + new_state_dict = rename_model_keys(model_state_dict=old_state_dict, rename_keys=rename_keys) + + nemo_state_dict = adjust_tensor_shapes(model, new_state_dict) + model.load_state_dict(nemo_state_dict, strict=False) + + dtype = torch_dtype_from_precision(args.precision) + model = model.to(dtype=dtype) + model.save_to(args.output_path) + logging.info(f'NeMo model saved to: {args.output_path}') + + +if __name__ == '__main__': + args = get_args() + convert(args) diff --git a/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py b/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py index 86d46e8b535c4..e56298f4e2d1c 100644 --- a/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py @@ -12,46 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -r""" -Conversion script to convert HuggingFace Starcoder2 checkpoints into nemo checkpoint. - Example to run this conversion script: - python convert_hf_starcoder2_to_nemo.py \ - --input_name_or_path \ - --output_path -""" +r""" Conversion script to convert HuggingFace StableDiffusion checkpoints (unet and vae) into checkpoints with nemo naming convention. """ import torch import numpy as np -import json -from pprint import pprint -from safetensors import torch as torch_s import safetensors - -import json import os -from argparse import ArgumentParser -from collections import OrderedDict +from argparse import ArgumentParser import torch import torch.nn -from omegaconf import OmegaConf -from pytorch_lightning.core.saving import _load_state as ptl_load_state -from pytorch_lightning.trainer.trainer import Trainer -from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel -from diffusers import StableDiffusionPipeline, UNet2DConditionModel, AutoencoderKL - -from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel -from nemo.collections.nlp.parts.nlp_overrides import ( - GradScaler, - MegatronHalfPrecisionPlugin, - NLPDDPStrategy, - NLPSaveRestoreConnector, - PipelineMixedPrecisionPlugin, -) from nemo.utils import logging -intkey = lambda x: int(x) - def filter_keys(rule, dict): keys = list(dict.keys()) nd = {k: dict[k] for k in keys if rule(k)} @@ -122,9 +94,6 @@ def model_to_tree(model): keys = list(model.keys()) tree = SegTree() for k in keys: - # wk = model.get(wk, torch.tensor([])) - # bk = model.get(bk, torch.tensor([])) - # tree.add(k, (wk, bk)) tree.add(k, "leaf") return tree @@ -147,32 +116,15 @@ def get_args(): def make_tiny_config(config): ''' dial down the config file to make things tractable ''' - # TODO return config def load_hf_ckpt(in_dir, args): - # takes a directory as input - # params_file = os.path.join(in_dir, 'config.json') - # assert os.path.exists(params_file) - # with open(params_file, 'r') as fp: - # model_args = json.load(fp) - # if args.debug: - # model_args = make_tiny_config(model_args) - - # # model = AutoModel.from_pretrained(in_dir) - # model = AutoModel.from_config(model_args) - # if args.model == 'unet': - # model = UNet2DConditionModel.from_pretrained(in_dir) - # elif args.model == 'vae': - # model = AutoencoderKL.from_pretrained(in_dir) - # model = torch_s.load(in_dir + "/diffusion") - # print(model) + # takes a directory as input, loads the checkpoint into a dict ckpt = {} + assert os.path.isdir(in_dir), "Expected directory with safetensors in it." with safetensors.safe_open(in_dir + "/diffusion_pytorch_model.safetensors", framework="pt") as f: for k in f.keys(): ckpt[k] = f.get_tensor(k) - # input("enter to continue...") - # ckpt = model.state_dict() return args, ckpt def dup_convert_name_recursive(tree: SegTree, convert_name=None): @@ -203,7 +155,7 @@ def convert_input_keys(hf_tree: SegTree): # start counting blocks from now on nemo_inp_blk = 1 down_blocks = hf_tree['down_blocks'] - down_blocks_keys = sorted(list(down_blocks.nodes.keys()), key=intkey) + down_blocks_keys = sorted(list(down_blocks.nodes.keys()), key=int) for downblockid in down_blocks_keys: block = down_blocks[str(downblockid)] # compute number of resnets, attentions, downsamplers in this block @@ -212,14 +164,14 @@ def convert_input_keys(hf_tree: SegTree): downsamplers = block.nodes.get('downsamplers', SegTree()) if len(attentions) == 0: # no attentions, this is a DownBlock2d - for resid in sorted(list(resnets.nodes.keys()), key=intkey): + for resid in sorted(list(resnets.nodes.keys()), key=int): resid = str(resid) resnets[resid].convert_name = f"input_blocks.{nemo_inp_blk}.0" map_resnet_block(resnets[resid]) nemo_inp_blk += 1 elif len(attentions) == len(resnets): # there are attention blocks here -- each resnet+attention becomes a block - for resid in sorted(list(resnets.nodes.keys()), key=intkey): + for resid in sorted(list(resnets.nodes.keys()), key=int): resid = str(resid) resnets[resid].convert_name = f"input_blocks.{nemo_inp_blk}.0" map_resnet_block(resnets[resid]) @@ -244,9 +196,8 @@ def clean_convert_names(tree): def map_attention_block(att_tree: SegTree): ''' this HF tree can either be an AttentionBlock or a DualAttention block currently assumed AttentionBlock - ''' - # TODO: Add check for dual attention block + # TODO(@rohitrango): Add check for dual attention block, but right now this works with SD and SDXL def check_att_type(tree): return "att_block" @@ -323,7 +274,7 @@ def convert_output_keys(hf_tree: SegTree): ''' output keys is similar to input keys ''' nemo_inp_blk = 0 up_blocks = hf_tree['up_blocks'] - up_blocks_keys = sorted(list(up_blocks.nodes.keys()), key=intkey) + up_blocks_keys = sorted(list(up_blocks.nodes.keys()), key=int) for downblockid in up_blocks_keys: block = up_blocks[str(downblockid)] @@ -333,7 +284,7 @@ def convert_output_keys(hf_tree: SegTree): upsamplers = block.nodes.get('upsamplers', SegTree()) if len(attentions) == 0: # no attentions, this is a UpBlock2D - for resid in sorted(list(resnets.nodes.keys()), key=intkey): + for resid in sorted(list(resnets.nodes.keys()), key=int): resid = str(resid) resnets[resid].convert_name = f"output_blocks.{nemo_inp_blk}.0" map_resnet_block(resnets[resid]) @@ -341,7 +292,7 @@ def convert_output_keys(hf_tree: SegTree): elif len(attentions) == len(resnets): # there are attention blocks here -- each resnet+attention becomes a block - for resid in sorted(list(resnets.nodes.keys()), key=intkey): + for resid in sorted(list(resnets.nodes.keys()), key=int): resid = str(resid) resnets[resid].convert_name = f"output_blocks.{nemo_inp_blk}.0" map_resnet_block(resnets[resid]) @@ -351,9 +302,8 @@ def convert_output_keys(hf_tree: SegTree): else: logging.warning("number of attention blocks is not the same as resnets - whats going on?") - # if there is a downsampler, then also append it + # if there is an upsampler, then also append it if len(upsamplers) > 0: - # for k in upsamplers.nodes.keys(): nemo_inp_blk -= 1 upsamplenum = 1 if len(attentions) == 0 else 2 # if there are attention modules, upsample is module2, else it is module 1 (to stay consistent with SD) upsamplers['0'].convert_name = f"output_blocks.{nemo_inp_blk}.{upsamplenum}" @@ -387,17 +337,10 @@ def convert_encoder(hf_tree: SegTree): # map the `mid_block` ( NeMo's mid layer is hardcoded in terms of number of modules) encoder['mid_block'].convert_name = 'mid' - # encoder['mid_block.resnets.0'].convert_name = 'block_1' - # encoder['mid_block.resnets.1'].convert_name = 'block_2' - # map_resnet_block(encoder['mid_block.resnets.0']) - # map_resnet_block(encoder['mid_block.resnets.1']) - # for reskey in {'conv1', 'conv2', 'norm1', 'norm2'}: - # dup_convert_name_recursive(encoder[f'mid_block.resnets.0.{reskey}'], reskey) - # dup_convert_name_recursive(encoder[f'mid_block.resnets.1.{reskey}'], reskey) dup_convert_name_recursive(encoder[f'mid_block.resnets.0'], 'block_1') dup_convert_name_recursive(encoder[f'mid_block.resnets.1'], 'block_2') - # attention part + # attention part of the mid block att = encoder['mid_block.attentions.0'] att.convert_name = 'attn_1' dup_convert_name_recursive(att['group_norm'], 'norm') @@ -418,13 +361,7 @@ def convert_decoder(hf_tree: SegTree): decoder['mid_block'].convert_name = 'mid' dup_convert_name_recursive(decoder[f'mid_block.resnets.0'], 'block_1') dup_convert_name_recursive(decoder[f'mid_block.resnets.1'], 'block_2') - # decoder['mid_block.resnets.0'].convert_name = 'block_1' - # decoder['mid_block.resnets.1'].convert_name = 'block_2' - # map_resnet_block(encoder['mid_block.resnets.0']) - # map_resnet_block(encoder['mid_block.resnets.1']) - # for reskey in {'conv1', 'conv2', 'norm1', 'norm2'}: - # dup_convert_name_recursive(decoder[f'mid_block.resnets.0.{reskey}'], reskey) - # dup_convert_name_recursive(decoder[f'mid_block.resnets.1.{reskey}'], reskey) + # map the attention part of decoder's midblock att = decoder['mid_block.attentions.0'] att.convert_name = 'attn_1' dup_convert_name_recursive(att['group_norm'], 'norm') @@ -482,7 +419,6 @@ def convert(args): for hf_key, nemo_key in mapping.items(): nemo_ckpt[nemo_key] = hf_ckpt[hf_key] # save this - # torch.save(args.output_path, nemo_ckpt) torch.save(nemo_ckpt, args.output_path) logging.info(f"Saved nemo file to {args.output_path}") diff --git a/scripts/checkpoint_converters/quantize_model_to_nf4.py b/scripts/checkpoint_converters/quantize_model_to_nf4.py new file mode 100644 index 0000000000000..05d9c4010c026 --- /dev/null +++ b/scripts/checkpoint_converters/quantize_model_to_nf4.py @@ -0,0 +1,77 @@ +from argparse import ArgumentParser +from typing import List + +import torch +from pytorch_lightning import Trainer +from torch import nn + +from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel +from nemo.collections.nlp.modules.common.megatron.adapters.qlora import nf4_quantize +from nemo.collections.nlp.parts.nlp_overrides import MegatronHalfPrecisionPlugin, NLPDDPStrategy +from nemo.utils import logging + +''' +This script quantizes the weights of linear layers to NF4 precision, then saves them in BF16 precision. +The resulting model will have the same format as the input, but have weights compatible with adapters trained +with QLoRA. +Flow of QLoRA inference +- Path 1 (online quantize): similar to training, set eval peft_scheme to 'qlora' and linear layers will be quantized + immediately after model loading. This is applicable to framework inference only. +- Path 2 (offline quantize): run this script to get a new pretrained base model, then set eval `peft_scheme` to `lora`. +Path 1 and Path 2 yield identical inference results, but Path 2 enables deployment of a QLoRA model without further +changes downstream. + +Example usage: +python scripts/checkpoint_converters/quantize_model_to_nf4.py \ +--input_name_or_path \ +--output_path \ +--target_modules linear_qkv,linear_proj,linear_fc1,linear_fc2 +''' + + +def corrupt_linear_weight_(model: nn.Module, target_modules: List[str]): + """ + Corrupt the linear weights of a model as specified by quantize_targets + "Corrupting" refers to quantizing the linear weights to NF4 then casting back to BF16 + """ + state_dict = model.state_dict() + keys = state_dict.keys() + for k in keys: + if any(f"{l}.weight" in k for l in target_modules): + # Convert a BF16 tensor to NF4 then back to BF16 + state_dict[k] = nf4_quantize(state_dict[k]).dequantize() + model.load_state_dict(state_dict) + + +def get_args(): + parser = ArgumentParser() + parser.add_argument( + "--input_name_or_path", + type=str, + required=True, + help="Path to .nemo base model checkpoint", + ) + parser.add_argument("--output_path", type=str, required=True, help="Path to output quantized .nemo file.") + parser.add_argument( + "--target_modules", + type=str, + default="linear_qkv,linear_proj,linear_fc1,linear_fc2", + help="Comma separated list of which linear module(s) to quantize", + ) + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = get_args() + dummy_trainer = Trainer( + devices=1, + accelerator='gpu', + strategy=NLPDDPStrategy(), + plugins=[MegatronHalfPrecisionPlugin(precision='bf16-mixed', device='cuda')], + ) + model = MegatronGPTSFTModel.restore_from(args.input_name_or_path, trainer=dummy_trainer).to(torch.bfloat16) + corrupt_linear_weight_(model, args.target_modules.split(',')) + + model.save_to(args.output_path) + logging.info(f"Quantized model saved to {args.output_path}") diff --git a/scripts/deploy/multimodal/deploy_triton.py b/scripts/deploy/multimodal/deploy_triton.py new file mode 100755 index 0000000000000..1e339b3405cf3 --- /dev/null +++ b/scripts/deploy/multimodal/deploy_triton.py @@ -0,0 +1,183 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +import sys +from pathlib import Path + +from nemo.deploy import DeployPyTriton + +LOGGER = logging.getLogger("NeMo") + +multimodal_supported = True +try: + from nemo.export.tensorrt_mm_exporter import TensorRTMMExporter +except Exception as e: + LOGGER.warning(f"Cannot import the TensorRTMMExporter exporter, it will not be available. {type(e).__name__}: {e}") + multimodal_supported = False + + +def get_args(argv): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description=f"Deploy nemo models to Triton", + ) + parser.add_argument("-vc", "--visual_checkpoint", type=str, help="Source .nemo file for visual model") + parser.add_argument( + "-lc", + "--llm_checkpoint", + type=str, + required=False, + help="Source .nemo file for llm", + ) + parser.add_argument( + "-mt", + "--model_type", + type=str, + required=True, + choices=["neva", "video-neva"], + help="Type of the model. neva and video-neva are only supported.", + ) + parser.add_argument( + "-lmt", + "--llm_model_type", + type=str, + required=True, + choices=["gptnext", "gpt", "llama", "falcon", "starcoder", "mixtral", "gemma"], + help="Type of LLM. gptnext, gpt, llama, falcon, and starcoder are only supported." + " gptnext and gpt are the same and keeping it for backward compatibility", + ) + parser.add_argument("-tmn", "--triton_model_name", required=True, type=str, help="Name for the service") + parser.add_argument("-tmv", "--triton_model_version", default=1, type=int, help="Version for the service") + parser.add_argument( + "-trp", "--triton_port", default=8000, type=int, help="Port for the Triton server to listen for requests" + ) + parser.add_argument( + "-tha", "--triton_http_address", default="0.0.0.0", type=str, help="HTTP address for the Triton server" + ) + parser.add_argument( + "-tmr", "--triton_model_repository", default=None, type=str, help="Folder for the trt-llm conversion" + ) + parser.add_argument("-ng", "--num_gpus", default=1, type=int, help="Number of GPUs for the deployment") + parser.add_argument( + "-dt", + "--dtype", + choices=["bfloat16", "float16"], + default="bfloat16", + type=str, + help="dtype of the model on TensorRT", + ) + parser.add_argument("-mil", "--max_input_len", default=4096, type=int, help="Max input length of the model") + parser.add_argument("-mol", "--max_output_len", default=256, type=int, help="Max output length of the model") + parser.add_argument("-mbs", "--max_batch_size", default=1, type=int, help="Max batch size of the model") + parser.add_argument("-mml", "--max_multimodal_len", default=3072, type=int, help="Max length of multimodal input") + args = parser.parse_args(argv) + return args + + +def get_trt_deployable(args): + if args.triton_model_repository is None: + trt_path = "/tmp/trt_model_dir/" + LOGGER.info( + "/tmp/trt_model_dir/ path will be used as the TensorRT folder. " + "Please set the --triton_model_repository parameter if you'd like to use a path that already " + "includes the TensorRT model files." + ) + Path(trt_path).mkdir(parents=True, exist_ok=True) + else: + trt_path = args.triton_model_repository + + if args.visual_checkpoint is None and args.triton_model_repository is None: + raise ValueError( + "The provided model repository is not a valid TensorRT model " + "directory. Please provide a --visual_checkpoint." + ) + + if args.visual_checkpoint is None and not os.path.isdir(args.triton_model_repository): + raise ValueError( + "The provided model repository is not a valid TensorRT model " + "directory. Please provide a --visual_checkpoint." + ) + + if args.visual_checkpoint is not None and args.model_type is None: + raise ValueError("Model type is required to be defined if a nemo checkpoint is provided.") + + exporter = TensorRTMMExporter( + model_dir=trt_path, + load_model=(args.visual_checkpoint is None), + ) + + if args.visual_checkpoint is not None: + try: + LOGGER.info("Export operation will be started to export the nemo checkpoint to TensorRT.") + exporter.export( + visual_checkpoint_path=args.visual_checkpoint, + llm_checkpoint_path=args.llm_checkpoint, + model_type=args.model_type, + llm_model_type=args.llm_model_type, + tensor_parallel_size=args.num_gpus, + max_input_len=args.max_input_len, + max_output_len=args.max_output_len, + max_batch_size=args.max_batch_size, + max_multimodal_len=args.max_multimodal_len, + dtype=args.dtype, + ) + except Exception as error: + raise RuntimeError("An error has occurred during the model export. Error message: " + str(error)) + + return exporter + + +def nemo_deploy(argv): + args = get_args(argv) + + loglevel = logging.INFO + + LOGGER.setLevel(loglevel) + LOGGER.info("Logging level set to {}".format(loglevel)) + LOGGER.info(args) + + triton_deployable = get_trt_deployable(args) + + try: + nm = DeployPyTriton( + model=triton_deployable, + triton_model_name=args.triton_model_name, + triton_model_version=args.triton_model_version, + max_batch_size=args.max_batch_size, + port=args.triton_port, + address=args.triton_http_address, + ) + + LOGGER.info("Triton deploy function will be called.") + nm.deploy() + except Exception as error: + LOGGER.error("Error message has occurred during deploy function. Error message: " + str(error)) + return + + try: + LOGGER.info("Model serving on Triton is will be started.") + nm.serve() + except Exception as error: + LOGGER.error("Error message has occurred during deploy function. Error message: " + str(error)) + return + + LOGGER.info("Model serving will be stopped.") + nm.stop() + + +if __name__ == '__main__': + nemo_deploy(sys.argv[1:]) diff --git a/scripts/deploy/multimodal/query.py b/scripts/deploy/multimodal/query.py new file mode 100644 index 0000000000000..955d708730ac4 --- /dev/null +++ b/scripts/deploy/multimodal/query.py @@ -0,0 +1,59 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import sys + +from nemo.deploy.multimodal import NemoQueryMultimodal + + +def get_args(argv): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description=f"Query Triton Multimodal server", + ) + parser.add_argument("-u", "--url", default="0.0.0.0", type=str, help="url for the triton server") + parser.add_argument("-mn", "--model_name", required=True, type=str, help="Name of the triton model") + parser.add_argument("-mt", "--model_type", required=True, type=str, help="Type of the triton model") + parser.add_argument("-int", "--input_text", required=True, type=str, help="Input text") + parser.add_argument("-im", "--input_media", required=True, type=str, help="File path of input media") + parser.add_argument("-bs", "--batch_size", default=1, type=int, help="Batch size") + parser.add_argument("-mol", "--max_output_len", default=128, type=int, help="Max output token length") + parser.add_argument("-tk", "--top_k", default=1, type=int, help="top_k") + parser.add_argument("-tpp", "--top_p", default=0.0, type=float, help="top_p") + parser.add_argument("-t", "--temperature", default=1.0, type=float, help="temperature") + parser.add_argument("-rp", "--repetition_penalty", default=1.0, type=float, help="repetition_penalty") + parser.add_argument("-nb", "--num_beams", default=1, type=int, help="num_beams") + parser.add_argument("-it", "--init_timeout", default=60.0, type=float, help="init timeout for the triton server") + + args = parser.parse_args(argv) + return args + + +if __name__ == '__main__': + args = get_args(sys.argv[1:]) + nq = NemoQueryMultimodal(url=args.url, model_name=args.model_name, model_type=args.model_type) + output = nq.query( + input_text=args.input_text, + input_media=args.input_media, + batch_size=args.batch_size, + max_output_len=args.max_output_len, + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, + repetition_penalty=args.repetition_penalty, + num_beams=args.num_beams, + init_timeout=args.init_timeout, + ) + print(output) diff --git a/scripts/deploy/nlp/deploy_inframework_triton.py b/scripts/deploy/nlp/deploy_inframework_triton.py new file mode 100755 index 0000000000000..b698e4cbacfd8 --- /dev/null +++ b/scripts/deploy/nlp/deploy_inframework_triton.py @@ -0,0 +1,103 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import sys + +from nemo.deploy import DeployPyTriton + +LOGGER = logging.getLogger("NeMo") + +megatron_llm_supported = True +try: + from nemo.deploy.nlp import MegatronLLMDeployable +except Exception as e: + LOGGER.warning(f"Cannot import MegatronLLMDeployable, it will not be available. {type(e).__name__}: {e}") + megatron_llm_supported = False + + +def get_args(argv): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description=f"Deploy nemo models to Triton", + ) + parser.add_argument("-nc", "--nemo_checkpoint", type=str, help="Source .nemo file") + parser.add_argument("-tmn", "--triton_model_name", required=True, type=str, help="Name for the service") + parser.add_argument("-tmv", "--triton_model_version", default=1, type=int, help="Version for the service") + parser.add_argument( + "-trp", "--triton_port", default=8000, type=int, help="Port for the Triton server to listen for requests" + ) + parser.add_argument( + "-tha", "--triton_http_address", default="0.0.0.0", type=str, help="HTTP address for the Triton server" + ) + parser.add_argument("-ng", "--num_gpus", default=1, type=int, help="Number of GPUs for the deployment") + parser.add_argument("-mbs", "--max_batch_size", default=8, type=int, help="Max batch size of the model") + parser.add_argument("-dm", "--debug_mode", default=False, action='store_true', help="Enable debug mode") + args = parser.parse_args(argv) + return args + + +def get_nemo_deployable(args): + if args.nemo_checkpoint is None: + raise ValueError("In-Framework deployment requires a .nemo checkpoint") + + return MegatronLLMDeployable(args.nemo_checkpoint, args.num_gpus) + + +def nemo_deploy(argv): + args = get_args(argv) + + if args.debug_mode: + loglevel = logging.DEBUG + else: + loglevel = logging.INFO + + LOGGER.setLevel(loglevel) + LOGGER.info("Logging level set to {}".format(loglevel)) + LOGGER.info(args) + + if not megatron_llm_supported: + raise ValueError("MegatronLLMDeployable is not supported in this environment.") + triton_deployable = get_nemo_deployable(args) + + try: + nm = DeployPyTriton( + model=triton_deployable, + triton_model_name=args.triton_model_name, + triton_model_version=args.triton_model_version, + max_batch_size=args.max_batch_size, + port=args.triton_port, + address=args.triton_http_address, + ) + + LOGGER.info("Triton deploy function will be called.") + nm.deploy() + except Exception as error: + LOGGER.error("Error message has occurred during deploy function. Error message: " + str(error)) + return + + try: + LOGGER.info("Model serving on Triton is will be started.") + nm.serve() + except Exception as error: + LOGGER.error("Error message has occurred during deploy function. Error message: " + str(error)) + return + + LOGGER.info("Model serving will be stopped.") + nm.stop() + + +if __name__ == '__main__': + nemo_deploy(sys.argv[1:]) diff --git a/scripts/deploy/nlp/deploy_triton.py b/scripts/deploy/nlp/deploy_triton.py index d0854916cd381..a306231bcd611 100755 --- a/scripts/deploy/nlp/deploy_triton.py +++ b/scripts/deploy/nlp/deploy_triton.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,12 +18,26 @@ import sys from pathlib import Path +import uvicorn + from nemo.deploy import DeployPyTriton -from nemo.deploy.nlp import MegatronLLMDeployable -from nemo.export import TensorRTLLM LOGGER = logging.getLogger("NeMo") +megatron_llm_supported = True +try: + from nemo.deploy.nlp import MegatronLLMDeployable +except Exception as e: + LOGGER.warning(f"Cannot import MegatronLLMDeployable, it will not be available. {type(e).__name__}: {e}") + megatron_llm_supported = False + +trt_llm_supported = True +try: + from nemo.export.tensorrt_llm import TensorRTLLM +except Exception as e: + LOGGER.warning(f"Cannot import the TensorRTLLM exporter, it will not be available. {type(e).__name__}: {e}") + trt_llm_supported = False + def get_args(argv): parser = argparse.ArgumentParser( @@ -63,6 +77,8 @@ def get_args(argv): "-tmr", "--triton_model_repository", default=None, type=str, help="Folder for the trt-llm conversion" ) parser.add_argument("-ng", "--num_gpus", default=1, type=int, help="Number of GPUs for the deployment") + parser.add_argument("-tps", "--tensor_parallelism_size", default=1, type=int, help="Tensor parallelism size") + parser.add_argument("-pps", "--pipeline_parallelism_size", default=1, type=int, help="Pipeline parallelism size") parser.add_argument( "-dt", "--dtype", @@ -89,6 +105,13 @@ def get_args(argv): action='store_true', help="Disables the remove input padding option.", ) + parser.add_argument( + "-upe", + "--use_parallel_embedding", + default=False, + action='store_true', + help='Use parallel embedding feature of TensorRT-LLM.', + ) parser.add_argument( "-mbm", '--multi_block_mode', @@ -146,11 +169,21 @@ def get_args(argv): nargs='?', const=None, default='TensorRT-LLM', - choices=['TensorRT-LLM', 'vLLM', 'In-Framework'], + choices=['TensorRT-LLM', 'In-Framework'], help="Different options to deploy nemo model.", ) + parser.add_argument( + "-srs", + "--start_rest_service", + default="False", + type=str, + help="Starts the REST service for OpenAI API support", + ) + parser.add_argument( + "-sha", "--service_http_address", default="0.0.0.0", type=str, help="HTTP address for the REST Service" + ) + parser.add_argument("-sp", "--service_port", default=8080, type=int, help="Port for the REST Service") parser.add_argument("-dm", "--debug_mode", default=False, action='store_true', help="Enable debug mode") - args = parser.parse_args(argv) return args @@ -160,8 +193,8 @@ def get_trtllm_deployable(args): trt_llm_path = "/tmp/trt_llm_model_dir/" LOGGER.info( "/tmp/trt_llm_model_dir/ path will be used as the TensorRT LLM folder. " - "Please set this parameter if you'd like to use a path that has already " - "included the TensorRT LLM model files." + "Please set the --triton_model_repository parameter if you'd like to use a path that already " + "includes the TensorRT LLM model files." ) Path(trt_llm_path).mkdir(parents=True, exist_ok=True) else: @@ -204,6 +237,11 @@ def get_trtllm_deployable(args): "There are {0} tables and {1} task ids.".format(len(ptuning_tables_files), len(args.task_ids)) ) + if args.start_rest_service: + if args.service_port == args.triton_port: + logging.error("REST service port and Triton server port cannot use the same port.") + return + trt_llm_exporter = TensorRTLLM( model_dir=trt_llm_path, lora_ckpt_list=args.lora_ckpt, @@ -218,13 +256,14 @@ def get_trtllm_deployable(args): nemo_checkpoint_path=args.nemo_checkpoint, model_type=args.model_type, n_gpus=args.num_gpus, - tensor_parallel_size=args.num_gpus, - pipeline_parallel_size=1, + tensor_parallelism_size=args.tensor_parallelism_size, + pipeline_parallelism_size=args.pipeline_parallelism_size, max_input_len=args.max_input_len, max_output_len=args.max_output_len, max_batch_size=args.max_batch_size, max_num_tokens=args.max_num_tokens, opt_num_tokens=args.opt_num_tokens, + use_parallel_embedding=args.use_parallel_embedding, max_prompt_embedding_table_size=args.max_prompt_embedding_table_size, paged_kv_cache=(not args.no_paged_kv_cache), remove_input_padding=(not args.disable_remove_input_padding), @@ -233,7 +272,6 @@ def get_trtllm_deployable(args): use_lora_plugin=args.use_lora_plugin, lora_target_modules=args.lora_target_modules, max_lora_rank=args.max_lora_rank, - save_nemo_model_config=True, ) except Exception as error: raise RuntimeError("An error has occurred during the model export. Error message: " + str(error)) @@ -282,11 +320,13 @@ def nemo_deploy(argv): backend = args.backend.lower() if backend == 'tensorrt-llm': + if not trt_llm_supported: + raise ValueError("TensorRT-LLM engine is not supported in this environment.") triton_deployable = get_trtllm_deployable(args) elif backend == 'in-framework': + if not megatron_llm_supported: + raise ValueError("MegatronLLMDeployable is not supported in this environment.") triton_deployable = get_nemo_deployable(args) - elif backend == 'vllm': - raise ValueError("vLLM will be supported in the next release.") else: raise ValueError("Backend: {0} is not supported.".format(backend)) @@ -309,11 +349,21 @@ def nemo_deploy(argv): try: LOGGER.info("Model serving on Triton is will be started.") + if args.start_rest_service == "True": + try: + LOGGER.info("REST service will be started.") + uvicorn.run( + 'nemo.deploy.service.rest_model_api:app', + host=args.service_http_address, + port=args.service_port, + reload=True, + ) + except Exception as error: + logging.error("Error message has occurred during REST service start. Error message: " + str(error)) nm.serve() except Exception as error: LOGGER.error("Error message has occurred during deploy function. Error message: " + str(error)) return - LOGGER.info("Model serving will be stopped.") nm.stop() diff --git a/scripts/deploy/nlp/deploy_vllm_triton.py b/scripts/deploy/nlp/deploy_vllm_triton.py new file mode 100755 index 0000000000000..a6a861575f698 --- /dev/null +++ b/scripts/deploy/nlp/deploy_vllm_triton.py @@ -0,0 +1,172 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +import sys +import tempfile + +from nemo.deploy import DeployPyTriton + +LOGGER = logging.getLogger("NeMo") + +try: + from nemo.export.vllm_exporter import vLLMExporter +except Exception as e: + LOGGER.error(f"Cannot import the vLLM exporter. {type(e).__name__}: {e}") + sys.exit(1) + + +def get_args(argv): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description=f"Export NeMo models to vLLM and deploy them on Triton", + ) + parser.add_argument("-nc", "--nemo_checkpoint", type=str, help="Source .nemo file") + parser.add_argument( + "-mt", + "--model_type", + type=str, + required=False, + choices=["llama", "mistral", "mixtral", "starcoder2", "gemma"], + help="Type of the model", + ) + parser.add_argument("-tmn", "--triton_model_name", required=True, type=str, help="Name for the service") + parser.add_argument("-tmv", "--triton_model_version", default=1, type=int, help="Version for the service") + parser.add_argument( + "-trp", "--triton_port", default=8000, type=int, help="Port for the Triton server to listen for requests" + ) + parser.add_argument( + "-tha", "--triton_http_address", default="0.0.0.0", type=str, help="HTTP address for the Triton server" + ) + parser.add_argument( + "-tmr", "--triton_model_repository", default=None, type=str, help="Folder for the vLLM conversion" + ) + parser.add_argument("-tps", "--tensor_parallelism_size", default=1, type=int, help="Tensor parallelism size") + parser.add_argument( + "-dt", + "--dtype", + choices=["bfloat16", "float16", "fp8", "int8"], + default="bfloat16", + type=str, + help="dtype of the model on TensorRT-LLM or vLLM", + ) + parser.add_argument( + "-mml", "--max_model_len", default=512, type=int, help="Max input + ouptut length of the model" + ) + parser.add_argument("-mbs", "--max_batch_size", default=8, type=int, help="Max batch size of the model") + parser.add_argument( + "-es", '--enable_streaming', default=False, action='store_true', help="Enables streaming sentences." + ) + parser.add_argument("-dm", "--debug_mode", default=False, action='store_true', help="Enable debug mode") + parser.add_argument( + '-ws', + '--weight_storage', + default='auto', + choices=['auto', 'cache', 'file', 'memory'], + help='Strategy for storing converted weights for vLLM: "file" - always write weights into a file, ' + '"memory" - always do an in-memory conversion, "cache" - reuse existing files if they are ' + 'newer than the nemo checkpoint, "auto" - use "cache" for multi-GPU runs and "memory" ' + 'for single-GPU runs.', + ) + parser.add_argument( + "-gmu", + '--gpu_memory_utilization', + default=0.9, + type=float, + help="GPU memory utilization percentage for vLLM.", + ) + args = parser.parse_args(argv) + return args + + +def get_vllm_deployable(args): + tempdir = None + model_dir = args.triton_model_repository + if model_dir is None: + tempdir = tempfile.TemporaryDirectory() + model_dir = tempdir.name + LOGGER.info( + f"{model_dir} path will be used as the vLLM intermediate folder. " + + "Please set the --triton_model_repository parameter if you'd like to use a path that already " + + "includes the vLLM model files." + ) + elif not os.path.exists(model_dir): + os.makedirs(model_dir) + + try: + exporter = vLLMExporter() + exporter.export( + nemo_checkpoint=args.nemo_checkpoint, + model_dir=model_dir, + model_type=args.model_type, + tensor_parallel_size=args.tensor_parallelism_size, + max_model_len=args.max_model_len, + dtype=args.dtype, + weight_storage=args.weight_storage, + gpu_memory_utilization=args.gpu_memory_utilization, + ) + return exporter + except Exception as error: + raise RuntimeError("An error has occurred during the model export. Error message: " + str(error)) + finally: + if tempdir is not None: + tempdir.cleanup() + + +def nemo_deploy(argv): + args = get_args(argv) + + if args.debug_mode: + loglevel = logging.DEBUG + else: + loglevel = logging.INFO + + LOGGER.setLevel(loglevel) + LOGGER.info("Logging level set to {}".format(loglevel)) + LOGGER.info(args) + + triton_deployable = get_vllm_deployable(args) + + try: + nm = DeployPyTriton( + model=triton_deployable, + triton_model_name=args.triton_model_name, + triton_model_version=args.triton_model_version, + max_batch_size=args.max_batch_size, + port=args.triton_port, + address=args.triton_http_address, + streaming=args.enable_streaming, + ) + + LOGGER.info("Triton deploy function will be called.") + nm.deploy() + except Exception as error: + LOGGER.error("Error message has occurred during deploy function. Error message: " + str(error)) + return + + try: + LOGGER.info("Model serving on Triton is will be started.") + nm.serve() + except Exception as error: + LOGGER.error("Error message has occurred during deploy function. Error message: " + str(error)) + return + + LOGGER.info("Model serving will be stopped.") + nm.stop() + + +if __name__ == '__main__': + nemo_deploy(sys.argv[1:]) diff --git a/scripts/deploy/nlp/query_inframework.py b/scripts/deploy/nlp/query_inframework.py new file mode 100644 index 0000000000000..e77ab72a1f04b --- /dev/null +++ b/scripts/deploy/nlp/query_inframework.py @@ -0,0 +1,83 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import sys + +from nemo.deploy.nlp.query_llm import NemoQueryLLMPyTorch + + +def get_args(argv): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description=f"Queries Triton server running an in-framework Nemo model", + ) + parser.add_argument("-u", "--url", default="0.0.0.0", type=str, help="url for the triton server") + parser.add_argument("-mn", "--model_name", required=True, type=str, help="Name of the triton model") + prompt_group = parser.add_mutually_exclusive_group(required=True) + prompt_group.add_argument("-p", "--prompt", required=False, type=str, help="Prompt") + prompt_group.add_argument("-pf", "--prompt_file", required=False, type=str, help="File to read the prompt from") + parser.add_argument("-mol", "--max_output_len", default=128, type=int, help="Max output token length") + parser.add_argument("-tk", "--top_k", default=1, type=int, help="top_k") + parser.add_argument("-tpp", "--top_p", default=0.0, type=float, help="top_p") + parser.add_argument("-t", "--temperature", default=1.0, type=float, help="temperature") + parser.add_argument("-it", "--init_timeout", default=60.0, type=float, help="init timeout for the triton server") + + args = parser.parse_args(argv) + return args + + +def query_llm( + url, + model_name, + prompts, + max_output_len=128, + top_k=1, + top_p=0.0, + temperature=1.0, + init_timeout=60.0, +): + nemo_query = NemoQueryLLMPyTorch(url, model_name) + return nemo_query.query_llm( + prompts=prompts, + max_length=max_output_len, + top_k=top_k, + top_p=top_p, + temperature=temperature, + init_timeout=init_timeout, + ) + + +def query(argv): + args = get_args(argv) + + if args.prompt_file is not None: + with open(args.prompt_file, "r") as f: + args.prompt = f.read() + + outputs = query_llm( + url=args.url, + model_name=args.model_name, + prompts=[args.prompt], + max_output_len=args.max_output_len, + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, + init_timeout=args.init_timeout, + ) + print(outputs["sentences"][0][0]) + + +if __name__ == '__main__': + query(sys.argv[1:]) diff --git a/scripts/export/export_to_trt_llm.py b/scripts/export/export_to_trt_llm.py index a0c70c8bbd857..a9b9d92c172b2 100644 --- a/scripts/export/export_to_trt_llm.py +++ b/scripts/export/export_to_trt_llm.py @@ -16,7 +16,7 @@ import logging import sys -from nemo.export import TensorRTLLM +from nemo.export.tensorrt_llm import TensorRTLLM LOGGER = logging.getLogger("NeMo") @@ -40,8 +40,8 @@ def get_args(argv): "-mr", "--model_repository", required=True, default=None, type=str, help="Folder for the trt-llm model files" ) parser.add_argument("-ng", "--num_gpus", default=1, type=int, help="Number of GPUs for the deployment") - parser.add_argument("-tps", "--tensor_parallelism_size", type=int, help="Tensor parallelism size") - parser.add_argument("-pps", "--pipeline_parallelism_size", type=int, help="Pipeline parallelism size") + parser.add_argument("-tps", "--tensor_parallelism_size", default=1, type=int, help="Tensor parallelism size") + parser.add_argument("-pps", "--pipeline_parallelism_size", default=1, type=int, help="Pipeline parallelism size") parser.add_argument( "-dt", "--dtype", @@ -138,8 +138,8 @@ def nemo_export_trt_llm(argv): nemo_checkpoint_path=args.nemo_checkpoint, model_type=args.model_type, n_gpus=args.num_gpus, - tensor_parallel_size=args.tensor_parallelism_size, - pipeline_parallel_size=args.pipeline_parallelism_size, + tensor_parallelism_size=args.tensor_parallelism_size, + pipeline_parallelism_size=args.pipeline_parallelism_size, max_input_len=args.max_input_len, max_output_len=args.max_output_len, max_batch_size=args.max_batch_size, @@ -153,7 +153,6 @@ def nemo_export_trt_llm(argv): use_lora_plugin=args.use_lora_plugin, lora_target_modules=args.lora_target_modules, max_lora_rank=args.max_lora_rank, - save_nemo_model_config=True, ) LOGGER.info("Export is successful.") diff --git a/scripts/nlp_language_modeling/preprocess_data_for_megatron.py b/scripts/nlp_language_modeling/preprocess_data_for_megatron.py index 945b9e7b68a2b..e1f89182279b2 100644 --- a/scripts/nlp_language_modeling/preprocess_data_for_megatron.py +++ b/scripts/nlp_language_modeling/preprocess_data_for_megatron.py @@ -104,6 +104,7 @@ except ImportError: nltk_available = False + # https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars): @@ -221,10 +222,16 @@ def get_args(): help='What tokenizer library to use.', ) group.add_argument( - '--tokenizer-type', type=str, default=None, help='What type of tokenizer to use.', + '--tokenizer-type', + type=str, + default=None, + help='What type of tokenizer to use.', ) group.add_argument( - '--tokenizer-model', type=str, default=None, help='Path to tokenizer model.', + '--tokenizer-model', + type=str, + default=None, + help='Path to tokenizer model.', ) group.add_argument('--vocab-file', type=str, default=None, help='Path to the vocab file') group.add_argument('--files-filter', type=str, default='**/*.json*', help='files filter str') @@ -248,7 +255,7 @@ def get_args(): group.add_argument( '--preproc-folder', action='store_true', - help='If set, will preprocess all .json or .json.gz files into a single .bin and .idx file. Folder path provided via the --input arg', + help='If set, will preprocess all .json or .jsonl or json.gz or .jsonl.gz files into a single .bin and .idx file. Folder path provided via the --input arg', ) group.add_argument('--apply-ftfy', action='store_true', help='If set, will apply ftfy to the input text') args = parser.parse_args() @@ -272,14 +279,18 @@ def main(): args = get_args() startup_start = time.time() if args.preproc_folder: - print('Searching folder for .json or .json.gz files...') + print('Searching folder for .json or .jsonl or json.gz or .jsonl.gz files...') assert os.path.exists(args.input), f'Folder does not exist: {args.input}' json_files = (str(f) for f in pathlib.Path(args.input).glob(args.files_filter)) - json_files = [f for f in json_files if f.endswith('.json') or f.endswith('.json.gz')] + json_files = [ + f + for f in json_files + if f.endswith('.json') or f.endswith('.jsonl') or f.endswith('.json.gz') or f.endswith('.jsonl.gz') + ] if len(json_files) == 0: - raise FileNotFoundError('No .json or .json.gz files found in folder.') + raise FileNotFoundError('No .json or .jsonl or json.gz or .jsonl.gz files found in folder.') else: - print(f'Found {len(json_files)} .json or .json.gz files.') + print(f'Found {len(json_files)} .json or .jsonl or json.gz or .jsonl.gz files.') else: assert os.path.exists(args.input), f'File does not exist: {args.input}' json_files = [args.input] diff --git a/setup.py b/setup.py index 180e5ab4f0833..292be13e65df8 100644 --- a/setup.py +++ b/setup.py @@ -90,6 +90,7 @@ def req_file(filename, folder="requirements"): 'tts': req_file("requirements_tts.txt"), 'slu': req_file("requirements_slu.txt"), 'multimodal': req_file("requirements_multimodal.txt"), + 'audio': req_file("requirements_audio.txt"), } @@ -135,6 +136,7 @@ def req_file(filename, folder="requirements"): ] ) ) +extras_require['audio'] = list(chain([extras_require['audio'], extras_require['core'], extras_require['common']])) # TTS has extra dependencies extras_require['tts'] = list(chain([extras_require['tts'], extras_require['asr']])) @@ -284,4 +286,9 @@ def finalize_options(self): keywords=__keywords__, # Custom commands. cmdclass={'style': StyleCommand}, + entry_points={ + "sdk.factories": [ + "llm = nemo.collections.llm", + ], + }, ) diff --git a/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py b/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py index c520bd4c12925..cac1eb2fcdf38 100644 --- a/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py +++ b/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py @@ -12,12 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import pytest import torch from omegaconf import DictConfig, ListConfig, OmegaConf -from nemo.collections.asr.models import ASRModel, EncDecCTCModel, EncDecRNNTModel -from nemo.collections.asr.parts.submodules.adapters import multi_head_attention_adapter_module +from nemo.collections.asr.models import ASRModel, EncDecCTCModel, EncDecMultiTaskModel, EncDecRNNTModel +from nemo.collections.asr.parts.submodules.adapters import ( + multi_head_attention_adapter_module, + transformer_multi_head_attention_adapter_module, +) from nemo.collections.asr.parts.utils import adapter_utils from nemo.collections.common.parts import adapter_modules from nemo.core.classes.mixins.access_mixins import AccessMixin @@ -286,8 +291,130 @@ def rnnt_model(): return model_instance +@pytest.fixture() +def multitask_model(test_data_dir): + preprocessor = {'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', 'params': dict({})} + + # fmt: off + tokenizer = { + 'dir': None, + 'type': 'agg', + 'langs': { + 'spl_tokens': { + 'dir': os.path.join(test_data_dir, 'asr', 'tokenizers', 'canary'), + 'type': 'bpe', + }, + 'en': { + 'dir': os.path.join(test_data_dir, 'asr', 'tokenizers', 'an4_spe_128'), + 'type': 'bpe', + } + }, + 'custom_tokenizer': { + '_target_': 'nemo.collections.common.tokenizers.canary_tokenizer.CanaryTokenizer', + 'tokenizers': None, + } + } + # fmt: on + + model_defaults = {"asr_enc_hidden": 128, "lm_enc_hidden": 128, "lm_dec_hidden": 128} + + # Test case where Encoder (default) is not adapter compatible + encoder = { + '_target_': 'nemo.collections.asr.modules.ConformerEncoder', + 'feat_in': 64, + 'feat_out': -1, + 'n_layers': 2, + 'd_model': 128, + 'subsampling': 'striding', + 'subsampling_factor': 4, + 'self_attention_model': 'rel_pos', + 'n_heads': 4, + 'conv_kernel_size': 31, + } + + transf_encoder = { + "_target_": "nemo.collections.asr.modules.transformer.transformer_encoders.TransformerEncoder", + "num_layers": 1, + "hidden_size": "${model_defaults.lm_enc_hidden}", + "inner_size": int(4 * model_defaults['lm_enc_hidden']), + "num_attention_heads": 8, + "ffn_dropout": 0.1, + "attn_score_dropout": 0.1, + "attn_layer_dropout": 0.1, + "mask_future": False, + "pre_ln": True, + "pre_ln_final_layer_norm": True, + } + + transf_decoder = { + "_target_": "nemo.collections.asr.modules.transformer.get_nemo_transformer", + "model_name": None, + "pretrained": False, + "encoder": None, + "pre_ln_final_layer_norm": True, + "config_dict": { + "max_sequence_length": 512, + "num_token_types": 0, + "embedding_dropout": 0.1, + "learn_positional_encodings": False, + "hidden_size": "${model_defaults.lm_dec_hidden}", + "inner_size": "${multiply:${model_defaults.lm_dec_hidden}, 4}", + "num_layers": 2, + "num_attention_heads": 8, + "ffn_dropout": 0.1, + "attn_score_dropout": 0.1, + "attn_layer_dropout": 0.1, + "hidden_act": "relu", + "pre_ln": True, + "vocab_size": None, # Will be set by the model at runtime + "adapter": True, # Add support for adapter class + }, + } + + head = { + "_target_": "nemo.collections.asr.parts.submodules.token_classifier.TokenClassifier", + "num_layers": 1, + "activation": "relu", + "log_softmax": True, + "hidden_size": "${transf_decoder.config_dict.hidden_size}", + "num_classes": None, # Will be set by the model at runtime + "dropout": 0.0, + "use_transformer_init": True, + } + + decoding = {'strategy': 'beam', 'beam': {'beam_size': 1, 'len_pen': 0.0, 'max_generation_delta': 50}} + + loss = { + "_target_": "nemo.collections.common.losses.smoothed_cross_entropy.SmoothedCrossEntropyLoss", + "label_smoothing": 0.0, + "pad_id": None, + } + + modelConfig = DictConfig( + { + 'sample_rate': 16000, + 'prompt_format': 'canary', + 'preprocessor': DictConfig(preprocessor), + 'model_defaults': DictConfig(model_defaults), + 'tokenizer': DictConfig(tokenizer), + 'encoder': DictConfig(encoder), + 'transf_encoder': DictConfig(transf_encoder), + 'transf_decoder': DictConfig(transf_decoder), + 'head': DictConfig(head), + 'decoding': DictConfig(decoding), + 'loss': DictConfig(loss), + } + ) + + model_instance = EncDecMultiTaskModel(cfg=modelConfig) + + # Execute the model class swap logic + model_instance.replace_adapter_compatible_modules() + return model_instance + + def get_adapter_cfg(in_features=50, dim=100, norm_pos='pre', atype='linear', **kwargs): - valid_types = ['linear', 'mha', 'relmha'] + valid_types = ['linear', 'mha', 'relmha', 'transf_mha'] if atype not in valid_types: raise ValueError(f"Invalid type. Valid types = {atype}") @@ -295,7 +422,15 @@ def get_adapter_cfg(in_features=50, dim=100, norm_pos='pre', atype='linear', **k cfg = adapter_modules.LinearAdapterConfig(in_features=in_features, dim=dim, norm_position=norm_pos) elif atype == 'mha': cfg = multi_head_attention_adapter_module.MultiHeadAttentionAdapterConfig( - n_head=kwargs.get('n_head', 1), n_feat=in_features + n_head=kwargs.get('n_head', 1), + n_feat=in_features, + proj_dim=kwargs.get('proj_dim', None), + ) + elif atype == 'transf_mha': + cfg = transformer_multi_head_attention_adapter_module.TransformerMultiHeadAttentionAdapterConfig( + num_attention_heads=kwargs.get('n_head', 1), + hidden_size=in_features, + proj_dim=kwargs.get('proj_dim', None), ) elif atype == 'relmha': cfg = multi_head_attention_adapter_module.RelPositionMultiHeadAttentionAdapterConfig( @@ -375,12 +510,14 @@ def test_asr_model_constructor_joint_module_ctc_skip(self, model): original_num_params = model.num_weights # this step should exit without adding adapters and without errors - model.add_adapter(name='joint:adapter_0', cfg=get_adapter_cfg()) + with pytest.raises(ValueError): + model.add_adapter(name='joint:adapter_0', cfg=get_adapter_cfg()) new_num_params = model.num_weights assert new_num_params == original_num_params @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_asr_model_constructor_joint_module_rnnt(self, rnnt_model): @@ -467,6 +604,74 @@ def test_squeezeformer_forward_mha(self, squeezeformer_ctc_adapter, name): assert torch.mean(torch.abs(origial_output - new_output)) < 1e-5 + @pytest.mark.unit + @pytest.mark.parametrize('adapter_type', ['linear', 'attn']) + @pytest.mark.parametrize( + 'name', ['adapter_0', 'encoder:adapter_0', 'transf_encoder:adapter_0', 'transf_decoder:adapter_0'] + ) + def test_canary_forward_mha(self, multitask_model, name, adapter_type): + multitask_model.eval() + torch.random.manual_seed(0) + input_signal = torch.randn(2, 512) + input_signal_length = torch.tensor([512, 512], dtype=torch.int32) + transcript = torch.randint(0, multitask_model.tokenizer.vocab_size, size=(2, 10)) + transcript_len = torch.tensor([10, 9], dtype=torch.int32) + + origial_output = multitask_model( + input_signal=input_signal, + input_signal_length=input_signal_length, + transcript=transcript, + transcript_length=transcript_len, + ) + og_logprob = origial_output[0] + og_enc_out = origial_output[2] + + if adapter_type == 'attn': + adapter_type = 'transf_mha' if 'transf' in name else 'mha' + + multitask_model.add_adapter(name=name, cfg=get_adapter_cfg(in_features=128, atype=adapter_type, proj_dim=4)) + + new_output = multitask_model( + input_signal=input_signal, + input_signal_length=input_signal_length, + transcript=transcript, + transcript_length=transcript_len, + ) + + new_logprob = new_output[0] + new_enc_out = new_output[2] + + assert torch.mean(torch.abs(og_logprob - new_logprob)) < 1e-5 + assert torch.mean(torch.abs(og_enc_out - new_enc_out)) < 1e-5 + + if 'linear' in adapter_type: + mod_name = name.split(":")[-1] + for mod in multitask_model.modules(): + if isinstance(mod, AdapterModuleMixin): + amodule = mod.get_adapter_module(mod_name) + if amodule is not None: + assert isinstance(amodule, adapter_modules.LinearAdapter) + + # Try to use incorrect adapter + with pytest.raises(ValueError): + multitask_model.add_adapter( + name="transf_encoder:adapter_1", cfg=get_adapter_cfg(in_features=128, atype='mha') + ) + + @pytest.mark.unit + @pytest.mark.parametrize('name', ['transf_decoder:adapter_0']) + def test_canary_forward_mha_decoder_fails_without_support(self, multitask_model, name): + multitask_model.eval() + torch.random.manual_seed(0) + + # Change internal class of transf_decoder module + adapter_class = multitask_model.transf_decoder.__class__ + multitask_model.transf_decoder.__class__ = get_registered_adapter(adapter_class).base_class + + with pytest.raises(AttributeError): + adapter_type = 'transf_mha' if 'transf' in name else 'mha' + multitask_model.add_adapter(name=name, cfg=get_adapter_cfg(in_features=128, atype=adapter_type)) + @pytest.mark.unit @pytest.mark.parametrize('name1', ['adapter_0', 'encoder:adapter_0', 'decoder:adapter_0']) @pytest.mark.parametrize('name2', ['adapter_1', 'encoder:adapter_1', 'decoder:adapter_1']) @@ -488,7 +693,8 @@ def test_asr_multi_adapter_forward(self, model, name1, name2): assert torch.mean(torch.abs(origial_output - new_output)) < 1e-5 @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.parametrize('name1', ['decoder:adapter_0', 'joint:adapter_0']) @pytest.mark.parametrize('name2', ['decoder:adapter_1', 'joint:adapter_1']) @@ -582,7 +788,8 @@ def test_constructor_pretrained(self): assert model.num_weights < 1e5 @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.with_downloads() @pytest.mark.unit diff --git a/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py b/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py index c4ee4b97a2a6a..ffaf1e640f3e8 100644 --- a/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py +++ b/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py @@ -111,6 +111,22 @@ def test_rel_pos_encoding_adapter_config(self): assert cls_subset is None assert dataclass_subset is None + @pytest.mark.unit + def test_transformer_mha_adapter_config(self): + IGNORED_ARGS = ['_target_'] + + result = config_utils.assert_dataclass_signature_match( + adapter_modules.TransformerMultiHeadAttentionAdapter, + adapter_modules.TransformerMultiHeadAttentionAdapterConfig, + ignore_args=IGNORED_ARGS, + ) + + signatures_match, cls_subset, dataclass_subset = result + + assert signatures_match + assert cls_subset is None + assert dataclass_subset is None + @pytest.mark.unit @pytest.mark.parametrize('n_head', [1, 2, 10]) @pytest.mark.parametrize('proj_dim', [None, -1]) @@ -194,6 +210,31 @@ def test_relpos_encoding_init(self): assert (out - x).sum().abs() <= 1e-8 assert out.shape == x.shape + @pytest.mark.unit + @pytest.mark.parametrize('n_head', [1, 2, 10]) + @pytest.mark.parametrize('proj_dim', [None, -1]) + def test_transformer_mha_adapter_init(self, n_head, proj_dim): + torch.random.manual_seed(0) + x = torch.randn(2, 32, 50) + lengths = torch.randint(1, x.size(1), size=(x.size(0),)) + lengths[torch.randint(0, x.size(0), size=(1,))[0]] = x.size(1) + + adapter = adapter_modules.TransformerMultiHeadAttentionAdapter( + num_attention_heads=n_head, hidden_size=50, attn_layer_dropout=0.0, proj_dim=proj_dim + ) + + pad_mask, att_mask = get_mask(lengths) + att_mask = att_mask.unsqueeze(1) + + with torch.no_grad(): + assert adapter.out_projection.weight.sum() == 0 + if hasattr(adapter.out_projection, 'bias') and adapter.out_projection.bias is not None: + assert adapter.out_projection.bias.sum() == 0 + + out = adapter(x, x, x, att_mask) + assert out.sum().abs() <= 1e-8 + assert out.shape == x.shape + @pytest.mark.unit def test_mha_adapter_strategy(self): adapter = adapter_modules.MultiHeadAttentionAdapter(n_head=1, n_feat=50, dropout_rate=0.0) @@ -225,3 +266,13 @@ def test_relpos_encoding_adapter_strategy(self): assert adapter.adapter_strategy is not None # assert default strategy is set assert isinstance(adapter.adapter_strategy, adapter_mixin_strategies.ReturnResultAdapterStrategy) + + @pytest.mark.unit + def test_transformer_mha_adapter_strategy(self): + adapter = adapter_modules.TransformerMultiHeadAttentionAdapter( + num_attention_heads=1, hidden_size=50, attn_layer_dropout=0.0 + ) + assert hasattr(adapter, 'adapter_strategy') + assert adapter.adapter_strategy is not None + # assert default strategy is set + assert isinstance(adapter.adapter_strategy, adapter_modules.MHAResidualAddAdapterStrategy) diff --git a/tests/collections/asr/test_asr_datasets.py b/tests/collections/asr/test_asr_datasets.py index a2e39628e4cb0..d5c5be8b44ad8 100644 --- a/tests/collections/asr/test_asr_datasets.py +++ b/tests/collections/asr/test_asr_datasets.py @@ -26,15 +26,7 @@ from omegaconf import DictConfig, OmegaConf from torch.utils.data import DataLoader -from nemo.collections.asr.data import audio_to_audio_dataset, audio_to_text_dataset -from nemo.collections.asr.data.audio_to_audio import ( - ASRAudioProcessor, - AudioToTargetDataset, - AudioToTargetWithEmbeddingDataset, - AudioToTargetWithReferenceDataset, - _audio_collate_fn, -) -from nemo.collections.asr.data.audio_to_audio_lhotse import LhotseAudioToTargetDataset, convert_manifest_nemo_to_lhotse +from nemo.collections.asr.data import audio_to_text_dataset from nemo.collections.asr.data.audio_to_text import ( DataStoreObject, TarredAudioToBPEDataset, @@ -50,7 +42,6 @@ from nemo.collections.asr.data.audio_to_text_dataset import inject_dataloader_value_from_model_config from nemo.collections.asr.data.feature_to_text import FeatureToBPEDataset, FeatureToCharDataset from nemo.collections.asr.models.ctc_models import EncDecCTCModel -from nemo.collections.asr.parts.utils.audio_utils import get_segment_start from nemo.collections.asr.parts.utils.manifest_utils import write_manifest from nemo.collections.common import tokenizers from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config @@ -141,7 +132,7 @@ def test_tarred_dataset(self, test_data_dir): @pytest.mark.unit def test_tarred_dataset_filter(self, test_data_dir): """ - Checks for + Checks for 1. file count when manifest len is less than tarred dataset 2. Ignoring files in manifest that are not in tarred balls @@ -431,7 +422,9 @@ def test_dali_char_vs_ref_dataset(self, test_data_dir): world_size=1, preprocessor_cfg=preprocessor_cfg, ) - ref_dataset = audio_to_text_dataset.get_char_dataset(config=dataset_cfg,) + ref_dataset = audio_to_text_dataset.get_char_dataset( + config=dataset_cfg, + ) ref_dataloader = DataLoader( dataset=ref_dataset, batch_size=batch_size, @@ -785,1134 +778,11 @@ def test_feature_with_rttm_to_text_bpe_dataset(self, test_data_dir): assert cnt == num_samples -class TestAudioDatasets: - @pytest.mark.unit - @pytest.mark.parametrize('num_channels', [1, 2]) - @pytest.mark.parametrize('num_targets', [1, 3]) - def test_list_to_multichannel(self, num_channels, num_targets): - """Test conversion of a list of arrays into - """ - random_seed = 42 - num_samples = 1000 - - # Generate random signals - _rng = np.random.default_rng(seed=random_seed) - - # Multi-channel signal - golden_target = _rng.normal(size=(num_channels * num_targets, num_samples)) - - # Create a list of num_targets signals with num_channels channels - target_list = [golden_target[n * num_channels : (n + 1) * num_channels, :] for n in range(num_targets)] - - # Check the original signal is not modified - assert (ASRAudioProcessor.list_to_multichannel(golden_target) == golden_target).all() - # Check the list is converted back to the original signal - assert (ASRAudioProcessor.list_to_multichannel(target_list) == golden_target).all() - - @pytest.mark.unit - @pytest.mark.parametrize('num_channels', [1, 2]) - def test_processor_process_audio(self, num_channels): - """Test signal normalization in process_audio. - """ - num_samples = 1000 - num_examples = 30 - - signals = ['input_signal', 'target_signal', 'reference_signal'] - - for normalization_signal in [None] + signals: - # Create processor - processor = ASRAudioProcessor( - sample_rate=16000, random_offset=False, normalization_signal=normalization_signal - ) - - # Generate random signals - for n in range(num_examples): - example = {signal: torch.randn(num_channels, num_samples) for signal in signals} - processed_example = processor.process_audio(example) - - # Expected scale - if normalization_signal: - scale = 1.0 / (example[normalization_signal].abs().max() + processor.eps) - else: - scale = 1.0 - - # Make sure all signals are scaled as expected - for signal in signals: - assert torch.allclose( - processed_example[signal], example[signal] * scale - ), f'Failed example {n} signal {signal}' - - @pytest.mark.unit - def test_audio_collate_fn(self): - """Test `_audio_collate_fn` - """ - batch_size = 16 - random_seed = 42 - atol = 1e-5 - - # Generate random signals - _rng = np.random.default_rng(seed=random_seed) - - signal_to_channels = { - 'input_signal': 2, - 'target_signal': 1, - 'reference_signal': 1, - } - - signal_to_length = { - 'input_signal': _rng.integers(low=5, high=25, size=batch_size), - 'target_signal': _rng.integers(low=5, high=25, size=batch_size), - 'reference_signal': _rng.integers(low=5, high=25, size=batch_size), - } - - # Generate batch - batch = [] - for n in range(batch_size): - item = dict() - for signal, num_channels in signal_to_channels.items(): - random_signal = _rng.normal(size=(num_channels, signal_to_length[signal][n])) - random_signal = np.squeeze(random_signal) # get rid of channel dimention for single-channel - item[signal] = torch.tensor(random_signal) - batch.append(item) - - # Run UUT - batched = _audio_collate_fn(batch) - - batched_signals = { - 'input_signal': batched[0].cpu().detach().numpy(), - 'target_signal': batched[2].cpu().detach().numpy(), - 'reference_signal': batched[4].cpu().detach().numpy(), - } - - batched_lengths = { - 'input_signal': batched[1].cpu().detach().numpy(), - 'target_signal': batched[3].cpu().detach().numpy(), - 'reference_signal': batched[5].cpu().detach().numpy(), - } - - # Check outputs - for signal, b_signal in batched_signals.items(): - for n in range(batch_size): - # Check length - uut_length = batched_lengths[signal][n] - golden_length = signal_to_length[signal][n] - assert ( - uut_length == golden_length - ), f'Example {n} signal {signal} length mismatch: batched ({uut_length}) != golden ({golden_length})' - - uut_signal = b_signal[n][:uut_length, ...] - golden_signal = batch[n][signal][:uut_length, ...].cpu().detach().numpy() - assert np.allclose( - uut_signal, golden_signal, atol=atol - ), f'Example {n} signal {signal} value mismatch.' - - @pytest.mark.unit - def test_audio_to_target_dataset(self): - """Test AudioWithTargetDataset in different configurations. - - Test below cover the following: - 1) no constraints - 2) filtering based on signal duration - 3) use with channel selector - 4) use with fixed audio duration and random subsegments - 5) collate a batch of items - - In this use case, each line of the manifest file has the following format: - ``` - { - 'input_filepath': 'path/to/input.wav', - 'target_filepath': 'path/to/path_to_target.wav', - 'duration': duration_of_input, - } - ``` - """ - # Data setup - random_seed = 42 - sample_rate = 16000 - num_examples = 25 - data_num_channels = { - 'input_signal': 4, - 'target_signal': 2, - } - data_min_duration = 2.0 - data_max_duration = 8.0 - data_key = { - 'input_signal': 'input_filepath', - 'target_signal': 'target_filepath', - } - - # Tolerance - atol = 1e-6 - - # Generate random signals - _rng = np.random.default_rng(seed=random_seed) - - # Input and target signals have the same duration - data_duration = np.round(_rng.uniform(low=data_min_duration, high=data_max_duration, size=num_examples), 3) - data_duration_samples = np.floor(data_duration * sample_rate).astype(int) - - data = dict() - for signal, num_channels in data_num_channels.items(): - data[signal] = [] - for n in range(num_examples): - if num_channels == 1: - random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_duration_samples[n])) - else: - random_signal = _rng.uniform(low=-0.5, high=0.5, size=(num_channels, data_duration_samples[n])) - data[signal].append(random_signal) - - with tempfile.TemporaryDirectory() as test_dir: - - # Build metadata for manifest - metadata = [] - - for n in range(num_examples): - - meta = dict() - - for signal in data: - # filenames - signal_filename = f'{signal}_{n:02d}.wav' - - # write audio files - sf.write(os.path.join(test_dir, signal_filename), data[signal][n].T, sample_rate, 'float') - - # update metadata - meta[data_key[signal]] = signal_filename - - meta['duration'] = data_duration[n] - metadata.append(meta) - - # Save manifest - manifest_filepath = os.path.join(test_dir, 'manifest.json') - write_manifest(manifest_filepath, metadata) - - # Test 1 - # - No constraints on channels or duration - dataset = AudioToTargetDataset( - manifest_filepath=manifest_filepath, - input_key=data_key['input_signal'], - target_key=data_key['target_signal'], - sample_rate=sample_rate, - ) - - # Also test the corresponding factory - config = { - 'manifest_filepath': manifest_filepath, - 'input_key': data_key['input_signal'], - 'target_key': data_key['target_signal'], - 'sample_rate': sample_rate, - } - dataset_factory = audio_to_audio_dataset.get_audio_to_target_dataset(config) - - # Prepare lhotse manifest - cuts_path = manifest_filepath.replace('.json', '_cuts.jsonl') - convert_manifest_nemo_to_lhotse( - input_manifest=manifest_filepath, - output_manifest=cuts_path, - input_key=data_key['input_signal'], - target_key=data_key['target_signal'], - ) - - # Prepare lhotse dataset - config_lhotse = { - 'cuts_path': cuts_path, - 'use_lhotse': True, - 'sample_rate': sample_rate, - 'batch_size': 1, - } - dl_lhotse = get_lhotse_dataloader_from_config( - OmegaConf.create(config_lhotse), global_rank=0, world_size=1, dataset=LhotseAudioToTargetDataset() - ) - dataset_lhotse = [item for item in dl_lhotse] - - # Test number of channels - for signal in data: - assert data_num_channels[signal] == dataset.num_channels( - signal - ), f'Num channels not correct for signal {signal}' - assert data_num_channels[signal] == dataset_factory.num_channels( - signal - ), f'Num channels not correct for signal {signal}' - - # Test returned examples - for n in range(num_examples): - for signal in data: - golden_signal = data[signal][n] - - for use_lhotse in [False, True]: - item_signal = ( - dataset_lhotse[n][signal].squeeze(0) if use_lhotse else dataset.__getitem__(n)[signal] - ) - item_factory_signal = dataset_factory.__getitem__(n)[signal] - - assert ( - item_signal.shape == golden_signal.shape - ), f'Test 1, use_lhotse={use_lhotse}: Signal {signal} item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' - assert np.allclose( - item_signal, golden_signal, atol=atol - ), f'Test 1, use_lhotse={use_lhotse}: Failed for example {n}, signal {signal} (random seed {random_seed})' - - assert np.allclose( - item_factory_signal, golden_signal, atol=atol - ), f'Test 1, use_lhotse={use_lhotse}: Failed for factory example {n}, signal {signal} (random seed {random_seed})' - - # Test 2 - # - Filtering based on signal duration - min_duration = 3.5 - max_duration = 7.5 - - dataset = AudioToTargetDataset( - manifest_filepath=manifest_filepath, - input_key=data_key['input_signal'], - target_key=data_key['target_signal'], - min_duration=min_duration, - max_duration=max_duration, - sample_rate=sample_rate, - ) - - # Prepare lhotse dataset - config_lhotse = { - 'cuts_path': cuts_path, - 'use_lhotse': True, - 'min_duration': min_duration, - 'max_duration': max_duration, - 'sample_rate': sample_rate, - 'batch_size': 1, - } - dl_lhotse = get_lhotse_dataloader_from_config( - OmegaConf.create(config_lhotse), global_rank=0, world_size=1, dataset=LhotseAudioToTargetDataset() - ) - dataset_lhotse = [item for item in dl_lhotse] - - filtered_examples = [n for n, val in enumerate(data_duration) if min_duration <= val <= max_duration] - - for n in range(len(dataset)): - for use_lhotse in [False, True]: - for signal in data: - item_signal = ( - dataset_lhotse[n][signal].squeeze(0) if use_lhotse else dataset.__getitem__(n)[signal] - ) - golden_signal = data[signal][filtered_examples[n]] - assert ( - item_signal.shape == golden_signal.shape - ), f'Test 2, use_lhotse={use_lhotse}: Signal {signal} item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' - - assert np.allclose( - item_signal, golden_signal, atol=atol - ), f'Test 2, use_lhotse={use_lhotse}: Failed for example {n}, signal {signal} (random seed {random_seed})' - - # Test 3 - # - Use channel selector - channel_selector = { - 'input_signal': [0, 2], - 'target_signal': 1, - } - - dataset = AudioToTargetDataset( - manifest_filepath=manifest_filepath, - input_key=data_key['input_signal'], - target_key=data_key['target_signal'], - input_channel_selector=channel_selector['input_signal'], - target_channel_selector=channel_selector['target_signal'], - sample_rate=sample_rate, - ) - - for n in range(len(dataset)): - item = dataset.__getitem__(n) - - for signal in data: - cs = channel_selector[signal] - item_signal = item[signal].cpu().detach().numpy() - golden_signal = data[signal][n][cs, ...] - assert ( - item_signal.shape == golden_signal.shape - ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' - assert np.allclose( - item_signal, golden_signal, atol=atol - ), f'Test 3: Failed for example {n}, signal {signal} (random seed {random_seed})' - - # Test 4 - # - Use fixed duration (random segment selection) - audio_duration = 4.0 - audio_duration_samples = int(np.floor(audio_duration * sample_rate)) - - filtered_examples = [n for n, val in enumerate(data_duration) if val >= audio_duration] - - for random_offset in [True, False]: - # Test subsegments with the default fixed offset and a random offset - - dataset = AudioToTargetDataset( - manifest_filepath=manifest_filepath, - input_key=data_key['input_signal'], - target_key=data_key['target_signal'], - sample_rate=sample_rate, - min_duration=audio_duration, - audio_duration=audio_duration, - random_offset=random_offset, # random offset when selecting subsegment - ) - - # Prepare lhotse dataset - config_lhotse = { - 'cuts_path': cuts_path, - 'use_lhotse': True, - 'min_duration': audio_duration, - 'truncate_duration': audio_duration, - 'truncate_offset_type': 'random' if random_offset else 'start', - 'sample_rate': sample_rate, - 'batch_size': 1, - } - dl_lhotse = get_lhotse_dataloader_from_config( - OmegaConf.create(config_lhotse), global_rank=0, world_size=1, dataset=LhotseAudioToTargetDataset() - ) - dataset_lhotse = [item for item in dl_lhotse] - - for n in range(len(dataset)): - for use_lhotse in [False, True]: - item = dataset_lhotse[n] if use_lhotse else dataset.__getitem__(n) - golden_start = golden_end = None - for signal in data: - item_signal = item[signal].squeeze(0) if use_lhotse else item[signal] - full_golden_signal = data[signal][filtered_examples[n]] - - # Find random segment using correlation on the first channel - # of the first signal, and then use it fixed for other signals - if golden_start is None: - golden_start = get_segment_start( - signal=full_golden_signal[0, :], segment=item_signal[0, :] - ) - if not random_offset: - assert ( - golden_start == 0 - ), f'Test 4, use_lhotse={use_lhotse}: Expecting the signal to start at 0 when random_offset is False' - - golden_end = golden_start + audio_duration_samples - golden_signal = full_golden_signal[..., golden_start:golden_end] - - # Test length is correct - assert ( - item_signal.shape[-1] == audio_duration_samples - ), f'Test 4, use_lhotse={use_lhotse}: Signal length ({item_signal.shape[-1]}) not matching the expected length ({audio_duration_samples})' - - assert ( - item_signal.shape == golden_signal.shape - ), f'Test 4, use_lhotse={use_lhotse}: Signal {signal} item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' - # Test signal values - assert np.allclose( - item_signal, golden_signal, atol=atol - ), f'Test 4, use_lhotse={use_lhotse}: Failed for example {n}, signal {signal} (random seed {random_seed})' - - # Test 5: - # - Test collate_fn - batch_size = 16 - - for use_lhotse in [False, True]: - if use_lhotse: - # Get batch from lhotse dataloader - config_lhotse['batch_size'] = batch_size - dl_lhotse = get_lhotse_dataloader_from_config( - OmegaConf.create(config_lhotse), - global_rank=0, - world_size=1, - dataset=LhotseAudioToTargetDataset(), - ) - batched = next(iter(dl_lhotse)) - else: - # Get examples from dataset and collate into a batch - batch = [dataset.__getitem__(n) for n in range(batch_size)] - batched = dataset.collate_fn(batch) - - # Test all shapes and lengths - for n, signal in enumerate(data.keys()): - length = signal.replace('_signal', '_length') - - if isinstance(batched, dict): - signal_shape = batched[signal].shape - signal_len = batched[length] - else: - signal_shape = batched[2 * n].shape - signal_len = batched[2 * n + 1] - - assert signal_shape == ( - batch_size, - data_num_channels[signal], - audio_duration_samples, - ), f'Test 5, use_lhotse={use_lhotse}: Unexpected signal {signal} shape {signal_shape}' - assert ( - len(signal_len) == batch_size - ), f'Test 5, use_lhotse={use_lhotse}: Unexpected length of signal_len ({len(signal_len)})' - assert all( - signal_len == audio_duration_samples - ), f'Test 5, use_lhotse={use_lhotse}: Unexpected signal_len {signal_len}' - - @pytest.mark.unit - def test_audio_to_target_dataset_with_target_list(self): - """Test AudioWithTargetDataset when the input manifest has a list - of audio files in the target key. - - In this use case, each line of the manifest file has the following format: - ``` - { - 'input_filepath': 'path/to/input.wav', - 'target_filepath': ['path/to/path_to_target_ch0.wav', 'path/to/path_to_target_ch1.wav'], - 'duration': duration_of_input, - } - ``` - """ - # Data setup - random_seed = 42 - sample_rate = 16000 - num_examples = 25 - data_num_channels = { - 'input_signal': 4, - 'target_signal': 2, - } - data_min_duration = 2.0 - data_max_duration = 8.0 - data_key = { - 'input_signal': 'input_filepath', - 'target_signal': 'target_filepath', - } - - # Tolerance - atol = 1e-6 - - # Generate random signals - _rng = np.random.default_rng(seed=random_seed) - - # Input and target signals have the same duration - data_duration = np.round(_rng.uniform(low=data_min_duration, high=data_max_duration, size=num_examples), 3) - data_duration_samples = np.floor(data_duration * sample_rate).astype(int) - - data = dict() - for signal, num_channels in data_num_channels.items(): - data[signal] = [] - for n in range(num_examples): - if num_channels == 1: - random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_duration_samples[n])) - else: - random_signal = _rng.uniform(low=-0.5, high=0.5, size=(num_channels, data_duration_samples[n])) - data[signal].append(random_signal) - - with tempfile.TemporaryDirectory() as test_dir: - - # Build metadata for manifest - metadata = [] - - for n in range(num_examples): - - meta = dict() - - for signal in data: - if signal == 'target_signal': - # Save targets as individual files - signal_filename = [] - for ch in range(data_num_channels[signal]): - # add current filename - signal_filename.append(f'{signal}_{n:02d}_ch_{ch}.wav') - # write audio file - sf.write( - os.path.join(test_dir, signal_filename[-1]), - data[signal][n][ch, :], - sample_rate, - 'float', - ) - else: - # single file - signal_filename = f'{signal}_{n:02d}.wav' - - # write audio files - sf.write(os.path.join(test_dir, signal_filename), data[signal][n].T, sample_rate, 'float') - - # update metadata - meta[data_key[signal]] = signal_filename - - meta['duration'] = data_duration[n] - metadata.append(meta) - - # Save manifest - manifest_filepath = os.path.join(test_dir, 'manifest.json') - write_manifest(manifest_filepath, metadata) - - # Test 1 - # - No constraints on channels or duration - dataset = AudioToTargetDataset( - manifest_filepath=manifest_filepath, - input_key=data_key['input_signal'], - target_key=data_key['target_signal'], - sample_rate=sample_rate, - ) - - config = { - 'manifest_filepath': manifest_filepath, - 'input_key': data_key['input_signal'], - 'target_key': data_key['target_signal'], - 'sample_rate': sample_rate, - } - dataset_factory = audio_to_audio_dataset.get_audio_to_target_dataset(config) - - # Prepare lhotse manifest - cuts_path = manifest_filepath.replace('.json', '_cuts.jsonl') - convert_manifest_nemo_to_lhotse( - input_manifest=manifest_filepath, - output_manifest=cuts_path, - input_key=data_key['input_signal'], - target_key=data_key['target_signal'], - ) - - # Prepare lhotse dataset - config_lhotse = { - 'cuts_path': cuts_path, - 'use_lhotse': True, - 'sample_rate': sample_rate, - 'batch_size': 1, - } - dl_lhotse = get_lhotse_dataloader_from_config( - OmegaConf.create(config_lhotse), global_rank=0, world_size=1, dataset=LhotseAudioToTargetDataset() - ) - dataset_lhotse = [item for item in dl_lhotse] - - for n in range(num_examples): - for use_lhotse in [False, True]: - item = dataset_lhotse[n] if use_lhotse else dataset.__getitem__(n) - item_factory = dataset_factory.__getitem__(n) - for signal in data: - item_signal = item[signal].squeeze(0) if use_lhotse else item[signal] - golden_signal = data[signal][n] - assert ( - item_signal.shape == golden_signal.shape - ), f'Test 1, use_lhotse={use_lhotse}: Signal {signal} item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' - assert np.allclose( - item_signal, golden_signal, atol=atol - ), f'Test 1, use_lhotse={use_lhotse}: Failed for example {n}, signal {signal} (random seed {random_seed})' - - assert np.allclose( - item_factory[signal], golden_signal, atol=atol - ), f'Test 1, use_lhotse={use_lhotse}: Failed for factory example {n}, signal {signal} (random seed {random_seed})' - - # Test 2 - # Set target as the first channel of input_filepath and all files listed in target_filepath. - # In this case, the target will have 3 channels. - # Note: this is currently not supported by lhotse, so we only test the default dataset here. - dataset = AudioToTargetDataset( - manifest_filepath=manifest_filepath, - input_key=data_key['input_signal'], - target_key=[data_key['input_signal'], data_key['target_signal']], - target_channel_selector=0, - sample_rate=sample_rate, - ) - - for n in range(num_examples): - item = dataset.__getitem__(n) - - for signal in data: - item_signal = item[signal].cpu().detach().numpy() - golden_signal = data[signal][n] - if signal == 'target_signal': - # add the first channel of the input - golden_signal = np.concatenate([data['input_signal'][n][0:1, ...], golden_signal], axis=0) - assert ( - item_signal.shape == golden_signal.shape - ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' - assert np.allclose( - item_signal, golden_signal, atol=atol - ), f'Test 2: Failed for example {n}, signal {signal} (random seed {random_seed})' - - @pytest.mark.unit - def test_audio_to_target_dataset_for_inference(self): - """Test AudioWithTargetDataset when target_key is - not set, i.e., it is `None`. This is the case, e.g., when - running inference, and a target is not available. - - In this use case, each line of the manifest file has the following format: - ``` - { - 'input_filepath': 'path/to/input.wav', - 'duration': duration_of_input, - } - ``` - """ - # Data setup - random_seed = 42 - sample_rate = 16000 - num_examples = 25 - data_num_channels = { - 'input_signal': 4, - } - data_min_duration = 2.0 - data_max_duration = 8.0 - data_key = { - 'input_signal': 'input_filepath', - } - - # Tolerance - atol = 1e-6 - - # Generate random signals - _rng = np.random.default_rng(seed=random_seed) - - # Input and target signals have the same duration - data_duration = np.round(_rng.uniform(low=data_min_duration, high=data_max_duration, size=num_examples), 3) - data_duration_samples = np.floor(data_duration * sample_rate).astype(int) - - data = dict() - for signal, num_channels in data_num_channels.items(): - data[signal] = [] - for n in range(num_examples): - if num_channels == 1: - random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_duration_samples[n])) - else: - random_signal = _rng.uniform(low=-0.5, high=0.5, size=(num_channels, data_duration_samples[n])) - data[signal].append(random_signal) - - with tempfile.TemporaryDirectory() as test_dir: - # Build metadata for manifest - metadata = [] - for n in range(num_examples): - meta = dict() - for signal in data: - # filenames - signal_filename = f'{signal}_{n:02d}.wav' - # write audio files - sf.write(os.path.join(test_dir, signal_filename), data[signal][n].T, sample_rate, 'float') - # update metadata - meta[data_key[signal]] = signal_filename - meta['duration'] = data_duration[n] - metadata.append(meta) - - # Save manifest - manifest_filepath = os.path.join(test_dir, 'manifest.json') - write_manifest(manifest_filepath, metadata) - - # Test 1 - # - No constraints on channels or duration - dataset = AudioToTargetDataset( - manifest_filepath=manifest_filepath, - input_key=data_key['input_signal'], - target_key=None, # target_signal will be empty - sample_rate=sample_rate, - ) - - # Also test the corresponding factory - config = { - 'manifest_filepath': manifest_filepath, - 'input_key': data_key['input_signal'], - 'target_key': None, - 'sample_rate': sample_rate, - } - dataset_factory = audio_to_audio_dataset.get_audio_to_target_dataset(config) - - # Prepare lhotse manifest - cuts_path = manifest_filepath.replace('.json', '_cuts.jsonl') - convert_manifest_nemo_to_lhotse( - input_manifest=manifest_filepath, - output_manifest=cuts_path, - input_key=data_key['input_signal'], - target_key=None, - ) - - # Prepare lhotse dataset - config_lhotse = { - 'cuts_path': cuts_path, - 'use_lhotse': True, - 'sample_rate': sample_rate, - 'batch_size': 1, - } - dl_lhotse = get_lhotse_dataloader_from_config( - OmegaConf.create(config_lhotse), global_rank=0, world_size=1, dataset=LhotseAudioToTargetDataset() - ) - dataset_lhotse = [item for item in dl_lhotse] - - for n in range(num_examples): - - for label in ['original', 'factory', 'lhotse']: - - if label == 'original': - item = dataset.__getitem__(n) - elif label == 'factory': - item = dataset_factory.__getitem__(n) - elif label == 'lhotse': - item = dataset_lhotse[n] - else: - raise ValueError(f'Unknown label {label}') - - # Check target is None - if 'target_signal' in item: - assert item['target_signal'].numel() == 0, f'{label}: target_signal is expected to be empty.' - - # Check valid signals - for signal in data: - - item_signal = item[signal].squeeze(0) if label == 'lhotse' else item[signal] - golden_signal = data[signal][n] - assert ( - item_signal.shape == golden_signal.shape - ), f'{label} -- Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' - assert np.allclose( - item_signal, golden_signal, atol=atol - ), f'{label} -- Test 1: Failed for example {n}, signal {signal} (random seed {random_seed})' - - @pytest.mark.unit - def test_audio_to_target_with_reference_dataset(self): - """Test AudioWithTargetWithReferenceDataset in different configurations. - - 1) reference synchronized with input and target - 2) reference not synchronized - - In this use case, each line of the manifest file has the following format: - ``` - { - 'input_filepath': 'path/to/input.wav', - 'target_filepath': 'path/to/path_to_target.wav', - 'reference_filepath': 'path/to/path_to_reference.wav', - 'duration': duration_of_input, - } - ``` - """ - # Data setup - random_seed = 42 - sample_rate = 16000 - num_examples = 25 - data_num_channels = { - 'input_signal': 4, - 'target_signal': 2, - 'reference_signal': 1, - } - data_min_duration = 2.0 - data_max_duration = 8.0 - data_key = { - 'input_signal': 'input_filepath', - 'target_signal': 'target_filepath', - 'reference_signal': 'reference_filepath', - } - - # Tolerance - atol = 1e-6 - - # Generate random signals - _rng = np.random.default_rng(seed=random_seed) - - # Input and target signals have the same duration - data_duration = np.round(_rng.uniform(low=data_min_duration, high=data_max_duration, size=num_examples), 3) - data_duration_samples = np.floor(data_duration * sample_rate).astype(int) - - data = dict() - for signal, num_channels in data_num_channels.items(): - data[signal] = [] - for n in range(num_examples): - if num_channels == 1: - random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_duration_samples[n])) - else: - random_signal = _rng.uniform(low=-0.5, high=0.5, size=(num_channels, data_duration_samples[n])) - data[signal].append(random_signal) - - with tempfile.TemporaryDirectory() as test_dir: - - # Build metadata for manifest - metadata = [] - - for n in range(num_examples): - - meta = dict() - - for signal in data: - # filenames - signal_filename = f'{signal}_{n:02d}.wav' - - # write audio files - sf.write(os.path.join(test_dir, signal_filename), data[signal][n].T, sample_rate, 'float') - - # update metadata - meta[data_key[signal]] = signal_filename - - meta['duration'] = data_duration[n] - metadata.append(meta) - - # Save manifest - manifest_filepath = os.path.join(test_dir, 'manifest.json') - write_manifest(manifest_filepath, metadata) - - # Test 1 - # - No constraints on channels or duration - # - Reference is not synchronized with input and target, so whole reference signal will be loaded - dataset = AudioToTargetWithReferenceDataset( - manifest_filepath=manifest_filepath, - input_key=data_key['input_signal'], - target_key=data_key['target_signal'], - reference_key=data_key['reference_signal'], - reference_is_synchronized=False, - sample_rate=sample_rate, - ) - - # Also test the corresponding factory - config = { - 'manifest_filepath': manifest_filepath, - 'input_key': data_key['input_signal'], - 'target_key': data_key['target_signal'], - 'reference_key': data_key['reference_signal'], - 'reference_is_synchronized': False, - 'sample_rate': sample_rate, - } - dataset_factory = audio_to_audio_dataset.get_audio_to_target_with_reference_dataset(config) - - for n in range(num_examples): - item = dataset.__getitem__(n) - item_factory = dataset_factory.__getitem__(n) - - for signal in data: - item_signal = item[signal].cpu().detach().numpy() - golden_signal = data[signal][n] - assert ( - item_signal.shape == golden_signal.shape - ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' - assert np.allclose( - item_signal, golden_signal, atol=atol - ), f'Test 1: Failed for example {n}, signal {signal} (random seed {random_seed})' - - item_factory_signal = item_factory[signal].cpu().detach().numpy() - assert np.allclose( - item_factory_signal, golden_signal, atol=atol - ), f'Test 1: Failed for factory example {n}, signal {signal} (random seed {random_seed})' - - # Test 2 - # - Use fixed duration (random segment selection) - # - Reference is synchronized with input and target, so the same segment of reference signal will be loaded - audio_duration = 4.0 - audio_duration_samples = int(np.floor(audio_duration * sample_rate)) - dataset = AudioToTargetWithReferenceDataset( - manifest_filepath=manifest_filepath, - input_key=data_key['input_signal'], - target_key=data_key['target_signal'], - reference_key=data_key['reference_signal'], - reference_is_synchronized=True, - sample_rate=sample_rate, - min_duration=audio_duration, - audio_duration=audio_duration, - random_offset=True, - ) - - filtered_examples = [n for n, val in enumerate(data_duration) if val >= audio_duration] - - for n in range(len(dataset)): - item = dataset.__getitem__(n) - - golden_start = golden_end = None - for signal in data: - item_signal = item[signal].cpu().detach().numpy() - full_golden_signal = data[signal][filtered_examples[n]] - - # Find random segment using correlation on the first channel - # of the first signal, and then use it fixed for other signals - if golden_start is None: - golden_start = get_segment_start(signal=full_golden_signal[0, :], segment=item_signal[0, :]) - golden_end = golden_start + audio_duration_samples - golden_signal = full_golden_signal[..., golden_start:golden_end] - - # Test length is correct - assert ( - item_signal.shape[-1] == audio_duration_samples - ), f'Test 2: Signal {signal} length ({item_signal.shape[-1]}) not matching the expected length ({audio_duration_samples})' - - # Test signal values - assert np.allclose( - item_signal, golden_signal, atol=atol - ), f'Test 2: Failed for example {n}, signal {signal} (random seed {random_seed})' - - # Test 3 - # - Use fixed duration (random segment selection) - # - Reference is not synchronized with input and target, so whole reference signal will be loaded - audio_duration = 4.0 - audio_duration_samples = int(np.floor(audio_duration * sample_rate)) - dataset = AudioToTargetWithReferenceDataset( - manifest_filepath=manifest_filepath, - input_key=data_key['input_signal'], - target_key=data_key['target_signal'], - reference_key=data_key['reference_signal'], - reference_is_synchronized=False, - sample_rate=sample_rate, - min_duration=audio_duration, - audio_duration=audio_duration, - random_offset=True, - ) - - filtered_examples = [n for n, val in enumerate(data_duration) if val >= audio_duration] - - for n in range(len(dataset)): - item = dataset.__getitem__(n) - - golden_start = golden_end = None - for signal in data: - item_signal = item[signal].cpu().detach().numpy() - full_golden_signal = data[signal][filtered_examples[n]] - - if signal == 'reference_signal': - # Complete signal is loaded for reference - golden_signal = full_golden_signal - else: - # Find random segment using correlation on the first channel - # of the first signal, and then use it fixed for other signals - if golden_start is None: - golden_start = get_segment_start( - signal=full_golden_signal[0, :], segment=item_signal[0, :] - ) - golden_end = golden_start + audio_duration_samples - golden_signal = full_golden_signal[..., golden_start:golden_end] - - # Test length is correct - assert ( - item_signal.shape[-1] == audio_duration_samples - ), f'Test 3: Signal {signal} length ({item_signal.shape[-1]}) not matching the expected length ({audio_duration_samples})' - assert ( - item_signal.shape == golden_signal.shape - ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' - # Test signal values - assert np.allclose( - item_signal, golden_signal, atol=atol - ), f'Test 3: Failed for example {n}, signal {signal} (random seed {random_seed})' - - # Test 4: - # - Test collate_fn - batch_size = 16 - batch = [dataset.__getitem__(n) for n in range(batch_size)] - _ = dataset.collate_fn(batch) - - @pytest.mark.unit - def test_audio_to_target_with_embedding_dataset(self): - """Test AudioWithTargetWithEmbeddingDataset. - - In this use case, each line of the manifest file has the following format: - ``` - { - 'input_filepath': 'path/to/input.wav', - 'target_filepath': 'path/to/path_to_target.wav', - 'embedding_filepath': 'path/to/path_to_embedding.npy', - 'duration': duration_of_input, - } - ``` - """ - # Data setup - random_seed = 42 - sample_rate = 16000 - num_examples = 25 - data_num_channels = { - 'input_signal': 4, - 'target_signal': 2, - 'embedding_vector': 1, - } - data_min_duration = 2.0 - data_max_duration = 8.0 - embedding_length = 64 # 64-dimensional embedding vector - data_key = { - 'input_signal': 'input_filepath', - 'target_signal': 'target_filepath', - 'embedding_vector': 'embedding_filepath', - } - - # Tolerance - atol = 1e-6 - - # Generate random signals - _rng = np.random.default_rng(seed=random_seed) - - # Input and target signals have the same duration - data_duration = np.round(_rng.uniform(low=data_min_duration, high=data_max_duration, size=num_examples), 3) - data_duration_samples = np.floor(data_duration * sample_rate).astype(int) - - data = dict() - for signal, num_channels in data_num_channels.items(): - data[signal] = [] - for n in range(num_examples): - data_length = embedding_length if signal == 'embedding_vector' else data_duration_samples[n] - - if num_channels == 1: - random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_length)) - else: - random_signal = _rng.uniform(low=-0.5, high=0.5, size=(num_channels, data_length)) - data[signal].append(random_signal) - - with tempfile.TemporaryDirectory() as test_dir: - - # Build metadata for manifest - metadata = [] - - for n in range(num_examples): - - meta = dict() - - for signal in data: - if signal == 'embedding_vector': - signal_filename = f'{signal}_{n:02d}.npy' - np.save(os.path.join(test_dir, signal_filename), data[signal][n]) - - else: - # filenames - signal_filename = f'{signal}_{n:02d}.wav' - - # write audio files - sf.write(os.path.join(test_dir, signal_filename), data[signal][n].T, sample_rate, 'float') - - # update metadata - meta[data_key[signal]] = signal_filename - - meta['duration'] = data_duration[n] - metadata.append(meta) - - # Save manifest - manifest_filepath = os.path.join(test_dir, 'manifest.json') - write_manifest(manifest_filepath, metadata) - - # Test 1 - # - No constraints on channels or duration - dataset = AudioToTargetWithEmbeddingDataset( - manifest_filepath=manifest_filepath, - input_key=data_key['input_signal'], - target_key=data_key['target_signal'], - embedding_key=data_key['embedding_vector'], - sample_rate=sample_rate, - ) - - # Also test the corresponding factory - config = { - 'manifest_filepath': manifest_filepath, - 'input_key': data_key['input_signal'], - 'target_key': data_key['target_signal'], - 'embedding_key': data_key['embedding_vector'], - 'sample_rate': sample_rate, - } - dataset_factory = audio_to_audio_dataset.get_audio_to_target_with_embedding_dataset(config) - - for n in range(num_examples): - item = dataset.__getitem__(n) - item_factory = dataset_factory.__getitem__(n) - - for signal in data: - item_signal = item[signal].cpu().detach().numpy() - golden_signal = data[signal][n] - assert ( - item_signal.shape == golden_signal.shape - ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' - assert np.allclose( - item_signal, golden_signal, atol=atol - ), f'Test 1: Failed for example {n}, signal {signal} (random seed {random_seed})' - - item_factory_signal = item_factory[signal].cpu().detach().numpy() - assert np.allclose( - item_factory_signal, golden_signal, atol=atol - ), f'Test 1: Failed for factory example {n}, signal {signal} (random seed {random_seed})' - - # Test 2: - # - Test collate_fn - batch_size = 16 - batch = [dataset.__getitem__(n) for n in range(batch_size)] - _ = dataset.collate_fn(batch) - - class TestUtilityFunctions: @pytest.mark.unit @pytest.mark.parametrize('cache_audio', [False, True]) def test_cache_datastore_manifests(self, cache_audio: bool): - """Test caching of manifest and audio files. - """ + """Test caching of manifest and audio files.""" # Data setup random_seed = 42 sample_rate = 16000 @@ -1974,9 +844,10 @@ def fake_get(self): # Return path as in the original get return self.local_path - with mock.patch( - 'nemo.collections.asr.data.audio_to_text.is_datastore_path', lambda x: True - ), mock.patch.object(DataStoreObject, 'get', fake_get): + with ( + mock.patch('nemo.collections.asr.data.audio_to_text.is_datastore_path', lambda x: True), + mock.patch.object(DataStoreObject, 'get', fake_get), + ): # Use a single worker for this test to avoid failure with mock & multiprocessing (#5607) cache_datastore_manifests(manifest_filepaths, cache_audio=cache_audio, num_workers=1) diff --git a/tests/collections/asr/test_asr_metrics.py b/tests/collections/asr/test_asr_metrics.py index 134d96f522b13..daee554a6585e 100644 --- a/tests/collections/asr/test_asr_metrics.py +++ b/tests/collections/asr/test_asr_metrics.py @@ -21,9 +21,7 @@ import pytest import torch -from torchmetrics.audio.snr import SignalNoiseRatio -from nemo.collections.asr.metrics.audio import AudioMetricWrapper from nemo.collections.asr.metrics.wer import WER, word_error_rate, word_error_rate_detail, word_error_rate_per_utt from nemo.collections.asr.parts.submodules.ctc_decoding import ( CTCBPEDecoding, @@ -128,7 +126,13 @@ def test_wer_function(self): float("inf"), float("inf"), ) - assert word_error_rate_detail(hypotheses=['cat', ''], references=['', 'gpu']) == (2.0, 1, 1.0, 1.0, 0.0,) + assert word_error_rate_detail(hypotheses=['cat', ''], references=['', 'gpu']) == ( + 2.0, + 1, + 1.0, + 1.0, + 0.0, + ) assert word_error_rate_detail(hypotheses=['cat'], references=['cot']) == (1.0, 1, 0.0, 0.0, 1.0) assert word_error_rate_detail(hypotheses=['G P U'], references=['GPU']) == (3.0, 1, 2.0, 0.0, 1.0) assert word_error_rate_detail(hypotheses=[''], references=['ducuti motorcycle'], use_cer=True) == ( @@ -540,130 +544,3 @@ def test_subword_decoding_labels(self): assert hyp.text != '' assert len(hyp.timestep) == 3 assert hyp.alignments is None - - -class TestAudioMetricWrapper: - def test_metric_full_batch(self): - """Test metric on batches where all examples have equal length. - """ - ref_metric = SignalNoiseRatio() - wrapped_metric = AudioMetricWrapper(metric=SignalNoiseRatio()) - - num_resets = 5 - num_batches = 10 - batch_size = 8 - num_channels = 2 - num_samples = 200 - - batch_shape = (batch_size, num_channels, num_samples) - - for nr in range(num_resets): - for nb in range(num_batches): - target = torch.rand(*batch_shape) - preds = target + torch.rand(1) * torch.rand(*batch_shape) - - # test forward for a single batch - batch_value_wrapped = wrapped_metric(preds=preds, target=target) - batch_value_ref = ref_metric(preds=preds, target=target) - - assert torch.allclose( - batch_value_wrapped, batch_value_ref - ), f'Metric forward not matching for batch {nb}, reset {nr}' - - # test compute (over num_batches) - assert torch.allclose( - wrapped_metric.compute(), ref_metric.compute() - ), f'Metric compute not matching for batch {nb}, reset {nr}' - - ref_metric.reset() - wrapped_metric.reset() - - def test_input_length(self): - """Test metric on batches where examples have different length. - """ - ref_metric = SignalNoiseRatio() - wrapped_metric = AudioMetricWrapper(metric=SignalNoiseRatio()) - - num_resets = 5 - num_batches = 10 - batch_size = 8 - num_channels = 2 - num_samples = 200 - - batch_shape = (batch_size, num_channels, num_samples) - - for nr in range(num_resets): - for nb in range(num_batches): - target = torch.rand(*batch_shape) - preds = target + torch.rand(1) * torch.rand(*batch_shape) - - input_length = torch.randint(low=num_samples // 2, high=num_samples, size=(batch_size,)) - - # test forward for a single batch - batch_value_wrapped = wrapped_metric(preds=preds, target=target, input_length=input_length) - - # compute reference value, assuming batch reduction using averaging - batch_value_ref = 0 - for b_idx, b_len in enumerate(input_length): - batch_value_ref += ref_metric(preds=preds[b_idx, ..., :b_len], target=target[b_idx, ..., :b_len]) - batch_value_ref /= batch_size # average - - assert torch.allclose( - batch_value_wrapped, batch_value_ref - ), f'Metric forward not matching for batch {nb}, reset {nr}' - - # test compute (over num_batches) - assert torch.allclose( - wrapped_metric.compute(), ref_metric.compute() - ), f'Metric compute not matching for batch {nb}, reset {nr}' - - ref_metric.reset() - wrapped_metric.reset() - - @pytest.mark.unit - @pytest.mark.parametrize('channel', [0, 1]) - def test_channel(self, channel): - """Test metric on a single channel from a batch. - """ - ref_metric = SignalNoiseRatio() - # select only a single channel - wrapped_metric = AudioMetricWrapper(metric=SignalNoiseRatio(), channel=channel) - - num_resets = 5 - num_batches = 10 - batch_size = 8 - num_channels = 2 - num_samples = 200 - - batch_shape = (batch_size, num_channels, num_samples) - - for nr in range(num_resets): - for nb in range(num_batches): - target = torch.rand(*batch_shape) - preds = target + torch.rand(1) * torch.rand(*batch_shape) - - # varying length - input_length = torch.randint(low=num_samples // 2, high=num_samples, size=(batch_size,)) - - # test forward for a single batch - batch_value_wrapped = wrapped_metric(preds=preds, target=target, input_length=input_length) - - # compute reference value, assuming batch reduction using averaging - batch_value_ref = 0 - for b_idx, b_len in enumerate(input_length): - batch_value_ref += ref_metric( - preds=preds[b_idx, channel, :b_len], target=target[b_idx, channel, :b_len] - ) - batch_value_ref /= batch_size # average - - assert torch.allclose( - batch_value_wrapped, batch_value_ref - ), f'Metric forward not matching for batch {nb}, reset {nr}' - - # test compute (over num_batches) - assert torch.allclose( - wrapped_metric.compute(), ref_metric.compute() - ), f'Metric compute not matching for batch {nb}, reset {nr}' - - ref_metric.reset() - wrapped_metric.reset() diff --git a/tests/collections/asr/test_asr_multitask_model_bpe.py b/tests/collections/asr/test_asr_multitask_model_bpe.py index 986df09deacbd..4e805c8f34dee 100644 --- a/tests/collections/asr/test_asr_multitask_model_bpe.py +++ b/tests/collections/asr/test_asr_multitask_model_bpe.py @@ -22,6 +22,7 @@ from nemo.collections.asr.models.aed_multitask_models import EncDecMultiTaskModel from nemo.collections.asr.parts.submodules import multitask_beam_decoding as beam_decode from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis +from nemo.collections.common.prompts.canary import CanaryPromptFormatter from nemo.collections.common.tokenizers import CanaryTokenizer @@ -275,6 +276,51 @@ def test_decoding_change(self, asr_model): assert isinstance(asr_model.decoding.decoding, beam_decode.TransformerAEDBeamInfer) assert asr_model.decoding.decoding.search_type == "default" + @pytest.mark.unit + def test_prompt_change(self, asr_model): + assert asr_model.prompt_format == 'canary' + assert isinstance(asr_model.prompt, CanaryPromptFormatter) + + # Default change prompt + asr_model.change_prompt() + assert asr_model.cfg.prompt_defaults is None + + prompt_defaults = asr_model.prompt.get_default_dialog_slots() + prompt_defaults[0]['slots']['pnc'] = 'no' + asr_model.change_prompt(prompt_defaults=prompt_defaults) + + assert asr_model.cfg.prompt_defaults[0]['slots']['pnc'] == 'no' + + @pytest.mark.unit + def test_prompt_change_subclass(self, asr_model): + assert asr_model.prompt_format == 'canary' + assert isinstance(asr_model.prompt, CanaryPromptFormatter) + + class CanaryPromptFormatterSubclass(CanaryPromptFormatter): + NAME = "canary2" + + # Default change prompt + asr_model.change_prompt() + assert asr_model.cfg.prompt_defaults is None + + prompt_defaults = asr_model.prompt.get_default_dialog_slots() + prompt_defaults[0]['slots']['pnc'] = 'no' + asr_model.change_prompt(prompt_format='canary2', prompt_defaults=prompt_defaults) + + assert asr_model.cfg.prompt_format == 'canary2' + assert asr_model.cfg.prompt_defaults[0]['slots']['pnc'] == 'no' + assert isinstance(asr_model.prompt, CanaryPromptFormatterSubclass) + + user_prompt = asr_model.prompt.get_default_dialog_slots()[0] + slots = user_prompt['slots'] + slots['source_lang'] = 'en' + slots['target_lang'] = 'en' + slots['task'] = 'asr' + slots['pnc'] = 'no' + ans = asr_model.prompt.encode_dialog([user_prompt]) + recovered = asr_model.tokenizer.ids_to_text(ans["input_ids"]) + assert recovered == "<|startoftranscript|><|en|><|transcribe|><|en|><|nopnc|>" + @pytest.mark.unit def test_transcribe_single_file(self, asr_model, test_data_dir): audio_file = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an46-mmap-b.wav") diff --git a/tests/collections/asr/test_preprocessing_segment.py b/tests/collections/asr/test_preprocessing_segment.py index 20e05e4964dc6..9f6144bad017e 100644 --- a/tests/collections/asr/test_preprocessing_segment.py +++ b/tests/collections/asr/test_preprocessing_segment.py @@ -15,6 +15,7 @@ import json import os import tempfile +from collections import namedtuple from typing import List, Type, Union import numpy as np @@ -22,8 +23,73 @@ import soundfile as sf from nemo.collections.asr.parts.preprocessing.perturb import NoisePerturbation, SilencePerturbation -from nemo.collections.asr.parts.preprocessing.segment import AudioSegment -from nemo.collections.asr.parts.utils.audio_utils import select_channels +from nemo.collections.asr.parts.preprocessing.segment import AudioSegment, select_channels + + +class TestSelectChannels: + num_samples = 1000 + max_diff_tol = 1e-9 + + @pytest.mark.unit + @pytest.mark.parametrize("channel_selector", [None, 'average', 0, 1, [0, 1]]) + def test_single_channel_input(self, channel_selector: Type[Union[str, int, List[int]]]): + """Cover the case with single-channel input signal. + Channel selector should not do anything in this case. + """ + golden_out = signal_in = np.random.rand(self.num_samples) + + if channel_selector not in [None, 0, 'average']: + # Expect a failure if looking for a different channel when input is 1D + with pytest.raises(ValueError): + # UUT + select_channels(signal_in, channel_selector) + else: + # UUT + signal_out = select_channels(signal_in, channel_selector) + + # Check difference + max_diff = np.max(np.abs(signal_out - golden_out)) + assert max_diff < self.max_diff_tol + + @pytest.mark.unit + @pytest.mark.parametrize("num_channels", [2, 4]) + @pytest.mark.parametrize("channel_selector", [None, 'average', 0, [1], [0, 1]]) + def test_multi_channel_input(self, num_channels: int, channel_selector: Type[Union[str, int, List[int]]]): + """Cover the case with multi-channel input signal and single- + or multi-channel output. + """ + signal_in = np.random.rand(self.num_samples, num_channels) + + # calculate golden output + if channel_selector is None: + golden_out = signal_in + elif channel_selector == 'average': + golden_out = np.mean(signal_in, axis=1) + else: + golden_out = signal_in[:, channel_selector].squeeze() + + # UUT + signal_out = select_channels(signal_in, channel_selector) + + # Check difference + max_diff = np.max(np.abs(signal_out - golden_out)) + assert max_diff < self.max_diff_tol + + @pytest.mark.unit + @pytest.mark.parametrize("num_channels", [1, 2]) + @pytest.mark.parametrize("channel_selector", [2, [1, 2]]) + def test_select_more_channels_than_available( + self, num_channels: int, channel_selector: Type[Union[str, int, List[int]]] + ): + """This test is expecting the UUT to fail because we ask for more channels + than available in the input signal. + """ + signal_in = np.random.rand(self.num_samples, num_channels) + + # expect failure since we ask for more channels than available + with pytest.raises(ValueError): + # UUT + select_channels(signal_in, channel_selector) class TestAudioSegment: @@ -40,8 +106,7 @@ def num_samples(self): @pytest.mark.parametrize("num_channels", [1, 4]) @pytest.mark.parametrize("channel_selector", [None, 'average', 0, 1, [0, 1]]) def test_init_single_channel(self, num_channels: int, channel_selector: Type[Union[str, int, List[int]]]): - """Test the constructor directly. - """ + """Test the constructor directly.""" if num_channels == 1: # samples is a one-dimensional vector for single-channel signal samples = np.random.rand(self.num_samples) @@ -95,8 +160,7 @@ def test_init_single_channel(self, num_channels: int, channel_selector: Type[Uni @pytest.mark.parametrize("num_channels", [1, 4]) @pytest.mark.parametrize("channel_selector", [None, 'average', 0]) def test_from_file(self, num_channels, channel_selector): - """Test loading a signal from a file. - """ + """Test loading a signal from a file.""" with tempfile.TemporaryDirectory() as test_dir: # Prepare a wav file audio_file = os.path.join(test_dir, 'audio.wav') @@ -127,8 +191,7 @@ def test_from_file(self, num_channels, channel_selector): @pytest.mark.parametrize("data_channels", [1, 4]) @pytest.mark.parametrize("noise_channels", [1, 4]) def test_noise_perturb_channels(self, data_channels, noise_channels): - """Test loading a signal from a file. - """ + """Test loading a signal from a file.""" with tempfile.TemporaryDirectory() as test_dir: # Prepare a wav file audio_file = os.path.join(test_dir, 'audio.wav') @@ -179,8 +242,7 @@ def test_noise_perturb_channels(self, data_channels, noise_channels): _ = perturber.perturb_with_foreground_noise(audio, noise) def test_silence_perturb(self): - """Test loading a signal from a file and apply silence perturbation - """ + """Test loading a signal from a file and apply silence perturbation""" with tempfile.TemporaryDirectory() as test_dir: # Prepare a wav file audio_file = os.path.join(test_dir, 'audio.wav') @@ -201,3 +263,225 @@ def test_silence_perturb(self): _ = perturber.perturb(audio) assert len(audio._samples) == ori_audio_len + 2 * dur * self.sample_rate + + @pytest.mark.unit + @pytest.mark.parametrize( + "num_channels, channel_selectors", + [ + (1, [None, 'average', 0]), + (3, [None, 'average', 0, 1, [0, 1]]), + ], + ) + @pytest.mark.parametrize("sample_rate", [8000, 16000, 22500]) + def test_audio_segment_from_file(self, tmpdir, num_channels, channel_selectors, sample_rate): + """Test loading and audio signal from a file.""" + signal_len_sec = 4 + num_samples = signal_len_sec * sample_rate + num_examples = 10 + rtol, atol = 1e-5, 1e-6 + + for n in range(num_examples): + # Create a test vector + audio_file = os.path.join(tmpdir, f'test_audio_{n:02}.wav') + samples = np.random.randn(num_samples, num_channels) + sf.write(audio_file, samples, sample_rate, 'float') + + for channel_selector in channel_selectors: + if channel_selector is None: + ref_samples = samples + elif isinstance(channel_selector, int) or isinstance(channel_selector, list): + ref_samples = samples[:, channel_selector] + elif channel_selector == 'average': + ref_samples = np.mean(samples, axis=1) + else: + raise ValueError(f'Unexpected value of channel_selector {channel_selector}') + + # 1) Load complete audio + # Reference + ref_samples = ref_samples.squeeze() + ref_channels = 1 if ref_samples.ndim == 1 else ref_samples.shape[1] + + # UUT + audio_segment = AudioSegment.from_file(audio_file, channel_selector=channel_selector) + + # Test + assert ( + audio_segment.sample_rate == sample_rate + ), f'channel_selector {channel_selector}, sample rate not matching: {audio_segment.sample_rate} != {sample_rate}' + assert ( + audio_segment.num_channels == ref_channels + ), f'channel_selector {channel_selector}, num channels not matching: {audio_segment.num_channels} != {ref_channels}' + assert audio_segment.num_samples == len( + ref_samples + ), f'channel_selector {channel_selector}, num samples not matching: {audio_segment.num_samples} != {len(ref_samples)}' + assert np.allclose( + audio_segment.samples, ref_samples, rtol=rtol, atol=atol + ), f'channel_selector {channel_selector}, samples not matching' + + # 2) Load a with duration=None and offset=None, should load the whole audio + + # UUT + audio_segment = AudioSegment.from_file( + audio_file, offset=None, duration=None, channel_selector=channel_selector + ) + + # Test + assert ( + audio_segment.sample_rate == sample_rate + ), f'channel_selector {channel_selector}, offset {offset}, duration {duration}, sample rate not matching: {audio_segment.sample_rate} != {sample_rate}' + assert ( + audio_segment.num_channels == ref_channels + ), f'channel_selector {channel_selector}, offset {offset}, duration {duration}, num channels not matching: {audio_segment.num_channels} != {ref_channels}' + assert audio_segment.num_samples == len( + ref_samples + ), f'channel_selector {channel_selector}, offset {offset}, duration {duration}, num samples not matching: {audio_segment.num_samples} != {len(ref_samples)}' + assert np.allclose( + audio_segment.samples, ref_samples, rtol=rtol, atol=atol + ), f'channel_selector {channel_selector}, offset {offset}, duration {duration}, samples not matching' + + # 3) Load a random segment + offset = 0.45 * np.random.rand() * signal_len_sec + duration = 0.45 * np.random.rand() * signal_len_sec + + # Reference + start = int(offset * sample_rate) + end = start + int(duration * sample_rate) + ref_samples = ref_samples[start:end, ...] + + # UUT + audio_segment = AudioSegment.from_file( + audio_file, offset=offset, duration=duration, channel_selector=channel_selector + ) + + # Test + assert ( + audio_segment.sample_rate == sample_rate + ), f'channel_selector {channel_selector}, offset {offset}, duration {duration}, sample rate not matching: {audio_segment.sample_rate} != {sample_rate}' + assert ( + audio_segment.num_channels == ref_channels + ), f'channel_selector {channel_selector}, offset {offset}, duration {duration}, num channels not matching: {audio_segment.num_channels} != {ref_channels}' + assert audio_segment.num_samples == len( + ref_samples + ), f'channel_selector {channel_selector}, offset {offset}, duration {duration}, num samples not matching: {audio_segment.num_samples} != {len(ref_samples)}' + assert np.allclose( + audio_segment.samples, ref_samples, rtol=rtol, atol=atol + ), f'channel_selector {channel_selector}, offset {offset}, duration {duration}, samples not matching' + + @pytest.mark.unit + @pytest.mark.parametrize( + "num_channels, channel_selectors", + [ + (1, [None, 'average', 0]), + (3, [None, 'average', 0, 1, [0, 1]]), + ], + ) + @pytest.mark.parametrize("offset", [0, 1.5]) + @pytest.mark.parametrize("duration", [1, 2]) + def test_audio_segment_multichannel_with_list(self, tmpdir, num_channels, channel_selectors, offset, duration): + """Test loading an audio signal from a list of single-channel files.""" + sample_rate = 16000 + signal_len_sec = 5 + num_samples = signal_len_sec * sample_rate + rtol, atol = 1e-5, 1e-6 + + # Random samples + samples = np.random.rand(num_samples, num_channels) + + # Save audio + audio_files = [] + for m in range(num_channels): + a_file = os.path.join(tmpdir, f'ch_{m}.wav') + sf.write(a_file, samples[:, m], sample_rate) + audio_files.append(a_file) + mc_file = os.path.join(tmpdir, f'mc.wav') + sf.write(mc_file, samples, sample_rate) + + for channel_selector in channel_selectors: + + # UUT: loading audio from a list of files + uut_segment = AudioSegment.from_file( + audio_file=audio_files, offset=offset, duration=duration, channel_selector=channel_selector + ) + + # Reference: load from the original file + ref_segment = AudioSegment.from_file( + audio_file=mc_file, offset=offset, duration=duration, channel_selector=channel_selector + ) + + # Check + assert ( + uut_segment.sample_rate == ref_segment.sample_rate + ), f'channel_selector {channel_selector}: expecting {ref_segment.sample_rate}, but UUT segment has {uut_segment.sample_rate}' + assert ( + uut_segment.num_samples == ref_segment.num_samples + ), f'channel_selector {channel_selector}: expecting {ref_segment.num_samples}, but UUT segment has {uut_segment.num_samples}' + assert np.allclose( + uut_segment.samples, ref_segment.samples, rtol=rtol, atol=atol + ), f'channel_selector {channel_selector}: samples not matching' + + # Try to get a channel that is out of range. + with pytest.raises(RuntimeError, match="Channel cannot be selected"): + AudioSegment.from_file(audio_file=audio_files, channel_selector=num_channels) + + if num_channels > 1: + # Try to load a list of multichannel files + # This is expected to fail since we only support loading a single-channel signal + # from each file when audio_file is a list + with pytest.raises(RuntimeError, match="Expecting a single-channel audio signal"): + AudioSegment.from_file(audio_file=[mc_file, mc_file]) + + with pytest.raises(RuntimeError, match="Expecting a single-channel audio signal"): + AudioSegment.from_file(audio_file=[mc_file, mc_file], channel_selector=0) + + @pytest.mark.unit + @pytest.mark.parametrize("target_sr", [8000, 16000]) + def test_audio_segment_trim_match(self, tmpdir, target_sr): + """Test loading and audio signal from a file matches when using a path and a list + for different target_sr, int_values and trim setups. + """ + sample_rate = 24000 + signal_len_sec = 2 + num_samples = signal_len_sec * sample_rate + num_examples = 10 + + TrimSetup = namedtuple("TrimSetup", "ref top_db frame_length hop_length") + trim_setups = [] + trim_setups.append(TrimSetup(np.max, 10, 2048, 1024)) + trim_setups.append(TrimSetup(1.0, 35, 2048, 1024)) + trim_setups.append(TrimSetup(0.8, 45, 2048, 1024)) + + for n in range(num_examples): + # Create a test vector + audio_file = os.path.join(tmpdir, f'test_audio_{n:02}.wav') + samples = np.random.randn(num_samples) + # normalize + samples = samples / np.max(samples) + # apply random scaling and window to have some samples cut by trim + samples = np.random.rand() * np.hanning(num_samples) * samples + sf.write(audio_file, samples, sample_rate, 'float') + + for trim_setup in trim_setups: + # UUT 1: load from a path + audio_segment_1 = AudioSegment.from_file( + audio_file, + target_sr=target_sr, + trim=True, + trim_ref=trim_setup.ref, + trim_top_db=trim_setup.top_db, + trim_frame_length=trim_setup.frame_length, + trim_hop_length=trim_setup.hop_length, + ) + + # UUT 2: load from a list + audio_segment_2 = AudioSegment.from_file( + [audio_file], + target_sr=target_sr, + trim=True, + trim_ref=trim_setup.ref, + trim_top_db=trim_setup.top_db, + trim_frame_length=trim_setup.frame_length, + trim_hop_length=trim_setup.hop_length, + ) + + # Test + assert audio_segment_1 == audio_segment_2, f'trim setup {trim_setup}, loaded segments not matching' diff --git a/tests/collections/asr/utils/test_audio_utils.py b/tests/collections/asr/utils/test_audio_utils.py deleted file mode 100644 index 58f3a2ef7cedf..0000000000000 --- a/tests/collections/asr/utils/test_audio_utils.py +++ /dev/null @@ -1,657 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from collections import namedtuple -from typing import List, Type, Union - -import librosa -import matplotlib.pyplot as plt -import numpy as np -import pytest -import scipy -import soundfile as sf -import torch - -from nemo.collections.asr.parts.preprocessing.segment import AudioSegment -from nemo.collections.asr.parts.utils.audio_utils import SOUND_VELOCITY as sound_velocity -from nemo.collections.asr.parts.utils.audio_utils import ( - calculate_sdr_numpy, - convmtx_mc_numpy, - db2mag, - estimated_coherence, - generate_approximate_noise_field, - get_segment_start, - mag2db, - pow2db, - rms, - select_channels, - theoretical_coherence, - toeplitz, -) - - -class TestAudioSegment: - @pytest.mark.unit - @pytest.mark.parametrize( - "num_channels, channel_selectors", [(1, [None, 'average', 0]), (3, [None, 'average', 0, 1, [0, 1]]),] - ) - @pytest.mark.parametrize("sample_rate", [8000, 16000, 22500]) - def test_audio_segment_from_file(self, tmpdir, num_channels, channel_selectors, sample_rate): - """Test loading and audio signal from a file. - """ - signal_len_sec = 4 - num_samples = signal_len_sec * sample_rate - num_examples = 10 - rtol, atol = 1e-5, 1e-6 - - for n in range(num_examples): - # Create a test vector - audio_file = os.path.join(tmpdir, f'test_audio_{n:02}.wav') - samples = np.random.randn(num_samples, num_channels) - sf.write(audio_file, samples, sample_rate, 'float') - - for channel_selector in channel_selectors: - if channel_selector is None: - ref_samples = samples - elif isinstance(channel_selector, int) or isinstance(channel_selector, list): - ref_samples = samples[:, channel_selector] - elif channel_selector == 'average': - ref_samples = np.mean(samples, axis=1) - else: - raise ValueError(f'Unexpected value of channel_selector {channel_selector}') - - # 1) Load complete audio - # Reference - ref_samples = ref_samples.squeeze() - ref_channels = 1 if ref_samples.ndim == 1 else ref_samples.shape[1] - - # UUT - audio_segment = AudioSegment.from_file(audio_file, channel_selector=channel_selector) - - # Test - assert ( - audio_segment.sample_rate == sample_rate - ), f'channel_selector {channel_selector}, sample rate not matching: {audio_segment.sample_rate} != {sample_rate}' - assert ( - audio_segment.num_channels == ref_channels - ), f'channel_selector {channel_selector}, num channels not matching: {audio_segment.num_channels} != {ref_channels}' - assert audio_segment.num_samples == len( - ref_samples - ), f'channel_selector {channel_selector}, num samples not matching: {audio_segment.num_samples} != {len(ref_samples)}' - assert np.allclose( - audio_segment.samples, ref_samples, rtol=rtol, atol=atol - ), f'channel_selector {channel_selector}, samples not matching' - - # 2) Load a with duration=None and offset=None, should load the whole audio - - # UUT - audio_segment = AudioSegment.from_file( - audio_file, offset=None, duration=None, channel_selector=channel_selector - ) - - # Test - assert ( - audio_segment.sample_rate == sample_rate - ), f'channel_selector {channel_selector}, offset {offset}, duration {duration}, sample rate not matching: {audio_segment.sample_rate} != {sample_rate}' - assert ( - audio_segment.num_channels == ref_channels - ), f'channel_selector {channel_selector}, offset {offset}, duration {duration}, num channels not matching: {audio_segment.num_channels} != {ref_channels}' - assert audio_segment.num_samples == len( - ref_samples - ), f'channel_selector {channel_selector}, offset {offset}, duration {duration}, num samples not matching: {audio_segment.num_samples} != {len(ref_samples)}' - assert np.allclose( - audio_segment.samples, ref_samples, rtol=rtol, atol=atol - ), f'channel_selector {channel_selector}, offset {offset}, duration {duration}, samples not matching' - - # 3) Load a random segment - offset = 0.45 * np.random.rand() * signal_len_sec - duration = 0.45 * np.random.rand() * signal_len_sec - - # Reference - start = int(offset * sample_rate) - end = start + int(duration * sample_rate) - ref_samples = ref_samples[start:end, ...] - - # UUT - audio_segment = AudioSegment.from_file( - audio_file, offset=offset, duration=duration, channel_selector=channel_selector - ) - - # Test - assert ( - audio_segment.sample_rate == sample_rate - ), f'channel_selector {channel_selector}, offset {offset}, duration {duration}, sample rate not matching: {audio_segment.sample_rate} != {sample_rate}' - assert ( - audio_segment.num_channels == ref_channels - ), f'channel_selector {channel_selector}, offset {offset}, duration {duration}, num channels not matching: {audio_segment.num_channels} != {ref_channels}' - assert audio_segment.num_samples == len( - ref_samples - ), f'channel_selector {channel_selector}, offset {offset}, duration {duration}, num samples not matching: {audio_segment.num_samples} != {len(ref_samples)}' - assert np.allclose( - audio_segment.samples, ref_samples, rtol=rtol, atol=atol - ), f'channel_selector {channel_selector}, offset {offset}, duration {duration}, samples not matching' - - @pytest.mark.unit - @pytest.mark.parametrize( - "num_channels, channel_selectors", [(1, [None, 'average', 0]), (3, [None, 'average', 0, 1, [0, 1]]),] - ) - @pytest.mark.parametrize("offset", [0, 1.5]) - @pytest.mark.parametrize("duration", [1, 2]) - def test_audio_segment_multichannel_with_list(self, tmpdir, num_channels, channel_selectors, offset, duration): - """Test loading an audio signal from a list of single-channel files. - """ - sample_rate = 16000 - signal_len_sec = 5 - num_samples = signal_len_sec * sample_rate - rtol, atol = 1e-5, 1e-6 - - # Random samples - samples = np.random.rand(num_samples, num_channels) - - # Save audio - audio_files = [] - for m in range(num_channels): - a_file = os.path.join(tmpdir, f'ch_{m}.wav') - sf.write(a_file, samples[:, m], sample_rate) - audio_files.append(a_file) - mc_file = os.path.join(tmpdir, f'mc.wav') - sf.write(mc_file, samples, sample_rate) - - for channel_selector in channel_selectors: - - # UUT: loading audio from a list of files - uut_segment = AudioSegment.from_file( - audio_file=audio_files, offset=offset, duration=duration, channel_selector=channel_selector - ) - - # Reference: load from the original file - ref_segment = AudioSegment.from_file( - audio_file=mc_file, offset=offset, duration=duration, channel_selector=channel_selector - ) - - # Check - assert ( - uut_segment.sample_rate == ref_segment.sample_rate - ), f'channel_selector {channel_selector}: expecting {ref_segment.sample_rate}, but UUT segment has {uut_segment.sample_rate}' - assert ( - uut_segment.num_samples == ref_segment.num_samples - ), f'channel_selector {channel_selector}: expecting {ref_segment.num_samples}, but UUT segment has {uut_segment.num_samples}' - assert np.allclose( - uut_segment.samples, ref_segment.samples, rtol=rtol, atol=atol - ), f'channel_selector {channel_selector}: samples not matching' - - # Try to get a channel that is out of range. - with pytest.raises(RuntimeError, match="Channel cannot be selected"): - AudioSegment.from_file(audio_file=audio_files, channel_selector=num_channels) - - if num_channels > 1: - # Try to load a list of multichannel files - # This is expected to fail since we only support loading a single-channel signal - # from each file when audio_file is a list - with pytest.raises(RuntimeError, match="Expecting a single-channel audio signal"): - AudioSegment.from_file(audio_file=[mc_file, mc_file]) - - with pytest.raises(RuntimeError, match="Expecting a single-channel audio signal"): - AudioSegment.from_file(audio_file=[mc_file, mc_file], channel_selector=0) - - @pytest.mark.unit - @pytest.mark.parametrize("target_sr", [8000, 16000]) - def test_audio_segment_trim_match(self, tmpdir, target_sr): - """Test loading and audio signal from a file matches when using a path and a list - for different target_sr, int_values and trim setups. - """ - sample_rate = 24000 - signal_len_sec = 2 - num_samples = signal_len_sec * sample_rate - num_examples = 10 - rtol, atol = 1e-5, 1e-6 - - TrimSetup = namedtuple("TrimSetup", "ref top_db frame_length hop_length") - trim_setups = [] - trim_setups.append(TrimSetup(np.max, 10, 2048, 1024)) - trim_setups.append(TrimSetup(1.0, 35, 2048, 1024)) - trim_setups.append(TrimSetup(0.8, 45, 2048, 1024)) - - for n in range(num_examples): - # Create a test vector - audio_file = os.path.join(tmpdir, f'test_audio_{n:02}.wav') - samples = np.random.randn(num_samples) - # normalize - samples = samples / np.max(samples) - # apply random scaling and window to have some samples cut by trim - samples = np.random.rand() * np.hanning(num_samples) * samples - sf.write(audio_file, samples, sample_rate, 'float') - - for trim_setup in trim_setups: - # UUT 1: load from a path - audio_segment_1 = AudioSegment.from_file( - audio_file, - target_sr=target_sr, - trim=True, - trim_ref=trim_setup.ref, - trim_top_db=trim_setup.top_db, - trim_frame_length=trim_setup.frame_length, - trim_hop_length=trim_setup.hop_length, - ) - - # UUT 2: load from a list - audio_segment_2 = AudioSegment.from_file( - [audio_file], - target_sr=target_sr, - trim=True, - trim_ref=trim_setup.ref, - trim_top_db=trim_setup.top_db, - trim_frame_length=trim_setup.frame_length, - trim_hop_length=trim_setup.hop_length, - ) - - # Test - assert audio_segment_1 == audio_segment_2, f'trim setup {trim_setup}, loaded segments not matching' - - -class TestSelectChannels: - num_samples = 1000 - max_diff_tol = 1e-9 - - @pytest.mark.unit - @pytest.mark.parametrize("channel_selector", [None, 'average', 0, 1, [0, 1]]) - def test_single_channel_input(self, channel_selector: Type[Union[str, int, List[int]]]): - """Cover the case with single-channel input signal. - Channel selector should not do anything in this case. - """ - golden_out = signal_in = np.random.rand(self.num_samples) - - if channel_selector not in [None, 0, 'average']: - # Expect a failure if looking for a different channel when input is 1D - with pytest.raises(ValueError): - # UUT - signal_out = select_channels(signal_in, channel_selector) - else: - # UUT - signal_out = select_channels(signal_in, channel_selector) - - # Check difference - max_diff = np.max(np.abs(signal_out - golden_out)) - assert max_diff < self.max_diff_tol - - @pytest.mark.unit - @pytest.mark.parametrize("num_channels", [2, 4]) - @pytest.mark.parametrize("channel_selector", [None, 'average', 0, [1], [0, 1]]) - def test_multi_channel_input(self, num_channels: int, channel_selector: Type[Union[str, int, List[int]]]): - """Cover the case with multi-channel input signal and single- - or multi-channel output. - """ - num_samples = 1000 - signal_in = np.random.rand(self.num_samples, num_channels) - - # calculate golden output - if channel_selector is None: - golden_out = signal_in - elif channel_selector == 'average': - golden_out = np.mean(signal_in, axis=1) - else: - golden_out = signal_in[:, channel_selector].squeeze() - - # UUT - signal_out = select_channels(signal_in, channel_selector) - - # Check difference - max_diff = np.max(np.abs(signal_out - golden_out)) - assert max_diff < self.max_diff_tol - - @pytest.mark.unit - @pytest.mark.parametrize("num_channels", [1, 2]) - @pytest.mark.parametrize("channel_selector", [2, [1, 2]]) - def test_select_more_channels_than_available( - self, num_channels: int, channel_selector: Type[Union[str, int, List[int]]] - ): - """This test is expecting the UUT to fail because we ask for more channels - than available in the input signal. - """ - num_samples = 1000 - signal_in = np.random.rand(self.num_samples, num_channels) - - # expect failure since we ask for more channels than available - with pytest.raises(ValueError): - # UUT - signal_out = select_channels(signal_in, channel_selector) - - -class TestGenerateApproximateNoiseField: - @pytest.mark.unit - @pytest.mark.parametrize('num_mics', [5]) - @pytest.mark.parametrize('mic_spacing', [0.05]) - @pytest.mark.parametrize('fft_length', [512, 2048]) - @pytest.mark.parametrize('sample_rate', [8000, 16000]) - @pytest.mark.parametrize('field', ['spherical']) - def test_theoretical_coherence_matrix( - self, num_mics: int, mic_spacing: float, fft_length: int, sample_rate: float, field: str - ): - """Test calculation of a theoretical coherence matrix. - """ - # test setup - max_diff_tol = 1e-9 - - # golden reference: spherical coherence - num_subbands = fft_length // 2 + 1 - angular_freq = 2 * np.pi * sample_rate * np.arange(0, num_subbands) / fft_length - golden_coherence = np.zeros((num_subbands, num_mics, num_mics)) - - for p in range(num_mics): - for q in range(num_mics): - if p == q: - golden_coherence[:, p, q] = 1.0 - else: - if field == 'spherical': - dist_pq = abs(p - q) * mic_spacing - sinc_arg = angular_freq * dist_pq / sound_velocity - golden_coherence[:, p, q] = np.sinc(sinc_arg / np.pi) - else: - raise NotImplementedError(f'Field {field} not supported.') - - # assume linear arrray - mic_positions = np.zeros((num_mics, 3)) - mic_positions[:, 0] = mic_spacing * np.arange(num_mics) - - # UUT - uut_coherence = theoretical_coherence( - mic_positions, sample_rate=sample_rate, fft_length=fft_length, field='spherical' - ) - - # Check difference - max_diff = np.max(np.abs(uut_coherence - golden_coherence)) - assert max_diff < max_diff_tol - - @pytest.mark.unit - @pytest.mark.parametrize('num_mics', [5]) - @pytest.mark.parametrize('mic_spacing', [0.10]) - @pytest.mark.parametrize('fft_length', [256, 512]) - @pytest.mark.parametrize('sample_rate', [8000, 16000]) - @pytest.mark.parametrize('field', ['spherical']) - def test_generate_approximate_noise_field( - self, - num_mics: int, - mic_spacing: float, - fft_length: int, - sample_rate: float, - field: str, - save_figures: bool = False, - ): - """Test approximate noise field with white noise as the input noise. - """ - duration_in_sec = 20 - relative_mse_tol_dB = -30 - relative_mse_tol = 10 ** (relative_mse_tol_dB / 10) - - num_samples = sample_rate * duration_in_sec - noise_signal = np.random.rand(num_samples, num_mics) - # random channel-wise power scaling - noise_signal *= np.random.randn(num_mics) - - # assume linear arrray - mic_positions = np.zeros((num_mics, 3)) - mic_positions[:, 0] = mic_spacing * np.arange(num_mics) - - # UUT - noise_field = generate_approximate_noise_field( - mic_positions, noise_signal, sample_rate=sample_rate, field=field, fft_length=fft_length - ) - - # Compare the estimated coherence with the theoretical coherence - - # reference - golden_coherence = theoretical_coherence( - mic_positions, sample_rate=sample_rate, field=field, fft_length=fft_length - ) - - # estimated - N = librosa.stft(noise_field.transpose(), n_fft=fft_length) - # (channel, subband, frame) -> (subband, frame, channel) - N = N.transpose(1, 2, 0) - uut_coherence = estimated_coherence(N) - - # Check difference - relative_mse_real = np.mean((uut_coherence.real - golden_coherence) ** 2) - assert relative_mse_real < relative_mse_tol - relative_mse_imag = np.mean((uut_coherence.imag) ** 2) - assert relative_mse_imag < relative_mse_tol - - if save_figures: - # For debugging and visualization template - figure_dir = os.path.expanduser('~/_coherence') - if not os.path.exists(figure_dir): - os.mkdir(figure_dir) - - freq = librosa.fft_frequencies(sr=sample_rate, n_fft=fft_length) - freq = freq / 1e3 # kHz - - plt.figure(figsize=(7, 10)) - for n in range(1, num_mics): - plt.subplot(num_mics - 1, 2, 2 * n - 1) - plt.plot(freq, golden_coherence[:, 0, n].real, label='golden') - plt.plot(freq, uut_coherence[:, 0, n].real, label='estimated') - plt.title(f'Real(coherence), p=0, q={n}') - plt.xlabel('f / kHz') - plt.grid() - plt.legend(loc='upper right') - - plt.subplot(num_mics - 1, 2, 2 * n) - plt.plot(golden_coherence[:, 0, n].imag, label='golden') - plt.plot(uut_coherence[:, 0, n].imag, label='estimated') - plt.title(f'Imag(coherence), p=0, q={n}') - plt.xlabel('f / kHz') - plt.grid() - plt.legend(loc='upper right') - - plt.tight_layout() - plt.savefig( - os.path.join( - figure_dir, f'num_mics_{num_mics}_sample_rate_{sample_rate}_fft_length_{fft_length}_{field}.png' - ) - ) - plt.close() - - -class TestAudioUtilsElements: - @pytest.mark.unit - def test_rms(self): - """Test RMS calculation - """ - # setup - A = np.random.rand() - omega = 100 - n_points = 1000 - rms_threshold = 1e-4 - # prep data - t = np.linspace(0, 2 * np.pi, n_points) - x = A * np.cos(2 * np.pi * omega * t) - # test - x_rms = rms(x) - golden_rms = A / np.sqrt(2) - assert ( - np.abs(x_rms - golden_rms) < rms_threshold - ), f'RMS not matching for A={A}, omega={omega}, n_point={n_points}' - - @pytest.mark.unit - def test_db_conversion(self): - """Test conversions to and from dB. - """ - num_examples = 10 - abs_threshold = 1e-6 - - mag = np.random.rand(num_examples) - mag_db = mag2db(mag) - - assert all(np.abs(mag - 10 ** (mag_db / 20)) < abs_threshold) - assert all(np.abs(db2mag(mag_db) - 10 ** (mag_db / 20)) < abs_threshold) - assert all(np.abs(pow2db(mag ** 2) - mag_db) < abs_threshold) - - @pytest.mark.unit - def test_get_segment_start(self): - random_seed = 42 - num_examples = 50 - num_samples = 2000 - - _rng = np.random.default_rng(seed=random_seed) - - for n in range(num_examples): - # Generate signal - signal = _rng.normal(size=num_samples) - # Random start in the first half - start = _rng.integers(low=0, high=num_samples // 2) - # Random length - end = _rng.integers(low=start, high=num_samples) - # Selected segment - segment = signal[start:end] - - # UUT - estimated_start = get_segment_start(signal=signal, segment=segment) - - assert ( - estimated_start == start - ), f'Example {n}: estimated start ({estimated_start}) not matching the actual start ({start})' - - @pytest.mark.unit - def test_calculate_sdr_numpy(self): - atol = 1e-6 - random_seed = 42 - num_examples = 50 - num_samples = 2000 - - _rng = np.random.default_rng(seed=random_seed) - - for n in range(num_examples): - # Generate signal - target = _rng.normal(size=num_samples) - # Adjust the estimate - golden_sdr = _rng.integers(low=-10, high=10) - estimate = target * (1 + 10 ** (-golden_sdr / 20)) - - # UUT - estimated_sdr = calculate_sdr_numpy(estimate=estimate, target=target, remove_mean=False) - - assert np.isclose( - estimated_sdr, golden_sdr, atol=atol - ), f'Example {n}: estimated ({estimated_sdr}) not matching the actual value ({golden_sdr})' - - # Add random mean and use remove_mean=True - # SDR should not change - target += _rng.uniform(low=-10, high=10) - estimate += _rng.uniform(low=-10, high=10) - - # UUT - estimated_sdr = calculate_sdr_numpy(estimate=estimate, target=target, remove_mean=True) - - assert np.isclose( - estimated_sdr, golden_sdr, atol=atol - ), f'Example {n}: estimated ({estimated_sdr}) not matching the actual value ({golden_sdr})' - - @pytest.mark.unit - def test_calculate_sdr_numpy_scale_invariant(self): - atol = 1e-6 - random_seed = 42 - num_examples = 50 - num_samples = 2000 - - _rng = np.random.default_rng(seed=random_seed) - - for n in range(num_examples): - # Generate signal - target = _rng.normal(size=num_samples) - # Adjust the estimate - estimate = target + _rng.uniform(low=0.01, high=1) * _rng.normal(size=target.size) - - # scaled target - target_scaled = target / (np.linalg.norm(target) + 1e-16) - target_scaled = np.sum(estimate * target_scaled) * target_scaled - - golden_sdr = calculate_sdr_numpy( - estimate=estimate, target=target_scaled, scale_invariant=False, remove_mean=False - ) - - # UUT - estimated_sdr = calculate_sdr_numpy( - estimate=estimate, target=target, scale_invariant=True, remove_mean=False - ) - - print(golden_sdr, estimated_sdr) - - assert np.isclose( - estimated_sdr, golden_sdr, atol=atol - ), f'Example {n}: estimated ({estimated_sdr}) not matching the actual value ({golden_sdr})' - - @pytest.mark.unit - @pytest.mark.parametrize('num_channels', [1, 3]) - @pytest.mark.parametrize('filter_length', [10]) - @pytest.mark.parametrize('delay', [0, 5]) - def test_convmtx_mc(self, num_channels: int, filter_length: int, delay: int): - """Test convmtx against convolve and sum. - Multiplication of convmtx_mc of input with a vectorized multi-channel filter - should match the sum of convolution of each input channel with the corresponding - filter. - """ - atol = 1e-6 - random_seed = 42 - num_examples = 10 - num_samples = 2000 - - _rng = np.random.default_rng(seed=random_seed) - - for n in range(num_examples): - x = _rng.normal(size=(num_samples, num_channels)) - f = _rng.normal(size=(filter_length, num_channels)) - - CM = convmtx_mc_numpy(x=x, filter_length=filter_length, delay=delay) - - # Multiply convmtx_mc with the vectorized filter - uut = CM @ f.transpose().reshape(-1, 1) - uut = uut.squeeze(1) - - # Calculate reference as sum of convolutions - golden_ref = 0 - for m in range(num_channels): - x_m_delayed = np.hstack([np.zeros(delay), x[:, m]]) - golden_ref += np.convolve(x_m_delayed, f[:, m], mode='full')[: len(x)] - - assert np.allclose(uut, golden_ref, atol=atol), f'Example {n}: UUT not matching the reference.' - - @pytest.mark.unit - @pytest.mark.parametrize('num_channels', [1, 3]) - @pytest.mark.parametrize('filter_length', [10]) - @pytest.mark.parametrize('num_samples', [10, 100]) - def test_toeplitz(self, num_channels: int, filter_length: int, num_samples: int): - """Test construction of a Toeplitz matrix for a given signal. - """ - atol = 1e-6 - random_seed = 42 - num_batches = 10 - batch_size = 8 - - _rng = np.random.default_rng(seed=random_seed) - - for n in range(num_batches): - x = _rng.normal(size=(batch_size, num_channels, num_samples)) - - # Construct Toeplitz matrix - Tx = toeplitz(x=torch.tensor(x)) - - # Compare against the reference - for b in range(batch_size): - for m in range(num_channels): - T_ref = scipy.linalg.toeplitz(x[b, m, ...]) - - assert np.allclose( - Tx[b, m, ...].cpu().numpy(), T_ref, atol=atol - ), f'Example {n}: not matching the reference for (b={b}, m={m}), .' diff --git a/tests/collections/asr/test_asr_data_simulation.py b/tests/collections/audio/test_audio_data_simulation.py similarity index 98% rename from tests/collections/asr/test_asr_data_simulation.py rename to tests/collections/audio/test_audio_data_simulation.py index 3cddf44f7657d..fed3ea2c3ea46 100644 --- a/tests/collections/asr/test_asr_data_simulation.py +++ b/tests/collections/audio/test_audio_data_simulation.py @@ -19,7 +19,8 @@ import pytest from numpy.random import default_rng -from nemo.collections.asr.data.data_simulation import ( +from nemo.collections.asr.parts.preprocessing.segment import AudioSegment +from nemo.collections.audio.data.data_simulation import ( ArrayGeometry, check_angle, convert_placement_to_range, @@ -27,14 +28,12 @@ simulate_room_mix, wrap_to_180, ) -from nemo.collections.asr.parts.preprocessing.segment import AudioSegment class TestDataSimulationUtils: @pytest.mark.unit def test_check_angle(self): - """Test angle checks. - """ + """Test angle checks.""" num_examples = 100 random = default_rng() @@ -61,8 +60,7 @@ def test_check_angle(self): @pytest.mark.unit def test_wrap_to_180(self): - """Test wrap. - """ + """Test wrap.""" test_cases = [] test_cases.append({'angle': 0, 'wrapped': 0}) test_cases.append({'angle': 45, 'wrapped': 45}) @@ -81,8 +79,7 @@ def test_wrap_to_180(self): @pytest.mark.unit def test_placement_range(self): - """Test placement range conversion. - """ + """Test placement range conversion.""" # Setup 1: test_cases = [] test_cases.append( @@ -181,8 +178,7 @@ def test_placement_range(self): @pytest.mark.parametrize("num_mics", [2, 4]) @pytest.mark.parametrize("num_sources", [1, 3]) def test_convert_rir_to_mc(self, num_mics: int, num_sources: int): - """Test conversion of a RIR from list of lists to multichannel array. - """ + """Test conversion of a RIR from list of lists to multichannel array.""" len_range = [50, 1000] random = default_rng() @@ -335,8 +331,7 @@ class TestRoomSimulation: @pytest.mark.unit def test_simulate_room_mix(self, test_data_dir): - """Test room simulation for fixed parameters. - """ + """Test room simulation for fixed parameters.""" # Test setup data_dir = os.path.join(test_data_dir, 'asr', 'data_simulation') diff --git a/tests/collections/audio/test_audio_datasets.py b/tests/collections/audio/test_audio_datasets.py new file mode 100644 index 0000000000000..d957234fc90b2 --- /dev/null +++ b/tests/collections/audio/test_audio_datasets.py @@ -0,0 +1,1156 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import tempfile + +import numpy as np +import pytest +import soundfile as sf +import torch.cuda +from omegaconf import OmegaConf + +from nemo.collections.asr.parts.utils.manifest_utils import write_manifest +from nemo.collections.audio.data import audio_to_audio_dataset +from nemo.collections.audio.data.audio_to_audio import ( + ASRAudioProcessor, + AudioToTargetDataset, + AudioToTargetWithEmbeddingDataset, + AudioToTargetWithReferenceDataset, + _audio_collate_fn, +) +from nemo.collections.audio.data.audio_to_audio_lhotse import ( + LhotseAudioToTargetDataset, + convert_manifest_nemo_to_lhotse, +) +from nemo.collections.audio.parts.utils.audio import get_segment_start +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config + + +class TestAudioDatasets: + @pytest.mark.unit + @pytest.mark.parametrize('num_channels', [1, 2]) + @pytest.mark.parametrize('num_targets', [1, 3]) + def test_list_to_multichannel(self, num_channels, num_targets): + """Test conversion of a list of arrays into""" + random_seed = 42 + num_samples = 1000 + + # Generate random signals + _rng = np.random.default_rng(seed=random_seed) + + # Multi-channel signal + golden_target = _rng.normal(size=(num_channels * num_targets, num_samples)) + + # Create a list of num_targets signals with num_channels channels + target_list = [golden_target[n * num_channels : (n + 1) * num_channels, :] for n in range(num_targets)] + + # Check the original signal is not modified + assert (ASRAudioProcessor.list_to_multichannel(golden_target) == golden_target).all() + # Check the list is converted back to the original signal + assert (ASRAudioProcessor.list_to_multichannel(target_list) == golden_target).all() + + @pytest.mark.unit + @pytest.mark.parametrize('num_channels', [1, 2]) + def test_processor_process_audio(self, num_channels): + """Test signal normalization in process_audio.""" + num_samples = 1000 + num_examples = 30 + + signals = ['input_signal', 'target_signal', 'reference_signal'] + + for normalization_signal in [None] + signals: + # Create processor + processor = ASRAudioProcessor( + sample_rate=16000, random_offset=False, normalization_signal=normalization_signal + ) + + # Generate random signals + for n in range(num_examples): + example = {signal: torch.randn(num_channels, num_samples) for signal in signals} + processed_example = processor.process_audio(example) + + # Expected scale + if normalization_signal: + scale = 1.0 / (example[normalization_signal].abs().max() + processor.eps) + else: + scale = 1.0 + + # Make sure all signals are scaled as expected + for signal in signals: + assert torch.allclose( + processed_example[signal], example[signal] * scale + ), f'Failed example {n} signal {signal}' + + @pytest.mark.unit + def test_audio_collate_fn(self): + """Test `_audio_collate_fn`""" + batch_size = 16 + random_seed = 42 + atol = 1e-5 + + # Generate random signals + _rng = np.random.default_rng(seed=random_seed) + + signal_to_channels = { + 'input_signal': 2, + 'target_signal': 1, + 'reference_signal': 1, + } + + signal_to_length = { + 'input_signal': _rng.integers(low=5, high=25, size=batch_size), + 'target_signal': _rng.integers(low=5, high=25, size=batch_size), + 'reference_signal': _rng.integers(low=5, high=25, size=batch_size), + } + + # Generate batch + batch = [] + for n in range(batch_size): + item = dict() + for signal, num_channels in signal_to_channels.items(): + random_signal = _rng.normal(size=(num_channels, signal_to_length[signal][n])) + random_signal = np.squeeze(random_signal) # get rid of channel dimention for single-channel + item[signal] = torch.tensor(random_signal) + batch.append(item) + + # Run UUT + batched = _audio_collate_fn(batch) + + batched_signals = { + 'input_signal': batched[0].cpu().detach().numpy(), + 'target_signal': batched[2].cpu().detach().numpy(), + 'reference_signal': batched[4].cpu().detach().numpy(), + } + + batched_lengths = { + 'input_signal': batched[1].cpu().detach().numpy(), + 'target_signal': batched[3].cpu().detach().numpy(), + 'reference_signal': batched[5].cpu().detach().numpy(), + } + + # Check outputs + for signal, b_signal in batched_signals.items(): + for n in range(batch_size): + # Check length + uut_length = batched_lengths[signal][n] + golden_length = signal_to_length[signal][n] + assert ( + uut_length == golden_length + ), f'Example {n} signal {signal} length mismatch: batched ({uut_length}) != golden ({golden_length})' + + uut_signal = b_signal[n][:uut_length, ...] + golden_signal = batch[n][signal][:uut_length, ...].cpu().detach().numpy() + assert np.allclose( + uut_signal, golden_signal, atol=atol + ), f'Example {n} signal {signal} value mismatch.' + + @pytest.mark.unit + def test_audio_to_target_dataset(self): + """Test AudioWithTargetDataset in different configurations. + + Test below cover the following: + 1) no constraints + 2) filtering based on signal duration + 3) use with channel selector + 4) use with fixed audio duration and random subsegments + 5) collate a batch of items + + In this use case, each line of the manifest file has the following format: + ``` + { + 'input_filepath': 'path/to/input.wav', + 'target_filepath': 'path/to/path_to_target.wav', + 'duration': duration_of_input, + } + ``` + """ + # Data setup + random_seed = 42 + sample_rate = 16000 + num_examples = 25 + data_num_channels = { + 'input_signal': 4, + 'target_signal': 2, + } + data_min_duration = 2.0 + data_max_duration = 8.0 + data_key = { + 'input_signal': 'input_filepath', + 'target_signal': 'target_filepath', + } + + # Tolerance + atol = 1e-6 + + # Generate random signals + _rng = np.random.default_rng(seed=random_seed) + + # Input and target signals have the same duration + data_duration = np.round(_rng.uniform(low=data_min_duration, high=data_max_duration, size=num_examples), 3) + data_duration_samples = np.floor(data_duration * sample_rate).astype(int) + + data = dict() + for signal, num_channels in data_num_channels.items(): + data[signal] = [] + for n in range(num_examples): + if num_channels == 1: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_duration_samples[n])) + else: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(num_channels, data_duration_samples[n])) + data[signal].append(random_signal) + + with tempfile.TemporaryDirectory() as test_dir: + + # Build metadata for manifest + metadata = [] + + for n in range(num_examples): + + meta = dict() + + for signal in data: + # filenames + signal_filename = f'{signal}_{n:02d}.wav' + + # write audio files + sf.write(os.path.join(test_dir, signal_filename), data[signal][n].T, sample_rate, 'float') + + # update metadata + meta[data_key[signal]] = signal_filename + + meta['duration'] = data_duration[n] + metadata.append(meta) + + # Save manifest + manifest_filepath = os.path.join(test_dir, 'manifest.json') + write_manifest(manifest_filepath, metadata) + + # Test 1 + # - No constraints on channels or duration + dataset = AudioToTargetDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + sample_rate=sample_rate, + ) + + # Also test the corresponding factory + config = { + 'manifest_filepath': manifest_filepath, + 'input_key': data_key['input_signal'], + 'target_key': data_key['target_signal'], + 'sample_rate': sample_rate, + } + dataset_factory = audio_to_audio_dataset.get_audio_to_target_dataset(config) + + # Prepare lhotse manifest + cuts_path = manifest_filepath.replace('.json', '_cuts.jsonl') + convert_manifest_nemo_to_lhotse( + input_manifest=manifest_filepath, + output_manifest=cuts_path, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + ) + + # Prepare lhotse dataset + config_lhotse = { + 'cuts_path': cuts_path, + 'use_lhotse': True, + 'sample_rate': sample_rate, + 'batch_size': 1, + } + dl_lhotse = get_lhotse_dataloader_from_config( + OmegaConf.create(config_lhotse), global_rank=0, world_size=1, dataset=LhotseAudioToTargetDataset() + ) + dataset_lhotse = [item for item in dl_lhotse] + + # Test number of channels + for signal in data: + assert data_num_channels[signal] == dataset.num_channels( + signal + ), f'Num channels not correct for signal {signal}' + assert data_num_channels[signal] == dataset_factory.num_channels( + signal + ), f'Num channels not correct for signal {signal}' + + # Test returned examples + for n in range(num_examples): + for signal in data: + golden_signal = data[signal][n] + + for use_lhotse in [False, True]: + item_signal = ( + dataset_lhotse[n][signal].squeeze(0) if use_lhotse else dataset.__getitem__(n)[signal] + ) + item_factory_signal = dataset_factory.__getitem__(n)[signal] + + assert ( + item_signal.shape == golden_signal.shape + ), f'Test 1, use_lhotse={use_lhotse}: Signal {signal} item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + assert np.allclose( + item_signal, golden_signal, atol=atol + ), f'Test 1, use_lhotse={use_lhotse}: Failed for example {n}, signal {signal} (random seed {random_seed})' + + assert np.allclose( + item_factory_signal, golden_signal, atol=atol + ), f'Test 1, use_lhotse={use_lhotse}: Failed for factory example {n}, signal {signal} (random seed {random_seed})' + + # Test 2 + # - Filtering based on signal duration + min_duration = 3.5 + max_duration = 7.5 + + dataset = AudioToTargetDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + min_duration=min_duration, + max_duration=max_duration, + sample_rate=sample_rate, + ) + + # Prepare lhotse dataset + config_lhotse = { + 'cuts_path': cuts_path, + 'use_lhotse': True, + 'min_duration': min_duration, + 'max_duration': max_duration, + 'sample_rate': sample_rate, + 'batch_size': 1, + } + dl_lhotse = get_lhotse_dataloader_from_config( + OmegaConf.create(config_lhotse), global_rank=0, world_size=1, dataset=LhotseAudioToTargetDataset() + ) + dataset_lhotse = [item for item in dl_lhotse] + + filtered_examples = [n for n, val in enumerate(data_duration) if min_duration <= val <= max_duration] + + for n in range(len(dataset)): + for use_lhotse in [False, True]: + for signal in data: + item_signal = ( + dataset_lhotse[n][signal].squeeze(0) if use_lhotse else dataset.__getitem__(n)[signal] + ) + golden_signal = data[signal][filtered_examples[n]] + assert ( + item_signal.shape == golden_signal.shape + ), f'Test 2, use_lhotse={use_lhotse}: Signal {signal} item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + + assert np.allclose( + item_signal, golden_signal, atol=atol + ), f'Test 2, use_lhotse={use_lhotse}: Failed for example {n}, signal {signal} (random seed {random_seed})' + + # Test 3 + # - Use channel selector + channel_selector = { + 'input_signal': [0, 2], + 'target_signal': 1, + } + + dataset = AudioToTargetDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + input_channel_selector=channel_selector['input_signal'], + target_channel_selector=channel_selector['target_signal'], + sample_rate=sample_rate, + ) + + for n in range(len(dataset)): + item = dataset.__getitem__(n) + + for signal in data: + cs = channel_selector[signal] + item_signal = item[signal].cpu().detach().numpy() + golden_signal = data[signal][n][cs, ...] + assert ( + item_signal.shape == golden_signal.shape + ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + assert np.allclose( + item_signal, golden_signal, atol=atol + ), f'Test 3: Failed for example {n}, signal {signal} (random seed {random_seed})' + + # Test 4 + # - Use fixed duration (random segment selection) + audio_duration = 4.0 + audio_duration_samples = int(np.floor(audio_duration * sample_rate)) + + filtered_examples = [n for n, val in enumerate(data_duration) if val >= audio_duration] + + for random_offset in [True, False]: + # Test subsegments with the default fixed offset and a random offset + + dataset = AudioToTargetDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + sample_rate=sample_rate, + min_duration=audio_duration, + audio_duration=audio_duration, + random_offset=random_offset, # random offset when selecting subsegment + ) + + # Prepare lhotse dataset + config_lhotse = { + 'cuts_path': cuts_path, + 'use_lhotse': True, + 'min_duration': audio_duration, + 'truncate_duration': audio_duration, + 'truncate_offset_type': 'random' if random_offset else 'start', + 'sample_rate': sample_rate, + 'batch_size': 1, + } + dl_lhotse = get_lhotse_dataloader_from_config( + OmegaConf.create(config_lhotse), global_rank=0, world_size=1, dataset=LhotseAudioToTargetDataset() + ) + dataset_lhotse = [item for item in dl_lhotse] + + for n in range(len(dataset)): + for use_lhotse in [False, True]: + item = dataset_lhotse[n] if use_lhotse else dataset.__getitem__(n) + golden_start = golden_end = None + for signal in data: + item_signal = item[signal].squeeze(0) if use_lhotse else item[signal] + full_golden_signal = data[signal][filtered_examples[n]] + + # Find random segment using correlation on the first channel + # of the first signal, and then use it fixed for other signals + if golden_start is None: + golden_start = get_segment_start( + signal=full_golden_signal[0, :], segment=item_signal[0, :] + ) + if not random_offset: + assert ( + golden_start == 0 + ), f'Test 4, use_lhotse={use_lhotse}: Expecting the signal to start at 0 when random_offset is False' + + golden_end = golden_start + audio_duration_samples + golden_signal = full_golden_signal[..., golden_start:golden_end] + + # Test length is correct + assert ( + item_signal.shape[-1] == audio_duration_samples + ), f'Test 4, use_lhotse={use_lhotse}: Signal length ({item_signal.shape[-1]}) not matching the expected length ({audio_duration_samples})' + + assert ( + item_signal.shape == golden_signal.shape + ), f'Test 4, use_lhotse={use_lhotse}: Signal {signal} item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + # Test signal values + assert np.allclose( + item_signal, golden_signal, atol=atol + ), f'Test 4, use_lhotse={use_lhotse}: Failed for example {n}, signal {signal} (random seed {random_seed})' + + # Test 5: + # - Test collate_fn + batch_size = 16 + + for use_lhotse in [False, True]: + if use_lhotse: + # Get batch from lhotse dataloader + config_lhotse['batch_size'] = batch_size + dl_lhotse = get_lhotse_dataloader_from_config( + OmegaConf.create(config_lhotse), + global_rank=0, + world_size=1, + dataset=LhotseAudioToTargetDataset(), + ) + batched = next(iter(dl_lhotse)) + else: + # Get examples from dataset and collate into a batch + batch = [dataset.__getitem__(n) for n in range(batch_size)] + batched = dataset.collate_fn(batch) + + # Test all shapes and lengths + for n, signal in enumerate(data.keys()): + length = signal.replace('_signal', '_length') + + if isinstance(batched, dict): + signal_shape = batched[signal].shape + signal_len = batched[length] + else: + signal_shape = batched[2 * n].shape + signal_len = batched[2 * n + 1] + + assert signal_shape == ( + batch_size, + data_num_channels[signal], + audio_duration_samples, + ), f'Test 5, use_lhotse={use_lhotse}: Unexpected signal {signal} shape {signal_shape}' + assert ( + len(signal_len) == batch_size + ), f'Test 5, use_lhotse={use_lhotse}: Unexpected length of signal_len ({len(signal_len)})' + assert all( + signal_len == audio_duration_samples + ), f'Test 5, use_lhotse={use_lhotse}: Unexpected signal_len {signal_len}' + + @pytest.mark.unit + def test_audio_to_target_dataset_with_target_list(self): + """Test AudioWithTargetDataset when the input manifest has a list + of audio files in the target key. + + In this use case, each line of the manifest file has the following format: + ``` + { + 'input_filepath': 'path/to/input.wav', + 'target_filepath': ['path/to/path_to_target_ch0.wav', 'path/to/path_to_target_ch1.wav'], + 'duration': duration_of_input, + } + ``` + """ + # Data setup + random_seed = 42 + sample_rate = 16000 + num_examples = 25 + data_num_channels = { + 'input_signal': 4, + 'target_signal': 2, + } + data_min_duration = 2.0 + data_max_duration = 8.0 + data_key = { + 'input_signal': 'input_filepath', + 'target_signal': 'target_filepath', + } + + # Tolerance + atol = 1e-6 + + # Generate random signals + _rng = np.random.default_rng(seed=random_seed) + + # Input and target signals have the same duration + data_duration = np.round(_rng.uniform(low=data_min_duration, high=data_max_duration, size=num_examples), 3) + data_duration_samples = np.floor(data_duration * sample_rate).astype(int) + + data = dict() + for signal, num_channels in data_num_channels.items(): + data[signal] = [] + for n in range(num_examples): + if num_channels == 1: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_duration_samples[n])) + else: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(num_channels, data_duration_samples[n])) + data[signal].append(random_signal) + + with tempfile.TemporaryDirectory() as test_dir: + + # Build metadata for manifest + metadata = [] + + for n in range(num_examples): + + meta = dict() + + for signal in data: + if signal == 'target_signal': + # Save targets as individual files + signal_filename = [] + for ch in range(data_num_channels[signal]): + # add current filename + signal_filename.append(f'{signal}_{n:02d}_ch_{ch}.wav') + # write audio file + sf.write( + os.path.join(test_dir, signal_filename[-1]), + data[signal][n][ch, :], + sample_rate, + 'float', + ) + else: + # single file + signal_filename = f'{signal}_{n:02d}.wav' + + # write audio files + sf.write(os.path.join(test_dir, signal_filename), data[signal][n].T, sample_rate, 'float') + + # update metadata + meta[data_key[signal]] = signal_filename + + meta['duration'] = data_duration[n] + metadata.append(meta) + + # Save manifest + manifest_filepath = os.path.join(test_dir, 'manifest.json') + write_manifest(manifest_filepath, metadata) + + # Test 1 + # - No constraints on channels or duration + dataset = AudioToTargetDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + sample_rate=sample_rate, + ) + + config = { + 'manifest_filepath': manifest_filepath, + 'input_key': data_key['input_signal'], + 'target_key': data_key['target_signal'], + 'sample_rate': sample_rate, + } + dataset_factory = audio_to_audio_dataset.get_audio_to_target_dataset(config) + + # Prepare lhotse manifest + cuts_path = manifest_filepath.replace('.json', '_cuts.jsonl') + convert_manifest_nemo_to_lhotse( + input_manifest=manifest_filepath, + output_manifest=cuts_path, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + ) + + # Prepare lhotse dataset + config_lhotse = { + 'cuts_path': cuts_path, + 'use_lhotse': True, + 'sample_rate': sample_rate, + 'batch_size': 1, + } + dl_lhotse = get_lhotse_dataloader_from_config( + OmegaConf.create(config_lhotse), global_rank=0, world_size=1, dataset=LhotseAudioToTargetDataset() + ) + dataset_lhotse = [item for item in dl_lhotse] + + for n in range(num_examples): + for use_lhotse in [False, True]: + item = dataset_lhotse[n] if use_lhotse else dataset.__getitem__(n) + item_factory = dataset_factory.__getitem__(n) + for signal in data: + item_signal = item[signal].squeeze(0) if use_lhotse else item[signal] + golden_signal = data[signal][n] + assert ( + item_signal.shape == golden_signal.shape + ), f'Test 1, use_lhotse={use_lhotse}: Signal {signal} item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + assert np.allclose( + item_signal, golden_signal, atol=atol + ), f'Test 1, use_lhotse={use_lhotse}: Failed for example {n}, signal {signal} (random seed {random_seed})' + + assert np.allclose( + item_factory[signal], golden_signal, atol=atol + ), f'Test 1, use_lhotse={use_lhotse}: Failed for factory example {n}, signal {signal} (random seed {random_seed})' + + # Test 2 + # Set target as the first channel of input_filepath and all files listed in target_filepath. + # In this case, the target will have 3 channels. + # Note: this is currently not supported by lhotse, so we only test the default dataset here. + dataset = AudioToTargetDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=[data_key['input_signal'], data_key['target_signal']], + target_channel_selector=0, + sample_rate=sample_rate, + ) + + for n in range(num_examples): + item = dataset.__getitem__(n) + + for signal in data: + item_signal = item[signal].cpu().detach().numpy() + golden_signal = data[signal][n] + if signal == 'target_signal': + # add the first channel of the input + golden_signal = np.concatenate([data['input_signal'][n][0:1, ...], golden_signal], axis=0) + assert ( + item_signal.shape == golden_signal.shape + ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + assert np.allclose( + item_signal, golden_signal, atol=atol + ), f'Test 2: Failed for example {n}, signal {signal} (random seed {random_seed})' + + @pytest.mark.unit + def test_audio_to_target_dataset_for_inference(self): + """Test AudioWithTargetDataset when target_key is + not set, i.e., it is `None`. This is the case, e.g., when + running inference, and a target is not available. + + In this use case, each line of the manifest file has the following format: + ``` + { + 'input_filepath': 'path/to/input.wav', + 'duration': duration_of_input, + } + ``` + """ + # Data setup + random_seed = 42 + sample_rate = 16000 + num_examples = 25 + data_num_channels = { + 'input_signal': 4, + } + data_min_duration = 2.0 + data_max_duration = 8.0 + data_key = { + 'input_signal': 'input_filepath', + } + + # Tolerance + atol = 1e-6 + + # Generate random signals + _rng = np.random.default_rng(seed=random_seed) + + # Input and target signals have the same duration + data_duration = np.round(_rng.uniform(low=data_min_duration, high=data_max_duration, size=num_examples), 3) + data_duration_samples = np.floor(data_duration * sample_rate).astype(int) + + data = dict() + for signal, num_channels in data_num_channels.items(): + data[signal] = [] + for n in range(num_examples): + if num_channels == 1: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_duration_samples[n])) + else: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(num_channels, data_duration_samples[n])) + data[signal].append(random_signal) + + with tempfile.TemporaryDirectory() as test_dir: + # Build metadata for manifest + metadata = [] + for n in range(num_examples): + meta = dict() + for signal in data: + # filenames + signal_filename = f'{signal}_{n:02d}.wav' + # write audio files + sf.write(os.path.join(test_dir, signal_filename), data[signal][n].T, sample_rate, 'float') + # update metadata + meta[data_key[signal]] = signal_filename + meta['duration'] = data_duration[n] + metadata.append(meta) + + # Save manifest + manifest_filepath = os.path.join(test_dir, 'manifest.json') + write_manifest(manifest_filepath, metadata) + + # Test 1 + # - No constraints on channels or duration + dataset = AudioToTargetDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=None, # target_signal will be empty + sample_rate=sample_rate, + ) + + # Also test the corresponding factory + config = { + 'manifest_filepath': manifest_filepath, + 'input_key': data_key['input_signal'], + 'target_key': None, + 'sample_rate': sample_rate, + } + dataset_factory = audio_to_audio_dataset.get_audio_to_target_dataset(config) + + # Prepare lhotse manifest + cuts_path = manifest_filepath.replace('.json', '_cuts.jsonl') + convert_manifest_nemo_to_lhotse( + input_manifest=manifest_filepath, + output_manifest=cuts_path, + input_key=data_key['input_signal'], + target_key=None, + ) + + # Prepare lhotse dataset + config_lhotse = { + 'cuts_path': cuts_path, + 'use_lhotse': True, + 'sample_rate': sample_rate, + 'batch_size': 1, + } + dl_lhotse = get_lhotse_dataloader_from_config( + OmegaConf.create(config_lhotse), global_rank=0, world_size=1, dataset=LhotseAudioToTargetDataset() + ) + dataset_lhotse = [item for item in dl_lhotse] + + for n in range(num_examples): + + for label in ['original', 'factory', 'lhotse']: + + if label == 'original': + item = dataset.__getitem__(n) + elif label == 'factory': + item = dataset_factory.__getitem__(n) + elif label == 'lhotse': + item = dataset_lhotse[n] + else: + raise ValueError(f'Unknown label {label}') + + # Check target is None + if 'target_signal' in item: + assert item['target_signal'].numel() == 0, f'{label}: target_signal is expected to be empty.' + + # Check valid signals + for signal in data: + + item_signal = item[signal].squeeze(0) if label == 'lhotse' else item[signal] + golden_signal = data[signal][n] + assert ( + item_signal.shape == golden_signal.shape + ), f'{label} -- Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + assert np.allclose( + item_signal, golden_signal, atol=atol + ), f'{label} -- Test 1: Failed for example {n}, signal {signal} (random seed {random_seed})' + + @pytest.mark.unit + def test_audio_to_target_with_reference_dataset(self): + """Test AudioWithTargetWithReferenceDataset in different configurations. + + 1) reference synchronized with input and target + 2) reference not synchronized + + In this use case, each line of the manifest file has the following format: + ``` + { + 'input_filepath': 'path/to/input.wav', + 'target_filepath': 'path/to/path_to_target.wav', + 'reference_filepath': 'path/to/path_to_reference.wav', + 'duration': duration_of_input, + } + ``` + """ + # Data setup + random_seed = 42 + sample_rate = 16000 + num_examples = 25 + data_num_channels = { + 'input_signal': 4, + 'target_signal': 2, + 'reference_signal': 1, + } + data_min_duration = 2.0 + data_max_duration = 8.0 + data_key = { + 'input_signal': 'input_filepath', + 'target_signal': 'target_filepath', + 'reference_signal': 'reference_filepath', + } + + # Tolerance + atol = 1e-6 + + # Generate random signals + _rng = np.random.default_rng(seed=random_seed) + + # Input and target signals have the same duration + data_duration = np.round(_rng.uniform(low=data_min_duration, high=data_max_duration, size=num_examples), 3) + data_duration_samples = np.floor(data_duration * sample_rate).astype(int) + + data = dict() + for signal, num_channels in data_num_channels.items(): + data[signal] = [] + for n in range(num_examples): + if num_channels == 1: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_duration_samples[n])) + else: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(num_channels, data_duration_samples[n])) + data[signal].append(random_signal) + + with tempfile.TemporaryDirectory() as test_dir: + + # Build metadata for manifest + metadata = [] + + for n in range(num_examples): + + meta = dict() + + for signal in data: + # filenames + signal_filename = f'{signal}_{n:02d}.wav' + + # write audio files + sf.write(os.path.join(test_dir, signal_filename), data[signal][n].T, sample_rate, 'float') + + # update metadata + meta[data_key[signal]] = signal_filename + + meta['duration'] = data_duration[n] + metadata.append(meta) + + # Save manifest + manifest_filepath = os.path.join(test_dir, 'manifest.json') + write_manifest(manifest_filepath, metadata) + + # Test 1 + # - No constraints on channels or duration + # - Reference is not synchronized with input and target, so whole reference signal will be loaded + dataset = AudioToTargetWithReferenceDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + reference_key=data_key['reference_signal'], + reference_is_synchronized=False, + sample_rate=sample_rate, + ) + + # Also test the corresponding factory + config = { + 'manifest_filepath': manifest_filepath, + 'input_key': data_key['input_signal'], + 'target_key': data_key['target_signal'], + 'reference_key': data_key['reference_signal'], + 'reference_is_synchronized': False, + 'sample_rate': sample_rate, + } + dataset_factory = audio_to_audio_dataset.get_audio_to_target_with_reference_dataset(config) + + for n in range(num_examples): + item = dataset.__getitem__(n) + item_factory = dataset_factory.__getitem__(n) + + for signal in data: + item_signal = item[signal].cpu().detach().numpy() + golden_signal = data[signal][n] + assert ( + item_signal.shape == golden_signal.shape + ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + assert np.allclose( + item_signal, golden_signal, atol=atol + ), f'Test 1: Failed for example {n}, signal {signal} (random seed {random_seed})' + + item_factory_signal = item_factory[signal].cpu().detach().numpy() + assert np.allclose( + item_factory_signal, golden_signal, atol=atol + ), f'Test 1: Failed for factory example {n}, signal {signal} (random seed {random_seed})' + + # Test 2 + # - Use fixed duration (random segment selection) + # - Reference is synchronized with input and target, so the same segment of reference signal will be loaded + audio_duration = 4.0 + audio_duration_samples = int(np.floor(audio_duration * sample_rate)) + dataset = AudioToTargetWithReferenceDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + reference_key=data_key['reference_signal'], + reference_is_synchronized=True, + sample_rate=sample_rate, + min_duration=audio_duration, + audio_duration=audio_duration, + random_offset=True, + ) + + filtered_examples = [n for n, val in enumerate(data_duration) if val >= audio_duration] + + for n in range(len(dataset)): + item = dataset.__getitem__(n) + + golden_start = golden_end = None + for signal in data: + item_signal = item[signal].cpu().detach().numpy() + full_golden_signal = data[signal][filtered_examples[n]] + + # Find random segment using correlation on the first channel + # of the first signal, and then use it fixed for other signals + if golden_start is None: + golden_start = get_segment_start(signal=full_golden_signal[0, :], segment=item_signal[0, :]) + golden_end = golden_start + audio_duration_samples + golden_signal = full_golden_signal[..., golden_start:golden_end] + + # Test length is correct + assert ( + item_signal.shape[-1] == audio_duration_samples + ), f'Test 2: Signal {signal} length ({item_signal.shape[-1]}) not matching the expected length ({audio_duration_samples})' + + # Test signal values + assert np.allclose( + item_signal, golden_signal, atol=atol + ), f'Test 2: Failed for example {n}, signal {signal} (random seed {random_seed})' + + # Test 3 + # - Use fixed duration (random segment selection) + # - Reference is not synchronized with input and target, so whole reference signal will be loaded + audio_duration = 4.0 + audio_duration_samples = int(np.floor(audio_duration * sample_rate)) + dataset = AudioToTargetWithReferenceDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + reference_key=data_key['reference_signal'], + reference_is_synchronized=False, + sample_rate=sample_rate, + min_duration=audio_duration, + audio_duration=audio_duration, + random_offset=True, + ) + + filtered_examples = [n for n, val in enumerate(data_duration) if val >= audio_duration] + + for n in range(len(dataset)): + item = dataset.__getitem__(n) + + golden_start = golden_end = None + for signal in data: + item_signal = item[signal].cpu().detach().numpy() + full_golden_signal = data[signal][filtered_examples[n]] + + if signal == 'reference_signal': + # Complete signal is loaded for reference + golden_signal = full_golden_signal + else: + # Find random segment using correlation on the first channel + # of the first signal, and then use it fixed for other signals + if golden_start is None: + golden_start = get_segment_start( + signal=full_golden_signal[0, :], segment=item_signal[0, :] + ) + golden_end = golden_start + audio_duration_samples + golden_signal = full_golden_signal[..., golden_start:golden_end] + + # Test length is correct + assert ( + item_signal.shape[-1] == audio_duration_samples + ), f'Test 3: Signal {signal} length ({item_signal.shape[-1]}) not matching the expected length ({audio_duration_samples})' + assert ( + item_signal.shape == golden_signal.shape + ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + # Test signal values + assert np.allclose( + item_signal, golden_signal, atol=atol + ), f'Test 3: Failed for example {n}, signal {signal} (random seed {random_seed})' + + # Test 4: + # - Test collate_fn + batch_size = 16 + batch = [dataset.__getitem__(n) for n in range(batch_size)] + _ = dataset.collate_fn(batch) + + @pytest.mark.unit + def test_audio_to_target_with_embedding_dataset(self): + """Test AudioWithTargetWithEmbeddingDataset. + + In this use case, each line of the manifest file has the following format: + ``` + { + 'input_filepath': 'path/to/input.wav', + 'target_filepath': 'path/to/path_to_target.wav', + 'embedding_filepath': 'path/to/path_to_embedding.npy', + 'duration': duration_of_input, + } + ``` + """ + # Data setup + random_seed = 42 + sample_rate = 16000 + num_examples = 25 + data_num_channels = { + 'input_signal': 4, + 'target_signal': 2, + 'embedding_vector': 1, + } + data_min_duration = 2.0 + data_max_duration = 8.0 + embedding_length = 64 # 64-dimensional embedding vector + data_key = { + 'input_signal': 'input_filepath', + 'target_signal': 'target_filepath', + 'embedding_vector': 'embedding_filepath', + } + + # Tolerance + atol = 1e-6 + + # Generate random signals + _rng = np.random.default_rng(seed=random_seed) + + # Input and target signals have the same duration + data_duration = np.round(_rng.uniform(low=data_min_duration, high=data_max_duration, size=num_examples), 3) + data_duration_samples = np.floor(data_duration * sample_rate).astype(int) + + data = dict() + for signal, num_channels in data_num_channels.items(): + data[signal] = [] + for n in range(num_examples): + data_length = embedding_length if signal == 'embedding_vector' else data_duration_samples[n] + + if num_channels == 1: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_length)) + else: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(num_channels, data_length)) + data[signal].append(random_signal) + + with tempfile.TemporaryDirectory() as test_dir: + + # Build metadata for manifest + metadata = [] + + for n in range(num_examples): + + meta = dict() + + for signal in data: + if signal == 'embedding_vector': + signal_filename = f'{signal}_{n:02d}.npy' + np.save(os.path.join(test_dir, signal_filename), data[signal][n]) + + else: + # filenames + signal_filename = f'{signal}_{n:02d}.wav' + + # write audio files + sf.write(os.path.join(test_dir, signal_filename), data[signal][n].T, sample_rate, 'float') + + # update metadata + meta[data_key[signal]] = signal_filename + + meta['duration'] = data_duration[n] + metadata.append(meta) + + # Save manifest + manifest_filepath = os.path.join(test_dir, 'manifest.json') + write_manifest(manifest_filepath, metadata) + + # Test 1 + # - No constraints on channels or duration + dataset = AudioToTargetWithEmbeddingDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + embedding_key=data_key['embedding_vector'], + sample_rate=sample_rate, + ) + + # Also test the corresponding factory + config = { + 'manifest_filepath': manifest_filepath, + 'input_key': data_key['input_signal'], + 'target_key': data_key['target_signal'], + 'embedding_key': data_key['embedding_vector'], + 'sample_rate': sample_rate, + } + dataset_factory = audio_to_audio_dataset.get_audio_to_target_with_embedding_dataset(config) + + for n in range(num_examples): + item = dataset.__getitem__(n) + item_factory = dataset_factory.__getitem__(n) + + for signal in data: + item_signal = item[signal].cpu().detach().numpy() + golden_signal = data[signal][n] + assert ( + item_signal.shape == golden_signal.shape + ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + assert np.allclose( + item_signal, golden_signal, atol=atol + ), f'Test 1: Failed for example {n}, signal {signal} (random seed {random_seed})' + + item_factory_signal = item_factory[signal].cpu().detach().numpy() + assert np.allclose( + item_factory_signal, golden_signal, atol=atol + ), f'Test 1: Failed for factory example {n}, signal {signal} (random seed {random_seed})' + + # Test 2: + # - Test collate_fn + batch_size = 16 + batch = [dataset.__getitem__(n) for n in range(batch_size)] + _ = dataset.collate_fn(batch) diff --git a/tests/collections/asr/test_asr_losses.py b/tests/collections/audio/test_audio_losses.py similarity index 95% rename from tests/collections/asr/test_asr_losses.py rename to tests/collections/audio/test_audio_losses.py index e050e7cc07c3d..8c8dbdb475983 100644 --- a/tests/collections/asr/test_asr_losses.py +++ b/tests/collections/audio/test_audio_losses.py @@ -16,7 +16,7 @@ import pytest import torch -from nemo.collections.asr.losses.audio_losses import ( +from nemo.collections.audio.losses.audio import ( MSELoss, SDRLoss, calculate_mse_batch, @@ -24,7 +24,7 @@ convolution_invariant_target, scale_invariant_target, ) -from nemo.collections.asr.parts.utils.audio_utils import ( +from nemo.collections.audio.parts.utils.audio import ( calculate_sdr_numpy, convolution_invariant_target_numpy, scale_invariant_target_numpy, @@ -35,8 +35,7 @@ class TestAudioLosses: @pytest.mark.unit @pytest.mark.parametrize('num_channels', [1, 4]) def test_sdr(self, num_channels: int): - """Test SDR calculation - """ + """Test SDR calculation""" test_eps = [0, 1e-16, 1e-1] batch_size = 8 num_samples = 50 @@ -73,12 +72,18 @@ def test_sdr(self, num_channels: int): for b in range(batch_size): for m in range(num_channels): golden_sdr[b, m] = calculate_sdr_numpy( - estimate=estimate[b, m, :], target=target[b, m, :], remove_mean=remove_mean, eps=eps, + estimate=estimate[b, m, :], + target=target[b, m, :], + remove_mean=remove_mean, + eps=eps, ) # Calculate SDR in torch uut_sdr = calculate_sdr_batch( - estimate=tensor_estimate, target=tensor_target, remove_mean=remove_mean, eps=eps, + estimate=tensor_estimate, + target=tensor_target, + remove_mean=remove_mean, + eps=eps, ) # Calculate SDR loss @@ -97,8 +102,7 @@ def test_sdr(self, num_channels: int): @pytest.mark.unit @pytest.mark.parametrize('num_channels', [1, 4]) def test_sdr_weighted(self, num_channels: int): - """Test SDR calculation with weighting for channels - """ + """Test SDR calculation with weighting for channels""" batch_size = 8 num_samples = 50 num_batches = 10 @@ -147,8 +151,7 @@ def test_sdr_weighted(self, num_channels: int): @pytest.mark.unit @pytest.mark.parametrize('num_channels', [1, 4]) def test_sdr_input_length(self, num_channels): - """Test SDR calculation with input length. - """ + """Test SDR calculation with input length.""" batch_size = 8 max_num_samples = 50 num_batches = 10 @@ -198,8 +201,7 @@ def test_sdr_input_length(self, num_channels): @pytest.mark.unit @pytest.mark.parametrize('num_channels', [1, 4]) def test_sdr_scale_invariant(self, num_channels: int): - """Test SDR calculation with scale invariant option. - """ + """Test SDR calculation with scale invariant option.""" batch_size = 8 max_num_samples = 50 num_batches = 10 @@ -251,8 +253,7 @@ def test_sdr_scale_invariant(self, num_channels: int): @pytest.mark.unit @pytest.mark.parametrize('num_channels', [1, 4]) def test_sdr_binary_mask(self, num_channels): - """Test SDR calculation with temporal mask. - """ + """Test SDR calculation with temporal mask.""" batch_size = 8 max_num_samples = 50 num_batches = 10 @@ -305,8 +306,7 @@ def test_sdr_binary_mask(self, num_channels): @pytest.mark.parametrize('num_channels', [1]) @pytest.mark.parametrize('sdr_max', [10, 0]) def test_sdr_max(self, num_channels: int, sdr_max: float): - """Test SDR calculation with soft max threshold. - """ + """Test SDR calculation with soft max threshold.""" batch_size = 8 max_num_samples = 50 num_batches = 10 @@ -357,8 +357,7 @@ def test_sdr_max(self, num_channels: int, sdr_max: float): @pytest.mark.parametrize('filter_length', [1, 32]) @pytest.mark.parametrize('num_channels', [1, 4]) def test_target_calculation(self, num_channels: int, filter_length: int): - """Test target calculation with scale and convolution invariance. - """ + """Test target calculation with scale and convolution invariance.""" batch_size = 8 max_num_samples = 50 num_batches = 10 @@ -422,8 +421,7 @@ def test_target_calculation(self, num_channels: int, filter_length: int): @pytest.mark.parametrize('filter_length', [1, 32]) @pytest.mark.parametrize('num_channels', [1, 4]) def test_sdr_convolution_invariant(self, num_channels: int, filter_length: int): - """Test SDR calculation with convolution invariant option. - """ + """Test SDR calculation with convolution invariant option.""" batch_size = 8 max_num_samples = 50 num_batches = 10 @@ -476,8 +474,7 @@ def test_sdr_convolution_invariant(self, num_channels: int, filter_length: int): @pytest.mark.parametrize('num_channels', [1, 4]) @pytest.mark.parametrize('ndim', [3, 4]) def test_mse(self, num_channels: int, ndim: int): - """Test SDR calculation - """ + """Test SDR calculation""" batch_size = 8 num_samples = 50 num_features = 123 @@ -539,8 +536,7 @@ def test_mse(self, num_channels: int, ndim: int): @pytest.mark.parametrize('num_channels', [1, 4]) @pytest.mark.parametrize('ndim', [3, 4]) def test_mse_weighted(self, num_channels: int, ndim: int): - """Test SDR calculation with weighting for channels - """ + """Test SDR calculation with weighting for channels""" batch_size = 8 num_samples = 50 num_features = 123 @@ -599,8 +595,7 @@ def test_mse_weighted(self, num_channels: int, ndim: int): @pytest.mark.parametrize('num_channels', [1, 4]) @pytest.mark.parametrize('ndim', [3, 4]) def test_mse_input_length(self, num_channels: int, ndim: int): - """Test SDR calculation with input length. - """ + """Test SDR calculation with input length.""" batch_size = 8 max_num_samples = 50 num_features = 123 diff --git a/tests/collections/audio/test_audio_metrics.py b/tests/collections/audio/test_audio_metrics.py new file mode 100644 index 0000000000000..2d693bc4ab209 --- /dev/null +++ b/tests/collections/audio/test_audio_metrics.py @@ -0,0 +1,142 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch +from torchmetrics.audio.snr import SignalNoiseRatio + +from nemo.collections.audio.metrics.audio import AudioMetricWrapper + + +class TestAudioMetricWrapper: + def test_metric_full_batch(self): + """Test metric on batches where all examples have equal length.""" + ref_metric = SignalNoiseRatio() + wrapped_metric = AudioMetricWrapper(metric=SignalNoiseRatio()) + + num_resets = 5 + num_batches = 10 + batch_size = 8 + num_channels = 2 + num_samples = 200 + + batch_shape = (batch_size, num_channels, num_samples) + + for nr in range(num_resets): + for nb in range(num_batches): + target = torch.rand(*batch_shape) + preds = target + torch.rand(1) * torch.rand(*batch_shape) + + # test forward for a single batch + batch_value_wrapped = wrapped_metric(preds=preds, target=target) + batch_value_ref = ref_metric(preds=preds, target=target) + + assert torch.allclose( + batch_value_wrapped, batch_value_ref + ), f'Metric forward not matching for batch {nb}, reset {nr}' + + # test compute (over num_batches) + assert torch.allclose( + wrapped_metric.compute(), ref_metric.compute() + ), f'Metric compute not matching for batch {nb}, reset {nr}' + + ref_metric.reset() + wrapped_metric.reset() + + def test_input_length(self): + """Test metric on batches where examples have different length.""" + ref_metric = SignalNoiseRatio() + wrapped_metric = AudioMetricWrapper(metric=SignalNoiseRatio()) + + num_resets = 5 + num_batches = 10 + batch_size = 8 + num_channels = 2 + num_samples = 200 + + batch_shape = (batch_size, num_channels, num_samples) + + for nr in range(num_resets): + for nb in range(num_batches): + target = torch.rand(*batch_shape) + preds = target + torch.rand(1) * torch.rand(*batch_shape) + + input_length = torch.randint(low=num_samples // 2, high=num_samples, size=(batch_size,)) + + # test forward for a single batch + batch_value_wrapped = wrapped_metric(preds=preds, target=target, input_length=input_length) + + # compute reference value, assuming batch reduction using averaging + batch_value_ref = 0 + for b_idx, b_len in enumerate(input_length): + batch_value_ref += ref_metric(preds=preds[b_idx, ..., :b_len], target=target[b_idx, ..., :b_len]) + batch_value_ref /= batch_size # average + + assert torch.allclose( + batch_value_wrapped, batch_value_ref + ), f'Metric forward not matching for batch {nb}, reset {nr}' + + # test compute (over num_batches) + assert torch.allclose( + wrapped_metric.compute(), ref_metric.compute() + ), f'Metric compute not matching for batch {nb}, reset {nr}' + + ref_metric.reset() + wrapped_metric.reset() + + @pytest.mark.unit + @pytest.mark.parametrize('channel', [0, 1]) + def test_channel(self, channel): + """Test metric on a single channel from a batch.""" + ref_metric = SignalNoiseRatio() + # select only a single channel + wrapped_metric = AudioMetricWrapper(metric=SignalNoiseRatio(), channel=channel) + + num_resets = 5 + num_batches = 10 + batch_size = 8 + num_channels = 2 + num_samples = 200 + + batch_shape = (batch_size, num_channels, num_samples) + + for nr in range(num_resets): + for nb in range(num_batches): + target = torch.rand(*batch_shape) + preds = target + torch.rand(1) * torch.rand(*batch_shape) + + # varying length + input_length = torch.randint(low=num_samples // 2, high=num_samples, size=(batch_size,)) + + # test forward for a single batch + batch_value_wrapped = wrapped_metric(preds=preds, target=target, input_length=input_length) + + # compute reference value, assuming batch reduction using averaging + batch_value_ref = 0 + for b_idx, b_len in enumerate(input_length): + batch_value_ref += ref_metric( + preds=preds[b_idx, channel, :b_len], target=target[b_idx, channel, :b_len] + ) + batch_value_ref /= batch_size # average + + assert torch.allclose( + batch_value_wrapped, batch_value_ref + ), f'Metric forward not matching for batch {nb}, reset {nr}' + + # test compute (over num_batches) + assert torch.allclose( + wrapped_metric.compute(), ref_metric.compute() + ), f'Metric compute not matching for batch {nb}, reset {nr}' + + ref_metric.reset() + wrapped_metric.reset() diff --git a/tests/collections/asr/test_audio_modules.py b/tests/collections/audio/test_audio_modules.py similarity index 96% rename from tests/collections/asr/test_audio_modules.py rename to tests/collections/audio/test_audio_modules.py index d789e97c3348d..ff90044d0e5c4 100644 --- a/tests/collections/asr/test_audio_modules.py +++ b/tests/collections/audio/test_audio_modules.py @@ -19,16 +19,16 @@ import pytest import torch -from nemo.collections.asr.modules.audio_modules import ( +from nemo.collections.audio.modules.features import SpectrogramToMultichannelFeatures +from nemo.collections.audio.modules.masking import ( MaskBasedDereverbWPE, MaskEstimatorFlexChannels, MaskEstimatorGSS, MaskReferenceChannel, - SpectrogramToMultichannelFeatures, - WPEFilter, ) -from nemo.collections.asr.modules.audio_preprocessing import AudioToSpectrogram -from nemo.collections.asr.parts.utils.audio_utils import convmtx_mc_numpy +from nemo.collections.audio.modules.transforms import AudioToSpectrogram +from nemo.collections.audio.parts.submodules.multichannel import WPEFilter +from nemo.collections.audio.parts.utils.audio import convmtx_mc_numpy from nemo.utils import logging try: @@ -46,8 +46,7 @@ class TestSpectrogramToMultichannelFeatures: @pytest.mark.parametrize('num_channels', [1, 4]) @pytest.mark.parametrize('mag_reduction', [None, 'rms', 'abs_mean', 'mean_abs']) def test_magnitude(self, fft_length: int, num_channels: int, mag_reduction: Optional[str]): - """Test calculation of spatial features for multi-channel audio. - """ + """Test calculation of spatial features for multi-channel audio.""" atol = 1e-6 batch_size = 8 num_samples = fft_length * 50 @@ -60,7 +59,10 @@ def test_magnitude(self, fft_length: int, num_channels: int, mag_reduction: Opti audio2spec = AudioToSpectrogram(fft_length=fft_length, hop_length=hop_length) spec2feat = SpectrogramToMultichannelFeatures( - num_subbands=audio2spec.num_subbands, mag_reduction=mag_reduction, use_ipd=False, mag_normalization=None, + num_subbands=audio2spec.num_subbands, + mag_reduction=mag_reduction, + use_ipd=False, + mag_normalization=None, ) for n in range(num_examples): @@ -96,8 +98,7 @@ def test_magnitude(self, fft_length: int, num_channels: int, mag_reduction: Opti @pytest.mark.parametrize('fft_length', [256]) @pytest.mark.parametrize('num_channels', [1, 4]) def test_ipd(self, fft_length: int, num_channels: int): - """Test calculation of IPD spatial features for multi-channel audio. - """ + """Test calculation of IPD spatial features for multi-channel audio.""" atol = 1e-5 batch_size = 8 num_samples = fft_length * 50 @@ -147,8 +148,7 @@ class TestMaskBasedProcessor: @pytest.mark.parametrize('num_channels', [1, 4]) @pytest.mark.parametrize('num_masks', [1, 2]) def test_mask_reference_channel(self, fft_length: int, num_channels: int, num_masks: int): - """Test masking of the reference channel. - """ + """Test masking of the reference channel.""" if num_channels == 1: # Only one channel available ref_channels = [0] @@ -245,8 +245,7 @@ def test_wpe_convtensor(self, num_channels: int, filter_length: int, delay: int) @pytest.mark.parametrize('filter_length', [10]) @pytest.mark.parametrize('delay', [0, 5]) def test_wpe_filter(self, num_channels: int, filter_length: int, delay: int): - """Test estimation of correlation matrices, filter and filtering. - """ + """Test estimation of correlation matrices, filter and filtering.""" atol = 1e-6 random_seed = 42 num_examples = 10 @@ -323,8 +322,7 @@ def test_wpe_filter(self, num_channels: int, filter_length: int, delay: int): @pytest.mark.parametrize('filter_length', [5]) @pytest.mark.parametrize('delay', [0, 2]) def test_mask_based_dereverb_init(self, num_channels: int, filter_length: int, delay: int): - """Test that dereverb can be initialized and can process audio. - """ + """Test that dereverb can be initialized and can process audio.""" num_examples = 10 batch_size = 8 num_subbands = 15 @@ -361,8 +359,7 @@ class TestMaskEstimator: def test_flex_channels( self, channel_reduction_position: int, channel_reduction_type: str, channel_block_type: str ): - """Test initialization of the mask estimator and make sure it can process input tensor. - """ + """Test initialization of the mask estimator and make sure it can process input tensor.""" # Model parameters num_subbands_tests = [32, 65] num_outputs_tests = [1, 2] diff --git a/tests/collections/asr/test_asr_part_submodules_multichannel.py b/tests/collections/audio/test_audio_part_submodules_multichannel.py similarity index 95% rename from tests/collections/asr/test_asr_part_submodules_multichannel.py rename to tests/collections/audio/test_audio_part_submodules_multichannel.py index f53d140277319..9c3b23a58d52f 100644 --- a/tests/collections/asr/test_asr_part_submodules_multichannel.py +++ b/tests/collections/audio/test_audio_part_submodules_multichannel.py @@ -15,7 +15,7 @@ import pytest import torch -from nemo.collections.asr.parts.submodules.multichannel_modules import ( +from nemo.collections.audio.parts.submodules.multichannel import ( ChannelAttentionPool, ChannelAugment, ChannelAveragePool, @@ -52,8 +52,7 @@ class TestTAC: @pytest.mark.unit @pytest.mark.parametrize('num_channels', [1, 2, 6]) def test_average(self, num_channels): - """Test transform-average-concatenate. - """ + """Test transform-average-concatenate.""" num_examples = 10 batch_size = 4 in_features = 128 @@ -115,8 +114,7 @@ class TestChannelPool: @pytest.mark.unit @pytest.mark.parametrize('num_channels', [1, 2, 6]) def test_average(self, num_channels): - """Test average channel pooling. - """ + """Test average channel pooling.""" num_examples = 10 batch_size = 4 in_features = 128 @@ -136,8 +134,7 @@ def test_average(self, num_channels): @pytest.mark.unit @pytest.mark.parametrize('num_channels', [2, 6]) def test_attention(self, num_channels): - """Test attention for channel pooling. - """ + """Test attention for channel pooling.""" num_examples = 10 batch_size = 4 in_features = 128 diff --git a/tests/collections/asr/test_audio_preprocessing.py b/tests/collections/audio/test_audio_transforms.py similarity index 98% rename from tests/collections/asr/test_audio_preprocessing.py rename to tests/collections/audio/test_audio_transforms.py index 600b9fed44fa1..342bb16e5b140 100644 --- a/tests/collections/asr/test_audio_preprocessing.py +++ b/tests/collections/audio/test_audio_transforms.py @@ -18,7 +18,7 @@ import pytest import torch -from nemo.collections.asr.modules.audio_preprocessing import AudioToSpectrogram, SpectrogramToAudio +from nemo.collections.audio.modules.transforms import AudioToSpectrogram, SpectrogramToAudio try: importlib.import_module('torchaudio') @@ -160,8 +160,7 @@ def test_spec_to_audio(self, fft_length: int, num_channels: int): def test_audio_to_spectrogram_reconstruction( self, fft_length: int, num_channels: int, magnitude_power: float, scale: float ): - """Test analysis and synthesis transform result in a perfect reconstruction. - """ + """Test analysis and synthesis transform result in a perfect reconstruction.""" batch_size = 4 num_samples = fft_length * 50 num_examples = 25 diff --git a/tests/collections/audio/utils/test_audio_utils.py b/tests/collections/audio/utils/test_audio_utils.py new file mode 100644 index 0000000000000..b108465f87352 --- /dev/null +++ b/tests/collections/audio/utils/test_audio_utils.py @@ -0,0 +1,360 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import librosa +import matplotlib.pyplot as plt +import numpy as np +import pytest +import scipy +import torch + +from nemo.collections.audio.parts.utils.audio import SOUND_VELOCITY as sound_velocity +from nemo.collections.audio.parts.utils.audio import ( + calculate_sdr_numpy, + convmtx_mc_numpy, + db2mag, + estimated_coherence, + generate_approximate_noise_field, + get_segment_start, + mag2db, + pow2db, + rms, + theoretical_coherence, + toeplitz, +) + + +class TestGenerateApproximateNoiseField: + @pytest.mark.unit + @pytest.mark.parametrize('num_mics', [5]) + @pytest.mark.parametrize('mic_spacing', [0.05]) + @pytest.mark.parametrize('fft_length', [512, 2048]) + @pytest.mark.parametrize('sample_rate', [8000, 16000]) + @pytest.mark.parametrize('field', ['spherical']) + def test_theoretical_coherence_matrix( + self, num_mics: int, mic_spacing: float, fft_length: int, sample_rate: float, field: str + ): + """Test calculation of a theoretical coherence matrix.""" + # test setup + max_diff_tol = 1e-9 + + # golden reference: spherical coherence + num_subbands = fft_length // 2 + 1 + angular_freq = 2 * np.pi * sample_rate * np.arange(0, num_subbands) / fft_length + golden_coherence = np.zeros((num_subbands, num_mics, num_mics)) + + for p in range(num_mics): + for q in range(num_mics): + if p == q: + golden_coherence[:, p, q] = 1.0 + else: + if field == 'spherical': + dist_pq = abs(p - q) * mic_spacing + sinc_arg = angular_freq * dist_pq / sound_velocity + golden_coherence[:, p, q] = np.sinc(sinc_arg / np.pi) + else: + raise NotImplementedError(f'Field {field} not supported.') + + # assume linear arrray + mic_positions = np.zeros((num_mics, 3)) + mic_positions[:, 0] = mic_spacing * np.arange(num_mics) + + # UUT + uut_coherence = theoretical_coherence( + mic_positions, sample_rate=sample_rate, fft_length=fft_length, field='spherical' + ) + + # Check difference + max_diff = np.max(np.abs(uut_coherence - golden_coherence)) + assert max_diff < max_diff_tol + + @pytest.mark.unit + @pytest.mark.parametrize('num_mics', [5]) + @pytest.mark.parametrize('mic_spacing', [0.10]) + @pytest.mark.parametrize('fft_length', [256, 512]) + @pytest.mark.parametrize('sample_rate', [8000, 16000]) + @pytest.mark.parametrize('field', ['spherical']) + def test_generate_approximate_noise_field( + self, + num_mics: int, + mic_spacing: float, + fft_length: int, + sample_rate: float, + field: str, + save_figures: bool = False, + ): + """Test approximate noise field with white noise as the input noise.""" + duration_in_sec = 20 + relative_mse_tol_dB = -30 + relative_mse_tol = 10 ** (relative_mse_tol_dB / 10) + + num_samples = sample_rate * duration_in_sec + noise_signal = np.random.rand(num_samples, num_mics) + # random channel-wise power scaling + noise_signal *= np.random.randn(num_mics) + + # assume linear arrray + mic_positions = np.zeros((num_mics, 3)) + mic_positions[:, 0] = mic_spacing * np.arange(num_mics) + + # UUT + noise_field = generate_approximate_noise_field( + mic_positions, noise_signal, sample_rate=sample_rate, field=field, fft_length=fft_length + ) + + # Compare the estimated coherence with the theoretical coherence + + # reference + golden_coherence = theoretical_coherence( + mic_positions, sample_rate=sample_rate, field=field, fft_length=fft_length + ) + + # estimated + N = librosa.stft(noise_field.transpose(), n_fft=fft_length) + # (channel, subband, frame) -> (subband, frame, channel) + N = N.transpose(1, 2, 0) + uut_coherence = estimated_coherence(N) + + # Check difference + relative_mse_real = np.mean((uut_coherence.real - golden_coherence) ** 2) + assert relative_mse_real < relative_mse_tol + relative_mse_imag = np.mean((uut_coherence.imag) ** 2) + assert relative_mse_imag < relative_mse_tol + + if save_figures: + # For debugging and visualization template + figure_dir = os.path.expanduser('~/_coherence') + if not os.path.exists(figure_dir): + os.mkdir(figure_dir) + + freq = librosa.fft_frequencies(sr=sample_rate, n_fft=fft_length) + freq = freq / 1e3 # kHz + + plt.figure(figsize=(7, 10)) + for n in range(1, num_mics): + plt.subplot(num_mics - 1, 2, 2 * n - 1) + plt.plot(freq, golden_coherence[:, 0, n].real, label='golden') + plt.plot(freq, uut_coherence[:, 0, n].real, label='estimated') + plt.title(f'Real(coherence), p=0, q={n}') + plt.xlabel('f / kHz') + plt.grid() + plt.legend(loc='upper right') + + plt.subplot(num_mics - 1, 2, 2 * n) + plt.plot(golden_coherence[:, 0, n].imag, label='golden') + plt.plot(uut_coherence[:, 0, n].imag, label='estimated') + plt.title(f'Imag(coherence), p=0, q={n}') + plt.xlabel('f / kHz') + plt.grid() + plt.legend(loc='upper right') + + plt.tight_layout() + plt.savefig( + os.path.join( + figure_dir, f'num_mics_{num_mics}_sample_rate_{sample_rate}_fft_length_{fft_length}_{field}.png' + ) + ) + plt.close() + + +class TestAudioUtilsElements: + @pytest.mark.unit + def test_rms(self): + """Test RMS calculation""" + # setup + A = np.random.rand() + omega = 100 + n_points = 1000 + rms_threshold = 1e-4 + # prep data + t = np.linspace(0, 2 * np.pi, n_points) + x = A * np.cos(2 * np.pi * omega * t) + # test + x_rms = rms(x) + golden_rms = A / np.sqrt(2) + assert ( + np.abs(x_rms - golden_rms) < rms_threshold + ), f'RMS not matching for A={A}, omega={omega}, n_point={n_points}' + + @pytest.mark.unit + def test_db_conversion(self): + """Test conversions to and from dB.""" + num_examples = 10 + abs_threshold = 1e-6 + + mag = np.random.rand(num_examples) + mag_db = mag2db(mag) + + assert all(np.abs(mag - 10 ** (mag_db / 20)) < abs_threshold) + assert all(np.abs(db2mag(mag_db) - 10 ** (mag_db / 20)) < abs_threshold) + assert all(np.abs(pow2db(mag**2) - mag_db) < abs_threshold) + + @pytest.mark.unit + def test_get_segment_start(self): + random_seed = 42 + num_examples = 50 + num_samples = 2000 + + _rng = np.random.default_rng(seed=random_seed) + + for n in range(num_examples): + # Generate signal + signal = _rng.normal(size=num_samples) + # Random start in the first half + start = _rng.integers(low=0, high=num_samples // 2) + # Random length + end = _rng.integers(low=start, high=num_samples) + # Selected segment + segment = signal[start:end] + + # UUT + estimated_start = get_segment_start(signal=signal, segment=segment) + + assert ( + estimated_start == start + ), f'Example {n}: estimated start ({estimated_start}) not matching the actual start ({start})' + + @pytest.mark.unit + def test_calculate_sdr_numpy(self): + atol = 1e-6 + random_seed = 42 + num_examples = 50 + num_samples = 2000 + + _rng = np.random.default_rng(seed=random_seed) + + for n in range(num_examples): + # Generate signal + target = _rng.normal(size=num_samples) + # Adjust the estimate + golden_sdr = _rng.integers(low=-10, high=10) + estimate = target * (1 + 10 ** (-golden_sdr / 20)) + + # UUT + estimated_sdr = calculate_sdr_numpy(estimate=estimate, target=target, remove_mean=False) + + assert np.isclose( + estimated_sdr, golden_sdr, atol=atol + ), f'Example {n}: estimated ({estimated_sdr}) not matching the actual value ({golden_sdr})' + + # Add random mean and use remove_mean=True + # SDR should not change + target += _rng.uniform(low=-10, high=10) + estimate += _rng.uniform(low=-10, high=10) + + # UUT + estimated_sdr = calculate_sdr_numpy(estimate=estimate, target=target, remove_mean=True) + + assert np.isclose( + estimated_sdr, golden_sdr, atol=atol + ), f'Example {n}: estimated ({estimated_sdr}) not matching the actual value ({golden_sdr})' + + @pytest.mark.unit + def test_calculate_sdr_numpy_scale_invariant(self): + atol = 1e-6 + random_seed = 42 + num_examples = 50 + num_samples = 2000 + + _rng = np.random.default_rng(seed=random_seed) + + for n in range(num_examples): + # Generate signal + target = _rng.normal(size=num_samples) + # Adjust the estimate + estimate = target + _rng.uniform(low=0.01, high=1) * _rng.normal(size=target.size) + + # scaled target + target_scaled = target / (np.linalg.norm(target) + 1e-16) + target_scaled = np.sum(estimate * target_scaled) * target_scaled + + golden_sdr = calculate_sdr_numpy( + estimate=estimate, target=target_scaled, scale_invariant=False, remove_mean=False + ) + + # UUT + estimated_sdr = calculate_sdr_numpy( + estimate=estimate, target=target, scale_invariant=True, remove_mean=False + ) + + print(golden_sdr, estimated_sdr) + + assert np.isclose( + estimated_sdr, golden_sdr, atol=atol + ), f'Example {n}: estimated ({estimated_sdr}) not matching the actual value ({golden_sdr})' + + @pytest.mark.unit + @pytest.mark.parametrize('num_channels', [1, 3]) + @pytest.mark.parametrize('filter_length', [10]) + @pytest.mark.parametrize('delay', [0, 5]) + def test_convmtx_mc(self, num_channels: int, filter_length: int, delay: int): + """Test convmtx against convolve and sum. + Multiplication of convmtx_mc of input with a vectorized multi-channel filter + should match the sum of convolution of each input channel with the corresponding + filter. + """ + atol = 1e-6 + random_seed = 42 + num_examples = 10 + num_samples = 2000 + + _rng = np.random.default_rng(seed=random_seed) + + for n in range(num_examples): + x = _rng.normal(size=(num_samples, num_channels)) + f = _rng.normal(size=(filter_length, num_channels)) + + CM = convmtx_mc_numpy(x=x, filter_length=filter_length, delay=delay) + + # Multiply convmtx_mc with the vectorized filter + uut = CM @ f.transpose().reshape(-1, 1) + uut = uut.squeeze(1) + + # Calculate reference as sum of convolutions + golden_ref = 0 + for m in range(num_channels): + x_m_delayed = np.hstack([np.zeros(delay), x[:, m]]) + golden_ref += np.convolve(x_m_delayed, f[:, m], mode='full')[: len(x)] + + assert np.allclose(uut, golden_ref, atol=atol), f'Example {n}: UUT not matching the reference.' + + @pytest.mark.unit + @pytest.mark.parametrize('num_channels', [1, 3]) + @pytest.mark.parametrize('filter_length', [10]) + @pytest.mark.parametrize('num_samples', [10, 100]) + def test_toeplitz(self, num_channels: int, filter_length: int, num_samples: int): + """Test construction of a Toeplitz matrix for a given signal.""" + atol = 1e-6 + random_seed = 42 + num_batches = 10 + batch_size = 8 + + _rng = np.random.default_rng(seed=random_seed) + + for n in range(num_batches): + x = _rng.normal(size=(batch_size, num_channels, num_samples)) + + # Construct Toeplitz matrix + Tx = toeplitz(x=torch.tensor(x)) + + # Compare against the reference + for b in range(batch_size): + for m in range(num_channels): + T_ref = scipy.linalg.toeplitz(x[b, m, ...]) + + assert np.allclose( + Tx[b, m, ...].cpu().numpy(), T_ref, atol=atol + ), f'Example {n}: not matching the reference for (b={b}, m={m}), .' diff --git a/tests/collections/common/test_lhotse_dataloading.py b/tests/collections/common/test_lhotse_dataloading.py index 111c00df392ac..31a8d332814e2 100644 --- a/tests/collections/common/test_lhotse_dataloading.py +++ b/tests/collections/common/test_lhotse_dataloading.py @@ -32,10 +32,6 @@ from nemo.collections.common.data.lhotse.text_adapters import TextExample from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer, create_spt_model -requires_torchaudio = pytest.mark.skipif( - not lhotse.utils.is_torchaudio_available(), reason="Lhotse Shar format support requires torchaudio." -) - @pytest.fixture(scope="session") def cutset_path(tmp_path_factory) -> Path: @@ -348,7 +344,6 @@ def test_dataloader_from_lhotse_cuts_channel_selector(mc_cutset_path: Path): assert torch.equal(b_cs["audio"], batches[n]["audio"][:, channel_selector, :]) -@requires_torchaudio def test_dataloader_from_lhotse_shar_cuts(cutset_shar_path: Path): config = OmegaConf.create( { @@ -682,7 +677,6 @@ def test_dataloader_from_tarred_nemo_manifest_concat(nemo_tarred_manifest_path: torch.testing.assert_close(b["audio_lens"], expected_audio_lens) -@requires_torchaudio def test_dataloader_from_lhotse_shar_cuts_combine_datasets_unweighted( cutset_shar_path: Path, cutset_shar_path_other: Path ): @@ -723,19 +717,18 @@ def test_dataloader_from_lhotse_shar_cuts_combine_datasets_unweighted( assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 2 # dataset 2 b = batches[1] - assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 2 # dataset 1 - assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 1 # dataset 2 + assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 0 # dataset 1 + assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 3 # dataset 2 b = batches[2] - assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 1 # dataset 1 - assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 2 # dataset 2 + assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 2 # dataset 1 + assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 1 # dataset 2 b = batches[3] assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 1 # dataset 1 assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 2 # dataset 2 -@requires_torchaudio def test_dataloader_from_lhotse_shar_cuts_combine_datasets_weighted( cutset_shar_path: Path, cutset_shar_path_other: Path ): @@ -776,12 +769,12 @@ def test_dataloader_from_lhotse_shar_cuts_combine_datasets_weighted( assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 0 # dataset 2 b = batches[1] - assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 3 # dataset 1 - assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 0 # dataset 2 + assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 1 # dataset 1 + assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 2 # dataset 2 b = batches[2] - assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 3 # dataset 1 - assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 0 # dataset 2 + assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 2 # dataset 1 + assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 1 # dataset 2 b = batches[3] assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 3 # dataset 1 @@ -792,8 +785,8 @@ def test_dataloader_from_lhotse_shar_cuts_combine_datasets_weighted( assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 0 # dataset 2 b = batches[5] - assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 1 # dataset 1 - assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 2 # dataset 2 + assert len([cid for cid in b["ids"] if cid.startswith("dummy")]) == 3 # dataset 1 + assert len([cid for cid in b["ids"] if cid.startswith("other")]) == 0 # dataset 2 class TextDataset(torch.utils.data.Dataset): diff --git a/tests/collections/nlp/test_nlp_exportables.py b/tests/collections/nlp/test_nlp_exportables.py index c0b97caea4edd..b404764e7eed5 100644 --- a/tests/collections/nlp/test_nlp_exportables.py +++ b/tests/collections/nlp/test_nlp_exportables.py @@ -21,6 +21,12 @@ import wget from omegaconf import DictConfig, OmegaConf +# WAR for https://github.com/pytorch/pytorch/issues/125462 +# Has to be applied before first import of NeMo +from nemo.core.classes import typecheck + +typecheck.enable_wrapping(enabled=False) + from nemo.collections import nlp as nemo_nlp from nemo.collections.nlp.models import IntentSlotClassificationModel from nemo.collections.nlp.modules.common import ( @@ -35,22 +41,25 @@ def classifier_export(obj): with tempfile.TemporaryDirectory() as tmpdir: filename = os.path.join(tmpdir, obj.__class__.__name__ + '.onnx') obj = obj.cuda() - obj.export(output=filename) + obj.export(output=filename, use_dynamo=True, check_trace=True) class TestExportableClassifiers: + @pytest.mark.pleasefixme @pytest.mark.run_only_on('GPU') @pytest.mark.unit def test_token_classifier_export_to_onnx(self): for num_layers in [1, 2, 4]: classifier_export(TokenClassifier(hidden_size=256, num_layers=num_layers, num_classes=16)) + @pytest.mark.pleasefixme @pytest.mark.run_only_on('GPU') @pytest.mark.unit def test_bert_pretraining_export_to_onnx(self): for num_layers in [1, 2, 4]: classifier_export(TokenClassifier(hidden_size=256, num_layers=num_layers, num_classes=16)) + @pytest.mark.pleasefixme @pytest.mark.run_only_on('GPU') @pytest.mark.unit def test_sequence_token_classifier_export_to_onnx(self): @@ -59,12 +68,14 @@ def test_sequence_token_classifier_export_to_onnx(self): SequenceTokenClassifier(hidden_size=256, num_slots=8, num_intents=8, num_layers=num_layers) ) + @pytest.mark.pleasefixme @pytest.mark.run_only_on('GPU') @pytest.mark.unit def test_sequence_classifier_export_to_onnx(self): for num_layers in [1, 2, 4]: classifier_export(SequenceClassifier(hidden_size=256, num_classes=16, num_layers=num_layers)) + @pytest.mark.pleasefixme @pytest.mark.run_only_on('GPU') @pytest.mark.unit def test_sequence_regression_export_to_onnx(self): @@ -165,6 +176,7 @@ def setup_method(self): } ) + @pytest.mark.pleasefixme @pytest.mark.run_only_on('GPU') @pytest.mark.unit def test_IntentSlotClassificationModel_export_to_onnx(self, dummy_data): @@ -175,7 +187,8 @@ def test_IntentSlotClassificationModel_export_to_onnx(self, dummy_data): trainer = pl.Trainer(**config.trainer) model = IntentSlotClassificationModel(config.model, trainer=trainer) filename = os.path.join(tmpdir, 'isc.onnx') - model.export(output=filename, check_trace=True) + model.export(output=filename, check_trace=True, use_dynamo=False) + model.export(output=filename, check_trace=True, use_dynamo=True) onnx_model = onnx.load(filename) onnx.checker.check_model(onnx_model, full_check=True) # throws when failed assert onnx_model.graph.input[0].name == 'input_ids' @@ -184,6 +197,7 @@ def test_IntentSlotClassificationModel_export_to_onnx(self, dummy_data): assert onnx_model.graph.output[0].name == 'intent_logits' assert onnx_model.graph.output[1].name == 'slot_logits' + @pytest.mark.pleasefixme @pytest.mark.with_downloads() @pytest.mark.run_only_on('GPU') @pytest.mark.unit @@ -191,7 +205,8 @@ def test_TokenClassificationModel_export_to_onnx(self): model = nemo_nlp.models.TokenClassificationModel.from_pretrained(model_name="ner_en_bert") with tempfile.TemporaryDirectory() as tmpdir: filename = os.path.join(tmpdir, 'ner.onnx') - model.export(output=filename, check_trace=True) + model.export(output=filename, check_trace=True, use_dynamo=False) + model.export(output=filename, check_trace=True, use_dynamo=True) onnx_model = onnx.load(filename) onnx.checker.check_model(onnx_model, full_check=True) # throws when failed assert onnx_model.graph.input[0].name == 'input_ids' @@ -199,6 +214,7 @@ def test_TokenClassificationModel_export_to_onnx(self): assert onnx_model.graph.input[2].name == 'token_type_ids' assert onnx_model.graph.output[0].name == 'logits' + @pytest.mark.pleasefixme @pytest.mark.with_downloads() @pytest.mark.run_only_on('GPU') @pytest.mark.unit @@ -206,7 +222,9 @@ def test_PunctuationCapitalizationModel_export_to_onnx(self): model = nemo_nlp.models.PunctuationCapitalizationModel.from_pretrained(model_name="punctuation_en_distilbert") with tempfile.TemporaryDirectory() as tmpdir: filename = os.path.join(tmpdir, 'puncap.onnx') - model.export(output=filename, check_trace=True) + model.export(output=filename, check_trace=True, use_dynamo=False) + # Unsupported FX nodes: {'call_function': ['aten.detach_.default']}. + # model.export(output=filename, check_trace=True, use_dynamo=True) onnx_model = onnx.load(filename) onnx.checker.check_model(onnx_model, full_check=True) # throws when failed assert onnx_model.graph.input[0].name == 'input_ids' @@ -214,6 +232,7 @@ def test_PunctuationCapitalizationModel_export_to_onnx(self): assert onnx_model.graph.output[0].name == 'punct_logits' assert onnx_model.graph.output[1].name == 'capit_logits' + @pytest.mark.pleasefixme @pytest.mark.with_downloads() @pytest.mark.run_only_on('GPU') @pytest.mark.unit @@ -221,7 +240,8 @@ def test_QAModel_export_to_onnx(self): model = nemo_nlp.models.QAModel.from_pretrained(model_name="qa_squadv2.0_bertbase") with tempfile.TemporaryDirectory() as tmpdir: filename = os.path.join(tmpdir, 'qa.onnx') - model.export(output=filename, check_trace=True) + model.export(output=filename, check_trace=True, use_dynamo=False) + model.export(output=filename, check_trace=True, use_dynamo=True) onnx_model = onnx.load(filename) assert onnx_model.graph.input[0].name == 'input_ids' assert onnx_model.graph.input[1].name == 'attention_mask' diff --git a/tests/collections/tts/test_tts_exportables.py b/tests/collections/tts/test_tts_exportables.py index 67f016b0c2af3..4d7c852132840 100644 --- a/tests/collections/tts/test_tts_exportables.py +++ b/tests/collections/tts/test_tts_exportables.py @@ -26,7 +26,7 @@ def fastpitch_model(): model = FastPitchModel.from_pretrained(model_name="tts_en_fastpitch") model.export_config['enable_volume'] = True - model.export_config['enable_ragged_batches'] = True + # model.export_config['enable_ragged_batches'] = True return model @@ -59,14 +59,16 @@ def radtts_model(): class TestExportable: + @pytest.mark.pleasefixme @pytest.mark.run_only_on('GPU') @pytest.mark.unit def test_FastPitchModel_export_to_onnx(self, fastpitch_model): model = fastpitch_model.cuda() with tempfile.TemporaryDirectory() as tmpdir: filename = os.path.join(tmpdir, 'fp.onnx') - model.export(output=filename, verbose=True, onnx_opset_version=14, check_trace=True) + model.export(output=filename, verbose=True, onnx_opset_version=14, check_trace=True, use_dynamo=True) + @pytest.mark.pleasefixme @pytest.mark.with_downloads() @pytest.mark.run_only_on('GPU') @pytest.mark.unit @@ -75,7 +77,7 @@ def test_HifiGanModel_export_to_onnx(self, hifigan_model): assert hifigan_model.generator is not None with tempfile.TemporaryDirectory() as tmpdir: filename = os.path.join(tmpdir, 'hfg.onnx') - model.export(output=filename, verbose=True, check_trace=True) + model.export(output=filename, use_dynamo=True, verbose=True, check_trace=True) @pytest.mark.pleasefixme @pytest.mark.run_only_on('GPU') diff --git a/tests/core/mixins/adapters/test_adapter_model_mixin.py b/tests/core/mixins/adapters/test_adapter_model_mixin.py index 87c6b4e4cfb3d..20ced653ceb6e 100644 --- a/tests/core/mixins/adapters/test_adapter_model_mixin.py +++ b/tests/core/mixins/adapters/test_adapter_model_mixin.py @@ -14,12 +14,12 @@ import os import shutil import tempfile -from typing import Tuple +from typing import List, Optional, Tuple import pytest import torch from hydra.utils import instantiate -from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig, OmegaConf, open_dict from nemo.core import ModelPT, NeuralModule from nemo.core.classes.mixins import adapter_mixin_strategies, adapter_mixins @@ -28,7 +28,7 @@ class DefaultModule(NeuralModule): - """ Define a default neural module (without adapter support)""" + """Define a default neural module (without adapter support)""" def __init__(self): super().__init__() @@ -51,7 +51,7 @@ def num_params(self): class DefaultModuleAdapter(DefaultModule, AdapterModuleMixin): - """ Subclass the DefaultModule, adding adapter module support""" + """Subclass the DefaultModule, adding adapter module support""" def forward(self, x): x = super(DefaultModuleAdapter, self).forward(x) @@ -66,7 +66,7 @@ def forward(self, x): class DefaultModelAdapterMixin(AdapterModelPTMixin): - """ Mixin class that implements this model's specific overrides to AdapterModelPTMixin + """Mixin class that implements this model's specific overrides to AdapterModelPTMixin It will container two modules, an encoder and a decoder, and both can have adapters. By default, encoder adapters are enabled, and decoder adapters are diabled. Decoder adapters can be enabled via the global_cfg in model.cfg.adapters. @@ -79,13 +79,13 @@ class DefaultModelAdapterMixin(AdapterModelPTMixin): def setup_adapters(self): supports_adapters = False - # Check the inheriting class' modules supports adapters or not - if hasattr(self, 'encoder') and isinstance(self.encoder, AdapterModuleMixin): - supports_adapters |= True - - if hasattr(self, 'decoder') and isinstance(self.decoder, AdapterModuleMixin): - supports_adapters |= True + # At least the encoder must extend AdapterModuleMixin + valid_adapter_names = [x for x in self.adapter_module_names if x != ''] + for module_name in valid_adapter_names: + if hasattr(self, module_name) and isinstance(getattr(self, module_name), AdapterModuleMixin): + supports_adapters |= True + # If adapters are supported, setup the adapter config + any modules (pre-existing adapter modules) if supports_adapters: super().setup_adapters() @@ -96,66 +96,98 @@ def add_adapter(self, name: str, cfg: DictConfig): # Resolve module name and adapter name module_name, adapter_name = self.resolve_adapter_module_name_(name) - # Try to retrieve global adapter config - global_config = self._get_global_cfg() - - # forward the method call to the individual modules - # If module name is empty, it is a global adapter, otherwise it is a local adapter - if (module_name == '' and global_config.get('encoder_adapter', True)) or (module_name == 'encoder'): - if hasattr(self, 'encoder'): - self.encoder.add_adapter(name, cfg) - - if (module_name == '' and global_config.get('decoder_adapter', False)) or (module_name == 'decoder'): - if hasattr(self, 'decoder'): - self.decoder.add_adapter(name, cfg) + # Use + as a splitter, in order to share one name across multiple modules + if '+' in module_name: + module_names = module_name.split('+') + else: + module_names = [module_name] + + valid_module_names = [x for x in self.adapter_module_names if x != ''] + default_module_name = self.default_adapter_module_name + + # Update the model.cfg with information about the new adapter from cfg + for module_name in module_names: + # Check if encoder adapters should be added + if module_name == '': + for default in default_module_name: # This model has multiple default modules + if hasattr(self, default): + # Dispatch the call to the default model. + getattr(self, default).add_adapter(name=name, cfg=cfg) + + elif module_name in valid_module_names: + # Check if module exists + if hasattr(self, module_name): + # Dispatch the call to the module. + getattr(self, module_name).add_adapter(name=name, cfg=cfg) def set_enabled_adapters(self, name=None, enabled: bool = True): # check if valid model with some adapter support super().set_enabled_adapters(name, enabled) - # Resolve module name and adapter name + # Resolve the module name and adapter name if name is not None: module_name, _ = self.resolve_adapter_module_name_(name) else: module_name = None - # Try to retrieve global adapter config - global_config = self._get_global_cfg() - - # Forward the method call to the individual modules - if name is None or global_config.get('encoder_adapter', True) or module_name in ('', 'encoder'): - if hasattr(self, 'encoder') and self.encoder.is_adapter_available(): - self.encoder.set_enabled_adapters(name, enabled) - - if name is None or global_config.get('decoder_adapter', False) or module_name == 'decoder': - if hasattr(self, 'decoder') and self.decoder.is_adapter_available(): - self.decoder.set_enabled_adapters(name, enabled) + # Use + as a splitter, in order to share one name across multiple modules + if module_name is not None and '+' in module_name: + module_names = module_name.split('+') + else: + module_names = [module_name] + + valid_module_names = [x for x in self.adapter_module_names if x != ''] + default_module_name = self.default_adapter_module_name + + # Check if default module name is None or not + if default_module_name is None: + raise ValueError( + f"Default module name is None. Class {self.__class__.__name__} must implement " + f"`default_adapter_module_name`" + ) + + # Forward the method call to the individual modules if they exist + for module_name in module_names: + # Check if encoder adapters should be used + + if module_name == '': + for default in default_module_name: + if hasattr(self, default) and isinstance(getattr(self, default), AdapterModuleMixin): + if getattr(self, default).is_adapter_available(): + # Dispatch the call to the default model. + getattr(self, default).set_enabled_adapters(name=name, enabled=enabled) + + elif module_name in valid_module_names: + if hasattr(self, module_name) and isinstance(getattr(self, module_name), AdapterModuleMixin): + if getattr(self, module_name).is_adapter_available(): + # Dispatch the call to the module. + getattr(self, module_name).set_enabled_adapters(name=name, enabled=enabled) def get_enabled_adapters(self) -> list: enabled_adapters = super().get_enabled_adapters() - # Forward the method call to the individual modules - if hasattr(self, 'encoder') and isinstance(self.encoder, AdapterModuleMixin): - encoder_adapters = self.encoder.get_enabled_adapters() - enabled_adapters.extend(encoder_adapters) + valid_module_names = [x for x in self.adapter_module_names if x != ''] - if hasattr(self, 'decoder') and isinstance(self.decoder, AdapterModuleMixin): - decoder_adapters = self.decoder.get_enabled_adapters() - enabled_adapters.extend(decoder_adapters) + # Check if encoder adapters should be used or are enabled + for module_name in valid_module_names: + if hasattr(self, module_name) and isinstance(getattr(self, module_name), AdapterModuleMixin): + enabled_adapters.extend(getattr(self, module_name).get_enabled_adapters()) + + enabled_adapters = list(sorted(list(set(enabled_adapters)))) return enabled_adapters def is_adapter_available(self) -> bool: adapters_available = super().is_adapter_available() - # Try to retrieve global adapter config - # Forward the method call to the individual modules - if hasattr(self, 'encoder') and isinstance(self.encoder, AdapterModuleMixin): - print("Encoder is adapter available", self.encoder.is_adapter_available()) - adapters_available |= self.encoder.is_adapter_available() + valid_module_names = [x for x in self.adapter_module_names if x != ''] - if hasattr(self, 'decoder') and isinstance(self.decoder, AdapterModuleMixin): - adapters_available |= self.decoder.is_adapter_available() + # Forward the method call to the individual modules + for module_name in valid_module_names: + print("Module name", module_name) + if hasattr(self, module_name) and isinstance(getattr(self, module_name), AdapterModuleMixin): + adapters_available |= getattr(self, module_name).is_adapter_available() + print("Adapter available for module", module_name, getattr(self, module_name).is_adapter_available()) return adapters_available @@ -198,6 +230,19 @@ def adapter_module_names(self) -> list: valid_adapter_modules = ['', 'encoder', 'decoder'] return valid_adapter_modules + @property + def default_adapter_module_name(self) -> Optional[List[str]]: + global_config = self._get_global_cfg() + default_modules = [] + encoder_adapter = global_config.get('encoder_adapter', True) + decoder_adapter = global_config.get('decoder_adapter', False) + + if encoder_adapter: + default_modules.append('encoder') + if decoder_adapter: + default_modules.append('decoder') + return default_modules + class DefaultAdapterModel(ModelPT, DefaultModelAdapterMixin): def __init__(self, cfg, trainer=None): @@ -302,6 +347,23 @@ def test_base_model_no_support_for_adapters(self, caplog): logging._logger.propagate = False logging.set_verbosity(original_verbosity) + @pytest.mark.unit + def test_base_model_replace_adapter_compatible_modules(self, caplog): + cfg = get_model_config(in_features=50, update_adapter_cfg=False) + model = DefaultAdapterModel(cfg) + + with pytest.raises(AttributeError): + model.add_adapter(name='adapter_0', cfg=get_adapter_cfg()) + + # Replace the modules of the model dynamically to support adapters + model.replace_adapter_compatible_modules() + + assert isinstance(model.encoder, AdapterModuleMixin) + assert model.encoder.is_adapter_available() is False + + model.add_adapter(name='encoder:adapter_0', cfg=get_adapter_cfg()) + assert model.encoder.is_adapter_available() is True + @pytest.mark.unit def test_single_adapter(self): cfg = get_model_config(in_features=50) @@ -934,8 +996,18 @@ def test_multiple_decoder_save_load_adapter_only_exact_name(self): assert (original_state_dict[ogkey] - restored_state_dict[newkey]).abs().mean() < 1e-6 @pytest.mark.unit - @pytest.mark.parametrize("decoder", ["adapter_0",]) # "decoder:adapter_0" - @pytest.mark.parametrize("encoder", ["adapter_1",]) # "encoder:adapter_1" + @pytest.mark.parametrize( + "decoder", + [ + "adapter_0", + ], + ) # "decoder:adapter_0" + @pytest.mark.parametrize( + "encoder", + [ + "adapter_1", + ], + ) # "encoder:adapter_1" def test_multiple_save_load_adapter_with_multiple_load(self, decoder, encoder): # create a model config, but do not add global_cfg to it # we want to test just module level adapter diff --git a/tests/core/test_fault_tolerance.py b/tests/core/test_fault_tolerance.py new file mode 100644 index 0000000000000..5b4e0ecba4aa3 --- /dev/null +++ b/tests/core/test_fault_tolerance.py @@ -0,0 +1,62 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import pytest +import pytorch_lightning as pl + +from nemo.utils.exp_manager import exp_manager + +try: + from ptl_resiliency import FaultToleranceCallback + + HAVE_FT = True +except (ImportError, ModuleNotFoundError): + HAVE_FT = False + + +@pytest.mark.skipif(not HAVE_FT, reason="requires resiliency package to be installed.") +class TestFaultTolerance: + + @pytest.mark.unit + def test_fault_tol_callback_not_created_by_default(self): + """There should be no FT callback by default""" + test_conf = {"create_tensorboard_logger": False, "create_checkpoint_callback": False} + test_trainer = pl.Trainer(accelerator='cpu') + ft_callback_found = None + exp_manager(test_trainer, test_conf) + for cb in test_trainer.callbacks: + if isinstance(cb, FaultToleranceCallback): + ft_callback_found = cb + assert ft_callback_found is None + + @pytest.mark.unit + def test_fault_tol_callback_created(self): + """Verify that fault tolerance callback is created""" + try: + os.environ['FAULT_TOL_CFG_PATH'] = "/tmp/dummy" + test_conf = { + "create_tensorboard_logger": False, + "create_checkpoint_callback": False, + "create_fault_tolerance_callback": True, + } + test_trainer = pl.Trainer(accelerator='cpu') + ft_callback_found = None + exp_manager(test_trainer, test_conf) + for cb in test_trainer.callbacks: + if isinstance(cb, FaultToleranceCallback): + ft_callback_found = cb + assert ft_callback_found is not None + finally: + del os.environ['FAULT_TOL_CFG_PATH'] diff --git a/tests/core/test_straggler_det.py b/tests/core/test_straggler_det.py new file mode 100644 index 0000000000000..53ba37ac28bb6 --- /dev/null +++ b/tests/core/test_straggler_det.py @@ -0,0 +1,139 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +import pytest +import pytorch_lightning as pl +import torch +from omegaconf import OmegaConf + +from nemo.core.classes import ModelPT +from nemo.utils.exp_manager import exp_manager + +try: + # `ptl_resiliency` is included in `gwe_resiliency_pkg` package + from ptl_resiliency import StragglerDetectionCallback + + HAVE_STRAGGLER_DET = True +except (ImportError, ModuleNotFoundError): + HAVE_STRAGGLER_DET = False + + +class OnesDataset(torch.utils.data.Dataset): + def __init__(self, dataset_len): + super().__init__() + self.__dataset_len = dataset_len + + def __getitem__(self, *args): + return torch.ones(2) + + def __len__(self): + return self.__dataset_len + + +class ExampleModel(ModelPT): + def __init__(self, log_dir, **kwargs): + cfg = OmegaConf.structured({}) + super().__init__(cfg) + pl.seed_everything(1234) + self.l1 = torch.nn.modules.Linear(in_features=2, out_features=1) + self.log_dir = log_dir + + def on_train_start(self): + super().on_train_start() + rank = torch.distributed.get_rank() + + def train_dataloader(self): + dataset = OnesDataset(128) + return torch.utils.data.DataLoader(dataset, batch_size=2, num_workers=8) + + def val_dataloader(self): + dataset = OnesDataset(128) + return torch.utils.data.DataLoader(dataset, batch_size=2, num_workers=8) + + def forward(self, batch): + output = self.l1(batch) + output = torch.nn.functional.l1_loss(output, torch.zeros(output.size()).to(output.device)) + return output + + def validation_step(self, batch, batch_idx): + self.loss = self(batch) + return self.loss + + def training_step(self, batch, batch_idx): + return self(batch) + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.1) + + def list_available_models(self, *args, **kwargs): + pass + + def setup_training_data(self, *args, **kwargs): + pass + + def setup_validation_data(self, *args, **kwargs): + pass + + def on_validation_epoch_end(self): + self.log("val_loss", torch.stack([self.loss]).mean()) + + +@pytest.mark.skipif(not HAVE_STRAGGLER_DET, reason="requires resiliency package to be installed.") +class TestStragglerDetection: + + @pytest.mark.run_only_on('GPU') + def test_prints_perf_scores(self, tmp_path): + # Run dummy 1 rank DDP training + # Training time is limited to 3 seconds and straggler reporting is set to 1 second + # Check if there are straggler related logs in the captured log + max_steps = 1_000_000 + tmp_path = tmp_path / "test_1" + print("TMP PATH", tmp_path) + + trainer = pl.Trainer( + strategy='ddp', + devices=1, + accelerator='gpu', + enable_checkpointing=False, + logger=False, + max_steps=max_steps, + val_check_interval=0.33, + ) + exp_manager( + trainer, + { + "max_time_per_run": "00:00:00:03", + "explicit_log_dir": str(tmp_path), + "create_checkpoint_callback": False, + "create_straggler_detection_callback": True, + "straggler_detection_params": { + "report_time_interval": 1.0, + "calc_relative_gpu_perf": True, + "calc_individual_gpu_perf": True, + "num_gpu_perf_scores_to_log": 1, + }, + }, + ) + model = ExampleModel(log_dir=tmp_path) + trainer.fit(model) + + # assume that NeMo logs are written into "nemo_log_globalrank-0_localrank-0.txt" + rank0_log_content = None + with open(tmp_path / "nemo_log_globalrank-0_localrank-0.txt") as f: + rank0_log_content = f.read() + + assert "GPU relative performance" in rank0_log_content + assert "GPU individual performance" in rank0_log_content diff --git a/tests/deploy/nemo_deploy.py b/tests/deploy/nemo_deploy.py index f188b6e2bac8b..5193fe9511386 100644 --- a/tests/deploy/nemo_deploy.py +++ b/tests/deploy/nemo_deploy.py @@ -27,7 +27,7 @@ run_export_tests = True try: from nemo.deploy import DeployPyTriton - from nemo.deploy.nlp import NemoQueryLLM + from nemo.deploy.nlp import NemoQueryLLM, NemoQueryLLMPyTorch from nemo.export import TensorRTLLM except Exception as e: run_export_tests = False @@ -140,7 +140,7 @@ def run_in_framework_inference( ) nm.deploy() nm.run() - nq = NemoQueryLLM(url="localhost:8000", model_name=model_name) + nq = NemoQueryLLMPyTorch(url="localhost:8000", model_name=model_name) output_deployed = nq.query_llm( prompts=prompt, @@ -241,8 +241,8 @@ def run_trt_llm_inference( nemo_checkpoint_path=checkpoint_path, model_type=model_type, n_gpus=n_gpu, - tensor_parallel_size=tp_size, - pipeline_parallel_size=pp_size, + tensor_parallelism_size=tp_size, + pipeline_parallelism_size=pp_size, max_input_len=max_input_len, max_output_len=max_output_len, max_batch_size=max_batch_size, @@ -252,7 +252,6 @@ def run_trt_llm_inference( max_num_tokens=int(max_input_len * max_batch_size * 0.2), opt_num_tokens=60, use_embedding_sharing=use_embedding_sharing, - save_nemo_model_config=True, ) if ptuning: diff --git a/tests/export/nemo_export.py b/tests/export/nemo_export.py index 5541cc0f8673b..6a296fdb92eb7 100644 --- a/tests/export/nemo_export.py +++ b/tests/export/nemo_export.py @@ -14,129 +14,222 @@ import argparse import json +import logging import shutil +import sys import time +from dataclasses import dataclass from pathlib import Path +from typing import Dict, List, Optional, Tuple + import torch -from tests.infer_data_path import get_infer_test_data +# Import infer_data_path from the parent folder assuming that the 'tests' package is not installed. +sys.path.append(str(Path(__file__).parent.parent)) +from infer_data_path import get_infer_test_data + +LOGGER = logging.getLogger("NeMo") -run_export_tests = True +triton_supported = True try: from nemo.deploy import DeployPyTriton from nemo.deploy.nlp import NemoQueryLLM - from nemo.export import TensorRTLLM except Exception as e: - run_export_tests = False + LOGGER.warning(f"Cannot import Triton, deployment will not be available. {type(e).__name__}: {e}") + triton_supported = False + +in_framework_supported = True +try: + from nemo.deploy.nlp import MegatronLLMDeployable, NemoQueryLLMPyTorch +except Exception as e: + LOGGER.warning( + f"Cannot import MegatronLLMDeployable, in-framework inference will not be available. {type(e).__name__}: {e}" + ) + in_framework_supported = False +trt_llm_supported = True +try: + from nemo.export.tensorrt_llm import TensorRTLLM +except Exception as e: + LOGGER.warning(f"Cannot import the TensorRTLLM exporter, it will not be available. {type(e).__name__}: {e}") + trt_llm_supported = False + +vllm_supported = True +try: + from nemo.export.vllm_exporter import vLLMExporter +except Exception as e: + LOGGER.warning(f"Cannot import the vLLM exporter, it will not be available. {type(e).__name__}: {e}") + vllm_supported = False -def get_accuracy_with_lambada(model, nq, task_ids, lora_uids, test_data_path=None): + +class UsageError(Exception): + pass + + +@dataclass +class FunctionalResult: + regular_pass: Optional[bool] = None + deployed_pass: Optional[bool] = None + + +@dataclass +class AccuracyResult: + accuracy: float + accuracy_relaxed: float + deployed_accuracy: float + deployed_accuracy_relaxed: float + evaluation_time: float + + +def get_accuracy_with_lambada(model, nq, task_ids, lora_uids, test_data_path): # lambada dataset based accuracy test, which includes more than 5000 sentences. # Use generated last token with original text's last token for accuracy comparison. # If the generated last token start with the original token, trtllm_correct make an increment. # It generates a CSV file for text comparison detail. - if test_data_path is None: - raise Exception("test_data_path cannot be None.") - - trtllm_correct = 0 - trtllm_deployed_correct = 0 - trtllm_correct_relaxed = 0 - trtllm_deployed_correct_relaxed = 0 + correct_answers = 0 + correct_answers_deployed = 0 + correct_answers_relaxed = 0 + correct_answers_deployed_relaxed = 0 all_expected_outputs = [] - all_trtllm_outputs = [] + all_actual_outputs = [] with open(test_data_path, 'r') as file: records = json.load(file) - eval_start = time.perf_counter() + eval_start = time.monotonic() for record in records: prompt = record["text_before_last_word"] expected_output = record["last_word"].strip().lower() - trtllm_output = model.forward( - input_texts=[prompt], - max_output_len=1, - top_k=1, - top_p=0, - temperature=0.1, - task_ids=task_ids, - lora_uids=lora_uids, - ) - trtllm_output = trtllm_output[0][0].strip().lower() - all_expected_outputs.append(expected_output) - all_trtllm_outputs.append(trtllm_output) - - if expected_output == trtllm_output: - trtllm_correct += 1 + if model is not None: + if isinstance(model, MegatronLLMDeployable): + model_output = model.generate( + inputs=[prompt], + length_params={"min_length": 1, "max_length": 1}, + sampling_params={ + "use_greedy": True, + "temperature": 0.1, + "top_k": 1, + "top_p": 0, + "repetition_penalty": 1.0, + "add_BOS": True, + "all_probs": False, + "compute_logprob": False, + "end_strings": ["<|endoftext|>", ""], + }, + ) + # MegatronLLMDeployable returns prompt + generated output, so need to slice off prompt + model_output = model_output["sentences"][0][len(prompt) :].strip().lower() + else: + model_output = model.forward( + input_texts=[prompt], + max_output_len=1, + top_k=1, + top_p=0, + temperature=0.1, + task_ids=task_ids, + lora_uids=lora_uids, + ) + model_output = model_output[0][0].strip().lower() + all_actual_outputs.append(model_output) + + if expected_output == model_output: + correct_answers += 1 - if ( - expected_output == trtllm_output - or trtllm_output.startswith(expected_output) - or expected_output.startswith(trtllm_output) - ): - if len(trtllm_output) == 1 and len(expected_output) > 1: - continue - trtllm_correct_relaxed += 1 + if ( + expected_output == model_output + or model_output.startswith(expected_output) + or expected_output.startswith(model_output) + ): + if len(model_output) == 1 and len(expected_output) > 1: + continue + correct_answers_relaxed += 1 if nq is not None: - trtllm_deployed_output = nq.query_llm( - prompts=[prompt], - max_output_len=1, - top_k=1, - top_p=0, - temperature=0.1, - task_id=task_ids, - ) - trtllm_deployed_output = trtllm_deployed_output[0][0].strip().lower() - - if expected_output == trtllm_deployed_output: - trtllm_deployed_correct += 1 + if isinstance(nq, NemoQueryLLMPyTorch): + deployed_output = nq.query_llm( + prompts=[prompt], + max_length=1, + top_k=1, + top_p=0, + temperature=0.1, + ) + # MegatronLLMDeployable returns prompt + generated output, so need to slice off prompt + deployed_output = deployed_output["sentences"][0][0][len(prompt) :].decode().strip().lower() + else: + deployed_output = nq.query_llm( + prompts=[prompt], + max_output_len=1, + top_k=1, + top_p=0, + temperature=0.1, + task_id=task_ids, + ) + deployed_output = deployed_output[0][0].strip().lower() + + if expected_output == deployed_output: + correct_answers_deployed += 1 if ( - expected_output == trtllm_deployed_output - or trtllm_deployed_output.startswith(expected_output) - or expected_output.startswith(trtllm_deployed_output) + expected_output == deployed_output + or deployed_output.startswith(expected_output) + or expected_output.startswith(deployed_output) ): - if len(trtllm_deployed_output) == 1 and len(expected_output) > 1: + if len(deployed_output) == 1 and len(expected_output) > 1: continue - trtllm_deployed_correct_relaxed += 1 - eval_end = time.perf_counter() + correct_answers_deployed_relaxed += 1 + eval_end = time.monotonic() + + return AccuracyResult( + accuracy=correct_answers / len(all_expected_outputs), + accuracy_relaxed=correct_answers_relaxed / len(all_expected_outputs), + deployed_accuracy=correct_answers_deployed / len(all_expected_outputs), + deployed_accuracy_relaxed=correct_answers_deployed_relaxed / len(all_expected_outputs), + evaluation_time=eval_end - eval_start, + ) - trtllm_accuracy = trtllm_correct / len(all_expected_outputs) - trtllm_accuracy_relaxed = trtllm_correct_relaxed / len(all_expected_outputs) - trtllm_deployed_accuracy = trtllm_deployed_correct / len(all_expected_outputs) - trtllm_deployed_accuracy_relaxed = trtllm_deployed_correct_relaxed / len(all_expected_outputs) +# Tests if the model outputs contain the expected keywords. +def check_model_outputs(streaming: bool, model_outputs, expected_outputs: List[str]) -> bool: - evaluation_time = eval_end - eval_start + # In streaming mode, we get a list of lists of lists, and we only care about the last item in that list + if streaming: + if len(model_outputs) == 0: + return False + model_outputs = model_outputs[-1] - return ( - trtllm_accuracy, - trtllm_accuracy_relaxed, - trtllm_deployed_accuracy, - trtllm_deployed_accuracy_relaxed, - evaluation_time, - ) + # See if we have the right number of final answers. + if len(model_outputs) != len(expected_outputs): + return False + + # Check the presence of keywords in the final answers. + for i in range(len(model_outputs)): + if expected_outputs[i] not in model_outputs[i][0]: + return False + + return True -def run_trt_llm_inference( +def run_inference( model_name, model_type, - prompt, + prompts, + expected_outputs, checkpoint_path, - trt_llm_model_dir, - n_gpu=1, + model_dir, + use_vllm, max_batch_size=8, use_embedding_sharing=False, max_input_len=128, max_output_len=128, + use_parallel_embedding=False, ptuning=False, p_tuning_checkpoint=None, lora=False, lora_checkpoint=None, - tp_size=None, - pp_size=None, + tp_size=1, + pp_size=1, top_k=1, top_p=0.0, temperature=1.0, @@ -144,20 +237,21 @@ def run_trt_llm_inference( debug=True, streaming=False, stop_words_list=None, + test_cpp_runtime=False, test_deployment=False, test_data_path=None, save_trt_engine=False, -): +) -> Tuple[Optional[FunctionalResult], Optional[AccuracyResult]]: if Path(checkpoint_path).exists(): - if n_gpu > torch.cuda.device_count(): + if tp_size > torch.cuda.device_count(): print( - "Path: {0} and model: {1} with {2} gpus won't be tested since available # of gpus = {3}".format( - checkpoint_path, model_name, n_gpu, torch.cuda.device_count() + "Path: {0} and model: {1} with {2} tps won't be tested since available # of gpus = {3}".format( + checkpoint_path, model_name, tp_size, torch.cuda.device_count() ) ) - return None, None, None, None, None + return (None, None) - Path(trt_llm_model_dir).mkdir(parents=True, exist_ok=True) + Path(model_dir).mkdir(parents=True, exist_ok=True) if debug: print("") @@ -167,7 +261,7 @@ def run_trt_llm_inference( ) print("") - print("Path: {0} and model: {1} with {2} gpus will be tested".format(checkpoint_path, model_name, n_gpu)) + print("Path: {0} and model: {1} with {2} tps will be tested".format(checkpoint_path, model_name, tp_size)) prompt_embeddings_checkpoint_path = None task_ids = None @@ -182,7 +276,7 @@ def run_trt_llm_inference( print("---- PTuning enabled.") else: print("---- PTuning could not be enabled and skipping the test.") - return None, None, None, None, None + return (None, None) lora_ckpt_list = None lora_uids = None @@ -199,36 +293,47 @@ def run_trt_llm_inference( print("---- LoRA enabled.") else: print("---- LoRA could not be enabled and skipping the test.") - return None, None, None, None, None - - trt_llm_exporter = TensorRTLLM(trt_llm_model_dir, lora_ckpt_list, load_model=False) - - trt_llm_exporter.export( - nemo_checkpoint_path=checkpoint_path, - model_type=model_type, - n_gpus=n_gpu, - tensor_parallel_size=tp_size, - pipeline_parallel_size=pp_size, - max_input_len=max_input_len, - max_output_len=max_output_len, - max_batch_size=max_batch_size, - max_prompt_embedding_table_size=max_prompt_embedding_table_size, - use_lora_plugin=use_lora_plugin, - lora_target_modules=lora_target_modules, - max_num_tokens=int(max_input_len * max_batch_size * 0.2), - opt_num_tokens=60, - use_embedding_sharing=use_embedding_sharing, - save_nemo_model_config=True, - ) + return (None, None) + + if use_vllm: + exporter = vLLMExporter() + + exporter.export( + nemo_checkpoint=checkpoint_path, + model_dir=model_dir, + model_type=model_type, + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + max_model_len=max_input_len + max_output_len, + gpu_memory_utilization=args.gpu_memory_utilization, + ) + else: + exporter = TensorRTLLM(model_dir, lora_ckpt_list, load_model=False) + + exporter.export( + nemo_checkpoint_path=checkpoint_path, + model_type=model_type, + tensor_parallelism_size=tp_size, + pipeline_parallelism_size=pp_size, + max_input_len=max_input_len, + max_output_len=max_output_len, + max_batch_size=max_batch_size, + use_parallel_embedding=use_parallel_embedding, + max_prompt_embedding_table_size=max_prompt_embedding_table_size, + use_lora_plugin=use_lora_plugin, + lora_target_modules=lora_target_modules, + max_num_tokens=int(max_input_len * max_batch_size * 0.2), + use_embedding_sharing=use_embedding_sharing, + ) if ptuning: - trt_llm_exporter.add_prompt_table( + exporter.add_prompt_table( task_name="0", prompt_embeddings_checkpoint_path=prompt_embeddings_checkpoint_path, ) - output = trt_llm_exporter.forward( - input_texts=prompt, + output = exporter.forward( + input_texts=prompts, max_output_len=max_output_len, top_k=top_k, top_p=top_p, @@ -239,12 +344,33 @@ def run_trt_llm_inference( stop_words_list=stop_words_list, ) - if not use_lora_plugin and not ptuning: - test_cpp_runtime( - engine_path=trt_llm_model_dir, - prompt=prompt, + # Unwrap the generator if needed + output = list(output) + + functional_result = FunctionalResult() + + # Check non-deployed funcitonal correctness + if args.functional_test: + functional_result.regular_pass = True + if not check_model_outputs(streaming, output, expected_outputs): + LOGGER.warning("Model outputs don't match the expected result.") + functional_result.regular_pass = False + + output_cpp = "" + if test_cpp_runtime and not use_lora_plugin and not ptuning and not use_vllm: + # This may cause OOM for large models as it creates 2nd instance of a model + exporter_cpp = TensorRTLLM( + model_dir, + load_model=True, + use_python_runtime=False, + ) + + output_cpp = exporter_cpp.forward( + input_texts=prompts, max_output_len=max_output_len, - debug=True, + top_k=top_k, + top_p=top_p, + temperature=temperature, ) nq = None @@ -252,7 +378,7 @@ def run_trt_llm_inference( output_deployed = "" if test_deployment: nm = DeployPyTriton( - model=trt_llm_exporter, + model=exporter, triton_model_name=model_name, port=8000, ) @@ -261,7 +387,7 @@ def run_trt_llm_inference( nq = NemoQueryLLM(url="localhost:8000", model_name=model_name) output_deployed = nq.query_llm( - prompts=prompt, + prompts=prompts, max_output_len=max_output_len, top_k=1, top_p=0.0, @@ -269,75 +395,66 @@ def run_trt_llm_inference( lora_uids=lora_uids, ) - if debug: + # Unwrap the generator if needed + output_deployed = list(output_deployed) + + # Check deployed funcitonal correctness + if args.functional_test: + functional_result.deployed_pass = True + if not check_model_outputs(streaming, output_deployed, expected_outputs): + LOGGER.warning("Deployed model outputs don't match the expected result.") + functional_result.deployed_pass = False + + if debug or functional_result.regular_pass == False or functional_result.deployed_pass == False: print("") - print("--- Prompt: ", prompt) + print("--- Prompt: ", prompts) print("") - print("--- Output: ", output) + print("--- Expected keywords: ", expected_outputs) print("") + print("--- Output: ", output) print("") print("--- Output deployed: ", output_deployed) print("") + print("") + print("--- Output with C++ runtime: ", output_cpp) + print("") + accuracy_result = None if run_accuracy: print("Start model accuracy testing ...") - result = get_accuracy_with_lambada(trt_llm_exporter, nq, task_ids, lora_uids, test_data_path) - if test_deployment: - nm.stop() - - if not save_trt_engine: - shutil.rmtree(trt_llm_model_dir) - return result + accuracy_result = get_accuracy_with_lambada(exporter, nq, task_ids, lora_uids, test_data_path) if test_deployment: nm.stop() if not save_trt_engine: - shutil.rmtree(trt_llm_model_dir) + shutil.rmtree(model_dir) - return None, None, None, None, None + return (functional_result, accuracy_result) else: raise Exception("Checkpoint {0} could not be found.".format(checkpoint_path)) -def test_cpp_runtime( - engine_path, - prompt, - max_output_len, - debug, -): - trt_llm_exporter = TensorRTLLM(engine_path, load_model=True) - output = trt_llm_exporter.forward( - input_texts=prompt, - max_output_len=max_output_len, - top_k=1, - top_p=0.0, - temperature=1.0, - ) - - if debug: - print("") - print("--- Output deployed with cpp runtime: ", output) - print("") - - def run_existing_checkpoints( model_name, - n_gpus, - tp_size=None, - pp_size=None, + use_vllm, + tp_size, + pp_size, + use_parallel_embedding=False, ptuning=False, lora=False, streaming=False, run_accuracy=False, + test_cpp_runtime=False, test_deployment=False, stop_words_list=None, test_data_path=None, save_trt_engine=False, -): - if n_gpus > torch.cuda.device_count(): + in_framework=False, +) -> Tuple[Optional[FunctionalResult], Optional[AccuracyResult]]: + if tp_size > torch.cuda.device_count(): print("Skipping the test due to not enough number of GPUs") - return None, None, None, None, None + return (None, None) test_data = get_infer_test_data() if not (model_name in test_data.keys()): @@ -345,9 +462,9 @@ def run_existing_checkpoints( model_info = test_data[model_name] - if n_gpus < model_info["min_gpus"]: - print("Min n_gpus for this model is {0}".format(n_gpus)) - return None, None, None, None, None + if tp_size < model_info["min_tps"]: + print("Min tps for this model is {0}".format(tp_size)) + return (None, None) p_tuning_checkpoint = None if ptuning: @@ -369,34 +486,109 @@ def run_existing_checkpoints( else: use_embedding_sharing = False - return run_trt_llm_inference( - model_name=model_name, - model_type=model_info["model_type"], - prompt=model_info["prompt_template"], - checkpoint_path=model_info["checkpoint"], - trt_llm_model_dir=model_info["trt_llm_model_dir"], - n_gpu=n_gpus, - max_batch_size=model_info["max_batch_size"], - use_embedding_sharing=use_embedding_sharing, - max_input_len=512, - max_output_len=model_info["max_output_len"], - ptuning=ptuning, - p_tuning_checkpoint=p_tuning_checkpoint, - lora=lora, - lora_checkpoint=lora_checkpoint, - tp_size=tp_size, - pp_size=pp_size, - top_k=1, - top_p=0.0, - temperature=1.0, - run_accuracy=run_accuracy, - debug=True, - streaming=streaming, - stop_words_list=stop_words_list, - test_deployment=test_deployment, - test_data_path=test_data_path, - save_trt_engine=save_trt_engine, - ) + if in_framework: + return run_in_framework_inference( + model_name=model_name, + prompts=model_info["prompt_template"], + checkpoint_path=model_info["checkpoint"], + num_gpus=tp_size, + max_output_len=model_info["max_output_len"], + run_accuracy=run_accuracy, + debug=True, + test_data_path=test_data_path, + ) + else: + return run_inference( + model_name=model_name, + model_type=model_info["model_type"], + prompts=model_info["prompt_template"], + expected_outputs=model_info["expected_keyword"], + checkpoint_path=model_info["checkpoint"], + model_dir=model_info["model_dir"], + use_vllm=use_vllm, + max_batch_size=model_info["max_batch_size"], + use_embedding_sharing=use_embedding_sharing, + use_parallel_embedding=use_parallel_embedding, + max_input_len=512, + max_output_len=model_info["max_output_len"], + ptuning=ptuning, + p_tuning_checkpoint=p_tuning_checkpoint, + lora=lora, + lora_checkpoint=lora_checkpoint, + tp_size=tp_size, + pp_size=pp_size, + top_k=1, + top_p=0.0, + temperature=1.0, + run_accuracy=run_accuracy, + debug=True, + streaming=streaming, + stop_words_list=stop_words_list, + test_cpp_runtime=test_cpp_runtime, + test_deployment=test_deployment, + test_data_path=test_data_path, + save_trt_engine=save_trt_engine, + ) + + +def run_in_framework_inference( + model_name, + prompts, + checkpoint_path, + num_gpus=1, + max_output_len=128, + top_k=1, + top_p=0.0, + temperature=1.0, + run_accuracy=False, + debug=True, + test_data_path=None, +) -> Tuple[Optional[FunctionalResult], Optional[AccuracyResult]]: + if Path(checkpoint_path).exists(): + if debug: + print("") + print("") + print( + "################################################## NEW TEST ##################################################" + ) + print("") + + print("Path: {0} and model: {1} will be tested".format(checkpoint_path, model_name)) + + deployed_model = MegatronLLMDeployable(checkpoint_path, num_gpus) + + nm = DeployPyTriton( + model=deployed_model, + triton_model_name=model_name, + port=8000, + ) + nm.deploy() + nm.run() + nq = NemoQueryLLMPyTorch(url="localhost:8000", model_name=model_name) + + output_deployed = nq.query_llm( + prompts=prompts, top_k=top_k, top_p=top_p, temperature=temperature, max_length=max_output_len + ) + output_deployed = output_deployed["sentences"] + # MegatronLLMDeployable will return the prompt + generated output, so cut off the prompt + for i, output in enumerate(output_deployed): + output = output[len(prompts[i]) :] + + # Unwrap the generator if needed + output_deployed = list(output_deployed) + print("\n --------- Output: ", output_deployed) + + accuracy_result = None + if run_accuracy: + print("Start model accuracy testing ...") + # This script is not written with torch.distributed support in mind, so running non-deployed in-framework models on multiple devices will not work + accuracy_result = get_accuracy_with_lambada(deployed_model, nq, None, None, test_data_path) + + nm.stop() + + return (None, accuracy_result) + else: + raise Exception("Checkpoint {0} could not be found.".format(checkpoint_path)) def get_args(): @@ -421,14 +613,19 @@ def get_args(): required=False, ) parser.add_argument( - "--min_gpus", + "--min_tps", type=int, default=1, required=True, ) parser.add_argument( - "--max_gpus", + "--max_tps", + type=int, + ) + parser.add_argument( + "--pps", type=int, + default=1, ) parser.add_argument( "--checkpoint_dir", @@ -437,7 +634,7 @@ def get_args(): required=False, ) parser.add_argument( - "--trt_llm_model_dir", + "--model_dir", type=str, ) parser.add_argument( @@ -455,6 +652,11 @@ def get_args(): type=int, default=128, ) + parser.add_argument( + "--use_parallel_embedding", + type=str, + default="False", + ) parser.add_argument( "--p_tuning_checkpoint", type=str, @@ -473,14 +675,6 @@ def get_args(): default=False, action='store_true', ) - parser.add_argument( - "--tp_size", - type=int, - ) - parser.add_argument( - "--pp_size", - type=int, - ) parser.add_argument( "--top_k", type=int, @@ -502,18 +696,23 @@ def get_args(): default="False", ) parser.add_argument("--streaming", default=False, action="store_true") + parser.add_argument( + "--test_cpp_runtime", + type=str, + default="False", + ) parser.add_argument( "--test_deployment", type=str, default="False", ) parser.add_argument( - "--debug", - default=False, - action='store_true', + "--functional_test", + type=str, + default="False", ) parser.add_argument( - "--ci_upload_test_results_to_cloud", + "--debug", default=False, action='store_true', ) @@ -527,114 +726,213 @@ def get_args(): type=str, default="False", ) + parser.add_argument( + "--use_vllm", + type=str, + default="False", + ) + parser.add_argument( + "--in_framework", + type=str, + default="False", + ) + parser.add_argument( + "-gmu", + '--gpu_memory_utilization', + default=0.95, # 0.95 is needed to run Mixtral-8x7B on 2x48GB GPUs + type=float, + help="GPU memory utilization percentage for vLLM.", + ) + + args = parser.parse_args() - return parser.parse_args() + def str_to_bool(name: str, s: str) -> bool: + true_strings = ["true", "1"] + false_strings = ["false", "0"] + if s.lower() in true_strings: + return True + if s.lower() in false_strings: + return False + raise UsageError(f"Invalid boolean value for argument --{name}: '{s}'") + + args.test_cpp_runtime = str_to_bool("test_cpp_runtime", args.test_cpp_runtime) + args.test_deployment = str_to_bool("test_deployment", args.test_deployment) + args.functional_test = str_to_bool("functional_test", args.functional_test) + args.save_trt_engine = str_to_bool("save_trt_engin", args.save_trt_engine) + args.run_accuracy = str_to_bool("run_accuracy", args.run_accuracy) + args.use_vllm = str_to_bool("use_vllm", args.use_vllm) + args.use_parallel_embedding = str_to_bool("use_parallel_embedding", args.use_parallel_embedding) + args.in_framework = str_to_bool("in_framework", args.in_framework) + + return args def run_inference_tests(args): - if args.test_deployment == "True": - args.test_deployment = True - else: - args.test_deployment = False + if not args.use_vllm and not args.in_framework and not trt_llm_supported: + raise UsageError("TensorRT-LLM engine is not supported in this environment.") - if args.save_trt_engine == "True": - args.save_trt_engine = True - else: - args.save_trt_engine = False + if args.use_vllm and not vllm_supported: + raise UsageError("vLLM engine is not supported in this environment.") - if args.run_accuracy == "True": - args.run_accuracy = True - else: - args.run_accuracy = False + if args.in_framework and not in_framework_supported: + raise UsageError("In-framework inference is not supported in this environment.") - if args.run_accuracy: - if args.test_data_path is None: - raise Exception("test_data_path param cannot be None.") + if args.use_vllm and (args.ptuning or args.lora): + raise UsageError("The vLLM integration currently does not support P-tuning or LoRA.") + + if args.test_deployment and not triton_supported: + raise UsageError("Deployment tests are not available because Triton is not supported in this environment.") + + if args.run_accuracy and args.test_data_path is None: + raise UsageError("Accuracy testing requires the --test_data_path argument.") + + if args.max_tps is None: + args.max_tps = args.min_tps + + if args.use_vllm and args.min_tps != args.max_tps: + raise UsageError( + "vLLM doesn't support changing tensor parallel group size without relaunching the process. " + "Use the same value for --min_tps and --max_tps." + ) - result_dic = {} + result_dic: Dict[int, Tuple[FunctionalResult, Optional[AccuracyResult]]] = {} if args.existing_test_models: - n_gpus = args.min_gpus - if args.max_gpus is None: - args.max_gpus = args.min_gpus + tps = args.min_tps - while n_gpus <= args.max_gpus: - result_dic[n_gpus] = run_existing_checkpoints( + while tps <= args.max_tps: + result_dic[tps] = run_existing_checkpoints( model_name=args.model_name, - n_gpus=n_gpus, + use_vllm=args.use_vllm, ptuning=args.ptuning, lora=args.lora, - tp_size=args.tp_size, - pp_size=args.pp_size, + tp_size=tps, + pp_size=args.pps, + use_parallel_embedding=args.use_parallel_embedding, streaming=args.streaming, test_deployment=args.test_deployment, + test_cpp_runtime=args.test_cpp_runtime, run_accuracy=args.run_accuracy, test_data_path=args.test_data_path, save_trt_engine=args.save_trt_engine, + in_framework=args.in_framework, ) - n_gpus = n_gpus * 2 + tps = tps * 2 else: - prompt_template = ["The capital of France is", "Largest animal in the sea is"] - n_gpus = args.min_gpus - if args.max_gpus is None: - args.max_gpus = args.min_gpus - - while n_gpus <= args.max_gpus: - result_dic[n_gpus] = run_trt_llm_inference( - model_name=args.model_name, - model_type=args.model_type, - prompt=prompt_template, - checkpoint_path=args.checkpoint_dir, - trt_llm_model_dir=args.trt_llm_model_dir, - n_gpu=n_gpus, - max_batch_size=args.max_batch_size, - max_input_len=args.max_input_len, - max_output_len=args.max_output_len, - ptuning=args.ptuning, - p_tuning_checkpoint=args.p_tuning_checkpoint, - lora=args.lora, - lora_checkpoint=args.lora_checkpoint, - tp_size=args.tp_size, - pp_size=args.pp_size, - top_k=args.top_k, - top_p=args.top_p, - temperature=args.temperature, - run_accuracy=args.run_accuracy, - debug=args.debug, - streaming=args.streaming, - test_deployment=args.test_deployment, - test_data_path=args.test_data_path, - save_trt_engine=args.save_trt_engine, - ) + if not args.in_framework and args.model_dir is None: + raise Exception("When using custom checkpoints, --model_dir is required.") + + prompts = ["The capital of France is", "Largest animal in the sea is"] + expected_outputs = ["Paris", "blue whale"] + tps = args.min_tps + + while tps <= args.max_tps: + if args.in_framework: + result_dic[tps] = run_in_framework_inference( + model_name=args.model_name, + prompts=prompts, + checkpoint_path=args.checkpoint_dir, + num_gpus=tps, + max_output_len=args.max_output_len, + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, + run_accuracy=args.run_accuracy, + debug=True, + test_data_path=args.test_data_path, + ) + else: + result_dic[tps] = run_inference( + model_name=args.model_name, + model_type=args.model_type, + prompts=prompts, + expected_outputs=expected_outputs, + checkpoint_path=args.checkpoint_dir, + model_dir=args.model_dir, + use_vllm=args.use_vllm, + tp_size=tps, + pp_size=args.pps, + max_batch_size=args.max_batch_size, + max_input_len=args.max_input_len, + max_output_len=args.max_output_len, + use_parallel_embedding=args.use_parallel_embedding, + ptuning=args.ptuning, + p_tuning_checkpoint=args.p_tuning_checkpoint, + lora=args.lora, + lora_checkpoint=args.lora_checkpoint, + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, + run_accuracy=args.run_accuracy, + debug=args.debug, + streaming=args.streaming, + test_deployment=args.test_deployment, + test_cpp_runtime=args.test_cpp_runtime, + test_data_path=args.test_data_path, + save_trt_engine=args.save_trt_engine, + ) - n_gpus = n_gpus * 2 + tps = tps * 2 - test_result = "PASS" + functional_test_result = "PASS" + accuracy_test_result = "PASS" print_separator = False print("============= Test Summary ============") - for i, results in result_dic.items(): - if not results[0] is None and not results[1] is None: - if print_separator: - print("---------------------------------------") - print( - "Number of GPUS: {}\n" - "Model Accuracy: {:.4f}\n" - "Relaxed Model Accuracy: {:.4f}\n" - "Deployed Model Accuracy: {:.4f}\n" - "Deployed Relaxed Model Accuracy: {:.4f}\n" - "Evaluation Time [s]: {:.2f}".format(i, *results) - ) - print_separator = True - if results[1] < 0.5: - test_result = "FAIL" + # in-framework tests will only return deployed model accuracy results for tps > 1 + deployed_tests_only = args.in_framework and args.max_tps > 1 + for num_tps, results in result_dic.items(): + functional_result, accuracy_result = results + + if print_separator: + print("---------------------------------------") + print_separator = True + + def optional_bool_to_pass_fail(b: Optional[bool]): + if b is None: + return "N/A" + return "PASS" if b else "FAIL" + + print(f"Tensor Parallelism: {num_tps}") + + if args.functional_test and functional_result is not None: + print(f"Functional Test: {optional_bool_to_pass_fail(functional_result.regular_pass)}") + print(f"Deployed Functional Test: {optional_bool_to_pass_fail(functional_result.deployed_pass)}") + + if functional_result.regular_pass == False: + functional_test_result = "FAIL" + if functional_result.deployed_pass == False: + functional_test_result = "FAIL" + + if args.run_accuracy and accuracy_result is not None: + print(f"Model Accuracy: {accuracy_result.accuracy:.4f}") + print(f"Relaxed Model Accuracy: {accuracy_result.accuracy_relaxed:.4f}") + print(f"Deployed Model Accuracy: {accuracy_result.deployed_accuracy:.4f}") + print(f"Deployed Relaxed Model Accuracy: {accuracy_result.deployed_accuracy_relaxed:.4f}") + print(f"Evaluation Time [s]: {accuracy_result.evaluation_time:.2f}") + if (deployed_tests_only and accuracy_result.deployed_accuracy_relaxed < 0.5) or ( + not deployed_tests_only and accuracy_result.accuracy_relaxed < 0.5 + ): + accuracy_test_result = "FAIL" print("=======================================") - print("TEST: " + test_result) - if test_result == "FAIL": + if args.functional_test: + print(f"Functional: {functional_test_result}") + if args.run_accuracy: + print(f"Acccuracy: {accuracy_test_result}") + + if functional_test_result == "FAIL": + raise Exception("Functional test failed") + + if accuracy_test_result == "FAIL": raise Exception("Model accuracy is below 0.5") if __name__ == '__main__': - args = get_args() - run_inference_tests(args) + try: + args = get_args() + run_inference_tests(args) + except UsageError as e: + LOGGER.error(f"{e}") + except argparse.ArgumentError as e: + LOGGER.error(f"{e}") diff --git a/tests/export/run.sh b/tests/export/run.sh index b3badd25a8f91..e534e4e87ee95 100644 --- a/tests/export/run.sh +++ b/tests/export/run.sh @@ -20,32 +20,28 @@ for i in $(env | grep ^PMIX_ | cut -d"=" -f 1); do unset -v $i; done set +x -python tests/export/nemo_export.py --model_name LLAMA2-7B-base --existing_test_models --min_gpus 1 --max_gpus 2 -python tests/export/nemo_export.py --model_name LLAMA2-7B-base --existing_test_models --min_gpus 1 --streaming -python tests/export/nemo_export.py --model_name LLAMA2-7B-base --existing_test_models --min_gpus 2 --tp_size 1 --pp_size 2 -python tests/export/nemo_export.py --model_name LLAMA2-7B-base --existing_test_models --min_gpus 4 --tp_size 2 --pp_size 2 -python tests/export/nemo_export.py --model_name LLAMA2-7B-base --existing_test_models --min_gpus 8 --tp_size 1 --pp_size 8 -python tests/export/nemo_export.py --model_name LLAMA2-7B-base --existing_test_models --ptuning --min_gpus 1 --max_gpus 2 -python tests/export/nemo_export.py --model_name LLAMA2-7B-base --existing_test_models --lora --min_gpus 1 --max_gpus 2 -python tests/export/nemo_export.py --model_name LLAMA2-7B-code --existing_test_models --min_gpus 1 --max_gpus 2 -python tests/export/nemo_export.py --model_name LLAMA2-7B-base-fp8 --existing_test_models --min_gpus 1 --max_gpus 1 -python tests/export/nemo_export.py --model_name LLAMA2-7B-base-int4 --existing_test_models --min_gpus 1 --max_gpus 1 -python tests/export/nemo_export.py --model_name LLAMA2-7B-base-int8 --existing_test_models --min_gpus 1 --max_gpus 1 -python tests/export/nemo_export.py --model_name LLAMA2-13B-base --existing_test_models --min_gpus 1 --max_gpus 2 -python tests/export/nemo_export.py --model_name LLAMA2-13B-base --existing_test_models --ptuning --min_gpus 1 --max_gpus 2 -python tests/export/nemo_export.py --model_name LLAMA2-13B-base-fp8 --existing_test_models --min_gpus 2 --max_gpus 2 -python tests/export/nemo_export.py --model_name LLAMA2-13B-base-int4 --existing_test_models --min_gpus 2 --max_gpus 2 -python tests/export/nemo_export.py --model_name LLAMA2-70B-base --existing_test_models --min_gpus 2 --max_gpus 8 -python tests/export/nemo_export.py --model_name LLAMA2-70B-base-fp8 --existing_test_models --min_gpus 8 --max_gpus 8 -python tests/export/nemo_export.py --model_name LLAMA2-70B-base-int4 --existing_test_models --min_gpus 8 --max_gpus 8 -python tests/export/nemo_export.py --model_name NV-GPT-8B-Base-4k --existing_test_models --min_gpus 1 --max_gpus 8 -python tests/export/nemo_export.py --model_name NV-GPT-8B-QA-4k --existing_test_models --min_gpus 1 --max_gpus 8 -python tests/export/nemo_export.py --model_name NV-GPT-8B-Chat-4k-SFT --existing_test_models --min_gpus 1 --max_gpus 8 -python tests/export/nemo_export.py --model_name NV-GPT-8B-Chat-4k-RLHF --existing_test_models --min_gpus 1 --max_gpus 8 -python tests/export/nemo_export.py --model_name NV-GPT-8B-Chat-4k-SteerLM --existing_test_models --min_gpus 1 --max_gpus 8 -python tests/export/nemo_export.py --model_name GPT-43B-Base --existing_test_models --min_gpus 2 --max_gpus 8 -python tests/export/nemo_export.py --model_name FALCON-7B-base --existing_test_models --min_gpus 1 --max_gpus 2 -python tests/export/nemo_export.py --model_name FALCON-40B-base --existing_test_models --min_gpus 2 --max_gpus 8 -python tests/export/nemo_export.py --model_name FALCON-180B-base --existing_test_models --min_gpus 8 --max_gpus 8 -python tests/export/nemo_export.py --model_name STARCODER1-15B-base --existing_test_models --min_gpus 1 --max_gpus 1 -python tests/export/nemo_export.py --model_name GEMMA-base --existing_test_models --min_gpus 1 --max_gpus 1 \ No newline at end of file + +python tests/export/nemo_export.py --model_name LLAMA2-7B-base --existing_test_models --min_tps 1 +python tests/export/nemo_export.py --model_name LLAMA2-7B-base --existing_test_models --min_tps 2 +python tests/export/nemo_export.py --model_name LLAMA2-7B-base --existing_test_models --ptuning --min_tps 1 --max_tps 2 +python tests/export/nemo_export.py --model_name LLAMA2-7B-base --existing_test_models --lora --min_tps 1 --max_tps 2 +python tests/export/nemo_export.py --model_name LLAMA2-7B-code --existing_test_models --min_tps 1 --max_tps 2 +python tests/export/nemo_export.py --model_name LLAMA2-7B-base-fp8 --existing_test_models --min_tps 1 --max_tps 1 +python tests/export/nemo_export.py --model_name LLAMA2-7B-base-int4 --existing_test_models --min_tps 1 --max_tps 1 +python tests/export/nemo_export.py --model_name LLAMA2-7B-base-int8 --existing_test_models --min_tps 1 --max_tps 1 +python tests/export/nemo_export.py --model_name LLAMA2-13B-base --existing_test_models --min_tps 1 --max_tps 2 +python tests/export/nemo_export.py --model_name LLAMA2-13B-base --existing_test_models --ptuning --min_tps 1 --max_tps 2 +python tests/export/nemo_export.py --model_name LLAMA2-13B-base-fp8 --existing_test_models --min_tps 2 --max_tps 2 +python tests/export/nemo_export.py --model_name LLAMA2-13B-base-int4 --existing_test_models --min_tps 2 --max_tps 2 +python tests/export/nemo_export.py --model_name LLAMA2-70B-base --existing_test_models --min_tps 2 --max_tps 8 +python tests/export/nemo_export.py --model_name LLAMA2-70B-base-fp8 --existing_test_models --min_tps 8 --max_tps 8 +python tests/export/nemo_export.py --model_name LLAMA2-70B-base-int4 --existing_test_models --min_tps 8 --max_tps 8 +python tests/export/nemo_export.py --model_name NV-GPT-8B-Base-4k --existing_test_models --min_tps 1 --max_tps 8 +python tests/export/nemo_export.py --model_name NV-GPT-8B-QA-4k --existing_test_models --min_tps 1 --max_tps 8 +python tests/export/nemo_export.py --model_name NV-GPT-8B-Chat-4k-SFT --existing_test_models --min_tps 1 --max_tps 8 +python tests/export/nemo_export.py --model_name NV-GPT-8B-Chat-4k-RLHF --existing_test_models --min_tps 1 --max_tps 8 +python tests/export/nemo_export.py --model_name NV-GPT-8B-Chat-4k-SteerLM --existing_test_models --min_tps 1 --max_tps 8 +python tests/export/nemo_export.py --model_name FALCON-7B-base --existing_test_models --min_tps 1 --max_tps 2 +python tests/export/nemo_export.py --model_name FALCON-40B-base --existing_test_models --min_tps 2 --max_tps 8 +python tests/export/nemo_export.py --model_name STARCODER1-15B-base --existing_test_models --min_tps 1 --max_tps 1 +python tests/export/nemo_export.py --model_name GEMMA-base --existing_test_models --min_tps 1 --max_tps 1 \ No newline at end of file diff --git a/tests/infer_data_path.py b/tests/infer_data_path.py index d7e6f231a58f6..45850dcb366a2 100644 --- a/tests/infer_data_path.py +++ b/tests/infer_data_path.py @@ -21,9 +21,9 @@ def get_infer_test_data(): test_data["NV-GPT-8B-Base-4k"] = {} test_data["NV-GPT-8B-Base-4k"]["model_type"] = "gptnext" - test_data["NV-GPT-8B-Base-4k"]["min_gpus"] = 1 + test_data["NV-GPT-8B-Base-4k"]["min_tps"] = 1 test_data["NV-GPT-8B-Base-4k"]["location"] = "Local" - test_data["NV-GPT-8B-Base-4k"]["trt_llm_model_dir"] = "/tmp/NV-GPT-8B-Base-4k/nv-gpt-8b-base-4k_v1.0/" + test_data["NV-GPT-8B-Base-4k"]["model_dir"] = "/tmp/NV-GPT-8B-Base-4k/nv-gpt-8b-base-4k_v1.0/" test_data["NV-GPT-8B-Base-4k"][ "checkpoint" ] = "/opt/checkpoints/NV-GPT-8B-Base-4k/nv-gpt-8b-base-4k_v1.0/NV-GPT-8B-Base-4k.nemo" @@ -39,9 +39,9 @@ def get_infer_test_data(): test_data["NV-GPT-8B-Base-16k"] = {} test_data["NV-GPT-8B-Base-16k"]["model_type"] = "gptnext" - test_data["NV-GPT-8B-Base-16k"]["min_gpus"] = 1 + test_data["NV-GPT-8B-Base-16k"]["min_tps"] = 1 test_data["NV-GPT-8B-Base-16k"]["location"] = "Local" - test_data["NV-GPT-8B-Base-16k"]["trt_llm_model_dir"] = "/tmp/NV-GPT-8B-Base-16k/nv-gpt-8b-base-16k_v1.0/" + test_data["NV-GPT-8B-Base-16k"]["model_dir"] = "/tmp/NV-GPT-8B-Base-16k/nv-gpt-8b-base-16k_v1.0/" test_data["NV-GPT-8B-Base-16k"][ "checkpoint" ] = "/opt/checkpoints/NV-GPT-8B-Base-16k/nv-gpt-8b-base-16k_v1.0/NV-GPT-8B-Base-16k.nemo" @@ -56,9 +56,9 @@ def get_infer_test_data(): test_data["NV-GPT-8B-QA-4k"] = {} test_data["NV-GPT-8B-QA-4k"]["model_type"] = "gptnext" - test_data["NV-GPT-8B-QA-4k"]["min_gpus"] = 1 + test_data["NV-GPT-8B-QA-4k"]["min_tps"] = 1 test_data["NV-GPT-8B-QA-4k"]["location"] = "Local" - test_data["NV-GPT-8B-QA-4k"]["trt_llm_model_dir"] = "/tmp/NV-GPT-8B-QA-4k/nv-gpt-8b-qa-4k_v1.0/" + test_data["NV-GPT-8B-QA-4k"]["model_dir"] = "/tmp/NV-GPT-8B-QA-4k/nv-gpt-8b-qa-4k_v1.0/" test_data["NV-GPT-8B-QA-4k"][ "checkpoint" ] = "/opt/checkpoints/NV-GPT-8B-QA-4k/nv-gpt-8b-qa-4k_v1.0/NV-GPT-8B-QA-4k.nemo" @@ -73,9 +73,9 @@ def get_infer_test_data(): test_data["NV-GPT-8B-Chat-4k-SFT"] = {} test_data["NV-GPT-8B-Chat-4k-SFT"]["model_type"] = "gptnext" - test_data["NV-GPT-8B-Chat-4k-SFT"]["min_gpus"] = 1 + test_data["NV-GPT-8B-Chat-4k-SFT"]["min_tps"] = 1 test_data["NV-GPT-8B-Chat-4k-SFT"]["location"] = "Local" - test_data["NV-GPT-8B-Chat-4k-SFT"]["trt_llm_model_dir"] = "/tmp/NV-GPT-8B-Chat-4k-SFT/nv-gpt-8b-chat-4k-sft_v1.0/" + test_data["NV-GPT-8B-Chat-4k-SFT"]["model_dir"] = "/tmp/NV-GPT-8B-Chat-4k-SFT/nv-gpt-8b-chat-4k-sft_v1.0/" test_data["NV-GPT-8B-Chat-4k-SFT"][ "checkpoint" ] = "/opt/checkpoints/NV-GPT-8B-Chat-4k-SFT/nv-gpt-8b-chat-4k-sft_v1.0/NV-GPT-8B-Chat-4k-SFT.nemo" @@ -90,11 +90,9 @@ def get_infer_test_data(): test_data["NV-GPT-8B-Chat-4k-RLHF"] = {} test_data["NV-GPT-8B-Chat-4k-RLHF"]["model_type"] = "gptnext" - test_data["NV-GPT-8B-Chat-4k-RLHF"]["min_gpus"] = 1 + test_data["NV-GPT-8B-Chat-4k-RLHF"]["min_tps"] = 1 test_data["NV-GPT-8B-Chat-4k-RLHF"]["location"] = "Local" - test_data["NV-GPT-8B-Chat-4k-RLHF"][ - "trt_llm_model_dir" - ] = "/tmp/NV-GPT-8B-Chat-4k-RLHF/nv-gpt-8b-chat-4k-rlhf_v1.0/" + test_data["NV-GPT-8B-Chat-4k-RLHF"]["model_dir"] = "/tmp/NV-GPT-8B-Chat-4k-RLHF/nv-gpt-8b-chat-4k-rlhf_v1.0/" test_data["NV-GPT-8B-Chat-4k-RLHF"][ "checkpoint" ] = "/opt/checkpoints/NV-GPT-8B-Chat-4k-RLHF/nv-gpt-8b-chat-4k-rlhf_v1.0/NV-GPT-8B-Chat-4k-RLHF.nemo" @@ -109,10 +107,10 @@ def get_infer_test_data(): test_data["NV-GPT-8B-Chat-4k-SteerLM"] = {} test_data["NV-GPT-8B-Chat-4k-SteerLM"]["model_type"] = "gptnext" - test_data["NV-GPT-8B-Chat-4k-SteerLM"]["min_gpus"] = 1 + test_data["NV-GPT-8B-Chat-4k-SteerLM"]["min_tps"] = 1 test_data["NV-GPT-8B-Chat-4k-SteerLM"]["location"] = "Local" test_data["NV-GPT-8B-Chat-4k-SteerLM"][ - "trt_llm_model_dir" + "model_dir" ] = "/tmp/NV-GPT-8B-Chat-4k-SteerLM/nv-gpt-8b-chat-4k-steerlm_v1.0/" test_data["NV-GPT-8B-Chat-4k-SteerLM"][ "checkpoint" @@ -128,9 +126,9 @@ def get_infer_test_data(): test_data["GPT-43B-Base"] = {} test_data["GPT-43B-Base"]["model_type"] = "gptnext" - test_data["GPT-43B-Base"]["min_gpus"] = 2 + test_data["GPT-43B-Base"]["min_tps"] = 2 test_data["GPT-43B-Base"]["location"] = "Local" - test_data["GPT-43B-Base"]["trt_llm_model_dir"] = "/tmp/GPT-43B-Base/gpt-43B-base/" + test_data["GPT-43B-Base"]["model_dir"] = "/tmp/GPT-43B-Base/gpt-43B-base/" test_data["GPT-43B-Base"]["checkpoint"] = "/opt/checkpoints/GPT-43B-Base/gpt-43B-base.nemo" test_data["GPT-43B-Base"]["prompt_template"] = [ "The capital of France is", @@ -143,9 +141,9 @@ def get_infer_test_data(): test_data["LLAMA2-7B-base"] = {} test_data["LLAMA2-7B-base"]["model_type"] = "llama" - test_data["LLAMA2-7B-base"]["min_gpus"] = 1 + test_data["LLAMA2-7B-base"]["min_tps"] = 1 test_data["LLAMA2-7B-base"]["location"] = "Local" - test_data["LLAMA2-7B-base"]["trt_llm_model_dir"] = "/tmp/LLAMA2-7B-base/trt_llm_model-1/" + test_data["LLAMA2-7B-base"]["model_dir"] = "/tmp/LLAMA2-7B-base/trt_llm_model-1/" test_data["LLAMA2-7B-base"]["checkpoint"] = "/opt/checkpoints/LLAMA2-7B-base/LLAMA2-7B-base-1.nemo" test_data["LLAMA2-7B-base"]["p_tuning_checkpoint"] = "/opt/checkpoints/LLAMA2-7B-PTuning/LLAMA2-7B-PTuning-1.nemo" test_data["LLAMA2-7B-base"]["lora_checkpoint"] = "/opt/checkpoints/LLAMA2-7B-Lora/LLAMA2-7B-Lora-1.nemo" @@ -160,9 +158,9 @@ def get_infer_test_data(): test_data["LLAMA2-13B-base"] = {} test_data["LLAMA2-13B-base"]["model_type"] = "llama" - test_data["LLAMA2-13B-base"]["min_gpus"] = 1 + test_data["LLAMA2-13B-base"]["min_tps"] = 1 test_data["LLAMA2-13B-base"]["location"] = "Local" - test_data["LLAMA2-13B-base"]["trt_llm_model_dir"] = "/tmp/LLAMA2-13B-base/trt_llm_model-1/" + test_data["LLAMA2-13B-base"]["model_dir"] = "/tmp/LLAMA2-13B-base/trt_llm_model-1/" test_data["LLAMA2-13B-base"]["checkpoint"] = "/opt/checkpoints/LLAMA2-13B-base/LLAMA2-13B-base-1.nemo" test_data["LLAMA2-13B-base"][ "p_tuning_checkpoint" @@ -178,9 +176,9 @@ def get_infer_test_data(): test_data["LLAMA2-70B-base"] = {} test_data["LLAMA2-70B-base"]["model_type"] = "llama" - test_data["LLAMA2-70B-base"]["min_gpus"] = 2 + test_data["LLAMA2-70B-base"]["min_tps"] = 2 test_data["LLAMA2-70B-base"]["location"] = "Local" - test_data["LLAMA2-70B-base"]["trt_llm_model_dir"] = "/tmp/LLAMA2-70B-base/trt_llm_model-1/" + test_data["LLAMA2-70B-base"]["model_dir"] = "/tmp/LLAMA2-70B-base/trt_llm_model-1/" test_data["LLAMA2-70B-base"]["checkpoint"] = "/opt/checkpoints/LLAMA2-70B-base/LLAMA2-70B-base-1.nemo" test_data["LLAMA2-70B-base"]["prompt_template"] = [ "The capital of France is", @@ -193,9 +191,9 @@ def get_infer_test_data(): test_data["LLAMA2-7B-code"] = {} test_data["LLAMA2-7B-code"]["model_type"] = "llama" - test_data["LLAMA2-7B-code"]["min_gpus"] = 1 + test_data["LLAMA2-7B-code"]["min_tps"] = 1 test_data["LLAMA2-7B-code"]["location"] = "Local" - test_data["LLAMA2-7B-code"]["trt_llm_model_dir"] = "/tmp/LLAMA2-7B-code/trt_llm_model-1/" + test_data["LLAMA2-7B-code"]["model_dir"] = "/tmp/LLAMA2-7B-code/trt_llm_model-1/" test_data["LLAMA2-7B-code"]["checkpoint"] = "/opt/checkpoints/LLAMA2-7B-code/LLAMA2-7B-code-1.nemo" test_data["LLAMA2-7B-code"]["prompt_template"] = [ "You are an expert programmer that writes simple, concise code and explanations. Write a python function to generate the nth fibonacci number." @@ -206,9 +204,9 @@ def get_infer_test_data(): test_data["LLAMA2-7B-base-fp8"] = {} test_data["LLAMA2-7B-base-fp8"]["model_type"] = "llama" - test_data["LLAMA2-7B-base-fp8"]["min_gpus"] = 1 + test_data["LLAMA2-7B-base-fp8"]["min_tps"] = 1 test_data["LLAMA2-7B-base-fp8"]["location"] = "Local" - test_data["LLAMA2-7B-base-fp8"]["trt_llm_model_dir"] = "/tmp/LLAMA2-7B-base-fp8/trt_llm_model-1/" + test_data["LLAMA2-7B-base-fp8"]["model_dir"] = "/tmp/LLAMA2-7B-base-fp8/trt_llm_model-1/" test_data["LLAMA2-7B-base-fp8"]["checkpoint"] = "/opt/checkpoints/LLAMA2-7B-base-fp8/LLAMA2-7B-base-fp8-1.qnemo" test_data["LLAMA2-7B-base-fp8"]["prompt_template"] = [ "The capital of France is", @@ -221,9 +219,9 @@ def get_infer_test_data(): test_data["LLAMA2-7B-base-int4"] = {} test_data["LLAMA2-7B-base-int4"]["model_type"] = "llama" - test_data["LLAMA2-7B-base-int4"]["min_gpus"] = 1 + test_data["LLAMA2-7B-base-int4"]["min_tps"] = 1 test_data["LLAMA2-7B-base-int4"]["location"] = "Local" - test_data["LLAMA2-7B-base-int4"]["trt_llm_model_dir"] = "/tmp/LLAMA2-7B-base-int4/trt_llm_model-1/" + test_data["LLAMA2-7B-base-int4"]["model_dir"] = "/tmp/LLAMA2-7B-base-int4/trt_llm_model-1/" test_data["LLAMA2-7B-base-int4"]["checkpoint"] = "/opt/checkpoints/LLAMA2-7B-base-int4/LLAMA2-7B-base-int4-1.qnemo" test_data["LLAMA2-7B-base-int4"]["prompt_template"] = [ "The capital of France is", @@ -236,9 +234,9 @@ def get_infer_test_data(): test_data["LLAMA2-7B-base-int8"] = {} test_data["LLAMA2-7B-base-int8"]["model_type"] = "llama" - test_data["LLAMA2-7B-base-int8"]["min_gpus"] = 1 + test_data["LLAMA2-7B-base-int8"]["min_tps"] = 1 test_data["LLAMA2-7B-base-int8"]["location"] = "Local" - test_data["LLAMA2-7B-base-int8"]["trt_llm_model_dir"] = "/tmp/LLAMA2-7B-base-int8/trt_llm_model-1/" + test_data["LLAMA2-7B-base-int8"]["model_dir"] = "/tmp/LLAMA2-7B-base-int8/trt_llm_model-1/" test_data["LLAMA2-7B-base-int8"]["checkpoint"] = "/opt/checkpoints/LLAMA2-7B-base-int8/LLAMA2-7B-base-int8-1.qnemo" test_data["LLAMA2-7B-base-int8"]["prompt_template"] = [ "The capital of France is", @@ -251,9 +249,9 @@ def get_infer_test_data(): test_data["LLAMA2-13B-base-fp8"] = {} test_data["LLAMA2-13B-base-fp8"]["model_type"] = "llama" - test_data["LLAMA2-13B-base-fp8"]["min_gpus"] = 2 + test_data["LLAMA2-13B-base-fp8"]["min_tps"] = 2 test_data["LLAMA2-13B-base-fp8"]["location"] = "Local" - test_data["LLAMA2-13B-base-fp8"]["trt_llm_model_dir"] = "/tmp/LLAMA2-13B-base-fp8/trt_llm_model-1/" + test_data["LLAMA2-13B-base-fp8"]["model_dir"] = "/tmp/LLAMA2-13B-base-fp8/trt_llm_model-1/" test_data["LLAMA2-13B-base-fp8"]["checkpoint"] = "/opt/checkpoints/LLAMA2-13B-base-fp8/LLAMA2-13B-base-fp8-1-qnemo" test_data["LLAMA2-13B-base-fp8"]["prompt_template"] = [ "The capital of France is", @@ -266,9 +264,9 @@ def get_infer_test_data(): test_data["LLAMA2-13B-base-int4"] = {} test_data["LLAMA2-13B-base-int4"]["model_type"] = "llama" - test_data["LLAMA2-13B-base-int4"]["min_gpus"] = 2 + test_data["LLAMA2-13B-base-int4"]["min_tps"] = 2 test_data["LLAMA2-13B-base-int4"]["location"] = "Local" - test_data["LLAMA2-13B-base-int4"]["trt_llm_model_dir"] = "/tmp/LLAMA2-13B-base-int4/trt_llm_model-1/" + test_data["LLAMA2-13B-base-int4"]["model_dir"] = "/tmp/LLAMA2-13B-base-int4/trt_llm_model-1/" test_data["LLAMA2-13B-base-int4"][ "checkpoint" ] = "/opt/checkpoints/LLAMA2-13B-base-int4/LLAMA2-13B-base-int4-1-qnemo" @@ -283,9 +281,9 @@ def get_infer_test_data(): test_data["LLAMA2-70B-base-fp8"] = {} test_data["LLAMA2-70B-base-fp8"]["model_type"] = "llama" - test_data["LLAMA2-70B-base-fp8"]["min_gpus"] = 8 + test_data["LLAMA2-70B-base-fp8"]["min_tps"] = 8 test_data["LLAMA2-70B-base-fp8"]["location"] = "Local" - test_data["LLAMA2-70B-base-fp8"]["trt_llm_model_dir"] = "/tmp/LLAMA2-70B-base-fp8/trt_llm_model-1/" + test_data["LLAMA2-70B-base-fp8"]["model_dir"] = "/tmp/LLAMA2-70B-base-fp8/trt_llm_model-1/" test_data["LLAMA2-70B-base-fp8"]["checkpoint"] = "/opt/checkpoints/LLAMA2-70B-base-fp8/LLAMA2-70B-base-fp8-1-qnemo" test_data["LLAMA2-70B-base-fp8"]["prompt_template"] = [ "The capital of France is", @@ -298,9 +296,9 @@ def get_infer_test_data(): test_data["LLAMA2-70B-base-int4"] = {} test_data["LLAMA2-70B-base-int4"]["model_type"] = "llama" - test_data["LLAMA2-70B-base-int4"]["min_gpus"] = 8 + test_data["LLAMA2-70B-base-int4"]["min_tps"] = 8 test_data["LLAMA2-70B-base-int4"]["location"] = "Local" - test_data["LLAMA2-70B-base-int4"]["trt_llm_model_dir"] = "/tmp/LLAMA2-70B-base-int4/trt_llm_model-1/" + test_data["LLAMA2-70B-base-int4"]["model_dir"] = "/tmp/LLAMA2-70B-base-int4/trt_llm_model-1/" test_data["LLAMA2-70B-base-int4"][ "checkpoint" ] = "/opt/checkpoints/LLAMA2-70B-base-int4/LLAMA2-70B-base-int4-1-qnemo" @@ -315,9 +313,9 @@ def get_infer_test_data(): test_data["FALCON-7B-base"] = {} test_data["FALCON-7B-base"]["model_type"] = "falcon" - test_data["FALCON-7B-base"]["min_gpus"] = 1 + test_data["FALCON-7B-base"]["min_tps"] = 1 test_data["FALCON-7B-base"]["location"] = "Local" - test_data["FALCON-7B-base"]["trt_llm_model_dir"] = "/tmp/FALCON-7B-base/trt_llm_model-1/" + test_data["FALCON-7B-base"]["model_dir"] = "/tmp/FALCON-7B-base/trt_llm_model-1/" test_data["FALCON-7B-base"]["checkpoint"] = "/opt/checkpoints/FALCON-7B-base/FALCON-7B-base-1.nemo" test_data["FALCON-7B-base"]["prompt_template"] = [ "The capital of France is", @@ -330,9 +328,9 @@ def get_infer_test_data(): test_data["FALCON-40B-base"] = {} test_data["FALCON-40B-base"]["model_type"] = "falcon" - test_data["FALCON-40B-base"]["min_gpus"] = 2 + test_data["FALCON-40B-base"]["min_tps"] = 2 test_data["FALCON-40B-base"]["location"] = "Local" - test_data["FALCON-40B-base"]["trt_llm_model_dir"] = "/tmp/FALCON-40B-base/trt_llm_model-1/" + test_data["FALCON-40B-base"]["model_dir"] = "/tmp/FALCON-40B-base/trt_llm_model-1/" test_data["FALCON-40B-base"]["checkpoint"] = "/opt/checkpoints/FALCON-40B-base/FALCON-40B-base-1.nemo" test_data["FALCON-40B-base"]["prompt_template"] = [ "The capital of France is", @@ -345,9 +343,9 @@ def get_infer_test_data(): test_data["FALCON-180B-base"] = {} test_data["FALCON-180B-base"]["model_type"] = "falcon" - test_data["FALCON-180B-base"]["min_gpus"] = 8 + test_data["FALCON-180B-base"]["min_tps"] = 8 test_data["FALCON-180B-base"]["location"] = "Local" - test_data["FALCON-180B-base"]["trt_llm_model_dir"] = "/tmp/FALCON-180B-base/trt_llm_model-1/" + test_data["FALCON-180B-base"]["model_dir"] = "/tmp/FALCON-180B-base/trt_llm_model-1/" test_data["FALCON-180B-base"]["checkpoint"] = "/opt/checkpoints/FALCON-180B-base/FALCON-180B-base-1.nemo" test_data["FALCON-180B-base"]["prompt_template"] = [ "The capital of France is", @@ -360,9 +358,9 @@ def get_infer_test_data(): test_data["STARCODER1-15B-base"] = {} test_data["STARCODER1-15B-base"]["model_type"] = "starcoder" - test_data["STARCODER1-15B-base"]["min_gpus"] = 1 + test_data["STARCODER1-15B-base"]["min_tps"] = 1 test_data["STARCODER1-15B-base"]["location"] = "Local" - test_data["STARCODER1-15B-base"]["trt_llm_model_dir"] = "/tmp/STARCODER1-15B-base/trt_llm_model-1/" + test_data["STARCODER1-15B-base"]["model_dir"] = "/tmp/STARCODER1-15B-base/trt_llm_model-1/" test_data["STARCODER1-15B-base"]["checkpoint"] = "/opt/checkpoints/STARCODER1-15B-base/STARCODER1-15B-base-1.nemo" test_data["STARCODER1-15B-base"]["prompt_template"] = ["def fibonnaci(n"] test_data["STARCODER1-15B-base"]["expected_keyword"] = ["fibonnaci"] @@ -371,9 +369,9 @@ def get_infer_test_data(): test_data["GEMMA-base"] = {} test_data["GEMMA-base"]["model_type"] = "gemma" - test_data["GEMMA-base"]["min_gpus"] = 1 + test_data["GEMMA-base"]["min_tps"] = 1 test_data["GEMMA-base"]["location"] = "Local" - test_data["GEMMA-base"]["trt_llm_model_dir"] = "/tmp/GEMMA-base/trt_llm_model-1/" + test_data["GEMMA-base"]["model_dir"] = "/tmp/GEMMA-base/trt_llm_model-1/" test_data["GEMMA-base"]["checkpoint"] = "/opt/checkpoints/GEMMA-base/GEMMA-base-1.nemo" test_data["GEMMA-base"]["prompt_template"] = [ "The capital of France is", diff --git a/tests/lightning/fabric/__init__.py b/tests/lightning/fabric/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/lightning/fabric/test_conversion.py b/tests/lightning/fabric/test_conversion.py new file mode 100644 index 0000000000000..53d8d1a2dd49d --- /dev/null +++ b/tests/lightning/fabric/test_conversion.py @@ -0,0 +1,76 @@ +import pytest +from lightning_fabric import plugins as fl_plugins +from lightning_fabric import strategies as fl_strategies +from pytorch_lightning import plugins as pl_plugins +from pytorch_lightning import strategies as pl_strategies + +from nemo import lightning as nl +from nemo.lightning.fabric.conversion import to_fabric + + +class TestConversion: + def test_ddp_strategy_conversion(self): + pl_strategy = pl_strategies.DDPStrategy() + fabric_strategy = to_fabric(pl_strategy) + + assert isinstance(fabric_strategy, fl_strategies.DDPStrategy) + + def test_fsdp_strategy_conversion(self): + pl_strategy = pl_strategies.FSDPStrategy( + cpu_offload=True, + ) + fabric_strategy = to_fabric(pl_strategy) + + assert isinstance(fabric_strategy, fl_strategies.FSDPStrategy) + assert fabric_strategy.cpu_offload.offload_params is True + + def test_mixed_precision_plugin_conversion(self): + pl_plugin = pl_plugins.MixedPrecision(precision='16-mixed', device='cpu') + fabric_plugin = to_fabric(pl_plugin) + + assert isinstance(fabric_plugin, fl_plugins.MixedPrecision) + assert fabric_plugin.precision == '16-mixed' + + def test_fsdp_precision_plugin_conversion(self): + pl_plugin = pl_plugins.FSDPPrecision(precision='16-mixed') + fabric_plugin = to_fabric(pl_plugin) + + assert isinstance(fabric_plugin, fl_plugins.FSDPPrecision) + assert fabric_plugin.precision == '16-mixed' + + def test_unsupported_object_conversion(self): + class UnsupportedObject: + pass + + with pytest.raises(NotImplementedError) as excinfo: + to_fabric(UnsupportedObject()) + + assert "No Fabric converter registered for UnsupportedObject" in str(excinfo.value) + + def test_megatron_strategy_conversion(self): + pl_strategy = nl.MegatronStrategy( + tensor_model_parallel_size=2, + pipeline_model_parallel_size=2, + virtual_pipeline_model_parallel_size=2, + context_parallel_size=2, + sequence_parallel=True, + expert_model_parallel_size=2, + moe_extended_tp=True, + ) + fabric_strategy = to_fabric(pl_strategy) + + assert isinstance(fabric_strategy, nl.FabricMegatronStrategy) + assert fabric_strategy.tensor_model_parallel_size == 2 + assert fabric_strategy.pipeline_model_parallel_size == 2 + assert fabric_strategy.virtual_pipeline_model_parallel_size == 2 + assert fabric_strategy.context_parallel_size == 2 + assert fabric_strategy.sequence_parallel is True + assert fabric_strategy.expert_model_parallel_size == 2 + assert fabric_strategy.moe_extended_tp is True + + def test_megatron_precision_conversion(self): + pl_plugin = nl.MegatronMixedPrecision(precision='16-mixed') + fabric_plugin = to_fabric(pl_plugin) + + assert isinstance(fabric_plugin, nl.FabricMegatronMixedPrecision) + assert fabric_plugin.precision == '16-mixed' diff --git a/tests/lightning/io/test_api.py b/tests/lightning/io/test_api.py index 9872d08601936..44e2dd9e2c211 100644 --- a/tests/lightning/io/test_api.py +++ b/tests/lightning/io/test_api.py @@ -1,23 +1,35 @@ +import transformer_engine as te +from pytorch_lightning.loggers import TensorBoardLogger + from nemo import lightning as nl from nemo.collections import llm +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer from nemo.lightning import io class TestLoad: def test_reload_ckpt(self, tmpdir): - trainer = nl.Trainer(devices=1, accelerator="cpu", strategy=nl.MegatronStrategy()) - # model = llm.Mistral7BModel() + trainer = nl.Trainer( + devices=1, + accelerator="cpu", + strategy=nl.MegatronStrategy(), + logger=TensorBoardLogger("tb_logs", name="my_model"), + ) + tokenizer = get_nmt_tokenizer("megatron", "GPT2BPETokenizer") model = llm.GPTModel( llm.GPTConfig( num_layers=2, hidden_size=1024, ffn_hidden_size=4096, num_attention_heads=8, - ) + ), + tokenizer=tokenizer, ) - ckpt = io.TrainerCheckpoint(model, trainer) + ckpt = io.TrainerContext(model, trainer) ckpt.io_dump(tmpdir) - loaded = io.load_ckpt(tmpdir) + loaded = io.load_context(tmpdir) assert loaded.model.config.seq_length == ckpt.model.config.seq_length + assert loaded.model.__io__.tokenizer.vocab_file.startswith(str(tmpdir)) + assert loaded.model.__io__.tokenizer.merges_file.startswith(str(tmpdir)) diff --git a/tests/lightning/pytorch/__init__.py b/tests/lightning/pytorch/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/lightning/pytorch/callbacks/__init__.py b/tests/lightning/pytorch/callbacks/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/lightning/pytorch/callbacks/test_model_transform.py b/tests/lightning/pytorch/callbacks/test_model_transform.py new file mode 100644 index 0000000000000..9894f7d7bc585 --- /dev/null +++ b/tests/lightning/pytorch/callbacks/test_model_transform.py @@ -0,0 +1,48 @@ +import pytest +import pytorch_lightning as pl +from torch import nn + +from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform + + +class TestModelTransformCallback: + @pytest.fixture + def callback(self): + return ModelTransform() + + @pytest.fixture + def pl_module(self): + return MockLightningModule() + + @pytest.fixture + def trainer(self): + return pl.Trainer() + + def test_setup_stores_transform(self, callback, pl_module, trainer, caplog): + callback.setup(trainer, pl_module, 'fit') + + assert callback.model_transform is not None, "callback.model_transform should be set after setup" + assert hasattr( + callback.model_transform, '__num_calls__' + ), "callback.model_transform should have __num_calls__ attribute" + assert callback.model_transform.__num_calls__ == 0, "callback.model_transform should not have been called yet" + assert pl_module.model_transform == callback.model_transform, "pl_module.model_transform should be updated" + + +class MockModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 10) + + def forward(self, x): + return self.linear(x) + + +class MockLightningModule(pl.LightningModule): + def __init__(self): + super().__init__() + self.model = MockModel() + self.model_transform = lambda m: nn.Sequential(m, nn.ReLU()) + + def forward(self, x): + return self.model(x) diff --git a/tests/lightning/pytorch/callbacks/test_nsys.py b/tests/lightning/pytorch/callbacks/test_nsys.py new file mode 100644 index 0000000000000..e8734ad1c1ac9 --- /dev/null +++ b/tests/lightning/pytorch/callbacks/test_nsys.py @@ -0,0 +1,195 @@ +from unittest.mock import MagicMock, patch + +import pytest +import torch +from nemo.lightning.pytorch.callbacks.nsys import NsysCallback + + +class TestNsysCallback: + @pytest.fixture(autouse=True) + def setup_mocks(self): + self.cuda_mock = patch('torch.cuda') + self.cudart_mock = patch('torch.cuda.cudart') + self.emit_nvtx_mock = patch('torch.autograd.profiler.emit_nvtx') + self.get_rank_mock = patch('nemo.lightning.pytorch.callbacks.nsys.get_rank') + + self.cuda_mock.start() + self.cudart_mock.start() + self.emit_nvtx_mock.start() + self.get_rank_mock.start() + + # Mock CUDA availability + torch.cuda.is_available = MagicMock(return_value=True) + torch.cuda.current_device = MagicMock(return_value=0) + + yield + + self.cuda_mock.stop() + self.cudart_mock.stop() + self.emit_nvtx_mock.stop() + self.get_rank_mock.stop() + + @pytest.fixture + def mock_trainer(self): + trainer = MagicMock() + trainer.strategy.root_device.type = 'cuda' + return trainer + + @pytest.fixture + def mock_pl_module(self): + return MagicMock() + + def test_init_valid_params(self): + """Test initialization with valid parameters.""" + callback = NsysCallback(start_step=10, end_step=20, ranks=[0, 1], gen_shape=True) + assert callback._nsys_profile_start_step == 10 + assert callback._nsys_profile_end_step == 20 + assert callback._nsys_profile_ranks == [0, 1] + assert callback._nsys_profile_gen_shape == True + + def test_init_invalid_params(self): + """Test initialization with invalid parameters.""" + with pytest.raises(AssertionError): + NsysCallback(start_step='10', end_step=20) + + with pytest.raises(AssertionError): + NsysCallback(start_step=10, end_step='20') + + with pytest.raises(AssertionError): + NsysCallback(start_step=20, end_step=10) + + @patch('nemo.lightning.pytorch.callbacks.nsys.get_rank') + @patch('torch.cuda.cudart') + @patch('torch.autograd.profiler.emit_nvtx') + def test_on_train_batch_start_profiling( + self, mock_emit_nvtx, mock_cudart, mock_get_rank, mock_trainer, mock_pl_module + ): + """Test on_train_batch_start when profiling should start.""" + mock_get_rank.return_value = 0 + callback = NsysCallback(start_step=10, end_step=20, ranks=[0], gen_shape=True) + + callback.on_train_batch_start(mock_trainer, mock_pl_module, None, 10) + + mock_cudart().cudaProfilerStart.assert_called_once() + mock_emit_nvtx.assert_called_once_with(record_shapes=True) + + @patch('nemo.lightning.pytorch.callbacks.nsys.get_rank') + @patch('torch.cuda.cudart') + def test_on_train_batch_start_no_profiling(self, mock_cudart, mock_get_rank, mock_trainer, mock_pl_module): + """Test on_train_batch_start when profiling should not start.""" + mock_get_rank.return_value = 0 + callback = NsysCallback(start_step=10, end_step=20, ranks=[0]) + + callback.on_train_batch_start(mock_trainer, mock_pl_module, None, 9) + + mock_cudart().cudaProfilerStart.assert_not_called() + + @patch('nemo.lightning.pytorch.callbacks.nsys.get_rank') + @patch('torch.cuda.cudart') + @patch('torch.autograd.profiler.emit_nvtx') + def test_on_train_batch_end_profiling( + self, mock_emit_nvtx, mock_cudart, mock_get_rank, mock_trainer, mock_pl_module + ): + """Test on_train_batch_end when profiling should end.""" + mock_get_rank.return_value = 0 + callback = NsysCallback(start_step=10, end_step=20, ranks=[0]) + + callback.on_train_batch_end(mock_trainer, mock_pl_module, None, None, 20) + + mock_cudart().cudaProfilerStop.assert_called_once() + + @patch('nemo.lightning.pytorch.callbacks.nsys.get_rank') + @patch('torch.cuda.cudart') + @patch('torch.autograd.profiler.emit_nvtx') + def test_on_train_batch_end_no_profiling( + self, mock_emit_nvtx, mock_cudart, mock_get_rank, mock_trainer, mock_pl_module + ): + """Test on_train_batch_end when profiling should not end.""" + mock_get_rank.return_value = 0 + callback = NsysCallback(start_step=10, end_step=20, ranks=[0]) + + callback.on_train_batch_end(mock_trainer, mock_pl_module, None, None, 19) + + mock_cudart().cudaProfilerStop.assert_not_called() + + def test_non_cuda_device(self, mock_trainer, mock_pl_module): + """Test behavior when the device is not CUDA.""" + mock_trainer.strategy.root_device.type = 'cpu' + callback = NsysCallback(start_step=10, end_step=20, ranks=[0]) + + callback.on_train_batch_start(mock_trainer, mock_pl_module, None, 10) + callback.on_train_batch_end(mock_trainer, mock_pl_module, None, None, 20) + + # No exceptions should be raised, and no profiling calls should be made + + @patch('nemo.lightning.pytorch.callbacks.nsys.get_rank') + def test_rank_not_in_profile_ranks(self, mock_get_rank, mock_trainer, mock_pl_module): + """Test behavior when the current rank is not in the profile ranks.""" + mock_get_rank.return_value = 1 + callback = NsysCallback(start_step=10, end_step=20, ranks=[0]) + callback = NsysCallback(start_step=10, end_step=20, ranks=[0]) + + callback.on_train_batch_start(mock_trainer, mock_pl_module, None, 10) + callback.on_train_batch_end(mock_trainer, mock_pl_module, None, None, 20) + + # No profiling calls should be made + + @pytest.mark.parametrize( + "start_step,end_step,batch_idx,expected_call", + [ + (10, 20, 9, False), + (10, 20, 10, True), + (10, 20, 15, False), + (10, 20, 20, False), + (10, 20, 21, False), + ], + ) + @patch('nemo.lightning.pytorch.callbacks.nsys.get_rank') + @patch('torch.cuda.cudart') + @patch('torch.autograd.profiler.emit_nvtx') + def test_profiling_range( + self, + mock_emit_nvtx, + mock_cudart, + mock_get_rank, + start_step, + end_step, + batch_idx, + expected_call, + mock_trainer, + mock_pl_module, + ): + """Test profiling behavior across different batch indices.""" + mock_get_rank.return_value = 0 + callback = NsysCallback(start_step=start_step, end_step=end_step, ranks=[0]) + + callback.on_train_batch_start(mock_trainer, mock_pl_module, None, batch_idx) + + if expected_call: + mock_cudart().cudaProfilerStart.assert_called_once() + mock_emit_nvtx.assert_called_once() + else: + mock_cudart().cudaProfilerStart.assert_not_called() + mock_emit_nvtx.assert_not_called() + + @patch('nemo.lightning.pytorch.callbacks.nsys.get_rank') + @patch('torch.cuda.cudart') + def test_single_profile_range(self, mock_cudart, mock_get_rank, mock_trainer, mock_pl_module): + """Test behavior with a single profile range.""" + mock_get_rank.return_value = 0 + callback = NsysCallback(start_step=10, end_step=40, ranks=[0]) + + # Ensure the device type is 'cuda' + mock_trainer.strategy.root_device.type = 'cuda' + + # Start of range + callback.on_train_batch_start(mock_trainer, mock_pl_module, None, 10) + assert mock_cudart().cudaProfilerStart.call_count == 1, "cudaProfilerStart was not called" + + # Middle of range + callback.on_train_batch_start(mock_trainer, mock_pl_module, None, 25) + assert mock_cudart().cudaProfilerStart.call_count == 1, "cudaProfilerStart was called again" + + # End of range + callback.on_train_batch_end(mock_trainer, mock_pl_module, None, None, 40) + assert mock_cudart().cudaProfilerStop.call_count == 1, "cudaProfilerStop was not called" diff --git a/tests/lightning/pytorch/callbacks/test_peft.py b/tests/lightning/pytorch/callbacks/test_peft.py new file mode 100644 index 0000000000000..81dc7f85bc086 --- /dev/null +++ b/tests/lightning/pytorch/callbacks/test_peft.py @@ -0,0 +1,68 @@ +from unittest.mock import MagicMock, patch + +import torch.nn as nn +from nemo.collections.llm import fn +from nemo.lightning.pytorch.callbacks.peft import PEFT, WrappedAdapterIO + + +class TestPEFT: + class DummyPEFT(PEFT): + def transform(self, module, name=None, prefix=None): + return module # No-op transform for testing + + class DummyModel(nn.Module, fn.FNMixin): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 10) + self.conv = nn.Conv2d(3, 3, 3) + + def test_peft_call(self): + model = self.DummyModel() + peft = self.DummyPEFT() + + transformed_model = peft(model) + + assert transformed_model.linear.weight.requires_grad == False + assert transformed_model.conv.weight.requires_grad == False + + def test_peft_setup(self): + peft = self.DummyPEFT() + trainer = MagicMock() + pl_module = MagicMock() + + pl_module.model_transform = peft + peft.setup(trainer, pl_module, "fit") + + assert isinstance(trainer.strategy._checkpoint_io, WrappedAdapterIO) + assert peft.model_transform is not None + assert peft._needs_to_call is True + + @patch('nemo.lightning.pytorch.callbacks.peft.logging') + def test_peft_on_train_epoch_start_with_adapter(self, mock_logging): + peft = self.DummyPEFT() + trainer = MagicMock() + pl_module = MagicMock() + pl_module.model_transform = peft + + peft.setup(trainer, pl_module, "fit") + + assert peft.model_transform is not None + assert peft._needs_to_call is True + + peft.wrapped_io = MagicMock() + peft.wrapped_io.adapter_ckpt_path = "dummy_path" + peft.wrapped_io.load_checkpoint.return_value = {"dummy_state": "dummy_value"} + peft.on_train_epoch_start(trainer, pl_module) + + mock_logging.info.assert_called_once_with("Loading adapters from dummy_path") + trainer.strategy.load_model_state_dict.assert_called_once_with({"dummy_state": "dummy_value"}, strict=False) + + def test_peft_on_load_checkpoint(self): + peft = self.DummyPEFT() + trainer = MagicMock() + pl_module = MagicMock() + checkpoint = {} + + peft.on_load_checkpoint(trainer, pl_module, checkpoint) + + assert pl_module.strict_loading == False diff --git a/tests/lightning/pytorch/callbacks/test_preemption.py b/tests/lightning/pytorch/callbacks/test_preemption.py new file mode 100644 index 0000000000000..5fcb4a1458eef --- /dev/null +++ b/tests/lightning/pytorch/callbacks/test_preemption.py @@ -0,0 +1,114 @@ +import logging +import signal +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +import torch +from pytorch_lightning import Trainer + +from nemo.lightning.pytorch.callbacks.preemption import PreemptionCallback, PreemptionException + + +class TestPreemptionCallback: + + @pytest.fixture + def callback(self): + return PreemptionCallback() + + @pytest.fixture + def mock_trainer(self): + trainer = MagicMock(spec=Trainer) + trainer.should_stop = False + return trainer + + def test_init(self, callback): + assert callback.sig == signal.SIGTERM + assert not callback._interrupted + assert callback._handler_context is None + + def test_custom_signal(self): + custom_callback = PreemptionCallback(sig=signal.SIGUSR1) + assert custom_callback.sig == signal.SIGUSR1 + + @pytest.mark.parametrize("initially_supported,becomes_supported", [(False, True), (False, False), (True, True)]) + def test_on_train_batch_start_distributed_init( + self, callback, mock_trainer, initially_supported, becomes_supported + ): + with ( + patch.object(PreemptionCallback, '_check_preemption_support') as mock_check, + patch.object(callback, '_preemption_handler') as mock_handler, + ): + + mock_check.side_effect = [initially_supported, becomes_supported] + + callback.on_train_start(mock_trainer, None) + callback.on_train_batch_start(mock_trainer, None, None, 0) + + expected_call_count = 1 if initially_supported else (1 if becomes_supported else 0) + assert mock_handler.call_count == expected_call_count + + if initially_supported: + mock_handler.assert_called_once_with() + elif becomes_supported: + mock_handler.assert_called_once_with() + else: + mock_handler.assert_not_called() + + @pytest.mark.parametrize( + "is_supported,interrupted,expected", + [ + (True, True, True), + (True, False, False), + (False, True, False), + (False, False, False), + ], + ) + def test_interrupted_property(self, callback, is_supported, interrupted, expected): + with ( + patch.object(PreemptionCallback, '_check_preemption_support', return_value=is_supported), + patch('torch.distributed.broadcast'), + patch('torch.tensor', return_value=torch.tensor(interrupted)), + patch('torch.cuda.is_available', return_value=True), + patch('torch.cuda.current_device', return_value=0), + ): + callback._interrupted = interrupted + assert callback.interrupted == expected + + def test_on_train_start(self, callback, mock_trainer): + with ( + patch.object(PreemptionCallback, 'preemption_supported', new_callable=PropertyMock) as mock_supported, + patch.object(callback, '_preemption_handler') as mock_handler, + ): + + # Test when preemption is supported + mock_supported.return_value = True + callback.on_train_start(mock_trainer, None) + mock_handler.assert_called_once() + mock_handler.reset_mock() + + # Test when preemption is not supported + mock_supported.return_value = False + callback.on_train_start(mock_trainer, None) + mock_handler.assert_not_called() + + def test_on_train_end(self, callback, mock_trainer): + mock_context = MagicMock() + callback._handler_context = mock_context + callback.on_train_end(mock_trainer, None) + mock_context.__exit__.assert_called_once_with(None, None, None) + + @pytest.mark.parametrize("interrupted", [True, False]) + def test_on_train_batch_end(self, callback, mock_trainer, interrupted): + with patch.object(PreemptionCallback, 'interrupted', new_callable=lambda: property(lambda self: interrupted)): + callback.on_train_batch_end(mock_trainer, None, None, None, 0) + assert mock_trainer.should_stop == interrupted + + def test_on_exception_preemption(self, callback, mock_trainer): + exception = PreemptionException("Test preemption") + callback.on_exception(mock_trainer, None, exception) + assert mock_trainer.should_stop + + def test_on_exception_other(self, callback, mock_trainer): + exception = ValueError("Some other exception") + callback.on_exception(mock_trainer, None, exception) + assert not mock_trainer.should_stop diff --git a/tests/lightning/pytorch/test_trainer.py b/tests/lightning/pytorch/test_trainer.py new file mode 100644 index 0000000000000..65c247eae0ef3 --- /dev/null +++ b/tests/lightning/pytorch/test_trainer.py @@ -0,0 +1,18 @@ +from nemo import lightning as nl + + +class TestFabricConversion: + def test_simple_conversion(self): + trainer = nl.Trainer( + devices=1, + accelerator="cpu", + strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), + plugins=nl.MegatronMixedPrecision(precision='16-mixed'), + ) + + fabric = trainer.to_fabric() + + assert isinstance(fabric.strategy, nl.FabricMegatronStrategy) + assert fabric.strategy.tensor_model_parallel_size == 2 + assert isinstance(fabric._precision, nl.FabricMegatronMixedPrecision) + assert fabric._precision.precision == '16-mixed' diff --git a/tests/lightning/test_megatron_parallel.py b/tests/lightning/test_megatron_parallel.py index fafd25e49f5af..e504c7eb5c7cd 100644 --- a/tests/lightning/test_megatron_parallel.py +++ b/tests/lightning/test_megatron_parallel.py @@ -1,4 +1,5 @@ from collections import defaultdict +from unittest.mock import MagicMock import pytest from megatron.core import parallel_state @@ -123,13 +124,14 @@ def test_add_callbacks(self) -> None: assert callback in callback_connector.callbacks["on_megatron_step_start"] assert callback in callback_connector.callbacks["on_megatron_microbatch_start"] - def test_event(self, mocker) -> None: + def test_event(self) -> None: callback_connector = mp.CallbackConnector() callback = TestCallback() callback_connector.add(callback) - mocker.spy(callback, "on_megatron_step_start") - mocker.spy(callback, "on_megatron_microbatch_start") + # Replace mocker.spy with manual mocking + callback.on_megatron_step_start = MagicMock() + callback.on_megatron_microbatch_start = MagicMock() callback_connector.event("on_megatron_step_start") callback_connector.event("on_megatron_microbatch_start") diff --git a/tests/lightning/test_nemo_logger.py b/tests/lightning/test_nemo_logger.py new file mode 100644 index 0000000000000..0dd49838d9e4f --- /dev/null +++ b/tests/lightning/test_nemo_logger.py @@ -0,0 +1,60 @@ +from unittest.mock import patch + +import pytest +from pytorch_lightning.callbacks import ModelCheckpoint as PTLModelCheckpoint +from pytorch_lightning.loggers import WandbLogger + +from nemo import lightning as nl + + +class TestNeMoLogger: + @pytest.fixture + def trainer(self): + return nl.Trainer(accelerator="cpu") + + def test_loggers(self): + trainer = nl.Trainer(accelerator="cpu") + logger = nl.NeMoLogger( + update_logger_directory=True, + wandb=WandbLogger(save_dir="test", offline=True), + ) + + logger.setup(trainer) + assert logger.tensorboard is None + assert len(logger.extra_loggers) == 0 + assert len(trainer.loggers) == 2 + assert isinstance(trainer.loggers[1], WandbLogger) + assert str(trainer.loggers[1].save_dir).endswith("nemo_experiments") + assert trainer.loggers[1]._name == "default" + + def test_explicit_log_dir(self, trainer): + explicit_dir = "explicit_test_dir" + logger = nl.NeMoLogger(name="test", explicit_log_dir=explicit_dir) + + with patch("nemo.utils.exp_manager.check_explicit_log_dir") as mock_check: + logger.setup(trainer) + mock_check.assert_called_once_with(trainer, explicit_dir, None, "test", None) + + def test_custom_version(self, trainer): + custom_version = "v1.0" + logger = nl.NeMoLogger(name="test", version=custom_version, use_datetime_version=False) + + app_state = logger.setup(trainer) + assert app_state.version == custom_version + + def test_file_logging_setup(self, trainer): + logger = nl.NeMoLogger(name="test") + + with patch("nemo.lightning.nemo_logger.logging.add_file_handler") as mock_add_handler: + logger.setup(trainer) + mock_add_handler.assert_called_once() + + def test_model_checkpoint_setup(self, trainer): + ckpt = PTLModelCheckpoint(dirpath="test_ckpt", filename="test-{epoch:02d}-{val_loss:.2f}") + logger = nl.NeMoLogger(name="test", ckpt=ckpt) + + logger.setup(trainer) + assert any(isinstance(cb, PTLModelCheckpoint) for cb in trainer.callbacks) + ptl_ckpt = next(cb for cb in trainer.callbacks if isinstance(cb, PTLModelCheckpoint)) + assert str(ptl_ckpt.dirpath).endswith("test_ckpt") + assert ptl_ckpt.filename == "test-{epoch:02d}-{val_loss:.2f}" diff --git a/tools/rir_corpus_generator/rir_corpus_generator.py b/tools/rir_corpus_generator/rir_corpus_generator.py index d6e153ab3959d..e3f1e05a70f08 100644 --- a/tools/rir_corpus_generator/rir_corpus_generator.py +++ b/tools/rir_corpus_generator/rir_corpus_generator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemo.collections.asr.data.data_simulation import RIRCorpusGenerator +from nemo.collections.audio.data.data_simulation import RIRCorpusGenerator from nemo.core.config import hydra_runner diff --git a/tools/rir_corpus_generator/rir_mix_generator.py b/tools/rir_corpus_generator/rir_mix_generator.py index 170c0285e86d1..a1e2856f94c4f 100644 --- a/tools/rir_corpus_generator/rir_mix_generator.py +++ b/tools/rir_corpus_generator/rir_mix_generator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemo.collections.asr.data.data_simulation import RIRMixGenerator +from nemo.collections.audio.data.data_simulation import RIRMixGenerator from nemo.core.config import hydra_runner diff --git a/tutorials/audio_tasks/README.md b/tutorials/audio/README.md similarity index 100% rename from tutorials/audio_tasks/README.md rename to tutorials/audio/README.md diff --git a/tutorials/audio_tasks/speech_enhancement/Speech_Enhancement_with_NeMo.ipynb b/tutorials/audio/speech_enhancement/Speech_Enhancement_with_NeMo.ipynb similarity index 98% rename from tutorials/audio_tasks/speech_enhancement/Speech_Enhancement_with_NeMo.ipynb rename to tutorials/audio/speech_enhancement/Speech_Enhancement_with_NeMo.ipynb index 535d67921e23e..ffd630824bdbe 100644 --- a/tutorials/audio_tasks/speech_enhancement/Speech_Enhancement_with_NeMo.ipynb +++ b/tutorials/audio/speech_enhancement/Speech_Enhancement_with_NeMo.ipynb @@ -494,7 +494,7 @@ "config_path = config_dir / 'masking.yaml'\n", "\n", "if not config_path.is_file():\n", - " !wget https://raw.githubusercontent.com/{GIT_USER}/NeMo/{GIT_BRANCH}/examples/audio_tasks/conf/masking.yaml -P {config_dir.as_posix()}\n", + " !wget https://raw.githubusercontent.com/{GIT_USER}/NeMo/{GIT_BRANCH}/examples/audio/conf/masking.yaml -P {config_dir.as_posix()}\n", "\n", "config = OmegaConf.load(config_path)\n", "config = OmegaConf.to_container(config, resolve=True)\n", @@ -717,9 +717,9 @@ }, "outputs": [], "source": [ - "from nemo.collections import asr as nemo_asr\n", + "from nemo.collections import audio as nemo_audio\n", "\n", - "enhancement_model = nemo_asr.models.EncMaskDecAudioToAudioModel(cfg=config.model, trainer=trainer)" + "enhancement_model = nemo_audio.models.EncMaskDecAudioToAudioModel(cfg=config.model, trainer=trainer)" ] }, { @@ -905,7 +905,7 @@ }, "outputs": [], "source": [ - "from nemo.collections.asr.parts.utils.audio_utils import db2mag\n", + "from nemo.collections.audio.parts.utils.audio import db2mag\n", "\n", "# Limit suppression to 10dB\n", "min_mask_db = -10\n", @@ -1064,7 +1064,7 @@ "# Add a mixture consistency projection\n", "with open_dict(config_dual_output):\n", " config_dual_output.model.mixture_consistency = OmegaConf.create({\n", - " '_target_': 'nemo.collections.asr.modules.audio_modules.MixtureConsistencyProjection',\n", + " '_target_': 'nemo.collections.audio.modules.projections.MixtureConsistencyProjection',\n", " 'weighting': 'power',\n", " })" ] @@ -1172,7 +1172,7 @@ }, "outputs": [], "source": [ - "dual_output_model = nemo_asr.models.EncMaskDecAudioToAudioModel(cfg=config_dual_output.model, trainer=trainer)\n", + "dual_output_model = nemo_audio.models.EncMaskDecAudioToAudioModel(cfg=config_dual_output.model, trainer=trainer)\n", "trainer.fit(dual_output_model)" ] }, @@ -1288,6 +1288,12 @@ } ], "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "gpuClass": "standard", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", @@ -1304,13 +1310,7 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.10" - }, - "colab": { - "provenance": [], - "gpuType": "T4" - }, - "accelerator": "GPU", - "gpuClass": "standard" + } }, "nbformat": 4, "nbformat_minor": 5 diff --git a/tutorials/llm/mamba/mamba.rst b/tutorials/llm/mamba/mamba.rst new file mode 100644 index 0000000000000..c09a6ae03087a --- /dev/null +++ b/tutorials/llm/mamba/mamba.rst @@ -0,0 +1,301 @@ +Mamba2 and Mamba2-Transformer Hybrid Models Fine-Tuning +======================================================= + +`State Space Models (SSMs) `__ have recently emerged as a promising alternative to transformers. SSMs offer advantages such as linear time complexity relative to sequence length and a constant cache size for inference. These features enable the processing of longer sequences and higher throughput. Despite these benefits, SSMs alone may fall short compared to transformers on tasks that demand strong copying or in-context learning capabilities. + +To harness the strengths of both approaches, SSM-Hybrid models incorporate MLP, Transformer, and SSM blocks in their architecture. As highlighted in `a study by NVIDIA `__, these hybrid models outperform traditional transformers of the same size by achieving faster inference times due to the inclusion of SSM blocks. Based on experimental results, Mamba2-Hybrid models not only surpass transformer baselines in performance but also benefit from increased computational efficiency. + +The Mamba2 models discussed in the `Transformers are SSMs `__ paper are available in five different sizes: 130 million, 370 million, 780 million, 1.3 billion, and 2.7 billion parameters. The Mamba2-Hybrid models, along with their Mamba2 baseline as released by `NVIDIA `__, are provided in an 8 billion parameter size. + +`Low-Rank Adaptation (LoRA) `__ has emerged as a popular Parameter Efficient Fine-Tuning (PEFT) technique that tunes a very small number of additional parameters as compared to full fine-tuning, thereby reducing the compute required. LoRA tuning can be applied to the linear layers in the Transformer and MLP blocks for the Mamba2-Hybrid models. + +`NVIDIA NeMo +Framework `__ provides tools to perform Fine-tuning on Mamba2 and Mamba2-Hybrid to fit your use case. + +Requirements +------------- + +In order to proceed, ensure that you have met the following requirements: + +* Full Fine-Tuning System Configuration + * Small models (130m, 370m, 780m) + * Access to at least 1 NVIDIA GPU with a cumulative memory of at least 40GB, for example: 1 x A6000-40GB. + + * Mid-size models (1.3b, 2.7b) + * Access to at least 1 NVIDIA GPU with a cumulative memory of at least 80GB, for example: 1 x H100-80GB or 1 x A100-80GB. + + * Large models (8b) + * Access to at least 2 NVIDIA GPUs with a cumulative memory of at least 80GB, for example: 2 x H100-80GB or 2 x A100-80GB. + +* LoRA Fine-Tuning (Mamba2-Hybrid only) System Configuration + * Access to at least 1 NVIDIA GPU with a cumulative memory of at least 80GB, for example: 1 x H100-80GB or 1 x A100-80GB. + + + +* A Docker-enabled environment, with `NVIDIA Container Runtime `_ installed, which will make the container GPU-aware. + + +* `Authenticate with NVIDIA NGC `_, and download `NGC CLI Tool `_. + + +Step-by-step Guide for Fine-Tuning +---------------------------------- + +Checkpoints from HuggingFace +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Obtain the desired checkpoint from HuggigFace. + +* `Repository `__ for the Mamba2 models from the `Transformers are SSMs paper `__. +* `Repository `__ for the Mamba2 and Mamba2-Hybrid models by `NVIDIA `__. + + +Convert the Pytorch Checkpoint to a NeMo Checkpoint +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +1. Get into NVIDIA Container + +2. Run the conversion script from . For this conversion script, you should provide the PyTorch state dictionary of the model for ``input_name_or_path``, i.e. this argument only accepts a single ``state_dict``. + +.. code:: bash + + CUDA_VISIBLE_DEVICES="0" python /NeMo/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py \ + --input_name_or_path \ + --output_path \ + --ngroups_mamba 8 \ + --precision bf16 + +* Note: the ``ngroups_mamba`` parameter should be 1 for the Mamba2 models from the `Transformers are SSMs paper `__ (130m, 370m, 780m, 1.3b, and 2.7b) and 8 for the Mamba2 and Mamba2-Hybrid models by `NVIDIA `__ (both 8b). + +Model (Tensor) Parallelism for the 8b Models +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* Note: Distributed checkpointing for the Mamba2 and Mamba2-Hybrid models will be implemented in the near future. For now, you should use the method below for converting to Tensor Parallel (TP) of different sizes. + +The HuggingFace checkpoint for the 8b model is for TP of size 1, and so is the ``.nemo`` checkpoint obtained for the previous step. To shard the model weights for a larger TP size, use the script from