Skip to content

Commit

Permalink
add automated_threhsold_setting
Browse files Browse the repository at this point in the history
  • Loading branch information
Henley13 committed May 27, 2020
1 parent 490e050 commit 69eaff1
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 10 deletions.
4 changes: 3 additions & 1 deletion bigfish/detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .spot_detection import detect_spots
from .spot_detection import local_maximum_detection
from .spot_detection import spots_thresholding
from .spot_detection import automated_threshold_setting

from .cluster_decomposition import decompose_cluster
from .cluster_decomposition import build_reference_spot
Expand All @@ -31,7 +32,8 @@
_spots = [
"detect_spots",
"local_maximum_detection",
"spots_thresholding"]
"spots_thresholding",
"automated_threshold_setting"]

_clusters = [
"decompose_cluster",
Expand Down
77 changes: 69 additions & 8 deletions bigfish/detection/spot_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

# ### LoG detection ###

def detect_spots(image, threshold, remove_duplicate=True, voxel_size_z=None,
voxel_size_yx=100, psf_z=None, psf_yx=200):
def detect_spots(image, threshold=None, remove_duplicate=True,
voxel_size_z=None, voxel_size_yx=100, psf_z=None, psf_yx=200):
"""Apply LoG filter followed by a Local Maximum algorithm to detect spots
in a 2-d or 3-d image.
Expand All @@ -35,7 +35,8 @@ def detect_spots(image, threshold, remove_duplicate=True, voxel_size_z=None,
image : np.ndarray
Image with shape (z, y, x) or (y, x).
threshold : float or int
A threshold to discriminate relevant spots from noisy blobs.
A threshold to discriminate relevant spots from noisy blobs. If None,
optimal threshold is selected automatically.
remove_duplicate : bool
Remove potential duplicate coordinates for the same spots. Slow the
running.
Expand All @@ -58,7 +59,7 @@ def detect_spots(image, threshold, remove_duplicate=True, voxel_size_z=None,
"""
# check parameters
stack.check_parameter(threshold=(float, int),
stack.check_parameter(threshold=(float, int, type(None)),
remove_duplicate=bool,
voxel_size_z=(int, float, type(None)),
voxel_size_yx=(int, float),
Expand All @@ -84,6 +85,10 @@ def detect_spots(image, threshold, remove_duplicate=True, voxel_size_z=None,
# find local maximum
mask_local_max = local_maximum_detection(image_filtered, sigma)

# get optimal threshold is necessary
if threshold is None:
threshold = automated_threshold_setting(image_filtered, mask_local_max)

# remove spots with a low intensity and return their coordinates
spots, _ = spots_thresholding(image_filtered, mask_local_max, threshold,
remove_duplicate)
Expand Down Expand Up @@ -150,13 +155,14 @@ def spots_thresholding(image, mask_local_max, threshold,
"""Filter detected spots and get coordinates of the remaining spots.
In order to make the thresholding robust, it should be applied to a
filtered image. If the local maximum is not unique (it can happen with
connected pixels with the same value), connected component algorithm is
applied to keep only one coordinate per spot.
filtered image (bigfish.stack.log_filter for example). If the local
maximum is not unique (it can happen with connected pixels with the same
value), connected component algorithm is applied to keep only one
coordinate per spot.
Parameters
----------
image : np.ndarray, np.uint
image : np.ndarray
Image with shape (z, y, x) or (y, x).
mask_local_max : np.ndarray, bool
Mask with shape (z, y, x) or (y, x) indicating the local peaks.
Expand Down Expand Up @@ -212,3 +218,58 @@ def spots_thresholding(image, mask_local_max, threshold,
spots = np.column_stack(spots)

return spots, mask


def automated_threshold_setting(image, mask_local_max):
"""Automatically set the optimal threshold to detect spots.
In order to make the thresholding robust, it should be applied to a
filtered image (bigfish.stack.log_filter for example). The optimal
threshold is selected based on the spots distribution. The latter should
have a kink discriminating a fast decreasing stage from a more stable one
(a plateau).
Parameters
----------
image : np.ndarray
Image with shape (z, y, x) or (y, x).
mask_local_max : np.ndarray, bool
Mask with shape (z, y, x) or (y, x) indicating the local peaks.
Returns
-------
threshold : int
Optimal threshold to discriminate spots from noisy blobs.
"""
# check parameters
stack.check_array(image,
ndim=[2, 3],
dtype=[np.uint8, np.uint16, np.float32, np.float64])
stack.check_array(mask_local_max,
ndim=[2, 3],
dtype=[bool])

# get threshold values x
start_range = 0
end_range = int(np.percentile(image, 99.9999))
x = [i for i in range(start_range, end_range + 1)]

# get spots count y and its logarithm
spots, mask_spots = spots_thresholding(
image, mask_local_max, threshold=x[0], remove_duplicate=False)
value_spots = image[mask_spots]
y = np.log([(value_spots > t).sum() for t in x])
y = stack.centered_moving_average(y, n=5)

# select threshold where the kink of the distribution is located
slope = (y[-1] - y[0]) / len(y)
y_grad = np.gradient(y)
m = list(y_grad >= slope)
j = m.index(False)
if j > 0:
m[:j] = [False] * j
i = m.index(True)
threshold = x[i]

return threshold
6 changes: 5 additions & 1 deletion bigfish/stack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from .utils import load_and_save_url
from .utils import check_hash
from .utils import compute_hash
from .utils import moving_average
from .utils import centered_moving_average

from .io import read_image
from .io import read_dv
Expand Down Expand Up @@ -88,7 +90,9 @@
"get_eps_float32",
"load_and_save_url",
"check_hash",
"compute_hash"]
"compute_hash",
"moving_average",
"centered_moving_average"]

_io = [
"read_image",
Expand Down
65 changes: 65 additions & 0 deletions bigfish/stack/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,3 +745,68 @@ def compute_hash(path):
sha256 = sha256hash.hexdigest()

return sha256


# ### Computation ###

def moving_average(array, n):
"""Compute a trailing average.
Parameters
----------
array : np.ndarray
Array used to compute moving average.
n : int
Window width of the moving average.
Returns
-------
"""
# check parameter
check_parameter(n=int)
check_array(array, ndim=1)

# compute moving average
cumsum = [0]
res = []
for i, x in enumerate(array, 1):
cumsum.append(cumsum[i-1] + x)
if i >= n:
ma = (cumsum[i] - cumsum[i - n]) / n
res.append(ma)
res = np.array(res)

return res


def centered_moving_average(array, n):
"""Compute a centered moving average.
Parameters
----------
array : np.ndarray
Array used to compute moving average.
n : int
Window width of the moving average.
Returns
-------
"""
# check parameter
check_parameter(n=int)
check_array(array, ndim=1)

# pad array to keep the same length and centered the outcome
if n % 2 == 0:
r = int(n / 2)
n += 1
else:
r = int((n - 1) / 2)
array_padded = np.pad(array, pad_width=r, mode="reflect")

# compute centered moving average
res = moving_average(array_padded, n)

return res

0 comments on commit 69eaff1

Please sign in to comment.