Coverage for source/agent/strategies/reinforcement_learning_strategy_handler.py: 95%
37 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-01 20:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-01 20:51 +0000
1# agent/strategies/reinforcement_learning_strategy_handler.py
3# global imports
4from rl.policy import BoltzmannQPolicy, Policy
5from tensorflow.keras.callbacks import Callback
6from tensorflow.keras.optimizers import Adam, Optimizer
7from typing import Any
9# local imports
10from source.agent import AgentBase, LearningStrategyHandlerBase, ReinforcementLearningAgent
11from source.environment import TradingEnvironment
12from source.model import BluePrintBase, TFModelAdapter
14class ReinforcementLearningStrategyHandler(LearningStrategyHandlerBase):
15 """
16 Implements a reinforcement learning strategy handler. It provides methods for creating
17 reinforcement learning agents, fitting them to the trading environment, and providing
18 parameters needed for model instantiation coherent with reinforcement learning.
19 """
21 # global class constants
22 PLOTTING_KEY: str = 'reinforcement_learning'
24 def __init__(self, policy: Policy = BoltzmannQPolicy(),
25 optimizer: Optimizer = Adam(learning_rate = 0.001)) -> None:
26 """
27 Class constructor. Initializes the policy and optimizer for the reinforcement learning agent.
29 Parameters:
30 policy (Policy): The policy to be used by the reinforcement learning agent.
31 optimizer (Optimizer): The optimizer to be used for training the reinforcement learning agent.
32 """
34 super().__init__()
35 self.__policy: Policy = policy
36 self.__optimizer: Optimizer = optimizer
38 def create_agent(self, model_blue_print: BluePrintBase,
39 trading_environment: TradingEnvironment) -> AgentBase:
40 """
41 Creates a reinforcement learning agent. It dynamically determines the parameters
42 needed for instantiation based on the model blueprint.
44 Parameters:
45 model_blue_print (BluePrintBase): The model blueprint to be used for agent creation.
46 trading_environment (TradingEnvironment): The trading environment in which the agent will operate.
48 Raises:
49 TypeError: If the model adapter is not an instance of TFModelAdapter.
51 Returns:
52 (AgentBase): An instance of the agent created using the model blueprint and trading environment.
53 """
55 parameters_needed_for_instantiation = model_blue_print.report_parameters_needed_for_instantiation()
56 kwargs = {}
57 for parameter in parameters_needed_for_instantiation:
58 kwargs[parameter] = self._provide_required_parameter(parameter, trading_environment)
59 model_adapter = model_blue_print.instantiate_model(**kwargs)
61 if not isinstance(model_adapter, TFModelAdapter):
62 raise TypeError("Model adapter must be an instance of TFModelAdapter for reinforcement learning.")
64 return ReinforcementLearningAgent(model_adapter.get_model(), self.__policy, self.__optimizer)
66 def fit(self, agent: ReinforcementLearningAgent, trading_environment: TradingEnvironment,
67 nr_of_steps: int, nr_of_episodes: int, callbacks: list[Callback]) -> tuple[list[str], dict[str, Any]]:
68 """
69 Fits the reinforcement learning agent to the trading environment.
71 Parameters:
72 agent (ReinforcementLearningAgent): The reinforcement learning agent to fit.
73 trading_environment (TradingEnvironment): The trading environment to use.
74 nr_of_steps (int): The number of training steps to perform.
75 nr_of_episodes (int): The number of training episodes to perform.
76 callbacks (list[Callback]): List of callbacks to use during training.
78 Raises:
79 TypeError: If the agent is not an instance of ReinforcementLearningAgent.
81 Returns:
82 (list[str], dict[str, Any]): A tuple containing the keys and data collected during training.
83 """
85 if not isinstance(agent, ReinforcementLearningAgent):
86 raise TypeError("Agent must be an instance of ReinforcementLearningAgent.")
88 keys, data = super().fit(agent, trading_environment, nr_of_steps, nr_of_episodes, callbacks)
89 keys.append(self.PLOTTING_KEY)
91 steps_per_episode = nr_of_steps // nr_of_episodes
92 reinforcement_learning_history = agent.reinforcement_learning_fit(trading_environment, nr_of_steps, steps_per_episode, callbacks)
93 data.append(reinforcement_learning_history)
95 return keys, data
97 def _provide_input_shape(self, trading_environment: TradingEnvironment) -> tuple[int, int]:
98 """
99 Provides the input shape for the model based on the trading environment.
101 Parameters:
102 trading_environment (TradingEnvironment): The trading environment to use.
104 Returns:
105 (tuple[int, int]): The input shape for the model.
106 """
108 return trading_environment.observation_space.shape
110 def _provide_output_length(self, trading_environment: TradingEnvironment) -> int:
111 """
112 Provides the output length for the model based on the trading environment.
114 Parameters:
115 trading_environment (TradingEnvironment): The trading environment to use.
117 Returns:
118 (int): The output length for the model.
119 """
121 return trading_environment.action_space.n
123 def _provide_spatial_data_shape(self, trading_environment: TradingEnvironment) -> tuple[int, int]:
124 """
125 Provides the spatial data shape for the model based on the trading environment.
127 Parameters:
128 trading_environment (TradingEnvironment): The trading environment to use.
130 Returns:
131 (tuple[int, int]): The spatial data shape for the model.
132 """
134 return trading_environment.get_environment_spatial_data_dimension()