Coverage for source/model/model_adapters/sklearn_model_adapter.py: 91%

54 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-09-29 20:04 +0000

1# model/model_adapters/sklearn_model_adapter.py 

2 

3# global imports 

4import joblib 

5import numpy as np 

6import os 

7from sklearn.base import BaseEstimator 

8from sklearn.calibration import CalibratedClassifierCV 

9from sklearn.model_selection import learning_curve, StratifiedKFold 

10from typing import Any, Callable 

11 

12# local imports 

13from source.model import ModelAdapterBase 

14 

15class SklearnModelAdapter(ModelAdapterBase): 

16 """ 

17 Implements a model adapter for scikit-learn models. It provides methods for loading, 

18 saving, printing summaries, fitting, predicting, and retrieving the model. 

19 """ 

20 

21 # global class constants 

22 TAG: str = "sklearn" 

23 

24 # local constants 

25 __MODEL_FILE_EXTENSION: str = ".pkl" 

26 

27 def __init__(self, model: BaseEstimator, should_compute_learning_curve: bool = True) -> None: 

28 """ 

29 Initializes the SklearnModelAdapter with a scikit-learn model. 

30 

31 Parameters: 

32 model (BaseEstimator): The scikit-learn model to adapt. 

33 should_compute_learning_curve (bool): Flag indicating whether to compute the 

34 learning curve. Defaults to True. 

35 """ 

36 

37 self.__model: BaseEstimator = model 

38 self.__should_compute_learning_curve: bool = should_compute_learning_curve 

39 

40 def load_model(self, path: str) -> None: 

41 """ 

42 Loads a scikit-learn model from a file. 

43 

44 Parameters: 

45 path (str): The path to the model file. 

46 

47 Raises: 

48 ValueError: If the path does not end with the expected model file extension. 

49 FileNotFoundError: If the model file does not exist at the specified path. 

50 """ 

51 

52 if not path.endswith(self.__MODEL_FILE_EXTENSION): 

53 raise ValueError(f"Model path must end with '{self.__MODEL_FILE_EXTENSION}'.") 

54 

55 if not os.path.exists(path): 

56 raise FileNotFoundError(f"Model file not found: {path}") 

57 

58 loaded_model = joblib.load(path) 

59 self.__model = loaded_model 

60 

61 def save_model(self, path: str) -> None: 

62 """ 

63 Saves a scikit-learn model to a file. 

64 

65 Parameters: 

66 path (str): The path to the model file. 

67 

68 Raises: 

69 ValueError: If the path does not end with the expected model file extension. 

70 """ 

71 

72 if not path.endswith(self.__MODEL_FILE_EXTENSION): 

73 raise ValueError(f"Model path must end with '{self.__MODEL_FILE_EXTENSION}'.") 

74 

75 os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok = True) 

76 joblib.dump(self.__model, path) 

77 

78 def print_summary(self, print_function: Callable = print) -> None: 

79 """ 

80 Prints a summary of the model's name and parameters. 

81 Uses the provided print function to output the summary. 

82 

83 Parameters: 

84 print_function (Callable): The function to use for printing the summary. 

85 """ 

86 

87 print_function(f"{'-'*80}") 

88 print_function(f"Model Summary: {type(self.__model).__name__}") 

89 print_function(f"{'-'*80}") 

90 

91 print_function("Model Parameters:") 

92 params = self.__model.get_params() 

93 for key, value in sorted(params.items()): 

94 print_function(f" {key}: {value}") 

95 print_function(f"{'-'*80}") 

96 

97 def fit(self, input_data: Any, output_data: Any, validation_data: Any, **kwargs) -> dict: 

98 """ 

99 Fits the model to the provided input and output data. 

100 

101 Parameters: 

102 input_data (Any): The input data for fitting the model. 

103 output_data (Any): The output data for fitting the model. 

104 validation_data (Any): The validation data for evaluating the model. 

105 (**kwargs): Additional keyword arguments for fitting the model. 

106 

107 Returns: 

108 (dict): A dictionary containing the results of the fitting process. 

109 """ 

110 

111 summary_data = {} 

112 cv = StratifiedKFold(n_splits = 5, shuffle = True, random_state = 42) 

113 if self.__should_compute_learning_curve: 

114 is_verbose = self.__model.get_params().get('verbose', False) 

115 self.__model.set_params(verbose = False) 

116 

117 train_sizes = np.linspace(0.1, 1.0, 5) 

118 train_sizes_abs, train_scores, valid_scores = learning_curve( 

119 self.__model, input_data, output_data, 

120 train_sizes = train_sizes, cv = cv, 

121 scoring = 'accuracy', n_jobs = -1 

122 ) 

123 

124 self.__model.set_params(verbose = is_verbose) 

125 summary_data = { 

126 "learning_curve_data_train_sizes": train_sizes_abs, 

127 "learning_curve_data_train_scores": train_scores, 

128 "learning_curve_data_valid_scores": valid_scores 

129 } 

130 

131 if not hasattr(self.__model, "predict_proba"): 

132 self.__model = CalibratedClassifierCV(self.__model, cv = cv) 

133 self.__model.fit(input_data, output_data, **kwargs) 

134 summary_data["validation_data_score"] = self.__model.score(*validation_data) 

135 

136 return summary_data 

137 

138 def predict(self, data: Any) -> dict: 

139 """ 

140 Predicts probabilities for the output for the given input data. 

141 

142 Parameters: 

143 data (Any): The input data for prediction. 

144 

145 Returns: 

146 (dict): The predicted output data. 

147 """ 

148 

149 return self.__model.predict_proba(data) 

150 

151 def get_model(self) -> BaseEstimator: 

152 """ 

153 Retrieves the underlying model. 

154 

155 Returns: 

156 (BaseEstimator): The underlying model instance. 

157 """ 

158 

159 return self.__model