Coverage for source/model/model_building_blocks/se_block.py: 36%
14 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-06 12:00 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-06 12:00 +0000
1# model/model_building_blocks/se_block.py
3import tensorflow as tf
4from tensorflow.keras.layers import GlobalAveragePooling2D, Reshape, Dense, Multiply
6class SEBlock:
7 """
8 Class implementing a Squeeze-and-Excitation (SE) block compatible with the TensorFlow API.
9 This block applies global average pooling followed by a squeeze-and-excitation operation,
10 as described in the SE-Net architecture.
12 Diagram:
14 ::
16 Input Tensor last dimension length - ITldl
17 Reduction rate - Rr
19 Input Tensor --> +-----------------+ +---------------+ +--------------------+ +--------------+
20 | | GlobalAvgPool |-->| Reshape | | Dense | | Dense | +----------+
21 | +-----------------+ | Shape: ITldl |-->| Nodes: ITldl // Rr |-->| Nodes: ITldl |-->| Multiply |
22 | | | | | | | | |
23 | +---------------+ +--------------------+ +--------------+ | |
24 | | |
25 +------------------------------------------------------------------------------------------------->| |
26 +----------+ --> Output Tensor
27 """
29 def __init__(self, reduction_ratio: int = 16) -> None:
30 """
31 Class constructor.
33 Parameters:
34 reduction_ratio (int): Reduction ratio used to control the size of the squeeze operation.
35 """
37 self.__reduction_ratio: int = reduction_ratio
39 def __call__(self, input_tensor: tf.Tensor) -> tf.Tensor:
40 """
41 Applies squeeze-and-excitation operation to the input tensor.
43 Parameters:
44 input_tensor (tf.Tensor): Input tensor to which the SE operation should be applied.
46 Returns:
47 tf.Tensor: Output tensor after the SE operation has been applied.
48 """
50 filters = input_tensor.shape[-1]
51 x_shape = (1, 1, filters)
53 # Squeeze and excitation (SE)
54 x = GlobalAveragePooling2D()(input_tensor)
55 x = Reshape(x_shape)(x)
56 x = Dense(filters // self.__reduction_ratio, activation = 'relu', use_bias = False)(x)
57 x = Dense(filters, activation='sigmoid', use_bias=False)(x)
59 output_tensor = Multiply()([input_tensor, x])
61 return output_tensor