Coverage for orcanet/core.py: 80%

390 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-03-28 14:22 +0000

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3""" 

4Core scripts for the OrcaNet package. 

5""" 

6 

7import os 

8import toml 

9import warnings 

10import time 

11from datetime import timedelta 

12import tensorflow as tf 

13 

14import orcanet.backend as backend 

15from orcanet.utilities.visualization import update_summary_plot 

16from orcanet.in_out import IOHandler 

17from orcanet.history import HistoryHandler 

18from orcanet.utilities.nn_utilities import load_zero_center_data 

19import orcanet.lib as lib 

20import orcanet.logging as logging 

21import orcanet.misc 

22import medgeconv 

23 

24 

25class Organizer: 

26 """ 

27 Core class for working with networks in OrcaNet. 

28 

29 Attributes 

30 ---------- 

31 cfg : orcanet.core.Configuration 

32 Contains all configurable options. 

33 io : orcanet.in_out.IOHandler 

34 Utility functions for accessing the info in cfg. 

35 history : orcanet.in_out.HistoryHandler 

36 For reading and plotting data from the log files created 

37 during training. 

38 

39 """ 

40 

41 def __init__( 

42 self, 

43 output_folder, 

44 list_file=None, 

45 config_file=None, 

46 tf_log_level=None, 

47 discover_tomls=True, 

48 ): 

49 """ 

50 Set the attributes of the Configuration object. 

51 

52 Instead of using a config_file, the attributes of orga.cfg can 

53 also be changed directly, e.g. by calling orga.cfg.batchsize. 

54 

55 Parameters 

56 ---------- 

57 output_folder : str 

58 Name of the folder of this model in which everything will be saved, 

59 e.g., the summary.txt log file is located in here. 

60 Will be used to load saved files or to save new ones. 

61 list_file : str, optional 

62 Path to a toml list file with pathes to all the h5 files that should 

63 be used for training and validation. 

64 Will be used to extract samples and labels. 

65 Default: Look for a file called 'list.toml' in the given output_folder. 

66 config_file : str, optional 

67 Path to a toml config file with settings that are used instead of 

68 the default ones. 

69 Default: Look for a file called 'config.toml' in the given output_folder. 

70 tf_log_level : int/str 

71 Sets the TensorFlow CPP_MIN_LOG_LEVEL environment variable. 

72 0 = all messages are logged (default behavior). 

73 1 = INFO messages are not printed. 

74 2 = INFO and WARNING messages are not printed. 

75 3 = INFO, WARNING, and ERROR messages are not printed. 

76 discover_tomls : bool 

77 If False, do not try to look for toml files in the given 

78 output_folder if list_file or config_file is None [Default: True]. 

79 

80 """ 

81 if tf_log_level is not None: 

82 os.environ["TF_CPP_MIN_LOG_LEVEL"] = str(tf_log_level) 

83 

84 if discover_tomls and list_file is None: 

85 list_file = orcanet.misc.find_file(output_folder, "list.toml") 

86 if discover_tomls and config_file is None: 

87 config_file = orcanet.misc.find_file(output_folder, "config.toml") 

88 

89 self.cfg = Configuration(output_folder, list_file, config_file) 

90 self.io = IOHandler(self.cfg) 

91 self.history = HistoryHandler(output_folder) 

92 

93 self.xs_mean = None 

94 self._auto_label_modifier = None 

95 self._stored_model = None 

96 self._strategy = None 

97 

98 def train_and_validate(self, model=None, epochs=None, to_epoch=None): 

99 """ 

100 Train a model and validate according to schedule. 

101 

102 The various settings of this process can be controlled with the 

103 attributes of orca.cfg. 

104 The model will be trained on the given data, saved and validated. 

105 Logfiles of the training are saved in the output folder. 

106 Plots showing the training and validation history, as well as 

107 the weights and activations of the network are generated in 

108 the plots subfolder after every validation. 

109 The training can be resumed by executing this function again. 

110 

111 Parameters 

112 ---------- 

113 model : ks.models.Model or str, optional 

114 Compiled keras model to use for training. Required for the first 

115 epoch (the start of training). 

116 Can also be the path to a saved keras model, which will be laoded. 

117 If model is None, the most recent saved model will be 

118 loaded automatically to continue the training. 

119 epochs : int, optional 

120 How many epochs should be trained by running this function. 

121 None for infinite. This includes the current epoch in case it 

122 is not finished yet, i.e. 1 means complete the epoch if there 

123 are files left, otherwise do the next epoch. 

124 to_epoch : int, optional 

125 Train up to and including this epoch. Can not be used together with 

126 epochs. 

127 

128 Returns 

129 ------- 

130 model : ks.models.Model 

131 The trained keras model. 

132 

133 """ 

134 latest_epoch = self.io.get_latest_epoch() 

135 

136 model = self._get_model(model, logging=False) 

137 self._stored_model = model 

138 

139 # check if the validation is missing for the latest fileno 

140 if latest_epoch is not None: 

141 state = self.history.get_state()[-1] 

142 if state["is_validated"] is False and self.val_is_due(latest_epoch): 

143 self.validate() 

144 

145 next_epoch = self.io.get_next_epoch(latest_epoch) 

146 n_train_files = self.io.get_no_of_files("train") 

147 

148 if to_epoch is None: 

149 epochs_left = epochs 

150 else: 

151 if epochs is not None: 

152 raise ValueError("Can not give both 'epochs' and 'to_epoch'") 

153 if latest_epoch is None: 

154 epochs_left = to_epoch 

155 else: 

156 epochs_left = max( 

157 0, to_epoch - self.io.get_next_epoch(latest_epoch)[0] + 1 

158 ) 

159 

160 trained_epochs = 0 

161 while epochs_left is None or trained_epochs < epochs_left: 

162 # Train on remaining files 

163 for file_no in range(next_epoch[1], n_train_files + 1): 

164 curr_epoch = (next_epoch[0], file_no) 

165 self.train(model) 

166 if self.val_is_due(curr_epoch): 

167 self.validate() 

168 

169 next_epoch = (next_epoch[0] + 1, 1) 

170 trained_epochs += 1 

171 

172 self._stored_model = None 

173 return model 

174 

175 def train(self, model=None): 

176 """ 

177 Trains a model on the next file. 

178 

179 The progress of the training is also logged and plotted. 

180 

181 Parameters 

182 ---------- 

183 model : ks.models.Model or str, optional 

184 Compiled keras model to use for training. Required for the first 

185 epoch (the start of training). 

186 Can also be the path to a saved keras model, which will be laoded. 

187 If model is None, the most recent saved model will be 

188 loaded automatically to continue the training. 

189 

190 Returns 

191 ------- 

192 history : dict 

193 The history of the training on this file. A record of training 

194 loss values and metrics values. 

195 

196 """ 

197 # Create folder structure 

198 self.io.get_subfolder(create=True) 

199 latest_epoch = self.io.get_latest_epoch() 

200 

201 model = self._get_model(model, logging=True) 

202 

203 self._set_up(model, logging=True) 

204 

205 # epoch about to be trained 

206 next_epoch = self.io.get_next_epoch(latest_epoch) 

207 next_epoch_float = self.io.get_epoch_float(*next_epoch) 

208 

209 if latest_epoch is None: 

210 self.io.check_connections(model) 

211 logging.log_start_training(self) 

212 

213 model_path = self.io.get_model_path(*next_epoch) 

214 model_path_local = self.io.get_model_path(*next_epoch, local=True) 

215 if os.path.isfile(model_path): 

216 raise FileExistsError( 

217 "Can not train model in epoch {} file {}, this model has " 

218 "already been saved!".format(*next_epoch) 

219 ) 

220 

221 smry_logger = logging.SummaryLogger(self, model) 

222 

223 if self.cfg.learning_rate is not None: 

224 tf.keras.backend.set_value( 

225 model.optimizer.lr, self.io.get_learning_rate(next_epoch) 

226 ) 

227 

228 files_dict = self.io.get_file("train", next_epoch[1]) 

229 

230 line = "Training in epoch {} on file {}/{}".format( 

231 next_epoch[0], next_epoch[1], self.io.get_no_of_files("train") 

232 ) 

233 self.io.print_log(line) 

234 self.io.print_log("-" * len(line)) 

235 self.io.print_log( 

236 "Learning rate is at {}".format( 

237 tf.keras.backend.get_value(model.optimizer.lr) 

238 ) 

239 ) 

240 self.io.print_log("Inputs and files:") 

241 for input_name, input_file in files_dict.items(): 

242 self.io.print_log( 

243 " {}: \t{}".format(input_name, os.path.basename(input_file)) 

244 ) 

245 

246 start_time = time.time() 

247 history = backend.train_model(self, model, next_epoch, batch_logger=True) 

248 elapsed_s = int(time.time() - start_time) 

249 

250 model.save(model_path) 

251 smry_logger.write_line( 

252 next_epoch_float, 

253 tf.keras.backend.get_value(model.optimizer.lr), 

254 history_train=history, 

255 ) 

256 

257 self.io.print_log("Training results:") 

258 for metric_name, loss in history.items(): 

259 self.io.print_log(f" {metric_name}: \t{loss}") 

260 self.io.print_log(f"Elapsed time: {timedelta(seconds=elapsed_s)}") 

261 self.io.print_log(f"Saved model to: {model_path_local}\n") 

262 

263 update_summary_plot(self) 

264 if self.cfg.cleanup_models: 

265 self.cleanup_models() 

266 

267 return history 

268 

269 def validate(self): 

270 """ 

271 Validate the most recent saved model on all validation files. 

272 

273 Will also log the progress, as well as update the summary plot and 

274 plot weights and activations of the model. 

275 

276 Returns 

277 ------- 

278 history : dict 

279 The history of the validation on all files. A record of validation 

280 loss values and metrics values. 

281 

282 """ 

283 latest_epoch = self.io.get_latest_epoch() 

284 if latest_epoch is None: 

285 raise ValueError("Can not validate: No saved model found") 

286 if self.history.get_state()[-1]["is_validated"] is True: 

287 raise ValueError( 

288 "Can not validate in epoch {} file {}: " 

289 "Has already been validated".format(*latest_epoch) 

290 ) 

291 

292 if self._stored_model is None: 

293 model = self.load_saved_model(*latest_epoch) 

294 else: 

295 model = self._stored_model 

296 

297 self._set_up(model, logging=True) 

298 

299 epoch_float = self.io.get_epoch_float(*latest_epoch) 

300 smry_logger = logging.SummaryLogger(self, model) 

301 

302 logging.log_start_validation(self) 

303 

304 start_time = time.time() 

305 history = backend.validate_model(self, model) 

306 elapsed_s = int(time.time() - start_time) 

307 

308 self.io.print_log("Validation results:") 

309 for metric_name, loss in history.items(): 

310 self.io.print_log(f" {metric_name}: \t{loss}") 

311 self.io.print_log(f"Elapsed time: {timedelta(seconds=elapsed_s)}\n") 

312 smry_logger.write_line(epoch_float, "n/a", history_val=history) 

313 

314 update_summary_plot(self) 

315 

316 if self.cfg.cleanup_models: 

317 self.cleanup_models() 

318 

319 return history 

320 

321 def predict(self, epoch=None, fileno=None, samples=None): 

322 """ 

323 Make a prediction if it does not exist yet, and return its filepath. 

324 

325 Load the model with the lowest validation loss, let it predict on 

326 all samples of the validation set 

327 in the toml list, and save this prediction together with all the 

328 y_values as h5 file(s) in the predictions subfolder. 

329 

330 Parameters 

331 ---------- 

332 epoch : int, optional 

333 Epoch of a model to load. Default: lowest val loss. 

334 fileno : int, optional 

335 File number of a model to load. Default: lowest val loss. 

336 samples : int, optional 

337 Don't use the full validation files, but just the given number 

338 of samples. 

339 

340 Returns 

341 ------- 

342 pred_filename : List 

343 List to the paths of all the prediction files. 

344 

345 """ 

346 epoch, fileno = self._get_auto_epoch(epoch, fileno) 

347 if self._check_if_pred_already_done(epoch, fileno): 

348 print("Prediction has already been done.") 

349 pred_filepaths = self.io.get_pred_files_list(epoch, fileno) 

350 

351 else: 

352 if self._stored_model is None: 

353 model = self.load_saved_model(epoch, fileno, logging=False) 

354 else: 

355 model = self._stored_model 

356 self._set_up(model) 

357 

358 start_time = time.time() 

359 backend.make_model_prediction(self, model, epoch, fileno, samples=samples) 

360 elapsed_s = int(time.time() - start_time) 

361 print("Finished predicting on all validation files.") 

362 print("Elapsed time: {}\n".format(timedelta(seconds=elapsed_s))) 

363 

364 pred_filepaths = self.io.get_pred_files_list(epoch, fileno) 

365 

366 return pred_filepaths 

367 

368 def inference(self, epoch=None, fileno=None, as_generator=False): 

369 """ 

370 Make an inference and return the filepaths. 

371 

372 Load the model with the lowest validation loss, let 

373 it predict on all samples of all inference files 

374 in the toml list, and save these predictions as h5 files in the 

375 predictions subfolder. y values will only be added if they are in 

376 the input file, so this can be used on un-labeled data as well. 

377 

378 Parameters 

379 ---------- 

380 epoch : int, optional 

381 Epoch of a model to load. Default: lowest val loss. 

382 fileno : int, optional 

383 File number of a model to load. Default: lowest val loss. 

384 as_generator : bool 

385 If true, return a generator, which yields the output filename 

386 after the inference of each file. 

387 If false (default), do all files back to back. 

388 

389 Returns 

390 ------- 

391 filenames : list 

392 List to the paths of all created output files. 

393 

394 """ 

395 gen = self._inference(epoch=epoch, fileno=fileno) 

396 if as_generator: 

397 return gen 

398 else: 

399 return [filename for filename in gen] 

400 

401 def _inference(self, epoch=None, fileno=None): 

402 """Get generator of doing inference, file by file.""" 

403 epoch, fileno = self._get_auto_epoch(epoch, fileno) 

404 if self._stored_model is None: 

405 model = self.load_saved_model(epoch, fileno, logging=False) 

406 else: 

407 model = self._stored_model 

408 self._set_up(model) 

409 

410 filenames = [] 

411 for files_dict in self.io.yield_files("inference"): 

412 # output filename is based on name of file in first input 

413 first_filename = os.path.basename(list(files_dict.values())[0]) 

414 output_filename = "model_epoch_{}_file_{}_on_{}".format( 

415 epoch, fileno, first_filename 

416 ) 

417 

418 output_path = os.path.join( 

419 self.io.get_subfolder("inference"), output_filename 

420 ) 

421 filenames.append(output_path) 

422 if os.path.exists(output_path): 

423 print("File {} exists already, skipping...".format(output_filename)) 

424 continue 

425 

426 print(f"Working on file {first_filename}") 

427 start_time = time.time() 

428 backend.h5_inference( 

429 self, model, files_dict, output_path, use_def_label=False 

430 ) 

431 elapsed_s = int(time.time() - start_time) 

432 print(f"Finished on file {first_filename} in {elapsed_s/60} min") 

433 yield output_path 

434 

435 def inference_on_file( 

436 self, input_file, output_file=None, saved_model=None, epoch=None, fileno=None 

437 ): 

438 """ 

439 Save the model prediction for each sample of the given input file. 

440 

441 Useful for sharing a saved model, since the usual training folder 

442 structure is not necessarily required. 

443 

444 Parameters 

445 --------- 

446 input_file : str or dict 

447 Path to a DL file on which the inference should be done on. 

448 Can also be a dict mapping input names to files. 

449 output_file : str, optional 

450 Save output to an h5 file with this name. Default: auto-generate 

451 name and save in same directory as the input file. 

452 saved_model : str, optional 

453 Optional path to a saved model, which will be used instead 

454 of loading the one with the given epoch/fileno. 

455 epoch : int, optional 

456 Epoch of a model to load from the directory. Only relevant if 

457 saved_model is None. Default: lowest val loss. 

458 fileno : int, optional 

459 File number of a model to load from the directory. Only relevant 

460 if saved_model is None. Default: lowest val loss. 

461 

462 Returns 

463 ------- 

464 str 

465 Name of the output file. 

466 

467 """ 

468 if saved_model is None: 

469 epoch, fileno = self._get_auto_epoch(epoch, fileno) 

470 model = self.load_saved_model(epoch, fileno, logging=False) 

471 else: 

472 model = self._load_model(saved_model) 

473 self._set_up(model) 

474 

475 if isinstance(input_file, str): 

476 input_file = {model.input_names[0]: input_file} 

477 if output_file is None: 

478 out_path, first_filename = os.path.split(list(input_file.values())[0]) 

479 output_file = os.path.join(out_path, "dl_pred_{}".format(first_filename)) 

480 start_time = time.time() 

481 backend.h5_inference( 

482 orga=self, 

483 model=model, 

484 files_dict=input_file, 

485 output_path=output_file, 

486 ) 

487 elapsed_s = int(time.time() - start_time) 

488 print(f"Finished inference in {elapsed_s / 60} min") 

489 return output_file 

490 

491 def cleanup_models(self): 

492 """ 

493 Delete all models except for the the most recent one (to continue 

494 training), and the ones with the highest and lowest loss/metrics. 

495 

496 """ 

497 all_epochs = self.io.get_all_epochs() 

498 epochs_to_keep = { 

499 self.io.get_latest_epoch(), 

500 } 

501 try: 

502 for metric in self.history.get_metrics(): 

503 epochs_to_keep.add( 

504 self.history.get_best_epoch_fileno( 

505 metric=f"val_{metric}", mini=True 

506 ) 

507 ) 

508 epochs_to_keep.add( 

509 self.history.get_best_epoch_fileno( 

510 metric=f"val_{metric}", mini=False 

511 ) 

512 ) 

513 except ValueError: 

514 # no best epoch exists 

515 pass 

516 

517 for epoch in epochs_to_keep: 

518 if epoch not in all_epochs: 

519 warnings.warn( 

520 f"ERROR: keeping_epoch {epoch} not in available epochs {all_epochs}, " 

521 f"skipping clean-up of models!" 

522 ) 

523 return 

524 

525 print("\nClean-up saved models:") 

526 for epoch in all_epochs: 

527 model_path = self.io.get_model_path(epoch[0], epoch[1]) 

528 model_name = os.path.basename(model_path) 

529 if epoch in epochs_to_keep: 

530 print("Keeping model {}".format(model_name)) 

531 else: 

532 print("Deleting model {}".format(model_name)) 

533 os.remove(model_path) 

534 

535 def _check_if_pred_already_done(self, epoch, fileno): 

536 """ 

537 Checks if the prediction has already been done before. 

538 (-> predicted on all validation files) 

539 

540 Returns 

541 ------- 

542 pred_done : bool 

543 Boolean flag to specify if the prediction has 

544 already been fully done or not. 

545 

546 """ 

547 latest_pred_file_no = self.io.get_latest_prediction_file_no(epoch, fileno) 

548 total_no_of_val_files = self.io.get_no_of_files("val") 

549 

550 if latest_pred_file_no is None: 

551 pred_done = False 

552 elif latest_pred_file_no == total_no_of_val_files: 

553 return True 

554 else: 

555 pred_done = False 

556 

557 return pred_done 

558 

559 def _get_auto_epoch(self, epoch, fileno): 

560 """Automatically retrieve best epoch/fileno if they are none.""" 

561 if fileno is None and epoch is None: 

562 epoch, fileno = self.history.get_best_epoch_fileno() 

563 print("Automatically set epoch to epoch {} file {}.".format(epoch, fileno)) 

564 elif fileno is None or epoch is None: 

565 raise ValueError("Either both or none of epoch and fileno must be None") 

566 return epoch, fileno 

567 

568 def get_xs_mean(self, logging=False): 

569 """ 

570 Set and return the zero center image for each list input. 

571 

572 Requires the cfg.zero_center_folder to be set. If no existing 

573 image for the given input files is found in the folder, it will 

574 be calculated and saved by averaging over all samples in the 

575 train dataset. 

576 

577 Parameters 

578 ---------- 

579 logging : bool 

580 If true, the execution of this function will be logged into the 

581 full summary in the output folder if called for the first time. 

582 

583 Returns 

584 ------- 

585 dict 

586 Dict of numpy arrays that contains the mean_image of the x dataset 

587 (1 array per list input). 

588 Example format: 

589 { "input_A" : ndarray, "input_B" : ndarray } 

590 

591 """ 

592 if self.xs_mean is None: 

593 if self.cfg.zero_center_folder is None: 

594 raise ValueError( 

595 "Can not calculate zero center: " "No zero center folder given" 

596 ) 

597 self.xs_mean = load_zero_center_data(self, logging=logging) 

598 return self.xs_mean 

599 

600 def load_saved_model(self, epoch, fileno, logging=False): 

601 """ 

602 Load a saved model. 

603 

604 Parameters 

605 ---------- 

606 epoch : int 

607 Epoch of the saved model. If both this and fileno are -1, 

608 load the most recent model. 

609 fileno : int 

610 Fileno of the saved model. 

611 logging : bool 

612 If True, will log this function call into the log.txt file. 

613 

614 Returns 

615 ------- 

616 model : keras model 

617 

618 """ 

619 path_of_model = self.io.get_model_path(epoch, fileno) 

620 path_loc = self.io.get_model_path(epoch, fileno, local=True) 

621 self.io.print_log("Loading saved model: " + path_loc, logging=logging) 

622 return self._load_model(path_of_model) 

623 

624 def _get_model(self, model, logging=False): 

625 """Load most recent saved model or use user model.""" 

626 latest_epoch = self.io.get_latest_epoch() 

627 

628 if latest_epoch is None: 

629 # new training, log info about model 

630 if model is None: 

631 raise ValueError( 

632 "You need to provide a compiled keras model " 

633 "for the start of the training! (You gave None)" 

634 ) 

635 

636 elif isinstance(model, str): 

637 # path to a saved model 

638 self.io.print_log("Loading model from " + model, logging=logging) 

639 model = self._load_model(model) 

640 

641 if logging: 

642 self._save_as_json(model) 

643 model.summary(print_fn=self.io.print_log) 

644 

645 try: 

646 plots_folder = self.io.get_subfolder("plots", create=True) 

647 tf.keras.utils.plot_model( 

648 model, plots_folder + "/model_plot.png", show_shapes=True 

649 ) 

650 except (ImportError, AttributeError) as e: 

651 # TODO remove AttributeError once https://github.com/tensorflow/tensorflow/issues/38988 is fixed 

652 warnings.warn("Can not plot model: " + str(e)) 

653 

654 else: 

655 # resuming training, load model if it is not given 

656 if model is None: 

657 model = self.load_saved_model(*latest_epoch, logging=logging) 

658 

659 elif isinstance(model, str): 

660 # path to a saved model 

661 self.io.print_log("Loading model from " + model, logging=logging) 

662 model = self._load_model(model) 

663 

664 return model 

665 

666 def _load_model(self, filepath): 

667 """Load from path, with custom objects and parallized.""" 

668 with self.get_strategy().scope(): 

669 model = tf.keras.models.load_model( 

670 filepath, custom_objects=self.cfg.get_custom_objects() 

671 ) 

672 return model 

673 

674 def _save_as_json(self, model): 

675 """Save the architecture of a model as json to fixed path.""" 

676 json_filename = "model_arch.json" 

677 

678 json_string = model.to_json(indent=1) 

679 model_folder = self.io.get_subfolder("saved_models", create=True) 

680 with open(os.path.join(model_folder, json_filename), "w") as f: 

681 f.write(json_string) 

682 

683 def _set_up(self, model, logging=False): 

684 """Necessary setup for training, validating and predicting.""" 

685 if self.cfg.label_modifier is None: 

686 self._setup_auto_lmod(model) 

687 

688 if self.cfg.zero_center_folder is not None: 

689 self.get_xs_mean(logging) 

690 

691 def _setup_auto_lmod(self, model): 

692 """Set up the auto label modifier for the given model.""" 

693 self._auto_label_modifier = lib.label_modifiers.ColumnLabels(model) 

694 

695 def val_is_due(self, epoch=None): 

696 """ 

697 True if validation is due on given epoch according to schedule. 

698 Does not check if it has been done already. 

699 

700 """ 

701 if epoch is None: 

702 epoch = self.io.get_latest_epoch() 

703 n_train_files = self.io.get_no_of_files("train") 

704 val_sched = (epoch[1] == n_train_files) or ( 

705 self.cfg.validate_interval is not None 

706 and epoch[1] % self.cfg.validate_interval == 0 

707 ) 

708 return val_sched 

709 

710 def get_strategy(self): 

711 """Get the strategy for distributed training.""" 

712 if self._strategy is None: 

713 if self.cfg.multi_gpu and len(tf.config.list_physical_devices("GPU")) > 1: 

714 self._strategy = tf.distribute.MirroredStrategy() 

715 print(f"Number of GPUs: {self._strategy.num_replicas_in_sync}") 

716 else: 

717 self._strategy = tf.distribute.get_strategy() 

718 return self._strategy 

719 

720 

721class Configuration(object): 

722 """ 

723 Contains all the configurable options in the OrcaNet scripts. 

724 

725 All of these public attributes (the ones without a 

726 leading underscore) can be changed either directly or with a 

727 .toml config file via the method update_config(). 

728 

729 Parameters 

730 ---------- 

731 output_folder : str 

732 Name of the folder of this model in which everything will be saved, 

733 e.g., the summary.txt log file is located in here. 

734 list_file : str or None 

735 Path to a toml list file with pathes to all the h5 files that should 

736 be used for training and validation. 

737 config_file : str or None 

738 Path to a toml config file with attributes that are used instead of 

739 the default ones. 

740 kwargs 

741 Overwrites the values given in the config file. 

742 

743 Attributes 

744 ---------- 

745 batchsize : int 

746 Batchsize that will be used for the training, validation and inference of 

747 the network. 

748 During training and validation, the last batch in each file will be 

749 skipped if it has fewer samples than the batchsize. 

750 callback_train : keras callback or list or None 

751 Callback or list of callbacks to use during training. 

752 class_weight : dict or None 

753 Optional dictionary mapping class indices (integers) to a weight 

754 (float) value, used for weighting the loss function (during 

755 training only). This can be useful to tell the model to 

756 "pay more attention" to samples from an under-represented class. 

757 cleanup_models : bool 

758 If true, will only keep the best (in terms of val loss) and the most 

759 recent from all saved models in order to save disk space. 

760 custom_objects : dict, optional 

761 Optional dictionary mapping names (strings) to custom classes or 

762 functions to be considered by keras during deserialization of models. 

763 dataset_modifier : function or None 

764 For orga.predict: Function that determines which datasets get created 

765 in the resulting h5 file. Default: save as array, i.e. every output layer 

766 will get one dataset each for both the label and the prediction, 

767 and one dataset containing the y_values from the validation files. 

768 fixed_batchsize : bool 

769 The last batch in the file might be smaller then the batchsize. 

770 Usually, this is no problem, but set to True to skip this batch 

771 [default: False]. 

772 key_x_values : str 

773 The name of the datagroup in the h5 input files which contains 

774 the samples for the network. 

775 key_y_values : str 

776 The name of the datagroup in the h5 input files which contains 

777 the info for the labels. 

778 label_modifier : function or None 

779 Operation to be performed on batches of y_values read from the input 

780 files before they are fed into the model as labels. If None is given, 

781 all y_values with the same name as the output layers will be passed 

782 to the model as a dict, with the keys being the dtype names. 

783 learning_rate : float, tuple, function, str (optional) 

784 The learning rate for the training. 

785 If None is given, don't change the learning rate at all. 

786 If it is a float: The learning rate will be constantly this value. 

787 If it is a tuple of two floats: The first float gives the learning rate 

788 in epoch 1 file 1, and the second float gives the decrease of the 

789 learning rate per file (e.g. 0.1 for 10% decrease per file). 

790 If it is a function: Takes as an input the epoch and the 

791 file number (in this order), and returns the learning rate. 

792 Both epoch and fileno start at 1, i.e. 1, 1 is the start of the 

793 training. 

794 If it is a str: Path to a csv file inside the main folder, containing 

795 3 columns with the epoch, fileno, and the value the lr will be set 

796 to when reaching this epoch/fileno. 

797 max_queue_size : int 

798 max_queue_size option of the keras training and evaluation generator 

799 methods. How many batches get preloaded from the generator. 

800 multi_gpu : bool 

801 Use all availble GPUs (distributed training if theres more then one). 

802 n_events : None or int 

803 For testing purposes. If not the whole .h5 file should be used for 

804 training, define the number of samples. 

805 sample_modifier : function or None 

806 Operation to be performed on batches of x_values read from the input 

807 files before they are fed into the model as samples. 

808 shuffle_train : bool 

809 If true, the order in which batches are read out from the files during 

810 training are randomized each time they are read out. 

811 train_logger_display : int 

812 How many batches should be averaged for one line in the training log files. 

813 train_logger_flush : int 

814 After how many lines the training log file should be flushed (updated on 

815 the disk). -1 for flush at the end of the file only. 

816 use_scratch_ssd : bool 

817 Declares if the input files should be copied to a local temp dir, 

818 i.e. the path defined in the 'TMPDIR' environment variable. 

819 validate_interval : int or None 

820 Validate the model after this many training files have been trained on 

821 in an epoch. There will always be a validation at the end of an epoch. 

822 None for only validate at the end of an epoch. 

823 Example: validate_interval=3 --> Validate after file 3, 6, 9, ... 

824 verbose_train : int 

825 verbose option of keras.model.fit_generator. 

826 0 = silent, 1 = progress bar, 2 = one line per epoch. 

827 verbose_val : int 

828 verbose option of evaluate_generator. 

829 0 = silent, 1 = progress bar. 

830 y_field_names : tuple or list or str, optional 

831 During train and val, read out only these fields from the y dataset. 

832 --> Speed up, especially if there are many fields. 

833 zero_center_folder : None or str 

834 Path to a folder in which zero centering images are stored. 

835 If this path is set, zero centering images for the given dataset will 

836 either be calculated and saved automatically at the start of the 

837 training, or loaded if they have been saved before. 

838 

839 """ 

840 

841 # TODO add a clober script that properly deletes models + logfiles 

842 def __init__(self, output_folder, list_file=None, config_file=None, **kwargs): 

843 self.batchsize = 64 

844 self.callback_train = [] 

845 self.class_weight = None 

846 self.cleanup_models = False 

847 self.custom_objects = {} 

848 self.dataset_modifier = None 

849 self.fixed_batchsize = False 

850 self.key_x_values = "x" 

851 self.key_y_values = "y" 

852 self.label_modifier = None 

853 self.learning_rate = None 

854 self.make_weight_plots = False # Removed in v0.11.1 

855 self.max_queue_size = 10 

856 self.multi_gpu = True 

857 self.n_events = None 

858 self.sample_modifier = None 

859 self.shuffle_train = False 

860 self.train_logger_display = 100 

861 self.train_logger_flush = -1 

862 self.use_scratch_ssd = False 

863 self.validate_interval = None 

864 self.verbose_train = 1 

865 self.verbose_val = 0 

866 self.y_field_names = None 

867 self.zero_center_folder = None 

868 

869 self._default_values = dict(self.__dict__) 

870 

871 # Main folder: 

872 if output_folder[-1] == "/": 

873 self.output_folder = output_folder 

874 else: 

875 self.output_folder = output_folder + "/" 

876 

877 # Private attributes: 

878 self._files_dict = { 

879 "train": None, 

880 "val": None, 

881 "inference": None, 

882 } 

883 self._list_file = None 

884 

885 # Load the optionally given list and config files. 

886 if list_file is not None: 

887 self.import_list_file(list_file) 

888 if config_file is not None: 

889 self.update_config(config_file) 

890 

891 # set given kwargs: 

892 for key, val in kwargs.items(): 

893 if hasattr(self, key): 

894 setattr(self, key, val) 

895 else: 

896 raise AttributeError("Unknown attribute {}".format(key)) 

897 

898 # deprecation warning TODO remove in the future 

899 if self.make_weight_plots: 

900 warnings.warn("make_weight_plots was removed in version v0.11.1") 

901 

902 def import_list_file(self, list_file): 

903 """ 

904 Import the filepaths of the h5 files from a toml list file. 

905 

906 Parameters 

907 ---------- 

908 list_file : str 

909 Path to the toml list file. 

910 

911 """ 

912 if self._list_file is not None: 

913 raise ValueError( 

914 "Can not load list file: Has already been loaded! " 

915 "({})".format(self._list_file) 

916 ) 

917 

918 file_content = toml.load(list_file) 

919 

920 name_mapping = { 

921 "train_files": "train", 

922 "validation_files": "val", 

923 "inference_files": "inference", 

924 } 

925 

926 for toml_name, files_dict_name in name_mapping.items(): 

927 files = _extract_filepaths(file_content, toml_name) 

928 self._files_dict[files_dict_name] = files or None 

929 

930 self._list_file = list_file 

931 

932 def update_config(self, config_file): 

933 """ 

934 Update the default cfg parameters with values from a toml config file. 

935 

936 Parameters 

937 ---------- 

938 config_file : str 

939 Path to a toml config file. 

940 

941 """ 

942 user_values = toml.load(config_file)["config"] 

943 for key, value in user_values.items(): 

944 if hasattr(self, key): 

945 if key == "sample_modifier": 

946 value = orcanet.misc.from_register( 

947 toml_entry=value, register=lib.sample_modifiers.smods 

948 ) 

949 elif key == "dataset_modifier": 

950 value = orcanet.misc.from_register( 

951 toml_entry=value, register=lib.dataset_modifiers.dmods 

952 ) 

953 elif key == "label_modifier": 

954 value = orcanet.misc.from_register( 

955 toml_entry=value, register=lib.label_modifiers.lmods 

956 ) 

957 setattr(self, key, value) 

958 else: 

959 raise AttributeError(f"Unknown attribute {key} in config file") 

960 

961 def get_list_file(self): 

962 """ 

963 Returns the path to the list file that was used to set the training 

964 and validation files. None if no list file has been used. 

965 

966 """ 

967 return self._list_file 

968 

969 def get_files(self, which): 

970 """ 

971 Get the training or validation file paths for each list input set. 

972 

973 Parameters 

974 ---------- 

975 which : str 

976 Either "train", "val" or "inference". 

977 

978 Returns 

979 ------- 

980 dict 

981 A dict containing the paths to the training or validation files on 

982 which the model will be trained on. Example for the format for 

983 two input sets with two files each: 

984 { 

985 "input_A" : ('path/to/set_A_file_1.h5', 'path/to/set_A_file_2.h5'), 

986 "input_B" : ('path/to/set_B_file_1.h5', 'path/to/set_B_file_2.h5'), 

987 } 

988 

989 """ 

990 if which not in self._files_dict.keys(): 

991 raise NameError("Unknown fileset name ", which) 

992 if self._files_dict[which] is None: 

993 raise AttributeError("No {} files have been specified!".format(which)) 

994 return self._files_dict[which] 

995 

996 def get_custom_objects(self): 

997 """Get user custom objects + orcanet internal ones.""" 

998 orcanet_co = medgeconv.custom_objects 

999 orcanet_loss_functions = lib.losses.loss_functions 

1000 return {**orcanet_co, **orcanet_loss_functions, **self.custom_objects} 

1001 

1002 

1003def _get_h5_files(folder): 

1004 h5files = [] 

1005 for f in os.listdir(folder): 

1006 if f.endswith(".h5"): 

1007 h5files.append(os.path.join(folder, f)) 

1008 h5files.sort() 

1009 if not h5files: 

1010 warnings.warn(f"No .h5 files in dir {folder}!") 

1011 return h5files 

1012 

1013 

1014def _extract_filepaths(file_content, which): 

1015 """ 

1016 Get train, val or inf filepaths of all inputs from a toml readout. 

1017 Makes sure that all input have the same number of files. 

1018 

1019 """ 

1020 # alternative names to write in the toml file 

1021 aliases = { 

1022 "train_files": ("training_files", "train", "training"), 

1023 "validation_files": ("val_files", "val", "validation"), 

1024 "inference_files": ("inf_files", "inf", "inference"), 

1025 } 

1026 assert which in aliases.keys(), f"{which} not in {list(aliases.keys())}" 

1027 

1028 def get_alias(ident): 

1029 for k, v in aliases.items(): 

1030 if ident == k or ident in v: 

1031 return k 

1032 else: 

1033 raise NameError( 

1034 f"Unknown argument '{ident}' in toml file: " 

1035 f"Must be either of {list(aliases.keys())}" 

1036 ) 

1037 

1038 files = {} 

1039 n_files = [] 

1040 for input_name, input_files in file_content.items(): 

1041 for filetype, filetyp_files in input_files.items(): 

1042 if get_alias(filetype) != which: 

1043 continue 

1044 # if a dir is given as a filepath, use all h5 files in that dir instead 

1045 expanded_files = [] 

1046 for path in filetyp_files: 

1047 if os.path.isdir(path): 

1048 expanded_files.extend(_get_h5_files(path)) 

1049 else: 

1050 expanded_files.append(path) 

1051 files[input_name] = tuple(expanded_files) 

1052 # store number of files for this output 

1053 n_files.append(len(expanded_files)) 

1054 

1055 if n_files and n_files.count(n_files[0]) != len(n_files): 

1056 raise ValueError("Input with different number of {} in toml list".format(which)) 

1057 

1058 return files