orcanet.utilities.nn_utilities

Utility functions used for training a NN.

Module Contents

Classes

RaiseOnNaN

Callback that terminates training when a NaN loss is encountered.

TimeModel

Print how long the model took for processing batches.

Functions

load_zero_center_data(orga[, logging])

Gets the xs_mean array(s) that can be used for zero-centering.

get_xs_mean_path(zero_center_folder, train_filepaths)

Search for precalculated xs_mean arrays in the zero_center_folder.

make_xs_mean(filepaths, key_samples[, total_memory])

Calculates the zero center image of a dataset.

get_array_memsize(array)

Calculates the approximate memory size of an array.

class orcanet.utilities.nn_utilities.RaiseOnNaN[source]

Callback that terminates training when a NaN loss is encountered.

on_batch_end(batch, logs=None)[source]

A backwards compatibility alias for on_train_batch_end.

class orcanet.utilities.nn_utilities.TimeModel(print_func=None)[source]

Print how long the model took for processing batches.

start_time()[source]
stop_time()[source]
print_stats()[source]
on_train_batch_begin(batch, logs=None)[source]

Called at the beginning of a training batch in fit methods.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in tf.keras.Model is set to N, this method will only be called every N batches.

Args:

batch: Integer, index of batch within the current epoch. logs: Dict, contains the return value of model.train_step. Typically,

the values of the Model’s metrics are returned. Example: {‘loss’: 0.2, ‘accuracy’: 0.7}.

on_test_batch_begin(batch, logs=None)[source]

Called at the beginning of a batch in evaluate methods.

Also called at the beginning of a validation batch in the fit methods, if validation data is provided.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in tf.keras.Model is set to N, this method will only be called every N batches.

Args:

batch: Integer, index of batch within the current epoch. logs: Dict, contains the return value of model.test_step. Typically,

the values of the Model’s metrics are returned. Example: {‘loss’: 0.2, ‘accuracy’: 0.7}.

on_predict_batch_begin(batch, logs=None)[source]

Called at the beginning of a batch in predict methods.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in tf.keras.Model is set to N, this method will only be called every N batches.

Args:

batch: Integer, index of batch within the current epoch. logs: Dict, contains the return value of model.predict_step,

it typically returns a dict with a key ‘outputs’ containing the model’s outputs.

on_train_batch_end(batch, logs=None)[source]

Called at the end of a training batch in fit methods.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in tf.keras.Model is set to N, this method will only be called every N batches.

Args:

batch: Integer, index of batch within the current epoch. logs: Dict. Aggregated metric results up until this batch.

on_test_batch_end(batch, logs=None)[source]

Called at the end of a batch in evaluate methods.

Also called at the end of a validation batch in the fit methods, if validation data is provided.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in tf.keras.Model is set to N, this method will only be called every N batches.

Args:

batch: Integer, index of batch within the current epoch. logs: Dict. Aggregated metric results up until this batch.

on_predict_batch_end(batch, logs=None)[source]

Called at the end of a batch in predict methods.

Subclasses should override for any actions to run.

Note that if the steps_per_execution argument to compile in tf.keras.Model is set to N, this method will only be called every N batches.

Args:

batch: Integer, index of batch within the current epoch. logs: Dict. Aggregated metric results up until this batch.

on_epoch_end(epoch, logs=None)[source]

Called at the end of an epoch.

Subclasses should override for any actions to run. This function should only be called during TRAIN mode.

Args:

epoch: Integer, index of epoch. logs: Dict, metric results for this training epoch, and for the

validation epoch if validation is performed. Validation result keys are prefixed with val_. For training epoch, the values of the

Model’s metrics are returned. Example`{‘loss’: 0.2, ‘accuracy’:

0.7}`.

orcanet.utilities.nn_utilities.load_zero_center_data(orga, logging=False)[source]

Gets the xs_mean array(s) that can be used for zero-centering.

The arrays are either loaded from a previously saved .npz file or they are calculated on the fly by calculating the mean value per bin for the given training files. The name of the saved image is derived from the name of the list file which was given to the cfg.

Parameters
orgaorcanet.core.Organizer

Contains all the configurable options in the OrcaNet scripts.

loggingbool

If true, will log the execution of this function into the full summary in the output folder.

Returns
xs_meandict

Dict of ndarray(s) that contains the mean_image of the x dataset (1 array per list input). Can be used for zero-centering later on. Example format: { “input_A” : ndarray, “input_B” : ndarray }

orcanet.utilities.nn_utilities.get_xs_mean_path(zero_center_folder, train_filepaths)[source]

Search for precalculated xs_mean arrays in the zero_center_folder.

The function opens every .npz file in the zero center folder and checks if the files used to generate this xs_mean (stored as subarray ‘zero_center_used_ip_files’) is the same as the given train_filepaths.

Parameters
zero_center_folderstr

Full path to the folder where the zero_centering arrays are stored.

train_filepathslist

The filepaths of all train_files.

Returns
xs_mean_pathNone/ndarray

The zero center filepath for the given train_filepaths if it exists in the zero_center_files. If not, returns None.

orcanet.utilities.nn_utilities.make_xs_mean(filepaths, key_samples, total_memory=4000000000.0)[source]

Calculates the zero center image of a dataset.

Calculating still works if xs is larger than the available memory and also if the file is compressed.

Parameters
filepathsList

Filepaths of the data files with the samples for which the mean_image will be calculated.

key_samplesstr

The name of the datagroup in your h5 input files which contains the samples to the network.

total_memoryint

check available memory and divide the mean calculation in steps total_memory = 4e9 # * n_gpu # In bytes. Take max. 1/2 of what is available per GPU (16G), just to make sure.

Returns
xs_meanndarray

The zero center image.

orcanet.utilities.nn_utilities.get_array_memsize(array)[source]

Calculates the approximate memory size of an array. :param ndarray array: an array. :return: float memsize: size of the array in bytes.