diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index 39d2459a9..0a4a3f0b8 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -154,6 +154,9 @@ def from_quantized( verify_hash: Optional[Union[str, List[str]]] = None, **kwargs, ): + if device is not None: + device = normalize_device(device) + # TODO need to normalize backend and others in a unified api device = parse_device_map(device, device_map)