diff --git a/deepcell_spots/applications/polaris.py b/deepcell_spots/applications/polaris.py index 34e34c4..3c4c3dc 100644 --- a/deepcell_spots/applications/polaris.py +++ b/deepcell_spots/applications/polaris.py @@ -298,8 +298,11 @@ def predict(self, decoding_result = self.decoding_app.predict( spots_intensities_vec, **decoding_training_kwargs) else: - decoding_result = {'probability': None, - 'predicted_id': None, 'predicted_name': None} + decoding_result = {'spot_index': None, + 'probability': None, + 'predicted_id': None, + 'predicted_name': None, + 'source': None} df_spots = output_to_df(spots_locations_vec, spots_cell_assignments_vec, decoding_result) df_intensities = pd.DataFrame(spots_intensities_vec) diff --git a/deepcell_spots/applications/spot_decoding.py b/deepcell_spots/applications/spot_decoding.py index e8b6bf7..5719696 100644 --- a/deepcell_spots/applications/spot_decoding.py +++ b/deepcell_spots/applications/spot_decoding.py @@ -56,8 +56,8 @@ class SpotDecoding(Application): Args: df_barcodes (pandas.DataFrame): Codebook, the first column is gene names ('Gene'), - the rest are binary barcodes, encoded using 1 and 0. Index should start at 1. - For exmaple, for a (rounds=10, channels=2) codebook, it should look like:: + the rest are binary barcodes, encoded using 1 and 0. Index should start at 1. + For exmaple, for a (rounds=10, channels=2) codebook, it should look like:: Index: RangeIndex (starting from 1) @@ -96,7 +96,6 @@ def __init__(self, self._validate_codebook(df_barcodes) self.df_barcodes = self._add_bkg_unknown_to_barcodes(df_barcodes) - super(SpotDecoding, self).__init__( model=None, @@ -113,8 +112,8 @@ def _validate_codebook(self, df_barcodes): Args: df_barcodes (pandas.DataFrame): Codebook, the first column is gene names ('Gene'), - the rest are binary barcodes, encoded using 1 and 0. Index should start at 1. - For exmaple, for a (rounds=10, channels=2) codebook, it should look like:: + the rest are binary barcodes, encoded using 1 and 0. Index should start at 1. + For exmaple, for a (rounds=10, channels=2) codebook, it should look like:: Index: RangeIndex (starting from 1) @@ -148,11 +147,25 @@ def _validate_codebook(self, df_barcodes): 'These values will be added automatically.') def _add_bkg_unknown_to_barcodes(self, df_barcodes): - """Add Background and Unknown category to the codebook. The barcode of Background + """Add Background and Unknown barcodes to the codebook. The barcode of Background is all zeros and the barcode for Unknown is all -1s. Args: - df_barcodes (pd.DataFrame): The codebook initialized by users. + df_barcodes (pd.DataFrame): Codebook, the first column is gene names ('Gene'), + the rest are binary barcodes, encoded using 1 and 0. Index should start at 1. + For exmaple, for a (rounds=10, channels=2) codebook, it should look like:: + + Index: + RangeIndex (starting from 1) + Columns: + Name: Gene, dtype: object + Name: r0c0, dtype: int64 + Name: r0c1, dtype: int64 + Name: r1c0, dtype: int64 + Name: r1c1, dtype: int64 + ... + Name: r9c0, dtype: int64 + Name: r9c1, dtype: int64 Returns: pd.DataFrame: The augmented codebook. @@ -188,16 +201,20 @@ def _decoding_output_to_dict(self, out): out (dict): Dictionary with keys: 'class_probs', 'params'. Returns: - dict: Dictionary with keys: 'probability', 'predicted_id', 'predicted_name'. + dict: Dictionary with keys: 'spot_index', 'probability', 'predicted_id', + 'predicted_name', 'source'. """ barcodes_idx2name = dict( zip(1 + np.arange(len(self.df_barcodes)), self.df_barcodes.Gene.values)) decoded_dict = {} decoded_dict['probability'] = out['class_probs'].max(axis=1) + decoded_dict['spot_index'] = np.arange(len(decoded_dict['probability'])) decoded_dict['predicted_id'] = out['class_probs'].argmax(axis=1) + 1 decoded_dict['predicted_name'] = np.array( list(map(barcodes_idx2name.get, decoded_dict['predicted_id']))) + decoded_dict['source'] = np.repeat( + 'prediction', len(decoded_dict['probability'])).astype('U25') return decoded_dict def _threshold_unknown_by_prob(self, decoded_dict, unknown_index, thres_prob=0.5): @@ -206,7 +223,7 @@ def _threshold_unknown_by_prob(self, decoded_dict, unknown_index, thres_prob=0.5 Args: decoded_dict (dict): Dictionary containing decoded spot identities with - keys: 'probability', 'predicted_id', 'predicted_name'. + keys: 'spot_index', 'probability', 'predicted_id', 'predicted_name', 'source'. unknown_index (int): The index for Unknown category. Returns: @@ -217,18 +234,22 @@ def _threshold_unknown_by_prob(self, decoded_dict, unknown_index, thres_prob=0.5 decoded_dict['predicted_name'][decoded_dict['probability'] < thres_prob] = 'Unknown' return decoded_dict - def _rescue_spots(self, - decoding_dict_trunc, - spots_intensities_vec): + def _rescue_errors(self, + decoding_dict, + spots_intensities_vec): """Rescues decoded spots assigned as 'Background' or 'Unknown' by if their spot probability values have a Hamming distance of 1 from each of the barcodes. Args: - decoding_dict_trunc (dict): Dictionary containing decoded spot identities with - keys: 'probability', 'predicted_id', 'predicted_name'. This dictionary has already - been processed to assign low probability predictions to 'Unknown'. + decoding_dict (dict): Dictionary containing decoded spot identities with + keys: 'spot_index', 'probability', 'predicted_id', 'predicted_name', 'source'. This + dictionary has already been processed to assign low probability predictions + to 'Unknown'. spots_intensities_vec (numpy.array): Array of spot probability values with shape [num_spots, r*c]. + Returns: + dict: Dictionary with keys: 'spot_index', 'probability', 'predicted_id', + 'predicted_name', 'source'. """ ch_names = list(self.df_barcodes.columns) @@ -237,35 +258,139 @@ def _rescue_spots(self, num_barcodes = barcodes_array.shape[0] barcode_len = barcodes_array.shape[1] - predicted_ids = decoding_dict_trunc['predicted_id'] - predicted_names = decoding_dict_trunc['predicted_name'] + predicted_ids = decoding_dict['predicted_id'] + predicted_names = decoding_dict['predicted_name'] + sources = decoding_dict['source'] + attempted = 0 + successful = 0 for i,pred in tqdm(enumerate(predicted_names)): if pred in ['Background', 'Unknown']: + attempted += 1 dist_list = np.zeros(num_barcodes) for ii in range(num_barcodes): dist_list[ii] = distance.hamming(np.round(spots_intensities_vec[i]), barcodes_array[ii]) scaled_dist_list = dist_list * barcode_len if 1 in scaled_dist_list: + successful += 1 + new_gene = np.argwhere(scaled_dist_list == 1)[0][0] - predicted_ids[i] = new_gene + # gene ids are 1-indexed + predicted_ids[i] = new_gene + 1 predicted_names[i] = self.df_barcodes['Gene'].values[new_gene] + sources[i] = 'error rescue' - decoding_dict_rescued = { + result = { + 'spot_index': decoding_dict['spot_index'], 'predicted_id': predicted_ids, 'predicted_name': predicted_names, - 'probability': decoding_dict_trunc['probability'] + 'probability': decoding_dict['probability'], + 'source': sources } + + print('{} of {} rescue attempts were successful.'.format(successful, attempted)) - return(decoding_dict_rescued) + return(result) + + def _rescue_mixed_spots(self, + decoding_dict, + spots_intensities_vec, + prob_threshold=0.95): + """Rescues decoded spots assigned as 'Background' or 'Unknown' by if their spot + probability values have a Hamming distance of 1 from each of the barcodes. + + Args: + decoding_dict (dict): Dictionary containing decoded spot identities with + keys: 'spot_index', 'probability', 'predicted_id', 'predicted_name', 'source'. + spots_intensities_vec (numpy.array): Array of spot probability values with shape + [num_spots, r*c]. + df_barcodes (pd.DataFrame): Codebook, the first column is gene names ('Gene'), + the rest are binary barcodes, encoded using 1 and 0. Index should start at 1. + For exmaple, for a (rounds=10, channels=2) codebook, it should look like:: + + Index: + RangeIndex (starting from 1) + Columns: + Name: Gene, dtype: object + Name: r0c0, dtype: int64 + Name: r0c1, dtype: int64 + Name: r1c0, dtype: int64 + Name: r1c1, dtype: int64 + ... + Name: r9c0, dtype: int64 + Name: r9c1, dtype: int64 + + Returns: + dict: Dictionary with keys: 'spot_index', 'probability', 'predicted_id', + 'predicted_name', 'source'. + """ + + ch_names = list(self.df_barcodes.columns) + ch_names.remove('Gene') + barcodes_array = self.df_barcodes[ch_names].values + num_barcodes = barcodes_array.shape[0] + barcode_len = barcodes_array.shape[1] + + spot_indices = decoding_dict['spot_index'] + predicted_ids = decoding_dict['predicted_id'] + predicted_names = decoding_dict['predicted_name'] + probabilities = decoding_dict['probability'] + sources = decoding_dict['source'] + + attempted = 0 + successful = 0 + for i,prob in tqdm(enumerate(probabilities)): + if prob < prob_threshold: + attempted += 1 + gene_id = predicted_ids[i] + if gene_id > num_barcodes-2: + continue + + # gene ids are 1-indexed + barcode = barcodes_array[gene_id-1] + intensities_updated = spots_intensities_vec[i].copy() + intensities_updated[barcode==1] = 0 + + dist_list = np.zeros(num_barcodes) + for ii in range(num_barcodes): + dist_list[ii] = distance.hamming(np.round(intensities_updated), + barcodes_array[ii]) + scaled_dist_list = dist_list * barcode_len + if 1 in scaled_dist_list: + successful += 1 + spot_indices = np.append(spot_indices, [spot_indices[i]]) + + new_id = np.argwhere(scaled_dist_list == 1)[0][0] + # gene ids are 1-indexed + predicted_ids = np.append(predicted_ids, [new_id + 1]) + + new_name = self.df_barcodes['Gene'].values[new_id] + predicted_names = np.append(predicted_names, [new_name]) + + probabilities = np.append(probabilities, [-1]) + + sources = np.append(sources, 'mixed rescue') + + result = { + 'spot_index': spot_indices, + 'predicted_id': predicted_ids, + 'predicted_name': predicted_names, + 'probability': probabilities, + 'source': sources + } + + print('{} of {} rescue attempts were successful.'.format(successful, attempted)) + + return(result) def _predict(self, spots_intensities_vec, num_iter, batch_size, thres_prob, - rescue_spots): + rescue_errors, + rescue_mixed): """Predict the gene assignment of each spot. Args: @@ -274,11 +399,14 @@ def _predict(self, num_iter (int): Number of iterations for training. Defaults to 500. batch_size (int): Size of batches for training. Defaults to 1000. thres_prob (float): The threshold of unknown category, within [0,1]. Defaults to 0.5. - rescue_spots (bool): Whether to check if 'Background'- and 'Unknown'-assigned spots + rescue_errors (bool): Whether to check if 'Background'- and 'Unknown'-assigned spots have a Hamming distance of 1 to other barcodes. + rescue_mixed (bool): Whether to check if low probability predictions are the result of + two mixed barcodes. Returns: - dict: Dictionary with keys: 'probability', 'predicted_id', 'predicted_name'. + dict: Dictionary with keys: 'spot_index', 'probability', 'predicted_id', + 'predicted_name', 'source'. """ self._validate_spots_intensities(spots_intensities_vec) @@ -303,19 +431,24 @@ def _predict(self, decoding_dict_trunc = self._threshold_unknown_by_prob( decoding_dict, unknown_index, thres_prob=thres_prob) - if rescue_spots: - decoding_dict_rescued = self._rescue_spots(decoding_dict_trunc, - spots_intensities_vec) - return decoding_dict_rescued - else: - return decoding_dict_trunc + if rescue_errors: + print('Revising errors...') + decoding_dict_trunc = self._rescue_errors(decoding_dict_trunc, + spots_intensities_vec) + if rescue_mixed: + print('Correcting mixed barcodes...') + decoding_dict_trunc = self._rescue_mixed_spots(decoding_dict_trunc, + spots_intensities_vec) + + return decoding_dict_trunc def predict(self, spots_intensities_vec, num_iter=500, batch_size=1000, thres_prob=0.5, - rescue_spots=True): + rescue_errors=True, + rescue_mixed=False): """Predict the gene assignment of each spot. Args: @@ -324,11 +457,14 @@ def predict(self, num_iter (int): Number of iterations for training. Defaults to 500. batch_size (int): Size of batches for training. Defaults to 1000. thres_prob (float): The threshold of unknown category, within [0,1]. Defaults to 0.5. - rescue_spots (bool): Whether to check if 'Background' and 'Unknown' assigned spots + rescue_errors (bool): Whether to check if 'Background'- and 'Unknown'-assigned spots have a Hamming distance of 1 to other barcodes. + rescue_mixed (bool): Whether to check if low probability predictions are the result of + two mixed barcodes. Returns: - dict: Dictionary with keys: 'probability', 'predicted_id', 'predicted_name'. + dict: Dictionary with keys: 'spot_index', 'probability', 'predicted_id', + 'predicted_name', 'source'. """ return self._predict( @@ -336,4 +472,5 @@ def predict(self, num_iter=num_iter, batch_size=batch_size, thres_prob=thres_prob, - rescue_spots=rescue_spots) + rescue_errors=rescue_errors, + rescue_mixed=rescue_mixed) diff --git a/deepcell_spots/applications/spot_decoding_test.py b/deepcell_spots/applications/spot_decoding_test.py index 39c0ea8..c1de046 100644 --- a/deepcell_spots/applications/spot_decoding_test.py +++ b/deepcell_spots/applications/spot_decoding_test.py @@ -59,9 +59,11 @@ def test_spot_decoding_app(self): decoding_dict_trunc1 = app1.predict( spots_intensities_vec=spots_intensities_vec1, num_iter=20, batch_size=100 ) + self.assertEqual(decoding_dict_trunc1["spot_index"].shape, (100,)) self.assertEqual(decoding_dict_trunc1["probability"].shape, (100,)) self.assertEqual(decoding_dict_trunc1["predicted_id"].shape, (100,)) self.assertEqual(decoding_dict_trunc1["predicted_name"].shape, (100,)) + self.assertEqual(decoding_dict_trunc1["source"].shape, (100,)) # simple functionality test df_barcodes2 = pd.DataFrame( @@ -188,3 +190,17 @@ def test_spot_decoding_app(self): spots_intensities_vec[0,0] = 0.5 with self.assertRaises(ValueError): _ = app.predict(spots_intensities_vec=spots_intensities_vec, num_iter=20, batch_size=100) + + # Test mixed rescue + app = SpotDecoding(df_barcodes=df_barcodes1, rounds=2, channels=3, + distribution='Relaxed Bernoulli', params_mode='2*R*C') + + spots_intensities_vec = np.random.rand(100, 6) + decoding_dict = app.predict(spots_intensities_vec=spots_intensities_vec, num_iter=20, batch_size=100, + rescue_mixed=True) + + self.assertGreaterEqual(decoding_dict["spot_index"].shape, (100,)) + self.assertGreaterEqual(decoding_dict["probability"].shape, (100,)) + self.assertGreaterEqual(decoding_dict["predicted_id"].shape, (100,)) + self.assertGreaterEqual(decoding_dict["predicted_name"].shape, (100,)) + self.assertGreaterEqual(decoding_dict["source"].shape, (100,)) diff --git a/deepcell_spots/decoding_functions.py b/deepcell_spots/decoding_functions.py index ca5da43..a325df5 100644 --- a/deepcell_spots/decoding_functions.py +++ b/deepcell_spots/decoding_functions.py @@ -603,9 +603,11 @@ def decoding_function(spots, optim, loss=TraceEnum_ELBO(max_plate_nesting=1)) pyro.set_rng_seed(set_seed) + print('Training...') losses = train(svi, num_iter, data, codes, c, r, min(num_spots, batch_size), distribution, params_mode) + print('Estimating barcode probabilities...') if distribution=='Gaussian': w_star = pyro.param('weights').detach()