Coverage for orcanet/history.py: 93%
134 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
1import numpy as np
2import os
3from orcanet.utilities.visualization import plot_history
4from orcanet.in_out import get_subfolder
7class HistoryHandler:
8 """
9 For reading and plotting data from summary and train log files.
11 """
13 def __init__(self, main_folder):
14 self.main_folder = main_folder
15 self.summary_filename = "summary.txt"
17 @property
18 def summary_file(self):
19 main_folder = self.main_folder
20 if not main_folder[-1] == "/":
21 main_folder += "/"
22 return main_folder + self.summary_filename
24 @property
25 def train_log_folder(self):
26 return get_subfolder(self.main_folder, "train_log")
28 def plot_metric(self, metric_name, **kwargs):
29 """
30 Plot the training and validation history of a metric.
32 This will read out data from the summary file, as well as
33 all training log files, and plot them over the epoch.
35 Parameters
36 ----------
37 metric_name : str
38 Name of the metric to be plotted over the epoch. This name is what
39 was written in the head line of the summary.txt file, except without
40 the train_ or val_ prefix.
41 kwargs
42 Keyword arguments for the plot_history function.
44 """
45 summary_data = self.get_summary_data()
46 full_train_data = self.get_train_data()
47 summary_label = "val_" + metric_name
49 if metric_name not in full_train_data.dtype.names:
50 raise ValueError(
51 "Train log metric name {} unknown, must be one of {}".format(
52 metric_name, self.get_metrics()
53 )
54 )
55 if summary_label not in summary_data.dtype.names:
56 raise ValueError(
57 "Summary metric name {} unknown, must be one of {}".format(
58 summary_label, self.get_metrics()
59 )
60 )
62 if summary_data["Epoch"].shape == (0,):
63 # When no lines are present in the summary.txt file.
64 val_data = None
65 else:
66 val_data = [summary_data["Epoch"], summary_data[summary_label]]
68 if full_train_data["Batch_float"].shape == (0,):
69 # When no lines are present
70 raise ValueError(
71 "Can not make summary plot: Training log files " "contain no data!"
72 )
73 else:
74 train_data = [full_train_data["Batch_float"], full_train_data[metric_name]]
76 # if no validation has been done yet
77 if np.all(np.isnan(val_data)[1]):
78 val_data = None
80 if "y_label" not in kwargs:
81 kwargs["y_label"] = metric_name
83 plot_history(train_data, val_data, **kwargs)
85 def plot_lr(self, **kwargs):
86 """
87 Plot the learning rate over the epochs.
89 Parameters
90 ----------
91 kwargs
92 Keyword arguments for the plot_history function.
94 Returns
95 -------
96 fig : matplotlib.figure.Figure
97 The plot.
99 """
100 summary_data = self.get_summary_data()
102 epoch = summary_data["Epoch"]
103 lr = summary_data["LR"]
104 # plot learning rate like val data (connected dots)
105 val_data = (epoch, lr)
107 if "y_label" not in kwargs:
108 kwargs["y_label"] = "Learning rate"
109 if "legend" not in kwargs:
110 kwargs["legend"] = False
112 plot_history(
113 train_data=None, val_data=val_data, logy=True, y_lims=None, **kwargs
114 )
116 def get_metrics(self):
117 """
118 Get the name of the metrics from the first line in the summary file.
120 This will be the actual name of the metric, i.e. "loss" and not
121 "train_loss" or "val_loss".
123 Returns
124 -------
125 all_metrics : List
126 A list of the metrics.
128 """
129 summary_data = self.get_summary_data()
130 all_metrics = []
131 for keyword in summary_data.dtype.names:
132 if keyword == "Epoch" or keyword == "LR":
133 continue
134 if "train_" in keyword:
135 keyword = keyword.split("train_")[-1]
136 else:
137 keyword = keyword.split("val_")[-1]
138 if keyword not in all_metrics:
139 all_metrics.append(keyword)
140 return all_metrics
142 def get_summary_data(self):
143 """
144 Read out the summary file in the output folder.
146 Returns
147 -------
148 summary_data : ndarray
149 Numpy structured array with the column names as datatypes.
150 Its shape is the number of lines with data.
152 """
153 summary_data = self._load_txt(self.summary_file)
154 if summary_data.shape == ():
155 # When only one line is present
156 summary_data = summary_data.reshape(
157 1,
158 )
159 return summary_data
161 def get_best_epoch_info(self, metric="val_loss", mini=True):
162 """
163 Get the line in the summary file where the given metric is best, i.e.
164 either minimal or maximal.
166 Parameters
167 ----------
168 metric : str
169 Which metric to look up.
170 mini : bool
171 If true, look up the minimum. Else the maximum.
173 Raises
174 ------
175 ValueError
176 If there is no best line (e.g. no validation has been done).
178 """
179 summary_data = self.get_summary_data()
180 metric_data = summary_data[metric]
182 if all(np.isnan(metric_data)):
183 raise ValueError("Can not find best epoch in summary.txt")
185 if mini:
186 opt_loss = np.nanmin(metric_data)
187 else:
188 opt_loss = np.nanmax(metric_data)
190 best_index = np.where(metric_data == opt_loss)[0]
191 # if multiple epochs with same loss, take first
192 best_index = min(best_index)
193 best_line = summary_data[best_index]
194 return best_line
196 def get_best_epoch_fileno(self, metric="val_loss", mini=True):
197 """
198 Get the epoch/fileno tuple where the given metric is smallest.
199 """
200 best_line = self.get_best_epoch_info(metric=metric, mini=mini)
201 best_epoch_float = best_line["Epoch"]
202 epoch, fileno = self._transform_epoch(best_epoch_float)
203 return epoch, fileno
205 def _transform_epoch(self, epoch_float):
206 """
207 Transfrom the epoch_float read from a file to a tuple.
209 (By just counting the number of lines in the given epoch).
210 TODO Hacky, (epoch, filno) should probably be written to the summary.
211 """
212 summary_data = self.get_summary_data()
214 epoch = int(np.floor(epoch_float - 1e-8))
215 # all lines in the epoch of epoch_float
216 indices = np.where(np.floor(summary_data["Epoch"] - 1e-8) == epoch)[0]
217 lines = summary_data[indices]
218 fileno = int(np.where(lines["Epoch"] == epoch_float)[0]) + 1
219 epoch += 1
220 return epoch, fileno
222 def get_column_names(self):
223 """
224 Get the str in the first line in each column.
226 Returns
227 -------
228 tuple : column_names
229 The names in the same order as they appear in the summary.txt.
231 """
232 summary_data = self.get_summary_data()
233 column_names = summary_data.dtype.names
234 return column_names
236 def get_train_data(self):
237 """
238 Read out all training logfiles in the output folder.
240 Read out the data from the summary.txt file, and from all training
241 log files in the train_log folder, which is in the same directory
242 as the summary.txt file.
244 Returns
245 -------
246 summary_data : numpy.ndarray
247 Structured array containing the data from the summary.txt file.
248 Its shape is the number of lines with data.
250 """
251 # list of all files in the train_log folder of this model
252 files = os.listdir(self.train_log_folder)
253 train_file_data = []
254 for file in files:
255 if not (file.startswith("log_epoch_") and file.endswith(".txt")):
256 continue
257 filepath = os.path.join(self.train_log_folder, file)
258 if os.path.getsize(filepath) == 0:
259 continue
261 # file is sth like "log_epoch_1_file_2.txt", extract epoch & fileno:
262 epoch, file_no = [int(file.split(".")[0].split("_")[i]) for i in [2, 4]]
263 file_data = self._load_txt(filepath)
264 train_file_data.append([[epoch, file_no], file_data])
266 if len(train_file_data) == 0:
267 raise OSError(f"No train files found in {self.train_log_folder}!")
269 # sort so that earlier epochs come first
270 train_file_data.sort()
271 full_train_data = train_file_data[0][1]
272 for [epoch, file_no], file_data in train_file_data[1:]:
273 full_train_data = np.append(full_train_data, file_data)
275 if full_train_data.shape == ():
276 # When only one line is present
277 full_train_data = full_train_data.reshape(
278 1,
279 )
280 return full_train_data
282 def get_state(self):
283 """
284 Get the state of a training.
286 For every line in the summary logfile, get a dict with the epoch
287 as a float, and is_trained and is_validated bools.
289 Returns
290 -------
291 state_dicts : List
292 List of dicts.
294 """
295 summary_data = self.get_summary_data()
296 state_dicts = []
297 names = summary_data.dtype.names
299 for line in summary_data:
300 val_losses, train_losses = {}, {}
301 for name in names:
302 if name.startswith("val_"):
303 val_losses[name] = line[name]
304 elif name.startswith("train_"):
305 train_losses[name] = line[name]
306 elif name not in ["Epoch", "LR"]:
307 raise NameError(
308 "Invalid summary file: Invalid column name {}: must be "
309 "either Epoch, LR, or start with val_ or train_".format(name)
310 )
311 # if theres any not-nan entry, consider it completed
312 is_trained = any(~np.isnan(tuple(train_losses.values())))
313 is_validated = any(~np.isnan(tuple(val_losses.values())))
315 state_dicts.append(
316 {
317 "epoch": line["Epoch"],
318 "is_trained": is_trained,
319 "is_validated": is_validated,
320 }
321 )
323 return state_dicts
325 @staticmethod
326 def _load_txt(filepath):
327 # TODO suboptimal that n/a gets replaced by np.nan, because this
328 # means that legitamte, not availble cells can not be distinguished
329 # from failed 'nan' metric values produced by training.
330 file_data = np.genfromtxt(
331 filepath,
332 names=True,
333 delimiter="|",
334 autostrip=True,
335 comments="--",
336 missing_values="n/a",
337 filling_values=np.nan,
338 )
339 # replace inf with nan so it can be plotted
340 for column_name in file_data.dtype.names:
341 x = file_data[column_name]
342 x[np.isinf(x)] = np.nan
343 return file_data