Coverage for source/agent/agents/reinforcement_learning_agent.py: 69%

26 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-06-06 12:00 +0000

1# agent/strategies/reinforcement_learning_agent.py 

2 

3# global imports 

4import rl 

5from rl.memory import SequentialMemory 

6from typing import Optional, Any, Callable 

7from rl.agents import DQNAgent 

8from tensorflow.keras.optimizers import Optimizer 

9from rl.policy import Policy 

10from tensorflow.keras.callbacks import Callback 

11from tensorflow.keras.models import Model 

12 

13# local imports 

14from source.agent import AgentBase 

15from source.agent import PerformanceTestable 

16from source.environment import TradingEnvironment 

17 

18class ReinforcementLearningAgent(AgentBase, PerformanceTestable): 

19 """""" 

20 

21 def __init__(self, model: Model, policy: Policy, optimizer: Optimizer) -> None: 

22 """""" 

23 

24 memory = SequentialMemory(limit = 100000, window_length = 1) 

25 self.__DQNAgent: DQNAgent = rl.agents.DQNAgent(model, policy, memory = memory, 

26 nb_actions = model.output_shape[-1], 

27 target_model_update = 1e-2) 

28 self.__DQNAgent.compile(optimizer) 

29 

30 def load_model(self, model_path: str) -> None: 

31 """""" 

32 

33 self.__DQNAgent.load_weights(model_path) 

34 

35 def save_model(self, model_path: str) -> None: 

36 """""" 

37 

38 self.__DQNAgent.save_weights(model_path) 

39 

40 def print_summary(self, print_function: Optional[Callable] = print) -> None: 

41 """""" 

42 

43 self.__DQNAgent.model.summary(print_fn = print_function) 

44 

45 def reinforcement_learning_fit(self, environment: TradingEnvironment, nr_of_steps: int, 

46 steps_per_episode: int, callbacks: list[Callback]) -> dict[str, Any]: 

47 """""" 

48 

49 return self.__DQNAgent.fit(environment, nr_of_steps, callbacks = callbacks, 

50 log_interval = steps_per_episode, 

51 nb_max_episode_steps = steps_per_episode).history 

52 

53 def perform(self, observation: list[float]) -> int: 

54 """""" 

55 

56 return self.__DQNAgent.forward(observation)