Coverage for source/plotting/classification_testing_plot_responsibility_chain.py: 97%
95 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-01 20:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-01 20:51 +0000
1# plotting/classification_testing_plot_responsibility_chain.py
3# global imports
4import logging
5import matplotlib.pyplot as plt
6import numpy as np
7from matplotlib.gridspec import GridSpec
8from sklearn.metrics import RocCurveDisplay
10# local imports
11from source.agent import ClassificationTestingStrategyHandler
12from source.plotting import PlotResponsibilityChainBase
14class ClassificationTestingPlotResponsibilityChain(PlotResponsibilityChainBase):
15 """
16 Implements a plotting responsibility chain for classification testing results.
17 It implements the _can_plot and _plot methods to visualize confusion matrices,
18 classification reports, and ROC curves.
19 """
21 # local constants
22 __ADDITIONAL_REPORT_LABELS = ["accuracy", "macro avg", "weighted avg"]
24 def _can_plot(self, key: str) -> bool:
25 """
26 Checks if the plot can be generated for the given key.
28 Parameters:
29 key (str): The key to check.
31 Returns:
32 (bool): True if the plot can be generated, False otherwise.
33 """
35 return key == ClassificationTestingStrategyHandler.PLOTTING_KEY
37 def _plot(self, plot_data: dict) -> plt.Axes:
38 """
39 Generates the classification testing plot based on the provided data.
41 Parameters:
42 plot_data (dict): The data to be plotted.
44 Returns:
45 (plt.Axes): The axes object containing the plot.
46 """
48 conf_matrix = plot_data.get("confusion_matrix", None)
49 class_report = plot_data.get("classification_report", None)
50 prediction_probabilities = plot_data.get("prediction_probabilities", None)
51 true_labels = plot_data.get("true_labels", None)
53 if conf_matrix is None or class_report is None or prediction_probabilities is None or true_labels is None:
54 logging.warning(f"Insufficient data for plotting results under key: {ClassificationTestingStrategyHandler.PLOTTING_KEY}.")
55 plt.text(0.5, 0.5, "Insufficient data for plotting",
56 ha = 'center', va = 'center', fontsize = 12)
57 return plt.gca()
59 additional_report = {}
60 for additional_label in self.__ADDITIONAL_REPORT_LABELS:
61 if additional_label in class_report:
62 additional_report[additional_label] = class_report.pop(additional_label)
64 fig = plt.figure(figsize = self._EXPECTED_FIGURE_SIZE)
65 gs = GridSpec(2, 2, figure = fig)
66 classes = list(class_report.keys())
67 shortened_classes_names = [class_name[:3] for class_name in classes]
69 # Plot 1: Confusion Matrix as a heatmap
70 ax1 = plt.subplot(gs[0, 0])
71 ax1.imshow(conf_matrix, interpolation = 'nearest', cmap = plt.cm.YlOrRd)
72 ax1.set_title(f"Confusion Matrix (Accuracy: {additional_report['accuracy']:.2%})")
74 tick_marks = np.arange(len(classes))
75 ax1.set_xticks(tick_marks)
76 ax1.set_yticks(tick_marks)
77 ax1.set_xticklabels(shortened_classes_names)
78 ax1.set_yticklabels(shortened_classes_names)
79 ax1.set_xlabel('Predicted label')
80 ax1.set_ylabel('True label')
82 thresh = conf_matrix.max() / 2.0
83 for i in range(conf_matrix.shape[0]):
84 for j in range(conf_matrix.shape[1]):
85 ax1.text(j, i, format(conf_matrix[i, j], 'd'),
86 ha="center", va="center",
87 color="white" if conf_matrix[i, j] > thresh else "black")
89 # Plot 2: Precision, Recall, F1 Score Bar Chart
90 ax2 = plt.subplot(gs[1, 0])
91 precision_scores = []
92 recall_scores = []
93 f1_scores = []
95 for metrics_dict in class_report.values():
96 precision_scores.append(metrics_dict["precision"])
97 recall_scores.append(metrics_dict["recall"])
98 f1_scores.append(metrics_dict["f1-score"])
100 shift = 0.2
101 ax2.bar(tick_marks - shift, precision_scores, shift, label = 'Precision')
102 ax2.bar(tick_marks, recall_scores, shift, label = 'Recall')
103 ax2.bar(tick_marks + shift, f1_scores, shift, label = 'F1-score')
105 ax2.set_title('Classification metrics by class')
106 ax2.set_xticks(tick_marks)
107 ax2.set_xticklabels(shortened_classes_names)
108 ax2.set_xlabel('Classes')
109 ax2.set_ylabel('Score')
110 ax2.set_ylim([0, 1])
111 ax2.legend()
113 # Plot 3: OvR-ROC curves
114 ax3 = plt.subplot(gs[0, 1])
116 for i, class_name in enumerate(classes):
117 y_true_class_binary = (true_labels == i).astype(int)
118 y_score = prediction_probabilities[:, i]
119 RocCurveDisplay.from_predictions(y_true_class_binary, y_score, name = f"{class_name}",
120 ax = ax3, plot_chance_level = (i == len(classes) - 1))
122 ax3.set_title('One-vs-Rest ROC curves')
123 ax3.set_xlabel('False positive rate')
124 ax3.set_ylabel('True positive rate')
125 ax3.grid(alpha = 0.3)
126 ax3.legend(loc = "lower right", fontsize = 'small')
127 plt.tight_layout()
129 # Plot 4: Macro avg and weighted avg
130 ax4 = plt.subplot(gs[1, 1])
131 additional_labels = list(additional_report.keys())[1:]
132 precision_scores = []
133 recall_scores = []
134 f1_scores = []
136 for metrics in additional_report.values():
137 if isinstance(metrics, dict):
138 precision_scores.append(metrics['precision'])
139 recall_scores.append(metrics['recall'])
140 f1_scores.append(metrics['f1-score'])
142 x = np.arange(len(additional_labels))
143 ax4.bar(x - shift, precision_scores, shift, label = 'Precision')
144 ax4.bar(x, recall_scores, shift, label = 'Recall')
145 ax4.bar(x + shift, f1_scores, shift, label = 'F1-score')
147 ax4.set_title('Macro avg and weighted avg')
148 ax4.set_xticks(x)
149 ax4.set_xticklabels(additional_labels)
150 ax4.set_xlabel('Metrics')
151 ax4.set_ylabel('Score')
152 ax4.set_ylim([0, 1])
153 ax4.legend()
154 plt.tight_layout()
156 return plt.gca()