Coverage for orcanet/lib/sample_modifiers.py: 82%
98 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
1"""
2Some basic sample modifiers to use with orcanet.
3Use them by setting .cfg.sample_modifier of the orcanet.core.Organizer.
5"""
6from abc import abstractmethod
7import warnings
8import numpy as np
9from orcanet.misc import get_register
10import tensorflow as tf
12# for loading via toml
13smods, register = get_register()
16class PerInputModifier:
17 """
18 For modifiers that do the same operation on each input.
19 Apply modify on x_value of each input, and output as dict.
21 """
23 def __call__(self, info_blob):
24 x_values = info_blob["x_values"]
25 xs = dict()
26 for key, x_value in x_values.items():
27 xs[key] = self.modify(x_value)
28 return xs
30 @abstractmethod
31 def modify(self, x_value):
32 """x_value is a batch of input data as a numpy array."""
33 raise NotImplementedError
36class JoinedModifier(PerInputModifier):
37 """
38 For applying multiple sample modifiers after each other.
40 Example
41 -------
42 organizer.cfg.sample_modifier = JoinedModifier([
43 Reshape((11, 13, 18)), Permute((2, 1, 3))
44 ])
45 --> Reshape each sample, then permute axes.
47 """
49 def __init__(self, sample_modifiers):
50 self.sample_modifiers = sample_modifiers
52 def modify(self, x_value):
53 result = x_value
54 for smod in self.sample_modifiers:
55 result = smod.modify(result)
56 return result
59@register
60class Permute(PerInputModifier):
61 """
62 Permute the axes of the samples to given order.
63 Batchsize axis is excluded, i.e. start indexing with 1!
65 Example
66 -------
67 organizer.cfg.sample_modifier = Permute((2, 1, 3))
68 --> Swap first two axes of each sample.
70 """
72 def __init__(self, axes):
73 self.axes = list(axes)
75 def modify(self, x_value):
76 return np.transpose(x_value, [0] + self.axes)
79@register
80class Reshape(PerInputModifier):
81 """
82 Reshape samples to given shape.
83 Batchsize axis is excluded!
85 Example
86 -------
87 organizer.cfg.sample_modifier = Reshape((11, 13, 18))
88 --> Reshape each sample to that shape.
90 """
92 def __init__(self, newshape):
93 self.newshape = list(newshape)
95 def modify(self, x_value):
96 return np.reshape(x_value, [x_value.shape[0]] + self.newshape)
99@register
100class GraphEdgeConv:
101 """
102 Read out points and coordinates, intended for the MEdgeConv layers.
104 For DL files produced with OrcaSong in graph mode.
106 Parameters
107 ----------
108 knn : int or None
109 Number of nearest neighbors used in the edge conv.
110 Pad events with too few hits by duping first hit, and give a warning.
111 node_features : tuple
112 Defines the node features.
113 coord_features : tuple
114 Defines the coordinates.
115 ragged : bool, optional
116 If True, return ragged tensors (nodes, coordinates).
117 If False, return regular tensors, padded to fixed length.
118 n_hits_padded and is_valid_features need to be given in this case.
119 with_lightspeed : bool
120 Multiply time for coordinates input with lightspeed.
121 Requires coord_features to have the entry 'time'.
122 column_names : tuple, optional
123 Name and order of the features in the last dimension of the array.
124 If None is given, will attempt to auto-read the column names from
125 the attributes of the dataset.
126 is_valid_features : str
127 Only for when ragged = False.
128 Defines the is_valid.
129 n_hits_padded : int, optional
130 Only for when ragged = False.
131 Pad or cut to exactly this many hits using 0s.
132 Non-indexed datasets will automatically set this value.
134 """
136 def __init__(
137 self,
138 knn=16,
139 node_features=("pos_x", "pos_y", "pos_z", "time", "dir_x", "dir_y", "dir_z"),
140 coord_features=("pos_x", "pos_y", "pos_z", "time"),
141 ragged=True,
142 with_lightspeed=True,
143 column_names=None,
144 is_valid_features="is_valid",
145 n_hits_padded=None,
146 ):
147 self.knn = knn
148 self.node_features = node_features
149 self.coord_features = coord_features
150 self.ragged = ragged
151 self.with_lightspeed = with_lightspeed
152 self.column_names = column_names
153 self.lightspeed = 0.225 # in water; m/ns
154 self.is_valid_features = is_valid_features
155 self.n_hits_padded = n_hits_padded
157 def _str_to_idx(self, which):
158 """Given column name(s), get index of column(s)."""
159 if isinstance(which, str):
160 return self.column_names.index(which)
161 else:
162 return [self.column_names.index(w) for w in which]
164 def _cache_column_names(self, x_dataset):
165 try:
166 self.column_names = [
167 x_dataset.attrs[f"hit_info_{i}"] for i in range(x_dataset.shape[-1])
168 ]
169 except Exception:
170 raise ValueError("Can not read column names from dataset attributes")
172 def reset_cache(self):
173 """Clear cached column names."""
174 self.column_names = None
176 def __call__(self, info_blob):
177 # graph has only one file, take it no matter the name
178 input_name = list(info_blob["x_values"].keys())[0]
179 datasets_meta = info_blob["meta"]["datasets"][input_name]
180 is_indexed = datasets_meta.get("samples_is_indexed")
181 if self.column_names is None:
182 self._cache_column_names(datasets_meta["samples"])
184 if is_indexed is True:
185 # for indexed sets, x_values is 2d (nodes x features)
186 x_values, n_items = info_blob["x_values"][input_name]
187 n_hits_padded = None
188 else:
189 # otherwise it's 3d (batch x max_nodes x features)
190 x_values = info_blob["x_values"][input_name]
191 is_valid = x_values[:, :, self._str_to_idx(self.is_valid_features)]
192 n_hits_padded = is_valid.shape[-1]
193 x_values = x_values[is_valid == 1]
194 n_items = is_valid.sum(-1)
196 x_values = x_values.astype("float32")
197 n_items = n_items.astype("int32")
199 # pad events with too few hits by duping first hit
200 if np.any(n_items < self.knn + 1):
201 x_values, n_items = _pad_disjoint(x_values, n_items, min_items=self.knn + 1)
203 nodes = x_values[:, self._str_to_idx(self.node_features)]
204 coords = x_values[:, self._str_to_idx(self.coord_features)]
206 if self.with_lightspeed:
207 coords[:, self.coord_features.index("time")] *= self.lightspeed
209 nodes_t = tf.RaggedTensor.from_row_lengths(nodes, n_items)
210 coords_t = tf.RaggedTensor.from_row_lengths(coords, n_items)
212 if self.ragged is True:
213 return {
214 "nodes": nodes_t,
215 "coords": coords_t,
216 }
217 else:
218 if self.n_hits_padded is not None:
219 n_hits_padded = self.n_hits_padded
220 if n_hits_padded is None:
221 raise ValueError("Have to give n_hits_padded if ragged is False!")
223 sh = [nodes_t.shape[0], n_hits_padded]
224 return {
225 "nodes": nodes_t.to_tensor(
226 default_value=0.0, shape=sh + [nodes_t.shape[-1]]
227 ),
228 "is_valid": tf.ones_like(nodes_t[:, :, 0]).to_tensor(
229 default_value=0.0, shape=sh
230 ),
231 "coords": coords_t.to_tensor(
232 default_value=0.0, shape=sh + [coords_t.shape[-1]]
233 ),
234 }
237def _pad_disjoint(x, n_items, min_items):
238 """Pad disjoint graphs to have a minimum number of hits per event."""
239 n_items = np.array(n_items)
240 missing = np.clip(min_items - n_items, 0, None)
241 for batchno in np.where(missing > 0)[0]:
242 warnings.warn(
243 f"Event has too few hits! Needed {min_items}, "
244 f"had {n_items[batchno]}! Padding..."
245 )
246 cumu = np.concatenate(
247 [
248 [
249 0,
250 ],
251 n_items.cumsum(),
252 ]
253 )
254 first_hit = x[cumu[batchno]]
255 x = np.insert(
256 x,
257 cumu[batchno + 1],
258 np.repeat(first_hit[None, :], missing[batchno], axis=0),
259 axis=0,
260 )
261 n_items[batchno] = min_items
262 return x, n_items