Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mamba fix #123

Merged
merged 33 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
41d96a3
make mamba
lchu6 Aug 5, 2024
81ca3c2
add quick debug
lchu6 Aug 10, 2024
0817ddc
add quick debug
lchu6 Aug 10, 2024
5d7e936
revert debug verbosity
lchu6 Aug 10, 2024
bcad3ad
Learning rate scheduler changed (Constant)
divya-kumari32 Nov 11, 2024
f8c1651
Mamba config restore
divya-kumari32 Nov 12, 2024
db679b2
Cosine 0.01 decay
divya-kumari32 Nov 17, 2024
c60500d
Add AutoHandler
daviswer Nov 12, 2024
40245af
Add Auto cfg option for AutoHAndler
daviswer Nov 12, 2024
2cfda44
Len gets called before open
daviswer Nov 12, 2024
fb462de
path/filepath typo fix
daviswer Nov 12, 2024
ad5a566
Partitioning fix from mup-search
daviswer Nov 12, 2024
5aaf1eb
Warmup interval change
divya-kumari32 Nov 17, 2024
496f1b6
Schedule change
divya-kumari32 Nov 17, 2024
715b04c
Constant schedule
divya-kumari32 Nov 27, 2024
2fc0d73
LR schedule change (cool down and constant lr)
divya-kumari32 Dec 10, 2024
ab6e3b2
Update dataset_utils.py
divya-kumari32 Dec 14, 2024
e3eaa80
LR schedule change (Warmup + constant)
divya-kumari32 Dec 16, 2024
824fda1
Update main_training.py
divya-kumari32 Dec 17, 2024
70bc786
Mirror doc len check into AHandler, fix mypy in autoHandler
daviswer Dec 17, 2024
d48dc5f
Linting
daviswer Dec 17, 2024
aeb8e61
Further linting
daviswer Dec 17, 2024
bd0f1d3
More mypy type fix
daviswer Dec 17, 2024
6ab883b
Rename main_training.py to main_training_mamba.py
divya-kumari32 Dec 17, 2024
9e6b7f9
Added main_training_llama.py file
divya-kumari32 Dec 17, 2024
d8af68b
Rename fms_to_hf.py to fms_to_hf_mamba.py
divya-kumari32 Dec 17, 2024
6a39d7c
Added fms_to_hf_llama.py file
divya-kumari32 Dec 17, 2024
ccb1132
Delete fms_fsdp/utils/config_utils.py
divya-kumari32 Dec 17, 2024
dc9e68b
Added mamba variant 9.8b
divya-kumari32 Dec 17, 2024
507f5be
Incremental mypy fix
daviswer Dec 17, 2024
e0b6123
Fix imports (mypy)
daviswer Dec 17, 2024
dacd8f2
Rename adapters to work correctly
ani300 Dec 17, 2024
1b60b61
linting
ani300 Dec 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 39 additions & 15 deletions fms_fsdp/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def update_config(config, **kwargs):

def get_model_config(model_variant):
if model_variant == "llama2_70b":
llama_config = LLaMAConfig(
model_config = LLaMAConfig(
emb_dim=8192,
multiple_of=4096,
nheads=64,
Expand All @@ -33,7 +33,7 @@ def get_model_config(model_variant):
hidden_grow_factor=28672 / 8192,
)
elif model_variant == "llama2_34b":
llama_config = LLaMAConfig(
model_config = LLaMAConfig(
emb_dim=8192,
nheads=64,
kvheads=8,
Expand All @@ -43,27 +43,27 @@ def get_model_config(model_variant):
rope_theta=1000000.0,
)
elif model_variant == "llama2_13b":
llama_config = LLaMAConfig(
model_config = LLaMAConfig(
emb_dim=5120,
nheads=40,
nlayers=40,
hidden_grow_factor=13824 / 5120,
)
elif model_variant == "llama2_7b":
llama_config = LLaMAConfig(
model_config = LLaMAConfig(
hidden_grow_factor=11008 / 4096,
kvheads=32,
)
elif model_variant == "llama2_1.4b":
llama_config = LLaMAConfig(
model_config = LLaMAConfig(
emb_dim=2048,
nheads=16,
nlayers=24,
hidden_grow_factor=3,
kvheads=4,
)
elif model_variant == "llama3_8b":
llama_config = LLaMAConfig(
model_config = LLaMAConfig(
src_vocab_size=128256,
emb_dim=4096,
nheads=32,
Expand All @@ -74,7 +74,7 @@ def get_model_config(model_variant):
rope_theta=500000.0,
)
elif model_variant == "llama3_8b_4k":
llama_config = LLaMAConfig(
model_config = LLaMAConfig(
src_vocab_size=128256,
emb_dim=4096,
nheads=32,
Expand All @@ -85,7 +85,7 @@ def get_model_config(model_variant):
rope_theta=500000.0,
)
elif model_variant == "llama3_1.8b":
llama_config = LLaMAConfig(
model_config = LLaMAConfig(
src_vocab_size=128256,
emb_dim=2048,
nheads=16,
Expand All @@ -96,7 +96,7 @@ def get_model_config(model_variant):
rope_theta=500000.0,
)
elif model_variant == "llama3_1.8b_4k":
llama_config = LLaMAConfig(
model_config = LLaMAConfig(
src_vocab_size=128256,
emb_dim=2048,
nheads=16,
Expand All @@ -107,7 +107,7 @@ def get_model_config(model_variant):
rope_theta=500000.0,
)
elif model_variant == "llama3_3.2b":
llama_config = LLaMAConfig(
model_config = LLaMAConfig(
src_vocab_size=128256,
emb_dim=3072,
nheads=24,
Expand All @@ -118,7 +118,7 @@ def get_model_config(model_variant):
rope_theta=500000.0,
)
elif model_variant == "llama3_3.2b_4k":
llama_config = LLaMAConfig(
model_config = LLaMAConfig(
src_vocab_size=128256,
emb_dim=3072,
nheads=24,
Expand All @@ -129,7 +129,7 @@ def get_model_config(model_variant):
rope_theta=500000.0,
)
elif model_variant == "llama3_70b":
llama_config = LLaMAConfig(
model_config = LLaMAConfig(
src_vocab_size=128256,
emb_dim=8192,
nheads=64,
Expand All @@ -140,7 +140,7 @@ def get_model_config(model_variant):
rope_theta=500000.0,
)
elif model_variant == "llama3_70b_4k":
llama_config = LLaMAConfig(
model_config = LLaMAConfig(
src_vocab_size=128256,
emb_dim=8192,
nheads=64,
Expand All @@ -151,15 +151,39 @@ def get_model_config(model_variant):
rope_theta=500000.0,
)
elif model_variant == "llama3_194m_4k":
llama_config = LLaMAConfig(
model_config = LLaMAConfig(
src_vocab_size=128256,
emb_dim=1024,
nheads=8,
nlayers=10,
max_expected_seq_len=4096,
rope_theta=500000.0,
)
elif model_variant == "mamba_9.8b":
model_config = {
"d_model": 4096,
"d_intermediate": 14336,
"n_layer": 32,
"vocab_size": 128256,
"ssm_cfg": {"layer": "Mamba2"},
"attn_layer_idx": [9, 18, 27],
"attn_cfg": {
"causal": True,
"d_conv": 0,
"head_dim": 128,
"num_heads": 32,
"num_heads_kv": 8,
"out_proj_bias": False,
"qkv_proj_bias": False,
"rotary_emb_dim": 64,
},
"rms_norm": True,
"residual_in_fp32": True,
"fused_add_norm": True,
"pad_vocab_size_multiple": 16,
"tie_embeddings": False,
}
else:
raise ValueError(f"model variant {model_variant} not supported.")

return llama_config
return model_config
8 changes: 5 additions & 3 deletions fms_fsdp/utils/dataloader_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from fms_fsdp.utils.dataset_utils import (
ArrowHandler,
AutoHandler,
BufferDataset,
CheckpointDataset,
ParquetHandler,
Expand All @@ -16,6 +17,7 @@
_handler_map = {
"arrow": ArrowHandler,
"hf_parquet": ParquetHandler,
"auto": AutoHandler,
}


Expand Down Expand Up @@ -84,10 +86,10 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
assert (
cfg.file_type in _handler_map
), f"File type {cfg.file_type} is not recognized ({list(_handler_map.keys())})"
if cfg.file_type == "hf_parquet":
filehandler = ParquetHandler(cfg.tokenizer_path, cfg.col_name)
if cfg.file_type == "hf_parquet" or cfg.file_type == "auto":
filehandler = _handler_map[cfg.file_type](cfg.tokenizer_path, cfg.col_name)
else:
filehandler = _handler_map[cfg.file_type](cfg.col_name)
filehandler = _handler_map[cfg.file_type]
# Base reader layer
data = StreamingDocDataset(
cfg.data_path,
Expand Down
79 changes: 67 additions & 12 deletions fms_fsdp/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,11 +357,11 @@ def length(self, path: str):

def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set):
doc = reader.get_batch(index)[self.col_name]
if len(doc) > 0:
if doc[0].as_py() in drop_tokens:
doc = doc.slice(1, len(doc) - 1)
if doc[-1].as_py() in drop_tokens:
doc = doc.slice(0, len(doc) - 1)
if len(doc) > 0 and doc[0].as_py() in drop_tokens:
doc = doc.slice(1, len(doc) - 1)
# Recheck len for edge case where doc=[eos]
if len(doc) > 0 and doc[-1].as_py() in drop_tokens:
doc = doc.slice(0, len(doc) - 1)
return doc

def slice(self, doc: pa.UInt32Array, index: int, n_pull: int) -> List:
Expand All @@ -384,24 +384,79 @@ def is_legal(self, filepath: str):
return "parquet" in os.path.splitext(filepath)[1]

def open(self, path: str):
return pq.read_pandas(path, columns=[self.col_name])[self.col_name]
return pq.read_pandas(path, columns=[self.col_name], partitioning=None)[
self.col_name
]

def length(self, path: str):
return pq.read_pandas(path, columns=[]).num_rows
return pq.read_metadata(path).num_rows

def get(self, reader, index: int, drop_tokens: Set):
doc = self.tokenizer(str(reader[index]))["input_ids"]
if len(doc) > 0:
if doc[0] in drop_tokens:
doc = doc[1:]
if doc[-1] in drop_tokens:
doc = doc[:-1]
if len(doc) > 0 and doc[0] in drop_tokens:
doc = doc[1:]
# Recheck len for edge case where doc=[eos]
if len(doc) > 0 and doc[-1] in drop_tokens:
doc = doc[:-1]
return doc

def slice(self, doc: List, index: int, n_pull: int) -> List:
return doc[index : index + n_pull]


class AutoHandler(_ShardFileHandler):
def __init__(self, tokenizer_path: str, col_name: str = "text"):
self.PHandler = ParquetHandler(tokenizer_path, col_name)
self.AHandler = ArrowHandler()
self.current = _ShardFileHandler()

def is_legal(self, filepath: str):
return (
"parquet" in os.path.splitext(filepath)[1]
or "arrow" in os.path.splitext(filepath)[1]
)

def open(self, path: str):
"""
Open the file, to be indexed via self.get() method.
Avoid reading entire multi-Gb files when possible!
"""
if "arrow" in os.path.splitext(path)[1]:
self.current = self.AHandler
else:
self.current = self.PHandler
return self.current.open(path)

def length(self, path: str):
"""
Calculate the number of documents in the given file.
Avoid reading entire multi-Gb files when possible!
"""
if "arrow" in os.path.splitext(path)[1]:
return self.AHandler.length(path)
else:
return self.PHandler.length(path)

def get(self, reader, index: int, drop_tokens: Set):
"""
Given the output of self.open() and an index, return the document at that index.
Then, remove the first and/or last items if they appear in drop_tokens.
Try to avoid reading entire documents at a time in case of long documents,
but this is less important than avoiding reading entire files as above.
Output must support len().
"""
return self.current.get(reader, index, drop_tokens)

def slice(self, doc, index: int, n_pull: int) -> List:
"""
Given a long document, retrieve n_pull consecutive items starting from index.
Again, try to be memory-efficient when doing so, but efficiency in self.get()
and self.open() is far more important.
Must return a python list.
"""
return self.current.slice(doc, index, n_pull)


#### ------------------------- PIPELINE LAYERS ------------------------- ####


Expand Down
2 changes: 1 addition & 1 deletion fms_to_hf.py → fms_to_hf_llama.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import fire
import torch
from fms.models.hf import to_hf_api
from fms.models.hf.utils import to_hf_api
from fms.models.llama import LLaMA
from torch.distributed._shard.checkpoint import FileSystemReader, load_state_dict
from transformers import LlamaConfig, LlamaForCausalLM
Expand Down
37 changes: 37 additions & 0 deletions fms_to_hf_mamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import fire
from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from torch.distributed._shard.checkpoint import FileSystemReader, load_state_dict

from fms_fsdp.utils.config_utils import get_model_config


def main(model_variant, load_path, save_path, tokenizer_name_or_path):
print("Initializing model...")
config_data = get_model_config(model_variant)
mamba_config = MambaConfig(**config_data)
model = MambaLMHeadModel(mamba_config)

print(f"Reading state dict from {load_path}")
state_dict = {"model_state": model.state_dict()}
load_state_dict(
state_dict=state_dict, storage_reader=FileSystemReader(load_path), no_dist=True
)

print("Loading state dict into the model...")
model.load_state_dict(state_dict["model_state"])

print("Saving model to HF-compatible format...")
model.save_pretrained(save_path)

print("Copying tokenizer...")
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
tokenizer.save_pretrained(save_path)

print(f"Model saving at {save_path}")


if __name__ == "__main__":
fire.Fire(main)
File renamed without changes.
Loading
Loading