Coverage for orcanet/model_builder.py: 58%

138 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-03-28 14:22 +0000

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3""" 

4Scripts for making specific models. 

5""" 

6 

7import warnings 

8import toml 

9from datetime import datetime 

10import tensorflow as tf 

11import tensorflow.keras as ks 

12import tensorflow.keras.layers as layers 

13 

14from orcanet.builder_util.builders import BlockBuilder 

15 

16 

17class ModelBuilder: 

18 """ 

19 Build and compile a keras model from a toml file, using OrcaNet building blocks. 

20 

21 The input of the model can match the dimensions of the input 

22 data given to the Organizer taking into account the sample 

23 modifier. 

24 

25 Attributes 

26 ---------- 

27 configs : list 

28 List with keywords for building each layer block in the model. 

29 defaults : dict 

30 Default values for the layer blocks in the model. 

31 optimizer : str or Optimizer 

32 Optimizer for training the model. Can be a string like "adam" (or 

33 "keras:adam" for the default keras variant), or an object derived 

34 from ks.optimizers.Optimizer. 

35 compile_opt : dict 

36 Keys: Names of the output layers of the model. 

37 Values: Loss function, optionally weight and optionally metric of 

38 each output layer. 

39 Format: { layer_name : { loss_function:, weight:, metrics: } } 

40 The loss_function is a string or a function, the weight is a float 

41 and metrics is a list of functions/strings. 

42 optimizer_args : dict, optional 

43 Kwargs for the optimizer. Not used when an optimizer object is given. 

44 input_opts : dict 

45 Specify options for the input of the model. 

46 

47 Methods 

48 ------- 

49 build 

50 Build the network using an instance of Organizer. 

51 build_with_input 

52 Build the network without an Organizer, just using given input shapes. 

53 compile 

54 Compile a model with the optimizer settings given in the model_file. 

55 

56 """ 

57 

58 def __init__(self, model_file, **custom_blocks): 

59 """ 

60 Read out parameters for creating models with OrcaNet from a toml file. 

61 

62 Parameters 

63 ---------- 

64 model_file : str 

65 Path to the model toml file. 

66 custom_blocks 

67 For building models with custom blocks in the toml: 

68 Custom block names as kwargs ('toml name'='block'). 

69 

70 """ 

71 file_content = toml.load(model_file) 

72 self.custom_blocks = custom_blocks 

73 

74 try: 

75 if "model" in file_content: 

76 model_args = file_content["model"] 

77 self.configs = model_args.pop("blocks") 

78 self.input_opts = model_args.pop("input_opts", {}) 

79 self.defaults = model_args 

80 

81 elif "body" in file_content: 

82 # legacy 

83 self._compat_init(file_content) 

84 

85 self.optimizer = None 

86 self.compile_opt = None 

87 self.optimizer_args = {} 

88 if "compile" in file_content: 

89 compile_sect = file_content["compile"] 

90 self.optimizer = compile_sect.pop("optimizer", None) 

91 self.compile_opt = compile_sect.pop("losses", None) 

92 self.optimizer_args = compile_sect 

93 

94 except KeyError as e: 

95 if len(e.args) == 1: 

96 option = e.args[0] 

97 else: 

98 option = e.args 

99 raise KeyError( 

100 "Missing parameter in toml model file: " + str(option) 

101 ) from None 

102 

103 def _compat_init(self, file_content): 

104 warnings.warn( 

105 "The format of this model toml file is deprecated, consider " 

106 "updating it to the new format (see online docu)." 

107 ) 

108 # legacy 

109 body = file_content["body"] 

110 if "architecture" in body: 

111 arch = body.pop("architecture") 

112 if arch != "single": 

113 raise ValueError("architecture keyword is deprecated") 

114 self.configs = body.pop("blocks") 

115 self.defaults = body 

116 

117 if "head" in file_content: 

118 head = file_content["head"] 

119 head_arch = head.pop("architecture") 

120 head_arch_args = head.pop("architecture_args") 

121 head_args = head 

122 

123 head_block_config = head_arch_args 

124 head_block_config["type"] = head_arch 

125 self.configs.append({**head_block_config, **head_args}) 

126 

127 def build(self, orga, log_comp_opts=False, verbose=False): 

128 """ 

129 Build the network using an instance of Organizer. 

130 

131 Input layers will be adapted to the input files in the organizer. 

132 Can also add the matching modifiers and custom objects to the orga. 

133 

134 Parameters 

135 ---------- 

136 orga : orcanet.core.Organizer 

137 Contains all the configurable options in the OrcaNet scripts. 

138 log_comp_opts : bool 

139 If the info used for the compilation of the model should be 

140 logged to the log.txt. 

141 verbose : bool 

142 Print info about the building process? 

143 

144 Returns 

145 ------- 

146 model : keras model 

147 The network. 

148 

149 """ 

150 if orga.cfg.fixed_batchsize: 

151 if ( 

152 "batchsize" in self.input_opts 

153 and self.input_opts["batchsize"] != orga.cfg.batchsize 

154 ): 

155 raise ValueError( 

156 f"Batchsize in input_opts is {self.input_opts['batchsize']}, " 

157 f"but in cfg its {orga.cfg.batchsize}" 

158 ) 

159 self.input_opts["batchsize"] = orga.cfg.batchsize 

160 

161 with orga.get_strategy().scope(): 

162 model = self.build_with_input( 

163 orga.io.get_input_shapes(), 

164 compile_model=True, 

165 custom_objects=orga.cfg.get_custom_objects(), 

166 verbose=verbose, 

167 ) 

168 

169 if log_comp_opts: 

170 self.log_model_properties(orga) 

171 model.summary() 

172 return model 

173 

174 def build_with_input( 

175 self, input_shapes, compile_model=True, custom_objects=None, verbose=False 

176 ): 

177 """ 

178 Build the network with given input shapes. 

179 

180 Parameters 

181 ---------- 

182 input_shapes : dict 

183 Keys: Name of the inputs of the model. 

184 Values: Their shape without the batchsize. 

185 compile_model : bool 

186 Compile the model? 

187 custom_objects : dict, optional 

188 Custom objects to use during compiling. 

189 verbose : bool 

190 Print info about the building process? 

191 

192 Returns 

193 ------- 

194 model : ks.Model 

195 The network. 

196 

197 """ 

198 builder = BlockBuilder( 

199 self.defaults, 

200 verbose=verbose, 

201 input_opts=self.input_opts, 

202 **self.custom_blocks, 

203 ) 

204 model = builder.build(input_shapes, self.configs) 

205 

206 if compile_model: 

207 self.compile_model(model, custom_objects=custom_objects) 

208 

209 return model 

210 

211 # def merge_models(self, model_list, trainable=False, stateful=True, 

212 # no_drop=True): 

213 # """ 

214 # Concatenate two or more single input cnns to a big one. 

215 # 

216 # It will explicitly look for a Flatten layer and cut after it, 

217 # Concatenate all models, and then add the head layers. 

218 # 

219 # Parameters 

220 # ---------- 

221 # model_list : list 

222 # List of keras models to stitch together. 

223 # trainable : bool 

224 # Whether the layers of the loaded models will be trainable. 

225 # stateful : bool 

226 # Whether the batchnorms of the loaded models will be stateful. 

227 # no_drop : bool 

228 # If true, rate of dropout layers from loaded models will 

229 # be set to zero. 

230 # 

231 # Returns 

232 # ------- 

233 # model : keras model 

234 # The uncompiled merged keras model. 

235 # 

236 # """ 

237 # # Get the input and Flatten layers in each of the given models 

238 # input_layers, flattens = [], [] 

239 # for i, model in enumerate(model_list): 

240 # if len(model.inputs) != 1: 

241 # raise ValueError( 

242 # "model input is not length 1 {}".format(model.inputs)) 

243 # input_layers.append(model.input) 

244 # flatten_found = 0 

245 # for layer in model.layers: 

246 # layer.trainable = trainable 

247 # layer.name = layer.name + '_net_' + str(i) 

248 # if isinstance(layer, layers.BatchNormalization): 

249 # layer.stateful = stateful 

250 # elif isinstance(layer, layers.Flatten): 

251 # flattens.append(layer.output) 

252 # flatten_found += 1 

253 # if flatten_found != 1: 

254 # raise TypeError( 

255 # "Expected 1 Flatten layer but got " + str(flatten_found)) 

256 # 

257 # # attach new head 

258 # x = layers.Concatenate()(flattens) 

259 # builder = BlockBuilder(body_defaults=None, 

260 # head_defaults=self.head_args) 

261 # output_layer = builder.attach_output_layers(x, self.head_arch, 

262 # flatten=False, 

263 # **self.head_arch_args) 

264 # 

265 # model = ks.models.Model(input_layers, output_layer) 

266 # if no_drop: 

267 # model = change_dropout_rate(model, before_concat=0.) 

268 # 

269 # return model 

270 

271 def compile_model(self, model, custom_objects=None): 

272 """ 

273 Compile a model with the optimizer settings given as the attributes. 

274 

275 Parameters 

276 ---------- 

277 model : ks.model 

278 A keras model. 

279 custom_objects : dict or None 

280 Maps names (strings) to custom loss functions. 

281 

282 Returns 

283 ------- 

284 model : keras model 

285 The compiled (or recompiled) keras model. 

286 

287 """ 

288 if any((self.optimizer is None, self.compile_opt is None)): 

289 raise ValueError("Can not compile, need optimizer name and losses") 

290 

291 loss_functions, loss_weights, loss_metrics = {}, {}, {} 

292 for layer_name, layer_info in self.compile_opt.items(): 

293 # Replace the str function name with actual function if it is custom 

294 loss_function = layer_info["function"] 

295 if custom_objects is not None and loss_function in custom_objects: 

296 loss_function = custom_objects[loss_function] 

297 loss_functions[layer_name] = loss_function 

298 

299 # Use given weight, else use default weight of 1 

300 if "weight" in layer_info: 

301 weight = layer_info["weight"] 

302 else: 

303 weight = 1.0 

304 loss_weights[layer_name] = weight 

305 

306 # Use given metrics, else use no metrics 

307 if "metrics" in layer_info: 

308 metrics = layer_info["metrics"] 

309 else: 

310 metrics = [] 

311 if custom_objects is not None: 

312 for i, metric in enumerate(metrics): 

313 if metric in custom_objects: 

314 metrics[i] = custom_objects[metric] 

315 

316 loss_metrics[layer_name] = metrics 

317 

318 optimizer = self._get_optimizer() 

319 model.compile( 

320 loss=loss_functions, 

321 optimizer=optimizer, 

322 metrics=loss_metrics, 

323 loss_weights=loss_weights, 

324 ) 

325 return model 

326 

327 def log_model_properties(self, orga): 

328 """ 

329 Writes the compile_opt config to the full log file. 

330 """ 

331 lines = list() 

332 lines.append("-" * 60) 

333 time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") 

334 lines.append("-" * 19 + " {} ".format(time) + "-" * 19) 

335 lines.append( 

336 "A model has been built using the model builder with the following configurations:\n" 

337 ) 

338 lines.append("Loss functions: ") 

339 for key in self.compile_opt: 

340 lines.append(key + ": " + str(self.compile_opt[key])) 

341 lines.append("\n") 

342 orga.io.print_log(lines) 

343 

344 def _get_optimizer(self): 

345 if not isinstance(self.optimizer, str): 

346 if self.optimizer_args: 

347 warnings.warn( 

348 "Custom callback used, optimizer_args are ignored: " 

349 + str(self.optimizer_args) 

350 ) 

351 return self.optimizer 

352 if self.optimizer == "adam": 

353 optimizer = get_adam(**self.optimizer_args) 

354 elif self.optimizer == "sgd": 

355 optimizer = get_sgd(**self.optimizer_args) 

356 elif self.optimizer.startswith("keras:"): 

357 optimizer = getattr(ks.optimizers, self.optimizer.split("keras:")[-1])( 

358 **self.optimizer_args 

359 ) 

360 else: 

361 raise NameError("Unknown optimizer name ({})".format(self.optimizer)) 

362 return optimizer 

363 

364 

365def get_adam(beta_1=0.9, beta_2=0.999, epsilon=0.1, decay=0.0, **kwargs): 

366 # epsilon=1 for deep networks 

367 return ks.optimizers.Adam( 

368 beta_1=beta_1, beta_2=beta_2, epsilon=epsilon, decay=decay, **kwargs 

369 ) 

370 

371 

372def get_sgd(momentum=0.9, decay=0, nesterov=True, **kwargs): 

373 return ks.optimizers.SGD( 

374 momentum=momentum, decay=decay, nesterov=nesterov, **kwargs 

375 ) 

376 

377 

378def _change_dropout_rate(model, before_concat, after_concat=None): 

379 """ 

380 Change the dropout rate in a model. 

381 

382 # TODO untested for tf 2.x! 

383 

384 Only for models with a concatenate layer, aka multiple 

385 single input models that were merged together. 

386 

387 Parameters 

388 ---------- 

389 model : keras model 

390 

391 before_concat : float 

392 New dropout rate before the concatenate layer in the model. 

393 after_concat : float or None 

394 New dropout rate after the concatenate layer. None will leave the 

395 dropout rate there as it was. 

396 

397 """ 

398 ch_bef, ch_aft, concat_found = 0, 0, 0 

399 

400 for layer in model.layers: 

401 if isinstance(layer, layers.Dropout): 

402 if concat_found == 0: 

403 layer.rate = before_concat 

404 ch_bef += 1 

405 else: 

406 layer.rate = after_concat 

407 ch_aft += 1 

408 

409 elif isinstance(layer, layers.Concatenate): 

410 concat_found += 1 

411 if after_concat is None: 

412 break 

413 

414 if concat_found != 1: 

415 raise TypeError("Expected 1 Flatten layer but got " + str(concat_found)) 

416 clone = ks.models.clone_model(model) 

417 clone.set_weights(model.get_weights()) 

418 print( 

419 "Changed dropout rates of {} layers before and {} layers after " 

420 "Concatenate.".format(ch_bef, ch_aft) 

421 ) 

422 return clone