Coverage for source/training/training_config.py: 23%

35 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-06-06 12:00 +0000

1# training/training_config.py 

2 

3# global imports 

4from typing import Optional 

5 

6# local imports 

7from source.agent import AgentHandler, LearningStrategyHandlerBase, TestingStrategyHandlerBase 

8from source.environment import LabelAnnotatorBase, RewardValidatorBase, SimpleLabelAnnotator, \ 

9 PriceRewardValidator, TradingEnvironment 

10from source.model import BluePrintBase 

11 

12class TrainingConfig(): 

13 """""" 

14 

15 def __init__(self, nr_of_steps: int, nr_of_episodes: int, model_blue_print: BluePrintBase, 

16 data_path: str, initial_budget: float, max_amount_of_trades: int, window_size: int, 

17 learning_strategy_handler: LearningStrategyHandlerBase, 

18 testing_strategy_handler: TestingStrategyHandlerBase, sell_stop_loss: float = 0.8, 

19 sell_take_profit: float = 1.2, buy_stop_loss: float = 0.8, buy_take_profit: float = 1.2, 

20 penalty_starts: int = 0, penalty_stops: int = 10, static_reward_adjustment: float = 1, 

21 repeat_test: int = 10, test_ratio: float = 0.2, validator: Optional[RewardValidatorBase] = None, 

22 label_annotator: Optional[LabelAnnotatorBase] = None) -> None: 

23 """""" 

24 

25 if validator is None: 

26 validator = PriceRewardValidator() 

27 

28 if label_annotator is None: 

29 label_annotator = SimpleLabelAnnotator() 

30 

31 # Training config 

32 self.nr_of_steps: int = nr_of_steps 

33 self.nr_of_episodes: int = nr_of_episodes 

34 self.repeat_test: int = repeat_test 

35 

36 # Environment config 

37 self.__data_path: str = data_path 

38 self.__test_ratio = test_ratio 

39 self.__initial_budget: float = initial_budget 

40 self.__max_amount_of_trades: int = max_amount_of_trades 

41 self.__window_size: int = window_size 

42 self.__sell_stop_loss: float = sell_stop_loss 

43 self.__sell_take_profit: float = sell_take_profit 

44 self.__buy_stop_loss: float = buy_stop_loss 

45 self.__buy_take_profit: float = buy_take_profit 

46 self.__penalty_starts: int = penalty_starts 

47 self.__penalty_stops: int = penalty_stops 

48 self.__static_reward_adjustment: float = static_reward_adjustment 

49 self.__validator: RewardValidatorBase = validator 

50 self.__label_annotator: LabelAnnotatorBase = label_annotator 

51 

52 # Agent config 

53 self.__model_blue_print: BluePrintBase = model_blue_print 

54 self.__learning_strategy_handler: LearningStrategyHandlerBase = learning_strategy_handler 

55 self.__testing_strategy_handler: TestingStrategyHandlerBase = testing_strategy_handler 

56 

57 def __str__(self) -> str: 

58 """ 

59 Returns a string representation of the configuration. 

60 

61 Creates a formatted multi-line string containing all configuration 

62 parameters and their values for easy logging. 

63 

64 Returns: 

65 str: Formatted string representation of the configuration. 

66 """ 

67 

68 return f"Training config:\n" \ 

69 f"\tnr_of_steps: {self.nr_of_steps}\n" \ 

70 f"\tnr_of_episodes: {self.nr_of_episodes}\n" \ 

71 f"\trepeat_test: {self.repeat_test}\n" \ 

72 f"\ttest_ratio: {self.__test_ratio}\n" \ 

73 f"\tinitial_budget: {self.__initial_budget}\n" \ 

74 f"\tmax_amount_of_trades: {self.__max_amount_of_trades}\n" \ 

75 f"\twindow_size: {self.__window_size}\n" \ 

76 f"\tsell_stop_loss: {self.__sell_stop_loss}\n" \ 

77 f"\tsell_take_profit: {self.__sell_take_profit}\n" \ 

78 f"\tbuy_stop_loss: {self.__buy_stop_loss}\n" \ 

79 f"\tbuy_take_profit: {self.__buy_take_profit}\n" \ 

80 f"\tpenalty_starts: {self.__penalty_starts}\n" \ 

81 f"\tpenalty_stops: {self.__penalty_stops}\n" \ 

82 f"\tstatic_reward_adjustment: {self.__static_reward_adjustment}\n" \ 

83 f"\tvalidator: {self.__validator.__class__.__name__}\n" \ 

84 f"\t\t{vars(self.__validator)}\n" \ 

85 f"\tmodel_blue_print: {self.__model_blue_print.__class__.__name__}\n" \ 

86 f"\t\t{vars(self.__model_blue_print)}\n" \ 

87 f"\tlearning_strategy_handler: {self.__learning_strategy_handler.__class__.__name__}\n" \ 

88 f"\t\t{vars(self.__learning_strategy_handler)}\n" \ 

89 f"\ttesting_strategy_handler: {self.__testing_strategy_handler.__class__.__name__}\n" \ 

90 f"\t\t{self.__testing_strategy_handler}\n" 

91 

92 def instantiate_agent_handler(self) -> AgentHandler: 

93 """""" 

94 

95 environment = TradingEnvironment(self.__data_path, self.__initial_budget, self.__max_amount_of_trades, 

96 self.__window_size, self.__validator, self.__label_annotator, 

97 self.__sell_stop_loss, self.__sell_take_profit, self.__buy_stop_loss, 

98 self.__buy_take_profit, self.__test_ratio, self.__penalty_starts, 

99 self.__penalty_stops, self.__static_reward_adjustment) 

100 

101 return AgentHandler(self.__model_blue_print, environment, self.__learning_strategy_handler, 

102 self.__testing_strategy_handler)