Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #52

Merged
merged 22 commits into from
Apr 4, 2022
Merged

Dev #52

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
c798555
keep the exposures of individual sigs at the matching for single matc…
dgulhan-bio Apr 2, 2022
9dcc8d9
Revert "keep the exposures of individual sigs at the matching for sin…
dgulhan-bio Apr 2, 2022
e5c5d46
Retain information on the mapping between de novo sigs and matched sigs
Hu-JIN Apr 2, 2022
e32749b
Avoid using new signatures if possible during grid search
Hu-JIN Apr 2, 2022
aeea2b6
filter addressing flat background issues in matching
dgulhan-bio Apr 2, 2022
99b24c9
Change method for selecting the best grid point during grid search.
Hu-JIN Apr 2, 2022
14927a4
Small update in terms of the best grid point
Hu-JIN Apr 2, 2022
b1ee802
Update denovo.py
Hu-JIN Apr 2, 2022
b1abe2f
Change constant name sigs_associated to SIGS_ASSOCIATED
Hu-JIN Apr 3, 2022
05cc146
Update get_sig_indices_associated()
Hu-JIN Apr 3, 2022
869b6dc
Make sure signature output of get_sig_indices_associated is np.array
Hu-JIN Apr 3, 2022
61bd523
Update nnls_sparse.py
Hu-JIN Apr 3, 2022
cd642a6
Update get_sig_indices_associated
Hu-JIN Apr 3, 2022
7165a81
Update comment
Hu-JIN Apr 3, 2022
22dc177
Add SIGS_ASSOCIATED_DICT
Hu-JIN Apr 3, 2022
7385b5d
Modify match in terms of connected_sigs
Hu-JIN Apr 3, 2022
e73b741
Handle connected_sigs in refitting
Hu-JIN Apr 3, 2022
8713258
Fix a bug in handling connected_sigs in refitting
Hu-JIN Apr 3, 2022
a654a57
Fix another bug in handling connected_sigs in refitting
Hu-JIN Apr 3, 2022
fa6d1d3
Update method for selecting the best grid point
Hu-JIN Apr 4, 2022
70227c2
Update denovo.py
Hu-JIN Apr 4, 2022
3a77510
Add some comments
Hu-JIN Apr 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 144 additions & 39 deletions musical/denovo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,7 +1004,7 @@ def plot_selection(self, title=None, plot_pvalues=True, outfile=None, figsize=No
plt.savefig(outfile, bbox_inches='tight')

def assign(self, W_catalog, method_assign='likelihood_bidirectional',
thresh_match=None, thresh_refit=None, thresh_new_sig=0.8, indices_associated_sigs=None):
thresh_match=None, thresh_refit=None, thresh_new_sig=0.8, connected_sigs=False):
# Check if fit has been run
if not hasattr(self, 'W'):
raise ValueError('The model has not been fitted.')
Expand Down Expand Up @@ -1032,10 +1032,10 @@ def assign(self, W_catalog, method_assign='likelihood_bidirectional',
self.thresh_match = thresh_match
self.thresh_refit = thresh_refit
self.thresh_new_sig = thresh_new_sig
self.indices_associated_sigs = indices_associated_sigs
self.W_s, self.H_s = assign(self.X_df, self.W_df, self.W_catalog, method=self.method_assign,
thresh_match=self.thresh_match, thresh_refit=self.thresh_refit, thresh_new_sig=self.thresh_new_sig,
indices_associated_sigs=self.indices_associated_sigs)
self.connected_sigs = connected_sigs
self.W_s, self.H_s, self.sig_map = assign(self.X_df, self.W_df, self.W_catalog, method=self.method_assign,
thresh_match=self.thresh_match, thresh_refit=self.thresh_refit, thresh_new_sig=self.thresh_new_sig,
connected_sigs=self.connected_sigs)
self.sigs_assigned = self.H_s.index[self.H_s.sum(1) > 0].values
self.n_sigs_assigned = len(self.sigs_assigned)
self.W_s = pd.DataFrame.copy(self.W_s[self.sigs_assigned])
Expand All @@ -1044,7 +1044,7 @@ def assign(self, W_catalog, method_assign='likelihood_bidirectional',
return self

def assign_grid(self, W_catalog, method_assign='likelihood_bidirectional',
thresh_match_grid=None, thresh_refit_grid=None, thresh_new_sig=0.8, indices_associated_sigs=None):
thresh_match_grid=None, thresh_refit_grid=None, thresh_new_sig=0.8, connected_sigs=False):
# Check if fit has been run
if not hasattr(self, 'W'):
raise ValueError('The model has not been fitted.')
Expand Down Expand Up @@ -1072,14 +1072,15 @@ def assign_grid(self, W_catalog, method_assign='likelihood_bidirectional',
self.thresh_match_grid = thresh_match_grid
self.thresh_refit_grid = thresh_refit_grid
self.thresh_new_sig = thresh_new_sig
self.indices_associated_sigs = indices_associated_sigs
W_s_grid_1d, H_s_grid, thresh_match_grid_unique = assign_grid(
self.connected_sigs = connected_sigs
W_s_grid_1d, H_s_grid, sig_map_grid_1d, thresh_match_grid_unique = assign_grid(
self.X_df, self.W_df, self.W_catalog, method=self.method_assign,
thresh_match_grid=self.thresh_match_grid, thresh_refit_grid=self.thresh_refit_grid,
thresh_new_sig=self.thresh_new_sig,
indices_associated_sigs=self.indices_associated_sigs,
connected_sigs=self.connected_sigs,
ncpu=self.ncpu, verbose=self.verbose
)
self.sig_map_grid_1d = sig_map_grid_1d
self.W_s_grid = {}
self.H_s_grid = {}
self.sigs_assigned_grid = {}
Expand Down Expand Up @@ -1208,17 +1209,148 @@ def validate(self, W_s=None, H_s=None, validate_n_replicates=1):
self.H_frobenius_dist_mean = np.mean(self.H_frobenius_dist)
return self

def validate_grid(self, validate_n_replicates=1, W_cos_dist_thresh=0.02):
def _select_best_grid_point(self):
"""
TODO:
1. Right now grid points with new signatures are removed independent of the errors. In the future, modify it so that
new signatures are allowed if it improves the error a lot. Then it can be used as an indicator of discovery of a new signature.
"""
# 1. Avoid using new signatures
# Select those solutions where sigs_assigned does not contain any de novo signatures.
candidate_grid_points = [key for key in self.W_simul_grid.keys() if len(set(self.sigs_assigned_grid[key]).intersection(self.signatures)) == 0]
if len(candidate_grid_points) == 0:
warnings.warn('During grid search, all solutions contain new signatures. This could potentially mean that '
'a new signature does exist. Or, try decreasing thresh_new_sig.',
UserWarning)
candidate_grid_points = list(self.W_simul_grid.keys())
else:
pass
### Looking at distance between (W_simul, H_simul) and (W_data, H_data)
if self.grid_selection_method == 'distance':
# 2. Smallest W_cos_dist
W_cos_dist_min = np.min([self.W_cos_dist_mean_grid[key] for key in candidate_grid_points])
# Candidates: W_cos_dist within W_cos_dist_min + self.grid_selection_cos_thresh
candidate_grid_points = [key for key in candidate_grid_points if self.W_cos_dist_mean_grid[key] <= W_cos_dist_min + self.grid_selection_cos_thresh]
if len(candidate_grid_points) == 1:
self.best_grid_point = candidate_grid_points[0]
else:
# 3. Look at number of assigned sigs
n_sigs_assigned_min = np.min([self.n_sigs_assigned_grid[key] for key in candidate_grid_points])
candidate_grid_points = [key for key in candidate_grid_points if self.n_sigs_assigned_grid[key] == n_sigs_assigned_min]
if len(candidate_grid_points) == 1:
self.best_grid_point = candidate_grid_points[0]
else:
if self.grid_selection_use_H:
# 4. Look at H error
tmp = [[key, self.H_frobenius_dist_mean_grid[key]] for key in candidate_grid_points]
tmp = sorted(tmp, key=itemgetter(1))
self.best_grid_point = tmp[0][0]
else:
# 4. Select the sparest result
tmp = [[key, key[0], key[1]] for key in candidate_grid_points]
# First select largest thresh_refit, then largest thresh_match
tmp = sorted(tmp, key=itemgetter(2, 1), reverse=True)
self.best_grid_point = tmp[0][0]
### Looking at p-values
elif self.grid_selection_method == 'pvalue':
# 2. Element wise errors between W_simul and W_data.
elementwise_errors_W = []
for key in candidate_grid_points:
W_simul = np.mean(self.W_simul_grid[key], 0) # We take the average signatures of multiple replicates of simulation results here.
error = np.absolute((self.W - W_simul).flatten())
#elementwise_errors_W.append([key, np.mean(error), error])
elementwise_errors_W.append([key, self.W_cos_dist_mean_grid[key], error]) # We can decide how to identify the 'best' one.
# Best result
elementwise_errors_W = sorted(elementwise_errors_W, key=itemgetter(1))
error_best = elementwise_errors_W[0][2]
# Test differences
self.W_simul_error_pvalue_grid = {}
self.W_simul_error_pvalue_tail_grid = {}
candidate_grid_points = []
for key, _, error in elementwise_errors_W:
pvalue = stats.mannwhitneyu(error_best, error, alternative='less')[1]
pvalue_tail = differential_tail_test(error_best, error, percentile=95, alternative='less')[1]
self.W_simul_error_pvalue_grid[key] = pvalue
self.W_simul_error_pvalue_tail_grid[key] = pvalue_tail
if pvalue > self.grid_selection_pvalue_thresh and pvalue_tail > self.grid_selection_pvalue_thresh:
candidate_grid_points.append(key)
########
if len(candidate_grid_points) == 1:
self.best_grid_point = candidate_grid_points[0]
else:
# 3. Look at number of assigned sigs
n_sigs_assigned_min = np.min([self.n_sigs_assigned_grid[key] for key in candidate_grid_points])
candidate_grid_points = [key for key in candidate_grid_points if self.n_sigs_assigned_grid[key] == n_sigs_assigned_min]
if len(candidate_grid_points) == 1:
self.best_grid_point = candidate_grid_points[0]
else:
if self.grid_selection_use_H:
# 4. Look at p-values for H matrix
elementwise_errors_H = []
for key in candidate_grid_points:
H_simul = np.mean(self.H_simul_grid[key], 0)
error = np.absolute((self.H - H_simul).flatten())
elementwise_errors_H.append([key, self.H_frobenius_dist_mean_grid[key], error])
elementwise_errors_H = sorted(elementwise_errors_H, key=itemgetter(1))
error_best = elementwise_errors_H[0][2]
# Test differences
self.H_simul_error_pvalue_grid = {}
self.H_simul_error_pvalue_tail_grid = {}
candidate_grid_points = []
for key, _, error in elementwise_errors_H:
pvalue = stats.mannwhitneyu(error_best, error, alternative='less')[1]
pvalue_tail = differential_tail_test(error_best, error, percentile=95, alternative='less')[1]
self.H_simul_error_pvalue_grid[key] = pvalue
self.H_simul_error_pvalue_tail_grid[key] = pvalue_tail
if pvalue > self.grid_selection_pvalue_thresh and pvalue_tail > self.grid_selection_pvalue_thresh:
candidate_grid_points.append(key)
if len(candidate_grid_points) == 1:
self.best_grid_point = candidate_grid_points[0]
else:
# 5. Select the sparsest result
tmp = [[key, key[0], key[1]] for key in candidate_grid_points]
# First select largest thresh_refit, then largest thresh_match
tmp = sorted(tmp, key=itemgetter(2, 1), reverse=True)
self.best_grid_point = tmp[0][0]
else:
# 4. Select the sparsest result
tmp = [[key, key[0], key[1]] for key in candidate_grid_points]
# First select largest thresh_refit, then largest thresh_match
tmp = sorted(tmp, key=itemgetter(2, 1), reverse=True)
self.best_grid_point = tmp[0][0]
###
self.thresh_match = self.best_grid_point[0]
self.thresh_refit = self.best_grid_point[1]
self.W_s = self.W_s_grid[self.best_grid_point]
self.H_s = self.H_s_grid[self.best_grid_point]
self.sigs_assigned = self.sigs_assigned_grid[self.best_grid_point]
self.n_sigs_assigned = self.n_sigs_assigned_grid[self.best_grid_point]
self.W_cos_dist = self.W_cos_dist_grid[self.best_grid_point]
self.W_cos_dist_mean = self.W_cos_dist_mean_grid[self.best_grid_point]
self.H_frobenius_dist = self.H_frobenius_dist_grid[self.best_grid_point]
self.H_frobenius_dist_mean = self.H_frobenius_dist_mean_grid[self.best_grid_point]
return self

def validate_grid(self, validate_n_replicates=1,
grid_selection_method='pvalue', grid_selection_use_H=False,
grid_selection_pvalue_thresh=0.05, grid_selection_cos_thresh=0.02):
"""Validation on a grid.

grid_selection_method: 'pvalue' or 'distance'
"""
################# Check running status and input
if not hasattr(self, 'W'):
raise ValueError('The model has not been fitted.')
if not hasattr(self, '_assign_grid_is_run'):
raise ValueError('Run assign_grid first.')
if grid_selection_method not in ['pvalue', 'distance']:
raise ValueError('Bad input for grid_selection_method.')
################# Run validation
self.validate_n_replicates = validate_n_replicates
self.W_cos_dist_thresh = W_cos_dist_thresh
self.grid_selection_method = grid_selection_method
self.grid_selection_use_H = grid_selection_use_H
self.grid_selection_pvalue_thresh = grid_selection_pvalue_thresh
self.grid_selection_cos_thresh = grid_selection_cos_thresh
self.X_simul_grid = {}
self.W_simul_grid = {}
self.H_simul_grid = {}
Expand Down Expand Up @@ -1265,34 +1397,7 @@ def validate_grid(self, validate_n_replicates=1, W_cos_dist_thresh=0.02):
self.W_cos_dist_mean_grid[(thresh_match, thresh_refit)] = np.mean(W_cos_dist)
self.H_frobenius_dist_mean_grid[(thresh_match, thresh_refit)] = np.mean(H_frobenius_dist)
################# Select best result
# Smallest W_cos_dist
W_cos_dist_min = np.min(list(self.W_cos_dist_mean_grid.values()))
# Candidates: W_cos_dist within W_cos_dist_min + self.W_cos_dist_thresh
candidate_grid_points = [key for key, value in self.W_cos_dist_mean_grid.items() if value <= W_cos_dist_min + self.W_cos_dist_thresh]
if len(candidate_grid_points) == 1:
self.best_grid_point = candidate_grid_points[0]
else:
# Look at number of assigned sigs
n_sigs_assigned_min = np.min([self.n_sigs_assigned_grid[key] for key in candidate_grid_points])
candidate_grid_points = [key for key in candidate_grid_points if self.n_sigs_assigned_grid[key] == n_sigs_assigned_min]
if len(candidate_grid_points) == 1:
self.best_grid_point = candidate_grid_points[0]
else:
# Look at H error
# Or choose the one with the strongest sparsity here.
tmp = [[key, self.H_frobenius_dist_mean_grid[key]] for key in candidate_grid_points]
tmp = sorted(tmp, key=itemgetter(1))
self.best_grid_point = tmp[0][0]
self.thresh_match = self.best_grid_point[0]
self.thresh_refit = self.best_grid_point[1]
self.W_s = self.W_s_grid[self.best_grid_point]
self.H_s = self.H_s_grid[self.best_grid_point]
self.sigs_assigned = self.sigs_assigned_grid[self.best_grid_point]
self.n_sigs_assigned = self.n_sigs_assigned_grid[self.best_grid_point]
self.W_cos_dist = self.W_cos_dist_grid[self.best_grid_point]
self.W_cos_dist_mean = self.W_cos_dist_mean_grid[self.best_grid_point]
self.H_frobenius_dist = self.H_frobenius_dist_grid[self.best_grid_point]
self.H_frobenius_dist_mean = self.H_frobenius_dist_mean_grid[self.best_grid_point]
self._select_best_grid_point()
return self

###########################################################################
Expand Down
Loading