Coverage for source/environment/label_annotator_base.py: 52%

29 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-07-27 20:13 +0000

1# environment/label_annotator_base.py 

2 

3# global imports 

4import pandas as pd 

5from abc import ABC, abstractmethod 

6from types import SimpleNamespace 

7from typing import Optional 

8 

9# local imports 

10 

11class LabelAnnotatorBase(ABC): 

12 """ 

13 Implements a base class for label annotators. It provides an interface for annotating 

14 data with labels based on price movements. 

15 """ 

16 

17 # derived constants 

18 _CLOSE_PRICE_COLUMN_NAME: str = "close" 

19 _CLOSE_PRICE_CHANGE_COLUMN_NAME: str = "future_normalized_diff" 

20 

21 @abstractmethod 

22 def __init__(self) -> None: 

23 """ 

24 Class constructor. Initializes the output classes for classification. 

25 """ 

26 

27 self._output_classes: Optional[SimpleNamespace] = None 

28 self._requested_columns: Optional[list[str]] = None 

29 

30 @abstractmethod 

31 def _classify_trend(self, row: pd.Series) -> int: 

32 """ 

33 Classifies the price movement trend based on the price difference. 

34 

35 Parameters: 

36 row (pd.Series): The row of with data requested to classify. 

37 

38 Returns: 

39 (int): The class label for the price movement trend. 

40 """ 

41 

42 pass 

43 

44 def annotate(self, data: pd.DataFrame) -> pd.Series: 

45 """ 

46 Annotates the provided data with labels based on price movements. 

47 

48 Parameters: 

49 data (pd.DataFrame): The data to annotate, must contain a 'close' column. 

50 

51 Raises: 

52 ValueError: If the output classes are not initialized before annotating data. 

53 

54 Returns: 

55 (pd.Series): A series of labels corresponding to the price movement trends. 

56 """ 

57 

58 if self._output_classes is None: 

59 raise ValueError("Output classes must be initialized in derived classes " \ 

60 "before annotating data.") 

61 

62 if self._requested_columns is None: 

63 self._requested_columns = [self._CLOSE_PRICE_CHANGE_COLUMN_NAME] 

64 

65 if self._CLOSE_PRICE_CHANGE_COLUMN_NAME in self._requested_columns: 

66 current_prices = data[self._CLOSE_PRICE_COLUMN_NAME] 

67 next_day_prices = data[self._CLOSE_PRICE_COLUMN_NAME].shift(-1) 

68 future_normalized_diff = (next_day_prices - current_prices) / current_prices 

69 data[self._CLOSE_PRICE_CHANGE_COLUMN_NAME] = future_normalized_diff 

70 

71 if missing_columns := set(self._requested_columns) - set(data.columns): 

72 raise ValueError(f"Data is missing required columns: {missing_columns}") 

73 

74 return data[self._requested_columns].apply(self._classify_trend, axis = 1)[:-1] 

75 

76 def get_output_classes(self) -> SimpleNamespace: 

77 """ 

78 Returns the output classes for the label annotator. 

79 

80 Returns: 

81 (SimpleNamespace): The SimpleNamespace containing the labels for classification. 

82 """ 

83 

84 return self._output_classes