diff --git a/pysmFISH/cell_assignment.py b/pysmFISH/cell_assignment.py index bf945a4..a105718 100644 --- a/pysmFISH/cell_assignment.py +++ b/pysmFISH/cell_assignment.py @@ -59,15 +59,22 @@ def expand_labels_mask(self, mask, distance, out_file_name='expanded', overlap_percentage = overlap_percentage / 100 overlap = int(chunk_size * overlap_percentage) - #Iterate over chunks and asign dots to cells + #Iterate over chunks for i, (view_min, view_max) in enumerate(sliding_window_view(chunks, 2)): print(i, view_min, view_max) if i == 0: view_min_extra = view_min view_max_extra = view_max + overlap + overlap_right = overlap + elif (view_max + overlap) > mask.shape[0]: + print('got here') + view_min_extra = view_min - overlap + view_max_extra = mask.shape[0] + overlap_right = mask.shape[0] - view_max else: view_min_extra = view_min - overlap - view_max_extra = view_max + overlap + view_max_extra = view_max + overlap + overlap_right = overlap #Expand labels img = mask[view_min_extra:view_max_extra, :] @@ -81,7 +88,14 @@ def expand_labels_mask(self, mask, distance, out_file_name='expanded', exp[view_min:, :] = img[overlap:] else: #Middle chunks - exp[view_min:view_max, :] = img[overlap:-overlap,:] + print(f' {overlap}, {overlap_right}') + print(f' {view_min}, {view_max}') + print(f' {view_min_extra}, {view_max_extra}') + print(f' {img.shape}') + print(f' {overlap}, {chunk_size+overlap}') + print(f' {img[overlap:chunk_size+overlap,:].shape}') + #exp[view_min:view_max, :] = img[overlap:-overlap_right,:] + exp[view_min:view_max, :] = img[overlap:chunk_size+overlap,:] return out_file_name diff --git a/pysmFISH/pipeline.py b/pysmFISH/pipeline.py index f0bc37b..abdca0d 100644 --- a/pysmFISH/pipeline.py +++ b/pysmFISH/pipeline.py @@ -131,6 +131,9 @@ class Pipeline: diameter_size (int): Size of the diameter of the cells to segment using cellpose min_overlapping_pixels_segmentation (int): Size of the overlapping label objects fov_alignement_mode (str): clip or merged (default clipped). + clip_size (str): only if fov_alignment_mode is merge it will clip the clip_size length of borders. + remove_distinct_genes (bool): when stitching it will also remove overlapping dots of different genes if set to + true. Defaults to true Attributes: @@ -257,6 +260,9 @@ def __init__( self.max_expansion_radius = kwarg.pop("max_expansion_radius", 18) self.fov_alignement_mode = kwarg.pop("fov_alignement_mode", "clip") + self.clip_size = kwarg.pop("clip_size", 0) + self.remove_distinct_genes = kwarg.pop("remove_distinct_genes", False) + # ----------------------------------- # PROCESSING STEPS @@ -1164,8 +1170,8 @@ def stitch_and_remove_dots_eel_graph_old_room_step(self): # # ---------------------------------------------------------------- def stitch_and_remove_dots_eel_graph_step(self, - remove_distinct_genes=False, - clip_size=0,): + remove_distinct_genes=self.remove_distinct_genes, + clip_size=self.clip_size,): """ Function to stitch the different fovs and remove the duplicated @@ -1249,8 +1255,8 @@ def stitch_and_remove_dots_eel_graph_step(self, matching_dot_radius=self.same_dot_radius_duplicate_dots, out_folder=folder, exp_name=self.metadata["experiment_name"], - remove_distinct_genes=remove_distinct_genes, - clip_size=clip_size, + remove_distinct_genes=self.remove_distinct_genes, + clip_size=self.clip_size, verbose=False, ) # Set to False in pipeline @@ -1260,9 +1266,9 @@ def stitch_and_remove_dots_eel_graph_step(self, def processing_fresh_tissue_step( self, parsing=True, + reprocessing=True, tag_ref_beads="_ChannelEuropium_Cy3_", tag_nuclei="_ChannelCy3_", - centering_mode='middle', ): """ This function create and run a processing graph that parse and filter the nuclei staining in fresh tissue @@ -1285,86 +1291,86 @@ def processing_fresh_tissue_step( f"cannot process fresh tissue because missing running_functions attr" ) - '''( - self.ds_beads, + if reprocessing: + ( + self.ds_beads, + self.ds_nuclei, + self.nuclei_metadata, + ) = fov_processing.process_fresh_sample_graph( + self.experiment_fpath, + self.running_functions, + self.analysis_parameters, + self.client, + self.chunk_size, + tag_ref_beads=tag_ref_beads, + tag_nuclei=tag_nuclei, + eel_metadata=self.metadata, + fresh_tissue_segmentation_engine=self.fresh_tissue_segmentation_engine, + diameter_size=self.diameter_size, + parsing=parsing, + save_steps_output=self.save_intermediate_steps, + ) + + pickle.dump( + [ + self.ds_beads, + self.ds_nuclei, + self.metadata, + ], + open( + Path(self.experiment_fpath) + / "fresh_tissue" + / "segmentation" + / "ds_tmp_data.pkl", + "wb", + ), + ) + + (self.ds_beads, self.ds_nuclei, self.nuclei_metadata) = pickle.load( + open( + Path(self.experiment_fpath) + / "fresh_tissue" + / "segmentation" + / "ds_tmp_data.pkl", + "rb", + ), + ) + + # Segmentation + fov_processing.segmentation_graph( self.ds_nuclei, - self.nuclei_metadata, - ) = fov_processing.process_fresh_sample_graph( + self.chunk_size, + self.experiment_fpath, + self.fresh_tissue_segmentation_engine, + self.diameter_size, + ) + ( + self.nuclei_org_tiles, + self.nuclei_adjusted_coords, + ) = stitching.stitched_beads_on_nuclei_fresh_tissue( self.experiment_fpath, - self.running_functions, - self.analysis_parameters, self.client, - self.chunk_size, - tag_ref_beads=tag_ref_beads, - tag_nuclei=tag_nuclei, - eel_metadata=self.metadata, - fresh_tissue_segmentation_engine=self.fresh_tissue_segmentation_engine, - diameter_size=self.diameter_size, - parsing=parsing, - save_steps_output=self.save_intermediate_steps, - )''' - - # pickle.dump( - # [ - # self.ds_beads, - # self.ds_nuclei, - # self.metadata, - # ], - # open( - # Path(self.experiment_fpath) - # / "fresh_tissue" - # / "segmentation" - # / "ds_tmp_data.pkl", - # "wb", - # ), - # ) - - # (self.ds_beads, self.ds_nuclei, self.nuclei_metadata) = pickle.load( - # open( - # Path(self.experiment_fpath) - # / "fresh_tissue" - # / "segmentation" - # / "ds_tmp_data.pkl", - # "rb", - # ), - # ) - - # # Segmentation - # fov_processing.segmentation_graph( - # self.ds_nuclei, - # self.chunk_size, - # self.experiment_fpath, - # self.fresh_tissue_segmentation_engine, - # self.diameter_size, - # ) - - # ( - # self.nuclei_org_tiles, - # self.nuclei_adjusted_coords, - # ) = stitching.stitched_beads_on_nuclei_fresh_tissue( - # self.experiment_fpath, - # self.client, - # self.ds_nuclei, - # self.ds_beads, - # round_num=1, - # ) - - # pickle.dump( - # [ - # self.ds_beads, - # self.ds_nuclei, - # self.nuclei_metadata, - # self.nuclei_org_tiles, - # self.nuclei_adjusted_coords, - # ], - # open( - # Path(self.experiment_fpath) - # / "fresh_tissue" - # / "segmentation" - # / "tmp_data.pkl", - # "wb", - # ), - # ) + self.ds_nuclei, + self.ds_beads, + round_num=1, + ) + + pickle.dump( + [ + self.ds_beads, + self.ds_nuclei, + self.nuclei_metadata, + self.nuclei_org_tiles, + self.nuclei_adjusted_coords, + ], + open( + Path(self.experiment_fpath) + / "fresh_tissue" + / "segmentation" + / "tmp_data.pkl", + "wb", + ), + ) ( self.ds_beads, @@ -1381,7 +1387,7 @@ def processing_fresh_tissue_step( "rb", ), ) - + segmentation_output_path = ( Path(self.experiment_fpath) / "fresh_tissue" / "segmentation" ) @@ -1395,7 +1401,36 @@ def processing_fresh_tissue_step( self.client, self.min_overlapping_pixels_segmentation, ) + gc.collect() + + def processing_assign_dots(self): + gc.collect() + segmentation_output_path = ( + Path(self.experiment_fpath) / "fresh_tissue" / "segmentation" + ) + + ( + self.ds_beads, + self.ds_nuclei, + self.nuclei_metadata, + self.nuclei_org_tiles, + self.nuclei_adjusted_coords, + ) = pickle.load( + open( + Path(self.experiment_fpath) + / "fresh_tissue" + / "segmentation" + / "tmp_data.pkl", + "rb", + ), + ) + segmented_object_dict_recalculated = pickle.load( + open( + segmentation_output_path / ("segmented_objects_dict_recalculated_ids.pkl"), + "rb", + ), + ) segmentation.register_assign( self.experiment_fpath, segmented_object_dict_recalculated, @@ -1407,7 +1442,6 @@ def processing_fresh_tissue_step( segmentation_output_path, self.max_expansion_radius, self.hamming_distance, - centering_mode=centering_mode, ) # -------------------------------- @@ -1663,6 +1697,7 @@ def run_full(self): step_start = datetime.now() self.processing_fresh_tissue_step() + self.processing_assign_dots() self.logger.info( f"{self.experiment_fpath.stem} timing: \ Processing fresh tissue completed in {utils.nice_deltastring(datetime.now() - step_start)}." diff --git a/pysmFISH/segmentation.py b/pysmFISH/segmentation.py index ca71c21..8f46d5b 100644 --- a/pysmFISH/segmentation.py +++ b/pysmFISH/segmentation.py @@ -484,12 +484,11 @@ def create_label_image( ), ) - '''zarr_fpath = segmentation_output_path / "image_segmented_labels.zarr" + zarr_fpath = segmentation_output_path / "image_segmented_labels.zarr" store = zarr.DirectoryStore(zarr_fpath, "w") grp = zarr.group(store=store, overwrite=True) - grp.create_dataset(name="segmented_labels_image", data=img)''' - - np.save(os.path.join(segmentation_output_path,'segmented_labels_image.npy'),img) + grp.create_dataset(name="segmented_labels_image", data=img) + #np.save(os.path.join(segmentation_output_path,'segmented_labels_image.npy'),img) return segmented_object_dict_recalculated @@ -708,11 +707,11 @@ def register_assign( source_RNA_df.loc[:, ["r_transformed", "c_transformed"]] = transformed_points # Replace this with chunk loading in the expanding function - '''zarr_fpath = segmentation_output_path / "image_segmented_labels.zarr" + zarr_fpath = segmentation_output_path / "image_segmented_labels.zarr" store = zarr.DirectoryStore(zarr_fpath, "r") grp = zarr.group(store=store, overwrite=False) - segmented_img = grp["segmented_labels_image"][...]''' - segmented_img = np.load(os.path.join(segmentation_output_path,'segmented_labels_image.npy')) + segmented_img = grp["segmented_labels_image"][...] + #segmented_img = np.load(os.path.join(segmentation_output_path,'segmented_labels_image.npy')) # instantiate model CA = Cell_Assignment() diff --git a/pysmFISH/stitching.py b/pysmFISH/stitching.py index a4f6dc4..65c1b19 100644 --- a/pysmFISH/stitching.py +++ b/pysmFISH/stitching.py @@ -1342,6 +1342,14 @@ def stitching_graph_fresh_nuclei( k: v for (k, v) in all_registrations_dict.items() if np.all(np.abs(v[0]) < 20) } + # Alejandro version + all_registrations_removed_large_shift = {} + for (k, v) in all_registrations_dict.items(): + if np.all(np.abs(v[0]) < 20): + all_registrations_removed_large_shift[k] = v + else: + all_registrations_removed_large_shift[k] = [np.array([0,0]), 1.0] + cpls = all_registrations_removed_large_shift.keys() # cpls = list(unfolded_overlapping_regions_dict.keys()) total_cpls = len(cpls)