Coverage for source/training/training_config.py: 98%

43 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-08-04 21:16 +0000

1# training/training_config.py 

2 

3# global imports 

4import pandas as pd 

5from typing import Any, Optional 

6 

7# local imports 

8from source.agent import AgentHandler, LearningStrategyHandlerBase, TestingStrategyHandlerBase 

9from source.environment import LabelAnnotatorBase, LabeledDataBalancer, PriceRewardValidator, \ 

10 RewardValidatorBase, SimpleLabelAnnotator, TradingEnvironment 

11from source.model import BluePrintBase 

12 

13class TrainingConfig(): 

14 """ 

15 Implements a configuration class for training agents in a trading environment. It encapsulates 

16 all necessary parameters for training, including the number of steps, episodes, model blueprint, 

17 data path, initial budget, maximum amount of trades, window size, and various reward parameters. 

18 It also provides methods for instantiating an agent handler and printing the configuration. 

19 """ 

20 

21 def __init__(self, nr_of_steps: int, nr_of_episodes: int, model_blue_print: BluePrintBase, 

22 data: pd.DataFrame, initial_budget: float, max_amount_of_trades: int, window_size: int, 

23 learning_strategy_handler: LearningStrategyHandlerBase, 

24 testing_strategy_handlers: list[TestingStrategyHandlerBase], sell_stop_loss: float = 0.8, 

25 sell_take_profit: float = 1.2, buy_stop_loss: float = 0.8, buy_take_profit: float = 1.2, 

26 penalty_starts: int = 0, penalty_stops: int = 10, static_reward_adjustment: float = 1, 

27 repeat_test: int = 10, test_ratio: float = 0.2, validator: Optional[RewardValidatorBase] = None, 

28 label_annotator: Optional[LabelAnnotatorBase] = None, 

29 labeled_data_balancer: Optional[LabeledDataBalancer] = None, 

30 meta_data: Optional[dict[str, Any]] = None, 

31 trading_mode: Optional[TradingEnvironment.TradingMode] = None) -> None: 

32 """ 

33 Class constructor. Initializes the training configuration with the provided parameters. 

34 

35 Parameters: 

36 nr_of_steps (int): The number of training steps to perform. 

37 nr_of_episodes (int): The number of training episodes to perform. 

38 model_blue_print (BluePrintBase): The blueprint for the model to be trained. 

39 data (pd.DataFrame): The training data. 

40 initial_budget (float): The initial budget for the trading agent. 

41 max_amount_of_trades (int): The maximum number of trades to perform. 

42 window_size (int): The size of the observation window. 

43 learning_strategy_handler (LearningStrategyHandlerBase): The handler for the learning strategy. 

44 testing_strategy_handler (TestingStrategyHandlerBase): The handler for the testing strategy. 

45 sell_stop_loss (float): The stop loss threshold for selling. 

46 sell_take_profit (float): The take profit threshold for selling. 

47 buy_stop_loss (float): The stop loss threshold for buying. 

48 buy_take_profit (float): The take profit threshold for buying. 

49 penalty_starts (int): The step at which to start applying penalties. 

50 penalty_stops (int): The step at which to stop applying penalties. 

51 static_reward_adjustment (float): The static adjustment factor for rewards. 

52 repeat_test (int): The number of times to repeat testing. 

53 test_ratio (float): The ratio of data to use for testing. 

54 validator (Optional[RewardValidatorBase]): The reward validator to use. Defaults to PriceRewardValidator. 

55 label_annotator (Optional[LabelAnnotatorBase]): The label annotator to use. Defaults to SimpleLabelAnnotator. 

56 labeled_data_balancer (Optional[LabeledDataBalancer]): The labeled data balancer to use. Defaults to None. 

57 meta_data (Optional[dict[str, Any]]): Optional metadata for the training configuration. 

58 trading_mode (Optional[TradingEnvironment.TradingMode]): The trading mode to use. Defaults to None. 

59 """ 

60 

61 if validator is None: 

62 validator = PriceRewardValidator() 

63 

64 if label_annotator is None: 

65 label_annotator = SimpleLabelAnnotator() 

66 

67 # Training config 

68 self.nr_of_steps: int = nr_of_steps 

69 self.nr_of_episodes: int = nr_of_episodes 

70 self.repeat_test: int = repeat_test 

71 

72 # Environment config 

73 self.__data: pd.DataFrame = data 

74 self.__meta_data: Optional[dict[str, Any]] = meta_data 

75 self.__trading_mode: Optional[TradingEnvironment.TradingMode] = trading_mode 

76 self.__test_ratio: float = test_ratio 

77 self.__initial_budget: float = initial_budget 

78 self.__max_amount_of_trades: int = max_amount_of_trades 

79 self.__window_size: int = window_size 

80 self.__sell_stop_loss: float = sell_stop_loss 

81 self.__sell_take_profit: float = sell_take_profit 

82 self.__buy_stop_loss: float = buy_stop_loss 

83 self.__buy_take_profit: float = buy_take_profit 

84 self.__penalty_starts: int = penalty_starts 

85 self.__penalty_stops: int = penalty_stops 

86 self.__static_reward_adjustment: float = static_reward_adjustment 

87 self.__validator: RewardValidatorBase = validator 

88 self.__label_annotator: LabelAnnotatorBase = label_annotator 

89 self.__labeled_data_balancer: Optional[LabeledDataBalancer] = labeled_data_balancer 

90 

91 # Agent config 

92 self.__model_blue_print: BluePrintBase = model_blue_print 

93 self.__learning_strategy_handler: LearningStrategyHandlerBase = learning_strategy_handler 

94 self.__testing_strategy_handlers: list[TestingStrategyHandlerBase] = testing_strategy_handlers 

95 

96 def __str__(self) -> str: 

97 """ 

98 Returns a string representation of the configuration. 

99 

100 Creates a formatted multi-line string containing all configuration 

101 parameters and their values for easy logging. 

102 

103 Returns: 

104 str: Formatted string representation of the configuration. 

105 """ 

106 

107 labeled_data_balancer_info = "" 

108 if self.__labeled_data_balancer is not None: 

109 labeled_data_balancer_info = \ 

110 f"\tlabeled_data_balancer: {self.__labeled_data_balancer.__class__.__name__}\n" \ 

111 f"\t\t{vars(self.__labeled_data_balancer)}\n" 

112 else: 

113 labeled_data_balancer_info = "\tlabeled_data_balancer: None\n" 

114 

115 return f"Training config:\n" \ 

116 f"\tnr_of_steps: {self.nr_of_steps}\n" \ 

117 f"\tnr_of_episodes: {self.nr_of_episodes}\n" \ 

118 f"\trepeat_test: {self.repeat_test}\n" \ 

119 f"\ttest_ratio: {self.__test_ratio}\n" \ 

120 f"\ttrading_mode: {self.__trading_mode}\n" \ 

121 f"\tinitial_budget: {self.__initial_budget}\n" \ 

122 f"\tmax_amount_of_trades: {self.__max_amount_of_trades}\n" \ 

123 f"\twindow_size: {self.__window_size}\n" \ 

124 f"\tsell_stop_loss: {self.__sell_stop_loss}\n" \ 

125 f"\tsell_take_profit: {self.__sell_take_profit}\n" \ 

126 f"\tbuy_stop_loss: {self.__buy_stop_loss}\n" \ 

127 f"\tbuy_take_profit: {self.__buy_take_profit}\n" \ 

128 f"\tpenalty_starts: {self.__penalty_starts}\n" \ 

129 f"\tpenalty_stops: {self.__penalty_stops}\n" \ 

130 f"\tstatic_reward_adjustment: {self.__static_reward_adjustment}\n" \ 

131 f"\tvalidator: {self.__validator.__class__.__name__}\n" \ 

132 f"\t\t{vars(self.__validator)}\n" \ 

133 f"\tlabel_annotator: {self.__label_annotator.__class__.__name__}\n" \ 

134 f"\t\t{vars(self.__label_annotator)}\n" \ 

135 f"{labeled_data_balancer_info}" \ 

136 f"\tmodel_blue_print: {self.__model_blue_print.__class__.__name__}\n" \ 

137 f"\t\t{vars(self.__model_blue_print)}\n" \ 

138 f"\tlearning_strategy_handler: {self.__learning_strategy_handler.__class__.__name__}\n" \ 

139 f"\t\t{vars(self.__learning_strategy_handler)}\n" \ 

140 f"\ttesting_strategy_handlers: {[handler.__class__.__name__ for handler in self.__testing_strategy_handlers]}\n" \ 

141 f"\t\t{[vars(handler) for handler in self.__testing_strategy_handlers]}\n" 

142 

143 def instantiate_agent_handler(self) -> AgentHandler: 

144 """ 

145 Instantiates the agent handler with the configured environment and strategies. 

146 

147 Returns: 

148 (AgentHandler): An instance of the agent handler configured with the model blueprint, 

149 trading environment, learning strategy handler and testing strategy handlers. 

150 """ 

151 

152 environment = TradingEnvironment(self.__data, self.__initial_budget, self.__max_amount_of_trades, 

153 self.__window_size, self.__validator, self.__label_annotator, 

154 self.__sell_stop_loss, self.__sell_take_profit, self.__buy_stop_loss, 

155 self.__buy_take_profit, self.__test_ratio, self.__penalty_starts, 

156 self.__penalty_stops, self.__static_reward_adjustment, 

157 self.__labeled_data_balancer, self.__meta_data, 

158 self.__trading_mode) 

159 

160 return AgentHandler(self.__model_blue_print, environment, self.__learning_strategy_handler, 

161 self.__testing_strategy_handlers)