Skip to content
This repository has been archived by the owner on Jan 3, 2024. It is now read-only.

Commit

Permalink
Update README.md (#9)
Browse files Browse the repository at this point in the history
* Update README.md

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
hmellor and pre-commit-ci[bot] authored Sep 15, 2023
1 parent a126fe8 commit 882825c
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 26 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ repos:
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
- id: isort
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
⚠️ The functionality provided by this repository can now be found in [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/) in the [to_superpixels()](https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/transforms/to_superpixels.html) function.

---

# PyTorch Superpixels

- [Why use superpixels?](#why-use-superpixels)
Expand Down
80 changes: 55 additions & 25 deletions example.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
from functools import partial
from multiprocessing import Pool
from os import cpu_count
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms.functional as F
from numpy.core.fromnumeric import product
from skimage.segmentation import find_boundaries, mark_boundaries, slic
from skimage.segmentation.boundaries import find_boundaries
import torch
import numpy as np
from torchvision.io import read_image
from torchvision.models.segmentation import fcn_resnet50
import matplotlib.pyplot as plt
from torchvision.transforms.functional import convert_image_dtype
from torchvision.utils import draw_segmentation_masks
from torchvision.utils import make_grid
from torchvision.utils import draw_segmentation_masks, make_grid

from pytorch_superpixels.runtime import superpixelise
from skimage.segmentation import slic, mark_boundaries, find_boundaries
from pathlib import Path
from multiprocessing import Pool
from os import cpu_count
from functools import partial

import torchvision.transforms.functional as F

def show(imgs):
if not isinstance(imgs, list):
Expand All @@ -32,9 +33,27 @@ def show(imgs):

if __name__ == "__main__":
sem_classes = [
'__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
"__background__",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor",
]
sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)}
image_dims = [420, 640]
Expand All @@ -46,24 +65,25 @@ def show(imgs):
batch = convert_image_dtype(batch_int, dtype=torch.float)

# permute because slic expects the last dimension to be channel
with Pool(processes = cpu_count()-1) as pool:
with Pool(processes=cpu_count() - 1) as pool:
# re-order axes for skimage
args = [x.permute(1,2,0) for x in batch]
args = [x.permute(1, 2, 0) for x in batch]
# 100 segments
kwargs = {"n_segments":100, "start_label":0, "slic_zero":True}
kwargs = {"n_segments": 100, "start_label": 0, "slic_zero": True}
func = partial(slic, **kwargs)
masks_100sp = pool.map(func, args)
# 1000 segments
kwargs["n_segments"] = 1000
func = partial(slic, **kwargs)
masks_1000sp = pool.map(func, args)


model = fcn_resnet50(pretrained=True, progress=False)
model = model.eval()

normalized_batch = F.normalize(batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
outputs = model(batch)['out']
normalized_batch = F.normalize(
batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
)
outputs = model(batch)["out"]

normalized_masks = torch.nn.functional.softmax(outputs, dim=1)
num_classes = normalized_masks.shape[1]
Expand All @@ -73,23 +93,33 @@ def generate_all_class_masks(outputs, masks):
masks = torch.from_numpy(masks)
outputs_sp = superpixelise(outputs, masks)
normalized_masks_sp = torch.nn.functional.softmax(outputs_sp, dim=1)
return normalized_masks_sp[i].argmax(0) == torch.arange(num_classes)[:, None, None]
return (
normalized_masks_sp[i].argmax(0) == torch.arange(num_classes)[:, None, None]
)

to_show = []
for i, image in enumerate(images):
# before
all_classes_masks = normalized_masks[i].argmax(0) == torch.arange(num_classes)[:, None, None]
to_show.append(draw_segmentation_masks(image, masks=all_classes_masks, alpha=.6))
all_classes_masks = (
normalized_masks[i].argmax(0) == torch.arange(num_classes)[:, None, None]
)
to_show.append(
draw_segmentation_masks(image, masks=all_classes_masks, alpha=0.6)
)
# after 100
all_classes_masks_sp = generate_all_class_masks(outputs, masks_100sp)
to_show.append(draw_segmentation_masks(image, masks=all_classes_masks_sp, alpha=.6))
to_show.append(
draw_segmentation_masks(image, masks=all_classes_masks_sp, alpha=0.6)
)
# show superpixel boundaries
boundaries = find_boundaries(masks_100sp[i])
to_show[-1][0:2, boundaries] = 255
to_show[-1][2, boundaries] = 0
# after 1000
all_classes_masks_sp = generate_all_class_masks(outputs, masks_1000sp)
to_show.append(draw_segmentation_masks(image, masks=all_classes_masks_sp, alpha=.6))
to_show.append(
draw_segmentation_masks(image, masks=all_classes_masks_sp, alpha=0.6)
)
# show superpixel boundaries
boundaries = find_boundaries(masks_1000sp[i])
to_show[-1][0:2, boundaries] = 255
Expand Down

0 comments on commit 882825c

Please sign in to comment.