Coverage for orcasong/tools/make_data_split.py: 84%
164 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-10-03 18:23 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-10-03 18:23 +0000
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
5__author__ = "Michael Moser, Daniel Guderian"
7import os
8import warnings
9import toml
10import argparse
11import h5py
12import random
13import numpy as np
16def add_parser(subparsers):
17 parser = subparsers.add_parser(
18 "make_data_split",
19 description="Create datasets based on the run_id's."
20 "Use the config to add input folder and set the ranges."
21 "Outputs a list in an txt file that can be used to "
22 "concatenate the files specfied",
23 )
24 parser.add_argument(
25 "config_file", type=str, help="See example config for detailed information"
26 )
27 parser.set_defaults(func=make_split)
30def get_all_ip_group_keys(cfg):
31 """
32 Gets the keys of all input groups in the config dict.
34 The input groups are defined as the dict elements, where the values have the type of a dict.
36 Parameters
37 ----------
38 cfg : dict
39 Dict that contains all configuration options and additional information.
41 Returns
42 -------
43 ip_group_keys : list
44 List of the input_group keys.
46 """
47 ip_group_keys = []
48 for key in cfg:
49 if type(cfg[key]) == dict:
50 ip_group_keys.append(key)
52 return ip_group_keys
55def get_h5_filepaths(dirpath):
56 """
57 Returns the filepaths of all .h5 files that are located in a specific directory.
59 Parameters
60 ----------
61 dirpath: str
62 Path of the directory where the .h5 files are located.
64 Returns
65 -------
66 filepaths : list
67 List with the full filepaths of all .h5 files in the dirpath folder.
69 """
70 filepaths = []
71 for f in os.listdir(dirpath):
72 if f.endswith(".h5"):
73 filepaths.append(dirpath + "/" + f)
75 # randomize order
76 random.Random(42).shuffle(filepaths)
78 return filepaths
81def get_number_of_evts(file, dataset_key="y"):
82 """
83 Returns the number of events of a file looking at the given dataset.
85 Parameters
86 ----------
87 file : h5 file
88 File to read the number of events from.
89 dataset_key : str
90 String which specifies, which dataset in a h5 file should be used for calculating the number of events.
92 Returns
93 -------
94 n_evts : int
95 The number of events in that file.
97 """
99 f = h5py.File(file, "r")
100 dset = f[dataset_key]
101 n_evts = dset.shape[0]
102 f.close()
104 return n_evts
107def get_number_of_evts_and_run_ids(
108 list_of_files, dataset_key="y", run_id_col_name="run_id"
109):
110 """
111 Gets the number of events and the run_ids for all hdf5 files in the list_of_files.
113 The number of events is calculated based on the dataset, which is specified with the dataset_key parameter.
115 Parameters
116 ----------
117 list_of_files : list
118 List which contains filepaths to h5 files.
119 dataset_key : str
120 String which specifies, which dataset in a h5 file should be used for calculating the number of events.
121 run_id_col_name : str
122 String, which specifies the column name of the 'run_id' column.
124 Returns
125 -------
126 total_number_of_evts : int
127 The cumulative (total) number of events.
128 mean_number_of_evts_per_file : float
129 The mean number of evts per file.
130 run_ids : list
131 List containing the run_ids of the files in the list_of_files.
133 """
135 total_number_of_evts = 0
136 run_ids = []
138 for i, fpath in enumerate(list_of_files):
139 f = h5py.File(fpath, "r")
141 dset = f[dataset_key]
142 n_evts = dset.shape[0]
143 total_number_of_evts += n_evts
145 run_id = f[dataset_key][0][run_id_col_name]
146 run_ids.append(run_id)
148 f.close()
150 mean_number_of_evts_per_file = total_number_of_evts / len(list_of_files)
152 return total_number_of_evts, mean_number_of_evts_per_file, run_ids
155def split(a, n):
156 """
157 Splits a list into n equal sized (if possible! if not, approximately) chunks.
159 Parameters
160 ----------
161 a : list
162 A list that should be split.
163 n : int
164 Number of times the input list should be split.
166 Returns
167 -------
168 a_split : list
169 The input list a, which has been split into n chunks.
171 """
172 # from https://stackoverflow.com/questions/2130016/splitting-a-list-into-n-parts-of-approximately-equal-length
173 k, m = divmod(len(a), n)
174 a_split = list(
175 (a[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n))
176 )
177 return a_split
180def print_input_statistics(cfg, ip_group_keys):
181 """
182 Prints some useful information for each input_group.
184 Parameters
185 ----------
186 cfg : dict
187 Dict that contains all configuration options and additional information.
188 ip_group_keys : list
189 List of the input_group keys.
191 """
193 print("----------------------------------------------------------------------")
194 print("Printing input statistics for your " + cfg["toml_filename"] + " input:")
195 print("----------------------------------------------------------------------")
197 print(
198 "Your input .toml file has the following data input groups: "
199 + str(ip_group_keys)
200 )
201 print("Total number of events: " + str(cfg["n_evts_total"]))
203 for key in ip_group_keys:
204 print("--------------------------------------------------------------------")
205 print("Info for group " + key + ":")
206 print("Directory: " + cfg[key]["dir"])
207 print("Total number of files: " + str(cfg[key]["n_files"]))
208 print("Total number of events: " + str(cfg[key]["n_evts"]))
209 print(
210 "Mean number of events per file: "
211 + str(round(cfg[key]["n_evts_per_file_mean"], 3))
212 )
213 print("--------------------------------------------------------------------")
216def add_fpaths_for_data_split_to_cfg(cfg, key):
217 """
218 Adds all the filepaths for the output files into a list, and puts them into the cfg['output_dsplit'][key] location
219 for all dsplits (train, validate, rest).
221 Parameters
222 ----------
223 cfg : dict
224 Dict that contains all configuration options and additional information.
225 key : str
226 The key of an input_group.
228 """
230 fpath_lists = {"train": [], "validate": [], "rest": []}
231 for i, fpath in enumerate(cfg[key]["fpaths"]):
233 run_id = cfg[key]["run_ids"][i]
235 for dsplit in ["train", "validate", "rest"]:
236 if "run_ids_" + dsplit in cfg[key]:
237 if (
238 cfg[key]["run_ids_" + dsplit][0]
239 <= run_id
240 <= cfg[key]["run_ids_" + dsplit][1]
241 ):
242 fpath_lists[dsplit].append(fpath)
244 for dsplit in ["train", "validate", "rest"]:
245 if len(fpath_lists[dsplit]) == 0:
246 continue
248 n_files_dsplit = cfg["n_files_" + dsplit]
249 fpath_lists[dsplit] = split(fpath_lists[dsplit], n_files_dsplit)
250 if "output_" + dsplit not in cfg:
251 cfg["output_" + dsplit] = dict()
252 cfg["output_" + dsplit][key] = fpath_lists[dsplit]
255def make_dsplit_list_files(cfg):
256 """
257 Writes .list files of the datasplits to the disk, with the information in the cfg['output_dsplit'] dict.
259 Parameters
260 ----------
261 cfg : dict
262 Dict that contains all configuration options and additional information.
264 """
265 # check if //conc_list_files folder exists, if not create it.
266 if not os.path.exists(cfg["output_file_folder"] + "/conc_list_files"):
267 os.makedirs(cfg["output_file_folder"] + "/conc_list_files")
269 print()
270 print()
271 print("In an run-by-run MC the run_id's might not be continuous.")
272 print("Here are the actual numbers in the split sets:")
273 print("----------------------------------------------")
275 # loop over the different specified sets
276 for dsplit in ["train", "validate", "rest"]:
278 if "output_" + dsplit not in cfg:
279 continue
281 print(dsplit, "set:")
283 first_key = list(cfg["output_" + dsplit].keys())[0]
284 n_output_files = len(cfg["output_" + dsplit][first_key])
286 # initialize counter of events for all input groups
287 imput_groups_dict = cfg["output_" + dsplit]
288 final_number_of_events = np.zeros(len(imput_groups_dict))
290 # loop over the number of outputfiles for each set
291 for i in range(n_output_files):
292 fpath_output = (
293 cfg["output_file_folder"]
294 + "/conc_list_files/"
295 + cfg["output_file_name"]
296 + "_"
297 + dsplit
298 + "_"
299 + str(i)
300 + ".txt"
301 )
303 # save the txt list
304 if "output_lists" not in cfg:
305 cfg["output_lists"] = list()
306 cfg["output_lists"].append(fpath_output)
308 with open(fpath_output, "w") as f_out:
309 for j in range(len(imput_groups_dict)):
310 keys = list(imput_groups_dict.keys())
312 for fpath in imput_groups_dict[keys[j]][i]:
313 # also count here the actual sizes
314 final_number_of_events[j] += get_number_of_evts(fpath)
315 f_out.write(fpath + "\n")
317 # and then print them
318 for i in range(len(imput_groups_dict)):
319 print(keys[i], ":", int(final_number_of_events[i]))
321 print("----------------------------------------------")
324def make_concatenate_and_shuffle_scripts(cfg):
325 """
326 Function that writes qsub .sh files which concatenates all files inside the list files.
328 Parameters
329 ----------
330 cfg : dict
331 Dict that contains all configuration options and additional information.
333 """
335 dirpath = cfg["output_file_folder"]
337 if not os.path.exists(
338 dirpath + "/logs"
339 ): # check if /logs folder exists, if not create it.
340 os.makedirs(dirpath + "/logs")
341 if not os.path.exists(
342 dirpath + "/job_scripts"
343 ): # check if /job_scripts folder exists, if not create it.
344 os.makedirs(dirpath + "/job_scripts")
345 if not os.path.exists(
346 dirpath + "/data_split"
347 ): # check if /data_split folder exists, if not create it.
348 os.makedirs(dirpath + "/data_split")
350 # make qsub .sh file for concatenating
351 for listfile_fpath in cfg["output_lists"]:
352 listfile_fname = os.path.basename(listfile_fpath)
353 listfile_fname_wout_ext = os.path.splitext(listfile_fname)[0]
354 conc_outputfile_fpath = (
355 cfg["output_file_folder"] + "/data_split/" + listfile_fname_wout_ext + ".h5"
356 )
358 fpath_bash_script = (
359 dirpath + "/job_scripts/concatenate_h5_" + listfile_fname_wout_ext + ".sh"
360 )
362 with open(fpath_bash_script, "w") as f:
363 f.write("#!/usr/bin/env bash\n")
364 f.write("\n")
365 f.write("source " + cfg["venv_path"] + "activate" + "\n")
366 f.write("\n")
367 f.write("# Concatenate the files in the list\n")
369 f.write(
370 "orcasong concatenate " + listfile_fpath + " --outfile " + conc_outputfile_fpath
371 )
373 # make qsub .sh file for shuffling
375 for listfile_fpath in cfg["output_lists"]:
376 listfile_fname = os.path.basename(listfile_fpath)
377 listfile_fname_wout_ext = os.path.splitext(listfile_fname)[0]
379 # This is the input for the shuffle tool!
380 conc_outputfile_fpath = (
381 cfg["output_file_folder"] + "/data_split/" + listfile_fname_wout_ext + ".h5"
382 )
384 fpath_bash_script = (
385 dirpath + "/job_scripts/shuffle_h5_" + listfile_fname_wout_ext + ".sh"
386 )
388 with open(fpath_bash_script, "w") as f:
389 f.write("#!/usr/bin/env bash\n")
390 f.write("\n")
391 f.write("source " + cfg["venv_path"] + "activate \n")
392 f.write("\n")
393 f.write("# Shuffle the h5 file \n")
395 f.write(
396 "orcasong h5shuffle2 " + conc_outputfile_fpath)
398 if cfg["shuffle_delete"]:
399 f.write("\n")
400 f.write("rm " + conc_outputfile_fpath + "\n")
403def make_split(config_file):
404 # decode config
405 cfg = toml.load(config_file)
406 cfg["toml_filename"] = config_file
408 # read out all the input groups
409 ip_group_keys = get_all_ip_group_keys(cfg)
411 # and now loop over input groups extracting info
412 n_evts_total = 0
413 for key in ip_group_keys:
414 print("Collecting information from input group " + key)
415 cfg[key]["fpaths"] = get_h5_filepaths(cfg[key]["dir"])
416 cfg[key]["n_files"] = len(cfg[key]["fpaths"])
417 (
418 cfg[key]["n_evts"],
419 cfg[key]["n_evts_per_file_mean"],
420 cfg[key]["run_ids"],
421 ) = get_number_of_evts_and_run_ids(cfg[key]["fpaths"], dataset_key="y")
423 n_evts_total += cfg[key]["n_evts"]
425 cfg["n_evts_total"] = n_evts_total
426 # print the extracted statistics
427 print_input_statistics(cfg, ip_group_keys)
429 if cfg["print_only"] is True:
430 from sys import exit
432 exit()
434 for key in ip_group_keys:
435 add_fpaths_for_data_split_to_cfg(cfg, key)
437 # create the list files
438 make_dsplit_list_files(cfg)
440 # create bash scripts that can be submitted to do the concatenation and shuffle
441 if cfg["make_qsub_bash_files"] is True:
442 make_concatenate_and_shuffle_scripts(cfg)