Coverage for source/model/model_adapters/tf_model_adapter.py: 95%
42 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-30 20:59 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-30 20:59 +0000
1# model/model_adapters/tf_model_adapter.py
3# global imports
4import numpy as np
5from tensorflow.keras.callbacks import Callback
6from tensorflow.keras.models import Model
7from tensorflow.keras.optimizers import Optimizer
8from tensorflow.keras.utils import to_categorical
9from typing import Any, Callable, Optional
11# local imports
12from source.model import ModelAdapterBase
14class TFModelAdapter(ModelAdapterBase):
15 """
16 Implements a model adapter for TensorFlow models. It provides methods for loading,
17 saving, printing summaries, fitting, predicting, and retrieving the model.
18 """
20 # global class constants
21 TAG: str = "tensorflow"
23 # local constants
24 __WEIGHTS_FILE_EXTENSION: str = ".h5"
25 __OPTIMIZER: str = "adam"
26 __LOSS: str = "categorical_crossentropy"
27 __METRICS: list[str] = ["accuracy"]
29 def __init__(self, model: Model, optimizer: Optional[Optimizer] = None,
30 loss: Optional[str] = None, metrics: Optional[list[str]] = None) -> None:
31 """
32 Class constructor. Initializes the model adapter with a TensorFlow model and optional parameters.
34 Parameters:
35 model (Model): The TensorFlow model to adapt.
36 optimizer (Optional[Optimizer]): The optimizer to use for training the model.
37 Defaults to Adam if not provided.
38 loss (Optional[str]): The loss function to use for training the model.
39 Defaults to categorical_crossentropy if not provided.
40 metrics (Optional[list[str]]): The metrics to evaluate during training.
41 Defaults to accuracy if not provided.
42 """
44 if optimizer is None:
45 optimizer = self.__OPTIMIZER
46 if loss is None:
47 loss = self.__LOSS
48 if metrics is None:
49 metrics = self.__METRICS
51 self.__model: Model = model
52 self.__model.compile(optimizer = optimizer, loss = loss, metrics = metrics)
54 self.__adjust_data_func = lambda x, y = None: (
55 np.expand_dims(np.array(x), axis = 1),
56 to_categorical(np.array(y), num_classes = self.__model.output_shape[1]) if y is not None else None
57 )
59 def load_model(self, path: str) -> None:
60 """
61 Loads a TensorFlow model from a file.
63 Parameters:
64 path (str): The path to the model file.
66 Raises:
67 ValueError: If the path does not end with the expected weights file extension.
68 """
70 if self.__WEIGHTS_FILE_EXTENSION not in path:
71 raise ValueError(f"Model path must end with '{self.__WEIGHTS_FILE_EXTENSION}'.")
73 self.__model.load_weights(path)
75 def save_model(self, path: str) -> None:
76 """
77 Saves a TensorFlow model to a file.
79 Parameters:
80 path (str): The path to the model file.
82 Raises:
83 ValueError: If the path does not end with the expected weights file extension.
84 """
86 if self.__WEIGHTS_FILE_EXTENSION not in path:
87 raise ValueError(f"Model path must end with '{self.__WEIGHTS_FILE_EXTENSION}'.")
89 self.__model.save_weights(path)
91 def print_summary(self, print_function: Callable = print) -> None:
92 """
93 Prints a summary of the model's architecture.
95 Parameters:
96 print_function (Callable): The function to use for printing the summary.
97 """
99 self.__model.summary(print_fn = print_function)
101 def fit(self, input_data: Any, output_data: Any, validation_data: Any,
102 epochs: int, batch_size: int, callbacks: list[Callback], **kwargs) -> dict:
103 """
104 Fits the model to the provided input and output data.
106 Parameters:
107 input_data (Any): The input data for fitting the model.
108 output_data (Any): The output data for fitting the model.
109 validation_data (Any): The validation data for evaluating the model.
110 epochs (int): The number of epochs to train the model.
111 batch_size (int): The batch size to use for training.
112 callbacks (list[Callback]): The list of callbacks to use during training.
113 (**kwargs): Additional keyword arguments for fitting the model.
115 Returns:
116 (dict): A dictionary containing the results of the fitting process.
117 """
119 validation_data = self.__adjust_data_func(validation_data[0], validation_data[1])
120 input_data, output_data = self.__adjust_data_func(input_data, output_data)
122 return self.__model.fit(input_data, output_data, epochs = epochs,
123 validation_data = validation_data, batch_size = batch_size,
124 callbacks = callbacks, **kwargs).history
126 def predict(self, data: Any) -> dict:
127 """
128 Predicts probabilities for the output for the given input data.
130 Parameters:
131 data (Any): The input data for prediction.
132 """
134 data, _ = self.__adjust_data_func(data)
135 return self.__model.predict(data)
137 def get_model(self) -> Model:
138 """
139 Retrieves the underlying model.
141 Returns:
142 (Model): The underlying model instance.
143 """
145 return self.__model