Coverage for source/agent/agent_handler.py: 40%

48 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-06-06 12:00 +0000

1# agent/agent_handler_base.py 

2 

3# global imports 

4import io 

5import logging 

6import random 

7from contextlib import redirect_stdout 

8from tensorflow.keras.callbacks import Callback 

9from typing import Any, Callable, Optional 

10 

11# local imports 

12from source.agent import AgentBase, LearningStrategyHandlerBase, TestingStrategyHandlerBase 

13from source.environment import TradingEnvironment 

14from source.model import BluePrintBase 

15 

16class AgentHandler(): 

17 """""" 

18 

19 def __init__(self, model_blue_print: BluePrintBase, 

20 trading_environment: TradingEnvironment, 

21 learning_strategy_handler: LearningStrategyHandlerBase, 

22 testing_strategy_handler: TestingStrategyHandlerBase) -> None: 

23 """""" 

24 

25 self.__trained: bool = False 

26 self.__learning_strategy_handler: LearningStrategyHandlerBase = learning_strategy_handler 

27 self.__testing_strategy_handler: TestingStrategyHandlerBase = testing_strategy_handler 

28 self.__trading_environment: TradingEnvironment = trading_environment 

29 self.__agent: AgentBase = learning_strategy_handler.create_agent(model_blue_print, trading_environment) 

30 

31 def train_agent(self, nr_of_steps: int, nr_of_episodes: int, callbacks: Optional[list[Callback]] = None, 

32 model_load_path: Optional[str] = None, 

33 model_save_path: Optional[str] = None) -> tuple[list[str], list[dict]]: 

34 """""" 

35 

36 if callbacks is None: 

37 callbacks = [] 

38 

39 self.__trading_environment.set_mode(TradingEnvironment.TRAIN_MODE) 

40 

41 if model_load_path is not None: 

42 self.__agent.load_model(model_load_path) 

43 

44 captured_output = io.StringIO() 

45 with redirect_stdout(captured_output): #TODO: Create an callback logger 

46 key, report_data = self.__learning_strategy_handler.fit(self.__agent, self.__trading_environment, 

47 nr_of_steps, nr_of_episodes, callbacks) 

48 

49 for line in captured_output.getvalue().split('\n'): 

50 if line.strip(): 

51 logging.info(line) 

52 self.__trained = True 

53 

54 if model_save_path is not None: 

55 self.__agent.save_model(model_save_path) 

56 

57 return key, report_data 

58 

59 def test_agent(self, repeat: int = 1) -> tuple[dict[int, list[str]], dict[int, list[dict[str, Any]]]]: 

60 """""" 

61 

62 if not self.__trained: 

63 logging.error('Agent is not trained yet! Train the agent before testing.') 

64 return {}, {} 

65 

66 self.__trading_environment.set_mode(TradingEnvironment.TEST_MODE) 

67 

68 report_data = {} 

69 key = {} 

70 for i in range(repeat): 

71 env_length = self.__trading_environment.get_environment_length() 

72 window_size = self.__trading_environment.get_trading_consts().WINDOW_SIZE 

73 current_iteration = random.randint(window_size, int(env_length/2)) 

74 self.__trading_environment.reset(current_iteration) 

75 key[i], report_data[i] = self.__testing_strategy_handler.evaluate(self.__agent, 

76 self.__trading_environment) 

77 

78 return key, report_data 

79 

80 def print_model_summary(self, print_function: Optional[Callable] = print) -> None: 

81 """""" 

82 

83 self.__agent.print_summary(print_function = print_function)