Coverage for orcasong/tools/make_data_split.py: 84%

164 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-10-03 18:23 +0000

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4 

5__author__ = "Michael Moser, Daniel Guderian" 

6 

7import os 

8import warnings 

9import toml 

10import argparse 

11import h5py 

12import random 

13import numpy as np 

14 

15 

16def add_parser(subparsers): 

17 parser = subparsers.add_parser( 

18 "make_data_split", 

19 description="Create datasets based on the run_id's." 

20 "Use the config to add input folder and set the ranges." 

21 "Outputs a list in an txt file that can be used to " 

22 "concatenate the files specfied", 

23 ) 

24 parser.add_argument( 

25 "config_file", type=str, help="See example config for detailed information" 

26 ) 

27 parser.set_defaults(func=make_split) 

28 

29 

30def get_all_ip_group_keys(cfg): 

31 """ 

32 Gets the keys of all input groups in the config dict. 

33 

34 The input groups are defined as the dict elements, where the values have the type of a dict. 

35 

36 Parameters 

37 ---------- 

38 cfg : dict 

39 Dict that contains all configuration options and additional information. 

40 

41 Returns 

42 ------- 

43 ip_group_keys : list 

44 List of the input_group keys. 

45 

46 """ 

47 ip_group_keys = [] 

48 for key in cfg: 

49 if type(cfg[key]) == dict: 

50 ip_group_keys.append(key) 

51 

52 return ip_group_keys 

53 

54 

55def get_h5_filepaths(dirpath): 

56 """ 

57 Returns the filepaths of all .h5 files that are located in a specific directory. 

58 

59 Parameters 

60 ---------- 

61 dirpath: str 

62 Path of the directory where the .h5 files are located. 

63 

64 Returns 

65 ------- 

66 filepaths : list 

67 List with the full filepaths of all .h5 files in the dirpath folder. 

68 

69 """ 

70 filepaths = [] 

71 for f in os.listdir(dirpath): 

72 if f.endswith(".h5"): 

73 filepaths.append(dirpath + "/" + f) 

74 

75 # randomize order 

76 random.Random(42).shuffle(filepaths) 

77 

78 return filepaths 

79 

80 

81def get_number_of_evts(file, dataset_key="y"): 

82 """ 

83 Returns the number of events of a file looking at the given dataset. 

84 

85 Parameters 

86 ---------- 

87 file : h5 file 

88 File to read the number of events from. 

89 dataset_key : str 

90 String which specifies, which dataset in a h5 file should be used for calculating the number of events. 

91 

92 Returns 

93 ------- 

94 n_evts : int 

95 The number of events in that file. 

96 

97 """ 

98 

99 f = h5py.File(file, "r") 

100 dset = f[dataset_key] 

101 n_evts = dset.shape[0] 

102 f.close() 

103 

104 return n_evts 

105 

106 

107def get_number_of_evts_and_run_ids( 

108 list_of_files, dataset_key="y", run_id_col_name="run_id" 

109): 

110 """ 

111 Gets the number of events and the run_ids for all hdf5 files in the list_of_files. 

112 

113 The number of events is calculated based on the dataset, which is specified with the dataset_key parameter. 

114 

115 Parameters 

116 ---------- 

117 list_of_files : list 

118 List which contains filepaths to h5 files. 

119 dataset_key : str 

120 String which specifies, which dataset in a h5 file should be used for calculating the number of events. 

121 run_id_col_name : str 

122 String, which specifies the column name of the 'run_id' column. 

123 

124 Returns 

125 ------- 

126 total_number_of_evts : int 

127 The cumulative (total) number of events. 

128 mean_number_of_evts_per_file : float 

129 The mean number of evts per file. 

130 run_ids : list 

131 List containing the run_ids of the files in the list_of_files. 

132 

133 """ 

134 

135 total_number_of_evts = 0 

136 run_ids = [] 

137 

138 for i, fpath in enumerate(list_of_files): 

139 f = h5py.File(fpath, "r") 

140 

141 dset = f[dataset_key] 

142 n_evts = dset.shape[0] 

143 total_number_of_evts += n_evts 

144 

145 run_id = f[dataset_key][0][run_id_col_name] 

146 run_ids.append(run_id) 

147 

148 f.close() 

149 

150 mean_number_of_evts_per_file = total_number_of_evts / len(list_of_files) 

151 

152 return total_number_of_evts, mean_number_of_evts_per_file, run_ids 

153 

154 

155def split(a, n): 

156 """ 

157 Splits a list into n equal sized (if possible! if not, approximately) chunks. 

158 

159 Parameters 

160 ---------- 

161 a : list 

162 A list that should be split. 

163 n : int 

164 Number of times the input list should be split. 

165 

166 Returns 

167 ------- 

168 a_split : list 

169 The input list a, which has been split into n chunks. 

170 

171 """ 

172 # from https://stackoverflow.com/questions/2130016/splitting-a-list-into-n-parts-of-approximately-equal-length 

173 k, m = divmod(len(a), n) 

174 a_split = list( 

175 (a[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n)) 

176 ) 

177 return a_split 

178 

179 

180def print_input_statistics(cfg, ip_group_keys): 

181 """ 

182 Prints some useful information for each input_group. 

183 

184 Parameters 

185 ---------- 

186 cfg : dict 

187 Dict that contains all configuration options and additional information. 

188 ip_group_keys : list 

189 List of the input_group keys. 

190 

191 """ 

192 

193 print("----------------------------------------------------------------------") 

194 print("Printing input statistics for your " + cfg["toml_filename"] + " input:") 

195 print("----------------------------------------------------------------------") 

196 

197 print( 

198 "Your input .toml file has the following data input groups: " 

199 + str(ip_group_keys) 

200 ) 

201 print("Total number of events: " + str(cfg["n_evts_total"])) 

202 

203 for key in ip_group_keys: 

204 print("--------------------------------------------------------------------") 

205 print("Info for group " + key + ":") 

206 print("Directory: " + cfg[key]["dir"]) 

207 print("Total number of files: " + str(cfg[key]["n_files"])) 

208 print("Total number of events: " + str(cfg[key]["n_evts"])) 

209 print( 

210 "Mean number of events per file: " 

211 + str(round(cfg[key]["n_evts_per_file_mean"], 3)) 

212 ) 

213 print("--------------------------------------------------------------------") 

214 

215 

216def add_fpaths_for_data_split_to_cfg(cfg, key): 

217 """ 

218 Adds all the filepaths for the output files into a list, and puts them into the cfg['output_dsplit'][key] location 

219 for all dsplits (train, validate, rest). 

220 

221 Parameters 

222 ---------- 

223 cfg : dict 

224 Dict that contains all configuration options and additional information. 

225 key : str 

226 The key of an input_group. 

227 

228 """ 

229 

230 fpath_lists = {"train": [], "validate": [], "rest": []} 

231 for i, fpath in enumerate(cfg[key]["fpaths"]): 

232 

233 run_id = cfg[key]["run_ids"][i] 

234 

235 for dsplit in ["train", "validate", "rest"]: 

236 if "run_ids_" + dsplit in cfg[key]: 

237 if ( 

238 cfg[key]["run_ids_" + dsplit][0] 

239 <= run_id 

240 <= cfg[key]["run_ids_" + dsplit][1] 

241 ): 

242 fpath_lists[dsplit].append(fpath) 

243 

244 for dsplit in ["train", "validate", "rest"]: 

245 if len(fpath_lists[dsplit]) == 0: 

246 continue 

247 

248 n_files_dsplit = cfg["n_files_" + dsplit] 

249 fpath_lists[dsplit] = split(fpath_lists[dsplit], n_files_dsplit) 

250 if "output_" + dsplit not in cfg: 

251 cfg["output_" + dsplit] = dict() 

252 cfg["output_" + dsplit][key] = fpath_lists[dsplit] 

253 

254 

255def make_dsplit_list_files(cfg): 

256 """ 

257 Writes .list files of the datasplits to the disk, with the information in the cfg['output_dsplit'] dict. 

258 

259 Parameters 

260 ---------- 

261 cfg : dict 

262 Dict that contains all configuration options and additional information. 

263 

264 """ 

265 # check if //conc_list_files folder exists, if not create it. 

266 if not os.path.exists(cfg["output_file_folder"] + "/conc_list_files"): 

267 os.makedirs(cfg["output_file_folder"] + "/conc_list_files") 

268 

269 print() 

270 print() 

271 print("In an run-by-run MC the run_id's might not be continuous.") 

272 print("Here are the actual numbers in the split sets:") 

273 print("----------------------------------------------") 

274 

275 # loop over the different specified sets 

276 for dsplit in ["train", "validate", "rest"]: 

277 

278 if "output_" + dsplit not in cfg: 

279 continue 

280 

281 print(dsplit, "set:") 

282 

283 first_key = list(cfg["output_" + dsplit].keys())[0] 

284 n_output_files = len(cfg["output_" + dsplit][first_key]) 

285 

286 # initialize counter of events for all input groups 

287 imput_groups_dict = cfg["output_" + dsplit] 

288 final_number_of_events = np.zeros(len(imput_groups_dict)) 

289 

290 # loop over the number of outputfiles for each set 

291 for i in range(n_output_files): 

292 fpath_output = ( 

293 cfg["output_file_folder"] 

294 + "/conc_list_files/" 

295 + cfg["output_file_name"] 

296 + "_" 

297 + dsplit 

298 + "_" 

299 + str(i) 

300 + ".txt" 

301 ) 

302 

303 # save the txt list 

304 if "output_lists" not in cfg: 

305 cfg["output_lists"] = list() 

306 cfg["output_lists"].append(fpath_output) 

307 

308 with open(fpath_output, "w") as f_out: 

309 for j in range(len(imput_groups_dict)): 

310 keys = list(imput_groups_dict.keys()) 

311 

312 for fpath in imput_groups_dict[keys[j]][i]: 

313 # also count here the actual sizes 

314 final_number_of_events[j] += get_number_of_evts(fpath) 

315 f_out.write(fpath + "\n") 

316 

317 # and then print them 

318 for i in range(len(imput_groups_dict)): 

319 print(keys[i], ":", int(final_number_of_events[i])) 

320 

321 print("----------------------------------------------") 

322 

323 

324def make_concatenate_and_shuffle_scripts(cfg): 

325 """ 

326 Function that writes qsub .sh files which concatenates all files inside the list files. 

327 

328 Parameters 

329 ---------- 

330 cfg : dict 

331 Dict that contains all configuration options and additional information. 

332 

333 """ 

334 

335 dirpath = cfg["output_file_folder"] 

336 

337 if not os.path.exists( 

338 dirpath + "/logs" 

339 ): # check if /logs folder exists, if not create it. 

340 os.makedirs(dirpath + "/logs") 

341 if not os.path.exists( 

342 dirpath + "/job_scripts" 

343 ): # check if /job_scripts folder exists, if not create it. 

344 os.makedirs(dirpath + "/job_scripts") 

345 if not os.path.exists( 

346 dirpath + "/data_split" 

347 ): # check if /data_split folder exists, if not create it. 

348 os.makedirs(dirpath + "/data_split") 

349 

350 # make qsub .sh file for concatenating 

351 for listfile_fpath in cfg["output_lists"]: 

352 listfile_fname = os.path.basename(listfile_fpath) 

353 listfile_fname_wout_ext = os.path.splitext(listfile_fname)[0] 

354 conc_outputfile_fpath = ( 

355 cfg["output_file_folder"] + "/data_split/" + listfile_fname_wout_ext + ".h5" 

356 ) 

357 

358 fpath_bash_script = ( 

359 dirpath + "/job_scripts/concatenate_h5_" + listfile_fname_wout_ext + ".sh" 

360 ) 

361 

362 with open(fpath_bash_script, "w") as f: 

363 f.write("#!/usr/bin/env bash\n") 

364 f.write("\n") 

365 f.write("source " + cfg["venv_path"] + "activate" + "\n") 

366 f.write("\n") 

367 f.write("# Concatenate the files in the list\n") 

368 

369 f.write( 

370 "orcasong concatenate " + listfile_fpath + " --outfile " + conc_outputfile_fpath 

371 ) 

372 

373 # make qsub .sh file for shuffling 

374 

375 for listfile_fpath in cfg["output_lists"]: 

376 listfile_fname = os.path.basename(listfile_fpath) 

377 listfile_fname_wout_ext = os.path.splitext(listfile_fname)[0] 

378 

379 # This is the input for the shuffle tool! 

380 conc_outputfile_fpath = ( 

381 cfg["output_file_folder"] + "/data_split/" + listfile_fname_wout_ext + ".h5" 

382 ) 

383 

384 fpath_bash_script = ( 

385 dirpath + "/job_scripts/shuffle_h5_" + listfile_fname_wout_ext + ".sh" 

386 ) 

387 

388 with open(fpath_bash_script, "w") as f: 

389 f.write("#!/usr/bin/env bash\n") 

390 f.write("\n") 

391 f.write("source " + cfg["venv_path"] + "activate \n") 

392 f.write("\n") 

393 f.write("# Shuffle the h5 file \n") 

394 

395 f.write( 

396 "orcasong h5shuffle2 " + conc_outputfile_fpath) 

397 

398 if cfg["shuffle_delete"]: 

399 f.write("\n") 

400 f.write("rm " + conc_outputfile_fpath + "\n") 

401 

402 

403def make_split(config_file): 

404 # decode config 

405 cfg = toml.load(config_file) 

406 cfg["toml_filename"] = config_file 

407 

408 # read out all the input groups 

409 ip_group_keys = get_all_ip_group_keys(cfg) 

410 

411 # and now loop over input groups extracting info 

412 n_evts_total = 0 

413 for key in ip_group_keys: 

414 print("Collecting information from input group " + key) 

415 cfg[key]["fpaths"] = get_h5_filepaths(cfg[key]["dir"]) 

416 cfg[key]["n_files"] = len(cfg[key]["fpaths"]) 

417 ( 

418 cfg[key]["n_evts"], 

419 cfg[key]["n_evts_per_file_mean"], 

420 cfg[key]["run_ids"], 

421 ) = get_number_of_evts_and_run_ids(cfg[key]["fpaths"], dataset_key="y") 

422 

423 n_evts_total += cfg[key]["n_evts"] 

424 

425 cfg["n_evts_total"] = n_evts_total 

426 # print the extracted statistics 

427 print_input_statistics(cfg, ip_group_keys) 

428 

429 if cfg["print_only"] is True: 

430 from sys import exit 

431 

432 exit() 

433 

434 for key in ip_group_keys: 

435 add_fpaths_for_data_split_to_cfg(cfg, key) 

436 

437 # create the list files 

438 make_dsplit_list_files(cfg) 

439 

440 # create bash scripts that can be submitted to do the concatenation and shuffle 

441 if cfg["make_qsub_bash_files"] is True: 

442 make_concatenate_and_shuffle_scripts(cfg)