Source code for orcanet.utilities.visualization

# -*- coding: utf-8 -*-
Visualization tools used without Keras.
Makes performance graphs for training and validating.
import os
import numpy as np

import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

[docs]class TrainValPlotter: """ Class for plotting train/val curves. Instructions ------------ 1. Use tvp.plot_curves(train, val) once or more on pairs of train/val data. 2. When all lines are plotted, use tvp.apply_layout() once for proper scaling, ylims, etc. """ def __init__(self): # White space added below and above points self.y_lim_padding = [0.10, 0.25] # Store all plotted points for setting x/y lims self._xpoints_train = np.array([]) self._xpoints_val = np.array([]) self._ypoints_train = np.array([]) self._ypoints_val = np.array([])
[docs] def plot_curves( self, train_data, val_data=None, train_label="training", val_label="validation", color=None, smooth_sigma=None, tlw=0.5, vlw=0.5, vms=3, ): """ Plot a training and optionally a validation line. The data can contain nan's. Parameters ---------- train_data : List X data [0] and y data [1] of the train curve. Will be plotted as connected dots. val_data : List, optional Optional X data [0] and y data [1] of the validation curve. Will be plotted as a faint solid line of the same color as train. train_label : str, optional Label for the train line in the legend. val_label : str, optional Label for the validation line in the legend. color : str, optional Color used for the train/val line. smooth_sigma : int, optional Apply gaussian blur to the train curve with given sigma. tlw : float Linewidth of train curve. vlw : float Linewidth of val curve. vms : float Markersize of the val curve. """ if train_data is None and val_data is None: raise ValueError("Can not plot when no train and val data is given.") if train_data is not None: epoch, y_data = train_data if smooth_sigma is not None: y_data = gaussian_smooth(y_data, smooth_sigma) self._xpoints_train = np.concatenate((self._xpoints_train, epoch)) self._ypoints_train = np.concatenate((self._ypoints_train, y_data)) train_plot = plt.plot( epoch, y_data, color=color, ls="-", zorder=3, label=train_label, lw=tlw, alpha=0.5, ) train_color = train_plot[0].get_color() else: train_color = color if val_data is not None: self._xpoints_val = np.concatenate((self._xpoints_val, val_data[0])) self._ypoints_val = np.concatenate((self._ypoints_val, val_data[1])) val_data_clean = skip_nans(val_data) # val plot always has the same color as the train plot plt.plot( val_data_clean[0], val_data_clean[1], color=train_color, marker="o", zorder=3, lw=vlw, markersize=vms, label=val_label,
[docs] def apply_layout( self, title=None, x_label="Epoch", y_label=None, grid=True, legend=True, x_lims=None, y_lims="auto", x_ticks="auto", logy=False, ): """ Apply given layout. Can calculate good y_lims and x_ticks automatically. Parameters ---------- title : str Title of the plot. x_label : str X label of the plot. y_label : str Y label of the plot. grid : bool If true, show a grid. legend : bool If true, show a legend. x_lims : List X limits of the data. y_lims : List or str Y limits of the data. "auto" for auto-calculation. x_ticks : List Positions of the major x ticks. logy : bool If true, make y axis log. """ if logy: plt.yscale("log") if x_ticks is not None: if x_ticks == "auto": all_x_points = np.concatenate((self._xpoints_train, self._xpoints_val)) x_ticks = get_epoch_xticks(all_x_points) else: x_ticks = x_ticks plt.xticks(x_ticks) if x_lims is not None: plt.xlim(x_lims) if y_lims is not None: if y_lims == "auto": y_lims = get_ylims( self._ypoints_train, self._ypoints_val, fraction=self.y_lim_padding, ) else: y_lims = y_lims plt.ylim(y_lims) if legend: plt.legend(loc="upper right") plt.xlabel(x_label) plt.ylabel(y_label) if title is not None: title = plt.title(title) title.set_position([0.5, 1.04]) if grid: plt.grid(True, zorder=0, linestyle="dotted")
[docs]def gaussian_smooth(y, sigma, truncate=4): """Smooth a 1d ndarray with a gaussian filter.""" # kernel_width = 2 * sigma * truncate + 1 kernel_x = np.arange(-truncate * sigma, truncate * sigma + 1) kernel = _gauss(kernel_x, 0, sigma) y = np.pad(np.asarray(y), int(len(kernel) / 2), "edge") blurred = np.convolve(y, kernel, "valid") return blurred
def _gauss(x, mu=0, sigma=1): return (1 / (np.sqrt(2 * np.pi) * sigma)) * np.exp( -np.power(x - mu, 2.0) / (2 * np.power(sigma, 2.0)) )
[docs]def plot_history( train_data, val_data=None, train_label="training", val_label="validation", color=None, **kwargs ): """ Plot the train/val curves in a single plot. For backward compat. Functionality moved to TrainValPlotter """ tvp = TrainValPlotter() tvp.plot_curves( train_data, val_data, train_label=train_label, val_label=val_label, color=color ) tvp.apply_layout(**kwargs)
[docs]def skip_nans(data): """ Skip over nan values, so that all dots are connected. Parameters ---------- data : List Contains x and y data as ndarrays. The y values may contain nans. Returns ------- data_clean : List Contains x and y data as ndarrays. Points with y=nan are skipped. """ not_nan = ~np.isnan(data[1]) data_clean = data[0][not_nan], data[1][not_nan] return data_clean
[docs]def get_ylims(y_points_train, y_points_val=None, fraction=0.25): """ Get the y limits for the summary plot. For the training data, limits are calculated while ignoring data points which are far from the median (in terms of the median distance from the median). This is because there are outliers sometimes in the training data, especially early on in the training. Parameters ---------- y_points_train : List y data of the train curve. y_points_val : List or None Y data of the validation curve. fraction : float or List How much whitespace of the total y range is added below and above the lines. Returns ------- y_lims : tuple Minimum, maximum of the data. """ assert not ( y_points_train is None and y_points_val is None ), "train and val data are None" def reject_outliers(data, threshold): d = np.abs(data - np.median(data)) mdev = np.median(d) s = d / mdev if mdev else 0.0 no_outliers = data[s < threshold] lims = np.amin(no_outliers), np.amax(no_outliers) return lims mins, maxs = [], [] if y_points_train is not None and len(y_points_train) != 0: y_train = y_points_train[~np.isnan(y_points_train)] y_lims_train = reject_outliers(y_train, 5) mins.append(y_lims_train[0]) maxs.append(y_lims_train[1]) if y_points_val is not None and len(y_points_val) != 0: y_val = y_points_val[~np.isnan(y_points_val)] if len(y_val) == 1: y_lim_val = y_val[0], y_val[0] else: y_lim_val = np.amin(y_val), np.amax(y_val) mins.append(y_lim_val[0]) maxs.append(y_lim_val[1]) if len(mins) == 1: y_lims = (mins[0], maxs[0]) else: y_lims = np.amin(mins), np.amax(maxs) if y_lims[0] == y_lims[1]: y_range = 0.1 * y_lims[0] else: y_range = y_lims[1] - y_lims[0] try: fraction = float(fraction) padding = [fraction, fraction] except TypeError: # is a list padding = fraction if padding != [0.0, 0.0]: y_lims = (y_lims[0] - padding[0] * y_range, y_lims[1] + padding[1] * y_range) return y_lims
[docs]def get_epoch_xticks(x_points): """ Calculates the xticks for the train and validation summary plot. One tick per epoch. Less the larger #epochs is. Parameters ---------- x_points : List A list of the x coordinates of all points. Returns ------- x_ticks_major : numpy.ndarray Array containing the ticks. """ if len(x_points) == 0: raise ValueError("x-coordinates are empty!") minimum, maximum = np.amin(x_points), np.amax(x_points) if maximum - minimum > 0.5: # for longer trainings start_epoch, end_epoch = np.floor(minimum), np.ceil(maximum) # less xticks if there are many epochs n_epochs = end_epoch - start_epoch x_ticks_stepsize = 1 + np.floor(n_epochs / 20.0) x_ticks_major = np.arange( start_epoch, end_epoch + x_ticks_stepsize, x_ticks_stepsize ) else: # for early peeks start_epoch = np.floor(minimum) end_epoch = maximum + minimum - start_epoch x_ticks_major = np.linspace(start_epoch, end_epoch, 6) return x_ticks_major
[docs]def update_summary_plot(orga): """ Plot and save all metrics of the given validation- and train-data into a pdf file, each metric in its own plot. If metric pairs of a variable and its error are found (e.g. e_loss and e_err_loss), they will have the same color and appear back to back in the plot. Parameters ---------- orga : orcanet.core.Organizer Contains all the configurable options in the OrcaNet scripts. """ plt.ioff() pdf_name ="plots", create=True) + "/summary_plot.pdf" # Extract the names of the metrics all_metrics = orga.history.get_metrics() # Sort them all_metrics = sort_metrics(all_metrics) # Plot them w/ custom color cycle colors = [ "#000000", "#332288", "#88CCEE", "#44AA99", "#117733", "#999933", "#DDCC77", "#CC6677", "#882255", "#AA4499", "#661100", "#6699CC", "#AA4466", "#4477AA", ] # ref. color_counter = 0 with PdfPages(pdf_name) as pdf: for metric_no, metric in enumerate(all_metrics): # If this metric is an err metric of a variable, color it the same if all_metrics[metric_no - 1] == metric.replace("_err", ""): color_counter -= 1 orga.history.plot_metric(metric, color=colors[color_counter % len(colors)]) plt.suptitle(os.path.basename(os.path.abspath(orga.cfg.output_folder))) color_counter += 1 pdf.savefig() plt.clf() orga.history.plot_lr() color_counter += 1 pdf.savefig() plt.close()
[docs]def sort_metrics(metric_names): """ Sort a list of metrics, so that errors are right after their variable. The format of the metric names have to be e.g. e_loss and e_err_loss for this to work. Example ---------- >>> sort_metrics( ['e_loss', 'loss', 'e_err_loss', 'dx_err_loss'] ) ['e_loss', 'e_err_loss', 'loss', 'dx_err_loss'] Parameters ---------- metric_names : List List of metric names. Returns ------- sorted_metrics : List List of sorted metric names with the same length as the input. """ sorted_metrics = [0] * len(metric_names) counter = 0 for metric_name in metric_names: if "err_" in metric_name: if metric_name.replace("err_", "") not in metric_names: sorted_metrics[counter] = metric_name counter += 1 continue sorted_metrics[counter] = metric_name counter += 1 err_loss = metric_name.split("_loss")[0] + "_err_loss" if err_loss in metric_names: sorted_metrics[counter] = err_loss counter += 1 assert 0 not in sorted_metrics, ( "Something went wrong with the sorting of " "metrics! Given was {}, output was " "{}. ".format(metric_names, sorted_metrics) ) return sorted_metrics