Coverage for source/agent/strategies/classification_testing_strategy_handler.py: 50%

16 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-08-01 20:51 +0000

1# agent/strategies/classification_testing_strategy_handler.py 

2 

3# global imports 

4import numpy as np 

5from sklearn.metrics import classification_report, confusion_matrix 

6from typing import Any 

7 

8# local imports 

9from source.agent import ClassificationTestable, TestingStrategyHandlerBase 

10from source.environment import TradingEnvironment 

11 

12class ClassificationTestingStrategyHandler(TestingStrategyHandlerBase): 

13 """ 

14 Implements a testing strategy handler for classification tasks. 

15 """ 

16 

17 # global class constants 

18 PLOTTING_KEY: str = 'classification_testing' 

19 

20 def evaluate(self, testable_agent: ClassificationTestable, environment: TradingEnvironment) -> \ 

21 tuple[list[str], list[dict[str, Any]]]: 

22 """ 

23 Evaluates the classification model using the given testable agent and trading environment. 

24 

25 Parameters: 

26 testable_agent (ClassificationTestable): The agent to be tested. 

27 environment (TradingEnvironment): The trading environment containing the test data. 

28 

29 Returns: 

30 (tuple[list[str], list[dict[str, Any]]]): A tuple containing the keys and data collected during evaluation. 

31 """ 

32 

33 classes = list(environment.get_trading_consts().OUTPUT_CLASSES.keys()) 

34 input_data, output_data, _, _ = environment.get_labeled_data() 

35 prediction_probabilities = testable_agent.classify(input_data) 

36 y_pred = np.argmax(prediction_probabilities, axis = 1) 

37 

38 conf_matrix = confusion_matrix(output_data, y_pred) 

39 class_report = classification_report(output_data, y_pred, target_names = classes, 

40 output_dict = True, zero_division = 0) 

41 

42 summary = { 

43 "true_labels": output_data, 

44 "prediction_probabilities": prediction_probabilities, 

45 "confusion_matrix": conf_matrix, 

46 "classification_report": class_report 

47 } 

48 

49 return [self.PLOTTING_KEY], [summary]