Skip to content

Commit

Permalink
Add get_effective_device(...) utility to aid in determining the effec…
Browse files Browse the repository at this point in the history
…tive device of models that are partially loaded.
  • Loading branch information
RyanJDick committed Dec 31, 2024
1 parent c8b4f2f commit bbc078a
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 11 deletions.
5 changes: 3 additions & 2 deletions invokeai/backend/image_util/hed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
resize_image_to_resolution,
safe_step,
)
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device


class DoubleConvBlock(torch.nn.Module):
Expand Down Expand Up @@ -109,7 +110,7 @@ def run(
Returns:
The detected edges.
"""
device = next(iter(self.network.parameters())).device
device = get_effective_device(self.network)
np_image = pil_to_np(input_image)
np_image = normalize_image_channel_count(np_image)
np_image = resize_image_to_resolution(np_image, detect_resolution)
Expand Down Expand Up @@ -183,7 +184,7 @@ def run(self, image: Image.Image, safe: bool = False, scribble: bool = False) ->
The detected edges.
"""

device = next(iter(self.model.parameters())).device
device = get_effective_device(self.model)

np_image = pil_to_np(image)

Expand Down
3 changes: 2 additions & 1 deletion invokeai/backend/image_util/infill_methods/lama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import invokeai.backend.util.logging as logger
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device


def norm_img(np_img):
Expand All @@ -31,7 +32,7 @@ def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
mask = norm_img(mask)
mask = (mask > 0) * 1

device = next(self._model.buffers()).device
device = get_effective_device(self._model)
image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device)

Expand Down
5 changes: 3 additions & 2 deletions invokeai/backend/image_util/lineart.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
pil_to_np,
resize_image_to_resolution,
)
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device


class ResidualBlock(nn.Module):
Expand Down Expand Up @@ -130,7 +131,7 @@ def run(
Returns:
The detected lineart.
"""
device = next(iter(self.model.parameters())).device
device = get_effective_device(self.model)

np_image = pil_to_np(input_image)
np_image = normalize_image_channel_count(np_image)
Expand Down Expand Up @@ -201,7 +202,7 @@ def run(self, image: Image.Image) -> Image.Image:
Returns:
The detected edges.
"""
device = next(iter(self.model.parameters())).device
device = get_effective_device(self.model)

np_image = pil_to_np(image)

Expand Down
5 changes: 3 additions & 2 deletions invokeai/backend/image_util/lineart_anime.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
pil_to_np,
resize_image_to_resolution,
)
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device


class UnetGenerator(nn.Module):
Expand Down Expand Up @@ -171,7 +172,7 @@ def run(self, input_image: Image.Image, detect_resolution: int = 512, image_reso
Returns:
The detected lineart.
"""
device = next(iter(self.model.parameters())).device
device = get_effective_device(self.model)
np_image = pil_to_np(input_image)

np_image = normalize_image_channel_count(np_image)
Expand Down Expand Up @@ -239,7 +240,7 @@ def to(self, device: torch.device):

def run(self, image: Image.Image) -> Image.Image:
"""Processes an image and returns the detected edges."""
device = next(iter(self.model.parameters())).device
device = get_effective_device(self.model)

np_image = pil_to_np(image)

Expand Down
6 changes: 4 additions & 2 deletions invokeai/backend/image_util/mlsd/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import torch
from torch.nn import functional as F

from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device


def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5):
'''
Expand Down Expand Up @@ -49,7 +51,7 @@ def pred_lines(image, model,
dist_thr=20.0):
h, w, _ = image.shape

device = next(iter(model.parameters())).device
device = get_effective_device(model)
h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]]

resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA),
Expand Down Expand Up @@ -108,7 +110,7 @@ def pred_squares(image,
'''
h, w, _ = image.shape
original_shape = [h, w]
device = next(iter(model.parameters())).device
device = get_effective_device(model)

resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA),
np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
Expand Down
3 changes: 2 additions & 1 deletion invokeai/backend/image_util/normal_bae/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from invokeai.backend.image_util.normal_bae.nets.NNET import NNET
from invokeai.backend.image_util.util import np_to_pil, pil_to_np, resize_to_multiple
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device


class NormalMapDetector:
Expand Down Expand Up @@ -64,7 +65,7 @@ def to(self, device: torch.device):
def run(self, image: Image.Image):
"""Processes an image and returns the detected normal map."""

device = next(iter(self.model.parameters())).device
device = get_effective_device(self.model)
np_image = pil_to_np(image)

height, width, _channels = np_image.shape
Expand Down
3 changes: 2 additions & 1 deletion invokeai/backend/image_util/pidi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from invokeai.backend.image_util.pidi.model import PiDiNet, pidinet
from invokeai.backend.image_util.util import nms, normalize_image_channel_count, np_to_pil, pil_to_np, safe_step
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device


class PIDINetDetector:
Expand Down Expand Up @@ -45,7 +46,7 @@ def run(
) -> Image.Image:
"""Processes an image and returns the detected edges."""

device = next(iter(self.model.parameters())).device
device = get_effective_device(self.model)

np_img = pil_to_np(image)
np_img = normalize_image_channel_count(np_img)
Expand Down
20 changes: 20 additions & 0 deletions invokeai/backend/model_manager/load/model_cache/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import itertools

import torch


def get_effective_device(model: torch.nn.Module) -> torch.device:
"""A utility to infer the 'effective' device of a model.
This utility handles the case where a model is partially loaded onto the GPU, so is safer than just calling:
`next(iter(model.parameters())).device`.
In the worst case, this utility has to check all model parameters, so if you already know the intended model device,
then it is better to avoid calling this function.
"""
# If all parameters are on the CPU, return the CPU device. Otherwise, return the first non-CPU device.
for p in itertools.chain(model.parameters(), model.buffers()):
if p.device.type != "cpu":
return p.device

return torch.device("cpu")

0 comments on commit bbc078a

Please sign in to comment.