Skip to content

Commit

Permalink
use scan to control remat
Browse files Browse the repository at this point in the history
  • Loading branch information
haoliu committed Jul 25, 2024
1 parent 0c20bfd commit 3778ac1
Show file tree
Hide file tree
Showing 10 changed files with 11 additions and 45 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ There are language-only and video-language versions, offering context sizes from


## Code structure
Use `scan_query_chunk_size` and `scan_key_chunk_size` to control the block size in blockwise compute of the self-attention. Use `scan_mlp_chunk_size` to control the block size in blockwise compute of the feedforward network. Use `scan_attention=True` and `scan_mlp=True` to enable/disable blockwise compute in the self-attention and feed-forward network. Use `remat_attention` and `remat_mlp` to control the rematerialization policy with `nothing_saveable` recommended.
Use `scan_query_chunk_size` and `scan_key_chunk_size` to control the block size in blockwise compute of the self-attention. Use `scan_mlp_chunk_size` to control the block size in blockwise compute of the feedforward network. Use `scan_attention=True` and `scan_mlp=True` to enable/disable blockwise compute in the self-attention and feed-forward network.

You can use `mesh_dim=dp, fsdp, tp, sp` to control the degree of parallelism and RingAttention. It is a string of 4 integers separated by commas, representing the number of data parallelism, fully sharded data parallelism, tensor parallelism, and sequence parallelism.
For example, `mesh_dim='1,64,4,1'` means 1 data parallelism, 64 fully sharded data parallelism, 4 tensor parallelism, and 1 sequence parallelism. `mesh_dim='1,1,4,64'` means 1 data parallelism, 1 fully sharded data parallelism, 4 tensor parallelism, and 64 sequence parallelism for RingAttention.
Expand Down
33 changes: 7 additions & 26 deletions lwm/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging

from ml_collections import ConfigDict
from tux import function_args_to_config, load_pickle, open_file, with_sharding_constraint, get_jax_mesh, get_gradient_checkpoint_policy
from tux import function_args_to_config, load_pickle, open_file, with_sharding_constraint, get_jax_mesh
from ringattention import ringattention, blockwise_feedforward, ringattention_jax


Expand Down Expand Up @@ -150,11 +150,8 @@ def __init__(
embd_pdrop=0.0,
attn_pdrop=0.0,
tie_word_embeddings=False,
remat_block='',
remat_attention='',
remat_mlp='',
scan_attention=False,
scan_mlp=False,
scan_attention=True,
scan_mlp=True,
scan_query_chunk_size=1024,
scan_key_chunk_size=1024,
scan_mlp_chunk_size=1024,
Expand All @@ -176,9 +173,6 @@ def __init__(
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attn_pdrop = attn_pdrop
self.remat_block = remat_block
self.remat_attention = remat_attention
self.remat_mlp = remat_mlp
self.scan_attention = scan_attention
self.scan_mlp = scan_mlp
self.scan_query_chunk_size = scan_query_chunk_size
Expand Down Expand Up @@ -556,7 +550,7 @@ def __call__(
query_chunk_size=self.config.scan_query_chunk_size,
key_chunk_size=self.config.scan_key_chunk_size,
dtype=self.dtype,
policy=get_gradient_checkpoint_policy('nothing_saveable'),
policy=jax.checkpoint_policies.nothing_saveable,
precision=self.precision,
prevent_cse=not self.config.scan_layers,
)
Expand Down Expand Up @@ -673,19 +667,12 @@ class FlaxLLaMABlock(nn.Module):
def setup(self) -> None:
attention_module = FlaxLLaMAAttention
mlp_module = FlaxLLaMAMLP
if self.config.remat_attention != '':
attention_module = remat(
FlaxLLaMAAttention, static_argnums=(4, 5, 6),
policy=get_gradient_checkpoint_policy(self.config.remat_attention),
prevent_cse=not self.config.scan_layers,
)
if self.config.remat_mlp != '':
if self.scan_map:
mlp_module = remat(
FlaxLLaMAMLP, static_argnums=(1,),
policy=get_gradient_checkpoint_policy(self.config.remat_mlp),
self.feed_forward, static_argnums=(1,),
policy=jax.checkpoint_policies.nothing_saveable,
prevent_cse=not self.config.scan_layers,
)

self.attention = attention_module(
self.config,
dtype=self.dtype,
Expand Down Expand Up @@ -930,12 +917,6 @@ def __call__(
all_hidden_states = () if output_hidden_states else None

block = FlaxLLaMABlock
if self.config.remat_block != '':
block = remat(
FlaxLLaMABlock, static_argnums=(4, 5, 6),
prevent_cse=not self.config.scan_layers,
policy=get_gradient_checkpoint_policy(self.config.remat_block)
)
if self.config.scan_layers:
initializing = self.is_mutable_collection('params')
params_spec = (
Expand Down
3 changes: 0 additions & 3 deletions lwm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,6 @@ def main(argv):
llama_config = config_cls.load_config(FLAGS.load_llama_config)
updates = config_cls(**FLAGS.llama)
llama_config.update(dict(
remat_block=updates.remat_block,
remat_attention=updates.remat_attention,
remat_mlp=updates.remat_mlp,
scan_attention=updates.scan_attention,
scan_mlp=updates.scan_mlp,
scan_query_chunk_size=updates.scan_query_chunk_size,
Expand Down
3 changes: 0 additions & 3 deletions lwm/vision_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,6 @@ def _load_model(self):
llama_config = VideoLLaMAConfig.load_config(FLAGS.load_llama_config)
updates = VideoLLaMAConfig(**FLAGS.llama)
llama_config.update(dict(
remat_block=updates.remat_block,
remat_attention=updates.remat_attention,
remat_mlp=updates.remat_mlp,
scan_attention=updates.scan_attention,
scan_mlp=updates.scan_mlp,
scan_query_chunk_size=updates.scan_query_chunk_size,
Expand Down
3 changes: 0 additions & 3 deletions lwm/vision_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ def main(argv):
llama_config = VideoLLaMAConfig.load_config(FLAGS.load_llama_config)
updates = VideoLLaMAConfig(**FLAGS.llama)
llama_config.update(dict(
remat_block=updates.remat_block,
remat_attention=updates.remat_attention,
remat_mlp=updates.remat_mlp,
scan_attention=updates.scan_attention,
scan_mlp=updates.scan_mlp,
scan_query_chunk_size=updates.scan_query_chunk_size,
Expand Down
3 changes: 0 additions & 3 deletions scripts/eval_needle.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,6 @@ def _load_model(self):
llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)
updates = LLaMAConfig(**FLAGS.llama)
llama_config.update(dict(
remat_block=updates.remat_block,
remat_attention=updates.remat_attention,
remat_mlp=updates.remat_mlp,
scan_attention=updates.scan_attention,
scan_mlp=updates.scan_mlp,
scan_query_chunk_size=updates.scan_query_chunk_size,
Expand Down
3 changes: 0 additions & 3 deletions scripts/eval_needle_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,6 @@ def _load_model(self):
llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)
updates = LLaMAConfig(**FLAGS.llama)
llama_config.update(dict(
remat_block=updates.remat_block,
remat_attention=updates.remat_attention,
remat_mlp=updates.remat_mlp,
scan_attention=updates.scan_attention,
scan_mlp=updates.scan_mlp,
scan_query_chunk_size=updates.scan_query_chunk_size,
Expand Down
2 changes: 1 addition & 1 deletion scripts/run_train_text.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ python3 -u -m lwm.train \
--save_model_freq=0 \
--save_milestone_freq=10 \
--load_llama_config='debug' \
--update_llama_config="dict(theta=10000,max_sequence_length=2048,scan_attention=True,scan_query_chunk_size=256,scan_key_chunk_size=512,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=1024,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
--update_llama_config="dict(theta=10000,max_sequence_length=2048,scan_attention=True,scan_query_chunk_size=256,scan_key_chunk_size=512,scan_mlp=True,scan_mlp_chunk_size=1024,scan_layers=True)" \
--tokenizer="$llama_tokenizer_path" \
--optimizer.type='adamw' \
--optimizer.accumulate_gradient_steps=1 \
Expand Down
2 changes: 1 addition & 1 deletion scripts/run_train_vision_text.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ python3 -u -m lwm.train \
--save_model_freq=0 \
--save_milestone_freq=10 \
--load_llama_config='debug' \
--update_llama_config="dict(theta=50000000,max_sequence_length=2048,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
--update_llama_config="dict(theta=50000000,max_sequence_length=2048,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,scan_mlp=True,scan_mlp_chunk_size=8192,scan_layers=True)" \
--tokenizer="$llama_tokenizer_path" \
--optimizer.type='adamw' \
--optimizer.accumulate_gradient_steps=1 \
Expand Down
2 changes: 1 addition & 1 deletion scripts/run_vision_chat.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ python3 -u -m lwm.vision_chat \
--dtype='fp32' \
--load_llama_config='7b' \
--max_n_frames=8 \
--update_llama_config="dict(sample_mode='text',theta=50000000,max_sequence_length=131072,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,remat_attention='',scan_mlp=False,scan_mlp_chunk_size=2048,remat_mlp='',remat_block='',scan_layers=True)" \
--update_llama_config="dict(sample_mode='text',theta=50000000,max_sequence_length=131072,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=2048,scan_layers=True)" \
--load_checkpoint="params::$lwm_checkpoint" \
--tokenizer="$llama_tokenizer_path" \
2>&1 | tee ~/output.log
Expand Down

0 comments on commit 3778ac1

Please sign in to comment.