Coverage for orcasong/tools/shuffle2.py: 95%

167 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-10-03 18:23 +0000

1import os 

2import time 

3import datetime 

4 

5import numpy as np 

6import psutil 

7import h5py 

8from km3pipe.sys import peak_memory_usage 

9import awkward as ak 

10 

11from orcasong.tools.postproc import get_filepath_output, copy_used_files 

12from orcasong.tools.concatenate import copy_attrs 

13 

14 

15__author__ = "Stefan Reck" 

16 

17 

18def h5shuffle2( 

19 input_file, 

20 output_file=None, 

21 iterations=None, 

22 datasets=("x", "y"), 

23 max_ram_fraction=0.25, 

24 max_ram=None, 

25 seed=42, 

26): 

27 """ 

28 Shuffle datasets in a h5file that have the same length. 

29 

30 Parameters 

31 ---------- 

32 input_file : str 

33 Path of the file that will be shuffle. 

34 output_file : str, optional 

35 If given, this will be the name of the output file. 

36 Otherwise, a name is auto generated. 

37 iterations : int, optional 

38 Shuffle the file this many times. For each additional iteration, 

39 a temporary file will be created and then deleted afterwards. 

40 Default: Auto choose best number based on available RAM. 

41 datasets : tuple 

42 Which datasets to include in output. 

43 max_ram : int, optional 

44 Available ram in bytes. Default: Use fraction of 

45 maximum available (see max_ram_fraction). 

46 max_ram_fraction : float 

47 in [0, 1]. Fraction of RAM to use for reading one batch of data 

48 when max_ram is None. Note: when using chunks, this should 

49 be <=~0.25, since lots of ram is needed for in-memory shuffling. 

50 seed : int or None 

51 Seed for randomness. 

52 

53 Returns 

54 ------- 

55 output_file : str 

56 Path to the output file. 

57 

58 """ 

59 if output_file is None: 

60 output_file = get_filepath_output(input_file, shuffle=True) 

61 if iterations is None: 

62 iterations = get_n_iterations( 

63 input_file, 

64 datasets=datasets, 

65 max_ram_fraction=max_ram_fraction, 

66 max_ram=max_ram, 

67 ) 

68 # filenames of all iterations, in the right order 

69 filenames = ( 

70 input_file, 

71 *_get_temp_filenames(output_file, number=iterations - 1), 

72 output_file, 

73 ) 

74 if seed: 

75 np.random.seed(seed) 

76 for i in range(iterations): 

77 print(f"\nIteration {i+1}/{iterations}") 

78 _shuffle_file( 

79 input_file=filenames[i], 

80 output_file=filenames[i + 1], 

81 delete=i > 0, 

82 datasets=datasets, 

83 max_ram=max_ram, 

84 max_ram_fraction=max_ram_fraction, 

85 ) 

86 return output_file 

87 

88 

89def _shuffle_file( 

90 input_file, 

91 output_file, 

92 datasets=("x", "y"), 

93 max_ram=None, 

94 max_ram_fraction=0.25, 

95 delete=False, 

96): 

97 start_time = time.time() 

98 if os.path.exists(output_file): 

99 raise FileExistsError(output_file) 

100 if max_ram is None: 

101 max_ram = get_max_ram(max_ram_fraction) 

102 # create file with temp name first, then rename afterwards 

103 temp_output_file = ( 

104 output_file + "_temp_" + time.strftime("%d-%m-%Y-%H-%M-%S", time.gmtime()) 

105 ) 

106 with h5py.File(input_file, "r") as f_in: 

107 dsets = (*datasets, *_get_indexed_datasets(f_in, datasets)) 

108 _check_dsets(f_in, dsets) 

109 dset_info = _get_largest_dset(f_in, dsets, max_ram) 

110 print(f"Shuffling datasets {dsets}") 

111 indices_per_batch = _get_indices_per_batch( 

112 dset_info["n_batches"], 

113 dset_info["n_chunks"], 

114 dset_info["chunksize"], 

115 ) 

116 

117 with h5py.File(temp_output_file, "x") as f_out: 

118 for dset_name in dsets: 

119 print("Creating dataset", dset_name) 

120 _shuffle_dset(f_out, f_in, dset_name, indices_per_batch) 

121 print("Done!") 

122 

123 copy_used_files(input_file, temp_output_file) 

124 copy_attrs(input_file, temp_output_file) 

125 os.rename(temp_output_file, output_file) 

126 if delete: 

127 os.remove(input_file) 

128 print( 

129 f"Elapsed time: " f"{datetime.timedelta(seconds=int(time.time() - start_time))}" 

130 ) 

131 return output_file 

132 

133 

134def get_max_ram(max_ram_fraction): 

135 max_ram = max_ram_fraction * psutil.virtual_memory().available 

136 print(f"Using {max_ram_fraction:.2%} of available ram = {max_ram} bytes") 

137 return max_ram 

138 

139 

140def get_n_iterations( 

141 input_file, datasets=("x", "y"), max_ram=None, max_ram_fraction=0.25 

142): 

143 """Get how often you have to shuffle with given ram to get proper randomness.""" 

144 if max_ram is None: 

145 max_ram = get_max_ram(max_ram_fraction=max_ram_fraction) 

146 with h5py.File(input_file, "r") as f_in: 

147 dset_info = _get_largest_dset(f_in, datasets, max_ram) 

148 n_iterations = np.amax( 

149 ( 

150 1, 

151 int( 

152 np.ceil( 

153 np.log(dset_info["n_chunks"]) 

154 / np.log(dset_info["chunks_per_batch"]) 

155 ) 

156 ), 

157 ) 

158 ) 

159 print(f"Largest dataset: {dset_info['name']}") 

160 print(f"Total chunks: {dset_info['n_chunks']}") 

161 print(f"Max. chunks per batch: {dset_info['chunks_per_batch']}") 

162 print(f"--> min iterations for full shuffle: {n_iterations}") 

163 return n_iterations 

164 

165 

166def _get_indices_per_batch(n_batches, n_chunks, chunksize): 

167 """ 

168 Return a list with the shuffled indices for each batch. 

169 

170 Returns 

171 ------- 

172 indices_per_batch : List 

173 Length n_batches, each element is a np.array[int]. 

174 Element i of the list are the indices of each sample in batch number i. 

175 

176 """ 

177 chunk_indices = np.arange(n_chunks) 

178 np.random.shuffle(chunk_indices) 

179 chunk_batches = np.array_split(chunk_indices, n_batches) 

180 

181 indices_per_batch = [] 

182 for bat in chunk_batches: 

183 idx = (bat[:, None] * chunksize + np.arange(chunksize)[None, :]).flatten() 

184 np.random.shuffle(idx) 

185 indices_per_batch.append(idx) 

186 

187 return indices_per_batch 

188 

189 

190def _get_largest_dset(f, datasets, max_ram): 

191 """ 

192 Get infos about the dset that needs the most batches. 

193 This is the dset that determines how many samples are shuffled at a time. 

194 """ 

195 dset_infos = _get_dset_infos(f, datasets, max_ram) 

196 return dset_infos[np.argmax([v["n_batches"] for v in dset_infos])] 

197 

198 

199def _check_dsets(f, datasets): 

200 # check if all datasets have the same number of lines 

201 n_lines_list = [] 

202 for dset_name in datasets: 

203 if dset_is_indexed(f, dset_name): 

204 dset_name = f"{dset_name}_indices" 

205 n_lines_list.append(len(f[dset_name])) 

206 

207 if not all([n == n_lines_list[0] for n in n_lines_list]): 

208 raise ValueError(f"Datasets have different lengths! " f"{n_lines_list}") 

209 

210 

211def _get_indexed_datasets(f, datasets): 

212 indexed_datasets = [] 

213 for dset_name in datasets: 

214 if dset_is_indexed(f, dset_name): 

215 indexed_datasets.append(f"{dset_name}_indices") 

216 return indexed_datasets 

217 

218 

219def _get_dset_infos(f, datasets, max_ram): 

220 """Retrieve infos for each dataset.""" 

221 dset_infos = [] 

222 for i, name in enumerate(datasets): 

223 if name.endswith("_indices"): 

224 continue 

225 if dset_is_indexed(f, name): 

226 # for indexed dataset: take average bytes in x per line in x_indices 

227 dset_data = f[name] 

228 name = f"{name}_indices" 

229 dset = f[name] 

230 bytes_per_line = np.asarray(dset[0]).nbytes * len(dset_data) / len(dset) 

231 else: 

232 dset = f[name] 

233 bytes_per_line = np.asarray(dset[0]).nbytes 

234 

235 n_lines = len(dset) 

236 chunksize = dset.chunks[0] 

237 n_chunks = int(np.ceil(n_lines / chunksize)) 

238 bytes_per_chunk = bytes_per_line * chunksize 

239 chunks_per_batch = int(np.floor(max_ram / bytes_per_chunk)) 

240 if chunks_per_batch < 2: 

241 raise ValueError( 

242 "Maximum usable RAM is {:.1f} MB, but one chunk of data is {:.1f} MB. " 

243 "This means shuffle can not be done, as at least two whole chunks " 

244 "need to fit in memory at a time. " 

245 "Try allocating more RAM, e.g. with the --max_ram option!".format( 

246 max_ram/1e6, bytes_per_chunk/1e6)) 

247 

248 dset_infos.append( 

249 { 

250 "name": name, 

251 "n_chunks": n_chunks, 

252 "chunks_per_batch": chunks_per_batch, 

253 "n_batches": int(np.ceil(n_chunks / chunks_per_batch)), 

254 "chunksize": chunksize, 

255 } 

256 ) 

257 

258 return dset_infos 

259 

260 

261def dset_is_indexed(f, dset_name): 

262 if f[dset_name].attrs.get("indexed"): 

263 if f"{dset_name}_indices" not in f: 

264 raise KeyError( 

265 f"{dset_name} is indexed, but {dset_name}_indices is missing!" 

266 ) 

267 return True 

268 else: 

269 return False 

270 

271 

272def _shuffle_dset(f_out, f_in, dset_name, indices_per_batch): 

273 """ 

274 Create a batchwise-shuffled dataset in the output file using given indices. 

275 

276 """ 

277 dset_in = f_in[dset_name] 

278 start_idx = 0 

279 running_index = 0 

280 for batch_number, indices in enumerate(indices_per_batch): 

281 print(f"Processing batch {batch_number+1}/{len(indices_per_batch)}") 

282 # remove indices outside of dset 

283 if dset_is_indexed(f_in, dset_name): 

284 max_index = len(f_in[f"{dset_name}_indices"]) 

285 else: 

286 max_index = len(dset_in) 

287 indices = indices[indices < max_index] 

288 

289 # reading has to be done with linearly increasing index 

290 # fancy indexing is super slow 

291 # so sort -> turn to slices -> read -> conc -> undo sorting 

292 sort_ix = np.argsort(indices) 

293 unsort_ix = np.argsort(sort_ix) 

294 fancy_indices = indices[sort_ix] 

295 slices = _slicify(fancy_indices) 

296 

297 if dset_is_indexed(f_in, dset_name): 

298 # special treatment for indexed: slice based on indices dataset 

299 dset_name_indexed = f"{dset_name}_indices" 

300 slices_indices = [f_in[dset_name_indexed][slc] for slc in slices] 

301 data_indices = np.concatenate(slices_indices) 

302 if any(np.diff(data_indices["index"]) <= 0): 

303 raise ValueError(f"'index' in {dset_name_indexed} is not increasing for every event!") 

304 

305 data = np.concatenate( 

306 [dset_in[slice(*_resolve_indexed(slc))] for slc in slices_indices] 

307 ) 

308 # convert to 3d awkward array, then shuffle, then back to numpy 

309 data_ak = ak.unflatten(data, data_indices["n_items"]) 

310 data = ak.flatten(data_ak[unsort_ix], axis=1).to_numpy() 

311 

312 else: 

313 data = np.concatenate([dset_in[slc] for slc in slices]) 

314 data = data[unsort_ix] 

315 

316 if dset_name.endswith("_indices"): 

317 # recacalculate index 

318 data["index"] = running_index + np.concatenate( 

319 [[0], np.cumsum(data["n_items"][:-1])] 

320 ) 

321 running_index = sum(data[-1]) 

322 

323 if batch_number == 0: 

324 out_dset = f_out.create_dataset( 

325 dset_name, 

326 data=data, 

327 maxshape=dset_in.shape, 

328 chunks=dset_in.chunks, 

329 compression=dset_in.compression, 

330 compression_opts=dset_in.compression_opts, 

331 shuffle=dset_in.shuffle, 

332 ) 

333 out_dset.resize(len(dset_in), axis=0) 

334 start_idx = len(data) 

335 else: 

336 end_idx = start_idx + len(data) 

337 f_out[dset_name][start_idx:end_idx] = data 

338 start_idx = end_idx 

339 

340 print("Memory peak: {0:.3f} MB".format(peak_memory_usage())) 

341 

342 if start_idx != len(dset_in): 

343 print(f"Warning: last index was {start_idx} not {len(dset_in)}") 

344 

345 

346def _slicify(fancy_indices): 

347 """[0,1,2, 6,7,8] --> [0:3, 6:9]""" 

348 steps = np.diff(fancy_indices) != 1 

349 slice_starts = np.concatenate([fancy_indices[:1], fancy_indices[1:][steps]]) 

350 slice_ends = np.concatenate([fancy_indices[:-1][steps], fancy_indices[-1:]]) + 1 

351 return [slice(slice_starts[i], slice_ends[i]) for i in range(len(slice_starts))] 

352 

353 

354def _resolve_indexed(ind): 

355 # based on slice of x_indices, get where to slice in x 

356 return ind["index"][0], ind["index"][-1] + ind["n_items"][-1] 

357 

358 

359def _get_temp_filenames(output_file, number): 

360 path, file = os.path.split(output_file) 

361 return [os.path.join(path, f"temp_iteration_{i}_{file}") for i in range(number)] 

362 

363 

364def run_parser(): 

365 # TODO deprecated 

366 raise NotImplementedError("h5shuffle2 has been renamed to orcasong h5shuffle2")