Coverage for source/agent/agents/classification_learning_agent.py: 91%

22 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-08-01 20:51 +0000

1# agent/agents/classification_learning_agent.py 

2 

3# global imports 

4import numpy as np 

5from tensorflow.keras.callbacks import Callback 

6from typing import Any 

7 

8# local imports 

9from source.agent import AgentBase, ClassificationTestable 

10 

11class ClassificationLearningAgent(AgentBase, ClassificationTestable): 

12 """ 

13 Implements a classification learning agent that can be trained and tested 

14 in a trading environment. It provides functionalities for fitting the model 

15 with classification data and making predictions. 

16 """ 

17 

18 def classification_fit(self, input_data: np.ndarray, output_data: np.ndarray, 

19 validation_data: tuple[np.ndarray, np.ndarray], batch_size: int, 

20 epochs: int, callbacks: list[Callback]) -> dict[str, Any]: 

21 """ 

22 Fits the model to the classification data. Parameters passed to the model adapter 

23 are dynamically determined based on the model's requirements. 

24 

25 Parameters: 

26 input_data (np.ndarray): The input data for training. 

27 output_data (np.ndarray): The output data for training. 

28 validation_data (tuple[np.ndarray, np.ndarray]): The validation data (input, output). 

29 batch_size (int): The batch size to be used during training. 

30 epochs (int): The number of epochs to train the model. 

31 callbacks (list[Callback]): A list of callbacks to be used during training. 

32 

33 Returns: 

34 (dict[str, Any]): A dictionary containing the training history and other relevant information. 

35 """ 

36 

37 provided_input_params = {} 

38 provided_input_params['input_data'] = input_data 

39 provided_input_params['output_data'] = output_data 

40 provided_input_params['validation_data'] = validation_data 

41 provided_input_params['batch_size'] = batch_size 

42 provided_input_params['epochs'] = epochs 

43 provided_input_params['callbacks'] = callbacks 

44 

45 parameters_needed_for_fitting = self._model_adapter.report_parameters_needed_for_fitting() 

46 kwargs = {} 

47 for parameter in parameters_needed_for_fitting: 

48 kwargs[parameter] = provided_input_params.get(parameter, None) 

49 

50 return self._model_adapter.fit(**kwargs) 

51 

52 def classify(self, data: np.ndarray) -> list[list[float]]: 

53 """ 

54 Classifies the input data using the trained model. 

55 

56 Parameters: 

57 data (np.ndarray): The input data to be classified. 

58 

59 Returns: 

60 (list[list[float]]): The predicted class probabilities for each input sample. 

61 """ 

62 

63 return self._model_adapter.predict(data) 

64 

65 def get_model_adapter_tag(self) -> str: 

66 """ 

67 Returns the tag of the model adapter. 

68 

69 Returns: 

70 (str): The tag of the model adapter. 

71 """ 

72 

73 return type(self._model_adapter).TAG