Coverage for source/environment/label_annotator_base.py: 52%
29 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-27 20:13 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-27 20:13 +0000
1# environment/label_annotator_base.py
3# global imports
4import pandas as pd
5from abc import ABC, abstractmethod
6from types import SimpleNamespace
7from typing import Optional
9# local imports
11class LabelAnnotatorBase(ABC):
12 """
13 Implements a base class for label annotators. It provides an interface for annotating
14 data with labels based on price movements.
15 """
17 # derived constants
18 _CLOSE_PRICE_COLUMN_NAME: str = "close"
19 _CLOSE_PRICE_CHANGE_COLUMN_NAME: str = "future_normalized_diff"
21 @abstractmethod
22 def __init__(self) -> None:
23 """
24 Class constructor. Initializes the output classes for classification.
25 """
27 self._output_classes: Optional[SimpleNamespace] = None
28 self._requested_columns: Optional[list[str]] = None
30 @abstractmethod
31 def _classify_trend(self, row: pd.Series) -> int:
32 """
33 Classifies the price movement trend based on the price difference.
35 Parameters:
36 row (pd.Series): The row of with data requested to classify.
38 Returns:
39 (int): The class label for the price movement trend.
40 """
42 pass
44 def annotate(self, data: pd.DataFrame) -> pd.Series:
45 """
46 Annotates the provided data with labels based on price movements.
48 Parameters:
49 data (pd.DataFrame): The data to annotate, must contain a 'close' column.
51 Raises:
52 ValueError: If the output classes are not initialized before annotating data.
54 Returns:
55 (pd.Series): A series of labels corresponding to the price movement trends.
56 """
58 if self._output_classes is None:
59 raise ValueError("Output classes must be initialized in derived classes " \
60 "before annotating data.")
62 if self._requested_columns is None:
63 self._requested_columns = [self._CLOSE_PRICE_CHANGE_COLUMN_NAME]
65 if self._CLOSE_PRICE_CHANGE_COLUMN_NAME in self._requested_columns:
66 current_prices = data[self._CLOSE_PRICE_COLUMN_NAME]
67 next_day_prices = data[self._CLOSE_PRICE_COLUMN_NAME].shift(-1)
68 future_normalized_diff = (next_day_prices - current_prices) / current_prices
69 data[self._CLOSE_PRICE_CHANGE_COLUMN_NAME] = future_normalized_diff
71 if missing_columns := set(self._requested_columns) - set(data.columns):
72 raise ValueError(f"Data is missing required columns: {missing_columns}")
74 return data[self._requested_columns].apply(self._classify_trend, axis = 1)[:-1]
76 def get_output_classes(self) -> SimpleNamespace:
77 """
78 Returns the output classes for the label annotator.
80 Returns:
81 (SimpleNamespace): The SimpleNamespace containing the labels for classification.
82 """
84 return self._output_classes