Coverage for source/environment/label_annotator_base.py: 43%
28 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-04 20:03 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-04 20:03 +0000
1# environment/label_annotator_base.py
3# global imports
4import pandas as pd
5from abc import ABC, abstractmethod
6from types import SimpleNamespace
7from typing import Optional
9# local imports
11class LabelAnnotatorBase(ABC):
12 """
13 Implements a base class for label annotators. It provides an interface for annotating
14 data with labels based on price movements.
15 """
17 # derived constants
18 _CLOSE_PRICE_COLUMN_NAME: str = "close"
19 _CLOSE_PRICE_CHANGE_COLUMN_NAME: str = "future_normalized_diff"
21 def __init__(self) -> None:
22 """
23 Class constructor. Initializes the output classes for classification.
24 """
26 self._output_classes: Optional[SimpleNamespace] = None
27 self._requested_columns: Optional[list[str]] = None
29 @abstractmethod
30 def _classify_trend(self, row: pd.Series) -> int:
31 """
32 Classifies the price movement trend based on the price difference.
34 Parameters:
35 row (pd.Series): The row of with data requested to classify.
37 Returns:
38 (int): The class label for the price movement trend.
39 """
41 pass
43 def annotate(self, data: pd.DataFrame) -> pd.Series:
44 """
45 Annotates the provided data with labels based on price movements.
47 Parameters:
48 data (pd.DataFrame): The data to annotate, must contain a 'close' column.
50 Raises:
51 ValueError: If the output classes are not initialized before annotating data.
53 Returns:
54 (pd.Series): A series of labels corresponding to the price movement trends.
55 """
57 if self._output_classes is None:
58 raise ValueError("Output classes must be initialized in derived classes " \
59 "before annotating data.")
61 if self._requested_columns is None:
62 self._requested_columns = [self._CLOSE_PRICE_CHANGE_COLUMN_NAME]
64 if self._CLOSE_PRICE_CHANGE_COLUMN_NAME in self._requested_columns:
65 current_prices = data[self._CLOSE_PRICE_COLUMN_NAME]
66 next_day_prices = data[self._CLOSE_PRICE_COLUMN_NAME].shift(-1)
67 future_normalized_diff = (next_day_prices - current_prices) / current_prices
68 data[self._CLOSE_PRICE_CHANGE_COLUMN_NAME] = future_normalized_diff
70 if missing_columns := set(self._requested_columns) - set(data.columns):
71 raise ValueError(f"Data is missing required columns: {missing_columns}")
73 return data[self._requested_columns].apply(self._classify_trend, axis = 1)
75 def get_output_classes(self) -> SimpleNamespace:
76 """
77 Returns the output classes for the label annotator.
79 Returns:
80 (SimpleNamespace): The SimpleNamespace containing the labels for classification.
81 """
83 return self._output_classes