-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtree_scatter.py
124 lines (108 loc) · 3.17 KB
/
tree_scatter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import polars as pl
import statistics as stat
import matplotlib.pyplot as plt
import matplotlib
import math
import numpy as np
from pathlib import Path
import pickle
from counterstats import countermean, countermedian
from matplotlib import colormaps
from collections import defaultdict, Counter
with open("workdir/all_dagtrees_example.p", "rb") as fh:
data = pickle.load(fh)
n_leaves, n_nodes_true_tree, dag_data, dnapars_data = (
data["NumLeaves"],
data["TrueTreeNumNodes"],
data["dagtrees"],
data["dnaparsTrees"],
)
dag_df = pl.DataFrame(
{colname: [it[i] for it in dag_data[1:]] for i, colname in enumerate(dag_data[0])}
)
dnapars_df = pl.DataFrame(
{
colname: [it[i] for it in dnapars_data[1:]]
for i, colname in enumerate(dnapars_data[0])
}
)
fig, ax = plt.subplots()
dag_df = dag_df.with_columns(
pl.Series(
"normalized_rf",
[
rf / (ncount + n_nodes_true_tree - (2 * n_leaves))
for rf, ncount in dag_df["RootedRF_and_nodecount"]
],
)
)
dnapars_df = dnapars_df.with_columns(
pl.Series(
"normalized_rf",
[
rf / (ncount + n_nodes_true_tree - (2 * n_leaves))
for rf, ncount in dnapars_df["RootedRF_and_nodecount"]
],
)
)
best_dnapars_rf = dnapars_df["normalized_rf"].min()
# Add grid lines
ax.grid(True, which="both", linestyle="--", linewidth=0.5, zorder=1)
_x_y_data = ["BPLikelihoodLogLoss", "ContextLikelihoodLogLoss"]
scatter_titles = ["History sDAG trees", "Dnapars trees"]
dag_not_better_df = dag_df.filter(pl.col("normalized_rf") >= best_dnapars_rf)
dag_better_df = dag_df.filter(pl.col("normalized_rf") < best_dnapars_rf)
ax.scatter(
*[
-dag_not_better_df.select((pl.col(col)).alias("this"))["this"]
for col in _x_y_data
],
# edgecolor='black', # Black border around each point
linewidths=0,
color=matplotlib.colors.colorConverter.to_rgba(
colormaps["Dark2"].colors[0], alpha=0.3
),
# alpha=.6,
marker=".",
s=70,
label="History sDAG trees",
zorder=2,
)
ax.scatter(
*[-dnapars_df.select((pl.col(col)).alias("this"))["this"] for col in _x_y_data],
# edgecolor='black', # Black border around each point
# facecolors='none',
linewidths=0,
color=matplotlib.colors.colorConverter.to_rgba(
colormaps["Dark2"].colors[1], alpha=1
),
# alpha=.9,
marker=".",
label="Dnapars trees",
zorder=2,
s=70,
)
ax.scatter(
*[-dag_better_df.select((pl.col(col)).alias("this"))["this"] for col in _x_y_data],
edgecolor="black", # Black border around each point
linewidths=0.7,
color=matplotlib.colors.colorConverter.to_rgba(
colormaps["Dark2"].colors[0], alpha=0.3
),
# alpha=.6,
marker=".",
s=70,
# linewidths=0,
# color=colormaps["Dark2"].colors[1],
# alpha=.6,
# marker='.',
label="History sDAG trees (improved RF-distance)",
zorder=2,
)
ax.set_title("Dnapars and History sDAG Tree Likelihoods")
ax.set_xlabel("Branching Process Log-Likelihood")
ax.set_ylabel("Poisson Context Log-Likelihood")
ax.legend()
fig_name = "tree_scatter.pdf"
fig.savefig(fig_name)
print(fig_name)