Coverage for source/plotting/classification_testing_plot_responsibility_chain.py: 97%

95 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-07-30 20:59 +0000

1# plotting/classification_testing_plot_responsibility_chain.py 

2 

3# global imports 

4import logging 

5import matplotlib.pyplot as plt 

6import numpy as np 

7from matplotlib.gridspec import GridSpec 

8from sklearn.metrics import RocCurveDisplay 

9 

10# local imports 

11from source.agent import ClassificationTestingStrategyHandler 

12from source.plotting import PlotResponsibilityChainBase 

13 

14class ClassificationTestingPlotResponsibilityChain(PlotResponsibilityChainBase): 

15 """ 

16 Implements a plotting responsibility chain for classification testing results. 

17 It implements the _can_plot and _plot methods to visualize confusion matrices, 

18 classification reports, and ROC curves. 

19 """ 

20 

21 # local constants 

22 __ADDITIONAL_REPORT_LABELS = ["accuracy", "macro avg", "weighted avg"] 

23 

24 def _can_plot(self, key: str) -> bool: 

25 """ 

26 Checks if the plot can be generated for the given key. 

27 

28 Parameters: 

29 key (str): The key to check. 

30 

31 Returns: 

32 (bool): True if the plot can be generated, False otherwise. 

33 """ 

34 

35 return key == ClassificationTestingStrategyHandler.PLOTTING_KEY 

36 

37 def _plot(self, plot_data: dict) -> plt.Axes: 

38 """ 

39 Generates the classification testing plot based on the provided data. 

40 

41 Parameters: 

42 plot_data (dict): The data to be plotted. 

43 

44 Returns: 

45 (plt.Axes): The axes object containing the plot. 

46 """ 

47 

48 conf_matrix = plot_data.get("confusion_matrix", None) 

49 class_report = plot_data.get("classification_report", None) 

50 prediction_probabilities = plot_data.get("prediction_probabilities", None) 

51 true_labels = plot_data.get("true_labels", None) 

52 

53 if conf_matrix is None or class_report is None or prediction_probabilities is None or true_labels is None: 

54 logging.warning(f"Insufficient data for plotting results under key: {ClassificationTestingStrategyHandler.PLOTTING_KEY}.") 

55 plt.text(0.5, 0.5, "Insufficient data for plotting", 

56 ha = 'center', va = 'center', fontsize = 12) 

57 return plt.gca() 

58 

59 additional_report = {} 

60 for additional_label in self.__ADDITIONAL_REPORT_LABELS: 

61 if additional_label in class_report: 

62 additional_report[additional_label] = class_report.pop(additional_label) 

63 

64 fig = plt.figure(figsize = self._EXPECTED_FIGURE_SIZE) 

65 gs = GridSpec(2, 2, figure = fig) 

66 classes = list(class_report.keys()) 

67 shortened_classes_names = [class_name[:3] for class_name in classes] 

68 

69 # Plot 1: Confusion Matrix as a heatmap 

70 ax1 = plt.subplot(gs[0, 0]) 

71 ax1.imshow(conf_matrix, interpolation = 'nearest', cmap = plt.cm.YlOrRd) 

72 ax1.set_title(f"Confusion Matrix (Accuracy: {additional_report['accuracy']:.2%})") 

73 

74 tick_marks = np.arange(len(classes)) 

75 ax1.set_xticks(tick_marks) 

76 ax1.set_yticks(tick_marks) 

77 ax1.set_xticklabels(shortened_classes_names) 

78 ax1.set_yticklabels(shortened_classes_names) 

79 ax1.set_xlabel('Predicted label') 

80 ax1.set_ylabel('True label') 

81 

82 thresh = conf_matrix.max() / 2.0 

83 for i in range(conf_matrix.shape[0]): 

84 for j in range(conf_matrix.shape[1]): 

85 ax1.text(j, i, format(conf_matrix[i, j], 'd'), 

86 ha="center", va="center", 

87 color="white" if conf_matrix[i, j] > thresh else "black") 

88 

89 # Plot 2: Precision, Recall, F1 Score Bar Chart 

90 ax2 = plt.subplot(gs[1, 0]) 

91 precision_scores = [] 

92 recall_scores = [] 

93 f1_scores = [] 

94 

95 for metrics_dict in class_report.values(): 

96 precision_scores.append(metrics_dict["precision"]) 

97 recall_scores.append(metrics_dict["recall"]) 

98 f1_scores.append(metrics_dict["f1-score"]) 

99 

100 shift = 0.2 

101 ax2.bar(tick_marks - shift, precision_scores, shift, label = 'Precision') 

102 ax2.bar(tick_marks, recall_scores, shift, label = 'Recall') 

103 ax2.bar(tick_marks + shift, f1_scores, shift, label = 'F1-score') 

104 

105 ax2.set_title('Classification metrics by class') 

106 ax2.set_xticks(tick_marks) 

107 ax2.set_xticklabels(shortened_classes_names) 

108 ax2.set_xlabel('Classes') 

109 ax2.set_ylabel('Score') 

110 ax2.set_ylim([0, 1]) 

111 ax2.legend() 

112 

113 # Plot 3: OvR-ROC curves 

114 ax3 = plt.subplot(gs[0, 1]) 

115 

116 for i, class_name in enumerate(classes): 

117 y_true_class_binary = (true_labels == i).astype(int) 

118 y_score = prediction_probabilities[:, i] 

119 RocCurveDisplay.from_predictions(y_true_class_binary, y_score, name = f"{class_name}", 

120 ax = ax3, plot_chance_level = (i == len(classes) - 1)) 

121 

122 ax3.set_title('One-vs-Rest ROC curves') 

123 ax3.set_xlabel('False positive rate') 

124 ax3.set_ylabel('True positive rate') 

125 ax3.grid(alpha = 0.3) 

126 ax3.legend(loc = "lower right", fontsize = 'small') 

127 plt.tight_layout() 

128 

129 # Plot 4: Macro avg and weighted avg 

130 ax4 = plt.subplot(gs[1, 1]) 

131 additional_labels = list(additional_report.keys())[1:] 

132 precision_scores = [] 

133 recall_scores = [] 

134 f1_scores = [] 

135 

136 for metrics in additional_report.values(): 

137 if isinstance(metrics, dict): 

138 precision_scores.append(metrics['precision']) 

139 recall_scores.append(metrics['recall']) 

140 f1_scores.append(metrics['f1-score']) 

141 

142 x = np.arange(len(additional_labels)) 

143 ax4.bar(x - shift, precision_scores, shift, label = 'Precision') 

144 ax4.bar(x, recall_scores, shift, label = 'Recall') 

145 ax4.bar(x + shift, f1_scores, shift, label = 'F1-score') 

146 

147 ax4.set_title('Macro avg and weighted avg') 

148 ax4.set_xticks(x) 

149 ax4.set_xticklabels(additional_labels) 

150 ax4.set_xlabel('Metrics') 

151 ax4.set_ylabel('Score') 

152 ax4.set_ylim([0, 1]) 

153 ax4.legend() 

154 plt.tight_layout() 

155 

156 return plt.gca()