Coverage for source/agent/agents/reinforcement_learning_agent.py: 81%
26 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-30 20:59 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-30 20:59 +0000
1# agent/strategies/reinforcement_learning_agent.py
3# global imports
4import rl
5from rl.agents import DQNAgent
6from rl.memory import SequentialMemory
7from rl.policy import Policy
8from tensorflow.keras.callbacks import Callback
9from tensorflow.keras.models import Model
10from tensorflow.keras.optimizers import Optimizer
11from typing import Any, Callable, Optional
13# local imports
14from source.agent import AgentBase, PerformanceTestable
15from source.environment import TradingEnvironment
17class ReinforcementLearningAgent(AgentBase, PerformanceTestable):
18 """
19 Implements a reinforcement learning agent using DQN that can be trained and tested
20 in a trading environment. It provides functionalities for fitting the model
21 with reinforcement learning data and performing actions based on observations.
22 """
24 def __init__(self, model: Model, policy: Policy, optimizer: Optimizer) -> None:
25 """
26 Class constructor. Creates a reinforcement learning agent with the given model,
27 policy, and optimizer.
29 Parameters:
30 model (Model): The Keras model to use for the agent.
31 policy (Policy): The policy to use for the agent's actions.
32 optimizer (Optimizer): The optimizer to use for training the agent.
33 """
35 memory = SequentialMemory(limit = 100000, window_length = 1)
36 self.__DQNAgent: DQNAgent = rl.agents.DQNAgent(model, policy, memory = memory,
37 nb_actions = model.output_shape[-1],
38 target_model_update = 1e-2)
39 self.__DQNAgent.compile(optimizer)
40 self.__DQNAgent.optimizer = self.__DQNAgent.model.optimizer # For compatibility with callbacks
42 def load_model(self, model_path: str) -> None:
43 """
44 Loads the model weights from the specified file path.
46 Parameters:
47 model_path (str): The path to the model weights file.
48 """
50 self.__DQNAgent.load_weights(model_path)
52 def save_model(self, model_path: str) -> None:
53 """
54 Saves the model weights to the specified file path.
56 Parameters:
57 model_path (str): The path to the model weights file.
58 """
60 self.__DQNAgent.save_weights(model_path)
62 def print_summary(self, print_function: Optional[Callable] = print) -> None:
63 """
64 Prints a summary of the model architecture.
66 Parameters:
67 print_function (Optional[Callable]): A function to use for printing the summary.
68 Defaults to the print function.
69 """
71 self.__DQNAgent.model.summary(print_fn = print_function)
73 def reinforcement_learning_fit(self, environment: TradingEnvironment, nr_of_steps: int,
74 steps_per_episode: int, callbacks: list[Callback]) -> dict[str, Any]:
75 """
76 Trains the reinforcement learning agent using the specified environment and parameters.
78 Parameters:
79 environment (TradingEnvironment): The trading environment to use for training.
80 nr_of_steps (int): The total number of steps to train the agent.
81 steps_per_episode (int): The number of steps per episode.
82 callbacks (list[Callback]): A list of Keras callbacks to use during training.
84 Returns:
85 (dict[str, Any]): The training history of the agent.
86 """
88 return self.__DQNAgent.fit(environment, nr_of_steps, callbacks = callbacks,
89 log_interval = steps_per_episode,
90 nb_max_episode_steps = steps_per_episode).history
92 def perform(self, observation: list[float]) -> int:
93 """
94 Performs the action for the agent based on the given observation.
96 Parameters:
97 observation (list[float]): The observation data to use for the action.
99 Returns:
100 (int): The result of the action.
101 """
103 return self.__DQNAgent.forward(observation)