Coverage for source/plotting/classification_testing_plot_responsibility_chain.py: 11%

97 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-06-06 12:00 +0000

1# plotting/classification_result_plot_responsibility_chain.py 

2 

3# global imports 

4import logging 

5import matplotlib.pyplot as plt 

6import numpy as np 

7from sklearn.metrics import RocCurveDisplay 

8from matplotlib.gridspec import GridSpec 

9 

10# local imports 

11from source.agent import ClassificationTestingStrategyHandler 

12from source.plotting import PlotResponsibilityChainBase 

13 

14class ClassificationTestingPlotResponsibilityChain(PlotResponsibilityChainBase): 

15 """""" 

16 

17 __ADDITIONAL_REPORT_LABELS = ["accuracy", "macro avg", "weighted avg"] 

18 

19 def _can_plot(self, key: str) -> bool: 

20 """""" 

21 

22 return key == ClassificationTestingStrategyHandler.PLOTTING_KEY 

23 

24 def _plot(self, plot_data: dict) -> plt.Axes: 

25 """""" 

26 

27 conf_matrix = plot_data.get("confusion_matrix", None) 

28 class_report = plot_data.get("classification_report", None) 

29 accuracy = plot_data.get("accuracy", None) 

30 prediction_probabilities = plot_data.get("prediction_probabilities", None) 

31 output_data = plot_data.get("output_data", None) 

32 

33 if conf_matrix is None or class_report is None or accuracy is None \ 

34 or prediction_probabilities is None or output_data is None: 

35 logging.warning("Insufficient data for plotting classification results.") 

36 plt.text(0.5, 0.5, "Insufficient data for plotting", 

37 ha = 'center', va = 'center', fontsize = 12) 

38 return plt.gca() 

39 

40 additional_report = {} 

41 for additional_label in self.__ADDITIONAL_REPORT_LABELS: 

42 if additional_label in class_report: 

43 additional_report[additional_label] = class_report.pop(additional_label) 

44 

45 fig = plt.figure(figsize = self._EXPECTED_FIGURE_SIZE) 

46 gs = GridSpec(2, 2, figure = fig) 

47 classes = list(class_report.keys()) 

48 shortened_classes_names = [class_name[:3] for class_name in classes] 

49 

50 # Plot 1: Confusion Matrix as a heatmap 

51 ax1 = plt.subplot(gs[0, 0]) 

52 ax1.imshow(conf_matrix, interpolation = 'nearest', cmap = plt.cm.YlOrRd) 

53 ax1.set_title(f"Confusion Matrix (Accuracy: {accuracy:.2%})") 

54 

55 # Add labels and color bar 

56 tick_marks = np.arange(len(classes)) 

57 ax1.set_xticks(tick_marks) 

58 ax1.set_yticks(tick_marks) 

59 ax1.set_xticklabels(shortened_classes_names) 

60 ax1.set_yticklabels(shortened_classes_names) 

61 ax1.set_xlabel('Predicted label') 

62 ax1.set_ylabel('True label') 

63 

64 # Add text annotations to show the values 

65 thresh = conf_matrix.max() / 2.0 

66 for i in range(conf_matrix.shape[0]): 

67 for j in range(conf_matrix.shape[1]): 

68 ax1.text(j, i, format(conf_matrix[i, j], 'd'), 

69 ha="center", va="center", 

70 color="white" if conf_matrix[i, j] > thresh else "black") 

71 

72 # Plot 2: Precision, Recall, F1 Score Bar Chart 

73 ax2 = plt.subplot(gs[1, 0]) 

74 precision_scores = [] 

75 recall_scores = [] 

76 f1_scores = [] 

77 

78 for metrics_dict in class_report.values(): 

79 precision_scores.append(metrics_dict["precision"]) 

80 recall_scores.append(metrics_dict["recall"]) 

81 f1_scores.append(metrics_dict["f1-score"]) 

82 

83 shift = 0.2 

84 ax2.bar(tick_marks - shift, precision_scores, shift, label = 'Precision') 

85 ax2.bar(tick_marks, recall_scores, shift, label = 'Recall') 

86 ax2.bar(tick_marks + shift, f1_scores, shift, label = 'F1-score') 

87 

88 ax2.set_title('Classification metrics by class') 

89 ax2.set_xticks(tick_marks) 

90 ax2.set_xticklabels(shortened_classes_names) 

91 ax2.set_xlabel('Classes') 

92 ax2.set_ylabel('Score') 

93 ax2.set_ylim([0, 1]) 

94 ax2.legend() 

95 

96 # Plot 3: OvR-ROC curves 

97 ax3 = plt.subplot(gs[0, 1]) 

98 y_true_class = np.argmax(output_data, axis = 1) 

99 

100 for i, class_name in enumerate(classes): 

101 y_true_class_binary = (y_true_class == i).astype(int) 

102 y_score = prediction_probabilities[:, i] 

103 RocCurveDisplay.from_predictions(y_true_class_binary, y_score, name = f"{class_name}", 

104 ax = ax3, plot_chance_level = (i == len(classes) - 1)) 

105 

106 ax3.set_title('One-vs-Rest ROC curves') 

107 ax3.set_xlabel('False positive rate') 

108 ax3.set_ylabel('True positive rate') 

109 ax3.grid(alpha = 0.3) 

110 ax3.legend(loc = "lower right", fontsize = 'small') 

111 plt.tight_layout() 

112 

113 # Plot 4: Macro avg and weighted avg 

114 ax4 = plt.subplot(gs[1, 1]) 

115 additional_labels = list(additional_report.keys())[1:] 

116 precision_scores = [] 

117 recall_scores = [] 

118 f1_scores = [] 

119 

120 for metrics in additional_report.values(): 

121 if isinstance(metrics, dict): 

122 precision_scores.append(metrics['precision']) 

123 recall_scores.append(metrics['recall']) 

124 f1_scores.append(metrics['f1-score']) 

125 

126 x = np.arange(len(additional_labels)) 

127 ax4.bar(x - shift, precision_scores, shift, label = 'Precision') 

128 ax4.bar(x, recall_scores, shift, label = 'Recall') 

129 ax4.bar(x + shift, f1_scores, shift, label = 'F1-score') 

130 

131 ax4.set_title('Macro avg and weighted avg') 

132 ax4.set_xticks(x) 

133 ax4.set_xticklabels(additional_labels) 

134 ax4.set_xlabel('Metrics') 

135 ax4.set_ylabel('Score') 

136 ax4.set_ylim([0, 1]) 

137 ax4.legend() 

138 plt.tight_layout() 

139 

140 return plt.gca()