Coverage for orcanet/core.py: 80%
390 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Core scripts for the OrcaNet package.
5"""
7import os
8import toml
9import warnings
10import time
11from datetime import timedelta
12import tensorflow as tf
14import orcanet.backend as backend
15from orcanet.utilities.visualization import update_summary_plot
16from orcanet.in_out import IOHandler
17from orcanet.history import HistoryHandler
18from orcanet.utilities.nn_utilities import load_zero_center_data
19import orcanet.lib as lib
20import orcanet.logging as logging
21import orcanet.misc
22import medgeconv
25class Organizer:
26 """
27 Core class for working with networks in OrcaNet.
29 Attributes
30 ----------
31 cfg : orcanet.core.Configuration
32 Contains all configurable options.
33 io : orcanet.in_out.IOHandler
34 Utility functions for accessing the info in cfg.
35 history : orcanet.in_out.HistoryHandler
36 For reading and plotting data from the log files created
37 during training.
39 """
41 def __init__(
42 self,
43 output_folder,
44 list_file=None,
45 config_file=None,
46 tf_log_level=None,
47 discover_tomls=True,
48 ):
49 """
50 Set the attributes of the Configuration object.
52 Instead of using a config_file, the attributes of orga.cfg can
53 also be changed directly, e.g. by calling orga.cfg.batchsize.
55 Parameters
56 ----------
57 output_folder : str
58 Name of the folder of this model in which everything will be saved,
59 e.g., the summary.txt log file is located in here.
60 Will be used to load saved files or to save new ones.
61 list_file : str, optional
62 Path to a toml list file with pathes to all the h5 files that should
63 be used for training and validation.
64 Will be used to extract samples and labels.
65 Default: Look for a file called 'list.toml' in the given output_folder.
66 config_file : str, optional
67 Path to a toml config file with settings that are used instead of
68 the default ones.
69 Default: Look for a file called 'config.toml' in the given output_folder.
70 tf_log_level : int/str
71 Sets the TensorFlow CPP_MIN_LOG_LEVEL environment variable.
72 0 = all messages are logged (default behavior).
73 1 = INFO messages are not printed.
74 2 = INFO and WARNING messages are not printed.
75 3 = INFO, WARNING, and ERROR messages are not printed.
76 discover_tomls : bool
77 If False, do not try to look for toml files in the given
78 output_folder if list_file or config_file is None [Default: True].
80 """
81 if tf_log_level is not None:
82 os.environ["TF_CPP_MIN_LOG_LEVEL"] = str(tf_log_level)
84 if discover_tomls and list_file is None:
85 list_file = orcanet.misc.find_file(output_folder, "list.toml")
86 if discover_tomls and config_file is None:
87 config_file = orcanet.misc.find_file(output_folder, "config.toml")
89 self.cfg = Configuration(output_folder, list_file, config_file)
90 self.io = IOHandler(self.cfg)
91 self.history = HistoryHandler(output_folder)
93 self.xs_mean = None
94 self._auto_label_modifier = None
95 self._stored_model = None
96 self._strategy = None
98 def train_and_validate(self, model=None, epochs=None, to_epoch=None):
99 """
100 Train a model and validate according to schedule.
102 The various settings of this process can be controlled with the
103 attributes of orca.cfg.
104 The model will be trained on the given data, saved and validated.
105 Logfiles of the training are saved in the output folder.
106 Plots showing the training and validation history, as well as
107 the weights and activations of the network are generated in
108 the plots subfolder after every validation.
109 The training can be resumed by executing this function again.
111 Parameters
112 ----------
113 model : ks.models.Model or str, optional
114 Compiled keras model to use for training. Required for the first
115 epoch (the start of training).
116 Can also be the path to a saved keras model, which will be laoded.
117 If model is None, the most recent saved model will be
118 loaded automatically to continue the training.
119 epochs : int, optional
120 How many epochs should be trained by running this function.
121 None for infinite. This includes the current epoch in case it
122 is not finished yet, i.e. 1 means complete the epoch if there
123 are files left, otherwise do the next epoch.
124 to_epoch : int, optional
125 Train up to and including this epoch. Can not be used together with
126 epochs.
128 Returns
129 -------
130 model : ks.models.Model
131 The trained keras model.
133 """
134 latest_epoch = self.io.get_latest_epoch()
136 model = self._get_model(model, logging=False)
137 self._stored_model = model
139 # check if the validation is missing for the latest fileno
140 if latest_epoch is not None:
141 state = self.history.get_state()[-1]
142 if state["is_validated"] is False and self.val_is_due(latest_epoch):
143 self.validate()
145 next_epoch = self.io.get_next_epoch(latest_epoch)
146 n_train_files = self.io.get_no_of_files("train")
148 if to_epoch is None:
149 epochs_left = epochs
150 else:
151 if epochs is not None:
152 raise ValueError("Can not give both 'epochs' and 'to_epoch'")
153 if latest_epoch is None:
154 epochs_left = to_epoch
155 else:
156 epochs_left = max(
157 0, to_epoch - self.io.get_next_epoch(latest_epoch)[0] + 1
158 )
160 trained_epochs = 0
161 while epochs_left is None or trained_epochs < epochs_left:
162 # Train on remaining files
163 for file_no in range(next_epoch[1], n_train_files + 1):
164 curr_epoch = (next_epoch[0], file_no)
165 self.train(model)
166 if self.val_is_due(curr_epoch):
167 self.validate()
169 next_epoch = (next_epoch[0] + 1, 1)
170 trained_epochs += 1
172 self._stored_model = None
173 return model
175 def train(self, model=None):
176 """
177 Trains a model on the next file.
179 The progress of the training is also logged and plotted.
181 Parameters
182 ----------
183 model : ks.models.Model or str, optional
184 Compiled keras model to use for training. Required for the first
185 epoch (the start of training).
186 Can also be the path to a saved keras model, which will be laoded.
187 If model is None, the most recent saved model will be
188 loaded automatically to continue the training.
190 Returns
191 -------
192 history : dict
193 The history of the training on this file. A record of training
194 loss values and metrics values.
196 """
197 # Create folder structure
198 self.io.get_subfolder(create=True)
199 latest_epoch = self.io.get_latest_epoch()
201 model = self._get_model(model, logging=True)
203 self._set_up(model, logging=True)
205 # epoch about to be trained
206 next_epoch = self.io.get_next_epoch(latest_epoch)
207 next_epoch_float = self.io.get_epoch_float(*next_epoch)
209 if latest_epoch is None:
210 self.io.check_connections(model)
211 logging.log_start_training(self)
213 model_path = self.io.get_model_path(*next_epoch)
214 model_path_local = self.io.get_model_path(*next_epoch, local=True)
215 if os.path.isfile(model_path):
216 raise FileExistsError(
217 "Can not train model in epoch {} file {}, this model has "
218 "already been saved!".format(*next_epoch)
219 )
221 smry_logger = logging.SummaryLogger(self, model)
223 if self.cfg.learning_rate is not None:
224 tf.keras.backend.set_value(
225 model.optimizer.lr, self.io.get_learning_rate(next_epoch)
226 )
228 files_dict = self.io.get_file("train", next_epoch[1])
230 line = "Training in epoch {} on file {}/{}".format(
231 next_epoch[0], next_epoch[1], self.io.get_no_of_files("train")
232 )
233 self.io.print_log(line)
234 self.io.print_log("-" * len(line))
235 self.io.print_log(
236 "Learning rate is at {}".format(
237 tf.keras.backend.get_value(model.optimizer.lr)
238 )
239 )
240 self.io.print_log("Inputs and files:")
241 for input_name, input_file in files_dict.items():
242 self.io.print_log(
243 " {}: \t{}".format(input_name, os.path.basename(input_file))
244 )
246 start_time = time.time()
247 history = backend.train_model(self, model, next_epoch, batch_logger=True)
248 elapsed_s = int(time.time() - start_time)
250 model.save(model_path)
251 smry_logger.write_line(
252 next_epoch_float,
253 tf.keras.backend.get_value(model.optimizer.lr),
254 history_train=history,
255 )
257 self.io.print_log("Training results:")
258 for metric_name, loss in history.items():
259 self.io.print_log(f" {metric_name}: \t{loss}")
260 self.io.print_log(f"Elapsed time: {timedelta(seconds=elapsed_s)}")
261 self.io.print_log(f"Saved model to: {model_path_local}\n")
263 update_summary_plot(self)
264 if self.cfg.cleanup_models:
265 self.cleanup_models()
267 return history
269 def validate(self):
270 """
271 Validate the most recent saved model on all validation files.
273 Will also log the progress, as well as update the summary plot and
274 plot weights and activations of the model.
276 Returns
277 -------
278 history : dict
279 The history of the validation on all files. A record of validation
280 loss values and metrics values.
282 """
283 latest_epoch = self.io.get_latest_epoch()
284 if latest_epoch is None:
285 raise ValueError("Can not validate: No saved model found")
286 if self.history.get_state()[-1]["is_validated"] is True:
287 raise ValueError(
288 "Can not validate in epoch {} file {}: "
289 "Has already been validated".format(*latest_epoch)
290 )
292 if self._stored_model is None:
293 model = self.load_saved_model(*latest_epoch)
294 else:
295 model = self._stored_model
297 self._set_up(model, logging=True)
299 epoch_float = self.io.get_epoch_float(*latest_epoch)
300 smry_logger = logging.SummaryLogger(self, model)
302 logging.log_start_validation(self)
304 start_time = time.time()
305 history = backend.validate_model(self, model)
306 elapsed_s = int(time.time() - start_time)
308 self.io.print_log("Validation results:")
309 for metric_name, loss in history.items():
310 self.io.print_log(f" {metric_name}: \t{loss}")
311 self.io.print_log(f"Elapsed time: {timedelta(seconds=elapsed_s)}\n")
312 smry_logger.write_line(epoch_float, "n/a", history_val=history)
314 update_summary_plot(self)
316 if self.cfg.cleanup_models:
317 self.cleanup_models()
319 return history
321 def predict(self, epoch=None, fileno=None, samples=None):
322 """
323 Make a prediction if it does not exist yet, and return its filepath.
325 Load the model with the lowest validation loss, let it predict on
326 all samples of the validation set
327 in the toml list, and save this prediction together with all the
328 y_values as h5 file(s) in the predictions subfolder.
330 Parameters
331 ----------
332 epoch : int, optional
333 Epoch of a model to load. Default: lowest val loss.
334 fileno : int, optional
335 File number of a model to load. Default: lowest val loss.
336 samples : int, optional
337 Don't use the full validation files, but just the given number
338 of samples.
340 Returns
341 -------
342 pred_filename : List
343 List to the paths of all the prediction files.
345 """
346 epoch, fileno = self._get_auto_epoch(epoch, fileno)
347 if self._check_if_pred_already_done(epoch, fileno):
348 print("Prediction has already been done.")
349 pred_filepaths = self.io.get_pred_files_list(epoch, fileno)
351 else:
352 if self._stored_model is None:
353 model = self.load_saved_model(epoch, fileno, logging=False)
354 else:
355 model = self._stored_model
356 self._set_up(model)
358 start_time = time.time()
359 backend.make_model_prediction(self, model, epoch, fileno, samples=samples)
360 elapsed_s = int(time.time() - start_time)
361 print("Finished predicting on all validation files.")
362 print("Elapsed time: {}\n".format(timedelta(seconds=elapsed_s)))
364 pred_filepaths = self.io.get_pred_files_list(epoch, fileno)
366 return pred_filepaths
368 def inference(self, epoch=None, fileno=None, as_generator=False):
369 """
370 Make an inference and return the filepaths.
372 Load the model with the lowest validation loss, let
373 it predict on all samples of all inference files
374 in the toml list, and save these predictions as h5 files in the
375 predictions subfolder. y values will only be added if they are in
376 the input file, so this can be used on un-labeled data as well.
378 Parameters
379 ----------
380 epoch : int, optional
381 Epoch of a model to load. Default: lowest val loss.
382 fileno : int, optional
383 File number of a model to load. Default: lowest val loss.
384 as_generator : bool
385 If true, return a generator, which yields the output filename
386 after the inference of each file.
387 If false (default), do all files back to back.
389 Returns
390 -------
391 filenames : list
392 List to the paths of all created output files.
394 """
395 gen = self._inference(epoch=epoch, fileno=fileno)
396 if as_generator:
397 return gen
398 else:
399 return [filename for filename in gen]
401 def _inference(self, epoch=None, fileno=None):
402 """Get generator of doing inference, file by file."""
403 epoch, fileno = self._get_auto_epoch(epoch, fileno)
404 if self._stored_model is None:
405 model = self.load_saved_model(epoch, fileno, logging=False)
406 else:
407 model = self._stored_model
408 self._set_up(model)
410 filenames = []
411 for files_dict in self.io.yield_files("inference"):
412 # output filename is based on name of file in first input
413 first_filename = os.path.basename(list(files_dict.values())[0])
414 output_filename = "model_epoch_{}_file_{}_on_{}".format(
415 epoch, fileno, first_filename
416 )
418 output_path = os.path.join(
419 self.io.get_subfolder("inference"), output_filename
420 )
421 filenames.append(output_path)
422 if os.path.exists(output_path):
423 print("File {} exists already, skipping...".format(output_filename))
424 continue
426 print(f"Working on file {first_filename}")
427 start_time = time.time()
428 backend.h5_inference(
429 self, model, files_dict, output_path, use_def_label=False
430 )
431 elapsed_s = int(time.time() - start_time)
432 print(f"Finished on file {first_filename} in {elapsed_s/60} min")
433 yield output_path
435 def inference_on_file(
436 self, input_file, output_file=None, saved_model=None, epoch=None, fileno=None
437 ):
438 """
439 Save the model prediction for each sample of the given input file.
441 Useful for sharing a saved model, since the usual training folder
442 structure is not necessarily required.
444 Parameters
445 ---------
446 input_file : str or dict
447 Path to a DL file on which the inference should be done on.
448 Can also be a dict mapping input names to files.
449 output_file : str, optional
450 Save output to an h5 file with this name. Default: auto-generate
451 name and save in same directory as the input file.
452 saved_model : str, optional
453 Optional path to a saved model, which will be used instead
454 of loading the one with the given epoch/fileno.
455 epoch : int, optional
456 Epoch of a model to load from the directory. Only relevant if
457 saved_model is None. Default: lowest val loss.
458 fileno : int, optional
459 File number of a model to load from the directory. Only relevant
460 if saved_model is None. Default: lowest val loss.
462 Returns
463 -------
464 str
465 Name of the output file.
467 """
468 if saved_model is None:
469 epoch, fileno = self._get_auto_epoch(epoch, fileno)
470 model = self.load_saved_model(epoch, fileno, logging=False)
471 else:
472 model = self._load_model(saved_model)
473 self._set_up(model)
475 if isinstance(input_file, str):
476 input_file = {model.input_names[0]: input_file}
477 if output_file is None:
478 out_path, first_filename = os.path.split(list(input_file.values())[0])
479 output_file = os.path.join(out_path, "dl_pred_{}".format(first_filename))
480 start_time = time.time()
481 backend.h5_inference(
482 orga=self,
483 model=model,
484 files_dict=input_file,
485 output_path=output_file,
486 )
487 elapsed_s = int(time.time() - start_time)
488 print(f"Finished inference in {elapsed_s / 60} min")
489 return output_file
491 def cleanup_models(self):
492 """
493 Delete all models except for the the most recent one (to continue
494 training), and the ones with the highest and lowest loss/metrics.
496 """
497 all_epochs = self.io.get_all_epochs()
498 epochs_to_keep = {
499 self.io.get_latest_epoch(),
500 }
501 try:
502 for metric in self.history.get_metrics():
503 epochs_to_keep.add(
504 self.history.get_best_epoch_fileno(
505 metric=f"val_{metric}", mini=True
506 )
507 )
508 epochs_to_keep.add(
509 self.history.get_best_epoch_fileno(
510 metric=f"val_{metric}", mini=False
511 )
512 )
513 except ValueError:
514 # no best epoch exists
515 pass
517 for epoch in epochs_to_keep:
518 if epoch not in all_epochs:
519 warnings.warn(
520 f"ERROR: keeping_epoch {epoch} not in available epochs {all_epochs}, "
521 f"skipping clean-up of models!"
522 )
523 return
525 print("\nClean-up saved models:")
526 for epoch in all_epochs:
527 model_path = self.io.get_model_path(epoch[0], epoch[1])
528 model_name = os.path.basename(model_path)
529 if epoch in epochs_to_keep:
530 print("Keeping model {}".format(model_name))
531 else:
532 print("Deleting model {}".format(model_name))
533 os.remove(model_path)
535 def _check_if_pred_already_done(self, epoch, fileno):
536 """
537 Checks if the prediction has already been done before.
538 (-> predicted on all validation files)
540 Returns
541 -------
542 pred_done : bool
543 Boolean flag to specify if the prediction has
544 already been fully done or not.
546 """
547 latest_pred_file_no = self.io.get_latest_prediction_file_no(epoch, fileno)
548 total_no_of_val_files = self.io.get_no_of_files("val")
550 if latest_pred_file_no is None:
551 pred_done = False
552 elif latest_pred_file_no == total_no_of_val_files:
553 return True
554 else:
555 pred_done = False
557 return pred_done
559 def _get_auto_epoch(self, epoch, fileno):
560 """Automatically retrieve best epoch/fileno if they are none."""
561 if fileno is None and epoch is None:
562 epoch, fileno = self.history.get_best_epoch_fileno()
563 print("Automatically set epoch to epoch {} file {}.".format(epoch, fileno))
564 elif fileno is None or epoch is None:
565 raise ValueError("Either both or none of epoch and fileno must be None")
566 return epoch, fileno
568 def get_xs_mean(self, logging=False):
569 """
570 Set and return the zero center image for each list input.
572 Requires the cfg.zero_center_folder to be set. If no existing
573 image for the given input files is found in the folder, it will
574 be calculated and saved by averaging over all samples in the
575 train dataset.
577 Parameters
578 ----------
579 logging : bool
580 If true, the execution of this function will be logged into the
581 full summary in the output folder if called for the first time.
583 Returns
584 -------
585 dict
586 Dict of numpy arrays that contains the mean_image of the x dataset
587 (1 array per list input).
588 Example format:
589 { "input_A" : ndarray, "input_B" : ndarray }
591 """
592 if self.xs_mean is None:
593 if self.cfg.zero_center_folder is None:
594 raise ValueError(
595 "Can not calculate zero center: " "No zero center folder given"
596 )
597 self.xs_mean = load_zero_center_data(self, logging=logging)
598 return self.xs_mean
600 def load_saved_model(self, epoch, fileno, logging=False):
601 """
602 Load a saved model.
604 Parameters
605 ----------
606 epoch : int
607 Epoch of the saved model. If both this and fileno are -1,
608 load the most recent model.
609 fileno : int
610 Fileno of the saved model.
611 logging : bool
612 If True, will log this function call into the log.txt file.
614 Returns
615 -------
616 model : keras model
618 """
619 path_of_model = self.io.get_model_path(epoch, fileno)
620 path_loc = self.io.get_model_path(epoch, fileno, local=True)
621 self.io.print_log("Loading saved model: " + path_loc, logging=logging)
622 return self._load_model(path_of_model)
624 def _get_model(self, model, logging=False):
625 """Load most recent saved model or use user model."""
626 latest_epoch = self.io.get_latest_epoch()
628 if latest_epoch is None:
629 # new training, log info about model
630 if model is None:
631 raise ValueError(
632 "You need to provide a compiled keras model "
633 "for the start of the training! (You gave None)"
634 )
636 elif isinstance(model, str):
637 # path to a saved model
638 self.io.print_log("Loading model from " + model, logging=logging)
639 model = self._load_model(model)
641 if logging:
642 self._save_as_json(model)
643 model.summary(print_fn=self.io.print_log)
645 try:
646 plots_folder = self.io.get_subfolder("plots", create=True)
647 tf.keras.utils.plot_model(
648 model, plots_folder + "/model_plot.png", show_shapes=True
649 )
650 except (ImportError, AttributeError) as e:
651 # TODO remove AttributeError once https://github.com/tensorflow/tensorflow/issues/38988 is fixed
652 warnings.warn("Can not plot model: " + str(e))
654 else:
655 # resuming training, load model if it is not given
656 if model is None:
657 model = self.load_saved_model(*latest_epoch, logging=logging)
659 elif isinstance(model, str):
660 # path to a saved model
661 self.io.print_log("Loading model from " + model, logging=logging)
662 model = self._load_model(model)
664 return model
666 def _load_model(self, filepath):
667 """Load from path, with custom objects and parallized."""
668 with self.get_strategy().scope():
669 model = tf.keras.models.load_model(
670 filepath, custom_objects=self.cfg.get_custom_objects()
671 )
672 return model
674 def _save_as_json(self, model):
675 """Save the architecture of a model as json to fixed path."""
676 json_filename = "model_arch.json"
678 json_string = model.to_json(indent=1)
679 model_folder = self.io.get_subfolder("saved_models", create=True)
680 with open(os.path.join(model_folder, json_filename), "w") as f:
681 f.write(json_string)
683 def _set_up(self, model, logging=False):
684 """Necessary setup for training, validating and predicting."""
685 if self.cfg.label_modifier is None:
686 self._setup_auto_lmod(model)
688 if self.cfg.zero_center_folder is not None:
689 self.get_xs_mean(logging)
691 def _setup_auto_lmod(self, model):
692 """Set up the auto label modifier for the given model."""
693 self._auto_label_modifier = lib.label_modifiers.ColumnLabels(model)
695 def val_is_due(self, epoch=None):
696 """
697 True if validation is due on given epoch according to schedule.
698 Does not check if it has been done already.
700 """
701 if epoch is None:
702 epoch = self.io.get_latest_epoch()
703 n_train_files = self.io.get_no_of_files("train")
704 val_sched = (epoch[1] == n_train_files) or (
705 self.cfg.validate_interval is not None
706 and epoch[1] % self.cfg.validate_interval == 0
707 )
708 return val_sched
710 def get_strategy(self):
711 """Get the strategy for distributed training."""
712 if self._strategy is None:
713 if self.cfg.multi_gpu and len(tf.config.list_physical_devices("GPU")) > 1:
714 self._strategy = tf.distribute.MirroredStrategy()
715 print(f"Number of GPUs: {self._strategy.num_replicas_in_sync}")
716 else:
717 self._strategy = tf.distribute.get_strategy()
718 return self._strategy
721class Configuration(object):
722 """
723 Contains all the configurable options in the OrcaNet scripts.
725 All of these public attributes (the ones without a
726 leading underscore) can be changed either directly or with a
727 .toml config file via the method update_config().
729 Parameters
730 ----------
731 output_folder : str
732 Name of the folder of this model in which everything will be saved,
733 e.g., the summary.txt log file is located in here.
734 list_file : str or None
735 Path to a toml list file with pathes to all the h5 files that should
736 be used for training and validation.
737 config_file : str or None
738 Path to a toml config file with attributes that are used instead of
739 the default ones.
740 kwargs
741 Overwrites the values given in the config file.
743 Attributes
744 ----------
745 batchsize : int
746 Batchsize that will be used for the training, validation and inference of
747 the network.
748 During training and validation, the last batch in each file will be
749 skipped if it has fewer samples than the batchsize.
750 callback_train : keras callback or list or None
751 Callback or list of callbacks to use during training.
752 class_weight : dict or None
753 Optional dictionary mapping class indices (integers) to a weight
754 (float) value, used for weighting the loss function (during
755 training only). This can be useful to tell the model to
756 "pay more attention" to samples from an under-represented class.
757 cleanup_models : bool
758 If true, will only keep the best (in terms of val loss) and the most
759 recent from all saved models in order to save disk space.
760 custom_objects : dict, optional
761 Optional dictionary mapping names (strings) to custom classes or
762 functions to be considered by keras during deserialization of models.
763 dataset_modifier : function or None
764 For orga.predict: Function that determines which datasets get created
765 in the resulting h5 file. Default: save as array, i.e. every output layer
766 will get one dataset each for both the label and the prediction,
767 and one dataset containing the y_values from the validation files.
768 fixed_batchsize : bool
769 The last batch in the file might be smaller then the batchsize.
770 Usually, this is no problem, but set to True to skip this batch
771 [default: False].
772 key_x_values : str
773 The name of the datagroup in the h5 input files which contains
774 the samples for the network.
775 key_y_values : str
776 The name of the datagroup in the h5 input files which contains
777 the info for the labels.
778 label_modifier : function or None
779 Operation to be performed on batches of y_values read from the input
780 files before they are fed into the model as labels. If None is given,
781 all y_values with the same name as the output layers will be passed
782 to the model as a dict, with the keys being the dtype names.
783 learning_rate : float, tuple, function, str (optional)
784 The learning rate for the training.
785 If None is given, don't change the learning rate at all.
786 If it is a float: The learning rate will be constantly this value.
787 If it is a tuple of two floats: The first float gives the learning rate
788 in epoch 1 file 1, and the second float gives the decrease of the
789 learning rate per file (e.g. 0.1 for 10% decrease per file).
790 If it is a function: Takes as an input the epoch and the
791 file number (in this order), and returns the learning rate.
792 Both epoch and fileno start at 1, i.e. 1, 1 is the start of the
793 training.
794 If it is a str: Path to a csv file inside the main folder, containing
795 3 columns with the epoch, fileno, and the value the lr will be set
796 to when reaching this epoch/fileno.
797 max_queue_size : int
798 max_queue_size option of the keras training and evaluation generator
799 methods. How many batches get preloaded from the generator.
800 multi_gpu : bool
801 Use all availble GPUs (distributed training if theres more then one).
802 n_events : None or int
803 For testing purposes. If not the whole .h5 file should be used for
804 training, define the number of samples.
805 sample_modifier : function or None
806 Operation to be performed on batches of x_values read from the input
807 files before they are fed into the model as samples.
808 shuffle_train : bool
809 If true, the order in which batches are read out from the files during
810 training are randomized each time they are read out.
811 train_logger_display : int
812 How many batches should be averaged for one line in the training log files.
813 train_logger_flush : int
814 After how many lines the training log file should be flushed (updated on
815 the disk). -1 for flush at the end of the file only.
816 use_scratch_ssd : bool
817 Declares if the input files should be copied to a local temp dir,
818 i.e. the path defined in the 'TMPDIR' environment variable.
819 validate_interval : int or None
820 Validate the model after this many training files have been trained on
821 in an epoch. There will always be a validation at the end of an epoch.
822 None for only validate at the end of an epoch.
823 Example: validate_interval=3 --> Validate after file 3, 6, 9, ...
824 verbose_train : int
825 verbose option of keras.model.fit_generator.
826 0 = silent, 1 = progress bar, 2 = one line per epoch.
827 verbose_val : int
828 verbose option of evaluate_generator.
829 0 = silent, 1 = progress bar.
830 y_field_names : tuple or list or str, optional
831 During train and val, read out only these fields from the y dataset.
832 --> Speed up, especially if there are many fields.
833 zero_center_folder : None or str
834 Path to a folder in which zero centering images are stored.
835 If this path is set, zero centering images for the given dataset will
836 either be calculated and saved automatically at the start of the
837 training, or loaded if they have been saved before.
839 """
841 # TODO add a clober script that properly deletes models + logfiles
842 def __init__(self, output_folder, list_file=None, config_file=None, **kwargs):
843 self.batchsize = 64
844 self.callback_train = []
845 self.class_weight = None
846 self.cleanup_models = False
847 self.custom_objects = {}
848 self.dataset_modifier = None
849 self.fixed_batchsize = False
850 self.key_x_values = "x"
851 self.key_y_values = "y"
852 self.label_modifier = None
853 self.learning_rate = None
854 self.make_weight_plots = False # Removed in v0.11.1
855 self.max_queue_size = 10
856 self.multi_gpu = True
857 self.n_events = None
858 self.sample_modifier = None
859 self.shuffle_train = False
860 self.train_logger_display = 100
861 self.train_logger_flush = -1
862 self.use_scratch_ssd = False
863 self.validate_interval = None
864 self.verbose_train = 1
865 self.verbose_val = 0
866 self.y_field_names = None
867 self.zero_center_folder = None
869 self._default_values = dict(self.__dict__)
871 # Main folder:
872 if output_folder[-1] == "/":
873 self.output_folder = output_folder
874 else:
875 self.output_folder = output_folder + "/"
877 # Private attributes:
878 self._files_dict = {
879 "train": None,
880 "val": None,
881 "inference": None,
882 }
883 self._list_file = None
885 # Load the optionally given list and config files.
886 if list_file is not None:
887 self.import_list_file(list_file)
888 if config_file is not None:
889 self.update_config(config_file)
891 # set given kwargs:
892 for key, val in kwargs.items():
893 if hasattr(self, key):
894 setattr(self, key, val)
895 else:
896 raise AttributeError("Unknown attribute {}".format(key))
898 # deprecation warning TODO remove in the future
899 if self.make_weight_plots:
900 warnings.warn("make_weight_plots was removed in version v0.11.1")
902 def import_list_file(self, list_file):
903 """
904 Import the filepaths of the h5 files from a toml list file.
906 Parameters
907 ----------
908 list_file : str
909 Path to the toml list file.
911 """
912 if self._list_file is not None:
913 raise ValueError(
914 "Can not load list file: Has already been loaded! "
915 "({})".format(self._list_file)
916 )
918 file_content = toml.load(list_file)
920 name_mapping = {
921 "train_files": "train",
922 "validation_files": "val",
923 "inference_files": "inference",
924 }
926 for toml_name, files_dict_name in name_mapping.items():
927 files = _extract_filepaths(file_content, toml_name)
928 self._files_dict[files_dict_name] = files or None
930 self._list_file = list_file
932 def update_config(self, config_file):
933 """
934 Update the default cfg parameters with values from a toml config file.
936 Parameters
937 ----------
938 config_file : str
939 Path to a toml config file.
941 """
942 user_values = toml.load(config_file)["config"]
943 for key, value in user_values.items():
944 if hasattr(self, key):
945 if key == "sample_modifier":
946 value = orcanet.misc.from_register(
947 toml_entry=value, register=lib.sample_modifiers.smods
948 )
949 elif key == "dataset_modifier":
950 value = orcanet.misc.from_register(
951 toml_entry=value, register=lib.dataset_modifiers.dmods
952 )
953 elif key == "label_modifier":
954 value = orcanet.misc.from_register(
955 toml_entry=value, register=lib.label_modifiers.lmods
956 )
957 setattr(self, key, value)
958 else:
959 raise AttributeError(f"Unknown attribute {key} in config file")
961 def get_list_file(self):
962 """
963 Returns the path to the list file that was used to set the training
964 and validation files. None if no list file has been used.
966 """
967 return self._list_file
969 def get_files(self, which):
970 """
971 Get the training or validation file paths for each list input set.
973 Parameters
974 ----------
975 which : str
976 Either "train", "val" or "inference".
978 Returns
979 -------
980 dict
981 A dict containing the paths to the training or validation files on
982 which the model will be trained on. Example for the format for
983 two input sets with two files each:
984 {
985 "input_A" : ('path/to/set_A_file_1.h5', 'path/to/set_A_file_2.h5'),
986 "input_B" : ('path/to/set_B_file_1.h5', 'path/to/set_B_file_2.h5'),
987 }
989 """
990 if which not in self._files_dict.keys():
991 raise NameError("Unknown fileset name ", which)
992 if self._files_dict[which] is None:
993 raise AttributeError("No {} files have been specified!".format(which))
994 return self._files_dict[which]
996 def get_custom_objects(self):
997 """Get user custom objects + orcanet internal ones."""
998 orcanet_co = medgeconv.custom_objects
999 orcanet_loss_functions = lib.losses.loss_functions
1000 return {**orcanet_co, **orcanet_loss_functions, **self.custom_objects}
1003def _get_h5_files(folder):
1004 h5files = []
1005 for f in os.listdir(folder):
1006 if f.endswith(".h5"):
1007 h5files.append(os.path.join(folder, f))
1008 h5files.sort()
1009 if not h5files:
1010 warnings.warn(f"No .h5 files in dir {folder}!")
1011 return h5files
1014def _extract_filepaths(file_content, which):
1015 """
1016 Get train, val or inf filepaths of all inputs from a toml readout.
1017 Makes sure that all input have the same number of files.
1019 """
1020 # alternative names to write in the toml file
1021 aliases = {
1022 "train_files": ("training_files", "train", "training"),
1023 "validation_files": ("val_files", "val", "validation"),
1024 "inference_files": ("inf_files", "inf", "inference"),
1025 }
1026 assert which in aliases.keys(), f"{which} not in {list(aliases.keys())}"
1028 def get_alias(ident):
1029 for k, v in aliases.items():
1030 if ident == k or ident in v:
1031 return k
1032 else:
1033 raise NameError(
1034 f"Unknown argument '{ident}' in toml file: "
1035 f"Must be either of {list(aliases.keys())}"
1036 )
1038 files = {}
1039 n_files = []
1040 for input_name, input_files in file_content.items():
1041 for filetype, filetyp_files in input_files.items():
1042 if get_alias(filetype) != which:
1043 continue
1044 # if a dir is given as a filepath, use all h5 files in that dir instead
1045 expanded_files = []
1046 for path in filetyp_files:
1047 if os.path.isdir(path):
1048 expanded_files.extend(_get_h5_files(path))
1049 else:
1050 expanded_files.append(path)
1051 files[input_name] = tuple(expanded_files)
1052 # store number of files for this output
1053 n_files.append(len(expanded_files))
1055 if n_files and n_files.count(n_files[0]) != len(n_files):
1056 raise ValueError("Input with different number of {} in toml list".format(which))
1058 return files