Coverage for source/training/training_config.py: 98%
43 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-04 21:16 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-04 21:16 +0000
1# training/training_config.py
3# global imports
4import pandas as pd
5from typing import Any, Optional
7# local imports
8from source.agent import AgentHandler, LearningStrategyHandlerBase, TestingStrategyHandlerBase
9from source.environment import LabelAnnotatorBase, LabeledDataBalancer, PriceRewardValidator, \
10 RewardValidatorBase, SimpleLabelAnnotator, TradingEnvironment
11from source.model import BluePrintBase
13class TrainingConfig():
14 """
15 Implements a configuration class for training agents in a trading environment. It encapsulates
16 all necessary parameters for training, including the number of steps, episodes, model blueprint,
17 data path, initial budget, maximum amount of trades, window size, and various reward parameters.
18 It also provides methods for instantiating an agent handler and printing the configuration.
19 """
21 def __init__(self, nr_of_steps: int, nr_of_episodes: int, model_blue_print: BluePrintBase,
22 data: pd.DataFrame, initial_budget: float, max_amount_of_trades: int, window_size: int,
23 learning_strategy_handler: LearningStrategyHandlerBase,
24 testing_strategy_handlers: list[TestingStrategyHandlerBase], sell_stop_loss: float = 0.8,
25 sell_take_profit: float = 1.2, buy_stop_loss: float = 0.8, buy_take_profit: float = 1.2,
26 penalty_starts: int = 0, penalty_stops: int = 10, static_reward_adjustment: float = 1,
27 repeat_test: int = 10, test_ratio: float = 0.2, validator: Optional[RewardValidatorBase] = None,
28 label_annotator: Optional[LabelAnnotatorBase] = None,
29 labeled_data_balancer: Optional[LabeledDataBalancer] = None,
30 meta_data: Optional[dict[str, Any]] = None,
31 trading_mode: Optional[TradingEnvironment.TradingMode] = None) -> None:
32 """
33 Class constructor. Initializes the training configuration with the provided parameters.
35 Parameters:
36 nr_of_steps (int): The number of training steps to perform.
37 nr_of_episodes (int): The number of training episodes to perform.
38 model_blue_print (BluePrintBase): The blueprint for the model to be trained.
39 data (pd.DataFrame): The training data.
40 initial_budget (float): The initial budget for the trading agent.
41 max_amount_of_trades (int): The maximum number of trades to perform.
42 window_size (int): The size of the observation window.
43 learning_strategy_handler (LearningStrategyHandlerBase): The handler for the learning strategy.
44 testing_strategy_handler (TestingStrategyHandlerBase): The handler for the testing strategy.
45 sell_stop_loss (float): The stop loss threshold for selling.
46 sell_take_profit (float): The take profit threshold for selling.
47 buy_stop_loss (float): The stop loss threshold for buying.
48 buy_take_profit (float): The take profit threshold for buying.
49 penalty_starts (int): The step at which to start applying penalties.
50 penalty_stops (int): The step at which to stop applying penalties.
51 static_reward_adjustment (float): The static adjustment factor for rewards.
52 repeat_test (int): The number of times to repeat testing.
53 test_ratio (float): The ratio of data to use for testing.
54 validator (Optional[RewardValidatorBase]): The reward validator to use. Defaults to PriceRewardValidator.
55 label_annotator (Optional[LabelAnnotatorBase]): The label annotator to use. Defaults to SimpleLabelAnnotator.
56 labeled_data_balancer (Optional[LabeledDataBalancer]): The labeled data balancer to use. Defaults to None.
57 meta_data (Optional[dict[str, Any]]): Optional metadata for the training configuration.
58 trading_mode (Optional[TradingEnvironment.TradingMode]): The trading mode to use. Defaults to None.
59 """
61 if validator is None:
62 validator = PriceRewardValidator()
64 if label_annotator is None:
65 label_annotator = SimpleLabelAnnotator()
67 # Training config
68 self.nr_of_steps: int = nr_of_steps
69 self.nr_of_episodes: int = nr_of_episodes
70 self.repeat_test: int = repeat_test
72 # Environment config
73 self.__data: pd.DataFrame = data
74 self.__meta_data: Optional[dict[str, Any]] = meta_data
75 self.__trading_mode: Optional[TradingEnvironment.TradingMode] = trading_mode
76 self.__test_ratio: float = test_ratio
77 self.__initial_budget: float = initial_budget
78 self.__max_amount_of_trades: int = max_amount_of_trades
79 self.__window_size: int = window_size
80 self.__sell_stop_loss: float = sell_stop_loss
81 self.__sell_take_profit: float = sell_take_profit
82 self.__buy_stop_loss: float = buy_stop_loss
83 self.__buy_take_profit: float = buy_take_profit
84 self.__penalty_starts: int = penalty_starts
85 self.__penalty_stops: int = penalty_stops
86 self.__static_reward_adjustment: float = static_reward_adjustment
87 self.__validator: RewardValidatorBase = validator
88 self.__label_annotator: LabelAnnotatorBase = label_annotator
89 self.__labeled_data_balancer: Optional[LabeledDataBalancer] = labeled_data_balancer
91 # Agent config
92 self.__model_blue_print: BluePrintBase = model_blue_print
93 self.__learning_strategy_handler: LearningStrategyHandlerBase = learning_strategy_handler
94 self.__testing_strategy_handlers: list[TestingStrategyHandlerBase] = testing_strategy_handlers
96 def __str__(self) -> str:
97 """
98 Returns a string representation of the configuration.
100 Creates a formatted multi-line string containing all configuration
101 parameters and their values for easy logging.
103 Returns:
104 str: Formatted string representation of the configuration.
105 """
107 labeled_data_balancer_info = ""
108 if self.__labeled_data_balancer is not None:
109 labeled_data_balancer_info = \
110 f"\tlabeled_data_balancer: {self.__labeled_data_balancer.__class__.__name__}\n" \
111 f"\t\t{vars(self.__labeled_data_balancer)}\n"
112 else:
113 labeled_data_balancer_info = "\tlabeled_data_balancer: None\n"
115 return f"Training config:\n" \
116 f"\tnr_of_steps: {self.nr_of_steps}\n" \
117 f"\tnr_of_episodes: {self.nr_of_episodes}\n" \
118 f"\trepeat_test: {self.repeat_test}\n" \
119 f"\ttest_ratio: {self.__test_ratio}\n" \
120 f"\ttrading_mode: {self.__trading_mode}\n" \
121 f"\tinitial_budget: {self.__initial_budget}\n" \
122 f"\tmax_amount_of_trades: {self.__max_amount_of_trades}\n" \
123 f"\twindow_size: {self.__window_size}\n" \
124 f"\tsell_stop_loss: {self.__sell_stop_loss}\n" \
125 f"\tsell_take_profit: {self.__sell_take_profit}\n" \
126 f"\tbuy_stop_loss: {self.__buy_stop_loss}\n" \
127 f"\tbuy_take_profit: {self.__buy_take_profit}\n" \
128 f"\tpenalty_starts: {self.__penalty_starts}\n" \
129 f"\tpenalty_stops: {self.__penalty_stops}\n" \
130 f"\tstatic_reward_adjustment: {self.__static_reward_adjustment}\n" \
131 f"\tvalidator: {self.__validator.__class__.__name__}\n" \
132 f"\t\t{vars(self.__validator)}\n" \
133 f"\tlabel_annotator: {self.__label_annotator.__class__.__name__}\n" \
134 f"\t\t{vars(self.__label_annotator)}\n" \
135 f"{labeled_data_balancer_info}" \
136 f"\tmodel_blue_print: {self.__model_blue_print.__class__.__name__}\n" \
137 f"\t\t{vars(self.__model_blue_print)}\n" \
138 f"\tlearning_strategy_handler: {self.__learning_strategy_handler.__class__.__name__}\n" \
139 f"\t\t{vars(self.__learning_strategy_handler)}\n" \
140 f"\ttesting_strategy_handlers: {[handler.__class__.__name__ for handler in self.__testing_strategy_handlers]}\n" \
141 f"\t\t{[vars(handler) for handler in self.__testing_strategy_handlers]}\n"
143 def instantiate_agent_handler(self) -> AgentHandler:
144 """
145 Instantiates the agent handler with the configured environment and strategies.
147 Returns:
148 (AgentHandler): An instance of the agent handler configured with the model blueprint,
149 trading environment, learning strategy handler and testing strategy handlers.
150 """
152 environment = TradingEnvironment(self.__data, self.__initial_budget, self.__max_amount_of_trades,
153 self.__window_size, self.__validator, self.__label_annotator,
154 self.__sell_stop_loss, self.__sell_take_profit, self.__buy_stop_loss,
155 self.__buy_take_profit, self.__test_ratio, self.__penalty_starts,
156 self.__penalty_stops, self.__static_reward_adjustment,
157 self.__labeled_data_balancer, self.__meta_data,
158 self.__trading_mode)
160 return AgentHandler(self.__model_blue_print, environment, self.__learning_strategy_handler,
161 self.__testing_strategy_handlers)