Skip to content

Commit

Permalink
Jit with peft (#11586)
Browse files Browse the repository at this point in the history
* move jitransform at the end

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* add docstring & post-init

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Add remove_extra_batch_keys and remove align_labels

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Run JitTransform on_train_epoch_start

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* add --use-torch-jit option

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* add docstrings

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* fix

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* pep8

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Apply isort and black reformatting

Signed-off-by: akoumpa <[email protected]>

---------

Signed-off-by: Alexandros Koumparoulis <[email protected]>
Signed-off-by: akoumpa <[email protected]>
Co-authored-by: akoumpa <[email protected]>
  • Loading branch information
akoumpa and akoumpa authored Dec 17, 2024
1 parent 993e575 commit de0b2e2
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 29 deletions.
17 changes: 15 additions & 2 deletions examples/llm/peft/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from lightning.pytorch.loggers import WandbLogger
from nemo import lightning as nl
from nemo.collections import llm
from nemo.lightning.pytorch.callbacks import JitConfig, JitTransform


def make_squad_hf_dataset(tokenizer):
Expand Down Expand Up @@ -53,7 +54,7 @@ def formatting_prompts_func(examples):
return datamodule


if __name__ == '__main__':
def main():
import argparse

parser = argparse.ArgumentParser()
Expand All @@ -63,6 +64,7 @@ def formatting_prompts_func(examples):
parser.add_argument('--accelerator', default='gpu', choices=['gpu'])
parser.add_argument('--max-steps', type=int, default=100)
parser.add_argument('--wandb-project', type=str, default=None)
parser.add_argument('--use-torch-jit', action='store_true')
args = parser.parse_args()

wandb = None
Expand All @@ -74,11 +76,17 @@ def formatting_prompts_func(examples):
)
grad_clip = 0.5
if args.strategy == 'fsdp':
# See: https://github.com/Lightning-AI/pytorch-lightning/blob/8ad3e29816a63d8ce5c00ac104b14729a4176f4f/src/lightning/pytorch/plugins/precision/fsdp.py#L81
# See:
# https://github.com/Lightning-AI/pytorch-lightning/blob/8ad3e29816a63d8ce5c00ac104b14729a4176f4f/src/lightning/pytorch/plugins/precision/fsdp.py#L81
grad_clip = None
use_dist_samp = False
tokenizer = llm.HFAutoModelForCausalLM.configure_tokenizer(args.model)

callbacks = []
if args.use_torch_jit:
jit_config = JitConfig(use_torch=True, torch_kwargs={'dynamic': True}, use_thunder=False)
callbacks = [JitTransform(jit_config)]

llm.api.finetune(
model=llm.HFAutoModelForCausalLM(args.model),
data=make_squad_hf_dataset(tokenizer.tokenizer),
Expand All @@ -94,6 +102,7 @@ def formatting_prompts_func(examples):
gradient_clip_val=grad_clip,
use_distributed_sampler=use_dist_samp,
logger=wandb,
callbacks=callbacks,
),
optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)),
log=None,
Expand All @@ -102,3 +111,7 @@ def formatting_prompts_func(examples):
dim=32,
),
)


if __name__ == '__main__':
main()
35 changes: 30 additions & 5 deletions examples/llm/sft/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
from nemo import lightning as nl
from nemo.collections import llm
from nemo.lightning.pytorch.accelerate.transformer_engine import is_te_accelerated
from nemo.lightning.pytorch.callbacks import ModelCallback
from nemo.lightning.pytorch.callbacks import JitConfig, JitTransform


class SquadDataModuleWithPthDataloader(llm.SquadDataModule):
"""Creates a squad dataset with a PT dataloader"""

def _create_dataloader(self, dataset, mode, **kwargs) -> DataLoader:
return DataLoader(
dataset,
Expand All @@ -37,17 +39,30 @@ def _create_dataloader(self, dataset, mode, **kwargs) -> DataLoader:


def squad(tokenizer) -> pl.LightningDataModule:
"""Instantiates a SquadDataModuleWithPthDataloader and return it
Args:
tokenizer (AutoTokenizer): the tokenizer to use
Returns:
pl.LightningDataModule: the dataset to train with.
"""
return SquadDataModuleWithPthDataloader(
tokenizer=tokenizer,
seq_length=2048,
seq_length=512,
micro_batch_size=2,
global_batch_size=128, # assert gbs == mbs * accumulate_grad_batches
num_workers=0,
dataset_kwargs={"sanity_check_dist_workers": False},
dataset_kwargs={
"sanity_check_dist_workers": False,
"pad_to_max_length": True,
"get_attention_mask_from_fusion": True,
},
)


if __name__ == '__main__':
def main():
"""Example script to run SFT with a HF transformers-instantiated model on squad."""
import argparse

parser = argparse.ArgumentParser()
Expand All @@ -60,6 +75,7 @@ def squad(tokenizer) -> pl.LightningDataModule:
parser.add_argument("--fp8-autocast", default=False, action='store_true')
parser.add_argument('--wandb-project', type=str, default=None)
parser.add_argument('--model-save-path', type=str, default=None)
parser.add_argument('--use-torch-jit', action='store_true')
args = parser.parse_args()

wandb = None
Expand Down Expand Up @@ -87,6 +103,11 @@ def squad(tokenizer) -> pl.LightningDataModule:
model = llm.HFAutoModelForCausalLM(model_name=args.model, model_accelerator=model_accelerator)
tokenizer = model.tokenizer

callbacks = []
if args.use_torch_jit:
jit_config = JitConfig(use_torch=True, torch_kwargs={'dynamic': False}, use_thunder=False)
callbacks = [JitTransform(jit_config)]

llm.api.finetune(
model=model,
data=squad(tokenizer),
Expand All @@ -101,8 +122,8 @@ def squad(tokenizer) -> pl.LightningDataModule:
accumulate_grad_batches=10,
gradient_clip_val=grad_clip,
use_distributed_sampler=use_dist_samp,
callbacks=[],
logger=wandb,
callbacks=callbacks,
),
optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)),
log=None,
Expand All @@ -116,3 +137,7 @@ def squad(tokenizer) -> pl.LightningDataModule:

if args.model_save_path is not None:
model.save_pretrained(args.model_save_path)


if __name__ == '__main__':
main()
11 changes: 9 additions & 2 deletions nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
io,
)
from nemo.lightning.base import NEMO_MODELS_CACHE
from nemo.lightning.pytorch.callbacks import PEFT, ModelTransform
from nemo.lightning.pytorch.callbacks import PEFT, JitTransform, ModelTransform
from nemo.utils import logging
from nemo.utils.get_rank import is_global_rank_zero

Expand Down Expand Up @@ -875,7 +875,14 @@ def _setup(
trainer.callbacks.append(model_transform)
else:
trainer.callbacks.append(ModelTransform())

# Move jit callback at the end ensure it's applied on top of any model transformations (peft)
jit_cb = None
for i, cb in enumerate(trainer.callbacks):
if isinstance(cb, JitTransform):
assert jit_cb is None
jit_cb = trainer.callbacks.pop(i)
if jit_cb is not None:
trainer.callbacks.append(jit_cb)
return app_state


Expand Down
53 changes: 35 additions & 18 deletions nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,11 @@
def masked_cross_entropy(logits, targets, mask=None):
if mask is not None:
loss = F.cross_entropy(logits, targets, reduction='none')
return torch.mean(loss[mask == 1])
return torch.mean(loss * mask.view(-1))
else:
return F.cross_entropy(logits, targets)


def align_labels(logits, labels):
logits = logits.float()
n_cls = logits.shape[-1]
if logits.shape[-2] == labels.shape[-1]:
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
elif logits.shape[-2] == labels.shape[-1] + 1:
logits = logits[..., :-1, :].contiguous()
else:
raise ValueError("Mismatched labels and logits shapes (" + str(labels.shape) + " " + str(logits.shape))
return logits.view(-1, n_cls), labels.view(-1)


class HFAutoModelForCausalLM(pl.LightningModule, io.IOMixin, fn.FNMixin):
def __init__(
self,
Expand Down Expand Up @@ -111,27 +98,42 @@ def training_step(self, batch):
labels = batch.pop('labels').to(self.model.device)
loss_mask = batch.pop('loss_mask', None)

# GPTSFTDataset emits `tokens` instead of `input_ids`
if not 'input_ids' in batch and 'tokens' in batch:
batch['input_ids'] = batch['tokens']
batch = self._remove_extra_batch_keys(batch)

outputs = self.forward(batch)

# Prepare for loss calculation
logits, labels = align_labels(outputs.logits.float(), labels)
logits = outputs.logits.float()
n_cls = logits.shape[-1]
logits, labels = logits.view(-1, n_cls), labels.view(-1)
assert logits.shape[-2] == labels.shape[-1]

loss = self.loss_fn(logits, labels, loss_mask)
self.log('train_log', loss, on_step=True, on_epoch=True, prog_bar=True)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
return loss

@torch.no_grad
def validation_step(self, batch, batch_idx):
labels = batch.pop('labels').to(self.model.device)
loss_mask = batch.pop('loss_mask', None)

# GPTSFTDataset emits `tokens` instead of `input_ids`
if not 'input_ids' in batch and 'tokens' in batch:
batch['input_ids'] = batch['tokens']
batch = self._remove_extra_batch_keys(batch)

outputs = self.forward(**batch)

logits, labels = align_labels(outputs.logits.float(), labels)
# Prepare for loss calculation
logits = outputs.logits.float()
n_cls = logits.shape[-1]
logits, labels = logits.view(-1, n_cls), labels.view(-1)
assert logits.shape[-2] == labels.shape[-1]
loss = self.loss_fn(logits, labels, loss_mask)

loss = self.loss_fn(logits, labels, loss_mask)
self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True)

def save_pretrained(self, path):
Expand All @@ -141,3 +143,18 @@ def save_pretrained(self, path):
self._tokenizer.save_pretrained(path)
else:
logging.warning("A tokenizer wasn't created before to save.")

def _remove_extra_batch_keys(self, batch, reserved_keys=['labels', 'loss_mask']):
"""Remove extra keys from batch that are not kwargs in model's forward
Args:
batch (dict): dictionary of tensors.
Returns:
dict: dictionary of tensors; keys that are not in model's forward are removed.
"""
import inspect

fwd_signature = inspect.signature(self.model.forward)
allowed_keys = list(fwd_signature.parameters.keys()) + reserved_keys
return {k: batch[k] for k in allowed_keys if k in batch}
68 changes: 66 additions & 2 deletions nemo/lightning/pytorch/callbacks/jit_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@


def extract_module_attr_name(pl_module: "pl.LightningModule") -> str:
"""Extracts the held nn.Module from a pl.LightningModule, will try "module", "model", or fail.
Args:
pl_module (pl.LightningModule): the LightningModule used in training.
Raises:
ValueError: if the pl_module has neither a .mdoel or .module
Returns:
str: the attr-name of the nn.Module
"""
if hasattr(pl_module, 'module'):
return 'module'
elif hasattr(pl_module, 'model'):
Expand All @@ -31,12 +42,34 @@ def extract_module_attr_name(pl_module: "pl.LightningModule") -> str:


def listify(x):
"""Wraps input in a list, if not already a list.
Args:
x (Anything): the input, can be anything.
Returns:
Anything | list(Anything): Anything (if it's already a list) o/w list(Anything)
"""
if not isinstance(x, list):
return [x]
return x


def get_modules_from_selector(model, module_selector):
"""Iterator over model's modules whose FQN match the module_selector.
Args:
model (nn.Module): the model to iterate over.
module_selector (str): module selector, if empty or '*' will return the whole model. If
there's an asterisk in the name will match it as a regexp.
Raises:
AttributeError: if the user provides an invalid selector.
AttributeError: if user's selector selects a non-nn.Module attribute.
Yields:
Iterator(nn.Module): iterator over modules whose FQN matches module_selector
"""
if module_selector is None or module_selector == '' or module_selector == '*':
yield model
return
Expand All @@ -50,7 +83,7 @@ def get_modules_from_selector(model, module_selector):
# handle wildcard selector
# TODO(@akoumparouli): support more complex selectors e.g. net_b.*.net_c.*.conv
for name, module in tmp.named_children():
if re.match(item, name):
if re.match(item.replace('*', '.*'), name):
yield module
return

Expand All @@ -65,6 +98,15 @@ def get_modules_from_selector(model, module_selector):


def compile_module(config, module):
"""Jit-compiles an nn.Module
Args:
config (JitConfig): jit config
module (nn.Module): the module to be compiled
Returns:
nn.Module: the (potentially) compiled module
"""
if config.use_torch:
module.compile(**config.torch_kwargs)
return True
Expand All @@ -88,12 +130,26 @@ def compile_module(config, module):

@dataclass
class JitConfig:
"""Config POD for Jit transforms (e.g. torch.compile or thunder)
Options:
- module_selector (str): reg-exp to match modules to apply JitTransform to, useful for multi-trunk
models where you want to apply it on one of them only. If empty will apply transform to root
module.
- use_torch (bool): whether to use torch.compile or not.
- torch_kwargs (dict): kwargs to pass to torch.compile.
- use_thunder (bool): whether to use thunder or not.
- profile_thunder (bool): toggle for thunder's profiler.
"""

module_selector: str = ''
use_torch: bool = False
torch_kwargs: dict = field(default_factory=dict)
use_thunder: bool = False
profile_thunder: bool = False

def __post_init__(self):
assert not (self.use_torch and self.use_thunder), "use_torch cannot be used at the same time with use_thunder"


class JitTransform(Callback, IOMixin):
"""
Expand All @@ -112,7 +168,15 @@ def __init__(self, config: JitConfig):
self.config = config
assert not (self.config.use_torch and self.config.use_thunder)

def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Jit-compiles the model at the start of the epoch.
While other events such as on_train_start are more suitable, we use on_train_epoch_start
since that is what is used in peft (we want to jit after adding the adapters).
Args:
trainer (pl.Trainer): PTL trainer
pl_module (pl.LightningModule): PTL module
"""
if self.config is None:
return
if not self.config.use_thunder and not self.config.use_torch:
Expand Down

0 comments on commit de0b2e2

Please sign in to comment.