Coverage for orcanet/utilities/nn_utilities.py: 92%
118 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""Utility functions used for training a NN."""
5import warnings
6import numpy as np
7import h5py
8import os
9import time
10import tensorflow.keras as ks
11from functools import reduce
14class RaiseOnNaN(ks.callbacks.Callback):
15 """
16 Callback that terminates training when a NaN loss is encountered.
17 """
19 def on_batch_end(self, batch, logs=None):
20 logs = logs or {}
21 loss = logs.get("loss")
22 if loss is not None:
23 if np.isnan(loss) or np.isinf(loss):
24 warnings.warn(f"Input values:\n{batch}\n\nLogs:\n{logs}")
25 raise ValueError(
26 f"Batch {batch}: Invalid loss {loss}, terminating training"
27 )
30class TimeModel(ks.callbacks.Callback):
31 """Print how long the model took for processing batches."""
33 def __init__(self, print_func=None):
34 super().__init__()
35 self.print_func = print_func
36 self._total_time = 0.0
37 self._total_batches = 0
38 self._t_start = 0.0
40 def start_time(self):
41 self._t_start = time.time()
43 def stop_time(self):
44 self._total_time += time.time() - self._t_start
45 self._total_batches += 1
47 def print_stats(self):
48 if self.print_func is None:
49 print_func = print
50 else:
51 print_func = self.print_func
52 print_func("Statistics of model calculations:")
53 print_func(f"\tTotal time:\t{self._total_time/60:.2f} min")
54 if self._total_batches != 0:
55 print_func(
56 f"\tPer batch:\t" f"{1000 * self._total_time/self._total_batches:.5} ms"
57 )
59 def on_train_batch_begin(self, batch, logs=None):
60 self.start_time()
62 def on_test_batch_begin(self, batch, logs=None):
63 self.start_time()
65 def on_predict_batch_begin(self, batch, logs=None):
66 self.start_time()
68 def on_train_batch_end(self, batch, logs=None):
69 self.stop_time()
71 def on_test_batch_end(self, batch, logs=None):
72 self.stop_time()
74 def on_predict_batch_end(self, batch, logs=None):
75 self.stop_time()
77 def on_epoch_end(self, epoch, logs=None):
78 self.print_stats()
81# ------------- Zero center functions -------------#
84def load_zero_center_data(orga, logging=False):
85 """
86 Gets the xs_mean array(s) that can be used for zero-centering.
88 The arrays are either loaded from a previously saved .npz file or they
89 are calculated on the fly by calculating the mean value per bin for the
90 given training files. The name of the saved image is derived from the
91 name of the list file which was given to the cfg.
93 Parameters
94 ----------
95 orga : orcanet.core.Organizer
96 Contains all the configurable options in the OrcaNet scripts.
97 logging : bool
98 If true, will log the execution of this function into the
99 full summary in the output folder.
101 Returns
102 -------
103 xs_mean : dict
104 Dict of ndarray(s) that contains the mean_image of the x dataset
105 (1 array per list input). Can be used for zero-centering later on.
106 Example format:
107 { "input_A" : ndarray, "input_B" : ndarray }
109 """
110 all_train_files = orga.cfg.get_files("train")
111 zero_center_folder = orga.cfg.zero_center_folder
112 if not zero_center_folder.endswith("/"):
113 zero_center_folder += "/"
114 train_files_list_name = os.path.basename(orga.cfg.get_list_file())
115 key_samples = orga.cfg.key_x_values
117 orga.io.print_log("Zero centering", logging)
118 orga.io.print_log("--------------", logging)
119 orga.io.print_log("Zero center folder: " + zero_center_folder, logging)
121 xs_mean = {}
122 for input_key, train_filepaths in all_train_files.items():
123 xs_mean_path = get_xs_mean_path(zero_center_folder, train_filepaths)
125 if xs_mean_path is not None:
126 orga.io.print_log(
127 "{}: Loading saved zero centering".format(input_key), logging
128 )
129 xs_mean_ip_i = np.load(xs_mean_path)["xs_mean"]
130 orga.io.print_log(
131 "\tLoaded file: {}".format(os.path.basename(xs_mean_path)), logging
132 )
134 else:
135 orga.io.print_log(
136 "{}: Making new zero centering".format(input_key), logging
137 )
139 xs_mean_ip_i = make_xs_mean(train_filepaths, key_samples)
140 filename = (
141 zero_center_folder
142 + train_files_list_name
143 + "_input_"
144 + str(input_key)
145 + ".npz"
146 )
147 np.savez(
148 filename,
149 xs_mean=xs_mean_ip_i,
150 zero_center_used_ip_files=train_filepaths,
151 )
153 orga.io.print_log(
154 "\tSaved as {} with shape {}".format(
155 os.path.basename(filename), xs_mean_ip_i.shape
156 ),
157 logging,
158 )
160 xs_mean[input_key] = xs_mean_ip_i
162 orga.io.print_log("", logging)
163 return xs_mean
166def get_xs_mean_path(zero_center_folder, train_filepaths):
167 """
168 Search for precalculated xs_mean arrays in the zero_center_folder.
170 The function opens every .npz file in the zero center folder and checks
171 if the files used to generate this xs_mean (stored as subarray
172 'zero_center_used_ip_files') is the same as the given train_filepaths.
174 Parameters
175 ----------
176 zero_center_folder : str
177 Full path to the folder where the zero_centering arrays are stored.
178 train_filepaths : list
179 The filepaths of all train_files.
181 Returns
182 -------
183 xs_mean_path : None/ndarray
184 The zero center filepath for the given train_filepaths if
185 it exists in the zero_center_files. If not, returns None.
187 """
188 xs_mean_path = None
190 if not os.path.isdir(zero_center_folder):
191 os.mkdir(zero_center_folder)
193 for file in os.listdir(zero_center_folder):
194 if not file.endswith(".npz"):
195 continue
196 file = zero_center_folder + file
197 used_ip_files = np.load(file)["zero_center_used_ip_files"]
198 if np.array_equal(used_ip_files, train_filepaths):
199 xs_mean_path = file
200 break
202 return xs_mean_path
205def make_xs_mean(filepaths, key_samples, total_memory=4e9):
206 """
207 Calculates the zero center image of a dataset.
209 Calculating still works if xs is larger than the available memory
210 and also if the file is compressed.
212 Parameters
213 ----------
214 filepaths : List
215 Filepaths of the data files with the samples for which the
216 mean_image will be calculated.
217 key_samples : str
218 The name of the datagroup in your h5 input files which contains
219 the samples to the network.
220 total_memory : int
221 check available memory and divide the mean calculation in steps
222 total_memory = 4e9 # * n_gpu # In bytes.
223 Take max. 1/2 of what is available per GPU (16G), just to make sure.
225 Returns
226 -------
227 xs_mean : ndarray
228 The zero center image.
230 """
231 xs_means = []
232 file_sizes = []
234 for filepath in filepaths:
236 with h5py.File(filepath, "r") as file:
237 filesize = get_array_memsize(file["x"])
238 steps = int(np.ceil(filesize / total_memory))
239 n_rows = file[key_samples].shape[0]
240 stepsize = int(n_rows / float(steps))
242 # create xs_mean_arr that stores intermediate mean_temp results
243 xs_mean_arr = np.zeros((steps,) + file["x"].shape[1:], dtype=np.float64)
244 print("\tCalculating for file: " + filepath)
245 for i in range(steps):
246 if i % 5 == 0:
247 print("\t Step " + str(i) + " of " + str(steps))
249 # for the last step, calculate mean till the end of the file
250 if i == steps - 1 or steps == 1:
251 xs_mean_temp = np.mean(
252 file[key_samples][i * stepsize : n_rows],
253 axis=0,
254 dtype=np.float64,
255 )
256 else:
257 xs_mean_temp = np.mean(
258 file[key_samples][i * stepsize : (i + 1) * stepsize],
259 axis=0,
260 dtype=np.float64,
261 )
263 xs_mean_arr[i] = xs_mean_temp
265 print("\tDone!")
266 # The mean for this file
267 xs_means.append(
268 np.mean(xs_mean_arr, axis=0, dtype=np.float64).astype(np.float32)
269 )
270 # the number of samples in this file
271 file_sizes.append(n_rows)
273 # calculate weighted average depending on no of samples in the files
274 file_sizes = [size / np.sum(file_sizes) for size in file_sizes]
275 xs_mean = np.average(xs_means, weights=file_sizes, axis=0)
276 return xs_mean
279def get_array_memsize(array):
280 """
281 Calculates the approximate memory size of an array.
282 :param ndarray array: an array.
283 :return: float memsize: size of the array in bytes.
284 """
285 shape = array.shape
286 n_numbers = reduce(lambda x, y: x * y, shape) # number of entries in an array
287 precision = 8 # Precision of each entry, typically uint8 for xs datasets
288 memsize = (n_numbers * precision) / float(8) # in bytes
290 return memsize