Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mixed barcode rescue to SpotDecoding #56

Merged
merged 7 commits into from
May 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions deepcell_spots/applications/polaris.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,11 @@ def predict(self,
decoding_result = self.decoding_app.predict(
spots_intensities_vec, **decoding_training_kwargs)
else:
decoding_result = {'probability': None,
'predicted_id': None, 'predicted_name': None}
decoding_result = {'spot_index': None,
'probability': None,
'predicted_id': None,
'predicted_name': None,
'source': None}

df_spots = output_to_df(spots_locations_vec, spots_cell_assignments_vec, decoding_result)
df_intensities = pd.DataFrame(spots_intensities_vec)
Expand Down
205 changes: 171 additions & 34 deletions deepcell_spots/applications/spot_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ class SpotDecoding(Application):

Args:
df_barcodes (pandas.DataFrame): Codebook, the first column is gene names ('Gene'),
the rest are binary barcodes, encoded using 1 and 0. Index should start at 1.
For exmaple, for a (rounds=10, channels=2) codebook, it should look like::
the rest are binary barcodes, encoded using 1 and 0. Index should start at 1.
For exmaple, for a (rounds=10, channels=2) codebook, it should look like::

Index:
RangeIndex (starting from 1)
Expand Down Expand Up @@ -96,7 +96,6 @@ def __init__(self,

self._validate_codebook(df_barcodes)
self.df_barcodes = self._add_bkg_unknown_to_barcodes(df_barcodes)


super(SpotDecoding, self).__init__(
model=None,
Expand All @@ -113,8 +112,8 @@ def _validate_codebook(self, df_barcodes):

Args:
df_barcodes (pandas.DataFrame): Codebook, the first column is gene names ('Gene'),
the rest are binary barcodes, encoded using 1 and 0. Index should start at 1.
For exmaple, for a (rounds=10, channels=2) codebook, it should look like::
the rest are binary barcodes, encoded using 1 and 0. Index should start at 1.
For exmaple, for a (rounds=10, channels=2) codebook, it should look like::

Index:
RangeIndex (starting from 1)
Expand Down Expand Up @@ -148,11 +147,25 @@ def _validate_codebook(self, df_barcodes):
'These values will be added automatically.')

def _add_bkg_unknown_to_barcodes(self, df_barcodes):
"""Add Background and Unknown category to the codebook. The barcode of Background
"""Add Background and Unknown barcodes to the codebook. The barcode of Background
is all zeros and the barcode for Unknown is all -1s.

Args:
df_barcodes (pd.DataFrame): The codebook initialized by users.
df_barcodes (pd.DataFrame): Codebook, the first column is gene names ('Gene'),
the rest are binary barcodes, encoded using 1 and 0. Index should start at 1.
For exmaple, for a (rounds=10, channels=2) codebook, it should look like::

Index:
RangeIndex (starting from 1)
Columns:
Name: Gene, dtype: object
Name: r0c0, dtype: int64
Name: r0c1, dtype: int64
Name: r1c0, dtype: int64
Name: r1c1, dtype: int64
...
Name: r9c0, dtype: int64
Name: r9c1, dtype: int64

Returns:
pd.DataFrame: The augmented codebook.
Expand Down Expand Up @@ -188,16 +201,20 @@ def _decoding_output_to_dict(self, out):
out (dict): Dictionary with keys: 'class_probs', 'params'.

Returns:
dict: Dictionary with keys: 'probability', 'predicted_id', 'predicted_name'.
dict: Dictionary with keys: 'spot_index', 'probability', 'predicted_id',
'predicted_name', 'source'.

"""
barcodes_idx2name = dict(
zip(1 + np.arange(len(self.df_barcodes)), self.df_barcodes.Gene.values))
decoded_dict = {}
decoded_dict['probability'] = out['class_probs'].max(axis=1)
decoded_dict['spot_index'] = np.arange(len(decoded_dict['probability']))
decoded_dict['predicted_id'] = out['class_probs'].argmax(axis=1) + 1
decoded_dict['predicted_name'] = np.array(
list(map(barcodes_idx2name.get, decoded_dict['predicted_id'])))
decoded_dict['source'] = np.repeat(
'prediction', len(decoded_dict['probability'])).astype('U25')
return decoded_dict

def _threshold_unknown_by_prob(self, decoded_dict, unknown_index, thres_prob=0.5):
Expand All @@ -206,7 +223,7 @@ def _threshold_unknown_by_prob(self, decoded_dict, unknown_index, thres_prob=0.5

Args:
decoded_dict (dict): Dictionary containing decoded spot identities with
keys: 'probability', 'predicted_id', 'predicted_name'.
keys: 'spot_index', 'probability', 'predicted_id', 'predicted_name', 'source'.
unknown_index (int): The index for Unknown category.

Returns:
Expand All @@ -217,18 +234,22 @@ def _threshold_unknown_by_prob(self, decoded_dict, unknown_index, thres_prob=0.5
decoded_dict['predicted_name'][decoded_dict['probability'] < thres_prob] = 'Unknown'
return decoded_dict

def _rescue_spots(self,
decoding_dict_trunc,
spots_intensities_vec):
def _rescue_errors(self,
decoding_dict,
spots_intensities_vec):
"""Rescues decoded spots assigned as 'Background' or 'Unknown' by if their spot
probability values have a Hamming distance of 1 from each of the barcodes.

Args:
decoding_dict_trunc (dict): Dictionary containing decoded spot identities with
keys: 'probability', 'predicted_id', 'predicted_name'. This dictionary has already
been processed to assign low probability predictions to 'Unknown'.
decoding_dict (dict): Dictionary containing decoded spot identities with
keys: 'spot_index', 'probability', 'predicted_id', 'predicted_name', 'source'. This
dictionary has already been processed to assign low probability predictions
to 'Unknown'.
spots_intensities_vec (numpy.array): Array of spot probability values with shape
[num_spots, r*c].
Returns:
dict: Dictionary with keys: 'spot_index', 'probability', 'predicted_id',
'predicted_name', 'source'.
"""

ch_names = list(self.df_barcodes.columns)
Expand All @@ -237,35 +258,139 @@ def _rescue_spots(self,
num_barcodes = barcodes_array.shape[0]
barcode_len = barcodes_array.shape[1]

predicted_ids = decoding_dict_trunc['predicted_id']
predicted_names = decoding_dict_trunc['predicted_name']
predicted_ids = decoding_dict['predicted_id']
predicted_names = decoding_dict['predicted_name']
sources = decoding_dict['source']

attempted = 0
successful = 0
for i,pred in tqdm(enumerate(predicted_names)):
if pred in ['Background', 'Unknown']:
attempted += 1
dist_list = np.zeros(num_barcodes)
for ii in range(num_barcodes):
dist_list[ii] = distance.hamming(np.round(spots_intensities_vec[i]),
barcodes_array[ii])
scaled_dist_list = dist_list * barcode_len
if 1 in scaled_dist_list:
successful += 1

new_gene = np.argwhere(scaled_dist_list == 1)[0][0]
predicted_ids[i] = new_gene
# gene ids are 1-indexed
predicted_ids[i] = new_gene + 1
predicted_names[i] = self.df_barcodes['Gene'].values[new_gene]
sources[i] = 'error rescue'

decoding_dict_rescued = {
result = {
'spot_index': decoding_dict['spot_index'],
'predicted_id': predicted_ids,
'predicted_name': predicted_names,
'probability': decoding_dict_trunc['probability']
'probability': decoding_dict['probability'],
'source': sources
}

print('{} of {} rescue attempts were successful.'.format(successful, attempted))

return(decoding_dict_rescued)
return(result)

def _rescue_mixed_spots(self,
decoding_dict,
spots_intensities_vec,
prob_threshold=0.95):
"""Rescues decoded spots assigned as 'Background' or 'Unknown' by if their spot
probability values have a Hamming distance of 1 from each of the barcodes.

Args:
decoding_dict (dict): Dictionary containing decoded spot identities with
keys: 'spot_index', 'probability', 'predicted_id', 'predicted_name', 'source'.
spots_intensities_vec (numpy.array): Array of spot probability values with shape
[num_spots, r*c].
df_barcodes (pd.DataFrame): Codebook, the first column is gene names ('Gene'),
the rest are binary barcodes, encoded using 1 and 0. Index should start at 1.
For exmaple, for a (rounds=10, channels=2) codebook, it should look like::

Index:
RangeIndex (starting from 1)
Columns:
Name: Gene, dtype: object
Name: r0c0, dtype: int64
Name: r0c1, dtype: int64
Name: r1c0, dtype: int64
Name: r1c1, dtype: int64
...
Name: r9c0, dtype: int64
Name: r9c1, dtype: int64

Returns:
dict: Dictionary with keys: 'spot_index', 'probability', 'predicted_id',
'predicted_name', 'source'.
"""

ch_names = list(self.df_barcodes.columns)
ch_names.remove('Gene')
barcodes_array = self.df_barcodes[ch_names].values
num_barcodes = barcodes_array.shape[0]
barcode_len = barcodes_array.shape[1]

spot_indices = decoding_dict['spot_index']
predicted_ids = decoding_dict['predicted_id']
predicted_names = decoding_dict['predicted_name']
probabilities = decoding_dict['probability']
sources = decoding_dict['source']

attempted = 0
successful = 0
for i,prob in tqdm(enumerate(probabilities)):
if prob < prob_threshold:
attempted += 1
gene_id = predicted_ids[i]
if gene_id > num_barcodes-2:
continue

# gene ids are 1-indexed
barcode = barcodes_array[gene_id-1]
intensities_updated = spots_intensities_vec[i].copy()
intensities_updated[barcode==1] = 0

dist_list = np.zeros(num_barcodes)
for ii in range(num_barcodes):
dist_list[ii] = distance.hamming(np.round(intensities_updated),
barcodes_array[ii])
scaled_dist_list = dist_list * barcode_len
if 1 in scaled_dist_list:
successful += 1
spot_indices = np.append(spot_indices, [spot_indices[i]])

new_id = np.argwhere(scaled_dist_list == 1)[0][0]
# gene ids are 1-indexed
predicted_ids = np.append(predicted_ids, [new_id + 1])

new_name = self.df_barcodes['Gene'].values[new_id]
predicted_names = np.append(predicted_names, [new_name])

probabilities = np.append(probabilities, [-1])

sources = np.append(sources, 'mixed rescue')

result = {
'spot_index': spot_indices,
'predicted_id': predicted_ids,
'predicted_name': predicted_names,
'probability': probabilities,
'source': sources
}

print('{} of {} rescue attempts were successful.'.format(successful, attempted))

return(result)

def _predict(self,
spots_intensities_vec,
num_iter,
batch_size,
thres_prob,
rescue_spots):
rescue_errors,
rescue_mixed):
"""Predict the gene assignment of each spot.

Args:
Expand All @@ -274,11 +399,14 @@ def _predict(self,
num_iter (int): Number of iterations for training. Defaults to 500.
batch_size (int): Size of batches for training. Defaults to 1000.
thres_prob (float): The threshold of unknown category, within [0,1]. Defaults to 0.5.
rescue_spots (bool): Whether to check if 'Background'- and 'Unknown'-assigned spots
rescue_errors (bool): Whether to check if 'Background'- and 'Unknown'-assigned spots
have a Hamming distance of 1 to other barcodes.
rescue_mixed (bool): Whether to check if low probability predictions are the result of
two mixed barcodes.

Returns:
dict: Dictionary with keys: 'probability', 'predicted_id', 'predicted_name'.
dict: Dictionary with keys: 'spot_index', 'probability', 'predicted_id',
'predicted_name', 'source'.
"""
self._validate_spots_intensities(spots_intensities_vec)

Expand All @@ -303,19 +431,24 @@ def _predict(self,
decoding_dict_trunc = self._threshold_unknown_by_prob(
decoding_dict, unknown_index, thres_prob=thres_prob)

if rescue_spots:
decoding_dict_rescued = self._rescue_spots(decoding_dict_trunc,
spots_intensities_vec)
return decoding_dict_rescued
else:
return decoding_dict_trunc
if rescue_errors:
print('Revising errors...')
decoding_dict_trunc = self._rescue_errors(decoding_dict_trunc,
spots_intensities_vec)
if rescue_mixed:
print('Correcting mixed barcodes...')
decoding_dict_trunc = self._rescue_mixed_spots(decoding_dict_trunc,
spots_intensities_vec)

return decoding_dict_trunc

def predict(self,
spots_intensities_vec,
num_iter=500,
batch_size=1000,
thres_prob=0.5,
rescue_spots=True):
rescue_errors=True,
rescue_mixed=False):
"""Predict the gene assignment of each spot.

Args:
Expand All @@ -324,16 +457,20 @@ def predict(self,
num_iter (int): Number of iterations for training. Defaults to 500.
batch_size (int): Size of batches for training. Defaults to 1000.
thres_prob (float): The threshold of unknown category, within [0,1]. Defaults to 0.5.
rescue_spots (bool): Whether to check if 'Background' and 'Unknown' assigned spots
rescue_errors (bool): Whether to check if 'Background'- and 'Unknown'-assigned spots
have a Hamming distance of 1 to other barcodes.
rescue_mixed (bool): Whether to check if low probability predictions are the result of
two mixed barcodes.

Returns:
dict: Dictionary with keys: 'probability', 'predicted_id', 'predicted_name'.
dict: Dictionary with keys: 'spot_index', 'probability', 'predicted_id',
'predicted_name', 'source'.
"""

return self._predict(
spots_intensities_vec=spots_intensities_vec,
num_iter=num_iter,
batch_size=batch_size,
thres_prob=thres_prob,
rescue_spots=rescue_spots)
rescue_errors=rescue_errors,
rescue_mixed=rescue_mixed)
16 changes: 16 additions & 0 deletions deepcell_spots/applications/spot_decoding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,11 @@ def test_spot_decoding_app(self):
decoding_dict_trunc1 = app1.predict(
spots_intensities_vec=spots_intensities_vec1, num_iter=20, batch_size=100
)
self.assertEqual(decoding_dict_trunc1["spot_index"].shape, (100,))
self.assertEqual(decoding_dict_trunc1["probability"].shape, (100,))
self.assertEqual(decoding_dict_trunc1["predicted_id"].shape, (100,))
self.assertEqual(decoding_dict_trunc1["predicted_name"].shape, (100,))
self.assertEqual(decoding_dict_trunc1["source"].shape, (100,))

# simple functionality test
df_barcodes2 = pd.DataFrame(
Expand Down Expand Up @@ -188,3 +190,17 @@ def test_spot_decoding_app(self):
spots_intensities_vec[0,0] = 0.5
with self.assertRaises(ValueError):
_ = app.predict(spots_intensities_vec=spots_intensities_vec, num_iter=20, batch_size=100)

# Test mixed rescue
app = SpotDecoding(df_barcodes=df_barcodes1, rounds=2, channels=3,
distribution='Relaxed Bernoulli', params_mode='2*R*C')

spots_intensities_vec = np.random.rand(100, 6)
decoding_dict = app.predict(spots_intensities_vec=spots_intensities_vec, num_iter=20, batch_size=100,
rescue_mixed=True)

self.assertGreaterEqual(decoding_dict["spot_index"].shape, (100,))
self.assertGreaterEqual(decoding_dict["probability"].shape, (100,))
self.assertGreaterEqual(decoding_dict["predicted_id"].shape, (100,))
self.assertGreaterEqual(decoding_dict["predicted_name"].shape, (100,))
self.assertGreaterEqual(decoding_dict["source"].shape, (100,))
2 changes: 2 additions & 0 deletions deepcell_spots/decoding_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,9 +603,11 @@ def decoding_function(spots,
optim, loss=TraceEnum_ELBO(max_plate_nesting=1))
pyro.set_rng_seed(set_seed)

print('Training...')
losses = train(svi, num_iter, data, codes, c, r,
min(num_spots, batch_size), distribution, params_mode)

print('Estimating barcode probabilities...')
if distribution=='Gaussian':
w_star = pyro.param('weights').detach()

Expand Down