Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partial Loading PR2: Add utils to support partial loading of models from CPU to GPU #7494

Merged
merged 13 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import Any

import torch


class CachedModelOnlyFullLoad:
"""A wrapper around a PyTorch model to handle full loads and unloads between the CPU and the compute device.
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
MPS memory, etc.
"""

def __init__(self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int):
"""Initialize a CachedModelOnlyFullLoad.
Args:
model (torch.nn.Module | Any): The model to wrap. Should be on the CPU.
compute_device (torch.device): The compute device to move the model to.
total_bytes (int): The total size (in bytes) of all the weights in the model.
"""
# model is often a torch.nn.Module, but could be any model type. Throughout this class, we handle both cases.
self._model = model
self._compute_device = compute_device
self._offload_device = torch.device("cpu")

# A CPU read-only copy of the model's state dict.
self._cpu_state_dict: dict[str, torch.Tensor] | None = None
if isinstance(model, torch.nn.Module):
self._cpu_state_dict = model.state_dict()

self._total_bytes = total_bytes
self._is_in_vram = False

@property
def model(self) -> torch.nn.Module:
return self._model

def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None:
"""Get a read-only copy of the model's state dict in RAM."""
# TODO(ryand): Document this better.
return self._cpu_state_dict

def total_bytes(self) -> int:
"""Get the total size (in bytes) of all the weights in the model."""
return self._total_bytes

def cur_vram_bytes(self) -> int:
"""Get the size (in bytes) of the weights that are currently in VRAM."""
if self._is_in_vram:
return self._total_bytes
else:
return 0

def is_in_vram(self) -> bool:
"""Return true if the model is currently in VRAM."""
return self._is_in_vram

def full_load_to_vram(self) -> int:
"""Load all weights into VRAM (if supported by the model).
Returns:
The number of bytes loaded into VRAM.
"""
if self._is_in_vram:
# Already in VRAM.
return 0

if not hasattr(self._model, "to"):
# Model doesn't support moving to a device.
return 0

if self._cpu_state_dict is not None:
new_state_dict: dict[str, torch.Tensor] = {}
for k, v in self._cpu_state_dict.items():
new_state_dict[k] = v.to(self._compute_device, copy=True)
self._model.load_state_dict(new_state_dict, assign=True)
self._model.to(self._compute_device)

self._is_in_vram = True
return self._total_bytes

def full_unload_from_vram(self) -> int:
"""Unload all weights from VRAM.
Returns:
The number of bytes unloaded from VRAM.
"""
if not self._is_in_vram:
# Already in RAM.
return 0

if self._cpu_state_dict is not None:
self._model.load_state_dict(self._cpu_state_dict, assign=True)
self._model.to(self._offload_device)

self._is_in_vram = False
return self._total_bytes
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import torch

from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
AUTOCAST_MODULE_TYPE_MAPPING,
apply_custom_layers_to_model,
remove_custom_layers_from_model,
)
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
from invokeai.backend.util.logging import InvokeAILogger


def set_nested_attr(obj: object, attr: str, value: object):
"""A helper function that extends setattr() to support nested attributes.

Example:
set_nested_attr(model, "module.encoder.conv1.weight", new_conv1_weight)
"""
attrs = attr.split(".")
for attr in attrs[:-1]:
obj = getattr(obj, attr)
setattr(obj, attrs[-1], value)


class CachedModelWithPartialLoad:
"""A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device.

Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
MPS memory, etc.
"""

def __init__(self, model: torch.nn.Module, compute_device: torch.device):
self._model = model
self._compute_device = compute_device

# A CPU read-only copy of the model's state dict.
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()

# TODO(ryand): Handle the case where the model sizes changes after initial load (e.g. due to dtype casting).
# Consider how we should handle this for both self._total_bytes and self._cur_vram_bytes.
self._total_bytes = sum(calc_tensor_size(p) for p in self._cpu_state_dict.values())
self._cur_vram_bytes: int | None = None

self._modules_that_support_autocast = self._find_modules_that_support_autocast()
self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast()

def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]:
"""Find all modules that support autocasting."""
return {n: m for n, m in self._model.named_modules() if type(m) in AUTOCAST_MODULE_TYPE_MAPPING}

def _find_keys_in_modules_that_do_not_support_autocast(self) -> set[str]:
keys_in_modules_that_do_not_support_autocast = set()
for key in self._cpu_state_dict.keys():
for module_name in self._modules_that_support_autocast.keys():
if key.startswith(module_name):
break
else:
keys_in_modules_that_do_not_support_autocast.add(key)
return keys_in_modules_that_do_not_support_autocast

def _move_non_persistent_buffers_to_device(self, device: torch.device):
"""Move the non-persistent buffers to the target device. These buffers are not included in the state dict,
so we need to move them manually.
"""
# HACK(ryand): Typically, non-persistent buffers are moved when calling module.to(device). We don't move entire
# modules, because we manage the devices of individual tensors using the state dict. Since non-persistent
# buffers are not included in the state dict, we need to handle them manually. The only way to do this is by
# using private torch.nn.Module attributes.
for module in self._model.modules():
for name, buffer in module.named_buffers():
if name in module._non_persistent_buffers_set:
module._buffers[name] = buffer.to(device, copy=True)

@property
def model(self) -> torch.nn.Module:
return self._model

def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None:
"""Get a read-only copy of the model's state dict in RAM."""
# TODO(ryand): Document this better.
return self._cpu_state_dict

def total_bytes(self) -> int:
"""Get the total size (in bytes) of all the weights in the model."""
return self._total_bytes

def cur_vram_bytes(self) -> int:
"""Get the size (in bytes) of the weights that are currently in VRAM."""
if self._cur_vram_bytes is None:
cur_state_dict = self._model.state_dict()
self._cur_vram_bytes = sum(
calc_tensor_size(p) for p in cur_state_dict.values() if p.device.type == self._compute_device.type
)
return self._cur_vram_bytes

def full_load_to_vram(self) -> int:
"""Load all weights into VRAM."""
return self.partial_load_to_vram(self.total_bytes())

def full_unload_from_vram(self) -> int:
"""Unload all weights from VRAM."""
return self.partial_unload_from_vram(self.total_bytes())

@torch.no_grad()
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
"""Load more weights into VRAM without exceeding vram_bytes_to_load.

Returns:
The number of bytes loaded into VRAM.
"""
# TODO(ryand): Handle the case where an exception is thrown while loading or unloading weights. At the very
# least, we should reset self._cur_vram_bytes to None.

vram_bytes_loaded = 0

cur_state_dict = self._model.state_dict()

# First, process the keys *must* be loaded into VRAM.
for key in self._keys_in_modules_that_do_not_support_autocast:
param = cur_state_dict[key]
if param.device.type == self._compute_device.type:
continue

param_size = calc_tensor_size(param)
cur_state_dict[key] = param.to(self._compute_device, copy=True)
vram_bytes_loaded += param_size

if vram_bytes_loaded > vram_bytes_to_load:
logger = InvokeAILogger.get_logger()
logger.warning(
f"Loaded {vram_bytes_loaded / 2**20} MB into VRAM, but only {vram_bytes_to_load / 2**20} MB were "
"requested. This is the minimum set of weights in VRAM required to run the model."
)

# Next, process the keys that can optionally be loaded into VRAM.
fully_loaded = True
for key, param in cur_state_dict.items():
if param.device.type == self._compute_device.type:
continue

param_size = calc_tensor_size(param)
if vram_bytes_loaded + param_size > vram_bytes_to_load:
# TODO(ryand): Should we just break here? If we couldn't fit this parameter into VRAM, is it really
# worth continuing to search for a smaller parameter that would fit?
fully_loaded = False
continue

cur_state_dict[key] = param.to(self._compute_device, copy=True)
vram_bytes_loaded += param_size

if vram_bytes_loaded > 0:
# We load the entire state dict, not just the parameters that changed, in case there are modules that
# override _load_from_state_dict() and do some funky stuff that requires the entire state dict.
# Alternatively, in the future, grouping parameters by module could probably solve this problem.
self._model.load_state_dict(cur_state_dict, assign=True)

if self._cur_vram_bytes is not None:
self._cur_vram_bytes += vram_bytes_loaded

if fully_loaded:
remove_custom_layers_from_model(self._model)
# TODO(ryand): Warn if the self.cur_vram_bytes() and self.total_bytes() are out of sync.
else:
apply_custom_layers_to_model(self._model)

# Move all non-persistent buffers to the compute device. These are a weird edge case and do not participate in
# the vram_bytes_loaded tracking.
self._move_non_persistent_buffers_to_device(self._compute_device)

return vram_bytes_loaded

@torch.no_grad()
def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int:
"""Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded.

Returns:
The number of bytes unloaded from VRAM.
"""
vram_bytes_freed = 0

offload_device = "cpu"
cur_state_dict = self._model.state_dict()
for key, param in cur_state_dict.items():
if vram_bytes_freed >= vram_bytes_to_free:
break

if param.device.type == offload_device:
continue

cur_state_dict[key] = self._cpu_state_dict[key]
vram_bytes_freed += calc_tensor_size(param)

if vram_bytes_freed > 0:
self._model.load_state_dict(cur_state_dict, assign=True)

if self._cur_vram_bytes is not None:
self._cur_vram_bytes -= vram_bytes_freed

# We may have gone from a fully-loaded model to a partially-loaded model, so we need to reapply the custom
# layers.
apply_custom_layers_to_model(self._model)
return vram_bytes_freed
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch

from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device

# This file contains custom torch.nn.Module classes that support streaming of weights to the target device.
# Each class sub-classes the original module type that is is replacing, so the following properties are preserved:
# - isinstance(m, torch.nn.OrginalModule) should still work.
# - Patching the weights (e.g. for LoRA) should still work if non-quantized.


class CustomLinear(torch.nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return torch.nn.functional.linear(input, weight, bias)


class CustomConv1d(torch.nn.Conv1d):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return self._conv_forward(input, weight, bias)


class CustomConv2d(torch.nn.Conv2d):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return self._conv_forward(input, weight, bias)


class CustomGroupNorm(torch.nn.GroupNorm):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)


class CustomEmbedding(torch.nn.Embedding):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
return torch.nn.functional.embedding(
input,
weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import TypeVar

import torch

T = TypeVar("T", torch.Tensor, None, torch.Tensor | None)


def cast_to_device(t: T, to_device: torch.device) -> T:
"""Helper function to cast an optional tensor to a target device."""
if t is None:
return t

if t.device.type != to_device.type:
return t.to(to_device)
return t
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import bitsandbytes as bnb
import torch

from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt


class CustomInvokeLinear8bitLt(InvokeLinear8bitLt):
def forward(self, x: torch.Tensor) -> torch.Tensor:
matmul_state = bnb.MatmulLtState()
matmul_state.threshold = self.state.threshold
matmul_state.has_fp16_weights = self.state.has_fp16_weights
matmul_state.use_pool = self.state.use_pool
matmul_state.is_training = self.training
# The underlying InvokeInt8Params weight must already be quantized.
assert self.weight.CB is not None
matmul_state.CB = cast_to_device(self.weight.CB, x.device)
matmul_state.SCB = cast_to_device(self.weight.SCB, x.device)

# weights are cast automatically as Int8Params, but the bias has to be cast manually.
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)

# NOTE(ryand): The second parameter should not be needed at all given our expected inference configuration, but
# it's dtype field must be accessible, even though it's not used. We pass in self.weight even though it could be
# on the wrong device.
return bnb.matmul(x, self.weight, bias=cast_to_device(self.bias, x.device), state=matmul_state)
Loading
Loading