diff --git a/musical/denovo.py b/musical/denovo.py index 5bc2240..bcf83b5 100644 --- a/musical/denovo.py +++ b/musical/denovo.py @@ -1004,7 +1004,7 @@ def plot_selection(self, title=None, plot_pvalues=True, outfile=None, figsize=No plt.savefig(outfile, bbox_inches='tight') def assign(self, W_catalog, method_assign='likelihood_bidirectional', - thresh_match=None, thresh_refit=None, thresh_new_sig=0.8, indices_associated_sigs=None): + thresh_match=None, thresh_refit=None, thresh_new_sig=0.8, connected_sigs=False): # Check if fit has been run if not hasattr(self, 'W'): raise ValueError('The model has not been fitted.') @@ -1032,10 +1032,10 @@ def assign(self, W_catalog, method_assign='likelihood_bidirectional', self.thresh_match = thresh_match self.thresh_refit = thresh_refit self.thresh_new_sig = thresh_new_sig - self.indices_associated_sigs = indices_associated_sigs - self.W_s, self.H_s = assign(self.X_df, self.W_df, self.W_catalog, method=self.method_assign, - thresh_match=self.thresh_match, thresh_refit=self.thresh_refit, thresh_new_sig=self.thresh_new_sig, - indices_associated_sigs=self.indices_associated_sigs) + self.connected_sigs = connected_sigs + self.W_s, self.H_s, self.sig_map = assign(self.X_df, self.W_df, self.W_catalog, method=self.method_assign, + thresh_match=self.thresh_match, thresh_refit=self.thresh_refit, thresh_new_sig=self.thresh_new_sig, + connected_sigs=self.connected_sigs) self.sigs_assigned = self.H_s.index[self.H_s.sum(1) > 0].values self.n_sigs_assigned = len(self.sigs_assigned) self.W_s = pd.DataFrame.copy(self.W_s[self.sigs_assigned]) @@ -1044,7 +1044,7 @@ def assign(self, W_catalog, method_assign='likelihood_bidirectional', return self def assign_grid(self, W_catalog, method_assign='likelihood_bidirectional', - thresh_match_grid=None, thresh_refit_grid=None, thresh_new_sig=0.8, indices_associated_sigs=None): + thresh_match_grid=None, thresh_refit_grid=None, thresh_new_sig=0.8, connected_sigs=False): # Check if fit has been run if not hasattr(self, 'W'): raise ValueError('The model has not been fitted.') @@ -1072,14 +1072,15 @@ def assign_grid(self, W_catalog, method_assign='likelihood_bidirectional', self.thresh_match_grid = thresh_match_grid self.thresh_refit_grid = thresh_refit_grid self.thresh_new_sig = thresh_new_sig - self.indices_associated_sigs = indices_associated_sigs - W_s_grid_1d, H_s_grid, thresh_match_grid_unique = assign_grid( + self.connected_sigs = connected_sigs + W_s_grid_1d, H_s_grid, sig_map_grid_1d, thresh_match_grid_unique = assign_grid( self.X_df, self.W_df, self.W_catalog, method=self.method_assign, thresh_match_grid=self.thresh_match_grid, thresh_refit_grid=self.thresh_refit_grid, thresh_new_sig=self.thresh_new_sig, - indices_associated_sigs=self.indices_associated_sigs, + connected_sigs=self.connected_sigs, ncpu=self.ncpu, verbose=self.verbose ) + self.sig_map_grid_1d = sig_map_grid_1d self.W_s_grid = {} self.H_s_grid = {} self.sigs_assigned_grid = {} @@ -1208,17 +1209,148 @@ def validate(self, W_s=None, H_s=None, validate_n_replicates=1): self.H_frobenius_dist_mean = np.mean(self.H_frobenius_dist) return self - def validate_grid(self, validate_n_replicates=1, W_cos_dist_thresh=0.02): + def _select_best_grid_point(self): + """ + TODO: + 1. Right now grid points with new signatures are removed independent of the errors. In the future, modify it so that + new signatures are allowed if it improves the error a lot. Then it can be used as an indicator of discovery of a new signature. + """ + # 1. Avoid using new signatures + # Select those solutions where sigs_assigned does not contain any de novo signatures. + candidate_grid_points = [key for key in self.W_simul_grid.keys() if len(set(self.sigs_assigned_grid[key]).intersection(self.signatures)) == 0] + if len(candidate_grid_points) == 0: + warnings.warn('During grid search, all solutions contain new signatures. This could potentially mean that ' + 'a new signature does exist. Or, try decreasing thresh_new_sig.', + UserWarning) + candidate_grid_points = list(self.W_simul_grid.keys()) + else: + pass + ### Looking at distance between (W_simul, H_simul) and (W_data, H_data) + if self.grid_selection_method == 'distance': + # 2. Smallest W_cos_dist + W_cos_dist_min = np.min([self.W_cos_dist_mean_grid[key] for key in candidate_grid_points]) + # Candidates: W_cos_dist within W_cos_dist_min + self.grid_selection_cos_thresh + candidate_grid_points = [key for key in candidate_grid_points if self.W_cos_dist_mean_grid[key] <= W_cos_dist_min + self.grid_selection_cos_thresh] + if len(candidate_grid_points) == 1: + self.best_grid_point = candidate_grid_points[0] + else: + # 3. Look at number of assigned sigs + n_sigs_assigned_min = np.min([self.n_sigs_assigned_grid[key] for key in candidate_grid_points]) + candidate_grid_points = [key for key in candidate_grid_points if self.n_sigs_assigned_grid[key] == n_sigs_assigned_min] + if len(candidate_grid_points) == 1: + self.best_grid_point = candidate_grid_points[0] + else: + if self.grid_selection_use_H: + # 4. Look at H error + tmp = [[key, self.H_frobenius_dist_mean_grid[key]] for key in candidate_grid_points] + tmp = sorted(tmp, key=itemgetter(1)) + self.best_grid_point = tmp[0][0] + else: + # 4. Select the sparest result + tmp = [[key, key[0], key[1]] for key in candidate_grid_points] + # First select largest thresh_refit, then largest thresh_match + tmp = sorted(tmp, key=itemgetter(2, 1), reverse=True) + self.best_grid_point = tmp[0][0] + ### Looking at p-values + elif self.grid_selection_method == 'pvalue': + # 2. Element wise errors between W_simul and W_data. + elementwise_errors_W = [] + for key in candidate_grid_points: + W_simul = np.mean(self.W_simul_grid[key], 0) # We take the average signatures of multiple replicates of simulation results here. + error = np.absolute((self.W - W_simul).flatten()) + #elementwise_errors_W.append([key, np.mean(error), error]) + elementwise_errors_W.append([key, self.W_cos_dist_mean_grid[key], error]) # We can decide how to identify the 'best' one. + # Best result + elementwise_errors_W = sorted(elementwise_errors_W, key=itemgetter(1)) + error_best = elementwise_errors_W[0][2] + # Test differences + self.W_simul_error_pvalue_grid = {} + self.W_simul_error_pvalue_tail_grid = {} + candidate_grid_points = [] + for key, _, error in elementwise_errors_W: + pvalue = stats.mannwhitneyu(error_best, error, alternative='less')[1] + pvalue_tail = differential_tail_test(error_best, error, percentile=95, alternative='less')[1] + self.W_simul_error_pvalue_grid[key] = pvalue + self.W_simul_error_pvalue_tail_grid[key] = pvalue_tail + if pvalue > self.grid_selection_pvalue_thresh and pvalue_tail > self.grid_selection_pvalue_thresh: + candidate_grid_points.append(key) + ######## + if len(candidate_grid_points) == 1: + self.best_grid_point = candidate_grid_points[0] + else: + # 3. Look at number of assigned sigs + n_sigs_assigned_min = np.min([self.n_sigs_assigned_grid[key] for key in candidate_grid_points]) + candidate_grid_points = [key for key in candidate_grid_points if self.n_sigs_assigned_grid[key] == n_sigs_assigned_min] + if len(candidate_grid_points) == 1: + self.best_grid_point = candidate_grid_points[0] + else: + if self.grid_selection_use_H: + # 4. Look at p-values for H matrix + elementwise_errors_H = [] + for key in candidate_grid_points: + H_simul = np.mean(self.H_simul_grid[key], 0) + error = np.absolute((self.H - H_simul).flatten()) + elementwise_errors_H.append([key, self.H_frobenius_dist_mean_grid[key], error]) + elementwise_errors_H = sorted(elementwise_errors_H, key=itemgetter(1)) + error_best = elementwise_errors_H[0][2] + # Test differences + self.H_simul_error_pvalue_grid = {} + self.H_simul_error_pvalue_tail_grid = {} + candidate_grid_points = [] + for key, _, error in elementwise_errors_H: + pvalue = stats.mannwhitneyu(error_best, error, alternative='less')[1] + pvalue_tail = differential_tail_test(error_best, error, percentile=95, alternative='less')[1] + self.H_simul_error_pvalue_grid[key] = pvalue + self.H_simul_error_pvalue_tail_grid[key] = pvalue_tail + if pvalue > self.grid_selection_pvalue_thresh and pvalue_tail > self.grid_selection_pvalue_thresh: + candidate_grid_points.append(key) + if len(candidate_grid_points) == 1: + self.best_grid_point = candidate_grid_points[0] + else: + # 5. Select the sparsest result + tmp = [[key, key[0], key[1]] for key in candidate_grid_points] + # First select largest thresh_refit, then largest thresh_match + tmp = sorted(tmp, key=itemgetter(2, 1), reverse=True) + self.best_grid_point = tmp[0][0] + else: + # 4. Select the sparsest result + tmp = [[key, key[0], key[1]] for key in candidate_grid_points] + # First select largest thresh_refit, then largest thresh_match + tmp = sorted(tmp, key=itemgetter(2, 1), reverse=True) + self.best_grid_point = tmp[0][0] + ### + self.thresh_match = self.best_grid_point[0] + self.thresh_refit = self.best_grid_point[1] + self.W_s = self.W_s_grid[self.best_grid_point] + self.H_s = self.H_s_grid[self.best_grid_point] + self.sigs_assigned = self.sigs_assigned_grid[self.best_grid_point] + self.n_sigs_assigned = self.n_sigs_assigned_grid[self.best_grid_point] + self.W_cos_dist = self.W_cos_dist_grid[self.best_grid_point] + self.W_cos_dist_mean = self.W_cos_dist_mean_grid[self.best_grid_point] + self.H_frobenius_dist = self.H_frobenius_dist_grid[self.best_grid_point] + self.H_frobenius_dist_mean = self.H_frobenius_dist_mean_grid[self.best_grid_point] + return self + + def validate_grid(self, validate_n_replicates=1, + grid_selection_method='pvalue', grid_selection_use_H=False, + grid_selection_pvalue_thresh=0.05, grid_selection_cos_thresh=0.02): """Validation on a grid. + + grid_selection_method: 'pvalue' or 'distance' """ ################# Check running status and input if not hasattr(self, 'W'): raise ValueError('The model has not been fitted.') if not hasattr(self, '_assign_grid_is_run'): raise ValueError('Run assign_grid first.') + if grid_selection_method not in ['pvalue', 'distance']: + raise ValueError('Bad input for grid_selection_method.') ################# Run validation self.validate_n_replicates = validate_n_replicates - self.W_cos_dist_thresh = W_cos_dist_thresh + self.grid_selection_method = grid_selection_method + self.grid_selection_use_H = grid_selection_use_H + self.grid_selection_pvalue_thresh = grid_selection_pvalue_thresh + self.grid_selection_cos_thresh = grid_selection_cos_thresh self.X_simul_grid = {} self.W_simul_grid = {} self.H_simul_grid = {} @@ -1265,34 +1397,7 @@ def validate_grid(self, validate_n_replicates=1, W_cos_dist_thresh=0.02): self.W_cos_dist_mean_grid[(thresh_match, thresh_refit)] = np.mean(W_cos_dist) self.H_frobenius_dist_mean_grid[(thresh_match, thresh_refit)] = np.mean(H_frobenius_dist) ################# Select best result - # Smallest W_cos_dist - W_cos_dist_min = np.min(list(self.W_cos_dist_mean_grid.values())) - # Candidates: W_cos_dist within W_cos_dist_min + self.W_cos_dist_thresh - candidate_grid_points = [key for key, value in self.W_cos_dist_mean_grid.items() if value <= W_cos_dist_min + self.W_cos_dist_thresh] - if len(candidate_grid_points) == 1: - self.best_grid_point = candidate_grid_points[0] - else: - # Look at number of assigned sigs - n_sigs_assigned_min = np.min([self.n_sigs_assigned_grid[key] for key in candidate_grid_points]) - candidate_grid_points = [key for key in candidate_grid_points if self.n_sigs_assigned_grid[key] == n_sigs_assigned_min] - if len(candidate_grid_points) == 1: - self.best_grid_point = candidate_grid_points[0] - else: - # Look at H error - # Or choose the one with the strongest sparsity here. - tmp = [[key, self.H_frobenius_dist_mean_grid[key]] for key in candidate_grid_points] - tmp = sorted(tmp, key=itemgetter(1)) - self.best_grid_point = tmp[0][0] - self.thresh_match = self.best_grid_point[0] - self.thresh_refit = self.best_grid_point[1] - self.W_s = self.W_s_grid[self.best_grid_point] - self.H_s = self.H_s_grid[self.best_grid_point] - self.sigs_assigned = self.sigs_assigned_grid[self.best_grid_point] - self.n_sigs_assigned = self.n_sigs_assigned_grid[self.best_grid_point] - self.W_cos_dist = self.W_cos_dist_grid[self.best_grid_point] - self.W_cos_dist_mean = self.W_cos_dist_mean_grid[self.best_grid_point] - self.H_frobenius_dist = self.H_frobenius_dist_grid[self.best_grid_point] - self.H_frobenius_dist_mean = self.H_frobenius_dist_mean_grid[self.best_grid_point] + self._select_best_grid_point() return self ########################################################################### diff --git a/musical/nnls_sparse.py b/musical/nnls_sparse.py index 14cb1de..05fec2a 100644 --- a/musical/nnls_sparse.py +++ b/musical/nnls_sparse.py @@ -49,7 +49,7 @@ def nnls_thresh_naive(x, W, thresh=0.05, thresh_agnostic=0.0, indices_associated if len(indices_retained) == 0: indices_retained = np.array([np.argmax(h)]) - if indices_associated_sigs != None: + if indices_associated_sigs is not None: for ind_pair in indices_associated_sigs: if any(item in ind_pair for item in indices_retained): indices_retained = np.unique(np.append(indices_retained, ind_pair)) @@ -59,7 +59,7 @@ def nnls_thresh_naive(x, W, thresh=0.05, thresh_agnostic=0.0, indices_associated return h -def nnls_thresh(x, W, thresh=0.05, thresh_agnostic=0.0): +def nnls_thresh(x, W, thresh=0.05, thresh_agnostic=0.0, indices_associated_sigs=None): """Thresholded nnls An initial NNLS is first done. Based on the initial result, @@ -103,6 +103,12 @@ def nnls_thresh(x, W, thresh=0.05, thresh_agnostic=0.0): if np.array_equal(indices_retained_next, indices_retained): break indices_retained = indices_retained_next + # Final NNLS + if indices_associated_sigs is not None: + for ind_pair in indices_associated_sigs: + if any(item in ind_pair for item in indices_retained): + indices_retained = np.unique(np.append(indices_retained, ind_pair)) + h, _ = sp.optimize.nnls(W[:, indices_retained], x) h = _fill_vector(h, indices_retained, n_sigs) return h @@ -181,7 +187,7 @@ def nnls_likelihood_backward(x, W, thresh=0.001, per_trial=True, indices_associa indices_retained = np.array([i for i in indices_retained if i != index_remove]) ### Final NNLS - if indices_associated_sigs != None: + if indices_associated_sigs is not None: for ind_pair in indices_associated_sigs: if any(item in ind_pair for item in indices_retained): indices_retained = np.unique(np.append(indices_retained, ind_pair)) @@ -288,7 +294,7 @@ def nnls_likelihood_bidirectional(x, W, thresh_backward=0.001, thresh_forward=No warnings.warn('Max_iter reached, suggesting that the problem may not converge. Or try increasing max_iter.', UserWarning) ### Final NNLS - if indices_associated_sigs != None: + if indices_associated_sigs is not None: for ind_pair in indices_associated_sigs: if any(item in ind_pair for item in indices_retained): indices_retained = np.unique(np.append(indices_retained, ind_pair)) @@ -298,7 +304,7 @@ def nnls_likelihood_bidirectional(x, W, thresh_backward=0.001, thresh_forward=No return h -def nnls_cosine_bidirectional(x, W, thresh_backward=0.01, thresh_forward=None, max_iter=1000): +def nnls_cosine_bidirectional(x, W, thresh_backward=0.01, thresh_forward=None, max_iter=1000, indices_associated_sigs=None): if thresh_forward is None: thresh_forward = thresh_backward if thresh_backward > thresh_forward: @@ -373,6 +379,10 @@ def nnls_cosine_bidirectional(x, W, thresh_backward=0.01, thresh_forward=None, m warnings.warn('Max_iter reached, suggesting that the problem may not converge. Or try increasing max_iter.', UserWarning) ### Final NNLS + if indices_associated_sigs is not None: + for ind_pair in indices_associated_sigs: + if any(item in ind_pair for item in indices_retained): + indices_retained = np.unique(np.append(indices_retained, ind_pair)) h, _ = sp.optimize.nnls(W[:, indices_retained], x) h = _fill_vector(h, indices_retained, n_sigs) return h @@ -411,7 +421,7 @@ def nnls_likelihood_backward_relaxed(x, W, thresh=0.001, per_trial=True, indices index_remove = indices_retained[np.argmin(loglikelihoods)] indices_retained = np.array([i for i in indices_retained if i != index_remove]) ### Final NNLS - if indices_associated_sigs != None: + if indices_associated_sigs is not None: for ind_pair in indices_associated_sigs: if any(item in ind_pair for item in indices_retained): indices_retained = np.unique(np.append(indices_retained, ind_pair)) @@ -504,7 +514,7 @@ def nnls_likelihood_bidirectional_relaxed(x, W, thresh_backward=0.001, thresh_fo warnings.warn('Max_iter reached, suggesting that the problem may not converge. Or try increasing max_iter.', UserWarning) ### Final NNLS - if indices_associated_sigs != None: + if indices_associated_sigs is not None: for ind_pair in indices_associated_sigs: if any(item in ind_pair for item in indices_retained): indices_retained = np.unique(np.append(indices_retained, ind_pair)) @@ -582,7 +592,7 @@ def fit(self, X, W): self.thresh2 = 0.0 self.thresh_agnostic = self.thresh2 self.H = [ - nnls_thresh(x, self.W.values, thresh=self.thresh, thresh_agnostic=self.thresh_agnostic) for x in self._X_in.T.values + nnls_thresh(x, self.W.values, thresh=self.thresh, thresh_agnostic=self.thresh_agnostic, indices_associated_sigs=self.indices_associated_sigs) for x in self._X_in.T.values ] elif self.method == 'likelihood_backward': if self.thresh1 is None: @@ -646,7 +656,7 @@ def fit(self, X, W): if self.max_iter is None: self.max_iter = 1000 self.H = [ - nnls_cosine_bidirectional(x, self.W.values, thresh_backward=self.thresh_backward, thresh_forward=self.thresh_forward, max_iter=self.max_iter) for x in self._X_in.T.values + nnls_cosine_bidirectional(x, self.W.values, thresh_backward=self.thresh_backward, thresh_forward=self.thresh_forward, max_iter=self.max_iter, indices_associated_sigs=self.indices_associated_sigs) for x in self._X_in.T.values ] else: raise ValueError('Invalid method for SparseNNLS.') @@ -659,6 +669,12 @@ def fit(self, X, W): indices_all = np.arange(0, n_sigs) for x, h in zip(self.X.T.values, self.H): indices_retained = indices_all[h > 0] + ### Not sure if this is actually necessary. I think it does not matter. + if self.indices_associated_sigs is not None: + for ind_pair in self.indices_associated_sigs: + if any(item in ind_pair for item in indices_retained): + indices_retained = np.unique(np.append(indices_retained, ind_pair)) + ### h_new, _ = sp.optimize.nnls(self.W.iloc[:, indices_retained], x) h_new = _fill_vector(h_new, indices_retained, n_sigs) tmp.append(h_new) @@ -755,6 +771,11 @@ def fit(self, X, W): self.thresh1_grid = np.concatenate([np.array([0.0]), np.geomspace(0.0001, 5, 50)]) if self.thresh2_grid is None: self.thresh2_grid = np.array([None]) + elif self.method == 'cosine_bidirectional': + if self.thresh1_grid is None: + self.thresh1_grid = np.concatenate([np.array([0.0]), np.geomspace(0.0001, 5, 50)]) + if self.thresh2_grid is None: + self.thresh2_grid = np.array([None]) else: raise ValueError('Invalid method for SparseNNLSGrid.') ########## diff --git a/musical/refit.py b/musical/refit.py index 4bf2a3e..104a202 100644 --- a/musical/refit.py +++ b/musical/refit.py @@ -11,18 +11,22 @@ import numpy as np import scipy as sp import pandas as pd +import warnings + from .nnls import nnls from .nnls_sparse import SparseNNLS, SparseNNLSGrid -from .utils import match_signature_to_catalog_nnls_sparse, beta_divergence, get_sig_indices_associated +from .utils import match_signature_to_catalog_nnls_sparse, beta_divergence, get_sig_indices_associated, SIGS_ASSOCIATED_DICT, SIGS_ASSOCIATED from .catalog import load_catalog def refit(X, W, method='likelihood_bidirectional', thresh=None, - indices_associated_sigs=None): + connected_sigs=False): """Wrapper around SparseNNLS for refitting Note that only one parameter thresh1 is allowed here. Both X and W should be pd.DataFrame. + If connected_sigs is set to True, we'll not fill in missing connected sigs, although a warning will be printed. + So make sure W contains all connected signatures if connected_sigs is set to True. """ # Check input if X.shape[0] != W.shape[0]: @@ -30,12 +34,29 @@ def refit(X, W, method='likelihood_bidirectional', thresh=None, if (X.index == W.index).sum() != X.shape[0]: raise ValueError('X and W have different indices.') # SparseNNLS + if connected_sigs: + indices_associated_sigs, _ = get_sig_indices_associated(W.columns.values, W.columns.values) + # Give some informative warnings + missing_sigs = [] + W_sigs = W.columns.values + for key in W_sigs: + if key in SIGS_ASSOCIATED_DICT.keys(): + for sig in SIGS_ASSOCIATED_DICT[key]: + if sig not in W_sigs: + missing_sigs.append(sig) + if len(missing_sigs) > 0: + warnings.warn(('In refit: connected_sigs is set to True. The input W contains signatures with connected signatures. ' + + 'However, W is missing some connected signatures. Specifically, W is missing: ' + + ','.join(missing_sigs) + '. Please fill in these missing sigs in W or make sure this is indeed what is wanted.'), + UserWarning) + else: + indices_associated_sigs = None model = SparseNNLS(method=method, thresh1=thresh, indices_associated_sigs=indices_associated_sigs) model.fit(X, W) return model.H, model def refit_grid(X, W, method='likelihood_bidirectional', thresh_grid=None, ncpu=1, verbose=0, - indices_associated_sigs=None): + connected_sigs=False): """Refitting on a grid of thresholds. """ # Check input @@ -46,6 +67,24 @@ def refit_grid(X, W, method='likelihood_bidirectional', thresh_grid=None, ncpu=1 # SparseNNLSGrid if thresh_grid is None: thresh_grid = np.array([0.001]) + # connected_sigs + if connected_sigs: + indices_associated_sigs, _ = get_sig_indices_associated(W.columns.values, W.columns.values) + # Give some informative warnings + missing_sigs = [] + W_sigs = W.columns.values + for key in W_sigs: + if key in SIGS_ASSOCIATED_DICT.keys(): + for sig in SIGS_ASSOCIATED_DICT[key]: + if sig not in W_sigs: + missing_sigs.append(sig) + if len(missing_sigs) > 0: + warnings.warn(('In refit: connected_sigs is set to True. The input W contains signatures with connected signatures. ' + + 'However, W is missing some connected signatures. Specifically, W is missing: ' + + ','.join(missing_sigs) + '. Please fill in these missing sigs in W or make sure this is indeed what is wanted.'), + UserWarning) + else: + indices_associated_sigs = None model = SparseNNLSGrid(method=method, thresh1_grid=thresh_grid, ncpu=ncpu, verbose=verbose, indices_associated_sigs=indices_associated_sigs) model.fit(X, W) # Results @@ -68,16 +107,51 @@ def _get_W_s(W, W_catalog, H_reduced, cos_similarities, thresh_new_sig): W_s = pd.DataFrame.copy(W.iloc[:, inds_new_sig]) # Let's not rename these new signatures and keep their original name in W, so that we can distinguish them. #W_s.columns = ['Sig_N' + str(i) for i in range(1, len(inds_new_sig) + 1)] + sig_map = pd.DataFrame(np.identity(len(inds_new_sig)), columns=W_s.columns, index=W_s.columns) H_tmp = H_reduced.iloc[:, inds_not_new_sig] W_s = pd.concat([W_s, W_catalog[H_tmp.index[H_tmp.sum(1) > 0]]], axis=1) + sig_map = pd.concat([sig_map, H_tmp.loc[H_tmp.index[H_tmp.sum(1) > 0]]]).fillna(0.0) elif len(inds_new_sig) > 0 and len(inds_not_new_sig) == 0: W_s = pd.DataFrame.copy(W.iloc[:, inds_new_sig]) + sig_map = pd.DataFrame(np.identity(len(inds_new_sig)), columns=W_s.columns, index=W_s.columns) elif len(inds_new_sig) == 0 and len(inds_not_new_sig) > 0: W_s = pd.DataFrame.copy(W_catalog[H_reduced.index]) + sig_map = H_reduced + return W_s, sig_map + +def _clear_W_s(W, W_s, sig_map, min_sum_0p01 = 0.15, min_sig_contrib_ratio = 0.25): + for i in range(sig_map.shape[1]): + set_to_zero = False + for j in range(sig_map.shape[0]): + fij = sig_map.iat[j,i] + if fij < 1 and fij != 0: + contrib = np.dot(W_s[sig_map.index[j]], fij) + sum_above_0p01 = np.sum(contrib[contrib > 0.01]) + max_sig_contrib_ratio = max(contrib)/max(W[sig_map.columns[i]]) + if sum_above_0p01 < min_sum_0p01 and max_sig_contrib_ratio < min_sig_contrib_ratio: + set_to_zero = True + sig_map.iat[j,i] = 0 + if set_to_zero: + weights, _ = sp.optimize.nnls(np.array(W_s.loc[:,sig_map.index[sig_map.loc[:, sig_map.columns[i]] > 0 ]]), np.array(W.loc[:,sig_map.columns[i]])) + sig_map.loc[sig_map.loc[:, sig_map.columns[i]] > 0, sig_map.columns[i]] = weights + sig_map = sig_map[sig_map.sum(1) > 0] + W_s = W_s[sig_map.index] + return W_s, sig_map + +def _add_missing_connected_sigs(W_s, W_catalog): + missing_sigs = [] + W_s_sigs = W_s.columns.values + W_catalog_sigs = W_catalog.columns.values + for key in W_s_sigs: + if key in SIGS_ASSOCIATED_DICT.keys(): + for sig in SIGS_ASSOCIATED_DICT[key]: + if sig not in W_s_sigs and sig in W_catalog_sigs: + missing_sigs.append(sig) + W_s = pd.concat([W_s, W_catalog[missing_sigs]], axis=1) return W_s def match(W, W_catalog, thresh_new_sig=0.8, method='likelihood_bidirectional', thresh=None, - indices_associated_sigs=None): + connected_sigs=False): """Wrapper around SparseNNLS for matching Note that only one parameter thresh1 is allowed here. @@ -97,14 +171,18 @@ def match(W, W_catalog, thresh_new_sig=0.8, method='likelihood_bidirectional', t if len(set(W.columns).intersection(W_catalog.columns)) > 0: raise ValueError('W and W_catalog cannot contain signatures with the same name.') # SparseNNLS - model = SparseNNLS(method=method, thresh1=thresh, indices_associated_sigs=indices_associated_sigs) + model = SparseNNLS(method=method, thresh1=thresh, indices_associated_sigs=None) model.fit(W, W_catalog) # Identify new signatures not in the catalog. - W_s = _get_W_s(W, W_catalog, model.H_reduced, model.cos_similarities, thresh_new_sig) - return W_s, model + W_s, sig_map = _get_W_s(W, W_catalog, model.H_reduced, model.cos_similarities, thresh_new_sig) + W_s, sig_map = _clear_W_s(W, W_s, sig_map) + # If connected_sigs, we add missing connected signatures to W_s + if connected_sigs: + W_s = _add_missing_connected_sigs(W_s, W_catalog) + return W_s, sig_map, model def match_grid(W, W_catalog, thresh_new_sig=0.8, method='likelihood_bidirectional', thresh_grid=None, ncpu=1, verbose=0, - indices_associated_sigs=None): + connected_sigs=False): """Matching on a grid of thresholds. """ # Check input @@ -117,7 +195,7 @@ def match_grid(W, W_catalog, thresh_new_sig=0.8, method='likelihood_bidirectiona # SparseNNLSGrid if thresh_grid is None: thresh_grid = np.array([0.001]) - model = SparseNNLSGrid(method=method, thresh1_grid=thresh_grid, ncpu=ncpu, verbose=verbose, indices_associated_sigs=indices_associated_sigs) + model = SparseNNLSGrid(method=method, thresh1_grid=thresh_grid, ncpu=ncpu, verbose=verbose, indices_associated_sigs=None) model.fit(W, W_catalog) # Identify new signatures not in the catalog. thresh2 = model.thresh2_grid[0] @@ -128,31 +206,37 @@ def match_grid(W, W_catalog, thresh_new_sig=0.8, method='likelihood_bidirectiona else: raise ValueError('thresh2 is modified unexpectedly.') W_s_grid = {} + sig_map_grid = {} for thresh in thresh_grid: key = (thresh, thresh2) - W_s = _get_W_s(W, W_catalog, model.H_reduced_grid[key], model.cos_similarities_grid[key], thresh_new_sig) + W_s, sig_map = _get_W_s(W, W_catalog, model.H_reduced_grid[key], model.cos_similarities_grid[key], thresh_new_sig) + W_s, sig_map = _clear_W_s(W, W_s, sig_map) + ## Add missing connected signatures + if connected_sigs: + W_s = _add_missing_connected_sigs(W_s, W_catalog) W_s_grid[thresh] = W_s - return W_s_grid, model + sig_map_grid[thresh] = sig_map + return W_s_grid, sig_map_grid, model def assign(X, W, W_catalog, method='likelihood_bidirectional', thresh_match=None, thresh_refit=None, thresh_new_sig=0.8, - indices_associated_sigs=None): + connected_sigs=False): """Assign = match + refit. The same method will be used for both match and refit. Match and refit can have different thresholds. But only one threshold is allowed for each. If you want to skip matching, set thresh_new_sig to a value > 1. """ - W_s, _ = match(W, W_catalog, thresh_new_sig=thresh_new_sig, method=method, thresh=thresh_match, indices_associated_sigs=indices_associated_sigs) - H_s, _ = refit(X, W_s, method=method, thresh=thresh_refit, indices_associated_sigs=indices_associated_sigs) - return W_s, H_s + W_s, sig_map, _ = match(W, W_catalog, thresh_new_sig=thresh_new_sig, method=method, thresh=thresh_match, connected_sigs=connected_sigs) + H_s, _ = refit(X, W_s, method=method, thresh=thresh_refit, connected_sigs=connected_sigs) + return W_s, H_s, sig_map def assign_grid(X, W, W_catalog, method='likelihood_bidirectional', thresh_match_grid=None, thresh_refit_grid=None, - thresh_new_sig=0.8, indices_associated_sigs=None, + thresh_new_sig=0.8, connected_sigs=False, ncpu=1, verbose=0): """Match and refit on a grid""" if thresh_match_grid is None: @@ -160,9 +244,9 @@ def assign_grid(X, W, W_catalog, method='likelihood_bidirectional', if thresh_refit_grid is None: thresh_refit_grid = np.array([0.001]) # First, matching on a grid - W_s_grid_1d, _ = match_grid(W, W_catalog, thresh_new_sig=thresh_new_sig, method=method, - thresh_grid=thresh_match_grid, ncpu=ncpu, verbose=verbose, - indices_associated_sigs=indices_associated_sigs) + W_s_grid_1d, sig_map_grid_1d, _ = match_grid(W, W_catalog, thresh_new_sig=thresh_new_sig, method=method, + thresh_grid=thresh_match_grid, ncpu=ncpu, verbose=verbose, + connected_sigs=connected_sigs) # Second, refitting on a grid. # When a matching result is already calculated before, do not do refitting again. H_s_grid = {} @@ -175,11 +259,11 @@ def assign_grid(X, W, W_catalog, method='likelihood_bidirectional', else: H_s_grid_1d, _ = refit_grid(X, W_s_grid_1d[thresh_match], method=method, thresh_grid=thresh_refit_grid, ncpu=ncpu, verbose=verbose, - indices_associated_sigs=indices_associated_sigs) + connected_sigs=connected_sigs) H_s_grid[thresh_match] = H_s_grid_1d calculated_matching_results[sigs] = thresh_match thresh_match_grid_unique.append(thresh_match) - return W_s_grid_1d, H_s_grid, np.array(thresh_match_grid_unique) + return W_s_grid_1d, H_s_grid, sig_map_grid_1d, np.array(thresh_match_grid_unique) diff --git a/musical/utils.py b/musical/utils.py index 46fac79..2525bbb 100644 --- a/musical/utils.py +++ b/musical/utils.py @@ -69,7 +69,18 @@ [["T>C", item] for item in trinucleotides_T] + [["T>G", item] for item in trinucleotides_T]) -sigs_associated = [['SBS2','SBS13'], ['SBS17a','SBS17b'], ['SBS10a','SBS10b','SBS10c','SBS10d','SBS28']] +SIGS_ASSOCIATED = [['SBS2','SBS13'], ['SBS17a','SBS17b'], ['SBS10a','SBS10b','SBS10c','SBS10d','SBS28']] +SIGS_ASSOCIATED_DICT = { + 'SBS2':['SBS2', 'SBS13'], + 'SBS13': ['SBS2', 'SBS13'], + 'SBS17a': ['SBS17a', 'SBS17b'], + 'SBS17b': ['SBS17a', 'SBS17b'], + 'SBS10a': ['SBS10a','SBS10b','SBS10c','SBS10d','SBS28'], + 'SBS10b': ['SBS10a','SBS10b','SBS10c','SBS10d','SBS28'], + 'SBS10c': ['SBS10a','SBS10b','SBS10c','SBS10d','SBS28'], + 'SBS10d': ['SBS10a','SBS10b','SBS10c','SBS10d','SBS28'], + 'SBS28': ['SBS10a','SBS10b','SBS10c','SBS10d','SBS28'] +} # Need to update indel_types_83_str = [ @@ -644,11 +655,15 @@ def classification_statistics(confusion_matrix=None, P=None, PP=None, All=None): return statistics -def get_sig_indices_associated(signatures, signatures_catalog): +def get_sig_indices_associated(signatures, signatures_catalog=None): + """ + signatures_catalog is used to get orders of the signature names. + It is also used for filtering out signatures not considered (i.e., those not in signatures_catalog). + """ signatures = np.array(signatures) nsig = signatures.size - for entry in sigs_associated: + for entry in SIGS_ASSOCIATED: has_entry = False missing_item = [] for item in entry: @@ -660,12 +675,13 @@ def get_sig_indices_associated(signatures, signatures_catalog): if has_entry and len(missing_item) > 0: signatures = np.append(signatures, missing_item) - signatures = np.sort(signatures) - signatures = [item for index,item in enumerate(signatures_catalog) if item in signatures] + if signatures_catalog is not None: + signatures = [item for index,item in enumerate(signatures_catalog) if item in signatures] + signatures = np.array(signatures) indices_associated = [] - for entry in sigs_associated: + for entry in SIGS_ASSOCIATED: indices_this = [] has_entry = False for item in entry: