Coverage for orcanet/utilities/visualization.py: 91%
156 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
1# -*- coding: utf-8 -*-
2"""
3Visualization tools used without Keras.
4Makes performance graphs for training and validating.
5"""
6import os
7import numpy as np
9import matplotlib.pyplot as plt
10from matplotlib.backends.backend_pdf import PdfPages
13class TrainValPlotter:
14 """
15 Class for plotting train/val curves.
17 Instructions
18 ------------
19 1. Use tvp.plot_curves(train, val) once or more on pairs of
20 train/val data.
21 2. When all lines are plotted, use tvp.apply_layout() once for proper
22 scaling, ylims, etc.
24 """
26 def __init__(self):
27 # White space added below and above points
28 self.y_lim_padding = [0.10, 0.25]
29 # Store all plotted points for setting x/y lims
30 self._xpoints_train = np.array([])
31 self._xpoints_val = np.array([])
32 self._ypoints_train = np.array([])
33 self._ypoints_val = np.array([])
35 def plot_curves(
36 self,
37 train_data,
38 val_data=None,
39 train_label="training",
40 val_label="validation",
41 color=None,
42 smooth_sigma=None,
43 tlw=0.5,
44 vlw=0.5,
45 vms=3,
46 ):
47 """
48 Plot a training and optionally a validation line.
50 The data can contain nan's.
52 Parameters
53 ----------
54 train_data : List
55 X data [0] and y data [1] of the train curve. Will be plotted as
56 connected dots.
57 val_data : List, optional
58 Optional X data [0] and y data [1] of the validation curve.
59 Will be plotted as a faint solid line of the same color as train.
60 train_label : str, optional
61 Label for the train line in the legend.
62 val_label : str, optional
63 Label for the validation line in the legend.
64 color : str, optional
65 Color used for the train/val line.
66 smooth_sigma : int, optional
67 Apply gaussian blur to the train curve with given sigma.
68 tlw : float
69 Linewidth of train curve.
70 vlw : float
71 Linewidth of val curve.
72 vms : float
73 Markersize of the val curve.
75 """
76 if train_data is None and val_data is None:
77 raise ValueError("Can not plot when no train and val data is given.")
79 if train_data is not None:
80 epoch, y_data = train_data
81 if smooth_sigma is not None:
82 y_data = gaussian_smooth(y_data, smooth_sigma)
84 self._xpoints_train = np.concatenate((self._xpoints_train, epoch))
85 self._ypoints_train = np.concatenate((self._ypoints_train, y_data))
87 train_plot = plt.plot(
88 epoch,
89 y_data,
90 color=color,
91 ls="-",
92 zorder=3,
93 label=train_label,
94 lw=tlw,
95 alpha=0.5,
96 )
97 train_color = train_plot[0].get_color()
98 else:
99 train_color = color
101 if val_data is not None:
102 self._xpoints_val = np.concatenate((self._xpoints_val, val_data[0]))
103 self._ypoints_val = np.concatenate((self._ypoints_val, val_data[1]))
105 val_data_clean = skip_nans(val_data)
106 # val plot always has the same color as the train plot
107 plt.plot(
108 val_data_clean[0],
109 val_data_clean[1],
110 color=train_color,
111 marker="o",
112 zorder=3,
113 lw=vlw,
114 markersize=vms,
115 label=val_label,
116 )
118 def apply_layout(
119 self,
120 title=None,
121 x_label="Epoch",
122 y_label=None,
123 grid=True,
124 legend=True,
125 x_lims=None,
126 y_lims="auto",
127 x_ticks="auto",
128 logy=False,
129 ):
130 """
131 Apply given layout.
132 Can calculate good y_lims and x_ticks automatically.
134 Parameters
135 ----------
136 title : str
137 Title of the plot.
138 x_label : str
139 X label of the plot.
140 y_label : str
141 Y label of the plot.
142 grid : bool
143 If true, show a grid.
144 legend : bool
145 If true, show a legend.
146 x_lims : List
147 X limits of the data.
148 y_lims : List or str
149 Y limits of the data. "auto" for auto-calculation.
150 x_ticks : List
151 Positions of the major x ticks.
152 logy : bool
153 If true, make y axis log.
155 """
156 if logy:
157 plt.yscale("log")
158 if x_ticks is not None:
159 if x_ticks == "auto":
160 all_x_points = np.concatenate((self._xpoints_train, self._xpoints_val))
161 x_ticks = get_epoch_xticks(all_x_points)
162 else:
163 x_ticks = x_ticks
164 plt.xticks(x_ticks)
166 if x_lims is not None:
167 plt.xlim(x_lims)
169 if y_lims is not None:
170 if y_lims == "auto":
171 y_lims = get_ylims(
172 self._ypoints_train,
173 self._ypoints_val,
174 fraction=self.y_lim_padding,
175 )
176 else:
177 y_lims = y_lims
178 plt.ylim(y_lims)
180 if legend:
181 plt.legend(loc="upper right")
183 plt.xlabel(x_label)
184 plt.ylabel(y_label)
186 if title is not None:
187 title = plt.title(title)
188 title.set_position([0.5, 1.04])
190 if grid:
191 plt.grid(True, zorder=0, linestyle="dotted")
194def gaussian_smooth(y, sigma, truncate=4):
195 """Smooth a 1d ndarray with a gaussian filter."""
196 # kernel_width = 2 * sigma * truncate + 1
197 kernel_x = np.arange(-truncate * sigma, truncate * sigma + 1)
198 kernel = _gauss(kernel_x, 0, sigma)
199 y = np.pad(np.asarray(y), int(len(kernel) / 2), "edge")
200 blurred = np.convolve(y, kernel, "valid")
201 return blurred
204def _gauss(x, mu=0, sigma=1):
205 return (1 / (np.sqrt(2 * np.pi) * sigma)) * np.exp(
206 -np.power(x - mu, 2.0) / (2 * np.power(sigma, 2.0))
207 )
210def plot_history(
211 train_data,
212 val_data=None,
213 train_label="training",
214 val_label="validation",
215 color=None,
216 **kwargs
217):
218 """
219 Plot the train/val curves in a single plot.
221 For backward compat. Functionality moved to TrainValPlotter
223 """
224 tvp = TrainValPlotter()
225 tvp.plot_curves(
226 train_data, val_data, train_label=train_label, val_label=val_label, color=color
227 )
228 tvp.apply_layout(**kwargs)
231def skip_nans(data):
232 """
233 Skip over nan values, so that all dots are connected.
235 Parameters
236 ----------
237 data : List
238 Contains x and y data as ndarrays. The y values may contain nans.
240 Returns
241 -------
242 data_clean : List
243 Contains x and y data as ndarrays. Points with y=nan are skipped.
245 """
246 not_nan = ~np.isnan(data[1])
247 data_clean = data[0][not_nan], data[1][not_nan]
248 return data_clean
251def get_ylims(y_points_train, y_points_val=None, fraction=0.25):
252 """
253 Get the y limits for the summary plot.
255 For the training data, limits are calculated while ignoring data points
256 which are far from the median (in terms of the median distance
257 from the median).
258 This is because there are outliers sometimes in the training data,
259 especially early on in the training.
261 Parameters
262 ----------
263 y_points_train : List
264 y data of the train curve.
265 y_points_val : List or None
266 Y data of the validation curve.
267 fraction : float or List
268 How much whitespace of the total y range is added below and above
269 the lines.
271 Returns
272 -------
273 y_lims : tuple
274 Minimum, maximum of the data.
276 """
277 assert not (
278 y_points_train is None and y_points_val is None
279 ), "train and val data are None"
281 def reject_outliers(data, threshold):
282 d = np.abs(data - np.median(data))
283 mdev = np.median(d)
284 s = d / mdev if mdev else 0.0
285 no_outliers = data[s < threshold]
286 lims = np.amin(no_outliers), np.amax(no_outliers)
287 return lims
289 mins, maxs = [], []
290 if y_points_train is not None and len(y_points_train) != 0:
291 y_train = y_points_train[~np.isnan(y_points_train)]
292 y_lims_train = reject_outliers(y_train, 5)
293 mins.append(y_lims_train[0])
294 maxs.append(y_lims_train[1])
296 if y_points_val is not None and len(y_points_val) != 0:
297 y_val = y_points_val[~np.isnan(y_points_val)]
299 if len(y_val) == 1:
300 y_lim_val = y_val[0], y_val[0]
301 else:
302 y_lim_val = np.amin(y_val), np.amax(y_val)
304 mins.append(y_lim_val[0])
305 maxs.append(y_lim_val[1])
307 if len(mins) == 1:
308 y_lims = (mins[0], maxs[0])
309 else:
310 y_lims = np.amin(mins), np.amax(maxs)
312 if y_lims[0] == y_lims[1]:
313 y_range = 0.1 * y_lims[0]
314 else:
315 y_range = y_lims[1] - y_lims[0]
317 try:
318 fraction = float(fraction)
319 padding = [fraction, fraction]
320 except TypeError:
321 # is a list
322 padding = fraction
324 if padding != [0.0, 0.0]:
325 y_lims = (y_lims[0] - padding[0] * y_range, y_lims[1] + padding[1] * y_range)
327 return y_lims
330def get_epoch_xticks(x_points):
331 """
332 Calculates the xticks for the train and validation summary plot.
334 One tick per epoch. Less the larger #epochs is.
336 Parameters
337 ----------
338 x_points : List
339 A list of the x coordinates of all points.
341 Returns
342 -------
343 x_ticks_major : numpy.ndarray
344 Array containing the ticks.
346 """
347 if len(x_points) == 0:
348 raise ValueError("x-coordinates are empty!")
350 minimum, maximum = np.amin(x_points), np.amax(x_points)
351 if maximum - minimum > 0.5:
352 # for longer trainings
353 start_epoch, end_epoch = np.floor(minimum), np.ceil(maximum)
354 # less xticks if there are many epochs
355 n_epochs = end_epoch - start_epoch
356 x_ticks_stepsize = 1 + np.floor(n_epochs / 20.0)
357 x_ticks_major = np.arange(
358 start_epoch, end_epoch + x_ticks_stepsize, x_ticks_stepsize
359 )
360 else:
361 # for early peeks
362 start_epoch = np.floor(minimum)
363 end_epoch = maximum + minimum - start_epoch
364 x_ticks_major = np.linspace(start_epoch, end_epoch, 6)
366 return x_ticks_major
369def update_summary_plot(orga):
370 """
371 Plot and save all metrics of the given validation- and train-data
372 into a pdf file, each metric in its own plot.
374 If metric pairs of a variable and its error are found (e.g. e_loss
375 and e_err_loss), they will have the same color and appear back to
376 back in the plot.
378 Parameters
379 ----------
380 orga : orcanet.core.Organizer
381 Contains all the configurable options in the OrcaNet scripts.
383 """
384 plt.ioff()
385 pdf_name = orga.io.get_subfolder("plots", create=True) + "/summary_plot.pdf"
387 # Extract the names of the metrics
388 all_metrics = orga.history.get_metrics()
389 # Sort them
390 all_metrics = sort_metrics(all_metrics)
391 # Plot them w/ custom color cycle
392 colors = [
393 "#000000",
394 "#332288",
395 "#88CCEE",
396 "#44AA99",
397 "#117733",
398 "#999933",
399 "#DDCC77",
400 "#CC6677",
401 "#882255",
402 "#AA4499",
403 "#661100",
404 "#6699CC",
405 "#AA4466",
406 "#4477AA",
407 ] # ref. personal.sron.nl/~pault/
408 color_counter = 0
409 with PdfPages(pdf_name) as pdf:
410 for metric_no, metric in enumerate(all_metrics):
411 # If this metric is an err metric of a variable, color it the same
412 if all_metrics[metric_no - 1] == metric.replace("_err", ""):
413 color_counter -= 1
414 orga.history.plot_metric(metric, color=colors[color_counter % len(colors)])
415 plt.suptitle(os.path.basename(os.path.abspath(orga.cfg.output_folder)))
416 color_counter += 1
417 pdf.savefig()
418 plt.clf()
420 orga.history.plot_lr()
421 color_counter += 1
422 pdf.savefig()
423 plt.close()
426def sort_metrics(metric_names):
427 """
428 Sort a list of metrics, so that errors are right after their variable.
429 The format of the metric names have to be e.g. e_loss and e_err_loss
430 for this to work.
432 Example
433 ----------
434 >>> sort_metrics( ['e_loss', 'loss', 'e_err_loss', 'dx_err_loss'] )
435 ['e_loss', 'e_err_loss', 'loss', 'dx_err_loss']
437 Parameters
438 ----------
439 metric_names : List
440 List of metric names.
442 Returns
443 -------
444 sorted_metrics : List
445 List of sorted metric names with the same length as the input.
447 """
448 sorted_metrics = [0] * len(metric_names)
449 counter = 0
450 for metric_name in metric_names:
451 if "err_" in metric_name:
452 if metric_name.replace("err_", "") not in metric_names:
453 sorted_metrics[counter] = metric_name
454 counter += 1
455 continue
456 sorted_metrics[counter] = metric_name
457 counter += 1
458 err_loss = metric_name.split("_loss")[0] + "_err_loss"
459 if err_loss in metric_names:
460 sorted_metrics[counter] = err_loss
461 counter += 1
463 assert 0 not in sorted_metrics, (
464 "Something went wrong with the sorting of "
465 "metrics! Given was {}, output was "
466 "{}. ".format(metric_names, sorted_metrics)
467 )
469 return sorted_metrics