Coverage for source/agent/agents/reinforcement_learning_agent.py: 81%

26 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-08-01 20:51 +0000

1# agent/agents/reinforcement_learning_agent.py 

2 

3# global imports 

4import rl 

5from rl.agents import DQNAgent 

6from rl.memory import SequentialMemory 

7from rl.policy import Policy 

8from tensorflow.keras.callbacks import Callback 

9from tensorflow.keras.models import Model 

10from tensorflow.keras.optimizers import Optimizer 

11from typing import Any, Callable, Optional 

12 

13# local imports 

14from source.agent import AgentBase, PerformanceTestable 

15from source.environment import TradingEnvironment 

16 

17class ReinforcementLearningAgent(AgentBase, PerformanceTestable): 

18 """ 

19 Implements a reinforcement learning agent using DQN that can be trained and tested 

20 in a trading environment. It provides functionalities for fitting the model 

21 with reinforcement learning data and performing actions based on observations. 

22 """ 

23 

24 def __init__(self, model: Model, policy: Policy, optimizer: Optimizer) -> None: 

25 """ 

26 Class constructor. Creates a reinforcement learning agent with the given model, 

27 policy, and optimizer. 

28 

29 Parameters: 

30 model (Model): The Keras model to use for the agent. 

31 policy (Policy): The policy to use for the agent's actions. 

32 optimizer (Optimizer): The optimizer to use for training the agent. 

33 """ 

34 

35 memory = SequentialMemory(limit = 100000, window_length = 1) 

36 self.__DQNAgent: DQNAgent = rl.agents.DQNAgent(model, policy, memory = memory, 

37 nb_actions = model.output_shape[-1], 

38 target_model_update = 1e-2) 

39 self.__DQNAgent.compile(optimizer) 

40 self.__DQNAgent.optimizer = self.__DQNAgent.model.optimizer # For compatibility with callbacks 

41 

42 def load_model(self, model_path: str) -> None: 

43 """ 

44 Loads the model weights from the specified file path. 

45 

46 Parameters: 

47 model_path (str): The path to the model weights file. 

48 """ 

49 

50 self.__DQNAgent.load_weights(model_path) 

51 

52 def save_model(self, model_path: str) -> None: 

53 """ 

54 Saves the model weights to the specified file path. 

55 

56 Parameters: 

57 model_path (str): The path to the model weights file. 

58 """ 

59 

60 self.__DQNAgent.save_weights(model_path) 

61 

62 def print_summary(self, print_function: Optional[Callable] = print) -> None: 

63 """ 

64 Prints a summary of the model architecture. 

65 

66 Parameters: 

67 print_function (Optional[Callable]): A function to use for printing the summary. 

68 Defaults to the print function. 

69 """ 

70 

71 self.__DQNAgent.model.summary(print_fn = print_function) 

72 

73 def reinforcement_learning_fit(self, environment: TradingEnvironment, nr_of_steps: int, 

74 steps_per_episode: int, callbacks: list[Callback]) -> dict[str, Any]: 

75 """ 

76 Trains the reinforcement learning agent using the specified environment and parameters. 

77 

78 Parameters: 

79 environment (TradingEnvironment): The trading environment to use for training. 

80 nr_of_steps (int): The total number of steps to train the agent. 

81 steps_per_episode (int): The number of steps per episode. 

82 callbacks (list[Callback]): A list of Keras callbacks to use during training. 

83 

84 Returns: 

85 (dict[str, Any]): The training history of the agent. 

86 """ 

87 

88 return self.__DQNAgent.fit(environment, nr_of_steps, callbacks = callbacks, 

89 log_interval = steps_per_episode, 

90 nb_max_episode_steps = steps_per_episode).history 

91 

92 def perform(self, observation: list[float]) -> int: 

93 """ 

94 Performs the action for the agent based on the given observation. 

95 

96 Parameters: 

97 observation (list[float]): The observation data to use for the action. 

98 

99 Returns: 

100 (int): The result of the action. 

101 """ 

102 

103 return self.__DQNAgent.forward(observation)