Coverage for source/agent/strategies/learning_strategy_handler_base.py: 86%

36 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-08-01 20:51 +0000

1# agent/strategies/learning_strategy_handler_base.py 

2 

3# global imports 

4from abc import ABC, abstractmethod 

5from tensorflow.keras.callbacks import Callback 

6from typing import Any 

7 

8# local imports 

9from source.agent import AgentBase 

10from source.environment import TradingEnvironment 

11from source.model import BluePrintBase 

12 

13class LearningStrategyHandlerBase(ABC): 

14 """ 

15 Implements a base class for learning strategy handlers. It provides an interface 

16 for creating agents, fitting them to the environment, and providing parameters 

17 needed for the model blueprint instantiation. 

18 """ 

19 

20 # global class constants 

21 PLOTTING_KEY: str = 'asset_price_movement_summary' 

22 

23 def __init__(self) -> None: 

24 """ 

25 Class constructor. Initializes the parameter provision callbacks. 

26 """ 

27 

28 self.__parameter_provision_callbacks = { 

29 'input_shape': self._provide_input_shape, 

30 'output_length': self._provide_output_length, 

31 'spatial_data_shape': self._provide_spatial_data_shape 

32 } 

33 

34 @abstractmethod 

35 def create_agent(self, model_blue_print: BluePrintBase, 

36 trading_environment: TradingEnvironment) -> AgentBase: 

37 """ 

38 Creates an agent based on the provided model blueprint and trading environment. 

39 

40 Parameters: 

41 model_blue_print (BluePrintBase): The model blueprint to be used for agent creation. 

42 trading_environment (TradingEnvironment): The trading environment in which the agent will operate. 

43 

44 Returns: 

45 (AgentBase): An instance of the agent created using the model blueprint and trading environment. 

46 """ 

47 

48 pass 

49 

50 @abstractmethod 

51 def fit(self, agent: AgentBase, environment: TradingEnvironment, nr_of_steps: int, nr_of_episodes: int, 

52 callbacks: list[Callback]) -> tuple[list[str], list[dict[str, Any]]]: 

53 """ 

54 Fits the agent to the environment using the specified number of steps and episodes. It also collects data 

55 for plotting summary statistics. 

56 

57 Parameters: 

58 agent (AgentBase): The agent to be fitted to the environment. 

59 environment (TradingEnvironment): The trading environment in which the agent will be trained. 

60 nr_of_steps (int): The number of steps to be taken during training. 

61 nr_of_episodes (int): The number of episodes to be run during training. 

62 callbacks (list[Callback]): A list of Keras callbacks to be used during training. 

63 

64 Returns: 

65 (tuple[list[str], list[dict[str, Any]]]): A tuple containing a list of keys and a list of dictionaries 

66 with the data collected during training. 

67 """ 

68 

69 data = {} 

70 

71 environment.set_mode(TradingEnvironment.TEST_MODE) 

72 data['test_part_price_movement'] = environment.get_data_for_iteration(['close']) 

73 data['test_part_volatility'] = environment.get_data_for_iteration(['volatility']) 

74 

75 environment.set_mode(TradingEnvironment.TRAIN_MODE) 

76 data['train_part_price_movement'] = environment.get_data_for_iteration(['close']) 

77 data['train_part_volatility'] = environment.get_data_for_iteration(['volatility']) 

78 

79 return [LearningStrategyHandlerBase.PLOTTING_KEY], [data] 

80 

81 def _provide_required_parameter(self, parameter: str, environment: TradingEnvironment) -> Any: 

82 """ 

83 Provides the required parameter for the given environment. 

84 

85 Parameters: 

86 parameter (str): The name of the parameter to be provided. 

87 environment (TradingEnvironment): The trading environment from which the parameter is to be provided. 

88 

89 Raises: 

90 ValueError: If the parameter is not supported by this environment configuration. 

91 

92 Returns: 

93 (Any): The value of the requested parameter. 

94 """ 

95 

96 if parameter in self.__parameter_provision_callbacks: 

97 return self.__parameter_provision_callbacks[parameter](environment) 

98 else: 

99 raise ValueError(f"Parameter '{parameter}' is not supported to provided by this environment configuration.") 

100 

101 @abstractmethod 

102 def _provide_input_shape(self, environment: TradingEnvironment) -> tuple[int, int]: 

103 """ 

104 Provides the input shape for the given environment. 

105 

106 Parameters: 

107 environment (TradingEnvironment): The trading environment for which the input shape is to be provided. 

108 

109 Returns: 

110 (tuple[int, int]): The input shape as a tuple. 

111 """ 

112 

113 pass 

114 

115 @abstractmethod 

116 def _provide_output_length(self, environment: TradingEnvironment) -> int: 

117 """ 

118 Provides the output length for the given environment. 

119 

120 Parameters: 

121 environment (TradingEnvironment): The trading environment for which the output length is to be provided. 

122 

123 Returns: 

124 (int): The output length. 

125 """ 

126 

127 pass 

128 

129 @abstractmethod 

130 def _provide_spatial_data_shape(self, environment: TradingEnvironment) -> tuple[int, int]: 

131 """ 

132 Provides the spatial data shape for the given environment. 

133 

134 Parameters: 

135 environment (TradingEnvironment): The trading environment for which the spatial data shape is to be provided. 

136 

137 Returns: 

138 (tuple[int, int]): The spatial data shape as a tuple. 

139 """ 

140 

141 pass