Coverage for source/agent/agents/classification_learning_agent.py: 91%
22 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-01 20:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-01 20:51 +0000
1# agent/agents/classification_learning_agent.py
3# global imports
4import numpy as np
5from tensorflow.keras.callbacks import Callback
6from typing import Any
8# local imports
9from source.agent import AgentBase, ClassificationTestable
11class ClassificationLearningAgent(AgentBase, ClassificationTestable):
12 """
13 Implements a classification learning agent that can be trained and tested
14 in a trading environment. It provides functionalities for fitting the model
15 with classification data and making predictions.
16 """
18 def classification_fit(self, input_data: np.ndarray, output_data: np.ndarray,
19 validation_data: tuple[np.ndarray, np.ndarray], batch_size: int,
20 epochs: int, callbacks: list[Callback]) -> dict[str, Any]:
21 """
22 Fits the model to the classification data. Parameters passed to the model adapter
23 are dynamically determined based on the model's requirements.
25 Parameters:
26 input_data (np.ndarray): The input data for training.
27 output_data (np.ndarray): The output data for training.
28 validation_data (tuple[np.ndarray, np.ndarray]): The validation data (input, output).
29 batch_size (int): The batch size to be used during training.
30 epochs (int): The number of epochs to train the model.
31 callbacks (list[Callback]): A list of callbacks to be used during training.
33 Returns:
34 (dict[str, Any]): A dictionary containing the training history and other relevant information.
35 """
37 provided_input_params = {}
38 provided_input_params['input_data'] = input_data
39 provided_input_params['output_data'] = output_data
40 provided_input_params['validation_data'] = validation_data
41 provided_input_params['batch_size'] = batch_size
42 provided_input_params['epochs'] = epochs
43 provided_input_params['callbacks'] = callbacks
45 parameters_needed_for_fitting = self._model_adapter.report_parameters_needed_for_fitting()
46 kwargs = {}
47 for parameter in parameters_needed_for_fitting:
48 kwargs[parameter] = provided_input_params.get(parameter, None)
50 return self._model_adapter.fit(**kwargs)
52 def classify(self, data: np.ndarray) -> list[list[float]]:
53 """
54 Classifies the input data using the trained model.
56 Parameters:
57 data (np.ndarray): The input data to be classified.
59 Returns:
60 (list[list[float]]): The predicted class probabilities for each input sample.
61 """
63 return self._model_adapter.predict(data)
65 def get_model_adapter_tag(self) -> str:
66 """
67 Returns the tag of the model adapter.
69 Returns:
70 (str): The tag of the model adapter.
71 """
73 return type(self._model_adapter).TAG