Coverage for orcanet/utilities/visualization.py: 91%

156 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-03-28 14:22 +0000

1# -*- coding: utf-8 -*- 

2""" 

3Visualization tools used without Keras. 

4Makes performance graphs for training and validating. 

5""" 

6import os 

7import numpy as np 

8 

9import matplotlib.pyplot as plt 

10from matplotlib.backends.backend_pdf import PdfPages 

11 

12 

13class TrainValPlotter: 

14 """ 

15 Class for plotting train/val curves. 

16 

17 Instructions 

18 ------------ 

19 1. Use tvp.plot_curves(train, val) once or more on pairs of 

20 train/val data. 

21 2. When all lines are plotted, use tvp.apply_layout() once for proper 

22 scaling, ylims, etc. 

23 

24 """ 

25 

26 def __init__(self): 

27 # White space added below and above points 

28 self.y_lim_padding = [0.10, 0.25] 

29 # Store all plotted points for setting x/y lims 

30 self._xpoints_train = np.array([]) 

31 self._xpoints_val = np.array([]) 

32 self._ypoints_train = np.array([]) 

33 self._ypoints_val = np.array([]) 

34 

35 def plot_curves( 

36 self, 

37 train_data, 

38 val_data=None, 

39 train_label="training", 

40 val_label="validation", 

41 color=None, 

42 smooth_sigma=None, 

43 tlw=0.5, 

44 vlw=0.5, 

45 vms=3, 

46 ): 

47 """ 

48 Plot a training and optionally a validation line. 

49 

50 The data can contain nan's. 

51 

52 Parameters 

53 ---------- 

54 train_data : List 

55 X data [0] and y data [1] of the train curve. Will be plotted as 

56 connected dots. 

57 val_data : List, optional 

58 Optional X data [0] and y data [1] of the validation curve. 

59 Will be plotted as a faint solid line of the same color as train. 

60 train_label : str, optional 

61 Label for the train line in the legend. 

62 val_label : str, optional 

63 Label for the validation line in the legend. 

64 color : str, optional 

65 Color used for the train/val line. 

66 smooth_sigma : int, optional 

67 Apply gaussian blur to the train curve with given sigma. 

68 tlw : float 

69 Linewidth of train curve. 

70 vlw : float 

71 Linewidth of val curve. 

72 vms : float 

73 Markersize of the val curve. 

74 

75 """ 

76 if train_data is None and val_data is None: 

77 raise ValueError("Can not plot when no train and val data is given.") 

78 

79 if train_data is not None: 

80 epoch, y_data = train_data 

81 if smooth_sigma is not None: 

82 y_data = gaussian_smooth(y_data, smooth_sigma) 

83 

84 self._xpoints_train = np.concatenate((self._xpoints_train, epoch)) 

85 self._ypoints_train = np.concatenate((self._ypoints_train, y_data)) 

86 

87 train_plot = plt.plot( 

88 epoch, 

89 y_data, 

90 color=color, 

91 ls="-", 

92 zorder=3, 

93 label=train_label, 

94 lw=tlw, 

95 alpha=0.5, 

96 ) 

97 train_color = train_plot[0].get_color() 

98 else: 

99 train_color = color 

100 

101 if val_data is not None: 

102 self._xpoints_val = np.concatenate((self._xpoints_val, val_data[0])) 

103 self._ypoints_val = np.concatenate((self._ypoints_val, val_data[1])) 

104 

105 val_data_clean = skip_nans(val_data) 

106 # val plot always has the same color as the train plot 

107 plt.plot( 

108 val_data_clean[0], 

109 val_data_clean[1], 

110 color=train_color, 

111 marker="o", 

112 zorder=3, 

113 lw=vlw, 

114 markersize=vms, 

115 label=val_label, 

116 ) 

117 

118 def apply_layout( 

119 self, 

120 title=None, 

121 x_label="Epoch", 

122 y_label=None, 

123 grid=True, 

124 legend=True, 

125 x_lims=None, 

126 y_lims="auto", 

127 x_ticks="auto", 

128 logy=False, 

129 ): 

130 """ 

131 Apply given layout. 

132 Can calculate good y_lims and x_ticks automatically. 

133 

134 Parameters 

135 ---------- 

136 title : str 

137 Title of the plot. 

138 x_label : str 

139 X label of the plot. 

140 y_label : str 

141 Y label of the plot. 

142 grid : bool 

143 If true, show a grid. 

144 legend : bool 

145 If true, show a legend. 

146 x_lims : List 

147 X limits of the data. 

148 y_lims : List or str 

149 Y limits of the data. "auto" for auto-calculation. 

150 x_ticks : List 

151 Positions of the major x ticks. 

152 logy : bool 

153 If true, make y axis log. 

154 

155 """ 

156 if logy: 

157 plt.yscale("log") 

158 if x_ticks is not None: 

159 if x_ticks == "auto": 

160 all_x_points = np.concatenate((self._xpoints_train, self._xpoints_val)) 

161 x_ticks = get_epoch_xticks(all_x_points) 

162 else: 

163 x_ticks = x_ticks 

164 plt.xticks(x_ticks) 

165 

166 if x_lims is not None: 

167 plt.xlim(x_lims) 

168 

169 if y_lims is not None: 

170 if y_lims == "auto": 

171 y_lims = get_ylims( 

172 self._ypoints_train, 

173 self._ypoints_val, 

174 fraction=self.y_lim_padding, 

175 ) 

176 else: 

177 y_lims = y_lims 

178 plt.ylim(y_lims) 

179 

180 if legend: 

181 plt.legend(loc="upper right") 

182 

183 plt.xlabel(x_label) 

184 plt.ylabel(y_label) 

185 

186 if title is not None: 

187 title = plt.title(title) 

188 title.set_position([0.5, 1.04]) 

189 

190 if grid: 

191 plt.grid(True, zorder=0, linestyle="dotted") 

192 

193 

194def gaussian_smooth(y, sigma, truncate=4): 

195 """Smooth a 1d ndarray with a gaussian filter.""" 

196 # kernel_width = 2 * sigma * truncate + 1 

197 kernel_x = np.arange(-truncate * sigma, truncate * sigma + 1) 

198 kernel = _gauss(kernel_x, 0, sigma) 

199 y = np.pad(np.asarray(y), int(len(kernel) / 2), "edge") 

200 blurred = np.convolve(y, kernel, "valid") 

201 return blurred 

202 

203 

204def _gauss(x, mu=0, sigma=1): 

205 return (1 / (np.sqrt(2 * np.pi) * sigma)) * np.exp( 

206 -np.power(x - mu, 2.0) / (2 * np.power(sigma, 2.0)) 

207 ) 

208 

209 

210def plot_history( 

211 train_data, 

212 val_data=None, 

213 train_label="training", 

214 val_label="validation", 

215 color=None, 

216 **kwargs 

217): 

218 """ 

219 Plot the train/val curves in a single plot. 

220 

221 For backward compat. Functionality moved to TrainValPlotter 

222 

223 """ 

224 tvp = TrainValPlotter() 

225 tvp.plot_curves( 

226 train_data, val_data, train_label=train_label, val_label=val_label, color=color 

227 ) 

228 tvp.apply_layout(**kwargs) 

229 

230 

231def skip_nans(data): 

232 """ 

233 Skip over nan values, so that all dots are connected. 

234 

235 Parameters 

236 ---------- 

237 data : List 

238 Contains x and y data as ndarrays. The y values may contain nans. 

239 

240 Returns 

241 ------- 

242 data_clean : List 

243 Contains x and y data as ndarrays. Points with y=nan are skipped. 

244 

245 """ 

246 not_nan = ~np.isnan(data[1]) 

247 data_clean = data[0][not_nan], data[1][not_nan] 

248 return data_clean 

249 

250 

251def get_ylims(y_points_train, y_points_val=None, fraction=0.25): 

252 """ 

253 Get the y limits for the summary plot. 

254 

255 For the training data, limits are calculated while ignoring data points 

256 which are far from the median (in terms of the median distance 

257 from the median). 

258 This is because there are outliers sometimes in the training data, 

259 especially early on in the training. 

260 

261 Parameters 

262 ---------- 

263 y_points_train : List 

264 y data of the train curve. 

265 y_points_val : List or None 

266 Y data of the validation curve. 

267 fraction : float or List 

268 How much whitespace of the total y range is added below and above 

269 the lines. 

270 

271 Returns 

272 ------- 

273 y_lims : tuple 

274 Minimum, maximum of the data. 

275 

276 """ 

277 assert not ( 

278 y_points_train is None and y_points_val is None 

279 ), "train and val data are None" 

280 

281 def reject_outliers(data, threshold): 

282 d = np.abs(data - np.median(data)) 

283 mdev = np.median(d) 

284 s = d / mdev if mdev else 0.0 

285 no_outliers = data[s < threshold] 

286 lims = np.amin(no_outliers), np.amax(no_outliers) 

287 return lims 

288 

289 mins, maxs = [], [] 

290 if y_points_train is not None and len(y_points_train) != 0: 

291 y_train = y_points_train[~np.isnan(y_points_train)] 

292 y_lims_train = reject_outliers(y_train, 5) 

293 mins.append(y_lims_train[0]) 

294 maxs.append(y_lims_train[1]) 

295 

296 if y_points_val is not None and len(y_points_val) != 0: 

297 y_val = y_points_val[~np.isnan(y_points_val)] 

298 

299 if len(y_val) == 1: 

300 y_lim_val = y_val[0], y_val[0] 

301 else: 

302 y_lim_val = np.amin(y_val), np.amax(y_val) 

303 

304 mins.append(y_lim_val[0]) 

305 maxs.append(y_lim_val[1]) 

306 

307 if len(mins) == 1: 

308 y_lims = (mins[0], maxs[0]) 

309 else: 

310 y_lims = np.amin(mins), np.amax(maxs) 

311 

312 if y_lims[0] == y_lims[1]: 

313 y_range = 0.1 * y_lims[0] 

314 else: 

315 y_range = y_lims[1] - y_lims[0] 

316 

317 try: 

318 fraction = float(fraction) 

319 padding = [fraction, fraction] 

320 except TypeError: 

321 # is a list 

322 padding = fraction 

323 

324 if padding != [0.0, 0.0]: 

325 y_lims = (y_lims[0] - padding[0] * y_range, y_lims[1] + padding[1] * y_range) 

326 

327 return y_lims 

328 

329 

330def get_epoch_xticks(x_points): 

331 """ 

332 Calculates the xticks for the train and validation summary plot. 

333 

334 One tick per epoch. Less the larger #epochs is. 

335 

336 Parameters 

337 ---------- 

338 x_points : List 

339 A list of the x coordinates of all points. 

340 

341 Returns 

342 ------- 

343 x_ticks_major : numpy.ndarray 

344 Array containing the ticks. 

345 

346 """ 

347 if len(x_points) == 0: 

348 raise ValueError("x-coordinates are empty!") 

349 

350 minimum, maximum = np.amin(x_points), np.amax(x_points) 

351 if maximum - minimum > 0.5: 

352 # for longer trainings 

353 start_epoch, end_epoch = np.floor(minimum), np.ceil(maximum) 

354 # less xticks if there are many epochs 

355 n_epochs = end_epoch - start_epoch 

356 x_ticks_stepsize = 1 + np.floor(n_epochs / 20.0) 

357 x_ticks_major = np.arange( 

358 start_epoch, end_epoch + x_ticks_stepsize, x_ticks_stepsize 

359 ) 

360 else: 

361 # for early peeks 

362 start_epoch = np.floor(minimum) 

363 end_epoch = maximum + minimum - start_epoch 

364 x_ticks_major = np.linspace(start_epoch, end_epoch, 6) 

365 

366 return x_ticks_major 

367 

368 

369def update_summary_plot(orga): 

370 """ 

371 Plot and save all metrics of the given validation- and train-data 

372 into a pdf file, each metric in its own plot. 

373 

374 If metric pairs of a variable and its error are found (e.g. e_loss 

375 and e_err_loss), they will have the same color and appear back to 

376 back in the plot. 

377 

378 Parameters 

379 ---------- 

380 orga : orcanet.core.Organizer 

381 Contains all the configurable options in the OrcaNet scripts. 

382 

383 """ 

384 plt.ioff() 

385 pdf_name = orga.io.get_subfolder("plots", create=True) + "/summary_plot.pdf" 

386 

387 # Extract the names of the metrics 

388 all_metrics = orga.history.get_metrics() 

389 # Sort them 

390 all_metrics = sort_metrics(all_metrics) 

391 # Plot them w/ custom color cycle 

392 colors = [ 

393 "#000000", 

394 "#332288", 

395 "#88CCEE", 

396 "#44AA99", 

397 "#117733", 

398 "#999933", 

399 "#DDCC77", 

400 "#CC6677", 

401 "#882255", 

402 "#AA4499", 

403 "#661100", 

404 "#6699CC", 

405 "#AA4466", 

406 "#4477AA", 

407 ] # ref. personal.sron.nl/~pault/ 

408 color_counter = 0 

409 with PdfPages(pdf_name) as pdf: 

410 for metric_no, metric in enumerate(all_metrics): 

411 # If this metric is an err metric of a variable, color it the same 

412 if all_metrics[metric_no - 1] == metric.replace("_err", ""): 

413 color_counter -= 1 

414 orga.history.plot_metric(metric, color=colors[color_counter % len(colors)]) 

415 plt.suptitle(os.path.basename(os.path.abspath(orga.cfg.output_folder))) 

416 color_counter += 1 

417 pdf.savefig() 

418 plt.clf() 

419 

420 orga.history.plot_lr() 

421 color_counter += 1 

422 pdf.savefig() 

423 plt.close() 

424 

425 

426def sort_metrics(metric_names): 

427 """ 

428 Sort a list of metrics, so that errors are right after their variable. 

429 The format of the metric names have to be e.g. e_loss and e_err_loss 

430 for this to work. 

431 

432 Example 

433 ---------- 

434 >>> sort_metrics( ['e_loss', 'loss', 'e_err_loss', 'dx_err_loss'] ) 

435 ['e_loss', 'e_err_loss', 'loss', 'dx_err_loss'] 

436 

437 Parameters 

438 ---------- 

439 metric_names : List 

440 List of metric names. 

441 

442 Returns 

443 ------- 

444 sorted_metrics : List 

445 List of sorted metric names with the same length as the input. 

446 

447 """ 

448 sorted_metrics = [0] * len(metric_names) 

449 counter = 0 

450 for metric_name in metric_names: 

451 if "err_" in metric_name: 

452 if metric_name.replace("err_", "") not in metric_names: 

453 sorted_metrics[counter] = metric_name 

454 counter += 1 

455 continue 

456 sorted_metrics[counter] = metric_name 

457 counter += 1 

458 err_loss = metric_name.split("_loss")[0] + "_err_loss" 

459 if err_loss in metric_names: 

460 sorted_metrics[counter] = err_loss 

461 counter += 1 

462 

463 assert 0 not in sorted_metrics, ( 

464 "Something went wrong with the sorting of " 

465 "metrics! Given was {}, output was " 

466 "{}. ".format(metric_names, sorted_metrics) 

467 ) 

468 

469 return sorted_metrics