Coverage for source/agent/strategies/classification_testing_strategy_handler.py: 50%
18 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-06 12:00 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-06 12:00 +0000
1# agent/strategies/classification_testing_strategy_handler.py
3# global imports
4from typing import Any
5from sklearn.metrics import confusion_matrix, classification_report
6import numpy as np
8# local imports
9from source.environment import TradingEnvironment
10from source.agent import TestingStrategyHandlerBase
11from source.agent import ClassificationTestable
13class ClassificationTestingStrategyHandler(TestingStrategyHandlerBase):
14 """"""
16 PLOTTING_KEY: str = 'classification_testing'
18 def evaluate(self, testable_agent: ClassificationTestable, environment: TradingEnvironment) -> \
19 tuple[list[str], list[dict[str, Any]]]:
20 """"""
22 classes = list(environment.get_trading_consts().OUTPUT_CLASSES.keys())
23 input_data, output_data = environment.get_labeled_data()
24 prediction_probabilities = testable_agent.classify(input_data)
26 y_true = np.argmax(output_data, axis = 1)
27 y_pred = np.argmax(prediction_probabilities, axis = 1)
29 conf_matrix = confusion_matrix(y_true, y_pred)
30 class_report = classification_report(y_true, y_pred, target_names = classes,
31 output_dict = True, zero_division = 0)
33 summary = {
34 "output_data": output_data,
35 "prediction_probabilities": prediction_probabilities,
36 "confusion_matrix": conf_matrix,
37 "classification_report": class_report,
38 "accuracy": (y_true == y_pred).mean()
39 }
41 return [ClassificationTestingStrategyHandler.PLOTTING_KEY], [summary]