Coverage for source/agent/agents/agent_base.py: 73%

11 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-07-30 20:59 +0000

1# agent/agents/agent_base.py 

2 

3# global imports 

4from typing import Callable, Optional 

5 

6# local imports 

7from source.model import ModelAdapterBase 

8 

9class AgentBase(): 

10 """ 

11 Implements base class for agents that can be trained and tested in a trading environment. 

12 It provides basic functionalities such as loading and saving models,printing model summaries. 

13 """ 

14 

15 def __init__(self, model_adapter: ModelAdapterBase) -> None: 

16 """ 

17 Class constructor. Initializes the agent with the given model adapter. 

18 

19 Parameters: 

20 model_adapter (ModelAdapterBase): The model adapter to be used for the agent. 

21 """ 

22 

23 self._model_adapter: ModelAdapterBase = model_adapter 

24 

25 def load_model(self, model_path: str) -> None: 

26 """ 

27 Loads the model from the specified path. 

28 

29 Parameters: 

30 model_path (str): The path to the model file. 

31 """ 

32 

33 self._model_adapter.load_model(model_path) 

34 

35 def save_model(self, model_path: str) -> None: 

36 """ 

37 Saves the model to the specified path. 

38 

39 Parameters: 

40 model_path (str): The path to the model file. 

41 """ 

42 

43 self._model_adapter.save_model(model_path) 

44 

45 def print_summary(self, print_function: Optional[Callable] = print) -> None: 

46 """ 

47 Prints a summary of the model architecture and parameters. 

48 

49 Parameters: 

50 print_function (Optional[Callable]): A function to print the summary. Defaults to print. 

51 """ 

52 

53 self._model_adapter.print_summary(print_function)