Coverage for orcasong/modules.py: 90%
240 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-10-03 18:23 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-10-03 18:23 +0000
1"""
2Custom km3pipe modules for making nn input files.
3"""
5import numpy as np
6import km3pipe as kp
7import km3modules as km
8import orcasong.plotting.plot_binstats as plot_binstats
10__author__ = "Stefan Reck"
13class McInfoMaker(kp.Module):
14 """
15 Stores info as float64 in the blob.
17 Attributes
18 ----------
19 extractor : function
20 Function to extract the info. Takes the blob as input, outputs
21 a dict with the desired mc_infos.
22 store_as : str
23 Store the mcinfo with this name in the blob.
25 """
27 def configure(self):
28 self.extractor = self.require("extractor")
29 self.store_as = self.require("store_as")
30 self.to_float64 = self.get("to_float64", default=True)
31 self.sort_y = self.get("sort_y", default=True)
33 def process(self, blob):
34 track = self.extractor(blob)
35 if self.sort_y:
36 track = {k: track[k] for k in sorted(track)}
37 if self.to_float64:
38 dtypes = []
39 for key, v in track.items():
40 if key in ("group_id", "event_id"):
41 dtypes.append((key, type(v)))
42 else:
43 dtypes.append((key, np.float64))
44 else:
45 dtypes = None
46 kp_hist = kp.dataclasses.Table(
47 track, dtype=dtypes, h5loc="y", name="event_info"
48 )
49 if len(kp_hist) != 1:
50 self.log.warning(
51 "Warning: Extracted mc_info should have len 1, "
52 "but it has len {}".format(len(kp_hist))
53 )
54 blob[self.store_as] = kp_hist
55 return blob
58class TimePreproc(kp.Module):
59 """
60 Preprocess the time in the blob in various ways.
62 Attributes
63 ----------
64 add_t0 : bool
65 If true, t0 will be added to times of hits.
66 center_time : bool
67 If true, center hit and mchit times with the time of the first
68 triggered hit.
70 """
72 def configure(self):
73 self.add_t0 = self.get("add_t0", default=False)
74 self.center_time = self.get("center_time", default=True)
76 self._print_flags = set()
78 def process(self, blob):
79 if not "Hits" in blob:
80 self.log.warn("One event doesn't have hits for some reason. Sad. Skipping.")
81 return
82 if self.add_t0:
83 blob = self.add_t0_time(blob)
84 if self.center_time:
85 blob = self.center_hittime(blob)
86 return blob
88 def add_t0_time(self, blob):
89 self._print_once("Adding t0 to hit times")
90 blob["Hits"].time = np.add(blob["Hits"].time, blob["Hits"].t0)
91 return blob
93 def center_hittime(self, blob):
94 hits_time = blob["Hits"].time
95 hits_triggered = blob["Hits"].triggered
96 t_first_trigger = np.min(hits_time[hits_triggered != 0])
98 self._print_once("Centering time of Hits with first triggered hit")
99 blob["Hits"].time = np.subtract(hits_time, t_first_trigger)
101 if "McHits" in blob:
102 self._print_once("Centering time of McHits with first triggered hit")
103 mchits_time = blob["McHits"].time
104 blob["McHits"].time = np.subtract(mchits_time, t_first_trigger)
106 return blob
108 def _print_once(self, text):
109 if text not in self._print_flags:
110 self._print_flags.add(text)
111 self.cprint(text)
114class ImageMaker(kp.Module):
115 """
116 Make a n-d histogram from "Hits", and store it in the blob as 'samples'.
118 Attributes
119 ----------
120 bin_edges_list : List
121 List with the names of the fields to bin, and the respective bin edges,
122 including the left- and right-most bin edge.
123 hit_weights : str, optional
124 Use blob["Hits"][hit_weights] as weights for samples in histogram.
126 """
128 def configure(self):
129 self.bin_edges_list = self.require("bin_edges_list")
130 self.hit_weights = self.get("hit_weights")
131 self.store_as = "samples"
133 def process(self, blob):
134 data, bins, name = [], [], ""
136 for bin_name, bin_edges in self.bin_edges_list:
137 data.append(blob["Hits"][bin_name])
138 bins.append(bin_edges)
139 name += bin_name + "_"
141 if self.hit_weights is not None:
142 weights = blob["Hits"][self.hit_weights]
143 else:
144 weights = None
146 histogram = np.histogramdd(data, bins=bins, weights=weights)[0]
148 hist_one_event = histogram[np.newaxis, ...].astype(np.uint8)
149 kp_hist = kp.dataclasses.NDArray(
150 hist_one_event, h5loc="x", title=name + "event_images"
151 )
153 blob[self.store_as] = kp_hist
154 return blob
157class BinningStatsMaker(kp.Module):
158 """
159 Generate a histogram of the number of hits for each binning field name.
161 E.g. if the bin_edges_list contains "pos_z", this will make a histogram
162 of #Hits vs. "pos_z", together with how many hits were outside
163 of the bin edges in both directions.
165 Per default, the resolution of the histogram (width of bins) will be
166 higher then the given bin edges, and the edges will be stored seperatly.
167 The time is the exception: The plotted bins have exactly the
168 given bin edges.
170 Attributes
171 ----------
172 bin_edges_list : List
173 List with the names of the fields to bin, and the respective bin edges,
174 including the left- and right-most bin edge.
175 res_increase : int
176 Increase the number of bins by this much in the hists (so that one
177 can see if the edges have been placed correctly). Is never used
178 for the time binning (field name "time").
179 bin_plot_freq : int
180 Extract data for the histograms only every given number of blobs
181 (reduces time the pipeline takes to complete).
183 """
185 def configure(self):
186 self.bin_edges_list = self.require("bin_edges_list")
187 self.res_increase = self.get("res_increase", default=5)
188 self.bin_plot_freq = 1
190 self.hists = {}
191 for bin_name, org_bin_edges in self.bin_edges_list:
192 # dont space bin edges for time
193 if bin_name == "time":
194 bin_edges = org_bin_edges
195 else:
196 bin_edges = self._space_bin_edges(org_bin_edges)
198 self.hists[bin_name] = {
199 "hist": np.zeros(len(bin_edges) - 1),
200 "hist_bin_edges": bin_edges,
201 "bin_edges": org_bin_edges,
202 # below smallest edge, above largest edge:
203 "cut_off": np.zeros(2),
204 }
206 self.i = 0
208 def _space_bin_edges(self, bin_edges):
209 """
210 Increase resolution of given binning.
211 """
212 increased_n_bins = (len(bin_edges) - 1) * self.res_increase + 1
213 bin_edges = np.linspace(bin_edges[0], bin_edges[-1], increased_n_bins)
215 return bin_edges
217 def process(self, blob):
218 """
219 Extract data from blob for the hist plots.
220 """
221 if self.i % self.bin_plot_freq == 0:
222 for bin_name, hists_data in self.hists.items():
223 hist_bin_edges = hists_data["hist_bin_edges"]
225 hits = blob["Hits"]
226 data = hits[bin_name]
227 # get how much is cut off due to these limits
228 out_pos = data[data > np.max(hist_bin_edges)].size
229 out_neg = data[data < np.min(hist_bin_edges)].size
231 # get all hits which are not cut off by other bin edges
232 data = hits[bin_name][self._is_in_limits(hits, excluded=bin_name)]
233 hist = np.histogram(data, bins=hist_bin_edges)[0]
235 self.hists[bin_name]["hist"] += hist
236 self.hists[bin_name]["cut_off"] += np.array([out_neg, out_pos])
238 self.i += 1
239 return blob
241 def finish(self):
242 """
243 Append the hists, which are the stats of the binning.
245 Its a dict with each binning field name containing the following
246 ndarrays:
248 bin_edges : The actual bin edges.
249 cut_off : How many events were cut off in positive and negative
250 direction due to this binning.
251 hist_bin_edges : The bin edges for the plot in finer resolution then
252 the actual bin edges.
253 hist : The number of hist in each bin of the hist_bin_edges.
255 """
256 return self.hists
258 def _is_in_limits(self, hits, excluded=None):
259 """Get which hits are in the limits defined by ALL bin edges
260 (except for given one)."""
261 inside = None
262 for dfield, edges in self.bin_edges_list:
263 if dfield == excluded:
264 continue
265 is_in = np.logical_and(
266 hits[dfield] >= min(edges), hits[dfield] <= max(edges)
267 )
268 if inside is None:
269 inside = is_in
270 else:
271 inside = np.logical_and(inside, is_in)
272 return inside
275class PointMaker(kp.Module):
276 """
277 Store individual hit info from "Hits" in the blob as 'samples'.
279 Used for graph networks.
281 Attributes
282 ----------
283 hit_infos : tuple, optional
284 Which entries in the '/Hits' Table will be kept. E.g. pos_x, time, ...
285 Default: Keep all entries.
286 time_window : tuple, optional
287 Two ints (start, end). Hits outside of this time window will be cut
288 away (based on 'Hits/time'). Default: Keep all hits.
289 only_triggered_hits : bool
290 If true, use only triggered hits. Otherwise, use all hits (default).
291 max_n_hits : int
292 Maximum number of hits that gets saved per event. If an event has
293 more, some will get cut randomly! Default: Keep all hits.
294 fixed_length : bool
295 If False (default), save hits of events with variable length as
296 2d arrays using km3pipe's indices.
297 If True, pad hits of each event with 0s to a fixed length,
298 so that they can be stored as 3d arrays like images.
299 max_n_hits needs to be given in that case, and a column will be
300 added called 'is_valid', which is 0 if the entry is padded,
301 and 1 otherwise.
302 This is inefficient and will cut off hits, so it should not be used.
303 dset_n_hits : str, optional
304 If given, store the number of hits that are in the time window
305 as a new column called 'n_hits_intime' in the dataset with
306 this name (usually this is EventInfo).
308 """
310 def configure(self):
311 self.hit_infos = self.get("hit_infos", default=None)
312 self.time_window = self.get("time_window", default=None)
313 self.only_triggered_hits = self.get("only_triggered_hits", default=False)
314 self.max_n_hits = self.get("max_n_hits", default=None)
315 self.fixed_length = self.get("fixed_length", default=False)
316 self.dset_n_hits = self.get("dset_n_hits", default=None)
317 self.store_as = "samples"
319 def process(self, blob):
320 if self.fixed_length and self.max_n_hits is None:
321 raise ValueError("Have to specify max_n_hits if fixed_length is True")
322 if self.hit_infos is None:
323 self.hit_infos = blob["Hits"].dtype.names
324 points, n_hits = self.get_points(blob)
325 blob[self.store_as] = kp.NDArray(points, h5loc="x", title="nodes")
326 if self.dset_n_hits:
327 blob[self.dset_n_hits] = blob[self.dset_n_hits].append_columns(
328 "n_hits_intime", n_hits
329 )
330 return blob
332 def get_points(self, blob):
333 """
334 Get the desired hit infos from the blob.
336 Returns
337 -------
338 points : np.array
339 The hit infos of this event as a 2d matrix. No of rows are
340 fixed to the given max_n_hits. Each of the self.extract_keys,
341 is in one column + an additional column which is 1 for
342 actual hits, and 0 for if its a padded row.
343 n_hits : int
344 Number of hits in the given time window.
345 Can be stored as n_hits_intime.
347 """
348 hits = blob["Hits"]
349 if self.only_triggered_hits:
350 hits = hits[hits.triggered != 0]
351 if self.time_window is not None:
352 # remove hits outside of time window
353 hits = hits[
354 np.logical_and(
355 hits["time"] >= self.time_window[0],
356 hits["time"] <= self.time_window[1],
357 )
358 ]
360 n_hits = len(hits)
361 if self.max_n_hits is not None and n_hits > self.max_n_hits:
362 # if there are too many hits, take random ones, but keep order
363 indices = np.arange(n_hits)
364 np.random.shuffle(indices)
365 which = indices[: self.max_n_hits]
366 which.sort()
367 hits = hits[which]
369 if self.fixed_length:
370 points = np.zeros(
371 (self.max_n_hits, len(self.hit_infos) + 1), dtype="float32"
372 )
373 for i, which in enumerate(self.hit_infos):
374 points[:n_hits, i] = hits[which]
375 # last column is whether there was a hit or no
376 points[:n_hits, -1] = 1.0
377 # store along new axis
378 points = np.expand_dims(points, 0)
379 else:
380 # TODO points should be a Table, not a ndarray
381 points = np.zeros((len(hits), len(self.hit_infos)), dtype="float32")
382 for i, which in enumerate(self.hit_infos):
383 points[:, i] = hits[which]
384 return points, n_hits
386 def finish(self):
387 columns = tuple(self.hit_infos)
388 if self.fixed_length:
389 columns += ("is_valid",)
390 return {"hit_infos": columns}
393class EventSkipper(kp.Module):
394 """
395 Skip events based on blob content.
397 Attributes
398 ----------
399 event_skipper : callable
400 Function that takes the blob as an input, and returns a bool.
401 If the bool is true, the blob will be skipped.
403 """
405 def configure(self):
406 self.event_skipper = self.require("event_skipper")
407 self._not_skipped = 0
408 self._skipped = 0
410 def process(self, blob):
411 if self.event_skipper(blob):
412 self._skipped += 1
413 return
414 else:
415 self._not_skipped += 1
416 return blob
418 def finish(self):
419 tot_events = self._skipped + self._not_skipped
420 self.cprint(
421 f"Skipped {self._skipped}/{tot_events} events "
422 f"({self._skipped/tot_events:.4%})."
423 )
426class DetApplier(kp.Module):
427 """
428 Apply detector information to the event data from a detx file, e.g.
429 calibrating hits.
431 Attributes
432 ----------
433 det_file : str
434 Path to a .detx detector geometry file.
435 calib_hits : bool
436 Apply calibration to hits. Default: True.
437 calib_mchits : bool
438 Apply calibration to mchits, if mchits are in the blob. Default: True.
439 correct_timeslew : bool
440 If true (default), the time slewing of hits depending on their tot
441 will be corrected. Only done if calib_hits is True.
442 center_hits_to : tuple, optional
443 Translate the xyz positions of the hits (and mchits), as if
444 the detector was centered at the given position.
445 E.g., if its (0, 0, None), the hits and mchits will be
446 centered at xy = 00, and z will be left untouched.
448 """
450 def configure(self):
451 self.det_file = self.require("det_file")
452 self.correct_timeslew = self.get("correct_timeslew", default=True)
453 self.calib_hits = self.get("calib_hits", default=True)
454 self.calib_mchits = self.get("calib_mchits", default=True)
455 self.center_hits_to = self.get("center_hits_to", default=None)
457 self.cprint(f"Calibrating with {self.det_file}")
458 self.calib = kp.calib.Calibration(filename=self.det_file)
459 self._calib_checked = False
461 # dict dim_name: float
462 self._vector_shift = None
464 if self.center_hits_to:
465 self._cache_shift_center()
467 def process(self, blob):
468 if self.calib_hits:
469 if self._calib_checked is False:
470 if "pos_x" in blob["Hits"]:
471 self.log.warn(
472 "Warning: Using a det file, but pos_x in Hits detected. "
473 "Is the file already calibrated? This might lead to "
474 "errors with t0."
475 )
476 self._calib_checked = True
478 blob["Hits"] = self.calib.apply(
479 blob["Hits"], correct_slewing=self.correct_timeslew
480 )
481 if self.calib_mchits and "McHits" in blob:
482 blob["McHits"] = self.calib.apply(blob["McHits"])
483 if self.center_hits_to:
484 self.shift_hits(blob)
485 return blob
487 def shift_hits(self, blob):
488 """Translate hits by cached vector."""
489 for dim_name in ("pos_x", "pos_y", "pos_z"):
490 blob["Hits"][dim_name] += self._vector_shift[dim_name]
491 if "McHits" in blob:
492 blob["McHits"][dim_name] += self._vector_shift[dim_name]
494 def _cache_shift_center(self):
495 det_center, shift = {}, {}
496 for i, dim_name in enumerate(("pos_x", "pos_y", "pos_z")):
497 center = self.calib.detector.dom_table[dim_name].mean()
498 det_center[dim_name] = center
500 if self.center_hits_to[i] is None:
501 shift[dim_name] = 0
502 else:
503 shift[dim_name] = self.center_hits_to[i] - center
505 self._vector_shift = shift
506 self.cprint(f"original detector center: {det_center}")
507 self.cprint(f"shift for hits: {self._vector_shift}")
510class HitRotator(kp.Module):
511 """
512 Rotates hits by angle theta.
514 Attributes
515 ----------
516 theta : float
517 Angle by which hits are rotated (radian).
519 """
521 def configure(self):
522 self.theta = self.require("theta")
524 def process(self, blob):
525 x = blob["Hits"]["x"]
526 y = blob["Hits"]["y"]
528 rot_matrix = np.array(
529 [
530 [np.cos(self.theta), -np.sin(self.theta)],
531 [np.sin(self.theta), np.cos(self.theta)],
532 ]
533 )
535 x_rot = []
536 y_rot = []
538 for i in range(0, len(x)):
539 vec = np.array([[x[i]], [y[i]]])
540 rot = np.dot(rot_matrix, vec)
541 x_rot.append(rot[0][0])
542 y_rot.append(rot[1][0])
544 blob["Hits"]["x"] = x_rot
545 blob["Hits"]["y"] = y_rot
547 return blob