Skip to content

Commit

Permalink
tune specific params in the base model (#7745)
Browse files Browse the repository at this point in the history
* tune specific params in the base model

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* renamed

Signed-off-by: arendu <[email protected]>

* check for attr before calling

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: arendu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and jubick1337 committed Jan 17, 2024
1 parent 1fede57 commit 166adee
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ model:

ia3_tuning:
layer_selection: null # selects in which layers to add ia3 adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers

selective_tuning:
tunable_base_param_names: ["self_attention", "word_embeddings"] # TODO: regex support @adithyre

data:
train_ds:
Expand Down
27 changes: 23 additions & 4 deletions nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class NLPAdapterModelMixin:

def __init__(self, *args, **kwargs):
self.use_peft = False
self.tunable_base_param_names = []
self.setup_complete = False
self.use_ptuning_only = False
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -192,10 +193,14 @@ def add_adapter(self, peft_cfgs: Union[PEFTConfig, List[PEFTConfig]]):

logging.info(f"After adding PEFT params:\n{self.summarize()}")
self.adapter_keys = self._get_all_keys() - self.base_keys
self.tunable_base_param_keys = set()

for cfg in peft_cfgs:
if cfg.weight_tying:
if hasattr(cfg, "weight_tying") and cfg.weight_tying:
self.tie_weights(cfg)

if hasattr(cfg, "tunable_base_param_names") and cfg.tunable_base_param_names:
self.set_tunable_base_params(cfg)
self.use_peft = True

def _get_config_and_state_dict_from_nemo(self, filepath, map_location):
Expand Down Expand Up @@ -239,6 +244,12 @@ def setup_optimizer_param_groups(self):
module.set_enabled_adapters(enabled=True)
module.unfreeze_enabled_adapters() # selectively unfreeze the adapter modules.
opt_params += [p for p in module.parameters() if p.requires_grad]

for name, param in self.named_parameters():
if name in self.tunable_base_param_keys:
param.requires_grad = True
opt_params += [param]

self._optimizer_param_groups = ({"params": opt_params},)
logging.info(f"Optimizer groups set:\n{self.summarize()}")
else:
Expand Down Expand Up @@ -282,9 +293,17 @@ def load_adapters(
), "Inferring peft scheme is only supported for .nemo checkpoints. Please supply the `peft_cfgs` argument."
peft_cfgs = [PEFT_CONFIG_MAP[conf.peft.peft_scheme](conf)]
self.add_adapter(peft_cfgs)
assert set(state_dict.keys()) == self.adapter_keys
assert set(state_dict.keys()) == self.adapter_keys.union(self.tunable_base_param_keys)
super().load_state_dict(state_dict, strict=False)

def set_tunable_base_params(self, peft_cfg):
for n, p in self.named_parameters():
for tpn in peft_cfg.tunable_base_param_names:
# TODO: simplistic param name matching, should support regex-like syntax @adithyare
if f".{tpn}." in n:
self.tunable_base_param_keys.add(n)
p.requires_grad = True # We set these to true to trigger setup_optimizer_param_groups

def tie_weights(self, peft_cfg):
pos_idx = 0

Expand Down Expand Up @@ -328,7 +347,7 @@ def get_peft_state_dict(self):
"""
state_dict = self.model.state_dict(prefix=self.model_prefix)
peft_state_dict = {}
for k in self.adapter_keys:
for k in self.adapter_keys.union(self.tunable_base_param_keys):
# state_dict keys needs to be in non-O2 format and will be corrected in PEFTSaveRestoreConnector if O2=True
new_k = k.replace("model.module.", "model.", 1)
peft_state_dict[new_k] = state_dict[k]
Expand Down Expand Up @@ -360,7 +379,7 @@ def load_state_dict(self, state_dict, strict: bool = True):
# setting strict=False will ignore the missing keys (which are not being updated anyway)
# explicitly check if state_dict.keys matches all the expected self.adapter_keys since we don't have the
# safety in strict=True anymore.
assert set(state_dict.keys()) == self.adapter_keys
assert set(state_dict.keys()) == self.adapter_keys.union(self.tunable_base_param_keys)
super().load_state_dict(state_dict, strict=False)
else:
super().load_state_dict(state_dict, strict=True)
Expand Down
12 changes: 11 additions & 1 deletion nemo/collections/nlp/parts/peft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,21 @@ def __init__(self, peft_cfg: DictConfig, name_key_to_cfg: Dict):
self.name_key_to_cfg = name_key_to_cfg

self.layer_selection = peft_cfg.get("layer_selection", None)
self.weight_tying = peft_cfg.get("weight_tying", False)
self.weight_tying = peft_cfg.get(
"weight_tying", False
) # TODO: move this attr to LoraPEFTConfig and AdapterPEFTConfig classes

def get_config_dict(self):
return self.name_key_to_cfg


class SelectivePEFTConfig(PEFTConfig):
def __init__(self, cfg):
selective_cfg = cfg.peft.selective_tuning
super().__init__(selective_cfg, name_key_to_cfg={})
self.tunable_base_param_names = selective_cfg.get("tunable_base_param_names", [])


class LoraPEFTConfig(PEFTConfig):
def __init__(self, cfg):
lora_cfg = cfg.peft.lora_tuning
Expand Down Expand Up @@ -195,6 +204,7 @@ def __init__(self, cfg):
"ia3": IA3PEFTConfig,
"ptuning": PtuningPEFTConfig,
"lora": LoraPEFTConfig,
"selective": SelectivePEFTConfig,
'none': None,
None: None,
}

0 comments on commit 166adee

Please sign in to comment.