Coverage for source/environment/label_annotator_base.py: 59%

17 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-06-06 12:00 +0000

1# agent/label_annotator_base.py 

2 

3# global imports 

4from abc import ABC, abstractmethod 

5import pandas as pd 

6from types import SimpleNamespace 

7 

8# local imports 

9 

10class LabelAnnotatorBase(ABC): 

11 """""" 

12 

13 __CLOSE_PRICE_COLUMN_NAME: str = "close" 

14 

15 def __init__(self) -> None: 

16 """""" 

17 

18 self._output_classes = SimpleNamespace() 

19 

20 @abstractmethod 

21 def _classify_trend(self, price_diff: float) -> int: 

22 """""" 

23 

24 pass 

25 

26 def annotate(self, data: pd.DataFrame) -> pd.Series: 

27 """""" 

28 

29 current_prices = data[self.__CLOSE_PRICE_COLUMN_NAME] 

30 next_day_prices = data[self.__CLOSE_PRICE_COLUMN_NAME].shift(-1) 

31 price_diffs = (next_day_prices - current_prices) / current_prices 

32 

33 return price_diffs.apply(self._classify_trend) 

34 

35 def get_output_classes(self) -> SimpleNamespace: 

36 """""" 

37 

38 return self._output_classes