Coverage for source/agent/agent_handler.py: 75%
51 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-04 20:03 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-04 20:03 +0000
1# agent/agent_handler_base.py
3# global imports
4import logging
5import numpy as np
6from tensorflow.keras.callbacks import Callback
7from typing import Any, Callable, Optional
9# local imports
10from source.agent import AgentBase, LearningStrategyHandlerBase, TestingStrategyHandlerBase
11from source.environment import TradingEnvironment
12from source.model import BluePrintBase
13from source.utils import redirect_stdout_to_logging
15class AgentHandler():
16 """
17 Implements agent handler that is responsible for training and testing
18 the agent in the given trading environment using the specified learning
19 and testing strategies. It is used as an wrapper around the agent
20 to provide a convenient interface for TrainingHandler.
21 """
23 def __init__(self, model_blue_print: BluePrintBase,
24 trading_environment: TradingEnvironment,
25 learning_strategy_handler: LearningStrategyHandlerBase,
26 testing_strategy_handlers: list[TestingStrategyHandlerBase]) -> None:
27 """
28 Class constructor. Initializes the agent handler with the given model blueprint,
29 trading environment, learning strategy handler, and testing strategy handler.
31 Parameters:
32 model_blue_print (BluePrintBase): The model blueprint to be used for the agent.
33 trading_environment (TradingEnvironment): The trading environment in which the agent will operate.
34 learning_strategy_handler (LearningStrategyHandlerBase): The learning strategy handler to be used for training.
35 testing_strategy_handlers (list[TestingStrategyHandlerBase]): The testing strategy handlers to be used for evaluation.
36 """
38 self.__trained: bool = False
39 self.__learning_strategy_handler: LearningStrategyHandlerBase = learning_strategy_handler
40 self.__testing_strategy_handlers: list[TestingStrategyHandlerBase] = testing_strategy_handlers
41 self.__trading_environment: TradingEnvironment = trading_environment
42 self.__agent: AgentBase = learning_strategy_handler.create_agent(model_blue_print, trading_environment)
44 def train_agent(self, nr_of_steps: int, nr_of_episodes: int, callbacks: Optional[list[Callback]] = None,
45 model_load_path: Optional[str] = None,
46 model_save_path: Optional[str] = None) -> tuple[list[str], list[dict]]:
47 """
48 Trains the agent using the specified number of steps and episodes.
50 Parameters:
51 nr_of_steps (int): The number of steps to train the agent.
52 nr_of_episodes (int): The number of episodes to train the agent.
53 callbacks (Optional[list[Callback]]): A list of callbacks to be used during training.
54 model_load_path (Optional[str]): Path to load the pre-trained model from.
55 model_save_path (Optional[str]): Path to save the trained model to.
57 Returns:
58 (tuple[list[str], list[dict]]): A tuple containing the keys and report data from
59 the training process.
60 """
62 if callbacks is None:
63 callbacks = []
65 self.__trading_environment.set_mode(TradingEnvironment.TRAIN_MODE)
67 if model_load_path is not None:
68 self.__agent.load_model(model_load_path)
70 with redirect_stdout_to_logging():
71 keys, report_data = self.__learning_strategy_handler.fit(
72 self.__agent,
73 self.__trading_environment,
74 nr_of_steps,
75 nr_of_episodes,
76 callbacks
77 )
79 self.__trained = True
80 if model_save_path is not None:
81 self.__agent.save_model(model_save_path)
83 return keys, report_data
85 def test_agent(self, repeat: int = 1) -> tuple[dict[int, list[str]], dict[int, list[dict[str, Any]]]]:
86 """
87 Tests the agent using the specified number of repetitions.
89 Parameters:
90 repeat (int): The number of times to repeat the testing process.
92 Returns:
93 (tuple[dict[int, list[str]], dict[int, list[dict[str, Any]]]]): A tuple containing the keys and
94 report data from the testing process.
95 """
97 if not self.__trained:
98 logging.error('Agent is not trained yet! Train the agent before testing.')
99 return {}, {}
101 self.__trading_environment.set_mode(TradingEnvironment.TEST_MODE)
103 if repeat < 1:
104 # Adding 1 to repeat to ensure that any value above 1 will give
105 # full evaluation and cross-validation on the n-split part
106 # equal to number of repeats
107 repeat += 1
109 report_data = {x: [] for x in range(repeat)}
110 keys = {x: [] for x in range(repeat)}
111 window_size = self.__trading_environment.get_trading_consts().WINDOW_SIZE
112 max_env_length = self.__trading_environment.get_environment_length() - 1
113 for strategy_handler in self.__testing_strategy_handlers:
114 last_env_length_index = window_size
115 full_eval_key, full_eval_data = \
116 strategy_handler.evaluate(self.__agent, self.__trading_environment,
117 (last_env_length_index, max_env_length))
118 keys[0] += full_eval_key
119 report_data[0] += full_eval_data
121 for i, env_length_index in enumerate(np.linspace(window_size, max_env_length, repeat, dtype = int)[1:], start = 1):
122 eval_key, eval_data = \
123 strategy_handler.evaluate(self.__agent, self.__trading_environment,
124 (last_env_length_index, env_length_index))
125 keys[i] += eval_key
126 report_data[i] += eval_data
127 last_env_length_index = env_length_index
129 return keys, report_data
131 def print_model_summary(self, print_function: Optional[Callable] = print) -> None:
132 """
133 Prints a summary of the model architecture and parameters.
135 Parameters:
136 print_function (Optional[Callable]): A function to print the summary. Defaults to print.
137 """
139 self.__agent.print_summary(print_function = print_function)