Coverage for source/plotting/classification_training_plot_responsibility_chain.py: 32%

25 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-06-06 12:00 +0000

1# plotting/classification_training_plot_responsibility_chain.py 

2 

3# global imports 

4import logging 

5import matplotlib.pyplot as plt 

6from matplotlib.gridspec import GridSpec 

7 

8# local imports 

9from source.agent import ClassificationLearningStrategyHandler 

10from source.plotting import PlotResponsibilityChainBase 

11 

12class ClassificationTrainingPlotResponsibilityChain(PlotResponsibilityChainBase): 

13 """""" 

14 

15 def _can_plot(self, key: str) -> bool: 

16 """""" 

17 

18 return key == ClassificationLearningStrategyHandler.PLOTTING_KEY 

19 

20 def _plot(self, plot_data: dict) -> plt.Axes: 

21 """""" 

22 

23 loss = plot_data.get("loss", None) 

24 accuracy = plot_data.get("accuracy", None) 

25 

26 if loss is None or accuracy is None: 

27 logging.warning("Insufficient data for plotting classification results.") 

28 plt.text(0.5, 0.5, "Insufficient data for plotting", 

29 ha = 'center', va = 'center', fontsize = 12) 

30 return plt.gca() 

31 

32 fig = plt.figure(figsize = self._EXPECTED_FIGURE_SIZE) 

33 gs = GridSpec(2, 1, figure = fig) 

34 

35 # Plot 1: Training loss and accuracy 

36 plt.subplot(gs[0, 0]) 

37 plt.title("Training loss and accuracy") 

38 plt.plot(loss, label = 'Loss') 

39 plt.plot(accuracy, label = 'Accuracy') 

40 plt.xlabel('Epoch') 

41 plt.ylabel('Value') 

42 plt.legend(loc = 'upper left') 

43 

44 return plt.gca()