Coverage for source/agent/agents/classification_learning_agent.py: 91%
23 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-30 20:59 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-30 20:59 +0000
1# agent/strategies/classification_learning_agent.py
3# global imports
4import numpy as np
5import pandas as pd
6from tensorflow.keras.callbacks import Callback
7from typing import Any
9# local imports
10from source.agent import AgentBase, ClassificationTestable
12class ClassificationLearningAgent(AgentBase, ClassificationTestable):
13 """
14 Implements a classification learning agent that can be trained and tested
15 in a trading environment. It provides functionalities for fitting the model
16 with classification data and making predictions.
17 """
19 def classification_fit(self, input_data: np.ndarray, output_data: np.ndarray,
20 validation_data: tuple[np.ndarray, np.ndarray], batch_size: int,
21 epochs: int, callbacks: list[Callback]) -> dict[str, Any]:
22 """
23 Fits the model to the classification data. Parameters passed to the model adapter
24 are dynamically determined based on the model's requirements.
26 Parameters:
27 input_data (np.ndarray): The input data for training.
28 output_data (np.ndarray): The output data for training.
29 validation_data (tuple[np.ndarray, np.ndarray]): The validation data (input, output).
30 batch_size (int): The batch size to be used during training.
31 epochs (int): The number of epochs to train the model.
32 callbacks (list[Callback]): A list of callbacks to be used during training.
34 Returns:
35 (dict[str, Any]): A dictionary containing the training history and other relevant information.
36 """
38 provided_input_params = {}
39 provided_input_params['input_data'] = input_data
40 provided_input_params['output_data'] = output_data
41 provided_input_params['validation_data'] = validation_data
42 provided_input_params['batch_size'] = batch_size
43 provided_input_params['epochs'] = epochs
44 provided_input_params['callbacks'] = callbacks
46 parameters_needed_for_fitting = self._model_adapter.report_parameters_needed_for_fitting()
47 kwargs = {}
48 for parameter in parameters_needed_for_fitting:
49 kwargs[parameter] = provided_input_params.get(parameter, None)
51 return self._model_adapter.fit(**kwargs)
53 def classify(self, data: pd.DataFrame) -> list[list[float]]:
54 """
55 Classifies the input data using the trained model.
57 Parameters:
58 data (pd.DataFrame): The input data to be classified.
60 Returns:
61 (list[list[float]]): The predicted class probabilities for each input sample.
62 """
64 return self._model_adapter.predict(data)
66 def get_model_adapter_tag(self) -> str:
67 """
68 Returns the tag of the model adapter.
70 Returns:
71 (str): The tag of the model adapter.
72 """
74 return type(self._model_adapter).TAG