Coverage for source/agent/agent_handler.py: 97%
79 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-30 15:13 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-30 15:13 +0000
1# agent/agent_handler.py
3import random
4import numpy as np
5import logging
6import io
7import rl
8from rl.agents import DQNAgent
9import rl.agents
10from rl.policy import Policy
11from rl.memory import SequentialMemory
12from tensorflow.keras.models import Model
13from tensorflow.keras.optimizers import Optimizer
14from tensorflow.keras.callbacks import Callback
15from typing import Optional
16from contextlib import redirect_stdout
18from source.environment import TradingEnvironment
20class AgentHandler():
21 """
22 Responsible for encapsulating a DQNAgent along with its associated training and testing procedures.
23 This class provides a simplified interface for managing deep reinforcement learning agent operations
24 in the trading environment context.
25 """
27 def __init__(self, model: Model, policy: Policy, nr_of_actions: int, optimizer: Optimizer) -> None:
28 """
29 Initializes the AgentHandler with given model, policy and action space parameters.
31 Parameters:
32 model (Model): Keras model used by the agent to learn from environment.
33 policy (Policy): Policy that determines action selection strategy.
34 nr_of_actions (int): Number of possible actions agent can take.
35 optimizer (Optimizer): Keras optimizer used for model training.
36 """
38 self.__trained: bool = False
39 self.__agent: DQNAgent = rl.agents.DQNAgent(model, policy, memory = SequentialMemory(limit = 100000, window_length = 1),
40 nb_actions = nr_of_actions, target_model_update = 1e-2)
41 self.__agent.compile(optimizer)
43 def train_agent(self, environment: TradingEnvironment, nr_of_steps: int, steps_per_episode: int,
44 callbacks: list[Callback] = [], weights_load_path: Optional[str] = None,
45 weights_save_path: Optional[str] = None) -> dict:
46 """
47 Trains the agent on the provided environment.
49 Parameters:
50 environment (TradingEnvironment): Trading environment to train on.
51 nr_of_steps (int): Total number of training steps.
52 steps_per_episode (int): Maximum steps per episode.
53 callbacks (list[Callback], optional): List of Keras callbacks for training.
54 weights_load_path (str, optional): Path to load pre-trained weights.
55 weights_save_path (str, optional): Path to save weights after training.
57 Returns:
58 dict: Dictionary containing training history metrics.
59 """
61 if weights_load_path is not None:
62 self.__agent.load_weights(weights_load_path)
64 captured_output = io.StringIO()
65 with redirect_stdout(captured_output): #TODO: Create an callback logger
66 history = self.__agent.fit(environment, nr_of_steps, callbacks = callbacks,
67 log_interval = steps_per_episode, nb_max_episode_steps = steps_per_episode)
69 for line in captured_output.getvalue().split('\n'):
70 if line.strip():
71 logging.info(line)
72 self.__trained = True
74 if weights_save_path is not None:
75 self.__agent.save_weights(weights_save_path)
77 return history.history
79 def test_agent(self, environment: TradingEnvironment, repeat: int = 1) -> dict:
80 """
81 Tests the trained agent on the provided environment.
83 Testing involves running the agent on the environment from random starting points
84 and recording the performance metrics like asset value changes and rewards.
86 Parameters:
87 environment (TradingEnvironment): Trading environment to test on.
88 repeat (int, optional): Number of test episodes to run. Defaults to 1.
90 Returns:
91 dict: Dictionary containing test metrics including asset values, rewards,
92 and trading performance statistics for each test episode.
93 Returns empty dict if agent is not trained.
94 """
96 if not self.__trained:
97 logging.error('Agent is not trained yet! Train the agent before testing.')
98 return {}
100 test_history = {}
101 env_length = environment.get_environment_length()
102 for i in range(repeat):
103 test_history[i] = {}
104 assets_values = []
105 reward_values = []
106 infos = []
107 iterations = []
108 done = False
110 window_size = environment.get_trading_consts().WINDOW_SIZE
111 current_iteration = random.randint(window_size, int(env_length/2))
112 environment.reset(current_iteration)
113 state = environment.state
114 trading_data = environment.get_trading_data()
115 current_assets = trading_data.current_budget + trading_data.currently_invested
116 iterations.append(current_iteration)
117 assets_values.append(current_assets)
118 reward_values.append(0)
119 infos.append({})
121 while(not done):
122 next_action = self.__agent.forward(state)
123 state, reward, done, info = environment.step(next_action)
125 if current_assets != info['current_budget'] + info['currently_invested'] or done:
126 current_iteration = environment.current_iteration
127 current_assets = info['current_budget'] + info['currently_invested']
128 iterations.append(current_iteration)
129 assets_values.append(current_assets)
130 reward_values.append(reward)
131 infos.append(info)
133 solvency_coefficient = (assets_values[-1] - assets_values[0]) / (iterations[-1] - iterations[0])
134 assets_values = (np.array(assets_values) / assets_values[0]).tolist()
135 currency_prices = environment.get_data_for_iteration(['close'], iterations[0], iterations[-1])
136 currency_prices = (np.array(currency_prices) / currency_prices[0]).tolist()
138 test_history[i]['assets_values'] = assets_values
139 test_history[i]['reward_values'] = reward_values
140 test_history[i]['currency_prices'] = currency_prices
141 test_history[i]['infos'] = infos
142 test_history[i]['iterations'] = iterations
143 test_history[i]['solvency_coefficient'] = solvency_coefficient
145 return test_history
147 def print_model_summary(self, print_function: Optional[callable] = print) -> None:
148 """
149 Prints the model summary using the provided print function.
151 Parameters:
152 print_function (callable, optional): Function to use for printing.
153 Defaults to built-in print.
154 """
156 self.__agent.model.summary(print_fn = print_function)