-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
49 lines (42 loc) · 1.41 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""
EECS 445 - Introduction to Machine Learning
Fall 2023 - Project 2
Utility functions
"""
import os
import numpy as np
import matplotlib.pyplot as plt
def config(attr):
"""
Retrieves the queried attribute value from the config file. Loads the
config file on first call.
"""
if not hasattr(config, "config"):
with open("config.json") as f:
config.config = eval(f.read())
node = config.config
for part in attr.split("."):
node = node[part]
return node
def make_training_plot(stats, name="DistilBert Fine-Tuning"):
"""Set up an interactive matplotlib graph to log metrics during training."""
# plt.ion()
fig, axes = plt.subplots(1, 4, figsize=(20, 5))
plt.suptitle(name)
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Accuracy")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Loss")
axes[2].set_xlabel("Epoch")
axes[2].set_ylabel("Recall")
axes[3].set_xlabel("Epoch")
axes[3].set_ylabel("Precision")
epoch_set = np.arange(len(stats))
for i in range(4):
axes[i].plot(epoch_set,stats[:,0+i], 'r--', marker="o", label="Validation")
axes[i].plot(epoch_set,stats[:,4+i], 'b--', marker="o", label="Training")
axes[i].legend()
save_dbert_training_plot()
def save_dbert_training_plot():
"""Save the training plot to a file."""
plt.savefig("distilbert_finetuning_plot.png", dpi=200)