Coverage for source/model/model_adapters/tf_model_adapter.py: 71%
35 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-06 12:00 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-06 12:00 +0000
1# model/model_adapters/tf_model_adapter.py
3# global imports
4from tensorflow.keras.models import Model
5from typing import Callable
6from tensorflow.keras.optimizers import Optimizer
7from typing import Any, Optional
9# local imports
10from source.model import ModelAdapterBase
12class TFModelAdapter(ModelAdapterBase):
13 """"""
15 WEIGHTS_FILE_EXTENSION: str = ".h5"
16 OPTIMIZER = "adam"
17 LOSS = "categorical_crossentropy"
18 METRICS = ["accuracy"]
20 def __init__(self, model: Model, optimizer: Optional[Optimizer] = None,
21 loss: Optional[str] = None, metrics: Optional[list[str]] = None) -> None:
22 """"""
24 if optimizer is None:
25 optimizer = TFModelAdapter.OPTIMIZER
26 if loss is None:
27 loss = TFModelAdapter.LOSS
28 if metrics is None:
29 metrics = TFModelAdapter.METRICS
31 self.__model: Model = model
32 self.__model.compile(optimizer = optimizer,
33 loss = loss,
34 metrics = metrics)
36 def load_model(self, path: str) -> None:
37 """"""
39 if TFModelAdapter.WEIGHTS_FILE_EXTENSION not in path:
40 raise ValueError(f"Model path must end with '{TFModelAdapter.WEIGHTS_FILE_EXTENSION}'.")
42 self.__model.load_weights(path)
44 def save_model(self, path: str) -> None:
45 """"""
47 if TFModelAdapter.WEIGHTS_FILE_EXTENSION not in path:
48 raise ValueError(f"Model path must end with '{TFModelAdapter.WEIGHTS_FILE_EXTENSION}'.")
50 self.__model.save_weights(path)
52 def print_summary(self, print_function: Callable = print) -> None:
53 """"""
55 self.__model.summary(print_fn = print_function)
57 def fit(self, input_data: Any, output_data: Any, **kwargs) -> dict:
58 """"""
60 return self.__model.fit(input_data, output_data,
61 validation_split = 0.1, **kwargs)
63 def predict(self, data: Any) -> dict:
64 """"""
66 return self.__model.predict(data)
68 def get_model(self) -> Model:
69 """"""
71 return self.__model