Source code for orcanet.history

import numpy as np
import os
from orcanet.utilities.visualization import plot_history
from orcanet.in_out import get_subfolder


[docs]class HistoryHandler: """ For reading and plotting data from summary and train log files. """ def __init__(self, main_folder): self.main_folder = main_folder self.summary_filename = "summary.txt" @property
[docs] def summary_file(self): main_folder = self.main_folder if not main_folder[-1] == "/": main_folder += "/" return main_folder + self.summary_filename
@property
[docs] def train_log_folder(self): return get_subfolder(self.main_folder, "train_log")
[docs] def plot_metric(self, metric_name, **kwargs): """ Plot the training and validation history of a metric. This will read out data from the summary file, as well as all training log files, and plot them over the epoch. Parameters ---------- metric_name : str Name of the metric to be plotted over the epoch. This name is what was written in the head line of the summary.txt file, except without the train_ or val_ prefix. kwargs Keyword arguments for the plot_history function. """ summary_data = self.get_summary_data() full_train_data = self.get_train_data() summary_label = "val_" + metric_name if metric_name not in full_train_data.dtype.names: raise ValueError( "Train log metric name {} unknown, must be one of {}".format( metric_name, self.get_metrics() ) ) if summary_label not in summary_data.dtype.names: raise ValueError( "Summary metric name {} unknown, must be one of {}".format( summary_label, self.get_metrics() ) ) if summary_data["Epoch"].shape == (0,): # When no lines are present in the summary.txt file. val_data = None else: val_data = [summary_data["Epoch"], summary_data[summary_label]] if full_train_data["Batch_float"].shape == (0,): # When no lines are present raise ValueError( "Can not make summary plot: Training log files " "contain no data!" ) else: train_data = [full_train_data["Batch_float"], full_train_data[metric_name]] # if no validation has been done yet if np.all(np.isnan(val_data)[1]): val_data = None if "y_label" not in kwargs: kwargs["y_label"] = metric_name plot_history(train_data, val_data, **kwargs)
[docs] def plot_lr(self, **kwargs): """ Plot the learning rate over the epochs. Parameters ---------- kwargs Keyword arguments for the plot_history function. Returns ------- fig : matplotlib.figure.Figure The plot. """ summary_data = self.get_summary_data() epoch = summary_data["Epoch"] lr = summary_data["LR"] # plot learning rate like val data (connected dots) val_data = (epoch, lr) if "y_label" not in kwargs: kwargs["y_label"] = "Learning rate" if "legend" not in kwargs: kwargs["legend"] = False plot_history( train_data=None, val_data=val_data, logy=True, y_lims=None, **kwargs
)
[docs] def get_metrics(self): """ Get the name of the metrics from the first line in the summary file. This will be the actual name of the metric, i.e. "loss" and not "train_loss" or "val_loss". Returns ------- all_metrics : List A list of the metrics. """ summary_data = self.get_summary_data() all_metrics = [] for keyword in summary_data.dtype.names: if keyword == "Epoch" or keyword == "LR": continue if "train_" in keyword: keyword = keyword.split("train_")[-1] else: keyword = keyword.split("val_")[-1] if keyword not in all_metrics: all_metrics.append(keyword) return all_metrics
[docs] def get_summary_data(self): """ Read out the summary file in the output folder. Returns ------- summary_data : ndarray Numpy structured array with the column names as datatypes. Its shape is the number of lines with data. """ summary_data = self._load_txt(self.summary_file) if summary_data.shape == (): # When only one line is present summary_data = summary_data.reshape( 1, ) return summary_data
[docs] def get_best_epoch_info(self, metric="val_loss", mini=True): """ Get the line in the summary file where the given metric is best, i.e. either minimal or maximal. Parameters ---------- metric : str Which metric to look up. mini : bool If true, look up the minimum. Else the maximum. Raises ------ ValueError If there is no best line (e.g. no validation has been done). """ summary_data = self.get_summary_data() metric_data = summary_data[metric] if all(np.isnan(metric_data)): raise ValueError("Can not find best epoch in summary.txt") if mini: opt_loss = np.nanmin(metric_data) else: opt_loss = np.nanmax(metric_data) best_index = np.where(metric_data == opt_loss)[0] # if multiple epochs with same loss, take first best_index = min(best_index) best_line = summary_data[best_index] return best_line
[docs] def get_best_epoch_fileno(self, metric="val_loss", mini=True): """ Get the epoch/fileno tuple where the given metric is smallest. """ best_line = self.get_best_epoch_info(metric=metric, mini=mini) best_epoch_float = best_line["Epoch"] epoch, fileno = self._transform_epoch(best_epoch_float) return epoch, fileno
def _transform_epoch(self, epoch_float): """ Transfrom the epoch_float read from a file to a tuple. (By just counting the number of lines in the given epoch). TODO Hacky, (epoch, filno) should probably be written to the summary. """ summary_data = self.get_summary_data() epoch = int(np.floor(epoch_float - 1e-8)) # all lines in the epoch of epoch_float indices = np.where(np.floor(summary_data["Epoch"] - 1e-8) == epoch)[0] lines = summary_data[indices] fileno = int(np.where(lines["Epoch"] == epoch_float)[0]) + 1 epoch += 1 return epoch, fileno
[docs] def get_column_names(self): """ Get the str in the first line in each column. Returns ------- tuple : column_names The names in the same order as they appear in the summary.txt. """ summary_data = self.get_summary_data() column_names = summary_data.dtype.names return column_names
[docs] def get_train_data(self): """ Read out all training logfiles in the output folder. Read out the data from the summary.txt file, and from all training log files in the train_log folder, which is in the same directory as the summary.txt file. Returns ------- summary_data : numpy.ndarray Structured array containing the data from the summary.txt file. Its shape is the number of lines with data. """ # list of all files in the train_log folder of this model files = os.listdir(self.train_log_folder) train_file_data = [] for file in files: if not (file.startswith("log_epoch_") and file.endswith(".txt")): continue filepath = os.path.join(self.train_log_folder, file) if os.path.getsize(filepath) == 0: continue # file is sth like "log_epoch_1_file_2.txt", extract epoch & fileno: epoch, file_no = [int(file.split(".")[0].split("_")[i]) for i in [2, 4]] file_data = self._load_txt(filepath) train_file_data.append([[epoch, file_no], file_data]) if len(train_file_data) == 0: raise OSError(f"No train files found in {self.train_log_folder}!") # sort so that earlier epochs come first train_file_data.sort() full_train_data = train_file_data[0][1] for [epoch, file_no], file_data in train_file_data[1:]: full_train_data = np.append(full_train_data, file_data) if full_train_data.shape == (): # When only one line is present full_train_data = full_train_data.reshape( 1, ) return full_train_data
[docs] def get_state(self): """ Get the state of a training. For every line in the summary logfile, get a dict with the epoch as a float, and is_trained and is_validated bools. Returns ------- state_dicts : List List of dicts. """ summary_data = self.get_summary_data() state_dicts = [] names = summary_data.dtype.names for line in summary_data: val_losses, train_losses = {}, {} for name in names: if name.startswith("val_"): val_losses[name] = line[name] elif name.startswith("train_"): train_losses[name] = line[name] elif name not in ["Epoch", "LR"]: raise NameError( "Invalid summary file: Invalid column name {}: must be " "either Epoch, LR, or start with val_ or train_".format(name) ) # if theres any not-nan entry, consider it completed is_trained = any(~np.isnan(tuple(train_losses.values()))) is_validated = any(~np.isnan(tuple(val_losses.values()))) state_dicts.append( { "epoch": line["Epoch"], "is_trained": is_trained, "is_validated": is_validated, } ) return state_dicts
@staticmethod def _load_txt(filepath): # TODO suboptimal that n/a gets replaced by np.nan, because this # means that legitamte, not availble cells can not be distinguished # from failed 'nan' metric values produced by training. file_data = np.genfromtxt( filepath, names=True, delimiter="|", autostrip=True, comments="--", missing_values="n/a", filling_values=np.nan, ) # replace inf with nan so it can be plotted for column_name in file_data.dtype.names: x = file_data[column_name] x[np.isinf(x)] = np.nan return file_data