Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dev #29

Merged
merged 1 commit into from
Apr 26, 2021
Merged

dev #29

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion musical/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""

from .utils import beta_divergence
from .plot import sigplot_bar
from .plot import sigplot_bar, plot_silhouettes
from .nmf import NMF
from .mvnmf import MVNMF, wrappedMVNMF
from .denovo import DenovoSig
Expand All @@ -11,6 +11,7 @@

__all__ = ['beta_divergence',
'sigplot_bar',
'plot_silhouettes',
'NMF',
'MVNMF',
'wrappedMVNMF',
Expand Down
99 changes: 99 additions & 0 deletions musical/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.ticker as ticker
from sklearn.preprocessing import normalize

from .utils import snv_types_96_str
Expand Down Expand Up @@ -202,3 +205,99 @@ def sigplot_bar(sig, norm=True, figsize=None, title=None,

if outfile is not None:
plt.savefig(outfile, bbox_inches='tight')


def plot_silhouettes(model, title_tag=None, plotpvalues=True,
tick_fontsize=12, label_fontsize=14,
outfile=None):

"""Plotting function for silhouette scores.

model is a MuSiCal object that has been processed with model.fit().

Parameters
----------
model : MuSiCal object after fitting using NMF or mnNMF

title_tag : string
String to add to plot titles, e.g. disease type
"""

#Convert dictionaries of all mean silhouette scores and all reconstruction errors to arrays
sil_score_mean_array = np.array(list(model.sil_score_mean_all.values()))
reconstruction_error_array = np.array(list(model.reconstruction_error_all.values()))

#Create DF from silhouette scores for heatmap and rename columns
sil_score_all_df = pd.DataFrame.from_dict(model.sil_score_all,orient = 'index')
sil_score_all_df.columns=[i for i in range(1, sil_score_all_df.shape[1] + 1)]

#Plot

# Set up the axes with gridspec
fig = plt.figure(figsize=(16, 4))
grid = plt.GridSpec(1, 5, hspace=0.5, wspace=4)
host = fig.add_subplot(grid[0, 0:3])
plt2 = host.twinx()

if plotpvalues:
plt3 = host.twinx()

heat_map = fig.add_subplot(grid[0, 3:])

#Generate line plot
host.set_xlabel("n components")
host.set_ylabel("Mean silhouette score")
plt2.set_ylabel("Reconstruction error")
if plotpvalues:
plt3.set_ylabel("p-value")

color1 = '#E94E1B'
color2 = '#1D71B8'
if plotpvalues:
color3 = '#2FAC66'

p1, = host.plot(model.n_components_all, sil_score_mean_array, color=color1, label="Mean silhouette score", linestyle='--', marker='o',)
p2, = plt2.plot(model.n_components_all, reconstruction_error_array, color=color2, label="Reconstruction error", linestyle=':', marker='D')
if plotpvalues:
p3, = plt3.plot(model.n_components_all[1:], model.pvalue_all, color=color3, label="p-value", linestyle='-.', marker='.', alpha=0.5)
if plotpvalues:
lns = [p1, p2, p3]
else:
lns = [p1, p2]

host.legend(handles=lns, loc="lower center", bbox_to_anchor=(0.5, -0.5))

host.yaxis.label.set_color(p1.get_color())
plt2.yaxis.label.set_color(p2.get_color())

if plotpvalues:
plt3.yaxis.label.set_color(p3.get_color())

#Adjust p-value spine position
if plotpvalues:
plt3.spines['right'].set_position(('outward', 65))

#Set ticks interval to 1
host.xaxis.set_major_locator(ticker.MultipleLocator(1))

#Higlight suggested signature
host.axvspan(model.n_components-0.25, model.n_components+0.25, color='grey', alpha=0.3)

#Set title
if title_tag is not None:
host.set_title('Silhouette scores and reconstruction errors for '+title_tag)
else:
host.set_title('Silhouette scores and reconstruction errors')

#Generate heatmap
heat_map = sns.heatmap(sil_score_all_df,vmin=0, vmax=1, cmap="YlGnBu")
heat_map.set_xlabel("Signatures")
heat_map.set_ylabel("n components")

if title_tag is not None:
heat_map.set_title('Silhouette scores for '+title_tag)
else:
heat_map.set_title('Silhouette scores')

if outfile is not None:
plt.savefig(outfile, bbox_inches='tight')