Coverage for source/agent/agent_handler.py: 95%

43 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-08-01 20:51 +0000

1# agent/agent_handler_base.py 

2 

3# global imports 

4import logging 

5import random 

6from tensorflow.keras.callbacks import Callback 

7from typing import Any, Callable, Optional 

8 

9# local imports 

10from source.agent import AgentBase, LearningStrategyHandlerBase, TestingStrategyHandlerBase 

11from source.environment import TradingEnvironment 

12from source.model import BluePrintBase 

13from source.utils import redirect_stdout_to_logging 

14 

15class AgentHandler(): 

16 """ 

17 Implements agent handler that is responsible for training and testing 

18 the agent in the given trading environment using the specified learning 

19 and testing strategies. It is used as an wrapper around the agent 

20 to provide a convenient interface for TrainingHandler. 

21 """ 

22 

23 def __init__(self, model_blue_print: BluePrintBase, 

24 trading_environment: TradingEnvironment, 

25 learning_strategy_handler: LearningStrategyHandlerBase, 

26 testing_strategy_handler: TestingStrategyHandlerBase) -> None: 

27 """ 

28 Class constructor. Initializes the agent handler with the given model blueprint, 

29 trading environment, learning strategy handler, and testing strategy handler. 

30 

31 Parameters: 

32 model_blue_print (BluePrintBase): The model blueprint to be used for the agent. 

33 trading_environment (TradingEnvironment): The trading environment in which the agent will operate. 

34 learning_strategy_handler (LearningStrategyHandlerBase): The learning strategy handler to be used for training. 

35 testing_strategy_handler (TestingStrategyHandlerBase): The testing strategy handler to be used for evaluation. 

36 """ 

37 

38 self.__trained: bool = False 

39 self.__learning_strategy_handler: LearningStrategyHandlerBase = learning_strategy_handler 

40 self.__testing_strategy_handler: TestingStrategyHandlerBase = testing_strategy_handler 

41 self.__trading_environment: TradingEnvironment = trading_environment 

42 self.__agent: AgentBase = learning_strategy_handler.create_agent(model_blue_print, trading_environment) 

43 

44 def train_agent(self, nr_of_steps: int, nr_of_episodes: int, callbacks: Optional[list[Callback]] = None, 

45 model_load_path: Optional[str] = None, 

46 model_save_path: Optional[str] = None) -> tuple[list[str], list[dict]]: 

47 """ 

48 Trains the agent using the specified number of steps and episodes. 

49 

50 Parameters: 

51 nr_of_steps (int): The number of steps to train the agent. 

52 nr_of_episodes (int): The number of episodes to train the agent. 

53 callbacks (Optional[list[Callback]]): A list of callbacks to be used during training. 

54 model_load_path (Optional[str]): Path to load the pre-trained model from. 

55 model_save_path (Optional[str]): Path to save the trained model to. 

56 

57 Returns: 

58 (tuple[list[str], list[dict]]): A tuple containing the keys and report data from 

59 the training process. 

60 """ 

61 

62 if callbacks is None: 

63 callbacks = [] 

64 

65 self.__trading_environment.set_mode(TradingEnvironment.TRAIN_MODE) 

66 

67 if model_load_path is not None: 

68 self.__agent.load_model(model_load_path) 

69 

70 with redirect_stdout_to_logging(): 

71 keys, report_data = self.__learning_strategy_handler.fit( 

72 self.__agent, 

73 self.__trading_environment, 

74 nr_of_steps, 

75 nr_of_episodes, 

76 callbacks 

77 ) 

78 

79 self.__trained = True 

80 if model_save_path is not None: 

81 self.__agent.save_model(model_save_path) 

82 

83 return keys, report_data 

84 

85 def test_agent(self, repeat: int = 1) -> tuple[dict[int, list[str]], dict[int, list[dict[str, Any]]]]: 

86 """ 

87 Tests the agent using the specified number of repetitions. 

88 

89 Parameters: 

90 repeat (int): The number of times to repeat the testing process. 

91 

92 Returns: 

93 (tuple[dict[int, list[str]], dict[int, list[dict[str, Any]]]]): A tuple containing the keys and 

94 report data from the testing process. 

95 """ 

96 

97 if not self.__trained: 

98 logging.error('Agent is not trained yet! Train the agent before testing.') 

99 return {}, {} 

100 

101 self.__trading_environment.set_mode(TradingEnvironment.TEST_MODE) 

102 

103 report_data = {} 

104 keys = {} 

105 for i in range(repeat): 

106 env_length = self.__trading_environment.get_environment_length() 

107 window_size = self.__trading_environment.get_trading_consts().WINDOW_SIZE 

108 current_iteration = random.randint(window_size, int(env_length/2)) 

109 self.__trading_environment.reset(current_iteration) 

110 keys[i], report_data[i] = self.__testing_strategy_handler.evaluate(self.__agent, 

111 self.__trading_environment) 

112 

113 return keys, report_data 

114 

115 def print_model_summary(self, print_function: Optional[Callable] = print) -> None: 

116 """ 

117 Prints a summary of the model architecture and parameters. 

118 

119 Parameters: 

120 print_function (Optional[Callable]): A function to print the summary. Defaults to print. 

121 """ 

122 

123 self.__agent.print_summary(print_function = print_function)