Coverage for orcanet/utilities/nn_utilities.py: 92%

118 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-03-28 14:22 +0000

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3"""Utility functions used for training a NN.""" 

4 

5import warnings 

6import numpy as np 

7import h5py 

8import os 

9import time 

10import tensorflow.keras as ks 

11from functools import reduce 

12 

13 

14class RaiseOnNaN(ks.callbacks.Callback): 

15 """ 

16 Callback that terminates training when a NaN loss is encountered. 

17 """ 

18 

19 def on_batch_end(self, batch, logs=None): 

20 logs = logs or {} 

21 loss = logs.get("loss") 

22 if loss is not None: 

23 if np.isnan(loss) or np.isinf(loss): 

24 warnings.warn(f"Input values:\n{batch}\n\nLogs:\n{logs}") 

25 raise ValueError( 

26 f"Batch {batch}: Invalid loss {loss}, terminating training" 

27 ) 

28 

29 

30class TimeModel(ks.callbacks.Callback): 

31 """Print how long the model took for processing batches.""" 

32 

33 def __init__(self, print_func=None): 

34 super().__init__() 

35 self.print_func = print_func 

36 self._total_time = 0.0 

37 self._total_batches = 0 

38 self._t_start = 0.0 

39 

40 def start_time(self): 

41 self._t_start = time.time() 

42 

43 def stop_time(self): 

44 self._total_time += time.time() - self._t_start 

45 self._total_batches += 1 

46 

47 def print_stats(self): 

48 if self.print_func is None: 

49 print_func = print 

50 else: 

51 print_func = self.print_func 

52 print_func("Statistics of model calculations:") 

53 print_func(f"\tTotal time:\t{self._total_time/60:.2f} min") 

54 if self._total_batches != 0: 

55 print_func( 

56 f"\tPer batch:\t" f"{1000 * self._total_time/self._total_batches:.5} ms" 

57 ) 

58 

59 def on_train_batch_begin(self, batch, logs=None): 

60 self.start_time() 

61 

62 def on_test_batch_begin(self, batch, logs=None): 

63 self.start_time() 

64 

65 def on_predict_batch_begin(self, batch, logs=None): 

66 self.start_time() 

67 

68 def on_train_batch_end(self, batch, logs=None): 

69 self.stop_time() 

70 

71 def on_test_batch_end(self, batch, logs=None): 

72 self.stop_time() 

73 

74 def on_predict_batch_end(self, batch, logs=None): 

75 self.stop_time() 

76 

77 def on_epoch_end(self, epoch, logs=None): 

78 self.print_stats() 

79 

80 

81# ------------- Zero center functions -------------# 

82 

83 

84def load_zero_center_data(orga, logging=False): 

85 """ 

86 Gets the xs_mean array(s) that can be used for zero-centering. 

87 

88 The arrays are either loaded from a previously saved .npz file or they 

89 are calculated on the fly by calculating the mean value per bin for the 

90 given training files. The name of the saved image is derived from the 

91 name of the list file which was given to the cfg. 

92 

93 Parameters 

94 ---------- 

95 orga : orcanet.core.Organizer 

96 Contains all the configurable options in the OrcaNet scripts. 

97 logging : bool 

98 If true, will log the execution of this function into the 

99 full summary in the output folder. 

100 

101 Returns 

102 ------- 

103 xs_mean : dict 

104 Dict of ndarray(s) that contains the mean_image of the x dataset 

105 (1 array per list input). Can be used for zero-centering later on. 

106 Example format: 

107 { "input_A" : ndarray, "input_B" : ndarray } 

108 

109 """ 

110 all_train_files = orga.cfg.get_files("train") 

111 zero_center_folder = orga.cfg.zero_center_folder 

112 if not zero_center_folder.endswith("/"): 

113 zero_center_folder += "/" 

114 train_files_list_name = os.path.basename(orga.cfg.get_list_file()) 

115 key_samples = orga.cfg.key_x_values 

116 

117 orga.io.print_log("Zero centering", logging) 

118 orga.io.print_log("--------------", logging) 

119 orga.io.print_log("Zero center folder: " + zero_center_folder, logging) 

120 

121 xs_mean = {} 

122 for input_key, train_filepaths in all_train_files.items(): 

123 xs_mean_path = get_xs_mean_path(zero_center_folder, train_filepaths) 

124 

125 if xs_mean_path is not None: 

126 orga.io.print_log( 

127 "{}: Loading saved zero centering".format(input_key), logging 

128 ) 

129 xs_mean_ip_i = np.load(xs_mean_path)["xs_mean"] 

130 orga.io.print_log( 

131 "\tLoaded file: {}".format(os.path.basename(xs_mean_path)), logging 

132 ) 

133 

134 else: 

135 orga.io.print_log( 

136 "{}: Making new zero centering".format(input_key), logging 

137 ) 

138 

139 xs_mean_ip_i = make_xs_mean(train_filepaths, key_samples) 

140 filename = ( 

141 zero_center_folder 

142 + train_files_list_name 

143 + "_input_" 

144 + str(input_key) 

145 + ".npz" 

146 ) 

147 np.savez( 

148 filename, 

149 xs_mean=xs_mean_ip_i, 

150 zero_center_used_ip_files=train_filepaths, 

151 ) 

152 

153 orga.io.print_log( 

154 "\tSaved as {} with shape {}".format( 

155 os.path.basename(filename), xs_mean_ip_i.shape 

156 ), 

157 logging, 

158 ) 

159 

160 xs_mean[input_key] = xs_mean_ip_i 

161 

162 orga.io.print_log("", logging) 

163 return xs_mean 

164 

165 

166def get_xs_mean_path(zero_center_folder, train_filepaths): 

167 """ 

168 Search for precalculated xs_mean arrays in the zero_center_folder. 

169 

170 The function opens every .npz file in the zero center folder and checks 

171 if the files used to generate this xs_mean (stored as subarray 

172 'zero_center_used_ip_files') is the same as the given train_filepaths. 

173 

174 Parameters 

175 ---------- 

176 zero_center_folder : str 

177 Full path to the folder where the zero_centering arrays are stored. 

178 train_filepaths : list 

179 The filepaths of all train_files. 

180 

181 Returns 

182 ------- 

183 xs_mean_path : None/ndarray 

184 The zero center filepath for the given train_filepaths if 

185 it exists in the zero_center_files. If not, returns None. 

186 

187 """ 

188 xs_mean_path = None 

189 

190 if not os.path.isdir(zero_center_folder): 

191 os.mkdir(zero_center_folder) 

192 

193 for file in os.listdir(zero_center_folder): 

194 if not file.endswith(".npz"): 

195 continue 

196 file = zero_center_folder + file 

197 used_ip_files = np.load(file)["zero_center_used_ip_files"] 

198 if np.array_equal(used_ip_files, train_filepaths): 

199 xs_mean_path = file 

200 break 

201 

202 return xs_mean_path 

203 

204 

205def make_xs_mean(filepaths, key_samples, total_memory=4e9): 

206 """ 

207 Calculates the zero center image of a dataset. 

208 

209 Calculating still works if xs is larger than the available memory 

210 and also if the file is compressed. 

211 

212 Parameters 

213 ---------- 

214 filepaths : List 

215 Filepaths of the data files with the samples for which the 

216 mean_image will be calculated. 

217 key_samples : str 

218 The name of the datagroup in your h5 input files which contains 

219 the samples to the network. 

220 total_memory : int 

221 check available memory and divide the mean calculation in steps 

222 total_memory = 4e9 # * n_gpu # In bytes. 

223 Take max. 1/2 of what is available per GPU (16G), just to make sure. 

224 

225 Returns 

226 ------- 

227 xs_mean : ndarray 

228 The zero center image. 

229 

230 """ 

231 xs_means = [] 

232 file_sizes = [] 

233 

234 for filepath in filepaths: 

235 

236 with h5py.File(filepath, "r") as file: 

237 filesize = get_array_memsize(file["x"]) 

238 steps = int(np.ceil(filesize / total_memory)) 

239 n_rows = file[key_samples].shape[0] 

240 stepsize = int(n_rows / float(steps)) 

241 

242 # create xs_mean_arr that stores intermediate mean_temp results 

243 xs_mean_arr = np.zeros((steps,) + file["x"].shape[1:], dtype=np.float64) 

244 print("\tCalculating for file: " + filepath) 

245 for i in range(steps): 

246 if i % 5 == 0: 

247 print("\t Step " + str(i) + " of " + str(steps)) 

248 

249 # for the last step, calculate mean till the end of the file 

250 if i == steps - 1 or steps == 1: 

251 xs_mean_temp = np.mean( 

252 file[key_samples][i * stepsize : n_rows], 

253 axis=0, 

254 dtype=np.float64, 

255 ) 

256 else: 

257 xs_mean_temp = np.mean( 

258 file[key_samples][i * stepsize : (i + 1) * stepsize], 

259 axis=0, 

260 dtype=np.float64, 

261 ) 

262 

263 xs_mean_arr[i] = xs_mean_temp 

264 

265 print("\tDone!") 

266 # The mean for this file 

267 xs_means.append( 

268 np.mean(xs_mean_arr, axis=0, dtype=np.float64).astype(np.float32) 

269 ) 

270 # the number of samples in this file 

271 file_sizes.append(n_rows) 

272 

273 # calculate weighted average depending on no of samples in the files 

274 file_sizes = [size / np.sum(file_sizes) for size in file_sizes] 

275 xs_mean = np.average(xs_means, weights=file_sizes, axis=0) 

276 return xs_mean 

277 

278 

279def get_array_memsize(array): 

280 """ 

281 Calculates the approximate memory size of an array. 

282 :param ndarray array: an array. 

283 :return: float memsize: size of the array in bytes. 

284 """ 

285 shape = array.shape 

286 n_numbers = reduce(lambda x, y: x * y, shape) # number of entries in an array 

287 precision = 8 # Precision of each entry, typically uint8 for xs datasets 

288 memsize = (n_numbers * precision) / float(8) # in bytes 

289 

290 return memsize