Coverage for source/environment/label_annotator_base.py: 59%
17 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-06 12:00 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-06 12:00 +0000
1# agent/label_annotator_base.py
3# global imports
4from abc import ABC, abstractmethod
5import pandas as pd
6from types import SimpleNamespace
8# local imports
10class LabelAnnotatorBase(ABC):
11 """"""
13 __CLOSE_PRICE_COLUMN_NAME: str = "close"
15 def __init__(self) -> None:
16 """"""
18 self._output_classes = SimpleNamespace()
20 @abstractmethod
21 def _classify_trend(self, price_diff: float) -> int:
22 """"""
24 pass
26 def annotate(self, data: pd.DataFrame) -> pd.Series:
27 """"""
29 current_prices = data[self.__CLOSE_PRICE_COLUMN_NAME]
30 next_day_prices = data[self.__CLOSE_PRICE_COLUMN_NAME].shift(-1)
31 price_diffs = (next_day_prices - current_prices) / current_prices
33 return price_diffs.apply(self._classify_trend)
35 def get_output_classes(self) -> SimpleNamespace:
36 """"""
38 return self._output_classes