Coverage for source/agent/strategies/classification_learning_strategy_handler.py: 44%
27 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-06 12:00 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-06 12:00 +0000
1# agent/strategies/classification_learning_strategy_handler.py
3# global imports
4from typing import Any
5from tensorflow.keras.callbacks import Callback
6import logging
8# local imports
9from source.agent import LearningStrategyHandlerBase
10from source.agent import AgentBase
11from source.agent import ClassificationLearningAgent
12from source.environment import TradingEnvironment
13from source.model import BluePrintBase
15class ClassificationLearningStrategyHandler(LearningStrategyHandlerBase):
16 """"""
18 # global constants
19 PLOTTING_KEY: str = 'classification_learning'
21 def create_agent(self, model_blue_print: BluePrintBase,
22 trading_environment: TradingEnvironment) -> AgentBase:
23 """"""
25 windows_size = trading_environment.get_trading_consts().WINDOW_SIZE
26 spatial_data_shape = trading_environment.get_environment_spatial_data_dimension()
27 market_data_shape = (spatial_data_shape[1] * windows_size, )
29 number_of_classes = len(trading_environment.get_trading_consts().OUTPUT_CLASSES)
30 model_adapter = model_blue_print.instantiate_model(market_data_shape, number_of_classes,
31 spatial_data_shape)
32 return ClassificationLearningAgent(model_adapter)
34 def fit(self, agent: ClassificationLearningAgent, environment: TradingEnvironment,
35 nr_of_steps: int, nr_of_episodes: int, callbacks: list[Callback]) -> tuple[list[str], list[dict[str, Any]]]:
36 """"""
38 if not isinstance(agent, ClassificationLearningAgent):
39 raise TypeError("Agent must be an instance of ClassificationLearningAgent.")
41 input_data, output_data = environment.get_labeled_data()
42 steps_per_epoch = nr_of_steps // nr_of_episodes
43 batch_size = len(input_data) // steps_per_epoch
44 if batch_size <= 0:
45 logging.warning("Batch size is zero or negative, using value of 1 instead.")
46 batch_size = 1
48 return [ClassificationLearningStrategyHandler.PLOTTING_KEY], \
49 [agent.classification_fit(input_data, output_data, batch_size = batch_size,
50 epochs = nr_of_episodes, callbacks = callbacks)]