Coverage for orcanet/backend.py: 88%

111 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-03-28 14:22 +0000

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3""" 

4Code for training and validating NN's, as well as evaluating them. 

5""" 

6import time 

7import h5py 

8import numpy as np 

9import os 

10 

11import orcanet 

12from orcanet.logging import BatchLogger 

13import orcanet.utilities.nn_utilities as nn_utilities 

14from orcanet.in_out import h5_get_number_of_rows 

15import orcanet.h5_generator as h5_generator 

16import orcanet.lib.dataset_modifiers as dataset_modifiers 

17 

18 

19def train_model(orga, model, epoch, batch_logger=False): 

20 """ 

21 Train a model on one file and return the history. 

22 

23 Parameters 

24 ---------- 

25 orga : orcanet.core.Organizer 

26 Contains all the configurable options in the OrcaNet scripts. 

27 model : keras.Model 

28 A compiled keras model. 

29 epoch : tuple 

30 Current epoch and the no of the file to train on. 

31 batch_logger : bool 

32 Use the orcanet batchlogger to log the training. 

33 

34 Returns 

35 ------- 

36 history : dict 

37 The history of the training on this file. A record of training 

38 loss values and metrics values. 

39 

40 """ 

41 callbacks = [ 

42 nn_utilities.RaiseOnNaN(), 

43 nn_utilities.TimeModel(print_func=orga.io.print_log), 

44 ] 

45 if batch_logger: 

46 callbacks.append(BatchLogger(orga, epoch)) 

47 if orga.cfg.callback_train is not None: 

48 try: 

49 callbacks.extend(orga.cfg.callback_train) 

50 except TypeError: 

51 callbacks.append(orga.cfg.callback_train) 

52 

53 training_generator = h5_generator.get_h5_generator( 

54 orga, 

55 files_dict=orga.io.get_file("train", epoch[1]), 

56 f_size=orga.cfg.n_events, 

57 phase="training", 

58 zero_center=orga.cfg.zero_center_folder is not None, 

59 shuffle=orga.cfg.shuffle_train, 

60 ) 

61 # status tf.2.5: In order to use ragged Tensors as input to fit, 

62 # we have to use a tf dataset and not a generator 

63 dataset = h5_generator.make_dataset(training_generator) 

64 

65 history = model.fit( 

66 dataset, 

67 steps_per_epoch=len(training_generator), 

68 verbose=orga.cfg.verbose_train, 

69 max_queue_size=orga.cfg.max_queue_size, 

70 callbacks=callbacks, 

71 initial_epoch=epoch[0] - 1, 

72 epochs=epoch[0], 

73 ) 

74 training_generator.print_timestats(print_func=orga.io.print_log) 

75 # get a dict with losses and metrics 

76 # only trained for one epoch, so value is list of len 1 

77 history = {key: value[0] for key, value in history.history.items()} 

78 return history 

79 

80 

81def validate_model(orga, model): 

82 """ 

83 Validates a model on all validation files and return the history. 

84 

85 Parameters 

86 ---------- 

87 orga : orcanet.core.Organizer 

88 Contains all the configurable options in the OrcaNet scripts. 

89 model : keras.Model 

90 A compiled keras model. 

91 

92 Returns 

93 ------- 

94 history : dict 

95 The history of the validation on all files. A record of validation 

96 loss values and metrics values. 

97 

98 """ 

99 # One history for each val file 

100 histories = [] 

101 f_sizes = orga.io.get_file_sizes("val") 

102 

103 for i, files_dict in enumerate(orga.io.yield_files("val")): 

104 val_generator = h5_generator.get_h5_generator( 

105 orga, 

106 files_dict, 

107 f_size=orga.cfg.n_events, 

108 phase="validation", 

109 zero_center=orga.cfg.zero_center_folder is not None, 

110 ) 

111 # status tf.2.5: In order to use ragged Tensors as input to fit, 

112 # we have to use a tf dataset and not a generator 

113 dataset = h5_generator.make_dataset(val_generator) 

114 history_file = model.evaluate( 

115 dataset, 

116 steps=len(val_generator), 

117 max_queue_size=orga.cfg.max_queue_size, 

118 verbose=orga.cfg.verbose_val, 

119 ) 

120 if not isinstance(history_file, list): 

121 history_file = [ 

122 history_file, 

123 ] 

124 histories.append(history_file) 

125 

126 # average over all val files 

127 history = weighted_average(histories, f_sizes) 

128 # This history is just a list, not a dict like with fit_generator 

129 # so transform to dict 

130 history = dict(zip(model.metrics_names, history)) 

131 

132 return history 

133 

134 

135def weighted_average(histories, f_sizes): 

136 """ 

137 Average multiple histories, weighted with the file size. 

138 

139 Each history can have multiple metrics, which are averaged seperatly. 

140 

141 Parameters 

142 ---------- 

143 histories : List 

144 List of histories, one for each file. Each history is also 

145 a list: each entry is a different loss or metric. 

146 f_sizes : List 

147 List of the file sizes, in the same order as the histories, i.e. 

148 the file of histories[0] has the length f_sizes[0]. 

149 

150 Returns 

151 ------- 

152 wgtd_average : List 

153 The weighted averaged history. Has the same length as each 

154 history in the histories List, i.e. one entry per loss or metric. 

155 

156 """ 

157 assert len(histories) == len(f_sizes) 

158 rltv_fsizes = [f_size / sum(f_sizes) for f_size in f_sizes] 

159 wgtd_average = np.dot(np.transpose(histories), rltv_fsizes) 

160 

161 return wgtd_average.tolist() 

162 

163 

164def h5_inference( 

165 orga, model, files_dict, output_path, samples=None, use_def_label=True 

166): 

167 """ 

168 Let a model predict on all samples in a h5 file, and save it as a h5 file. 

169 

170 Per default, the h5 file will contain a datagroup y_values straight from 

171 the given files, as well as two datagroups per output layer of the network, 

172 which have the labels and the predicted values in them as numpy arrays, 

173 respectively. 

174 

175 Parameters 

176 ---------- 

177 orga : orcanet.core.Organizer 

178 Contains all the configurable options in the OrcaNet scripts. 

179 model : keras.Model 

180 Trained Keras model of a neural network. 

181 files_dict : dict 

182 Dict mapping model input names to h5 file paths. 

183 output_path : str 

184 Name of the output h5 file containing the predictions. 

185 samples : int, optional 

186 Dont use all events in the file, but instead only the given number. 

187 use_def_label : bool 

188 If True and no label modifier is given by user, use the default 

189 label modifier instead of none. 

190 

191 """ 

192 file_size = h5_get_number_of_rows( 

193 list(files_dict.values())[0], datasets=[orga.cfg.key_y_values] 

194 ) 

195 generator = h5_generator.get_h5_generator( 

196 orga, 

197 files_dict, 

198 zero_center=orga.cfg.zero_center_folder is not None, 

199 keras_mode=False, 

200 use_def_label=use_def_label, 

201 phase="inference", 

202 ) 

203 try: 

204 itergen = iter(generator) 

205 

206 if samples is None: 

207 steps = len(generator) 

208 else: 

209 steps = int(samples / orga.cfg.batchsize) 

210 print_every = max(100, min(int(round(steps / 10, -2)), 1000)) 

211 model_time_total = 0.0 

212 dataset_last_element = {} 

213 

214 temp_output_path = os.path.join( 

215 os.path.dirname(output_path), 

216 "temp_" 

217 + os.path.basename(output_path) 

218 + "_" 

219 + time.strftime("%d-%m-%Y-%H-%M-%S", time.gmtime()), 

220 ) 

221 print(f"Creating temporary file {temp_output_path}") 

222 with h5py.File(temp_output_path, "x") as h5_file: 

223 # add version and paths of h5files 

224 h5_file.attrs.create("orcanet", orcanet.__version__) 

225 for input_key, file in files_dict.items(): 

226 h5_file.attrs.create(f"orcanet_inp_{input_key}", file) 

227 

228 for s in range(steps): 

229 if s % print_every == 0: 

230 print( 

231 "Predicting in step {}/{} ({:0.2%})".format(s, steps, s / steps) 

232 ) 

233 

234 info_blob = next(itergen) 

235 

236 start_time = time.time() 

237 y_pred = model.predict_on_batch(info_blob["xs"]) 

238 model_time_total += time.time() - start_time 

239 

240 if not isinstance(y_pred, list): 

241 # if only one output, transform to a list 

242 y_pred = [y_pred] 

243 # transform y_pred to dict 

244 y_pred = {out: y_pred[i] for i, out in enumerate(model.output_names)} 

245 info_blob["y_pred"] = y_pred 

246 

247 if info_blob.get("org_batchsize") is not None: 

248 _slice_to_size(info_blob) 

249 

250 if orga.cfg.dataset_modifier is None: 

251 datasets = dataset_modifiers.as_array(info_blob) 

252 else: 

253 datasets = orga.cfg.dataset_modifier(info_blob) 

254 

255 for dataset_name, data in datasets.items(): 

256 if s == 0: 

257 # create datasets in the first step 

258 h5_file.create_dataset( 

259 dataset_name, 

260 shape=(file_size,) + data.shape[1:], 

261 dtype=data.dtype, 

262 chunks=True, 

263 compression="gzip", 

264 compression_opts=1, 

265 shuffle=True, 

266 ) 

267 dataset_last_element[dataset_name] = 0 

268 

269 ix = dataset_last_element[dataset_name] 

270 h5_file[dataset_name][ix : (ix + data.shape[0])] = data 

271 dataset_last_element[dataset_name] += data.shape[0] 

272 finally: 

273 generator.close() 

274 

275 if os.path.exists(output_path): 

276 raise FileExistsError( 

277 f"{output_path} exists already! But file {temp_output_path} " 

278 f"is finished and can be safely used." 

279 ) 

280 os.rename(temp_output_path, output_path) 

281 orga.io.print_log(f"\nPrediction completed on file {os.path.basename(output_path)}") 

282 orga.io.print_log("Statistics of model prediction:") 

283 orga.io.print_log(f"\tTotal time:\t{model_time_total / 60:.2f} min") 

284 orga.io.print_log(f"\tPer batch:\t{1000 * model_time_total / steps:.5} ms") 

285 generator.print_timestats(print_func=orga.io.print_log) 

286 

287 

288def _slice_to_size(info_blob): 

289 org_batchsize = info_blob["org_batchsize"] 

290 for input_key, x in info_blob["xs"].items(): 

291 info_blob["xs"][input_key] = x[:org_batchsize] 

292 for output_key, y_pred in info_blob["y_pred"].items(): 

293 info_blob["y_pred"][output_key] = y_pred[:org_batchsize] 

294 if info_blob.get("ys") is not None: 

295 for output_key, y in info_blob["ys"].items(): 

296 info_blob["ys"][output_key] = y[:org_batchsize] 

297 

298 

299def make_model_prediction(orga, model, epoch, fileno, samples=None): 

300 """ 

301 Let a model predict on all validation samples, and save it as a h5 file. 

302 

303 Per default, the h5 file will contain a datagroup y_values straight from 

304 the given files, as well as two datagroups per output layer of the network, 

305 which have the labels and the predicted values in them as numpy arrays, 

306 respectively. 

307 

308 Parameters 

309 ---------- 

310 orga : orcanet.core.Organizer 

311 Contains all the configurable options in the OrcaNet scripts. 

312 model : keras.Model 

313 A compiled keras model. 

314 epoch : int 

315 Epoch of the last model training step in the epoch, file_no tuple. 

316 fileno : int 

317 File number of the last model training step in the epoch, file_no tuple. 

318 samples : int or None 

319 Number of events that should be predicted. 

320 If samples=None, the whole file will be used. 

321 

322 """ 

323 latest_pred_file_no = orga.io.get_latest_prediction_file_no(epoch, fileno) 

324 if latest_pred_file_no is None: 

325 latest_pred_file_no = 0 

326 

327 # For every val file set (one set can have multiple files if 

328 # the model has multiple inputs): 

329 for f_number, files_dict in enumerate(orga.io.yield_files("val"), 1): 

330 if f_number <= latest_pred_file_no: 

331 continue 

332 

333 pred_filepath = orga.io.get_pred_path(epoch, fileno, f_number) 

334 h5_inference(orga, model, files_dict, pred_filepath, samples=samples)