Coverage for orcanet/logging.py: 97%

242 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-03-28 14:22 +0000

1""" 

2Scripts for writing the logfiles. 

3""" 

4 

5import numpy as np 

6import os 

7import tensorflow.keras as ks 

8from datetime import datetime 

9from shutil import move 

10import orcanet 

11 

12 

13class TrainfileLogger: 

14 def __init__(self, log_file, column_names): 

15 """ 

16 For writing the training log file in a nice format. 

17 

18 Parameters 

19 ---------- 

20 log_file : opened file 

21 The logfile. 

22 column_names : List 

23 A list of column names for the file. 

24 

25 """ 

26 # Minimum width of the cells in characters. 

27 self.minimum_cell_width = 11 

28 # Precision to which floats are rounded if they appear in data. 

29 self.float_precision = 6 

30 

31 self.log_file = log_file 

32 self.column_names = column_names 

33 self._widths = None 

34 

35 def level_file(self): 

36 """ 

37 Make file with only the head lines. 

38 

39 Existing file will be overwritten. 

40 """ 

41 headline, widths = self._gen_line_str(self.column_names) 

42 vline = ["-" * width for width in widths] 

43 vertical_line = self._gen_line_str(vline, widths, seperator="-+-")[0] 

44 self.log_file.write(headline + "\n") 

45 self.log_file.write(vertical_line + "\n") 

46 

47 self._widths = widths 

48 

49 def write_line(self, values): 

50 """ 

51 Write a line with data to the file. 

52 

53 Parameters 

54 ---------- 

55 values : List 

56 The data, in the same order as the column names. 

57 

58 """ 

59 if self._widths is None: 

60 raise ValueError("Can not log: .level_file has to be called first") 

61 if len(values) != len(self.column_names): 

62 raise ValueError( 

63 "Can not log: Expected {} values, but got " 

64 "{}".format(len(self.column_names), len(values)) 

65 ) 

66 

67 line = self._gen_line_str(values, self._widths)[0] 

68 self.log_file.write(line + "\n") 

69 

70 def _gen_line_str(self, data, widths=None, seperator=" | "): 

71 line, widths = gen_line_str( 

72 data, 

73 widths=widths, 

74 seperator=seperator, 

75 float_precision=self.float_precision, 

76 minimum_cell_width=self.minimum_cell_width, 

77 ) 

78 return line, widths 

79 

80 

81def gen_line_str( 

82 data, widths=None, seperator=" | ", float_precision=4, minimum_cell_width=9 

83): 

84 """ 

85 Generate a line in nice human readable format, 

86 consisting of multiple spaced and seperated cells. 

87 

88 Parameters 

89 ---------- 

90 data : tuple 

91 Strings or floats of what is in each cell. It must be in the 

92 same order and have the same length as the column names. 

93 widths : List or None 

94 Optional: The width of every cell. If None, will set it 

95 automatically, depending on the data. 

96 If widths is given, but what is given in data is wider than 

97 the width, the cell will expand without notice. Must have the 

98 same length as the column names. 

99 seperator : str 

100 String that seperates two adjacent cells. 

101 float_precision : int 

102 Precision to which floats are rounded if they appear in data. 

103 The length of the resulting numbercan be up to 5 characters longer 

104 than this value (due to . and e-09) 

105 

106 minimum_cell_width : int 

107 Minimum width of the cells in characters. 

108 

109 Returns 

110 ------- 

111 line : str 

112 The line. 

113 new_widths : List 

114 The widths of the cells. 

115 

116 """ 

117 cells, new_widths = gen_line_cells( 

118 data, widths, float_precision, minimum_cell_width 

119 ) 

120 line = seperator.join(str(cell) for cell in cells) 

121 return line, new_widths 

122 

123 

124def gen_line_cells(data, widths=None, float_precision=4, minimum_cell_width=9): 

125 """ 

126 Generate the content of the cells for a line in the summary file. 

127 

128 See gen_line_str (above) for doc. 

129 

130 Returns 

131 ------- 

132 cells : List 

133 new_widths : List 

134 

135 """ 

136 if widths is None: 

137 new_widths = [] 

138 else: 

139 new_widths = widths 

140 

141 cells = [] 

142 for i, entry in enumerate(data): 

143 # If entry is a number, round to given precision and make it a string 

144 if not isinstance(entry, str): 

145 entry = format(float(entry), "." + str(float_precision) + "g") 

146 

147 if widths is None: 

148 cell_width = max(minimum_cell_width, len(entry)) 

149 new_widths.append(cell_width) 

150 else: 

151 cell_width = widths[i] 

152 

153 cell_cont = format(entry, "<" + str(cell_width)) 

154 cells.append(cell_cont) 

155 

156 return cells, new_widths 

157 

158 

159class SummaryLogger: 

160 """ 

161 For writing the summary logfile made during training. 

162 

163 Parameters 

164 ---------- 

165 orga : orcanet.core.Organizer 

166 Contains all the configurable options in the OrcaNet scripts. 

167 model : ks.model.Model or None 

168 Keras model containing the metrics to plot. 

169 

170 """ 

171 

172 def __init__(self, orga, model): 

173 self.orga = orga 

174 self.model = model 

175 # Minimum width of the cells in characters. 

176 self.minimum_cell_width = 11 

177 # Precision to which floats are rounded if they appear in data. 

178 self.float_precision = 6 

179 

180 self.logfile_name = orga.cfg.output_folder + "summary.txt" 

181 self.temp_filepath = orga.cfg.output_folder + "/.temp_summary.txt" 

182 

183 def write_line(self, epoch_float, lr, history_train=None, history_val=None): 

184 """ 

185 Write a line to the summary.txt file in the trained model folder. 

186 

187 Will update an existing line if possible. 

188 

189 Notes 

190 ----- 

191 In tf 2.2, model.metrics_names is only filled after the model 

192 has been used on data, i.e. only after that point this line can 

193 be run. Otherwise, _get_column_names will throw a NameError. 

194 

195 Parameters 

196 ---------- 

197 epoch_float : float 

198 The current epoch and fileno as a float. 

199 lr : float/str 

200 The current learning rate of the model. 

201 history_train : dict 

202 Dict containing the history of the training, averaged over files. 

203 Keys: Metric names, e.g. "loss", "accuracy", ... 

204 Values: Value of the metric during validation as a float. 

205 history_val : dict or None 

206 Dict of validation losses for all the metrics, averaged over 

207 all validation files. 

208 Keys: Metric names, e.g. "loss", "accuracy", ... 

209 Values: Value of the metric during validation as a float. 

210 

211 """ 

212 if history_val is None and history_train is None: 

213 raise ValueError( 

214 "Can not summary log when both train and val history are None" 

215 ) 

216 

217 widths = self._init_writing() 

218 

219 # Format the content: (Epoch, LR, train_1, val_1, ...) 

220 data = [epoch_float, lr] 

221 for i, metric_name in enumerate(self.model.metrics_names): 

222 if history_train is None: 

223 data.append("n/a") 

224 else: 

225 data.append(history_train[metric_name]) 

226 if history_val is None: 

227 data.append("n/a") 

228 else: 

229 data.append(history_val[metric_name]) 

230 

231 # if the epoch is already in the file, its line will get updated 

232 update_line = False 

233 summary_data = self.orga.history.get_summary_data() 

234 if len(summary_data) > 0: 

235 last_line = summary_data[-1] 

236 # get epoch to same length as it appears in file 

237 # TODO this is bad, epoch, fileno should probably be in the 

238 # summary.txt in their own columns 

239 data[0] = float(self._gen_line_cells(data, widths)[0][0]) 

240 

241 if last_line["Epoch"] == data[0]: 

242 # merge arrays but ignore LR 

243 data = merge_arrays(last_line, data, exclude=1) 

244 update_line = True 

245 

246 line = self._gen_line_str(data, widths)[0] 

247 self._save_line(line, update_line) 

248 

249 def _get_column_names(self): 

250 column_names = [ 

251 "Epoch", 

252 "LR", 

253 ] 

254 for metric_name in self.model.metrics_names: 

255 column_names.append("train_" + str(metric_name)) 

256 column_names.append("val_" + str(metric_name)) 

257 column_names = tuple(column_names) 

258 

259 if os.path.isfile(self.logfile_name): 

260 # if summary exists already, check if model metrics match 

261 file_column_names = self.orga.history.get_column_names() 

262 if not set(column_names) == set(file_column_names): 

263 raise NameError( 

264 "Can not log to summary: column names differ (from model: " 

265 "{}, from summary file: {}".format(column_names, file_column_names) 

266 ) 

267 column_names = file_column_names 

268 

269 return column_names 

270 

271 def _save_line(self, line, update=False): 

272 """Write a line in the summary file. If update, overwrite the last line.""" 

273 if not update: 

274 with open(self.logfile_name, "a+") as logfile: 

275 logfile.write(line + "\n") 

276 

277 else: 

278 # replace last line by rewriting whole summary file :-( 

279 with open(self.logfile_name, "r") as old_file: 

280 lines = old_file.readlines() 

281 lines[-1] = line + "\n" 

282 

283 # make new summary file as a temp 

284 with open(self.temp_filepath, "w") as temp_file: 

285 for old_line in lines: 

286 temp_file.write(old_line) 

287 

288 # Remove original file 

289 os.remove(self.logfile_name) 

290 # Move new file 

291 move(self.temp_filepath, self.logfile_name) 

292 

293 def _init_writing(self): 

294 """ 

295 Get the widths of the columns, and write the head if the file is new. 

296 

297 The widths have the length of the metric names, but at least the 

298 self.minimum_cell_width. 

299 

300 Returns 

301 ------- 

302 widths : list 

303 The width of every cell in characters. 

304 

305 """ 

306 column_names = self._get_column_names() 

307 headline, widths = self._gen_line_str(column_names) 

308 if ( 

309 not os.path.isfile(self.logfile_name) 

310 or os.stat(self.logfile_name).st_size == 0 

311 ): 

312 

313 vline = ["-" * width for width in widths] 

314 vertical_line = self._gen_line_str(vline, widths, seperator="-+-")[0] 

315 with open(self.logfile_name, "a+") as logfile: 

316 logfile.write(headline + "\n") 

317 logfile.write(vertical_line + "\n") 

318 

319 return widths 

320 

321 def _gen_line_str(self, data, widths=None, seperator=" | "): 

322 line, widths = gen_line_str( 

323 data, 

324 widths=widths, 

325 seperator=seperator, 

326 float_precision=self.float_precision, 

327 minimum_cell_width=self.minimum_cell_width, 

328 ) 

329 return line, widths 

330 

331 def _gen_line_cells(self, data, widths=None): 

332 cells, widths = gen_line_cells( 

333 data, 

334 widths=widths, 

335 float_precision=self.float_precision, 

336 minimum_cell_width=self.minimum_cell_width, 

337 ) 

338 return cells, widths 

339 

340 

341def merge_arrays(base, supp, exclude=None): 

342 """ 

343 Fill nans in a list with values from another list. 

344 

345 Parameters 

346 ---------- 

347 base : List 

348 supp : List 

349 exclude : List or int 

350 Which indices to ignore. 

351 

352 Returns 

353 ------- 

354 np.array 

355 

356 """ 

357 try: 

358 iter(exclude) 

359 except TypeError: 

360 exclude = [exclude] 

361 

362 for i in range(len(base)): 

363 if exclude is not None and i in exclude: 

364 continue 

365 

366 if base[i] == supp[i]: 

367 continue 

368 

369 elif base[i] == "n/a" or np.isnan(base[i]): 

370 base[i] = supp[i] 

371 

372 elif supp[i] == "n/a" or np.isnan(supp[i]): 

373 continue 

374 

375 else: 

376 raise ValueError( 

377 "Cannot merge arrays at index {}: Base {}, supplement {}".format( 

378 i, base[i], supp[i] 

379 ) 

380 ) 

381 return base 

382 

383 

384class BatchLogger(ks.callbacks.Callback): 

385 """ 

386 Write logfiles during training. 

387 

388 Averages the losses of the model over some number of batches, 

389 and then writes that in a line in the logfile. 

390 The Batch_float entry in the logfiles gives the absolute position 

391 of the batch in the epoch (i.e. taking all files into account). 

392 This class is intended to be used only for one epoch = one file. 

393 

394 Parameters 

395 ---------- 

396 orga : orcanet.core.Organizer 

397 Contains all the configurable options in the OrcaNet scripts. 

398 epoch : tuple 

399 Epoch and file number. 

400 reset_metrics : bool 

401 Reset internal state of metric after eveery batch? 

402 

403 """ 

404 

405 def __init__(self, orga, epoch, reset_metrics=True): 

406 super().__init__() 

407 self.reset_metrics = reset_metrics 

408 

409 self.epoch_number = epoch[0] 

410 self.f_number = epoch[1] 

411 

412 # settings (read from orga) 

413 self.display = orga.cfg.train_logger_display 

414 self.flush = orga.cfg.train_logger_flush 

415 self.logfile_name = "{}/log_epoch_{}_file_{}.txt".format( 

416 orga.io.get_subfolder("train_log", create=True), 

417 self.epoch_number, 

418 self.f_number, 

419 ) 

420 self.batchsize = orga.cfg.batchsize 

421 self.file_sizes = np.array(orga.io.get_file_sizes("train")) 

422 

423 # get the total no of batches over all files (not just the current one) 

424 # This is for calculating the batch_float number in the logs 

425 file_batches = np.ceil(self.file_sizes / self.batchsize) 

426 self.total_batches = np.sum(file_batches) 

427 # no of batches seen in previous files 

428 if self.f_number == 1: 

429 self.previous_batches = 0.0 

430 elif self.f_number > 1: 

431 self.previous_batches = np.cumsum(file_batches)[self.f_number - 2] 

432 else: 

433 raise AssertionError("f_number not >= 1 ({})".format(self.f_number)) 

434 

435 self.seen = None 

436 self.lines = None 

437 self.cum_metrics = None 

438 self.file = None 

439 self.epoch_initialized = False 

440 self._stored_metrics = False 

441 self._logger = None 

442 

443 def on_epoch_begin(self, epoch, logs=None): 

444 self.epoch_initialized = False 

445 

446 def initialize_epoch(self): 

447 """Start a new logfile and prepare the logger.""" 

448 # no of seen batches in this epoch 

449 self.seen = 0 

450 # list of stored lines, so that multiple can be written at once 

451 self.lines = [] 

452 # store the various metrices to be able to average over multiple batches 

453 self.cum_metrics = {} 

454 for metric in self.model.metrics_names: 

455 self.cum_metrics[metric] = 0 

456 self.file = open(self.logfile_name, "w") 

457 self._write_head() 

458 

459 def on_batch_end(self, batch, logs=None): 

460 # self.params: 

461 # {'epochs': 1, 'steps': 50, 'verbose': 1, 'do_validation': False, 

462 # 'metrics': ['loss', 'dx_loss', 'dx_err_loss', 

463 # 'val_loss', 'val_dx_loss', 'val_dx_err_loss']} 

464 # logs: 

465 # {'batch': 7, 'size': 5, 'loss': 2.06344, 

466 # 'dx_loss': 0.19809794, 'dx_err_loss': 0.08246058} 

467 logs = logs or {} 

468 if not self.epoch_initialized: 

469 self.initialize_epoch() 

470 self.epoch_initialized = True 

471 

472 self.seen += 1 

473 for metric in self.model.metrics_names: 

474 self.cum_metrics[metric] += logs.get(metric) 

475 if not self._stored_metrics: 

476 self._stored_metrics = True 

477 

478 if self.seen % self.display == 0: 

479 self._write_line() 

480 

481 if self.flush != -1 and self.display % self.flush == 0: 

482 self._flush_file() 

483 if self.reset_metrics: 

484 self.model.reset_metrics() 

485 

486 def on_epoch_end(self, batch, logs=None): 

487 # on epoch end here means that this is called after one fit_generator 

488 # loop in Keras is finished, so after one file in our case. 

489 """ 

490 if self._stored_metrics: 

491 # write stats of remaining batches 

492 self._write_line() 

493 """ 

494 self.file.close() 

495 

496 def _write_line(self): 

497 """Write a line with the metrics for current status and reset metrics.""" 

498 # The fraction is shifted by self.display / 2., so that it is in 

499 # the middle of the samples 

500 batch_frctn = ( 

501 self.epoch_number 

502 - 1 

503 + (self.previous_batches + self.seen - self.display / 2.0) 

504 / self.total_batches 

505 ) 

506 

507 line_data = [self.seen, batch_frctn] 

508 for metric in self.model.metrics_names: 

509 line_data.append(self.cum_metrics[metric] / self.display) 

510 self.cum_metrics[metric] = 0 

511 self._logger.write_line(line_data) 

512 self._stored_metrics = False 

513 

514 def _flush_file(self): 

515 self.file.flush() 

516 os.fsync(self.file.fileno()) 

517 

518 def _write_head(self): 

519 """write column names for all losses / metrics""" 

520 column_names = ["Batch", "Batch_float"] 

521 for metric in self.model.metrics_names: 

522 column_names.append(metric) 

523 self._logger = TrainfileLogger(self.file, column_names) 

524 self._logger.level_file() 

525 

526 

527def log_start_training(orga): 

528 """ 

529 When a training is started for the first time, this logs all the 

530 input parameters to the log.txt file. 

531 

532 Parameters 

533 ---------- 

534 orga : orcanet.core.Organizer 

535 Contains all the configurable options in the OrcaNet scripts. 

536 

537 """ 

538 lines = [] 

539 log = lines.append 

540 

541 log("-" * 60) 

542 time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") 

543 log("-" * 19 + " {} ".format(time) + "-" * 19) 

544 log("-" * 10 + " orcanet version " + str(orcanet.__version__)) 

545 log("\nTraining run started with the following configuration:\n") 

546 

547 log("Output folder:\t" + orga.cfg.output_folder) 

548 log("List file path:\t" + orga.cfg.get_list_file() + "\n") 

549 

550 log("Given trainfiles in the .list file:") 

551 for input_name, input_files in orga.cfg.get_files("train").items(): 

552 log(" " + input_name + ":") 

553 [log("\t" + input_file) for input_file in input_files] 

554 

555 log("\nGiven validation files in the .list file:") 

556 for input_name, input_files in orga.cfg.get_files("val").items(): 

557 log(" " + input_name + ":") 

558 [log("\t" + input_file) for input_file in input_files] 

559 

560 log("\nSettings used:") 

561 for key, value in vars(orga.cfg).items(): 

562 if key == "output_folder" or key.startswith("_"): 

563 continue 

564 log(" {}:\t{}".format(key, value)) 

565 

566 log("") 

567 orga.io.print_log(lines) 

568 

569 

570def log_start_validation(orga): 

571 """Log filenames used for validation.""" 

572 line = "Validation" 

573 orga.io.print_log(line) 

574 orga.io.print_log("-" * len(line)) 

575 lines = [ 

576 "Inputs and files:", 

577 ] 

578 for input_name, input_files in orga.io.get_local_files("val").items(): 

579 line = " " + input_name + ":\t" 

580 for i, input_file in enumerate(input_files): 

581 if i != 0: 

582 line += ", " 

583 line += os.path.basename(input_file) 

584 lines.append(line) 

585 orga.io.print_log(lines) 

586 

587 

588# class TensorBoardWrapper(ks.callbacks.TensorBoard): 

589# """Up to now (05.10.17), Keras doesn't accept TensorBoard callbacks with validation data that is fed by a generator. 

590# Supplying the validation data is needed for the histogram_freq > 1 argument in the TB callback. 

591# Without a workaround, only scalar values (e.g. loss, accuracy) and the computational graph of the model can be saved. 

592# 

593# This class acts as a Wrapper for the ks.callbacks.TensorBoard class in such a way, 

594# that the whole validation data is put into a single array by using the generator. 

595# Then, the single array is used in the validation steps. This workaround is experimental!""" 

596# def __init__(self, batch_gen, nb_steps, **kwargs): 

597# super(TensorBoardWrapper, self).__init__(**kwargs) 

598# self.batch_gen = batch_gen # The generator. 

599# self.nb_steps = nb_steps # Number of times to call next() on the generator. 

600# 

601# def on_epoch_end(self, epoch, logs): 

602# # Fill in the `validation_data` property. 

603# # After it's filled in, the regular on_epoch_end method has access to the validation_data. 

604# imgs, tags = None, None 

605# for s in range(self.nb_steps): 

606# ib, tb = next(self.batch_gen) 

607# if imgs is None and tags is None: 

608# imgs = np.zeros(((self.nb_steps * ib.shape[0],) + ib.shape[1:]), dtype=np.float32) 

609# tags = np.zeros(((self.nb_steps * tb.shape[0],) + tb.shape[1:]), dtype=np.uint8) 

610# imgs[s * ib.shape[0]:(s + 1) * ib.shape[0]] = ib 

611# tags[s * tb.shape[0]:(s + 1) * tb.shape[0]] = tb 

612# self.validation_data = [imgs, tags, np.ones(imgs.shape[0]), 0.0] 

613# return super(TensorBoardWrapper, self).on_epoch_end(epoch, logs)