Coverage for source/environment/simple_label_annotator.py: 67%
18 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-01 20:51 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-01 20:51 +0000
1# environment/simple_label_annotator.py
3# global imports
4import pandas as pd
5from types import SimpleNamespace
7# local imports
8from source.environment import LabelAnnotatorBase
10class SimpleLabelAnnotator(LabelAnnotatorBase):
11 """
12 Implements a simple label annotator that classifies price movements into three classes:
13 - Up trend
14 - Down trend
15 - No trend
16 """
18 def __init__(self, threshold: float = 0.01) -> None:
19 """
20 Class constructor. Initializes the SimpleLabelAnnotator with a specified threshold for trend classification.
22 Parameters:
23 threshold (float): The threshold for classifying price movements.
24 """
26 super().__init__()
27 self._output_classes = SimpleNamespace()
28 self._output_classes.UP_TREND = 0
29 self._output_classes.DOWN_TREND = 1
30 self._output_classes.NO_TREND = 2
31 self.__threshold = threshold
33 def _classify_trend(self, row: pd.Series) -> int:
34 """
35 Classifies the price movement trend based on the price difference.
37 Parameters:
38 row (pd.Series): The row of with data requested to classify.
40 Returns:
41 (int): The class label for the price movement trend.
42 """
44 price_diff = row[self._CLOSE_PRICE_CHANGE_COLUMN_NAME]
45 if price_diff > self.__threshold:
46 return self._output_classes.UP_TREND
47 elif price_diff < -self.__threshold:
48 return self._output_classes.DOWN_TREND
49 else:
50 return self._output_classes.NO_TREND