Coverage for orcanet/lib/sample_modifiers.py: 82%

98 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-03-28 14:22 +0000

1""" 

2Some basic sample modifiers to use with orcanet. 

3Use them by setting .cfg.sample_modifier of the orcanet.core.Organizer. 

4 

5""" 

6from abc import abstractmethod 

7import warnings 

8import numpy as np 

9from orcanet.misc import get_register 

10import tensorflow as tf 

11 

12# for loading via toml 

13smods, register = get_register() 

14 

15 

16class PerInputModifier: 

17 """ 

18 For modifiers that do the same operation on each input. 

19 Apply modify on x_value of each input, and output as dict. 

20 

21 """ 

22 

23 def __call__(self, info_blob): 

24 x_values = info_blob["x_values"] 

25 xs = dict() 

26 for key, x_value in x_values.items(): 

27 xs[key] = self.modify(x_value) 

28 return xs 

29 

30 @abstractmethod 

31 def modify(self, x_value): 

32 """x_value is a batch of input data as a numpy array.""" 

33 raise NotImplementedError 

34 

35 

36class JoinedModifier(PerInputModifier): 

37 """ 

38 For applying multiple sample modifiers after each other. 

39 

40 Example 

41 ------- 

42 organizer.cfg.sample_modifier = JoinedModifier([ 

43 Reshape((11, 13, 18)), Permute((2, 1, 3)) 

44 ]) 

45 --> Reshape each sample, then permute axes. 

46 

47 """ 

48 

49 def __init__(self, sample_modifiers): 

50 self.sample_modifiers = sample_modifiers 

51 

52 def modify(self, x_value): 

53 result = x_value 

54 for smod in self.sample_modifiers: 

55 result = smod.modify(result) 

56 return result 

57 

58 

59@register 

60class Permute(PerInputModifier): 

61 """ 

62 Permute the axes of the samples to given order. 

63 Batchsize axis is excluded, i.e. start indexing with 1! 

64 

65 Example 

66 ------- 

67 organizer.cfg.sample_modifier = Permute((2, 1, 3)) 

68 --> Swap first two axes of each sample. 

69 

70 """ 

71 

72 def __init__(self, axes): 

73 self.axes = list(axes) 

74 

75 def modify(self, x_value): 

76 return np.transpose(x_value, [0] + self.axes) 

77 

78 

79@register 

80class Reshape(PerInputModifier): 

81 """ 

82 Reshape samples to given shape. 

83 Batchsize axis is excluded! 

84 

85 Example 

86 ------- 

87 organizer.cfg.sample_modifier = Reshape((11, 13, 18)) 

88 --> Reshape each sample to that shape. 

89 

90 """ 

91 

92 def __init__(self, newshape): 

93 self.newshape = list(newshape) 

94 

95 def modify(self, x_value): 

96 return np.reshape(x_value, [x_value.shape[0]] + self.newshape) 

97 

98 

99@register 

100class GraphEdgeConv: 

101 """ 

102 Read out points and coordinates, intended for the MEdgeConv layers. 

103 

104 For DL files produced with OrcaSong in graph mode. 

105 

106 Parameters 

107 ---------- 

108 knn : int or None 

109 Number of nearest neighbors used in the edge conv. 

110 Pad events with too few hits by duping first hit, and give a warning. 

111 node_features : tuple 

112 Defines the node features. 

113 coord_features : tuple 

114 Defines the coordinates. 

115 ragged : bool, optional 

116 If True, return ragged tensors (nodes, coordinates). 

117 If False, return regular tensors, padded to fixed length. 

118 n_hits_padded and is_valid_features need to be given in this case. 

119 with_lightspeed : bool 

120 Multiply time for coordinates input with lightspeed. 

121 Requires coord_features to have the entry 'time'. 

122 column_names : tuple, optional 

123 Name and order of the features in the last dimension of the array. 

124 If None is given, will attempt to auto-read the column names from 

125 the attributes of the dataset. 

126 is_valid_features : str 

127 Only for when ragged = False. 

128 Defines the is_valid. 

129 n_hits_padded : int, optional 

130 Only for when ragged = False. 

131 Pad or cut to exactly this many hits using 0s. 

132 Non-indexed datasets will automatically set this value. 

133 

134 """ 

135 

136 def __init__( 

137 self, 

138 knn=16, 

139 node_features=("pos_x", "pos_y", "pos_z", "time", "dir_x", "dir_y", "dir_z"), 

140 coord_features=("pos_x", "pos_y", "pos_z", "time"), 

141 ragged=True, 

142 with_lightspeed=True, 

143 column_names=None, 

144 is_valid_features="is_valid", 

145 n_hits_padded=None, 

146 ): 

147 self.knn = knn 

148 self.node_features = node_features 

149 self.coord_features = coord_features 

150 self.ragged = ragged 

151 self.with_lightspeed = with_lightspeed 

152 self.column_names = column_names 

153 self.lightspeed = 0.225 # in water; m/ns 

154 self.is_valid_features = is_valid_features 

155 self.n_hits_padded = n_hits_padded 

156 

157 def _str_to_idx(self, which): 

158 """Given column name(s), get index of column(s).""" 

159 if isinstance(which, str): 

160 return self.column_names.index(which) 

161 else: 

162 return [self.column_names.index(w) for w in which] 

163 

164 def _cache_column_names(self, x_dataset): 

165 try: 

166 self.column_names = [ 

167 x_dataset.attrs[f"hit_info_{i}"] for i in range(x_dataset.shape[-1]) 

168 ] 

169 except Exception: 

170 raise ValueError("Can not read column names from dataset attributes") 

171 

172 def reset_cache(self): 

173 """Clear cached column names.""" 

174 self.column_names = None 

175 

176 def __call__(self, info_blob): 

177 # graph has only one file, take it no matter the name 

178 input_name = list(info_blob["x_values"].keys())[0] 

179 datasets_meta = info_blob["meta"]["datasets"][input_name] 

180 is_indexed = datasets_meta.get("samples_is_indexed") 

181 if self.column_names is None: 

182 self._cache_column_names(datasets_meta["samples"]) 

183 

184 if is_indexed is True: 

185 # for indexed sets, x_values is 2d (nodes x features) 

186 x_values, n_items = info_blob["x_values"][input_name] 

187 n_hits_padded = None 

188 else: 

189 # otherwise it's 3d (batch x max_nodes x features) 

190 x_values = info_blob["x_values"][input_name] 

191 is_valid = x_values[:, :, self._str_to_idx(self.is_valid_features)] 

192 n_hits_padded = is_valid.shape[-1] 

193 x_values = x_values[is_valid == 1] 

194 n_items = is_valid.sum(-1) 

195 

196 x_values = x_values.astype("float32") 

197 n_items = n_items.astype("int32") 

198 

199 # pad events with too few hits by duping first hit 

200 if np.any(n_items < self.knn + 1): 

201 x_values, n_items = _pad_disjoint(x_values, n_items, min_items=self.knn + 1) 

202 

203 nodes = x_values[:, self._str_to_idx(self.node_features)] 

204 coords = x_values[:, self._str_to_idx(self.coord_features)] 

205 

206 if self.with_lightspeed: 

207 coords[:, self.coord_features.index("time")] *= self.lightspeed 

208 

209 nodes_t = tf.RaggedTensor.from_row_lengths(nodes, n_items) 

210 coords_t = tf.RaggedTensor.from_row_lengths(coords, n_items) 

211 

212 if self.ragged is True: 

213 return { 

214 "nodes": nodes_t, 

215 "coords": coords_t, 

216 } 

217 else: 

218 if self.n_hits_padded is not None: 

219 n_hits_padded = self.n_hits_padded 

220 if n_hits_padded is None: 

221 raise ValueError("Have to give n_hits_padded if ragged is False!") 

222 

223 sh = [nodes_t.shape[0], n_hits_padded] 

224 return { 

225 "nodes": nodes_t.to_tensor( 

226 default_value=0.0, shape=sh + [nodes_t.shape[-1]] 

227 ), 

228 "is_valid": tf.ones_like(nodes_t[:, :, 0]).to_tensor( 

229 default_value=0.0, shape=sh 

230 ), 

231 "coords": coords_t.to_tensor( 

232 default_value=0.0, shape=sh + [coords_t.shape[-1]] 

233 ), 

234 } 

235 

236 

237def _pad_disjoint(x, n_items, min_items): 

238 """Pad disjoint graphs to have a minimum number of hits per event.""" 

239 n_items = np.array(n_items) 

240 missing = np.clip(min_items - n_items, 0, None) 

241 for batchno in np.where(missing > 0)[0]: 

242 warnings.warn( 

243 f"Event has too few hits! Needed {min_items}, " 

244 f"had {n_items[batchno]}! Padding..." 

245 ) 

246 cumu = np.concatenate( 

247 [ 

248 [ 

249 0, 

250 ], 

251 n_items.cumsum(), 

252 ] 

253 ) 

254 first_hit = x[cumu[batchno]] 

255 x = np.insert( 

256 x, 

257 cumu[batchno + 1], 

258 np.repeat(first_hit[None, :], missing[batchno], axis=0), 

259 axis=0, 

260 ) 

261 n_items[batchno] = min_items 

262 return x, n_items