Coverage for source/agent/strategies/reinforcement_learning_strategy_handler.py: 56%
25 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-06 12:00 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-06 12:00 +0000
1# agent/strategies/reinforcement_learning_strategy_handler.py
3# global imports
4from typing import Any
5from tensorflow.keras.optimizers import Optimizer, Adam
6from tensorflow.keras.callbacks import Callback
7from rl.policy import Policy, BoltzmannQPolicy
9# local imports
10from source.agent import LearningStrategyHandlerBase
11from source.agent import AgentBase
12from source.agent import ReinforcementLearningAgent
13from source.environment import TradingEnvironment
14from source.model import BluePrintBase
16class ReinforcementLearningStrategyHandler(LearningStrategyHandlerBase):
17 """"""
19 PLOTTING_KEY: str = 'reinforcement_learning'
21 def __init__(self, policy: Policy = BoltzmannQPolicy(),
22 optimizer: Optimizer = Adam(learning_rate = 0.001)) -> None:
23 """"""
25 self.__policy: Policy = policy
26 self.__optimizer: Optimizer = optimizer
28 def create_agent(self, model_blue_print: BluePrintBase,
29 trading_environment: TradingEnvironment) -> AgentBase:
30 """"""
32 observation_space_shape = trading_environment.observation_space.shape
33 nr_of_actions = trading_environment.action_space.n
34 spatial_data_shape = trading_environment.get_environment_spatial_data_dimension()
35 model_adapter = model_blue_print.instantiate_model(observation_space_shape, nr_of_actions,
36 spatial_data_shape)
37 return ReinforcementLearningAgent(model_adapter.get_model(), self.__policy, self.__optimizer)
39 def fit(self, agent: ReinforcementLearningAgent, trading_environment: TradingEnvironment,
40 nr_of_steps: int, nr_of_episodes: int, callbacks: list[Callback]) -> tuple[list[str], dict[str, Any]]:
42 if not isinstance(agent, ReinforcementLearningAgent):
43 raise TypeError("Agent must be an instance of ReinforcementLearningAgent.")
45 steps_per_episode = nr_of_steps // nr_of_episodes
46 return [ReinforcementLearningStrategyHandler.PLOTTING_KEY], \
47 [agent.reinforcement_learning_fit(trading_environment,
48 nr_of_steps, steps_per_episode, callbacks)]