Coverage for source/agent/agent_handler.py: 75%

51 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-08-04 20:03 +0000

1# agent/agent_handler_base.py 

2 

3# global imports 

4import logging 

5import numpy as np 

6from tensorflow.keras.callbacks import Callback 

7from typing import Any, Callable, Optional 

8 

9# local imports 

10from source.agent import AgentBase, LearningStrategyHandlerBase, TestingStrategyHandlerBase 

11from source.environment import TradingEnvironment 

12from source.model import BluePrintBase 

13from source.utils import redirect_stdout_to_logging 

14 

15class AgentHandler(): 

16 """ 

17 Implements agent handler that is responsible for training and testing 

18 the agent in the given trading environment using the specified learning 

19 and testing strategies. It is used as an wrapper around the agent 

20 to provide a convenient interface for TrainingHandler. 

21 """ 

22 

23 def __init__(self, model_blue_print: BluePrintBase, 

24 trading_environment: TradingEnvironment, 

25 learning_strategy_handler: LearningStrategyHandlerBase, 

26 testing_strategy_handlers: list[TestingStrategyHandlerBase]) -> None: 

27 """ 

28 Class constructor. Initializes the agent handler with the given model blueprint, 

29 trading environment, learning strategy handler, and testing strategy handler. 

30 

31 Parameters: 

32 model_blue_print (BluePrintBase): The model blueprint to be used for the agent. 

33 trading_environment (TradingEnvironment): The trading environment in which the agent will operate. 

34 learning_strategy_handler (LearningStrategyHandlerBase): The learning strategy handler to be used for training. 

35 testing_strategy_handlers (list[TestingStrategyHandlerBase]): The testing strategy handlers to be used for evaluation. 

36 """ 

37 

38 self.__trained: bool = False 

39 self.__learning_strategy_handler: LearningStrategyHandlerBase = learning_strategy_handler 

40 self.__testing_strategy_handlers: list[TestingStrategyHandlerBase] = testing_strategy_handlers 

41 self.__trading_environment: TradingEnvironment = trading_environment 

42 self.__agent: AgentBase = learning_strategy_handler.create_agent(model_blue_print, trading_environment) 

43 

44 def train_agent(self, nr_of_steps: int, nr_of_episodes: int, callbacks: Optional[list[Callback]] = None, 

45 model_load_path: Optional[str] = None, 

46 model_save_path: Optional[str] = None) -> tuple[list[str], list[dict]]: 

47 """ 

48 Trains the agent using the specified number of steps and episodes. 

49 

50 Parameters: 

51 nr_of_steps (int): The number of steps to train the agent. 

52 nr_of_episodes (int): The number of episodes to train the agent. 

53 callbacks (Optional[list[Callback]]): A list of callbacks to be used during training. 

54 model_load_path (Optional[str]): Path to load the pre-trained model from. 

55 model_save_path (Optional[str]): Path to save the trained model to. 

56 

57 Returns: 

58 (tuple[list[str], list[dict]]): A tuple containing the keys and report data from 

59 the training process. 

60 """ 

61 

62 if callbacks is None: 

63 callbacks = [] 

64 

65 self.__trading_environment.set_mode(TradingEnvironment.TRAIN_MODE) 

66 

67 if model_load_path is not None: 

68 self.__agent.load_model(model_load_path) 

69 

70 with redirect_stdout_to_logging(): 

71 keys, report_data = self.__learning_strategy_handler.fit( 

72 self.__agent, 

73 self.__trading_environment, 

74 nr_of_steps, 

75 nr_of_episodes, 

76 callbacks 

77 ) 

78 

79 self.__trained = True 

80 if model_save_path is not None: 

81 self.__agent.save_model(model_save_path) 

82 

83 return keys, report_data 

84 

85 def test_agent(self, repeat: int = 1) -> tuple[dict[int, list[str]], dict[int, list[dict[str, Any]]]]: 

86 """ 

87 Tests the agent using the specified number of repetitions. 

88 

89 Parameters: 

90 repeat (int): The number of times to repeat the testing process. 

91 

92 Returns: 

93 (tuple[dict[int, list[str]], dict[int, list[dict[str, Any]]]]): A tuple containing the keys and 

94 report data from the testing process. 

95 """ 

96 

97 if not self.__trained: 

98 logging.error('Agent is not trained yet! Train the agent before testing.') 

99 return {}, {} 

100 

101 self.__trading_environment.set_mode(TradingEnvironment.TEST_MODE) 

102 

103 if repeat < 1: 

104 # Adding 1 to repeat to ensure that any value above 1 will give 

105 # full evaluation and cross-validation on the n-split part 

106 # equal to number of repeats 

107 repeat += 1 

108 

109 report_data = {x: [] for x in range(repeat)} 

110 keys = {x: [] for x in range(repeat)} 

111 window_size = self.__trading_environment.get_trading_consts().WINDOW_SIZE 

112 max_env_length = self.__trading_environment.get_environment_length() - 1 

113 for strategy_handler in self.__testing_strategy_handlers: 

114 last_env_length_index = window_size 

115 full_eval_key, full_eval_data = \ 

116 strategy_handler.evaluate(self.__agent, self.__trading_environment, 

117 (last_env_length_index, max_env_length)) 

118 keys[0] += full_eval_key 

119 report_data[0] += full_eval_data 

120 

121 for i, env_length_index in enumerate(np.linspace(window_size, max_env_length, repeat, dtype = int)[1:], start = 1): 

122 eval_key, eval_data = \ 

123 strategy_handler.evaluate(self.__agent, self.__trading_environment, 

124 (last_env_length_index, env_length_index)) 

125 keys[i] += eval_key 

126 report_data[i] += eval_data 

127 last_env_length_index = env_length_index 

128 

129 return keys, report_data 

130 

131 def print_model_summary(self, print_function: Optional[Callable] = print) -> None: 

132 """ 

133 Prints a summary of the model architecture and parameters. 

134 

135 Parameters: 

136 print_function (Optional[Callable]): A function to print the summary. Defaults to print. 

137 """ 

138 

139 self.__agent.print_summary(print_function = print_function)