Coverage for source/environment/simple_label_annotator.py: 67%

18 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-07-30 20:59 +0000

1# environment/simple_label_annotator.py 

2 

3# global imports 

4import pandas as pd 

5from types import SimpleNamespace 

6 

7# local imports 

8from source.environment import LabelAnnotatorBase 

9 

10class SimpleLabelAnnotator(LabelAnnotatorBase): 

11 """ 

12 Implements a simple label annotator that classifies price movements into three classes: 

13 - Up trend 

14 - Down trend 

15 - No trend 

16 """ 

17 

18 def __init__(self, threshold: float = 0.01) -> None: 

19 """ 

20 Class constructor. Initializes the SimpleLabelAnnotator with a specified threshold for trend classification. 

21 

22 Parameters: 

23 threshold (float): The threshold for classifying price movements. 

24 """ 

25 

26 super().__init__() 

27 self._output_classes = SimpleNamespace() 

28 self._output_classes.UP_TREND = 0 

29 self._output_classes.DOWN_TREND = 1 

30 self._output_classes.NO_TREND = 2 

31 self.__threshold = threshold 

32 

33 def _classify_trend(self, row: pd.Series) -> int: 

34 """ 

35 Classifies the price movement trend based on the price difference. 

36 

37 Parameters: 

38 row (pd.Series): The row of with data requested to classify. 

39 

40 Returns: 

41 (int): The class label for the price movement trend. 

42 """ 

43 

44 price_diff = row[self._CLOSE_PRICE_CHANGE_COLUMN_NAME] 

45 if price_diff > self.__threshold: 

46 return self._output_classes.UP_TREND 

47 elif price_diff < -self.__threshold: 

48 return self._output_classes.DOWN_TREND 

49 else: 

50 return self._output_classes.NO_TREND