Coverage for orcanet/in_out.py: 94%

322 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-03-28 14:22 +0000

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3""" 

4Utility code regarding user input. 

5""" 

6 

7import os 

8import shutil 

9import h5py 

10import numpy as np 

11from inspect import signature 

12 

13# moved into IOHandler.get_batch for speed up; tensorflow import is slow! 

14# from orcanet.h5_generator import Hdf5BatchGenerator 

15 

16 

17def get_subfolder(main_folder, name=None, create=False): 

18 """ 

19 Get the path to one or all subfolders of the main folder. 

20 

21 Parameters 

22 ---------- 

23 main_folder : str 

24 The main folder. 

25 name : str or None 

26 The name of the subfolder. 

27 create : bool 

28 If the subfolder should be created if it does not exist. 

29 

30 Returns 

31 ------- 

32 subfolder : str or tuple 

33 The path of the subfolder. If name is None, all subfolders 

34 will be returned as a tuple. 

35 

36 """ 

37 if not main_folder[-1] == "/": 

38 main_folder += "/" 

39 

40 subfolders = { 

41 "train_log": main_folder + "train_log", 

42 "saved_models": main_folder + "saved_models", 

43 "plots": main_folder + "plots", 

44 "activations": main_folder + "plots/activations", 

45 "predictions": main_folder + "predictions", 

46 "inference": main_folder + "predictions/inference", 

47 } 

48 

49 def get(fdr): 

50 subfdr = subfolders[fdr] 

51 if create and not os.path.exists(subfdr): 

52 print("Creating directory: " + subfdr) 

53 os.makedirs(subfdr) 

54 return subfdr 

55 

56 if name is None: 

57 subfolder = [get(name) for name in subfolders] 

58 else: 

59 subfolder = get(name) 

60 return subfolder 

61 

62 

63def get_inputs(model): 

64 """Get names and keras layers of the inputs of the model, as a dict.""" 

65 return {name: model.get_layer(name) for name in model.input_names} 

66 

67 

68class IOHandler(object): 

69 """ 

70 Access info indirectly contained in the cfg object. 

71 """ 

72 

73 def __init__(self, cfg): 

74 self.cfg = cfg 

75 

76 # copies of files on local tmpdir 

77 self._tmpdir_files_dict = { 

78 "train": None, 

79 "val": None, 

80 "inference": None, 

81 } 

82 

83 def get_latest_epoch(self): 

84 """ 

85 Return the highest epoch/fileno pair of any saved model. 

86 

87 Returns 

88 ------- 

89 latest_epoch : tuple or None 

90 The highest epoch, file_no pair. None if the folder is 

91 empty or does not exist yet. 

92 

93 """ 

94 epochs = self.get_all_epochs() 

95 if len(epochs) == 0: 

96 latest_epoch = None 

97 else: 

98 latest_epoch = epochs[-1] 

99 

100 return latest_epoch 

101 

102 def get_all_epochs(self): 

103 """ 

104 Get a sorted list of the epoch/fileno pairs of all saved models. 

105 

106 Returns 

107 ------- 

108 epochs : List 

109 The (epoch, fileno) tuples. List is empty if none can be found. 

110 """ 

111 saved_models_folder = self.cfg.output_folder + "saved_models" 

112 epochs = [] 

113 

114 if os.path.exists(saved_models_folder): 

115 files = [] 

116 for file in os.listdir(saved_models_folder): 

117 if file.startswith("model_epoch_") and file.endswith(".h5"): 

118 files.append(file) 

119 

120 for file in files: 

121 # model_epoch_XX_file_YY 

122 file_base = os.path.splitext(file)[0] 

123 f_epoch, file_no = file_base.split("model_epoch_")[-1].split("_file_") 

124 epochs.append((int(f_epoch), int(file_no))) 

125 epochs.sort() 

126 

127 return epochs 

128 

129 def get_next_epoch(self, epoch): 

130 """ 

131 Return the next epoch / fileno tuple. 

132 

133 It depends on how many train files there are. 

134 

135 Parameters 

136 ---------- 

137 epoch : tuple or None 

138 Current epoch and file number. 

139 

140 Returns 

141 ------- 

142 next_epoch : tuple 

143 Next epoch and file number. 

144 

145 """ 

146 if epoch is None: 

147 next_epoch = (1, 1) 

148 elif epoch[1] == self.get_no_of_files("train"): 

149 next_epoch = (epoch[0] + 1, 1) 

150 else: 

151 next_epoch = (epoch[0], epoch[1] + 1) 

152 return next_epoch 

153 

154 def get_previous_epoch(self, epoch): 

155 """Return the previous epoch / fileno tuple.""" 

156 if epoch[1] == 1: 

157 if epoch[0] == 1: 

158 raise ValueError( 

159 "Can not get previous epoch of epoch {} file {}".format(*epoch) 

160 ) 

161 n_train_files = self.get_no_of_files("train") 

162 prev_epoch = (epoch[0] - 1, n_train_files) 

163 

164 else: 

165 prev_epoch = (epoch[0], epoch[1] - 1) 

166 

167 return prev_epoch 

168 

169 def get_subfolder(self, name=None, create=False): 

170 """ 

171 Get the path to one or all subfolders of the main folder. 

172 

173 Parameters 

174 ---------- 

175 name : str or None 

176 The name of the subfolder. 

177 create : bool 

178 If the subfolder should be created if it does not exist. 

179 

180 Returns 

181 ------- 

182 subfolder : str or tuple 

183 The path of the subfolder. If name is None, all subfolders 

184 will be returned as a tuple. 

185 

186 """ 

187 subfolder = get_subfolder(self.cfg.output_folder, name, create) 

188 return subfolder 

189 

190 def get_model_path(self, epoch, fileno, local=False): 

191 """ 

192 Get the path to a model (which might not exist yet). 

193 

194 Parameters 

195 ---------- 

196 epoch : int 

197 Its epoch. 

198 fileno : int 

199 Its file number. 

200 local : bool 

201 If True, will only return the path inside the output_folder, 

202 i.e. models/models_epochXX_file_YY.h5. 

203 

204 Returns 

205 ------- 

206 model_path : str 

207 The path to the model. 

208 """ 

209 if epoch == -1 and fileno == -1: 

210 epoch, fileno = self.get_latest_epoch() 

211 if epoch < 1 or fileno < 1: 

212 raise ValueError( 

213 "Invalid epoch/file number {}, {}: Must be " 

214 "either (-1, -1) or both >0".format(epoch, fileno) 

215 ) 

216 

217 subfolder = self.get_subfolder("saved_models") 

218 if local: 

219 subfolder = subfolder.split("/")[-1] 

220 file_name = "model_epoch_{}_file_{}.h5".format(epoch, fileno) 

221 

222 model_path = subfolder + "/" + file_name 

223 return model_path 

224 

225 def get_latest_prediction_file_no(self, epoch, fileno): 

226 """ 

227 Returns the file number of the latest currently predicted val file. 

228 

229 Parameters 

230 ---------- 

231 epoch : int 

232 Epoch of the model that has predicted. 

233 fileno : int 

234 Fileno of the model that has predicted. 

235 

236 Returns 

237 ------- 

238 latest_val_file_no : int or None 

239 File number of the prediction file with the highest val index. 

240 STARTS FROM 1, so this is whats in the file name. 

241 None if there is none. 

242 

243 """ 

244 prediction_folder = self.get_subfolder("predictions", create=True) 

245 

246 val_file_nos = [] 

247 

248 for file in os.listdir(prediction_folder): 

249 # name e.g.: pred_model_epoch_6_file_1_on_list_val_file_1.h5 

250 if not (file.endswith(".h5") and file.startswith("pred_model")): 

251 continue 

252 

253 f_epoch, f_fileno, val_file_no = split_name_of_predfile(file) 

254 if f_epoch == epoch and f_fileno == fileno: 

255 val_file_nos.append(val_file_no) 

256 

257 if len(val_file_nos) == 0: 

258 latest_val_file_no = None 

259 else: 

260 latest_val_file_no = max(val_file_nos) 

261 

262 return latest_val_file_no 

263 

264 def get_pred_path(self, epoch, fileno, pred_file_no): 

265 """ 

266 Gets the path of a prediction file. The ints all start from 1. 

267 

268 Parameters 

269 ---------- 

270 epoch : int 

271 Epoch of an already trained nn model. 

272 fileno : int 

273 File number train step of an already trained nn model. 

274 pred_file_no : int 

275 Val file no of the prediction files that are found in the 

276 prediction folder. 

277 

278 Returns 

279 ------- 

280 pred_filepath : str 

281 The path. 

282 

283 """ 

284 list_file = self.cfg.get_list_file() 

285 if list_file is None: 

286 raise ValueError( 

287 "No toml list file specified. Can not look up " "saved prediction" 

288 ) 

289 list_name = os.path.splitext(os.path.basename(list_file))[0] 

290 

291 pred_filepath = self.get_subfolder( 

292 "predictions" 

293 ) + "/pred_model_epoch_{}_file_{}_on_{}_val_file_{}.h5".format( 

294 epoch, fileno, list_name, pred_file_no 

295 ) 

296 

297 return pred_filepath 

298 

299 def get_pred_files_list(self, epoch=None, fileno=None): 

300 """ 

301 Returns a sorted list with all pred .h5 files in the prediction folder. 

302 Does not include the inference files. 

303 

304 Parameters 

305 ---------- 

306 epoch : int, optional 

307 Specific model epoch to look pred files up for. 

308 fileno : int, optional 

309 Specific model epoch to look pred files up for. 

310 

311 Returns 

312 ------- 

313 pred_files_list : List 

314 List with the full filepaths of all prediction results files. 

315 

316 """ 

317 prediction_folder = self.get_subfolder("predictions") 

318 

319 pred_files_list = [] 

320 for file in os.listdir(prediction_folder): 

321 if not (file.startswith("pred_model_epoch") and file.endswith(".h5")): 

322 continue 

323 pred_file = os.path.join(prediction_folder, file) 

324 p_epoch, p_file_no, p_val_file_no = split_name_of_predfile(pred_file) 

325 if epoch is not None and epoch != p_epoch: 

326 continue 

327 if fileno is not None and fileno != p_file_no: 

328 continue 

329 pred_files_list.append(pred_file) 

330 

331 pred_files_list.sort() # sort predicted val files from 1 ... n 

332 return pred_files_list 

333 

334 def get_local_files(self, which): 

335 """ 

336 Get the training or validation file paths for each list input set. 

337 

338 Returns the path to the copy of the file on the local tmpdir, which 

339 it will generate if called for the first time. 

340 

341 Parameters 

342 ---------- 

343 which : str 

344 Either "train", "val", or "inference". 

345 

346 Returns 

347 ------- 

348 dict 

349 A dict containing the paths to the training or validation files on 

350 which the model will be trained on. Example for the format for 

351 two input sets with two files each: 

352 { 

353 "input_A" : ('path/to/set_A_file_1.h5', 'path/to/set_A_file_2.h5'), 

354 "input_B" : ('path/to/set_B_file_1.h5', 'path/to/set_B_file_2.h5'), 

355 } 

356 

357 """ 

358 if which not in self._tmpdir_files_dict.keys(): 

359 raise NameError("Unknown fileset name ", which) 

360 

361 files = self.cfg.get_files(which) 

362 if self.cfg.use_scratch_ssd: 

363 if self._tmpdir_files_dict[which] is None: 

364 self._tmpdir_files_dict[which] = use_local_tmpdir(files) 

365 return self._tmpdir_files_dict[which] 

366 else: 

367 return files 

368 

369 def get_n_bins(self): 

370 """ 

371 Get the number of bins from the training files. 

372 

373 Only the first files are looked up, the others should be identical. 

374 

375 Returns 

376 ------- 

377 n_bins : dict 

378 Toml-list input names as keys, list of the bins as values. 

379 

380 """ 

381 # TODO check if bins are equal in all files? 

382 train_files = self.get_local_files("train") 

383 n_bins = {} 

384 for input_key in train_files: 

385 with h5py.File(train_files[input_key][0], "r") as f: 

386 n_bins[input_key] = f[self.cfg.key_x_values].shape[1:] 

387 return n_bins 

388 

389 def get_file_sizes(self, which): 

390 """ 

391 Get the number of samples in each training or validation input file. 

392 

393 Parameters 

394 ---------- 

395 which : str 

396 Either train or val. 

397 

398 Returns 

399 ------- 

400 file_sizes : List 

401 Its length is equal to the number of files in each input set. 

402 

403 Raises 

404 ------ 

405 ValueError 

406 If there is a different number of samples in any of the 

407 files of all inputs. 

408 

409 """ 

410 file_sizes_full, error_file_sizes, file_sizes = {}, [], [] 

411 for n, file_no_set in enumerate(self.yield_files(which)): 

412 # the number of samples in the n-th file of all inputs 

413 file_sizes_full[n] = [ 

414 h5_get_number_of_rows(file, datasets=[self.cfg.key_y_values]) 

415 for file in file_no_set.values() 

416 ] 

417 if not file_sizes_full[n].count(file_sizes_full[n][0]) == len( 

418 file_sizes_full[n] 

419 ): 

420 error_file_sizes.append(n) 

421 else: 

422 file_sizes.append(file_sizes_full[n][0]) 

423 

424 if len(error_file_sizes) != 0: 

425 err_msg = ( 

426 "The files you gave for the different inputs of the model " 

427 "do not all have the same number of samples!\n" 

428 ) 

429 for n in error_file_sizes: 

430 err_msg += ( 

431 "File no {} in {} has the following files sizes " 

432 "for the different inputs: {}\n".format( 

433 n, which, file_sizes_full[n] 

434 ) 

435 ) 

436 raise ValueError(err_msg) 

437 

438 return file_sizes 

439 

440 def get_no_of_files(self, which): 

441 """ 

442 Return the number of training or validation files. 

443 

444 Only looks up the no of files of one (random) list input, as equal 

445 length is checked during read in. 

446 

447 Parameters 

448 ---------- 

449 which : str 

450 Either train or val. 

451 

452 Returns 

453 ------- 

454 no_of_files : int 

455 The number of files. 

456 

457 """ 

458 files = self.get_local_files(which) 

459 no_of_files = len(list(files.values())[0]) 

460 return no_of_files 

461 

462 def yield_files(self, which): 

463 """ 

464 Yield a training or validation filepaths for every input. 

465 

466 They will be yielded in the same order as they are given in the 

467 toml file. 

468 

469 Parameters 

470 ---------- 

471 which : str 

472 Either train or val. 

473 

474 Yields 

475 ------ 

476 files_dict : dict 

477 Keys: The name of every toml list input. 

478 Values: One of the filepaths. 

479 

480 """ 

481 files = self.get_local_files(which) 

482 for file_no in range(self.get_no_of_files(which)): 

483 files_dict = {key: files[key][file_no] for key in files} 

484 yield files_dict 

485 

486 def get_file(self, which, file_no): 

487 """Get a dict with the n-th files.""" 

488 files = self.get_local_files(which) 

489 files_dict = {key: files[key][file_no - 1] for key in files} 

490 return files_dict 

491 

492 def check_connections(self, model): 

493 """ 

494 Check if the names and shapes of the samples and labels in the 

495 given input files work with the model. 

496 

497 Also takes into account the possibly present sample or label modifiers. 

498 

499 Parameters 

500 ---------- 

501 model : ks.model 

502 A keras model. 

503 

504 Raises 

505 ------ 

506 ValueError 

507 If they dont work together. 

508 

509 """ 

510 print("\nInput check\n-----------") 

511 # Get a batch of data to investigate the given modifier functions 

512 info_blob = self.get_batch() 

513 y_values = info_blob["y_values"] 

514 layer_inputs = get_inputs(model) 

515 # keys: name of layers, values: shape of input 

516 layer_inp_shapes = { 

517 key: layer_inputs[key].input_shape[0][1:] for key in layer_inputs 

518 } 

519 list_inp_shapes = self.get_n_bins() 

520 

521 print( 

522 "The data in the files of the toml list have the following " 

523 "names and shapes:" 

524 ) 

525 for list_key in list_inp_shapes: 

526 print("\t{}\t{}".format(list_key, list_inp_shapes[list_key])) 

527 

528 if self.cfg.sample_modifier is None: 

529 print("\nYou did not specify a sample modifier.") 

530 info_blob["xs"] = info_blob["x_values"] 

531 else: 

532 modified_xs = self.cfg.sample_modifier(info_blob) 

533 modified_shapes = { 

534 modi_key: tuple(modified_xs[modi_key].shape)[1:] 

535 for modi_key in modified_xs 

536 } 

537 print( 

538 "\nAfter applying your sample modifier, they have the " 

539 "following names and shapes:" 

540 ) 

541 for list_key in modified_shapes: 

542 print("\t{}\t{}".format(list_key, modified_shapes[list_key])) 

543 list_inp_shapes = modified_shapes 

544 info_blob["xs"] = modified_xs 

545 

546 print("\nYour model requires the following input names and shapes:") 

547 for layer_key in layer_inp_shapes: 

548 print("\t{}\t{}".format(layer_key, layer_inp_shapes[layer_key])) 

549 

550 # Both inputs are dicts with name: shape of input/output layers/data 

551 err_inp_names, err_inp_shapes = [], [] 

552 for layer_name in layer_inp_shapes: 

553 if layer_name not in list_inp_shapes.keys(): 

554 # no matching name 

555 err_inp_names.append(layer_name) 

556 elif list_inp_shapes[layer_name] != layer_inp_shapes[layer_name]: 

557 # no matching shape 

558 err_inp_shapes.append(layer_name) 

559 

560 err_msg_inp = "" 

561 if len(err_inp_names) == 0 and len(err_inp_shapes) == 0: 

562 print("\nInput check passed.") 

563 else: 

564 print("\nInput check failed!") 

565 if len(err_inp_names) != 0: 

566 err_msg_inp += ( 

567 "No matching input name from the input files " 

568 "for input layer(s): " 

569 + (", ".join(str(e) for e in err_inp_names) + "\n") 

570 ) 

571 if len(err_inp_shapes) != 0: 

572 err_msg_inp += ( 

573 "Shapes of layers and labels do not match for " 

574 "the following input layer(s): " 

575 + (", ".join(str(e) for e in err_inp_shapes) + "\n") 

576 ) 

577 print("Error:", err_msg_inp) 

578 

579 # ---------------------------------- 

580 print("\nOutput check\n------------") 

581 # tuple of strings 

582 mc_names = y_values.dtype.names 

583 print( 

584 "The following {} label names are in the first file of the " 

585 "toml list:".format(len(mc_names)) 

586 ) 

587 print("\t" + ", ".join(str(name) for name in mc_names), end="\n\n") 

588 

589 if self.cfg.label_modifier is not None: 

590 label_names = tuple(self.cfg.label_modifier(info_blob).keys()) 

591 print( 

592 "The following {} labels get produced from them by your " 

593 "label_modifier:".format(len(label_names)) 

594 ) 

595 print("\t" + ", ".join(str(name) for name in label_names), end="\n\n") 

596 else: 

597 label_names = mc_names 

598 print( 

599 "You did not specify a label_modifier. The output layers " 

600 "will be provided with labels that match their name from " 

601 "the above.\n\n" 

602 ) 

603 

604 # tuple of strings 

605 loss_names = tuple(model.output_names) 

606 print("Your model has the following {} output layers:".format(len(loss_names))) 

607 print("\t" + ", ".join(str(name) for name in loss_names), end="\n\n") 

608 

609 err_out_names = [] 

610 for loss_name in loss_names: 

611 if loss_name not in label_names: 

612 err_out_names.append(loss_name) 

613 

614 err_msg_out = "" 

615 if len(err_out_names) == 0: 

616 print("Output check passed.\n") 

617 else: 

618 print("Output check failed!") 

619 if len(err_out_names) != 0: 

620 err_msg_out += ( 

621 "No matching label name from the input files " 

622 "for output layer(s): " 

623 + (", ".join(str(e) for e in err_out_names) + "\n") 

624 ) 

625 print("Error:", err_msg_out) 

626 

627 err_msg = err_msg_inp + err_msg_out 

628 if err_msg != "": 

629 raise ValueError(err_msg) 

630 

631 def get_batch(self): 

632 """ 

633 For testing purposes, return a batch of x_values and y_values. 

634 

635 This will always be the first batchsize samples and y_values from 

636 the first file, before any modifiers have been applied. 

637 

638 Returns 

639 ------- 

640 info_blob : dict 

641 X- and y-values from the files. Has the following entries: 

642 x_values : dict 

643 Keys: Names of the input datasets from the list toml file. 

644 Values: ndarray, a batch of samples. 

645 y_values : ndarray 

646 From the y_values datagroup of the input files. 

647 

648 """ 

649 # this will import tf; move inside here for speed up 

650 from orcanet.h5_generator import Hdf5BatchGenerator 

651 

652 gen = Hdf5BatchGenerator( 

653 next(self.yield_files("train")), 

654 batchsize=self.cfg.batchsize, 

655 key_x_values=self.cfg.key_x_values, 

656 key_y_values=self.cfg.key_y_values, 

657 keras_mode=False, 

658 ) 

659 info_blob = gen[0] 

660 info_blob.pop("xs") 

661 info_blob.pop("ys") 

662 return info_blob 

663 

664 def get_input_shapes(self): 

665 """ 

666 Get the input names and shapes of the data after the modifier has 

667 been applied. 

668 

669 Returns 

670 ------- 

671 input_shapes : dict 

672 Keys: Name of the inputs of the model. 

673 Values: Their shape without the batchsize. 

674 

675 """ 

676 if self.cfg.sample_modifier is None: 

677 input_shapes = self.get_n_bins() 

678 else: 

679 info_blob = self.get_batch() 

680 xs_mod = self.cfg.sample_modifier(info_blob) 

681 input_shapes = { 

682 input_name: tuple(input_xs.shape)[1:] 

683 for input_name, input_xs in xs_mod.items() 

684 } 

685 return input_shapes 

686 

687 def print_log(self, lines, logging=True): 

688 """Print and also log to the full log file.""" 

689 if isinstance(lines, str): 

690 lines = [ 

691 lines, 

692 ] 

693 

694 if not logging: 

695 for line in lines: 

696 print(line) 

697 else: 

698 full_log_file = self.cfg.output_folder + "log.txt" 

699 with open(full_log_file, "a+") as f_out: 

700 for line in lines: 

701 f_out.write(line + "\n") 

702 print(line) 

703 

704 def get_epoch_float(self, epoch, fileno): 

705 """Make a float value out of epoch/fileno.""" 

706 # calculate the fraction of samples per file compared to all files, 

707 # e.g. [100, 50, 50] --> [0.5, 0.75, 1] 

708 file_sizes = self.get_file_sizes("train") 

709 file_sizes_rltv = np.cumsum(file_sizes) / np.sum(file_sizes) 

710 

711 epoch_float = epoch - 1 + file_sizes_rltv[fileno - 1] 

712 return epoch_float 

713 

714 def get_learning_rate(self, epoch): 

715 """ 

716 Get the learning rate for a given epoch and file number. 

717 

718 The user learning rate (cfg.learning_rate) can be None, a float, 

719 a tuple, or a function. 

720 

721 Parameters 

722 ---------- 

723 epoch : tuple 

724 Epoch and file number. Both start at 1, i.e. the start of the 

725 training is (1, 1), the next file is (1, 2), ... 

726 This is also in the filename of the saved models. 

727 

728 Returns 

729 ------- 

730 lr : float 

731 The learning rate that will be used for the given epoch/fileno. 

732 

733 """ 

734 error_msg = ( 

735 "The learning rate must be either a float, a tuple of " 

736 "two floats or a function." 

737 ) 

738 no_train_files = self.get_no_of_files("train") 

739 user_lr = self.cfg.learning_rate 

740 

741 if isinstance(user_lr, str): 

742 # read lr from a csv file in the main folder, which must have 

743 # 3 columns (Epoch, fileno, lr) 

744 lr_file = os.path.join(self.cfg.output_folder, user_lr) 

745 lr_table = np.genfromtxt(lr_file) 

746 

747 if len(lr_table.shape) == 1: 

748 lr_table = lr_table.reshape((1,) + lr_table.shape) 

749 

750 if len(lr_table.shape) != 2 or lr_table.shape[1] != 3: 

751 raise ValueError("Invalid lr.csv format") 

752 lr_table = [[tuple(lrt[0:2]), lrt[2]] for lrt in lr_table] 

753 lr_table.sort() 

754 

755 lr = None 

756 # get lr from the table, one line before where the table is bigger 

757 for table_epoch in lr_table: 

758 if table_epoch[0] > tuple(epoch): 

759 break 

760 else: 

761 lr = table_epoch[1] 

762 if lr is None: 

763 raise ValueError( 

764 "csv learning rate not specified for epoch {}".format(epoch) 

765 ) 

766 return lr 

767 

768 try: 

769 # Float => Constant LR 

770 lr = float(user_lr) 

771 return lr 

772 except (ValueError, TypeError): 

773 pass 

774 

775 try: 

776 # List => Exponentially decaying LR 

777 length = len(user_lr) 

778 lr_init = float(user_lr[0]) 

779 lr_decay = float(user_lr[1]) 

780 if length != 2: 

781 raise LookupError( 

782 "{} (Your tuple has length {})".format(error_msg, len(user_lr)) 

783 ) 

784 

785 lr = lr_init * (1 - lr_decay) ** ( 

786 (epoch[1] - 1) + (epoch[0] - 1) * no_train_files 

787 ) 

788 return lr 

789 except (ValueError, TypeError): 

790 pass 

791 

792 try: 

793 # Callable => User defined function 

794 n_params = len(signature(user_lr).parameters) 

795 if n_params != 2: 

796 raise TypeError( 

797 "A custom learning rate function must have two " 

798 "input parameters: The epoch and the file number. " 

799 "(yours has {})".format(n_params) 

800 ) 

801 lr = user_lr(epoch[0], epoch[1]) 

802 return lr 

803 except (ValueError, TypeError): 

804 raise TypeError( 

805 "{} (You gave {} of type {}) ".format(error_msg, user_lr, type(user_lr)) 

806 ) 

807 

808 

809def split_name_of_predfile(file): 

810 """ 

811 Get epoch, fileno, cal fileno from the name of a predfile. 

812 

813 Parameters 

814 ---------- 

815 file : str 

816 Like this: model_epoch_XX_file_YY_on_USERLIST_val_file_ZZ.h5 

817 

818 Returns 

819 ------- 

820 epoch , file_no, val_file_no : tuple(int) 

821 As integers. 

822 

823 """ 

824 file_base = os.path.splitext(file)[0] 

825 rest, val_file_no = file_base.split("_val_file_") 

826 rest, file_no = rest.split("_on_")[0].split("_file_") 

827 epoch = rest.split("_epoch_")[-1] 

828 

829 epoch, file_no, val_file_no = map(int, [epoch, file_no, val_file_no]) 

830 

831 return epoch, file_no, val_file_no 

832 

833 

834def h5_get_number_of_rows(h5_filepath, datasets=None): 

835 """ 

836 Gets the total number of rows of of a .h5 file. 

837 

838 Multiple dataset names can be given as a list to check if they all 

839 have the same number of rows (axis 0). 

840 

841 Parameters 

842 ---------- 

843 h5_filepath : str 

844 filepath of the .h5 file. 

845 datasets : list 

846 Optional, The names of datasets in the file to check. 

847 

848 Returns 

849 ------- 

850 number_of_rows: int 

851 number of rows of the .h5 file in the first dataset. 

852 

853 Raises 

854 ------ 

855 AssertionError 

856 If the given datasets do not have the same no of rows. 

857 

858 """ 

859 with h5py.File(h5_filepath, "r") as f: 

860 if datasets is None: 

861 datasets = [x for x in list(f.keys())] 

862 

863 number_of_rows = [f[dataset].shape[0] for dataset in datasets] 

864 if not number_of_rows.count(number_of_rows[0]) == len(number_of_rows): 

865 err_msg = ( 

866 "Datasets do not have the same number of samples " "in file " + h5_filepath 

867 ) 

868 for i, dataset in enumerate(datasets): 

869 err_msg += "\nDataset: {}\tSamples: {}".format(dataset, number_of_rows[i]) 

870 raise AssertionError(err_msg) 

871 return number_of_rows[0] 

872 

873 

874def use_local_tmpdir(files): 

875 """ 

876 Copies given files to the local temp folder. 

877 

878 Parameters 

879 ---------- 

880 files : dict 

881 Dict containing the file pathes. 

882 

883 Returns 

884 ------- 

885 files_ssd : dict 

886 Dict with updated SSD/scratch filepaths. 

887 

888 """ 

889 local_scratch_path = os.environ["TMPDIR"] 

890 files_ssd = {} 

891 

892 for input_key in files: 

893 old_pathes = files[input_key] 

894 new_pathes = [] 

895 for f_path in old_pathes: 

896 # copy to /scratch node-local SSD 

897 f_path_ssd = os.path.join(local_scratch_path, os.path.basename(f_path)) 

898 print("Copying", f_path, "\nto", f_path_ssd) 

899 shutil.copy2(f_path, local_scratch_path) 

900 new_pathes.append(f_path_ssd) 

901 files_ssd[input_key] = tuple(new_pathes) 

902 

903 print("Finished copying to local tmpdir folder.") 

904 return files_ssd