Coverage for source/agent/agent_handler.py: 40%
48 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-06 12:00 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-06 12:00 +0000
1# agent/agent_handler_base.py
3# global imports
4import io
5import logging
6import random
7from contextlib import redirect_stdout
8from tensorflow.keras.callbacks import Callback
9from typing import Any, Callable, Optional
11# local imports
12from source.agent import AgentBase, LearningStrategyHandlerBase, TestingStrategyHandlerBase
13from source.environment import TradingEnvironment
14from source.model import BluePrintBase
16class AgentHandler():
17 """"""
19 def __init__(self, model_blue_print: BluePrintBase,
20 trading_environment: TradingEnvironment,
21 learning_strategy_handler: LearningStrategyHandlerBase,
22 testing_strategy_handler: TestingStrategyHandlerBase) -> None:
23 """"""
25 self.__trained: bool = False
26 self.__learning_strategy_handler: LearningStrategyHandlerBase = learning_strategy_handler
27 self.__testing_strategy_handler: TestingStrategyHandlerBase = testing_strategy_handler
28 self.__trading_environment: TradingEnvironment = trading_environment
29 self.__agent: AgentBase = learning_strategy_handler.create_agent(model_blue_print, trading_environment)
31 def train_agent(self, nr_of_steps: int, nr_of_episodes: int, callbacks: Optional[list[Callback]] = None,
32 model_load_path: Optional[str] = None,
33 model_save_path: Optional[str] = None) -> tuple[list[str], list[dict]]:
34 """"""
36 if callbacks is None:
37 callbacks = []
39 self.__trading_environment.set_mode(TradingEnvironment.TRAIN_MODE)
41 if model_load_path is not None:
42 self.__agent.load_model(model_load_path)
44 captured_output = io.StringIO()
45 with redirect_stdout(captured_output): #TODO: Create an callback logger
46 key, report_data = self.__learning_strategy_handler.fit(self.__agent, self.__trading_environment,
47 nr_of_steps, nr_of_episodes, callbacks)
49 for line in captured_output.getvalue().split('\n'):
50 if line.strip():
51 logging.info(line)
52 self.__trained = True
54 if model_save_path is not None:
55 self.__agent.save_model(model_save_path)
57 return key, report_data
59 def test_agent(self, repeat: int = 1) -> tuple[dict[int, list[str]], dict[int, list[dict[str, Any]]]]:
60 """"""
62 if not self.__trained:
63 logging.error('Agent is not trained yet! Train the agent before testing.')
64 return {}, {}
66 self.__trading_environment.set_mode(TradingEnvironment.TEST_MODE)
68 report_data = {}
69 key = {}
70 for i in range(repeat):
71 env_length = self.__trading_environment.get_environment_length()
72 window_size = self.__trading_environment.get_trading_consts().WINDOW_SIZE
73 current_iteration = random.randint(window_size, int(env_length/2))
74 self.__trading_environment.reset(current_iteration)
75 key[i], report_data[i] = self.__testing_strategy_handler.evaluate(self.__agent,
76 self.__trading_environment)
78 return key, report_data
80 def print_model_summary(self, print_function: Optional[Callable] = print) -> None:
81 """"""
83 self.__agent.print_summary(print_function = print_function)