forked from scaelles/DEXTR-KerasTensorflow
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
17 changed files
with
1,055 additions
and
206 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -103,3 +103,5 @@ ENV/ | |
|
||
# mypy | ||
.mypy_cache/ | ||
|
||
/models/*.h5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
FROM python:3.6 | ||
|
||
WORKDIR /temp/ | ||
ADD requirements.txt /temp/ | ||
RUN pip install -r requirements.txt && \ | ||
pip install pycocotools | ||
|
||
WORKDIR /workspace/ | ||
|
||
EXPOSE 8888 | ||
CMD jupyter notebook --port=8888 --ip=0.0.0.0 --no-browser --allow-root |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .helpers import * | ||
from .resnet import * | ||
from .classifiers import * | ||
from .model import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
#!/usr/bin/env python | ||
from os.path import splitext, join | ||
import numpy as np | ||
from scipy import misc | ||
from keras import backend as K | ||
|
||
import tensorflow as tf | ||
from .resnet import build_network | ||
from .helpers import * | ||
|
||
|
||
class DEXTR(object): | ||
"""Pyramid Scene Parsing Network by Hengshuang Zhao et al 2017""" | ||
|
||
def __init__(self, nb_classes, resnet_layers, input_shape, num_input_channels=4, | ||
classifier='psp', weights_path='models/dextr_pascal-sbd.h5', sigmoid=False): | ||
self.input_shape = input_shape | ||
self.num_input_channels = num_input_channels | ||
self.sigmoid = sigmoid | ||
self.model = build_network(nb_classes=nb_classes, resnet_layers=resnet_layers, num_input_channels=num_input_channels, | ||
input_shape=self.input_shape, classifier=classifier, sigmoid=self.sigmoid, output_size=self.input_shape) | ||
|
||
self.model.load_weights(weights_path) | ||
|
||
def feed_forward(self, data): | ||
|
||
assert data.shape == (self.input_shape[0], self.input_shape[1], self.num_input_channels) | ||
prediction = self.model.predict(np.expand_dims(data, 0))[0] | ||
|
||
return prediction | ||
|
||
def predict_mask(self, image, points, pad=50, threshold=0.8, zero_pad=True): | ||
points = np.array(points).astype(np.int) | ||
image = np.array(image) | ||
bbox = get_bbox(image, points=points, pad=pad, zero_pad=zero_pad) | ||
crop_image = crop_from_bbox(image, bbox, zero_pad=zero_pad) | ||
resize_image = fixed_resize(crop_image, (512, 512)).astype(np.float32) | ||
|
||
# Generate extreme point heat map normalized to image values | ||
extreme_points = points - [np.min(points[:, 0]), np.min(points[:, 1])] + [pad , pad] | ||
extreme_points = (512 * extreme_points * [1 / crop_image.shape[1], 1 / crop_image.shape[0]]).astype(np.int) | ||
extreme_heatmap = make_gt(resize_image, extreme_points, sigma=10) | ||
extreme_heatmap = cstm_normalize(extreme_heatmap, 255) | ||
|
||
# Concatenate inputs and convert to tensor | ||
input_dextr = np.concatenate((resize_image, extreme_heatmap[:, :, np.newaxis]), axis=2) | ||
|
||
pred = self.model.predict(input_dextr[np.newaxis, ...])[0, :, :, 0] | ||
result = crop2fullmask(pred, bbox, im_size=image.shape[:2], zero_pad=zero_pad, relax=pad) > threshold | ||
|
||
return result | ||
|
||
def predict(self, img): | ||
# Preprocess | ||
img = misc.imresize(img, self.input_shape) | ||
img = img.astype('float32') | ||
probs = self.feed_forward(img) | ||
return probs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
version: "3" | ||
|
||
services: | ||
example: | ||
ports: | ||
- "8888:8888" | ||
volumes: | ||
- ./models:/workspace/models | ||
- ./examples:/workspace/notebooks | ||
- ./dextr:/workspace/libs/dextr | ||
- ./imgs:/workspace/imgs | ||
build: . |
Oops, something went wrong.