Coverage for source/plotting/classification_training_plot_responsibility_chain.py: 92%

84 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-08-01 20:51 +0000

1# plotting/classification_training_plot_responsibility_chain.py 

2 

3# global imports 

4import logging 

5import matplotlib.pyplot as plt 

6from enum import Enum 

7from matplotlib.gridspec import GridSpec 

8 

9# local imports 

10from source.agent import ClassificationLearningStrategyHandler 

11from source.model import SklearnModelAdapter, TFModelAdapter 

12from source.plotting import PlotResponsibilityChainBase 

13 

14class ClassificationTrainingPlotResponsibilityChain(PlotResponsibilityChainBase): 

15 """ 

16 Implements a plotting responsibility chain for classification training results. 

17 It implements the _can_plot and _plot methods to visualize training loss and accuracy. 

18 """ 

19 

20 class PlottingMode(Enum): 

21 """ 

22 Enumeration for the different plotting modes. 

23 It defines the modes for TensorFlow and Scikit-learn models. 

24 """ 

25 

26 UNSUPPORTED = "unsupported" 

27 TF = TFModelAdapter.TAG 

28 SKLEARN = SklearnModelAdapter.TAG 

29 

30 def __init__(self) -> None: 

31 """ 

32 Class constructor. Initializes the plotting responsibility chain with an unsupported mode. 

33 """ 

34 

35 self.__mode: ClassificationTrainingPlotResponsibilityChain.PlottingMode = ClassificationTrainingPlotResponsibilityChain.PlottingMode.UNSUPPORTED 

36 

37 def _can_plot(self, key: str) -> bool: 

38 """ 

39 Checks if the plot can be generated for the given key. 

40 

41 Parameters: 

42 key (str): The key to check. 

43 

44 Returns: 

45 (bool): True if the plot can be generated, False otherwise. 

46 """ 

47 

48 if TFModelAdapter.TAG in key or SklearnModelAdapter.TAG in key: 

49 self.__mode = ClassificationTrainingPlotResponsibilityChain.PlottingMode(key.split("_")[-1]) 

50 

51 return ClassificationLearningStrategyHandler.PLOTTING_KEYS[1] in key 

52 else: 

53 self.__mode = ClassificationTrainingPlotResponsibilityChain.PlottingMode.UNSUPPORTED 

54 

55 return False 

56 

57 def _plot(self, plot_data: dict) -> plt.Axes: 

58 """ 

59 Generates the plot based on the current mode and provided data. 

60 

61 Parameters: 

62 plot_data (dict): The data to be plotted. 

63 

64 Returns: 

65 (plt.Axes): The axes object containing the plot. 

66 """ 

67 

68 if self.__mode == ClassificationTrainingPlotResponsibilityChain.PlottingMode.TF: 

69 return self.__plot_tf(plot_data) 

70 elif self.__mode == ClassificationTrainingPlotResponsibilityChain.PlottingMode.SKLEARN: 

71 return self.__plot_sklearn(plot_data) 

72 else: 

73 raise ValueError(f"Unsupported plotting mode: {self.__mode}. Expected 'tensorflow' or 'sklearn'.") 

74 

75 def __plot_tf(self, plot_data: dict) -> plt.Axes: 

76 """ 

77 Generates the TensorFlow training plot. 

78 

79 Parameters: 

80 plot_data (dict): The data to be plotted. 

81 

82 Returns: 

83 (plt.Axes): The axes object containing the plot. 

84 """ 

85 

86 loss = plot_data.get("loss", None) 

87 accuracy = plot_data.get("accuracy", None) 

88 val_loss = plot_data.get("val_loss", None) 

89 val_accuracy = plot_data.get("val_accuracy", None) 

90 

91 if loss is None or accuracy is None or val_loss is None or val_accuracy is None: 

92 logging.warning(f"Insufficient data for plotting results under key: {ClassificationLearningStrategyHandler.PLOTTING_KEYS[1]}.") 

93 plt.text(0.5, 0.5, "Insufficient data for plotting", 

94 ha = 'center', va = 'center', fontsize = 12) 

95 return plt.gca() 

96 

97 fig = plt.figure(figsize = self._EXPECTED_FIGURE_SIZE) 

98 gs = GridSpec(2, 1, figure = fig) 

99 

100 # Plot 1: Training loss and accuracy 

101 ax1 = plt.subplot(gs[0, 0]) 

102 ax1.set_title("Training loss and accuracy") 

103 ax1.plot(loss, label = 'Loss') 

104 ax1.plot(accuracy, label = 'Accuracy') 

105 ax1.set_xlabel('Epoch') 

106 ax1.set_ylabel('Value') 

107 ax1.legend(loc = 'upper left') 

108 

109 # Plot 2: Validation loss and accuracy 

110 ax2 = plt.subplot(gs[1, 0]) 

111 ax2.set_title("Validation loss and accuracy") 

112 ax2.plot(val_loss, label = 'Validation Loss') 

113 ax2.plot(val_accuracy, label = 'Validation Accuracy') 

114 ax2.set_xlabel('Epoch') 

115 ax2.set_ylabel('Value') 

116 ax2.legend(loc = 'upper left') 

117 plt.tight_layout() 

118 

119 return plt.gca() 

120 

121 def __plot_sklearn(self, plot_data: dict) -> plt.Axes: 

122 """ 

123 Generates the Scikit-learn training plot. 

124 

125 Parameters: 

126 plot_data (dict): The data to be plotted. 

127 

128 Returns: 

129 (plt.Axes): The axes object containing the plot. 

130 """ 

131 

132 learning_curve_data_train_sizes = plot_data.get("learning_curve_data_train_sizes", None) 

133 learning_curve_data_train_scores = plot_data.get("learning_curve_data_train_scores", None) 

134 learning_curve_data_valid_scores = plot_data.get("learning_curve_data_valid_scores", None) 

135 validation_data_score = plot_data.get("validation_data_score", None) 

136 

137 if learning_curve_data_train_sizes is None or learning_curve_data_train_scores is None \ 

138 or learning_curve_data_valid_scores is None or validation_data_score is None: 

139 logging.warning(f"Insufficient data for plotting results under key: {ClassificationLearningStrategyHandler.PLOTTING_KEYS[1]}.") 

140 plt.text(0.5, 0.5, "Insufficient data for plotting", 

141 ha = 'center', va = 'center', fontsize = 12) 

142 return plt.gca() 

143 

144 fig = plt.figure(figsize = self._EXPECTED_FIGURE_SIZE) 

145 gs = GridSpec(2, 1, figure = fig) 

146 

147 train_scores_mean = learning_curve_data_train_scores.mean(axis = 1) 

148 train_scores_std = learning_curve_data_train_scores.std(axis = 1) 

149 valid_scores_mean = learning_curve_data_valid_scores.mean(axis = 1) 

150 valid_scores_std = learning_curve_data_valid_scores.std(axis = 1) 

151 

152 # Plot 1: Training learning curve 

153 ax1 = plt.subplot(gs[0, 0]) 

154 ax1.set_title("Training learning curve") 

155 ax1.fill_between(learning_curve_data_train_sizes, train_scores_mean - train_scores_std, 

156 train_scores_mean + train_scores_std, alpha = 0.1, color = "r") 

157 ax1.plot(learning_curve_data_train_sizes, train_scores_mean, 'o-', color = "r", label = "Training score") 

158 ax1.set_xlabel('Training examples') 

159 ax1.set_ylabel('Accuracy') 

160 ax1.legend(loc = 'upper left') 

161 

162 # Plot 2: Validation learning curve 

163 ax2 = plt.subplot(gs[1, 0]) 

164 ax2.set_title(f"Validation learning curve, score: {validation_data_score:.2f}") 

165 ax2.fill_between(learning_curve_data_train_sizes, valid_scores_mean - valid_scores_std, 

166 valid_scores_mean + valid_scores_std, alpha = 0.1, color = "g") 

167 ax2.plot(learning_curve_data_train_sizes, valid_scores_mean, 'o-', color = "g", label = "Cross-validation score") 

168 ax2.set_xlabel('Validation examples') 

169 ax2.set_ylabel('Accuracy') 

170 ax2.legend(loc = 'upper left') 

171 plt.tight_layout() 

172 

173 return plt.gca()