Coverage for source/training/training_config.py: 86%
43 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-30 19:46 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-30 19:46 +0000
1# training/training_config.py
3from rl.policy import Policy, BoltzmannQPolicy
4from tensorflow.keras.optimizers import Optimizer, Adam
6from ..environment.trading_environment import TradingEnvironment
7from ..agent.agent_handler import AgentHandler
8from ..model.model_blue_prints.base_blue_print import BaseBluePrint
9from ..environment.reward_validator_base import RewardValidatorBase
11class TrainingConfig():
12 """
13 Responsible for creating and configuring training environment and agent.
15 This class encapsulates all configuration parameters needed for training
16 and testing a trading agent, including environment setup, agent creation,
17 and reward validation. It provides a centralized way to manage training
18 parameters and instantiate required components.
19 """
21 def __init__(self, nr_of_steps: int, nr_of_episodes: int, model_blue_print: BaseBluePrint,
22 data_path: str, initial_budget: float, max_amount_of_trades: int, window_size: int,
23 validator: RewardValidatorBase, sell_stop_loss: float = 0.8, sell_take_profit: float = 1.2,
24 buy_stop_loss: float = 0.8, buy_take_profit: float = 1.2, penalty_starts: int = 0,
25 penalty_stops: int = 10, static_reward_adjustment: float = 1, policy: Policy = BoltzmannQPolicy(),
26 optimizer: Optimizer = Adam(learning_rate=1e-3), repeat_test: int = 10, test_ratio: float = 0.2) -> None:
27 """
28 Initializes the training configuration with provided parameters.
30 Parameters:
31 nr_of_steps (int): Total number of training steps.
32 nr_of_episodes (int): Number of training episodes.
33 model_blue_print (BaseBluePrint): Blueprint for creating the neural network model.
34 data_path (str): Path to the training data file.
35 initial_budget (float): Starting budget for the agent.
36 max_amount_of_trades (int): Maximum number of trades allowed to be placed in the environment.
37 window_size (int): Size of the observation window for market data.
38 validator (RewardValidatorBase): Strategy for validating and calculating rewards.
39 sell_stop_loss (float): Coefficient defining when to stop loss on sell positions.
40 sell_take_profit (float): Coefficient defining when to take profit on sell positions.
41 buy_stop_loss (float): Coefficient defining when to stop loss on buy positions.
42 buy_take_profit (float): Coefficient defining when to take profit on buy positions.
43 penalty_starts (int): Starting point (in trading periods without activity) that penalty should be applied from.
44 penalty_stops (int): Ending point (in trading periods without activity) that penalty growth should be stopped at.
45 static_reward_adjustment (float): Adjustment factor for rewards, used to penalize unwanted actions.
46 policy (Policy): Policy for action selection during training.
47 optimizer (Optimizer): Optimizer to be used for model compilation and training.
48 repeat_test (int): Number of times to repeat testing for evaluation.
49 """
51 # Training config
52 self.nr_of_steps = nr_of_steps
53 self.nr_of_episodes = nr_of_episodes
54 self.repeat_test = repeat_test
56 # Environment config
57 self.__data_path: str = data_path
58 self.__test_ratio = test_ratio
59 self.__initial_budget: float = initial_budget
60 self.__max_amount_of_trades: int = max_amount_of_trades
61 self.__window_size: int = window_size
62 self.__sell_stop_loss: float = sell_stop_loss
63 self.__sell_take_profit: float = sell_take_profit
64 self.__buy_stop_loss: float = buy_stop_loss
65 self.__buy_take_profit: float = buy_take_profit
66 self.__penalty_starts: int = penalty_starts
67 self.__penalty_stops: int = penalty_stops
68 self.__static_reward_adjustment: float = static_reward_adjustment
69 self.__validator: RewardValidatorBase = validator
70 self.__instantiated_environment: TradingEnvironment = None
72 # Agent config
73 self.__model_blue_print: BaseBluePrint = model_blue_print
74 self.__policy: Policy = policy
75 self.__optimizer: Optimizer = optimizer
76 self.__instantiated_agent: AgentHandler = None
78 def __str__(self) -> str:
79 """
80 Returns a string representation of the configuration.
82 Creates a formatted multi-line string containing all configuration
83 parameters and their values for easy logging.
85 Returns:
86 str: Formatted string representation of the configuration.
87 """
89 return f"Training config:\n" \
90 f"\tnr_of_steps: {self.nr_of_steps}\n" \
91 f"\tnr_of_episodes: {self.nr_of_episodes}\n" \
92 f"\trepeat_test: {self.repeat_test}\n" \
93 f"\tinitial_budget: {self.__initial_budget}\n" \
94 f"\tmax_amount_of_trades: {self.__max_amount_of_trades}\n" \
95 f"\twindow_size: {self.__window_size}\n" \
96 f"\tsell_stop_loss: {self.__sell_stop_loss}\n" \
97 f"\tsell_take_profit: {self.__sell_take_profit}\n" \
98 f"\tbuy_stop_loss: {self.__buy_stop_loss}\n" \
99 f"\tbuy_take_profit: {self.__buy_take_profit}\n" \
100 f"\tpenalty_starts: {self.__penalty_starts}\n" \
101 f"\tpenalty_stops: {self.__penalty_stops}\n" \
102 f"\tstatic_reward_adjustment: {self.__static_reward_adjustment}\n" \
103 f"\tvalidator: {self.__validator.__class__.__name__}\n" \
104 f"\t\t{vars(self.__validator)}\n" \
105 f"\tmodel_blue_print: {self.__model_blue_print.__class__.__name__}\n" \
106 f"\t\t{vars(self.__model_blue_print)}\n" \
107 f"\tpolicy: {self.__policy.__class__.__name__}\n" \
108 f"\t\t{vars(self.__policy)}\n" \
109 f"\toptimizer: {self.__optimizer.__class__.__name__}\n" \
110 f"\t\t{self.__optimizer._hyper}\n"
112 def instantiate_environment(self) -> TradingEnvironment:
113 """
114 Creates and returns a TradingEnvironment based on the configuration.
116 Instantiates a new trading environment with the parameters specified
117 in this config. Stores the created environment internally for later use
118 when creating the agent.
120 Returns:
121 TradingEnvironment: Configured trading environment ready for training.
122 """
124 self.__instantiated_environment = TradingEnvironment(self.__data_path,
125 self.__initial_budget,
126 self.__max_amount_of_trades,
127 self.__window_size,
128 self.__validator,
129 self.__sell_stop_loss,
130 self.__sell_take_profit,
131 self.__buy_stop_loss,
132 self.__buy_take_profit,
133 self.__test_ratio,
134 self.__penalty_starts,
135 self.__penalty_stops,
136 self.__static_reward_adjustment)
138 return self.__instantiated_environment
140 def instantiate_agent(self) -> AgentHandler:
141 """
142 Creates and returns an AgentHandler based on the configuration.
144 Uses the model blueprint to create a neural network model with the correct
145 input and output dimensions based on the environment's observation and action
146 spaces. Then wraps this model in an AgentHandler with the specified policy
147 and optimizer.
149 Returns:
150 AgentHandler: Configured agent handler ready for training.
152 Raises:
153 RuntimeError: If environment has not been instantiated first.
154 """
156 if self.__instantiated_environment is None:
157 raise RuntimeError("Environment not instantiated yet!")
159 observation_space_shape = self.__instantiated_environment.observation_space.shape
160 nr_of_actions = self.__instantiated_environment.action_space.n
161 spatial_data_shape = self.__instantiated_environment.get_environment_spatial_data_dimension()
162 model = self.__model_blue_print.instantiate_model(observation_space_shape, nr_of_actions, spatial_data_shape)
163 self.__instantiated_agent = AgentHandler(model, self.__policy, nr_of_actions, self.__optimizer)
165 return self.__instantiated_agent