diff --git a/deepcell_spots/utils/results_utils.py b/deepcell_spots/utils/results_utils.py index 9bb1ab3..6bbcf98 100644 --- a/deepcell_spots/utils/results_utils.py +++ b/deepcell_spots/utils/results_utils.py @@ -81,6 +81,68 @@ def get_cell_counts(df_spots): return(df_cell_counts) +def assign_barcodes(df_spots, segmentation_results): + """Assigns barcode identity to a cell for Polaris prediction for data from optical pooled + screens. This function does not support multi-batch inputs. + + Args: + df_spots (pandas.DataFrame): Polaris result, columns are `x`, `y`, `batch_id`, `cell_id`, + `probability`, `predicted_id`, `predicted_name`, `spot_index`, `source`, and `masked`. + `batch_id` should only have one unique value. + segmentation_results (numpy.array): Segmentation result from Polaris with shape + `(1,x,y,1)`. Pixel values should match `cell_id` values in `df_spots`. The background + pixels are assumed to have the value 0. + + Returns: + pandas.DataFrame: Barcode assignment for each cell, columns are `cell_id`, `x`, `y`, + `predicted_name`, `predicted_id`, `spot_counts`, `spot_fraction`. `x` and `y` are the + centroid of the cell with value `cell_id` in `segmentation_results`. `predicted_name` + and `predicted_id` correspond to the assigned barcode. `spot_counts` is the number of + spots detected in a cell with the assigned barcode. `spot_fraction` is the fraction + of detections in a cell with the assigned barcode. + """ + df_assignments = pd.DataFrame(columns=['cell_id', 'x', 'y', 'predicted_name', 'predicted_id', + 'spot_counts', 'spot_fraction']) + + if len(segmentation_results.shape) != 4: + raise ValueError('Input data must have {} dimensions. ' + 'Input data only has {} dimensions'.format( + 4, len(segmentation_results.shape))) + if segmentation_results.shape[0] != 1: + raise ValueError('Input data must have a batch dimension of size 1. ' + 'Input data only has a batch dimension of size {}.'.format( + segmentation_results.shape[0])) + + for i in tqdm(range(1,np.max(segmentation_results).astype(int)+1)): + df_cell = df_spots.loc[df_spots.cell_id == i] + df_cell = df_cell.loc[~df_cell.predicted_name.isin(['Background', 'Unknown'])] + n_spots = len(df_cell) + + cell_pixels = np.argwhere(segmentation_results == i) + x = np.mean(cell_pixels[:,1]) + y = np.mean(cell_pixels[:,2]) + + if n_spots > 0: + barcode_dict = {} + for barcode in df_cell.predicted_name.unique(): + df_barcode = df_cell.loc[df_cell.predicted_name==barcode] + barcode_dict[barcode] = sum(df_barcode.probability) + assignment = max(barcode_dict, key=barcode_dict.get) + df_correct = df_cell.loc[df_cell.predicted_name==assignment] + assignment_id = df_correct.predicted_id.values[0] + counts = len(df_correct) + fraction = counts/n_spots + else: + assignment = 'None' + assignment_id = -1 + counts = 0 + fraction = 0 + + df_assignments.loc[len(df_assignments)] = [i, x, y, assignment, assignment_id, counts, fraction] + + return(df_assignments) + + def filter_results(df_spots, batch_id=None, cell_id=None, gene_name=None, source=None, masked=False): """Filter Pandas DataFrame output from Polaris application by batch ID, cell ID, diff --git a/deepcell_spots/utils/results_utils_test.py b/deepcell_spots/utils/results_utils_test.py index 0354f60..d72e094 100644 --- a/deepcell_spots/utils/results_utils_test.py +++ b/deepcell_spots/utils/results_utils_test.py @@ -35,7 +35,7 @@ from tensorflow.python.platform import test from deepcell_spots.utils.results_utils import (filter_results, gene_visualization, - get_cell_counts) + get_cell_counts, assign_barcodes) class TestResultsUtils(test.TestCase): @@ -53,11 +53,11 @@ def test_get_cell_counts(self): 'predicted_name', 'spot_index', 'source', 'masked'] ) df_cell_counts = get_cell_counts(df_spots) - self.assertAllEqual(df_cell_counts.batch_id.values[0], 0) - self.assertAllEqual(df_cell_counts.cell_id.values[0], 1) - self.assertAllEqual(df_cell_counts.A.values[0], 3) - self.assertAllEqual(df_cell_counts.B.values[0], 1) - self.assertAllEqual(df_cell_counts.C.values[0], 1) + self.assertEqual(df_cell_counts.batch_id.values[0], 0) + self.assertEqual(df_cell_counts.cell_id.values[0], 1) + self.assertEqual(df_cell_counts.A.values[0], 3) + self.assertEqual(df_cell_counts.B.values[0], 1) + self.assertEqual(df_cell_counts.C.values[0], 1) def test_filter_results(self): df_spots = pd.DataFrame( @@ -116,6 +116,72 @@ def test_filter_results(self): self.assertEqual(len(df_spots.columns), len(df_filter.columns)) + def test_assign_barcodes(self): + # Test one spot example + df_spots = pd.DataFrame( + [ + [10, 10, 0, 1, 0.95, 1, 'A', 0, 'prediction', 0] + ], + columns=['x', 'y', 'batch_id', 'cell_id', 'probability', 'predicted_id', + 'predicted_name', 'spot_index', 'source', 'masked'] + ) + segmentation_results = np.ones((1, 20, 20, 1)) + df_assignments = assign_barcodes(df_spots, segmentation_results) + self.assertEqual(len(df_assignments), 1) + self.assertEqual(df_assignments.predicted_name.values[0], 'A') + self.assertEqual(df_assignments.predicted_id.values[0], 1) + self.assertEqual(df_assignments.spot_counts.values[0], 1) + self.assertEqual(df_assignments.spot_fraction.values[0], 1) + + # Test multi spot example + df_spots = pd.DataFrame( + [ + [8, 8, 0, 1, 0.95, 1, 'A', 0, 'prediction', 0], + [9, 9, 0, 1, 0.95, 1, 'A', 0, 'prediction', 0], + [10, 10, 0, 1, 0.95, 1, 'A', 0, 'prediction', 0], + [11, 11, 0, 1, 0.9, 2, 'B', 0, 'prediction', 0], + [12, 12, 0, 1, 0.9, 2, 'B', 0, 'prediction', 0] + ], + columns=['x', 'y', 'batch_id', 'cell_id', 'probability', 'predicted_id', + 'predicted_name', 'spot_index', 'source', 'masked'] + ) + segmentation_results = np.ones((1, 20, 20, 1)) + df_assignments = assign_barcodes(df_spots, segmentation_results) + self.assertEqual(len(df_assignments), 1) + self.assertEqual(df_assignments.predicted_name.values[0], 'A') + self.assertEqual(df_assignments.predicted_id.values[0], 1) + self.assertEqual(df_assignments.spot_counts.values[0], 3) + self.assertEqual(df_assignments.spot_fraction.values[0], 0.6) + + # Test cell with no spots/two cells + df_spots = pd.DataFrame( + [ + [8, 8, 0, 1, 0.95, 1, 'A', 0, 'prediction', 0], + [9, 9, 0, 1, 0.95, 1, 'A', 0, 'prediction', 0], + [10, 10, 0, 1, 0.95, 1, 'A', 0, 'prediction', 0], + [11, 11, 0, 1, 0.9, 2, 'B', 0, 'prediction', 0], + [12, 12, 0, 1, 0.9, 2, 'B', 0, 'prediction', 0] + ], + columns=['x', 'y', 'batch_id', 'cell_id', 'probability', 'predicted_id', + 'predicted_name', 'spot_index', 'source', 'masked'] + ) + segmentation_results = np.ones((1, 20, 20, 1)) + segmentation_results[0,0] += 1 + df_assignments = assign_barcodes(df_spots, segmentation_results) + self.assertEqual(len(df_assignments), 2) + self.assertEqual(df_assignments.predicted_name.values[1], 'None') + self.assertEqual(df_assignments.predicted_id.values[1], -1) + self.assertEqual(df_assignments.spot_counts.values[1], 0) + self.assertEqual(df_assignments.spot_fraction.values[1], 0) + + # Test raises errors + segmentation_results = np.ones((1, 20, 20)) + with self.assertRaises(ValueError): + _ = assign_barcodes(df_spots, segmentation_results) + segmentation_results = np.ones((2, 20, 20, 1)) + with self.assertRaises(ValueError): + _ = assign_barcodes(df_spots, segmentation_results) + def test_gene_visualization(self): df_spots = pd.DataFrame( [