Coverage for source/environment/simple_label_annotator.py: 64%

14 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-07-27 17:11 +0000

1# environment/simple_label_annotator.py 

2 

3# global imports 

4 

5# local imports 

6from source.environment import LabelAnnotatorBase 

7 

8class SimpleLabelAnnotator(LabelAnnotatorBase): 

9 """ 

10 Implements a simple label annotator that classifies price movements into three classes: 

11 - Up trend 

12 - Down trend 

13 - No trend 

14 """ 

15 

16 def __init__(self, threshold: float = 0.01) -> None: 

17 """ 

18 Class constructor. Initializes the SimpleLabelAnnotator with a specified threshold for trend classification. 

19 

20 Parameters: 

21 threshold (float): The threshold for classifying price movements. 

22 """ 

23 

24 super().__init__() 

25 self._output_classes.UP_TREND = 0 

26 self._output_classes.DOWN_TREND = 1 

27 self._output_classes.NO_TREND = 2 

28 self.__threshold = threshold 

29 

30 def _classify_trend(self, price_diff: float) -> int: 

31 """ 

32 Classifies the price movement trend based on the price difference. 

33 

34 Parameters: 

35 price_diff (float): The price difference to classify. 

36 

37 Returns: 

38 (int): The class label for the price movement trend. 

39 """ 

40 

41 if price_diff > self.__threshold: 

42 return self._output_classes.UP_TREND 

43 elif price_diff < -self.__threshold: 

44 return self._output_classes.DOWN_TREND 

45 else: 

46 return self._output_classes.NO_TREND