Coverage for orcanet/h5_generator.py: 80%

190 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-03-28 14:22 +0000

1import h5py 

2import time 

3import numpy as np 

4import tensorflow as tf 

5import tensorflow.keras as ks 

6 

7 

8class Hdf5BatchGenerator(ks.utils.Sequence): 

9 def __init__( 

10 self, 

11 files_dict, 

12 batchsize=64, 

13 key_x_values="x", 

14 key_y_values="y", 

15 sample_modifier=None, 

16 label_modifier=None, 

17 fixed_batchsize=False, 

18 y_field_names=None, 

19 phase="training", 

20 xs_mean=None, 

21 f_size=None, 

22 keras_mode=True, 

23 shuffle=False, 

24 class_weights=None, 

25 ): 

26 """ 

27 Yields batches of input data from h5 files. 

28 

29 This will go through one file, or multiple files in parallel, and yield 

30 one batch of data, which can then be used as an input to a model. 

31 Since multiple filepaths can be given to read out in parallel, 

32 this can also be used for models with multiple inputs. 

33 

34 Parameters 

35 ---------- 

36 files_dict : dict 

37 Pathes of the files to train on. 

38 Keys: The name of every input (from the toml list file, can be multiple). 

39 Values: The filepath of a single h5py file to read data from. 

40 batchsize : int 

41 Batchsize that will be used for reading data from the files. 

42 key_x_values : str 

43 The name of the datagroup in the h5 input files which contains 

44 the samples for the network. 

45 key_y_values : str 

46 The name of the datagroup in the h5 input files which contains 

47 the info for the labels. If this name is not in the file, 

48 y_values will be set to None. 

49 sample_modifier : function or None 

50 Operation to be performed on batches of samples read from the input 

51 files before they are fed into the model. 

52 y_field_names : tuple or list or str, optional 

53 During train and val, read out only these fields from the y dataset. 

54 --> Speed up, especially if there are many fields. 

55 phase : str 

56 Which phase are we in? training, validation, or inference. 

57 Inference means both orga.predict and orga.inference, i.e. 

58 whenever we write a h5 file. 

59 label_modifier : function or None 

60 Operation to be performed on batches of labels read from the input files 

61 before they are fed into the model. 

62 fixed_batchsize : bool 

63 The last batch in the file might be smaller then the batchsize. 

64 Usually, this is no problem, but set to True to pad this batch to 

65 given batchsize. 

66 xs_mean : ndarray or None 

67 Zero center image to be subtracted from data as preprocessing. 

68 f_size : int or None 

69 Specifies the number of samples to be read from the .h5 file. 

70 If none, the whole .h5 file will be used. 

71 keras_mode : bool 

72 If true, yield xs and ys (samples and labels) for the keras fit 

73 generator function. 

74 If false, yield the info_blob containing the full sample and label 

75 info, both before and after the modifiers have been applied. 

76 shuffle : bool 

77 Randomize the order in which batches are read from the file 

78 (once during init). Can reduce read out speed. 

79 

80 """ 

81 if phase not in ("training", "validation", "inference"): 

82 raise ValueError("Invalid phase") 

83 self.files_dict = files_dict 

84 self.batchsize = batchsize 

85 self.key_x_values = key_x_values 

86 self.key_y_values = key_y_values 

87 self.sample_modifier = sample_modifier 

88 self.label_modifier = label_modifier 

89 self.fixed_batchsize = fixed_batchsize 

90 self.phase = phase 

91 self.xs_mean = xs_mean 

92 self.f_size = f_size 

93 self.keras_mode = keras_mode 

94 self.shuffle = shuffle 

95 self.class_weights = class_weights 

96 

97 if y_field_names is not None: 

98 if isinstance(y_field_names, str): 

99 y_field_names = (y_field_names,) 

100 else: 

101 y_field_names = tuple(y_field_names) 

102 self.y_field_names = y_field_names 

103 

104 # a dict with the names of list inputs as keys, and the opened 

105 # h5 files as values 

106 self._files = {} 

107 # start index of each batch in the file 

108 self._sample_pos = None 

109 # total number of samples per file 

110 self._total_f_size = None 

111 

112 # for keeping track of the readout speed 

113 self._total_time = 0.0 

114 self._total_batches = 0 

115 self._file_meta = None 

116 

117 self.open() 

118 

119 def __len__(self): 

120 """Number of batches in the Sequence (includes queue).""" 

121 return len(self._sample_pos) 

122 

123 def __getitem__(self, index): 

124 """ 

125 Gets batch number `index`. 

126 

127 Returns 

128 ------- 

129 xs : dict 

130 Samples for the model train on. 

131 Keys : str 

132 The name(s) of the input layer(s) of the model. 

133 Values : ndarray or tuple 

134 A batch of samples for the corresponding input. 

135 If x is an indexed datasets, this will be a tuple instead, 

136 with [0] being the values, and [1] being the number of 

137 items per sample. 

138 ys : dict or None 

139 Labels for the model to train on. Will be None if there are 

140 no labels in the file. 

141 Keys : str 

142 The name(s) of the output layer(s) of the model. 

143 Values : ndarray 

144 A batch of labels for the corresponding output. 

145 

146 If class_weights is not None, will return aditionally: 

147 sample_weights : dict 

148 Maps output names to weights for each sample in the batch as a 

149 np.array. 

150 

151 If keras_mode is False, will return instead: 

152 info_blob : dict 

153 Blob containing the x_values, y_values, xs and ys, and optionally 

154 the sample_weights. 

155 

156 """ 

157 start_time = time.time() 

158 file_index = self._sample_pos[index] 

159 info_blob = {"phase": self.phase, "meta": self.get_file_meta()} 

160 info_blob["x_values"] = self.get_x_values(file_index) 

161 info_blob["y_values"] = self.get_y_values(file_index) 

162 

163 # Modify the samples 

164 if self.sample_modifier is not None: 

165 xs = self.sample_modifier(info_blob) 

166 else: 

167 xs = info_blob["x_values"] 

168 info_blob["xs"] = xs 

169 

170 # Modify the labels 

171 if info_blob["y_values"] is not None and self.label_modifier is not None: 

172 ys = self.label_modifier(info_blob) 

173 else: 

174 ys = None 

175 info_blob["ys"] = ys 

176 

177 if self.fixed_batchsize: 

178 self.pad_to_size(info_blob) 

179 

180 if self.class_weights is not None: 

181 info_blob["sample_weights"] = _get_sample_weights(ys, self.class_weights) 

182 

183 self._total_time += time.time() - start_time 

184 self._total_batches += 1 

185 if self.keras_mode: 

186 if info_blob.get("sample_weights"): 

187 return info_blob["xs"], info_blob["ys"], info_blob["sample_weights"] 

188 else: 

189 return info_blob["xs"], info_blob["ys"] 

190 else: 

191 return info_blob 

192 

193 def pad_to_size(self, info_blob): 

194 """Pad the batch to have a fixed batchsize.""" 

195 org_batchsize = next(iter(info_blob["xs"].values())).shape[0] 

196 if org_batchsize == self.batchsize: 

197 return 

198 info_blob["org_batchsize"] = org_batchsize 

199 for input_key, x in info_blob["xs"].items(): 

200 info_blob["xs"][input_key] = _pad_to_size(x, self.batchsize) 

201 if info_blob.get("ys") is not None: 

202 for output_key, y in info_blob["ys"].items(): 

203 info_blob["ys"][output_key] = _pad_to_size(y, self.batchsize) 

204 

205 def open(self): 

206 """Open all files and prepare for read out.""" 

207 for input_key, file in self.files_dict.items(): 

208 self._files[input_key] = h5py.File(file, "r") 

209 self._store_file_length() 

210 self._store_batch_indices() 

211 

212 def close(self): 

213 """Close all files again.""" 

214 for f in list(self._files.values()): 

215 f.close() 

216 

217 def get_x_values(self, start_index): 

218 """ 

219 Read one batch of samples from the files and zero center. 

220 

221 Parameters 

222 ---------- 

223 start_index : int 

224 The start index in the h5 files at which the batch will be read. 

225 The end index will be the start index + the batch size. 

226 

227 Returns 

228 ------- 

229 x_values : dict 

230 One batch of data for each input file. 

231 

232 """ 

233 x_values = {} 

234 for input_key, file in self._files.items(): 

235 slc = slice(start_index, start_index + self._batchsize) 

236 

237 ix_dset_name = _get_indexed_dset_name(file, self.key_x_values) 

238 if ix_dset_name is None: 

239 # normal dataset 

240 x_values[input_key] = file[self.key_x_values][slc] 

241 else: 

242 # indexed dataset: adjust slice according to indices 

243 indices = file[ix_dset_name][slc] 

244 slc = slice( 

245 indices[0]["index"], 

246 indices[-1]["index"] + indices[-1]["n_items"], 

247 ) 

248 x_values[input_key] = (file[self.key_x_values][slc], indices["n_items"]) 

249 

250 if self.xs_mean is not None: 

251 x_values[input_key] = np.subtract( 

252 x_values[input_key], self.xs_mean[input_key] 

253 ) 

254 

255 return x_values 

256 

257 def get_y_values(self, start_index): 

258 """ 

259 Get y_values for the nn. Since the y_values are hopefully the same 

260 for all the files, use the ones from the first. TODO add check 

261 

262 Parameters 

263 ---------- 

264 start_index : int 

265 The start index in the h5 files at which the batch will be read. 

266 The end index will be the start index + the batch size. 

267 

268 Returns 

269 ------- 

270 y_values : ndarray 

271 The y_values, right from the files. 

272 

273 """ 

274 first_file = list(self._files.values())[0] 

275 try: 

276 slc = slice(start_index, start_index + self._batchsize) 

277 if self.y_field_names is not None and self.phase != "inference": 

278 y_values = first_file[self.key_y_values][ 

279 (slc,) 

280 + tuple( 

281 self.y_field_names, 

282 ) 

283 ] 

284 if len(self.y_field_names) == 1: 

285 # result of slice is a ndarray; convert to structured 

286 y_values = y_values.astype( 

287 np.dtype([(self.y_field_names[0], y_values.dtype)]) 

288 ) 

289 else: 

290 y_values = first_file[self.key_y_values][slc] 

291 except KeyError: 

292 # can not look up y_values, lets hope we dont need them 

293 y_values = None 

294 return y_values 

295 

296 def print_timestats(self, print_func=None): 

297 """Print stats about how long it took to read batches.""" 

298 if print_func is None: 

299 print_func = print 

300 print_func("Statistics of data readout:") 

301 print_func(f"\tTotal time:\t{self._total_time/60:.2f} min") 

302 if self._total_batches != 0: 

303 print_func( 

304 f"\tPer batch:\t" f"{1000 * self._total_time/self._total_batches:.5} ms" 

305 ) 

306 

307 def get_file_meta(self): 

308 """Meta information about the files. Only read out once.""" 

309 if self._file_meta is None: 

310 self._file_meta = {} 

311 # sample and label dataset for each input 

312 datasets = {} 

313 for input_key, file in self._files.items(): 

314 datasets[input_key] = { 

315 "samples": file[self.key_x_values], 

316 "samples_is_indexed": _get_indexed_dset_name( 

317 file, self.key_x_values 

318 ) 

319 is not None, 

320 "labels": file[self.key_y_values], 

321 } 

322 self._file_meta["datasets"] = datasets 

323 return self._file_meta 

324 

325 @property 

326 def _size(self): 

327 """Size of the files that will be read in. Can be smaller than the actual 

328 file size if defined by user.""" 

329 if self.f_size is None: 

330 return self._total_f_size 

331 else: 

332 return self.f_size 

333 

334 @property 

335 def _batchsize(self): 

336 """ 

337 Return the effective batchsize. Can be smaller than the user defined 

338 one if it would be larger than the size of the file. 

339 """ 

340 if self._size < self.batchsize: 

341 return self._size 

342 else: 

343 return self.batchsize 

344 

345 def _store_file_length(self): 

346 """ 

347 Make sure all files have the same length and store this length. 

348 """ 

349 lengths = [] 

350 for f in list(self._files.values()): 

351 ix_dset_name = _get_indexed_dset_name(f, self.key_x_values) 

352 if ix_dset_name is None: 

353 dset_name = self.key_x_values 

354 else: 

355 dset_name = ix_dset_name 

356 lengths.append(len(f[dset_name])) 

357 

358 if not lengths.count(lengths[0]) == len(lengths): 

359 self.close() 

360 raise ValueError( 

361 "All data files must have the same length! " 

362 "Given were:\n " + str(lengths) 

363 ) 

364 

365 self._total_f_size = lengths[0] 

366 

367 def _store_batch_indices(self): 

368 """ 

369 Define the start indices of each batch in the h5 file and store this. 

370 """ 

371 if self.phase == "inference": 

372 # for inference: take all batches 

373 total_no_of_batches = np.ceil(self._size / self._batchsize) 

374 else: 

375 # else: skip last batch if it has too few event for a full batch 

376 # this is mostly because tf datasets can't be used 

377 # with variable batchsize (status tf 2.5) 

378 total_no_of_batches = np.floor(self._size / self._batchsize) 

379 

380 sample_pos = np.arange(int(total_no_of_batches)) * self._batchsize 

381 if self.shuffle: 

382 np.random.shuffle(sample_pos) 

383 

384 self._sample_pos = sample_pos 

385 

386 

387def _get_indexed_dset_name(file, dset): 

388 """If this is an indexed dataset, return the name of the indexed set.""" 

389 dset_name_indexed = f"{dset}_indices" 

390 if file[dset].attrs.get("indexed") and dset_name_indexed in file: 

391 return dset_name_indexed 

392 else: 

393 return None 

394 

395 

396def _get_sample_weights(ys, class_weights): 

397 """ 

398 Produce a weight for each sample given the weight for each class. 

399 

400 Parameters 

401 ---------- 

402 ys : dict 

403 Maps output names to categorical one-hot labels as np.arrays. 

404 Expected to be 2D (n_samples, n_classes). 

405 class_weights : dict 

406 Maps output neuron numbers to weights as floats. 

407 

408 Returns 

409 ------- 

410 sample_weights : dict 

411 Maps output names to weights for each sample in the batch as a 

412 np.array. 

413 

414 """ 

415 sample_weights = {} 

416 for output_name, labels in ys.items(): 

417 class_weights_arr = np.ones(labels.shape[1]) 

418 for k, v in class_weights.items(): 

419 class_weights_arr[int(k)] = v 

420 labels_class = np.argmax(labels, axis=-1) 

421 sample_weights[output_name] = class_weights_arr[labels_class] 

422 return sample_weights 

423 

424 

425def get_h5_generator( 

426 orga, 

427 files_dict, 

428 f_size=None, 

429 zero_center=False, 

430 keras_mode=True, 

431 shuffle=False, 

432 use_def_label=True, 

433 phase="training", 

434): 

435 """ 

436 Initialize the hdf5_batch_generator_base with the paramters in orga.cfg. 

437 

438 Parameters 

439 ---------- 

440 orga : orcanet.core.Organizer 

441 Contains all the configurable options in the OrcaNet scripts. 

442 files_dict : dict 

443 Pathes of the files to train on. 

444 Keys: The name of every input (from the toml list file, can be multiple). 

445 Values: The filepath of a single h5py file to read samples from. 

446 f_size : int or None 

447 Specifies the number of samples to be read from the .h5 file. 

448 If none, the whole .h5 file will be used. 

449 zero_center : bool 

450 Whether to use zero centering. 

451 Requires orga.zero_center_folder to be set. 

452 keras_mode : bool 

453 Specifies if mc-infos (y_values) should be yielded as well. The 

454 mc-infos are used for evaluation after training and testing is finished. 

455 shuffle : bool 

456 Randomize the order in which batches are read from the file. 

457 Significantly reduces read out speed. 

458 use_def_label : bool 

459 If True and no label modifier is given by user, use the default 

460 label modifier instead of none. 

461 

462 Yields 

463 ------ 

464 xs : dict 

465 Data for the model train on. 

466 Keys : str The name(s) of the input layer(s) of the model. 

467 Values : ndarray A batch of samples for the corresponding input. 

468 ys : dict or None 

469 Labels for the model to train on. 

470 Keys : str The name(s) of the output layer(s) of the model. 

471 Values : ndarray A batch of labels for the corresponding output. 

472 Will be None if there are no labels in the file. 

473 y_values : ndarray, optional 

474 Y values from the file. Only yielded if yield_mc_info is True. 

475 

476 """ 

477 if orga.cfg.label_modifier is not None: 

478 label_modifier = orga.cfg.label_modifier 

479 elif use_def_label: 

480 assert ( 

481 orga._auto_label_modifier is not None 

482 ), "Auto label modifier has not been set up" 

483 label_modifier = orga._auto_label_modifier 

484 else: 

485 label_modifier = None 

486 

487 # get xs_mean or load/create if not stored yet 

488 if zero_center: 

489 xs_mean = orga.get_xs_mean() 

490 else: 

491 xs_mean = None 

492 

493 generator = Hdf5BatchGenerator( 

494 files_dict=files_dict, 

495 batchsize=orga.cfg.batchsize, 

496 key_x_values=orga.cfg.key_x_values, 

497 key_y_values=orga.cfg.key_y_values, 

498 sample_modifier=orga.cfg.sample_modifier, 

499 label_modifier=label_modifier, 

500 phase=phase, 

501 xs_mean=xs_mean, 

502 f_size=f_size, 

503 keras_mode=keras_mode, 

504 shuffle=shuffle, 

505 class_weights=orga.cfg.class_weight, 

506 fixed_batchsize=orga.cfg.fixed_batchsize, 

507 y_field_names=orga.cfg.y_field_names, 

508 ) 

509 

510 return generator 

511 

512 

513def make_dataset(gen): 

514 output_signature = tuple([{k: _get_spec(v) for k, v in d.items()} for d in gen[0]]) 

515 return tf.data.Dataset.from_generator( 

516 lambda: gen, output_signature=output_signature 

517 ) 

518 

519 

520def _get_spec(x): 

521 if isinstance(x, tf.RaggedTensor): 

522 return tf.RaggedTensorSpec.from_value(x) 

523 else: 

524 return tf.TensorSpec( 

525 shape=x.shape, 

526 dtype=x.dtype, 

527 ) 

528 

529 

530def _pad_to_size(x, size): 

531 """Pad x to given size along axis 0 by repeating last element.""" 

532 length = x.shape[0] 

533 if length > size: 

534 raise ValueError(f"Can't pad x with shape {x.shape} to length {size}") 

535 elif length == size: 

536 return x 

537 else: 

538 if tf.is_tensor(x): 

539 f_conc = tf.concat 

540 else: 

541 f_conc = np.concatenate 

542 

543 return f_conc([x] + [x[-1:]] * (size - length), axis=0)