Coverage for source/agent/agents/reinforcement_learning_agent.py: 69%
26 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-06 12:00 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-06 12:00 +0000
1# agent/strategies/reinforcement_learning_agent.py
3# global imports
4import rl
5from rl.memory import SequentialMemory
6from typing import Optional, Any, Callable
7from rl.agents import DQNAgent
8from tensorflow.keras.optimizers import Optimizer
9from rl.policy import Policy
10from tensorflow.keras.callbacks import Callback
11from tensorflow.keras.models import Model
13# local imports
14from source.agent import AgentBase
15from source.agent import PerformanceTestable
16from source.environment import TradingEnvironment
18class ReinforcementLearningAgent(AgentBase, PerformanceTestable):
19 """"""
21 def __init__(self, model: Model, policy: Policy, optimizer: Optimizer) -> None:
22 """"""
24 memory = SequentialMemory(limit = 100000, window_length = 1)
25 self.__DQNAgent: DQNAgent = rl.agents.DQNAgent(model, policy, memory = memory,
26 nb_actions = model.output_shape[-1],
27 target_model_update = 1e-2)
28 self.__DQNAgent.compile(optimizer)
30 def load_model(self, model_path: str) -> None:
31 """"""
33 self.__DQNAgent.load_weights(model_path)
35 def save_model(self, model_path: str) -> None:
36 """"""
38 self.__DQNAgent.save_weights(model_path)
40 def print_summary(self, print_function: Optional[Callable] = print) -> None:
41 """"""
43 self.__DQNAgent.model.summary(print_fn = print_function)
45 def reinforcement_learning_fit(self, environment: TradingEnvironment, nr_of_steps: int,
46 steps_per_episode: int, callbacks: list[Callback]) -> dict[str, Any]:
47 """"""
49 return self.__DQNAgent.fit(environment, nr_of_steps, callbacks = callbacks,
50 log_interval = steps_per_episode,
51 nb_max_episode_steps = steps_per_episode).history
53 def perform(self, observation: list[float]) -> int:
54 """"""
56 return self.__DQNAgent.forward(observation)