Skip to content

Commit

Permalink
add elbow plot function
Browse files Browse the repository at this point in the history
  • Loading branch information
Henley13 committed Jun 14, 2021
1 parent e367a2f commit 3e47346
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 98 deletions.
4 changes: 3 additions & 1 deletion bigfish/detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .spot_detection import local_maximum_detection
from .spot_detection import spots_thresholding
from .spot_detection import automated_threshold_setting
from .spot_detection import get_elbow_values

from .dense_decomposition import decompose_dense
from .dense_decomposition import get_dense_region
Expand All @@ -33,7 +34,8 @@
"detect_spots",
"local_maximum_detection",
"spots_thresholding",
"automated_threshold_setting"]
"automated_threshold_setting",
"get_elbow_values"]

_dense = [
"decompose_dense",
Expand Down
91 changes: 91 additions & 0 deletions bigfish/detection/spot_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,3 +544,94 @@ def _get_breaking_point(x, y):
breaking_point = float(x[i])

return breaking_point, x, y


def get_elbow_values(images, voxel_size_z=None, voxel_size_yx=100, psf_z=None,
psf_yx=200):
# check parameters
stack.check_parameter(voxel_size_z=(int, float, type(None)),
voxel_size_yx=(int, float),
psf_z=(int, float, type(None)),
psf_yx=(int, float))

# if one image is provided we enlist it
if not isinstance(images, list):
stack.check_array(images,
ndim=[2, 3],
dtype=[np.uint8, np.uint16,
np.float32, np.float64])
ndim = images.ndim
images = [images]
n = 1
else:
ndim = None
for i, image in enumerate(images):
stack.check_array(image,
ndim=[2, 3],
dtype=[np.uint8, np.uint16,
np.float32, np.float64])
if i == 0:
ndim = image.ndim
else:
if ndim != image.ndim:
raise ValueError("Provided images should have the same "
"number of dimensions.")
n = len(images)

# check consistency between parameters
if ndim == 3 and voxel_size_z is None:
raise ValueError("Provided images has {0} dimensions but "
"'voxel_size_z' parameter is missing.".format(ndim))
if ndim == 3 and psf_z is None:
raise ValueError("Provided images has {0} dimensions but "
"'psf_z' parameter is missing.".format(ndim))
if ndim == 2:
voxel_size_z = None
psf_z = None

# compute sigma
sigma = stack.get_sigma(voxel_size_z, voxel_size_yx, psf_z, psf_yx)

# apply LoG filter and find local maximum
images_filtered = []
pixel_values = []
masks = []
for image in images:
# filter image
image_filtered = stack.log_filter(image, sigma)
images_filtered.append(image_filtered)

# get pixels value
pixel_values += list(image_filtered.ravel())

# find local maximum
mask_local_max = local_maximum_detection(
image_filtered, sigma)
masks.append(mask_local_max)

# get threshold values we want to test
thresholds = _get_candidate_thresholds(pixel_values)

# get spots count and its logarithm
all_value_spots = []
minimum_threshold = float(thresholds[0])
for i in range(n):
image_filtered = images_filtered[i]
mask_local_max = masks[i]
spots, mask_spots = spots_thresholding(
image_filtered, mask_local_max,
threshold=minimum_threshold,
remove_duplicate=False)
value_spots = image_filtered[mask_spots]
all_value_spots.append(value_spots)
all_value_spots = np.concatenate(all_value_spots)
thresholds, count_spots = _get_spot_counts(
thresholds, all_value_spots)

# select threshold where the kink of the distribution is located
if count_spots.size > 0:
threshold, _, _ = _get_breaking_point(thresholds, count_spots)
else:
threshold = None

return thresholds, count_spots, threshold
4 changes: 2 additions & 2 deletions bigfish/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .plot_images import plot_cell

from .plot_quality import plot_sharpness
from .plot_quality import plot_snr_spots
from .plot_quality import plot_elbow

from .plot_classification import plot_confusion_matrix
from .plot_classification import plot_2d_projection
Expand All @@ -43,7 +43,7 @@

_quality = [
"plot_sharpness",
"plot_snr_spots"]
"plot_elbow"]

_utils = [
"save_plot",
Expand Down
120 changes: 25 additions & 95 deletions bigfish/plot/plot_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
"""

import bigfish.stack as stack
import bigfish.detection as detection

import matplotlib.pyplot as plt
import numpy as np

from .utils import save_plot


# ### Focus - sharpness ###

def plot_sharpness(focus_measures, labels=None, title=None, framesize=(5, 5),
size_title=20, size_axes=15, size_legend=15,
path_output=None, ext="png", show=True):
Expand Down Expand Up @@ -97,53 +100,13 @@ def plot_sharpness(focus_measures, labels=None, title=None, framesize=(5, 5),
return


def plot_snr_spots(snr_spots, labels=None, colors=None, x_lim=None, y_lim=None,
title=None, framesize=(10, 5), size_title=20, size_axes=15,
size_legend=15, path_output=None, ext="png", show=True):
"""Plot Signal-to-Noise Ratio computed for all detected spots.
Parameters
----------
snr_spots : List[np.ndarray] or np.ndarray
A list of 1-d arrays with the SNR computed for every spot of an image.
One array per image.
labels : List[str], str or None
Labels or the curves.
colors : List[str], str or None
Colors or the curves.
x_lim : tuple or None
Limits of the x-axis.
y_lim : tuple or None
Limits of the y-axis.
title : str or None
Title of the plot.
framesize : tuple
Size of the frame used to plot with 'plt.figure(figsize=framesize)'.
size_title : int
Size of the title.
size_axes : int
Size of the axes label.
size_legend : int
Size of the legend.
path_output : str or None
Path to save the image (without extension).
ext : str or List[str]
Extension used to save the plot. If it is a list of strings, the plot
will be saved several times.
show : bool
Show the figure or not.
Returns
-------
# ### Elbow plot ###

"""
def plot_elbow(images, voxel_size_z, voxel_size_yx, psf_z, psf_yx, title=None,
framesize=(5, 5), size_title=20, size_axes=15, size_legend=15,
path_output=None, ext="png", show=True):
# check parameters
stack.check_parameter(snr_spots=(list, np.ndarray),
labels=(list, str, type(None)),
colors=(list, str, type(None)),
x_lim=(tuple, type(None)),
y_lim=(tuple, type(None)),
title=(str, type(None)),
stack.check_parameter(title=(str, list, type(None)),
framesize=tuple,
size_title=int,
size_axes=int,
Expand All @@ -152,61 +115,30 @@ def plot_snr_spots(snr_spots, labels=None, colors=None, x_lim=None, y_lim=None,
ext=(str, list),
show=bool)

# enlist values if necessary
if isinstance(snr_spots, np.ndarray):
snr_spots = [snr_spots]
if labels is not None and isinstance(labels, str):
labels = [labels]
if colors is not None and isinstance(colors, str):
colors = [colors]

# check arrays
for snr_spots_ in snr_spots:
stack.check_array(snr_spots_,
ndim=1,
dtype=[np.float32, np.float64])

# check number of parameters
if labels is not None and len(snr_spots) != len(labels):
raise ValueError("The number of labels provided ({0}) differs from "
"the number of arrays to plot ({1})."
.format(len(labels), len(snr_spots)))
if colors is not None and len(snr_spots) != len(colors):
raise ValueError("The number of colors provided ({0}) differs from "
"the number of arrays to plot ({1})."
.format(len(colors), len(snr_spots)))

# frame
plt.figure(figsize=framesize)
# get candidate thresholds and spots count to plot the elbow curve
thresholds, count_spots, threshold = detection.get_elbow_values(
images=images,
voxel_size_z=voxel_size_z,
voxel_size_yx=voxel_size_yx,
psf_z=psf_z,
psf_yx=psf_yx)

# plot
for i, snr_spots_ in enumerate(snr_spots):
values = sorted(snr_spots_, reverse=True)
if labels is None and colors is None:
plt.plot(values, lw=2)
elif labels is None and colors is not None:
color = colors[i]
plt.plot(values, lw=2, c=color)
elif labels is not None and colors is None:
label = labels[i]
plt.plot(values, lw=2, label=label)
else:
label = labels[i]
color = colors[i]
plt.plot(values, lw=2, c=color, label=label)
plt.figure(figsize=framesize)
plt.plot(thresholds, count_spots, c="#2c7bb6", lw=2)
if threshold is not None:
i_threshold = np.argmax(thresholds == threshold)
plt.scatter(threshold, count_spots[i_threshold],
marker="D", c="#d7191c", s=60, label="Selected threshold")

# axes
if title is not None:
plt.title(title, fontweight="bold", fontsize=size_title)
plt.xlabel("Detected Spots", fontweight="bold", fontsize=size_axes)
plt.ylabel("Signal-to-Noise Ratio", fontweight="bold", fontsize=size_axes)
if x_lim is not None:
plt.xlim(x_lim)
if y_lim is not None:
plt.ylim(y_lim)
if labels is not None:
plt.xlabel("Thresholds", fontweight="bold", fontsize=size_axes)
plt.ylabel("Number of mRNAs detected (log scale)",
fontweight="bold", fontsize=size_axes)
if threshold is not None:
plt.legend(prop={'size': size_legend})

plt.tight_layout()
if path_output is not None:
save_plot(path_output, ext)
Expand All @@ -216,5 +148,3 @@ def plot_snr_spots(snr_spots, labels=None, colors=None, x_lim=None, y_lim=None,
plt.close()

return


0 comments on commit 3e47346

Please sign in to comment.