Coverage for source/agent/strategies/classification_learning_strategy_handler.py: 98%
49 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-30 20:59 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-30 20:59 +0000
1# agent/strategies/classification_learning_strategy_handler.py
3# global imports
4import logging
5from tensorflow.keras.callbacks import Callback
6from typing import Any
8# local imports
9from source.agent import AgentBase, ClassificationLearningAgent, LearningStrategyHandlerBase
10from source.environment import TradingEnvironment
11from source.model import BluePrintBase
13class ClassificationLearningStrategyHandler(LearningStrategyHandlerBase):
14 """
15 Implements a learning strategy handler for classification tasks. It provides
16 functionalities for creating agents and fitting models.
17 """
19 # global class constants
20 PLOTTING_KEYS: str = ['price_movement_trend_class_summary', 'classification_learning']
22 def create_agent(self, model_blue_print: BluePrintBase,
23 trading_environment: TradingEnvironment) -> AgentBase:
24 """
25 Creates a classification learning agent. It dynamically determines the parameters
26 needed for instantiation based on the model blueprint.
28 Parameters:
29 model_blue_print (BluePrintBase): The model blueprint to be used for agent creation.
30 trading_environment (TradingEnvironment): The trading environment in which the agent will operate.
32 Returns:
33 (Agent): An instance of the classification learning agent created using the model blueprint and trading environment.
34 """
36 parameters_needed_for_instantiation = model_blue_print.report_parameters_needed_for_instantiation()
37 kwargs = {}
38 for parameter in parameters_needed_for_instantiation:
39 kwargs[parameter] = self._provide_required_parameter(parameter, trading_environment)
41 return ClassificationLearningAgent(model_blue_print.instantiate_model(**kwargs))
43 def fit(self, agent: ClassificationLearningAgent, trading_environment: TradingEnvironment,
44 nr_of_steps: int, nr_of_episodes: int, callbacks: list[Callback]) -> tuple[list[str], list[dict[str, Any]]]:
45 """
46 Fits the classification learning agent to the trading environment.
48 Parameters:
49 agent (ClassificationLearningAgent): The classification learning agent to fit.
50 environment (TradingEnvironment): The trading environment to use.
51 nr_of_steps (int): The number of training steps to perform.
52 nr_of_episodes (int): The number of training episodes to perform.
53 callbacks (list[Callback]): List of callbacks to use during training.
55 Raises:
56 TypeError: If the agent is not an instance of ClassificationLearningAgent.
58 Returns:
59 (list[str], list[dict[str, Any]]): A tuple containing the keys and data collected during training.
60 """
62 if not isinstance(agent, ClassificationLearningAgent):
63 raise TypeError("Agent must be an instance of ClassificationLearningAgent.")
65 keys, data = super().fit(agent, trading_environment, nr_of_steps, nr_of_episodes, callbacks)
66 keys.append(self.PLOTTING_KEYS[0])
67 keys.append(self.PLOTTING_KEYS[1] + "_" + agent.get_model_adapter_tag())
69 input_data, output_data, input_data_test, output_data_test = trading_environment.get_labeled_data()
70 steps_per_epoch = nr_of_steps // nr_of_episodes
71 batch_size = len(input_data) // steps_per_epoch
72 validation_tuple = (input_data_test, output_data_test)
73 if batch_size <= 0:
74 logging.warning("Batch size is zero or negative, using value of 1 instead.")
75 batch_size = 1
77 data.append(self.__prepare_price_movement_trend_class_summary_plot_data(trading_environment))
78 data.append(agent.classification_fit(input_data, output_data, validation_data = validation_tuple,
79 batch_size = batch_size, epochs = nr_of_episodes, callbacks = callbacks))
81 return keys, data
83 def _provide_input_shape(self, trading_environment: TradingEnvironment) -> tuple[int, int]:
84 """
85 Provides the input shape for the model based on the trading environment.
87 Parameters:
88 trading_environment (TradingEnvironment): The trading environment to use.
90 Returns:
91 (tuple[int, int]): The input shape for the model.
92 """
94 windows_size = trading_environment.get_trading_consts().WINDOW_SIZE
95 spatial_data_shape = trading_environment.get_environment_spatial_data_dimension()
96 return (spatial_data_shape[1] * windows_size, )
98 def _provide_output_length(self, trading_environment: TradingEnvironment) -> int:
99 """
100 Provides the output length for the model based on the trading environment.
102 Parameters:
103 trading_environment (TradingEnvironment): The trading environment to use.
105 Returns:
106 (int): The output length for the model.
107 """
109 return len(trading_environment.get_trading_consts().OUTPUT_CLASSES)
111 def _provide_spatial_data_shape(self, trading_environment: TradingEnvironment) -> tuple[int, int]:
112 """
113 Provides the spatial data shape for the model based on the trading environment.
115 Parameters:
116 trading_environment (TradingEnvironment): The trading environment to use.
118 Returns:
119 (tuple[int, int]): The spatial data shape for the model.
120 """
122 return trading_environment.get_environment_spatial_data_dimension()
124 def __prepare_price_movement_trend_class_summary_plot_data(self, trading_environment: TradingEnvironment) -> \
125 dict[str, Any]:
126 """
127 Prepares the data for the price movement trend classification summary plot.
129 Parameters:
130 environment (TradingEnvironment): The trading environment to use.
132 Returns:
133 (dict[str, Any]): The prepared data for the plot.
134 """
136 data = {}
137 trading_environment.set_mode(TradingEnvironment.TEST_MODE)
138 data['test_part_price_movement'] = trading_environment.get_data_for_iteration(['close'])
139 _, test_part_labels, _, _ = trading_environment.get_labeled_data(should_split = False,
140 should_balance = False,
141 verbose = False)
142 data['test_part_labels'] = test_part_labels.astype('int')
144 trading_environment.set_mode(TradingEnvironment.TRAIN_MODE)
145 data['train_part_price_movement'] = trading_environment.get_data_for_iteration(['close'])
146 _, train_part_labels, _, _ = trading_environment.get_labeled_data(should_split = False,
147 should_balance = False,
148 verbose = False)
149 data['train_part_labels'] = train_part_labels.astype('int')
151 return data