Coverage for source/model/model_building_blocks/se_block.py: 36%

14 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-07-30 20:59 +0000

1# model/model_building_blocks/se_block.py 

2 

3# global imports 

4import tensorflow as tf 

5from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Multiply, Reshape 

6 

7# local imports 

8 

9class SEBlock: 

10 """ 

11 Class implementing a Squeeze-and-Excitation (SE) block compatible with the TensorFlow API. 

12 This block applies global average pooling followed by a squeeze-and-excitation operation, 

13 as described in the SE-Net architecture. 

14 

15 Diagram: 

16 

17 :: 

18 

19 Input Tensor last dimension length - ITldl 

20 Reduction rate - Rr 

21 

22 Input Tensor --> +-----------------+ +---------------+ +--------------------+ +--------------+ 

23 | | GlobalAvgPool |-->| Reshape | | Dense | | Dense | +----------+ 

24 | +-----------------+ | Shape: ITldl |-->| Nodes: ITldl // Rr |-->| Nodes: ITldl |-->| Multiply | 

25 | | | | | | | | | 

26 | +---------------+ +--------------------+ +--------------+ | | 

27 | | | 

28 +------------------------------------------------------------------------------------------------->| | 

29 +----------+ --> Output Tensor 

30 """ 

31 

32 def __init__(self, reduction_ratio: int = 16) -> None: 

33 """ 

34 Class constructor. 

35 

36 Parameters: 

37 reduction_ratio (int): Reduction ratio used to control the size of the squeeze operation. 

38 """ 

39 

40 self.__reduction_ratio: int = reduction_ratio 

41 

42 def __call__(self, input_tensor: tf.Tensor) -> tf.Tensor: 

43 """ 

44 Applies squeeze-and-excitation operation to the input tensor. 

45 

46 Parameters: 

47 input_tensor (tf.Tensor): Input tensor to which the SE operation should be applied. 

48 

49 Returns: 

50 tf.Tensor: Output tensor after the SE operation has been applied. 

51 """ 

52 

53 filters = input_tensor.shape[-1] 

54 x_shape = (1, 1, filters) 

55 

56 # Squeeze and excitation (SE) 

57 x = GlobalAveragePooling2D()(input_tensor) 

58 x = Reshape(x_shape)(x) 

59 x = Dense(filters // self.__reduction_ratio, activation = 'relu', use_bias = False)(x) 

60 x = Dense(filters, activation = 'sigmoid', use_bias = False)(x) 

61 

62 output_tensor = Multiply()([input_tensor, x]) 

63 

64 return output_tensor