Coverage for source/model/model_adapters/sklearn_model_adapter.py: 91%
54 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-23 15:31 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-23 15:31 +0000
1# model/model_adapters/sklearn_model_adapter.py
3# global imports
4import joblib
5import numpy as np
6import os
7from sklearn.base import BaseEstimator
8from sklearn.calibration import CalibratedClassifierCV
9from sklearn.model_selection import learning_curve, StratifiedKFold
10from typing import Any, Callable
12# local imports
13from source.model import ModelAdapterBase
15class SklearnModelAdapter(ModelAdapterBase):
16 """
17 Implements a model adapter for scikit-learn models. It provides methods for loading,
18 saving, printing summaries, fitting, predicting, and retrieving the model.
19 """
21 # global class constants
22 TAG: str = "sklearn"
24 # local constants
25 __MODEL_FILE_EXTENSION: str = ".pkl"
27 def __init__(self, model: BaseEstimator, should_compute_learning_curve: bool = True) -> None:
28 """
29 Initializes the SklearnModelAdapter with a scikit-learn model.
31 Parameters:
32 model (BaseEstimator): The scikit-learn model to adapt.
33 should_compute_learning_curve (bool): Flag indicating whether to compute the
34 learning curve. Defaults to True.
35 """
37 self.__model: BaseEstimator = model
38 self.__should_compute_learning_curve: bool = should_compute_learning_curve
40 def load_model(self, path: str) -> None:
41 """
42 Loads a scikit-learn model from a file.
44 Parameters:
45 path (str): The path to the model file.
47 Raises:
48 ValueError: If the path does not end with the expected model file extension.
49 FileNotFoundError: If the model file does not exist at the specified path.
50 """
52 if not path.endswith(self.__MODEL_FILE_EXTENSION):
53 raise ValueError(f"Model path must end with '{self.__MODEL_FILE_EXTENSION}'.")
55 if not os.path.exists(path):
56 raise FileNotFoundError(f"Model file not found: {path}")
58 loaded_model = joblib.load(path)
59 self.__model = loaded_model
61 def save_model(self, path: str) -> None:
62 """
63 Saves a scikit-learn model to a file.
65 Parameters:
66 path (str): The path to the model file.
68 Raises:
69 ValueError: If the path does not end with the expected model file extension.
70 """
72 if not path.endswith(self.__MODEL_FILE_EXTENSION):
73 raise ValueError(f"Model path must end with '{self.__MODEL_FILE_EXTENSION}'.")
75 os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok = True)
76 joblib.dump(self.__model, path)
78 def print_summary(self, print_function: Callable = print) -> None:
79 """
80 Prints a summary of the model's name and parameters.
81 Uses the provided print function to output the summary.
83 Parameters:
84 print_function (Callable): The function to use for printing the summary.
85 """
87 print_function(f"{'-'*80}")
88 print_function(f"Model Summary: {type(self.__model).__name__}")
89 print_function(f"{'-'*80}")
91 print_function("Model Parameters:")
92 params = self.__model.get_params()
93 for key, value in sorted(params.items()):
94 print_function(f" {key}: {value}")
95 print_function(f"{'-'*80}")
97 def fit(self, input_data: Any, output_data: Any, validation_data: Any, **kwargs) -> dict:
98 """
99 Fits the model to the provided input and output data.
101 Parameters:
102 input_data (Any): The input data for fitting the model.
103 output_data (Any): The output data for fitting the model.
104 validation_data (Any): The validation data for evaluating the model.
105 (**kwargs): Additional keyword arguments for fitting the model.
107 Returns:
108 (dict): A dictionary containing the results of the fitting process.
109 """
111 summary_data = {}
112 cv = StratifiedKFold(n_splits = 5, shuffle = True, random_state = 42)
113 if self.__should_compute_learning_curve:
114 is_verbose = self.__model.get_params().get('verbose', False)
115 self.__model.set_params(verbose = False)
117 train_sizes = np.linspace(0.1, 1.0, 5)
118 train_sizes_abs, train_scores, valid_scores = learning_curve(
119 self.__model, input_data, output_data,
120 train_sizes = train_sizes, cv = cv,
121 scoring = 'accuracy', n_jobs = -1
122 )
124 self.__model.set_params(verbose = is_verbose)
125 summary_data = {
126 "learning_curve_data_train_sizes": train_sizes_abs,
127 "learning_curve_data_train_scores": train_scores,
128 "learning_curve_data_valid_scores": valid_scores
129 }
131 if not hasattr(self.__model, "predict_proba"):
132 self.__model = CalibratedClassifierCV(self.__model, cv = cv)
133 self.__model.fit(input_data, output_data, **kwargs)
134 summary_data["validation_data_score"] = self.__model.score(*validation_data)
136 return summary_data
138 def predict(self, data: Any) -> dict:
139 """
140 Predicts probabilities for the output for the given input data.
142 Parameters:
143 data (Any): The input data for prediction.
145 Returns:
146 (dict): The predicted output data.
147 """
149 return self.__model.predict_proba(data)
151 def get_model(self) -> BaseEstimator:
152 """
153 Retrieves the underlying model.
155 Returns:
156 (BaseEstimator): The underlying model instance.
157 """
159 return self.__model