Coverage for orcasong/tools/shuffle2.py: 95%
167 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-10-03 18:23 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-10-03 18:23 +0000
1import os
2import time
3import datetime
5import numpy as np
6import psutil
7import h5py
8from km3pipe.sys import peak_memory_usage
9import awkward as ak
11from orcasong.tools.postproc import get_filepath_output, copy_used_files
12from orcasong.tools.concatenate import copy_attrs
15__author__ = "Stefan Reck"
18def h5shuffle2(
19 input_file,
20 output_file=None,
21 iterations=None,
22 datasets=("x", "y"),
23 max_ram_fraction=0.25,
24 max_ram=None,
25 seed=42,
26):
27 """
28 Shuffle datasets in a h5file that have the same length.
30 Parameters
31 ----------
32 input_file : str
33 Path of the file that will be shuffle.
34 output_file : str, optional
35 If given, this will be the name of the output file.
36 Otherwise, a name is auto generated.
37 iterations : int, optional
38 Shuffle the file this many times. For each additional iteration,
39 a temporary file will be created and then deleted afterwards.
40 Default: Auto choose best number based on available RAM.
41 datasets : tuple
42 Which datasets to include in output.
43 max_ram : int, optional
44 Available ram in bytes. Default: Use fraction of
45 maximum available (see max_ram_fraction).
46 max_ram_fraction : float
47 in [0, 1]. Fraction of RAM to use for reading one batch of data
48 when max_ram is None. Note: when using chunks, this should
49 be <=~0.25, since lots of ram is needed for in-memory shuffling.
50 seed : int or None
51 Seed for randomness.
53 Returns
54 -------
55 output_file : str
56 Path to the output file.
58 """
59 if output_file is None:
60 output_file = get_filepath_output(input_file, shuffle=True)
61 if iterations is None:
62 iterations = get_n_iterations(
63 input_file,
64 datasets=datasets,
65 max_ram_fraction=max_ram_fraction,
66 max_ram=max_ram,
67 )
68 # filenames of all iterations, in the right order
69 filenames = (
70 input_file,
71 *_get_temp_filenames(output_file, number=iterations - 1),
72 output_file,
73 )
74 if seed:
75 np.random.seed(seed)
76 for i in range(iterations):
77 print(f"\nIteration {i+1}/{iterations}")
78 _shuffle_file(
79 input_file=filenames[i],
80 output_file=filenames[i + 1],
81 delete=i > 0,
82 datasets=datasets,
83 max_ram=max_ram,
84 max_ram_fraction=max_ram_fraction,
85 )
86 return output_file
89def _shuffle_file(
90 input_file,
91 output_file,
92 datasets=("x", "y"),
93 max_ram=None,
94 max_ram_fraction=0.25,
95 delete=False,
96):
97 start_time = time.time()
98 if os.path.exists(output_file):
99 raise FileExistsError(output_file)
100 if max_ram is None:
101 max_ram = get_max_ram(max_ram_fraction)
102 # create file with temp name first, then rename afterwards
103 temp_output_file = (
104 output_file + "_temp_" + time.strftime("%d-%m-%Y-%H-%M-%S", time.gmtime())
105 )
106 with h5py.File(input_file, "r") as f_in:
107 dsets = (*datasets, *_get_indexed_datasets(f_in, datasets))
108 _check_dsets(f_in, dsets)
109 dset_info = _get_largest_dset(f_in, dsets, max_ram)
110 print(f"Shuffling datasets {dsets}")
111 indices_per_batch = _get_indices_per_batch(
112 dset_info["n_batches"],
113 dset_info["n_chunks"],
114 dset_info["chunksize"],
115 )
117 with h5py.File(temp_output_file, "x") as f_out:
118 for dset_name in dsets:
119 print("Creating dataset", dset_name)
120 _shuffle_dset(f_out, f_in, dset_name, indices_per_batch)
121 print("Done!")
123 copy_used_files(input_file, temp_output_file)
124 copy_attrs(input_file, temp_output_file)
125 os.rename(temp_output_file, output_file)
126 if delete:
127 os.remove(input_file)
128 print(
129 f"Elapsed time: " f"{datetime.timedelta(seconds=int(time.time() - start_time))}"
130 )
131 return output_file
134def get_max_ram(max_ram_fraction):
135 max_ram = max_ram_fraction * psutil.virtual_memory().available
136 print(f"Using {max_ram_fraction:.2%} of available ram = {max_ram} bytes")
137 return max_ram
140def get_n_iterations(
141 input_file, datasets=("x", "y"), max_ram=None, max_ram_fraction=0.25
142):
143 """Get how often you have to shuffle with given ram to get proper randomness."""
144 if max_ram is None:
145 max_ram = get_max_ram(max_ram_fraction=max_ram_fraction)
146 with h5py.File(input_file, "r") as f_in:
147 dset_info = _get_largest_dset(f_in, datasets, max_ram)
148 n_iterations = np.amax(
149 (
150 1,
151 int(
152 np.ceil(
153 np.log(dset_info["n_chunks"])
154 / np.log(dset_info["chunks_per_batch"])
155 )
156 ),
157 )
158 )
159 print(f"Largest dataset: {dset_info['name']}")
160 print(f"Total chunks: {dset_info['n_chunks']}")
161 print(f"Max. chunks per batch: {dset_info['chunks_per_batch']}")
162 print(f"--> min iterations for full shuffle: {n_iterations}")
163 return n_iterations
166def _get_indices_per_batch(n_batches, n_chunks, chunksize):
167 """
168 Return a list with the shuffled indices for each batch.
170 Returns
171 -------
172 indices_per_batch : List
173 Length n_batches, each element is a np.array[int].
174 Element i of the list are the indices of each sample in batch number i.
176 """
177 chunk_indices = np.arange(n_chunks)
178 np.random.shuffle(chunk_indices)
179 chunk_batches = np.array_split(chunk_indices, n_batches)
181 indices_per_batch = []
182 for bat in chunk_batches:
183 idx = (bat[:, None] * chunksize + np.arange(chunksize)[None, :]).flatten()
184 np.random.shuffle(idx)
185 indices_per_batch.append(idx)
187 return indices_per_batch
190def _get_largest_dset(f, datasets, max_ram):
191 """
192 Get infos about the dset that needs the most batches.
193 This is the dset that determines how many samples are shuffled at a time.
194 """
195 dset_infos = _get_dset_infos(f, datasets, max_ram)
196 return dset_infos[np.argmax([v["n_batches"] for v in dset_infos])]
199def _check_dsets(f, datasets):
200 # check if all datasets have the same number of lines
201 n_lines_list = []
202 for dset_name in datasets:
203 if dset_is_indexed(f, dset_name):
204 dset_name = f"{dset_name}_indices"
205 n_lines_list.append(len(f[dset_name]))
207 if not all([n == n_lines_list[0] for n in n_lines_list]):
208 raise ValueError(f"Datasets have different lengths! " f"{n_lines_list}")
211def _get_indexed_datasets(f, datasets):
212 indexed_datasets = []
213 for dset_name in datasets:
214 if dset_is_indexed(f, dset_name):
215 indexed_datasets.append(f"{dset_name}_indices")
216 return indexed_datasets
219def _get_dset_infos(f, datasets, max_ram):
220 """Retrieve infos for each dataset."""
221 dset_infos = []
222 for i, name in enumerate(datasets):
223 if name.endswith("_indices"):
224 continue
225 if dset_is_indexed(f, name):
226 # for indexed dataset: take average bytes in x per line in x_indices
227 dset_data = f[name]
228 name = f"{name}_indices"
229 dset = f[name]
230 bytes_per_line = np.asarray(dset[0]).nbytes * len(dset_data) / len(dset)
231 else:
232 dset = f[name]
233 bytes_per_line = np.asarray(dset[0]).nbytes
235 n_lines = len(dset)
236 chunksize = dset.chunks[0]
237 n_chunks = int(np.ceil(n_lines / chunksize))
238 bytes_per_chunk = bytes_per_line * chunksize
239 chunks_per_batch = int(np.floor(max_ram / bytes_per_chunk))
240 if chunks_per_batch < 2:
241 raise ValueError(
242 "Maximum usable RAM is {:.1f} MB, but one chunk of data is {:.1f} MB. "
243 "This means shuffle can not be done, as at least two whole chunks "
244 "need to fit in memory at a time. "
245 "Try allocating more RAM, e.g. with the --max_ram option!".format(
246 max_ram/1e6, bytes_per_chunk/1e6))
248 dset_infos.append(
249 {
250 "name": name,
251 "n_chunks": n_chunks,
252 "chunks_per_batch": chunks_per_batch,
253 "n_batches": int(np.ceil(n_chunks / chunks_per_batch)),
254 "chunksize": chunksize,
255 }
256 )
258 return dset_infos
261def dset_is_indexed(f, dset_name):
262 if f[dset_name].attrs.get("indexed"):
263 if f"{dset_name}_indices" not in f:
264 raise KeyError(
265 f"{dset_name} is indexed, but {dset_name}_indices is missing!"
266 )
267 return True
268 else:
269 return False
272def _shuffle_dset(f_out, f_in, dset_name, indices_per_batch):
273 """
274 Create a batchwise-shuffled dataset in the output file using given indices.
276 """
277 dset_in = f_in[dset_name]
278 start_idx = 0
279 running_index = 0
280 for batch_number, indices in enumerate(indices_per_batch):
281 print(f"Processing batch {batch_number+1}/{len(indices_per_batch)}")
282 # remove indices outside of dset
283 if dset_is_indexed(f_in, dset_name):
284 max_index = len(f_in[f"{dset_name}_indices"])
285 else:
286 max_index = len(dset_in)
287 indices = indices[indices < max_index]
289 # reading has to be done with linearly increasing index
290 # fancy indexing is super slow
291 # so sort -> turn to slices -> read -> conc -> undo sorting
292 sort_ix = np.argsort(indices)
293 unsort_ix = np.argsort(sort_ix)
294 fancy_indices = indices[sort_ix]
295 slices = _slicify(fancy_indices)
297 if dset_is_indexed(f_in, dset_name):
298 # special treatment for indexed: slice based on indices dataset
299 dset_name_indexed = f"{dset_name}_indices"
300 slices_indices = [f_in[dset_name_indexed][slc] for slc in slices]
301 data_indices = np.concatenate(slices_indices)
302 if any(np.diff(data_indices["index"]) <= 0):
303 raise ValueError(f"'index' in {dset_name_indexed} is not increasing for every event!")
305 data = np.concatenate(
306 [dset_in[slice(*_resolve_indexed(slc))] for slc in slices_indices]
307 )
308 # convert to 3d awkward array, then shuffle, then back to numpy
309 data_ak = ak.unflatten(data, data_indices["n_items"])
310 data = ak.flatten(data_ak[unsort_ix], axis=1).to_numpy()
312 else:
313 data = np.concatenate([dset_in[slc] for slc in slices])
314 data = data[unsort_ix]
316 if dset_name.endswith("_indices"):
317 # recacalculate index
318 data["index"] = running_index + np.concatenate(
319 [[0], np.cumsum(data["n_items"][:-1])]
320 )
321 running_index = sum(data[-1])
323 if batch_number == 0:
324 out_dset = f_out.create_dataset(
325 dset_name,
326 data=data,
327 maxshape=dset_in.shape,
328 chunks=dset_in.chunks,
329 compression=dset_in.compression,
330 compression_opts=dset_in.compression_opts,
331 shuffle=dset_in.shuffle,
332 )
333 out_dset.resize(len(dset_in), axis=0)
334 start_idx = len(data)
335 else:
336 end_idx = start_idx + len(data)
337 f_out[dset_name][start_idx:end_idx] = data
338 start_idx = end_idx
340 print("Memory peak: {0:.3f} MB".format(peak_memory_usage()))
342 if start_idx != len(dset_in):
343 print(f"Warning: last index was {start_idx} not {len(dset_in)}")
346def _slicify(fancy_indices):
347 """[0,1,2, 6,7,8] --> [0:3, 6:9]"""
348 steps = np.diff(fancy_indices) != 1
349 slice_starts = np.concatenate([fancy_indices[:1], fancy_indices[1:][steps]])
350 slice_ends = np.concatenate([fancy_indices[:-1][steps], fancy_indices[-1:]]) + 1
351 return [slice(slice_starts[i], slice_ends[i]) for i in range(len(slice_starts))]
354def _resolve_indexed(ind):
355 # based on slice of x_indices, get where to slice in x
356 return ind["index"][0], ind["index"][-1] + ind["n_items"][-1]
359def _get_temp_filenames(output_file, number):
360 path, file = os.path.split(output_file)
361 return [os.path.join(path, f"temp_iteration_{i}_{file}") for i in range(number)]
364def run_parser():
365 # TODO deprecated
366 raise NotImplementedError("h5shuffle2 has been renamed to orcasong h5shuffle2")