Coverage for source/agent/strategies/classification_testing_strategy_handler.py: 50%
16 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-01 20:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-01 20:51 +0000
1# agent/strategies/classification_testing_strategy_handler.py
3# global imports
4import numpy as np
5from sklearn.metrics import classification_report, confusion_matrix
6from typing import Any
8# local imports
9from source.agent import ClassificationTestable, TestingStrategyHandlerBase
10from source.environment import TradingEnvironment
12class ClassificationTestingStrategyHandler(TestingStrategyHandlerBase):
13 """
14 Implements a testing strategy handler for classification tasks.
15 """
17 # global class constants
18 PLOTTING_KEY: str = 'classification_testing'
20 def evaluate(self, testable_agent: ClassificationTestable, environment: TradingEnvironment) -> \
21 tuple[list[str], list[dict[str, Any]]]:
22 """
23 Evaluates the classification model using the given testable agent and trading environment.
25 Parameters:
26 testable_agent (ClassificationTestable): The agent to be tested.
27 environment (TradingEnvironment): The trading environment containing the test data.
29 Returns:
30 (tuple[list[str], list[dict[str, Any]]]): A tuple containing the keys and data collected during evaluation.
31 """
33 classes = list(environment.get_trading_consts().OUTPUT_CLASSES.keys())
34 input_data, output_data, _, _ = environment.get_labeled_data()
35 prediction_probabilities = testable_agent.classify(input_data)
36 y_pred = np.argmax(prediction_probabilities, axis = 1)
38 conf_matrix = confusion_matrix(output_data, y_pred)
39 class_report = classification_report(output_data, y_pred, target_names = classes,
40 output_dict = True, zero_division = 0)
42 summary = {
43 "true_labels": output_data,
44 "prediction_probabilities": prediction_probabilities,
45 "confusion_matrix": conf_matrix,
46 "classification_report": class_report
47 }
49 return [self.PLOTTING_KEY], [summary]