Coverage for orcanet/builder_util/layer_blocks.py: 95%

269 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-03-28 14:22 +0000

1import tensorflow as tf 

2import tensorflow.keras.backend as K 

3import tensorflow.keras as ks 

4import tensorflow.keras.layers as layers 

5import medgeconv 

6from orcanet.misc import get_register 

7 

8# for loading via toml and orcanet custom objects 

9blocks, register = get_register() 

10# edge conv blocks 

11register(medgeconv.DisjointEdgeConvBlock) 

12 

13 

14@register 

15class ConvBlock: 

16 """ 

17 1D/2D/3D Convolutional block followed by BatchNorm, Activation, 

18 MaxPooling and/or Dropout. 

19 

20 Parameters 

21 ---------- 

22 conv_dim : int 

23 Specifies the dimension of the convolutional block, 1D/2D/3D. 

24 filters : int 

25 Number of filters used for the convolutional layer. 

26 strides : int or tuple 

27 The stride length of the convolution. 

28 padding : str or int or list 

29 If str: Padding of the conv block. 

30 If int or list: Padding argument of a ZeroPaddingND layer that 

31 gets added before the convolution. 

32 kernel_size : int or tuple 

33 Kernel size which is used for all three dimensions. 

34 pool_size : None or int or tuple 

35 Specifies pool size for the pooling layer, e.g. (1,1,2) 

36 -> sizes for a 3D conv block. If its None, no pooling will be added, 

37 except for when global average pooling is used. 

38 pool_type : str, optional 

39 The type of pooling layer to add. Ignored if pool_size is None. 

40 Can be max_pooling (default), average_pooling, or 

41 global_average_pooling. 

42 pool_padding : str 

43 Padding option of the pooling layer. 

44 dropout : float or None 

45 Adds a dropout layer if the value is not None. 

46 Can not be used together with sdropout. 

47 Hint: 0 will add a dropout layer, but with a rate of 0 (=no dropout). 

48 sdropout : float or None 

49 Adds a spatial dropout layer if the value is not None. 

50 Can not be used together with dropout. 

51 activation : str or None 

52 Type of activation function that should be used. E.g. 'linear', 

53 'relu', 'elu', 'selu'. 

54 kernel_l2_reg : float, optional 

55 Regularization factor of l2 regularizer for the weights. 

56 batchnorm : bool 

57 Adds a batch normalization layer. 

58 kernel_initializer : string 

59 Initializer for the kernel weights. 

60 time_distributed : bool 

61 If True, apply the TimeDistributed Wrapper around all layers. 

62 dilation_rate : int 

63 An integer or tuple/list of a single integer, specifying the 

64 dilation rate to use for dilated convolution. Currently, 

65 specifying any dilation_rate value != 1 is incompatible 

66 with specifying any strides value != 1. 

67 

68 """ 

69 

70 def __init__( 

71 self, 

72 conv_dim, 

73 filters, 

74 kernel_size=3, 

75 strides=1, 

76 padding="same", 

77 pool_type="max_pooling", 

78 pool_size=None, 

79 pool_padding="valid", 

80 dropout=None, 

81 sdropout=None, 

82 activation="relu", 

83 kernel_l2_reg=None, 

84 batchnorm=False, 

85 kernel_initializer="he_normal", 

86 time_distributed=False, 

87 dilation_rate=1, 

88 ): 

89 self.conv_dim = conv_dim 

90 self.filters = filters 

91 self.kernel_size = kernel_size 

92 self.strides = strides 

93 self.padding = padding 

94 self.pool_type = pool_type 

95 self.pool_size = pool_size 

96 self.pool_padding = pool_padding 

97 self.dropout = dropout 

98 self.sdropout = sdropout 

99 self.activation = activation 

100 self.kernel_l2_reg = kernel_l2_reg 

101 self.batchnorm = batchnorm 

102 self.kernel_initializer = kernel_initializer 

103 self.time_distributed = time_distributed 

104 self.dilation_rate = dilation_rate 

105 

106 def __call__(self, inputs): 

107 if self.dropout is not None and self.sdropout is not None: 

108 raise ValueError( 

109 "Can only use either dropout or spatial " "dropout, not both" 

110 ) 

111 

112 dim_layers = _get_dimensional_layers(self.conv_dim) 

113 convolution_nd = dim_layers["convolution"] 

114 s_dropout_nd = dim_layers["s_dropout"] 

115 

116 if self.kernel_l2_reg is not None: 

117 kernel_reg = ks.regularizers.l2(self.kernel_l2_reg) 

118 else: 

119 kernel_reg = None 

120 

121 if self.batchnorm: 

122 use_bias = False 

123 else: 

124 use_bias = True 

125 

126 block_layers = list() 

127 

128 if isinstance(self.padding, str): 

129 padding = self.padding 

130 else: 

131 block_layers.append(dim_layers["zero_padding"](self.padding)) 

132 padding = "valid" 

133 

134 block_layers.append( 

135 convolution_nd( 

136 filters=self.filters, 

137 kernel_size=self.kernel_size, 

138 strides=self.strides, 

139 padding=padding, 

140 kernel_initializer=self.kernel_initializer, 

141 use_bias=use_bias, 

142 kernel_regularizer=kernel_reg, 

143 dilation_rate=self.dilation_rate, 

144 ) 

145 ) 

146 if self.batchnorm: 

147 channel_axis = ( 

148 1 if ks.backend.image_data_format() == "channels_first" else -1 

149 ) 

150 block_layers.append(layers.BatchNormalization(axis=channel_axis)) 

151 if self.activation is not None: 

152 block_layers.append(layers.Activation(self.activation)) 

153 

154 if self.pool_type == "global_average_pooling": 

155 pooling_nd = dim_layers[self.pool_type] 

156 block_layers.append(pooling_nd()) 

157 elif self.pool_size is not None: 

158 pooling_nd = dim_layers[self.pool_type] 

159 block_layers.append( 

160 pooling_nd(pool_size=self.pool_size, padding=self.pool_padding) 

161 ) 

162 

163 if self.dropout is not None: 

164 block_layers.append(layers.Dropout(self.dropout)) 

165 elif self.sdropout is not None: 

166 block_layers.append(s_dropout_nd(self.sdropout)) 

167 

168 x = inputs 

169 for block_layer in block_layers: 

170 if self.time_distributed: 

171 x = layers.TimeDistributed(block_layer)(x) 

172 else: 

173 x = block_layer(x) 

174 return x 

175 

176 

177@register 

178class DenseBlock: 

179 """ 

180 Dense layer followed by BatchNorm, Activation and/or Dropout. 

181 

182 Parameters 

183 ---------- 

184 units : int 

185 Number of neurons of the dense layer. 

186 dropout : float or None 

187 Adds a dropout layer if the value is not None. 

188 activation : str or None 

189 Type of activation function that should be used. E.g. 'linear', 

190 'relu', 'elu', 'selu'. 

191 kernel_l2_reg : float, optional 

192 Regularization factor of l2 regularizer for the weights. 

193 batchnorm : bool 

194 Adds a batch normalization layer. 

195 

196 """ 

197 

198 def __init__( 

199 self, 

200 units, 

201 dropout=None, 

202 activation="relu", 

203 kernel_l2_reg=None, 

204 batchnorm=False, 

205 kernel_initializer="he_normal", 

206 ): 

207 self.units = units 

208 self.dropout = dropout 

209 self.activation = activation 

210 self.kernel_l2_reg = kernel_l2_reg 

211 self.batchnorm = batchnorm 

212 self.kernel_initializer = kernel_initializer 

213 

214 def __call__(self, inputs): 

215 if self.kernel_l2_reg is not None: 

216 kernel_reg = ks.regularizers.l2(self.kernel_l2_reg) 

217 else: 

218 kernel_reg = None 

219 

220 if self.batchnorm: 

221 use_bias = False 

222 else: 

223 use_bias = True 

224 

225 x = layers.Dense( 

226 units=self.units, 

227 use_bias=use_bias, 

228 kernel_initializer=self.kernel_initializer, 

229 kernel_regularizer=kernel_reg, 

230 )(inputs) 

231 

232 if self.batchnorm: 

233 channel_axis = ( 

234 1 if ks.backend.image_data_format() == "channels_first" else -1 

235 ) 

236 x = layers.BatchNormalization(axis=channel_axis)(x) 

237 if self.activation is not None: 

238 x = layers.Activation(self.activation)(x) 

239 if self.dropout is not None: 

240 x = layers.Dropout(self.dropout)(x) 

241 return x 

242 

243 

244@register 

245class ResnetBlock: 

246 """ 

247 A residual building block for resnets. 2 c layers with a shortcut. 

248 https://arxiv.org/pdf/1605.07146.pdf 

249 

250 Parameters 

251 ---------- 

252 conv_dim : int 

253 Specifies the dimension of the convolutional block, 2D/3D. 

254 filters : int 

255 Number of filters used for the convolutional layers. 

256 strides : int or tuple 

257 The stride length of the convolution. If strides is 1, this is 

258 the identity block. If not, it has a conv block 

259 at the shortcut. 

260 kernel_size : int or tuple 

261 Kernel size which is used for all three dimensions. 

262 activation : str or None 

263 Type of activation function that should be used. E.g. 'linear', 

264 'relu', 'elu', 'selu'. 

265 batchnorm : bool 

266 Adds a batch normalization layer. 

267 kernel_initializer : string 

268 Initializer for the kernel weights. 

269 time_distributed : bool 

270 If True, apply the TimeDistributed Wrapper around all layers. 

271 

272 """ 

273 

274 def __init__( 

275 self, 

276 conv_dim, 

277 filters, 

278 strides=1, 

279 kernel_size=3, 

280 activation="relu", 

281 batchnorm=False, 

282 kernel_initializer="he_normal", 

283 time_distributed=False, 

284 ): 

285 self.conv_dim = conv_dim 

286 self.filters = filters 

287 self.kernel_size = kernel_size 

288 self.strides = strides 

289 self.activation = activation 

290 self.batchnorm = batchnorm 

291 self.kernel_initializer = kernel_initializer 

292 self.time_distributed = time_distributed 

293 

294 def __call__(self, inputs): 

295 x = ConvBlock( 

296 conv_dim=self.conv_dim, 

297 filters=self.filters, 

298 kernel_size=self.kernel_size, 

299 strides=self.strides, 

300 kernel_initializer=self.kernel_initializer, 

301 batchnorm=self.batchnorm, 

302 activation=self.activation, 

303 time_distributed=self.time_distributed, 

304 )(inputs) 

305 x = ConvBlock( 

306 conv_dim=self.conv_dim, 

307 filters=self.filters, 

308 kernel_size=self.kernel_size, 

309 kernel_initializer=self.kernel_initializer, 

310 batchnorm=self.batchnorm, 

311 activation=None, 

312 time_distributed=self.time_distributed, 

313 )(x) 

314 

315 if self.strides != 1: 

316 shortcut = ConvBlock( 

317 conv_dim=self.conv_dim, 

318 filters=self.filters, 

319 kernel_size=1, 

320 strides=self.strides, 

321 kernel_initializer=self.kernel_initializer, 

322 activation=None, 

323 batchnorm=self.batchnorm, 

324 time_distributed=self.time_distributed, 

325 )(inputs) 

326 else: 

327 shortcut = inputs 

328 

329 x = layers.add([x, shortcut]) 

330 acti_layer = layers.Activation(self.activation) 

331 if self.time_distributed: 

332 return layers.TimeDistributed(acti_layer)(x) 

333 else: 

334 return acti_layer(x) 

335 

336 

337@register 

338class ResnetBnetBlock: 

339 """ 

340 A residual bottleneck building block for resnets. 

341 https://arxiv.org/pdf/1605.07146.pdf 

342 

343 Parameters 

344 ---------- 

345 conv_dim : int 

346 Specifies the dimension of the convolutional block, 2D/3D. 

347 filters : List 

348 Number of filters used for the convolutional layers. 

349 Has to be length 3. First and third is for the 1x1 convolutions. 

350 strides : int or tuple 

351 The stride length of the convolution. If strides is 1, this is 

352 the identity block. If not, it has a conv block 

353 at the shortcut. 

354 kernel_size : int or tuple 

355 Kernel size which is used for all three dimensions. 

356 activation : str or None 

357 Type of activation function that should be used. E.g. 'linear', 

358 'relu', 'elu', 'selu'. 

359 batchnorm : bool 

360 Adds a batch normalization layer. 

361 kernel_initializer : string 

362 Initializer for the kernel weights. 

363 

364 """ 

365 

366 def __init__( 

367 self, 

368 conv_dim, 

369 filters, 

370 strides=1, 

371 kernel_size=3, 

372 activation="relu", 

373 batchnorm=False, 

374 kernel_initializer="he_normal", 

375 ): 

376 self.conv_dim = conv_dim 

377 self.filters = filters 

378 self.kernel_size = kernel_size 

379 self.strides = strides 

380 self.activation = activation 

381 self.batchnorm = batchnorm 

382 self.kernel_initializer = kernel_initializer 

383 

384 def __call__(self, inputs): 

385 filters1, filters2, filters3 = self.filters 

386 

387 x = ConvBlock( 

388 conv_dim=self.conv_dim, 

389 filters=filters1, 

390 kernel_size=1, 

391 strides=self.strides, 

392 kernel_initializer=self.kernel_initializer, 

393 batchnorm=self.batchnorm, 

394 activation=self.activation, 

395 )(inputs) 

396 x = ConvBlock( 

397 conv_dim=self.conv_dim, 

398 filters=filters2, 

399 kernel_size=self.kernel_size, 

400 kernel_initializer=self.kernel_initializer, 

401 batchnorm=self.batchnorm, 

402 activation=self.activation, 

403 )(x) 

404 x = ConvBlock( 

405 conv_dim=self.conv_dim, 

406 filters=filters3, 

407 kernel_size=1, 

408 kernel_initializer=self.kernel_initializer, 

409 batchnorm=self.batchnorm, 

410 activation=None, 

411 )(x) 

412 

413 if self.strides != 1: 

414 shortcut = ConvBlock( 

415 conv_dim=self.conv_dim, 

416 filters=filters3, 

417 kernel_size=1, 

418 strides=self.strides, 

419 kernel_initializer=self.kernel_initializer, 

420 activation=None, 

421 batchnorm=self.batchnorm, 

422 )(inputs) 

423 else: 

424 shortcut = inputs 

425 

426 x = layers.add([x, shortcut]) 

427 x = layers.Activation(self.activation)(x) 

428 return x 

429 

430 

431@register 

432class InceptionBlockV2: 

433 """ 

434 A GoogleNet Inception block (v2). 

435 https://arxiv.org/pdf/1512.00567v3.pdf, see fig. 5. 

436 Keras implementation, e.g.: 

437 https://github.com/keras-team/keras-applications/blob/master/keras_applications/inception_resnet_v2.py 

438 

439 Parameters 

440 ---------- 

441 conv_dim : int 

442 Specifies the dimension of the convolutional block, 1D/2D/3D. 

443 filters_1x1 : int or None 

444 No. of filters for the 1x1 convolutional branch. 

445 If None, dont make this branch. 

446 filters_pool : int or None 

447 No. of filters for the pooling branch. 

448 If None, dont make this branch. 

449 filters_3x3 : tuple or None 

450 No. of filters for the 3x3 convolutional branch. First int 

451 is the filters in the 1x1 conv, second int for the 3x3 conv. 

452 First should be chosen smaller for computational efficiency. 

453 If None, dont make this branch. 

454 filters_3x3dbl : tuple or None 

455 No. of filters for the 3x3 convolutional branch. First int 

456 is the filters in the 1x1 conv, second int for the two 3x3 convs. 

457 First should be chosen smaller for computational efficiency. 

458 If None, dont make this branch. 

459 strides : int or tuple 

460 Stride length of this block. 

461 Like in the keras implementation, no 1x1 convs with stride > 1 

462 will be used, instead they will be skipped. 

463 

464 """ 

465 

466 def __init__( 

467 self, 

468 conv_dim, 

469 filters_1x1, 

470 filters_pool, 

471 filters_3x3, 

472 filters_3x3dbl, 

473 strides=1, 

474 activation="relu", 

475 batchnorm=False, 

476 dropout=None, 

477 ): 

478 self.filters_1x1 = filters_1x1 # 64 

479 self.filters_pool = filters_pool # 64 

480 self.filters_3x3 = filters_3x3 # 48, 64 

481 self.filters_3x3dbl = filters_3x3dbl # 64, 96 

482 self.strides = strides 

483 self.conv_options = { 

484 "conv_dim": conv_dim, 

485 "dropout": dropout, 

486 "batchnorm": batchnorm, 

487 "activation": activation, 

488 } 

489 

490 def __call__(self, inputs): 

491 branches = [] 

492 # 1x1 convolution 

493 if self.filters_1x1 and self.strides == 1: 

494 branch1x1 = ConvBlock( 

495 filters=self.filters_1x1, 

496 kernel_size=1, 

497 strides=self.strides, 

498 **self.conv_options, 

499 )(inputs) 

500 branches.append(branch1x1) 

501 

502 # pooling 

503 if self.filters_pool: 

504 max_pooling_nd = _get_dimensional_layers(self.conv_options["conv_dim"])[ 

505 "max_pooling" 

506 ] 

507 branch_pool = max_pooling_nd( 

508 pool_size=3, strides=self.strides, padding="same" 

509 )(inputs) 

510 if self.strides == 1: 

511 branch_pool = ConvBlock( 

512 filters=self.filters_pool, kernel_size=1, **self.conv_options 

513 )(branch_pool) 

514 branches.append(branch_pool) 

515 

516 # 3x3 convolution 

517 if self.filters_3x3: 

518 branch3x3 = ConvBlock( 

519 filters=self.filters_3x3[0], kernel_size=1, **self.conv_options 

520 )(inputs) 

521 branch3x3 = ConvBlock( 

522 filters=self.filters_3x3[1], 

523 kernel_size=3, 

524 strides=self.strides, 

525 **self.conv_options, 

526 )(branch3x3) 

527 branches.append(branch3x3) 

528 

529 # double 3x3 convolution 

530 if self.filters_3x3dbl: 

531 branch3x3dbl = ConvBlock( 

532 filters=self.filters_3x3dbl[0], kernel_size=1, **self.conv_options 

533 )(inputs) 

534 branch3x3dbl = ConvBlock( 

535 filters=self.filters_3x3dbl[1], kernel_size=1, **self.conv_options 

536 )(branch3x3dbl) 

537 branch3x3dbl = ConvBlock( 

538 filters=self.filters_3x3dbl[1], 

539 kernel_size=1, 

540 strides=self.strides, 

541 **self.conv_options, 

542 )(branch3x3dbl) 

543 branches.append(branch3x3dbl) 

544 

545 # concatenate all branches 

546 channel_axis = 1 if ks.backend.image_data_format() == "channels_first" else -1 

547 x = layers.concatenate(branches, axis=channel_axis) 

548 return x 

549 

550 

551@register 

552class OutputReg: 

553 """ 

554 Dense layer(s) for regression. 

555 

556 Parameters 

557 ---------- 

558 output_neurons : int 

559 Number of neurons in the last layer. 

560 output_name : str or None 

561 Name that will be given to the output layer of the network. 

562 unit_list : List, optional 

563 A list of ints. Add additional Dense layers after the gpool 

564 with this many units in them. E.g., [64, 32] would add 

565 two Dense layers, the first with 64 neurons, the secound with 

566 32 neurons. 

567 transition : str or None 

568 Name of a layer that will be used as the first layer of this block. 

569 Example: 'keras:GlobalAveragePooling2D', 'keras:Flatten' 

570 kwargs 

571 Keywords for the dense blocks that get added if unit_list is 

572 not None. 

573 

574 """ 

575 

576 def __init__( 

577 self, 

578 output_neurons, 

579 output_name, 

580 unit_list=None, 

581 transition="keras:Flatten", 

582 **kwargs, 

583 ): 

584 self.output_neurons = output_neurons 

585 self.output_name = output_name 

586 if isinstance(unit_list, int): 

587 unit_list = (unit_list,) 

588 self.unit_list = unit_list 

589 self.transition = transition 

590 self.kwargs = kwargs 

591 

592 def __call__(self, layer): 

593 if self.transition: 

594 x = getattr(layers, self.transition.split("keras:")[-1])()(layer) 

595 else: 

596 x = layer 

597 

598 if self.unit_list is not None: 

599 for units in self.unit_list: 

600 x = DenseBlock(units=units, **self.kwargs)(x) 

601 

602 out = layers.Dense( 

603 units=self.output_neurons, activation=None, name=self.output_name 

604 )(x) 

605 

606 return out 

607 

608 

609@register 

610class OutputRegNormal: 

611 """ 

612 Output block for regression using a normal distribution as output. 

613 

614 The output tensor will have shape (?, 2, output_neurons), 

615 with [:, 0] being the mu and [:, 1] being the sigma. 

616 

617 Parameters 

618 ---------- 

619 mu_activation : str, optional 

620 Activation function for the mu neurons. 

621 sigma_activation : str, optional 

622 Activation function for the sigma neurons. 

623 

624 See OutputReg for other parameters. 

625 

626 """ 

627 

628 def __init__( 

629 self, 

630 output_neurons, 

631 output_name, 

632 unit_list=None, 

633 mu_activation=None, 

634 sigma_activation="softplus", 

635 transition=None, 

636 **kwargs, 

637 ): 

638 self.output_neurons = output_neurons 

639 self.output_name = output_name 

640 if isinstance(unit_list, int): 

641 unit_list = (unit_list,) 

642 self.unit_list = unit_list 

643 self.mu_activation = mu_activation 

644 self.sigma_activation = sigma_activation 

645 self.transition = transition 

646 self.kwargs = kwargs 

647 

648 def __call__(self, layer): 

649 if self.transition: 

650 x = getattr(layers, self.transition.split("keras:")[-1])()(layer) 

651 else: 

652 x = layer 

653 

654 if self.unit_list is not None: 

655 for units in self.unit_list: 

656 x = DenseBlock(units=units, **self.kwargs)(x) 

657 

658 mu = layers.Dense( 

659 units=self.output_neurons, 

660 activation=self.mu_activation, 

661 name=f"{self.output_name}_mu", 

662 )(x) 

663 sigma = layers.Dense( 

664 units=self.output_neurons, 

665 activation=self.sigma_activation, 

666 name=f"{self.output_name}_sigma", 

667 )(x) 

668 

669 return layers.Concatenate(name=self.output_name, axis=-2)( 

670 [tf.expand_dims(tsr, -2) for tsr in [mu, sigma]] 

671 ) 

672 

673 

674@register 

675class OutputRegNormalSplit(OutputRegNormal): 

676 """ 

677 Output block for regression using a normal distribution as output. 

678 

679 The sigma will be produced by its own tower of dense layers that 

680 is seperated from the rest of the network via gradient stop. 

681 

682 The output is a list with two tensors: 

683 - The first is the mu with shape (?, output_neurons) and name output_name. 

684 - The second is mu + sigma with shape (?, 2, output_neurons), 

685 with [:, 0] being the mu and [:, 1] being the sigma. 

686 Its name is output_name + '_err'. 

687 

688 Parameters 

689 ---------- 

690 sigma_unit_list : List, optional 

691 A list of ints. Neurons in the Dense layers for the tower that 

692 outputs the sigma. E.g., [64, 32] would add 

693 two Dense layers, the first with 64 neurons, the second with 

694 32 neurons. 

695 Default: Same as unit_list. 

696 

697 See OutputRegNormal for other parameters. 

698 

699 """ 

700 

701 def __init__(self, *args, sigma_unit_list=None, **kwargs): 

702 super().__init__(*args, **kwargs) 

703 if sigma_unit_list is None: 

704 sigma_unit_list = self.unit_list 

705 self.sigma_unit_list = sigma_unit_list 

706 

707 def __call__(self, layer): 

708 if self.transition: 

709 x_base = getattr(layers, self.transition.split("keras:")[-1])()(layer) 

710 else: 

711 x_base = layer 

712 

713 x = x_base 

714 if self.unit_list is not None: 

715 for units in self.unit_list: 

716 x = DenseBlock(units=units, **self.kwargs)(x) 

717 mu = layers.Dense( 

718 units=self.output_neurons, 

719 activation=self.mu_activation, 

720 name=self.output_name, 

721 )(x) 

722 

723 # Network for the errors of the labels 

724 x = ks.backend.stop_gradient(x_base) 

725 if self.sigma_unit_list is not None: 

726 for units in self.sigma_unit_list: 

727 x = DenseBlock(units=units, **self.kwargs)(x) 

728 sigma = layers.Dense( 

729 units=self.output_neurons, 

730 activation=self.sigma_activation, 

731 name=f"{self.output_name}_sigma", 

732 )(x) 

733 

734 mu_stopped = ks.backend.stop_gradient(mu) 

735 err_output = layers.Concatenate(name=f"{self.output_name}_err", axis=-2)( 

736 [tf.expand_dims(tsr, -2) for tsr in [mu_stopped, sigma]] 

737 ) 

738 

739 return [mu, err_output] 

740 

741 

742@register 

743class OutputCateg: 

744 """ 

745 Dense layer(s) for categorization. 

746 

747 Parameters 

748 ---------- 

749 categories : int 

750 Number of categories (= neurons in the last layer). 

751 output_name : str 

752 Name that will be given to the output layer of the network. 

753 unit_list : List, optional 

754 A list of ints. Add additional Dense layers after the gpool 

755 with this many units in them. E.g., [64, 32] would add 

756 two Dense layers, the first with 64 neurons, the secound with 

757 32 neurons. 

758 transition : str or None 

759 Name of a layer that will be used as the first layer of this block. 

760 Example: 'keras:GlobalAveragePooling2D', 'keras:Flatten' 

761 kwargs 

762 Keywords for the dense blocks that get added if unit_list is 

763 not None. 

764 

765 """ 

766 

767 def __init__( 

768 self, 

769 categories, 

770 output_name, 

771 unit_list=None, 

772 transition="keras:Flatten", 

773 **kwargs, 

774 ): 

775 self.categories = categories 

776 self.output_name = output_name 

777 self.unit_list = unit_list 

778 self.transition = transition 

779 self.kwargs = kwargs 

780 

781 def __call__(self, layer): 

782 if self.transition: 

783 x = getattr(layers, self.transition.split("keras:")[-1])()(layer) 

784 else: 

785 x = layer 

786 

787 if self.unit_list is not None: 

788 for units in self.unit_list: 

789 x = DenseBlock(units=units, **self.kwargs)(x) 

790 

791 out = layers.Dense( 

792 units=self.categories, 

793 activation="softmax", 

794 kernel_initializer="he_normal", 

795 name=self.output_name, 

796 )(x) 

797 

798 return out 

799 

800 

801@register 

802class OutputRegErr: 

803 """ 

804 Double network for regression + error estimation. 

805 

806 It has 3 dense layer blocks, followed by one dense layer 

807 for each output_name, as well as dense layer blocks, followed by one dense layer 

808 for the respective error of each output_name. 

809 

810 Parameters 

811 ---------- 

812 output_names : List 

813 List of strs, the output names, each with one neuron + one err neuron. 

814 flatten : bool 

815 If True, start with a flatten layer. 

816 kwargs 

817 Keywords for the dense blocks. 

818 

819 """ 

820 

821 # TODO deprecated, only here for historcal reasons 

822 def __init__(self, output_names, flatten=True, **kwargs): 

823 self.flatten = flatten 

824 self.output_names = output_names 

825 self.kwargs = kwargs 

826 

827 def __call__(self, layer): 

828 if self.flatten: 

829 flatten = layers.Flatten()(layer) 

830 else: 

831 flatten = layer 

832 outputs = [] 

833 

834 x = DenseBlock(units=128, **self.kwargs)(flatten) 

835 x = DenseBlock(units=32, **self.kwargs)(x) 

836 

837 for name in self.output_names: 

838 output_label = layers.Dense(units=1, name=name)(x) 

839 outputs.append(output_label) 

840 

841 # Network for the errors of the labels 

842 x_err = layers.Lambda(lambda a: K.stop_gradient(a))(flatten) 

843 

844 x_err = DenseBlock(units=128, **self.kwargs)(x_err) 

845 x_err = DenseBlock(units=64, **self.kwargs)(x_err) 

846 x_err = DenseBlock(units=32, **self.kwargs)(x_err) 

847 

848 for i, name in enumerate(self.output_names): 

849 output_label_error = layers.Dense( 

850 units=1, activation="linear", name=name + "_err_temp" 

851 )(x_err) 

852 # Predicted label gets concatenated with its error (needed for loss function) 

853 output_label_merged = layers.Concatenate(name=name + "_err")( 

854 [outputs[i], output_label_error] 

855 ) 

856 outputs.append(output_label_merged) 

857 return outputs 

858 

859 

860def _get_dimensional_layers(dim): 

861 if dim not in (1, 2, 3): 

862 raise ValueError(f"Dimension must be 1, 2 or 3, not {dim}") 

863 dim_layers = { 

864 "convolution": { 

865 1: layers.Convolution1D, 

866 2: layers.Convolution2D, 

867 3: layers.Convolution3D, 

868 }, 

869 "max_pooling": { 

870 1: layers.MaxPooling1D, 

871 2: layers.MaxPooling2D, 

872 3: layers.MaxPooling3D, 

873 }, 

874 "average_pooling": { 

875 1: layers.AveragePooling1D, 

876 2: layers.AveragePooling2D, 

877 3: layers.AveragePooling3D, 

878 }, 

879 "global_average_pooling": { 

880 1: layers.GlobalAveragePooling1D, 

881 2: layers.GlobalAveragePooling2D, 

882 3: layers.GlobalAveragePooling3D, 

883 }, 

884 "s_dropout": { 

885 1: layers.SpatialDropout1D, 

886 2: layers.SpatialDropout2D, 

887 3: layers.SpatialDropout3D, 

888 }, 

889 "zero_padding": { 

890 1: layers.ZeroPadding1D, 

891 2: layers.ZeroPadding2D, 

892 3: layers.ZeroPadding3D, 

893 }, 

894 } 

895 return {layer_type: dim_layers[layer_type][dim] for layer_type in dim_layers.keys()}