Coverage for orcanet/utilities/summarize_training.py: 0%

143 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-03-28 14:22 +0000

1import matplotlib.pyplot as plt 

2import argparse 

3import warnings 

4from orcanet.history import HistoryHandler 

5from orcanet.utilities.visualization import TrainValPlotter 

6 

7 

8class Summarizer: 

9 """ 

10 Summarize one or more trainings by giving their orcanet folder(s). 

11 

12 - Plot the training and validation curves in a single plot and show them 

13 - Print info about the best and worst epochs 

14 

15 Parameters 

16 ---------- 

17 folders : str or List, optional 

18 Path to a orcanet folder, or to multiple folder as a list. 

19 [default: CWD]. 

20 metric : str or List, optional 

21 The metric to plot [default: 'loss']. 

22 If its a list: Same length as folders. Plot a different metric for 

23 each folder. 

24 smooth : int, optional 

25 Apply gaussian blur to the train curve with given sigma. 

26 labels : List, optional 

27 Labels for each folder. 

28 noplot : bool 

29 Dont plot the train/val curves [default: False]. 

30 width : float 

31 Scaling of the width of the curves and the marker size [default: 1]. 

32 

33 """ 

34 

35 def __init__( 

36 self, 

37 folders, 

38 metric="loss", 

39 smooth=None, 

40 labels=None, 

41 noplot=False, 

42 width=1.0, 

43 verbose=True, 

44 ): 

45 if not folders: 

46 folders = ["./"] 

47 elif isinstance(folders, str): 

48 folders = [folders] 

49 self.folders = folders 

50 

51 if isinstance(metric, str): 

52 metric = [metric] 

53 if len(metric) == 1: 

54 self.metrics = metric * len(self.folders) 

55 self._unique_metrics = False 

56 else: 

57 if len(metric) != len(folders): 

58 raise ValueError("Need to give exactly one metric per folder!") 

59 self.metrics = metric 

60 self._unique_metrics = True 

61 

62 if labels is None: 

63 self.labels = self.folders 

64 else: 

65 self.labels = labels 

66 

67 self.verbose = verbose 

68 self.smooth = smooth 

69 self.noplot = noplot 

70 self.width = width 

71 self._tvp = None 

72 

73 def summarize(self, show=True): 

74 if not self.noplot: 

75 self._tvp = TrainValPlotter() 

76 

77 min_stats, max_stats = [], [] 

78 if self.verbose: 

79 print("Reading stats of {} trainings...".format(len(self.folders))) 

80 for folder_no in range(len(self.folders)): 

81 try: 

82 min_stat, max_stat = self._summarize_folder(folder_no) 

83 if min_stat is not None: 

84 min_stats.append(min_stat) 

85 if max_stat is not None: 

86 max_stats.append(max_stat) 

87 except Exception as e: 

88 if self.verbose: 

89 print( 

90 f"Warning: Can not summarize {self.folders[folder_no]}" 

91 f", skipping... ({e})" 

92 ) 

93 

94 if self._unique_metrics: 

95 column_title, y_label = ("combined metrics",) * 2 

96 else: 

97 column_title, y_label = self._full_metrics[0], self._metric_names[0] 

98 

99 if self.verbose: 

100 if len(min_stats) > 0: 

101 min_stats.sort() 

102 print("\nMinimum\n-------") 

103 print("{} \t{}\t{}\t{}".format(" ", "Epoch", column_title, "name")) 

104 for i, stat in enumerate(min_stats, 1): 

105 print("{} | \t{}\t{}\t{}".format(i, stat[2], stat[0], stat[1])) 

106 

107 if len(max_stats) > 0: 

108 max_stats.sort(reverse=True) 

109 print("\nMaximum\n-------") 

110 print("{} \t{}\t{}\t{}".format(" ", "Epoch", column_title, "name")) 

111 for i, stat in enumerate(max_stats, 1): 

112 print("{} | \t{}\t{}\t{}".format(i, stat[2], stat[0], stat[1])) 

113 

114 if not self.noplot: 

115 self._tvp.apply_layout( 

116 x_label="Epoch", 

117 y_label=y_label, 

118 grid=True, 

119 legend=True, 

120 ) 

121 if show: 

122 plt.show() 

123 

124 @property 

125 def _metric_names(self): 

126 """E.g. [loss, ...]""" 

127 metric_names = [] 

128 for metric in self.metrics: 

129 if metric.startswith("train_"): 

130 m = metric[6:] 

131 elif metric.startswith("val_"): 

132 m = metric[4:] 

133 else: 

134 m = metric 

135 metric_names.append(m) 

136 return metric_names 

137 

138 @property 

139 def _full_metrics(self): 

140 """E.g. [val_loss, ...]""" 

141 full_metrics = [] 

142 for metric in self.metrics: 

143 if metric.startswith("train_") or metric.startswith("val_"): 

144 full_metrics.append(metric) 

145 else: 

146 full_metrics.append("val_" + metric) 

147 return full_metrics 

148 

149 def _summarize_folder(self, folder_no): 

150 label = self.labels[folder_no] 

151 folder = self.folders[folder_no] 

152 

153 hist = HistoryHandler(folder) 

154 val_data, min_stat, max_stat = None, None, None 

155 # read data from summary file 

156 try: 

157 smry_met_name = self._full_metrics[folder_no] 

158 max_line = hist.get_best_epoch_info(metric=smry_met_name, mini=False) 

159 min_line = hist.get_best_epoch_info(metric=smry_met_name, mini=True) 

160 min_stat = [min_line[smry_met_name], label, min_line["Epoch"]] 

161 max_stat = [max_line[smry_met_name], label, max_line["Epoch"]] 

162 

163 summary_data = hist.get_summary_data() 

164 val_data = [ 

165 summary_data["Epoch"], 

166 summary_data[self._full_metrics[folder_no]], 

167 ] 

168 except OSError: 

169 if self.verbose: 

170 print(f"Warning: No summary file found for {folder}") 

171 

172 except ValueError as e: 

173 if self.verbose: 

174 print(f"Error reading summary file {hist.summary_file} ({e})") 

175 

176 # read data from training files 

177 full_train_data = hist.get_train_data() 

178 train_data = [ 

179 full_train_data["Batch_float"], 

180 full_train_data[self._metric_names[folder_no]], 

181 ] 

182 

183 if not self.noplot: 

184 if len(self.labels) == 1: 

185 train_label, val_label = "training", "validation" 

186 elif val_data is None: 

187 train_label, val_label = label, None 

188 else: 

189 train_label, val_label = None, label 

190 

191 self._tvp.plot_curves( 

192 train_data=train_data, 

193 val_data=val_data, 

194 train_label=train_label, 

195 val_label=val_label, 

196 smooth_sigma=self.smooth, 

197 tlw=0.5 * self.width, 

198 vlw=0.5 * self.width, 

199 vms=3 * self.width ** 0.5, 

200 ) 

201 return min_stat, max_stat 

202 

203 def summarize_dirs(self): 

204 """ 

205 Get the best and worst epochs of all given folders as a dict. 

206 

207 Returns 

208 ------- 

209 minima : dict 

210 Keys : Name of folder. 

211 Values : [Epoch, metric] of where the metric is lowest. 

212 maxima : dict 

213 As above, but for where the metric is highest. 

214 

215 """ 

216 minima, maxima = {}, {} 

217 for folder_no, folder in enumerate(self.folders): 

218 hist = HistoryHandler(folder) 

219 smry_met_name = self._full_metrics[folder_no] 

220 try: 

221 max_line = hist.get_best_epoch_info(metric=smry_met_name, mini=False) 

222 min_line = hist.get_best_epoch_info(metric=smry_met_name, mini=True) 

223 except OSError as e: 

224 warnings.warn(str(e)) 

225 continue 

226 

227 minima[folder] = [min_line["Epoch"], min_line[smry_met_name]] 

228 maxima[folder] = [max_line["Epoch"], max_line[smry_met_name]] 

229 return minima, maxima 

230 

231 

232def summarize(**kwargs): 

233 for key in list(kwargs.keys()): 

234 if kwargs[key] is None: 

235 kwargs.pop(key) 

236 Summarizer(**kwargs).summarize() 

237 

238 

239def get_parser(): 

240 parser = argparse.ArgumentParser( 

241 description=str(Summarizer.__doc__), 

242 formatter_class=argparse.RawTextHelpFormatter, 

243 ) 

244 parser.add_argument("folders", type=str, nargs="*") 

245 parser.add_argument("--metric", type=str, nargs="*") 

246 parser.add_argument("--smooth", nargs="?", type=int) 

247 parser.add_argument("--width", nargs="?", type=float) 

248 parser.add_argument("--labels", nargs="*", type=str) 

249 parser.add_argument("--noplot", action="store_true") 

250 return parser 

251 

252 

253def main(): 

254 parser = get_parser() 

255 args = vars(parser.parse_args()) 

256 summarize(**args) 

257 

258 

259if __name__ == "__main__": 

260 main()