Coverage for orcasong/core.py: 84%

140 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-10-03 18:23 +0000

1import os 

2import warnings 

3from abc import abstractmethod 

4import h5py 

5import km3pipe as kp 

6import km3modules as km 

7 

8import orcasong 

9import orcasong.modules as modules 

10import orcasong.plotting.plot_binstats as plot_binstats 

11 

12 

13__author__ = "Stefan Reck" 

14 

15 

16class BaseProcessor: 

17 """ 

18 Preprocess km3net/antares events for neural networks. 

19 

20 This serves as a baseclass, which handles things like reading 

21 events, calibrating, generating labels and saving the output. 

22 

23 Parameters 

24 ---------- 

25 extractor : function, optional 

26 Function that extracts desired info from a blob, which is then 

27 stored as the "y" datafield in the .h5 file. 

28 The function takes the km3pipe blob as an input, and returns 

29 a dict mapping str to floats. 

30 Examples can be found in orcasong.extractors. 

31 det_file : str, optional 

32 Path to a .detx detector geometry file, which can be used to 

33 calibrate the hits. 

34 correct_mc_time : bool 

35 Convert MC hit times to JTE times. Will only be done if 

36 mc_hits and mc_tracks are there. 

37 center_time : bool 

38 Subtract time of first triggered hit from all hit times. Will 

39 also be done for McHits if they are in the blob [default: True]. 

40 calib_hits : bool 

41 Apply calibration to hits if det file is given. Default: True. 

42 calib_mchits : bool 

43 Apply calibration to mchits if det file is given and mchits are 

44 found in the blob. Default: True. 

45 correct_timeslew : bool 

46 If true (default), the time slewing of hits depending on their tot 

47 will be corrected during calibration. 

48 Only done if det file is given and calib_hits is True. 

49 center_hits_to : tuple, optional 

50 Translate the xyz positions of the hits (and mchits), as if 

51 the detector was centered at the given position. 

52 E.g., if its (0, 0, None), the hits and mchits will be 

53 centered at xy = 00, and z will be left untouched. 

54 Can only be used when a detx file is given. 

55 add_t0 : bool 

56 If true, add t0 to the time of hits and mchits. If using a 

57 det_file, this will already have been done automatically 

58 [default: False]. 

59 event_skipper : func, optional 

60 Function that takes the blob as an input, and returns a bool. 

61 If the bool is true, the blob will be skipped. 

62 This is placed after the binning and mc_info extractor. 

63 chunksize : int, optional 

64 Chunksize (along axis_0) used for saving the output 

65 to a .h5 file [default: None, i.e. auto chunking]. 

66 keep_event_info : bool 

67 If True, will keep the "event_info" table [default: False]. 

68 overwrite : bool 

69 If True, overwrite the output file if it exists already. 

70 If False, throw an error instead. 

71 sort_y : bool 

72 Sort the columns in the y dataset alphabetically. 

73 y_to_float64 : bool 

74 Convert everything in the y dataset to float 64 (Default: True). 

75 Hint: Not all other dtypes can store nan! 

76 

77 Attributes 

78 ---------- 

79 n_statusbar : int or None 

80 Print a statusbar every n blobs. 

81 n_memory_observer : int or None 

82 Print memory usage every n blobs. 

83 complib : str 

84 Compression library used for saving the output to a .h5 file. 

85 All PyTables compression filters are available, e.g. 'zlib', 

86 'lzf', 'blosc', ... . 

87 complevel : int 

88 Compression level for the compression filter that is used for 

89 saving the output to a .h5 file. 

90 flush_frequency : int 

91 After how many events the accumulated output should be flushed to 

92 the harddisk. 

93 A larger value leads to a faster orcasong execution, 

94 but it increases the RAM usage as well. 

95 seed : int, optional 

96 Makes all random (numpy) actions reproducable. Set at the start of 

97 each pipeline. 

98 

99 """ 

100 

101 def __init__( 

102 self, 

103 extractor=None, 

104 det_file=None, 

105 correct_mc_time=True, 

106 center_time=True, 

107 calib_hits=True, 

108 calib_mchits=True, 

109 add_t0=False, 

110 correct_timeslew=True, 

111 center_hits_to=None, 

112 event_skipper=None, 

113 chunksize=None, 

114 keep_event_info=False, 

115 overwrite=True, 

116 sort_y=True, 

117 y_to_float64=True, 

118 ): 

119 if center_hits_to is not None and det_file is None: 

120 raise ValueError("det_file has to be given when using center_hits_to") 

121 

122 self.extractor = extractor 

123 self.det_file = det_file 

124 self.correct_mc_time = correct_mc_time 

125 self.center_time = center_time 

126 self.calib_hits = calib_hits 

127 self.calib_mchits = calib_mchits 

128 self.add_t0 = add_t0 

129 self.correct_timeslew = correct_timeslew 

130 self.center_hits_to = center_hits_to 

131 self.event_skipper = event_skipper 

132 self.chunksize = chunksize 

133 self.keep_event_info = keep_event_info 

134 self.overwrite = overwrite 

135 self.sort_y = sort_y 

136 self.y_to_float64 = y_to_float64 

137 

138 self.n_statusbar = 1000 

139 self.n_memory_observer = 1000 

140 self.complib = "zlib" 

141 self.complevel = 1 

142 self.flush_frequency = 1000 

143 self.seed = 42 

144 

145 def run(self, infile, outfile=None): 

146 """ 

147 Process the events from the infile, and save them to the outfile. 

148 

149 Parameters 

150 ---------- 

151 infile : str 

152 Path to the input file. 

153 outfile : str, optional 

154 Path to the output file (will be created). If none is given, 

155 will auto generate the name and save it in the cwd. 

156 

157 """ 

158 if outfile is None: 

159 outfile = os.path.join( 

160 os.getcwd(), 

161 "{}_dl.h5".format(os.path.splitext(os.path.basename(infile))[0]), 

162 ) 

163 if not self.overwrite: 

164 if os.path.isfile(outfile): 

165 raise FileExistsError(f"File exists: {outfile}") 

166 if self.seed: 

167 km.GlobalRandomState(seed=self.seed) 

168 pipe = self.build_pipe(infile, outfile) 

169 summary = pipe.drain() 

170 with h5py.File(outfile, "a") as f: 

171 self.finish_file(f, summary) 

172 

173 def run_multi(self, infiles, outfolder): 

174 """ 

175 Process multiple files into their own output files each. 

176 The output file names will be generated automatically. 

177 

178 Parameters 

179 ---------- 

180 infiles : List 

181 The path to infiles as str. 

182 outfolder : str 

183 The output folder to place them in. 

184 

185 """ 

186 outfiles = [] 

187 for infile in infiles: 

188 outfile = os.path.join( 

189 outfolder, f"{os.path.splitext(os.path.basename(infile))[0]}_dl.h5" 

190 ) 

191 outfiles.append(outfile) 

192 self.run(infile, outfile) 

193 return outfiles 

194 

195 def build_pipe(self, infile, outfile, timeit=True): 

196 """Initialize and connect the modules from the different stages.""" 

197 components = [ 

198 *self.get_cmpts_pre(infile=infile), 

199 *self.get_cmpts_main(), 

200 *self.get_cmpts_post(outfile=outfile), 

201 ] 

202 pipe = kp.Pipeline(timeit=timeit) 

203 if self.n_statusbar is not None: 

204 pipe.attach(km.common.StatusBar, every=self.n_statusbar) 

205 if self.n_memory_observer is not None: 

206 pipe.attach(km.common.MemoryObserver, every=self.n_memory_observer) 

207 for cmpt, kwargs in components: 

208 pipe.attach(cmpt, **kwargs) 

209 return pipe 

210 

211 def get_cmpts_pre(self, infile): 

212 """Modules that read and calibrate the events.""" 

213 cmpts = [(kp.io.hdf5.HDF5Pump, {"filename": infile})] 

214 

215 if self.correct_mc_time: 

216 with h5py.File(infile, "r") as f: 

217 if "mc_hits" in f and "mc_tracks" in f: 

218 cmpts.append((km.mc.MCTimeCorrector, {})) 

219 else: 

220 warnings.warn("Can not correct mc time: mc_hits " 

221 "and/or mc_tracks not found!") 

222 

223 if self.det_file: 

224 cmpts.append( 

225 ( 

226 modules.DetApplier, 

227 { 

228 "det_file": self.det_file, 

229 "correct_timeslew": self.correct_timeslew, 

230 "center_hits_to": self.center_hits_to, 

231 "calib_hits": self.calib_hits, 

232 "calib_mchits": self.calib_mchits, 

233 }, 

234 ) 

235 ) 

236 

237 if any((self.center_time, self.add_t0)): 

238 cmpts.append( 

239 ( 

240 modules.TimePreproc, 

241 {"add_t0": self.add_t0, "center_time": self.center_time}, 

242 ) 

243 ) 

244 return cmpts 

245 

246 @abstractmethod 

247 def get_cmpts_main(self): 

248 """Produce and store the samples as 'samples' in the blob.""" 

249 raise NotImplementedError 

250 

251 def get_cmpts_post(self, outfile): 

252 """Modules that postproc and save the events.""" 

253 cmpts = [] 

254 if self.extractor is not None: 

255 cmpts.append( 

256 ( 

257 modules.McInfoMaker, 

258 { 

259 "extractor": self.extractor, 

260 "to_float64": self.y_to_float64, 

261 "sort_y": self.sort_y, 

262 "store_as": "mc_info", 

263 }, 

264 ) 

265 ) 

266 

267 if self.event_skipper is not None: 

268 cmpts.append((modules.EventSkipper, {"event_skipper": self.event_skipper})) 

269 

270 keys_keep = ["samples", "mc_info", "header", "raw_header"] 

271 if self.keep_event_info: 

272 keys_keep.append("EventInfo") 

273 cmpts.append((km.common.Keep, {"keys": keys_keep})) 

274 

275 cmpts.append( 

276 ( 

277 kp.io.HDF5Sink, 

278 { 

279 "filename": outfile, 

280 "complib": self.complib, 

281 "complevel": self.complevel, 

282 "chunksize": self.chunksize, 

283 "flush_frequency": self.flush_frequency, 

284 }, 

285 ) 

286 ) 

287 return cmpts 

288 

289 def finish_file(self, f, summary): 

290 """ 

291 Work with the output file after the pipe has finished. 

292 

293 Parameters 

294 ---------- 

295 f : h5py.File 

296 The opened output file. 

297 summary : km3pipe.Blob 

298 The output from pipe.drain(). 

299 

300 """ 

301 # Add current orcasong version to h5 file 

302 f.attrs.create("orcasong", orcasong.__version__) 

303 

304 

305class FileBinner(BaseProcessor): 

306 """ 

307 For making binned images and mc_infos, which can be used for conv nets. 

308 

309 Can also add statistics of the binning to the h5 files, which can 

310 be plotted to show the distribution of hits among the bins and how 

311 many hits were cut off. 

312 

313 Parameters 

314 ---------- 

315 bin_edges_list : List 

316 List with the names of the fields to bin, and the respective bin 

317 edges, including the left- and right-most bin edge. 

318 Example: For 10 bins in the z direction, and 100 bins in time: 

319 bin_edges_list = [ 

320 ["pos_z", np.linspace(0, 10, 11)], 

321 ["time", np.linspace(-50, 550, 101)], 

322 ] 

323 Some examples can be found in orcasong.bin_edges. 

324 add_bin_stats : bool 

325 Add statistics of the binning to the output file. They can be 

326 plotted with util/bin_stats_plot.py [default: True]. 

327 hit_weights : str, optional 

328 Use blob["Hits"][hit_weights] as weights for samples in histogram. 

329 kwargs 

330 Options of the BaseProcessor. 

331 

332 """ 

333 

334 def __init__(self, bin_edges_list, add_bin_stats=True, hit_weights=None, chunksize=32, **kwargs): 

335 self.bin_edges_list = bin_edges_list 

336 self.add_bin_stats = add_bin_stats 

337 self.hit_weights = hit_weights 

338 super().__init__(chunksize=chunksize, **kwargs) 

339 

340 def get_cmpts_main(self): 

341 """Generate nD images.""" 

342 cmpts = [] 

343 if self.add_bin_stats: 

344 cmpts.append( 

345 (modules.BinningStatsMaker, {"bin_edges_list": self.bin_edges_list}) 

346 ) 

347 cmpts.append( 

348 ( 

349 modules.ImageMaker, 

350 { 

351 "bin_edges_list": self.bin_edges_list, 

352 "hit_weights": self.hit_weights, 

353 }, 

354 ) 

355 ) 

356 return cmpts 

357 

358 def finish_file(self, f, summary): 

359 super().finish_file(f, summary) 

360 if self.add_bin_stats: 

361 plot_binstats.add_hists_to_h5file(summary["BinningStatsMaker"], f) 

362 

363 def run_multi(self, infiles, outfolder, save_plot=False): 

364 """ 

365 Bin multiple files into their own output files each. 

366 The output file names will be generated automatically. 

367 

368 Parameters 

369 ---------- 

370 infiles : List 

371 The path to infiles as str. 

372 outfolder : str 

373 The output folder to place them in. 

374 save_plot : bool 

375 Save the binning hists as a pdf. Only possible if add_bin_stats 

376 is True. 

377 

378 """ 

379 if save_plot and not self.add_bin_stats: 

380 raise ValueError("Can not make plot when add_bin_stats is False") 

381 

382 name, shape = self.get_names_and_shape() 

383 print("Generating {} images with shape {}".format(name, shape)) 

384 

385 outfiles = super().run_multi(infiles=infiles, outfolder=outfolder) 

386 

387 if save_plot: 

388 plot_binstats.plot_hist_of_files( 

389 files=outfiles, save_as=outfolder + "binning_hist.pdf" 

390 ) 

391 return outfiles 

392 

393 def get_names_and_shape(self): 

394 """ 

395 Get names and shape of the resulting x data, 

396 e.g. (pos_z, time), (18, 50). 

397 """ 

398 names, shape = [], [] 

399 for bin_name, bin_edges in self.bin_edges_list: 

400 names.append(bin_name) 

401 shape.append(len(bin_edges) - 1) 

402 return tuple(names), tuple(shape) 

403 

404 def __repr__(self): 

405 return "<FileBinner: {} {}>".format(*self.get_names_and_shape()) 

406 

407 

408class FileGraph(BaseProcessor): 

409 """ 

410 Turn km3 events to graph data. 

411 

412 The resulting file will have a dataset "x" of shape 

413 (total n_hits, len(hit_infos)). 

414 The column names of the last axis (i.e. hit_infos) are saved 

415 as attributes of the dataset (f["x"].attrs). 

416 

417 Parameters 

418 ---------- 

419 hit_infos : tuple, optional 

420 Which entries in the '/Hits' Table will be kept. E.g. pos_x, time, ... 

421 Often, only dir_x/y/z, pos_x/y/z and time are required. 

422 Default: Keep all entries. 

423 time_window : tuple, optional 

424 Two ints (start, end). Hits outside of this time window will be cut 

425 away (based on 'Hits/time'). Default: Keep all hits. 

426 only_triggered_hits : bool 

427 If true, use only triggered hits. Otherwise, use all hits (default). 

428 max_n_hits : int 

429 Maximum number of hits that gets saved per event. If an event has 

430 more, some will get cut randomly! Default: Keep all hits. 

431 fixed_length : bool 

432 Legacy option. 

433 If False (default), save hits of events with variable length as 

434 2d arrays using km3pipe's indices. 

435 If True, pad hits of each event with 0s to a fixed length, 

436 so that they can be stored as 3d arrays like images. 

437 max_n_hits needs to be given in that case, and a column will be 

438 added called 'is_valid', which is 0 if the entry is padded, 

439 and 1 otherwise. 

440 This is inefficient and will cut off hits, so it should not be used. 

441 kwargs 

442 Options of the BaseProcessor. 

443 

444 """ 

445 

446 def __init__( 

447 self, 

448 max_n_hits=None, 

449 time_window=None, 

450 hit_infos=None, 

451 only_triggered_hits=False, 

452 fixed_length=False, 

453 **kwargs, 

454 ): 

455 self.max_n_hits = max_n_hits 

456 self.fixed_length = fixed_length 

457 self.time_window = time_window 

458 self.hit_infos = hit_infos 

459 self.only_triggered_hits = only_triggered_hits 

460 super().__init__(**kwargs) 

461 

462 def get_cmpts_main(self): 

463 return [ 

464 ( 

465 ( 

466 modules.PointMaker, 

467 { 

468 "max_n_hits": self.max_n_hits, 

469 "fixed_length": self.fixed_length, 

470 "time_window": self.time_window, 

471 "hit_infos": self.hit_infos, 

472 "dset_n_hits": "EventInfo", 

473 "only_triggered_hits": self.only_triggered_hits, 

474 }, 

475 ) 

476 ) 

477 ] 

478 

479 def finish_file(self, f, summary): 

480 super().finish_file(f, summary) 

481 for i, hit_info in enumerate(summary["PointMaker"]["hit_infos"]): 

482 f["x"].attrs.create(f"hit_info_{i}", hit_info) 

483 f["x"].attrs.create("indexed", not self.fixed_length)