Coverage for source/training/training_config.py: 0%

40 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-07-23 22:15 +0000

1# training/training_config.py 

2 

3# global imports 

4from typing import Optional 

5 

6# local imports 

7from source.agent import AgentHandler, LearningStrategyHandlerBase, TestingStrategyHandlerBase 

8from source.environment import LabelAnnotatorBase, LabeledDataBalancer, PriceRewardValidator, \ 

9 RewardValidatorBase, SimpleLabelAnnotator, TradingEnvironment 

10from source.model import BluePrintBase 

11 

12class TrainingConfig(): 

13 """ 

14 Implements a configuration class for training agents in a trading environment. It encapsulates 

15 all necessary parameters for training, including the number of steps, episodes, model blueprint, 

16 data path, initial budget, maximum amount of trades, window size, and various reward parameters. 

17 It also provides methods for instantiating an agent handler and printing the configuration. 

18 """ 

19 

20 def __init__(self, nr_of_steps: int, nr_of_episodes: int, model_blue_print: BluePrintBase, 

21 data_path: str, initial_budget: float, max_amount_of_trades: int, window_size: int, 

22 learning_strategy_handler: LearningStrategyHandlerBase, 

23 testing_strategy_handler: TestingStrategyHandlerBase, sell_stop_loss: float = 0.8, 

24 sell_take_profit: float = 1.2, buy_stop_loss: float = 0.8, buy_take_profit: float = 1.2, 

25 penalty_starts: int = 0, penalty_stops: int = 10, static_reward_adjustment: float = 1, 

26 repeat_test: int = 10, test_ratio: float = 0.2, validator: Optional[RewardValidatorBase] = None, 

27 label_annotator: Optional[LabelAnnotatorBase] = None, 

28 labeled_data_balancer: Optional[LabeledDataBalancer] = None) -> None: 

29 """ 

30 Class constructor. Initializes the training configuration with the provided parameters. 

31 

32 Parameters: 

33 nr_of_steps (int): The number of training steps to perform. 

34 nr_of_episodes (int): The number of training episodes to perform. 

35 model_blue_print (BluePrintBase): The blueprint for the model to be trained. 

36 data_path (str): The path to the training data. 

37 initial_budget (float): The initial budget for the trading agent. 

38 max_amount_of_trades (int): The maximum number of trades to perform. 

39 window_size (int): The size of the observation window. 

40 learning_strategy_handler (LearningStrategyHandlerBase): The handler for the learning strategy. 

41 testing_strategy_handler (TestingStrategyHandlerBase): The handler for the testing strategy. 

42 sell_stop_loss (float): The stop loss threshold for selling. 

43 sell_take_profit (float): The take profit threshold for selling. 

44 buy_stop_loss (float): The stop loss threshold for buying. 

45 buy_take_profit (float): The take profit threshold for buying. 

46 penalty_starts (int): The step at which to start applying penalties. 

47 penalty_stops (int): The step at which to stop applying penalties. 

48 static_reward_adjustment (float): The static adjustment factor for rewards. 

49 repeat_test (int): The number of times to repeat testing. 

50 test_ratio (float): The ratio of data to use for testing. 

51 validator (Optional[RewardValidatorBase]): The reward validator to use. Defaults to PriceRewardValidator. 

52 label_annotator (Optional[LabelAnnotatorBase]): The label annotator to use. Defaults to SimpleLabelAnnotator. 

53 labeled_data_balancer (Optional[LabeledDataBalancer]): The labeled data balancer to use. Defaults to None. 

54 """ 

55 

56 if validator is None: 

57 validator = PriceRewardValidator() 

58 

59 if label_annotator is None: 

60 label_annotator = SimpleLabelAnnotator() 

61 

62 # Training config 

63 self.nr_of_steps: int = nr_of_steps 

64 self.nr_of_episodes: int = nr_of_episodes 

65 self.repeat_test: int = repeat_test 

66 

67 # Environment config 

68 self.__data_path: str = data_path 

69 self.__test_ratio = test_ratio 

70 self.__initial_budget: float = initial_budget 

71 self.__max_amount_of_trades: int = max_amount_of_trades 

72 self.__window_size: int = window_size 

73 self.__sell_stop_loss: float = sell_stop_loss 

74 self.__sell_take_profit: float = sell_take_profit 

75 self.__buy_stop_loss: float = buy_stop_loss 

76 self.__buy_take_profit: float = buy_take_profit 

77 self.__penalty_starts: int = penalty_starts 

78 self.__penalty_stops: int = penalty_stops 

79 self.__static_reward_adjustment: float = static_reward_adjustment 

80 self.__validator: RewardValidatorBase = validator 

81 self.__label_annotator: LabelAnnotatorBase = label_annotator 

82 self.__labeled_data_balancer: Optional[LabeledDataBalancer] = labeled_data_balancer 

83 

84 # Agent config 

85 self.__model_blue_print: BluePrintBase = model_blue_print 

86 self.__learning_strategy_handler: LearningStrategyHandlerBase = learning_strategy_handler 

87 self.__testing_strategy_handler: TestingStrategyHandlerBase = testing_strategy_handler 

88 

89 def __str__(self) -> str: 

90 """ 

91 Returns a string representation of the configuration. 

92 

93 Creates a formatted multi-line string containing all configuration 

94 parameters and their values for easy logging. 

95 

96 Returns: 

97 str: Formatted string representation of the configuration. 

98 """ 

99 

100 labeled_data_balancer_info = "" 

101 if self.__labeled_data_balancer is not None: 

102 labeled_data_balancer_info = \ 

103 f"\tlabeled_data_balancer: {self.__labeled_data_balancer.__class__.__name__}\n" \ 

104 f"\t\t{vars(self.__labeled_data_balancer)}\n" 

105 else: 

106 labeled_data_balancer_info = "\tlabeled_data_balancer: None\n" 

107 

108 return f"Training config:\n" \ 

109 f"\tnr_of_steps: {self.nr_of_steps}\n" \ 

110 f"\tnr_of_episodes: {self.nr_of_episodes}\n" \ 

111 f"\trepeat_test: {self.repeat_test}\n" \ 

112 f"\ttest_ratio: {self.__test_ratio}\n" \ 

113 f"\tinitial_budget: {self.__initial_budget}\n" \ 

114 f"\tmax_amount_of_trades: {self.__max_amount_of_trades}\n" \ 

115 f"\twindow_size: {self.__window_size}\n" \ 

116 f"\tsell_stop_loss: {self.__sell_stop_loss}\n" \ 

117 f"\tsell_take_profit: {self.__sell_take_profit}\n" \ 

118 f"\tbuy_stop_loss: {self.__buy_stop_loss}\n" \ 

119 f"\tbuy_take_profit: {self.__buy_take_profit}\n" \ 

120 f"\tpenalty_starts: {self.__penalty_starts}\n" \ 

121 f"\tpenalty_stops: {self.__penalty_stops}\n" \ 

122 f"\tstatic_reward_adjustment: {self.__static_reward_adjustment}\n" \ 

123 f"\tvalidator: {self.__validator.__class__.__name__}\n" \ 

124 f"\t\t{vars(self.__validator)}\n" \ 

125 f"\tlabel_annotator: {self.__label_annotator.__class__.__name__}\n" \ 

126 f"\t\t{vars(self.__label_annotator)}\n" \ 

127 f"{labeled_data_balancer_info}" \ 

128 f"\tmodel_blue_print: {self.__model_blue_print.__class__.__name__}\n" \ 

129 f"\t\t{vars(self.__model_blue_print)}\n" \ 

130 f"\tlearning_strategy_handler: {self.__learning_strategy_handler.__class__.__name__}\n" \ 

131 f"\t\t{vars(self.__learning_strategy_handler)}\n" \ 

132 f"\ttesting_strategy_handler: {self.__testing_strategy_handler.__class__.__name__}\n" \ 

133 f"\t\t{vars(self.__testing_strategy_handler)}\n" 

134 

135 def instantiate_agent_handler(self) -> AgentHandler: 

136 """ 

137 Instantiates the agent handler with the configured environment and strategies. 

138 

139 Returns: 

140 (AgentHandler): An instance of the agent handler configured with the model blueprint, 

141 trading environment, learning strategy handler and testing strategy handler. 

142 """ 

143 

144 environment = TradingEnvironment(self.__data_path, self.__initial_budget, self.__max_amount_of_trades, 

145 self.__window_size, self.__validator, self.__label_annotator, 

146 self.__sell_stop_loss, self.__sell_take_profit, self.__buy_stop_loss, 

147 self.__buy_take_profit, self.__test_ratio, self.__penalty_starts, 

148 self.__penalty_stops, self.__static_reward_adjustment, 

149 self.__labeled_data_balancer) 

150 

151 return AgentHandler(self.__model_blue_print, environment, self.__learning_strategy_handler, 

152 self.__testing_strategy_handler)