Coverage for source/model/model_adapters/model_adapter_base.py: 68%

25 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-08-01 20:51 +0000

1# model/model_adapters/model_adapter_base.py 

2 

3# global imports 

4import inspect 

5from abc import ABC, abstractmethod 

6from typing import Any, Callable 

7 

8# local imports 

9 

10class ModelAdapterBase(ABC): 

11 """ 

12 Implements a base class for model adapters. It provides an interface for loading, 

13 saving, printing summaries, fitting, predicting, and retrieving the model. 

14 """ 

15 

16 @abstractmethod 

17 def load_model(self, path: str) -> None: 

18 """ 

19 Loads a model from the specified path. 

20 

21 Parameters: 

22 path (str): The path to the model file. 

23 """ 

24 

25 pass 

26 

27 @abstractmethod 

28 def save_model(self, path: str) -> None: 

29 """ 

30 Saves the model to the specified path. 

31 

32 Parameters: 

33 path (str): The path to the model file. 

34 """ 

35 

36 pass 

37 

38 @abstractmethod 

39 def print_summary(self, print_function: Callable = print) -> None: 

40 """ 

41 Prints a summary of the model's architecture and parameters. 

42 

43 Parameters: 

44 print_function (Callable): The function to use for printing the summary. 

45 Defaults to the built-in print function. 

46 """ 

47 

48 pass 

49 

50 @abstractmethod 

51 def fit(self, input_data: Any, output_data: Any, **kwargs) -> Any: 

52 """ 

53 Fits the model to the provided input and output data. 

54 

55 Parameters: 

56 input_data (Any): The input data for fitting the model. 

57 output_data (Any): The output data for fitting the model. 

58 (**kwargs): Additional keyword arguments for fitting the model. 

59 

60 Returns: 

61 (Any): The result of the fitting process, which may vary depending on the model. 

62 """ 

63 

64 pass 

65 

66 @abstractmethod 

67 def predict(self, data: Any) -> Any: 

68 """ 

69 Predicts the output for the given input data. 

70 

71 Parameters: 

72 data (Any): The input data for prediction. 

73 

74 Returns: 

75 (Any): The predicted output data. 

76 """ 

77 

78 pass 

79 

80 @abstractmethod 

81 def get_model(self) -> Any: 

82 """ 

83 Retrieves the underlying model. 

84 

85 Returns: 

86 (Any): The underlying model instance. 

87 """ 

88 

89 pass 

90 

91 def report_parameters_needed_for_fitting(self) -> list[str]: 

92 """ 

93 Reports the parameters needed for fitting the model. 

94 

95 Returns: 

96 (list[str]): A list of parameter names that are required for fitting. 

97 """ 

98 

99 fit_function_params = dict(inspect.signature(self.fit).parameters) 

100 

101 return [name for name, param in fit_function_params.items() 

102 if param.default is param.empty and param.kind not in [param.VAR_KEYWORD, param.VAR_POSITIONAL]]