:py:mod:`orcanet.utilities.nn_utilities` ======================================== .. py:module:: orcanet.utilities.nn_utilities .. autoapi-nested-parse:: Utility functions used for training a NN. .. !! processed by numpydoc !! Module Contents --------------- Classes ~~~~~~~ .. autoapisummary:: orcanet.utilities.nn_utilities.RaiseOnNaN orcanet.utilities.nn_utilities.TimeModel Functions ~~~~~~~~~ .. autoapisummary:: orcanet.utilities.nn_utilities.load_zero_center_data orcanet.utilities.nn_utilities.get_xs_mean_path orcanet.utilities.nn_utilities.make_xs_mean orcanet.utilities.nn_utilities.get_array_memsize .. py:class:: RaiseOnNaN Callback that terminates training when a NaN loss is encountered. .. !! processed by numpydoc !! .. py:method:: on_batch_end(batch, logs=None) A backwards compatibility alias for `on_train_batch_end`. .. !! processed by numpydoc !! .. py:class:: TimeModel(print_func=None) Print how long the model took for processing batches. .. !! processed by numpydoc !! .. py:method:: start_time() .. py:method:: stop_time() .. py:method:: print_stats() .. py:method:: on_train_batch_begin(batch, logs=None) Called at the beginning of a training batch in `fit` methods. Subclasses should override for any actions to run. Note that if the `steps_per_execution` argument to `compile` in `tf.keras.Model` is set to `N`, this method will only be called every `N` batches. Args: batch: Integer, index of batch within the current epoch. logs: Dict, contains the return value of `model.train_step`. Typically, the values of the `Model`'s metrics are returned. Example: `{'loss': 0.2, 'accuracy': 0.7}`. .. !! processed by numpydoc !! .. py:method:: on_test_batch_begin(batch, logs=None) Called at the beginning of a batch in `evaluate` methods. Also called at the beginning of a validation batch in the `fit` methods, if validation data is provided. Subclasses should override for any actions to run. Note that if the `steps_per_execution` argument to `compile` in `tf.keras.Model` is set to `N`, this method will only be called every `N` batches. Args: batch: Integer, index of batch within the current epoch. logs: Dict, contains the return value of `model.test_step`. Typically, the values of the `Model`'s metrics are returned. Example: `{'loss': 0.2, 'accuracy': 0.7}`. .. !! processed by numpydoc !! .. py:method:: on_predict_batch_begin(batch, logs=None) Called at the beginning of a batch in `predict` methods. Subclasses should override for any actions to run. Note that if the `steps_per_execution` argument to `compile` in `tf.keras.Model` is set to `N`, this method will only be called every `N` batches. Args: batch: Integer, index of batch within the current epoch. logs: Dict, contains the return value of `model.predict_step`, it typically returns a dict with a key 'outputs' containing the model's outputs. .. !! processed by numpydoc !! .. py:method:: on_train_batch_end(batch, logs=None) Called at the end of a training batch in `fit` methods. Subclasses should override for any actions to run. Note that if the `steps_per_execution` argument to `compile` in `tf.keras.Model` is set to `N`, this method will only be called every `N` batches. Args: batch: Integer, index of batch within the current epoch. logs: Dict. Aggregated metric results up until this batch. .. !! processed by numpydoc !! .. py:method:: on_test_batch_end(batch, logs=None) Called at the end of a batch in `evaluate` methods. Also called at the end of a validation batch in the `fit` methods, if validation data is provided. Subclasses should override for any actions to run. Note that if the `steps_per_execution` argument to `compile` in `tf.keras.Model` is set to `N`, this method will only be called every `N` batches. Args: batch: Integer, index of batch within the current epoch. logs: Dict. Aggregated metric results up until this batch. .. !! processed by numpydoc !! .. py:method:: on_predict_batch_end(batch, logs=None) Called at the end of a batch in `predict` methods. Subclasses should override for any actions to run. Note that if the `steps_per_execution` argument to `compile` in `tf.keras.Model` is set to `N`, this method will only be called every `N` batches. Args: batch: Integer, index of batch within the current epoch. logs: Dict. Aggregated metric results up until this batch. .. !! processed by numpydoc !! .. py:method:: on_epoch_end(epoch, logs=None) Called at the end of an epoch. Subclasses should override for any actions to run. This function should only be called during TRAIN mode. Args: epoch: Integer, index of epoch. logs: Dict, metric results for this training epoch, and for the validation epoch if validation is performed. Validation result keys are prefixed with `val_`. For training epoch, the values of the `Model`'s metrics are returned. Example : `{'loss': 0.2, 'accuracy': 0.7}`. .. !! processed by numpydoc !! .. py:function:: load_zero_center_data(orga, logging=False) Gets the xs_mean array(s) that can be used for zero-centering. The arrays are either loaded from a previously saved .npz file or they are calculated on the fly by calculating the mean value per bin for the given training files. The name of the saved image is derived from the name of the list file which was given to the cfg. :Parameters: **orga** : orcanet.core.Organizer Contains all the configurable options in the OrcaNet scripts. **logging** : bool If true, will log the execution of this function into the full summary in the output folder. :Returns: **xs_mean** : dict Dict of ndarray(s) that contains the mean_image of the x dataset (1 array per list input). Can be used for zero-centering later on. Example format: { "input_A" : ndarray, "input_B" : ndarray } .. !! processed by numpydoc !! .. py:function:: get_xs_mean_path(zero_center_folder, train_filepaths) Search for precalculated xs_mean arrays in the zero_center_folder. The function opens every .npz file in the zero center folder and checks if the files used to generate this xs_mean (stored as subarray 'zero_center_used_ip_files') is the same as the given train_filepaths. :Parameters: **zero_center_folder** : str Full path to the folder where the zero_centering arrays are stored. **train_filepaths** : list The filepaths of all train_files. :Returns: **xs_mean_path** : None/ndarray The zero center filepath for the given train_filepaths if it exists in the zero_center_files. If not, returns None. .. !! processed by numpydoc !! .. py:function:: make_xs_mean(filepaths, key_samples, total_memory=4000000000.0) Calculates the zero center image of a dataset. Calculating still works if xs is larger than the available memory and also if the file is compressed. :Parameters: **filepaths** : List Filepaths of the data files with the samples for which the mean_image will be calculated. **key_samples** : str The name of the datagroup in your h5 input files which contains the samples to the network. **total_memory** : int check available memory and divide the mean calculation in steps total_memory = 4e9 # * n_gpu # In bytes. Take max. 1/2 of what is available per GPU (16G), just to make sure. :Returns: **xs_mean** : ndarray The zero center image. .. !! processed by numpydoc !! .. py:function:: get_array_memsize(array) Calculates the approximate memory size of an array. :param ndarray array: an array. :return: float memsize: size of the array in bytes. .. !! processed by numpydoc !!