Coverage for source/environment/label_annotator_base.py: 67%

18 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-07-27 17:11 +0000

1# environment/label_annotator_base.py 

2 

3# global imports 

4import pandas as pd 

5from abc import ABC, abstractmethod 

6from types import SimpleNamespace 

7 

8# local imports 

9 

10class LabelAnnotatorBase(ABC): 

11 """ 

12 Implements a base class for label annotators. It provides an interface for annotating 

13 data with labels based on price movements. 

14 """ 

15 

16 # local constants 

17 __CLOSE_PRICE_COLUMN_NAME: str = "close" 

18 

19 @abstractmethod 

20 def __init__(self) -> None: 

21 """ 

22 Class constructor. Initializes the output classes for classification. 

23 """ 

24 

25 self._output_classes = SimpleNamespace() 

26 

27 @abstractmethod 

28 def _classify_trend(self, price_diff: float) -> int: 

29 """ 

30 Classifies the price movement trend based on the price difference. 

31 

32 Parameters: 

33 price_diff (float): The price difference to classify. 

34 

35 Returns: 

36 (int): The class label for the price movement trend. 

37 """ 

38 

39 pass 

40 

41 def annotate(self, data: pd.DataFrame) -> pd.Series: 

42 """ 

43 Annotates the provided data with labels based on price movements. 

44 

45 Parameters: 

46 data (pd.DataFrame): The data to annotate, must contain a 'close' column. 

47 

48 Returns: 

49 (pd.Series): A series of labels corresponding to the price movement trends. 

50 """ 

51 

52 current_prices = data[self.__CLOSE_PRICE_COLUMN_NAME] 

53 next_day_prices = data[self.__CLOSE_PRICE_COLUMN_NAME].shift(-1) 

54 price_diffs = (next_day_prices - current_prices) / current_prices 

55 

56 return price_diffs.dropna().apply(self._classify_trend) 

57 

58 def get_output_classes(self) -> SimpleNamespace: 

59 """ 

60 Returns the output classes for the label annotator. 

61 

62 Returns: 

63 (SimpleNamespace): The SimpleNamespace containing the labels for classification. 

64 """ 

65 

66 return self._output_classes