Coverage for source/agent/agent_handler.py: 95%
43 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-01 20:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-01 20:51 +0000
1# agent/agent_handler_base.py
3# global imports
4import logging
5import random
6from tensorflow.keras.callbacks import Callback
7from typing import Any, Callable, Optional
9# local imports
10from source.agent import AgentBase, LearningStrategyHandlerBase, TestingStrategyHandlerBase
11from source.environment import TradingEnvironment
12from source.model import BluePrintBase
13from source.utils import redirect_stdout_to_logging
15class AgentHandler():
16 """
17 Implements agent handler that is responsible for training and testing
18 the agent in the given trading environment using the specified learning
19 and testing strategies. It is used as an wrapper around the agent
20 to provide a convenient interface for TrainingHandler.
21 """
23 def __init__(self, model_blue_print: BluePrintBase,
24 trading_environment: TradingEnvironment,
25 learning_strategy_handler: LearningStrategyHandlerBase,
26 testing_strategy_handler: TestingStrategyHandlerBase) -> None:
27 """
28 Class constructor. Initializes the agent handler with the given model blueprint,
29 trading environment, learning strategy handler, and testing strategy handler.
31 Parameters:
32 model_blue_print (BluePrintBase): The model blueprint to be used for the agent.
33 trading_environment (TradingEnvironment): The trading environment in which the agent will operate.
34 learning_strategy_handler (LearningStrategyHandlerBase): The learning strategy handler to be used for training.
35 testing_strategy_handler (TestingStrategyHandlerBase): The testing strategy handler to be used for evaluation.
36 """
38 self.__trained: bool = False
39 self.__learning_strategy_handler: LearningStrategyHandlerBase = learning_strategy_handler
40 self.__testing_strategy_handler: TestingStrategyHandlerBase = testing_strategy_handler
41 self.__trading_environment: TradingEnvironment = trading_environment
42 self.__agent: AgentBase = learning_strategy_handler.create_agent(model_blue_print, trading_environment)
44 def train_agent(self, nr_of_steps: int, nr_of_episodes: int, callbacks: Optional[list[Callback]] = None,
45 model_load_path: Optional[str] = None,
46 model_save_path: Optional[str] = None) -> tuple[list[str], list[dict]]:
47 """
48 Trains the agent using the specified number of steps and episodes.
50 Parameters:
51 nr_of_steps (int): The number of steps to train the agent.
52 nr_of_episodes (int): The number of episodes to train the agent.
53 callbacks (Optional[list[Callback]]): A list of callbacks to be used during training.
54 model_load_path (Optional[str]): Path to load the pre-trained model from.
55 model_save_path (Optional[str]): Path to save the trained model to.
57 Returns:
58 (tuple[list[str], list[dict]]): A tuple containing the keys and report data from
59 the training process.
60 """
62 if callbacks is None:
63 callbacks = []
65 self.__trading_environment.set_mode(TradingEnvironment.TRAIN_MODE)
67 if model_load_path is not None:
68 self.__agent.load_model(model_load_path)
70 with redirect_stdout_to_logging():
71 keys, report_data = self.__learning_strategy_handler.fit(
72 self.__agent,
73 self.__trading_environment,
74 nr_of_steps,
75 nr_of_episodes,
76 callbacks
77 )
79 self.__trained = True
80 if model_save_path is not None:
81 self.__agent.save_model(model_save_path)
83 return keys, report_data
85 def test_agent(self, repeat: int = 1) -> tuple[dict[int, list[str]], dict[int, list[dict[str, Any]]]]:
86 """
87 Tests the agent using the specified number of repetitions.
89 Parameters:
90 repeat (int): The number of times to repeat the testing process.
92 Returns:
93 (tuple[dict[int, list[str]], dict[int, list[dict[str, Any]]]]): A tuple containing the keys and
94 report data from the testing process.
95 """
97 if not self.__trained:
98 logging.error('Agent is not trained yet! Train the agent before testing.')
99 return {}, {}
101 self.__trading_environment.set_mode(TradingEnvironment.TEST_MODE)
103 report_data = {}
104 keys = {}
105 for i in range(repeat):
106 env_length = self.__trading_environment.get_environment_length()
107 window_size = self.__trading_environment.get_trading_consts().WINDOW_SIZE
108 current_iteration = random.randint(window_size, int(env_length/2))
109 self.__trading_environment.reset(current_iteration)
110 keys[i], report_data[i] = self.__testing_strategy_handler.evaluate(self.__agent,
111 self.__trading_environment)
113 return keys, report_data
115 def print_model_summary(self, print_function: Optional[Callable] = print) -> None:
116 """
117 Prints a summary of the model architecture and parameters.
119 Parameters:
120 print_function (Optional[Callable]): A function to print the summary. Defaults to print.
121 """
123 self.__agent.print_summary(print_function = print_function)