Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alejandro #30

Merged
merged 10 commits into from
May 26, 2022
20 changes: 17 additions & 3 deletions pysmFISH/cell_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,22 @@ def expand_labels_mask(self, mask, distance, out_file_name='expanded',
overlap_percentage = overlap_percentage / 100
overlap = int(chunk_size * overlap_percentage)

#Iterate over chunks and asign dots to cells
#Iterate over chunks
for i, (view_min, view_max) in enumerate(sliding_window_view(chunks, 2)):
print(i, view_min, view_max)
if i == 0:
view_min_extra = view_min
view_max_extra = view_max + overlap
overlap_right = overlap
elif (view_max + overlap) > mask.shape[0]:
print('got here')
view_min_extra = view_min - overlap
view_max_extra = mask.shape[0]
overlap_right = mask.shape[0] - view_max
else:
view_min_extra = view_min - overlap
view_max_extra = view_max + overlap
view_max_extra = view_max + overlap
overlap_right = overlap

#Expand labels
img = mask[view_min_extra:view_max_extra, :]
Expand All @@ -81,7 +88,14 @@ def expand_labels_mask(self, mask, distance, out_file_name='expanded',
exp[view_min:, :] = img[overlap:]

else: #Middle chunks
exp[view_min:view_max, :] = img[overlap:-overlap,:]
print(f' {overlap}, {overlap_right}')
print(f' {view_min}, {view_max}')
print(f' {view_min_extra}, {view_max_extra}')
print(f' {img.shape}')
print(f' {overlap}, {chunk_size+overlap}')
print(f' {img[overlap:chunk_size+overlap,:].shape}')
#exp[view_min:view_max, :] = img[overlap:-overlap_right,:]
exp[view_min:view_max, :] = img[overlap:chunk_size+overlap,:]

return out_file_name

Expand Down
203 changes: 119 additions & 84 deletions pysmFISH/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ class Pipeline:
diameter_size (int): Size of the diameter of the cells to segment using cellpose
min_overlapping_pixels_segmentation (int): Size of the overlapping label objects
fov_alignement_mode (str): clip or merged (default clipped).
clip_size (str): only if fov_alignment_mode is merge it will clip the clip_size length of borders.
remove_distinct_genes (bool): when stitching it will also remove overlapping dots of different genes if set to
true. Defaults to true


Attributes:
Expand Down Expand Up @@ -257,6 +260,9 @@ def __init__(

self.max_expansion_radius = kwarg.pop("max_expansion_radius", 18)
self.fov_alignement_mode = kwarg.pop("fov_alignement_mode", "clip")
self.clip_size = kwarg.pop("clip_size", 0)
self.remove_distinct_genes = kwarg.pop("remove_distinct_genes", False)


# -----------------------------------
# PROCESSING STEPS
Expand Down Expand Up @@ -1164,8 +1170,8 @@ def stitch_and_remove_dots_eel_graph_old_room_step(self):
# # ----------------------------------------------------------------

def stitch_and_remove_dots_eel_graph_step(self,
remove_distinct_genes=False,
clip_size=0,):
remove_distinct_genes=self.remove_distinct_genes,
clip_size=self.clip_size,):

"""
Function to stitch the different fovs and remove the duplicated
Expand Down Expand Up @@ -1249,8 +1255,8 @@ def stitch_and_remove_dots_eel_graph_step(self,
matching_dot_radius=self.same_dot_radius_duplicate_dots,
out_folder=folder,
exp_name=self.metadata["experiment_name"],
remove_distinct_genes=remove_distinct_genes,
clip_size=clip_size,
remove_distinct_genes=self.remove_distinct_genes,
clip_size=self.clip_size,
verbose=False,
) # Set to False in pipeline

Expand All @@ -1260,9 +1266,9 @@ def stitch_and_remove_dots_eel_graph_step(self,
def processing_fresh_tissue_step(
self,
parsing=True,
reprocessing=True,
tag_ref_beads="_ChannelEuropium_Cy3_",
tag_nuclei="_ChannelCy3_",
centering_mode='middle',
):
"""
This function create and run a processing graph that parse and filter the nuclei staining in fresh tissue
Expand All @@ -1285,86 +1291,86 @@ def processing_fresh_tissue_step(
f"cannot process fresh tissue because missing running_functions attr"
)

'''(
self.ds_beads,
if reprocessing:
(
self.ds_beads,
self.ds_nuclei,
self.nuclei_metadata,
) = fov_processing.process_fresh_sample_graph(
self.experiment_fpath,
self.running_functions,
self.analysis_parameters,
self.client,
self.chunk_size,
tag_ref_beads=tag_ref_beads,
tag_nuclei=tag_nuclei,
eel_metadata=self.metadata,
fresh_tissue_segmentation_engine=self.fresh_tissue_segmentation_engine,
diameter_size=self.diameter_size,
parsing=parsing,
save_steps_output=self.save_intermediate_steps,
)

pickle.dump(
[
self.ds_beads,
self.ds_nuclei,
self.metadata,
],
open(
Path(self.experiment_fpath)
/ "fresh_tissue"
/ "segmentation"
/ "ds_tmp_data.pkl",
"wb",
),
)

(self.ds_beads, self.ds_nuclei, self.nuclei_metadata) = pickle.load(
open(
Path(self.experiment_fpath)
/ "fresh_tissue"
/ "segmentation"
/ "ds_tmp_data.pkl",
"rb",
),
)

# Segmentation
fov_processing.segmentation_graph(
self.ds_nuclei,
self.nuclei_metadata,
) = fov_processing.process_fresh_sample_graph(
self.chunk_size,
self.experiment_fpath,
self.fresh_tissue_segmentation_engine,
self.diameter_size,
)
(
self.nuclei_org_tiles,
self.nuclei_adjusted_coords,
) = stitching.stitched_beads_on_nuclei_fresh_tissue(
self.experiment_fpath,
self.running_functions,
self.analysis_parameters,
self.client,
self.chunk_size,
tag_ref_beads=tag_ref_beads,
tag_nuclei=tag_nuclei,
eel_metadata=self.metadata,
fresh_tissue_segmentation_engine=self.fresh_tissue_segmentation_engine,
diameter_size=self.diameter_size,
parsing=parsing,
save_steps_output=self.save_intermediate_steps,
)'''

# pickle.dump(
# [
# self.ds_beads,
# self.ds_nuclei,
# self.metadata,
# ],
# open(
# Path(self.experiment_fpath)
# / "fresh_tissue"
# / "segmentation"
# / "ds_tmp_data.pkl",
# "wb",
# ),
# )

# (self.ds_beads, self.ds_nuclei, self.nuclei_metadata) = pickle.load(
# open(
# Path(self.experiment_fpath)
# / "fresh_tissue"
# / "segmentation"
# / "ds_tmp_data.pkl",
# "rb",
# ),
# )

# # Segmentation
# fov_processing.segmentation_graph(
# self.ds_nuclei,
# self.chunk_size,
# self.experiment_fpath,
# self.fresh_tissue_segmentation_engine,
# self.diameter_size,
# )

# (
# self.nuclei_org_tiles,
# self.nuclei_adjusted_coords,
# ) = stitching.stitched_beads_on_nuclei_fresh_tissue(
# self.experiment_fpath,
# self.client,
# self.ds_nuclei,
# self.ds_beads,
# round_num=1,
# )

# pickle.dump(
# [
# self.ds_beads,
# self.ds_nuclei,
# self.nuclei_metadata,
# self.nuclei_org_tiles,
# self.nuclei_adjusted_coords,
# ],
# open(
# Path(self.experiment_fpath)
# / "fresh_tissue"
# / "segmentation"
# / "tmp_data.pkl",
# "wb",
# ),
# )
self.ds_nuclei,
self.ds_beads,
round_num=1,
)

pickle.dump(
[
self.ds_beads,
self.ds_nuclei,
self.nuclei_metadata,
self.nuclei_org_tiles,
self.nuclei_adjusted_coords,
],
open(
Path(self.experiment_fpath)
/ "fresh_tissue"
/ "segmentation"
/ "tmp_data.pkl",
"wb",
),
)

(
self.ds_beads,
Expand All @@ -1381,7 +1387,7 @@ def processing_fresh_tissue_step(
"rb",
),
)

segmentation_output_path = (
Path(self.experiment_fpath) / "fresh_tissue" / "segmentation"
)
Expand All @@ -1395,7 +1401,36 @@ def processing_fresh_tissue_step(
self.client,
self.min_overlapping_pixels_segmentation,
)
gc.collect()

def processing_assign_dots(self):
gc.collect()
segmentation_output_path = (
Path(self.experiment_fpath) / "fresh_tissue" / "segmentation"
)

(
self.ds_beads,
self.ds_nuclei,
self.nuclei_metadata,
self.nuclei_org_tiles,
self.nuclei_adjusted_coords,
) = pickle.load(
open(
Path(self.experiment_fpath)
/ "fresh_tissue"
/ "segmentation"
/ "tmp_data.pkl",
"rb",
),
)

segmented_object_dict_recalculated = pickle.load(
open(
segmentation_output_path / ("segmented_objects_dict_recalculated_ids.pkl"),
"rb",
),
)
segmentation.register_assign(
self.experiment_fpath,
segmented_object_dict_recalculated,
Expand All @@ -1407,7 +1442,6 @@ def processing_fresh_tissue_step(
segmentation_output_path,
self.max_expansion_radius,
self.hamming_distance,
centering_mode=centering_mode,
)

# --------------------------------
Expand Down Expand Up @@ -1663,6 +1697,7 @@ def run_full(self):

step_start = datetime.now()
self.processing_fresh_tissue_step()
self.processing_assign_dots()
self.logger.info(
f"{self.experiment_fpath.stem} timing: \
Processing fresh tissue completed in {utils.nice_deltastring(datetime.now() - step_start)}."
Expand Down
13 changes: 6 additions & 7 deletions pysmFISH/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,12 +484,11 @@ def create_label_image(
),
)

'''zarr_fpath = segmentation_output_path / "image_segmented_labels.zarr"
zarr_fpath = segmentation_output_path / "image_segmented_labels.zarr"
store = zarr.DirectoryStore(zarr_fpath, "w")
grp = zarr.group(store=store, overwrite=True)
grp.create_dataset(name="segmented_labels_image", data=img)'''

np.save(os.path.join(segmentation_output_path,'segmented_labels_image.npy'),img)
grp.create_dataset(name="segmented_labels_image", data=img)
#np.save(os.path.join(segmentation_output_path,'segmented_labels_image.npy'),img)

return segmented_object_dict_recalculated

Expand Down Expand Up @@ -708,11 +707,11 @@ def register_assign(
source_RNA_df.loc[:, ["r_transformed", "c_transformed"]] = transformed_points

# Replace this with chunk loading in the expanding function
'''zarr_fpath = segmentation_output_path / "image_segmented_labels.zarr"
zarr_fpath = segmentation_output_path / "image_segmented_labels.zarr"
store = zarr.DirectoryStore(zarr_fpath, "r")
grp = zarr.group(store=store, overwrite=False)
segmented_img = grp["segmented_labels_image"][...]'''
segmented_img = np.load(os.path.join(segmentation_output_path,'segmented_labels_image.npy'))
segmented_img = grp["segmented_labels_image"][...]
#segmented_img = np.load(os.path.join(segmentation_output_path,'segmented_labels_image.npy'))

# instantiate model
CA = Cell_Assignment()
Expand Down
8 changes: 8 additions & 0 deletions pysmFISH/stitching.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,6 +1342,14 @@ def stitching_graph_fresh_nuclei(
k: v for (k, v) in all_registrations_dict.items() if np.all(np.abs(v[0]) < 20)
}

# Alejandro version
all_registrations_removed_large_shift = {}
for (k, v) in all_registrations_dict.items():
if np.all(np.abs(v[0]) < 20):
all_registrations_removed_large_shift[k] = v
else:
all_registrations_removed_large_shift[k] = [np.array([0,0]), 1.0]

cpls = all_registrations_removed_large_shift.keys()
# cpls = list(unfolded_overlapping_regions_dict.keys())
total_cpls = len(cpls)
Expand Down