Skip to content

Commit

Permalink
Merge pull request #23 from anton-bushuiev/main
Browse files Browse the repository at this point in the history
Add leakage analysis
  • Loading branch information
anton-bushuiev authored Jun 4, 2024
2 parents 84b117e + c8d04c6 commit a88e97b
Show file tree
Hide file tree
Showing 2 changed files with 1,340 additions and 0 deletions.
71 changes: 71 additions & 0 deletions massspecgym/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors
import matplotlib.cm as cm
import matplotlib.colors as mcolors
import matplotlib.ticker as ticker
import pandas as pd
import typing as T
from rdkit.Chem import AllChem as Chem
Expand Down Expand Up @@ -150,3 +154,70 @@ def get_smiles_bpe_tokenizer() -> ByteLevelBPETokenizer:
],
)
return smiles_tokenizer


def parse_spec_array(arr: str) -> np.ndarray:
return list(map(float, arr.split(",")))


def plot_spectrum(spec, hue=None, xlim=None, ylim=None, mirror_spec=None, highl_idx=None,
figsize=(6, 2), colors=None, save_pth=None):

if colors is not None:
assert len(colors) >= 3
else:
colors = ['blue', 'green', 'red']

# Normalize input spectrum
def norm_spec(spec):
assert len(spec.shape) == 2
if spec.shape[0] != 2:
spec = spec.T
mzs, ins = spec[0], spec[1]
return mzs, ins / max(ins) * 100
mzs, ins = norm_spec(spec)

# Initialize plotting
init_plotting(figsize=figsize)
fig, ax = plt.subplots(1, 1)

# Setup color palette
if hue is not None:
norm = matplotlib.colors.Normalize(vmin=min(hue), vmax=max(hue), clip=True)
mapper = cm.ScalarMappable(norm=norm, cmap=cm.cool)
plt.colorbar(mapper, ax=ax)

# Plot spectrum
for i in range(len(mzs)):
if hue is not None:
color = mcolors.to_hex(mapper.to_rgba(hue[i]))
else:
color = colors[0]
plt.plot([mzs[i], mzs[i]], [0, ins[i]], color=color, marker='o', markevery=(1, 2), mfc='white', zorder=2)

# Plot mirror spectrum
if mirror_spec is not None:
mzs_m, ins_m = norm_spec(mirror_spec)

@ticker.FuncFormatter
def major_formatter(x, pos):
label = str(round(-x)) if x < 0 else str(round(x))
return label

for i in range(len(mzs_m)):
plt.plot([mzs_m[i], mzs_m[i]], [0, -ins_m[i]], color=colors[2], marker='o', markevery=(1, 2), mfc='white',
zorder=1)
ax.yaxis.set_major_formatter(major_formatter)

# Setup axes
if xlim is not None:
plt.xlim(xlim[0], xlim[1])
else:
plt.xlim(0, max(mzs) + 10)
if ylim is not None:
plt.ylim(ylim[0], ylim[1])
plt.xlabel('m/z')
plt.ylabel('Intensity [%]')

if save_pth is not None:
raise NotImplementedError()
Loading

0 comments on commit a88e97b

Please sign in to comment.