Coverage for source/model/model_adapters/model_adapter_base.py: 68%
25 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-30 20:59 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-30 20:59 +0000
1# model/model_adapters/model_adapter_base.py
3# global imports
4import inspect
5from abc import ABC, abstractmethod
6from typing import Any, Callable
8# local imports
10class ModelAdapterBase(ABC):
11 """
12 Implements a base class for model adapters. It provides an interface for loading,
13 saving, printing summaries, fitting, predicting, and retrieving the model.
14 """
16 @abstractmethod
17 def load_model(self, path: str) -> None:
18 """
19 Loads a model from the specified path.
21 Parameters:
22 path (str): The path to the model file.
23 """
25 pass
27 @abstractmethod
28 def save_model(self, path: str) -> None:
29 """
30 Saves the model to the specified path.
32 Parameters:
33 path (str): The path to the model file.
34 """
36 pass
38 @abstractmethod
39 def print_summary(self, print_function: Callable = print) -> None:
40 """
41 Prints a summary of the model's architecture and parameters.
43 Parameters:
44 print_function (Callable): The function to use for printing the summary.
45 Defaults to the built-in print function.
46 """
48 pass
50 @abstractmethod
51 def fit(self, input_data: Any, output_data: Any, **kwargs) -> Any:
52 """
53 Fits the model to the provided input and output data.
55 Parameters:
56 input_data (Any): The input data for fitting the model.
57 output_data (Any): The output data for fitting the model.
58 (**kwargs): Additional keyword arguments for fitting the model.
60 Returns:
61 (Any): The result of the fitting process, which may vary depending on the model.
62 """
64 pass
66 @abstractmethod
67 def predict(self, data: Any) -> Any:
68 """
69 Predicts the output for the given input data.
71 Parameters:
72 data (Any): The input data for prediction.
74 Returns:
75 (Any): The predicted output data.
76 """
78 pass
80 @abstractmethod
81 def get_model(self) -> Any:
82 """
83 Retrieves the underlying model.
85 Returns:
86 (Any): The underlying model instance.
87 """
89 pass
91 def report_parameters_needed_for_fitting(self) -> list[str]:
92 """
93 Reports the parameters needed for fitting the model.
95 Returns:
96 (list[str]): A list of parameter names that are required for fitting.
97 """
99 fit_function_params = dict(inspect.signature(self.fit).parameters)
101 return [name for name, param in fit_function_params.items()
102 if param.default is param.empty and param.kind not in [param.VAR_KEYWORD, param.VAR_POSITIONAL]]