Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alit/mamba #9575

Merged
merged 32 commits into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
8a26848
adding mamba support
Jul 1, 2024
73d7c4c
fix import mixins
Jul 1, 2024
66886b5
rm convert jamba
Jul 1, 2024
f9e2066
Apply isort and black reformatting
JRD971000 Jul 1, 2024
96ab05c
more cleanups
Jul 2, 2024
f24cd69
Merge branch 'alit/mamba' of https://github.com/NVIDIA/NeMo into alit…
Jul 2, 2024
2e74b64
use GPT text gen
Jul 2, 2024
05c377a
Apply isort and black reformatting
JRD971000 Jul 2, 2024
59f176a
fixing gbs in TP convetor
Jul 2, 2024
74a30de
resolve merge conflicts
Jul 2, 2024
dfc24e2
Apply isort and black reformatting
JRD971000 Jul 2, 2024
7edd5cc
add reqs
Jul 2, 2024
3eee1c7
Merge branch 'alit/mamba' of https://github.com/NVIDIA/NeMo into alit…
Jul 2, 2024
c0afdc4
add tutorial
Jul 3, 2024
6097379
minor fix to tutorial
Jul 3, 2024
8e7aea0
moving finetuning files
arendu Jul 3, 2024
1db8269
moving finetuning files
arendu Jul 3, 2024
0f326d6
address comments
Jul 4, 2024
da7461a
Apply isort and black reformatting
JRD971000 Jul 4, 2024
022622e
address comments
Jul 4, 2024
7b67568
Apply isort and black reformatting
JRD971000 Jul 4, 2024
7dce2bf
Merge branch 'main' into alit/mamba
JRD971000 Jul 5, 2024
0d5cc37
address comments
Jul 5, 2024
2cf9040
add mamba dependancies
Jul 5, 2024
0353eb9
Merge branch 'main' into alit/mamba
JRD971000 Jul 5, 2024
23a2d20
add mcore tag
Jul 5, 2024
a9a24b7
merge main
Jul 5, 2024
53792ab
Merge branch 'alit/mamba' of https://github.com/NVIDIA/NeMo into alit…
Jul 5, 2024
2b97d0b
Merge branch 'main' into alit/mamba
JRD971000 Jul 5, 2024
0747052
modify dockerfile ci
Jul 6, 2024
14a8878
modify dockerfile ci
Jul 6, 2024
0860395
Merge branch 'main' into alit/mamba
JRD971000 Jul 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions Dockerfile.ci
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ EOF
WORKDIR /workspace

# Install NeMo requirements
ARG TE_TAG=bfe21c3d68b0a9951e5716fb520045db53419c5e
ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea
ARG MODELOPT_VERSION=0.13.0
ARG MCORE_TAG=0ab8dd4c7520408683fdb9f8ac119eff7d38fc0e
ARG MCORE_TAG=0bc3547702464501feefeb5523b7a17e591b21fa
ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c
RUN \
--mount=type=bind,source=requirements,target=requirements \
Expand All @@ -61,6 +61,22 @@ git checkout ${MCORE_TAG} && \
popd && \
popd
export PYTHONPATH="${PYTHONPATH}:/workspace/Megatron-LM"

# Mamba dependancy installation
git clone https://github.com/state-spaces/mamba.git && \
cd mamba && \
git checkout v2.0.3 && \
python setup.py install && \
cd .. && \
rm -rf mamba

git clone https://github.com/Dao-AILab/causal-conv1d && \
cd causal-conv1d && \
git checkout v1.2.2.post1 && \
python setup.py install && \
cd .. && \
rm -rf causal-conv1d

EOF

# Copy over NeMo code
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,119 +48,38 @@ exp_manager:


model:
restore_from_path: null
# model parallelism
mcore_gpt: True
micro_batch_size: 1
global_batch_size: 8
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
virtual_pipeline_model_parallel_size: null
expert_model_parallel_size: 1 # expert model parallelism

vocab_size: 65536
# model architecture
encoder_seq_length: 4096
hybrid_override_pattern: null
max_position_embeddings: ${.encoder_seq_length}
position_embedding_type: 'none' # Position embedding type. Options ['learned_absolute', 'rope', 'alibi', 'kerple' , 'xpos', 'sandwich'] xpos and sandwich are experimental.
num_layers: 64
gated_linear_unit: False
add_bias_linear: False
num_query_groups: 8
ngroups_mamba: 8
attention_dropout: 0.0
hidden_dropout: 0.0
hidden_size: 4096
ffn_hidden_size: 14336 # Transformer FFN hidden size. Usually 4 * hidden_size.
num_attention_heads: 32
transformer_block_type: pre_ln
init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.')
kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null
apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number.
normalization: RMSNorm
layernorm_epsilon: 1e-5
num_moe_experts: 16
moe_router_topk: 2
moe_aux_loss_coeff: 0.001
make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency.
pre_process: True # add embedding
post_process: True # add pooler
megatron_legacy: False
persist_layer_norm: True


# mixed-precision
attention_softmax_in_fp32: False

# Distributed checkpoint setup
dist_ckpt_format: 'zarr' # Set to 'torch_dist' to use PyTorch distributed checkpoint format.
dist_ckpt_load_on_device: True # whether to load checkpoint weights directly on GPU or to CPU
dist_ckpt_parallel_save: False # if true, each worker will write its own part of the dist checkpoint


tokenizer:
library: 'huggingface'
type: 'EleutherAI/gpt-neox-20b'
model: null
vocab_file: null
merge_file: null
sentencepiece_legacy: False
use_fast: True

# precision
native_amp_init_scale: 4294967296 # 2 ** 32
native_amp_growth_interval: 1000
fp32_residual_connection: False # Move residual connections to fp32
fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16

# Megatron O2-style half-precision
megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters
grad_allreduce_chunk_size_mb: 125

# Fusion
grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce. Only used with O2 and no pipeline parallelism..
gradient_accumulation_fusion: True # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism and O2.
bias_activation_fusion: False # Use a kernel that fuses the bias addition from weight matrices with the subsequent activation function.
bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition.
masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask.
get_attention_mask_from_fusion: True # When using fused softmax it will create the attention mask so we won't copy it to the pipeline stages.
apply_rope_fusion: True # Use a kernel to add rotary positional embeddings. Only used if position_embedding_type=rope

# miscellaneous
seed: 1234
use_cpu_initialization: False # Init weights on the CPU (slow for large models)
onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter.
gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory)

## Activation Checkpointing
# NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed.
# These memory intensive activations are also less compute intensive which makes activation checkpointing more efficient for LLMs (20B+).
tensor_model_parallel_size: 1 # intra-layer model parallelism
pipeline_model_parallel_size: 1 # inter-layer model parallelism

encoder_seq_length: 1024
global_batch_size: 8
micro_batch_size: 1
restore_from_path: ??? # Path to an existing .nemo model you wish to add new tasks to or run inference with
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
save_nemo_on_validation_end: False # Saves an inference ready .nemo file every time a checkpoint is saved during training.
sync_batch_comm: False
megatron_amp_O2: False

## Sequence Parallelism
# Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially
# See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.
# 'full' will checkpoint the entire transformer layer.
activations_checkpoint_granularity: null # 'selective' or 'full'
activations_checkpoint_method: null # 'uniform', 'block'
sequence_parallel: False

## Activation Checkpoint
activations_checkpoint_granularity: null # 'selective' or 'full'
activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective'
# 'uniform' divides the total number of transformer layers and checkpoints the input activation
# of each chunk at the specified granularity. When used with 'selective', 'uniform' checkpoints all attention blocks in the model.
# of each chunk at the specified granularity
# 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity
activations_checkpoint_num_layers: null
# when using 'uniform' this creates groups of transformer layers to checkpoint. Usually set to 1. Increase to save more memory.
# when using 'block' this this will checkpoint the first activations_checkpoint_num_layers per pipeline stage.
num_micro_batches_with_partial_activation_checkpoints: null
# This feature is valid only when used with pipeline-model-parallelism.
# When an integer value is provided, it sets the number of micro-batches where only a partial number of Transformer layers get checkpointed
# and recomputed within a window of micro-batches. The rest of micro-batches in the window checkpoint all Transformer layers. The size of window is
# set by the maximum outstanding micro-batch backpropagations, which varies at different pipeline stages. The number of partial layers to checkpoint
# per micro-batch is set by 'activations_checkpoint_num_layers' with 'activations_checkpoint_method' of 'block'.
# This feature enables using activation checkpoint at a fraction of micro-batches up to the point of full GPU memory usage.
activations_checkpoint_num_layers: null # not used with 'selective'
activations_checkpoint_layers_per_pipeline: null
# This feature is valid only when used with pipeline-model-parallelism.
# When an integer value (rounded down when float is given) is provided, it sets the number of Transformer layers to skip checkpointing at later
# pipeline stages. For example, 'activations_checkpoint_layers_per_pipeline' of 3 makes pipeline stage 1 to checkpoint 3 layers less than
# stage 0 and stage 2 to checkpoint 6 layers less stage 0, and so on. This is possible because later pipeline stage
# uses less GPU memory with fewer outstanding micro-batch backpropagations. Used with 'num_micro_batches_with_partial_activation_checkpoints',
# this feature removes most of activation checkpoints at the last pipeline stage, which is the critical execution path.
sequence_parallel: False
answer_only_loss: True
gradient_as_bucket_view: False

hidden_dropout: 0.0
attention_dropout: 0.0
ffn_dropout: 0.0

peft:
peft_scheme: "lora" # can be either adapter,ia3, lora, or ptuning
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,113 +39,39 @@ exp_manager:
model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}}

model:
restore_from_path: null
# model parallelism
mcore_gpt: True
micro_batch_size: 2
global_batch_size: 2
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
virtual_pipeline_model_parallel_size: null
expert_model_parallel_size: 1 # expert model parallelism
hybrid_override_pattern: null
vocab_size: 65536
# model architecture
encoder_seq_length: 4096
max_position_embeddings: ${.encoder_seq_length}
position_embedding_type: 'none' # Position embedding type. Options ['learned_absolute', 'rope', 'alibi', 'kerple' , 'xpos', 'sandwich'] xpos and sandwich are experimental.
num_layers: 64
gated_linear_unit: False
num_query_groups: 8
ngroups_mamba: 8
attention_dropout: 0.0
hidden_dropout: 0.0
hidden_size: 4096
ffn_hidden_size: 14336 # Transformer FFN hidden size. Usually 4 * hidden_size.
num_attention_heads: 32
transformer_block_type: pre_ln
init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.')
kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null
apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number.
normalization: RMSNorm
layernorm_epsilon: 1e-5
num_moe_experts: 16
moe_router_topk: 2
moe_aux_loss_coeff: 0.001
make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency.
pre_process: True # add embedding
post_process: True # add pooler
megatron_legacy: False
persist_layer_norm: True
add_bias_linear: False

answer_only_loss: True

tokenizer:
library: 'huggingface'
type: 'EleutherAI/gpt-neox-20b'
model: null
vocab_file: null
merge_file: null
sentencepiece_legacy: False
use_fast: True


# precision
native_amp_init_scale: 4294967296 # 2 ** 32
native_amp_growth_interval: 1000
fp32_residual_connection: False # Move residual connections to fp32
fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16

# Megatron O2-style half-precision
megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters
grad_allreduce_chunk_size_mb: 125

# Fusion
grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce. Only used with O2 and no pipeline parallelism..
gradient_accumulation_fusion: True # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism and O2.
bias_activation_fusion: False # Use a kernel that fuses the bias addition from weight matrices with the subsequent activation function.
bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition.
masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask.
get_attention_mask_from_fusion: True # When using fused softmax it will create the attention mask so we won't copy it to the pipeline stages.
apply_rope_fusion: True # Use a kernel to add rotary positional embeddings. Only used if position_embedding_type=rope


# miscellaneous
seed: 1234
use_cpu_initialization: False # Init weights on the CPU (slow for large models)
onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter.
gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory)

## Activation Checkpointing
# NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed.
# These memory intensive activations are also less compute intensive which makes activation checkpointing more efficient for LLMs (20B+).
tensor_model_parallel_size: 1 # intra-layer model parallelism
pipeline_model_parallel_size: 1 # inter-layer model parallelism

encoder_seq_length: 1024
global_batch_size: 8
micro_batch_size: 1
restore_from_path: ??? # Path to an existing .nemo model you wish to add new tasks to or run inference with
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
save_nemo_on_validation_end: False # Saves an inference ready .nemo file every time a checkpoint is saved during training.
sync_batch_comm: False
megatron_amp_O2: False

## Sequence Parallelism
# Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially
# See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.
# 'full' will checkpoint the entire transformer layer.
activations_checkpoint_granularity: null # 'selective' or 'full'
activations_checkpoint_recurrent: False # If set to True, the checkpointing is only done for rglru and conv1d and not for attention and mlp layers
activations_checkpoint_method: null # 'uniform', 'block'
sequence_parallel: False

## Activation Checkpoint
activations_checkpoint_granularity: null # 'selective' or 'full'
activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective'
# 'uniform' divides the total number of transformer layers and checkpoints the input activation
# of each chunk at the specified granularity. When used with 'selective', 'uniform' checkpoints all attention blocks in the model.
# of each chunk at the specified granularity
# 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity
activations_checkpoint_num_layers: null
# when using 'uniform' this creates groups of transformer layers to checkpoint. Usually set to 1. Increase to save more memory.
# when using 'block' this this will checkpoint the first activations_checkpoint_num_layers per pipeline stage.
num_micro_batches_with_partial_activation_checkpoints: null
# This feature is valid only when used with pipeline-model-parallelism.
# When an integer value is provided, it sets the number of micro-batches where only a partial number of Transformer layers get checkpointed
# and recomputed within a window of micro-batches. The rest of micro-batches in the window checkpoint all Transformer layers. The size of window is
# set by the maximum outstanding micro-batch backpropagations, which varies at different pipeline stages. The number of partial layers to checkpoint
# per micro-batch is set by 'activations_checkpoint_num_layers' with 'activations_checkpoint_method' of 'block'.
# This feature enables using activation checkpoint at a fraction of micro-batches up to the point of full GPU memory usage.
activations_checkpoint_num_layers: null # not used with 'selective'
activations_checkpoint_layers_per_pipeline: null
# This feature is valid only when used with pipeline-model-parallelism.
# When an integer value (rounded down when float is given) is provided, it sets the number of Transformer layers to skip checkpointing at later
# pipeline stages. For example, 'activations_checkpoint_layers_per_pipeline' of 3 makes pipeline stage 1 to checkpoint 3 layers less than
# stage 0 and stage 2 to checkpoint 6 layers less stage 0, and so on. This is possible because later pipeline stage
# uses less GPU memory with fewer outstanding micro-batch backpropagations. Used with 'num_micro_batches_with_partial_activation_checkpoints',
# this feature removes most of activation checkpoints at the last pipeline stage, which is the critical execution path.
sequence_parallel: False
answer_only_loss: True
gradient_as_bucket_view: False

hidden_dropout: 0.0
attention_dropout: 0.0
ffn_dropout: 0.0


peft:
peft_scheme: null # can be either adapter,ia3, lora, or ptuning
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
# limitations under the License.

import torch

# from megatron.core.models.mamba import MambaModel
# from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec
from megatron.core.models.mamba import MambaModel
from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec
from omegaconf.dictconfig import DictConfig
from pytorch_lightning.trainer.trainer import Trainer

Expand Down Expand Up @@ -46,16 +45,15 @@ def model_provider_func(self, pre_process, post_process):
self.transformer_config.layernorm_epsilon = self.cfg.get('layernorm_epsilon', 1e-5)

# TODO @ataghibakhsh: add mamba_ssm_ngroups=self.cfg.get('mamba_ssm_ngroups', 8) once MLM MR merged
# TODO @ataghibakhsh: add the following
'''MambaModel(

model = MambaModel(
config=self.transformer_config,
max_sequence_length=self.cfg.get('encoder_seq_length', 4096),
vocab_size=self.cfg.get('vocab_size', 65536),
mamba_ssm_ngroups=self.cfg.get('mamba_ssm_ngroups', 8),
mamba_stack_spec=mamba_stack_spec,
hybrid_override_pattern=self.hybrid_override_pattern,
)'''
# after package mismatch is resovled
model = None
)

return model

Expand Down
17 changes: 9 additions & 8 deletions nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,15 @@ def _check_and_add_adapter(self, name, module, peft_name, peft_cfg, name_key_to_
f'model.{mcore_target}',
f'model.module.{mcore_target}',
]: # simple string match for now
swap_mcore_mixin(module, mcore_mixin)
if model_utils.import_class_by_path(peft_cfg._target_) in module.get_accepted_adapter_types():
module.add_adapter(
name=peft_name,
cfg=peft_cfg,
base_model_cfg=self.cfg,
model_parallel_config=self.model_parallel_config,
)
if not isinstance(module, IdentityOp):
swap_mcore_mixin(module, mcore_mixin)
if model_utils.import_class_by_path(peft_cfg._target_) in module.get_accepted_adapter_types():
module.add_adapter(
name=peft_name,
cfg=peft_cfg,
base_model_cfg=self.cfg,
model_parallel_config=self.model_parallel_config,
)
elif isinstance(module, AdapterModuleMixin):
if model_utils.import_class_by_path(peft_cfg._target_) in module.get_accepted_adapter_types():
module.add_adapter(
Expand Down
2 changes: 0 additions & 2 deletions requirements/requirements_nlp.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
accelerated-scan
boto3
causal-conv1d==1.2.0.post2
einops
faiss-cpu
fasttext
Expand All @@ -10,7 +9,6 @@ gdown
h5py
ijson
jieba
mamba-ssm==1.2.0.post1
markdown2
matplotlib>=3.3.2
#megatron_core>0.6.0 # add back once mcore on pypi is compatible again
Expand Down
Loading
Loading