Skip to content

Commit

Permalink
different colours for centroids + visualisation utils for devograph (#87
Browse files Browse the repository at this point in the history
)

* Delete .gitignore

* added files

* add files

* Update cell_membrane_segmentor.py

* Update cell_membrane_segmentor.py

---------

Co-authored-by: Apple <[email protected]>
  • Loading branch information
sushmanthreddy and Apple authored Mar 20, 2023
1 parent a3eca33 commit 122e172
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 13 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ devolearn.egg-info/
dist/
centroids.csv
.vscode/
.venv/
.venv/
.DS_Store/
37 changes: 25 additions & 12 deletions devolearn/cell_membrane_segmentor/cell_membrane_segmentor.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@
3d segmentation model for C elegans embryo
"""

def generate_centroid_image(thresh):
def generate_centroid_image(thresh, color_mode=False):
"""Used when centroid_mode is set to True
Args:
thresh (np.array): 2d numpy array that is returned from the segmentation model
color_mode (bool, optional): If True, returns a 3 channel colored image. Dafaults to False.
Returns:
np.array : image containing the contours and their respective centroids
Expand All @@ -39,7 +40,11 @@ def generate_centroid_image(thresh):

thresh = cv2.blur(thresh, (5,5))
thresh = thresh.astype(np.uint8)
centroid_image = np.zeros(thresh.shape)
if color_mode == False:
centroid_image = np.zeros(thresh.shape)
else:
centroid_image = np.zeros((thresh.shape[0], thresh.shape[1], 3))

cnts = cv2.findContours(thresh, cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)
cnts = imutils.grab_contours(cnts)
centroids = []
Expand All @@ -50,8 +55,12 @@ def generate_centroid_image(thresh):
cX = int(M["m10"] / M["m00"])
cY = int(M["m01"] / M["m00"])
# draw the contour and center of the shape on the image
cv2.drawContours(centroid_image, [c], -1, (255, 255, 255), 2)
cv2.circle(centroid_image, (cX, cY), 2, (255, 255, 255), -1)
if color_mode == False:
cv2.drawContours(centroid_image, [c], -1, (255, 255, 255), 2)
cv2.circle(centroid_image, (cX, cY), 2, (255, 255, 255), -1)
else:
cv2.drawContours(centroid_image, [c], -1, (0, 0, 255), 2) #blue
cv2.circle(centroid_image, (cX, cY), 2, (0, 255, 0), -1) #green
centroids.append((cX, cY))
except:
pass
Expand Down Expand Up @@ -115,7 +124,7 @@ def preprocess(self, image_grayscale_numpy):
tensor = self.mini_transform(image_grayscale_numpy).unsqueeze(0).to(self.device)
return tensor

def predict(self, image_path, pred_size = (350,250), centroid_mode = False):
def predict(self, image_path, pred_size = (350,250), centroid_mode = False, color_mode = False):
"""Loads an image from image_path and converts it to grayscale,
then passes it through the model and returns centroids of the segmented features.
reference{
Expand All @@ -131,8 +140,12 @@ def predict(self, image_path, pred_size = (350,250), centroid_mode = False):
centroid_mode set to False:
np.array : 1 channel image.
centroid_mode set to True:
np.array : 1 channel image,
list : list of centroids.
color_mode set to False:
np.array : 1 channel image,
list : list of centroids.
color_mode set to True:
np.array : 3 channel image,
list : list of centroids.
"""

im = cv2.imread(image_path,0)
Expand All @@ -149,11 +162,10 @@ def predict(self, image_path, pred_size = (350,250), centroid_mode = False):
if centroid_mode == False:
return res
else:
centroid_image, centroids = generate_centroid_image(res)
centroid_image, centroids = generate_centroid_image(res, color_mode = color_mode)
return centroid_image, centroids


def predict_from_video(self, video_path, pred_size = (350,250), save_folder = "preds", centroid_mode = False, notebook_mode = False):
def predict_from_video(self, video_path, pred_size = (350,250), save_folder = "preds", centroid_mode = False, color_mode = False, notebook_mode = False):
"""Splits a video from video_path into frames and passes the
frames through the model for predictions. Saves predicted images in save_folder.
And optionally saves all the centroid predictions into a pandas.DataFrame.
Expand All @@ -163,6 +175,7 @@ def predict_from_video(self, video_path, pred_size = (350,250), save_folder = "p
pred_size (tuple, optional): size of output image,(width,height). Defaults to (350,250).
save_folder (str, optional): path to folder to be saved in. Defaults to "preds".
centroid_mode (bool, optional): set to true to return both the segmented image and the list of centroids. Defaults to False.
color_mode (bool, optional): set to true to return a color image. Defaults to False.
notebook_mode (bool, optional): toogle between script(False) and notebook(True), for better user interface. Defaults to False.
Returns:
Expand Down Expand Up @@ -202,7 +215,7 @@ def predict_from_video(self, video_path, pred_size = (350,250), save_folder = "p
res = self.model(tensor).detach().cpu().numpy()[0][0]

if centroid_mode == True:
res, centroids = generate_centroid_image(res)
res, centroids = generate_centroid_image(res, color_mode = color_mode)
filenames_centroids.append([save_name, centroids])

res = cv2.resize(res,pred_size)
Expand All @@ -214,7 +227,7 @@ def predict_from_video(self, video_path, pred_size = (350,250), save_folder = "p
res = self.model(tensor).detach().cpu().numpy()[0][0]

if centroid_mode == True:
res, centroids = generate_centroid_image(res)
res, centroids = generate_centroid_image(res, color_mode = color_mode)
filenames_centroids.append([save_name, centroids])

res = cv2.resize(res,pred_size)
Expand Down

0 comments on commit 122e172

Please sign in to comment.