Coverage for source/agent/strategies/classification_learning_strategy_handler.py: 44%

27 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-06-06 12:00 +0000

1# agent/strategies/classification_learning_strategy_handler.py 

2 

3# global imports 

4from typing import Any 

5from tensorflow.keras.callbacks import Callback 

6import logging 

7 

8# local imports 

9from source.agent import LearningStrategyHandlerBase 

10from source.agent import AgentBase 

11from source.agent import ClassificationLearningAgent 

12from source.environment import TradingEnvironment 

13from source.model import BluePrintBase 

14 

15class ClassificationLearningStrategyHandler(LearningStrategyHandlerBase): 

16 """""" 

17 

18 # global constants 

19 PLOTTING_KEY: str = 'classification_learning' 

20 

21 def create_agent(self, model_blue_print: BluePrintBase, 

22 trading_environment: TradingEnvironment) -> AgentBase: 

23 """""" 

24 

25 windows_size = trading_environment.get_trading_consts().WINDOW_SIZE 

26 spatial_data_shape = trading_environment.get_environment_spatial_data_dimension() 

27 market_data_shape = (spatial_data_shape[1] * windows_size, ) 

28 

29 number_of_classes = len(trading_environment.get_trading_consts().OUTPUT_CLASSES) 

30 model_adapter = model_blue_print.instantiate_model(market_data_shape, number_of_classes, 

31 spatial_data_shape) 

32 return ClassificationLearningAgent(model_adapter) 

33 

34 def fit(self, agent: ClassificationLearningAgent, environment: TradingEnvironment, 

35 nr_of_steps: int, nr_of_episodes: int, callbacks: list[Callback]) -> tuple[list[str], list[dict[str, Any]]]: 

36 """""" 

37 

38 if not isinstance(agent, ClassificationLearningAgent): 

39 raise TypeError("Agent must be an instance of ClassificationLearningAgent.") 

40 

41 input_data, output_data = environment.get_labeled_data() 

42 steps_per_epoch = nr_of_steps // nr_of_episodes 

43 batch_size = len(input_data) // steps_per_epoch 

44 if batch_size <= 0: 

45 logging.warning("Batch size is zero or negative, using value of 1 instead.") 

46 batch_size = 1 

47 

48 return [ClassificationLearningStrategyHandler.PLOTTING_KEY], \ 

49 [agent.classification_fit(input_data, output_data, batch_size = batch_size, 

50 epochs = nr_of_episodes, callbacks = callbacks)]