Coverage for source/agent/strategies/reinforcement_learning_strategy_handler.py: 56%

25 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-06-06 12:00 +0000

1# agent/strategies/reinforcement_learning_strategy_handler.py 

2 

3# global imports 

4from typing import Any 

5from tensorflow.keras.optimizers import Optimizer, Adam 

6from tensorflow.keras.callbacks import Callback 

7from rl.policy import Policy, BoltzmannQPolicy 

8 

9# local imports 

10from source.agent import LearningStrategyHandlerBase 

11from source.agent import AgentBase 

12from source.agent import ReinforcementLearningAgent 

13from source.environment import TradingEnvironment 

14from source.model import BluePrintBase 

15 

16class ReinforcementLearningStrategyHandler(LearningStrategyHandlerBase): 

17 """""" 

18 

19 PLOTTING_KEY: str = 'reinforcement_learning' 

20 

21 def __init__(self, policy: Policy = BoltzmannQPolicy(), 

22 optimizer: Optimizer = Adam(learning_rate = 0.001)) -> None: 

23 """""" 

24 

25 self.__policy: Policy = policy 

26 self.__optimizer: Optimizer = optimizer 

27 

28 def create_agent(self, model_blue_print: BluePrintBase, 

29 trading_environment: TradingEnvironment) -> AgentBase: 

30 """""" 

31 

32 observation_space_shape = trading_environment.observation_space.shape 

33 nr_of_actions = trading_environment.action_space.n 

34 spatial_data_shape = trading_environment.get_environment_spatial_data_dimension() 

35 model_adapter = model_blue_print.instantiate_model(observation_space_shape, nr_of_actions, 

36 spatial_data_shape) 

37 return ReinforcementLearningAgent(model_adapter.get_model(), self.__policy, self.__optimizer) 

38 

39 def fit(self, agent: ReinforcementLearningAgent, trading_environment: TradingEnvironment, 

40 nr_of_steps: int, nr_of_episodes: int, callbacks: list[Callback]) -> tuple[list[str], dict[str, Any]]: 

41 

42 if not isinstance(agent, ReinforcementLearningAgent): 

43 raise TypeError("Agent must be an instance of ReinforcementLearningAgent.") 

44 

45 steps_per_episode = nr_of_steps // nr_of_episodes 

46 return [ReinforcementLearningStrategyHandler.PLOTTING_KEY], \ 

47 [agent.reinforcement_learning_fit(trading_environment, 

48 nr_of_steps, steps_per_episode, callbacks)]