diff --git a/invokeai/backend/image_util/hed.py b/invokeai/backend/image_util/hed.py index ec12c26b2e3..a2d3449f650 100644 --- a/invokeai/backend/image_util/hed.py +++ b/invokeai/backend/image_util/hed.py @@ -18,6 +18,7 @@ resize_image_to_resolution, safe_step, ) +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device class DoubleConvBlock(torch.nn.Module): @@ -109,7 +110,7 @@ def run( Returns: The detected edges. """ - device = next(iter(self.network.parameters())).device + device = get_effective_device(self.network) np_image = pil_to_np(input_image) np_image = normalize_image_channel_count(np_image) np_image = resize_image_to_resolution(np_image, detect_resolution) @@ -183,7 +184,7 @@ def run(self, image: Image.Image, safe: bool = False, scribble: bool = False) -> The detected edges. """ - device = next(iter(self.model.parameters())).device + device = get_effective_device(self.model) np_image = pil_to_np(image) diff --git a/invokeai/backend/image_util/infill_methods/lama.py b/invokeai/backend/image_util/infill_methods/lama.py index cd5838d1f2b..faf25e44a49 100644 --- a/invokeai/backend/image_util/infill_methods/lama.py +++ b/invokeai/backend/image_util/infill_methods/lama.py @@ -7,6 +7,7 @@ import invokeai.backend.util.logging as logger from invokeai.backend.model_manager.config import AnyModel +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device def norm_img(np_img): @@ -31,7 +32,7 @@ def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any: mask = norm_img(mask) mask = (mask > 0) * 1 - device = next(self._model.buffers()).device + device = get_effective_device(self._model) image = torch.from_numpy(image).unsqueeze(0).to(device) mask = torch.from_numpy(mask).unsqueeze(0).to(device) diff --git a/invokeai/backend/image_util/lineart.py b/invokeai/backend/image_util/lineart.py index 8fcca24b0e0..bfef6f6da08 100644 --- a/invokeai/backend/image_util/lineart.py +++ b/invokeai/backend/image_util/lineart.py @@ -17,6 +17,7 @@ pil_to_np, resize_image_to_resolution, ) +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device class ResidualBlock(nn.Module): @@ -130,7 +131,7 @@ def run( Returns: The detected lineart. """ - device = next(iter(self.model.parameters())).device + device = get_effective_device(self.model) np_image = pil_to_np(input_image) np_image = normalize_image_channel_count(np_image) @@ -201,7 +202,7 @@ def run(self, image: Image.Image) -> Image.Image: Returns: The detected edges. """ - device = next(iter(self.model.parameters())).device + device = get_effective_device(self.model) np_image = pil_to_np(image) diff --git a/invokeai/backend/image_util/lineart_anime.py b/invokeai/backend/image_util/lineart_anime.py index 09dcb6655e3..fa406cf1d4b 100644 --- a/invokeai/backend/image_util/lineart_anime.py +++ b/invokeai/backend/image_util/lineart_anime.py @@ -19,6 +19,7 @@ pil_to_np, resize_image_to_resolution, ) +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device class UnetGenerator(nn.Module): @@ -171,7 +172,7 @@ def run(self, input_image: Image.Image, detect_resolution: int = 512, image_reso Returns: The detected lineart. """ - device = next(iter(self.model.parameters())).device + device = get_effective_device(self.model) np_image = pil_to_np(input_image) np_image = normalize_image_channel_count(np_image) @@ -239,7 +240,7 @@ def to(self, device: torch.device): def run(self, image: Image.Image) -> Image.Image: """Processes an image and returns the detected edges.""" - device = next(iter(self.model.parameters())).device + device = get_effective_device(self.model) np_image = pil_to_np(image) diff --git a/invokeai/backend/image_util/mlsd/utils.py b/invokeai/backend/image_util/mlsd/utils.py index dbe9a98d09e..dbadce01a4f 100644 --- a/invokeai/backend/image_util/mlsd/utils.py +++ b/invokeai/backend/image_util/mlsd/utils.py @@ -14,6 +14,8 @@ import torch from torch.nn import functional as F +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device + def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5): ''' @@ -49,7 +51,7 @@ def pred_lines(image, model, dist_thr=20.0): h, w, _ = image.shape - device = next(iter(model.parameters())).device + device = get_effective_device(model) h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]] resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA), @@ -108,7 +110,7 @@ def pred_squares(image, ''' h, w, _ = image.shape original_shape = [h, w] - device = next(iter(model.parameters())).device + device = get_effective_device(model) resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA), np.ones([input_shape[0], input_shape[1], 1])], axis=-1) diff --git a/invokeai/backend/image_util/normal_bae/__init__.py b/invokeai/backend/image_util/normal_bae/__init__.py index d0b1339113e..5ad221ecd4a 100644 --- a/invokeai/backend/image_util/normal_bae/__init__.py +++ b/invokeai/backend/image_util/normal_bae/__init__.py @@ -13,6 +13,7 @@ from invokeai.backend.image_util.normal_bae.nets.NNET import NNET from invokeai.backend.image_util.util import np_to_pil, pil_to_np, resize_to_multiple +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device class NormalMapDetector: @@ -64,7 +65,7 @@ def to(self, device: torch.device): def run(self, image: Image.Image): """Processes an image and returns the detected normal map.""" - device = next(iter(self.model.parameters())).device + device = get_effective_device(self.model) np_image = pil_to_np(image) height, width, _channels = np_image.shape diff --git a/invokeai/backend/image_util/pidi/__init__.py b/invokeai/backend/image_util/pidi/__init__.py index 8673b219140..63df7b6058e 100644 --- a/invokeai/backend/image_util/pidi/__init__.py +++ b/invokeai/backend/image_util/pidi/__init__.py @@ -11,6 +11,7 @@ from invokeai.backend.image_util.pidi.model import PiDiNet, pidinet from invokeai.backend.image_util.util import nms, normalize_image_channel_count, np_to_pil, pil_to_np, safe_step +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device class PIDINetDetector: @@ -45,7 +46,7 @@ def run( ) -> Image.Image: """Processes an image and returns the detected edges.""" - device = next(iter(self.model.parameters())).device + device = get_effective_device(self.model) np_img = pil_to_np(image) np_img = normalize_image_channel_count(np_img) diff --git a/invokeai/backend/model_manager/load/model_cache/utils.py b/invokeai/backend/model_manager/load/model_cache/utils.py new file mode 100644 index 00000000000..2b581990c69 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/utils.py @@ -0,0 +1,20 @@ +import itertools + +import torch + + +def get_effective_device(model: torch.nn.Module) -> torch.device: + """A utility to infer the 'effective' device of a model. + + This utility handles the case where a model is partially loaded onto the GPU, so is safer than just calling: + `next(iter(model.parameters())).device`. + + In the worst case, this utility has to check all model parameters, so if you already know the intended model device, + then it is better to avoid calling this function. + """ + # If all parameters are on the CPU, return the CPU device. Otherwise, return the first non-CPU device. + for p in itertools.chain(model.parameters(), model.buffers()): + if p.device.type != "cpu": + return p.device + + return torch.device("cpu")