Coverage for source/agent/strategies/classification_learning_strategy_handler.py: 98%

49 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-08-01 20:51 +0000

1# agent/strategies/classification_learning_strategy_handler.py 

2 

3# global imports 

4import logging 

5from tensorflow.keras.callbacks import Callback 

6from typing import Any 

7 

8# local imports 

9from source.agent import AgentBase, ClassificationLearningAgent, LearningStrategyHandlerBase, \ 

10 PerformanceTestableClassificationLearningAgent 

11from source.environment import TradingEnvironment 

12from source.model import BluePrintBase 

13 

14class ClassificationLearningStrategyHandler(LearningStrategyHandlerBase): 

15 """ 

16 Implements a learning strategy handler for classification tasks. It provides 

17 functionalities for creating agents and fitting models. 

18 """ 

19 

20 # global class constants 

21 PLOTTING_KEYS: str = ['price_movement_trend_class_summary', 'classification_learning'] 

22 

23 def create_agent(self, model_blue_print: BluePrintBase, 

24 trading_environment: TradingEnvironment) -> AgentBase: 

25 """ 

26 Creates a classification learning agent. It dynamically determines the parameters 

27 needed for instantiation based on the model blueprint. 

28 

29 Parameters: 

30 model_blue_print (BluePrintBase): The model blueprint to be used for agent creation. 

31 trading_environment (TradingEnvironment): The trading environment in which the agent will operate. 

32 

33 Returns: 

34 (Agent): An instance of the classification learning agent created using the model blueprint and trading environment. 

35 """ 

36 

37 parameters_needed_for_instantiation = model_blue_print.report_parameters_needed_for_instantiation() 

38 kwargs = {} 

39 for parameter in parameters_needed_for_instantiation: 

40 kwargs[parameter] = self._provide_required_parameter(parameter, trading_environment) 

41 

42 return PerformanceTestableClassificationLearningAgent(model_blue_print.instantiate_model(**kwargs)) 

43 

44 def fit(self, agent: ClassificationLearningAgent, trading_environment: TradingEnvironment, 

45 nr_of_steps: int, nr_of_episodes: int, callbacks: list[Callback]) -> tuple[list[str], list[dict[str, Any]]]: 

46 """ 

47 Fits the classification learning agent to the trading environment. 

48 

49 Parameters: 

50 agent (ClassificationLearningAgent): The classification learning agent to fit. 

51 environment (TradingEnvironment): The trading environment to use. 

52 nr_of_steps (int): The number of training steps to perform. 

53 nr_of_episodes (int): The number of training episodes to perform. 

54 callbacks (list[Callback]): List of callbacks to use during training. 

55 

56 Raises: 

57 TypeError: If the agent is not an instance of ClassificationLearningAgent. 

58 

59 Returns: 

60 (list[str], list[dict[str, Any]]): A tuple containing the keys and data collected during training. 

61 """ 

62 

63 if not isinstance(agent, ClassificationLearningAgent): 

64 raise TypeError("Agent must be an instance of ClassificationLearningAgent.") 

65 

66 keys, data = super().fit(agent, trading_environment, nr_of_steps, nr_of_episodes, callbacks) 

67 keys.append(self.PLOTTING_KEYS[0]) 

68 keys.append(self.PLOTTING_KEYS[1] + "_" + agent.get_model_adapter_tag()) 

69 

70 input_data, output_data, input_data_test, output_data_test = trading_environment.get_labeled_data() 

71 steps_per_epoch = nr_of_steps // nr_of_episodes 

72 batch_size = len(input_data) // steps_per_epoch 

73 validation_tuple = (input_data_test, output_data_test) 

74 if batch_size <= 0: 

75 logging.warning("Batch size is zero or negative, using value of 1 instead.") 

76 batch_size = 1 

77 

78 data.append(self.__prepare_price_movement_trend_class_summary_plot_data(trading_environment)) 

79 data.append(agent.classification_fit(input_data, output_data, validation_data = validation_tuple, 

80 batch_size = batch_size, epochs = nr_of_episodes, callbacks = callbacks)) 

81 

82 return keys, data 

83 

84 def _provide_input_shape(self, trading_environment: TradingEnvironment) -> tuple[int, int]: 

85 """ 

86 Provides the input shape for the model based on the trading environment. 

87 

88 Parameters: 

89 trading_environment (TradingEnvironment): The trading environment to use. 

90 

91 Returns: 

92 (tuple[int, int]): The input shape for the model. 

93 """ 

94 

95 windows_size = trading_environment.get_trading_consts().WINDOW_SIZE 

96 spatial_data_shape = trading_environment.get_environment_spatial_data_dimension() 

97 return (spatial_data_shape[1] * windows_size, ) 

98 

99 def _provide_output_length(self, trading_environment: TradingEnvironment) -> int: 

100 """ 

101 Provides the output length for the model based on the trading environment. 

102 

103 Parameters: 

104 trading_environment (TradingEnvironment): The trading environment to use. 

105 

106 Returns: 

107 (int): The output length for the model. 

108 """ 

109 

110 return len(trading_environment.get_trading_consts().OUTPUT_CLASSES) 

111 

112 def _provide_spatial_data_shape(self, trading_environment: TradingEnvironment) -> tuple[int, int]: 

113 """ 

114 Provides the spatial data shape for the model based on the trading environment. 

115 

116 Parameters: 

117 trading_environment (TradingEnvironment): The trading environment to use. 

118 

119 Returns: 

120 (tuple[int, int]): The spatial data shape for the model. 

121 """ 

122 

123 return trading_environment.get_environment_spatial_data_dimension() 

124 

125 def __prepare_price_movement_trend_class_summary_plot_data(self, trading_environment: TradingEnvironment) -> \ 

126 dict[str, Any]: 

127 """ 

128 Prepares the data for the price movement trend classification summary plot. 

129 

130 Parameters: 

131 environment (TradingEnvironment): The trading environment to use. 

132 

133 Returns: 

134 (dict[str, Any]): The prepared data for the plot. 

135 """ 

136 

137 data = {} 

138 trading_environment.set_mode(TradingEnvironment.TEST_MODE) 

139 data['test_part_price_movement'] = trading_environment.get_data_for_iteration(['close']) 

140 _, test_part_labels, _, _ = trading_environment.get_labeled_data(should_split = False, 

141 should_balance = False, 

142 verbose = False) 

143 data['test_part_labels'] = test_part_labels.astype('int') 

144 

145 trading_environment.set_mode(TradingEnvironment.TRAIN_MODE) 

146 data['train_part_price_movement'] = trading_environment.get_data_for_iteration(['close']) 

147 _, train_part_labels, _, _ = trading_environment.get_labeled_data(should_split = False, 

148 should_balance = False, 

149 verbose = False) 

150 data['train_part_labels'] = train_part_labels.astype('int') 

151 

152 return data