Coverage for source/agent/strategies/learning_strategy_handler_base.py: 86%
36 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-30 20:59 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-30 20:59 +0000
1# agent/strategies/learning_strategy_handler_base.py
3# global imports
4from abc import ABC, abstractmethod
5from tensorflow.keras.callbacks import Callback
6from typing import Any
8# local imports
9from source.agent import AgentBase
10from source.environment import TradingEnvironment
11from source.model import BluePrintBase
13class LearningStrategyHandlerBase(ABC):
14 """
15 Implements a base class for learning strategy handlers. It provides an interface
16 for creating agents, fitting them to the environment, and providing parameters
17 needed for the model blueprint instantiation.
18 """
20 # global class constants
21 PLOTTING_KEY: str = 'asset_price_movement_summary'
23 def __init__(self) -> None:
24 """
25 Class constructor. Initializes the parameter provision callbacks.
26 """
28 self.__parameter_provision_callbacks = {
29 'input_shape': self._provide_input_shape,
30 'output_length': self._provide_output_length,
31 'spatial_data_shape': self._provide_spatial_data_shape
32 }
34 @abstractmethod
35 def create_agent(self, model_blue_print: BluePrintBase,
36 trading_environment: TradingEnvironment) -> AgentBase:
37 """
38 Creates an agent based on the provided model blueprint and trading environment.
40 Parameters:
41 model_blue_print (BluePrintBase): The model blueprint to be used for agent creation.
42 trading_environment (TradingEnvironment): The trading environment in which the agent will operate.
44 Returns:
45 (AgentBase): An instance of the agent created using the model blueprint and trading environment.
46 """
48 pass
50 @abstractmethod
51 def fit(self, agent: AgentBase, environment: TradingEnvironment, nr_of_steps: int, nr_of_episodes: int,
52 callbacks: list[Callback]) -> tuple[list[str], list[dict[str, Any]]]:
53 """
54 Fits the agent to the environment using the specified number of steps and episodes. It also collects data
55 for plotting summary statistics.
57 Parameters:
58 agent (AgentBase): The agent to be fitted to the environment.
59 environment (TradingEnvironment): The trading environment in which the agent will be trained.
60 nr_of_steps (int): The number of steps to be taken during training.
61 nr_of_episodes (int): The number of episodes to be run during training.
62 callbacks (list[Callback]): A list of Keras callbacks to be used during training.
64 Returns:
65 (tuple[list[str], list[dict[str, Any]]]): A tuple containing a list of keys and a list of dictionaries
66 with the data collected during training.
67 """
69 data = {}
71 environment.set_mode(TradingEnvironment.TEST_MODE)
72 data['test_part_price_movement'] = environment.get_data_for_iteration(['close'])
73 data['test_part_volatility'] = environment.get_data_for_iteration(['volatility'])
75 environment.set_mode(TradingEnvironment.TRAIN_MODE)
76 data['train_part_price_movement'] = environment.get_data_for_iteration(['close'])
77 data['train_part_volatility'] = environment.get_data_for_iteration(['volatility'])
79 return [LearningStrategyHandlerBase.PLOTTING_KEY], [data]
81 def _provide_required_parameter(self, parameter: str, environment: TradingEnvironment) -> Any:
82 """
83 Provides the required parameter for the given environment.
85 Parameters:
86 parameter (str): The name of the parameter to be provided.
87 environment (TradingEnvironment): The trading environment from which the parameter is to be provided.
89 Raises:
90 ValueError: If the parameter is not supported by this environment configuration.
92 Returns:
93 (Any): The value of the requested parameter.
94 """
96 if parameter in self.__parameter_provision_callbacks:
97 return self.__parameter_provision_callbacks[parameter](environment)
98 else:
99 raise ValueError(f"Parameter '{parameter}' is not supported to provided by this environment configuration.")
101 @abstractmethod
102 def _provide_input_shape(self, environment: TradingEnvironment) -> tuple[int, int]:
103 """
104 Provides the input shape for the given environment.
106 Parameters:
107 environment (TradingEnvironment): The trading environment for which the input shape is to be provided.
109 Returns:
110 (tuple[int, int]): The input shape as a tuple.
111 """
113 pass
115 @abstractmethod
116 def _provide_output_length(self, environment: TradingEnvironment) -> int:
117 """
118 Provides the output length for the given environment.
120 Parameters:
121 environment (TradingEnvironment): The trading environment for which the output length is to be provided.
123 Returns:
124 (int): The output length.
125 """
127 pass
129 @abstractmethod
130 def _provide_spatial_data_shape(self, environment: TradingEnvironment) -> tuple[int, int]:
131 """
132 Provides the spatial data shape for the given environment.
134 Parameters:
135 environment (TradingEnvironment): The trading environment for which the spatial data shape is to be provided.
137 Returns:
138 (tuple[int, int]): The spatial data shape as a tuple.
139 """
141 pass