Coverage for source/environment/label_annotator_base.py: 50%

28 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-07-30 19:45 +0000

1# environment/label_annotator_base.py 

2 

3# global imports 

4import pandas as pd 

5from abc import ABC, abstractmethod 

6from types import SimpleNamespace 

7from typing import Optional 

8 

9# local imports 

10 

11class LabelAnnotatorBase(ABC): 

12 """ 

13 Implements a base class for label annotators. It provides an interface for annotating 

14 data with labels based on price movements. 

15 """ 

16 

17 # derived constants 

18 _CLOSE_PRICE_COLUMN_NAME: str = "close" 

19 _CLOSE_PRICE_CHANGE_COLUMN_NAME: str = "future_normalized_diff" 

20 

21 def __init__(self) -> None: 

22 """ 

23 Class constructor. Initializes the output classes for classification. 

24 """ 

25 

26 self._output_classes: Optional[SimpleNamespace] = None 

27 self._requested_columns: Optional[list[str]] = None 

28 

29 @abstractmethod 

30 def _classify_trend(self, row: pd.Series) -> int: 

31 """ 

32 Classifies the price movement trend based on the price difference. 

33 

34 Parameters: 

35 row (pd.Series): The row of with data requested to classify. 

36 

37 Returns: 

38 (int): The class label for the price movement trend. 

39 """ 

40 

41 pass 

42 

43 def annotate(self, data: pd.DataFrame) -> pd.Series: 

44 """ 

45 Annotates the provided data with labels based on price movements. 

46 

47 Parameters: 

48 data (pd.DataFrame): The data to annotate, must contain a 'close' column. 

49 

50 Raises: 

51 ValueError: If the output classes are not initialized before annotating data. 

52 

53 Returns: 

54 (pd.Series): A series of labels corresponding to the price movement trends. 

55 """ 

56 

57 if self._output_classes is None: 

58 raise ValueError("Output classes must be initialized in derived classes " \ 

59 "before annotating data.") 

60 

61 if self._requested_columns is None: 

62 self._requested_columns = [self._CLOSE_PRICE_CHANGE_COLUMN_NAME] 

63 

64 if self._CLOSE_PRICE_CHANGE_COLUMN_NAME in self._requested_columns: 

65 current_prices = data[self._CLOSE_PRICE_COLUMN_NAME] 

66 next_day_prices = data[self._CLOSE_PRICE_COLUMN_NAME].shift(-1) 

67 future_normalized_diff = (next_day_prices - current_prices) / current_prices 

68 data[self._CLOSE_PRICE_CHANGE_COLUMN_NAME] = future_normalized_diff 

69 

70 if missing_columns := set(self._requested_columns) - set(data.columns): 

71 raise ValueError(f"Data is missing required columns: {missing_columns}") 

72 

73 return data[self._requested_columns].apply(self._classify_trend, axis = 1)[:-1] 

74 

75 def get_output_classes(self) -> SimpleNamespace: 

76 """ 

77 Returns the output classes for the label annotator. 

78 

79 Returns: 

80 (SimpleNamespace): The SimpleNamespace containing the labels for classification. 

81 """ 

82 

83 return self._output_classes