Coverage for orcanet/model_builder.py: 58%
138 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4Scripts for making specific models.
5"""
7import warnings
8import toml
9from datetime import datetime
10import tensorflow as tf
11import tensorflow.keras as ks
12import tensorflow.keras.layers as layers
14from orcanet.builder_util.builders import BlockBuilder
17class ModelBuilder:
18 """
19 Build and compile a keras model from a toml file, using OrcaNet building blocks.
21 The input of the model can match the dimensions of the input
22 data given to the Organizer taking into account the sample
23 modifier.
25 Attributes
26 ----------
27 configs : list
28 List with keywords for building each layer block in the model.
29 defaults : dict
30 Default values for the layer blocks in the model.
31 optimizer : str or Optimizer
32 Optimizer for training the model. Can be a string like "adam" (or
33 "keras:adam" for the default keras variant), or an object derived
34 from ks.optimizers.Optimizer.
35 compile_opt : dict
36 Keys: Names of the output layers of the model.
37 Values: Loss function, optionally weight and optionally metric of
38 each output layer.
39 Format: { layer_name : { loss_function:, weight:, metrics: } }
40 The loss_function is a string or a function, the weight is a float
41 and metrics is a list of functions/strings.
42 optimizer_args : dict, optional
43 Kwargs for the optimizer. Not used when an optimizer object is given.
44 input_opts : dict
45 Specify options for the input of the model.
47 Methods
48 -------
49 build
50 Build the network using an instance of Organizer.
51 build_with_input
52 Build the network without an Organizer, just using given input shapes.
53 compile
54 Compile a model with the optimizer settings given in the model_file.
56 """
58 def __init__(self, model_file, **custom_blocks):
59 """
60 Read out parameters for creating models with OrcaNet from a toml file.
62 Parameters
63 ----------
64 model_file : str
65 Path to the model toml file.
66 custom_blocks
67 For building models with custom blocks in the toml:
68 Custom block names as kwargs ('toml name'='block').
70 """
71 file_content = toml.load(model_file)
72 self.custom_blocks = custom_blocks
74 try:
75 if "model" in file_content:
76 model_args = file_content["model"]
77 self.configs = model_args.pop("blocks")
78 self.input_opts = model_args.pop("input_opts", {})
79 self.defaults = model_args
81 elif "body" in file_content:
82 # legacy
83 self._compat_init(file_content)
85 self.optimizer = None
86 self.compile_opt = None
87 self.optimizer_args = {}
88 if "compile" in file_content:
89 compile_sect = file_content["compile"]
90 self.optimizer = compile_sect.pop("optimizer", None)
91 self.compile_opt = compile_sect.pop("losses", None)
92 self.optimizer_args = compile_sect
94 except KeyError as e:
95 if len(e.args) == 1:
96 option = e.args[0]
97 else:
98 option = e.args
99 raise KeyError(
100 "Missing parameter in toml model file: " + str(option)
101 ) from None
103 def _compat_init(self, file_content):
104 warnings.warn(
105 "The format of this model toml file is deprecated, consider "
106 "updating it to the new format (see online docu)."
107 )
108 # legacy
109 body = file_content["body"]
110 if "architecture" in body:
111 arch = body.pop("architecture")
112 if arch != "single":
113 raise ValueError("architecture keyword is deprecated")
114 self.configs = body.pop("blocks")
115 self.defaults = body
117 if "head" in file_content:
118 head = file_content["head"]
119 head_arch = head.pop("architecture")
120 head_arch_args = head.pop("architecture_args")
121 head_args = head
123 head_block_config = head_arch_args
124 head_block_config["type"] = head_arch
125 self.configs.append({**head_block_config, **head_args})
127 def build(self, orga, log_comp_opts=False, verbose=False):
128 """
129 Build the network using an instance of Organizer.
131 Input layers will be adapted to the input files in the organizer.
132 Can also add the matching modifiers and custom objects to the orga.
134 Parameters
135 ----------
136 orga : orcanet.core.Organizer
137 Contains all the configurable options in the OrcaNet scripts.
138 log_comp_opts : bool
139 If the info used for the compilation of the model should be
140 logged to the log.txt.
141 verbose : bool
142 Print info about the building process?
144 Returns
145 -------
146 model : keras model
147 The network.
149 """
150 if orga.cfg.fixed_batchsize:
151 if (
152 "batchsize" in self.input_opts
153 and self.input_opts["batchsize"] != orga.cfg.batchsize
154 ):
155 raise ValueError(
156 f"Batchsize in input_opts is {self.input_opts['batchsize']}, "
157 f"but in cfg its {orga.cfg.batchsize}"
158 )
159 self.input_opts["batchsize"] = orga.cfg.batchsize
161 with orga.get_strategy().scope():
162 model = self.build_with_input(
163 orga.io.get_input_shapes(),
164 compile_model=True,
165 custom_objects=orga.cfg.get_custom_objects(),
166 verbose=verbose,
167 )
169 if log_comp_opts:
170 self.log_model_properties(orga)
171 model.summary()
172 return model
174 def build_with_input(
175 self, input_shapes, compile_model=True, custom_objects=None, verbose=False
176 ):
177 """
178 Build the network with given input shapes.
180 Parameters
181 ----------
182 input_shapes : dict
183 Keys: Name of the inputs of the model.
184 Values: Their shape without the batchsize.
185 compile_model : bool
186 Compile the model?
187 custom_objects : dict, optional
188 Custom objects to use during compiling.
189 verbose : bool
190 Print info about the building process?
192 Returns
193 -------
194 model : ks.Model
195 The network.
197 """
198 builder = BlockBuilder(
199 self.defaults,
200 verbose=verbose,
201 input_opts=self.input_opts,
202 **self.custom_blocks,
203 )
204 model = builder.build(input_shapes, self.configs)
206 if compile_model:
207 self.compile_model(model, custom_objects=custom_objects)
209 return model
211 # def merge_models(self, model_list, trainable=False, stateful=True,
212 # no_drop=True):
213 # """
214 # Concatenate two or more single input cnns to a big one.
215 #
216 # It will explicitly look for a Flatten layer and cut after it,
217 # Concatenate all models, and then add the head layers.
218 #
219 # Parameters
220 # ----------
221 # model_list : list
222 # List of keras models to stitch together.
223 # trainable : bool
224 # Whether the layers of the loaded models will be trainable.
225 # stateful : bool
226 # Whether the batchnorms of the loaded models will be stateful.
227 # no_drop : bool
228 # If true, rate of dropout layers from loaded models will
229 # be set to zero.
230 #
231 # Returns
232 # -------
233 # model : keras model
234 # The uncompiled merged keras model.
235 #
236 # """
237 # # Get the input and Flatten layers in each of the given models
238 # input_layers, flattens = [], []
239 # for i, model in enumerate(model_list):
240 # if len(model.inputs) != 1:
241 # raise ValueError(
242 # "model input is not length 1 {}".format(model.inputs))
243 # input_layers.append(model.input)
244 # flatten_found = 0
245 # for layer in model.layers:
246 # layer.trainable = trainable
247 # layer.name = layer.name + '_net_' + str(i)
248 # if isinstance(layer, layers.BatchNormalization):
249 # layer.stateful = stateful
250 # elif isinstance(layer, layers.Flatten):
251 # flattens.append(layer.output)
252 # flatten_found += 1
253 # if flatten_found != 1:
254 # raise TypeError(
255 # "Expected 1 Flatten layer but got " + str(flatten_found))
256 #
257 # # attach new head
258 # x = layers.Concatenate()(flattens)
259 # builder = BlockBuilder(body_defaults=None,
260 # head_defaults=self.head_args)
261 # output_layer = builder.attach_output_layers(x, self.head_arch,
262 # flatten=False,
263 # **self.head_arch_args)
264 #
265 # model = ks.models.Model(input_layers, output_layer)
266 # if no_drop:
267 # model = change_dropout_rate(model, before_concat=0.)
268 #
269 # return model
271 def compile_model(self, model, custom_objects=None):
272 """
273 Compile a model with the optimizer settings given as the attributes.
275 Parameters
276 ----------
277 model : ks.model
278 A keras model.
279 custom_objects : dict or None
280 Maps names (strings) to custom loss functions.
282 Returns
283 -------
284 model : keras model
285 The compiled (or recompiled) keras model.
287 """
288 if any((self.optimizer is None, self.compile_opt is None)):
289 raise ValueError("Can not compile, need optimizer name and losses")
291 loss_functions, loss_weights, loss_metrics = {}, {}, {}
292 for layer_name, layer_info in self.compile_opt.items():
293 # Replace the str function name with actual function if it is custom
294 loss_function = layer_info["function"]
295 if custom_objects is not None and loss_function in custom_objects:
296 loss_function = custom_objects[loss_function]
297 loss_functions[layer_name] = loss_function
299 # Use given weight, else use default weight of 1
300 if "weight" in layer_info:
301 weight = layer_info["weight"]
302 else:
303 weight = 1.0
304 loss_weights[layer_name] = weight
306 # Use given metrics, else use no metrics
307 if "metrics" in layer_info:
308 metrics = layer_info["metrics"]
309 else:
310 metrics = []
311 if custom_objects is not None:
312 for i, metric in enumerate(metrics):
313 if metric in custom_objects:
314 metrics[i] = custom_objects[metric]
316 loss_metrics[layer_name] = metrics
318 optimizer = self._get_optimizer()
319 model.compile(
320 loss=loss_functions,
321 optimizer=optimizer,
322 metrics=loss_metrics,
323 loss_weights=loss_weights,
324 )
325 return model
327 def log_model_properties(self, orga):
328 """
329 Writes the compile_opt config to the full log file.
330 """
331 lines = list()
332 lines.append("-" * 60)
333 time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
334 lines.append("-" * 19 + " {} ".format(time) + "-" * 19)
335 lines.append(
336 "A model has been built using the model builder with the following configurations:\n"
337 )
338 lines.append("Loss functions: ")
339 for key in self.compile_opt:
340 lines.append(key + ": " + str(self.compile_opt[key]))
341 lines.append("\n")
342 orga.io.print_log(lines)
344 def _get_optimizer(self):
345 if not isinstance(self.optimizer, str):
346 if self.optimizer_args:
347 warnings.warn(
348 "Custom callback used, optimizer_args are ignored: "
349 + str(self.optimizer_args)
350 )
351 return self.optimizer
352 if self.optimizer == "adam":
353 optimizer = get_adam(**self.optimizer_args)
354 elif self.optimizer == "sgd":
355 optimizer = get_sgd(**self.optimizer_args)
356 elif self.optimizer.startswith("keras:"):
357 optimizer = getattr(ks.optimizers, self.optimizer.split("keras:")[-1])(
358 **self.optimizer_args
359 )
360 else:
361 raise NameError("Unknown optimizer name ({})".format(self.optimizer))
362 return optimizer
365def get_adam(beta_1=0.9, beta_2=0.999, epsilon=0.1, decay=0.0, **kwargs):
366 # epsilon=1 for deep networks
367 return ks.optimizers.Adam(
368 beta_1=beta_1, beta_2=beta_2, epsilon=epsilon, decay=decay, **kwargs
369 )
372def get_sgd(momentum=0.9, decay=0, nesterov=True, **kwargs):
373 return ks.optimizers.SGD(
374 momentum=momentum, decay=decay, nesterov=nesterov, **kwargs
375 )
378def _change_dropout_rate(model, before_concat, after_concat=None):
379 """
380 Change the dropout rate in a model.
382 # TODO untested for tf 2.x!
384 Only for models with a concatenate layer, aka multiple
385 single input models that were merged together.
387 Parameters
388 ----------
389 model : keras model
391 before_concat : float
392 New dropout rate before the concatenate layer in the model.
393 after_concat : float or None
394 New dropout rate after the concatenate layer. None will leave the
395 dropout rate there as it was.
397 """
398 ch_bef, ch_aft, concat_found = 0, 0, 0
400 for layer in model.layers:
401 if isinstance(layer, layers.Dropout):
402 if concat_found == 0:
403 layer.rate = before_concat
404 ch_bef += 1
405 else:
406 layer.rate = after_concat
407 ch_aft += 1
409 elif isinstance(layer, layers.Concatenate):
410 concat_found += 1
411 if after_concat is None:
412 break
414 if concat_found != 1:
415 raise TypeError("Expected 1 Flatten layer but got " + str(concat_found))
416 clone = ks.models.clone_model(model)
417 clone.set_weights(model.get_weights())
418 print(
419 "Changed dropout rates of {} layers before and {} layers after "
420 "Concatenate.".format(ch_bef, ch_aft)
421 )
422 return clone