Coverage for source/environment/simple_label_annotator.py: 64%
14 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-27 17:11 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-27 17:11 +0000
1# environment/simple_label_annotator.py
3# global imports
5# local imports
6from source.environment import LabelAnnotatorBase
8class SimpleLabelAnnotator(LabelAnnotatorBase):
9 """
10 Implements a simple label annotator that classifies price movements into three classes:
11 - Up trend
12 - Down trend
13 - No trend
14 """
16 def __init__(self, threshold: float = 0.01) -> None:
17 """
18 Class constructor. Initializes the SimpleLabelAnnotator with a specified threshold for trend classification.
20 Parameters:
21 threshold (float): The threshold for classifying price movements.
22 """
24 super().__init__()
25 self._output_classes.UP_TREND = 0
26 self._output_classes.DOWN_TREND = 1
27 self._output_classes.NO_TREND = 2
28 self.__threshold = threshold
30 def _classify_trend(self, price_diff: float) -> int:
31 """
32 Classifies the price movement trend based on the price difference.
34 Parameters:
35 price_diff (float): The price difference to classify.
37 Returns:
38 (int): The class label for the price movement trend.
39 """
41 if price_diff > self.__threshold:
42 return self._output_classes.UP_TREND
43 elif price_diff < -self.__threshold:
44 return self._output_classes.DOWN_TREND
45 else:
46 return self._output_classes.NO_TREND