Coverage for orcanet/logging.py: 97%
242 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
1"""
2Scripts for writing the logfiles.
3"""
5import numpy as np
6import os
7import tensorflow.keras as ks
8from datetime import datetime
9from shutil import move
10import orcanet
13class TrainfileLogger:
14 def __init__(self, log_file, column_names):
15 """
16 For writing the training log file in a nice format.
18 Parameters
19 ----------
20 log_file : opened file
21 The logfile.
22 column_names : List
23 A list of column names for the file.
25 """
26 # Minimum width of the cells in characters.
27 self.minimum_cell_width = 11
28 # Precision to which floats are rounded if they appear in data.
29 self.float_precision = 6
31 self.log_file = log_file
32 self.column_names = column_names
33 self._widths = None
35 def level_file(self):
36 """
37 Make file with only the head lines.
39 Existing file will be overwritten.
40 """
41 headline, widths = self._gen_line_str(self.column_names)
42 vline = ["-" * width for width in widths]
43 vertical_line = self._gen_line_str(vline, widths, seperator="-+-")[0]
44 self.log_file.write(headline + "\n")
45 self.log_file.write(vertical_line + "\n")
47 self._widths = widths
49 def write_line(self, values):
50 """
51 Write a line with data to the file.
53 Parameters
54 ----------
55 values : List
56 The data, in the same order as the column names.
58 """
59 if self._widths is None:
60 raise ValueError("Can not log: .level_file has to be called first")
61 if len(values) != len(self.column_names):
62 raise ValueError(
63 "Can not log: Expected {} values, but got "
64 "{}".format(len(self.column_names), len(values))
65 )
67 line = self._gen_line_str(values, self._widths)[0]
68 self.log_file.write(line + "\n")
70 def _gen_line_str(self, data, widths=None, seperator=" | "):
71 line, widths = gen_line_str(
72 data,
73 widths=widths,
74 seperator=seperator,
75 float_precision=self.float_precision,
76 minimum_cell_width=self.minimum_cell_width,
77 )
78 return line, widths
81def gen_line_str(
82 data, widths=None, seperator=" | ", float_precision=4, minimum_cell_width=9
83):
84 """
85 Generate a line in nice human readable format,
86 consisting of multiple spaced and seperated cells.
88 Parameters
89 ----------
90 data : tuple
91 Strings or floats of what is in each cell. It must be in the
92 same order and have the same length as the column names.
93 widths : List or None
94 Optional: The width of every cell. If None, will set it
95 automatically, depending on the data.
96 If widths is given, but what is given in data is wider than
97 the width, the cell will expand without notice. Must have the
98 same length as the column names.
99 seperator : str
100 String that seperates two adjacent cells.
101 float_precision : int
102 Precision to which floats are rounded if they appear in data.
103 The length of the resulting numbercan be up to 5 characters longer
104 than this value (due to . and e-09)
106 minimum_cell_width : int
107 Minimum width of the cells in characters.
109 Returns
110 -------
111 line : str
112 The line.
113 new_widths : List
114 The widths of the cells.
116 """
117 cells, new_widths = gen_line_cells(
118 data, widths, float_precision, minimum_cell_width
119 )
120 line = seperator.join(str(cell) for cell in cells)
121 return line, new_widths
124def gen_line_cells(data, widths=None, float_precision=4, minimum_cell_width=9):
125 """
126 Generate the content of the cells for a line in the summary file.
128 See gen_line_str (above) for doc.
130 Returns
131 -------
132 cells : List
133 new_widths : List
135 """
136 if widths is None:
137 new_widths = []
138 else:
139 new_widths = widths
141 cells = []
142 for i, entry in enumerate(data):
143 # If entry is a number, round to given precision and make it a string
144 if not isinstance(entry, str):
145 entry = format(float(entry), "." + str(float_precision) + "g")
147 if widths is None:
148 cell_width = max(minimum_cell_width, len(entry))
149 new_widths.append(cell_width)
150 else:
151 cell_width = widths[i]
153 cell_cont = format(entry, "<" + str(cell_width))
154 cells.append(cell_cont)
156 return cells, new_widths
159class SummaryLogger:
160 """
161 For writing the summary logfile made during training.
163 Parameters
164 ----------
165 orga : orcanet.core.Organizer
166 Contains all the configurable options in the OrcaNet scripts.
167 model : ks.model.Model or None
168 Keras model containing the metrics to plot.
170 """
172 def __init__(self, orga, model):
173 self.orga = orga
174 self.model = model
175 # Minimum width of the cells in characters.
176 self.minimum_cell_width = 11
177 # Precision to which floats are rounded if they appear in data.
178 self.float_precision = 6
180 self.logfile_name = orga.cfg.output_folder + "summary.txt"
181 self.temp_filepath = orga.cfg.output_folder + "/.temp_summary.txt"
183 def write_line(self, epoch_float, lr, history_train=None, history_val=None):
184 """
185 Write a line to the summary.txt file in the trained model folder.
187 Will update an existing line if possible.
189 Notes
190 -----
191 In tf 2.2, model.metrics_names is only filled after the model
192 has been used on data, i.e. only after that point this line can
193 be run. Otherwise, _get_column_names will throw a NameError.
195 Parameters
196 ----------
197 epoch_float : float
198 The current epoch and fileno as a float.
199 lr : float/str
200 The current learning rate of the model.
201 history_train : dict
202 Dict containing the history of the training, averaged over files.
203 Keys: Metric names, e.g. "loss", "accuracy", ...
204 Values: Value of the metric during validation as a float.
205 history_val : dict or None
206 Dict of validation losses for all the metrics, averaged over
207 all validation files.
208 Keys: Metric names, e.g. "loss", "accuracy", ...
209 Values: Value of the metric during validation as a float.
211 """
212 if history_val is None and history_train is None:
213 raise ValueError(
214 "Can not summary log when both train and val history are None"
215 )
217 widths = self._init_writing()
219 # Format the content: (Epoch, LR, train_1, val_1, ...)
220 data = [epoch_float, lr]
221 for i, metric_name in enumerate(self.model.metrics_names):
222 if history_train is None:
223 data.append("n/a")
224 else:
225 data.append(history_train[metric_name])
226 if history_val is None:
227 data.append("n/a")
228 else:
229 data.append(history_val[metric_name])
231 # if the epoch is already in the file, its line will get updated
232 update_line = False
233 summary_data = self.orga.history.get_summary_data()
234 if len(summary_data) > 0:
235 last_line = summary_data[-1]
236 # get epoch to same length as it appears in file
237 # TODO this is bad, epoch, fileno should probably be in the
238 # summary.txt in their own columns
239 data[0] = float(self._gen_line_cells(data, widths)[0][0])
241 if last_line["Epoch"] == data[0]:
242 # merge arrays but ignore LR
243 data = merge_arrays(last_line, data, exclude=1)
244 update_line = True
246 line = self._gen_line_str(data, widths)[0]
247 self._save_line(line, update_line)
249 def _get_column_names(self):
250 column_names = [
251 "Epoch",
252 "LR",
253 ]
254 for metric_name in self.model.metrics_names:
255 column_names.append("train_" + str(metric_name))
256 column_names.append("val_" + str(metric_name))
257 column_names = tuple(column_names)
259 if os.path.isfile(self.logfile_name):
260 # if summary exists already, check if model metrics match
261 file_column_names = self.orga.history.get_column_names()
262 if not set(column_names) == set(file_column_names):
263 raise NameError(
264 "Can not log to summary: column names differ (from model: "
265 "{}, from summary file: {}".format(column_names, file_column_names)
266 )
267 column_names = file_column_names
269 return column_names
271 def _save_line(self, line, update=False):
272 """Write a line in the summary file. If update, overwrite the last line."""
273 if not update:
274 with open(self.logfile_name, "a+") as logfile:
275 logfile.write(line + "\n")
277 else:
278 # replace last line by rewriting whole summary file :-(
279 with open(self.logfile_name, "r") as old_file:
280 lines = old_file.readlines()
281 lines[-1] = line + "\n"
283 # make new summary file as a temp
284 with open(self.temp_filepath, "w") as temp_file:
285 for old_line in lines:
286 temp_file.write(old_line)
288 # Remove original file
289 os.remove(self.logfile_name)
290 # Move new file
291 move(self.temp_filepath, self.logfile_name)
293 def _init_writing(self):
294 """
295 Get the widths of the columns, and write the head if the file is new.
297 The widths have the length of the metric names, but at least the
298 self.minimum_cell_width.
300 Returns
301 -------
302 widths : list
303 The width of every cell in characters.
305 """
306 column_names = self._get_column_names()
307 headline, widths = self._gen_line_str(column_names)
308 if (
309 not os.path.isfile(self.logfile_name)
310 or os.stat(self.logfile_name).st_size == 0
311 ):
313 vline = ["-" * width for width in widths]
314 vertical_line = self._gen_line_str(vline, widths, seperator="-+-")[0]
315 with open(self.logfile_name, "a+") as logfile:
316 logfile.write(headline + "\n")
317 logfile.write(vertical_line + "\n")
319 return widths
321 def _gen_line_str(self, data, widths=None, seperator=" | "):
322 line, widths = gen_line_str(
323 data,
324 widths=widths,
325 seperator=seperator,
326 float_precision=self.float_precision,
327 minimum_cell_width=self.minimum_cell_width,
328 )
329 return line, widths
331 def _gen_line_cells(self, data, widths=None):
332 cells, widths = gen_line_cells(
333 data,
334 widths=widths,
335 float_precision=self.float_precision,
336 minimum_cell_width=self.minimum_cell_width,
337 )
338 return cells, widths
341def merge_arrays(base, supp, exclude=None):
342 """
343 Fill nans in a list with values from another list.
345 Parameters
346 ----------
347 base : List
348 supp : List
349 exclude : List or int
350 Which indices to ignore.
352 Returns
353 -------
354 np.array
356 """
357 try:
358 iter(exclude)
359 except TypeError:
360 exclude = [exclude]
362 for i in range(len(base)):
363 if exclude is not None and i in exclude:
364 continue
366 if base[i] == supp[i]:
367 continue
369 elif base[i] == "n/a" or np.isnan(base[i]):
370 base[i] = supp[i]
372 elif supp[i] == "n/a" or np.isnan(supp[i]):
373 continue
375 else:
376 raise ValueError(
377 "Cannot merge arrays at index {}: Base {}, supplement {}".format(
378 i, base[i], supp[i]
379 )
380 )
381 return base
384class BatchLogger(ks.callbacks.Callback):
385 """
386 Write logfiles during training.
388 Averages the losses of the model over some number of batches,
389 and then writes that in a line in the logfile.
390 The Batch_float entry in the logfiles gives the absolute position
391 of the batch in the epoch (i.e. taking all files into account).
392 This class is intended to be used only for one epoch = one file.
394 Parameters
395 ----------
396 orga : orcanet.core.Organizer
397 Contains all the configurable options in the OrcaNet scripts.
398 epoch : tuple
399 Epoch and file number.
400 reset_metrics : bool
401 Reset internal state of metric after eveery batch?
403 """
405 def __init__(self, orga, epoch, reset_metrics=True):
406 super().__init__()
407 self.reset_metrics = reset_metrics
409 self.epoch_number = epoch[0]
410 self.f_number = epoch[1]
412 # settings (read from orga)
413 self.display = orga.cfg.train_logger_display
414 self.flush = orga.cfg.train_logger_flush
415 self.logfile_name = "{}/log_epoch_{}_file_{}.txt".format(
416 orga.io.get_subfolder("train_log", create=True),
417 self.epoch_number,
418 self.f_number,
419 )
420 self.batchsize = orga.cfg.batchsize
421 self.file_sizes = np.array(orga.io.get_file_sizes("train"))
423 # get the total no of batches over all files (not just the current one)
424 # This is for calculating the batch_float number in the logs
425 file_batches = np.ceil(self.file_sizes / self.batchsize)
426 self.total_batches = np.sum(file_batches)
427 # no of batches seen in previous files
428 if self.f_number == 1:
429 self.previous_batches = 0.0
430 elif self.f_number > 1:
431 self.previous_batches = np.cumsum(file_batches)[self.f_number - 2]
432 else:
433 raise AssertionError("f_number not >= 1 ({})".format(self.f_number))
435 self.seen = None
436 self.lines = None
437 self.cum_metrics = None
438 self.file = None
439 self.epoch_initialized = False
440 self._stored_metrics = False
441 self._logger = None
443 def on_epoch_begin(self, epoch, logs=None):
444 self.epoch_initialized = False
446 def initialize_epoch(self):
447 """Start a new logfile and prepare the logger."""
448 # no of seen batches in this epoch
449 self.seen = 0
450 # list of stored lines, so that multiple can be written at once
451 self.lines = []
452 # store the various metrices to be able to average over multiple batches
453 self.cum_metrics = {}
454 for metric in self.model.metrics_names:
455 self.cum_metrics[metric] = 0
456 self.file = open(self.logfile_name, "w")
457 self._write_head()
459 def on_batch_end(self, batch, logs=None):
460 # self.params:
461 # {'epochs': 1, 'steps': 50, 'verbose': 1, 'do_validation': False,
462 # 'metrics': ['loss', 'dx_loss', 'dx_err_loss',
463 # 'val_loss', 'val_dx_loss', 'val_dx_err_loss']}
464 # logs:
465 # {'batch': 7, 'size': 5, 'loss': 2.06344,
466 # 'dx_loss': 0.19809794, 'dx_err_loss': 0.08246058}
467 logs = logs or {}
468 if not self.epoch_initialized:
469 self.initialize_epoch()
470 self.epoch_initialized = True
472 self.seen += 1
473 for metric in self.model.metrics_names:
474 self.cum_metrics[metric] += logs.get(metric)
475 if not self._stored_metrics:
476 self._stored_metrics = True
478 if self.seen % self.display == 0:
479 self._write_line()
481 if self.flush != -1 and self.display % self.flush == 0:
482 self._flush_file()
483 if self.reset_metrics:
484 self.model.reset_metrics()
486 def on_epoch_end(self, batch, logs=None):
487 # on epoch end here means that this is called after one fit_generator
488 # loop in Keras is finished, so after one file in our case.
489 """
490 if self._stored_metrics:
491 # write stats of remaining batches
492 self._write_line()
493 """
494 self.file.close()
496 def _write_line(self):
497 """Write a line with the metrics for current status and reset metrics."""
498 # The fraction is shifted by self.display / 2., so that it is in
499 # the middle of the samples
500 batch_frctn = (
501 self.epoch_number
502 - 1
503 + (self.previous_batches + self.seen - self.display / 2.0)
504 / self.total_batches
505 )
507 line_data = [self.seen, batch_frctn]
508 for metric in self.model.metrics_names:
509 line_data.append(self.cum_metrics[metric] / self.display)
510 self.cum_metrics[metric] = 0
511 self._logger.write_line(line_data)
512 self._stored_metrics = False
514 def _flush_file(self):
515 self.file.flush()
516 os.fsync(self.file.fileno())
518 def _write_head(self):
519 """write column names for all losses / metrics"""
520 column_names = ["Batch", "Batch_float"]
521 for metric in self.model.metrics_names:
522 column_names.append(metric)
523 self._logger = TrainfileLogger(self.file, column_names)
524 self._logger.level_file()
527def log_start_training(orga):
528 """
529 When a training is started for the first time, this logs all the
530 input parameters to the log.txt file.
532 Parameters
533 ----------
534 orga : orcanet.core.Organizer
535 Contains all the configurable options in the OrcaNet scripts.
537 """
538 lines = []
539 log = lines.append
541 log("-" * 60)
542 time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
543 log("-" * 19 + " {} ".format(time) + "-" * 19)
544 log("-" * 10 + " orcanet version " + str(orcanet.__version__))
545 log("\nTraining run started with the following configuration:\n")
547 log("Output folder:\t" + orga.cfg.output_folder)
548 log("List file path:\t" + orga.cfg.get_list_file() + "\n")
550 log("Given trainfiles in the .list file:")
551 for input_name, input_files in orga.cfg.get_files("train").items():
552 log(" " + input_name + ":")
553 [log("\t" + input_file) for input_file in input_files]
555 log("\nGiven validation files in the .list file:")
556 for input_name, input_files in orga.cfg.get_files("val").items():
557 log(" " + input_name + ":")
558 [log("\t" + input_file) for input_file in input_files]
560 log("\nSettings used:")
561 for key, value in vars(orga.cfg).items():
562 if key == "output_folder" or key.startswith("_"):
563 continue
564 log(" {}:\t{}".format(key, value))
566 log("")
567 orga.io.print_log(lines)
570def log_start_validation(orga):
571 """Log filenames used for validation."""
572 line = "Validation"
573 orga.io.print_log(line)
574 orga.io.print_log("-" * len(line))
575 lines = [
576 "Inputs and files:",
577 ]
578 for input_name, input_files in orga.io.get_local_files("val").items():
579 line = " " + input_name + ":\t"
580 for i, input_file in enumerate(input_files):
581 if i != 0:
582 line += ", "
583 line += os.path.basename(input_file)
584 lines.append(line)
585 orga.io.print_log(lines)
588# class TensorBoardWrapper(ks.callbacks.TensorBoard):
589# """Up to now (05.10.17), Keras doesn't accept TensorBoard callbacks with validation data that is fed by a generator.
590# Supplying the validation data is needed for the histogram_freq > 1 argument in the TB callback.
591# Without a workaround, only scalar values (e.g. loss, accuracy) and the computational graph of the model can be saved.
592#
593# This class acts as a Wrapper for the ks.callbacks.TensorBoard class in such a way,
594# that the whole validation data is put into a single array by using the generator.
595# Then, the single array is used in the validation steps. This workaround is experimental!"""
596# def __init__(self, batch_gen, nb_steps, **kwargs):
597# super(TensorBoardWrapper, self).__init__(**kwargs)
598# self.batch_gen = batch_gen # The generator.
599# self.nb_steps = nb_steps # Number of times to call next() on the generator.
600#
601# def on_epoch_end(self, epoch, logs):
602# # Fill in the `validation_data` property.
603# # After it's filled in, the regular on_epoch_end method has access to the validation_data.
604# imgs, tags = None, None
605# for s in range(self.nb_steps):
606# ib, tb = next(self.batch_gen)
607# if imgs is None and tags is None:
608# imgs = np.zeros(((self.nb_steps * ib.shape[0],) + ib.shape[1:]), dtype=np.float32)
609# tags = np.zeros(((self.nb_steps * tb.shape[0],) + tb.shape[1:]), dtype=np.uint8)
610# imgs[s * ib.shape[0]:(s + 1) * ib.shape[0]] = ib
611# tags[s * tb.shape[0]:(s + 1) * tb.shape[0]] = tb
612# self.validation_data = [imgs, tags, np.ones(imgs.shape[0]), 0.0]
613# return super(TensorBoardWrapper, self).on_epoch_end(epoch, logs)