Coverage for source/plotting/classification_training_plot_responsibility_chain.py: 92%
84 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-01 20:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-01 20:51 +0000
1# plotting/classification_training_plot_responsibility_chain.py
3# global imports
4import logging
5import matplotlib.pyplot as plt
6from enum import Enum
7from matplotlib.gridspec import GridSpec
9# local imports
10from source.agent import ClassificationLearningStrategyHandler
11from source.model import SklearnModelAdapter, TFModelAdapter
12from source.plotting import PlotResponsibilityChainBase
14class ClassificationTrainingPlotResponsibilityChain(PlotResponsibilityChainBase):
15 """
16 Implements a plotting responsibility chain for classification training results.
17 It implements the _can_plot and _plot methods to visualize training loss and accuracy.
18 """
20 class PlottingMode(Enum):
21 """
22 Enumeration for the different plotting modes.
23 It defines the modes for TensorFlow and Scikit-learn models.
24 """
26 UNSUPPORTED = "unsupported"
27 TF = TFModelAdapter.TAG
28 SKLEARN = SklearnModelAdapter.TAG
30 def __init__(self) -> None:
31 """
32 Class constructor. Initializes the plotting responsibility chain with an unsupported mode.
33 """
35 self.__mode: ClassificationTrainingPlotResponsibilityChain.PlottingMode = ClassificationTrainingPlotResponsibilityChain.PlottingMode.UNSUPPORTED
37 def _can_plot(self, key: str) -> bool:
38 """
39 Checks if the plot can be generated for the given key.
41 Parameters:
42 key (str): The key to check.
44 Returns:
45 (bool): True if the plot can be generated, False otherwise.
46 """
48 if TFModelAdapter.TAG in key or SklearnModelAdapter.TAG in key:
49 self.__mode = ClassificationTrainingPlotResponsibilityChain.PlottingMode(key.split("_")[-1])
51 return ClassificationLearningStrategyHandler.PLOTTING_KEYS[1] in key
52 else:
53 self.__mode = ClassificationTrainingPlotResponsibilityChain.PlottingMode.UNSUPPORTED
55 return False
57 def _plot(self, plot_data: dict) -> plt.Axes:
58 """
59 Generates the plot based on the current mode and provided data.
61 Parameters:
62 plot_data (dict): The data to be plotted.
64 Returns:
65 (plt.Axes): The axes object containing the plot.
66 """
68 if self.__mode == ClassificationTrainingPlotResponsibilityChain.PlottingMode.TF:
69 return self.__plot_tf(plot_data)
70 elif self.__mode == ClassificationTrainingPlotResponsibilityChain.PlottingMode.SKLEARN:
71 return self.__plot_sklearn(plot_data)
72 else:
73 raise ValueError(f"Unsupported plotting mode: {self.__mode}. Expected 'tensorflow' or 'sklearn'.")
75 def __plot_tf(self, plot_data: dict) -> plt.Axes:
76 """
77 Generates the TensorFlow training plot.
79 Parameters:
80 plot_data (dict): The data to be plotted.
82 Returns:
83 (plt.Axes): The axes object containing the plot.
84 """
86 loss = plot_data.get("loss", None)
87 accuracy = plot_data.get("accuracy", None)
88 val_loss = plot_data.get("val_loss", None)
89 val_accuracy = plot_data.get("val_accuracy", None)
91 if loss is None or accuracy is None or val_loss is None or val_accuracy is None:
92 logging.warning(f"Insufficient data for plotting results under key: {ClassificationLearningStrategyHandler.PLOTTING_KEYS[1]}.")
93 plt.text(0.5, 0.5, "Insufficient data for plotting",
94 ha = 'center', va = 'center', fontsize = 12)
95 return plt.gca()
97 fig = plt.figure(figsize = self._EXPECTED_FIGURE_SIZE)
98 gs = GridSpec(2, 1, figure = fig)
100 # Plot 1: Training loss and accuracy
101 ax1 = plt.subplot(gs[0, 0])
102 ax1.set_title("Training loss and accuracy")
103 ax1.plot(loss, label = 'Loss')
104 ax1.plot(accuracy, label = 'Accuracy')
105 ax1.set_xlabel('Epoch')
106 ax1.set_ylabel('Value')
107 ax1.legend(loc = 'upper left')
109 # Plot 2: Validation loss and accuracy
110 ax2 = plt.subplot(gs[1, 0])
111 ax2.set_title("Validation loss and accuracy")
112 ax2.plot(val_loss, label = 'Validation Loss')
113 ax2.plot(val_accuracy, label = 'Validation Accuracy')
114 ax2.set_xlabel('Epoch')
115 ax2.set_ylabel('Value')
116 ax2.legend(loc = 'upper left')
117 plt.tight_layout()
119 return plt.gca()
121 def __plot_sklearn(self, plot_data: dict) -> plt.Axes:
122 """
123 Generates the Scikit-learn training plot.
125 Parameters:
126 plot_data (dict): The data to be plotted.
128 Returns:
129 (plt.Axes): The axes object containing the plot.
130 """
132 learning_curve_data_train_sizes = plot_data.get("learning_curve_data_train_sizes", None)
133 learning_curve_data_train_scores = plot_data.get("learning_curve_data_train_scores", None)
134 learning_curve_data_valid_scores = plot_data.get("learning_curve_data_valid_scores", None)
135 validation_data_score = plot_data.get("validation_data_score", None)
137 if learning_curve_data_train_sizes is None or learning_curve_data_train_scores is None \
138 or learning_curve_data_valid_scores is None or validation_data_score is None:
139 logging.warning(f"Insufficient data for plotting results under key: {ClassificationLearningStrategyHandler.PLOTTING_KEYS[1]}.")
140 plt.text(0.5, 0.5, "Insufficient data for plotting",
141 ha = 'center', va = 'center', fontsize = 12)
142 return plt.gca()
144 fig = plt.figure(figsize = self._EXPECTED_FIGURE_SIZE)
145 gs = GridSpec(2, 1, figure = fig)
147 train_scores_mean = learning_curve_data_train_scores.mean(axis = 1)
148 train_scores_std = learning_curve_data_train_scores.std(axis = 1)
149 valid_scores_mean = learning_curve_data_valid_scores.mean(axis = 1)
150 valid_scores_std = learning_curve_data_valid_scores.std(axis = 1)
152 # Plot 1: Training learning curve
153 ax1 = plt.subplot(gs[0, 0])
154 ax1.set_title("Training learning curve")
155 ax1.fill_between(learning_curve_data_train_sizes, train_scores_mean - train_scores_std,
156 train_scores_mean + train_scores_std, alpha = 0.1, color = "r")
157 ax1.plot(learning_curve_data_train_sizes, train_scores_mean, 'o-', color = "r", label = "Training score")
158 ax1.set_xlabel('Training examples')
159 ax1.set_ylabel('Accuracy')
160 ax1.legend(loc = 'upper left')
162 # Plot 2: Validation learning curve
163 ax2 = plt.subplot(gs[1, 0])
164 ax2.set_title(f"Validation learning curve, score: {validation_data_score:.2f}")
165 ax2.fill_between(learning_curve_data_train_sizes, valid_scores_mean - valid_scores_std,
166 valid_scores_mean + valid_scores_std, alpha = 0.1, color = "g")
167 ax2.plot(learning_curve_data_train_sizes, valid_scores_mean, 'o-', color = "g", label = "Cross-validation score")
168 ax2.set_xlabel('Validation examples')
169 ax2.set_ylabel('Accuracy')
170 ax2.legend(loc = 'upper left')
171 plt.tight_layout()
173 return plt.gca()