From 31c884e93469dd1391bb54eb1468311c38bbccac Mon Sep 17 00:00:00 2001 From: Yuanhao Zhai Date: Wed, 10 Apr 2024 05:39:02 -0400 Subject: [PATCH] FEAT Allow load_adapter to use different device (#1631) --- src/peft/peft_model.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index d44a2be5af..edf0b92f8e 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -799,7 +799,14 @@ def _update_offload(self, offload_index: dict[str, dict[str, str]], adapters_wei os.makedirs(base_name) safe_save_file(safe_dict, new_fname, metadata=metadata) - def load_adapter(self, model_id: str, adapter_name: str, is_trainable: bool = False, **kwargs: Any): + def load_adapter( + self, + model_id: str, + adapter_name: str, + is_trainable: bool = False, + torch_device: Optional[str] = None, + **kwargs: Any, + ): """ Load a trained adapter into the model. @@ -816,13 +823,16 @@ def load_adapter(self, model_id: str, adapter_name: str, is_trainable: bool = Fa is_trainable (`bool`, *optional*, defaults to `False`): Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and can only be used for inference. + torch_device (`str`, *optional*, defaults to None): + The device to load the adapter on. If `None`, the device will be inferred. kwargs: (`optional`): Additional arguments to modify the way the adapter is loaded, e.g. the token for Hugging Face Hub. """ from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING hf_hub_download_kwargs, kwargs = self._split_kwargs(kwargs) - torch_device = infer_device() + if torch_device is None: + torch_device = infer_device() if adapter_name not in self.peft_config: # load the config