Coverage for source/agent/strategies/reinforcement_learning_strategy_handler.py: 95%

37 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-07-30 20:59 +0000

1# agent/strategies/reinforcement_learning_strategy_handler.py 

2 

3# global imports 

4from rl.policy import BoltzmannQPolicy, Policy 

5from tensorflow.keras.callbacks import Callback 

6from tensorflow.keras.optimizers import Adam, Optimizer 

7from typing import Any 

8 

9# local imports 

10from source.agent import AgentBase, LearningStrategyHandlerBase, ReinforcementLearningAgent 

11from source.environment import TradingEnvironment 

12from source.model import BluePrintBase, TFModelAdapter 

13 

14class ReinforcementLearningStrategyHandler(LearningStrategyHandlerBase): 

15 """ 

16 Implements a reinforcement learning strategy handler. It provides methods for creating 

17 reinforcement learning agents, fitting them to the trading environment, and providing 

18 parameters needed for model instantiation coherent with reinforcement learning. 

19 """ 

20 

21 # global class constants 

22 PLOTTING_KEY: str = 'reinforcement_learning' 

23 

24 def __init__(self, policy: Policy = BoltzmannQPolicy(), 

25 optimizer: Optimizer = Adam(learning_rate = 0.001)) -> None: 

26 """ 

27 Class constructor. Initializes the policy and optimizer for the reinforcement learning agent. 

28 

29 Parameters: 

30 policy (Policy): The policy to be used by the reinforcement learning agent. 

31 optimizer (Optimizer): The optimizer to be used for training the reinforcement learning agent. 

32 """ 

33 

34 super().__init__() 

35 self.__policy: Policy = policy 

36 self.__optimizer: Optimizer = optimizer 

37 

38 def create_agent(self, model_blue_print: BluePrintBase, 

39 trading_environment: TradingEnvironment) -> AgentBase: 

40 """ 

41 Creates a reinforcement learning agent. It dynamically determines the parameters 

42 needed for instantiation based on the model blueprint. 

43 

44 Parameters: 

45 model_blue_print (BluePrintBase): The model blueprint to be used for agent creation. 

46 trading_environment (TradingEnvironment): The trading environment in which the agent will operate. 

47 

48 Raises: 

49 TypeError: If the model adapter is not an instance of TFModelAdapter. 

50 

51 Returns: 

52 (AgentBase): An instance of the agent created using the model blueprint and trading environment. 

53 """ 

54 

55 parameters_needed_for_instantiation = model_blue_print.report_parameters_needed_for_instantiation() 

56 kwargs = {} 

57 for parameter in parameters_needed_for_instantiation: 

58 kwargs[parameter] = self._provide_required_parameter(parameter, trading_environment) 

59 model_adapter = model_blue_print.instantiate_model(**kwargs) 

60 

61 if not isinstance(model_adapter, TFModelAdapter): 

62 raise TypeError("Model adapter must be an instance of TFModelAdapter for reinforcement learning.") 

63 

64 return ReinforcementLearningAgent(model_adapter.get_model(), self.__policy, self.__optimizer) 

65 

66 def fit(self, agent: ReinforcementLearningAgent, trading_environment: TradingEnvironment, 

67 nr_of_steps: int, nr_of_episodes: int, callbacks: list[Callback]) -> tuple[list[str], dict[str, Any]]: 

68 """ 

69 Fits the reinforcement learning agent to the trading environment. 

70 

71 Parameters: 

72 agent (ReinforcementLearningAgent): The reinforcement learning agent to fit. 

73 trading_environment (TradingEnvironment): The trading environment to use. 

74 nr_of_steps (int): The number of training steps to perform. 

75 nr_of_episodes (int): The number of training episodes to perform. 

76 callbacks (list[Callback]): List of callbacks to use during training. 

77 

78 Raises: 

79 TypeError: If the agent is not an instance of ReinforcementLearningAgent. 

80 

81 Returns: 

82 (list[str], dict[str, Any]): A tuple containing the keys and data collected during training. 

83 """ 

84 

85 if not isinstance(agent, ReinforcementLearningAgent): 

86 raise TypeError("Agent must be an instance of ReinforcementLearningAgent.") 

87 

88 keys, data = super().fit(agent, trading_environment, nr_of_steps, nr_of_episodes, callbacks) 

89 keys.append(self.PLOTTING_KEY) 

90 

91 steps_per_episode = nr_of_steps // nr_of_episodes 

92 reinforcement_learning_history = agent.reinforcement_learning_fit(trading_environment, nr_of_steps, steps_per_episode, callbacks) 

93 data.append(reinforcement_learning_history) 

94 

95 return keys, data 

96 

97 def _provide_input_shape(self, trading_environment: TradingEnvironment) -> tuple[int, int]: 

98 """ 

99 Provides the input shape for the model based on the trading environment. 

100 

101 Parameters: 

102 trading_environment (TradingEnvironment): The trading environment to use. 

103 

104 Returns: 

105 (tuple[int, int]): The input shape for the model. 

106 """ 

107 

108 return trading_environment.observation_space.shape 

109 

110 def _provide_output_length(self, trading_environment: TradingEnvironment) -> int: 

111 """ 

112 Provides the output length for the model based on the trading environment. 

113 

114 Parameters: 

115 trading_environment (TradingEnvironment): The trading environment to use. 

116 

117 Returns: 

118 (int): The output length for the model. 

119 """ 

120 

121 return trading_environment.action_space.n 

122 

123 def _provide_spatial_data_shape(self, trading_environment: TradingEnvironment) -> tuple[int, int]: 

124 """ 

125 Provides the spatial data shape for the model based on the trading environment. 

126 

127 Parameters: 

128 trading_environment (TradingEnvironment): The trading environment to use. 

129 

130 Returns: 

131 (tuple[int, int]): The spatial data shape for the model. 

132 """ 

133 

134 return trading_environment.get_environment_spatial_data_dimension()