Skip to content

Commit

Permalink
Add docs to the requantize(...) function explaining why it was copied…
Browse files Browse the repository at this point in the history
… from optimum-quanto.
  • Loading branch information
RyanJDick committed Aug 21, 2024
1 parent d11dc6d commit 38c2e78
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions invokeai/backend/requantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,21 @@
import torch
from optimum.quanto.quantize import _quantize_submodule

# def custom_freeze(model: torch.nn.Module):
# for name, m in model.named_modules():
# if isinstance(m, QModuleMixin):
# m.weight =
# m.freeze()


def requantize(
model: torch.nn.Module,
state_dict: Dict[str, Any],
quantization_map: Dict[str, Dict[str, str]],
device: torch.device = None,
device: torch.device | None = None,
):
"""This function was initially copied from:
https://github.com/huggingface/optimum-quanto/blob/832f7f5c3926c91fe4f923aaaf037a780ac3e6c3/optimum/quanto/quantize.py#L101
The function was modified to remove the `freeze()` call. The `freeze()` call is very slow and unnecessary when the
weights are about to be loaded from a state_dict.
TODO(ryand): Unless I'm overlooking something, this should be contributed upstream to the `optimum-quanto` library.
"""
if device is None:
device = next(model.parameters()).device
if device.type == "meta":
Expand Down Expand Up @@ -45,6 +47,7 @@ def move_tensor(t, device):
setattr(m, name, torch.nn.Parameter(move_tensor(param, "cpu")))
for name, param in m.named_buffers(recurse=False):
setattr(m, name, move_tensor(param, "cpu"))

# Freeze model and move to target device
# freeze(model)
# model.to(device)
Expand Down

0 comments on commit 38c2e78

Please sign in to comment.