Coverage for orcanet/lib/label_modifiers.py: 86%

125 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-03-28 14:22 +0000

1import warnings 

2import numpy as np 

3import orcanet.misc as misc 

4 

5# for loading via toml 

6lmods, register = misc.get_register() 

7 

8 

9class ColumnLabels: 

10 """ 

11 Label of each model output is column with the same name in the h5 file. 

12 This is the default label modifier. 

13 

14 Example 

15 ------- 

16 Model has output "energy" --> label is column "energy" from the label 

17 dataset in the h5 file. 

18 

19 Parameters 

20 ---------- 

21 model : ks.Model 

22 A keras model. 

23 

24 """ 

25 

26 def __init__(self, model): 

27 self.output_names = model.output_names 

28 

29 def __call__(self, info_blob): 

30 ys = {name: info_blob["y_values"][name] for name in self.output_names} 

31 return ys 

32 

33 

34@register 

35class RegressionLabels: 

36 """ 

37 Generate labels for regression. 

38 

39 Parameters 

40 ---------- 

41 columns : str or list 

42 Name(s) of the columns in the label dataset that contain the labels. 

43 model_output : str, optional 

44 Name of the output of the network. 

45 Default: Same as columns (only valid if columns is a str). 

46 log10 : bool 

47 Take log10 of the labels. Invalid values in the label will produce 0 

48 and a warning. 

49 stacks : int, optional 

50 Stack copies of the label this many times along a new axis at position 1. 

51 E.g. if the label is shape (?, 3), it will become 

52 shape (?, stacks, 3). Used for lkl regression. 

53 

54 Examples 

55 -------- 

56 >>> RegressionLabels(columns=['dir_x', 'dir_y', 'dir_z'], model_output='dir') 

57 or in the config.toml: 

58 label_modifier = {name='RegressionLabels', columns=['dir_x','dir_y','dir_z'], model_output='dir'} 

59 Will produce array of shape (bs, 3) for model output 'dir'. 

60 >>> RegressionLabels(columns='dir_x') 

61 Will produce array of shape (bs, 1) for model output 'dir_x'. 

62 

63 """ 

64 

65 def __init__(self, columns, model_output=None, log10=False, stacks=None): 

66 if isinstance(columns, str): 

67 columns = [ 

68 columns, 

69 ] 

70 else: 

71 columns = list(columns) 

72 if model_output is None: 

73 if len(columns) != 1: 

74 raise ValueError( 

75 f"If model_output is not given, columns must be length 1!" 

76 ) 

77 model_output = columns[0] 

78 

79 self.columns = columns 

80 self.model_output = model_output 

81 self.stacks = stacks 

82 self.log10 = log10 

83 self._warned = False 

84 

85 def __call__(self, info_blob): 

86 y_values = info_blob["y_values"] 

87 if y_values is None: 

88 if not self._warned: 

89 warnings.warn(f"Can not generate labels: No y_values available!") 

90 self._warned = True 

91 return None 

92 try: 

93 y_value = y_values[self.columns] 

94 except KeyError: 

95 if not self._warned: 

96 warnings.warn( 

97 f"Can not generate labels: {self.columns} " f"not found in y_values" 

98 ) 

99 self._warned = True 

100 return None 

101 y_value = misc.to_ndarray(y_value, dtype="float32") 

102 return {self.model_output: self.process_label(y_value)} 

103 

104 def process_label(self, y_value): 

105 ys = y_value 

106 if self.log10: 

107 gr_zero = ys > 0 

108 if not np.all(gr_zero): 

109 warnings.warn( 

110 "invalid value encountered in log10, setting result to 1", 

111 category=RuntimeWarning, 

112 ) 

113 ys = np.log10(ys, where=gr_zero, out=np.ones_like(ys, dtype="float32")) 

114 if self.stacks: 

115 ys = np.repeat(ys[:, None], repeats=self.stacks, axis=1) 

116 

117 return ys 

118 

119 

120@register 

121class RegressionLabelsSplit(RegressionLabels): 

122 """ 

123 Generate labels for regression. 

124 

125 Intended for networks that output recos and errs in seperate towers 

126 (for example when using OutputRegNormalSplit as output layer block). 

127 

128 Example 

129 ------- 

130 >>> RegressionLabelsSplit(columns=['dir_x', 'dir_y', 'dir_z'], model_output='dir') 

131 Will produce label 'dir' of shape (bs, 3), 

132 and label 'dir_err' of shape (bs, 2, 3). 

133 

134 'dir_err' is just the label twice, along a new axis at -2. 

135 Necessary because pred and truth must be the same shape. 

136 

137 """ 

138 

139 def __init__(self, *args, **kwargs): 

140 super().__init__(*args, **kwargs) 

141 self.err_output_format = "{}_err" 

142 if self.stacks is not None: 

143 warnings.warn( 

144 "Can not use stacks option with RegressionLabelsSplit, ignoring..." 

145 ) 

146 self.stacks = None 

147 self._warned = False 

148 

149 def __call__(self, info_blob): 

150 output_dict = super().__call__(info_blob) 

151 if output_dict is None: 

152 return None 

153 err_outputs = {} 

154 for name, label in output_dict.items(): 

155 err_outputs[self.err_output_format.format(name)] = np.repeat( 

156 np.expand_dims(label, axis=-2), repeats=2, axis=-2 

157 ) 

158 output_dict.update(err_outputs) 

159 return output_dict 

160 

161 

162@register 

163class ClassificationLabels: 

164 """ 

165 One-hot encoding for general purpose classification labels based on one mc label column. 

166 

167 Parameters 

168 ---------- 

169 column : str 

170 Identifier of which mc info to create the labels from. 

171 classes : dict 

172 Specify for each class the conditions the column name has to fulfil. 

173 The keys have to be named "class1", "class2", etc 

174 model_output : str, optional 

175 The name of the output layer's outputs. 

176 

177 Example 

178 ------- 

179 2-class cf for signal and background; put this into the config.toml: 

180 label_modifier = {name="ClassificationLabels", column="particle_type", classes={class1 = [12, -12, 14, -14], class2 = [13, -13, 0]}, model_output="bg_output"} 

181 

182 """ 

183 

184 def __init__( 

185 self, 

186 column, 

187 classes, 

188 model_output=None, 

189 ): 

190 self.column = column 

191 self.classes = classes 

192 self.model_output = model_output 

193 self._warned = False 

194 

195 if "class1" not in self.classes: 

196 raise KeyError("Class names must be named 'class1', 'class2',...") 

197 if not len(self.classes["class1"]) > 0: 

198 raise ValueError("Not a valid list for a class") 

199 

200 if model_output is None: 

201 self.model_output = column 

202 

203 def __call__(self, info_blob): 

204 

205 y_values = info_blob["y_values"] 

206 

207 if y_values is None: 

208 if not self._warned: 

209 warnings.warn(f"Can not generate labels: No y_values available!") 

210 self._warned = True 

211 return None 

212 

213 try: 

214 y_value = y_values[self.column] 

215 except ValueError: 

216 if not self._warned: 

217 warnings.warn( 

218 f"Can not generate labels: {self.column} " f"not found in y_values" 

219 ) 

220 self._warned = True 

221 # let this pass by for real data 

222 return None 

223 

224 # create an array of the final shape, initialized with zeros 

225 n_classes = len(self.classes) 

226 batchsize = y_values.shape[0] 

227 categories = np.zeros((batchsize, n_classes), dtype="bool") 

228 

229 # iterate over every class and set entries to 1 if condition is fulfilled 

230 for i in range(n_classes): 

231 categories[:, i] = np.in1d( 

232 y_values[self.column], self.classes["class" + str(i + 1)] 

233 ) 

234 

235 return {self.model_output: categories.astype(np.float32)} 

236 

237 

238@register 

239class TSClassifier: 

240 

241 """ 

242 One-hot encoding for track/shower classifier. Muon neutrino CC are tracks, the rest 

243 of neutrinos is shower. This means, this has to be extended for tau neutrinos. Atm. 

244 muon events, if any, are tracks. 

245 

246 Parameters 

247 ---------- 

248 is_cc_convention : int 

249 The convention used in the MC prod to indicate a charged current interaction. 

250 For post 2020 productions this is 2. 

251 model_output : str, optional 

252 Name of the output of the network. 

253 Default: Same as names (only valid if names is a str). 

254 

255 Example 

256 ------- 

257 label_modifier = {name='TSClassifier', is_cc_convention=2} 

258 

259 """ 

260 

261 def __init__( 

262 self, 

263 is_cc_convention, 

264 model_output="ts_output", 

265 ): 

266 self.is_cc_convention = is_cc_convention 

267 self.model_output = model_output 

268 self._warned = False 

269 

270 def __call__(self, info_blob): 

271 

272 y_values = info_blob["y_values"] 

273 

274 try: 

275 particle_type = y_values["particle_type"] 

276 is_cc = y_values["is_cc"] == self.is_cc_convention 

277 except ValueError: 

278 if not self._warned: 

279 warnings.warn( 

280 f"Can not generate labels: particle_type or is_cc not found in y_values" 

281 ) 

282 self._warned = True 

283 # let this pass by for real data 

284 return None 

285 

286 ys = dict() 

287 

288 # create conditions from particle_type and is cc 

289 is_muon_cc = np.logical_and(np.abs(particle_type) == 14, is_cc) 

290 

291 # in case there are atm. muon events in the mix as well, declare them to be tracks 

292 is_track = np.logical_or(is_muon_cc, np.abs(particle_type) == 13) 

293 

294 is_shower = np.invert(is_track) 

295 

296 batchsize = y_values.shape[0] 

297 # categorical [shower, track] -> [1,0] = shower, [0,1] = track 

298 categorical_ts = np.zeros((batchsize, 2), dtype="bool") 

299 

300 categorical_ts[:, 0] = is_track 

301 categorical_ts[:, 1] = is_shower 

302 

303 ys[self.model_output] = categorical_ts.astype(np.float32) 

304 

305 return ys