Coverage for orcasong/extractors/bundles.py: 29%

140 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-10-03 18:23 +0000

1import warnings 

2import numpy as np 

3 

4 

5class BundleDataExtractor: 

6 """Get info present in real data.""" 

7 

8 def __init__(self, infile, only_downgoing_tracks=False): 

9 self.only_downgoing_tracks = only_downgoing_tracks 

10 

11 def __call__(self, blob): 

12 # just take everything from event info 

13 if not len(blob["EventInfo"]) == 1: 

14 warnings.warn(f"Event info has length {len(blob['EventInfo'])}, not 1") 

15 track = dict(zip(blob["EventInfo"].dtype.names, blob["EventInfo"][0])) 

16 track.update( 

17 **get_best_track(blob, only_downgoing_tracks=self.only_downgoing_tracks) 

18 ) 

19 

20 track["n_hits"] = len(blob["Hits"]) 

21 track["n_triggered_hits"] = blob["Hits"]["triggered"].sum() 

22 is_triggered = blob["Hits"]["triggered"].astype(bool) 

23 track["n_triggered_doms"] = len(np.unique(blob["Hits"]["dom_id"][is_triggered])) 

24 track["t_last_triggered"] = blob["Hits"]["time"][is_triggered].max() 

25 

26 unique_hits = get_only_first_hit_per_pmt(blob["Hits"]) 

27 track["n_pmts"] = len(unique_hits) 

28 track["n_triggered_pmts"] = unique_hits["triggered"].sum() 

29 

30 if "n_hits_intime" in blob["EventInfo"]: 

31 n_hits_intime = blob["EventInfo"]["n_hits_intime"] 

32 else: 

33 n_hits_intime = np.nan 

34 track["n_hits_intime"] = n_hits_intime 

35 return track 

36 

37 

38def get_only_first_hit_per_pmt(hits): 

39 """Keep only the first hit of each pmt.""" 

40 idents = np.stack((hits["dom_id"], hits["channel_id"]), axis=-1) 

41 sorted_time_indices = np.argsort(hits["time"]) 

42 # indices of first hit per pmt in time sorted array: 

43 indices = np.unique(idents[sorted_time_indices], axis=0, return_index=True)[1] 

44 # indices of first hit per pmt in original array: 

45 first_hit_indices = np.sort(sorted_time_indices[indices]) 

46 return hits[first_hit_indices] 

47 

48 

49def get_best_track(blob, missing_value=np.nan, only_downgoing_tracks=False): 

50 """ 

51 I mean first track, i.e. the one with longest chain and highest lkl/nhits. 

52 Can also take the best track only of those that are downgoing. 

53 """ 

54 # hardcode names here since the first blob might not have Tracks 

55 names = ( 

56 "E", 

57 "JCOPY_Z_M", 

58 "JENERGY_CHI2", 

59 "JENERGY_ENERGY", 

60 "JENERGY_MUON_RANGE_METRES", 

61 "JENERGY_NDF", 

62 "JENERGY_NOISE_LIKELIHOOD", 

63 "JENERGY_NUMBER_OF_HITS", 

64 "JGANDALF_BETA0_RAD", 

65 "JGANDALF_BETA1_RAD", 

66 "JGANDALF_CHI2", 

67 "JGANDALF_LAMBDA", 

68 "JGANDALF_NUMBER_OF_HITS", 

69 "JGANDALF_NUMBER_OF_ITERATIONS", 

70 "JSHOWERFIT_ENERGY", 

71 "JSTART_LENGTH_METRES", 

72 "JSTART_NPE_MIP", 

73 "JSTART_NPE_MIP_TOTAL", 

74 "JVETO_NPE", 

75 "JVETO_NUMBER_OF_HITS", 

76 "dir_x", 

77 "dir_y", 

78 "dir_z", 

79 "id", 

80 "length", 

81 "likelihood", 

82 "pos_x", 

83 "pos_y", 

84 "pos_z", 

85 "rec_type", 

86 "t", 

87 "group_id", 

88 ) 

89 if "Tracks" in blob: 

90 tracks = blob["Tracks"] 

91 elif "BestJmuon" in blob: 

92 if only_downgoing_tracks: 

93 raise ValueError("only_downgoing_tracks option requires all tracks!") 

94 tracks = blob["BestJmuon"] 

95 else: 

96 tracks = None 

97 

98 index = None 

99 if tracks is not None: 

100 if only_downgoing_tracks: 

101 downs = np.where(tracks.dir_z < 0)[0] 

102 if len(downs) != 0: 

103 index = downs[0] 

104 else: 

105 index = 0 

106 

107 if index is not None: 

108 track = tracks[index] 

109 return {f"jg_{name}_reco": track[name] for name in names} 

110 else: 

111 return {f"jg_{name}_reco": missing_value for name in names} 

112 

113 

114class BundleMCExtractor: 

115 """ 

116 For atmospheric muon studies on mupage or corsika simulations. 

117 

118 Parameters 

119 ---------- 

120 inactive_du : int, optional 

121 Don't count mchits in this du. E.g. for ORCA4, DU 1 is inactive. 

122 min_n_mchits_list : tuple 

123 How many mchits does a muon have to produce to be counted? 

124 Create a seperate set of entries for each number in the tuple. 

125 plane_point : tuple 

126 For bundle diameter: XYZ coordinates of where the center of the 

127 plane is in which the muon positions get calculated. Should be set 

128 to the center of the detector! 

129 with_mc_index : bool 

130 Add a column called mc_index containing the mc run number, 

131 which is attempted to be read from the filename. This is for 

132 when the same run id/event id combination appears in mc files, 

133 which can happend e.g. in run by run simulations when there are 

134 multiplie mc runs per data run. 

135 Requires the filename to have a very specific format, which is 

136 likely not future-proof. 

137 TODO this would ideally not be read from the filename, 

138 but there is currently not other way of accessing it (07/2021). 

139 is_corsika : bool 

140 Use this when using Corsika!!! 

141 only_downgoing_tracks : bool 

142 For the best track (JG reco), consider only the ones that are downgoing. 

143 missing_value : float 

144 If a value is missing, use this value instead. 

145 

146 """ 

147 

148 def __init__( 

149 self, 

150 infile, 

151 inactive_du=None, 

152 min_n_mchits_list=(0, 1, 10), 

153 plane_point=(17, 17, 111), 

154 with_mc_index=True, 

155 is_corsika=False, 

156 only_downgoing_tracks=False, 

157 missing_value=np.nan, 

158 ): 

159 self.inactive_du = inactive_du 

160 self.min_n_mchits_list = min_n_mchits_list 

161 self.plane_point = plane_point 

162 self.with_mc_index = with_mc_index 

163 self.missing_value = missing_value 

164 self.is_corsika = is_corsika 

165 self.only_downgoing_tracks = only_downgoing_tracks 

166 

167 self.data_extractor = BundleDataExtractor( 

168 infile, only_downgoing_tracks=only_downgoing_tracks 

169 ) 

170 

171 if self.with_mc_index: 

172 self.mc_index = get_mc_index(infile) 

173 print(f"Using mc_index {self.mc_index}") 

174 else: 

175 self.mc_index = None 

176 

177 def __call__(self, blob): 

178 mc_info = self.data_extractor(blob) 

179 

180 if self.is_corsika: 

181 # Corsika has a primary particle. Store infos about it 

182 prim_track = blob["McTracks"][0] 

183 

184 # primary should be track 0 with id 0 

185 if prim_track["id"] != 0: 

186 raise ValueError("Error finding primary: mc_tracks[0]['id'] != 0") 

187 

188 # direction of the primary 

189 mc_info["dir_x"] = prim_track.dir_x 

190 mc_info["dir_y"] = prim_track.dir_y 

191 mc_info["dir_z"] = prim_track.dir_z 

192 # use primary direction as plane normal 

193 plane_normal = np.array(prim_track[["dir_x", "dir_y", "dir_z"]].tolist()) 

194 

195 for fld in ("pos_x", "pos_y", "pos_z", "pdgid", "energy", "time"): 

196 mc_info[f"primary_{fld}"] = prim_track[fld] 

197 

198 # remove primary for the following, since it's not a muon 

199 blob["McTracks"] = blob["McTracks"][1:] 

200 else: 

201 # In mupage, all muons in a bundle are parallel. So just take dir of first muon 

202 mc_info["dir_x"] = blob["McTracks"].dir_x[0] 

203 mc_info["dir_y"] = blob["McTracks"].dir_y[0] 

204 mc_info["dir_z"] = blob["McTracks"].dir_z[0] 

205 plane_normal = None 

206 

207 # n_mc_hits of each muon in active dus 

208 mchits_per_muon = get_mchits_per_muon(blob, inactive_du=self.inactive_du) 

209 

210 for min_n_mchits in self.min_n_mchits_list: 

211 if min_n_mchits == 0: 

212 mc_tracks_sel = blob["McTracks"] 

213 suffix = "sim" 

214 else: 

215 mc_tracks_sel = blob["McTracks"][mchits_per_muon >= min_n_mchits] 

216 suffix = f"{min_n_mchits}_mchits" 

217 

218 # total number of mchits of all muons 

219 mc_info[f"n_mc_hits_{suffix}"] = np.sum( 

220 mchits_per_muon[mchits_per_muon >= min_n_mchits] 

221 ) 

222 # number of muons with at least the given number of mchits 

223 mc_info[f"n_muons_{suffix}"] = len(mc_tracks_sel) 

224 # summed up energy of all muons 

225 mc_info[f"energy_{suffix}"] = np.sum(mc_tracks_sel.energy) 

226 # bundle diameter; only makes sense for 2+ muons 

227 if len(mc_tracks_sel) >= 2: 

228 positions_plane = get_plane_positions( 

229 positions=np.concatenate( 

230 [mc_tracks_sel[k][:, None] for k in ("pos_x", "pos_y", "pos_z")], axis=-1), 

231 directions=np.concatenate( 

232 [mc_tracks_sel[k][:, None] for k in ("dir_x", "dir_y", "dir_z")], axis=-1), 

233 plane_point=self.plane_point, 

234 plane_normal=plane_normal, 

235 ) 

236 pairwise_distances = get_pairwise_distances(positions_plane) 

237 mc_info[f"max_pair_dist_{suffix}"] = pairwise_distances.max() 

238 mc_info[f"mean_pair_dist_{suffix}"] = pairwise_distances.mean() 

239 else: 

240 mc_info[f"max_pair_dist_{suffix}"] = self.missing_value 

241 mc_info[f"mean_pair_dist_{suffix}"] = self.missing_value 

242 

243 if self.with_mc_index: 

244 mc_info["mc_index"] = self.mc_index 

245 

246 return mc_info 

247 

248 

249def get_plane_positions(positions, directions, plane_point, plane_normal=None): 

250 """ 

251 Get the position of each muon in a 2d plane. 

252 Length will be preserved, i.e. 1m in 3d space is also 1m in plane space. 

253 

254 Parameters 

255 ---------- 

256 positions : np.array 

257 The position of each muon in 3d cartesian space, shape (n_muons, 3). 

258 directions : np.array 

259 The direction of each muon as a cartesian unit vector, shape (n_muons, 3). 

260 plane_point : np.array 

261 A 3d cartesian point on the plane. This will be (0, 0) in the plane 

262 coordinate system. Shape (3, ). 

263 plane_normal : np.array, optional 

264 A 3d cartesian vector perpendicular to the plane, shape (3, ). 

265 Default: Use directions if all muons are parallel, otherwise raise. 

266 

267 Returns 

268 ------- 

269 positions_plane : np.array 

270 The 2d position of each muon in the plane, shape (n_muons, 2). 

271 

272 """ 

273 if plane_normal is None: 

274 if not np.all(directions == directions[0]): 

275 raise ValueError( 

276 "Muon tracks are not all parallel: plane_normal has to be specified!" 

277 ) 

278 plane_normal = directions[0] 

279 

280 # get the 3d points where each muon collides with the plane 

281 points = [] 

282 for i in range(len(directions)): 

283 ndotu = np.dot(plane_normal, directions[i]) 

284 if abs(ndotu) < 1e-6: 

285 raise ValueError("no intersection or line is within plane") 

286 

287 w = positions[i] - plane_point 

288 si = -np.dot(plane_normal, w) / ndotu 

289 psi = w + si * directions[i] + plane_point 

290 points.append(psi) 

291 points = np.array(points) 

292 

293 # Get the unit vectors of the plane. u is 0 in x, v is 0 in y. 

294 u = np.array([1, 0, -plane_normal[0] / plane_normal[2]]) 

295 v = np.array([0, 1, -plane_normal[1] / plane_normal[2]]) 

296 # norm: 

297 u = u / np.linalg.norm(u) 

298 v = v / np.linalg.norm(v) 

299 

300 # xy coordinates in plane 

301 x_dash = (points[:, 0] - plane_point[0]) / u[0] 

302 y_dash = (points[:, 1] - plane_point[1]) / v[1] 

303 position_plane = np.array([x_dash, y_dash]).T 

304 

305 return position_plane 

306 

307 

308def get_pairwise_distances(positions_plane, as_matrix=False): 

309 """ 

310 Get the perpendicular distance between each muon pair. 

311 

312 Parameters 

313 ---------- 

314 positions_plane : np.array 

315 The 2d position of each muon in a plane, shape (n_muons, 2). 

316 as_matrix : bool 

317 Return the whole 2D distance matrix. 

318 

319 Returns 

320 ------- 

321 np.array 

322 The distances between each pair of muons. 

323 1D if as_matrix is False (default), else 2D. 

324 

325 """ 

326 pos_x, pos_y = positions_plane[:, 0], positions_plane[:, 1] 

327 

328 dists_x = np.expand_dims(pos_x, -2) - np.expand_dims(pos_x, -1) 

329 dists_y = np.expand_dims(pos_y, -2) - np.expand_dims(pos_y, -1) 

330 l2_dists = np.sqrt(dists_x ** 2 + dists_y ** 2) 

331 if as_matrix: 

332 return l2_dists 

333 else: 

334 return l2_dists[np.triu_indices_from(l2_dists, k=1)] 

335 

336 

337def get_mchits_per_muon(blob, inactive_du=None): 

338 """ 

339 For each muon in McTracks, get the number of McHits. 

340 

341 Parameters 

342 ---------- 

343 blob 

344 The blob. 

345 inactive_du : int, optional 

346 McHits in this DU will not be counted. 

347 

348 Returns 

349 ------- 

350 np.array 

351 n_mchits, len = number of muons --> blob["McTracks"]["id"] 

352 

353 """ 

354 ids = blob["McTracks"]["id"] 

355 # Origin of each mchit (as int) in the active line 

356 origin = blob["McHits"]["origin"] 

357 if inactive_du: 

358 # only hits in active line 

359 origin = origin[blob["McHits"]["du"] != inactive_du] 

360 # get how many mchits were produced per muon in the bundle 

361 origin_dict = dict(zip(*np.unique(origin, return_counts=True))) 

362 return np.array([origin_dict.get(i, 0) for i in ids]) 

363 

364 

365def get_mc_index(aanet_filename): 

366 # e.g. mcv5.40.mupage_10G.sirene.jterbr00005782.jorcarec.aanet.365.h5 

367 return int(aanet_filename.split(".")[-2])