Skip to content

Commit

Permalink
Propagate device_map to hf model
Browse files Browse the repository at this point in the history
  • Loading branch information
stevelaskaridis committed Feb 20, 2024
1 parent 46cc4bc commit 9568e83
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions dsp/modules/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, model: str, checkpoint: Optional[str] = None, is_client: bool
model (str): HF model identifier to load and use
checkpoint (str, optional): load specific checkpoints of the model. Defaults to None.
is_client (bool, optional): whether to access models via client. Defaults to False.
hf_device_map (str, optional): HF config strategy to load the model.
hf_device_map (str, optional): HF config strategy to load the model.
Recommeded to use "auto", which will help loading large models using accelerate. Defaults to "auto".
"""

Expand Down Expand Up @@ -72,14 +72,20 @@ def __init__(self, model: str, checkpoint: Optional[str] = None, is_client: bool
# self.model = AutoModelClass.from_pretrained(peft_config.base_model_name_or_path, return_dict=True, load_in_8bit=True, device_map=hf_device_map)
# self.model = PeftModel.from_pretrained(self.model, checkpoint)
# else:
self.model = AutoModelClass.from_pretrained(checkpoint).to(self.device)
if self.device_map:
self.model = AutoModelClass.from_pretrained(checkpoint, device_map=self.device_map)
else:
self.model = AutoModelClass.from_pretrained(checkpoint).to(self.device)
else:
self.model = AutoModelClass.from_pretrained(model).to(self.device)
if self.device_map:
self.model = AutoModelClass.from_pretrained(model, device_map=self.device_map)
else:
self.model = AutoModelClass.from_pretrained(model).to(self.device)
self.drop_prompt_from_output = False
except ValueError:
self.model = AutoModelForCausalLM.from_pretrained(
model if checkpoint is None else checkpoint,
device_map=hf_device_map
device_map=self.device_map
)
self.drop_prompt_from_output = True
self.tokenizer = AutoTokenizer.from_pretrained(model)
Expand Down

0 comments on commit 9568e83

Please sign in to comment.