From 3778ac1eb0b1b4cc38fc83864879ee69c3087c07 Mon Sep 17 00:00:00 2001 From: haoliu Date: Thu, 25 Jul 2024 01:02:59 +0000 Subject: [PATCH] use scan to control remat --- README.md | 2 +- lwm/llama.py | 33 +++++++------------------------- lwm/train.py | 3 --- lwm/vision_chat.py | 3 --- lwm/vision_generation.py | 3 --- scripts/eval_needle.py | 3 --- scripts/eval_needle_multi.py | 3 --- scripts/run_train_text.sh | 2 +- scripts/run_train_vision_text.sh | 2 +- scripts/run_vision_chat.sh | 2 +- 10 files changed, 11 insertions(+), 45 deletions(-) diff --git a/README.md b/README.md index 5de1c87..9c8c127 100644 --- a/README.md +++ b/README.md @@ -89,7 +89,7 @@ There are language-only and video-language versions, offering context sizes from ## Code structure -Use `scan_query_chunk_size` and `scan_key_chunk_size` to control the block size in blockwise compute of the self-attention. Use `scan_mlp_chunk_size` to control the block size in blockwise compute of the feedforward network. Use `scan_attention=True` and `scan_mlp=True` to enable/disable blockwise compute in the self-attention and feed-forward network. Use `remat_attention` and `remat_mlp` to control the rematerialization policy with `nothing_saveable` recommended. +Use `scan_query_chunk_size` and `scan_key_chunk_size` to control the block size in blockwise compute of the self-attention. Use `scan_mlp_chunk_size` to control the block size in blockwise compute of the feedforward network. Use `scan_attention=True` and `scan_mlp=True` to enable/disable blockwise compute in the self-attention and feed-forward network. You can use `mesh_dim=dp, fsdp, tp, sp` to control the degree of parallelism and RingAttention. It is a string of 4 integers separated by commas, representing the number of data parallelism, fully sharded data parallelism, tensor parallelism, and sequence parallelism. For example, `mesh_dim='1,64,4,1'` means 1 data parallelism, 64 fully sharded data parallelism, 4 tensor parallelism, and 1 sequence parallelism. `mesh_dim='1,1,4,64'` means 1 data parallelism, 1 fully sharded data parallelism, 4 tensor parallelism, and 64 sequence parallelism for RingAttention. diff --git a/lwm/llama.py b/lwm/llama.py index 6173781..3d3efd0 100644 --- a/lwm/llama.py +++ b/lwm/llama.py @@ -26,7 +26,7 @@ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging from ml_collections import ConfigDict -from tux import function_args_to_config, load_pickle, open_file, with_sharding_constraint, get_jax_mesh, get_gradient_checkpoint_policy +from tux import function_args_to_config, load_pickle, open_file, with_sharding_constraint, get_jax_mesh from ringattention import ringattention, blockwise_feedforward, ringattention_jax @@ -150,11 +150,8 @@ def __init__( embd_pdrop=0.0, attn_pdrop=0.0, tie_word_embeddings=False, - remat_block='', - remat_attention='', - remat_mlp='', - scan_attention=False, - scan_mlp=False, + scan_attention=True, + scan_mlp=True, scan_query_chunk_size=1024, scan_key_chunk_size=1024, scan_mlp_chunk_size=1024, @@ -176,9 +173,6 @@ def __init__( self.resid_pdrop = resid_pdrop self.embd_pdrop = embd_pdrop self.attn_pdrop = attn_pdrop - self.remat_block = remat_block - self.remat_attention = remat_attention - self.remat_mlp = remat_mlp self.scan_attention = scan_attention self.scan_mlp = scan_mlp self.scan_query_chunk_size = scan_query_chunk_size @@ -556,7 +550,7 @@ def __call__( query_chunk_size=self.config.scan_query_chunk_size, key_chunk_size=self.config.scan_key_chunk_size, dtype=self.dtype, - policy=get_gradient_checkpoint_policy('nothing_saveable'), + policy=jax.checkpoint_policies.nothing_saveable, precision=self.precision, prevent_cse=not self.config.scan_layers, ) @@ -673,19 +667,12 @@ class FlaxLLaMABlock(nn.Module): def setup(self) -> None: attention_module = FlaxLLaMAAttention mlp_module = FlaxLLaMAMLP - if self.config.remat_attention != '': - attention_module = remat( - FlaxLLaMAAttention, static_argnums=(4, 5, 6), - policy=get_gradient_checkpoint_policy(self.config.remat_attention), - prevent_cse=not self.config.scan_layers, - ) - if self.config.remat_mlp != '': + if self.scan_map: mlp_module = remat( - FlaxLLaMAMLP, static_argnums=(1,), - policy=get_gradient_checkpoint_policy(self.config.remat_mlp), + self.feed_forward, static_argnums=(1,), + policy=jax.checkpoint_policies.nothing_saveable, prevent_cse=not self.config.scan_layers, ) - self.attention = attention_module( self.config, dtype=self.dtype, @@ -930,12 +917,6 @@ def __call__( all_hidden_states = () if output_hidden_states else None block = FlaxLLaMABlock - if self.config.remat_block != '': - block = remat( - FlaxLLaMABlock, static_argnums=(4, 5, 6), - prevent_cse=not self.config.scan_layers, - policy=get_gradient_checkpoint_policy(self.config.remat_block) - ) if self.config.scan_layers: initializing = self.is_mutable_collection('params') params_spec = ( diff --git a/lwm/train.py b/lwm/train.py index e8a02a2..05aaf7f 100644 --- a/lwm/train.py +++ b/lwm/train.py @@ -106,9 +106,6 @@ def main(argv): llama_config = config_cls.load_config(FLAGS.load_llama_config) updates = config_cls(**FLAGS.llama) llama_config.update(dict( - remat_block=updates.remat_block, - remat_attention=updates.remat_attention, - remat_mlp=updates.remat_mlp, scan_attention=updates.scan_attention, scan_mlp=updates.scan_mlp, scan_query_chunk_size=updates.scan_query_chunk_size, diff --git a/lwm/vision_chat.py b/lwm/vision_chat.py index eec6d4a..0abccaf 100644 --- a/lwm/vision_chat.py +++ b/lwm/vision_chat.py @@ -150,9 +150,6 @@ def _load_model(self): llama_config = VideoLLaMAConfig.load_config(FLAGS.load_llama_config) updates = VideoLLaMAConfig(**FLAGS.llama) llama_config.update(dict( - remat_block=updates.remat_block, - remat_attention=updates.remat_attention, - remat_mlp=updates.remat_mlp, scan_attention=updates.scan_attention, scan_mlp=updates.scan_mlp, scan_query_chunk_size=updates.scan_query_chunk_size, diff --git a/lwm/vision_generation.py b/lwm/vision_generation.py index 3d247b0..1a0011f 100644 --- a/lwm/vision_generation.py +++ b/lwm/vision_generation.py @@ -62,9 +62,6 @@ def main(argv): llama_config = VideoLLaMAConfig.load_config(FLAGS.load_llama_config) updates = VideoLLaMAConfig(**FLAGS.llama) llama_config.update(dict( - remat_block=updates.remat_block, - remat_attention=updates.remat_attention, - remat_mlp=updates.remat_mlp, scan_attention=updates.scan_attention, scan_mlp=updates.scan_mlp, scan_query_chunk_size=updates.scan_query_chunk_size, diff --git a/scripts/eval_needle.py b/scripts/eval_needle.py index e94c02e..2da7fd8 100644 --- a/scripts/eval_needle.py +++ b/scripts/eval_needle.py @@ -329,9 +329,6 @@ def _load_model(self): llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config) updates = LLaMAConfig(**FLAGS.llama) llama_config.update(dict( - remat_block=updates.remat_block, - remat_attention=updates.remat_attention, - remat_mlp=updates.remat_mlp, scan_attention=updates.scan_attention, scan_mlp=updates.scan_mlp, scan_query_chunk_size=updates.scan_query_chunk_size, diff --git a/scripts/eval_needle_multi.py b/scripts/eval_needle_multi.py index 9786d20..c2f2cb6 100644 --- a/scripts/eval_needle_multi.py +++ b/scripts/eval_needle_multi.py @@ -338,9 +338,6 @@ def _load_model(self): llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config) updates = LLaMAConfig(**FLAGS.llama) llama_config.update(dict( - remat_block=updates.remat_block, - remat_attention=updates.remat_attention, - remat_mlp=updates.remat_mlp, scan_attention=updates.scan_attention, scan_mlp=updates.scan_mlp, scan_query_chunk_size=updates.scan_query_chunk_size, diff --git a/scripts/run_train_text.sh b/scripts/run_train_text.sh index cf0040a..dcd832f 100755 --- a/scripts/run_train_text.sh +++ b/scripts/run_train_text.sh @@ -24,7 +24,7 @@ python3 -u -m lwm.train \ --save_model_freq=0 \ --save_milestone_freq=10 \ --load_llama_config='debug' \ - --update_llama_config="dict(theta=10000,max_sequence_length=2048,scan_attention=True,scan_query_chunk_size=256,scan_key_chunk_size=512,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=1024,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \ + --update_llama_config="dict(theta=10000,max_sequence_length=2048,scan_attention=True,scan_query_chunk_size=256,scan_key_chunk_size=512,scan_mlp=True,scan_mlp_chunk_size=1024,scan_layers=True)" \ --tokenizer="$llama_tokenizer_path" \ --optimizer.type='adamw' \ --optimizer.accumulate_gradient_steps=1 \ diff --git a/scripts/run_train_vision_text.sh b/scripts/run_train_vision_text.sh index 30e7822..36fd906 100755 --- a/scripts/run_train_vision_text.sh +++ b/scripts/run_train_vision_text.sh @@ -24,7 +24,7 @@ python3 -u -m lwm.train \ --save_model_freq=0 \ --save_milestone_freq=10 \ --load_llama_config='debug' \ - --update_llama_config="dict(theta=50000000,max_sequence_length=2048,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \ + --update_llama_config="dict(theta=50000000,max_sequence_length=2048,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,scan_mlp=True,scan_mlp_chunk_size=8192,scan_layers=True)" \ --tokenizer="$llama_tokenizer_path" \ --optimizer.type='adamw' \ --optimizer.accumulate_gradient_steps=1 \ diff --git a/scripts/run_vision_chat.sh b/scripts/run_vision_chat.sh index 286d0c9..be06b8b 100755 --- a/scripts/run_vision_chat.sh +++ b/scripts/run_vision_chat.sh @@ -22,7 +22,7 @@ python3 -u -m lwm.vision_chat \ --dtype='fp32' \ --load_llama_config='7b' \ --max_n_frames=8 \ - --update_llama_config="dict(sample_mode='text',theta=50000000,max_sequence_length=131072,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,remat_attention='',scan_mlp=False,scan_mlp_chunk_size=2048,remat_mlp='',remat_block='',scan_layers=True)" \ + --update_llama_config="dict(sample_mode='text',theta=50000000,max_sequence_length=131072,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=2048,scan_layers=True)" \ --load_checkpoint="params::$lwm_checkpoint" \ --tokenizer="$llama_tokenizer_path" \ 2>&1 | tee ~/output.log