Coverage for source/agent/strategies/classification_testing_strategy_handler.py: 50%

16 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-08-04 21:16 +0000

1# agent/strategies/classification_testing_strategy_handler.py 

2 

3# global imports 

4import numpy as np 

5from sklearn.metrics import classification_report, confusion_matrix 

6from typing import Any 

7 

8# local imports 

9from source.agent import ClassificationTestable, TestingStrategyHandlerBase 

10from source.environment import TradingEnvironment 

11 

12class ClassificationTestingStrategyHandler(TestingStrategyHandlerBase): 

13 """ 

14 Implements a testing strategy handler for classification tasks. 

15 """ 

16 

17 # global class constants 

18 PLOTTING_KEY: str = 'classification_testing' 

19 

20 def evaluate(self, testable_agent: ClassificationTestable, environment: TradingEnvironment, 

21 env_length_range: tuple[int, int]) -> tuple[list[str], list[dict[str, Any]]]: 

22 """ 

23 Evaluates the classification model using the given testable agent and trading environment. 

24 

25 Parameters: 

26 testable_agent (ClassificationTestable): The agent to be tested. 

27 environment (TradingEnvironment): The trading environment containing the test data. 

28 env_length_range (tuple[int, int]): A tuple specifying the range of environment lengths to consider. 

29 

30 Returns: 

31 (tuple[list[str], list[dict[str, Any]]]): A tuple containing the keys and data collected during evaluation. 

32 """ 

33 

34 classes = list(environment.get_trading_consts().OUTPUT_CLASSES.keys()) 

35 input_data, output_data, _, _ = environment.get_labeled_data(env_length_range = env_length_range) 

36 prediction_probabilities = testable_agent.classify(input_data) 

37 y_pred = np.argmax(prediction_probabilities, axis = 1) 

38 

39 conf_matrix = confusion_matrix(output_data, y_pred) 

40 class_report = classification_report(output_data, y_pred, target_names = classes, 

41 output_dict = True, zero_division = 0) 

42 

43 summary = { 

44 "true_labels": output_data, 

45 "prediction_probabilities": prediction_probabilities, 

46 "confusion_matrix": conf_matrix, 

47 "classification_report": class_report 

48 } 

49 

50 return [self.PLOTTING_KEY], [summary]