Coverage for source/agent/strategies/reinforcement_learning_strategy_handler.py: 96%

45 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-08-19 09:30 +0000

1# agent/strategies/reinforcement_learning_strategy_handler.py 

2 

3# global imports 

4import tensorflow as tf 

5from rl.policy import BoltzmannQPolicy, Policy 

6from tensorflow.keras.callbacks import Callback 

7from tensorflow.keras.optimizers import Adam, Optimizer 

8from typing import Any 

9 

10# local imports 

11from source.agent import AgentBase, LearningStrategyHandlerBase, ReinforcementLearningAgent 

12from source.environment import TradingEnvironment 

13from source.model import BluePrintBase, TFModelAdapter 

14 

15class ReinforcementLearningStrategyHandler(LearningStrategyHandlerBase): 

16 """ 

17 Implements a reinforcement learning strategy handler. It provides methods for creating 

18 reinforcement learning agents, fitting them to the trading environment, and providing 

19 parameters needed for model instantiation coherent with reinforcement learning. 

20 """ 

21 

22 # global class constants 

23 PLOTTING_KEY: str = 'reinforcement_learning' 

24 

25 def __init__(self, policy: Policy = BoltzmannQPolicy(), 

26 optimizer: Optimizer = Adam(learning_rate = 0.001)) -> None: 

27 """ 

28 Class constructor. Initializes the policy and optimizer for the reinforcement learning agent. 

29 

30 Parameters: 

31 policy (Policy): The policy to be used by the reinforcement learning agent. 

32 optimizer (Optimizer): The optimizer to be used for training the reinforcement learning agent. 

33 """ 

34 

35 super().__init__() 

36 self.__policy: Policy = policy 

37 self.__optimizer: Optimizer = optimizer 

38 

39 def create_agent(self, model_blue_print: BluePrintBase, 

40 trading_environment: TradingEnvironment) -> AgentBase: 

41 """ 

42 Creates a reinforcement learning agent. It dynamically determines the parameters 

43 needed for instantiation based on the model blueprint. 

44 

45 Parameters: 

46 model_blue_print (BluePrintBase): The model blueprint to be used for agent creation. 

47 trading_environment (TradingEnvironment): The trading environment in which the agent will operate. 

48 

49 Raises: 

50 TypeError: If the model adapter is not an instance of TFModelAdapter. 

51 

52 Returns: 

53 (AgentBase): An instance of the agent created using the model blueprint and trading environment. 

54 """ 

55 

56 # Configure TensorFlow settings to be compatible with the rl library 

57 if hasattr(tf, 'compat'): 

58 tf.compat.v1.disable_eager_execution() 

59 tf.compat.v1.experimental.output_all_intermediates(True) 

60 config = tf.compat.v1.ConfigProto() 

61 config.gpu_options.allow_growth = True 

62 config.allow_soft_placement = True 

63 tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config = config)) 

64 

65 parameters_needed_for_instantiation = model_blue_print.report_parameters_needed_for_instantiation() 

66 kwargs = {} 

67 for parameter in parameters_needed_for_instantiation: 

68 kwargs[parameter] = self._provide_required_parameter(parameter, trading_environment) 

69 model_adapter = model_blue_print.instantiate_model(**kwargs) 

70 

71 if not isinstance(model_adapter, TFModelAdapter): 

72 raise TypeError("Model adapter must be an instance of TFModelAdapter for reinforcement learning.") 

73 

74 return ReinforcementLearningAgent(model_adapter.get_model(), self.__policy, self.__optimizer) 

75 

76 def fit(self, agent: ReinforcementLearningAgent, trading_environment: TradingEnvironment, 

77 nr_of_steps: int, nr_of_episodes: int, callbacks: list[Callback]) -> tuple[list[str], dict[str, Any]]: 

78 """ 

79 Fits the reinforcement learning agent to the trading environment. 

80 

81 Parameters: 

82 agent (ReinforcementLearningAgent): The reinforcement learning agent to fit. 

83 trading_environment (TradingEnvironment): The trading environment to use. 

84 nr_of_steps (int): The number of training steps to perform. 

85 nr_of_episodes (int): The number of training episodes to perform. 

86 callbacks (list[Callback]): List of callbacks to use during training. 

87 

88 Raises: 

89 TypeError: If the agent is not an instance of ReinforcementLearningAgent. 

90 

91 Returns: 

92 (list[str], dict[str, Any]): A tuple containing the keys and data collected during training. 

93 """ 

94 

95 if not isinstance(agent, ReinforcementLearningAgent): 

96 raise TypeError("Agent must be an instance of ReinforcementLearningAgent.") 

97 

98 keys, data = super().fit(agent, trading_environment, nr_of_steps, nr_of_episodes, callbacks) 

99 keys.append(self.PLOTTING_KEY) 

100 

101 steps_per_episode = nr_of_steps // nr_of_episodes 

102 reinforcement_learning_history = agent.reinforcement_learning_fit(trading_environment, nr_of_steps, steps_per_episode, callbacks) 

103 data.append(reinforcement_learning_history) 

104 

105 return keys, data 

106 

107 def _provide_input_shape(self, trading_environment: TradingEnvironment) -> tuple[int, int]: 

108 """ 

109 Provides the input shape for the model based on the trading environment. 

110 

111 Parameters: 

112 trading_environment (TradingEnvironment): The trading environment to use. 

113 

114 Returns: 

115 (tuple[int, int]): The input shape for the model. 

116 """ 

117 

118 return trading_environment.observation_space.shape 

119 

120 def _provide_output_length(self, trading_environment: TradingEnvironment) -> int: 

121 """ 

122 Provides the output length for the model based on the trading environment. 

123 

124 Parameters: 

125 trading_environment (TradingEnvironment): The trading environment to use. 

126 

127 Returns: 

128 (int): The output length for the model. 

129 """ 

130 

131 return trading_environment.action_space.n 

132 

133 def _provide_spatial_data_shape(self, trading_environment: TradingEnvironment) -> tuple[int, int]: 

134 """ 

135 Provides the spatial data shape for the model based on the trading environment. 

136 

137 Parameters: 

138 trading_environment (TradingEnvironment): The trading environment to use. 

139 

140 Returns: 

141 (tuple[int, int]): The spatial data shape for the model. 

142 """ 

143 

144 return trading_environment.get_environment_spatial_data_dimension()