Coverage for source/training/training_config.py: 0%
40 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-23 22:15 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-23 22:15 +0000
1# training/training_config.py
3# global imports
4from typing import Optional
6# local imports
7from source.agent import AgentHandler, LearningStrategyHandlerBase, TestingStrategyHandlerBase
8from source.environment import LabelAnnotatorBase, LabeledDataBalancer, PriceRewardValidator, \
9 RewardValidatorBase, SimpleLabelAnnotator, TradingEnvironment
10from source.model import BluePrintBase
12class TrainingConfig():
13 """
14 Implements a configuration class for training agents in a trading environment. It encapsulates
15 all necessary parameters for training, including the number of steps, episodes, model blueprint,
16 data path, initial budget, maximum amount of trades, window size, and various reward parameters.
17 It also provides methods for instantiating an agent handler and printing the configuration.
18 """
20 def __init__(self, nr_of_steps: int, nr_of_episodes: int, model_blue_print: BluePrintBase,
21 data_path: str, initial_budget: float, max_amount_of_trades: int, window_size: int,
22 learning_strategy_handler: LearningStrategyHandlerBase,
23 testing_strategy_handler: TestingStrategyHandlerBase, sell_stop_loss: float = 0.8,
24 sell_take_profit: float = 1.2, buy_stop_loss: float = 0.8, buy_take_profit: float = 1.2,
25 penalty_starts: int = 0, penalty_stops: int = 10, static_reward_adjustment: float = 1,
26 repeat_test: int = 10, test_ratio: float = 0.2, validator: Optional[RewardValidatorBase] = None,
27 label_annotator: Optional[LabelAnnotatorBase] = None,
28 labeled_data_balancer: Optional[LabeledDataBalancer] = None) -> None:
29 """
30 Class constructor. Initializes the training configuration with the provided parameters.
32 Parameters:
33 nr_of_steps (int): The number of training steps to perform.
34 nr_of_episodes (int): The number of training episodes to perform.
35 model_blue_print (BluePrintBase): The blueprint for the model to be trained.
36 data_path (str): The path to the training data.
37 initial_budget (float): The initial budget for the trading agent.
38 max_amount_of_trades (int): The maximum number of trades to perform.
39 window_size (int): The size of the observation window.
40 learning_strategy_handler (LearningStrategyHandlerBase): The handler for the learning strategy.
41 testing_strategy_handler (TestingStrategyHandlerBase): The handler for the testing strategy.
42 sell_stop_loss (float): The stop loss threshold for selling.
43 sell_take_profit (float): The take profit threshold for selling.
44 buy_stop_loss (float): The stop loss threshold for buying.
45 buy_take_profit (float): The take profit threshold for buying.
46 penalty_starts (int): The step at which to start applying penalties.
47 penalty_stops (int): The step at which to stop applying penalties.
48 static_reward_adjustment (float): The static adjustment factor for rewards.
49 repeat_test (int): The number of times to repeat testing.
50 test_ratio (float): The ratio of data to use for testing.
51 validator (Optional[RewardValidatorBase]): The reward validator to use. Defaults to PriceRewardValidator.
52 label_annotator (Optional[LabelAnnotatorBase]): The label annotator to use. Defaults to SimpleLabelAnnotator.
53 labeled_data_balancer (Optional[LabeledDataBalancer]): The labeled data balancer to use. Defaults to None.
54 """
56 if validator is None:
57 validator = PriceRewardValidator()
59 if label_annotator is None:
60 label_annotator = SimpleLabelAnnotator()
62 # Training config
63 self.nr_of_steps: int = nr_of_steps
64 self.nr_of_episodes: int = nr_of_episodes
65 self.repeat_test: int = repeat_test
67 # Environment config
68 self.__data_path: str = data_path
69 self.__test_ratio = test_ratio
70 self.__initial_budget: float = initial_budget
71 self.__max_amount_of_trades: int = max_amount_of_trades
72 self.__window_size: int = window_size
73 self.__sell_stop_loss: float = sell_stop_loss
74 self.__sell_take_profit: float = sell_take_profit
75 self.__buy_stop_loss: float = buy_stop_loss
76 self.__buy_take_profit: float = buy_take_profit
77 self.__penalty_starts: int = penalty_starts
78 self.__penalty_stops: int = penalty_stops
79 self.__static_reward_adjustment: float = static_reward_adjustment
80 self.__validator: RewardValidatorBase = validator
81 self.__label_annotator: LabelAnnotatorBase = label_annotator
82 self.__labeled_data_balancer: Optional[LabeledDataBalancer] = labeled_data_balancer
84 # Agent config
85 self.__model_blue_print: BluePrintBase = model_blue_print
86 self.__learning_strategy_handler: LearningStrategyHandlerBase = learning_strategy_handler
87 self.__testing_strategy_handler: TestingStrategyHandlerBase = testing_strategy_handler
89 def __str__(self) -> str:
90 """
91 Returns a string representation of the configuration.
93 Creates a formatted multi-line string containing all configuration
94 parameters and their values for easy logging.
96 Returns:
97 str: Formatted string representation of the configuration.
98 """
100 labeled_data_balancer_info = ""
101 if self.__labeled_data_balancer is not None:
102 labeled_data_balancer_info = \
103 f"\tlabeled_data_balancer: {self.__labeled_data_balancer.__class__.__name__}\n" \
104 f"\t\t{vars(self.__labeled_data_balancer)}\n"
105 else:
106 labeled_data_balancer_info = "\tlabeled_data_balancer: None\n"
108 return f"Training config:\n" \
109 f"\tnr_of_steps: {self.nr_of_steps}\n" \
110 f"\tnr_of_episodes: {self.nr_of_episodes}\n" \
111 f"\trepeat_test: {self.repeat_test}\n" \
112 f"\ttest_ratio: {self.__test_ratio}\n" \
113 f"\tinitial_budget: {self.__initial_budget}\n" \
114 f"\tmax_amount_of_trades: {self.__max_amount_of_trades}\n" \
115 f"\twindow_size: {self.__window_size}\n" \
116 f"\tsell_stop_loss: {self.__sell_stop_loss}\n" \
117 f"\tsell_take_profit: {self.__sell_take_profit}\n" \
118 f"\tbuy_stop_loss: {self.__buy_stop_loss}\n" \
119 f"\tbuy_take_profit: {self.__buy_take_profit}\n" \
120 f"\tpenalty_starts: {self.__penalty_starts}\n" \
121 f"\tpenalty_stops: {self.__penalty_stops}\n" \
122 f"\tstatic_reward_adjustment: {self.__static_reward_adjustment}\n" \
123 f"\tvalidator: {self.__validator.__class__.__name__}\n" \
124 f"\t\t{vars(self.__validator)}\n" \
125 f"\tlabel_annotator: {self.__label_annotator.__class__.__name__}\n" \
126 f"\t\t{vars(self.__label_annotator)}\n" \
127 f"{labeled_data_balancer_info}" \
128 f"\tmodel_blue_print: {self.__model_blue_print.__class__.__name__}\n" \
129 f"\t\t{vars(self.__model_blue_print)}\n" \
130 f"\tlearning_strategy_handler: {self.__learning_strategy_handler.__class__.__name__}\n" \
131 f"\t\t{vars(self.__learning_strategy_handler)}\n" \
132 f"\ttesting_strategy_handler: {self.__testing_strategy_handler.__class__.__name__}\n" \
133 f"\t\t{vars(self.__testing_strategy_handler)}\n"
135 def instantiate_agent_handler(self) -> AgentHandler:
136 """
137 Instantiates the agent handler with the configured environment and strategies.
139 Returns:
140 (AgentHandler): An instance of the agent handler configured with the model blueprint,
141 trading environment, learning strategy handler and testing strategy handler.
142 """
144 environment = TradingEnvironment(self.__data_path, self.__initial_budget, self.__max_amount_of_trades,
145 self.__window_size, self.__validator, self.__label_annotator,
146 self.__sell_stop_loss, self.__sell_take_profit, self.__buy_stop_loss,
147 self.__buy_take_profit, self.__test_ratio, self.__penalty_starts,
148 self.__penalty_stops, self.__static_reward_adjustment,
149 self.__labeled_data_balancer)
151 return AgentHandler(self.__model_blue_print, environment, self.__learning_strategy_handler,
152 self.__testing_strategy_handler)