Coverage for source/plotting/classification_testing_plot_responsibility_chain.py: 11%
97 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-06 12:00 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-06 12:00 +0000
1# plotting/classification_result_plot_responsibility_chain.py
3# global imports
4import logging
5import matplotlib.pyplot as plt
6import numpy as np
7from sklearn.metrics import RocCurveDisplay
8from matplotlib.gridspec import GridSpec
10# local imports
11from source.agent import ClassificationTestingStrategyHandler
12from source.plotting import PlotResponsibilityChainBase
14class ClassificationTestingPlotResponsibilityChain(PlotResponsibilityChainBase):
15 """"""
17 __ADDITIONAL_REPORT_LABELS = ["accuracy", "macro avg", "weighted avg"]
19 def _can_plot(self, key: str) -> bool:
20 """"""
22 return key == ClassificationTestingStrategyHandler.PLOTTING_KEY
24 def _plot(self, plot_data: dict) -> plt.Axes:
25 """"""
27 conf_matrix = plot_data.get("confusion_matrix", None)
28 class_report = plot_data.get("classification_report", None)
29 accuracy = plot_data.get("accuracy", None)
30 prediction_probabilities = plot_data.get("prediction_probabilities", None)
31 output_data = plot_data.get("output_data", None)
33 if conf_matrix is None or class_report is None or accuracy is None \
34 or prediction_probabilities is None or output_data is None:
35 logging.warning("Insufficient data for plotting classification results.")
36 plt.text(0.5, 0.5, "Insufficient data for plotting",
37 ha = 'center', va = 'center', fontsize = 12)
38 return plt.gca()
40 additional_report = {}
41 for additional_label in self.__ADDITIONAL_REPORT_LABELS:
42 if additional_label in class_report:
43 additional_report[additional_label] = class_report.pop(additional_label)
45 fig = plt.figure(figsize = self._EXPECTED_FIGURE_SIZE)
46 gs = GridSpec(2, 2, figure = fig)
47 classes = list(class_report.keys())
48 shortened_classes_names = [class_name[:3] for class_name in classes]
50 # Plot 1: Confusion Matrix as a heatmap
51 ax1 = plt.subplot(gs[0, 0])
52 ax1.imshow(conf_matrix, interpolation = 'nearest', cmap = plt.cm.YlOrRd)
53 ax1.set_title(f"Confusion Matrix (Accuracy: {accuracy:.2%})")
55 # Add labels and color bar
56 tick_marks = np.arange(len(classes))
57 ax1.set_xticks(tick_marks)
58 ax1.set_yticks(tick_marks)
59 ax1.set_xticklabels(shortened_classes_names)
60 ax1.set_yticklabels(shortened_classes_names)
61 ax1.set_xlabel('Predicted label')
62 ax1.set_ylabel('True label')
64 # Add text annotations to show the values
65 thresh = conf_matrix.max() / 2.0
66 for i in range(conf_matrix.shape[0]):
67 for j in range(conf_matrix.shape[1]):
68 ax1.text(j, i, format(conf_matrix[i, j], 'd'),
69 ha="center", va="center",
70 color="white" if conf_matrix[i, j] > thresh else "black")
72 # Plot 2: Precision, Recall, F1 Score Bar Chart
73 ax2 = plt.subplot(gs[1, 0])
74 precision_scores = []
75 recall_scores = []
76 f1_scores = []
78 for metrics_dict in class_report.values():
79 precision_scores.append(metrics_dict["precision"])
80 recall_scores.append(metrics_dict["recall"])
81 f1_scores.append(metrics_dict["f1-score"])
83 shift = 0.2
84 ax2.bar(tick_marks - shift, precision_scores, shift, label = 'Precision')
85 ax2.bar(tick_marks, recall_scores, shift, label = 'Recall')
86 ax2.bar(tick_marks + shift, f1_scores, shift, label = 'F1-score')
88 ax2.set_title('Classification metrics by class')
89 ax2.set_xticks(tick_marks)
90 ax2.set_xticklabels(shortened_classes_names)
91 ax2.set_xlabel('Classes')
92 ax2.set_ylabel('Score')
93 ax2.set_ylim([0, 1])
94 ax2.legend()
96 # Plot 3: OvR-ROC curves
97 ax3 = plt.subplot(gs[0, 1])
98 y_true_class = np.argmax(output_data, axis = 1)
100 for i, class_name in enumerate(classes):
101 y_true_class_binary = (y_true_class == i).astype(int)
102 y_score = prediction_probabilities[:, i]
103 RocCurveDisplay.from_predictions(y_true_class_binary, y_score, name = f"{class_name}",
104 ax = ax3, plot_chance_level = (i == len(classes) - 1))
106 ax3.set_title('One-vs-Rest ROC curves')
107 ax3.set_xlabel('False positive rate')
108 ax3.set_ylabel('True positive rate')
109 ax3.grid(alpha = 0.3)
110 ax3.legend(loc = "lower right", fontsize = 'small')
111 plt.tight_layout()
113 # Plot 4: Macro avg and weighted avg
114 ax4 = plt.subplot(gs[1, 1])
115 additional_labels = list(additional_report.keys())[1:]
116 precision_scores = []
117 recall_scores = []
118 f1_scores = []
120 for metrics in additional_report.values():
121 if isinstance(metrics, dict):
122 precision_scores.append(metrics['precision'])
123 recall_scores.append(metrics['recall'])
124 f1_scores.append(metrics['f1-score'])
126 x = np.arange(len(additional_labels))
127 ax4.bar(x - shift, precision_scores, shift, label = 'Precision')
128 ax4.bar(x, recall_scores, shift, label = 'Recall')
129 ax4.bar(x + shift, f1_scores, shift, label = 'F1-score')
131 ax4.set_title('Macro avg and weighted avg')
132 ax4.set_xticks(x)
133 ax4.set_xticklabels(additional_labels)
134 ax4.set_xlabel('Metrics')
135 ax4.set_ylabel('Score')
136 ax4.set_ylim([0, 1])
137 ax4.legend()
138 plt.tight_layout()
140 return plt.gca()