Coverage for source/model/model_building_blocks/se_block.py: 36%
14 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-30 20:59 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-30 20:59 +0000
1# model/model_building_blocks/se_block.py
3# global imports
4import tensorflow as tf
5from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Multiply, Reshape
7# local imports
9class SEBlock:
10 """
11 Class implementing a Squeeze-and-Excitation (SE) block compatible with the TensorFlow API.
12 This block applies global average pooling followed by a squeeze-and-excitation operation,
13 as described in the SE-Net architecture.
15 Diagram:
17 ::
19 Input Tensor last dimension length - ITldl
20 Reduction rate - Rr
22 Input Tensor --> +-----------------+ +---------------+ +--------------------+ +--------------+
23 | | GlobalAvgPool |-->| Reshape | | Dense | | Dense | +----------+
24 | +-----------------+ | Shape: ITldl |-->| Nodes: ITldl // Rr |-->| Nodes: ITldl |-->| Multiply |
25 | | | | | | | | |
26 | +---------------+ +--------------------+ +--------------+ | |
27 | | |
28 +------------------------------------------------------------------------------------------------->| |
29 +----------+ --> Output Tensor
30 """
32 def __init__(self, reduction_ratio: int = 16) -> None:
33 """
34 Class constructor.
36 Parameters:
37 reduction_ratio (int): Reduction ratio used to control the size of the squeeze operation.
38 """
40 self.__reduction_ratio: int = reduction_ratio
42 def __call__(self, input_tensor: tf.Tensor) -> tf.Tensor:
43 """
44 Applies squeeze-and-excitation operation to the input tensor.
46 Parameters:
47 input_tensor (tf.Tensor): Input tensor to which the SE operation should be applied.
49 Returns:
50 tf.Tensor: Output tensor after the SE operation has been applied.
51 """
53 filters = input_tensor.shape[-1]
54 x_shape = (1, 1, filters)
56 # Squeeze and excitation (SE)
57 x = GlobalAveragePooling2D()(input_tensor)
58 x = Reshape(x_shape)(x)
59 x = Dense(filters // self.__reduction_ratio, activation = 'relu', use_bias = False)(x)
60 x = Dense(filters, activation = 'sigmoid', use_bias = False)(x)
62 output_tensor = Multiply()([input_tensor, x])
64 return output_tensor