Coverage for source/environment/label_annotator_base.py: 67%
18 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-27 17:11 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-27 17:11 +0000
1# environment/label_annotator_base.py
3# global imports
4import pandas as pd
5from abc import ABC, abstractmethod
6from types import SimpleNamespace
8# local imports
10class LabelAnnotatorBase(ABC):
11 """
12 Implements a base class for label annotators. It provides an interface for annotating
13 data with labels based on price movements.
14 """
16 # local constants
17 __CLOSE_PRICE_COLUMN_NAME: str = "close"
19 @abstractmethod
20 def __init__(self) -> None:
21 """
22 Class constructor. Initializes the output classes for classification.
23 """
25 self._output_classes = SimpleNamespace()
27 @abstractmethod
28 def _classify_trend(self, price_diff: float) -> int:
29 """
30 Classifies the price movement trend based on the price difference.
32 Parameters:
33 price_diff (float): The price difference to classify.
35 Returns:
36 (int): The class label for the price movement trend.
37 """
39 pass
41 def annotate(self, data: pd.DataFrame) -> pd.Series:
42 """
43 Annotates the provided data with labels based on price movements.
45 Parameters:
46 data (pd.DataFrame): The data to annotate, must contain a 'close' column.
48 Returns:
49 (pd.Series): A series of labels corresponding to the price movement trends.
50 """
52 current_prices = data[self.__CLOSE_PRICE_COLUMN_NAME]
53 next_day_prices = data[self.__CLOSE_PRICE_COLUMN_NAME].shift(-1)
54 price_diffs = (next_day_prices - current_prices) / current_prices
56 return price_diffs.dropna().apply(self._classify_trend)
58 def get_output_classes(self) -> SimpleNamespace:
59 """
60 Returns the output classes for the label annotator.
62 Returns:
63 (SimpleNamespace): The SimpleNamespace containing the labels for classification.
64 """
66 return self._output_classes