Coverage for source/plotting/classification_testing_plot_responsibility_chain.py: 97%
107 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-24 10:18 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-24 10:18 +0000
1# plotting/classification_testing_plot_responsibility_chain.py
3# global imports
4import logging
5import matplotlib.pyplot as plt
6import numpy as np
7from matplotlib.colors import Normalize
8from matplotlib.gridspec import GridSpec
9from sklearn.metrics import RocCurveDisplay
11# local imports
12from source.agent import ClassificationTestingStrategyHandler
13from source.plotting import PlotResponsibilityChainBase
15class ClassificationTestingPlotResponsibilityChain(PlotResponsibilityChainBase):
16 """
17 Implements a plotting responsibility chain for classification testing results.
18 It implements the _can_plot and _plot methods to visualize confusion matrices,
19 classification reports, and ROC curves.
20 """
22 # local constants
23 __ADDITIONAL_REPORT_LABELS = ["accuracy", "macro avg", "weighted avg"]
25 def _can_plot(self, key: str) -> bool:
26 """
27 Checks if the plot can be generated for the given key.
29 Parameters:
30 key (str): The key to check.
32 Returns:
33 (bool): True if the plot can be generated, False otherwise.
34 """
36 return key == ClassificationTestingStrategyHandler.PLOTTING_KEY
38 def _plot(self, plot_data: dict) -> plt.Axes:
39 """
40 Generates the classification testing plot based on the provided data.
42 Parameters:
43 plot_data (dict): The data to be plotted.
45 Returns:
46 (plt.Axes): The axes object containing the plot.
47 """
49 conf_matrix = plot_data.get("confusion_matrix", None)
50 class_report = plot_data.get("classification_report", None)
51 prediction_probabilities = plot_data.get("prediction_probabilities", None)
52 true_labels = plot_data.get("true_labels", None)
54 if conf_matrix is None or class_report is None or prediction_probabilities is None or true_labels is None:
55 logging.warning(f"Insufficient data for plotting results under key: {ClassificationTestingStrategyHandler.PLOTTING_KEY}.")
56 plt.text(0.5, 0.5, "Insufficient data for plotting",
57 ha = 'center', va = 'center', fontsize = 12)
58 return plt.gca()
60 additional_report = {}
61 for additional_label in self.__ADDITIONAL_REPORT_LABELS:
62 if additional_label in class_report:
63 additional_report[additional_label] = class_report.pop(additional_label)
65 fig = plt.figure(figsize = self._EXPECTED_FIGURE_SIZE)
66 gs = GridSpec(2, 2, figure = fig)
67 classes = list(class_report.keys())
68 shortened_classes_names = [class_name[:3] for class_name in classes]
70 # Plot 1: Confusion Matrix as a heatmap
71 ax1 = plt.subplot(gs[0, 0])
72 ax1.set_title(f"Confusion Matrix (Accuracy: {additional_report['accuracy']:.2%})")
74 normalized_conf_matrix = conf_matrix.astype('float') / conf_matrix.sum(axis = 1, keepdims = True)
75 normalized_conf_matrix = np.round(np.nan_to_num(normalized_conf_matrix, nan = 0.0), 2)
76 ax1.imshow(normalized_conf_matrix, interpolation = 'nearest', cmap = plt.cm.GnBu, vmin = 0, vmax = 1)
78 tick_marks = np.arange(len(classes))
79 ax1.set_xticks(tick_marks)
80 ax1.set_yticks(tick_marks)
81 ax1.set_xticklabels(shortened_classes_names)
82 ax1.set_yticklabels(shortened_classes_names)
83 ax1.set_xlabel('Predicted label')
84 ax1.set_ylabel('True label')
86 for i in range(conf_matrix.shape[0]):
87 for j in range(conf_matrix.shape[1]):
88 color = "white" if normalized_conf_matrix[i, j] > 0.5 else "black"
89 ax1.text(j, i - 0.1, format(conf_matrix[i, j], 'd'),
90 ha = "center", va = "center", fontsize = 10, weight = 'bold', color = color)
91 ax1.text(j, i + 0.15, f'{normalized_conf_matrix[i, j]:.2f}',
92 ha = "center", va = "center", fontsize = 8, color = color)
94 # Plot 2: Precision, Recall, F1 Score Bar Chart
95 ax2 = plt.subplot(gs[1, 0])
96 precision_scores = []
97 recall_scores = []
98 f1_scores = []
100 for metrics_dict in class_report.values():
101 precision_scores.append(metrics_dict["precision"])
102 recall_scores.append(metrics_dict["recall"])
103 f1_scores.append(metrics_dict["f1-score"])
105 shift = 0.2
106 precision_bars = ax2.bar(tick_marks - shift, precision_scores, shift, label = 'Precision')
107 recall_bars = ax2.bar(tick_marks, recall_scores, shift, label = 'Recall')
108 f1_bars = ax2.bar(tick_marks + shift, f1_scores, shift, label = 'F1-score')
110 for i, (precision_bar, recall_bar, f1_bar) in enumerate(zip(precision_bars, recall_bars, f1_bars)):
111 ax2.text(precision_bar.get_x() + (precision_bar.get_width() / 2),
112 precision_bar.get_height() + 0.01 if precision_bar.get_height() < 0.9 else \
113 precision_bar.get_height() - 0.01, f'{precision_scores[i]:.3f}',
114 ha = 'center', va = 'bottom' if precision_bar.get_height() < 0.9 else 'top', rotation = 90,
115 fontsize = 8, weight = 'bold')
116 ax2.text(recall_bar.get_x() + (recall_bar.get_width() / 2),
117 recall_bar.get_height() + 0.01 if recall_bar.get_height() < 0.9 else \
118 recall_bar.get_height() - 0.01, f'{recall_scores[i]:.3f}',
119 ha = 'center', va = 'bottom' if recall_bar.get_height() < 0.9 else 'top', rotation = 90,
120 fontsize = 8, weight = 'bold')
121 ax2.text(f1_bar.get_x() + (f1_bar.get_width() / 2),
122 f1_bar.get_height() + 0.01 if f1_bar.get_height() < 0.9 else \
123 f1_bar.get_height() - 0.01, f'{f1_scores[i]:.3f}',
124 ha = 'center', va = 'bottom' if f1_bar.get_height() < 0.9 else 'top', rotation = 90,
125 fontsize = 8, weight = 'bold')
127 ax2.set_title('Classification metrics by class')
128 ax2.set_xticks(tick_marks)
129 ax2.set_xticklabels(shortened_classes_names)
130 ax2.set_xlabel('Classes')
131 ax2.set_ylabel('Score')
132 ax2.set_ylim([0, 1])
133 ax2.legend(fontsize = 'x-small')
135 # Plot 3: OvR-ROC curves
136 ax3 = plt.subplot(gs[0, 1])
138 for i, class_name in enumerate(classes):
139 y_true_class_binary = (true_labels == i).astype(int)
140 y_score = prediction_probabilities[:, i]
141 RocCurveDisplay.from_predictions(y_true_class_binary, y_score, name = f"{class_name}",
142 ax = ax3, plot_chance_level = (i == len(classes) - 1))
144 ax3.set_title('One-vs-Rest ROC curves')
145 ax3.set_xlabel('False positive rate')
146 ax3.set_ylabel('True positive rate')
147 ax3.grid(alpha = 0.3)
148 ax3.legend(loc = "lower right", fontsize = 'x-small')
149 plt.tight_layout()
151 # Plot 4: Macro avg and weighted avg
152 ax4 = plt.subplot(gs[1, 1])
153 additional_labels = list(additional_report.keys())[1:]
154 precision_scores = []
155 recall_scores = []
156 f1_scores = []
158 for metrics in additional_report.values():
159 if isinstance(metrics, dict):
160 precision_scores.append(metrics['precision'])
161 recall_scores.append(metrics['recall'])
162 f1_scores.append(metrics['f1-score'])
164 x = np.arange(len(additional_labels))
165 precision_bars = ax4.bar(x - shift, precision_scores, shift, label = 'Precision')
166 recall_bars = ax4.bar(x, recall_scores, shift, label = 'Recall')
167 f1_bars = ax4.bar(x + shift, f1_scores, shift, label = 'F1-score')
169 for i, (precision_bar, recall_bar, f1_bar) in enumerate(zip(precision_bars, recall_bars, f1_bars)):
170 ax4.text(precision_bar.get_x() + (precision_bar.get_width() / 2),
171 precision_bar.get_height() + 0.01 if precision_bar.get_height() < 0.9 else \
172 precision_bar.get_height() - 0.01, f'{precision_scores[i]:.3f}',
173 ha = 'center', va = 'bottom' if precision_bar.get_height() < 0.9 else 'top', rotation = 90,
174 fontsize = 8, weight = 'bold')
175 ax4.text(recall_bar.get_x() + (recall_bar.get_width() / 2),
176 recall_bar.get_height() + 0.01 if recall_bar.get_height() < 0.9 else \
177 recall_bar.get_height() - 0.01, f'{recall_scores[i]:.3f}',
178 ha = 'center', va = 'bottom' if recall_bar.get_height() < 0.9 else 'top', rotation = 90,
179 fontsize = 8, weight = 'bold')
180 ax4.text(f1_bar.get_x() + (f1_bar.get_width() / 2),
181 f1_bar.get_height() + 0.01 if f1_bar.get_height() < 0.9 else \
182 f1_bar.get_height() - 0.01, f'{f1_scores[i]:.3f}',
183 ha = 'center', va = 'bottom' if f1_bar.get_height() < 0.9 else 'top', rotation = 90,
184 fontsize = 8, weight = 'bold')
186 ax4.set_title('Macro avg and weighted avg')
187 ax4.set_xticks(x)
188 ax4.set_xticklabels(additional_labels)
189 ax4.set_xlabel('Metrics')
190 ax4.set_ylabel('Score')
191 ax4.set_ylim([0, 1])
192 ax4.legend(fontsize = 'x-small')
193 plt.tight_layout()
195 return plt.gca()