Coverage for source/model/model_adapters/tf_model_adapter.py: 95%

42 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-08-01 20:51 +0000

1# model/model_adapters/tf_model_adapter.py 

2 

3# global imports 

4import numpy as np 

5from tensorflow.keras.callbacks import Callback 

6from tensorflow.keras.models import Model 

7from tensorflow.keras.optimizers import Optimizer 

8from tensorflow.keras.utils import to_categorical 

9from typing import Any, Callable, Optional 

10 

11# local imports 

12from source.model import ModelAdapterBase 

13 

14class TFModelAdapter(ModelAdapterBase): 

15 """ 

16 Implements a model adapter for TensorFlow models. It provides methods for loading, 

17 saving, printing summaries, fitting, predicting, and retrieving the model. 

18 """ 

19 

20 # global class constants 

21 TAG: str = "tensorflow" 

22 

23 # local constants 

24 __WEIGHTS_FILE_EXTENSION: str = ".h5" 

25 __OPTIMIZER: str = "adam" 

26 __LOSS: str = "categorical_crossentropy" 

27 __METRICS: list[str] = ["accuracy"] 

28 

29 def __init__(self, model: Model, optimizer: Optional[Optimizer] = None, 

30 loss: Optional[str] = None, metrics: Optional[list[str]] = None) -> None: 

31 """ 

32 Class constructor. Initializes the model adapter with a TensorFlow model and optional parameters. 

33 

34 Parameters: 

35 model (Model): The TensorFlow model to adapt. 

36 optimizer (Optional[Optimizer]): The optimizer to use for training the model. 

37 Defaults to Adam if not provided. 

38 loss (Optional[str]): The loss function to use for training the model. 

39 Defaults to categorical_crossentropy if not provided. 

40 metrics (Optional[list[str]]): The metrics to evaluate during training. 

41 Defaults to accuracy if not provided. 

42 """ 

43 

44 if optimizer is None: 

45 optimizer = self.__OPTIMIZER 

46 if loss is None: 

47 loss = self.__LOSS 

48 if metrics is None: 

49 metrics = self.__METRICS 

50 

51 self.__model: Model = model 

52 self.__model.compile(optimizer = optimizer, loss = loss, metrics = metrics) 

53 

54 self.__adjust_data_func = lambda x, y = None: ( 

55 np.expand_dims(np.array(x), axis = 1), 

56 to_categorical(np.array(y), num_classes = self.__model.output_shape[1]) if y is not None else None 

57 ) 

58 

59 def load_model(self, path: str) -> None: 

60 """ 

61 Loads a TensorFlow model from a file. 

62 

63 Parameters: 

64 path (str): The path to the model file. 

65 

66 Raises: 

67 ValueError: If the path does not end with the expected weights file extension. 

68 """ 

69 

70 if self.__WEIGHTS_FILE_EXTENSION not in path: 

71 raise ValueError(f"Model path must end with '{self.__WEIGHTS_FILE_EXTENSION}'.") 

72 

73 self.__model.load_weights(path) 

74 

75 def save_model(self, path: str) -> None: 

76 """ 

77 Saves a TensorFlow model to a file. 

78 

79 Parameters: 

80 path (str): The path to the model file. 

81 

82 Raises: 

83 ValueError: If the path does not end with the expected weights file extension. 

84 """ 

85 

86 if self.__WEIGHTS_FILE_EXTENSION not in path: 

87 raise ValueError(f"Model path must end with '{self.__WEIGHTS_FILE_EXTENSION}'.") 

88 

89 self.__model.save_weights(path) 

90 

91 def print_summary(self, print_function: Callable = print) -> None: 

92 """ 

93 Prints a summary of the model's architecture. 

94 

95 Parameters: 

96 print_function (Callable): The function to use for printing the summary. 

97 """ 

98 

99 self.__model.summary(print_fn = print_function) 

100 

101 def fit(self, input_data: Any, output_data: Any, validation_data: Any, 

102 epochs: int, batch_size: int, callbacks: list[Callback], **kwargs) -> dict: 

103 """ 

104 Fits the model to the provided input and output data. 

105 

106 Parameters: 

107 input_data (Any): The input data for fitting the model. 

108 output_data (Any): The output data for fitting the model. 

109 validation_data (Any): The validation data for evaluating the model. 

110 epochs (int): The number of epochs to train the model. 

111 batch_size (int): The batch size to use for training. 

112 callbacks (list[Callback]): The list of callbacks to use during training. 

113 (**kwargs): Additional keyword arguments for fitting the model. 

114 

115 Returns: 

116 (dict): A dictionary containing the results of the fitting process. 

117 """ 

118 

119 validation_data = self.__adjust_data_func(validation_data[0], validation_data[1]) 

120 input_data, output_data = self.__adjust_data_func(input_data, output_data) 

121 

122 return self.__model.fit(input_data, output_data, epochs = epochs, 

123 validation_data = validation_data, batch_size = batch_size, 

124 callbacks = callbacks, **kwargs).history 

125 

126 def predict(self, data: Any) -> dict: 

127 """ 

128 Predicts probabilities for the output for the given input data. 

129 

130 Parameters: 

131 data (Any): The input data for prediction. 

132 """ 

133 

134 data, _ = self.__adjust_data_func(data) 

135 return self.__model.predict(data) 

136 

137 def get_model(self) -> Model: 

138 """ 

139 Retrieves the underlying model. 

140 

141 Returns: 

142 (Model): The underlying model instance. 

143 """ 

144 

145 return self.__model