diff --git a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py new file mode 100644 index 00000000000..719a559dd02 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py @@ -0,0 +1,93 @@ +from typing import Any + +import torch + + +class CachedModelOnlyFullLoad: + """A wrapper around a PyTorch model to handle full loads and unloads between the CPU and the compute device. + Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory, + MPS memory, etc. + """ + + def __init__(self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int): + """Initialize a CachedModelOnlyFullLoad. + Args: + model (torch.nn.Module | Any): The model to wrap. Should be on the CPU. + compute_device (torch.device): The compute device to move the model to. + total_bytes (int): The total size (in bytes) of all the weights in the model. + """ + # model is often a torch.nn.Module, but could be any model type. Throughout this class, we handle both cases. + self._model = model + self._compute_device = compute_device + self._offload_device = torch.device("cpu") + + # A CPU read-only copy of the model's state dict. + self._cpu_state_dict: dict[str, torch.Tensor] | None = None + if isinstance(model, torch.nn.Module): + self._cpu_state_dict = model.state_dict() + + self._total_bytes = total_bytes + self._is_in_vram = False + + @property + def model(self) -> torch.nn.Module: + return self._model + + def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None: + """Get a read-only copy of the model's state dict in RAM.""" + # TODO(ryand): Document this better. + return self._cpu_state_dict + + def total_bytes(self) -> int: + """Get the total size (in bytes) of all the weights in the model.""" + return self._total_bytes + + def cur_vram_bytes(self) -> int: + """Get the size (in bytes) of the weights that are currently in VRAM.""" + if self._is_in_vram: + return self._total_bytes + else: + return 0 + + def is_in_vram(self) -> bool: + """Return true if the model is currently in VRAM.""" + return self._is_in_vram + + def full_load_to_vram(self) -> int: + """Load all weights into VRAM (if supported by the model). + Returns: + The number of bytes loaded into VRAM. + """ + if self._is_in_vram: + # Already in VRAM. + return 0 + + if not hasattr(self._model, "to"): + # Model doesn't support moving to a device. + return 0 + + if self._cpu_state_dict is not None: + new_state_dict: dict[str, torch.Tensor] = {} + for k, v in self._cpu_state_dict.items(): + new_state_dict[k] = v.to(self._compute_device, copy=True) + self._model.load_state_dict(new_state_dict, assign=True) + self._model.to(self._compute_device) + + self._is_in_vram = True + return self._total_bytes + + def full_unload_from_vram(self) -> int: + """Unload all weights from VRAM. + Returns: + The number of bytes unloaded from VRAM. + """ + if not self._is_in_vram: + # Already in RAM. + return 0 + + if self._cpu_state_dict is not None: + self._model.load_state_dict(self._cpu_state_dict, assign=True) + self._model.to(self._offload_device) + + self._is_in_vram = False + return self._total_bytes diff --git a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py new file mode 100644 index 00000000000..ab1a62db461 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py @@ -0,0 +1,201 @@ +import torch + +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( + AUTOCAST_MODULE_TYPE_MAPPING, + apply_custom_layers_to_model, + remove_custom_layers_from_model, +) +from invokeai.backend.util.calc_tensor_size import calc_tensor_size +from invokeai.backend.util.logging import InvokeAILogger + + +def set_nested_attr(obj: object, attr: str, value: object): + """A helper function that extends setattr() to support nested attributes. + + Example: + set_nested_attr(model, "module.encoder.conv1.weight", new_conv1_weight) + """ + attrs = attr.split(".") + for attr in attrs[:-1]: + obj = getattr(obj, attr) + setattr(obj, attrs[-1], value) + + +class CachedModelWithPartialLoad: + """A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device. + + Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory, + MPS memory, etc. + """ + + def __init__(self, model: torch.nn.Module, compute_device: torch.device): + self._model = model + self._compute_device = compute_device + + # A CPU read-only copy of the model's state dict. + self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict() + + # TODO(ryand): Handle the case where the model sizes changes after initial load (e.g. due to dtype casting). + # Consider how we should handle this for both self._total_bytes and self._cur_vram_bytes. + self._total_bytes = sum(calc_tensor_size(p) for p in self._cpu_state_dict.values()) + self._cur_vram_bytes: int | None = None + + self._modules_that_support_autocast = self._find_modules_that_support_autocast() + self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast() + + def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]: + """Find all modules that support autocasting.""" + return {n: m for n, m in self._model.named_modules() if type(m) in AUTOCAST_MODULE_TYPE_MAPPING} + + def _find_keys_in_modules_that_do_not_support_autocast(self) -> set[str]: + keys_in_modules_that_do_not_support_autocast = set() + for key in self._cpu_state_dict.keys(): + for module_name in self._modules_that_support_autocast.keys(): + if key.startswith(module_name): + break + else: + keys_in_modules_that_do_not_support_autocast.add(key) + return keys_in_modules_that_do_not_support_autocast + + def _move_non_persistent_buffers_to_device(self, device: torch.device): + """Move the non-persistent buffers to the target device. These buffers are not included in the state dict, + so we need to move them manually. + """ + # HACK(ryand): Typically, non-persistent buffers are moved when calling module.to(device). We don't move entire + # modules, because we manage the devices of individual tensors using the state dict. Since non-persistent + # buffers are not included in the state dict, we need to handle them manually. The only way to do this is by + # using private torch.nn.Module attributes. + for module in self._model.modules(): + for name, buffer in module.named_buffers(): + if name in module._non_persistent_buffers_set: + module._buffers[name] = buffer.to(device, copy=True) + + @property + def model(self) -> torch.nn.Module: + return self._model + + def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None: + """Get a read-only copy of the model's state dict in RAM.""" + # TODO(ryand): Document this better. + return self._cpu_state_dict + + def total_bytes(self) -> int: + """Get the total size (in bytes) of all the weights in the model.""" + return self._total_bytes + + def cur_vram_bytes(self) -> int: + """Get the size (in bytes) of the weights that are currently in VRAM.""" + if self._cur_vram_bytes is None: + cur_state_dict = self._model.state_dict() + self._cur_vram_bytes = sum( + calc_tensor_size(p) for p in cur_state_dict.values() if p.device.type == self._compute_device.type + ) + return self._cur_vram_bytes + + def full_load_to_vram(self) -> int: + """Load all weights into VRAM.""" + return self.partial_load_to_vram(self.total_bytes()) + + def full_unload_from_vram(self) -> int: + """Unload all weights from VRAM.""" + return self.partial_unload_from_vram(self.total_bytes()) + + @torch.no_grad() + def partial_load_to_vram(self, vram_bytes_to_load: int) -> int: + """Load more weights into VRAM without exceeding vram_bytes_to_load. + + Returns: + The number of bytes loaded into VRAM. + """ + # TODO(ryand): Handle the case where an exception is thrown while loading or unloading weights. At the very + # least, we should reset self._cur_vram_bytes to None. + + vram_bytes_loaded = 0 + + cur_state_dict = self._model.state_dict() + + # First, process the keys *must* be loaded into VRAM. + for key in self._keys_in_modules_that_do_not_support_autocast: + param = cur_state_dict[key] + if param.device.type == self._compute_device.type: + continue + + param_size = calc_tensor_size(param) + cur_state_dict[key] = param.to(self._compute_device, copy=True) + vram_bytes_loaded += param_size + + if vram_bytes_loaded > vram_bytes_to_load: + logger = InvokeAILogger.get_logger() + logger.warning( + f"Loaded {vram_bytes_loaded / 2**20} MB into VRAM, but only {vram_bytes_to_load / 2**20} MB were " + "requested. This is the minimum set of weights in VRAM required to run the model." + ) + + # Next, process the keys that can optionally be loaded into VRAM. + fully_loaded = True + for key, param in cur_state_dict.items(): + if param.device.type == self._compute_device.type: + continue + + param_size = calc_tensor_size(param) + if vram_bytes_loaded + param_size > vram_bytes_to_load: + # TODO(ryand): Should we just break here? If we couldn't fit this parameter into VRAM, is it really + # worth continuing to search for a smaller parameter that would fit? + fully_loaded = False + continue + + cur_state_dict[key] = param.to(self._compute_device, copy=True) + vram_bytes_loaded += param_size + + if vram_bytes_loaded > 0: + # We load the entire state dict, not just the parameters that changed, in case there are modules that + # override _load_from_state_dict() and do some funky stuff that requires the entire state dict. + # Alternatively, in the future, grouping parameters by module could probably solve this problem. + self._model.load_state_dict(cur_state_dict, assign=True) + + if self._cur_vram_bytes is not None: + self._cur_vram_bytes += vram_bytes_loaded + + if fully_loaded: + remove_custom_layers_from_model(self._model) + # TODO(ryand): Warn if the self.cur_vram_bytes() and self.total_bytes() are out of sync. + else: + apply_custom_layers_to_model(self._model) + + # Move all non-persistent buffers to the compute device. These are a weird edge case and do not participate in + # the vram_bytes_loaded tracking. + self._move_non_persistent_buffers_to_device(self._compute_device) + + return vram_bytes_loaded + + @torch.no_grad() + def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int: + """Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded. + + Returns: + The number of bytes unloaded from VRAM. + """ + vram_bytes_freed = 0 + + offload_device = "cpu" + cur_state_dict = self._model.state_dict() + for key, param in cur_state_dict.items(): + if vram_bytes_freed >= vram_bytes_to_free: + break + + if param.device.type == offload_device: + continue + + cur_state_dict[key] = self._cpu_state_dict[key] + vram_bytes_freed += calc_tensor_size(param) + + if vram_bytes_freed > 0: + self._model.load_state_dict(cur_state_dict, assign=True) + + if self._cur_vram_bytes is not None: + self._cur_vram_bytes -= vram_bytes_freed + + # We may have gone from a fully-loaded model to a partially-loaded model, so we need to reapply the custom + # layers. + apply_custom_layers_to_model(self._model) + return vram_bytes_freed diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/__init__.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/autocast_modules.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/autocast_modules.py new file mode 100644 index 00000000000..8a1bacf6833 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/autocast_modules.py @@ -0,0 +1,50 @@ +import torch + +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device + +# This file contains custom torch.nn.Module classes that support streaming of weights to the target device. +# Each class sub-classes the original module type that is is replacing, so the following properties are preserved: +# - isinstance(m, torch.nn.OrginalModule) should still work. +# - Patching the weights (e.g. for LoRA) should still work if non-quantized. + + +class CustomLinear(torch.nn.Linear): + def forward(self, input: torch.Tensor) -> torch.Tensor: + weight = cast_to_device(self.weight, input.device) + bias = cast_to_device(self.bias, input.device) + return torch.nn.functional.linear(input, weight, bias) + + +class CustomConv1d(torch.nn.Conv1d): + def forward(self, input: torch.Tensor) -> torch.Tensor: + weight = cast_to_device(self.weight, input.device) + bias = cast_to_device(self.bias, input.device) + return self._conv_forward(input, weight, bias) + + +class CustomConv2d(torch.nn.Conv2d): + def forward(self, input: torch.Tensor) -> torch.Tensor: + weight = cast_to_device(self.weight, input.device) + bias = cast_to_device(self.bias, input.device) + return self._conv_forward(input, weight, bias) + + +class CustomGroupNorm(torch.nn.GroupNorm): + def forward(self, input: torch.Tensor) -> torch.Tensor: + weight = cast_to_device(self.weight, input.device) + bias = cast_to_device(self.bias, input.device) + return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + + +class CustomEmbedding(torch.nn.Embedding): + def forward(self, input: torch.Tensor) -> torch.Tensor: + weight = cast_to_device(self.weight, input.device) + return torch.nn.functional.embedding( + input, + weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/cast_to_device.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/cast_to_device.py new file mode 100644 index 00000000000..7a50a19953b --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/cast_to_device.py @@ -0,0 +1,15 @@ +from typing import TypeVar + +import torch + +T = TypeVar("T", torch.Tensor, None, torch.Tensor | None) + + +def cast_to_device(t: T, to_device: torch.device) -> T: + """Helper function to cast an optional tensor to a target device.""" + if t is None: + return t + + if t.device.type != to_device.type: + return t.to(to_device) + return t diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_invoke_linear_8_bit_lt.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_invoke_linear_8_bit_lt.py new file mode 100644 index 00000000000..3941a2af6be --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_invoke_linear_8_bit_lt.py @@ -0,0 +1,27 @@ +import bitsandbytes as bnb +import torch + +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device +from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt + + +class CustomInvokeLinear8bitLt(InvokeLinear8bitLt): + def forward(self, x: torch.Tensor) -> torch.Tensor: + matmul_state = bnb.MatmulLtState() + matmul_state.threshold = self.state.threshold + matmul_state.has_fp16_weights = self.state.has_fp16_weights + matmul_state.use_pool = self.state.use_pool + matmul_state.is_training = self.training + # The underlying InvokeInt8Params weight must already be quantized. + assert self.weight.CB is not None + matmul_state.CB = cast_to_device(self.weight.CB, x.device) + matmul_state.SCB = cast_to_device(self.weight.SCB, x.device) + + # weights are cast automatically as Int8Params, but the bias has to be cast manually. + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + # NOTE(ryand): The second parameter should not be needed at all given our expected inference configuration, but + # it's dtype field must be accessible, even though it's not used. We pass in self.weight even though it could be + # on the wrong device. + return bnb.matmul(x, self.weight, bias=cast_to_device(self.bias, x.device), state=matmul_state) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_invoke_linear_nf4.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_invoke_linear_nf4.py new file mode 100644 index 00000000000..c697b3c7b43 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_invoke_linear_nf4.py @@ -0,0 +1,45 @@ +import copy + +import bitsandbytes as bnb +import torch + +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device +from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4 + + +class CustomInvokeLinearNF4(InvokeLinearNF4): + def forward(self, x: torch.Tensor) -> torch.Tensor: + bnb.nn.modules.fix_4bit_weight_quant_state_from_module(self) + + # weights are cast automatically as Int8Params, but the bias has to be cast manually + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + if not self.compute_type_is_set: + self.set_compute_type(x) + self.compute_type_is_set = True + + inp_dtype = x.dtype + if self.compute_dtype is not None: + x = x.to(self.compute_dtype) + + bias = None if self.bias is None else self.bias.to(self.compute_dtype) + + # HACK(ryand): Casting self.weight to the device also casts the self.weight.quant_state in-place (i.e. it + # does not follow the tensor semantics of returning a new copy when converting to a different device). This + # means that quant_state elements that started on the CPU would be left on the GPU, which we don't want. To + # avoid this side effect we make a shallow copy of the original quant_state so that we can restore it. Fixing + # this properly would require more invasive changes to the bitsandbytes library. + + # Make a shallow copy of the quant_state so that we can undo the in-place modification that occurs when casting + # to a new device. + old_quant_state = copy.copy(self.weight.quant_state) + weight = cast_to_device(self.weight, x.device) + self.weight.quant_state = old_quant_state + + # For some reason, the quant_state.to(...) implementation fails to cast the quant_state.code field. We do this + # manually here. + weight.quant_state.code = cast_to_device(weight.quant_state.code, x.device) + + bias = cast_to_device(self.bias, x.device) + return bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state).to(inp_dtype) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py new file mode 100644 index 00000000000..825eebf64e8 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py @@ -0,0 +1,56 @@ +import torch + +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autocast_modules import ( + CustomConv1d, + CustomConv2d, + CustomEmbedding, + CustomGroupNorm, + CustomLinear, +) + +AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]] = { + torch.nn.Linear: CustomLinear, + torch.nn.Conv1d: CustomConv1d, + torch.nn.Conv2d: CustomConv2d, + torch.nn.GroupNorm: CustomGroupNorm, + torch.nn.Embedding: CustomEmbedding, +} + +try: + # These dependencies are not expected to be present on MacOS. + from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_8_bit_lt import ( + CustomInvokeLinear8bitLt, + ) + from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_nf4 import ( + CustomInvokeLinearNF4, + ) + from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt + from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4 + + AUTOCAST_MODULE_TYPE_MAPPING[InvokeLinear8bitLt] = CustomInvokeLinear8bitLt + AUTOCAST_MODULE_TYPE_MAPPING[InvokeLinearNF4] = CustomInvokeLinearNF4 +except ImportError: + pass + + +def apply_custom_layers_to_model(model: torch.nn.Module): + def apply_custom_layers(module: torch.nn.Module): + override_type = AUTOCAST_MODULE_TYPE_MAPPING.get(type(module), None) + if override_type is not None: + module.__class__ = override_type + + # model.apply(...) calls apply_custom_layers(...) on each module in the model. + model.apply(apply_custom_layers) + + +def remove_custom_layers_from_model(model: torch.nn.Module): + # Invert AUTOCAST_MODULE_TYPE_MAPPING. + original_module_type_mapping = {v: k for k, v in AUTOCAST_MODULE_TYPE_MAPPING.items()} + + def remove_custom_layers(module: torch.nn.Module): + override_type = original_module_type_mapping.get(type(module), None) + if override_type is not None: + module.__class__ = override_type + + # model.apply(...) calls remove_custom_layers(...) on each module in the model. + model.apply(remove_custom_layers) diff --git a/invokeai/backend/quantization/bnb_llm_int8.py b/invokeai/backend/quantization/bnb_llm_int8.py index 02f94936e96..8722a19c373 100644 --- a/invokeai/backend/quantization/bnb_llm_int8.py +++ b/invokeai/backend/quantization/bnb_llm_int8.py @@ -25,12 +25,9 @@ def cuda(self, device): self.CB = self.data self.SCB = self.SCB.cuda() else: - # we store the 8-bit rows-major weight - # we convert this weight to the turning/ampere weight during the first inference pass + # We quantize the weight and store in 8bit row-major B = self.data.contiguous().half().cuda(device) - CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) - del CBt - del SCBt + CB, SCB, _ = bnb.functional.int8_vectorwise_quant(B) self.data = CB self.CB = CB self.SCB = SCB @@ -55,9 +52,10 @@ def _load_from_state_dict( # See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format. scb = state_dict.pop(prefix + "SCB", None) - # Currently, we only support weight_format=0. weight_format = state_dict.pop(prefix + "weight_format", None) - assert weight_format == 0 + if weight_format is not None: + # Currently, we only support weight_format=0. + assert weight_format == 0 # TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs` # rather than raising an exception to correctly implement this API. @@ -99,6 +97,27 @@ def _load_from_state_dict( new_state.use_pool = self.state.use_pool self.state = new_state + def forward(self, x: torch.Tensor): + # The state management in the base bnb.nn.Linear8bitLt is very convoluted. We override the forward method to + # try to simplify the state management a bit. We initialize a new MatmulLtState object for each forward pass. + # By avoiding persistent state, it is easier to move the layer between devices without worrying about keeping + # references to weights on the old device (e.g. self.state.CB). + matmul_state = bnb.MatmulLtState() + matmul_state.threshold = self.state.threshold + matmul_state.has_fp16_weights = self.state.has_fp16_weights + matmul_state.use_pool = self.state.use_pool + matmul_state.is_training = self.training + # The underlying InvokeInt8Params weight must already be quantized. + assert self.weight.CB is not None + matmul_state.CB = self.weight.CB + matmul_state.SCB = self.weight.SCB + + # weights are cast automatically as Int8Params, but the bias has to be cast manually. + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + return bnb.matmul(x, self.weight, bias=self.bias, state=matmul_state) + def _convert_linear_layers_to_llm_8bit( module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = "" diff --git a/pyproject.toml b/pyproject.toml index c9cca90a030..1a989e93f17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ classifiers = [ dependencies = [ # Core generation dependencies, pinned for reproducible builds. "accelerate==1.0.1", - "bitsandbytes==0.43.3; sys_platform!='darwin'", + "bitsandbytes==0.45.0; sys_platform!='darwin'", "clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "compel==2.0.2", "controlnet-aux==0.0.7", diff --git a/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_only_full_load.py b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_only_full_load.py new file mode 100644 index 00000000000..76a3774288c --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_only_full_load.py @@ -0,0 +1,122 @@ +import torch + +from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import ( + CachedModelOnlyFullLoad, +) +from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule, parameterize_mps_and_cuda + + +class NonTorchModel: + """A model that does not sub-class torch.nn.Module.""" + + def __init__(self): + self.linear = torch.nn.Linear(10, 32) + + def run_inference(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + +@parameterize_mps_and_cuda +def test_cached_model_total_bytes(device: str): + model = DummyModule() + cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100) + assert cached_model.total_bytes() == 100 + + +@parameterize_mps_and_cuda +def test_cached_model_is_in_vram(device: str): + model = DummyModule() + cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100) + assert not cached_model.is_in_vram() + assert cached_model.cur_vram_bytes() == 0 + + cached_model.full_load_to_vram() + assert cached_model.is_in_vram() + assert cached_model.cur_vram_bytes() == 100 + + cached_model.full_unload_from_vram() + assert not cached_model.is_in_vram() + assert cached_model.cur_vram_bytes() == 0 + + +@parameterize_mps_and_cuda +def test_cached_model_full_load_and_unload(device: str): + model = DummyModule() + cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100) + assert cached_model.full_load_to_vram() == 100 + assert cached_model.is_in_vram() + assert all(p.device.type == device for p in cached_model.model.parameters()) + + assert cached_model.full_unload_from_vram() == 100 + assert not cached_model.is_in_vram() + assert all(p.device.type == "cpu" for p in cached_model.model.parameters()) + + +@parameterize_mps_and_cuda +def test_cached_model_get_cpu_state_dict(device: str): + model = DummyModule() + cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100) + assert not cached_model.is_in_vram() + + # The CPU state dict can be accessed and has the expected properties. + cpu_state_dict = cached_model.get_cpu_state_dict() + assert cpu_state_dict is not None + assert len(cpu_state_dict) == len(model.state_dict()) + assert all(p.device.type == "cpu" for p in cpu_state_dict.values()) + + # Full load the model into VRAM. + cached_model.full_load_to_vram() + assert cached_model.is_in_vram() + + # The CPU state dict is still available, and still on the CPU. + cpu_state_dict = cached_model.get_cpu_state_dict() + assert cpu_state_dict is not None + assert len(cpu_state_dict) == len(model.state_dict()) + assert all(p.device.type == "cpu" for p in cpu_state_dict.values()) + + +@parameterize_mps_and_cuda +def test_cached_model_full_load_and_inference(device: str): + model = DummyModule() + cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100) + assert not cached_model.is_in_vram() + + # Run inference on the CPU. + x = torch.randn(1, 10) + output1 = model(x) + assert output1.device.type == "cpu" + + # Full load the model into VRAM. + cached_model.full_load_to_vram() + assert cached_model.is_in_vram() + + # Run inference on the GPU. + output2 = model(x.to(device)) + assert output2.device.type == device + + # The outputs should be the same for both runs. + assert torch.allclose(output1, output2.to("cpu")) + + +@parameterize_mps_and_cuda +def test_non_torch_model(device: str): + model = NonTorchModel() + cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100) + assert not cached_model.is_in_vram() + + # The model does not have a CPU state dict. + assert cached_model.get_cpu_state_dict() is None + + # Attempting to load the model into VRAM should have no effect. + cached_model.full_load_to_vram() + assert not cached_model.is_in_vram() + assert cached_model.cur_vram_bytes() == 0 + + # Attempting to unload the model from VRAM should have no effect. + cached_model.full_unload_from_vram() + assert not cached_model.is_in_vram() + assert cached_model.cur_vram_bytes() == 0 + + # Running inference on the CPU should work. + output1 = model.run_inference(torch.randn(1, 10)) + assert output1.device.type == "cpu" diff --git a/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py new file mode 100644 index 00000000000..e3c99d0c34f --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py @@ -0,0 +1,274 @@ +import itertools + +import torch + +from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import ( + CachedModelWithPartialLoad, +) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autocast_modules import CustomLinear +from invokeai.backend.util.calc_tensor_size import calc_tensor_size +from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule, parameterize_mps_and_cuda + + +@parameterize_mps_and_cuda +def test_cached_model_total_bytes(device: str): + model = DummyModule() + cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) + linear1_numel = 10 * 32 + 32 + linear2_numel = 32 * 64 + 64 + buffer1_numel = 64 + # Note that the non-persistent buffer (buffer2) is not included in .total_bytes() calculation. + assert cached_model.total_bytes() == (linear1_numel + linear2_numel + buffer1_numel) * 4 + + +@parameterize_mps_and_cuda +def test_cached_model_cur_vram_bytes(device: str): + model = DummyModule() + # Model starts in CPU memory. + cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) + assert cached_model.cur_vram_bytes() == 0 + + # Full load the model into VRAM. + cached_model.full_load_to_vram() + assert cached_model.cur_vram_bytes() > 0 + assert cached_model.cur_vram_bytes() == cached_model.total_bytes() + assert all(p.device.type == device for p in model.parameters()) + assert all(p.device.type == device for p in model.buffers()) + + +@parameterize_mps_and_cuda +def test_cached_model_partial_load(device: str): + model = DummyModule() + # Model starts in CPU memory. + cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) + model_total_bytes = cached_model.total_bytes() + assert cached_model.cur_vram_bytes() == 0 + + # Partially load the model into VRAM. + target_vram_bytes = int(model_total_bytes * 0.6) + loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes) + + # Check that the model is partially loaded into VRAM. + assert loaded_bytes > 0 + assert loaded_bytes < model_total_bytes + assert loaded_bytes == cached_model.cur_vram_bytes() + assert loaded_bytes == sum( + calc_tensor_size(p) + for n, p in itertools.chain(model.named_parameters(), model.named_buffers()) + if p.device.type == device and n != "buffer2" + ) + + # Check that the model's modules have been patched with CustomLinear layers. + assert type(model.linear1) is CustomLinear + assert type(model.linear2) is CustomLinear + + +@parameterize_mps_and_cuda +def test_cached_model_partial_unload(device: str): + model = DummyModule() + # Model starts in CPU memory. + cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) + model_total_bytes = cached_model.total_bytes() + assert cached_model.cur_vram_bytes() == 0 + + # Full load the model into VRAM. + cached_model.full_load_to_vram() + assert cached_model.cur_vram_bytes() == model_total_bytes + + # Partially unload the model from VRAM. + bytes_to_free = int(model_total_bytes * 0.4) + freed_bytes = cached_model.partial_unload_from_vram(bytes_to_free) + + # Check that the model is partially unloaded from VRAM. + assert freed_bytes >= bytes_to_free + assert freed_bytes < model_total_bytes + assert freed_bytes == model_total_bytes - cached_model.cur_vram_bytes() + assert freed_bytes == sum( + calc_tensor_size(p) for p in itertools.chain(model.parameters(), model.buffers()) if p.device.type == "cpu" + ) + + # Check that the model's modules are still patched with CustomLinear layers. + assert type(model.linear1) is CustomLinear + assert type(model.linear2) is CustomLinear + + +@parameterize_mps_and_cuda +def test_cached_model_full_load_and_unload(device: str): + model = DummyModule() + cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) + + # Model starts in CPU memory. + model_total_bytes = cached_model.total_bytes() + assert cached_model.cur_vram_bytes() == 0 + + # Full load the model into VRAM. + loaded_bytes = cached_model.full_load_to_vram() + assert loaded_bytes > 0 + assert loaded_bytes == model_total_bytes + assert loaded_bytes == cached_model.cur_vram_bytes() + assert all(p.device.type == device for p in itertools.chain(model.parameters(), model.buffers())) + assert type(model.linear1) is torch.nn.Linear + assert type(model.linear2) is torch.nn.Linear + + # Full unload the model from VRAM. + unloaded_bytes = cached_model.full_unload_from_vram() + + # Check that the model is fully unloaded from VRAM. + assert unloaded_bytes > 0 + assert unloaded_bytes == model_total_bytes + assert cached_model.cur_vram_bytes() == 0 + # Note that the non-persistent buffer (buffer2) is not required to be unloaded from VRAM. + assert all( + p.device.type == "cpu" + for n, p in itertools.chain(model.named_parameters(), model.named_buffers()) + if n != "buffer2" + ) + + +@parameterize_mps_and_cuda +def test_cached_model_full_load_from_partial(device: str): + model = DummyModule() + cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) + + # Model starts in CPU memory. + model_total_bytes = cached_model.total_bytes() + assert cached_model.cur_vram_bytes() == 0 + + # Partially load the model into VRAM. + target_vram_bytes = int(model_total_bytes * 0.6) + loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes) + assert loaded_bytes > 0 + assert loaded_bytes < model_total_bytes + assert loaded_bytes == cached_model.cur_vram_bytes() + assert type(model.linear1) is CustomLinear + assert type(model.linear2) is CustomLinear + + # Full load the rest of the model into VRAM. + loaded_bytes_2 = cached_model.full_load_to_vram() + assert loaded_bytes_2 > 0 + assert loaded_bytes_2 < model_total_bytes + assert loaded_bytes + loaded_bytes_2 == cached_model.cur_vram_bytes() + assert loaded_bytes + loaded_bytes_2 == model_total_bytes + assert all(p.device.type == device for p in itertools.chain(model.parameters(), model.buffers())) + assert type(model.linear1) is torch.nn.Linear + assert type(model.linear2) is torch.nn.Linear + + +@parameterize_mps_and_cuda +def test_cached_model_full_unload_from_partial(device: str): + model = DummyModule() + cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) + + # Model starts in CPU memory. + model_total_bytes = cached_model.total_bytes() + assert cached_model.cur_vram_bytes() == 0 + + # Partially load the model into VRAM. + target_vram_bytes = int(model_total_bytes * 0.6) + loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes) + assert loaded_bytes > 0 + assert loaded_bytes < model_total_bytes + assert loaded_bytes == cached_model.cur_vram_bytes() + + # Full unload the model from VRAM. + unloaded_bytes = cached_model.full_unload_from_vram() + assert unloaded_bytes > 0 + assert unloaded_bytes == loaded_bytes + assert cached_model.cur_vram_bytes() == 0 + # Note that the non-persistent buffer (buffer2) is not required to be unloaded from VRAM. + assert all( + p.device.type == "cpu" + for n, p in itertools.chain(model.named_parameters(), model.named_buffers()) + if n != "buffer2" + ) + + +@parameterize_mps_and_cuda +def test_cached_model_get_cpu_state_dict(device: str): + model = DummyModule() + cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) + + # Model starts in CPU memory. + assert cached_model.cur_vram_bytes() == 0 + + # The CPU state dict can be accessed and has the expected properties. + cpu_state_dict = cached_model.get_cpu_state_dict() + assert cpu_state_dict is not None + assert len(cpu_state_dict) == len(model.state_dict()) + assert all(p.device.type == "cpu" for p in cpu_state_dict.values()) + + # Full load the model into VRAM. + cached_model.full_load_to_vram() + assert cached_model.cur_vram_bytes() == cached_model.total_bytes() + + # The CPU state dict is still available, and still on the CPU. + cpu_state_dict = cached_model.get_cpu_state_dict() + assert cpu_state_dict is not None + assert len(cpu_state_dict) == len(model.state_dict()) + assert all(p.device.type == "cpu" for p in cpu_state_dict.values()) + + +@parameterize_mps_and_cuda +def test_cached_model_full_load_and_inference(device: str): + model = DummyModule() + cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) + # Model starts in CPU memory. + model_total_bytes = cached_model.total_bytes() + assert cached_model.cur_vram_bytes() == 0 + + # Run inference on the CPU. + x = torch.randn(1, 10) + output1 = model(x) + assert output1.device.type == "cpu" + + # Full load the model into VRAM. + loaded_bytes = cached_model.full_load_to_vram() + assert loaded_bytes > 0 + assert loaded_bytes == model_total_bytes + assert loaded_bytes == cached_model.cur_vram_bytes() + assert all(p.device.type == device for p in itertools.chain(model.parameters(), model.buffers())) + + # Run inference on the GPU. + output2 = model(x.to(device)) + assert output2.device.type == device + + # The outputs should be the same for both runs. + assert torch.allclose(output1, output2.to("cpu")) + + +@parameterize_mps_and_cuda +def test_cached_model_partial_load_and_inference(device: str): + model = DummyModule() + # Model starts in CPU memory. + cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) + model_total_bytes = cached_model.total_bytes() + assert cached_model.cur_vram_bytes() == 0 + + # Run inference on the CPU. + x = torch.randn(1, 10) + output1 = model(x) + assert output1.device.type == "cpu" + + # Partially load the model into VRAM. + target_vram_bytes = int(model_total_bytes * 0.6) + loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes) + + # Check that the model is partially loaded into VRAM. + assert loaded_bytes > 0 + assert loaded_bytes < model_total_bytes + assert loaded_bytes == cached_model.cur_vram_bytes() + assert loaded_bytes == sum( + calc_tensor_size(p) + for n, p in itertools.chain(model.named_parameters(), model.named_buffers()) + if p.device.type == device and n != "buffer2" + ) + # Check that the model's modules have been patched with CustomLinear layers. + assert type(model.linear1) is CustomLinear + assert type(model.linear2) is CustomLinear + + # Run inference on the GPU. + output2 = model(x.to(device)) + assert output2.device.type == device + + # The output should be the same as the output from the CPU. + assert torch.allclose(output1, output2.to("cpu")) diff --git a/tests/backend/model_manager/load/model_cache/cached_model/utils.py b/tests/backend/model_manager/load/model_cache/cached_model/utils.py new file mode 100644 index 00000000000..9554299e066 --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/cached_model/utils.py @@ -0,0 +1,31 @@ +import pytest +import torch + + +class DummyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(10, 32) + self.linear2 = torch.nn.Linear(32, 64) + self.register_buffer("buffer1", torch.ones(64)) + # Non-persistent buffers are not included in the state dict. We need to make sure that this case is handled + # correctly by the partial loading code. + self.register_buffer("buffer2", torch.ones(64), persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear1(x) + x = self.linear2(x) + x = x + self.buffer1 + x = x + self.buffer2 + return x + + +parameterize_mps_and_cuda = pytest.mark.parametrize( + ("device"), + [ + pytest.param( + "mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS is not available.") + ), + pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")), + ], +) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_autocast_modules.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_autocast_modules.py new file mode 100644 index 00000000000..38fa467c602 --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_autocast_modules.py @@ -0,0 +1,144 @@ +import pytest +import torch + +if not torch.cuda.is_available(): + pytest.skip("CUDA is not available", allow_module_level=True) +else: + from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_8_bit_lt import ( + CustomInvokeLinear8bitLt, + ) + from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_nf4 import ( + CustomInvokeLinearNF4, + ) + from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt + from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4 + + +@pytest.fixture +def linear_8bit_lt_layer(): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + torch.manual_seed(1) + + orig_layer = torch.nn.Linear(32, 64) + orig_layer_state_dict = orig_layer.state_dict() + + # Prepare a quantized InvokeLinear8bitLt layer. + quantized_layer = InvokeLinear8bitLt(input_features=32, output_features=64, has_fp16_weights=False) + quantized_layer.load_state_dict(orig_layer_state_dict) + quantized_layer.to("cuda") + + # Assert that the InvokeLinear8bitLt layer is quantized. + assert quantized_layer.weight.CB is not None + assert quantized_layer.weight.SCB is not None + assert quantized_layer.weight.CB.dtype == torch.int8 + + return quantized_layer + + +def test_custom_invoke_linear_8bit_lt_all_weights_on_cuda(linear_8bit_lt_layer: InvokeLinear8bitLt): + """Test CustomInvokeLinear8bitLt inference with all weights on the GPU.""" + # Run inference on the original layer. + x = torch.randn(1, 32).to("cuda") + y_quantized = linear_8bit_lt_layer(x) + + # Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it. + linear_8bit_lt_layer.__class__ = CustomInvokeLinear8bitLt + y_custom = linear_8bit_lt_layer(x) + + # Assert that the quantized and custom layers produce the same output. + assert torch.allclose(y_quantized, y_custom, atol=1e-5) + + +def test_custom_invoke_linear_8bit_lt_all_weights_on_cpu(linear_8bit_lt_layer: InvokeLinear8bitLt): + """Test CustomInvokeLinear8bitLt inference with all weights on the CPU (streaming to the GPU).""" + # Run inference on the original layer. + x = torch.randn(1, 32).to("cuda") + y_quantized = linear_8bit_lt_layer(x) + + # Copy the state dict to the CPU and reload it. + state_dict = linear_8bit_lt_layer.state_dict() + state_dict = {k: v.to("cpu") for k, v in state_dict.items()} + linear_8bit_lt_layer.load_state_dict(state_dict) + + # Inference of the original layer should fail. + with pytest.raises(RuntimeError): + linear_8bit_lt_layer(x) + + # Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it. + linear_8bit_lt_layer.__class__ = CustomInvokeLinear8bitLt + y_custom = linear_8bit_lt_layer(x) + + # Assert that the quantized and custom layers produce the same output. + assert torch.allclose(y_quantized, y_custom, atol=1e-5) + + +@pytest.fixture +def linear_nf4_layer(): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + torch.manual_seed(1) + + orig_layer = torch.nn.Linear(64, 16) + orig_layer_state_dict = orig_layer.state_dict() + + # Prepare a quantized InvokeLinearNF4 layer. + quantized_layer = InvokeLinearNF4(input_features=64, output_features=16) + quantized_layer.load_state_dict(orig_layer_state_dict) + quantized_layer.to("cuda") + + # Assert that the InvokeLinearNF4 layer is quantized. + assert quantized_layer.weight.bnb_quantized + + return quantized_layer + + +def test_custom_invoke_linear_nf4_all_weights_on_cuda(linear_nf4_layer: InvokeLinearNF4): + """Test CustomInvokeLinearNF4 inference with all weights on the GPU.""" + # Run inference on the original layer. + x = torch.randn(1, 64).to("cuda") + y_quantized = linear_nf4_layer(x) + + # Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it. + linear_nf4_layer.__class__ = CustomInvokeLinearNF4 + y_custom = linear_nf4_layer(x) + + # Assert that the quantized and custom layers produce the same output. + assert torch.allclose(y_quantized, y_custom, atol=1e-5) + + +# We run with two different input dimensions, because the NF4 layer follows a different code path depending on the +# input dimension, and this has caused issues in the past. +@pytest.mark.parametrize("input_dim_0", [1, 2]) +def test_custom_invoke_linear_nf4_all_weights_on_cpu(linear_nf4_layer: InvokeLinearNF4, input_dim_0: int): + """Test CustomInvokeLinearNF4 inference with all weights on the CPU (streaming to the GPU).""" + # Run inference on the original layer. + x = torch.randn(input_dim_0, 64).to(device="cuda") + y_quantized = linear_nf4_layer(x) + + # Copy the state dict to the CPU and reload it. + state_dict = linear_nf4_layer.state_dict() + state_dict = {k: v.to("cpu") for k, v in state_dict.items()} + linear_nf4_layer.load_state_dict(state_dict) + + # Inference of the original layer should fail. + with pytest.raises(RuntimeError): + linear_nf4_layer(x) + + # Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it. + linear_nf4_layer.__class__ = CustomInvokeLinearNF4 + y_custom = linear_nf4_layer(x) + + # Assert that the state dict (and the tensors that it references) are still on the CPU. + assert all(v.device == torch.device("cpu") for v in state_dict.values()) + + # Assert that the weight, bias, and quant_state are all on the CPU. + assert linear_nf4_layer.weight.device == torch.device("cpu") + assert linear_nf4_layer.bias.device == torch.device("cpu") + assert linear_nf4_layer.weight.quant_state.absmax.device == torch.device("cpu") + assert linear_nf4_layer.weight.quant_state.code.device == torch.device("cpu") + + # Assert that the quantized and custom layers produce the same output. + assert torch.allclose(y_quantized, y_custom, atol=1e-5) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_torch_module_autocast.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_torch_module_autocast.py new file mode 100644 index 00000000000..65b9f66066d --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_torch_module_autocast.py @@ -0,0 +1,132 @@ +import os + +import gguf +import pytest +import torch + +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( + apply_custom_layers_to_model, + remove_custom_layers_from_model, +) +from tests.backend.quantization.gguf.test_ggml_tensor import quantize_tensor + +try: + from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt, quantize_model_llm_int8 +except ImportError: + # This is expected to fail on MacOS + pass + +cuda_and_mps = pytest.mark.parametrize( + "device", + [ + pytest.param( + torch.device("cuda"), marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device") + ), + pytest.param( + torch.device("mps"), + marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="requires MPS device"), + ), + ], +) + + +class ModelWithLinearLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(32, 64) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + +@pytest.fixture(params=["none", "gguf"]) +def model(request: pytest.FixtureRequest) -> torch.nn.Module: + if request.param == "none": + return ModelWithLinearLayer() + elif request.param == "gguf": + # Initialize ModelWithLinearLayer and replace the linear layer weight with a GGML quantized weight. + model = ModelWithLinearLayer() + ggml_quantized_weight = quantize_tensor(model.linear.weight, gguf.GGMLQuantizationType.Q8_0) + model.linear.weight = torch.nn.Parameter(ggml_quantized_weight) + return model + else: + raise ValueError(f"Invalid quantization type: {request.param}") + + +@cuda_and_mps +@torch.no_grad() +def test_torch_module_autocast_linear_layer(device: torch.device, model: torch.nn.Module): + # Skip this test with MPS on GitHub Actions. It fails but I haven't taken the tie to figure out why. It passes + # locally on MacOS. + if os.environ.get("GITHUB_ACTIONS") == "true" and device.type == "mps": + pytest.skip("This test is flaky on GitHub Actions") + + # Model parameters should start off on the CPU. + assert all(p.device.type == "cpu" for p in model.parameters()) + + torch.manual_seed(0) + + # Run inference on the CPU. + x = torch.randn(1, 32, device="cpu") + expected = model(x) + assert expected.device.type == "cpu" + + # Apply the custom layers to the model. + apply_custom_layers_to_model(model) + + # Run the model on the device. + autocast_result = model(x.to(device)) + + # The model output should be on the device. + assert autocast_result.device.type == device.type + # The model parameters should still be on the CPU. + assert all(p.device.type == "cpu" for p in model.parameters()) + + # Remove the custom layers from the model. + remove_custom_layers_from_model(model) + + # After removing the custom layers, the model should no longer be able to run inference on the device. + with pytest.raises(RuntimeError): + _ = model(x.to(device)) + + # Run inference again on the CPU. + after_result = model(x) + + assert after_result.device.type == "cpu" + + # The results from all inference runs should be the same. + assert torch.allclose(autocast_result.to("cpu"), expected, atol=1e-5) + assert torch.allclose(after_result, expected, atol=1e-5) + + +@torch.no_grad() +def test_torch_module_autocast_bnb_llm_int8_linear_layer(): + if not torch.cuda.is_available(): + pytest.skip("requires CUDA device") + + torch.manual_seed(0) + + model = ModelWithLinearLayer() + model = quantize_model_llm_int8(model, modules_to_not_convert=set()) + # The act of moving the model to the CUDA device will trigger quantization. + model.to("cuda") + # Confirm that the layer is quantized. + assert isinstance(model.linear, InvokeLinear8bitLt) + assert model.linear.weight.CB is not None + assert model.linear.weight.SCB is not None + + # Run inference on the GPU. + x = torch.randn(1, 32) + expected = model(x.to("cuda")) + assert expected.device.type == "cuda" + + # Move the model back to the CPU and add the custom layers to the model. + model.to("cpu") + apply_custom_layers_to_model(model) + + # Run inference with weights being streamed to the GPU. + autocast_result = model(x.to("cuda")) + assert autocast_result.device.type == "cuda" + + # The results from all inference runs should be the same. + assert torch.allclose(autocast_result, expected, atol=1e-5) diff --git a/tests/backend/quantization/test_bnb_llm_int8.py b/tests/backend/quantization/test_bnb_llm_int8.py new file mode 100644 index 00000000000..481b809d03c --- /dev/null +++ b/tests/backend/quantization/test_bnb_llm_int8.py @@ -0,0 +1,85 @@ +import pytest +import torch + +try: + from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt +except ImportError: + pass + + +def test_invoke_linear_8bit_lt_quantization(): + """Test quantization with InvokeLinear8bitLt.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + # Set the seed for reproducibility since we are using a pretty tight atol. + torch.manual_seed(3) + + orig_layer = torch.nn.Linear(32, 64) + orig_layer_state_dict = orig_layer.state_dict() + + # Initialize a InvokeLinear8bitLt layer (it is not quantized yet). + quantized_layer = InvokeLinear8bitLt(input_features=32, output_features=64, has_fp16_weights=False) + + # Load the non-quantized layer's state dict into the quantized layer. + quantized_layer.load_state_dict(orig_layer_state_dict) + + # Move the InvokeLinear8bitLt layer to the GPU. This triggers quantization. + quantized_layer.to("cuda") + + # Assert that the InvokeLinear8bitLt layer is quantized. + assert quantized_layer.weight.CB is not None + assert quantized_layer.weight.SCB is not None + assert quantized_layer.weight.CB.dtype == torch.int8 + + # Run inference on both the original and quantized layers. + x = torch.randn(1, 32) + y = orig_layer(x) + y_quantized = quantized_layer(x.to("cuda")) + assert y.shape == y_quantized.shape + # All within ~20% of each other. + assert torch.allclose(y, y_quantized.to("cpu"), atol=0.05) + + +def test_invoke_linear_8bit_lt_state_dict_roundtrip(): + """Test that we can roundtrip the state dict of a quantized InvokeLinear8bitLt layer.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + # Set the seed for reproducibility since we are using a pretty tight atol. + torch.manual_seed(3) + + orig_layer = torch.nn.Linear(32, 64) + orig_layer_state_dict = orig_layer.state_dict() + + # Run inference on the original layer. + x = torch.randn(1, 32) + y = orig_layer(x) + + # Prepare a quantized InvokeLinear8bitLt layer. + quantized_layer_1 = InvokeLinear8bitLt(input_features=32, output_features=64, has_fp16_weights=False) + quantized_layer_1.load_state_dict(orig_layer_state_dict) + quantized_layer_1.to("cuda") + + # Assert that the InvokeLinear8bitLt layer is quantized. + assert quantized_layer_1.weight.CB is not None + assert quantized_layer_1.weight.SCB is not None + assert quantized_layer_1.weight.CB.dtype == torch.int8 + + # Run inference on the quantized layer. + y_quantized_1 = quantized_layer_1(x.to("cuda")) + + # Save the state dict of the quantized layer. + quantized_layer_1_state_dict = quantized_layer_1.state_dict() + + # Load the state dict of the quantized layer into a new quantized layer. + quantized_layer_2 = InvokeLinear8bitLt(input_features=32, output_features=64, has_fp16_weights=False) + quantized_layer_2.load_state_dict(quantized_layer_1_state_dict) + quantized_layer_2.to("cuda") + + # Run inference on the new quantized layer. + y_quantized_2 = quantized_layer_2(x.to("cuda")) + + # Assert that the inference results are the same. + assert torch.allclose(y, y_quantized_1.to("cpu"), atol=0.05) + assert torch.allclose(y_quantized_1, y_quantized_2, atol=1e-5)