Coverage for orcanet/history.py: 93%

134 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-03-28 14:22 +0000

1import numpy as np 

2import os 

3from orcanet.utilities.visualization import plot_history 

4from orcanet.in_out import get_subfolder 

5 

6 

7class HistoryHandler: 

8 """ 

9 For reading and plotting data from summary and train log files. 

10 

11 """ 

12 

13 def __init__(self, main_folder): 

14 self.main_folder = main_folder 

15 self.summary_filename = "summary.txt" 

16 

17 @property 

18 def summary_file(self): 

19 main_folder = self.main_folder 

20 if not main_folder[-1] == "/": 

21 main_folder += "/" 

22 return main_folder + self.summary_filename 

23 

24 @property 

25 def train_log_folder(self): 

26 return get_subfolder(self.main_folder, "train_log") 

27 

28 def plot_metric(self, metric_name, **kwargs): 

29 """ 

30 Plot the training and validation history of a metric. 

31 

32 This will read out data from the summary file, as well as 

33 all training log files, and plot them over the epoch. 

34 

35 Parameters 

36 ---------- 

37 metric_name : str 

38 Name of the metric to be plotted over the epoch. This name is what 

39 was written in the head line of the summary.txt file, except without 

40 the train_ or val_ prefix. 

41 kwargs 

42 Keyword arguments for the plot_history function. 

43 

44 """ 

45 summary_data = self.get_summary_data() 

46 full_train_data = self.get_train_data() 

47 summary_label = "val_" + metric_name 

48 

49 if metric_name not in full_train_data.dtype.names: 

50 raise ValueError( 

51 "Train log metric name {} unknown, must be one of {}".format( 

52 metric_name, self.get_metrics() 

53 ) 

54 ) 

55 if summary_label not in summary_data.dtype.names: 

56 raise ValueError( 

57 "Summary metric name {} unknown, must be one of {}".format( 

58 summary_label, self.get_metrics() 

59 ) 

60 ) 

61 

62 if summary_data["Epoch"].shape == (0,): 

63 # When no lines are present in the summary.txt file. 

64 val_data = None 

65 else: 

66 val_data = [summary_data["Epoch"], summary_data[summary_label]] 

67 

68 if full_train_data["Batch_float"].shape == (0,): 

69 # When no lines are present 

70 raise ValueError( 

71 "Can not make summary plot: Training log files " "contain no data!" 

72 ) 

73 else: 

74 train_data = [full_train_data["Batch_float"], full_train_data[metric_name]] 

75 

76 # if no validation has been done yet 

77 if np.all(np.isnan(val_data)[1]): 

78 val_data = None 

79 

80 if "y_label" not in kwargs: 

81 kwargs["y_label"] = metric_name 

82 

83 plot_history(train_data, val_data, **kwargs) 

84 

85 def plot_lr(self, **kwargs): 

86 """ 

87 Plot the learning rate over the epochs. 

88 

89 Parameters 

90 ---------- 

91 kwargs 

92 Keyword arguments for the plot_history function. 

93 

94 Returns 

95 ------- 

96 fig : matplotlib.figure.Figure 

97 The plot. 

98 

99 """ 

100 summary_data = self.get_summary_data() 

101 

102 epoch = summary_data["Epoch"] 

103 lr = summary_data["LR"] 

104 # plot learning rate like val data (connected dots) 

105 val_data = (epoch, lr) 

106 

107 if "y_label" not in kwargs: 

108 kwargs["y_label"] = "Learning rate" 

109 if "legend" not in kwargs: 

110 kwargs["legend"] = False 

111 

112 plot_history( 

113 train_data=None, val_data=val_data, logy=True, y_lims=None, **kwargs 

114 ) 

115 

116 def get_metrics(self): 

117 """ 

118 Get the name of the metrics from the first line in the summary file. 

119 

120 This will be the actual name of the metric, i.e. "loss" and not 

121 "train_loss" or "val_loss". 

122 

123 Returns 

124 ------- 

125 all_metrics : List 

126 A list of the metrics. 

127 

128 """ 

129 summary_data = self.get_summary_data() 

130 all_metrics = [] 

131 for keyword in summary_data.dtype.names: 

132 if keyword == "Epoch" or keyword == "LR": 

133 continue 

134 if "train_" in keyword: 

135 keyword = keyword.split("train_")[-1] 

136 else: 

137 keyword = keyword.split("val_")[-1] 

138 if keyword not in all_metrics: 

139 all_metrics.append(keyword) 

140 return all_metrics 

141 

142 def get_summary_data(self): 

143 """ 

144 Read out the summary file in the output folder. 

145 

146 Returns 

147 ------- 

148 summary_data : ndarray 

149 Numpy structured array with the column names as datatypes. 

150 Its shape is the number of lines with data. 

151 

152 """ 

153 summary_data = self._load_txt(self.summary_file) 

154 if summary_data.shape == (): 

155 # When only one line is present 

156 summary_data = summary_data.reshape( 

157 1, 

158 ) 

159 return summary_data 

160 

161 def get_best_epoch_info(self, metric="val_loss", mini=True): 

162 """ 

163 Get the line in the summary file where the given metric is best, i.e. 

164 either minimal or maximal. 

165 

166 Parameters 

167 ---------- 

168 metric : str 

169 Which metric to look up. 

170 mini : bool 

171 If true, look up the minimum. Else the maximum. 

172 

173 Raises 

174 ------ 

175 ValueError 

176 If there is no best line (e.g. no validation has been done). 

177 

178 """ 

179 summary_data = self.get_summary_data() 

180 metric_data = summary_data[metric] 

181 

182 if all(np.isnan(metric_data)): 

183 raise ValueError("Can not find best epoch in summary.txt") 

184 

185 if mini: 

186 opt_loss = np.nanmin(metric_data) 

187 else: 

188 opt_loss = np.nanmax(metric_data) 

189 

190 best_index = np.where(metric_data == opt_loss)[0] 

191 # if multiple epochs with same loss, take first 

192 best_index = min(best_index) 

193 best_line = summary_data[best_index] 

194 return best_line 

195 

196 def get_best_epoch_fileno(self, metric="val_loss", mini=True): 

197 """ 

198 Get the epoch/fileno tuple where the given metric is smallest. 

199 """ 

200 best_line = self.get_best_epoch_info(metric=metric, mini=mini) 

201 best_epoch_float = best_line["Epoch"] 

202 epoch, fileno = self._transform_epoch(best_epoch_float) 

203 return epoch, fileno 

204 

205 def _transform_epoch(self, epoch_float): 

206 """ 

207 Transfrom the epoch_float read from a file to a tuple. 

208 

209 (By just counting the number of lines in the given epoch). 

210 TODO Hacky, (epoch, filno) should probably be written to the summary. 

211 """ 

212 summary_data = self.get_summary_data() 

213 

214 epoch = int(np.floor(epoch_float - 1e-8)) 

215 # all lines in the epoch of epoch_float 

216 indices = np.where(np.floor(summary_data["Epoch"] - 1e-8) == epoch)[0] 

217 lines = summary_data[indices] 

218 fileno = int(np.where(lines["Epoch"] == epoch_float)[0]) + 1 

219 epoch += 1 

220 return epoch, fileno 

221 

222 def get_column_names(self): 

223 """ 

224 Get the str in the first line in each column. 

225 

226 Returns 

227 ------- 

228 tuple : column_names 

229 The names in the same order as they appear in the summary.txt. 

230 

231 """ 

232 summary_data = self.get_summary_data() 

233 column_names = summary_data.dtype.names 

234 return column_names 

235 

236 def get_train_data(self): 

237 """ 

238 Read out all training logfiles in the output folder. 

239 

240 Read out the data from the summary.txt file, and from all training 

241 log files in the train_log folder, which is in the same directory 

242 as the summary.txt file. 

243 

244 Returns 

245 ------- 

246 summary_data : numpy.ndarray 

247 Structured array containing the data from the summary.txt file. 

248 Its shape is the number of lines with data. 

249 

250 """ 

251 # list of all files in the train_log folder of this model 

252 files = os.listdir(self.train_log_folder) 

253 train_file_data = [] 

254 for file in files: 

255 if not (file.startswith("log_epoch_") and file.endswith(".txt")): 

256 continue 

257 filepath = os.path.join(self.train_log_folder, file) 

258 if os.path.getsize(filepath) == 0: 

259 continue 

260 

261 # file is sth like "log_epoch_1_file_2.txt", extract epoch & fileno: 

262 epoch, file_no = [int(file.split(".")[0].split("_")[i]) for i in [2, 4]] 

263 file_data = self._load_txt(filepath) 

264 train_file_data.append([[epoch, file_no], file_data]) 

265 

266 if len(train_file_data) == 0: 

267 raise OSError(f"No train files found in {self.train_log_folder}!") 

268 

269 # sort so that earlier epochs come first 

270 train_file_data.sort() 

271 full_train_data = train_file_data[0][1] 

272 for [epoch, file_no], file_data in train_file_data[1:]: 

273 full_train_data = np.append(full_train_data, file_data) 

274 

275 if full_train_data.shape == (): 

276 # When only one line is present 

277 full_train_data = full_train_data.reshape( 

278 1, 

279 ) 

280 return full_train_data 

281 

282 def get_state(self): 

283 """ 

284 Get the state of a training. 

285 

286 For every line in the summary logfile, get a dict with the epoch 

287 as a float, and is_trained and is_validated bools. 

288 

289 Returns 

290 ------- 

291 state_dicts : List 

292 List of dicts. 

293 

294 """ 

295 summary_data = self.get_summary_data() 

296 state_dicts = [] 

297 names = summary_data.dtype.names 

298 

299 for line in summary_data: 

300 val_losses, train_losses = {}, {} 

301 for name in names: 

302 if name.startswith("val_"): 

303 val_losses[name] = line[name] 

304 elif name.startswith("train_"): 

305 train_losses[name] = line[name] 

306 elif name not in ["Epoch", "LR"]: 

307 raise NameError( 

308 "Invalid summary file: Invalid column name {}: must be " 

309 "either Epoch, LR, or start with val_ or train_".format(name) 

310 ) 

311 # if theres any not-nan entry, consider it completed 

312 is_trained = any(~np.isnan(tuple(train_losses.values()))) 

313 is_validated = any(~np.isnan(tuple(val_losses.values()))) 

314 

315 state_dicts.append( 

316 { 

317 "epoch": line["Epoch"], 

318 "is_trained": is_trained, 

319 "is_validated": is_validated, 

320 } 

321 ) 

322 

323 return state_dicts 

324 

325 @staticmethod 

326 def _load_txt(filepath): 

327 # TODO suboptimal that n/a gets replaced by np.nan, because this 

328 # means that legitamte, not availble cells can not be distinguished 

329 # from failed 'nan' metric values produced by training. 

330 file_data = np.genfromtxt( 

331 filepath, 

332 names=True, 

333 delimiter="|", 

334 autostrip=True, 

335 comments="--", 

336 missing_values="n/a", 

337 filling_values=np.nan, 

338 ) 

339 # replace inf with nan so it can be plotted 

340 for column_name in file_data.dtype.names: 

341 x = file_data[column_name] 

342 x[np.isinf(x)] = np.nan 

343 return file_data