Coverage for orcasong/core.py: 84%
140 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-10-03 18:23 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-10-03 18:23 +0000
1import os
2import warnings
3from abc import abstractmethod
4import h5py
5import km3pipe as kp
6import km3modules as km
8import orcasong
9import orcasong.modules as modules
10import orcasong.plotting.plot_binstats as plot_binstats
13__author__ = "Stefan Reck"
16class BaseProcessor:
17 """
18 Preprocess km3net/antares events for neural networks.
20 This serves as a baseclass, which handles things like reading
21 events, calibrating, generating labels and saving the output.
23 Parameters
24 ----------
25 extractor : function, optional
26 Function that extracts desired info from a blob, which is then
27 stored as the "y" datafield in the .h5 file.
28 The function takes the km3pipe blob as an input, and returns
29 a dict mapping str to floats.
30 Examples can be found in orcasong.extractors.
31 det_file : str, optional
32 Path to a .detx detector geometry file, which can be used to
33 calibrate the hits.
34 correct_mc_time : bool
35 Convert MC hit times to JTE times. Will only be done if
36 mc_hits and mc_tracks are there.
37 center_time : bool
38 Subtract time of first triggered hit from all hit times. Will
39 also be done for McHits if they are in the blob [default: True].
40 calib_hits : bool
41 Apply calibration to hits if det file is given. Default: True.
42 calib_mchits : bool
43 Apply calibration to mchits if det file is given and mchits are
44 found in the blob. Default: True.
45 correct_timeslew : bool
46 If true (default), the time slewing of hits depending on their tot
47 will be corrected during calibration.
48 Only done if det file is given and calib_hits is True.
49 center_hits_to : tuple, optional
50 Translate the xyz positions of the hits (and mchits), as if
51 the detector was centered at the given position.
52 E.g., if its (0, 0, None), the hits and mchits will be
53 centered at xy = 00, and z will be left untouched.
54 Can only be used when a detx file is given.
55 add_t0 : bool
56 If true, add t0 to the time of hits and mchits. If using a
57 det_file, this will already have been done automatically
58 [default: False].
59 event_skipper : func, optional
60 Function that takes the blob as an input, and returns a bool.
61 If the bool is true, the blob will be skipped.
62 This is placed after the binning and mc_info extractor.
63 chunksize : int, optional
64 Chunksize (along axis_0) used for saving the output
65 to a .h5 file [default: None, i.e. auto chunking].
66 keep_event_info : bool
67 If True, will keep the "event_info" table [default: False].
68 overwrite : bool
69 If True, overwrite the output file if it exists already.
70 If False, throw an error instead.
71 sort_y : bool
72 Sort the columns in the y dataset alphabetically.
73 y_to_float64 : bool
74 Convert everything in the y dataset to float 64 (Default: True).
75 Hint: Not all other dtypes can store nan!
77 Attributes
78 ----------
79 n_statusbar : int or None
80 Print a statusbar every n blobs.
81 n_memory_observer : int or None
82 Print memory usage every n blobs.
83 complib : str
84 Compression library used for saving the output to a .h5 file.
85 All PyTables compression filters are available, e.g. 'zlib',
86 'lzf', 'blosc', ... .
87 complevel : int
88 Compression level for the compression filter that is used for
89 saving the output to a .h5 file.
90 flush_frequency : int
91 After how many events the accumulated output should be flushed to
92 the harddisk.
93 A larger value leads to a faster orcasong execution,
94 but it increases the RAM usage as well.
95 seed : int, optional
96 Makes all random (numpy) actions reproducable. Set at the start of
97 each pipeline.
99 """
101 def __init__(
102 self,
103 extractor=None,
104 det_file=None,
105 correct_mc_time=True,
106 center_time=True,
107 calib_hits=True,
108 calib_mchits=True,
109 add_t0=False,
110 correct_timeslew=True,
111 center_hits_to=None,
112 event_skipper=None,
113 chunksize=None,
114 keep_event_info=False,
115 overwrite=True,
116 sort_y=True,
117 y_to_float64=True,
118 ):
119 if center_hits_to is not None and det_file is None:
120 raise ValueError("det_file has to be given when using center_hits_to")
122 self.extractor = extractor
123 self.det_file = det_file
124 self.correct_mc_time = correct_mc_time
125 self.center_time = center_time
126 self.calib_hits = calib_hits
127 self.calib_mchits = calib_mchits
128 self.add_t0 = add_t0
129 self.correct_timeslew = correct_timeslew
130 self.center_hits_to = center_hits_to
131 self.event_skipper = event_skipper
132 self.chunksize = chunksize
133 self.keep_event_info = keep_event_info
134 self.overwrite = overwrite
135 self.sort_y = sort_y
136 self.y_to_float64 = y_to_float64
138 self.n_statusbar = 1000
139 self.n_memory_observer = 1000
140 self.complib = "zlib"
141 self.complevel = 1
142 self.flush_frequency = 1000
143 self.seed = 42
145 def run(self, infile, outfile=None):
146 """
147 Process the events from the infile, and save them to the outfile.
149 Parameters
150 ----------
151 infile : str
152 Path to the input file.
153 outfile : str, optional
154 Path to the output file (will be created). If none is given,
155 will auto generate the name and save it in the cwd.
157 """
158 if outfile is None:
159 outfile = os.path.join(
160 os.getcwd(),
161 "{}_dl.h5".format(os.path.splitext(os.path.basename(infile))[0]),
162 )
163 if not self.overwrite:
164 if os.path.isfile(outfile):
165 raise FileExistsError(f"File exists: {outfile}")
166 if self.seed:
167 km.GlobalRandomState(seed=self.seed)
168 pipe = self.build_pipe(infile, outfile)
169 summary = pipe.drain()
170 with h5py.File(outfile, "a") as f:
171 self.finish_file(f, summary)
173 def run_multi(self, infiles, outfolder):
174 """
175 Process multiple files into their own output files each.
176 The output file names will be generated automatically.
178 Parameters
179 ----------
180 infiles : List
181 The path to infiles as str.
182 outfolder : str
183 The output folder to place them in.
185 """
186 outfiles = []
187 for infile in infiles:
188 outfile = os.path.join(
189 outfolder, f"{os.path.splitext(os.path.basename(infile))[0]}_dl.h5"
190 )
191 outfiles.append(outfile)
192 self.run(infile, outfile)
193 return outfiles
195 def build_pipe(self, infile, outfile, timeit=True):
196 """Initialize and connect the modules from the different stages."""
197 components = [
198 *self.get_cmpts_pre(infile=infile),
199 *self.get_cmpts_main(),
200 *self.get_cmpts_post(outfile=outfile),
201 ]
202 pipe = kp.Pipeline(timeit=timeit)
203 if self.n_statusbar is not None:
204 pipe.attach(km.common.StatusBar, every=self.n_statusbar)
205 if self.n_memory_observer is not None:
206 pipe.attach(km.common.MemoryObserver, every=self.n_memory_observer)
207 for cmpt, kwargs in components:
208 pipe.attach(cmpt, **kwargs)
209 return pipe
211 def get_cmpts_pre(self, infile):
212 """Modules that read and calibrate the events."""
213 cmpts = [(kp.io.hdf5.HDF5Pump, {"filename": infile})]
215 if self.correct_mc_time:
216 with h5py.File(infile, "r") as f:
217 if "mc_hits" in f and "mc_tracks" in f:
218 cmpts.append((km.mc.MCTimeCorrector, {}))
219 else:
220 warnings.warn("Can not correct mc time: mc_hits "
221 "and/or mc_tracks not found!")
223 if self.det_file:
224 cmpts.append(
225 (
226 modules.DetApplier,
227 {
228 "det_file": self.det_file,
229 "correct_timeslew": self.correct_timeslew,
230 "center_hits_to": self.center_hits_to,
231 "calib_hits": self.calib_hits,
232 "calib_mchits": self.calib_mchits,
233 },
234 )
235 )
237 if any((self.center_time, self.add_t0)):
238 cmpts.append(
239 (
240 modules.TimePreproc,
241 {"add_t0": self.add_t0, "center_time": self.center_time},
242 )
243 )
244 return cmpts
246 @abstractmethod
247 def get_cmpts_main(self):
248 """Produce and store the samples as 'samples' in the blob."""
249 raise NotImplementedError
251 def get_cmpts_post(self, outfile):
252 """Modules that postproc and save the events."""
253 cmpts = []
254 if self.extractor is not None:
255 cmpts.append(
256 (
257 modules.McInfoMaker,
258 {
259 "extractor": self.extractor,
260 "to_float64": self.y_to_float64,
261 "sort_y": self.sort_y,
262 "store_as": "mc_info",
263 },
264 )
265 )
267 if self.event_skipper is not None:
268 cmpts.append((modules.EventSkipper, {"event_skipper": self.event_skipper}))
270 keys_keep = ["samples", "mc_info", "header", "raw_header"]
271 if self.keep_event_info:
272 keys_keep.append("EventInfo")
273 cmpts.append((km.common.Keep, {"keys": keys_keep}))
275 cmpts.append(
276 (
277 kp.io.HDF5Sink,
278 {
279 "filename": outfile,
280 "complib": self.complib,
281 "complevel": self.complevel,
282 "chunksize": self.chunksize,
283 "flush_frequency": self.flush_frequency,
284 },
285 )
286 )
287 return cmpts
289 def finish_file(self, f, summary):
290 """
291 Work with the output file after the pipe has finished.
293 Parameters
294 ----------
295 f : h5py.File
296 The opened output file.
297 summary : km3pipe.Blob
298 The output from pipe.drain().
300 """
301 # Add current orcasong version to h5 file
302 f.attrs.create("orcasong", orcasong.__version__)
305class FileBinner(BaseProcessor):
306 """
307 For making binned images and mc_infos, which can be used for conv nets.
309 Can also add statistics of the binning to the h5 files, which can
310 be plotted to show the distribution of hits among the bins and how
311 many hits were cut off.
313 Parameters
314 ----------
315 bin_edges_list : List
316 List with the names of the fields to bin, and the respective bin
317 edges, including the left- and right-most bin edge.
318 Example: For 10 bins in the z direction, and 100 bins in time:
319 bin_edges_list = [
320 ["pos_z", np.linspace(0, 10, 11)],
321 ["time", np.linspace(-50, 550, 101)],
322 ]
323 Some examples can be found in orcasong.bin_edges.
324 add_bin_stats : bool
325 Add statistics of the binning to the output file. They can be
326 plotted with util/bin_stats_plot.py [default: True].
327 hit_weights : str, optional
328 Use blob["Hits"][hit_weights] as weights for samples in histogram.
329 kwargs
330 Options of the BaseProcessor.
332 """
334 def __init__(self, bin_edges_list, add_bin_stats=True, hit_weights=None, chunksize=32, **kwargs):
335 self.bin_edges_list = bin_edges_list
336 self.add_bin_stats = add_bin_stats
337 self.hit_weights = hit_weights
338 super().__init__(chunksize=chunksize, **kwargs)
340 def get_cmpts_main(self):
341 """Generate nD images."""
342 cmpts = []
343 if self.add_bin_stats:
344 cmpts.append(
345 (modules.BinningStatsMaker, {"bin_edges_list": self.bin_edges_list})
346 )
347 cmpts.append(
348 (
349 modules.ImageMaker,
350 {
351 "bin_edges_list": self.bin_edges_list,
352 "hit_weights": self.hit_weights,
353 },
354 )
355 )
356 return cmpts
358 def finish_file(self, f, summary):
359 super().finish_file(f, summary)
360 if self.add_bin_stats:
361 plot_binstats.add_hists_to_h5file(summary["BinningStatsMaker"], f)
363 def run_multi(self, infiles, outfolder, save_plot=False):
364 """
365 Bin multiple files into their own output files each.
366 The output file names will be generated automatically.
368 Parameters
369 ----------
370 infiles : List
371 The path to infiles as str.
372 outfolder : str
373 The output folder to place them in.
374 save_plot : bool
375 Save the binning hists as a pdf. Only possible if add_bin_stats
376 is True.
378 """
379 if save_plot and not self.add_bin_stats:
380 raise ValueError("Can not make plot when add_bin_stats is False")
382 name, shape = self.get_names_and_shape()
383 print("Generating {} images with shape {}".format(name, shape))
385 outfiles = super().run_multi(infiles=infiles, outfolder=outfolder)
387 if save_plot:
388 plot_binstats.plot_hist_of_files(
389 files=outfiles, save_as=outfolder + "binning_hist.pdf"
390 )
391 return outfiles
393 def get_names_and_shape(self):
394 """
395 Get names and shape of the resulting x data,
396 e.g. (pos_z, time), (18, 50).
397 """
398 names, shape = [], []
399 for bin_name, bin_edges in self.bin_edges_list:
400 names.append(bin_name)
401 shape.append(len(bin_edges) - 1)
402 return tuple(names), tuple(shape)
404 def __repr__(self):
405 return "<FileBinner: {} {}>".format(*self.get_names_and_shape())
408class FileGraph(BaseProcessor):
409 """
410 Turn km3 events to graph data.
412 The resulting file will have a dataset "x" of shape
413 (total n_hits, len(hit_infos)).
414 The column names of the last axis (i.e. hit_infos) are saved
415 as attributes of the dataset (f["x"].attrs).
417 Parameters
418 ----------
419 hit_infos : tuple, optional
420 Which entries in the '/Hits' Table will be kept. E.g. pos_x, time, ...
421 Often, only dir_x/y/z, pos_x/y/z and time are required.
422 Default: Keep all entries.
423 time_window : tuple, optional
424 Two ints (start, end). Hits outside of this time window will be cut
425 away (based on 'Hits/time'). Default: Keep all hits.
426 only_triggered_hits : bool
427 If true, use only triggered hits. Otherwise, use all hits (default).
428 max_n_hits : int
429 Maximum number of hits that gets saved per event. If an event has
430 more, some will get cut randomly! Default: Keep all hits.
431 fixed_length : bool
432 Legacy option.
433 If False (default), save hits of events with variable length as
434 2d arrays using km3pipe's indices.
435 If True, pad hits of each event with 0s to a fixed length,
436 so that they can be stored as 3d arrays like images.
437 max_n_hits needs to be given in that case, and a column will be
438 added called 'is_valid', which is 0 if the entry is padded,
439 and 1 otherwise.
440 This is inefficient and will cut off hits, so it should not be used.
441 kwargs
442 Options of the BaseProcessor.
444 """
446 def __init__(
447 self,
448 max_n_hits=None,
449 time_window=None,
450 hit_infos=None,
451 only_triggered_hits=False,
452 fixed_length=False,
453 **kwargs,
454 ):
455 self.max_n_hits = max_n_hits
456 self.fixed_length = fixed_length
457 self.time_window = time_window
458 self.hit_infos = hit_infos
459 self.only_triggered_hits = only_triggered_hits
460 super().__init__(**kwargs)
462 def get_cmpts_main(self):
463 return [
464 (
465 (
466 modules.PointMaker,
467 {
468 "max_n_hits": self.max_n_hits,
469 "fixed_length": self.fixed_length,
470 "time_window": self.time_window,
471 "hit_infos": self.hit_infos,
472 "dset_n_hits": "EventInfo",
473 "only_triggered_hits": self.only_triggered_hits,
474 },
475 )
476 )
477 ]
479 def finish_file(self, f, summary):
480 super().finish_file(f, summary)
481 for i, hit_info in enumerate(summary["PointMaker"]["hit_infos"]):
482 f["x"].attrs.create(f"hit_info_{i}", hit_info)
483 f["x"].attrs.create("indexed", not self.fixed_length)