Coverage for source/model/model_adapters/sklearn_model_adapter.py: 90%

51 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-08-01 20:51 +0000

1# model/model_adapters/sklearn_model_adapter.py 

2 

3# global imports 

4import joblib 

5import numpy as np 

6import os 

7from sklearn.base import BaseEstimator 

8from sklearn.calibration import CalibratedClassifierCV 

9from sklearn.model_selection import learning_curve, StratifiedKFold 

10from typing import Any, Callable 

11 

12# local imports 

13from source.model import ModelAdapterBase 

14 

15class SklearnModelAdapter(ModelAdapterBase): 

16 """ 

17 Implements a model adapter for scikit-learn models. It provides methods for loading, 

18 saving, printing summaries, fitting, predicting, and retrieving the model. 

19 """ 

20 

21 # global class constants 

22 TAG: str = "sklearn" 

23 

24 # local constants 

25 __MODEL_FILE_EXTENSION: str = ".pkl" 

26 

27 def __init__(self, model: BaseEstimator) -> None: 

28 """ 

29 Initializes the SklearnModelAdapter with a scikit-learn model. 

30 

31 Parameters: 

32 model (BaseEstimator): The scikit-learn model to adapt. 

33 """ 

34 

35 self.__model: BaseEstimator = model 

36 

37 def load_model(self, path: str) -> None: 

38 """ 

39 Loads a scikit-learn model from a file. 

40 

41 Parameters: 

42 path (str): The path to the model file. 

43 

44 Raises: 

45 ValueError: If the path does not end with the expected model file extension. 

46 FileNotFoundError: If the model file does not exist at the specified path. 

47 """ 

48 

49 if not path.endswith(self.__MODEL_FILE_EXTENSION): 

50 raise ValueError(f"Model path must end with '{self.__MODEL_FILE_EXTENSION}'.") 

51 

52 if not os.path.exists(path): 

53 raise FileNotFoundError(f"Model file not found: {path}") 

54 

55 loaded_model = joblib.load(path) 

56 self.__model = loaded_model 

57 

58 def save_model(self, path: str) -> None: 

59 """ 

60 Saves a scikit-learn model to a file. 

61 

62 Parameters: 

63 path (str): The path to the model file. 

64 

65 Raises: 

66 ValueError: If the path does not end with the expected model file extension. 

67 """ 

68 

69 if not path.endswith(self.__MODEL_FILE_EXTENSION): 

70 raise ValueError(f"Model path must end with '{self.__MODEL_FILE_EXTENSION}'.") 

71 

72 os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok = True) 

73 joblib.dump(self.__model, path) 

74 

75 def print_summary(self, print_function: Callable = print) -> None: 

76 """ 

77 Prints a summary of the model's name and parameters. 

78 Uses the provided print function to output the summary. 

79 

80 Parameters: 

81 print_function (Callable): The function to use for printing the summary. 

82 """ 

83 

84 print_function(f"{'-'*80}") 

85 print_function(f"Model Summary: {type(self.__model).__name__}") 

86 print_function(f"{'-'*80}") 

87 

88 print_function("Model Parameters:") 

89 params = self.__model.get_params() 

90 for key, value in sorted(params.items()): 

91 print_function(f" {key}: {value}") 

92 print_function(f"{'-'*80}") 

93 

94 def fit(self, input_data: Any, output_data: Any, validation_data: Any, **kwargs) -> dict: 

95 """ 

96 Fits the model to the provided input and output data. 

97 

98 Parameters: 

99 input_data (Any): The input data for fitting the model. 

100 output_data (Any): The output data for fitting the model. 

101 validation_data (Any): The validation data for evaluating the model. 

102 (**kwargs): Additional keyword arguments for fitting the model. 

103 

104 Returns: 

105 (dict): A dictionary containing the results of the fitting process. 

106 """ 

107 

108 is_verbose = self.__model.get_params().get('verbose', False) 

109 self.__model.set_params(verbose = False) 

110 

111 cv = StratifiedKFold(n_splits = 5, shuffle = True, random_state = 42) 

112 train_sizes = np.linspace(0.1, 1.0, 5) 

113 train_sizes_abs, train_scores, valid_scores = learning_curve( 

114 self.__model, input_data, output_data, 

115 train_sizes = train_sizes, cv = cv, 

116 scoring = 'accuracy', n_jobs = -1 

117 ) 

118 

119 self.__model.set_params(verbose = is_verbose) 

120 summary_data = { 

121 "learning_curve_data_train_sizes": train_sizes_abs, 

122 "learning_curve_data_train_scores": train_scores, 

123 "learning_curve_data_valid_scores": valid_scores 

124 } 

125 

126 if not hasattr(self.__model, "predict_proba"): 

127 self.__model = CalibratedClassifierCV(self.__model, cv = cv) 

128 self.__model.fit(input_data, output_data, **kwargs) 

129 summary_data["validation_data_score"] = self.__model.score(*validation_data) 

130 

131 return summary_data 

132 

133 def predict(self, data: Any) -> dict: 

134 """ 

135 Predicts probabilities for the output for the given input data. 

136 

137 Parameters: 

138 data (Any): The input data for prediction. 

139 

140 Returns: 

141 (dict): The predicted output data. 

142 """ 

143 

144 return self.__model.predict_proba(data) 

145 

146 def get_model(self) -> BaseEstimator: 

147 """ 

148 Retrieves the underlying model. 

149 

150 Returns: 

151 (BaseEstimator): The underlying model instance. 

152 """ 

153 

154 return self.__model