Coverage for orcanet/h5_generator.py: 80%
190 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
1import h5py
2import time
3import numpy as np
4import tensorflow as tf
5import tensorflow.keras as ks
8class Hdf5BatchGenerator(ks.utils.Sequence):
9 def __init__(
10 self,
11 files_dict,
12 batchsize=64,
13 key_x_values="x",
14 key_y_values="y",
15 sample_modifier=None,
16 label_modifier=None,
17 fixed_batchsize=False,
18 y_field_names=None,
19 phase="training",
20 xs_mean=None,
21 f_size=None,
22 keras_mode=True,
23 shuffle=False,
24 class_weights=None,
25 ):
26 """
27 Yields batches of input data from h5 files.
29 This will go through one file, or multiple files in parallel, and yield
30 one batch of data, which can then be used as an input to a model.
31 Since multiple filepaths can be given to read out in parallel,
32 this can also be used for models with multiple inputs.
34 Parameters
35 ----------
36 files_dict : dict
37 Pathes of the files to train on.
38 Keys: The name of every input (from the toml list file, can be multiple).
39 Values: The filepath of a single h5py file to read data from.
40 batchsize : int
41 Batchsize that will be used for reading data from the files.
42 key_x_values : str
43 The name of the datagroup in the h5 input files which contains
44 the samples for the network.
45 key_y_values : str
46 The name of the datagroup in the h5 input files which contains
47 the info for the labels. If this name is not in the file,
48 y_values will be set to None.
49 sample_modifier : function or None
50 Operation to be performed on batches of samples read from the input
51 files before they are fed into the model.
52 y_field_names : tuple or list or str, optional
53 During train and val, read out only these fields from the y dataset.
54 --> Speed up, especially if there are many fields.
55 phase : str
56 Which phase are we in? training, validation, or inference.
57 Inference means both orga.predict and orga.inference, i.e.
58 whenever we write a h5 file.
59 label_modifier : function or None
60 Operation to be performed on batches of labels read from the input files
61 before they are fed into the model.
62 fixed_batchsize : bool
63 The last batch in the file might be smaller then the batchsize.
64 Usually, this is no problem, but set to True to pad this batch to
65 given batchsize.
66 xs_mean : ndarray or None
67 Zero center image to be subtracted from data as preprocessing.
68 f_size : int or None
69 Specifies the number of samples to be read from the .h5 file.
70 If none, the whole .h5 file will be used.
71 keras_mode : bool
72 If true, yield xs and ys (samples and labels) for the keras fit
73 generator function.
74 If false, yield the info_blob containing the full sample and label
75 info, both before and after the modifiers have been applied.
76 shuffle : bool
77 Randomize the order in which batches are read from the file
78 (once during init). Can reduce read out speed.
80 """
81 if phase not in ("training", "validation", "inference"):
82 raise ValueError("Invalid phase")
83 self.files_dict = files_dict
84 self.batchsize = batchsize
85 self.key_x_values = key_x_values
86 self.key_y_values = key_y_values
87 self.sample_modifier = sample_modifier
88 self.label_modifier = label_modifier
89 self.fixed_batchsize = fixed_batchsize
90 self.phase = phase
91 self.xs_mean = xs_mean
92 self.f_size = f_size
93 self.keras_mode = keras_mode
94 self.shuffle = shuffle
95 self.class_weights = class_weights
97 if y_field_names is not None:
98 if isinstance(y_field_names, str):
99 y_field_names = (y_field_names,)
100 else:
101 y_field_names = tuple(y_field_names)
102 self.y_field_names = y_field_names
104 # a dict with the names of list inputs as keys, and the opened
105 # h5 files as values
106 self._files = {}
107 # start index of each batch in the file
108 self._sample_pos = None
109 # total number of samples per file
110 self._total_f_size = None
112 # for keeping track of the readout speed
113 self._total_time = 0.0
114 self._total_batches = 0
115 self._file_meta = None
117 self.open()
119 def __len__(self):
120 """Number of batches in the Sequence (includes queue)."""
121 return len(self._sample_pos)
123 def __getitem__(self, index):
124 """
125 Gets batch number `index`.
127 Returns
128 -------
129 xs : dict
130 Samples for the model train on.
131 Keys : str
132 The name(s) of the input layer(s) of the model.
133 Values : ndarray or tuple
134 A batch of samples for the corresponding input.
135 If x is an indexed datasets, this will be a tuple instead,
136 with [0] being the values, and [1] being the number of
137 items per sample.
138 ys : dict or None
139 Labels for the model to train on. Will be None if there are
140 no labels in the file.
141 Keys : str
142 The name(s) of the output layer(s) of the model.
143 Values : ndarray
144 A batch of labels for the corresponding output.
146 If class_weights is not None, will return aditionally:
147 sample_weights : dict
148 Maps output names to weights for each sample in the batch as a
149 np.array.
151 If keras_mode is False, will return instead:
152 info_blob : dict
153 Blob containing the x_values, y_values, xs and ys, and optionally
154 the sample_weights.
156 """
157 start_time = time.time()
158 file_index = self._sample_pos[index]
159 info_blob = {"phase": self.phase, "meta": self.get_file_meta()}
160 info_blob["x_values"] = self.get_x_values(file_index)
161 info_blob["y_values"] = self.get_y_values(file_index)
163 # Modify the samples
164 if self.sample_modifier is not None:
165 xs = self.sample_modifier(info_blob)
166 else:
167 xs = info_blob["x_values"]
168 info_blob["xs"] = xs
170 # Modify the labels
171 if info_blob["y_values"] is not None and self.label_modifier is not None:
172 ys = self.label_modifier(info_blob)
173 else:
174 ys = None
175 info_blob["ys"] = ys
177 if self.fixed_batchsize:
178 self.pad_to_size(info_blob)
180 if self.class_weights is not None:
181 info_blob["sample_weights"] = _get_sample_weights(ys, self.class_weights)
183 self._total_time += time.time() - start_time
184 self._total_batches += 1
185 if self.keras_mode:
186 if info_blob.get("sample_weights"):
187 return info_blob["xs"], info_blob["ys"], info_blob["sample_weights"]
188 else:
189 return info_blob["xs"], info_blob["ys"]
190 else:
191 return info_blob
193 def pad_to_size(self, info_blob):
194 """Pad the batch to have a fixed batchsize."""
195 org_batchsize = next(iter(info_blob["xs"].values())).shape[0]
196 if org_batchsize == self.batchsize:
197 return
198 info_blob["org_batchsize"] = org_batchsize
199 for input_key, x in info_blob["xs"].items():
200 info_blob["xs"][input_key] = _pad_to_size(x, self.batchsize)
201 if info_blob.get("ys") is not None:
202 for output_key, y in info_blob["ys"].items():
203 info_blob["ys"][output_key] = _pad_to_size(y, self.batchsize)
205 def open(self):
206 """Open all files and prepare for read out."""
207 for input_key, file in self.files_dict.items():
208 self._files[input_key] = h5py.File(file, "r")
209 self._store_file_length()
210 self._store_batch_indices()
212 def close(self):
213 """Close all files again."""
214 for f in list(self._files.values()):
215 f.close()
217 def get_x_values(self, start_index):
218 """
219 Read one batch of samples from the files and zero center.
221 Parameters
222 ----------
223 start_index : int
224 The start index in the h5 files at which the batch will be read.
225 The end index will be the start index + the batch size.
227 Returns
228 -------
229 x_values : dict
230 One batch of data for each input file.
232 """
233 x_values = {}
234 for input_key, file in self._files.items():
235 slc = slice(start_index, start_index + self._batchsize)
237 ix_dset_name = _get_indexed_dset_name(file, self.key_x_values)
238 if ix_dset_name is None:
239 # normal dataset
240 x_values[input_key] = file[self.key_x_values][slc]
241 else:
242 # indexed dataset: adjust slice according to indices
243 indices = file[ix_dset_name][slc]
244 slc = slice(
245 indices[0]["index"],
246 indices[-1]["index"] + indices[-1]["n_items"],
247 )
248 x_values[input_key] = (file[self.key_x_values][slc], indices["n_items"])
250 if self.xs_mean is not None:
251 x_values[input_key] = np.subtract(
252 x_values[input_key], self.xs_mean[input_key]
253 )
255 return x_values
257 def get_y_values(self, start_index):
258 """
259 Get y_values for the nn. Since the y_values are hopefully the same
260 for all the files, use the ones from the first. TODO add check
262 Parameters
263 ----------
264 start_index : int
265 The start index in the h5 files at which the batch will be read.
266 The end index will be the start index + the batch size.
268 Returns
269 -------
270 y_values : ndarray
271 The y_values, right from the files.
273 """
274 first_file = list(self._files.values())[0]
275 try:
276 slc = slice(start_index, start_index + self._batchsize)
277 if self.y_field_names is not None and self.phase != "inference":
278 y_values = first_file[self.key_y_values][
279 (slc,)
280 + tuple(
281 self.y_field_names,
282 )
283 ]
284 if len(self.y_field_names) == 1:
285 # result of slice is a ndarray; convert to structured
286 y_values = y_values.astype(
287 np.dtype([(self.y_field_names[0], y_values.dtype)])
288 )
289 else:
290 y_values = first_file[self.key_y_values][slc]
291 except KeyError:
292 # can not look up y_values, lets hope we dont need them
293 y_values = None
294 return y_values
296 def print_timestats(self, print_func=None):
297 """Print stats about how long it took to read batches."""
298 if print_func is None:
299 print_func = print
300 print_func("Statistics of data readout:")
301 print_func(f"\tTotal time:\t{self._total_time/60:.2f} min")
302 if self._total_batches != 0:
303 print_func(
304 f"\tPer batch:\t" f"{1000 * self._total_time/self._total_batches:.5} ms"
305 )
307 def get_file_meta(self):
308 """Meta information about the files. Only read out once."""
309 if self._file_meta is None:
310 self._file_meta = {}
311 # sample and label dataset for each input
312 datasets = {}
313 for input_key, file in self._files.items():
314 datasets[input_key] = {
315 "samples": file[self.key_x_values],
316 "samples_is_indexed": _get_indexed_dset_name(
317 file, self.key_x_values
318 )
319 is not None,
320 "labels": file[self.key_y_values],
321 }
322 self._file_meta["datasets"] = datasets
323 return self._file_meta
325 @property
326 def _size(self):
327 """Size of the files that will be read in. Can be smaller than the actual
328 file size if defined by user."""
329 if self.f_size is None:
330 return self._total_f_size
331 else:
332 return self.f_size
334 @property
335 def _batchsize(self):
336 """
337 Return the effective batchsize. Can be smaller than the user defined
338 one if it would be larger than the size of the file.
339 """
340 if self._size < self.batchsize:
341 return self._size
342 else:
343 return self.batchsize
345 def _store_file_length(self):
346 """
347 Make sure all files have the same length and store this length.
348 """
349 lengths = []
350 for f in list(self._files.values()):
351 ix_dset_name = _get_indexed_dset_name(f, self.key_x_values)
352 if ix_dset_name is None:
353 dset_name = self.key_x_values
354 else:
355 dset_name = ix_dset_name
356 lengths.append(len(f[dset_name]))
358 if not lengths.count(lengths[0]) == len(lengths):
359 self.close()
360 raise ValueError(
361 "All data files must have the same length! "
362 "Given were:\n " + str(lengths)
363 )
365 self._total_f_size = lengths[0]
367 def _store_batch_indices(self):
368 """
369 Define the start indices of each batch in the h5 file and store this.
370 """
371 if self.phase == "inference":
372 # for inference: take all batches
373 total_no_of_batches = np.ceil(self._size / self._batchsize)
374 else:
375 # else: skip last batch if it has too few event for a full batch
376 # this is mostly because tf datasets can't be used
377 # with variable batchsize (status tf 2.5)
378 total_no_of_batches = np.floor(self._size / self._batchsize)
380 sample_pos = np.arange(int(total_no_of_batches)) * self._batchsize
381 if self.shuffle:
382 np.random.shuffle(sample_pos)
384 self._sample_pos = sample_pos
387def _get_indexed_dset_name(file, dset):
388 """If this is an indexed dataset, return the name of the indexed set."""
389 dset_name_indexed = f"{dset}_indices"
390 if file[dset].attrs.get("indexed") and dset_name_indexed in file:
391 return dset_name_indexed
392 else:
393 return None
396def _get_sample_weights(ys, class_weights):
397 """
398 Produce a weight for each sample given the weight for each class.
400 Parameters
401 ----------
402 ys : dict
403 Maps output names to categorical one-hot labels as np.arrays.
404 Expected to be 2D (n_samples, n_classes).
405 class_weights : dict
406 Maps output neuron numbers to weights as floats.
408 Returns
409 -------
410 sample_weights : dict
411 Maps output names to weights for each sample in the batch as a
412 np.array.
414 """
415 sample_weights = {}
416 for output_name, labels in ys.items():
417 class_weights_arr = np.ones(labels.shape[1])
418 for k, v in class_weights.items():
419 class_weights_arr[int(k)] = v
420 labels_class = np.argmax(labels, axis=-1)
421 sample_weights[output_name] = class_weights_arr[labels_class]
422 return sample_weights
425def get_h5_generator(
426 orga,
427 files_dict,
428 f_size=None,
429 zero_center=False,
430 keras_mode=True,
431 shuffle=False,
432 use_def_label=True,
433 phase="training",
434):
435 """
436 Initialize the hdf5_batch_generator_base with the paramters in orga.cfg.
438 Parameters
439 ----------
440 orga : orcanet.core.Organizer
441 Contains all the configurable options in the OrcaNet scripts.
442 files_dict : dict
443 Pathes of the files to train on.
444 Keys: The name of every input (from the toml list file, can be multiple).
445 Values: The filepath of a single h5py file to read samples from.
446 f_size : int or None
447 Specifies the number of samples to be read from the .h5 file.
448 If none, the whole .h5 file will be used.
449 zero_center : bool
450 Whether to use zero centering.
451 Requires orga.zero_center_folder to be set.
452 keras_mode : bool
453 Specifies if mc-infos (y_values) should be yielded as well. The
454 mc-infos are used for evaluation after training and testing is finished.
455 shuffle : bool
456 Randomize the order in which batches are read from the file.
457 Significantly reduces read out speed.
458 use_def_label : bool
459 If True and no label modifier is given by user, use the default
460 label modifier instead of none.
462 Yields
463 ------
464 xs : dict
465 Data for the model train on.
466 Keys : str The name(s) of the input layer(s) of the model.
467 Values : ndarray A batch of samples for the corresponding input.
468 ys : dict or None
469 Labels for the model to train on.
470 Keys : str The name(s) of the output layer(s) of the model.
471 Values : ndarray A batch of labels for the corresponding output.
472 Will be None if there are no labels in the file.
473 y_values : ndarray, optional
474 Y values from the file. Only yielded if yield_mc_info is True.
476 """
477 if orga.cfg.label_modifier is not None:
478 label_modifier = orga.cfg.label_modifier
479 elif use_def_label:
480 assert (
481 orga._auto_label_modifier is not None
482 ), "Auto label modifier has not been set up"
483 label_modifier = orga._auto_label_modifier
484 else:
485 label_modifier = None
487 # get xs_mean or load/create if not stored yet
488 if zero_center:
489 xs_mean = orga.get_xs_mean()
490 else:
491 xs_mean = None
493 generator = Hdf5BatchGenerator(
494 files_dict=files_dict,
495 batchsize=orga.cfg.batchsize,
496 key_x_values=orga.cfg.key_x_values,
497 key_y_values=orga.cfg.key_y_values,
498 sample_modifier=orga.cfg.sample_modifier,
499 label_modifier=label_modifier,
500 phase=phase,
501 xs_mean=xs_mean,
502 f_size=f_size,
503 keras_mode=keras_mode,
504 shuffle=shuffle,
505 class_weights=orga.cfg.class_weight,
506 fixed_batchsize=orga.cfg.fixed_batchsize,
507 y_field_names=orga.cfg.y_field_names,
508 )
510 return generator
513def make_dataset(gen):
514 output_signature = tuple([{k: _get_spec(v) for k, v in d.items()} for d in gen[0]])
515 return tf.data.Dataset.from_generator(
516 lambda: gen, output_signature=output_signature
517 )
520def _get_spec(x):
521 if isinstance(x, tf.RaggedTensor):
522 return tf.RaggedTensorSpec.from_value(x)
523 else:
524 return tf.TensorSpec(
525 shape=x.shape,
526 dtype=x.dtype,
527 )
530def _pad_to_size(x, size):
531 """Pad x to given size along axis 0 by repeating last element."""
532 length = x.shape[0]
533 if length > size:
534 raise ValueError(f"Can't pad x with shape {x.shape} to length {size}")
535 elif length == size:
536 return x
537 else:
538 if tf.is_tensor(x):
539 f_conc = tf.concat
540 else:
541 f_conc = np.concatenate
543 return f_conc([x] + [x[-1:]] * (size - length), axis=0)