From ec662e4e9f8e5f4b0b8c17196acd4a5be096b7e7 Mon Sep 17 00:00:00 2001 From: John St John Date: Tue, 4 Mar 2025 01:20:12 +0000 Subject: [PATCH] Update the brca notebook with a run on an fp8 supporting machine --- .../examples/bionemo-evo2/.gitignore | 2 + .../bionemo-evo2/evo2_zeroshot_brca.ipynb | 985 +++++++++--------- 2 files changed, 512 insertions(+), 475 deletions(-) diff --git a/docs/docs/user-guide/examples/bionemo-evo2/.gitignore b/docs/docs/user-guide/examples/bionemo-evo2/.gitignore index fa7159094..465fe2329 100644 --- a/docs/docs/user-guide/examples/bionemo-evo2/.gitignore +++ b/docs/docs/user-guide/examples/bionemo-evo2/.gitignore @@ -10,3 +10,5 @@ nemo2_evo2_1b_8k/ preprocessed_data/ pretraining_demo/ +brca1_fasta_files/ +brca1/ diff --git a/docs/docs/user-guide/examples/bionemo-evo2/evo2_zeroshot_brca.ipynb b/docs/docs/user-guide/examples/bionemo-evo2/evo2_zeroshot_brca.ipynb index 02e1bbc20..7574ebc5d 100644 --- a/docs/docs/user-guide/examples/bionemo-evo2/evo2_zeroshot_brca.ipynb +++ b/docs/docs/user-guide/examples/bionemo-evo2/evo2_zeroshot_brca.ipynb @@ -13,19 +13,19 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/opt_einsum-3.4.0-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.23a0+6627725-py3.12-linux-x86_64.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/lightning_utilities-0.12.0.dev0-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/dill-0.3.9-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n", + "\u001b[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/lightning_utilities-0.12.0.dev0-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n", "\u001b[0m\u001b[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/looseversion-1.3.0-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.23a0+6627725-py3.12-linux-x86_64.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/opt_einsum-3.4.0-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n", "\u001b[0m\u001b[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/lightning_thunder-0.2.0.dev0-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/dill-0.3.9-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n", "\u001b[0mLooking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n", "Requirement already satisfied: biopython in /usr/local/lib/python3.12/dist-packages (1.85)\n", "Requirement already satisfied: openpyxl in /usr/local/lib/python3.12/dist-packages (3.1.5)\n", @@ -43,7 +43,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -71,7 +71,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -98,7 +98,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -239,7 +239,7 @@ "9 17 41276132 A T -0.207552 FUNC/INT" ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -279,7 +279,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -316,28 +316,28 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "To make things run faster, we'll just look at a balanced sample of our data." + "To make things run faster, we'll just look at a balanced sample of our data. If you want to run on the full dataset, set `disable_sample=True`" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "(84, 6)" + "(330, 6)" ] }, - "execution_count": 7, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "disable_sample = False\n", - "SAMPLE_FRAC = 0.05\n", + "SAMPLE_FRAC = 0.2\n", "balanced_sample = True\n", "\n", "random_state = 42\n", @@ -365,15 +365,15 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Total unique reference sequences: 79\n", - "Total unique variant sequences: 84\n" + "Total unique reference sequences: 296\n", + "Total unique variant sequences: 330\n" ] } ], @@ -447,13 +447,13 @@ "\n", "Then, we load Evo 2 1B model, loading the Evo 2 weights from hugging face.\n", "\n", - "*Note - for better performance, load the 7b model by replacing all occurrences of `1b` below with `7b`.*\n", + "*Note - for better performance, load the 7b model by setting `MODEL_SIZE=\"7b\"` which also works well GPUs that do not support FP8.*\n", "\n" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -465,12 +465,13 @@ } ], "source": [ + "MODEL_SIZE = \"1b\" # also try 7b if you have a GPU with more than 32GB of memory\n", "# Define checkpoint path\n", - "checkpoint_path = Path(\"nemo2_evo2_1b_8k\")\n", + "checkpoint_path = Path(f\"nemo2_evo2_{MODEL_SIZE}_8k\")\n", "\n", "# Check if the directory does not exist or is empty\n", "if not checkpoint_path.exists() or not any(checkpoint_path.iterdir()):\n", - " !evo2_convert_to_nemo2 --model-path hf://arcinstitute/savanna_evo2_1b_base --model-size 1b --output-dir nemo2_evo2_1b_8k\n", + " !evo2_convert_to_nemo2 --model-path hf://arcinstitute/savanna_evo2_1b_base --model-size {MODEL_SIZE} --output-dir nemo2_evo2_{MODEL_SIZE}_8k\n", "else:\n", " print(\"Checkpoint directory is not empty. Skipping command.\")\n" ] @@ -484,40 +485,80 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "FP8 Support: True\n", + "Device: NVIDIA RTX 6000 Ada Generation, Compute Capability: 8.9\n" + ] + } + ], "source": [ "# Define output directories for prediction results\n", "predict_ref_dir = output_dir / \"reference_predictions\"\n", "predict_var_dir = output_dir / \"variant_predictions\"\n", "predict_ref_dir.mkdir(parents=True, exist_ok=True)\n", "predict_var_dir.mkdir(parents=True, exist_ok=True)\n", + "# Check if FP8 is supported on the current GPU\n", + "import torch\n", + "\n", + "def check_fp8_support():\n", + " \"\"\"\n", + " Check if FP8 is supported on the current GPU.\n", + " FP8 requires compute capability 8.9+ (Ada Lovelace/Hopper architecture or newer).\n", + " \"\"\"\n", + " if not torch.cuda.is_available():\n", + " return False, \"CUDA not available\"\n", + " \n", + " device_props = torch.cuda.get_device_properties(0)\n", + " compute_capability = f\"{device_props.major}.{device_props.minor}\"\n", + " device_name = device_props.name\n", + " \n", + " # FP8 is supported on compute capability 8.9+ (Ada Lovelace/Hopper architecture)\n", + " is_supported = (device_props.major > 8) or (device_props.major == 8 and device_props.minor >= 9)\n", + " \n", + " return is_supported, f\"Device: {device_name}, Compute Capability: {compute_capability}\"\n", + "\n", + "fp8_supported, gpu_info = check_fp8_support()\n", + "print(f\"FP8 Support: {fp8_supported}\")\n", + "print(gpu_info)\n", + "\n", + "# Note: If FP8 is not supported, you may want to disable it in the model config\n", + "# The Evo2 config has 'use_fp8_input_projections: True' by default\n", + "\n", + "fp8_option = \"--fp8\" if fp8_supported else \"\"\n", "\n", "# Update predict commands to run on the full dataset\n", "predict_ref_command = (\n", " f\"predict_evo2 --fasta {ref_fasta_path} --ckpt-dir {checkpoint_path} \"\n", - " f\"--output-dir {predict_ref_dir} --model-size 1b --tensor-parallel-size 1 \"\n", - " \"--pipeline-model-parallel-size 1 --context-parallel-size 1 --output-log-prob-seqs\"\n", + " f\"--output-dir {predict_ref_dir} --model-size {MODEL_SIZE} --tensor-parallel-size 1 \"\n", + " f\"--pipeline-model-parallel-size 1 --context-parallel-size 1 --output-log-prob-seqs {fp8_option}\"\n", ")\n", "\n", "predict_var_command = (\n", " f\"predict_evo2 --fasta {var_fasta_path} --ckpt-dir {checkpoint_path} \"\n", - " f\"--output-dir {predict_var_dir} --model-size 1b --tensor-parallel-size 1 \"\n", - " \"--pipeline-model-parallel-size 1 --context-parallel-size 1 --output-log-prob-seqs\"\n", + " f\"--output-dir {predict_var_dir} --model-size {MODEL_SIZE} --tensor-parallel-size 1 \"\n", + " f\"--pipeline-model-parallel-size 1 --context-parallel-size 1 --output-log-prob-seqs {fp8_option}\"\n", ")" ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[NeMo W 2025-03-03 23:36:30 nemo_logging:405] /usr/local/lib/python3.12/dist-packages/pydub/utils.py:170: RuntimeWarning: Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\n", + "Running command: predict_evo2 --fasta brca1_fasta_files/brca1_reference_sequences.fasta --ckpt-dir nemo2_evo2_1b_8k --output-dir brca1_fasta_files/reference_predictions --model-size 1b --tensor-parallel-size 1 --pipeline-model-parallel-size 1 --context-parallel-size 1 --output-log-prob-seqs --fp8\n", + "[WARNING | bitsandbytes.cextension]: Could not find the bitsandbytes CUDA binary at PosixPath('/usr/local/lib/python3.12/dist-packages/bitsandbytes/libbitsandbytes_cuda128.so')\n", + "[WARNING | bitsandbytes.cextension]: The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.\n", + "[NeMo W 2025-03-04 01:01:10 nemo_logging:405] /usr/local/lib/python3.12/dist-packages/pydub/utils.py:170: RuntimeWarning: Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\n", " warn(\"Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\", RuntimeWarning)\n", " \n", "[WARNING | py.warnings ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:163: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.\n", @@ -547,318 +588,319 @@ "[INFO | pytorch_lightning.utilities.rank_zero]: GPU available: True (cuda), used: True\n", "[INFO | pytorch_lightning.utilities.rank_zero]: TPU available: False, using: 0 TPU cores\n", "[INFO | pytorch_lightning.utilities.rank_zero]: HPU available: False, using: 0 HPUs\n", - "[NeMo W 2025-03-03 23:36:31 nemo_logging:405] No version folders would be created under the log folder as 'resume_if_exists' is enabled.\n", - "[NeMo I 2025-03-03 23:36:31 nemo_logging:393] Experiments will be logged at /tmp/tmpsn4mexa6/default\n", - "[NeMo W 2025-03-03 23:36:31 nemo_logging:405] \"update_logger_directory\" is True. Overwriting tensorboard logger \"save_dir\" to /tmp/tmpsn4mexa6\n", - "[NeMo I 2025-03-03 23:36:31 nemo_logging:393] Using byte-level tokenization\n", - "[NeMo I 2025-03-03 23:36:31 nemo_logging:393] Rank 0 has data parallel group : [0]\n", - "[NeMo I 2025-03-03 23:36:31 nemo_logging:393] Rank 0 has combined group of data parallel and context parallel : [0]\n", - "[NeMo I 2025-03-03 23:36:31 nemo_logging:393] All data parallel group ranks with context parallel combined: [[0]]\n", - "[NeMo I 2025-03-03 23:36:31 nemo_logging:393] Ranks 0 has data parallel rank: 0\n", - "[NeMo I 2025-03-03 23:36:31 nemo_logging:393] Rank 0 has context parallel group: [0]\n", - "[NeMo I 2025-03-03 23:36:31 nemo_logging:393] All context parallel group ranks: [[0]]\n", - "[NeMo I 2025-03-03 23:36:31 nemo_logging:393] Ranks 0 has context parallel rank: 0\n", - "[NeMo I 2025-03-03 23:36:31 nemo_logging:393] Rank 0 has model parallel group: [0]\n", - "[NeMo I 2025-03-03 23:36:31 nemo_logging:393] All model parallel group ranks: [[0]]\n", - "[NeMo I 2025-03-03 23:36:31 nemo_logging:393] Rank 0 has tensor model parallel group: [0]\n", - "[NeMo I 2025-03-03 23:36:31 nemo_logging:393] All tensor model parallel group ranks: [[0]]\n", - "[NeMo I 2025-03-03 23:36:31 nemo_logging:393] Rank 0 has tensor model parallel rank: 0\n", - "[NeMo I 2025-03-03 23:36:31 nemo_logging:393] Rank 0 has pipeline model parallel group: [0]\n", - "[NeMo I 2025-03-03 23:36:31 nemo_logging:393] Rank 0 has embedding group: [0]\n", - "[NeMo I 2025-03-03 23:36:31 nemo_logging:393] All pipeline model parallel group ranks: [[0]]\n", - "[NeMo I 2025-03-03 23:36:31 nemo_logging:393] Rank 0 has pipeline model parallel rank 0\n", - "[NeMo I 2025-03-03 23:36:31 nemo_logging:393] All embedding group ranks: [[0]]\n", - "[NeMo I 2025-03-03 23:36:31 nemo_logging:393] Rank 0 has embedding rank: 0\n", + "[NeMo W 2025-03-04 01:01:11 nemo_logging:405] No version folders would be created under the log folder as 'resume_if_exists' is enabled.\n", + "[NeMo I 2025-03-04 01:01:11 nemo_logging:393] Experiments will be logged at /tmp/tmpupzx4lk1/default\n", + "[NeMo W 2025-03-04 01:01:11 nemo_logging:405] \"update_logger_directory\" is True. Overwriting tensorboard logger \"save_dir\" to /tmp/tmpupzx4lk1\n", + "[NeMo I 2025-03-04 01:01:11 nemo_logging:393] Using byte-level tokenization\n", + "[NeMo I 2025-03-04 01:01:11 nemo_logging:393] Rank 0 has data parallel group : [0]\n", + "[NeMo I 2025-03-04 01:01:11 nemo_logging:393] Rank 0 has combined group of data parallel and context parallel : [0]\n", + "[NeMo I 2025-03-04 01:01:11 nemo_logging:393] All data parallel group ranks with context parallel combined: [[0]]\n", + "[NeMo I 2025-03-04 01:01:11 nemo_logging:393] Ranks 0 has data parallel rank: 0\n", + "[NeMo I 2025-03-04 01:01:11 nemo_logging:393] Rank 0 has context parallel group: [0]\n", + "[NeMo I 2025-03-04 01:01:11 nemo_logging:393] All context parallel group ranks: [[0]]\n", + "[NeMo I 2025-03-04 01:01:11 nemo_logging:393] Ranks 0 has context parallel rank: 0\n", + "[NeMo I 2025-03-04 01:01:11 nemo_logging:393] Rank 0 has model parallel group: [0]\n", + "[NeMo I 2025-03-04 01:01:11 nemo_logging:393] All model parallel group ranks: [[0]]\n", + "[NeMo I 2025-03-04 01:01:11 nemo_logging:393] Rank 0 has tensor model parallel group: [0]\n", + "[NeMo I 2025-03-04 01:01:11 nemo_logging:393] All tensor model parallel group ranks: [[0]]\n", + "[NeMo I 2025-03-04 01:01:11 nemo_logging:393] Rank 0 has tensor model parallel rank: 0\n", + "[NeMo I 2025-03-04 01:01:11 nemo_logging:393] Rank 0 has pipeline model parallel group: [0]\n", + "[NeMo I 2025-03-04 01:01:11 nemo_logging:393] Rank 0 has embedding group: [0]\n", + "[NeMo I 2025-03-04 01:01:11 nemo_logging:393] All pipeline model parallel group ranks: [[0]]\n", + "[NeMo I 2025-03-04 01:01:11 nemo_logging:393] Rank 0 has pipeline model parallel rank 0\n", + "[NeMo I 2025-03-04 01:01:11 nemo_logging:393] All embedding group ranks: [[0]]\n", + "[NeMo I 2025-03-04 01:01:11 nemo_logging:393] Rank 0 has embedding rank: 0\n", "Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1\n", "[INFO | pytorch_lightning.utilities.rank_zero]: ----------------------------------------------------------------------------------------------------\n", "distributed_backend=nccl\n", "All distributed processes registered. Starting with 1 processes\n", "----------------------------------------------------------------------------------------------------\n", "\n", - "[NeMo I 2025-03-03 23:36:31 num_microbatches_calculator:228] setting number of microbatches to constant 1\n", - "[NeMo I 2025-03-03 23:36:32 nemo_logging:393] Padded vocab_size: 512, original vocab_size: 512, dummy tokens: 0.\n", - "[NeMo W 2025-03-03 23:36:32 random:220] CPU RNG state changed within GPU RNG context\n", - "[NeMo W 2025-03-03 23:36:32 random:220] CPU RNG state changed within GPU RNG context\n", - "[NeMo W 2025-03-03 23:36:32 random:220] CPU RNG state changed within GPU RNG context\n", - "[NeMo W 2025-03-03 23:36:32 random:220] CPU RNG state changed within GPU RNG context\n", - "[NeMo W 2025-03-03 23:36:32 random:220] CPU RNG state changed within GPU RNG context\n", - "[NeMo W 2025-03-03 23:36:32 random:220] CPU RNG state changed within GPU RNG context\n", - "[NeMo W 2025-03-03 23:36:32 random:220] CPU RNG state changed within GPU RNG context\n", - "[NeMo W 2025-03-03 23:36:32 random:220] CPU RNG state changed within GPU RNG context\n", - "[NeMo W 2025-03-03 23:36:32 random:220] CPU RNG state changed within GPU RNG context\n", - "[NeMo W 2025-03-03 23:36:32 random:220] CPU RNG state changed within GPU RNG context\n", - "[NeMo W 2025-03-03 23:36:32 random:220] CPU RNG state changed within GPU RNG context\n", - "[NeMo W 2025-03-03 23:36:32 random:220] CPU RNG state changed within GPU RNG context\n", - "[NeMo W 2025-03-03 23:36:32 random:220] CPU RNG state changed within GPU RNG context\n", - "[NeMo W 2025-03-03 23:36:32 random:220] CPU RNG state changed within GPU RNG context\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]\n", - "[NeMo W 2025-03-03 23:36:32 nemo_logging:405] Could not copy Trainer's 'max_steps' to LR scheduler's 'max_steps'. If you are not using an LR scheduler, this warning can safely be ignored.\n", - "[NeMo I 2025-03-03 23:36:32 nemo_logging:393] > number of parameters on (tensor, pipeline) model parallel rank (0 ,0): 1108204800\n", - "[NeMo I 2025-03-03 23:36:32 utils:302] Setting up DistributedDataParallel with config DistributedDataParallelConfig(grad_reduce_in_fp32=False, overlap_grad_reduce=False, overlap_param_gather=False, align_param_gather=False, use_distributed_optimizer=False, num_distributed_optimizer_instances=1, check_for_nan_in_grad=True, check_for_large_grads=False, bucket_size=None, average_in_collective=False, fp8_param_gather=False)\n", - "[NeMo I 2025-03-03 23:36:32 utils:323] Number of buckets for gradient all-reduce / reduce-scatter: 1\n", + "[NeMo I 2025-03-04 01:01:11 num_microbatches_calculator:228] setting number of microbatches to constant 1\n", + "[NeMo I 2025-03-04 01:01:11 nemo_logging:393] Padded vocab_size: 512, original vocab_size: 512, dummy tokens: 0.\n", + "[NeMo W 2025-03-04 01:01:11 random:220] CPU RNG state changed within GPU RNG context\n", + "[NeMo W 2025-03-04 01:01:11 random:220] CPU RNG state changed within GPU RNG context\n", + "[NeMo W 2025-03-04 01:01:11 random:220] CPU RNG state changed within GPU RNG context\n", + "[NeMo W 2025-03-04 01:01:11 random:220] CPU RNG state changed within GPU RNG context\n", + "[NeMo W 2025-03-04 01:01:11 random:220] CPU RNG state changed within GPU RNG context\n", + "[NeMo W 2025-03-04 01:01:11 random:220] CPU RNG state changed within GPU RNG context\n", + "[NeMo W 2025-03-04 01:01:11 random:220] CPU RNG state changed within GPU RNG context\n", + "[NeMo W 2025-03-04 01:01:11 random:220] CPU RNG state changed within GPU RNG context\n", + "[NeMo W 2025-03-04 01:01:11 random:220] CPU RNG state changed within GPU RNG context\n", + "[NeMo W 2025-03-04 01:01:11 random:220] CPU RNG state changed within GPU RNG context\n", + "[NeMo W 2025-03-04 01:01:11 random:220] CPU RNG state changed within GPU RNG context\n", + "[NeMo W 2025-03-04 01:01:11 random:220] CPU RNG state changed within GPU RNG context\n", + "[NeMo W 2025-03-04 01:01:11 random:220] CPU RNG state changed within GPU RNG context\n", + "[NeMo W 2025-03-04 01:01:11 random:220] CPU RNG state changed within GPU RNG context\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "[NeMo W 2025-03-04 01:01:11 nemo_logging:405] Could not copy Trainer's 'max_steps' to LR scheduler's 'max_steps'. If you are not using an LR scheduler, this warning can safely be ignored.\n", + "[NeMo I 2025-03-04 01:01:11 nemo_logging:393] > number of parameters on (tensor, pipeline) model parallel rank (0 ,0): 1108204800\n", + "[NeMo I 2025-03-04 01:01:11 utils:302] Setting up DistributedDataParallel with config DistributedDataParallelConfig(grad_reduce_in_fp32=False, overlap_grad_reduce=False, overlap_param_gather=False, align_param_gather=False, use_distributed_optimizer=False, num_distributed_optimizer_instances=1, check_for_nan_in_grad=True, check_for_large_grads=False, bucket_size=None, average_in_collective=False, fp8_param_gather=False)\n", + "[NeMo I 2025-03-04 01:01:11 utils:323] Number of buckets for gradient all-reduce / reduce-scatter: 1\n", " Params for bucket 1 (1108204800 elements):\n", - " \tmodule.decoder.layers.18.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.16.mixer.dense_projection.weight\n", - " \tmodule.decoder.layers.13.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.8.mixer.dense_projection.weight\n", - " \tmodule.decoder.layers.6.mixer.mixer.filter.gamma\n", - " \tmodule.decoder.layers.3.mlp.linear_fc1.layer_norm_weight\n", - " \tmodule.decoder.layers.3.self_attention.linear_proj.bias\n", - " \tmodule.decoder.layers.0.mixer.dense.bias\n", - " \tmodule.decoder.layers.22.mixer.mixer.conv_bias\n", - " \tmodule.decoder.layers.20.mixer.dense_projection.layer_norm_weight\n", - " \tmodule.decoder.layers.14.mlp.linear_fc1.weight\n", - " \tmodule.decoder.layers.9.mlp.linear_fc1.weight\n", - " \tmodule.decoder.layers.7.mixer.dense_projection.layer_norm_weight\n", - " \tmodule.decoder.layers.1.mixer.mixer.conv_bias\n", - " \tmodule.decoder.layers.23.mixer.mixer.filter.gamma\n", - " \tmodule.decoder.layers.18.mixer.dense_projection.weight\n", - " \tmodule.decoder.layers.15.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.13.mixer.mixer.filter.p\n", - " \tmodule.decoder.layers.11.mixer.hyena_proj_conv.short_conv_weight\n", - " \tmodule.decoder.layers.7.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.1.mlp.linear_fc1.layer_norm_weight\n", - " \tmodule.embedding.word_embeddings.weight\n", - " \tmodule.decoder.layers.24.mlp.linear_fc1.weight\n", - " \tmodule.decoder.layers.24.self_attention.linear_proj.weight\n", - " \tmodule.decoder.layers.21.mixer.mixer.short_conv.short_conv_weight\n", - " \tmodule.decoder.layers.14.mixer.dense.weight\n", - " \tmodule.decoder.layers.12.mixer.dense.bias\n", - " \tmodule.decoder.layers.9.mlp.linear_fc1.layer_norm_weight\n", - " \tmodule.decoder.layers.9.mixer.dense_projection.layer_norm_weight\n", - " \tmodule.decoder.layers.4.mixer.hyena_proj_conv.short_conv_weight\n", - " \tmodule.decoder.layers.0.mixer.mixer.short_conv.short_conv_weight\n", - " \tmodule.decoder.layers.0.mixer.dense_projection.layer_norm_weight\n", - " \tmodule.decoder.layers.20.mixer.dense_projection.weight\n", - " \tmodule.decoder.layers.17.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.13.mixer.hyena_proj_conv.short_conv_weight\n", - " \tmodule.decoder.layers.10.self_attention.linear_qkv.weight\n", - " \tmodule.decoder.layers.7.mixer.dense_projection.weight\n", - " \tmodule.decoder.layers.5.mixer.hyena_proj_conv.short_conv_weight\n", - " \tmodule.decoder.layers.1.mixer.mixer.filter.h\n", - " \tmodule.decoder.layers.24.self_attention.linear_qkv.layer_norm_weight\n", - " \tmodule.decoder.layers.23.mixer.mixer.conv_bias\n", - " \tmodule.decoder.layers.18.mlp.linear_fc1.weight\n", - " \tmodule.decoder.layers.16.mixer.dense.bias\n", - " \tmodule.decoder.layers.13.mlp.linear_fc1.weight\n", - " \tmodule.decoder.layers.12.mixer.dense_projection.layer_norm_weight\n", - " \tmodule.decoder.layers.9.mixer.mixer.filter.R\n", - " \tmodule.decoder.layers.8.mixer.dense.bias\n", - " \tmodule.decoder.layers.3.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.2.mixer.dense_projection.weight\n", - " \tmodule.decoder.layers.22.mixer.hyena_proj_conv.short_conv_weight\n", - " \tmodule.decoder.layers.19.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.17.mlp.linear_fc1.layer_norm_weight\n", - " \tmodule.decoder.layers.15.mlp.linear_fc1.layer_norm_weight\n", - " \tmodule.decoder.layers.12.mixer.dense.weight\n", - " \tmodule.decoder.layers.11.mixer.dense.bias\n", - " \tmodule.decoder.layers.6.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.18.mixer.dense.weight\n", - " \tmodule.decoder.layers.15.mlp.linear_fc1.weight\n", - " \tmodule.decoder.layers.13.mlp.linear_fc1.layer_norm_weight\n", - " \tmodule.decoder.layers.7.mlp.linear_fc1.weight\n", - " \tmodule.decoder.layers.1.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.23.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.16.mixer.dense.weight\n", - " \tmodule.decoder.layers.14.mlp.linear_fc1.layer_norm_weight\n", - " \tmodule.decoder.layers.12.mixer.dense_projection.weight\n", - " \tmodule.decoder.layers.8.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.6.mixer.mixer.filter.p\n", - " \tmodule.decoder.layers.4.mixer.dense.weight\n", - " \tmodule.decoder.layers.20.mixer.dense.bias\n", - " \tmodule.decoder.layers.17.mlp.linear_fc1.weight\n", - " \tmodule.decoder.layers.17.self_attention.linear_proj.bias\n", - " \tmodule.decoder.layers.15.mixer.mixer.filter.h\n", - " \tmodule.decoder.layers.13.mixer.mixer.filter.R\n", - " \tmodule.decoder.layers.11.mixer.dense_projection.layer_norm_weight\n", - " \tmodule.decoder.layers.7.mixer.dense.weight\n", - " \tmodule.decoder.layers.6.mixer.mixer.conv_bias\n", - " \tmodule.decoder.layers.24.mlp.linear_fc1.layer_norm_weight\n", - " \tmodule.decoder.layers.23.mixer.mixer.filter.p\n", - " \tmodule.decoder.layers.21.mixer.hyena_proj_conv.short_conv_weight\n", - " \tmodule.decoder.layers.19.mlp.linear_fc1.layer_norm_weight\n", - " \tmodule.decoder.layers.11.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.6.mixer.hyena_proj_conv.short_conv_weight\n", - " \tmodule.decoder.layers.3.self_attention.linear_qkv.weight\n", - " \tmodule.decoder.layers.0.mixer.hyena_proj_conv.short_conv_weight\n", - " \tmodule.decoder.layers.17.self_attention.linear_qkv.layer_norm_weight\n", " \tmodule.decoder.layers.22.mixer.dense.bias\n", - " \tmodule.decoder.layers.19.mlp.linear_fc1.weight\n", - " \tmodule.decoder.layers.15.mixer.mixer.conv_bias\n", - " \tmodule.decoder.layers.13.mixer.dense_projection.layer_norm_weight\n", + " \tmodule.decoder.layers.19.mixer.mixer.conv_bias\n", + " \tmodule.decoder.layers.16.mixer.mixer.conv_bias\n", + " \tmodule.decoder.layers.11.mlp.linear_fc1.weight\n", + " \tmodule.decoder.layers.9.mixer.dense.bias\n", " \tmodule.decoder.layers.6.mlp.linear_fc1.weight\n", - " \tmodule.decoder.layers.5.mixer.dense_projection.layer_norm_weight\n", - " \tmodule.decoder.layers.1.mixer.hyena_proj_conv.short_conv_weight\n", + " \tmodule.decoder.layers.3.mlp.linear_fc2.weight\n", + " \tmodule.decoder.layers.0.mixer.dense_projection.weight\n", " \tmodule.decoder.layers.23.mixer.hyena_proj_conv.short_conv_weight\n", " \tmodule.decoder.layers.20.mixer.dense.weight\n", - " \tmodule.decoder.layers.18.mlp.linear_fc1.layer_norm_weight\n", - " \tmodule.decoder.layers.16.mixer.mixer.filter.gamma\n", - " \tmodule.decoder.layers.11.mixer.dense_projection.weight\n", - " \tmodule.decoder.layers.8.mixer.mixer.conv_bias\n", + " \tmodule.decoder.layers.15.mixer.hyena_proj_conv.short_conv_weight\n", + " \tmodule.decoder.layers.12.mlp.linear_fc2.weight\n", + " \tmodule.decoder.layers.10.mlp.linear_fc1.layer_norm_weight\n", + " \tmodule.decoder.layers.8.mlp.linear_fc1.layer_norm_weight\n", " \tmodule.decoder.layers.5.mixer.dense.weight\n", - " \tmodule.decoder.layers.1.mlp.linear_fc1.weight\n", - " \tmodule.decoder.layers.17.self_attention.linear_proj.weight\n", - " \tmodule.decoder.layers.24.self_attention.linear_proj.bias\n", + " \tmodule.decoder.layers.2.mixer.mixer.filter.R\n", + " \tmodule.decoder.layers.1.mixer.dense_projection.weight\n", " \tmodule.decoder.layers.23.mlp.linear_fc1.weight\n", " \tmodule.decoder.layers.22.mixer.dense_projection.layer_norm_weight\n", - " \tmodule.decoder.layers.19.mixer.mixer.filter.h\n", - " \tmodule.decoder.layers.14.mixer.mixer.short_conv.short_conv_weight\n", + " \tmodule.decoder.layers.18.mixer.mixer.short_conv.short_conv_weight\n", + " \tmodule.decoder.layers.11.mixer.dense.weight\n", + " \tmodule.decoder.layers.10.self_attention.linear_proj.bias\n", " \tmodule.decoder.layers.8.mlp.linear_fc1.weight\n", " \tmodule.decoder.layers.6.mlp.linear_fc1.layer_norm_weight\n", - " \tmodule.decoder.layers.4.mixer.dense_projection.weight\n", + " \tmodule.decoder.layers.3.mlp.linear_fc1.layer_norm_weight\n", " \tmodule.decoder.layers.22.mixer.dense.weight\n", - " \tmodule.decoder.layers.13.mixer.dense_projection.weight\n", - " \tmodule.decoder.layers.10.mlp.linear_fc2.weight\n", + " \tmodule.decoder.layers.16.mlp.linear_fc2.weight\n", + " \tmodule.decoder.layers.9.mixer.dense.weight\n", " \tmodule.decoder.layers.7.mlp.linear_fc1.layer_norm_weight\n", " \tmodule.decoder.layers.5.mixer.dense_projection.weight\n", - " \tmodule.decoder.layers.2.mixer.dense.weight\n", + " \tmodule.decoder.layers.4.mlp.linear_fc1.layer_norm_weight\n", + " \tmodule.decoder.layers.0.mlp.linear_fc2.weight\n", " \tmodule.decoder.layers.23.mlp.linear_fc1.layer_norm_weight\n", - " \tmodule.decoder.layers.19.mixer.mixer.conv_bias\n", - " \tmodule.decoder.layers.16.mixer.mixer.conv_bias\n", - " \tmodule.decoder.layers.11.mlp.linear_fc1.weight\n", - " \tmodule.decoder.layers.9.mixer.dense.bias\n", + " \tmodule.decoder.layers.13.mixer.dense.bias\n", + " \tmodule.decoder.layers.10.mlp.linear_fc1.weight\n", " \tmodule.decoder.layers.8.mixer.mixer.filter.h\n", " \tmodule.decoder.layers.6.mixer.mixer.filter.R\n", - " \tmodule.decoder.layers.4.mixer.dense_projection.layer_norm_weight\n", - " \tmodule.decoder.layers.0.mixer.dense.weight\n", + " \tmodule.decoder.layers.3.self_attention.linear_proj.bias\n", + " \tmodule.decoder.layers.2.mlp.linear_fc2.weight\n", + " \tmodule.decoder.layers.1.mixer.dense.weight\n", + " \tmodule.decoder.layers.24.self_attention.linear_qkv.weight\n", " \tmodule.decoder.layers.22.mixer.dense_projection.weight\n", - " \tmodule.decoder.layers.20.mixer.mixer.filter.gamma\n", - " \tmodule.decoder.layers.15.mixer.hyena_proj_conv.short_conv_weight\n", - " \tmodule.decoder.layers.12.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.10.mlp.linear_fc1.layer_norm_weight\n", - " \tmodule.decoder.layers.7.mixer.dense.bias\n", - " \tmodule.decoder.layers.4.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.2.mixer.dense.bias\n", - " \tmodule.decoder.layers.1.mixer.dense.bias\n", - " \tmodule.decoder.layers.23.mixer.mixer.filter.R\n", - " \tmodule.decoder.layers.21.mixer.dense_projection.layer_norm_weight\n", - " \tmodule.decoder.layers.18.mixer.mixer.short_conv.short_conv_weight\n", - " \tmodule.decoder.layers.11.mixer.dense.weight\n", - " \tmodule.decoder.layers.6.mixer.dense_projection.layer_norm_weight\n", - " \tmodule.decoder.layers.3.self_attention.linear_proj.weight\n", - " \tmodule.decoder.layers.2.mixer.mixer.filter.gamma\n", - " \tmodule.decoder.layers.21.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.16.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.9.mixer.dense.weight\n", - " \tmodule.decoder.layers.9.mixer.mixer.filter.gamma\n", - " \tmodule.decoder.layers.4.mlp.linear_fc1.layer_norm_weight\n", - " \tmodule.decoder.layers.1.mixer.dense_projection.layer_norm_weight\n", - " \tmodule.decoder.layers.23.mixer.dense_projection.layer_norm_weight\n", - " \tmodule.decoder.layers.20.mixer.mixer.conv_bias\n", - " \tmodule.decoder.layers.13.mixer.dense.bias\n", - " \tmodule.decoder.layers.10.mlp.linear_fc1.weight\n", - " \tmodule.decoder.layers.7.mixer.mixer.short_conv.short_conv_weight\n", - " \tmodule.decoder.layers.5.mixer.dense.bias\n", - " \tmodule.decoder.layers.21.mixer.dense_projection.weight\n", " \tmodule.decoder.layers.19.mixer.hyena_proj_conv.short_conv_weight\n", " \tmodule.decoder.layers.16.mixer.mixer.filter.p\n", " \tmodule.decoder.layers.14.mixer.hyena_proj_conv.short_conv_weight\n", " \tmodule.decoder.layers.12.mlp.linear_fc1.layer_norm_weight\n", - " \tmodule.decoder.layers.6.mixer.dense_projection.weight\n", " \tmodule.decoder.layers.3.mlp.linear_fc1.weight\n", - " \tmodule.decoder.layers.0.mixer.dense_projection.weight\n", - " \tmodule.decoder.layers.21.mixer.dense.bias\n", + " \tmodule.decoder.layers.0.mlp.linear_fc1.layer_norm_weight\n", + " \tmodule.decoder.layers.23.mixer.mixer.filter.R\n", + " \tmodule.decoder.layers.21.mixer.dense_projection.layer_norm_weight\n", " \tmodule.decoder.layers.15.mixer.dense.bias\n", " \tmodule.decoder.layers.12.mlp.linear_fc1.weight\n", " \tmodule.decoder.layers.10.self_attention.linear_qkv.layer_norm_weight\n", - " \tmodule.decoder.layers.4.mlp.linear_fc1.weight\n", - " \tmodule.decoder.layers.2.mixer.mixer.filter.R\n", - " \tmodule.decoder.layers.1.mixer.dense_projection.weight\n", - " \tmodule.decoder.layers.23.mixer.dense_projection.weight\n", - " \tmodule.decoder.layers.20.mlp.linear_fc2.weight\n", + " \tmodule.decoder.layers.8.mixer.mixer.conv_bias\n", + " \tmodule.decoder.layers.6.mixer.dense_projection.layer_norm_weight\n", + " \tmodule.decoder.layers.2.mixer.mixer.filter.p\n", + " \tmodule.decoder.final_norm.weight\n", + " \tmodule.decoder.layers.21.mlp.linear_fc2.weight\n", " \tmodule.decoder.layers.16.mixer.hyena_proj_conv.short_conv_weight\n", " \tmodule.decoder.layers.13.mixer.dense.weight\n", " \tmodule.decoder.layers.11.mlp.linear_fc1.layer_norm_weight\n", - " \tmodule.decoder.layers.10.self_attention.linear_proj.bias\n", - " \tmodule.decoder.layers.8.mixer.hyena_proj_conv.short_conv_weight\n", - " \tmodule.decoder.layers.5.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.21.mlp.linear_fc1.weight\n", + " \tmodule.decoder.layers.9.mixer.mixer.filter.gamma\n", + " \tmodule.decoder.layers.3.self_attention.linear_qkv.layer_norm_weight\n", + " \tmodule.decoder.layers.2.mixer.hyena_proj_conv.short_conv_weight\n", + " \tmodule.decoder.layers.24.self_attention.linear_proj.bias\n", + " \tmodule.decoder.layers.23.mixer.dense_projection.layer_norm_weight\n", + " \tmodule.decoder.layers.20.mixer.mixer.conv_bias\n", + " \tmodule.decoder.layers.19.mlp.linear_fc2.weight\n", " \tmodule.decoder.layers.16.mlp.linear_fc1.weight\n", " \tmodule.decoder.layers.15.mixer.dense_projection.layer_norm_weight\n", " \tmodule.decoder.layers.12.mixer.mixer.filter.h\n", " \tmodule.decoder.layers.10.self_attention.linear_proj.weight\n", - " \tmodule.decoder.layers.9.mixer.hyena_proj_conv.short_conv_weight\n", - " \tmodule.decoder.layers.5.mlp.linear_fc1.layer_norm_weight\n", - " \tmodule.decoder.layers.0.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.22.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.20.mixer.mixer.filter.p\n", + " \tmodule.decoder.layers.7.mixer.mixer.short_conv.short_conv_weight\n", + " \tmodule.decoder.layers.5.mixer.dense.bias\n", + " \tmodule.decoder.layers.21.mixer.dense_projection.weight\n", " \tmodule.decoder.layers.18.mixer.hyena_proj_conv.short_conv_weight\n", " \tmodule.decoder.layers.15.mixer.dense.weight\n", - " \tmodule.decoder.layers.2.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.1.mixer.dense.weight\n", - " \tmodule.decoder.layers.21.mixer.dense.weight\n", + " \tmodule.decoder.layers.6.mixer.dense_projection.weight\n", + " \tmodule.decoder.layers.2.mlp.linear_fc1.weight\n", + " \tmodule.decoder.layers.2.mixer.dense_projection.layer_norm_weight\n", + " \tmodule.decoder.layers.20.mixer.dense_projection.layer_norm_weight\n", " \tmodule.decoder.layers.19.mixer.dense.bias\n", " \tmodule.decoder.layers.16.mlp.linear_fc1.layer_norm_weight\n", " \tmodule.decoder.layers.12.mixer.mixer.conv_bias\n", " \tmodule.decoder.layers.9.mixer.mixer.conv_bias\n", - " \tmodule.decoder.layers.8.mlp.linear_fc1.layer_norm_weight\n", + " \tmodule.decoder.layers.5.mixer.dense_projection.layer_norm_weight\n", " \tmodule.decoder.layers.4.mixer.dense.bias\n", - " \tmodule.decoder.layers.0.mlp.linear_fc1.layer_norm_weight\n", - " \tmodule.decoder.layers.20.mixer.hyena_proj_conv.short_conv_weight\n", + " \tmodule.decoder.layers.2.mixer.mixer.conv_bias\n", " \tmodule.decoder.layers.17.self_attention.linear_qkv.weight\n", + " \tmodule.decoder.layers.23.mixer.dense_projection.weight\n", + " \tmodule.decoder.layers.20.mlp.linear_fc2.weight\n", " \tmodule.decoder.layers.15.mixer.dense_projection.weight\n", " \tmodule.decoder.layers.13.mixer.mixer.filter.gamma\n", - " \tmodule.decoder.layers.7.mixer.hyena_proj_conv.short_conv_weight\n", - " \tmodule.decoder.layers.5.mixer.mixer.conv_bias\n", - " \tmodule.decoder.layers.2.mixer.mixer.filter.p\n", - " \tmodule.decoder.layers.23.mixer.dense.bias\n", - " \tmodule.decoder.layers.20.mlp.linear_fc1.weight\n", - " \tmodule.decoder.layers.19.mixer.dense_projection.layer_norm_weight\n", - " \tmodule.decoder.layers.16.mixer.mixer.filter.R\n", - " \tmodule.decoder.layers.14.mixer.dense_projection.layer_norm_weight\n", + " \tmodule.decoder.layers.8.mixer.hyena_proj_conv.short_conv_weight\n", + " \tmodule.decoder.layers.5.mlp.linear_fc2.weight\n", + " \tmodule.decoder.layers.2.mlp.linear_fc1.layer_norm_weight\n", + " \tmodule.decoder.layers.21.mlp.linear_fc1.weight\n", + " \tmodule.decoder.layers.19.mixer.dense_projection.layer_norm_weight\n", + " \tmodule.decoder.layers.16.mixer.mixer.filter.R\n", + " \tmodule.decoder.layers.14.mixer.dense_projection.layer_norm_weight\n", " \tmodule.decoder.layers.11.mixer.mixer.short_conv.short_conv_weight\n", - " \tmodule.decoder.layers.5.mlp.linear_fc1.weight\n", - " \tmodule.decoder.layers.3.self_attention.linear_qkv.layer_norm_weight\n", - " \tmodule.decoder.layers.2.mixer.hyena_proj_conv.short_conv_weight\n", - " \tmodule.decoder.layers.24.self_attention.linear_qkv.weight\n", - " \tmodule.decoder.layers.22.mlp.linear_fc1.layer_norm_weight\n", + " \tmodule.decoder.layers.0.mixer.dense.bias\n", + " \tmodule.decoder.layers.22.mlp.linear_fc2.weight\n", + " \tmodule.decoder.layers.20.mixer.mixer.filter.p\n", " \tmodule.decoder.layers.19.mixer.dense.weight\n", " \tmodule.decoder.layers.18.mixer.dense.bias\n", " \tmodule.decoder.layers.14.mlp.linear_fc2.weight\n", " \tmodule.decoder.layers.9.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.9.mixer.dense_projection.weight\n", - " \tmodule.decoder.layers.6.mixer.dense.weight\n", + " \tmodule.decoder.layers.5.mlp.linear_fc1.layer_norm_weight\n", " \tmodule.decoder.layers.4.mixer.mixer.short_conv.short_conv_weight\n", - " \tmodule.decoder.layers.0.mlp.linear_fc1.weight\n", - " \tmodule.decoder.layers.22.mlp.linear_fc1.weight\n", - " \tmodule.decoder.layers.20.mlp.linear_fc1.layer_norm_weight\n", + " \tmodule.decoder.layers.1.mixer.mixer.conv_bias\n", + " \tmodule.decoder.layers.21.mixer.dense.weight\n", " \tmodule.decoder.layers.16.mixer.dense_projection.layer_norm_weight\n", " \tmodule.decoder.layers.13.mixer.mixer.conv_bias\n", - " \tmodule.decoder.layers.8.mixer.dense_projection.layer_norm_weight\n", - " \tmodule.decoder.layers.5.mixer.mixer.filter.h\n", - " \tmodule.decoder.layers.2.mlp.linear_fc1.weight\n", - " \tmodule.decoder.layers.2.mixer.dense_projection.layer_norm_weight\n", + " \tmodule.decoder.layers.6.mixer.dense.bias\n", + " \tmodule.decoder.layers.1.mlp.linear_fc1.layer_norm_weight\n", " \tmodule.decoder.layers.24.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.23.mixer.dense.weight\n", - " \tmodule.decoder.layers.21.mlp.linear_fc1.layer_norm_weight\n", + " \tmodule.decoder.layers.21.mixer.dense.bias\n", " \tmodule.decoder.layers.19.mixer.dense_projection.weight\n", " \tmodule.decoder.layers.14.mixer.dense_projection.weight\n", " \tmodule.decoder.layers.12.mixer.hyena_proj_conv.short_conv_weight\n", " \tmodule.decoder.layers.9.mixer.mixer.filter.p\n", + " \tmodule.decoder.layers.7.mixer.hyena_proj_conv.short_conv_weight\n", + " \tmodule.decoder.layers.4.mlp.linear_fc1.weight\n", + " \tmodule.decoder.layers.0.mixer.mixer.short_conv.short_conv_weight\n", + " \tmodule.decoder.layers.23.mixer.dense.bias\n", + " \tmodule.decoder.layers.20.mlp.linear_fc1.weight\n", + " \tmodule.decoder.layers.18.mixer.dense_projection.layer_norm_weight\n", + " \tmodule.decoder.layers.14.mixer.dense.bias\n", + " \tmodule.decoder.layers.8.mixer.dense.bias\n", + " \tmodule.decoder.layers.5.mlp.linear_fc1.weight\n", + " \tmodule.decoder.layers.1.mixer.mixer.filter.h\n", + " \tmodule.decoder.layers.24.mlp.linear_fc1.layer_norm_weight\n", + " \tmodule.decoder.layers.22.mlp.linear_fc1.layer_norm_weight\n", + " \tmodule.decoder.layers.18.mlp.linear_fc2.weight\n", + " \tmodule.decoder.layers.16.mixer.dense_projection.weight\n", + " \tmodule.decoder.layers.13.mlp.linear_fc2.weight\n", + " \tmodule.decoder.layers.9.mixer.hyena_proj_conv.short_conv_weight\n", + " \tmodule.decoder.layers.6.mixer.dense.weight\n", + " \tmodule.decoder.layers.2.mixer.dense_projection.weight\n", + " \tmodule.decoder.layers.0.mixer.dense_projection.layer_norm_weight\n", + " \tmodule.decoder.layers.22.mlp.linear_fc1.weight\n", + " \tmodule.decoder.layers.20.mlp.linear_fc1.layer_norm_weight\n", + " \tmodule.decoder.layers.20.mixer.dense_projection.weight\n", + " \tmodule.decoder.layers.14.mlp.linear_fc1.weight\n", + " \tmodule.decoder.layers.9.mlp.linear_fc1.weight\n", + " \tmodule.decoder.layers.8.mixer.dense_projection.layer_norm_weight\n", + " \tmodule.decoder.layers.5.mixer.mixer.filter.h\n", + " \tmodule.decoder.layers.23.mixer.dense.weight\n", + " \tmodule.decoder.layers.21.mlp.linear_fc1.layer_norm_weight\n", + " \tmodule.decoder.layers.20.mixer.mixer.filter.gamma\n", + " \tmodule.decoder.layers.18.mixer.dense_projection.weight\n", + " \tmodule.decoder.layers.15.mlp.linear_fc2.weight\n", + " \tmodule.decoder.layers.13.mixer.mixer.filter.p\n", + " \tmodule.decoder.layers.11.mixer.hyena_proj_conv.short_conv_weight\n", " \tmodule.decoder.layers.8.mixer.dense.weight\n", - " \tmodule.decoder.layers.6.mixer.dense.bias\n", - " \tmodule.decoder.layers.2.mixer.mixer.conv_bias\n", - " \tmodule.decoder.final_norm.weight\n", + " \tmodule.decoder.layers.1.mlp.linear_fc2.weight\n", + " \tmodule.decoder.layers.24.mlp.linear_fc1.weight\n", " \tmodule.decoder.layers.22.mixer.mixer.filter.h\n", " \tmodule.decoder.layers.20.mixer.mixer.filter.R\n", - " \tmodule.decoder.layers.18.mixer.dense_projection.layer_norm_weight\n", - " \tmodule.decoder.layers.14.mixer.dense.bias\n", - " \tmodule.decoder.layers.2.mlp.linear_fc1.layer_norm_weight\n", - "[NeMo I 2025-03-03 23:36:32 nemo_logging:393] Doing selective restore from RestoreConfig(path='nemo2_evo2_1b_8k', adapter_path=None, load_model_state=True, load_optim_state=False, load_artifacts=False)\n", - "[NeMo I 2025-03-03 23:36:32 nemo_logging:393] Using dist-ckpt load strategy.\n", - "[WARNING | py.warnings ]: /workspace/bionemo2/3rdparty/Megatron-LM/megatron/core/dist_checkpointing/strategies/torch.py:847: FutureWarning: `load_state_dict` is deprecated and will be removed in future versions. Please use `load` instead.\n", + " \tmodule.decoder.layers.14.mixer.dense.weight\n", + " \tmodule.decoder.layers.12.mixer.dense.bias\n", + " \tmodule.decoder.layers.9.mlp.linear_fc1.layer_norm_weight\n", + " \tmodule.decoder.layers.5.mixer.mixer.conv_bias\n", + " \tmodule.decoder.layers.4.mixer.hyena_proj_conv.short_conv_weight\n", + " \tmodule.decoder.layers.17.self_attention.linear_proj.bias\n", + " \tmodule.decoder.layers.17.mlp.linear_fc2.weight\n", + " \tmodule.decoder.layers.13.mixer.hyena_proj_conv.short_conv_weight\n", + " \tmodule.decoder.layers.10.self_attention.linear_qkv.weight\n", + " \tmodule.decoder.layers.8.mixer.dense_projection.weight\n", + " \tmodule.decoder.layers.6.mixer.mixer.filter.gamma\n", + " \tmodule.decoder.layers.24.self_attention.linear_qkv.layer_norm_weight\n", + " \tmodule.decoder.layers.22.mixer.mixer.conv_bias\n", + " \tmodule.decoder.layers.18.mlp.linear_fc1.weight\n", + " \tmodule.decoder.layers.16.mixer.dense.bias\n", + " \tmodule.decoder.layers.13.mlp.linear_fc1.weight\n", + " \tmodule.decoder.layers.12.mixer.dense_projection.layer_norm_weight\n", + " \tmodule.decoder.layers.9.mixer.mixer.filter.R\n", + " \tmodule.decoder.layers.7.mixer.dense_projection.layer_norm_weight\n", + " \tmodule.decoder.layers.3.self_attention.linear_qkv.weight\n", + " \tmodule.decoder.layers.0.mixer.hyena_proj_conv.short_conv_weight\n", + " \tmodule.decoder.layers.23.mixer.mixer.filter.gamma\n", + " \tmodule.decoder.layers.17.mlp.linear_fc1.layer_norm_weight\n", + " \tmodule.decoder.layers.15.mlp.linear_fc1.layer_norm_weight\n", + " \tmodule.decoder.layers.12.mixer.dense.weight\n", + " \tmodule.decoder.layers.11.mixer.dense.bias\n", + " \tmodule.decoder.layers.7.mlp.linear_fc2.weight\n", + " \tmodule.decoder.layers.1.mixer.hyena_proj_conv.short_conv_weight\n", + " \tmodule.decoder.layers.24.self_attention.linear_proj.weight\n", + " \tmodule.decoder.layers.21.mixer.mixer.short_conv.short_conv_weight\n", + " \tmodule.decoder.layers.18.mixer.dense.weight\n", + " \tmodule.decoder.layers.15.mlp.linear_fc1.weight\n", + " \tmodule.decoder.layers.13.mlp.linear_fc1.layer_norm_weight\n", + " \tmodule.decoder.layers.9.mixer.dense_projection.layer_norm_weight\n", + " \tmodule.decoder.layers.6.mixer.mixer.conv_bias\n", + " \tmodule.decoder.layers.1.mlp.linear_fc1.weight\n", + " \tmodule.decoder.layers.20.mixer.hyena_proj_conv.short_conv_weight\n", + " \tmodule.decoder.layers.16.mixer.dense.weight\n", + " \tmodule.decoder.layers.14.mlp.linear_fc1.layer_norm_weight\n", + " \tmodule.decoder.layers.12.mixer.dense_projection.weight\n", + " \tmodule.decoder.layers.7.mixer.dense_projection.weight\n", + " \tmodule.decoder.layers.5.mixer.hyena_proj_conv.short_conv_weight\n", + " \tmodule.decoder.layers.4.mixer.dense.weight\n", + " \tmodule.decoder.layers.23.mixer.mixer.conv_bias\n", + " \tmodule.decoder.layers.17.mlp.linear_fc1.weight\n", + " \tmodule.decoder.layers.15.mixer.mixer.filter.h\n", + " \tmodule.decoder.layers.13.mixer.mixer.filter.R\n", + " \tmodule.decoder.layers.11.mixer.dense_projection.layer_norm_weight\n", + " \tmodule.decoder.layers.7.mixer.dense.bias\n", + " \tmodule.decoder.layers.2.mixer.dense.weight\n", + " \tmodule.embedding.word_embeddings.weight\n", + " \tmodule.decoder.layers.22.mixer.hyena_proj_conv.short_conv_weight\n", + " \tmodule.decoder.layers.19.mlp.linear_fc1.layer_norm_weight\n", + " \tmodule.decoder.layers.11.mlp.linear_fc2.weight\n", + " \tmodule.decoder.layers.9.mixer.dense_projection.weight\n", + " \tmodule.decoder.layers.6.mlp.linear_fc2.weight\n", + " \tmodule.decoder.layers.4.mixer.dense_projection.layer_norm_weight\n", + " \tmodule.decoder.layers.0.mixer.dense.weight\n", + " \tmodule.decoder.layers.19.mlp.linear_fc1.weight\n", + " \tmodule.decoder.layers.17.self_attention.linear_qkv.layer_norm_weight\n", + " \tmodule.decoder.layers.15.mixer.mixer.conv_bias\n", + " \tmodule.decoder.layers.13.mixer.dense_projection.layer_norm_weight\n", + " \tmodule.decoder.layers.7.mlp.linear_fc1.weight\n", + " \tmodule.decoder.layers.4.mlp.linear_fc2.weight\n", + " \tmodule.decoder.layers.2.mixer.dense.bias\n", + " \tmodule.decoder.layers.1.mixer.dense.bias\n", + " \tmodule.decoder.layers.23.mlp.linear_fc2.weight\n", + " \tmodule.decoder.layers.18.mlp.linear_fc1.layer_norm_weight\n", + " \tmodule.decoder.layers.16.mixer.mixer.filter.gamma\n", + " \tmodule.decoder.layers.11.mixer.dense_projection.weight\n", + " \tmodule.decoder.layers.8.mlp.linear_fc2.weight\n", + " \tmodule.decoder.layers.6.mixer.mixer.filter.p\n", + " \tmodule.decoder.layers.3.self_attention.linear_proj.weight\n", + " \tmodule.decoder.layers.2.mixer.mixer.filter.gamma\n", + " \tmodule.decoder.layers.0.mlp.linear_fc1.weight\n", + " \tmodule.decoder.layers.20.mixer.dense.bias\n", + " \tmodule.decoder.layers.19.mixer.mixer.filter.h\n", + " \tmodule.decoder.layers.17.self_attention.linear_proj.weight\n", + " \tmodule.decoder.layers.14.mixer.mixer.short_conv.short_conv_weight\n", + " \tmodule.decoder.layers.7.mixer.dense.weight\n", + " \tmodule.decoder.layers.4.mixer.dense_projection.weight\n", + " \tmodule.decoder.layers.1.mixer.dense_projection.layer_norm_weight\n", + " \tmodule.decoder.layers.23.mixer.mixer.filter.p\n", + " \tmodule.decoder.layers.21.mixer.hyena_proj_conv.short_conv_weight\n", + " \tmodule.decoder.layers.13.mixer.dense_projection.weight\n", + " \tmodule.decoder.layers.10.mlp.linear_fc2.weight\n", + " \tmodule.decoder.layers.6.mixer.hyena_proj_conv.short_conv_weight\n", + "[NeMo I 2025-03-04 01:01:11 nemo_logging:393] Doing selective restore from RestoreConfig(path='nemo2_evo2_1b_8k', adapter_path=None, load_model_state=True, load_optim_state=False, load_artifacts=False)\n", + "[NeMo I 2025-03-04 01:01:11 nemo_logging:393] Using dist-ckpt load strategy.\n", + "[WARNING | py.warnings ]: /workspaces/bionemo-framework/3rdparty/Megatron-LM/megatron/core/dist_checkpointing/strategies/torch.py:847: FutureWarning: `load_state_dict` is deprecated and will be removed in future versions. Please use `load` instead.\n", " checkpoint.load_state_dict(\n", "\n", "[WARNING | py.warnings ]: /usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/planner_helpers.py:316: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.\n", " device = getattr(value, \"device\", None)\n", "\n", - "[NeMo I 2025-03-03 23:36:33 nemo_logging:393] Global Checkpoint Load : Rank : 0 : Start time : 1741044992.504s : Time spent in load_checkpoint: 1.046s\n", - "[NeMo I 2025-03-03 23:36:33 nemo_logging:393] Restoring model weights from RestoreConfig(path='nemo2_evo2_1b_8k', adapter_path=None, load_model_state=True, load_optim_state=False, load_artifacts=False)\n", - "[NeMo I 2025-03-03 23:36:33 nemo_logging:393] Finished restoring from RestoreConfig(path='nemo2_evo2_1b_8k', adapter_path=None, load_model_state=True, load_optim_state=False, load_artifacts=False), cleaning up.\n" + "[NeMo I 2025-03-04 01:01:12 nemo_logging:393] Global Checkpoint Load : Rank : 0 : Start time : 1741050071.495s : Time spent in load_checkpoint: 0.932s\n", + "[NeMo I 2025-03-04 01:01:12 nemo_logging:393] Restoring model weights from RestoreConfig(path='nemo2_evo2_1b_8k', adapter_path=None, load_model_state=True, load_optim_state=False, load_artifacts=False)\n", + "[NeMo I 2025-03-04 01:01:12 nemo_logging:393] Finished restoring from RestoreConfig(path='nemo2_evo2_1b_8k', adapter_path=None, load_model_state=True, load_optim_state=False, load_artifacts=False), cleaning up.\n" ] } ], "source": [ + "print(f\"Running command: {predict_ref_command}\")\n", "!{predict_ref_command}" ] }, @@ -871,14 +913,17 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[NeMo W 2025-03-03 23:37:15 nemo_logging:405] /usr/local/lib/python3.12/dist-packages/pydub/utils.py:170: RuntimeWarning: Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\n", + "Running command: predict_evo2 --fasta brca1_fasta_files/brca1_variant_sequences.fasta --ckpt-dir nemo2_evo2_1b_8k --output-dir brca1_fasta_files/variant_predictions --model-size 1b --tensor-parallel-size 1 --pipeline-model-parallel-size 1 --context-parallel-size 1 --output-log-prob-seqs --fp8\n", + "[WARNING | bitsandbytes.cextension]: Could not find the bitsandbytes CUDA binary at PosixPath('/usr/local/lib/python3.12/dist-packages/bitsandbytes/libbitsandbytes_cuda128.so')\n", + "[WARNING | bitsandbytes.cextension]: The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.\n", + "[NeMo W 2025-03-04 01:02:34 nemo_logging:405] /usr/local/lib/python3.12/dist-packages/pydub/utils.py:170: RuntimeWarning: Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\n", " warn(\"Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\", RuntimeWarning)\n", " \n", "[WARNING | py.warnings ]: /usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/selective_scan_interface.py:163: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.\n", @@ -908,69 +953,102 @@ "[INFO | pytorch_lightning.utilities.rank_zero]: GPU available: True (cuda), used: True\n", "[INFO | pytorch_lightning.utilities.rank_zero]: TPU available: False, using: 0 TPU cores\n", "[INFO | pytorch_lightning.utilities.rank_zero]: HPU available: False, using: 0 HPUs\n", - "[NeMo W 2025-03-03 23:37:17 nemo_logging:405] No version folders would be created under the log folder as 'resume_if_exists' is enabled.\n", - "[NeMo I 2025-03-03 23:37:17 nemo_logging:393] Experiments will be logged at /tmp/tmpcu9581ff/default\n", - "[NeMo W 2025-03-03 23:37:17 nemo_logging:405] \"update_logger_directory\" is True. Overwriting tensorboard logger \"save_dir\" to /tmp/tmpcu9581ff\n", - "[NeMo I 2025-03-03 23:37:17 nemo_logging:393] Using byte-level tokenization\n", - "[NeMo I 2025-03-03 23:37:17 nemo_logging:393] Rank 0 has data parallel group : [0]\n", - "[NeMo I 2025-03-03 23:37:17 nemo_logging:393] Rank 0 has combined group of data parallel and context parallel : [0]\n", - "[NeMo I 2025-03-03 23:37:17 nemo_logging:393] All data parallel group ranks with context parallel combined: [[0]]\n", - "[NeMo I 2025-03-03 23:37:17 nemo_logging:393] Ranks 0 has data parallel rank: 0\n", - "[NeMo I 2025-03-03 23:37:17 nemo_logging:393] Rank 0 has context parallel group: [0]\n", - "[NeMo I 2025-03-03 23:37:17 nemo_logging:393] All context parallel group ranks: [[0]]\n", - "[NeMo I 2025-03-03 23:37:17 nemo_logging:393] Ranks 0 has context parallel rank: 0\n", - "[NeMo I 2025-03-03 23:37:17 nemo_logging:393] Rank 0 has model parallel group: [0]\n", - "[NeMo I 2025-03-03 23:37:17 nemo_logging:393] All model parallel group ranks: [[0]]\n", - "[NeMo I 2025-03-03 23:37:17 nemo_logging:393] Rank 0 has tensor model parallel group: [0]\n", - "[NeMo I 2025-03-03 23:37:17 nemo_logging:393] All tensor model parallel group ranks: [[0]]\n", - "[NeMo I 2025-03-03 23:37:17 nemo_logging:393] Rank 0 has tensor model parallel rank: 0\n", - "[NeMo I 2025-03-03 23:37:17 nemo_logging:393] Rank 0 has pipeline model parallel group: [0]\n", - "[NeMo I 2025-03-03 23:37:17 nemo_logging:393] Rank 0 has embedding group: [0]\n", - "[NeMo I 2025-03-03 23:37:17 nemo_logging:393] All pipeline model parallel group ranks: [[0]]\n", - "[NeMo I 2025-03-03 23:37:17 nemo_logging:393] Rank 0 has pipeline model parallel rank 0\n", - "[NeMo I 2025-03-03 23:37:17 nemo_logging:393] All embedding group ranks: [[0]]\n", - "[NeMo I 2025-03-03 23:37:17 nemo_logging:393] Rank 0 has embedding rank: 0\n", + "[NeMo W 2025-03-04 01:02:35 nemo_logging:405] No version folders would be created under the log folder as 'resume_if_exists' is enabled.\n", + "[NeMo I 2025-03-04 01:02:35 nemo_logging:393] Experiments will be logged at /tmp/tmpf9avvfzw/default\n", + "[NeMo W 2025-03-04 01:02:35 nemo_logging:405] \"update_logger_directory\" is True. Overwriting tensorboard logger \"save_dir\" to /tmp/tmpf9avvfzw\n", + "[NeMo I 2025-03-04 01:02:35 nemo_logging:393] Using byte-level tokenization\n", + "[NeMo I 2025-03-04 01:02:35 nemo_logging:393] Rank 0 has data parallel group : [0]\n", + "[NeMo I 2025-03-04 01:02:35 nemo_logging:393] Rank 0 has combined group of data parallel and context parallel : [0]\n", + "[NeMo I 2025-03-04 01:02:35 nemo_logging:393] All data parallel group ranks with context parallel combined: [[0]]\n", + "[NeMo I 2025-03-04 01:02:35 nemo_logging:393] Ranks 0 has data parallel rank: 0\n", + "[NeMo I 2025-03-04 01:02:35 nemo_logging:393] Rank 0 has context parallel group: [0]\n", + "[NeMo I 2025-03-04 01:02:35 nemo_logging:393] All context parallel group ranks: [[0]]\n", + "[NeMo I 2025-03-04 01:02:35 nemo_logging:393] Ranks 0 has context parallel rank: 0\n", + "[NeMo I 2025-03-04 01:02:35 nemo_logging:393] Rank 0 has model parallel group: [0]\n", + "[NeMo I 2025-03-04 01:02:35 nemo_logging:393] All model parallel group ranks: [[0]]\n", + "[NeMo I 2025-03-04 01:02:35 nemo_logging:393] Rank 0 has tensor model parallel group: [0]\n", + "[NeMo I 2025-03-04 01:02:35 nemo_logging:393] All tensor model parallel group ranks: [[0]]\n", + "[NeMo I 2025-03-04 01:02:35 nemo_logging:393] Rank 0 has tensor model parallel rank: 0\n", + "[NeMo I 2025-03-04 01:02:35 nemo_logging:393] Rank 0 has pipeline model parallel group: [0]\n", + "[NeMo I 2025-03-04 01:02:35 nemo_logging:393] Rank 0 has embedding group: [0]\n", + "[NeMo I 2025-03-04 01:02:35 nemo_logging:393] All pipeline model parallel group ranks: [[0]]\n", + "[NeMo I 2025-03-04 01:02:35 nemo_logging:393] Rank 0 has pipeline model parallel rank 0\n", + "[NeMo I 2025-03-04 01:02:35 nemo_logging:393] All embedding group ranks: [[0]]\n", + "[NeMo I 2025-03-04 01:02:35 nemo_logging:393] Rank 0 has embedding rank: 0\n", "Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1\n", "[INFO | pytorch_lightning.utilities.rank_zero]: ----------------------------------------------------------------------------------------------------\n", "distributed_backend=nccl\n", "All distributed processes registered. Starting with 1 processes\n", "----------------------------------------------------------------------------------------------------\n", "\n", - "[NeMo I 2025-03-03 23:37:17 num_microbatches_calculator:228] setting number of microbatches to constant 1\n", - "[NeMo I 2025-03-03 23:37:17 nemo_logging:393] Padded vocab_size: 512, original vocab_size: 512, dummy tokens: 0.\n", - "[NeMo W 2025-03-03 23:37:17 random:220] CPU RNG state changed within GPU RNG context\n", - "[NeMo W 2025-03-03 23:37:17 random:220] CPU RNG state changed within GPU RNG context\n", - "[NeMo W 2025-03-03 23:37:17 random:220] CPU RNG state changed within GPU RNG context\n", - "[NeMo W 2025-03-03 23:37:17 random:220] CPU RNG state changed within GPU RNG context\n", - "[NeMo W 2025-03-03 23:37:17 random:220] CPU RNG state changed within GPU RNG context\n", - "[NeMo W 2025-03-03 23:37:17 random:220] CPU RNG state changed within GPU RNG context\n", - "[NeMo W 2025-03-03 23:37:17 random:220] CPU RNG state changed within GPU RNG context\n", - "[NeMo W 2025-03-03 23:37:17 random:220] CPU RNG state changed within GPU RNG context\n", - "[NeMo W 2025-03-03 23:37:17 random:220] CPU RNG state changed within GPU RNG context\n", - "[NeMo W 2025-03-03 23:37:17 random:220] CPU RNG state changed within GPU RNG context\n", - "[NeMo W 2025-03-03 23:37:17 random:220] CPU RNG state changed within GPU RNG context\n", - "[NeMo W 2025-03-03 23:37:17 random:220] CPU RNG state changed within GPU RNG context\n", - "[NeMo W 2025-03-03 23:37:17 random:220] CPU RNG state changed within GPU RNG context\n", - "[NeMo W 2025-03-03 23:37:17 random:220] CPU RNG state changed within GPU RNG context\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]\n", - "[NeMo W 2025-03-03 23:37:17 nemo_logging:405] Could not copy Trainer's 'max_steps' to LR scheduler's 'max_steps'. If you are not using an LR scheduler, this warning can safely be ignored.\n", - "[NeMo I 2025-03-03 23:37:17 nemo_logging:393] > number of parameters on (tensor, pipeline) model parallel rank (0 ,0): 1108204800\n", - "[NeMo I 2025-03-03 23:37:17 utils:302] Setting up DistributedDataParallel with config DistributedDataParallelConfig(grad_reduce_in_fp32=False, overlap_grad_reduce=False, overlap_param_gather=False, align_param_gather=False, use_distributed_optimizer=False, num_distributed_optimizer_instances=1, check_for_nan_in_grad=True, check_for_large_grads=False, bucket_size=None, average_in_collective=False, fp8_param_gather=False)\n", - "[NeMo I 2025-03-03 23:37:17 utils:323] Number of buckets for gradient all-reduce / reduce-scatter: 1\n", + "[NeMo I 2025-03-04 01:02:35 num_microbatches_calculator:228] setting number of microbatches to constant 1\n", + "[NeMo I 2025-03-04 01:02:35 nemo_logging:393] Padded vocab_size: 512, original vocab_size: 512, dummy tokens: 0.\n", + "[NeMo W 2025-03-04 01:02:35 random:220] CPU RNG state changed within GPU RNG context\n", + "[NeMo W 2025-03-04 01:02:35 random:220] CPU RNG state changed within GPU RNG context\n", + "[NeMo W 2025-03-04 01:02:35 random:220] CPU RNG state changed within GPU RNG context\n", + "[NeMo W 2025-03-04 01:02:35 random:220] CPU RNG state changed within GPU RNG context\n", + "[NeMo W 2025-03-04 01:02:35 random:220] CPU RNG state changed within GPU RNG context\n", + "[NeMo W 2025-03-04 01:02:35 random:220] CPU RNG state changed within GPU RNG context\n", + "[NeMo W 2025-03-04 01:02:35 random:220] CPU RNG state changed within GPU RNG context\n", + "[NeMo W 2025-03-04 01:02:35 random:220] CPU RNG state changed within GPU RNG context\n", + "[NeMo W 2025-03-04 01:02:35 random:220] CPU RNG state changed within GPU RNG context\n", + "[NeMo W 2025-03-04 01:02:35 random:220] CPU RNG state changed within GPU RNG context\n", + "[NeMo W 2025-03-04 01:02:35 random:220] CPU RNG state changed within GPU RNG context\n", + "[NeMo W 2025-03-04 01:02:35 random:220] CPU RNG state changed within GPU RNG context\n", + "[NeMo W 2025-03-04 01:02:35 random:220] CPU RNG state changed within GPU RNG context\n", + "[NeMo W 2025-03-04 01:02:35 random:220] CPU RNG state changed within GPU RNG context\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "[NeMo W 2025-03-04 01:02:35 nemo_logging:405] Could not copy Trainer's 'max_steps' to LR scheduler's 'max_steps'. If you are not using an LR scheduler, this warning can safely be ignored.\n", + "[NeMo I 2025-03-04 01:02:35 nemo_logging:393] > number of parameters on (tensor, pipeline) model parallel rank (0 ,0): 1108204800\n", + "[NeMo I 2025-03-04 01:02:35 utils:302] Setting up DistributedDataParallel with config DistributedDataParallelConfig(grad_reduce_in_fp32=False, overlap_grad_reduce=False, overlap_param_gather=False, align_param_gather=False, use_distributed_optimizer=False, num_distributed_optimizer_instances=1, check_for_nan_in_grad=True, check_for_large_grads=False, bucket_size=None, average_in_collective=False, fp8_param_gather=False)\n", + "[NeMo I 2025-03-04 01:02:35 utils:323] Number of buckets for gradient all-reduce / reduce-scatter: 1\n", " Params for bucket 1 (1108204800 elements):\n", + " \tmodule.decoder.layers.24.self_attention.linear_qkv.weight\n", + " \tmodule.decoder.layers.22.mixer.dense_projection.weight\n", + " \tmodule.decoder.layers.20.mixer.mixer.filter.gamma\n", + " \tmodule.decoder.layers.15.mixer.hyena_proj_conv.short_conv_weight\n", + " \tmodule.decoder.layers.12.mixer.mixer.conv_bias\n", + " \tmodule.decoder.layers.9.mixer.dense.bias\n", + " \tmodule.decoder.layers.7.mixer.dense.bias\n", + " \tmodule.decoder.layers.4.mlp.linear_fc2.weight\n", + " \tmodule.decoder.layers.2.mixer.dense.bias\n", + " \tmodule.decoder.layers.1.mixer.dense.bias\n", + " \tmodule.decoder.layers.23.mixer.mixer.filter.R\n", + " \tmodule.decoder.layers.21.mixer.dense_projection.layer_norm_weight\n", + " \tmodule.decoder.layers.18.mixer.mixer.short_conv.short_conv_weight\n", + " \tmodule.decoder.layers.10.self_attention.linear_qkv.layer_norm_weight\n", + " \tmodule.decoder.layers.6.mixer.dense_projection.layer_norm_weight\n", + " \tmodule.decoder.layers.3.mlp.linear_fc1.layer_norm_weight\n", + " \tmodule.decoder.layers.2.mixer.mixer.filter.gamma\n", + " \tmodule.decoder.layers.21.mlp.linear_fc2.weight\n", + " \tmodule.decoder.layers.16.mlp.linear_fc2.weight\n", + " \tmodule.decoder.layers.11.mlp.linear_fc1.layer_norm_weight\n", + " \tmodule.decoder.layers.9.mixer.mixer.filter.gamma\n", + " \tmodule.decoder.layers.4.mlp.linear_fc1.layer_norm_weight\n", + " \tmodule.decoder.layers.1.mixer.dense_projection.layer_norm_weight\n", + " \tmodule.decoder.final_norm.weight\n", + " \tmodule.decoder.layers.23.mixer.dense_projection.layer_norm_weight\n", + " \tmodule.decoder.layers.20.mixer.mixer.conv_bias\n", + " \tmodule.decoder.layers.13.mixer.dense.bias\n", + " \tmodule.decoder.layers.12.mixer.mixer.filter.h\n", + " \tmodule.decoder.layers.10.self_attention.linear_proj.weight\n", + " \tmodule.decoder.layers.7.mixer.mixer.short_conv.short_conv_weight\n", + " \tmodule.decoder.layers.5.mixer.dense.bias\n", " \tmodule.decoder.layers.21.mixer.dense_projection.weight\n", " \tmodule.decoder.layers.19.mixer.hyena_proj_conv.short_conv_weight\n", " \tmodule.decoder.layers.16.mixer.mixer.filter.p\n", " \tmodule.decoder.layers.14.mixer.hyena_proj_conv.short_conv_weight\n", " \tmodule.decoder.layers.11.mixer.dense.bias\n", + " \tmodule.decoder.layers.10.self_attention.linear_proj.bias\n", " \tmodule.decoder.layers.6.mixer.dense_projection.weight\n", " \tmodule.decoder.layers.3.mlp.linear_fc1.weight\n", - " \tmodule.decoder.layers.0.mlp.linear_fc1.layer_norm_weight\n", + " \tmodule.decoder.layers.0.mixer.dense_projection.weight\n", " \tmodule.decoder.layers.21.mixer.dense.bias\n", " \tmodule.decoder.layers.15.mixer.dense.bias\n", - " \tmodule.decoder.layers.12.mlp.linear_fc1.weight\n", + " \tmodule.decoder.layers.13.mixer.dense_projection.layer_norm_weight\n", " \tmodule.decoder.layers.4.mlp.linear_fc1.weight\n", - " \tmodule.decoder.layers.2.mixer.mixer.filter.p\n", + " \tmodule.decoder.layers.2.mixer.mixer.filter.R\n", + " \tmodule.decoder.layers.1.mixer.dense_projection.weight\n", " \tmodule.decoder.layers.23.mixer.dense_projection.weight\n", " \tmodule.decoder.layers.20.mlp.linear_fc2.weight\n", " \tmodule.decoder.layers.16.mixer.hyena_proj_conv.short_conv_weight\n", @@ -978,47 +1056,44 @@ " \tmodule.decoder.layers.8.mixer.hyena_proj_conv.short_conv_weight\n", " \tmodule.decoder.layers.5.mlp.linear_fc2.weight\n", " \tmodule.decoder.layers.3.self_attention.linear_qkv.layer_norm_weight\n", - " \tmodule.decoder.layers.2.mixer.hyena_proj_conv.short_conv_weight\n", - " \tmodule.decoder.layers.1.mixer.dense_projection.layer_norm_weight\n", " \tmodule.decoder.layers.24.self_attention.linear_proj.bias\n", " \tmodule.decoder.layers.21.mlp.linear_fc1.weight\n", " \tmodule.decoder.layers.16.mlp.linear_fc1.weight\n", " \tmodule.decoder.layers.15.mixer.dense_projection.layer_norm_weight\n", " \tmodule.decoder.layers.11.mixer.mixer.short_conv.short_conv_weight\n", " \tmodule.decoder.layers.6.mixer.mixer.conv_bias\n", - " \tmodule.decoder.layers.0.mlp.linear_fc1.weight\n", + " \tmodule.decoder.layers.0.mlp.linear_fc2.weight\n", " \tmodule.decoder.layers.22.mlp.linear_fc2.weight\n", " \tmodule.decoder.layers.20.mixer.mixer.filter.p\n", " \tmodule.decoder.layers.18.mixer.hyena_proj_conv.short_conv_weight\n", " \tmodule.decoder.layers.15.mixer.dense.weight\n", - " \tmodule.decoder.layers.10.self_attention.linear_proj.bias\n", " \tmodule.decoder.layers.9.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.2.mlp.linear_fc1.weight\n", - " \tmodule.decoder.layers.2.mixer.dense_projection.layer_norm_weight\n", + " \tmodule.decoder.layers.2.mlp.linear_fc2.weight\n", + " \tmodule.decoder.layers.1.mixer.dense.weight\n", " \tmodule.decoder.layers.21.mixer.dense.weight\n", " \tmodule.decoder.layers.19.mixer.dense.bias\n", " \tmodule.decoder.layers.16.mlp.linear_fc1.layer_norm_weight\n", - " \tmodule.decoder.layers.12.mixer.dense_projection.layer_norm_weight\n", " \tmodule.decoder.layers.8.mlp.linear_fc1.layer_norm_weight\n", " \tmodule.decoder.layers.4.mixer.dense.bias\n", - " \tmodule.decoder.layers.2.mixer.mixer.conv_bias\n", + " \tmodule.decoder.layers.3.self_attention.linear_proj.bias\n", + " \tmodule.decoder.layers.0.mlp.linear_fc1.layer_norm_weight\n", " \tmodule.decoder.layers.24.mlp.linear_fc2.weight\n", " \tmodule.decoder.layers.20.mixer.hyena_proj_conv.short_conv_weight\n", " \tmodule.decoder.layers.17.self_attention.linear_qkv.weight\n", " \tmodule.decoder.layers.15.mixer.dense_projection.weight\n", - " \tmodule.decoder.layers.13.mixer.mixer.filter.gamma\n", + " \tmodule.decoder.layers.12.mixer.hyena_proj_conv.short_conv_weight\n", " \tmodule.decoder.layers.9.mixer.mixer.filter.p\n", " \tmodule.decoder.layers.7.mixer.hyena_proj_conv.short_conv_weight\n", " \tmodule.decoder.layers.5.mixer.mixer.conv_bias\n", - " \tmodule.decoder.layers.2.mlp.linear_fc1.layer_norm_weight\n", + " \tmodule.decoder.layers.2.mixer.mixer.filter.p\n", " \tmodule.decoder.layers.23.mixer.dense.bias\n", " \tmodule.decoder.layers.20.mlp.linear_fc1.weight\n", " \tmodule.decoder.layers.19.mixer.dense_projection.layer_norm_weight\n", " \tmodule.decoder.layers.16.mixer.mixer.filter.R\n", " \tmodule.decoder.layers.14.mixer.dense_projection.layer_norm_weight\n", + " \tmodule.decoder.layers.13.mixer.mixer.filter.gamma\n", " \tmodule.decoder.layers.5.mlp.linear_fc1.weight\n", - " \tmodule.decoder.layers.3.self_attention.linear_proj.bias\n", - " \tmodule.decoder.layers.0.mixer.dense.bias\n", + " \tmodule.decoder.layers.2.mixer.hyena_proj_conv.short_conv_weight\n", " \tmodule.decoder.layers.24.mlp.linear_fc1.layer_norm_weight\n", " \tmodule.decoder.layers.22.mlp.linear_fc1.layer_norm_weight\n", " \tmodule.decoder.layers.19.mixer.dense.weight\n", @@ -1027,63 +1102,68 @@ " \tmodule.decoder.layers.9.mixer.hyena_proj_conv.short_conv_weight\n", " \tmodule.decoder.layers.6.mixer.dense.weight\n", " \tmodule.decoder.layers.4.mixer.mixer.short_conv.short_conv_weight\n", - " \tmodule.decoder.layers.1.mixer.mixer.conv_bias\n", + " \tmodule.decoder.layers.0.mlp.linear_fc1.weight\n", " \tmodule.decoder.layers.22.mlp.linear_fc1.weight\n", " \tmodule.decoder.layers.20.mlp.linear_fc1.layer_norm_weight\n", " \tmodule.decoder.layers.16.mixer.dense_projection.layer_norm_weight\n", " \tmodule.decoder.layers.13.mixer.mixer.conv_bias\n", + " \tmodule.decoder.layers.12.mlp.linear_fc1.weight\n", " \tmodule.decoder.layers.9.mlp.linear_fc1.weight\n", " \tmodule.decoder.layers.8.mixer.dense_projection.layer_norm_weight\n", " \tmodule.decoder.layers.5.mixer.mixer.filter.h\n", - " \tmodule.decoder.layers.1.mlp.linear_fc1.layer_norm_weight\n", + " \tmodule.decoder.layers.2.mlp.linear_fc1.weight\n", + " \tmodule.decoder.layers.2.mixer.dense_projection.layer_norm_weight\n", " \tmodule.decoder.layers.23.mixer.dense.weight\n", " \tmodule.decoder.layers.21.mlp.linear_fc1.layer_norm_weight\n", " \tmodule.decoder.layers.19.mixer.dense_projection.weight\n", " \tmodule.decoder.layers.14.mixer.dense_projection.weight\n", - " \tmodule.decoder.layers.12.mixer.mixer.conv_bias\n", " \tmodule.decoder.layers.11.mixer.hyena_proj_conv.short_conv_weight\n", " \tmodule.decoder.layers.8.mixer.dense.weight\n", " \tmodule.decoder.layers.6.mixer.dense.bias\n", - " \tmodule.decoder.layers.0.mixer.mixer.short_conv.short_conv_weight\n", + " \tmodule.decoder.layers.2.mixer.mixer.conv_bias\n", " \tmodule.decoder.layers.24.mlp.linear_fc1.weight\n", " \tmodule.decoder.layers.22.mixer.mixer.filter.h\n", " \tmodule.decoder.layers.20.mixer.mixer.filter.R\n", " \tmodule.decoder.layers.18.mixer.dense_projection.layer_norm_weight\n", " \tmodule.decoder.layers.14.mixer.dense.bias\n", " \tmodule.decoder.layers.9.mlp.linear_fc1.layer_norm_weight\n", - " \tmodule.decoder.layers.1.mixer.mixer.filter.h\n", + " \tmodule.decoder.layers.2.mlp.linear_fc1.layer_norm_weight\n", " \tmodule.decoder.layers.18.mlp.linear_fc2.weight\n", " \tmodule.decoder.layers.16.mixer.dense_projection.weight\n", " \tmodule.decoder.layers.13.mlp.linear_fc2.weight\n", " \tmodule.decoder.layers.10.self_attention.linear_qkv.weight\n", " \tmodule.decoder.layers.8.mixer.dense_projection.weight\n", " \tmodule.decoder.layers.6.mixer.mixer.filter.gamma\n", - " \tmodule.decoder.layers.2.mixer.dense_projection.weight\n", - " \tmodule.decoder.layers.0.mixer.dense_projection.layer_norm_weight\n", + " \tmodule.decoder.layers.0.mixer.dense.bias\n", " \tmodule.decoder.layers.24.self_attention.linear_qkv.layer_norm_weight\n", " \tmodule.decoder.layers.22.mixer.mixer.conv_bias\n", " \tmodule.decoder.layers.20.mixer.dense_projection.layer_norm_weight\n", " \tmodule.decoder.layers.14.mlp.linear_fc1.weight\n", - " \tmodule.decoder.layers.12.mixer.dense_projection.weight\n", + " \tmodule.decoder.layers.12.mixer.dense_projection.layer_norm_weight\n", " \tmodule.decoder.layers.9.mixer.mixer.filter.R\n", " \tmodule.decoder.layers.7.mixer.dense_projection.layer_norm_weight\n", + " \tmodule.decoder.layers.1.mixer.mixer.conv_bias\n", " \tmodule.decoder.layers.23.mixer.mixer.filter.gamma\n", " \tmodule.decoder.layers.18.mixer.dense_projection.weight\n", " \tmodule.decoder.layers.15.mlp.linear_fc2.weight\n", " \tmodule.decoder.layers.13.mixer.mixer.filter.p\n", + " \tmodule.decoder.layers.12.mixer.dense.weight\n", " \tmodule.decoder.layers.7.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.1.mlp.linear_fc2.weight\n", + " \tmodule.decoder.layers.1.mlp.linear_fc1.layer_norm_weight\n", + " \tmodule.embedding.word_embeddings.weight\n", " \tmodule.decoder.layers.24.self_attention.linear_proj.weight\n", " \tmodule.decoder.layers.21.mixer.mixer.short_conv.short_conv_weight\n", " \tmodule.decoder.layers.14.mixer.dense.weight\n", " \tmodule.decoder.layers.9.mixer.dense_projection.layer_norm_weight\n", " \tmodule.decoder.layers.4.mixer.hyena_proj_conv.short_conv_weight\n", + " \tmodule.decoder.layers.0.mixer.mixer.short_conv.short_conv_weight\n", " \tmodule.decoder.layers.20.mixer.dense_projection.weight\n", " \tmodule.decoder.layers.17.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.17.self_attention.linear_proj.bias\n", " \tmodule.decoder.layers.13.mixer.hyena_proj_conv.short_conv_weight\n", + " \tmodule.decoder.layers.12.mixer.dense_projection.weight\n", " \tmodule.decoder.layers.7.mixer.dense_projection.weight\n", " \tmodule.decoder.layers.5.mixer.hyena_proj_conv.short_conv_weight\n", + " \tmodule.decoder.layers.1.mixer.mixer.filter.h\n", " \tmodule.decoder.layers.23.mixer.mixer.conv_bias\n", " \tmodule.decoder.layers.18.mlp.linear_fc1.weight\n", " \tmodule.decoder.layers.16.mixer.dense.bias\n", @@ -1091,55 +1171,53 @@ " \tmodule.decoder.layers.11.mixer.dense_projection.layer_norm_weight\n", " \tmodule.decoder.layers.8.mixer.dense.bias\n", " \tmodule.decoder.layers.3.self_attention.linear_qkv.weight\n", - " \tmodule.decoder.layers.0.mixer.hyena_proj_conv.short_conv_weight\n", + " \tmodule.decoder.layers.2.mixer.dense_projection.weight\n", + " \tmodule.decoder.layers.0.mixer.dense_projection.layer_norm_weight\n", " \tmodule.decoder.layers.22.mixer.hyena_proj_conv.short_conv_weight\n", " \tmodule.decoder.layers.19.mlp.linear_fc2.weight\n", " \tmodule.decoder.layers.17.mlp.linear_fc1.layer_norm_weight\n", " \tmodule.decoder.layers.15.mlp.linear_fc1.layer_norm_weight\n", - " \tmodule.decoder.layers.12.mixer.dense.weight\n", " \tmodule.decoder.layers.11.mlp.linear_fc2.weight\n", " \tmodule.decoder.layers.9.mixer.dense_projection.weight\n", " \tmodule.decoder.layers.6.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.1.mixer.hyena_proj_conv.short_conv_weight\n", - " \tmodule.decoder.layers.1.mixer.dense_projection.weight\n", " \tmodule.decoder.layers.18.mixer.dense.weight\n", " \tmodule.decoder.layers.15.mlp.linear_fc1.weight\n", " \tmodule.decoder.layers.13.mlp.linear_fc1.layer_norm_weight\n", + " \tmodule.decoder.layers.12.mlp.linear_fc2.weight\n", + " \tmodule.decoder.layers.12.mlp.linear_fc1.layer_norm_weight\n", " \tmodule.decoder.layers.7.mlp.linear_fc1.weight\n", - " \tmodule.decoder.layers.1.mlp.linear_fc1.weight\n", + " \tmodule.decoder.layers.1.mlp.linear_fc2.weight\n", " \tmodule.decoder.layers.23.mlp.linear_fc2.weight\n", + " \tmodule.decoder.layers.17.self_attention.linear_proj.bias\n", " \tmodule.decoder.layers.16.mixer.dense.weight\n", " \tmodule.decoder.layers.14.mlp.linear_fc1.layer_norm_weight\n", - " \tmodule.decoder.layers.12.mixer.hyena_proj_conv.short_conv_weight\n", " \tmodule.decoder.layers.11.mixer.dense_projection.weight\n", " \tmodule.decoder.layers.8.mlp.linear_fc2.weight\n", " \tmodule.decoder.layers.6.mixer.mixer.filter.p\n", " \tmodule.decoder.layers.4.mixer.dense.weight\n", - " \tmodule.embedding.word_embeddings.weight\n", " \tmodule.decoder.layers.20.mixer.dense.bias\n", " \tmodule.decoder.layers.17.mlp.linear_fc1.weight\n", " \tmodule.decoder.layers.15.mixer.mixer.filter.h\n", " \tmodule.decoder.layers.13.mixer.mixer.filter.R\n", + " \tmodule.decoder.layers.12.mixer.dense.bias\n", " \tmodule.decoder.layers.9.mixer.mixer.conv_bias\n", " \tmodule.decoder.layers.7.mixer.dense.weight\n", - " \tmodule.decoder.layers.2.mixer.dense.weight\n", " \tmodule.decoder.layers.23.mixer.mixer.filter.p\n", " \tmodule.decoder.layers.21.mixer.hyena_proj_conv.short_conv_weight\n", " \tmodule.decoder.layers.19.mlp.linear_fc1.layer_norm_weight\n", " \tmodule.decoder.layers.10.mlp.linear_fc2.weight\n", " \tmodule.decoder.layers.6.mixer.hyena_proj_conv.short_conv_weight\n", " \tmodule.decoder.layers.4.mixer.dense_projection.layer_norm_weight\n", - " \tmodule.decoder.layers.0.mixer.dense.weight\n", + " \tmodule.decoder.layers.0.mixer.hyena_proj_conv.short_conv_weight\n", + " \tmodule.decoder.layers.17.self_attention.linear_qkv.layer_norm_weight\n", " \tmodule.decoder.layers.22.mixer.dense.bias\n", " \tmodule.decoder.layers.19.mlp.linear_fc1.weight\n", - " \tmodule.decoder.layers.17.self_attention.linear_qkv.layer_norm_weight\n", " \tmodule.decoder.layers.15.mixer.mixer.conv_bias\n", - " \tmodule.decoder.layers.13.mixer.dense_projection.layer_norm_weight\n", " \tmodule.decoder.layers.11.mlp.linear_fc1.weight\n", " \tmodule.decoder.layers.6.mlp.linear_fc1.weight\n", + " \tmodule.decoder.layers.5.mlp.linear_fc1.layer_norm_weight\n", " \tmodule.decoder.layers.5.mixer.dense_projection.layer_norm_weight\n", - " \tmodule.decoder.layers.2.mixer.dense.bias\n", - " \tmodule.decoder.layers.1.mixer.dense.bias\n", + " \tmodule.decoder.layers.1.mixer.hyena_proj_conv.short_conv_weight\n", " \tmodule.decoder.layers.23.mixer.hyena_proj_conv.short_conv_weight\n", " \tmodule.decoder.layers.20.mixer.dense.weight\n", " \tmodule.decoder.layers.18.mlp.linear_fc1.layer_norm_weight\n", @@ -1148,23 +1226,22 @@ " \tmodule.decoder.layers.8.mixer.mixer.conv_bias\n", " \tmodule.decoder.layers.5.mixer.dense.weight\n", " \tmodule.decoder.layers.3.self_attention.linear_proj.weight\n", - " \tmodule.decoder.layers.2.mixer.mixer.filter.gamma\n", + " \tmodule.decoder.layers.1.mlp.linear_fc1.weight\n", " \tmodule.decoder.layers.23.mlp.linear_fc1.weight\n", " \tmodule.decoder.layers.22.mixer.dense_projection.layer_norm_weight\n", " \tmodule.decoder.layers.19.mixer.mixer.filter.h\n", " \tmodule.decoder.layers.17.self_attention.linear_proj.weight\n", " \tmodule.decoder.layers.14.mixer.mixer.short_conv.short_conv_weight\n", - " \tmodule.decoder.layers.12.mixer.dense.bias\n", " \tmodule.decoder.layers.11.mixer.dense.weight\n", " \tmodule.decoder.layers.8.mlp.linear_fc1.weight\n", " \tmodule.decoder.layers.6.mlp.linear_fc1.layer_norm_weight\n", " \tmodule.decoder.layers.4.mixer.dense_projection.weight\n", - " \tmodule.decoder.final_norm.weight\n", " \tmodule.decoder.layers.22.mixer.dense.weight\n", " \tmodule.decoder.layers.13.mixer.dense_projection.weight\n", " \tmodule.decoder.layers.9.mixer.dense.weight\n", " \tmodule.decoder.layers.7.mlp.linear_fc1.layer_norm_weight\n", " \tmodule.decoder.layers.5.mixer.dense_projection.weight\n", + " \tmodule.decoder.layers.2.mixer.dense.weight\n", " \tmodule.decoder.layers.23.mlp.linear_fc1.layer_norm_weight\n", " \tmodule.decoder.layers.19.mixer.mixer.conv_bias\n", " \tmodule.decoder.layers.16.mixer.mixer.conv_bias\n", @@ -1172,54 +1249,23 @@ " \tmodule.decoder.layers.8.mixer.mixer.filter.h\n", " \tmodule.decoder.layers.6.mixer.mixer.filter.R\n", " \tmodule.decoder.layers.3.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.0.mixer.dense_projection.weight\n", - " \tmodule.decoder.layers.24.self_attention.linear_qkv.weight\n", - " \tmodule.decoder.layers.22.mixer.dense_projection.weight\n", - " \tmodule.decoder.layers.20.mixer.mixer.filter.gamma\n", - " \tmodule.decoder.layers.15.mixer.hyena_proj_conv.short_conv_weight\n", - " \tmodule.decoder.layers.12.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.12.mixer.mixer.filter.h\n", - " \tmodule.decoder.layers.9.mixer.dense.bias\n", - " \tmodule.decoder.layers.7.mixer.dense.bias\n", - " \tmodule.decoder.layers.4.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.2.mixer.mixer.filter.R\n", - " \tmodule.decoder.layers.23.mixer.mixer.filter.R\n", - " \tmodule.decoder.layers.21.mixer.dense_projection.layer_norm_weight\n", - " \tmodule.decoder.layers.18.mixer.mixer.short_conv.short_conv_weight\n", - " \tmodule.decoder.layers.10.self_attention.linear_qkv.layer_norm_weight\n", - " \tmodule.decoder.layers.6.mixer.dense_projection.layer_norm_weight\n", - " \tmodule.decoder.layers.5.mlp.linear_fc1.layer_norm_weight\n", - " \tmodule.decoder.layers.3.mlp.linear_fc1.layer_norm_weight\n", - " \tmodule.decoder.layers.21.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.16.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.12.mlp.linear_fc1.layer_norm_weight\n", - " \tmodule.decoder.layers.11.mlp.linear_fc1.layer_norm_weight\n", - " \tmodule.decoder.layers.9.mixer.mixer.filter.gamma\n", - " \tmodule.decoder.layers.4.mlp.linear_fc1.layer_norm_weight\n", - " \tmodule.decoder.layers.0.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.23.mixer.dense_projection.layer_norm_weight\n", - " \tmodule.decoder.layers.20.mixer.mixer.conv_bias\n", - " \tmodule.decoder.layers.13.mixer.dense.bias\n", - " \tmodule.decoder.layers.10.self_attention.linear_proj.weight\n", - " \tmodule.decoder.layers.7.mixer.mixer.short_conv.short_conv_weight\n", - " \tmodule.decoder.layers.5.mixer.dense.bias\n", - " \tmodule.decoder.layers.2.mlp.linear_fc2.weight\n", - " \tmodule.decoder.layers.1.mixer.dense.weight\n", - "[NeMo I 2025-03-03 23:37:17 nemo_logging:393] Doing selective restore from RestoreConfig(path='nemo2_evo2_1b_8k', adapter_path=None, load_model_state=True, load_optim_state=False, load_artifacts=False)\n", - "[NeMo I 2025-03-03 23:37:17 nemo_logging:393] Using dist-ckpt load strategy.\n", - "[WARNING | py.warnings ]: /workspace/bionemo2/3rdparty/Megatron-LM/megatron/core/dist_checkpointing/strategies/torch.py:847: FutureWarning: `load_state_dict` is deprecated and will be removed in future versions. Please use `load` instead.\n", + " \tmodule.decoder.layers.0.mixer.dense.weight\n", + "[NeMo I 2025-03-04 01:02:35 nemo_logging:393] Doing selective restore from RestoreConfig(path='nemo2_evo2_1b_8k', adapter_path=None, load_model_state=True, load_optim_state=False, load_artifacts=False)\n", + "[NeMo I 2025-03-04 01:02:35 nemo_logging:393] Using dist-ckpt load strategy.\n", + "[WARNING | py.warnings ]: /workspaces/bionemo-framework/3rdparty/Megatron-LM/megatron/core/dist_checkpointing/strategies/torch.py:847: FutureWarning: `load_state_dict` is deprecated and will be removed in future versions. Please use `load` instead.\n", " checkpoint.load_state_dict(\n", "\n", "[WARNING | py.warnings ]: /usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/planner_helpers.py:316: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.\n", " device = getattr(value, \"device\", None)\n", "\n", - "[NeMo I 2025-03-03 23:37:18 nemo_logging:393] Global Checkpoint Load : Rank : 0 : Start time : 1741045037.679s : Time spent in load_checkpoint: 1.103s\n", - "[NeMo I 2025-03-03 23:37:18 nemo_logging:393] Restoring model weights from RestoreConfig(path='nemo2_evo2_1b_8k', adapter_path=None, load_model_state=True, load_optim_state=False, load_artifacts=False)\n", - "[NeMo I 2025-03-03 23:37:18 nemo_logging:393] Finished restoring from RestoreConfig(path='nemo2_evo2_1b_8k', adapter_path=None, load_model_state=True, load_optim_state=False, load_artifacts=False), cleaning up.\n" + "[NeMo I 2025-03-04 01:02:36 nemo_logging:393] Global Checkpoint Load : Rank : 0 : Start time : 1741050155.807s : Time spent in load_checkpoint: 0.618s\n", + "[NeMo I 2025-03-04 01:02:36 nemo_logging:393] Restoring model weights from RestoreConfig(path='nemo2_evo2_1b_8k', adapter_path=None, load_model_state=True, load_optim_state=False, load_artifacts=False)\n", + "[NeMo I 2025-03-04 01:02:36 nemo_logging:393] Finished restoring from RestoreConfig(path='nemo2_evo2_1b_8k', adapter_path=None, load_model_state=True, load_optim_state=False, load_artifacts=False), cleaning up.\n" ] } ], "source": [ + "print(f\"Running command: {predict_var_command}\")\n", "!{predict_var_command}" ] }, @@ -1240,7 +1286,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -1268,7 +1314,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -1309,108 +1355,101 @@ " \n", " 0\n", " 17\n", - " 41199726\n", - " T\n", + " 41199729\n", " C\n", - " 0.159762\n", - " FUNC/INT\n", - " BRCA1_ref_pos_41199726_T_class_FUNC/INT\n", - " BRCA1_var_pos_41199726_TtoC_class_FUNC/INT\n", - " -1.048409\n", - " -1.048462\n", - " -0.000054\n", + " T\n", + " -2.646816\n", + " LOF\n", + " BRCA1_ref_pos_41199729_C_class_LOF\n", + " BRCA1_var_pos_41199729_CtoT_class_LOF\n", + " -0.952360\n", + " -0.953044\n", + " -0.000684\n", " \n", " \n", " 1\n", " 17\n", - " 41209074\n", + " 41215381\n", " T\n", - " A\n", - " -2.065569\n", + " G\n", + " -2.352741\n", " LOF\n", - " BRCA1_ref_pos_41209074_T_class_LOF\n", - " BRCA1_var_pos_41209074_TtoA_class_LOF\n", - " -0.826655\n", - " -0.826915\n", - " -0.000260\n", + " BRCA1_ref_pos_41215381_T_class_LOF\n", + " BRCA1_var_pos_41215381_TtoG_class_LOF\n", + " -0.848368\n", + " -0.848730\n", + " -0.000361\n", " \n", " \n", " 2\n", " 17\n", - " 41256913\n", - " A\n", + " 41215390\n", " C\n", - " -0.847753\n", - " FUNC/INT\n", - " BRCA1_ref_pos_41256913_A_class_FUNC/INT\n", - " BRCA1_var_pos_41256913_AtoC_class_FUNC/INT\n", - " -0.864035\n", - " -0.864014\n", - " 0.000021\n", + " A\n", + " -1.371155\n", + " LOF\n", + " BRCA1_ref_pos_41215390_C_class_LOF\n", + " BRCA1_var_pos_41215390_CtoA_class_LOF\n", + " -0.848341\n", + " -0.847456\n", + " 0.000885\n", " \n", " \n", " 3\n", " 17\n", - " 41219631\n", + " 41219688\n", " T\n", " A\n", - " -2.053739\n", + " -2.053136\n", " LOF\n", - " BRCA1_ref_pos_41219631_T_class_LOF\n", - " BRCA1_var_pos_41219631_TtoA_class_LOF\n", - " -1.091372\n", - " -1.091227\n", - " 0.000145\n", + " BRCA1_ref_pos_41219688_T_class_LOF\n", + " BRCA1_var_pos_41219688_TtoA_class_LOF\n", + " -1.027623\n", + " -1.028068\n", + " -0.000445\n", " \n", " \n", " 4\n", " 17\n", - " 41215965\n", + " 41219652\n", + " C\n", " G\n", - " A\n", - " -1.671525\n", + " -2.026390\n", " LOF\n", - " BRCA1_ref_pos_41215965_G_class_LOF\n", - " BRCA1_var_pos_41215965_GtoA_class_LOF\n", - " -0.930776\n", - " -0.930750\n", - " 0.000026\n", + " BRCA1_ref_pos_41219652_C_class_LOF\n", + " BRCA1_var_pos_41219652_CtoG_class_LOF\n", + " -1.032667\n", + " -1.032678\n", + " -0.000011\n", " \n", " \n", "\n", "" ], "text/plain": [ - " chrom pos ref alt score class \\\n", - "0 17 41199726 T C 0.159762 FUNC/INT \n", - "1 17 41209074 T A -2.065569 LOF \n", - "2 17 41256913 A C -0.847753 FUNC/INT \n", - "3 17 41219631 T A -2.053739 LOF \n", - "4 17 41215965 G A -1.671525 LOF \n", - "\n", - " ref_fasta_name \\\n", - "0 BRCA1_ref_pos_41199726_T_class_FUNC/INT \n", - "1 BRCA1_ref_pos_41209074_T_class_LOF \n", - "2 BRCA1_ref_pos_41256913_A_class_FUNC/INT \n", - "3 BRCA1_ref_pos_41219631_T_class_LOF \n", - "4 BRCA1_ref_pos_41215965_G_class_LOF \n", + " chrom pos ref alt score class \\\n", + "0 17 41199729 C T -2.646816 LOF \n", + "1 17 41215381 T G -2.352741 LOF \n", + "2 17 41215390 C A -1.371155 LOF \n", + "3 17 41219688 T A -2.053136 LOF \n", + "4 17 41219652 C G -2.026390 LOF \n", "\n", - " var_fasta_name ref_log_probs var_log_probs \\\n", - "0 BRCA1_var_pos_41199726_TtoC_class_FUNC/INT -1.048409 -1.048462 \n", - "1 BRCA1_var_pos_41209074_TtoA_class_LOF -0.826655 -0.826915 \n", - "2 BRCA1_var_pos_41256913_AtoC_class_FUNC/INT -0.864035 -0.864014 \n", - "3 BRCA1_var_pos_41219631_TtoA_class_LOF -1.091372 -1.091227 \n", - "4 BRCA1_var_pos_41215965_GtoA_class_LOF -0.930776 -0.930750 \n", + " ref_fasta_name var_fasta_name \\\n", + "0 BRCA1_ref_pos_41199729_C_class_LOF BRCA1_var_pos_41199729_CtoT_class_LOF \n", + "1 BRCA1_ref_pos_41215381_T_class_LOF BRCA1_var_pos_41215381_TtoG_class_LOF \n", + "2 BRCA1_ref_pos_41215390_C_class_LOF BRCA1_var_pos_41215390_CtoA_class_LOF \n", + "3 BRCA1_ref_pos_41219688_T_class_LOF BRCA1_var_pos_41219688_TtoA_class_LOF \n", + "4 BRCA1_ref_pos_41219652_C_class_LOF BRCA1_var_pos_41219652_CtoG_class_LOF \n", "\n", - " evo2_delta_score \n", - "0 -0.000054 \n", - "1 -0.000260 \n", - "2 0.000021 \n", - "3 0.000145 \n", - "4 0.000026 " + " ref_log_probs var_log_probs evo2_delta_score \n", + "0 -0.952360 -0.953044 -0.000684 \n", + "1 -0.848368 -0.848730 -0.000361 \n", + "2 -0.848341 -0.847456 0.000885 \n", + "3 -1.027623 -1.028068 -0.000445 \n", + "4 -1.032667 -1.032678 -0.000011 " ] }, - "execution_count": 26, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -1440,12 +1479,12 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 14, "metadata": {}, "outputs": [ { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -1493,20 +1532,23 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We can also calculate the area under the receiver operating characteristic curve (AUROC) of this zero-shot prediction method.\n", + "We can also calculate the area under the receiver operating characteristic curve (AUROC) of this zero-shot prediction method. Note that the results are nearly random unless you are on one of the following configurations:\n", + "* `--fp8` on an fp8 enabled GPU with either the 1b or 7b models. The 40b likely works as well.\n", + "* the 7b model uniquely seems to work well without `--fp8` so if you are on an older device, the 7b model should produce\n", + " robust results. Change the `MODEL_SIZE` earlier in this tutorial and rerun for good results in that case.\n", "\n" ] }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Zero-shot prediction AUROC: 0.4\n" + "Zero-shot prediction AUROC: 0.77\n" ] } ], @@ -1517,13 +1559,6 @@ "auroc = roc_auc_score(y_true, -brca1_df['evo2_delta_score'])\n", "print(f'Zero-shot prediction AUROC: {auroc:.2}')" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": {