Coverage for source/plotting/classification_training_plot_responsibility_chain.py: 32%
25 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-06 12:00 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-06 12:00 +0000
1# plotting/classification_training_plot_responsibility_chain.py
3# global imports
4import logging
5import matplotlib.pyplot as plt
6from matplotlib.gridspec import GridSpec
8# local imports
9from source.agent import ClassificationLearningStrategyHandler
10from source.plotting import PlotResponsibilityChainBase
12class ClassificationTrainingPlotResponsibilityChain(PlotResponsibilityChainBase):
13 """"""
15 def _can_plot(self, key: str) -> bool:
16 """"""
18 return key == ClassificationLearningStrategyHandler.PLOTTING_KEY
20 def _plot(self, plot_data: dict) -> plt.Axes:
21 """"""
23 loss = plot_data.get("loss", None)
24 accuracy = plot_data.get("accuracy", None)
26 if loss is None or accuracy is None:
27 logging.warning("Insufficient data for plotting classification results.")
28 plt.text(0.5, 0.5, "Insufficient data for plotting",
29 ha = 'center', va = 'center', fontsize = 12)
30 return plt.gca()
32 fig = plt.figure(figsize = self._EXPECTED_FIGURE_SIZE)
33 gs = GridSpec(2, 1, figure = fig)
35 # Plot 1: Training loss and accuracy
36 plt.subplot(gs[0, 0])
37 plt.title("Training loss and accuracy")
38 plt.plot(loss, label = 'Loss')
39 plt.plot(accuracy, label = 'Accuracy')
40 plt.xlabel('Epoch')
41 plt.ylabel('Value')
42 plt.legend(loc = 'upper left')
44 return plt.gca()