Coverage for orcasong/modules.py: 90%

240 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-10-03 18:23 +0000

1""" 

2Custom km3pipe modules for making nn input files. 

3""" 

4 

5import numpy as np 

6import km3pipe as kp 

7import km3modules as km 

8import orcasong.plotting.plot_binstats as plot_binstats 

9 

10__author__ = "Stefan Reck" 

11 

12 

13class McInfoMaker(kp.Module): 

14 """ 

15 Stores info as float64 in the blob. 

16 

17 Attributes 

18 ---------- 

19 extractor : function 

20 Function to extract the info. Takes the blob as input, outputs 

21 a dict with the desired mc_infos. 

22 store_as : str 

23 Store the mcinfo with this name in the blob. 

24 

25 """ 

26 

27 def configure(self): 

28 self.extractor = self.require("extractor") 

29 self.store_as = self.require("store_as") 

30 self.to_float64 = self.get("to_float64", default=True) 

31 self.sort_y = self.get("sort_y", default=True) 

32 

33 def process(self, blob): 

34 track = self.extractor(blob) 

35 if self.sort_y: 

36 track = {k: track[k] for k in sorted(track)} 

37 if self.to_float64: 

38 dtypes = [] 

39 for key, v in track.items(): 

40 if key in ("group_id", "event_id"): 

41 dtypes.append((key, type(v))) 

42 else: 

43 dtypes.append((key, np.float64)) 

44 else: 

45 dtypes = None 

46 kp_hist = kp.dataclasses.Table( 

47 track, dtype=dtypes, h5loc="y", name="event_info" 

48 ) 

49 if len(kp_hist) != 1: 

50 self.log.warning( 

51 "Warning: Extracted mc_info should have len 1, " 

52 "but it has len {}".format(len(kp_hist)) 

53 ) 

54 blob[self.store_as] = kp_hist 

55 return blob 

56 

57 

58class TimePreproc(kp.Module): 

59 """ 

60 Preprocess the time in the blob in various ways. 

61 

62 Attributes 

63 ---------- 

64 add_t0 : bool 

65 If true, t0 will be added to times of hits. 

66 center_time : bool 

67 If true, center hit and mchit times with the time of the first 

68 triggered hit. 

69 

70 """ 

71 

72 def configure(self): 

73 self.add_t0 = self.get("add_t0", default=False) 

74 self.center_time = self.get("center_time", default=True) 

75 

76 self._print_flags = set() 

77 

78 def process(self, blob): 

79 if not "Hits" in blob: 

80 self.log.warn("One event doesn't have hits for some reason. Sad. Skipping.") 

81 return 

82 if self.add_t0: 

83 blob = self.add_t0_time(blob) 

84 if self.center_time: 

85 blob = self.center_hittime(blob) 

86 return blob 

87 

88 def add_t0_time(self, blob): 

89 self._print_once("Adding t0 to hit times") 

90 blob["Hits"].time = np.add(blob["Hits"].time, blob["Hits"].t0) 

91 return blob 

92 

93 def center_hittime(self, blob): 

94 hits_time = blob["Hits"].time 

95 hits_triggered = blob["Hits"].triggered 

96 t_first_trigger = np.min(hits_time[hits_triggered != 0]) 

97 

98 self._print_once("Centering time of Hits with first triggered hit") 

99 blob["Hits"].time = np.subtract(hits_time, t_first_trigger) 

100 

101 if "McHits" in blob: 

102 self._print_once("Centering time of McHits with first triggered hit") 

103 mchits_time = blob["McHits"].time 

104 blob["McHits"].time = np.subtract(mchits_time, t_first_trigger) 

105 

106 return blob 

107 

108 def _print_once(self, text): 

109 if text not in self._print_flags: 

110 self._print_flags.add(text) 

111 self.cprint(text) 

112 

113 

114class ImageMaker(kp.Module): 

115 """ 

116 Make a n-d histogram from "Hits", and store it in the blob as 'samples'. 

117 

118 Attributes 

119 ---------- 

120 bin_edges_list : List 

121 List with the names of the fields to bin, and the respective bin edges, 

122 including the left- and right-most bin edge. 

123 hit_weights : str, optional 

124 Use blob["Hits"][hit_weights] as weights for samples in histogram. 

125 

126 """ 

127 

128 def configure(self): 

129 self.bin_edges_list = self.require("bin_edges_list") 

130 self.hit_weights = self.get("hit_weights") 

131 self.store_as = "samples" 

132 

133 def process(self, blob): 

134 data, bins, name = [], [], "" 

135 

136 for bin_name, bin_edges in self.bin_edges_list: 

137 data.append(blob["Hits"][bin_name]) 

138 bins.append(bin_edges) 

139 name += bin_name + "_" 

140 

141 if self.hit_weights is not None: 

142 weights = blob["Hits"][self.hit_weights] 

143 else: 

144 weights = None 

145 

146 histogram = np.histogramdd(data, bins=bins, weights=weights)[0] 

147 

148 hist_one_event = histogram[np.newaxis, ...].astype(np.uint8) 

149 kp_hist = kp.dataclasses.NDArray( 

150 hist_one_event, h5loc="x", title=name + "event_images" 

151 ) 

152 

153 blob[self.store_as] = kp_hist 

154 return blob 

155 

156 

157class BinningStatsMaker(kp.Module): 

158 """ 

159 Generate a histogram of the number of hits for each binning field name. 

160 

161 E.g. if the bin_edges_list contains "pos_z", this will make a histogram 

162 of #Hits vs. "pos_z", together with how many hits were outside 

163 of the bin edges in both directions. 

164 

165 Per default, the resolution of the histogram (width of bins) will be 

166 higher then the given bin edges, and the edges will be stored seperatly. 

167 The time is the exception: The plotted bins have exactly the 

168 given bin edges. 

169 

170 Attributes 

171 ---------- 

172 bin_edges_list : List 

173 List with the names of the fields to bin, and the respective bin edges, 

174 including the left- and right-most bin edge. 

175 res_increase : int 

176 Increase the number of bins by this much in the hists (so that one 

177 can see if the edges have been placed correctly). Is never used 

178 for the time binning (field name "time"). 

179 bin_plot_freq : int 

180 Extract data for the histograms only every given number of blobs 

181 (reduces time the pipeline takes to complete). 

182 

183 """ 

184 

185 def configure(self): 

186 self.bin_edges_list = self.require("bin_edges_list") 

187 self.res_increase = self.get("res_increase", default=5) 

188 self.bin_plot_freq = 1 

189 

190 self.hists = {} 

191 for bin_name, org_bin_edges in self.bin_edges_list: 

192 # dont space bin edges for time 

193 if bin_name == "time": 

194 bin_edges = org_bin_edges 

195 else: 

196 bin_edges = self._space_bin_edges(org_bin_edges) 

197 

198 self.hists[bin_name] = { 

199 "hist": np.zeros(len(bin_edges) - 1), 

200 "hist_bin_edges": bin_edges, 

201 "bin_edges": org_bin_edges, 

202 # below smallest edge, above largest edge: 

203 "cut_off": np.zeros(2), 

204 } 

205 

206 self.i = 0 

207 

208 def _space_bin_edges(self, bin_edges): 

209 """ 

210 Increase resolution of given binning. 

211 """ 

212 increased_n_bins = (len(bin_edges) - 1) * self.res_increase + 1 

213 bin_edges = np.linspace(bin_edges[0], bin_edges[-1], increased_n_bins) 

214 

215 return bin_edges 

216 

217 def process(self, blob): 

218 """ 

219 Extract data from blob for the hist plots. 

220 """ 

221 if self.i % self.bin_plot_freq == 0: 

222 for bin_name, hists_data in self.hists.items(): 

223 hist_bin_edges = hists_data["hist_bin_edges"] 

224 

225 hits = blob["Hits"] 

226 data = hits[bin_name] 

227 # get how much is cut off due to these limits 

228 out_pos = data[data > np.max(hist_bin_edges)].size 

229 out_neg = data[data < np.min(hist_bin_edges)].size 

230 

231 # get all hits which are not cut off by other bin edges 

232 data = hits[bin_name][self._is_in_limits(hits, excluded=bin_name)] 

233 hist = np.histogram(data, bins=hist_bin_edges)[0] 

234 

235 self.hists[bin_name]["hist"] += hist 

236 self.hists[bin_name]["cut_off"] += np.array([out_neg, out_pos]) 

237 

238 self.i += 1 

239 return blob 

240 

241 def finish(self): 

242 """ 

243 Append the hists, which are the stats of the binning. 

244 

245 Its a dict with each binning field name containing the following 

246 ndarrays: 

247 

248 bin_edges : The actual bin edges. 

249 cut_off : How many events were cut off in positive and negative 

250 direction due to this binning. 

251 hist_bin_edges : The bin edges for the plot in finer resolution then 

252 the actual bin edges. 

253 hist : The number of hist in each bin of the hist_bin_edges. 

254 

255 """ 

256 return self.hists 

257 

258 def _is_in_limits(self, hits, excluded=None): 

259 """Get which hits are in the limits defined by ALL bin edges 

260 (except for given one).""" 

261 inside = None 

262 for dfield, edges in self.bin_edges_list: 

263 if dfield == excluded: 

264 continue 

265 is_in = np.logical_and( 

266 hits[dfield] >= min(edges), hits[dfield] <= max(edges) 

267 ) 

268 if inside is None: 

269 inside = is_in 

270 else: 

271 inside = np.logical_and(inside, is_in) 

272 return inside 

273 

274 

275class PointMaker(kp.Module): 

276 """ 

277 Store individual hit info from "Hits" in the blob as 'samples'. 

278 

279 Used for graph networks. 

280 

281 Attributes 

282 ---------- 

283 hit_infos : tuple, optional 

284 Which entries in the '/Hits' Table will be kept. E.g. pos_x, time, ... 

285 Default: Keep all entries. 

286 time_window : tuple, optional 

287 Two ints (start, end). Hits outside of this time window will be cut 

288 away (based on 'Hits/time'). Default: Keep all hits. 

289 only_triggered_hits : bool 

290 If true, use only triggered hits. Otherwise, use all hits (default). 

291 max_n_hits : int 

292 Maximum number of hits that gets saved per event. If an event has 

293 more, some will get cut randomly! Default: Keep all hits. 

294 fixed_length : bool 

295 If False (default), save hits of events with variable length as 

296 2d arrays using km3pipe's indices. 

297 If True, pad hits of each event with 0s to a fixed length, 

298 so that they can be stored as 3d arrays like images. 

299 max_n_hits needs to be given in that case, and a column will be 

300 added called 'is_valid', which is 0 if the entry is padded, 

301 and 1 otherwise. 

302 This is inefficient and will cut off hits, so it should not be used. 

303 dset_n_hits : str, optional 

304 If given, store the number of hits that are in the time window 

305 as a new column called 'n_hits_intime' in the dataset with 

306 this name (usually this is EventInfo). 

307 

308 """ 

309 

310 def configure(self): 

311 self.hit_infos = self.get("hit_infos", default=None) 

312 self.time_window = self.get("time_window", default=None) 

313 self.only_triggered_hits = self.get("only_triggered_hits", default=False) 

314 self.max_n_hits = self.get("max_n_hits", default=None) 

315 self.fixed_length = self.get("fixed_length", default=False) 

316 self.dset_n_hits = self.get("dset_n_hits", default=None) 

317 self.store_as = "samples" 

318 

319 def process(self, blob): 

320 if self.fixed_length and self.max_n_hits is None: 

321 raise ValueError("Have to specify max_n_hits if fixed_length is True") 

322 if self.hit_infos is None: 

323 self.hit_infos = blob["Hits"].dtype.names 

324 points, n_hits = self.get_points(blob) 

325 blob[self.store_as] = kp.NDArray(points, h5loc="x", title="nodes") 

326 if self.dset_n_hits: 

327 blob[self.dset_n_hits] = blob[self.dset_n_hits].append_columns( 

328 "n_hits_intime", n_hits 

329 ) 

330 return blob 

331 

332 def get_points(self, blob): 

333 """ 

334 Get the desired hit infos from the blob. 

335 

336 Returns 

337 ------- 

338 points : np.array 

339 The hit infos of this event as a 2d matrix. No of rows are 

340 fixed to the given max_n_hits. Each of the self.extract_keys, 

341 is in one column + an additional column which is 1 for 

342 actual hits, and 0 for if its a padded row. 

343 n_hits : int 

344 Number of hits in the given time window. 

345 Can be stored as n_hits_intime. 

346 

347 """ 

348 hits = blob["Hits"] 

349 if self.only_triggered_hits: 

350 hits = hits[hits.triggered != 0] 

351 if self.time_window is not None: 

352 # remove hits outside of time window 

353 hits = hits[ 

354 np.logical_and( 

355 hits["time"] >= self.time_window[0], 

356 hits["time"] <= self.time_window[1], 

357 ) 

358 ] 

359 

360 n_hits = len(hits) 

361 if self.max_n_hits is not None and n_hits > self.max_n_hits: 

362 # if there are too many hits, take random ones, but keep order 

363 indices = np.arange(n_hits) 

364 np.random.shuffle(indices) 

365 which = indices[: self.max_n_hits] 

366 which.sort() 

367 hits = hits[which] 

368 

369 if self.fixed_length: 

370 points = np.zeros( 

371 (self.max_n_hits, len(self.hit_infos) + 1), dtype="float32" 

372 ) 

373 for i, which in enumerate(self.hit_infos): 

374 points[:n_hits, i] = hits[which] 

375 # last column is whether there was a hit or no 

376 points[:n_hits, -1] = 1.0 

377 # store along new axis 

378 points = np.expand_dims(points, 0) 

379 else: 

380 # TODO points should be a Table, not a ndarray 

381 points = np.zeros((len(hits), len(self.hit_infos)), dtype="float32") 

382 for i, which in enumerate(self.hit_infos): 

383 points[:, i] = hits[which] 

384 return points, n_hits 

385 

386 def finish(self): 

387 columns = tuple(self.hit_infos) 

388 if self.fixed_length: 

389 columns += ("is_valid",) 

390 return {"hit_infos": columns} 

391 

392 

393class EventSkipper(kp.Module): 

394 """ 

395 Skip events based on blob content. 

396 

397 Attributes 

398 ---------- 

399 event_skipper : callable 

400 Function that takes the blob as an input, and returns a bool. 

401 If the bool is true, the blob will be skipped. 

402 

403 """ 

404 

405 def configure(self): 

406 self.event_skipper = self.require("event_skipper") 

407 self._not_skipped = 0 

408 self._skipped = 0 

409 

410 def process(self, blob): 

411 if self.event_skipper(blob): 

412 self._skipped += 1 

413 return 

414 else: 

415 self._not_skipped += 1 

416 return blob 

417 

418 def finish(self): 

419 tot_events = self._skipped + self._not_skipped 

420 self.cprint( 

421 f"Skipped {self._skipped}/{tot_events} events " 

422 f"({self._skipped/tot_events:.4%})." 

423 ) 

424 

425 

426class DetApplier(kp.Module): 

427 """ 

428 Apply detector information to the event data from a detx file, e.g. 

429 calibrating hits. 

430 

431 Attributes 

432 ---------- 

433 det_file : str 

434 Path to a .detx detector geometry file. 

435 calib_hits : bool 

436 Apply calibration to hits. Default: True. 

437 calib_mchits : bool 

438 Apply calibration to mchits, if mchits are in the blob. Default: True. 

439 correct_timeslew : bool 

440 If true (default), the time slewing of hits depending on their tot 

441 will be corrected. Only done if calib_hits is True. 

442 center_hits_to : tuple, optional 

443 Translate the xyz positions of the hits (and mchits), as if 

444 the detector was centered at the given position. 

445 E.g., if its (0, 0, None), the hits and mchits will be 

446 centered at xy = 00, and z will be left untouched. 

447 

448 """ 

449 

450 def configure(self): 

451 self.det_file = self.require("det_file") 

452 self.correct_timeslew = self.get("correct_timeslew", default=True) 

453 self.calib_hits = self.get("calib_hits", default=True) 

454 self.calib_mchits = self.get("calib_mchits", default=True) 

455 self.center_hits_to = self.get("center_hits_to", default=None) 

456 

457 self.cprint(f"Calibrating with {self.det_file}") 

458 self.calib = kp.calib.Calibration(filename=self.det_file) 

459 self._calib_checked = False 

460 

461 # dict dim_name: float 

462 self._vector_shift = None 

463 

464 if self.center_hits_to: 

465 self._cache_shift_center() 

466 

467 def process(self, blob): 

468 if self.calib_hits: 

469 if self._calib_checked is False: 

470 if "pos_x" in blob["Hits"]: 

471 self.log.warn( 

472 "Warning: Using a det file, but pos_x in Hits detected. " 

473 "Is the file already calibrated? This might lead to " 

474 "errors with t0." 

475 ) 

476 self._calib_checked = True 

477 

478 blob["Hits"] = self.calib.apply( 

479 blob["Hits"], correct_slewing=self.correct_timeslew 

480 ) 

481 if self.calib_mchits and "McHits" in blob: 

482 blob["McHits"] = self.calib.apply(blob["McHits"]) 

483 if self.center_hits_to: 

484 self.shift_hits(blob) 

485 return blob 

486 

487 def shift_hits(self, blob): 

488 """Translate hits by cached vector.""" 

489 for dim_name in ("pos_x", "pos_y", "pos_z"): 

490 blob["Hits"][dim_name] += self._vector_shift[dim_name] 

491 if "McHits" in blob: 

492 blob["McHits"][dim_name] += self._vector_shift[dim_name] 

493 

494 def _cache_shift_center(self): 

495 det_center, shift = {}, {} 

496 for i, dim_name in enumerate(("pos_x", "pos_y", "pos_z")): 

497 center = self.calib.detector.dom_table[dim_name].mean() 

498 det_center[dim_name] = center 

499 

500 if self.center_hits_to[i] is None: 

501 shift[dim_name] = 0 

502 else: 

503 shift[dim_name] = self.center_hits_to[i] - center 

504 

505 self._vector_shift = shift 

506 self.cprint(f"original detector center: {det_center}") 

507 self.cprint(f"shift for hits: {self._vector_shift}") 

508 

509 

510class HitRotator(kp.Module): 

511 """ 

512 Rotates hits by angle theta. 

513 

514 Attributes 

515 ---------- 

516 theta : float 

517 Angle by which hits are rotated (radian). 

518 

519 """ 

520 

521 def configure(self): 

522 self.theta = self.require("theta") 

523 

524 def process(self, blob): 

525 x = blob["Hits"]["x"] 

526 y = blob["Hits"]["y"] 

527 

528 rot_matrix = np.array( 

529 [ 

530 [np.cos(self.theta), -np.sin(self.theta)], 

531 [np.sin(self.theta), np.cos(self.theta)], 

532 ] 

533 ) 

534 

535 x_rot = [] 

536 y_rot = [] 

537 

538 for i in range(0, len(x)): 

539 vec = np.array([[x[i]], [y[i]]]) 

540 rot = np.dot(rot_matrix, vec) 

541 x_rot.append(rot[0][0]) 

542 y_rot.append(rot[1][0]) 

543 

544 blob["Hits"]["x"] = x_rot 

545 blob["Hits"]["y"] = y_rot 

546 

547 return blob