Coverage for orcanet/utilities/summarize_training.py: 0%
143 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
1import matplotlib.pyplot as plt
2import argparse
3import warnings
4from orcanet.history import HistoryHandler
5from orcanet.utilities.visualization import TrainValPlotter
8class Summarizer:
9 """
10 Summarize one or more trainings by giving their orcanet folder(s).
12 - Plot the training and validation curves in a single plot and show them
13 - Print info about the best and worst epochs
15 Parameters
16 ----------
17 folders : str or List, optional
18 Path to a orcanet folder, or to multiple folder as a list.
19 [default: CWD].
20 metric : str or List, optional
21 The metric to plot [default: 'loss'].
22 If its a list: Same length as folders. Plot a different metric for
23 each folder.
24 smooth : int, optional
25 Apply gaussian blur to the train curve with given sigma.
26 labels : List, optional
27 Labels for each folder.
28 noplot : bool
29 Dont plot the train/val curves [default: False].
30 width : float
31 Scaling of the width of the curves and the marker size [default: 1].
33 """
35 def __init__(
36 self,
37 folders,
38 metric="loss",
39 smooth=None,
40 labels=None,
41 noplot=False,
42 width=1.0,
43 verbose=True,
44 ):
45 if not folders:
46 folders = ["./"]
47 elif isinstance(folders, str):
48 folders = [folders]
49 self.folders = folders
51 if isinstance(metric, str):
52 metric = [metric]
53 if len(metric) == 1:
54 self.metrics = metric * len(self.folders)
55 self._unique_metrics = False
56 else:
57 if len(metric) != len(folders):
58 raise ValueError("Need to give exactly one metric per folder!")
59 self.metrics = metric
60 self._unique_metrics = True
62 if labels is None:
63 self.labels = self.folders
64 else:
65 self.labels = labels
67 self.verbose = verbose
68 self.smooth = smooth
69 self.noplot = noplot
70 self.width = width
71 self._tvp = None
73 def summarize(self, show=True):
74 if not self.noplot:
75 self._tvp = TrainValPlotter()
77 min_stats, max_stats = [], []
78 if self.verbose:
79 print("Reading stats of {} trainings...".format(len(self.folders)))
80 for folder_no in range(len(self.folders)):
81 try:
82 min_stat, max_stat = self._summarize_folder(folder_no)
83 if min_stat is not None:
84 min_stats.append(min_stat)
85 if max_stat is not None:
86 max_stats.append(max_stat)
87 except Exception as e:
88 if self.verbose:
89 print(
90 f"Warning: Can not summarize {self.folders[folder_no]}"
91 f", skipping... ({e})"
92 )
94 if self._unique_metrics:
95 column_title, y_label = ("combined metrics",) * 2
96 else:
97 column_title, y_label = self._full_metrics[0], self._metric_names[0]
99 if self.verbose:
100 if len(min_stats) > 0:
101 min_stats.sort()
102 print("\nMinimum\n-------")
103 print("{} \t{}\t{}\t{}".format(" ", "Epoch", column_title, "name"))
104 for i, stat in enumerate(min_stats, 1):
105 print("{} | \t{}\t{}\t{}".format(i, stat[2], stat[0], stat[1]))
107 if len(max_stats) > 0:
108 max_stats.sort(reverse=True)
109 print("\nMaximum\n-------")
110 print("{} \t{}\t{}\t{}".format(" ", "Epoch", column_title, "name"))
111 for i, stat in enumerate(max_stats, 1):
112 print("{} | \t{}\t{}\t{}".format(i, stat[2], stat[0], stat[1]))
114 if not self.noplot:
115 self._tvp.apply_layout(
116 x_label="Epoch",
117 y_label=y_label,
118 grid=True,
119 legend=True,
120 )
121 if show:
122 plt.show()
124 @property
125 def _metric_names(self):
126 """E.g. [loss, ...]"""
127 metric_names = []
128 for metric in self.metrics:
129 if metric.startswith("train_"):
130 m = metric[6:]
131 elif metric.startswith("val_"):
132 m = metric[4:]
133 else:
134 m = metric
135 metric_names.append(m)
136 return metric_names
138 @property
139 def _full_metrics(self):
140 """E.g. [val_loss, ...]"""
141 full_metrics = []
142 for metric in self.metrics:
143 if metric.startswith("train_") or metric.startswith("val_"):
144 full_metrics.append(metric)
145 else:
146 full_metrics.append("val_" + metric)
147 return full_metrics
149 def _summarize_folder(self, folder_no):
150 label = self.labels[folder_no]
151 folder = self.folders[folder_no]
153 hist = HistoryHandler(folder)
154 val_data, min_stat, max_stat = None, None, None
155 # read data from summary file
156 try:
157 smry_met_name = self._full_metrics[folder_no]
158 max_line = hist.get_best_epoch_info(metric=smry_met_name, mini=False)
159 min_line = hist.get_best_epoch_info(metric=smry_met_name, mini=True)
160 min_stat = [min_line[smry_met_name], label, min_line["Epoch"]]
161 max_stat = [max_line[smry_met_name], label, max_line["Epoch"]]
163 summary_data = hist.get_summary_data()
164 val_data = [
165 summary_data["Epoch"],
166 summary_data[self._full_metrics[folder_no]],
167 ]
168 except OSError:
169 if self.verbose:
170 print(f"Warning: No summary file found for {folder}")
172 except ValueError as e:
173 if self.verbose:
174 print(f"Error reading summary file {hist.summary_file} ({e})")
176 # read data from training files
177 full_train_data = hist.get_train_data()
178 train_data = [
179 full_train_data["Batch_float"],
180 full_train_data[self._metric_names[folder_no]],
181 ]
183 if not self.noplot:
184 if len(self.labels) == 1:
185 train_label, val_label = "training", "validation"
186 elif val_data is None:
187 train_label, val_label = label, None
188 else:
189 train_label, val_label = None, label
191 self._tvp.plot_curves(
192 train_data=train_data,
193 val_data=val_data,
194 train_label=train_label,
195 val_label=val_label,
196 smooth_sigma=self.smooth,
197 tlw=0.5 * self.width,
198 vlw=0.5 * self.width,
199 vms=3 * self.width ** 0.5,
200 )
201 return min_stat, max_stat
203 def summarize_dirs(self):
204 """
205 Get the best and worst epochs of all given folders as a dict.
207 Returns
208 -------
209 minima : dict
210 Keys : Name of folder.
211 Values : [Epoch, metric] of where the metric is lowest.
212 maxima : dict
213 As above, but for where the metric is highest.
215 """
216 minima, maxima = {}, {}
217 for folder_no, folder in enumerate(self.folders):
218 hist = HistoryHandler(folder)
219 smry_met_name = self._full_metrics[folder_no]
220 try:
221 max_line = hist.get_best_epoch_info(metric=smry_met_name, mini=False)
222 min_line = hist.get_best_epoch_info(metric=smry_met_name, mini=True)
223 except OSError as e:
224 warnings.warn(str(e))
225 continue
227 minima[folder] = [min_line["Epoch"], min_line[smry_met_name]]
228 maxima[folder] = [max_line["Epoch"], max_line[smry_met_name]]
229 return minima, maxima
232def summarize(**kwargs):
233 for key in list(kwargs.keys()):
234 if kwargs[key] is None:
235 kwargs.pop(key)
236 Summarizer(**kwargs).summarize()
239def get_parser():
240 parser = argparse.ArgumentParser(
241 description=str(Summarizer.__doc__),
242 formatter_class=argparse.RawTextHelpFormatter,
243 )
244 parser.add_argument("folders", type=str, nargs="*")
245 parser.add_argument("--metric", type=str, nargs="*")
246 parser.add_argument("--smooth", nargs="?", type=int)
247 parser.add_argument("--width", nargs="?", type=float)
248 parser.add_argument("--labels", nargs="*", type=str)
249 parser.add_argument("--noplot", action="store_true")
250 return parser
253def main():
254 parser = get_parser()
255 args = vars(parser.parse_args())
256 summarize(**args)
259if __name__ == "__main__":
260 main()