Coverage for source/training/training_config.py: 23%
35 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-06 12:00 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-06 12:00 +0000
1# training/training_config.py
3# global imports
4from typing import Optional
6# local imports
7from source.agent import AgentHandler, LearningStrategyHandlerBase, TestingStrategyHandlerBase
8from source.environment import LabelAnnotatorBase, RewardValidatorBase, SimpleLabelAnnotator, \
9 PriceRewardValidator, TradingEnvironment
10from source.model import BluePrintBase
12class TrainingConfig():
13 """"""
15 def __init__(self, nr_of_steps: int, nr_of_episodes: int, model_blue_print: BluePrintBase,
16 data_path: str, initial_budget: float, max_amount_of_trades: int, window_size: int,
17 learning_strategy_handler: LearningStrategyHandlerBase,
18 testing_strategy_handler: TestingStrategyHandlerBase, sell_stop_loss: float = 0.8,
19 sell_take_profit: float = 1.2, buy_stop_loss: float = 0.8, buy_take_profit: float = 1.2,
20 penalty_starts: int = 0, penalty_stops: int = 10, static_reward_adjustment: float = 1,
21 repeat_test: int = 10, test_ratio: float = 0.2, validator: Optional[RewardValidatorBase] = None,
22 label_annotator: Optional[LabelAnnotatorBase] = None) -> None:
23 """"""
25 if validator is None:
26 validator = PriceRewardValidator()
28 if label_annotator is None:
29 label_annotator = SimpleLabelAnnotator()
31 # Training config
32 self.nr_of_steps: int = nr_of_steps
33 self.nr_of_episodes: int = nr_of_episodes
34 self.repeat_test: int = repeat_test
36 # Environment config
37 self.__data_path: str = data_path
38 self.__test_ratio = test_ratio
39 self.__initial_budget: float = initial_budget
40 self.__max_amount_of_trades: int = max_amount_of_trades
41 self.__window_size: int = window_size
42 self.__sell_stop_loss: float = sell_stop_loss
43 self.__sell_take_profit: float = sell_take_profit
44 self.__buy_stop_loss: float = buy_stop_loss
45 self.__buy_take_profit: float = buy_take_profit
46 self.__penalty_starts: int = penalty_starts
47 self.__penalty_stops: int = penalty_stops
48 self.__static_reward_adjustment: float = static_reward_adjustment
49 self.__validator: RewardValidatorBase = validator
50 self.__label_annotator: LabelAnnotatorBase = label_annotator
52 # Agent config
53 self.__model_blue_print: BluePrintBase = model_blue_print
54 self.__learning_strategy_handler: LearningStrategyHandlerBase = learning_strategy_handler
55 self.__testing_strategy_handler: TestingStrategyHandlerBase = testing_strategy_handler
57 def __str__(self) -> str:
58 """
59 Returns a string representation of the configuration.
61 Creates a formatted multi-line string containing all configuration
62 parameters and their values for easy logging.
64 Returns:
65 str: Formatted string representation of the configuration.
66 """
68 return f"Training config:\n" \
69 f"\tnr_of_steps: {self.nr_of_steps}\n" \
70 f"\tnr_of_episodes: {self.nr_of_episodes}\n" \
71 f"\trepeat_test: {self.repeat_test}\n" \
72 f"\ttest_ratio: {self.__test_ratio}\n" \
73 f"\tinitial_budget: {self.__initial_budget}\n" \
74 f"\tmax_amount_of_trades: {self.__max_amount_of_trades}\n" \
75 f"\twindow_size: {self.__window_size}\n" \
76 f"\tsell_stop_loss: {self.__sell_stop_loss}\n" \
77 f"\tsell_take_profit: {self.__sell_take_profit}\n" \
78 f"\tbuy_stop_loss: {self.__buy_stop_loss}\n" \
79 f"\tbuy_take_profit: {self.__buy_take_profit}\n" \
80 f"\tpenalty_starts: {self.__penalty_starts}\n" \
81 f"\tpenalty_stops: {self.__penalty_stops}\n" \
82 f"\tstatic_reward_adjustment: {self.__static_reward_adjustment}\n" \
83 f"\tvalidator: {self.__validator.__class__.__name__}\n" \
84 f"\t\t{vars(self.__validator)}\n" \
85 f"\tmodel_blue_print: {self.__model_blue_print.__class__.__name__}\n" \
86 f"\t\t{vars(self.__model_blue_print)}\n" \
87 f"\tlearning_strategy_handler: {self.__learning_strategy_handler.__class__.__name__}\n" \
88 f"\t\t{vars(self.__learning_strategy_handler)}\n" \
89 f"\ttesting_strategy_handler: {self.__testing_strategy_handler.__class__.__name__}\n" \
90 f"\t\t{self.__testing_strategy_handler}\n"
92 def instantiate_agent_handler(self) -> AgentHandler:
93 """"""
95 environment = TradingEnvironment(self.__data_path, self.__initial_budget, self.__max_amount_of_trades,
96 self.__window_size, self.__validator, self.__label_annotator,
97 self.__sell_stop_loss, self.__sell_take_profit, self.__buy_stop_loss,
98 self.__buy_take_profit, self.__test_ratio, self.__penalty_starts,
99 self.__penalty_stops, self.__static_reward_adjustment)
101 return AgentHandler(self.__model_blue_print, environment, self.__learning_strategy_handler,
102 self.__testing_strategy_handler)