Skip to content

Commit

Permalink
Sc 41 optimize flowsom clustering (#81)
Browse files Browse the repository at this point in the history
* SC_41_z_score

* SC_41 plot clusters

* SC_41 flowsom

* SC_41 flowsom clusering visualization

* SC_41 test flowsom

* SC_41 docs

* SC_41 update docstring

* SC_41 docs
  • Loading branch information
ArneDefauw authored Jan 8, 2025
1 parent 53a717f commit 1d5fc0f
Show file tree
Hide file tree
Showing 9 changed files with 518 additions and 1,317 deletions.
3 changes: 3 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ Plotting functions.
.. autosummary::
:toctree: generated
pl.pixel_clusters
pl.pixel_clusters_heatmap
pl.snr_ratio
pl.group_snr_ratio
pl.snr_clustermap
Expand Down
1,409 changes: 102 additions & 1,307 deletions docs/tutorials/general/FlowSOM_for_pixel_and_cell_clustering.ipynb

Large diffs are not rendered by default.

68 changes: 68 additions & 0 deletions src/harpy/_tests/test_plot/test_plot_flowsom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import importlib
import os

import matplotlib
import pytest

from harpy.plot._flowsom import pixel_clusters, pixel_clusters_heatmap


@pytest.mark.skipif(not importlib.util.find_spec("flowsom"), reason="requires the flowSOM library")
def test_plot_pixel_clusters(sdata_blobs, tmp_path):
from harpy.image.pixel_clustering._clustering import flowsom
from harpy.table.pixel_clustering._cluster_intensity import cluster_intensity

matplotlib.use("Agg")

img_layer = "blobs_image"
channels = ["lineage_0", "lineage_1", "lineage_5", "lineage_9"]
fraction = 0.1

sdata_blobs, fsom, mapping = flowsom(
sdata_blobs,
img_layer=[img_layer],
output_layer_clusters=[f"{img_layer}_clusters"],
output_layer_metaclusters=[f"{img_layer}_metaclusters"],
channels=channels,
fraction=fraction,
n_clusters=20,
random_state=100,
chunks=(1, 200, 200),
overwrite=True,
)

assert f"{img_layer}_clusters" in sdata_blobs.labels
assert f"{img_layer}_metaclusters" in sdata_blobs.labels
assert fsom.model._is_fitted

pixel_clusters(
sdata_blobs,
labels_layer="blobs_image_metaclusters",
figsize=(10, 10),
coordinate_systems="global",
crd=(100, 300, 100, 300),
output=os.path.join(tmp_path, "pixel_clusters.png"),
)

sdata_blobs = cluster_intensity(
sdata_blobs,
mapping=mapping,
img_layer=[img_layer],
labels_layer=[f"{img_layer}_clusters"],
to_coordinate_system=["global"],
output_layer="counts_clusters",
chunks="auto",
overwrite=True,
)

for _metaclusters in [True, False]:
pixel_clusters_heatmap(
sdata_blobs,
table_layer="counts_clusters",
figsize=(30, 20),
fig_kwargs={"dpi": 100},
linewidths=0.01,
metaclusters=_metaclusters,
z_score=True,
output=os.path.join(tmp_path, f"pixel_clusters_heatmap_{_metaclusters}.png"),
)
59 changes: 55 additions & 4 deletions src/harpy/image/pixel_clustering/_preprocess.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import annotations

import os
import shutil
import uuid
from collections.abc import Iterable

import dask.array as da
import numpy as np
from dask import compute, persist
from dask.array import Array
from dask_image.ndfilters import gaussian_filter
from spatialdata import SpatialData
Expand All @@ -12,6 +16,9 @@

from harpy.image._image import _get_spatial_element, add_image_layer
from harpy.image._normalize import _nonzero_nonnan_percentile, _nonzero_nonnan_percentile_axis_0
from harpy.utils.pylogger import get_pylogger

log = get_pylogger(__name__)


def pixel_clustering_preprocess(
Expand All @@ -26,6 +33,8 @@ def pixel_clustering_preprocess(
norm_sum: bool = True,
chunks: str | int | tuple[int, ...] | None = None,
scale_factors: ScaleFactors_t | None = None,
cast_dtype: type | None = np.float32,
persist_intermediate: bool = True,
overwrite: bool = False,
) -> SpatialData:
"""
Expand Down Expand Up @@ -59,6 +68,15 @@ def pixel_clustering_preprocess(
Chunk sizes for processing. If provided as a tuple, it should contain chunk sizes for `c`, `(z)`, `y`, `x`.
scale_factors
Scale factors to apply for multiscale
persist_intermediate
If set to `True` will persist all preprocessed elements in `img_layer` in memory.
If the elements in `img_layer` are large, this could lead to increased ram usage.
Set to `False` to write to intermediate zarr store instead, which will reduce ram usage, but will increase computation time slightly.
Persist or writing to intermediate zarr store is needed, otherwise Dask will not be able to optimize the computation graph for the multiple `img_layer` use case.
Ignored if `sdata` is not backed by a zarr store, or if there is only one element in `img_layer`.
cast_dtype
Image data in `img_layer` will be casted to `dtype` before preprocessing starts. If set to None, and input image is of integer type, normalizations will lead to
data of type `numpy.float64` due to quantile normalizations, leading to increased memory usage.
overwrite
If `True`, overwrites existing data in `output_layer`.
Expand Down Expand Up @@ -116,6 +134,8 @@ def pixel_clustering_preprocess(
# add trivial z dimension for 2D case
arr = arr[:, None, ...]
to_squeeze = True
if cast_dtype is not None:
arr = arr.astype(cast_dtype)
_arr_list.append(arr)

if q is not None:
Expand All @@ -126,6 +146,7 @@ def pixel_clustering_preprocess(
results_arr_percentile.append(arr_percentile)
arr_percentile = da.stack(results_arr_percentile, axis=0)
arr_percentile_mean = da.mean(arr_percentile, axis=0) # mean over all images
# arr_percentile_mean has shape (#n_channels,)
# for multiple img_layer, in ark one uses np.mean( arr_percentile ) as the percentile to normalize

# 2) calculate norm sum percentile for img_layer
Expand Down Expand Up @@ -191,14 +212,37 @@ def pixel_clustering_preprocess(

if q_post is not None:
arr_percentile_post_norm = da.stack(results_arr_percentile_post_norm, axis=0)
arr_percentile_post_norm_mean = da.mean(arr_percentile_post_norm, axis=0)
arr_percentile_post_norm_mean = da.mean(
arr_percentile_post_norm, axis=0
) # arr_percentil_post_mean is of shape (#n_channels,)

# Now normalize each image layer by arr_percentile_post_norm and add to spatialdata object
for i in range(len(_arr_list)):
if q_post is not None:
if q_post is not None:
for i in range(len(_arr_list)):
_arr_list[i] = _arr_list[i] / da.asarray(arr_percentile_post_norm_mean[..., None, None, None])

# save the preprocessed images, in this way we get the preprocessed images from which we sample
# need to let dask do optimization of the computation graph in case there are multiple images
# otherwise will recaclulate whole preprocessing every time we add an image layer to the sdata zarr store
# should only do this if there is more than one image
clean_up = False
if len(_arr_list) > 1:
if sdata.is_backed() and not persist_intermediate:
clean_up = True
_uuid = uuid.uuid4()
for i, _arr in enumerate(_arr_list):
_intermediate_zarr_store = os.path.join(os.path.dirname(sdata.path), f"{i}_{_uuid}.zarr")
log.info(f"Preparing to write to intermediate zarr store {_intermediate_zarr_store}.")
_arr.to_zarr(_intermediate_zarr_store, compute=False)
# write to intermediate zarr store, and let dask optimize computation graph
compute(_arr_list)
# load them dask array back lazily
for i in range(len(_arr_list)):
_arr_list[i] = da.from_zarr(os.path.join(os.path.dirname(sdata.path), f"{i}_{_uuid}.zarr"))
else:
_arr_list = persist(*_arr_list)

# save the preprocessed images, in this way we get the preprocessed images from which we sample
for i in range(len(_arr_list)):
sdata = add_image_layer(
sdata,
arr=_arr_list[i].squeeze(1) if to_squeeze else _arr_list[i],
Expand All @@ -209,6 +253,13 @@ def pixel_clustering_preprocess(
overwrite=overwrite,
)

if clean_up:
# clean up the intermediate zarr store.
for i in range(len(_arr_list)):
_intermediate_zarr_store = os.path.join(os.path.dirname(sdata.path), f"{i}_{_uuid}.zarr")
log.info(f"Removing intermediate zarr store {_intermediate_zarr_store}")
shutil.rmtree(_intermediate_zarr_store)

return sdata


Expand Down
2 changes: 0 additions & 2 deletions src/harpy/io/_visium_hd.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ def visium_hd(
Read *10x Genomics* Visium HD formatted dataset.
Wrapper around `spatialdata.io.readers.visium_hd.visium_hd`, but with the resulting table annotated by a labels layer.
To use this function, please install `spatialdata_io` via this fork: https://github.com/ArneDefauw/spatialdata-io.git@visium_hd.
E.g. `pip install git+https://github.com/ArneDefauw/spatialdata-io.git@visium_hd`.
.. seealso::
Expand Down
1 change: 1 addition & 0 deletions src/harpy/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ._cluster_cleanliness import cluster_cleanliness
from ._clustering import cluster
from ._enrichment import nhood_enrichment
from ._flowsom import pixel_clusters, pixel_clusters_heatmap
from ._plot import plot_image, plot_labels, plot_shapes
from ._preprocess import preprocess_transcriptomics
from ._qc_cells import plot_adata, ridgeplot_channel, ridgeplot_channel_sample
Expand Down
Loading

0 comments on commit 1d5fc0f

Please sign in to comment.