Coverage for source/training/training_config.py: 98%

42 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-07-30 20:59 +0000

1# training/training_config.py 

2 

3# global imports 

4import pandas as pd 

5from typing import Any, Optional 

6 

7# local imports 

8from source.agent import AgentHandler, LearningStrategyHandlerBase, TestingStrategyHandlerBase 

9from source.environment import LabelAnnotatorBase, LabeledDataBalancer, PriceRewardValidator, \ 

10 RewardValidatorBase, SimpleLabelAnnotator, TradingEnvironment 

11from source.model import BluePrintBase 

12 

13class TrainingConfig(): 

14 """ 

15 Implements a configuration class for training agents in a trading environment. It encapsulates 

16 all necessary parameters for training, including the number of steps, episodes, model blueprint, 

17 data path, initial budget, maximum amount of trades, window size, and various reward parameters. 

18 It also provides methods for instantiating an agent handler and printing the configuration. 

19 """ 

20 

21 def __init__(self, nr_of_steps: int, nr_of_episodes: int, model_blue_print: BluePrintBase, 

22 data: pd.DataFrame, initial_budget: float, max_amount_of_trades: int, window_size: int, 

23 learning_strategy_handler: LearningStrategyHandlerBase, 

24 testing_strategy_handler: TestingStrategyHandlerBase, sell_stop_loss: float = 0.8, 

25 sell_take_profit: float = 1.2, buy_stop_loss: float = 0.8, buy_take_profit: float = 1.2, 

26 penalty_starts: int = 0, penalty_stops: int = 10, static_reward_adjustment: float = 1, 

27 repeat_test: int = 10, test_ratio: float = 0.2, validator: Optional[RewardValidatorBase] = None, 

28 label_annotator: Optional[LabelAnnotatorBase] = None, 

29 labeled_data_balancer: Optional[LabeledDataBalancer] = None, 

30 meta_data: Optional[dict[str, Any]] = None) -> None: 

31 """ 

32 Class constructor. Initializes the training configuration with the provided parameters. 

33 

34 Parameters: 

35 nr_of_steps (int): The number of training steps to perform. 

36 nr_of_episodes (int): The number of training episodes to perform. 

37 model_blue_print (BluePrintBase): The blueprint for the model to be trained. 

38 data (pd.DataFrame): The training data. 

39 initial_budget (float): The initial budget for the trading agent. 

40 max_amount_of_trades (int): The maximum number of trades to perform. 

41 window_size (int): The size of the observation window. 

42 learning_strategy_handler (LearningStrategyHandlerBase): The handler for the learning strategy. 

43 testing_strategy_handler (TestingStrategyHandlerBase): The handler for the testing strategy. 

44 sell_stop_loss (float): The stop loss threshold for selling. 

45 sell_take_profit (float): The take profit threshold for selling. 

46 buy_stop_loss (float): The stop loss threshold for buying. 

47 buy_take_profit (float): The take profit threshold for buying. 

48 penalty_starts (int): The step at which to start applying penalties. 

49 penalty_stops (int): The step at which to stop applying penalties. 

50 static_reward_adjustment (float): The static adjustment factor for rewards. 

51 repeat_test (int): The number of times to repeat testing. 

52 test_ratio (float): The ratio of data to use for testing. 

53 validator (Optional[RewardValidatorBase]): The reward validator to use. Defaults to PriceRewardValidator. 

54 label_annotator (Optional[LabelAnnotatorBase]): The label annotator to use. Defaults to SimpleLabelAnnotator. 

55 labeled_data_balancer (Optional[LabeledDataBalancer]): The labeled data balancer to use. Defaults to None. 

56 meta_data (Optional[dict[str, Any]]): Optional metadata for the training configuration. 

57 """ 

58 

59 if validator is None: 

60 validator = PriceRewardValidator() 

61 

62 if label_annotator is None: 

63 label_annotator = SimpleLabelAnnotator() 

64 

65 # Training config 

66 self.nr_of_steps: int = nr_of_steps 

67 self.nr_of_episodes: int = nr_of_episodes 

68 self.repeat_test: int = repeat_test 

69 

70 # Environment config 

71 self.__data: pd.DataFrame = data 

72 self.__meta_data: Optional[dict[str, Any]] = meta_data 

73 self.__test_ratio: float = test_ratio 

74 self.__initial_budget: float = initial_budget 

75 self.__max_amount_of_trades: int = max_amount_of_trades 

76 self.__window_size: int = window_size 

77 self.__sell_stop_loss: float = sell_stop_loss 

78 self.__sell_take_profit: float = sell_take_profit 

79 self.__buy_stop_loss: float = buy_stop_loss 

80 self.__buy_take_profit: float = buy_take_profit 

81 self.__penalty_starts: int = penalty_starts 

82 self.__penalty_stops: int = penalty_stops 

83 self.__static_reward_adjustment: float = static_reward_adjustment 

84 self.__validator: RewardValidatorBase = validator 

85 self.__label_annotator: LabelAnnotatorBase = label_annotator 

86 self.__labeled_data_balancer: Optional[LabeledDataBalancer] = labeled_data_balancer 

87 

88 # Agent config 

89 self.__model_blue_print: BluePrintBase = model_blue_print 

90 self.__learning_strategy_handler: LearningStrategyHandlerBase = learning_strategy_handler 

91 self.__testing_strategy_handler: TestingStrategyHandlerBase = testing_strategy_handler 

92 

93 def __str__(self) -> str: 

94 """ 

95 Returns a string representation of the configuration. 

96 

97 Creates a formatted multi-line string containing all configuration 

98 parameters and their values for easy logging. 

99 

100 Returns: 

101 str: Formatted string representation of the configuration. 

102 """ 

103 

104 labeled_data_balancer_info = "" 

105 if self.__labeled_data_balancer is not None: 

106 labeled_data_balancer_info = \ 

107 f"\tlabeled_data_balancer: {self.__labeled_data_balancer.__class__.__name__}\n" \ 

108 f"\t\t{vars(self.__labeled_data_balancer)}\n" 

109 else: 

110 labeled_data_balancer_info = "\tlabeled_data_balancer: None\n" 

111 

112 return f"Training config:\n" \ 

113 f"\tnr_of_steps: {self.nr_of_steps}\n" \ 

114 f"\tnr_of_episodes: {self.nr_of_episodes}\n" \ 

115 f"\trepeat_test: {self.repeat_test}\n" \ 

116 f"\ttest_ratio: {self.__test_ratio}\n" \ 

117 f"\tinitial_budget: {self.__initial_budget}\n" \ 

118 f"\tmax_amount_of_trades: {self.__max_amount_of_trades}\n" \ 

119 f"\twindow_size: {self.__window_size}\n" \ 

120 f"\tsell_stop_loss: {self.__sell_stop_loss}\n" \ 

121 f"\tsell_take_profit: {self.__sell_take_profit}\n" \ 

122 f"\tbuy_stop_loss: {self.__buy_stop_loss}\n" \ 

123 f"\tbuy_take_profit: {self.__buy_take_profit}\n" \ 

124 f"\tpenalty_starts: {self.__penalty_starts}\n" \ 

125 f"\tpenalty_stops: {self.__penalty_stops}\n" \ 

126 f"\tstatic_reward_adjustment: {self.__static_reward_adjustment}\n" \ 

127 f"\tvalidator: {self.__validator.__class__.__name__}\n" \ 

128 f"\t\t{vars(self.__validator)}\n" \ 

129 f"\tlabel_annotator: {self.__label_annotator.__class__.__name__}\n" \ 

130 f"\t\t{vars(self.__label_annotator)}\n" \ 

131 f"{labeled_data_balancer_info}" \ 

132 f"\tmodel_blue_print: {self.__model_blue_print.__class__.__name__}\n" \ 

133 f"\t\t{vars(self.__model_blue_print)}\n" \ 

134 f"\tlearning_strategy_handler: {self.__learning_strategy_handler.__class__.__name__}\n" \ 

135 f"\t\t{vars(self.__learning_strategy_handler)}\n" \ 

136 f"\ttesting_strategy_handler: {self.__testing_strategy_handler.__class__.__name__}\n" \ 

137 f"\t\t{vars(self.__testing_strategy_handler)}\n" 

138 

139 def instantiate_agent_handler(self) -> AgentHandler: 

140 """ 

141 Instantiates the agent handler with the configured environment and strategies. 

142 

143 Returns: 

144 (AgentHandler): An instance of the agent handler configured with the model blueprint, 

145 trading environment, learning strategy handler and testing strategy handler. 

146 """ 

147 

148 environment = TradingEnvironment(self.__data, self.__initial_budget, self.__max_amount_of_trades, 

149 self.__window_size, self.__validator, self.__label_annotator, 

150 self.__sell_stop_loss, self.__sell_take_profit, self.__buy_stop_loss, 

151 self.__buy_take_profit, self.__test_ratio, self.__penalty_starts, 

152 self.__penalty_stops, self.__static_reward_adjustment, 

153 self.__labeled_data_balancer, self.__meta_data) 

154 

155 return AgentHandler(self.__model_blue_print, environment, self.__learning_strategy_handler, 

156 self.__testing_strategy_handler)