Coverage for source/agent/strategies/classification_learning_strategy_handler.py: 98%

49 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-07-30 20:59 +0000

1# agent/strategies/classification_learning_strategy_handler.py 

2 

3# global imports 

4import logging 

5from tensorflow.keras.callbacks import Callback 

6from typing import Any 

7 

8# local imports 

9from source.agent import AgentBase, ClassificationLearningAgent, LearningStrategyHandlerBase 

10from source.environment import TradingEnvironment 

11from source.model import BluePrintBase 

12 

13class ClassificationLearningStrategyHandler(LearningStrategyHandlerBase): 

14 """ 

15 Implements a learning strategy handler for classification tasks. It provides 

16 functionalities for creating agents and fitting models. 

17 """ 

18 

19 # global class constants 

20 PLOTTING_KEYS: str = ['price_movement_trend_class_summary', 'classification_learning'] 

21 

22 def create_agent(self, model_blue_print: BluePrintBase, 

23 trading_environment: TradingEnvironment) -> AgentBase: 

24 """ 

25 Creates a classification learning agent. It dynamically determines the parameters 

26 needed for instantiation based on the model blueprint. 

27 

28 Parameters: 

29 model_blue_print (BluePrintBase): The model blueprint to be used for agent creation. 

30 trading_environment (TradingEnvironment): The trading environment in which the agent will operate. 

31 

32 Returns: 

33 (Agent): An instance of the classification learning agent created using the model blueprint and trading environment. 

34 """ 

35 

36 parameters_needed_for_instantiation = model_blue_print.report_parameters_needed_for_instantiation() 

37 kwargs = {} 

38 for parameter in parameters_needed_for_instantiation: 

39 kwargs[parameter] = self._provide_required_parameter(parameter, trading_environment) 

40 

41 return ClassificationLearningAgent(model_blue_print.instantiate_model(**kwargs)) 

42 

43 def fit(self, agent: ClassificationLearningAgent, trading_environment: TradingEnvironment, 

44 nr_of_steps: int, nr_of_episodes: int, callbacks: list[Callback]) -> tuple[list[str], list[dict[str, Any]]]: 

45 """ 

46 Fits the classification learning agent to the trading environment. 

47 

48 Parameters: 

49 agent (ClassificationLearningAgent): The classification learning agent to fit. 

50 environment (TradingEnvironment): The trading environment to use. 

51 nr_of_steps (int): The number of training steps to perform. 

52 nr_of_episodes (int): The number of training episodes to perform. 

53 callbacks (list[Callback]): List of callbacks to use during training. 

54 

55 Raises: 

56 TypeError: If the agent is not an instance of ClassificationLearningAgent. 

57 

58 Returns: 

59 (list[str], list[dict[str, Any]]): A tuple containing the keys and data collected during training. 

60 """ 

61 

62 if not isinstance(agent, ClassificationLearningAgent): 

63 raise TypeError("Agent must be an instance of ClassificationLearningAgent.") 

64 

65 keys, data = super().fit(agent, trading_environment, nr_of_steps, nr_of_episodes, callbacks) 

66 keys.append(self.PLOTTING_KEYS[0]) 

67 keys.append(self.PLOTTING_KEYS[1] + "_" + agent.get_model_adapter_tag()) 

68 

69 input_data, output_data, input_data_test, output_data_test = trading_environment.get_labeled_data() 

70 steps_per_epoch = nr_of_steps // nr_of_episodes 

71 batch_size = len(input_data) // steps_per_epoch 

72 validation_tuple = (input_data_test, output_data_test) 

73 if batch_size <= 0: 

74 logging.warning("Batch size is zero or negative, using value of 1 instead.") 

75 batch_size = 1 

76 

77 data.append(self.__prepare_price_movement_trend_class_summary_plot_data(trading_environment)) 

78 data.append(agent.classification_fit(input_data, output_data, validation_data = validation_tuple, 

79 batch_size = batch_size, epochs = nr_of_episodes, callbacks = callbacks)) 

80 

81 return keys, data 

82 

83 def _provide_input_shape(self, trading_environment: TradingEnvironment) -> tuple[int, int]: 

84 """ 

85 Provides the input shape for the model based on the trading environment. 

86 

87 Parameters: 

88 trading_environment (TradingEnvironment): The trading environment to use. 

89 

90 Returns: 

91 (tuple[int, int]): The input shape for the model. 

92 """ 

93 

94 windows_size = trading_environment.get_trading_consts().WINDOW_SIZE 

95 spatial_data_shape = trading_environment.get_environment_spatial_data_dimension() 

96 return (spatial_data_shape[1] * windows_size, ) 

97 

98 def _provide_output_length(self, trading_environment: TradingEnvironment) -> int: 

99 """ 

100 Provides the output length for the model based on the trading environment. 

101 

102 Parameters: 

103 trading_environment (TradingEnvironment): The trading environment to use. 

104 

105 Returns: 

106 (int): The output length for the model. 

107 """ 

108 

109 return len(trading_environment.get_trading_consts().OUTPUT_CLASSES) 

110 

111 def _provide_spatial_data_shape(self, trading_environment: TradingEnvironment) -> tuple[int, int]: 

112 """ 

113 Provides the spatial data shape for the model based on the trading environment. 

114 

115 Parameters: 

116 trading_environment (TradingEnvironment): The trading environment to use. 

117 

118 Returns: 

119 (tuple[int, int]): The spatial data shape for the model. 

120 """ 

121 

122 return trading_environment.get_environment_spatial_data_dimension() 

123 

124 def __prepare_price_movement_trend_class_summary_plot_data(self, trading_environment: TradingEnvironment) -> \ 

125 dict[str, Any]: 

126 """ 

127 Prepares the data for the price movement trend classification summary plot. 

128 

129 Parameters: 

130 environment (TradingEnvironment): The trading environment to use. 

131 

132 Returns: 

133 (dict[str, Any]): The prepared data for the plot. 

134 """ 

135 

136 data = {} 

137 trading_environment.set_mode(TradingEnvironment.TEST_MODE) 

138 data['test_part_price_movement'] = trading_environment.get_data_for_iteration(['close']) 

139 _, test_part_labels, _, _ = trading_environment.get_labeled_data(should_split = False, 

140 should_balance = False, 

141 verbose = False) 

142 data['test_part_labels'] = test_part_labels.astype('int') 

143 

144 trading_environment.set_mode(TradingEnvironment.TRAIN_MODE) 

145 data['train_part_price_movement'] = trading_environment.get_data_for_iteration(['close']) 

146 _, train_part_labels, _, _ = trading_environment.get_labeled_data(should_split = False, 

147 should_balance = False, 

148 verbose = False) 

149 data['train_part_labels'] = train_part_labels.astype('int') 

150 

151 return data