Coverage for source/agent/agents/reinforcement_learning_agent.py: 65%
34 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-04 20:03 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-04 20:03 +0000
1# agent/agents/reinforcement_learning_agent.py
3# global imports
4import rl
5import tensorflow as tf
6from rl.agents import DQNAgent
7from rl.memory import SequentialMemory
8from rl.policy import Policy
9from tensorflow.keras.callbacks import Callback
10from tensorflow.keras.models import Model
11from tensorflow.keras.optimizers import Optimizer
12from typing import Any, Callable, Optional
14# local imports
15from source.agent import AgentBase, PerformanceTestable
16from source.environment import TradingEnvironment
18class ReinforcementLearningAgent(AgentBase, PerformanceTestable):
19 """
20 Implements a reinforcement learning agent using DQN that can be trained and tested
21 in a trading environment. It provides functionalities for fitting the model
22 with reinforcement learning data and performing actions based on observations.
23 """
25 def __init__(self, model: Model, policy: Policy, optimizer: Optimizer) -> None:
26 """
27 Class constructor. Creates a reinforcement learning agent with the given model,
28 policy, and optimizer.
30 Parameters:
31 model (Model): The Keras model to use for the agent.
32 policy (Policy): The policy to use for the agent's actions.
33 optimizer (Optimizer): The optimizer to use for training the agent.
34 """
36 memory = SequentialMemory(limit = 100000, window_length = 1)
37 self.__DQNAgent: DQNAgent = rl.agents.DQNAgent(model, policy, memory = memory,
38 nb_actions = model.output_shape[-1],
39 target_model_update = 1e-2)
40 self.__DQNAgent.compile(optimizer)
41 self.__DQNAgent.optimizer = self.__DQNAgent.model.optimizer # For compatibility with callbacks
43 def load_model(self, model_path: str) -> None:
44 """
45 Loads the model weights from the specified file path.
47 Parameters:
48 model_path (str): The path to the model weights file.
49 """
51 self.__DQNAgent.load_weights(model_path)
53 def save_model(self, model_path: str) -> None:
54 """
55 Saves the model weights to the specified file path.
57 Parameters:
58 model_path (str): The path to the model weights file.
59 """
61 self.__DQNAgent.save_weights(model_path)
63 def print_summary(self, print_function: Optional[Callable] = print) -> None:
64 """
65 Prints a summary of the model architecture.
67 Parameters:
68 print_function (Optional[Callable]): A function to use for printing the summary.
69 Defaults to the print function.
70 """
72 self.__DQNAgent.model.summary(print_fn = print_function)
74 def reinforcement_learning_fit(self, environment: TradingEnvironment, nr_of_steps: int,
75 steps_per_episode: int, callbacks: list[Callback]) -> dict[str, Any]:
76 """
77 Trains the reinforcement learning agent using the specified environment and parameters.
79 Parameters:
80 environment (TradingEnvironment): The trading environment to use for training.
81 nr_of_steps (int): The total number of steps to train the agent.
82 steps_per_episode (int): The number of steps per episode.
83 callbacks (list[Callback]): A list of Keras callbacks to use during training.
85 Returns:
86 (dict[str, Any]): The training history of the agent.
87 """
89 # Configure TensorFlow settings to be compatible with the rl library
90 if hasattr(tf, 'compat'):
91 tf.compat.v1.disable_eager_execution()
92 tf.compat.v1.experimental.output_all_intermediates(True)
93 config = tf.compat.v1.ConfigProto()
94 config.gpu_options.allow_growth = True
95 config.allow_soft_placement = True
96 tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config = config))
98 return self.__DQNAgent.fit(environment, nr_of_steps, callbacks = callbacks,
99 log_interval = steps_per_episode,
100 nb_max_episode_steps = steps_per_episode).history
102 def perform(self, observation: list[float]) -> int:
103 """
104 Performs the action for the agent based on the given observation.
106 Parameters:
107 observation (list[float]): The observation data to use for the action.
109 Returns:
110 (int): The result of the action.
111 """
113 return self.__DQNAgent.forward(observation)