Coverage for source/agent/strategies/reinforcement_learning_strategy_handler.py: 96%
45 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-19 10:22 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-19 10:22 +0000
1# agent/strategies/reinforcement_learning_strategy_handler.py
3# global imports
4import tensorflow as tf
5from rl.policy import BoltzmannQPolicy, Policy
6from tensorflow.keras.callbacks import Callback
7from tensorflow.keras.optimizers import Adam, Optimizer
8from typing import Any
10# local imports
11from source.agent import AgentBase, LearningStrategyHandlerBase, ReinforcementLearningAgent
12from source.environment import TradingEnvironment
13from source.model import BluePrintBase, TFModelAdapter
15class ReinforcementLearningStrategyHandler(LearningStrategyHandlerBase):
16 """
17 Implements a reinforcement learning strategy handler. It provides methods for creating
18 reinforcement learning agents, fitting them to the trading environment, and providing
19 parameters needed for model instantiation coherent with reinforcement learning.
20 """
22 # global class constants
23 PLOTTING_KEY: str = 'reinforcement_learning'
25 def __init__(self, policy: Policy = BoltzmannQPolicy(),
26 optimizer: Optimizer = Adam(learning_rate = 0.001)) -> None:
27 """
28 Class constructor. Initializes the policy and optimizer for the reinforcement learning agent.
30 Parameters:
31 policy (Policy): The policy to be used by the reinforcement learning agent.
32 optimizer (Optimizer): The optimizer to be used for training the reinforcement learning agent.
33 """
35 super().__init__()
36 self.__policy: Policy = policy
37 self.__optimizer: Optimizer = optimizer
39 def create_agent(self, model_blue_print: BluePrintBase,
40 trading_environment: TradingEnvironment) -> AgentBase:
41 """
42 Creates a reinforcement learning agent. It dynamically determines the parameters
43 needed for instantiation based on the model blueprint.
45 Parameters:
46 model_blue_print (BluePrintBase): The model blueprint to be used for agent creation.
47 trading_environment (TradingEnvironment): The trading environment in which the agent will operate.
49 Raises:
50 TypeError: If the model adapter is not an instance of TFModelAdapter.
52 Returns:
53 (AgentBase): An instance of the agent created using the model blueprint and trading environment.
54 """
56 # Configure TensorFlow settings to be compatible with the rl library
57 if hasattr(tf, 'compat'):
58 tf.compat.v1.disable_eager_execution()
59 tf.compat.v1.experimental.output_all_intermediates(True)
60 config = tf.compat.v1.ConfigProto()
61 config.gpu_options.allow_growth = True
62 config.allow_soft_placement = True
63 tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config = config))
65 parameters_needed_for_instantiation = model_blue_print.report_parameters_needed_for_instantiation()
66 kwargs = {}
67 for parameter in parameters_needed_for_instantiation:
68 kwargs[parameter] = self._provide_required_parameter(parameter, trading_environment)
69 model_adapter = model_blue_print.instantiate_model(**kwargs)
71 if not isinstance(model_adapter, TFModelAdapter):
72 raise TypeError("Model adapter must be an instance of TFModelAdapter for reinforcement learning.")
74 return ReinforcementLearningAgent(model_adapter.get_model(), self.__policy, self.__optimizer)
76 def fit(self, agent: ReinforcementLearningAgent, trading_environment: TradingEnvironment,
77 nr_of_steps: int, nr_of_episodes: int, callbacks: list[Callback]) -> tuple[list[str], dict[str, Any]]:
78 """
79 Fits the reinforcement learning agent to the trading environment.
81 Parameters:
82 agent (ReinforcementLearningAgent): The reinforcement learning agent to fit.
83 trading_environment (TradingEnvironment): The trading environment to use.
84 nr_of_steps (int): The number of training steps to perform.
85 nr_of_episodes (int): The number of training episodes to perform.
86 callbacks (list[Callback]): List of callbacks to use during training.
88 Raises:
89 TypeError: If the agent is not an instance of ReinforcementLearningAgent.
91 Returns:
92 (list[str], dict[str, Any]): A tuple containing the keys and data collected during training.
93 """
95 if not isinstance(agent, ReinforcementLearningAgent):
96 raise TypeError("Agent must be an instance of ReinforcementLearningAgent.")
98 keys, data = super().fit(agent, trading_environment, nr_of_steps, nr_of_episodes, callbacks)
99 keys.append(self.PLOTTING_KEY)
101 steps_per_episode = nr_of_steps // nr_of_episodes
102 reinforcement_learning_history = agent.reinforcement_learning_fit(trading_environment, nr_of_steps, steps_per_episode, callbacks)
103 data.append(reinforcement_learning_history)
105 return keys, data
107 def _provide_input_shape(self, trading_environment: TradingEnvironment) -> tuple[int, int]:
108 """
109 Provides the input shape for the model based on the trading environment.
111 Parameters:
112 trading_environment (TradingEnvironment): The trading environment to use.
114 Returns:
115 (tuple[int, int]): The input shape for the model.
116 """
118 return trading_environment.observation_space.shape
120 def _provide_output_length(self, trading_environment: TradingEnvironment) -> int:
121 """
122 Provides the output length for the model based on the trading environment.
124 Parameters:
125 trading_environment (TradingEnvironment): The trading environment to use.
127 Returns:
128 (int): The output length for the model.
129 """
131 return trading_environment.action_space.n
133 def _provide_spatial_data_shape(self, trading_environment: TradingEnvironment) -> tuple[int, int]:
134 """
135 Provides the spatial data shape for the model based on the trading environment.
137 Parameters:
138 trading_environment (TradingEnvironment): The trading environment to use.
140 Returns:
141 (tuple[int, int]): The spatial data shape for the model.
142 """
144 return trading_environment.get_environment_spatial_data_dimension()