Coverage for source/agent/agent_handler.py: 97%

79 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-30 15:13 +0000

1# agent/agent_handler.py 

2 

3import random 

4import numpy as np 

5import logging 

6import io 

7import rl 

8from rl.agents import DQNAgent 

9import rl.agents 

10from rl.policy import Policy 

11from rl.memory import SequentialMemory 

12from tensorflow.keras.models import Model 

13from tensorflow.keras.optimizers import Optimizer 

14from tensorflow.keras.callbacks import Callback 

15from typing import Optional 

16from contextlib import redirect_stdout 

17 

18from source.environment import TradingEnvironment 

19 

20class AgentHandler(): 

21 """ 

22 Responsible for encapsulating a DQNAgent along with its associated training and testing procedures. 

23 This class provides a simplified interface for managing deep reinforcement learning agent operations 

24 in the trading environment context. 

25 """ 

26 

27 def __init__(self, model: Model, policy: Policy, nr_of_actions: int, optimizer: Optimizer) -> None: 

28 """ 

29 Initializes the AgentHandler with given model, policy and action space parameters. 

30 

31 Parameters: 

32 model (Model): Keras model used by the agent to learn from environment. 

33 policy (Policy): Policy that determines action selection strategy. 

34 nr_of_actions (int): Number of possible actions agent can take. 

35 optimizer (Optimizer): Keras optimizer used for model training. 

36 """ 

37 

38 self.__trained: bool = False 

39 self.__agent: DQNAgent = rl.agents.DQNAgent(model, policy, memory = SequentialMemory(limit = 100000, window_length = 1), 

40 nb_actions = nr_of_actions, target_model_update = 1e-2) 

41 self.__agent.compile(optimizer) 

42 

43 def train_agent(self, environment: TradingEnvironment, nr_of_steps: int, steps_per_episode: int, 

44 callbacks: list[Callback] = [], weights_load_path: Optional[str] = None, 

45 weights_save_path: Optional[str] = None) -> dict: 

46 """ 

47 Trains the agent on the provided environment. 

48 

49 Parameters: 

50 environment (TradingEnvironment): Trading environment to train on. 

51 nr_of_steps (int): Total number of training steps. 

52 steps_per_episode (int): Maximum steps per episode. 

53 callbacks (list[Callback], optional): List of Keras callbacks for training. 

54 weights_load_path (str, optional): Path to load pre-trained weights. 

55 weights_save_path (str, optional): Path to save weights after training. 

56 

57 Returns: 

58 dict: Dictionary containing training history metrics. 

59 """ 

60 

61 if weights_load_path is not None: 

62 self.__agent.load_weights(weights_load_path) 

63 

64 captured_output = io.StringIO() 

65 with redirect_stdout(captured_output): #TODO: Create an callback logger 

66 history = self.__agent.fit(environment, nr_of_steps, callbacks = callbacks, 

67 log_interval = steps_per_episode, nb_max_episode_steps = steps_per_episode) 

68 

69 for line in captured_output.getvalue().split('\n'): 

70 if line.strip(): 

71 logging.info(line) 

72 self.__trained = True 

73 

74 if weights_save_path is not None: 

75 self.__agent.save_weights(weights_save_path) 

76 

77 return history.history 

78 

79 def test_agent(self, environment: TradingEnvironment, repeat: int = 1) -> dict: 

80 """ 

81 Tests the trained agent on the provided environment. 

82 

83 Testing involves running the agent on the environment from random starting points 

84 and recording the performance metrics like asset value changes and rewards. 

85 

86 Parameters: 

87 environment (TradingEnvironment): Trading environment to test on. 

88 repeat (int, optional): Number of test episodes to run. Defaults to 1. 

89 

90 Returns: 

91 dict: Dictionary containing test metrics including asset values, rewards, 

92 and trading performance statistics for each test episode. 

93 Returns empty dict if agent is not trained. 

94 """ 

95 

96 if not self.__trained: 

97 logging.error('Agent is not trained yet! Train the agent before testing.') 

98 return {} 

99 

100 test_history = {} 

101 env_length = environment.get_environment_length() 

102 for i in range(repeat): 

103 test_history[i] = {} 

104 assets_values = [] 

105 reward_values = [] 

106 infos = [] 

107 iterations = [] 

108 done = False 

109 

110 window_size = environment.get_trading_consts().WINDOW_SIZE 

111 current_iteration = random.randint(window_size, int(env_length/2)) 

112 environment.reset(current_iteration) 

113 state = environment.state 

114 trading_data = environment.get_trading_data() 

115 current_assets = trading_data.current_budget + trading_data.currently_invested 

116 iterations.append(current_iteration) 

117 assets_values.append(current_assets) 

118 reward_values.append(0) 

119 infos.append({}) 

120 

121 while(not done): 

122 next_action = self.__agent.forward(state) 

123 state, reward, done, info = environment.step(next_action) 

124 

125 if current_assets != info['current_budget'] + info['currently_invested'] or done: 

126 current_iteration = environment.current_iteration 

127 current_assets = info['current_budget'] + info['currently_invested'] 

128 iterations.append(current_iteration) 

129 assets_values.append(current_assets) 

130 reward_values.append(reward) 

131 infos.append(info) 

132 

133 solvency_coefficient = (assets_values[-1] - assets_values[0]) / (iterations[-1] - iterations[0]) 

134 assets_values = (np.array(assets_values) / assets_values[0]).tolist() 

135 currency_prices = environment.get_data_for_iteration(['close'], iterations[0], iterations[-1]) 

136 currency_prices = (np.array(currency_prices) / currency_prices[0]).tolist() 

137 

138 test_history[i]['assets_values'] = assets_values 

139 test_history[i]['reward_values'] = reward_values 

140 test_history[i]['currency_prices'] = currency_prices 

141 test_history[i]['infos'] = infos 

142 test_history[i]['iterations'] = iterations 

143 test_history[i]['solvency_coefficient'] = solvency_coefficient 

144 

145 return test_history 

146 

147 def print_model_summary(self, print_function: Optional[callable] = print) -> None: 

148 """ 

149 Prints the model summary using the provided print function. 

150 

151 Parameters: 

152 print_function (callable, optional): Function to use for printing. 

153 Defaults to built-in print. 

154 """ 

155 

156 self.__agent.model.summary(print_fn = print_function)