diff --git a/musical/__init__.py b/musical/__init__.py index 2183d2c..6b8a305 100644 --- a/musical/__init__.py +++ b/musical/__init__.py @@ -2,7 +2,7 @@ """ from .utils import beta_divergence -from .plot import sigplot_bar +from .plot import sigplot_bar, plot_silhouettes from .nmf import NMF from .mvnmf import MVNMF, wrappedMVNMF from .denovo import DenovoSig @@ -11,6 +11,7 @@ __all__ = ['beta_divergence', 'sigplot_bar', + 'plot_silhouettes', 'NMF', 'MVNMF', 'wrappedMVNMF', diff --git a/musical/plot.py b/musical/plot.py index c1e7533..a16c1a2 100644 --- a/musical/plot.py +++ b/musical/plot.py @@ -2,8 +2,11 @@ import seaborn as sns import numpy as np +import pandas as pd import matplotlib as mpl import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec +import matplotlib.ticker as ticker from sklearn.preprocessing import normalize from .utils import snv_types_96_str @@ -202,3 +205,99 @@ def sigplot_bar(sig, norm=True, figsize=None, title=None, if outfile is not None: plt.savefig(outfile, bbox_inches='tight') + + +def plot_silhouettes(model, title_tag=None, plotpvalues=True, + tick_fontsize=12, label_fontsize=14, + outfile=None): + + """Plotting function for silhouette scores. + + model is a MuSiCal object that has been processed with model.fit(). + + Parameters + ---------- + model : MuSiCal object after fitting using NMF or mnNMF + + title_tag : string + String to add to plot titles, e.g. disease type + """ + + #Convert dictionaries of all mean silhouette scores and all reconstruction errors to arrays + sil_score_mean_array = np.array(list(model.sil_score_mean_all.values())) + reconstruction_error_array = np.array(list(model.reconstruction_error_all.values())) + + #Create DF from silhouette scores for heatmap and rename columns + sil_score_all_df = pd.DataFrame.from_dict(model.sil_score_all,orient = 'index') + sil_score_all_df.columns=[i for i in range(1, sil_score_all_df.shape[1] + 1)] + + #Plot + + # Set up the axes with gridspec + fig = plt.figure(figsize=(16, 4)) + grid = plt.GridSpec(1, 5, hspace=0.5, wspace=4) + host = fig.add_subplot(grid[0, 0:3]) + plt2 = host.twinx() + + if plotpvalues: + plt3 = host.twinx() + + heat_map = fig.add_subplot(grid[0, 3:]) + + #Generate line plot + host.set_xlabel("n components") + host.set_ylabel("Mean silhouette score") + plt2.set_ylabel("Reconstruction error") + if plotpvalues: + plt3.set_ylabel("p-value") + + color1 = '#E94E1B' + color2 = '#1D71B8' + if plotpvalues: + color3 = '#2FAC66' + + p1, = host.plot(model.n_components_all, sil_score_mean_array, color=color1, label="Mean silhouette score", linestyle='--', marker='o',) + p2, = plt2.plot(model.n_components_all, reconstruction_error_array, color=color2, label="Reconstruction error", linestyle=':', marker='D') + if plotpvalues: + p3, = plt3.plot(model.n_components_all[1:], model.pvalue_all, color=color3, label="p-value", linestyle='-.', marker='.', alpha=0.5) + if plotpvalues: + lns = [p1, p2, p3] + else: + lns = [p1, p2] + + host.legend(handles=lns, loc="lower center", bbox_to_anchor=(0.5, -0.5)) + + host.yaxis.label.set_color(p1.get_color()) + plt2.yaxis.label.set_color(p2.get_color()) + + if plotpvalues: + plt3.yaxis.label.set_color(p3.get_color()) + + #Adjust p-value spine position + if plotpvalues: + plt3.spines['right'].set_position(('outward', 65)) + + #Set ticks interval to 1 + host.xaxis.set_major_locator(ticker.MultipleLocator(1)) + + #Higlight suggested signature + host.axvspan(model.n_components-0.25, model.n_components+0.25, color='grey', alpha=0.3) + + #Set title + if title_tag is not None: + host.set_title('Silhouette scores and reconstruction errors for '+title_tag) + else: + host.set_title('Silhouette scores and reconstruction errors') + + #Generate heatmap + heat_map = sns.heatmap(sil_score_all_df,vmin=0, vmax=1, cmap="YlGnBu") + heat_map.set_xlabel("Signatures") + heat_map.set_ylabel("n components") + + if title_tag is not None: + heat_map.set_title('Silhouette scores for '+title_tag) + else: + heat_map.set_title('Silhouette scores') + + if outfile is not None: + plt.savefig(outfile, bbox_inches='tight') \ No newline at end of file