Coverage for source/plotting/classification_testing_plot_responsibility_chain.py: 97%

107 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-09-14 18:08 +0000

1# plotting/classification_testing_plot_responsibility_chain.py 

2 

3# global imports 

4import logging 

5import matplotlib.pyplot as plt 

6import numpy as np 

7from matplotlib.colors import Normalize 

8from matplotlib.gridspec import GridSpec 

9from sklearn.metrics import RocCurveDisplay 

10 

11# local imports 

12from source.agent import ClassificationTestingStrategyHandler 

13from source.plotting import PlotResponsibilityChainBase 

14 

15class ClassificationTestingPlotResponsibilityChain(PlotResponsibilityChainBase): 

16 """ 

17 Implements a plotting responsibility chain for classification testing results. 

18 It implements the _can_plot and _plot methods to visualize confusion matrices, 

19 classification reports, and ROC curves. 

20 """ 

21 

22 # local constants 

23 __ADDITIONAL_REPORT_LABELS = ["accuracy", "macro avg", "weighted avg"] 

24 

25 def _can_plot(self, key: str) -> bool: 

26 """ 

27 Checks if the plot can be generated for the given key. 

28 

29 Parameters: 

30 key (str): The key to check. 

31 

32 Returns: 

33 (bool): True if the plot can be generated, False otherwise. 

34 """ 

35 

36 return key == ClassificationTestingStrategyHandler.PLOTTING_KEY 

37 

38 def _plot(self, plot_data: dict) -> plt.Axes: 

39 """ 

40 Generates the classification testing plot based on the provided data. 

41 

42 Parameters: 

43 plot_data (dict): The data to be plotted. 

44 

45 Returns: 

46 (plt.Axes): The axes object containing the plot. 

47 """ 

48 

49 conf_matrix = plot_data.get("confusion_matrix", None) 

50 class_report = plot_data.get("classification_report", None) 

51 prediction_probabilities = plot_data.get("prediction_probabilities", None) 

52 true_labels = plot_data.get("true_labels", None) 

53 

54 if conf_matrix is None or class_report is None or prediction_probabilities is None or true_labels is None: 

55 logging.warning(f"Insufficient data for plotting results under key: {ClassificationTestingStrategyHandler.PLOTTING_KEY}.") 

56 plt.text(0.5, 0.5, "Insufficient data for plotting", 

57 ha = 'center', va = 'center', fontsize = 12) 

58 return plt.gca() 

59 

60 additional_report = {} 

61 for additional_label in self.__ADDITIONAL_REPORT_LABELS: 

62 if additional_label in class_report: 

63 additional_report[additional_label] = class_report.pop(additional_label) 

64 

65 fig = plt.figure(figsize = self._EXPECTED_FIGURE_SIZE) 

66 gs = GridSpec(2, 2, figure = fig) 

67 classes = list(class_report.keys()) 

68 shortened_classes_names = [class_name[:3] for class_name in classes] 

69 

70 # Plot 1: Confusion Matrix as a heatmap 

71 ax1 = plt.subplot(gs[0, 0]) 

72 ax1.set_title(f"Confusion Matrix (Accuracy: {additional_report['accuracy']:.2%})") 

73 

74 normalized_conf_matrix = conf_matrix.astype('float') / conf_matrix.sum(axis = 1, keepdims = True) 

75 normalized_conf_matrix = np.round(np.nan_to_num(normalized_conf_matrix, nan = 0.0), 2) 

76 ax1.imshow(normalized_conf_matrix, interpolation = 'nearest', cmap = plt.cm.GnBu, vmin = 0, vmax = 1) 

77 

78 tick_marks = np.arange(len(classes)) 

79 ax1.set_xticks(tick_marks) 

80 ax1.set_yticks(tick_marks) 

81 ax1.set_xticklabels(shortened_classes_names) 

82 ax1.set_yticklabels(shortened_classes_names) 

83 ax1.set_xlabel('Predicted label') 

84 ax1.set_ylabel('True label') 

85 

86 for i in range(conf_matrix.shape[0]): 

87 for j in range(conf_matrix.shape[1]): 

88 color = "white" if normalized_conf_matrix[i, j] > 0.5 else "black" 

89 ax1.text(j, i - 0.1, format(conf_matrix[i, j], 'd'), 

90 ha = "center", va = "center", fontsize = 10, weight = 'bold', color = color) 

91 ax1.text(j, i + 0.15, f'{normalized_conf_matrix[i, j]:.2f}', 

92 ha = "center", va = "center", fontsize = 8, color = color) 

93 

94 # Plot 2: Precision, Recall, F1 Score Bar Chart 

95 ax2 = plt.subplot(gs[1, 0]) 

96 precision_scores = [] 

97 recall_scores = [] 

98 f1_scores = [] 

99 

100 for metrics_dict in class_report.values(): 

101 precision_scores.append(metrics_dict["precision"]) 

102 recall_scores.append(metrics_dict["recall"]) 

103 f1_scores.append(metrics_dict["f1-score"]) 

104 

105 shift = 0.2 

106 precision_bars = ax2.bar(tick_marks - shift, precision_scores, shift, label = 'Precision') 

107 recall_bars = ax2.bar(tick_marks, recall_scores, shift, label = 'Recall') 

108 f1_bars = ax2.bar(tick_marks + shift, f1_scores, shift, label = 'F1-score') 

109 

110 for i, (precision_bar, recall_bar, f1_bar) in enumerate(zip(precision_bars, recall_bars, f1_bars)): 

111 ax2.text(precision_bar.get_x() + (precision_bar.get_width() / 2), 

112 precision_bar.get_height() + 0.01 if precision_bar.get_height() < 0.9 else \ 

113 precision_bar.get_height() - 0.01, f'{precision_scores[i]:.3f}', 

114 ha = 'center', va = 'bottom' if precision_bar.get_height() < 0.9 else 'top', rotation = 90, 

115 fontsize = 8, weight = 'bold') 

116 ax2.text(recall_bar.get_x() + (recall_bar.get_width() / 2), 

117 recall_bar.get_height() + 0.01 if recall_bar.get_height() < 0.9 else \ 

118 recall_bar.get_height() - 0.01, f'{recall_scores[i]:.3f}', 

119 ha = 'center', va = 'bottom' if recall_bar.get_height() < 0.9 else 'top', rotation = 90, 

120 fontsize = 8, weight = 'bold') 

121 ax2.text(f1_bar.get_x() + (f1_bar.get_width() / 2), 

122 f1_bar.get_height() + 0.01 if f1_bar.get_height() < 0.9 else \ 

123 f1_bar.get_height() - 0.01, f'{f1_scores[i]:.3f}', 

124 ha = 'center', va = 'bottom' if f1_bar.get_height() < 0.9 else 'top', rotation = 90, 

125 fontsize = 8, weight = 'bold') 

126 

127 ax2.set_title('Classification metrics by class') 

128 ax2.set_xticks(tick_marks) 

129 ax2.set_xticklabels(shortened_classes_names) 

130 ax2.set_xlabel('Classes') 

131 ax2.set_ylabel('Score') 

132 ax2.set_ylim([0, 1]) 

133 ax2.legend(fontsize = 'x-small') 

134 

135 # Plot 3: OvR-ROC curves 

136 ax3 = plt.subplot(gs[0, 1]) 

137 

138 for i, class_name in enumerate(classes): 

139 y_true_class_binary = (true_labels == i).astype(int) 

140 y_score = prediction_probabilities[:, i] 

141 RocCurveDisplay.from_predictions(y_true_class_binary, y_score, name = f"{class_name}", 

142 ax = ax3, plot_chance_level = (i == len(classes) - 1)) 

143 

144 ax3.set_title('One-vs-Rest ROC curves') 

145 ax3.set_xlabel('False positive rate') 

146 ax3.set_ylabel('True positive rate') 

147 ax3.grid(alpha = 0.3) 

148 ax3.legend(loc = "lower right", fontsize = 'x-small') 

149 plt.tight_layout() 

150 

151 # Plot 4: Macro avg and weighted avg 

152 ax4 = plt.subplot(gs[1, 1]) 

153 additional_labels = list(additional_report.keys())[1:] 

154 precision_scores = [] 

155 recall_scores = [] 

156 f1_scores = [] 

157 

158 for metrics in additional_report.values(): 

159 if isinstance(metrics, dict): 

160 precision_scores.append(metrics['precision']) 

161 recall_scores.append(metrics['recall']) 

162 f1_scores.append(metrics['f1-score']) 

163 

164 x = np.arange(len(additional_labels)) 

165 precision_bars = ax4.bar(x - shift, precision_scores, shift, label = 'Precision') 

166 recall_bars = ax4.bar(x, recall_scores, shift, label = 'Recall') 

167 f1_bars = ax4.bar(x + shift, f1_scores, shift, label = 'F1-score') 

168 

169 for i, (precision_bar, recall_bar, f1_bar) in enumerate(zip(precision_bars, recall_bars, f1_bars)): 

170 ax4.text(precision_bar.get_x() + (precision_bar.get_width() / 2), 

171 precision_bar.get_height() + 0.01 if precision_bar.get_height() < 0.9 else \ 

172 precision_bar.get_height() - 0.01, f'{precision_scores[i]:.3f}', 

173 ha = 'center', va = 'bottom' if precision_bar.get_height() < 0.9 else 'top', rotation = 90, 

174 fontsize = 8, weight = 'bold') 

175 ax4.text(recall_bar.get_x() + (recall_bar.get_width() / 2), 

176 recall_bar.get_height() + 0.01 if recall_bar.get_height() < 0.9 else \ 

177 recall_bar.get_height() - 0.01, f'{recall_scores[i]:.3f}', 

178 ha = 'center', va = 'bottom' if recall_bar.get_height() < 0.9 else 'top', rotation = 90, 

179 fontsize = 8, weight = 'bold') 

180 ax4.text(f1_bar.get_x() + (f1_bar.get_width() / 2), 

181 f1_bar.get_height() + 0.01 if f1_bar.get_height() < 0.9 else \ 

182 f1_bar.get_height() - 0.01, f'{f1_scores[i]:.3f}', 

183 ha = 'center', va = 'bottom' if f1_bar.get_height() < 0.9 else 'top', rotation = 90, 

184 fontsize = 8, weight = 'bold') 

185 

186 ax4.set_title('Macro avg and weighted avg') 

187 ax4.set_xticks(x) 

188 ax4.set_xticklabels(additional_labels) 

189 ax4.set_xlabel('Metrics') 

190 ax4.set_ylabel('Score') 

191 ax4.set_ylim([0, 1]) 

192 ax4.legend(fontsize = 'x-small') 

193 plt.tight_layout() 

194 

195 return plt.gca()