Coverage for source/training/training_config.py: 86%

43 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-31 12:26 +0000

1# training/training_config.py 

2 

3from rl.policy import Policy, BoltzmannQPolicy 

4from tensorflow.keras.optimizers import Optimizer, Adam 

5 

6from ..environment.trading_environment import TradingEnvironment 

7from ..agent.agent_handler import AgentHandler 

8from ..model.model_blue_prints.base_blue_print import BaseBluePrint 

9from ..environment.reward_validator_base import RewardValidatorBase 

10 

11class TrainingConfig(): 

12 """ 

13 Responsible for creating and configuring training environment and agent. 

14 

15 This class encapsulates all configuration parameters needed for training 

16 and testing a trading agent, including environment setup, agent creation, 

17 and reward validation. It provides a centralized way to manage training 

18 parameters and instantiate required components. 

19 """ 

20 

21 def __init__(self, nr_of_steps: int, nr_of_episodes: int, model_blue_print: BaseBluePrint, 

22 data_path: str, initial_budget: float, max_amount_of_trades: int, window_size: int, 

23 validator: RewardValidatorBase, sell_stop_loss: float = 0.8, sell_take_profit: float = 1.2, 

24 buy_stop_loss: float = 0.8, buy_take_profit: float = 1.2, penalty_starts: int = 0, 

25 penalty_stops: int = 10, static_reward_adjustment: float = 1, policy: Policy = BoltzmannQPolicy(), 

26 optimizer: Optimizer = Adam(learning_rate=1e-3), repeat_test: int = 10, test_ratio: float = 0.2) -> None: 

27 """ 

28 Initializes the training configuration with provided parameters. 

29 

30 Parameters: 

31 nr_of_steps (int): Total number of training steps. 

32 nr_of_episodes (int): Number of training episodes. 

33 model_blue_print (BaseBluePrint): Blueprint for creating the neural network model. 

34 data_path (str): Path to the training data file. 

35 initial_budget (float): Starting budget for the agent. 

36 max_amount_of_trades (int): Maximum number of trades allowed to be placed in the environment. 

37 window_size (int): Size of the observation window for market data. 

38 validator (RewardValidatorBase): Strategy for validating and calculating rewards. 

39 sell_stop_loss (float): Coefficient defining when to stop loss on sell positions. 

40 sell_take_profit (float): Coefficient defining when to take profit on sell positions. 

41 buy_stop_loss (float): Coefficient defining when to stop loss on buy positions. 

42 buy_take_profit (float): Coefficient defining when to take profit on buy positions. 

43 penalty_starts (int): Starting point (in trading periods without activity) that penalty should be applied from. 

44 penalty_stops (int): Ending point (in trading periods without activity) that penalty growth should be stopped at. 

45 static_reward_adjustment (float): Adjustment factor for rewards, used to penalize unwanted actions. 

46 policy (Policy): Policy for action selection during training. 

47 optimizer (Optimizer): Optimizer to be used for model compilation and training. 

48 repeat_test (int): Number of times to repeat testing for evaluation. 

49 test_ratio (float): Ratio of data to be used for testing vs training. 

50 """ 

51 

52 # Training config 

53 self.nr_of_steps = nr_of_steps 

54 self.nr_of_episodes = nr_of_episodes 

55 self.repeat_test = repeat_test 

56 

57 # Environment config 

58 self.__data_path: str = data_path 

59 self.__test_ratio = test_ratio 

60 self.__initial_budget: float = initial_budget 

61 self.__max_amount_of_trades: int = max_amount_of_trades 

62 self.__window_size: int = window_size 

63 self.__sell_stop_loss: float = sell_stop_loss 

64 self.__sell_take_profit: float = sell_take_profit 

65 self.__buy_stop_loss: float = buy_stop_loss 

66 self.__buy_take_profit: float = buy_take_profit 

67 self.__penalty_starts: int = penalty_starts 

68 self.__penalty_stops: int = penalty_stops 

69 self.__static_reward_adjustment: float = static_reward_adjustment 

70 self.__validator: RewardValidatorBase = validator 

71 self.__instantiated_environment: TradingEnvironment = None 

72 

73 # Agent config 

74 self.__model_blue_print: BaseBluePrint = model_blue_print 

75 self.__policy: Policy = policy 

76 self.__optimizer: Optimizer = optimizer 

77 self.__instantiated_agent: AgentHandler = None 

78 

79 def __str__(self) -> str: 

80 """ 

81 Returns a string representation of the configuration. 

82 

83 Creates a formatted multi-line string containing all configuration 

84 parameters and their values for easy logging. 

85 

86 Returns: 

87 str: Formatted string representation of the configuration. 

88 """ 

89 

90 return f"Training config:\n" \ 

91 f"\tnr_of_steps: {self.nr_of_steps}\n" \ 

92 f"\tnr_of_episodes: {self.nr_of_episodes}\n" \ 

93 f"\trepeat_test: {self.repeat_test}\n" \ 

94 f"\ttest_ratio: {self.__test_ratio}\n" \ 

95 f"\tinitial_budget: {self.__initial_budget}\n" \ 

96 f"\tmax_amount_of_trades: {self.__max_amount_of_trades}\n" \ 

97 f"\twindow_size: {self.__window_size}\n" \ 

98 f"\tsell_stop_loss: {self.__sell_stop_loss}\n" \ 

99 f"\tsell_take_profit: {self.__sell_take_profit}\n" \ 

100 f"\tbuy_stop_loss: {self.__buy_stop_loss}\n" \ 

101 f"\tbuy_take_profit: {self.__buy_take_profit}\n" \ 

102 f"\tpenalty_starts: {self.__penalty_starts}\n" \ 

103 f"\tpenalty_stops: {self.__penalty_stops}\n" \ 

104 f"\tstatic_reward_adjustment: {self.__static_reward_adjustment}\n" \ 

105 f"\tvalidator: {self.__validator.__class__.__name__}\n" \ 

106 f"\t\t{vars(self.__validator)}\n" \ 

107 f"\tmodel_blue_print: {self.__model_blue_print.__class__.__name__}\n" \ 

108 f"\t\t{vars(self.__model_blue_print)}\n" \ 

109 f"\tpolicy: {self.__policy.__class__.__name__}\n" \ 

110 f"\t\t{vars(self.__policy)}\n" \ 

111 f"\toptimizer: {self.__optimizer.__class__.__name__}\n" \ 

112 f"\t\t{self.__optimizer._hyper}\n" 

113 

114 def instantiate_environment(self) -> TradingEnvironment: 

115 """ 

116 Creates and returns a TradingEnvironment based on the configuration. 

117 

118 Instantiates a new trading environment with the parameters specified 

119 in this config. Stores the created environment internally for later use 

120 when creating the agent. 

121 

122 Returns: 

123 TradingEnvironment: Configured trading environment ready for training. 

124 """ 

125 

126 self.__instantiated_environment = TradingEnvironment(self.__data_path, 

127 self.__initial_budget, 

128 self.__max_amount_of_trades, 

129 self.__window_size, 

130 self.__validator, 

131 self.__sell_stop_loss, 

132 self.__sell_take_profit, 

133 self.__buy_stop_loss, 

134 self.__buy_take_profit, 

135 self.__test_ratio, 

136 self.__penalty_starts, 

137 self.__penalty_stops, 

138 self.__static_reward_adjustment) 

139 

140 return self.__instantiated_environment 

141 

142 def instantiate_agent(self) -> AgentHandler: 

143 """ 

144 Creates and returns an AgentHandler based on the configuration. 

145 

146 Uses the model blueprint to create a neural network model with the correct 

147 input and output dimensions based on the environment's observation and action 

148 spaces. Then wraps this model in an AgentHandler with the specified policy 

149 and optimizer. 

150 

151 Returns: 

152 AgentHandler: Configured agent handler ready for training. 

153 

154 Raises: 

155 RuntimeError: If environment has not been instantiated first. 

156 """ 

157 

158 if self.__instantiated_environment is None: 

159 raise RuntimeError("Environment not instantiated yet!") 

160 

161 observation_space_shape = self.__instantiated_environment.observation_space.shape 

162 nr_of_actions = self.__instantiated_environment.action_space.n 

163 spatial_data_shape = self.__instantiated_environment.get_environment_spatial_data_dimension() 

164 model = self.__model_blue_print.instantiate_model(observation_space_shape, nr_of_actions, spatial_data_shape) 

165 self.__instantiated_agent = AgentHandler(model, self.__policy, nr_of_actions, self.__optimizer) 

166 

167 return self.__instantiated_agent