Coverage for source/agent/strategies/classification_learning_strategy_handler.py: 98%
49 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-01 20:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-01 20:51 +0000
1# agent/strategies/classification_learning_strategy_handler.py
3# global imports
4import logging
5from tensorflow.keras.callbacks import Callback
6from typing import Any
8# local imports
9from source.agent import AgentBase, ClassificationLearningAgent, LearningStrategyHandlerBase, \
10 PerformanceTestableClassificationLearningAgent
11from source.environment import TradingEnvironment
12from source.model import BluePrintBase
14class ClassificationLearningStrategyHandler(LearningStrategyHandlerBase):
15 """
16 Implements a learning strategy handler for classification tasks. It provides
17 functionalities for creating agents and fitting models.
18 """
20 # global class constants
21 PLOTTING_KEYS: str = ['price_movement_trend_class_summary', 'classification_learning']
23 def create_agent(self, model_blue_print: BluePrintBase,
24 trading_environment: TradingEnvironment) -> AgentBase:
25 """
26 Creates a classification learning agent. It dynamically determines the parameters
27 needed for instantiation based on the model blueprint.
29 Parameters:
30 model_blue_print (BluePrintBase): The model blueprint to be used for agent creation.
31 trading_environment (TradingEnvironment): The trading environment in which the agent will operate.
33 Returns:
34 (Agent): An instance of the classification learning agent created using the model blueprint and trading environment.
35 """
37 parameters_needed_for_instantiation = model_blue_print.report_parameters_needed_for_instantiation()
38 kwargs = {}
39 for parameter in parameters_needed_for_instantiation:
40 kwargs[parameter] = self._provide_required_parameter(parameter, trading_environment)
42 return PerformanceTestableClassificationLearningAgent(model_blue_print.instantiate_model(**kwargs))
44 def fit(self, agent: ClassificationLearningAgent, trading_environment: TradingEnvironment,
45 nr_of_steps: int, nr_of_episodes: int, callbacks: list[Callback]) -> tuple[list[str], list[dict[str, Any]]]:
46 """
47 Fits the classification learning agent to the trading environment.
49 Parameters:
50 agent (ClassificationLearningAgent): The classification learning agent to fit.
51 environment (TradingEnvironment): The trading environment to use.
52 nr_of_steps (int): The number of training steps to perform.
53 nr_of_episodes (int): The number of training episodes to perform.
54 callbacks (list[Callback]): List of callbacks to use during training.
56 Raises:
57 TypeError: If the agent is not an instance of ClassificationLearningAgent.
59 Returns:
60 (list[str], list[dict[str, Any]]): A tuple containing the keys and data collected during training.
61 """
63 if not isinstance(agent, ClassificationLearningAgent):
64 raise TypeError("Agent must be an instance of ClassificationLearningAgent.")
66 keys, data = super().fit(agent, trading_environment, nr_of_steps, nr_of_episodes, callbacks)
67 keys.append(self.PLOTTING_KEYS[0])
68 keys.append(self.PLOTTING_KEYS[1] + "_" + agent.get_model_adapter_tag())
70 input_data, output_data, input_data_test, output_data_test = trading_environment.get_labeled_data()
71 steps_per_epoch = nr_of_steps // nr_of_episodes
72 batch_size = len(input_data) // steps_per_epoch
73 validation_tuple = (input_data_test, output_data_test)
74 if batch_size <= 0:
75 logging.warning("Batch size is zero or negative, using value of 1 instead.")
76 batch_size = 1
78 data.append(self.__prepare_price_movement_trend_class_summary_plot_data(trading_environment))
79 data.append(agent.classification_fit(input_data, output_data, validation_data = validation_tuple,
80 batch_size = batch_size, epochs = nr_of_episodes, callbacks = callbacks))
82 return keys, data
84 def _provide_input_shape(self, trading_environment: TradingEnvironment) -> tuple[int, int]:
85 """
86 Provides the input shape for the model based on the trading environment.
88 Parameters:
89 trading_environment (TradingEnvironment): The trading environment to use.
91 Returns:
92 (tuple[int, int]): The input shape for the model.
93 """
95 windows_size = trading_environment.get_trading_consts().WINDOW_SIZE
96 spatial_data_shape = trading_environment.get_environment_spatial_data_dimension()
97 return (spatial_data_shape[1] * windows_size, )
99 def _provide_output_length(self, trading_environment: TradingEnvironment) -> int:
100 """
101 Provides the output length for the model based on the trading environment.
103 Parameters:
104 trading_environment (TradingEnvironment): The trading environment to use.
106 Returns:
107 (int): The output length for the model.
108 """
110 return len(trading_environment.get_trading_consts().OUTPUT_CLASSES)
112 def _provide_spatial_data_shape(self, trading_environment: TradingEnvironment) -> tuple[int, int]:
113 """
114 Provides the spatial data shape for the model based on the trading environment.
116 Parameters:
117 trading_environment (TradingEnvironment): The trading environment to use.
119 Returns:
120 (tuple[int, int]): The spatial data shape for the model.
121 """
123 return trading_environment.get_environment_spatial_data_dimension()
125 def __prepare_price_movement_trend_class_summary_plot_data(self, trading_environment: TradingEnvironment) -> \
126 dict[str, Any]:
127 """
128 Prepares the data for the price movement trend classification summary plot.
130 Parameters:
131 environment (TradingEnvironment): The trading environment to use.
133 Returns:
134 (dict[str, Any]): The prepared data for the plot.
135 """
137 data = {}
138 trading_environment.set_mode(TradingEnvironment.TEST_MODE)
139 data['test_part_price_movement'] = trading_environment.get_data_for_iteration(['close'])
140 _, test_part_labels, _, _ = trading_environment.get_labeled_data(should_split = False,
141 should_balance = False,
142 verbose = False)
143 data['test_part_labels'] = test_part_labels.astype('int')
145 trading_environment.set_mode(TradingEnvironment.TRAIN_MODE)
146 data['train_part_price_movement'] = trading_environment.get_data_for_iteration(['close'])
147 _, train_part_labels, _, _ = trading_environment.get_labeled_data(should_split = False,
148 should_balance = False,
149 verbose = False)
150 data['train_part_labels'] = train_part_labels.astype('int')
152 return data