Coverage for orcanet/in_out.py: 94%
322 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Utility code regarding user input.
5"""
7import os
8import shutil
9import h5py
10import numpy as np
11from inspect import signature
13# moved into IOHandler.get_batch for speed up; tensorflow import is slow!
14# from orcanet.h5_generator import Hdf5BatchGenerator
17def get_subfolder(main_folder, name=None, create=False):
18 """
19 Get the path to one or all subfolders of the main folder.
21 Parameters
22 ----------
23 main_folder : str
24 The main folder.
25 name : str or None
26 The name of the subfolder.
27 create : bool
28 If the subfolder should be created if it does not exist.
30 Returns
31 -------
32 subfolder : str or tuple
33 The path of the subfolder. If name is None, all subfolders
34 will be returned as a tuple.
36 """
37 if not main_folder[-1] == "/":
38 main_folder += "/"
40 subfolders = {
41 "train_log": main_folder + "train_log",
42 "saved_models": main_folder + "saved_models",
43 "plots": main_folder + "plots",
44 "activations": main_folder + "plots/activations",
45 "predictions": main_folder + "predictions",
46 "inference": main_folder + "predictions/inference",
47 }
49 def get(fdr):
50 subfdr = subfolders[fdr]
51 if create and not os.path.exists(subfdr):
52 print("Creating directory: " + subfdr)
53 os.makedirs(subfdr)
54 return subfdr
56 if name is None:
57 subfolder = [get(name) for name in subfolders]
58 else:
59 subfolder = get(name)
60 return subfolder
63def get_inputs(model):
64 """Get names and keras layers of the inputs of the model, as a dict."""
65 return {name: model.get_layer(name) for name in model.input_names}
68class IOHandler(object):
69 """
70 Access info indirectly contained in the cfg object.
71 """
73 def __init__(self, cfg):
74 self.cfg = cfg
76 # copies of files on local tmpdir
77 self._tmpdir_files_dict = {
78 "train": None,
79 "val": None,
80 "inference": None,
81 }
83 def get_latest_epoch(self):
84 """
85 Return the highest epoch/fileno pair of any saved model.
87 Returns
88 -------
89 latest_epoch : tuple or None
90 The highest epoch, file_no pair. None if the folder is
91 empty or does not exist yet.
93 """
94 epochs = self.get_all_epochs()
95 if len(epochs) == 0:
96 latest_epoch = None
97 else:
98 latest_epoch = epochs[-1]
100 return latest_epoch
102 def get_all_epochs(self):
103 """
104 Get a sorted list of the epoch/fileno pairs of all saved models.
106 Returns
107 -------
108 epochs : List
109 The (epoch, fileno) tuples. List is empty if none can be found.
110 """
111 saved_models_folder = self.cfg.output_folder + "saved_models"
112 epochs = []
114 if os.path.exists(saved_models_folder):
115 files = []
116 for file in os.listdir(saved_models_folder):
117 if file.startswith("model_epoch_") and file.endswith(".h5"):
118 files.append(file)
120 for file in files:
121 # model_epoch_XX_file_YY
122 file_base = os.path.splitext(file)[0]
123 f_epoch, file_no = file_base.split("model_epoch_")[-1].split("_file_")
124 epochs.append((int(f_epoch), int(file_no)))
125 epochs.sort()
127 return epochs
129 def get_next_epoch(self, epoch):
130 """
131 Return the next epoch / fileno tuple.
133 It depends on how many train files there are.
135 Parameters
136 ----------
137 epoch : tuple or None
138 Current epoch and file number.
140 Returns
141 -------
142 next_epoch : tuple
143 Next epoch and file number.
145 """
146 if epoch is None:
147 next_epoch = (1, 1)
148 elif epoch[1] == self.get_no_of_files("train"):
149 next_epoch = (epoch[0] + 1, 1)
150 else:
151 next_epoch = (epoch[0], epoch[1] + 1)
152 return next_epoch
154 def get_previous_epoch(self, epoch):
155 """Return the previous epoch / fileno tuple."""
156 if epoch[1] == 1:
157 if epoch[0] == 1:
158 raise ValueError(
159 "Can not get previous epoch of epoch {} file {}".format(*epoch)
160 )
161 n_train_files = self.get_no_of_files("train")
162 prev_epoch = (epoch[0] - 1, n_train_files)
164 else:
165 prev_epoch = (epoch[0], epoch[1] - 1)
167 return prev_epoch
169 def get_subfolder(self, name=None, create=False):
170 """
171 Get the path to one or all subfolders of the main folder.
173 Parameters
174 ----------
175 name : str or None
176 The name of the subfolder.
177 create : bool
178 If the subfolder should be created if it does not exist.
180 Returns
181 -------
182 subfolder : str or tuple
183 The path of the subfolder. If name is None, all subfolders
184 will be returned as a tuple.
186 """
187 subfolder = get_subfolder(self.cfg.output_folder, name, create)
188 return subfolder
190 def get_model_path(self, epoch, fileno, local=False):
191 """
192 Get the path to a model (which might not exist yet).
194 Parameters
195 ----------
196 epoch : int
197 Its epoch.
198 fileno : int
199 Its file number.
200 local : bool
201 If True, will only return the path inside the output_folder,
202 i.e. models/models_epochXX_file_YY.h5.
204 Returns
205 -------
206 model_path : str
207 The path to the model.
208 """
209 if epoch == -1 and fileno == -1:
210 epoch, fileno = self.get_latest_epoch()
211 if epoch < 1 or fileno < 1:
212 raise ValueError(
213 "Invalid epoch/file number {}, {}: Must be "
214 "either (-1, -1) or both >0".format(epoch, fileno)
215 )
217 subfolder = self.get_subfolder("saved_models")
218 if local:
219 subfolder = subfolder.split("/")[-1]
220 file_name = "model_epoch_{}_file_{}.h5".format(epoch, fileno)
222 model_path = subfolder + "/" + file_name
223 return model_path
225 def get_latest_prediction_file_no(self, epoch, fileno):
226 """
227 Returns the file number of the latest currently predicted val file.
229 Parameters
230 ----------
231 epoch : int
232 Epoch of the model that has predicted.
233 fileno : int
234 Fileno of the model that has predicted.
236 Returns
237 -------
238 latest_val_file_no : int or None
239 File number of the prediction file with the highest val index.
240 STARTS FROM 1, so this is whats in the file name.
241 None if there is none.
243 """
244 prediction_folder = self.get_subfolder("predictions", create=True)
246 val_file_nos = []
248 for file in os.listdir(prediction_folder):
249 # name e.g.: pred_model_epoch_6_file_1_on_list_val_file_1.h5
250 if not (file.endswith(".h5") and file.startswith("pred_model")):
251 continue
253 f_epoch, f_fileno, val_file_no = split_name_of_predfile(file)
254 if f_epoch == epoch and f_fileno == fileno:
255 val_file_nos.append(val_file_no)
257 if len(val_file_nos) == 0:
258 latest_val_file_no = None
259 else:
260 latest_val_file_no = max(val_file_nos)
262 return latest_val_file_no
264 def get_pred_path(self, epoch, fileno, pred_file_no):
265 """
266 Gets the path of a prediction file. The ints all start from 1.
268 Parameters
269 ----------
270 epoch : int
271 Epoch of an already trained nn model.
272 fileno : int
273 File number train step of an already trained nn model.
274 pred_file_no : int
275 Val file no of the prediction files that are found in the
276 prediction folder.
278 Returns
279 -------
280 pred_filepath : str
281 The path.
283 """
284 list_file = self.cfg.get_list_file()
285 if list_file is None:
286 raise ValueError(
287 "No toml list file specified. Can not look up " "saved prediction"
288 )
289 list_name = os.path.splitext(os.path.basename(list_file))[0]
291 pred_filepath = self.get_subfolder(
292 "predictions"
293 ) + "/pred_model_epoch_{}_file_{}_on_{}_val_file_{}.h5".format(
294 epoch, fileno, list_name, pred_file_no
295 )
297 return pred_filepath
299 def get_pred_files_list(self, epoch=None, fileno=None):
300 """
301 Returns a sorted list with all pred .h5 files in the prediction folder.
302 Does not include the inference files.
304 Parameters
305 ----------
306 epoch : int, optional
307 Specific model epoch to look pred files up for.
308 fileno : int, optional
309 Specific model epoch to look pred files up for.
311 Returns
312 -------
313 pred_files_list : List
314 List with the full filepaths of all prediction results files.
316 """
317 prediction_folder = self.get_subfolder("predictions")
319 pred_files_list = []
320 for file in os.listdir(prediction_folder):
321 if not (file.startswith("pred_model_epoch") and file.endswith(".h5")):
322 continue
323 pred_file = os.path.join(prediction_folder, file)
324 p_epoch, p_file_no, p_val_file_no = split_name_of_predfile(pred_file)
325 if epoch is not None and epoch != p_epoch:
326 continue
327 if fileno is not None and fileno != p_file_no:
328 continue
329 pred_files_list.append(pred_file)
331 pred_files_list.sort() # sort predicted val files from 1 ... n
332 return pred_files_list
334 def get_local_files(self, which):
335 """
336 Get the training or validation file paths for each list input set.
338 Returns the path to the copy of the file on the local tmpdir, which
339 it will generate if called for the first time.
341 Parameters
342 ----------
343 which : str
344 Either "train", "val", or "inference".
346 Returns
347 -------
348 dict
349 A dict containing the paths to the training or validation files on
350 which the model will be trained on. Example for the format for
351 two input sets with two files each:
352 {
353 "input_A" : ('path/to/set_A_file_1.h5', 'path/to/set_A_file_2.h5'),
354 "input_B" : ('path/to/set_B_file_1.h5', 'path/to/set_B_file_2.h5'),
355 }
357 """
358 if which not in self._tmpdir_files_dict.keys():
359 raise NameError("Unknown fileset name ", which)
361 files = self.cfg.get_files(which)
362 if self.cfg.use_scratch_ssd:
363 if self._tmpdir_files_dict[which] is None:
364 self._tmpdir_files_dict[which] = use_local_tmpdir(files)
365 return self._tmpdir_files_dict[which]
366 else:
367 return files
369 def get_n_bins(self):
370 """
371 Get the number of bins from the training files.
373 Only the first files are looked up, the others should be identical.
375 Returns
376 -------
377 n_bins : dict
378 Toml-list input names as keys, list of the bins as values.
380 """
381 # TODO check if bins are equal in all files?
382 train_files = self.get_local_files("train")
383 n_bins = {}
384 for input_key in train_files:
385 with h5py.File(train_files[input_key][0], "r") as f:
386 n_bins[input_key] = f[self.cfg.key_x_values].shape[1:]
387 return n_bins
389 def get_file_sizes(self, which):
390 """
391 Get the number of samples in each training or validation input file.
393 Parameters
394 ----------
395 which : str
396 Either train or val.
398 Returns
399 -------
400 file_sizes : List
401 Its length is equal to the number of files in each input set.
403 Raises
404 ------
405 ValueError
406 If there is a different number of samples in any of the
407 files of all inputs.
409 """
410 file_sizes_full, error_file_sizes, file_sizes = {}, [], []
411 for n, file_no_set in enumerate(self.yield_files(which)):
412 # the number of samples in the n-th file of all inputs
413 file_sizes_full[n] = [
414 h5_get_number_of_rows(file, datasets=[self.cfg.key_y_values])
415 for file in file_no_set.values()
416 ]
417 if not file_sizes_full[n].count(file_sizes_full[n][0]) == len(
418 file_sizes_full[n]
419 ):
420 error_file_sizes.append(n)
421 else:
422 file_sizes.append(file_sizes_full[n][0])
424 if len(error_file_sizes) != 0:
425 err_msg = (
426 "The files you gave for the different inputs of the model "
427 "do not all have the same number of samples!\n"
428 )
429 for n in error_file_sizes:
430 err_msg += (
431 "File no {} in {} has the following files sizes "
432 "for the different inputs: {}\n".format(
433 n, which, file_sizes_full[n]
434 )
435 )
436 raise ValueError(err_msg)
438 return file_sizes
440 def get_no_of_files(self, which):
441 """
442 Return the number of training or validation files.
444 Only looks up the no of files of one (random) list input, as equal
445 length is checked during read in.
447 Parameters
448 ----------
449 which : str
450 Either train or val.
452 Returns
453 -------
454 no_of_files : int
455 The number of files.
457 """
458 files = self.get_local_files(which)
459 no_of_files = len(list(files.values())[0])
460 return no_of_files
462 def yield_files(self, which):
463 """
464 Yield a training or validation filepaths for every input.
466 They will be yielded in the same order as they are given in the
467 toml file.
469 Parameters
470 ----------
471 which : str
472 Either train or val.
474 Yields
475 ------
476 files_dict : dict
477 Keys: The name of every toml list input.
478 Values: One of the filepaths.
480 """
481 files = self.get_local_files(which)
482 for file_no in range(self.get_no_of_files(which)):
483 files_dict = {key: files[key][file_no] for key in files}
484 yield files_dict
486 def get_file(self, which, file_no):
487 """Get a dict with the n-th files."""
488 files = self.get_local_files(which)
489 files_dict = {key: files[key][file_no - 1] for key in files}
490 return files_dict
492 def check_connections(self, model):
493 """
494 Check if the names and shapes of the samples and labels in the
495 given input files work with the model.
497 Also takes into account the possibly present sample or label modifiers.
499 Parameters
500 ----------
501 model : ks.model
502 A keras model.
504 Raises
505 ------
506 ValueError
507 If they dont work together.
509 """
510 print("\nInput check\n-----------")
511 # Get a batch of data to investigate the given modifier functions
512 info_blob = self.get_batch()
513 y_values = info_blob["y_values"]
514 layer_inputs = get_inputs(model)
515 # keys: name of layers, values: shape of input
516 layer_inp_shapes = {
517 key: layer_inputs[key].input_shape[0][1:] for key in layer_inputs
518 }
519 list_inp_shapes = self.get_n_bins()
521 print(
522 "The data in the files of the toml list have the following "
523 "names and shapes:"
524 )
525 for list_key in list_inp_shapes:
526 print("\t{}\t{}".format(list_key, list_inp_shapes[list_key]))
528 if self.cfg.sample_modifier is None:
529 print("\nYou did not specify a sample modifier.")
530 info_blob["xs"] = info_blob["x_values"]
531 else:
532 modified_xs = self.cfg.sample_modifier(info_blob)
533 modified_shapes = {
534 modi_key: tuple(modified_xs[modi_key].shape)[1:]
535 for modi_key in modified_xs
536 }
537 print(
538 "\nAfter applying your sample modifier, they have the "
539 "following names and shapes:"
540 )
541 for list_key in modified_shapes:
542 print("\t{}\t{}".format(list_key, modified_shapes[list_key]))
543 list_inp_shapes = modified_shapes
544 info_blob["xs"] = modified_xs
546 print("\nYour model requires the following input names and shapes:")
547 for layer_key in layer_inp_shapes:
548 print("\t{}\t{}".format(layer_key, layer_inp_shapes[layer_key]))
550 # Both inputs are dicts with name: shape of input/output layers/data
551 err_inp_names, err_inp_shapes = [], []
552 for layer_name in layer_inp_shapes:
553 if layer_name not in list_inp_shapes.keys():
554 # no matching name
555 err_inp_names.append(layer_name)
556 elif list_inp_shapes[layer_name] != layer_inp_shapes[layer_name]:
557 # no matching shape
558 err_inp_shapes.append(layer_name)
560 err_msg_inp = ""
561 if len(err_inp_names) == 0 and len(err_inp_shapes) == 0:
562 print("\nInput check passed.")
563 else:
564 print("\nInput check failed!")
565 if len(err_inp_names) != 0:
566 err_msg_inp += (
567 "No matching input name from the input files "
568 "for input layer(s): "
569 + (", ".join(str(e) for e in err_inp_names) + "\n")
570 )
571 if len(err_inp_shapes) != 0:
572 err_msg_inp += (
573 "Shapes of layers and labels do not match for "
574 "the following input layer(s): "
575 + (", ".join(str(e) for e in err_inp_shapes) + "\n")
576 )
577 print("Error:", err_msg_inp)
579 # ----------------------------------
580 print("\nOutput check\n------------")
581 # tuple of strings
582 mc_names = y_values.dtype.names
583 print(
584 "The following {} label names are in the first file of the "
585 "toml list:".format(len(mc_names))
586 )
587 print("\t" + ", ".join(str(name) for name in mc_names), end="\n\n")
589 if self.cfg.label_modifier is not None:
590 label_names = tuple(self.cfg.label_modifier(info_blob).keys())
591 print(
592 "The following {} labels get produced from them by your "
593 "label_modifier:".format(len(label_names))
594 )
595 print("\t" + ", ".join(str(name) for name in label_names), end="\n\n")
596 else:
597 label_names = mc_names
598 print(
599 "You did not specify a label_modifier. The output layers "
600 "will be provided with labels that match their name from "
601 "the above.\n\n"
602 )
604 # tuple of strings
605 loss_names = tuple(model.output_names)
606 print("Your model has the following {} output layers:".format(len(loss_names)))
607 print("\t" + ", ".join(str(name) for name in loss_names), end="\n\n")
609 err_out_names = []
610 for loss_name in loss_names:
611 if loss_name not in label_names:
612 err_out_names.append(loss_name)
614 err_msg_out = ""
615 if len(err_out_names) == 0:
616 print("Output check passed.\n")
617 else:
618 print("Output check failed!")
619 if len(err_out_names) != 0:
620 err_msg_out += (
621 "No matching label name from the input files "
622 "for output layer(s): "
623 + (", ".join(str(e) for e in err_out_names) + "\n")
624 )
625 print("Error:", err_msg_out)
627 err_msg = err_msg_inp + err_msg_out
628 if err_msg != "":
629 raise ValueError(err_msg)
631 def get_batch(self):
632 """
633 For testing purposes, return a batch of x_values and y_values.
635 This will always be the first batchsize samples and y_values from
636 the first file, before any modifiers have been applied.
638 Returns
639 -------
640 info_blob : dict
641 X- and y-values from the files. Has the following entries:
642 x_values : dict
643 Keys: Names of the input datasets from the list toml file.
644 Values: ndarray, a batch of samples.
645 y_values : ndarray
646 From the y_values datagroup of the input files.
648 """
649 # this will import tf; move inside here for speed up
650 from orcanet.h5_generator import Hdf5BatchGenerator
652 gen = Hdf5BatchGenerator(
653 next(self.yield_files("train")),
654 batchsize=self.cfg.batchsize,
655 key_x_values=self.cfg.key_x_values,
656 key_y_values=self.cfg.key_y_values,
657 keras_mode=False,
658 )
659 info_blob = gen[0]
660 info_blob.pop("xs")
661 info_blob.pop("ys")
662 return info_blob
664 def get_input_shapes(self):
665 """
666 Get the input names and shapes of the data after the modifier has
667 been applied.
669 Returns
670 -------
671 input_shapes : dict
672 Keys: Name of the inputs of the model.
673 Values: Their shape without the batchsize.
675 """
676 if self.cfg.sample_modifier is None:
677 input_shapes = self.get_n_bins()
678 else:
679 info_blob = self.get_batch()
680 xs_mod = self.cfg.sample_modifier(info_blob)
681 input_shapes = {
682 input_name: tuple(input_xs.shape)[1:]
683 for input_name, input_xs in xs_mod.items()
684 }
685 return input_shapes
687 def print_log(self, lines, logging=True):
688 """Print and also log to the full log file."""
689 if isinstance(lines, str):
690 lines = [
691 lines,
692 ]
694 if not logging:
695 for line in lines:
696 print(line)
697 else:
698 full_log_file = self.cfg.output_folder + "log.txt"
699 with open(full_log_file, "a+") as f_out:
700 for line in lines:
701 f_out.write(line + "\n")
702 print(line)
704 def get_epoch_float(self, epoch, fileno):
705 """Make a float value out of epoch/fileno."""
706 # calculate the fraction of samples per file compared to all files,
707 # e.g. [100, 50, 50] --> [0.5, 0.75, 1]
708 file_sizes = self.get_file_sizes("train")
709 file_sizes_rltv = np.cumsum(file_sizes) / np.sum(file_sizes)
711 epoch_float = epoch - 1 + file_sizes_rltv[fileno - 1]
712 return epoch_float
714 def get_learning_rate(self, epoch):
715 """
716 Get the learning rate for a given epoch and file number.
718 The user learning rate (cfg.learning_rate) can be None, a float,
719 a tuple, or a function.
721 Parameters
722 ----------
723 epoch : tuple
724 Epoch and file number. Both start at 1, i.e. the start of the
725 training is (1, 1), the next file is (1, 2), ...
726 This is also in the filename of the saved models.
728 Returns
729 -------
730 lr : float
731 The learning rate that will be used for the given epoch/fileno.
733 """
734 error_msg = (
735 "The learning rate must be either a float, a tuple of "
736 "two floats or a function."
737 )
738 no_train_files = self.get_no_of_files("train")
739 user_lr = self.cfg.learning_rate
741 if isinstance(user_lr, str):
742 # read lr from a csv file in the main folder, which must have
743 # 3 columns (Epoch, fileno, lr)
744 lr_file = os.path.join(self.cfg.output_folder, user_lr)
745 lr_table = np.genfromtxt(lr_file)
747 if len(lr_table.shape) == 1:
748 lr_table = lr_table.reshape((1,) + lr_table.shape)
750 if len(lr_table.shape) != 2 or lr_table.shape[1] != 3:
751 raise ValueError("Invalid lr.csv format")
752 lr_table = [[tuple(lrt[0:2]), lrt[2]] for lrt in lr_table]
753 lr_table.sort()
755 lr = None
756 # get lr from the table, one line before where the table is bigger
757 for table_epoch in lr_table:
758 if table_epoch[0] > tuple(epoch):
759 break
760 else:
761 lr = table_epoch[1]
762 if lr is None:
763 raise ValueError(
764 "csv learning rate not specified for epoch {}".format(epoch)
765 )
766 return lr
768 try:
769 # Float => Constant LR
770 lr = float(user_lr)
771 return lr
772 except (ValueError, TypeError):
773 pass
775 try:
776 # List => Exponentially decaying LR
777 length = len(user_lr)
778 lr_init = float(user_lr[0])
779 lr_decay = float(user_lr[1])
780 if length != 2:
781 raise LookupError(
782 "{} (Your tuple has length {})".format(error_msg, len(user_lr))
783 )
785 lr = lr_init * (1 - lr_decay) ** (
786 (epoch[1] - 1) + (epoch[0] - 1) * no_train_files
787 )
788 return lr
789 except (ValueError, TypeError):
790 pass
792 try:
793 # Callable => User defined function
794 n_params = len(signature(user_lr).parameters)
795 if n_params != 2:
796 raise TypeError(
797 "A custom learning rate function must have two "
798 "input parameters: The epoch and the file number. "
799 "(yours has {})".format(n_params)
800 )
801 lr = user_lr(epoch[0], epoch[1])
802 return lr
803 except (ValueError, TypeError):
804 raise TypeError(
805 "{} (You gave {} of type {}) ".format(error_msg, user_lr, type(user_lr))
806 )
809def split_name_of_predfile(file):
810 """
811 Get epoch, fileno, cal fileno from the name of a predfile.
813 Parameters
814 ----------
815 file : str
816 Like this: model_epoch_XX_file_YY_on_USERLIST_val_file_ZZ.h5
818 Returns
819 -------
820 epoch , file_no, val_file_no : tuple(int)
821 As integers.
823 """
824 file_base = os.path.splitext(file)[0]
825 rest, val_file_no = file_base.split("_val_file_")
826 rest, file_no = rest.split("_on_")[0].split("_file_")
827 epoch = rest.split("_epoch_")[-1]
829 epoch, file_no, val_file_no = map(int, [epoch, file_no, val_file_no])
831 return epoch, file_no, val_file_no
834def h5_get_number_of_rows(h5_filepath, datasets=None):
835 """
836 Gets the total number of rows of of a .h5 file.
838 Multiple dataset names can be given as a list to check if they all
839 have the same number of rows (axis 0).
841 Parameters
842 ----------
843 h5_filepath : str
844 filepath of the .h5 file.
845 datasets : list
846 Optional, The names of datasets in the file to check.
848 Returns
849 -------
850 number_of_rows: int
851 number of rows of the .h5 file in the first dataset.
853 Raises
854 ------
855 AssertionError
856 If the given datasets do not have the same no of rows.
858 """
859 with h5py.File(h5_filepath, "r") as f:
860 if datasets is None:
861 datasets = [x for x in list(f.keys())]
863 number_of_rows = [f[dataset].shape[0] for dataset in datasets]
864 if not number_of_rows.count(number_of_rows[0]) == len(number_of_rows):
865 err_msg = (
866 "Datasets do not have the same number of samples " "in file " + h5_filepath
867 )
868 for i, dataset in enumerate(datasets):
869 err_msg += "\nDataset: {}\tSamples: {}".format(dataset, number_of_rows[i])
870 raise AssertionError(err_msg)
871 return number_of_rows[0]
874def use_local_tmpdir(files):
875 """
876 Copies given files to the local temp folder.
878 Parameters
879 ----------
880 files : dict
881 Dict containing the file pathes.
883 Returns
884 -------
885 files_ssd : dict
886 Dict with updated SSD/scratch filepaths.
888 """
889 local_scratch_path = os.environ["TMPDIR"]
890 files_ssd = {}
892 for input_key in files:
893 old_pathes = files[input_key]
894 new_pathes = []
895 for f_path in old_pathes:
896 # copy to /scratch node-local SSD
897 f_path_ssd = os.path.join(local_scratch_path, os.path.basename(f_path))
898 print("Copying", f_path, "\nto", f_path_ssd)
899 shutil.copy2(f_path, local_scratch_path)
900 new_pathes.append(f_path_ssd)
901 files_ssd[input_key] = tuple(new_pathes)
903 print("Finished copying to local tmpdir folder.")
904 return files_ssd