Source code for orcanet.utilities.summarize_training

import matplotlib.pyplot as plt
import argparse
import warnings
from orcanet.history import HistoryHandler
from orcanet.utilities.visualization import TrainValPlotter


[docs]class Summarizer: """ Summarize one or more trainings by giving their orcanet folder(s). - Plot the training and validation curves in a single plot and show them - Print info about the best and worst epochs Parameters ---------- folders : str or List, optional Path to a orcanet folder, or to multiple folder as a list. [default: CWD]. metric : str or List, optional The metric to plot [default: 'loss']. If its a list: Same length as folders. Plot a different metric for each folder. smooth : int, optional Apply gaussian blur to the train curve with given sigma. labels : List, optional Labels for each folder. noplot : bool Dont plot the train/val curves [default: False]. width : float Scaling of the width of the curves and the marker size [default: 1]. """ def __init__( self, folders, metric="loss", smooth=None, labels=None, noplot=False, width=1.0, verbose=True, ): if not folders: folders = ["./"] elif isinstance(folders, str): folders = [folders] self.folders = folders if isinstance(metric, str): metric = [metric] if len(metric) == 1: self.metrics = metric * len(self.folders) self._unique_metrics = False else: if len(metric) != len(folders): raise ValueError("Need to give exactly one metric per folder!") self.metrics = metric self._unique_metrics = True if labels is None: self.labels = self.folders else: self.labels = labels self.verbose = verbose self.smooth = smooth self.noplot = noplot self.width = width self._tvp = None
[docs] def summarize(self, show=True): if not self.noplot: self._tvp = TrainValPlotter() min_stats, max_stats = [], [] if self.verbose: print("Reading stats of {} trainings...".format(len(self.folders))) for folder_no in range(len(self.folders)): try: min_stat, max_stat = self._summarize_folder(folder_no) if min_stat is not None: min_stats.append(min_stat) if max_stat is not None: max_stats.append(max_stat) except Exception as e: if self.verbose: print( f"Warning: Can not summarize {self.folders[folder_no]}" f", skipping... ({e})" ) if self._unique_metrics: column_title, y_label = ("combined metrics",) * 2 else: column_title, y_label = self._full_metrics[0], self._metric_names[0] if self.verbose: if len(min_stats) > 0: min_stats.sort() print("\nMinimum\n-------") print("{} \t{}\t{}\t{}".format(" ", "Epoch", column_title, "name")) for i, stat in enumerate(min_stats, 1): print("{} | \t{}\t{}\t{}".format(i, stat[2], stat[0], stat[1])) if len(max_stats) > 0: max_stats.sort(reverse=True) print("\nMaximum\n-------") print("{} \t{}\t{}\t{}".format(" ", "Epoch", column_title, "name")) for i, stat in enumerate(max_stats, 1): print("{} | \t{}\t{}\t{}".format(i, stat[2], stat[0], stat[1])) if not self.noplot: self._tvp.apply_layout( x_label="Epoch", y_label=y_label, grid=True, legend=True, ) if show: plt.show()
@property def _metric_names(self): """E.g. [loss, ...]""" metric_names = [] for metric in self.metrics: if metric.startswith("train_"): m = metric[6:] elif metric.startswith("val_"): m = metric[4:] else: m = metric metric_names.append(m) return metric_names @property def _full_metrics(self): """E.g. [val_loss, ...]""" full_metrics = [] for metric in self.metrics: if metric.startswith("train_") or metric.startswith("val_"): full_metrics.append(metric) else: full_metrics.append("val_" + metric) return full_metrics def _summarize_folder(self, folder_no): label = self.labels[folder_no] folder = self.folders[folder_no] hist = HistoryHandler(folder) val_data, min_stat, max_stat = None, None, None # read data from summary file try: smry_met_name = self._full_metrics[folder_no] max_line = hist.get_best_epoch_info(metric=smry_met_name, mini=False) min_line = hist.get_best_epoch_info(metric=smry_met_name, mini=True) min_stat = [min_line[smry_met_name], label, min_line["Epoch"]] max_stat = [max_line[smry_met_name], label, max_line["Epoch"]] summary_data = hist.get_summary_data() val_data = [ summary_data["Epoch"], summary_data[self._full_metrics[folder_no]], ] except OSError: if self.verbose: print(f"Warning: No summary file found for {folder}") except ValueError as e: if self.verbose: print(f"Error reading summary file {hist.summary_file} ({e})") # read data from training files full_train_data = hist.get_train_data() train_data = [ full_train_data["Batch_float"], full_train_data[self._metric_names[folder_no]], ] if not self.noplot: if len(self.labels) == 1: train_label, val_label = "training", "validation" elif val_data is None: train_label, val_label = label, None else: train_label, val_label = None, label self._tvp.plot_curves( train_data=train_data, val_data=val_data, train_label=train_label, val_label=val_label, smooth_sigma=self.smooth, tlw=0.5 * self.width, vlw=0.5 * self.width, vms=3 * self.width ** 0.5, ) return min_stat, max_stat
[docs] def summarize_dirs(self): """ Get the best and worst epochs of all given folders as a dict. Returns ------- minima : dict Keys : Name of folder. Values : [Epoch, metric] of where the metric is lowest. maxima : dict As above, but for where the metric is highest. """ minima, maxima = {}, {} for folder_no, folder in enumerate(self.folders): hist = HistoryHandler(folder) smry_met_name = self._full_metrics[folder_no] try: max_line = hist.get_best_epoch_info(metric=smry_met_name, mini=False) min_line = hist.get_best_epoch_info(metric=smry_met_name, mini=True) except OSError as e: warnings.warn(str(e)) continue minima[folder] = [min_line["Epoch"], min_line[smry_met_name]] maxima[folder] = [max_line["Epoch"], max_line[smry_met_name]] return minima, maxima
[docs]def summarize(**kwargs): for key in list(kwargs.keys()): if kwargs[key] is None: kwargs.pop(key) Summarizer(**kwargs).summarize()
[docs]def get_parser(): parser = argparse.ArgumentParser( description=str(Summarizer.__doc__), formatter_class=argparse.RawTextHelpFormatter, ) parser.add_argument("folders", type=str, nargs="*") parser.add_argument("--metric", type=str, nargs="*") parser.add_argument("--smooth", nargs="?", type=int) parser.add_argument("--width", nargs="?", type=float) parser.add_argument("--labels", nargs="*", type=str) parser.add_argument("--noplot", action="store_true") return parser
[docs]def main(): parser = get_parser() args = vars(parser.parse_args()) summarize(**args)
if __name__ == "__main__": main()