Coverage for source/agent/agents/classification_learning_agent.py: 91%

23 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-07-30 20:59 +0000

1# agent/strategies/classification_learning_agent.py 

2 

3# global imports 

4import numpy as np 

5import pandas as pd 

6from tensorflow.keras.callbacks import Callback 

7from typing import Any 

8 

9# local imports 

10from source.agent import AgentBase, ClassificationTestable 

11 

12class ClassificationLearningAgent(AgentBase, ClassificationTestable): 

13 """ 

14 Implements a classification learning agent that can be trained and tested 

15 in a trading environment. It provides functionalities for fitting the model 

16 with classification data and making predictions. 

17 """ 

18 

19 def classification_fit(self, input_data: np.ndarray, output_data: np.ndarray, 

20 validation_data: tuple[np.ndarray, np.ndarray], batch_size: int, 

21 epochs: int, callbacks: list[Callback]) -> dict[str, Any]: 

22 """ 

23 Fits the model to the classification data. Parameters passed to the model adapter 

24 are dynamically determined based on the model's requirements. 

25 

26 Parameters: 

27 input_data (np.ndarray): The input data for training. 

28 output_data (np.ndarray): The output data for training. 

29 validation_data (tuple[np.ndarray, np.ndarray]): The validation data (input, output). 

30 batch_size (int): The batch size to be used during training. 

31 epochs (int): The number of epochs to train the model. 

32 callbacks (list[Callback]): A list of callbacks to be used during training. 

33 

34 Returns: 

35 (dict[str, Any]): A dictionary containing the training history and other relevant information. 

36 """ 

37 

38 provided_input_params = {} 

39 provided_input_params['input_data'] = input_data 

40 provided_input_params['output_data'] = output_data 

41 provided_input_params['validation_data'] = validation_data 

42 provided_input_params['batch_size'] = batch_size 

43 provided_input_params['epochs'] = epochs 

44 provided_input_params['callbacks'] = callbacks 

45 

46 parameters_needed_for_fitting = self._model_adapter.report_parameters_needed_for_fitting() 

47 kwargs = {} 

48 for parameter in parameters_needed_for_fitting: 

49 kwargs[parameter] = provided_input_params.get(parameter, None) 

50 

51 return self._model_adapter.fit(**kwargs) 

52 

53 def classify(self, data: pd.DataFrame) -> list[list[float]]: 

54 """ 

55 Classifies the input data using the trained model. 

56 

57 Parameters: 

58 data (pd.DataFrame): The input data to be classified. 

59 

60 Returns: 

61 (list[list[float]]): The predicted class probabilities for each input sample. 

62 """ 

63 

64 return self._model_adapter.predict(data) 

65 

66 def get_model_adapter_tag(self) -> str: 

67 """ 

68 Returns the tag of the model adapter. 

69 

70 Returns: 

71 (str): The tag of the model adapter. 

72 """ 

73 

74 return type(self._model_adapter).TAG