Coverage for source/model/model_adapters/sklearn_model_adapter.py: 90%
51 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-30 20:59 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-30 20:59 +0000
1# model/model_adapters/sklearn_model_adapter.py
3# global imports
4import joblib
5import numpy as np
6import os
7from sklearn.base import BaseEstimator
8from sklearn.calibration import CalibratedClassifierCV
9from sklearn.model_selection import learning_curve, StratifiedKFold
10from typing import Any, Callable
12# local imports
13from source.model import ModelAdapterBase
15class SklearnModelAdapter(ModelAdapterBase):
16 """
17 Implements a model adapter for scikit-learn models. It provides methods for loading,
18 saving, printing summaries, fitting, predicting, and retrieving the model.
19 """
21 # global class constants
22 TAG: str = "sklearn"
24 # local constants
25 __MODEL_FILE_EXTENSION: str = ".pkl"
27 def __init__(self, model: BaseEstimator) -> None:
28 """
29 Initializes the SklearnModelAdapter with a scikit-learn model.
31 Parameters:
32 model (BaseEstimator): The scikit-learn model to adapt.
33 """
35 self.__model: BaseEstimator = model
37 def load_model(self, path: str) -> None:
38 """
39 Loads a scikit-learn model from a file.
41 Parameters:
42 path (str): The path to the model file.
44 Raises:
45 ValueError: If the path does not end with the expected model file extension.
46 FileNotFoundError: If the model file does not exist at the specified path.
47 """
49 if not path.endswith(self.__MODEL_FILE_EXTENSION):
50 raise ValueError(f"Model path must end with '{self.__MODEL_FILE_EXTENSION}'.")
52 if not os.path.exists(path):
53 raise FileNotFoundError(f"Model file not found: {path}")
55 loaded_model = joblib.load(path)
56 self.__model = loaded_model
58 def save_model(self, path: str) -> None:
59 """
60 Saves a scikit-learn model to a file.
62 Parameters:
63 path (str): The path to the model file.
65 Raises:
66 ValueError: If the path does not end with the expected model file extension.
67 """
69 if not path.endswith(self.__MODEL_FILE_EXTENSION):
70 raise ValueError(f"Model path must end with '{self.__MODEL_FILE_EXTENSION}'.")
72 os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok = True)
73 joblib.dump(self.__model, path)
75 def print_summary(self, print_function: Callable = print) -> None:
76 """
77 Prints a summary of the model's name and parameters.
78 Uses the provided print function to output the summary.
80 Parameters:
81 print_function (Callable): The function to use for printing the summary.
82 """
84 print_function(f"{'-'*80}")
85 print_function(f"Model Summary: {type(self.__model).__name__}")
86 print_function(f"{'-'*80}")
88 print_function("Model Parameters:")
89 params = self.__model.get_params()
90 for key, value in sorted(params.items()):
91 print_function(f" {key}: {value}")
92 print_function(f"{'-'*80}")
94 def fit(self, input_data: Any, output_data: Any, validation_data: Any, **kwargs) -> dict:
95 """
96 Fits the model to the provided input and output data.
98 Parameters:
99 input_data (Any): The input data for fitting the model.
100 output_data (Any): The output data for fitting the model.
101 validation_data (Any): The validation data for evaluating the model.
102 (**kwargs): Additional keyword arguments for fitting the model.
104 Returns:
105 (dict): A dictionary containing the results of the fitting process.
106 """
108 is_verbose = self.__model.get_params().get('verbose', False)
109 self.__model.set_params(verbose = False)
111 cv = StratifiedKFold(n_splits = 5, shuffle = True, random_state = 42)
112 train_sizes = np.linspace(0.1, 1.0, 5)
113 train_sizes_abs, train_scores, valid_scores = learning_curve(
114 self.__model, input_data, output_data,
115 train_sizes = train_sizes, cv = cv,
116 scoring = 'accuracy', n_jobs = -1
117 )
119 self.__model.set_params(verbose = is_verbose)
120 summary_data = {
121 "learning_curve_data_train_sizes": train_sizes_abs,
122 "learning_curve_data_train_scores": train_scores,
123 "learning_curve_data_valid_scores": valid_scores
124 }
126 if not hasattr(self.__model, "predict_proba"):
127 self.__model = CalibratedClassifierCV(self.__model, cv = cv)
128 self.__model.fit(input_data, output_data, **kwargs)
129 summary_data["validation_data_score"] = self.__model.score(*validation_data)
131 return summary_data
133 def predict(self, data: Any) -> dict:
134 """
135 Predicts probabilities for the output for the given input data.
137 Parameters:
138 data (Any): The input data for prediction.
140 Returns:
141 (dict): The predicted output data.
142 """
144 return self.__model.predict_proba(data)
146 def get_model(self) -> BaseEstimator:
147 """
148 Retrieves the underlying model.
150 Returns:
151 (BaseEstimator): The underlying model instance.
152 """
154 return self.__model