Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Sequence Parallelism #6506

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def get_console_scripts() -> List[str]:
"modelscope": ["modelscope"],
"openmind": ["openmind"],
"dev": ["pre-commit", "ruff", "pytest"],
"sp": ["ring-flash-attn", "flash-attn"],
}


Expand Down
8 changes: 8 additions & 0 deletions src/llamafactory/data/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,11 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
block_diag_attn: bool = False
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
compute_dtype: "torch.dtype" = torch.float32
require_position_ids: bool = False

def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
if not self.require_position_ids:
features = [{k: v for k, v in d.items() if k != "position_ids"} for d in features]
features = super().__call__(features)
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
Expand All @@ -129,6 +132,8 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
Data collator for pairwise data.
"""

require_position_ids: bool = False

def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
r"""
Pads batched data to the longest sequence in the batch.
Expand All @@ -146,6 +151,9 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tenso
"images": feature["images"],
"videos": feature["videos"],
}
if self.require_position_ids:
# if requires, would be padded to cutoff_len in preprocessing
target_feature["position_ids"] = feature[f"{key}_position_ids"]
concatenated_features.append(target_feature)

return super().__call__(concatenated_features)
Expand Down
13 changes: 13 additions & 0 deletions src/llamafactory/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,16 @@ def split_dataset(
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
dataset = dataset.train_test_split(test_size=val_size, seed=seed)
return DatasetDict({"train": dataset["train"], "validation": dataset["test"]})


# modified from https://github.com/jzhang38/EasyContext/
def preprocess_sp_dataset(seq_ids, world_size, sequence_parallel_mode):
if sequence_parallel_mode == "zigzag-ring":
step = len(seq_ids) // (2 * world_size)
value_chunks = [seq_ids[s : s + step] for s in range(0, len(seq_ids), step)]
local_values = list()
for rank in range(world_size):
local_values.append(value_chunks[rank] + value_chunks[2 * world_size - rank - 1])
return local_values
else:
raise NotImplementedError("Other sequence parallel modes are to be implemented.")
74 changes: 72 additions & 2 deletions src/llamafactory/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import os
import sys
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
Expand All @@ -21,10 +22,10 @@
from transformers.utils.versions import require_version

from ..extras import logging
from ..extras.constants import FILEEXT2TYPE
from ..extras.constants import FILEEXT2TYPE, IGNORE_INDEX
from ..extras.misc import has_tokenized_data
from .aligner import align_dataset
from .data_utils import merge_dataset, split_dataset
from .data_utils import merge_dataset, preprocess_sp_dataset, split_dataset
from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func

Expand Down Expand Up @@ -222,6 +223,75 @@ def _get_preprocessed_dataset(
return dataset


def sequence_parallel_decorator(get_dataset):
@functools.wraps(get_dataset)
def sequence_parallel_processor(*args, **kwargs):
dataset_module = get_dataset(*args, **kwargs)
# get arguments
# NOTE: hard-coded indexing seems inevitable in such decorator style implementation?
model_args, data_args, training_args = args[1], args[2], args[3]
tokenizer = kwargs["tokenizer"]
if model_args.sequence_parallel_size > 1:

def pad_sequence(examples):
max_length = data_args.cutoff_len
input_pad_token_id = tokenizer.pad_token_id
assert data_args.ignore_pad_token_for_loss
label_pad_token_id = IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id

for k, v in examples.items():
if k.endswith("input_ids"):
pad_token_id = input_pad_token_id
elif k.endswith("labels"):
pad_token_id = label_pad_token_id
# shift labels here
v = [seq[1:] for seq in v]
elif k.endswith("attention_mask"):
pad_token_id = 0
elif k.endswith("position_ids"):
pad_token_id = max_length - 1 # pad the max position id
elif k == "images" or k == "videos":
pad_token_id = -1
continue # TODO: haven't tested multi-modal yet
else:
raise NotImplementedError(f"Unexpected dataset key: {k}")
examples[k] = [seq + [pad_token_id] * (max_length - len(seq)) for seq in v]

return examples

# sp for Sequence Parallel
def sp_split(examples):
for k, v in examples.items():
chunks = list()
for row in v:
if row is None:
chunks += [None] * model_args.sequence_parallel_size
else:
chunks += preprocess_sp_dataset(
row, model_args.sequence_parallel_size, model_args.sequence_parallel_mode
)
examples[k] = chunks
return examples

# padding then split
for k in dataset_module:
dataset = dataset_module[k]
if data_args.shuffle_for_sequence_parallel:
dataset = dataset.shuffle(seed=training_args.seed)
padded_dataset = dataset.map(pad_sequence, batched=True)
sp_dataset = padded_dataset.map(sp_split, batched=True)
dataset_module[k] = sp_dataset

else:
# no sequence parallelism
pass

return dataset_module

return sequence_parallel_processor


@sequence_parallel_decorator
def get_dataset(
template: "Template",
model_args: "ModelArguments",
Expand Down
2 changes: 2 additions & 0 deletions src/llamafactory/data/processors/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,11 @@ def preprocess_pairwise_dataset(
)
model_inputs["chosen_input_ids"].append(chosen_input_ids)
model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
model_inputs["chosen_position_ids"].append(list(range(len(chosen_input_ids))))
model_inputs["chosen_labels"].append(chosen_labels)
model_inputs["rejected_input_ids"].append(rejected_input_ids)
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
model_inputs["rejected_position_ids"].append(list(range(len(rejected_input_ids))))
model_inputs["rejected_labels"].append(rejected_labels)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
Expand Down
2 changes: 2 additions & 0 deletions src/llamafactory/data/processors/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def preprocess_supervised_dataset(
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["position_ids"].append(list(range(len(input_ids))))
model_inputs["labels"].append(labels)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
Expand Down Expand Up @@ -204,6 +205,7 @@ def preprocess_packed_supervised_dataset(

model_inputs["input_ids"].append(packed_input_ids)
model_inputs["attention_mask"].append(packed_attention_masks)
model_inputs["position_ids"].append(list(range(len(packed_input_ids))))
model_inputs["labels"].append(packed_labels)
model_inputs["images"].append(packed_images or None)
model_inputs["videos"].append(packed_videos or None)
Expand Down
6 changes: 6 additions & 0 deletions src/llamafactory/hparams/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ class DataArguments:
default=False,
metadata={"help": "Enable sequence packing without cross-attention."},
)
shuffle_for_sequence_parallel: bool = field(
default=True,
metadata={
"help": "Shuffle dataset before sequence parallel preprocessing (should shuffle before pad & split)."
},
)
tool_format: Optional[str] = field(
default=None,
metadata={"help": "Tool format to use for constructing function calling examples."},
Expand Down
10 changes: 10 additions & 0 deletions src/llamafactory/hparams/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,16 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
default=False,
metadata={"help": "Whether or not to enable liger kernel for faster training."},
)
sequence_parallel_size: int = field(
default=1,
metadata={
"help": "Number of GPUs to process one data sequence. Values greater than 1 means enabling sequence parallelism."
},
)
sequence_parallel_mode: Literal["zigzag-ring", "llama3", "ulysses"] = field(
default="zigzag-ring",
metadata={"help": "Specific mode of sequence parallel implementation."},
)
moe_aux_loss_coef: Optional[float] = field(
default=None,
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
Expand Down
18 changes: 18 additions & 0 deletions src/llamafactory/hparams/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,24 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")

if data_args.cutoff_len % model_args.sequence_parallel_size != 0:
raise ValueError("cutoff_len must be a multiple of sequence_parallel_size.")

if model_args.sequence_parallel_size > 1:
if (data_args.cutoff_len // model_args.sequence_parallel_size) % 8 != 0:
tmp_sp_len = data_args.cutoff_len // model_args.sequence_parallel_size
closest_cutoff_len = int((tmp_sp_len + (8 - tmp_sp_len % 8)) * model_args.sequence_parallel_size)
logger.warning_rank0(
f"cutoff_len must be a multiple of 8 after dividing sequence_parallel_size. With sequence parallel, we first pad to cutoff_len and then split the sequence. \nAll the DataCollators pad to multiple of 8, which is hard-coded in LLaMA-Factory. If the splitted sequences are not already mutliple of 8, padding it to be would effectively change the original sequence and is wrong. \nWe automatically increase the cutoff_len = {data_args.cutoff_len} you set to the larger but closest number satifying this condition to be {closest_cutoff_len}."
)
data_args.cutoff_len = closest_cutoff_len
# raise ValueError(f"cutoff_len must be a multiple of 8 after dividing sequence_parallel_size. With sequence parallel, we first pad to cutoff_len and then split the sequence. \nAll the DataCollators pad to multiple of 8, which is hard-coded in LLaMA-Factory. If the splitted sequences are not already mutliple of 8, padding it to be would effectively change the original sequence and is wrong. \nThe closest cutoff_len satifying this condition is {closest_cutoff_len}. Try setting --cutoff_len {closest_cutoff_len}")

if model_args.sequence_parallel_mode == "zigzag-ring" and data_args.neat_packing:
raise ValueError(
"zigzag ring attention does not support neat_packing. Disable neat_packing or use other sequence_parallel_mode."
)

if data_args.neat_packing and not data_args.packing:
logger.warning_rank0("`neat_packing` requires `packing` is True. Change `packing` to True.")
data_args.packing = True
Expand Down
11 changes: 11 additions & 0 deletions src/llamafactory/model/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .model_utils.liger_kernel import apply_liger_kernel
from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .model_utils.sequence_parallel import apply_sequence_parallel
from .model_utils.unsloth import load_unsloth_pretrained_model
from .model_utils.valuehead import load_valuehead_params
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
Expand Down Expand Up @@ -132,7 +133,16 @@ def load_model(
init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
if (
model_args.sequence_parallel_size > 1
and hasattr(config, "attention_dropout")
and config.attention_dropout != 0.0
):
logger.warning_rank0("Sequence Parallel doesn't support attention_dropout yet. Forcing it to zero.")
config.attention_dropout = 0.0

apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
sequence_parallel_group = apply_sequence_parallel(model_args) # monkey patching, similar to liger_kernel

model = None
lazy_load = False
Expand Down Expand Up @@ -210,4 +220,5 @@ def load_model(
)
)

model.sequence_parallel_group = sequence_parallel_group
return model
71 changes: 71 additions & 0 deletions src/llamafactory/model/model_utils/sequence_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# modified from
# 1. https://github.com/zhuzilin/ring-flash-attention/blob/main/ring_flash_attn/adapters/hf_adapter.py
# 2. https://github.com/jzhang38/EasyContext/
from functools import partial

import torch.distributed as dist
import transformers
import transformers.modeling_flash_attention_utils
from ring_flash_attn import zigzag_ring_flash_attn_func


def new_flash_attn_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=0,
sliding_window=None,
is_causal=True,
group=None,
**kwargs,
):
attn_output = zigzag_ring_flash_attn_func(
query_states, key_states, value_states, dropout, causal=is_causal, group=group
)

return attn_output


def init_sp_group(sp_size):
assert dist.is_initialized()
world_size = dist.get_world_size()
assert world_size % sp_size == 0, "Total number of GPUs must be a multiple of sequence_parallel_size."

sp_group_num = world_size // sp_size
sp_ranks_list = [list(range(i * sp_size, i * sp_size + sp_size)) for i in range(sp_group_num)]

sp_groups = [dist.new_group(sp_ranks_this) for sp_ranks_this in sp_ranks_list]

global_rank_this = dist.get_rank()
sp_idx = global_rank_this // sp_size
return sp_groups[sp_idx]


def apply_sequence_parallel(model_args):
if model_args.sequence_parallel_size == 1:
return None # no sequence parallelism

# init sequence-parallel groups here
group_this = init_sp_group(model_args.sequence_parallel_size)

try:
# old_flash_attention_forward = transformers.modeling_flash_attention_utils._flash_attention_forward
if model_args.sequence_parallel_mode == "zigzag-ring":
new_flash_attention_forward = partial(new_flash_attn_forward, group=group_this)
# assert check_params(old_flash_attention_forward, new_flash_attention_forward)
else:
raise NotImplementedError("Other sequence parallel modes are to be implemented.")

# monkey patching
transformers.modeling_flash_attention_utils._flash_attention_forward = new_flash_attention_forward
except Exception:
raise ValueError(
f"The current transformer version {transformers.__version__} is not supported. "
"please pip install transformers within the versions that llama-factory requires. "
"If the code failed with the latest version, "
"please file an issue to https://github.com/Qihoo360/360-llama-factory"
)

return group_this
Loading