Skip to content

Commit

Permalink
precompute SAM scale
Browse files Browse the repository at this point in the history
  • Loading branch information
nilsmechtel committed Feb 3, 2025
1 parent 0670be1 commit c1b09df
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 14 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ tmp/
output/
.ipynb_checkpoints
.DS_Store
*/__MACOSX
*.pt
*.pth
*zip
Expand Down
40 changes: 27 additions & 13 deletions bioimageio_colab/models/sam_image_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def _to_image_format(self, array: np.ndarray) -> np.ndarray:
if not isinstance(array, np.ndarray):
array = np.array(array)

# Convert grayscale image to RGB
if array.ndim == 2:
# Convert grayscale image to RGB
array = np.concatenate([array[..., None]] * 3, axis=-1)
Expand All @@ -54,16 +53,28 @@ def _to_image_format(self, array: np.ndarray) -> np.ndarray:

return array

def _preprocess(self, array: np.array):
def _preprocess(self, array: np.array) -> tuple:
"""
Preprocess the input image before feeding it to the model.
Args:
array (np.ndarray): Input image in either 2-channel grayscale (H, W) or 3-channel RGB (H, W, 3) format.
Returns:
tuple: A tuple containing the following elements:
- input_image_torch (torch.Tensor): The preprocessed input image as a torch tensor (BCHW)
- original_image_shape (tuple): The size of the original input image (H, W)
- sam_scale (float): The scale factor used to resize the image to the correct size
"""
import torch

# Validate image shape and dtype
# Convert input to image format (HWC, uint8)
original_image = self._to_image_format(array)

input_image = self._transform.apply_image(original_image) # input: np.array

input_image_torch = torch.as_tensor(input_image, device=self.device)

original_image_shape = original_image.shape[:2]
# Resize the image to the correct size
resized_image = self._transform.apply_image(original_image)
sam_scale = self._transform.target_length / max(original_image_shape)
# Convert the image to a torch tensor
input_image_torch = torch.as_tensor(resized_image, device=self.device)
# Ensure the input tensor is in the correct shape (BCHW)
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[
None, :, :, :
Expand All @@ -78,7 +89,7 @@ def _preprocess(self, array: np.array):
# Preprocess the image before feeding it to the model
input_image_torch = self._normalize_and_pad(input_image_torch)

return input_image_torch
return input_image_torch, original_image_shape, sam_scale

def encode(self, array: np.ndarray) -> dict:
"""
Expand All @@ -95,14 +106,17 @@ def encode(self, array: np.ndarray) -> dict:
import torch

# Preprocess the input image
input_image = self._preprocess(array)
input_size = tuple(input_image.shape[-2:])
input_image_torch, original_image_shape, sam_scale = self._preprocess(array)

# Run inference
with torch.no_grad():
features = self.image_encoder(input_image)
features = self.image_encoder(input_image_torch)

return {"features": features.cpu().numpy(), "input_size": input_size}
return {
"features": features.cpu().numpy(),
"original_image_shape": original_image_shape,
"sam_scale": sam_scale,
}


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ requires = ["setuptools", "wheel"]

[project]
name = "bioimageio-colab"
version = "0.2.8"
version = "0.2.9"
readme = "README.md"
description = "Collaborative image annotation and model training with human in the loop."
dependencies = [
Expand Down

0 comments on commit c1b09df

Please sign in to comment.