Coverage for source/agent/agents/classification_learning_agent.py: 82%

11 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-06-06 12:00 +0000

1# agent/strategies/classification_learning_agent.py 

2 

3# global imports 

4import numpy as np 

5import pandas as pd 

6from typing import Any 

7from tensorflow.keras.callbacks import Callback 

8 

9# local imports 

10from source.agent import AgentBase 

11from source.agent import ClassificationTestable 

12 

13class ClassificationLearningAgent(AgentBase, ClassificationTestable): 

14 """""" 

15 

16 def classification_fit(self, input_data, output_data, batch_size: int, 

17 epochs: int, callbacks: list[Callback]) -> dict[str, Any]: 

18 """""" 

19 

20 return self._model_adapter.fit(input_data, output_data, batch_size = batch_size, 

21 epochs = epochs, callbacks = callbacks).history 

22 

23 def classify(self, data: pd.DataFrame) -> list[list[float]]: 

24 """""" 

25 

26 return self._model_adapter.predict(data)