Coverage for source/model/model_adapters/tf_model_adapter.py: 71%

35 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-06-06 12:00 +0000

1# model/model_adapters/tf_model_adapter.py 

2 

3# global imports 

4from tensorflow.keras.models import Model 

5from typing import Callable 

6from tensorflow.keras.optimizers import Optimizer 

7from typing import Any, Optional 

8 

9# local imports 

10from source.model import ModelAdapterBase 

11 

12class TFModelAdapter(ModelAdapterBase): 

13 """""" 

14 

15 WEIGHTS_FILE_EXTENSION: str = ".h5" 

16 OPTIMIZER = "adam" 

17 LOSS = "categorical_crossentropy" 

18 METRICS = ["accuracy"] 

19 

20 def __init__(self, model: Model, optimizer: Optional[Optimizer] = None, 

21 loss: Optional[str] = None, metrics: Optional[list[str]] = None) -> None: 

22 """""" 

23 

24 if optimizer is None: 

25 optimizer = TFModelAdapter.OPTIMIZER 

26 if loss is None: 

27 loss = TFModelAdapter.LOSS 

28 if metrics is None: 

29 metrics = TFModelAdapter.METRICS 

30 

31 self.__model: Model = model 

32 self.__model.compile(optimizer = optimizer, 

33 loss = loss, 

34 metrics = metrics) 

35 

36 def load_model(self, path: str) -> None: 

37 """""" 

38 

39 if TFModelAdapter.WEIGHTS_FILE_EXTENSION not in path: 

40 raise ValueError(f"Model path must end with '{TFModelAdapter.WEIGHTS_FILE_EXTENSION}'.") 

41 

42 self.__model.load_weights(path) 

43 

44 def save_model(self, path: str) -> None: 

45 """""" 

46 

47 if TFModelAdapter.WEIGHTS_FILE_EXTENSION not in path: 

48 raise ValueError(f"Model path must end with '{TFModelAdapter.WEIGHTS_FILE_EXTENSION}'.") 

49 

50 self.__model.save_weights(path) 

51 

52 def print_summary(self, print_function: Callable = print) -> None: 

53 """""" 

54 

55 self.__model.summary(print_fn = print_function) 

56 

57 def fit(self, input_data: Any, output_data: Any, **kwargs) -> dict: 

58 """""" 

59 

60 return self.__model.fit(input_data, output_data, 

61 validation_split = 0.1, **kwargs) 

62 

63 def predict(self, data: Any) -> dict: 

64 """""" 

65 

66 return self.__model.predict(data) 

67 

68 def get_model(self) -> Model: 

69 """""" 

70 

71 return self.__model