-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtraining_plots_class.py
151 lines (113 loc) · 9.13 KB
/
training_plots_class.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# README
# Phillip Long
# August 21, 2023
# Makes plots describing the training of the key class neural network.
# python ./training_plots_class.py history_filepath percentiles_history_filepath output_filepath
# python /dfs7/adl/pnlong/artificial_dj/determine_key/training_plots_class.py "/dfs7/adl/pnlong/artificial_dj/data/key_class_nn.history.tsv" "/dfs7/adl/pnlong/artificial_dj/data/key_class_nn.percentiles_history.tsv" "/dfs7/adl/pnlong/artificial_dj/data/key_class_nn.png"
# IMPORTS
##################################################
import sys
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
# sys.argv = ("./training_plots_class.py", "/Users/philliplong/Desktop/Coding/artificial_dj/data/key_class_nn.history.tsv", "/Users/philliplong/Desktop/Coding/artificial_dj/data/key_class_nn.percentiles_history.tsv", "/Users/philliplong/Desktop/Coding/artificial_dj/data/key_class_nn.png")
# sys.argv = ("./training_plots_class.py", "/dfs7/adl/pnlong/artificial_dj/data/key_class_nn.history.tsv", "/dfs7/adl/pnlong/artificial_dj/data/key_class_nn.percentiles_history.tsv", "/dfs7/adl/pnlong/artificial_dj/data/key_class_nn.png")
##################################################
# IMPORT FILEPATHS, LOAD FILES
##################################################
# make arguments
OUTPUT_FILEPATH_HISTORY = sys.argv[1]
OUTPUT_FILEPATH_PERCENTILES_HISTORY = sys.argv[2]
OUTPUT_FILEPATH = sys.argv[3]
# load in tsv files that have been generated
history = pd.read_csv(OUTPUT_FILEPATH_HISTORY, sep = "\t", header = 0, index_col = False, keep_default_na = False, na_values = "NA")
history["epoch"] = range(history.at[0, "epoch"], history.at[0, "epoch"] + len(history)) # make sure epochs are constantly ascending
history["train_accuracy"] = history["train_accuracy"].apply(lambda accuracy: 100 * accuracy) # multiply accuracies by 100
history["validate_accuracy"] = history["validate_accuracy"].apply(lambda accuracy: 100 * accuracy) # multiply accuracies by 100
percentiles_history = pd.read_csv(OUTPUT_FILEPATH_PERCENTILES_HISTORY, sep = "\t", header = 0, index_col = False, keep_default_na = False, na_values = "NA")
percentiles_history["epoch"] = [epoch for epoch in range(percentiles_history.at[0, "epoch"], percentiles_history.at[0, "epoch"] + (len(percentiles_history) // 101)) for _ in range(101)] # make sure percentiles epochs are constantly ascending
# get list of epochs where training switches from pretrained layers to final regression layers
freeze_pretrained_epochs = [(history.at[0, "epoch"], history.at[0, "freeze_pretrained"])] + [(history.at[i, "epoch"], history.at[i, "freeze_pretrained"]) for i in range(1, len(history)) if history.at[i, "freeze_pretrained"] != history.at[i - 1, "freeze_pretrained"]] + [(history.at[len(history) - 1, "epoch"], history.at[len(history) - 1, "freeze_pretrained"])]
if freeze_pretrained_epochs[-1] == freeze_pretrained_epochs[-2]: # remove final element if duplicate
del freeze_pretrained_epochs[-1]
freeze_pretrained_alpha = lambda freeze_pretrained: str(1 - (0.2 * int(not freeze_pretrained))) # get the color (grayscale) for the backgrounds representing freeze_pretrained
freeze_pretrained_epochs_colors = [freeze_pretrained_alpha(freeze_pretrained = epoch[1]) for epoch in freeze_pretrained_epochs[:-1]] # get color indicies from freeze_pretrained values
ignore_freeze_pretrained = (len(freeze_pretrained_epochs_colors) <= 1) # if freeze_pretrained doesn't change
freeze_pretrained_epochs = [epoch[0] for epoch in freeze_pretrained_epochs] # get epoch values
freeze_pretrained_epochs[-1] += 1 # add one to the last epoch, since it is a repeat of the second to last
##################################################
# CREATE PLOT
##################################################
# plot loss accuracy, and percentiles per epoch
fig, axes = plt.subplot_mosaic(mosaic = [["loss", "percentiles_history"], ["accuracy", "percentiles_history"]], constrained_layout = True, figsize = (12, 8))
fig.suptitle("Key Class Neural Network")
colors = ("tab:blue", "tab:red", "tab:orange", "tab:green", "tab:purple", "tab:brown", "tab:pink", "tab:gray", "tab:olive", "tab:cyan")
if not ignore_freeze_pretrained:
# add freeze_pretrained legend
freeze_pretrained_edgecolor = str(0.6 * min(float(freeze_pretrained_alpha(freeze_pretrained = True)), float(freeze_pretrained_alpha(freeze_pretrained = False))))
fig.legend(handles = (mpatches.Patch(facecolor = freeze_pretrained_alpha(freeze_pretrained = True), label = "Pretrained Layers Frozen", edgecolor = freeze_pretrained_edgecolor), mpatches.Patch(facecolor = freeze_pretrained_alpha(freeze_pretrained = False), label = "Pretrained Layers Unfrozen", edgecolor = freeze_pretrained_edgecolor)), loc = "upper left")
##################################################
# PLOT LOSS AND ACCURACY
##################################################
# plot background training type
for i in range(len(freeze_pretrained_epochs) - 1):
axes["loss"].axvspan(freeze_pretrained_epochs[i], freeze_pretrained_epochs[i + 1], facecolor = freeze_pretrained_epochs_colors[i])
axes["accuracy"].axvspan(freeze_pretrained_epochs[i], freeze_pretrained_epochs[i + 1], facecolor = freeze_pretrained_epochs_colors[i])
# plot line for every training session
for i in range(len(freeze_pretrained_epochs) - 1):
# create subset of history
history_subset = history[(freeze_pretrained_epochs[i] <= history["epoch"]) & (history["epoch"] < freeze_pretrained_epochs[i + 1])].drop(columns = ("freeze_pretrained")).reset_index(drop = True)
# add point to make the graph more aesthetically pleasing
if len(history_subset) == 1: # if there is one point, extend straight line
history_subset.loc[len(history_subset)] = history_subset.loc[len(history_subset) - 1].to_list()
history_subset.at[len(history_subset) - 1, "epoch"] += 1
elif len(history_subset) >= 2: # if more than one point, extend line with same slope
slopes = history_subset.loc[len(history_subset) - 1].subtract(other = history_subset.loc[len(history_subset) - 2], fill_value = 0)
slopes = slopes.multiply(other = pd.Series(data = [1,] + ([0.5,] * (len(history_subset.columns) - 1)), index = history_subset.columns), fill_value = 1) # make slopes half of what they are (compromising between flat line and continuous slope line)
history_subset.loc[len(history_subset)] = history_subset.loc[len(history_subset) - 1].add(other = slopes, fill_value = 0).to_list()
# plot values
for set_type, color in zip(("train", "validate"), colors[:2]):
axes["loss"].plot(history_subset["epoch"], history_subset[set_type + "_loss"], color = color, linestyle = "solid", label = set_type.title())
axes["accuracy"].plot(history_subset["epoch"], history_subset[set_type + "_accuracy"], color = color, linestyle = "dashed", label = set_type.title())
# set x-axis labels
axes["accuracy"].set_xlabel("Epoch")
axes["loss"].sharex(other = axes["accuracy"]) # share the x-axis labels
# set y-axis labels
axes["loss"].set_ylabel("Loss")
axes["accuracy"].set_ylabel("Accuracy (%)")
# set legends
handles_by_label_loss = dict(zip(*(axes["loss"].get_legend_handles_labels()[::-1]))) # for removing duplicate legend values
axes["loss"].legend(handles = handles_by_label_loss.values(), labels = handles_by_label_loss.keys(), loc = "upper right")
handles_by_label_accuracy = dict(zip(*(axes["accuracy"].get_legend_handles_labels()[::-1]))) # for removing duplicate legend values
axes["accuracy"].legend(handles = handles_by_label_accuracy.values(), labels = handles_by_label_accuracy.keys(), loc = "upper right")
# set titles
axes["loss"].set_title("Learning Curve")
axes["accuracy"].set_title("Accuracy")
##################################################
# PLOT PERCENTILES PER EPOCH
##################################################
# plot percentiles per epoch (final 5 epochs)
n_epochs = min(5, len(pd.unique(percentiles_history["epoch"])), len(colors)) # get the minimum value so there is no error
colors = colors[:n_epochs] # subet colors to the number of epochs
percentiles_history = percentiles_history[percentiles_history["epoch"] > (max(percentiles_history["epoch"]) - n_epochs)] # get 5 or less most recent epochs
epochs = sorted(pd.unique(percentiles_history["epoch"])) # get the epoch numbers that will be displayed in the plot
for i, epoch in enumerate(epochs):
percentile_at_epoch = percentiles_history[percentiles_history["epoch"] == epoch]
axes["percentiles_history"].plot(percentile_at_epoch["percentile"], percentile_at_epoch["value"], color = colors[i], linestyle = "solid", label = epoch)
# set x-axis label
axes["percentiles_history"].set_xlabel("Percentile")
# set y-axis label
axes["percentiles_history"].set_ylabel("Error")
# set legend
axes["percentiles_history"].legend(title = "Epoch", loc = "upper left")
# add gridlines
axes["percentiles_history"].grid()
# set title
axes["percentiles_history"].set_title("Validation Data Percentiles")
##################################################
# SAVE
##################################################
# save figure
fig.savefig(OUTPUT_FILEPATH, dpi = 240) # save image
print(f"Training plot saved to {OUTPUT_FILEPATH}.")
##################################################