Coverage for orcanet/builder_util/builders.py: 77%
97 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-03-28 14:22 +0000
1import inspect
2import warnings
3import tensorflow.keras as ks
4import tensorflow.keras.layers as layers
6import orcanet.builder_util.layer_blocks as layer_blocks
9class BlockBuilder:
10 """
11 Builds single-input block-wise sequential neural network.
13 Parameters
14 ----------
15 defaults : dict or None
16 Default values for all blocks in the model.
17 verbose : bool
18 Print info about the building process?
19 batch_size : int, optional
20 Define a fixed batchsize for the input.
22 """
24 def __init__(self, defaults=None, verbose=False, input_opts=None, **kwargs):
25 """
26 Set dict with default values for the layers of the model.
27 Can also define custom block names as kwargs (key = toml name,
28 value = block).
29 """
30 # dict with toml keyword vs block for all custom blocks
31 self.all_blocks = {
32 **layer_blocks.blocks,
33 # legacy names:
34 "conv_block": layer_blocks.ConvBlock,
35 "dense_block": layer_blocks.DenseBlock,
36 "resnet_block": layer_blocks.ResnetBlock,
37 "resnet_bneck_block": layer_blocks.ResnetBnetBlock,
38 "categorical": _attach_output_cat,
39 "gpool": _attach_output_gpool_categ,
40 "gpool_categ": _attach_output_gpool_categ,
41 "gpool_reg": layer_blocks.OutputReg,
42 "regression_error": layer_blocks.OutputRegErr,
43 }
45 if kwargs:
46 self.all_blocks = {**self.all_blocks, **kwargs}
48 self._check_arguments(defaults)
49 self.defaults = defaults
50 self.verbose = verbose
51 if input_opts is None:
52 self.input_opts = {}
53 else:
54 self.input_opts = input_opts
56 def build(self, input_shape, configs):
57 """
58 Build the whole model, using the default values when arguments
59 are missing in the layer_configs.
61 Parameters
62 ----------
63 input_shape : dict
64 Name and shape of the input layer.
65 configs : list
66 List of configurations for the blocks in the model.
67 Each element in the list is a dict and will result in one block
68 connected to the previous one. The dict has to contain the type
69 of the block, as well as any arguments required by that
70 specific block type.
72 Returns
73 -------
74 model : keras model
76 """
77 input_layer = get_input_block(input_shape, **self.input_opts)
79 x = input_layer
80 for layer_config in configs:
81 x = self.attach_block(x, layer_config)
83 return ks.models.Model(inputs=input_layer, outputs=x)
85 def attach_block(self, layer, layer_config):
86 """
87 Attach a block to the given layer based on the layer config.
89 Will use the default values given during initialization if they are not
90 present in the layer config.
92 Parameters
93 ----------
94 layer : keras layer
95 Layer to attach the block to.
96 layer_config : dict
97 Configuration of the block to attach. The dict has to contain
98 the type of the block, as well as any arguments required by that
99 specific block.
101 Returns
102 -------
103 keras layer
105 """
106 filled = self._with_defaults(layer_config, self.defaults)
107 if self.verbose:
108 print(f"Attaching layer {filled} to tensor {layer}")
109 block = self._get_blocks(filled.pop("type"))
110 return block(**filled)(layer)
112 def _with_defaults(self, config, defaults):
113 """Make a copy of a layer config and complete it with default values
114 for its block, if they are missing in the layer config.
115 """
116 conf = dict(config)
118 if config is not None and "type" in config:
119 block_name = config["type"]
120 elif defaults is not None and "type" in defaults:
121 block_name = defaults["type"]
122 conf["type"] = defaults["type"]
123 else:
124 raise KeyError("No layer block type specified")
126 block = self._get_blocks(block_name)
127 args = list(inspect.signature(block.__init__).parameters.keys())
129 if defaults is not None:
130 for key, val in defaults.items():
131 if key in args and key not in conf:
132 conf[key] = val
134 return conf
136 def _get_blocks(self, name=None):
137 """Get the block class/function depending on the name."""
138 if name is None:
139 return self.all_blocks
140 elif name.startswith("keras:"):
141 return getattr(ks.layers, name.split("keras:")[1])
142 elif name in self.all_blocks:
143 return self.all_blocks[name]
144 else:
145 raise NameError(
146 f"Unknown block type: {name}, must either start with "
147 f"'keras:', or be one of {list(self.all_blocks.keys())}"
148 )
150 def _check_arguments(self, defaults):
151 """Check if given defaults appear in at least one block."""
152 if defaults is None:
153 return
154 # possible arguments for all blocks
155 psb_args = [
156 "type",
157 ]
158 for block in self._get_blocks().values():
159 args = list(inspect.signature(block.__init__).parameters.keys())
160 for arg in args:
161 if arg not in psb_args and arg != "kwargs":
162 psb_args.append(arg)
164 for t_def in defaults.keys():
165 if t_def not in psb_args:
166 warnings.warn(
167 f"Unknown default argument: {t_def} (has to appear in a block)"
168 )
171def get_input_block(input_shapes, batchsize=None, names=None):
172 """
173 Build input layers according to a dict mapping the layer names to shapes.
174 If none appears in shape, input is ragged.
176 Parameters
177 ----------
178 input_shapes : dict
179 Keys: Input layer names.
180 Values: Their shapes.
181 batchsize : int, optional
182 Specify fixed batchsize.
183 names : tuple, optional
184 Make sure the inputs are these names and return them in this order.
186 Returns
187 -------
188 inputs : tf.Tensor or tuple
189 A list of named keras input layers, or the input Tensor if there
190 is only one input.
192 """
193 if names is None:
194 input_names = list(input_shapes.keys())
195 else:
196 if not set(names) == set(input_shapes.keys()):
197 raise ValueError(
198 f"Invalid input names: Expected {names} "
199 f"got {list(input_shapes.keys())}"
200 )
201 input_names = names
203 inputs = []
204 for input_name in input_names:
205 inputs.append(
206 layers.Input(
207 shape=input_shapes[input_name],
208 name=input_name,
209 dtype=ks.backend.floatx(),
210 batch_size=batchsize,
211 ragged=None in input_shapes[input_name],
212 )
213 )
215 if len(inputs) == 1:
216 return inputs[0]
217 else:
218 return tuple(inputs)
221class _attach_output_cat:
222 # legacy
223 def __init__(self, categories, output_name, flatten=True):
224 self.categories = categories
225 self.output_name = output_name
226 self.flatten = flatten
228 def __call__(self, layer):
229 if self.flatten:
230 transition = "keras:Flatten"
231 else:
232 transition = None
233 out = layer_blocks.OutputCateg(
234 categories=self.categories,
235 output_name=self.output_name,
236 transition=transition,
237 unit_list=(128, 32),
238 )(layer)
239 return out
242class _attach_output_gpool_categ:
243 # legacy
244 def __init__(self, categories, output_name, dropout=None):
245 self.categories = categories
246 self.output_name = output_name
247 self.dropout = dropout
249 def __call__(self, layer):
250 x = layers.GlobalAveragePooling2D()(layer)
251 if self.dropout is not None:
252 x = layers.Dropout(self.dropout)(x)
253 out = layer_blocks.OutputCateg(
254 categories=self.categories,
255 output_name=self.output_name,
256 transition=None,
257 )(x)
258 return out