Coverage for source/model/model_building_blocks/se_block.py: 36%

14 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-06-06 12:00 +0000

1# model/model_building_blocks/se_block.py 

2 

3import tensorflow as tf 

4from tensorflow.keras.layers import GlobalAveragePooling2D, Reshape, Dense, Multiply 

5 

6class SEBlock: 

7 """ 

8 Class implementing a Squeeze-and-Excitation (SE) block compatible with the TensorFlow API. 

9 This block applies global average pooling followed by a squeeze-and-excitation operation, 

10 as described in the SE-Net architecture. 

11 

12 Diagram: 

13 

14 :: 

15 

16 Input Tensor last dimension length - ITldl 

17 Reduction rate - Rr 

18 

19 Input Tensor --> +-----------------+ +---------------+ +--------------------+ +--------------+ 

20 | | GlobalAvgPool |-->| Reshape | | Dense | | Dense | +----------+ 

21 | +-----------------+ | Shape: ITldl |-->| Nodes: ITldl // Rr |-->| Nodes: ITldl |-->| Multiply | 

22 | | | | | | | | | 

23 | +---------------+ +--------------------+ +--------------+ | | 

24 | | | 

25 +------------------------------------------------------------------------------------------------->| | 

26 +----------+ --> Output Tensor 

27 """ 

28 

29 def __init__(self, reduction_ratio: int = 16) -> None: 

30 """ 

31 Class constructor. 

32 

33 Parameters: 

34 reduction_ratio (int): Reduction ratio used to control the size of the squeeze operation. 

35 """ 

36 

37 self.__reduction_ratio: int = reduction_ratio 

38 

39 def __call__(self, input_tensor: tf.Tensor) -> tf.Tensor: 

40 """ 

41 Applies squeeze-and-excitation operation to the input tensor. 

42 

43 Parameters: 

44 input_tensor (tf.Tensor): Input tensor to which the SE operation should be applied. 

45 

46 Returns: 

47 tf.Tensor: Output tensor after the SE operation has been applied. 

48 """ 

49 

50 filters = input_tensor.shape[-1] 

51 x_shape = (1, 1, filters) 

52 

53 # Squeeze and excitation (SE) 

54 x = GlobalAveragePooling2D()(input_tensor) 

55 x = Reshape(x_shape)(x) 

56 x = Dense(filters // self.__reduction_ratio, activation = 'relu', use_bias = False)(x) 

57 x = Dense(filters, activation='sigmoid', use_bias=False)(x) 

58 

59 output_tensor = Multiply()([input_tensor, x]) 

60 

61 return output_tensor