Coverage for source/agent/strategies/classification_testing_strategy_handler.py: 50%

18 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-06-06 12:00 +0000

1# agent/strategies/classification_testing_strategy_handler.py 

2 

3# global imports 

4from typing import Any 

5from sklearn.metrics import confusion_matrix, classification_report 

6import numpy as np 

7 

8# local imports 

9from source.environment import TradingEnvironment 

10from source.agent import TestingStrategyHandlerBase 

11from source.agent import ClassificationTestable 

12 

13class ClassificationTestingStrategyHandler(TestingStrategyHandlerBase): 

14 """""" 

15 

16 PLOTTING_KEY: str = 'classification_testing' 

17 

18 def evaluate(self, testable_agent: ClassificationTestable, environment: TradingEnvironment) -> \ 

19 tuple[list[str], list[dict[str, Any]]]: 

20 """""" 

21 

22 classes = list(environment.get_trading_consts().OUTPUT_CLASSES.keys()) 

23 input_data, output_data = environment.get_labeled_data() 

24 prediction_probabilities = testable_agent.classify(input_data) 

25 

26 y_true = np.argmax(output_data, axis = 1) 

27 y_pred = np.argmax(prediction_probabilities, axis = 1) 

28 

29 conf_matrix = confusion_matrix(y_true, y_pred) 

30 class_report = classification_report(y_true, y_pred, target_names = classes, 

31 output_dict = True, zero_division = 0) 

32 

33 summary = { 

34 "output_data": output_data, 

35 "prediction_probabilities": prediction_probabilities, 

36 "confusion_matrix": conf_matrix, 

37 "classification_report": class_report, 

38 "accuracy": (y_true == y_pred).mean() 

39 } 

40 

41 return [ClassificationTestingStrategyHandler.PLOTTING_KEY], [summary]