Coverage for source/environment/trading_environment.py: 88%
178 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-30 20:59 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-30 20:59 +0000
1# environment/trading environment.py
3# global imports
4import copy
5import logging
6import math
7import numpy as np
8import pandas as pd
9import random
10from gym import Env
11from gym.spaces import Box, Discrete
12from sklearn.preprocessing import StandardScaler
13from sklearn.model_selection import train_test_split
14from tensorflow.keras.utils import to_categorical
15from types import SimpleNamespace
16from typing import Any, Optional
18# local imports
19from source.environment import Broker, LabelAnnotatorBase, LabeledDataBalancer, RewardValidatorBase
21class TradingEnvironment(Env):
22 """
23 Implements stock market environment that actor can perform actions (place orders) in.
24 It is used to train various models using various approaches. Can be
25 configured to award points and impose a penalty in several ways.
26 """
28 # global class constants
29 TRAIN_MODE = 'train'
30 TEST_MODE = 'test'
32 def __init__(self, data: pd.DataFrame, initial_budget: float, max_amount_of_trades: int,
33 window_size: int, validator: RewardValidatorBase, label_annotator: LabelAnnotatorBase,
34 sell_stop_loss: float, sell_take_profit: float, buy_stop_loss: float, buy_take_profit: float,
35 test_ratio: float = 0.2, penalty_starts: int = 0, penalty_stops: int = 10,
36 static_reward_adjustment: float = 1, labeled_data_balancer: Optional[LabeledDataBalancer] = None,
37 meta_data: Optional[dict[str, Any]] = None) -> None:
38 """
39 Class constructor. Allows to define all crucial constans, reward validation methods,
40 environmental penalty policies, etc.
42 Parameters:
43 data (pd.DataFrame): DataFrame containing historical market data.
44 initial_budget (float): Initial budget constant for trader to start from.
45 max_amount_of_trades (int): Max amount of trades that can be ongoing at the same time.
46 Seting this constant prevents traders from placing orders randomly and defines
47 amount of money that can be assigned to a single order at certain iteration.
48 window_size (int): Constant defining how far in the past trader will be able to look
49 into at certain iteration.
50 validator (RewardValidatorBase): Validator implementing policy used to award points
51 for closed trades.
52 label_annotator (LabelAnnotatorBase): Annotator implementing policy used to label
53 data with target values. It is used to provide supervised agents with information
54 about what is the target class value for certain iteration.
55 sell_stop_loss (float): Constant used to define losing boundary at which sell order
56 (short) is closed.
57 sell_take_profit (float): Constant used to define winning boundary at which sell order
58 (short) is closed.
59 buy_stop_loss (float): Constant used to define losing boundary at which buy order
60 (long) is closed.
61 buy_take_profit (float): Constant used to define winning boundary at which buy order
62 (long) is closed.
63 test_ratio (float): Ratio of data that should be used for testing purposes.
64 penalty_starts (int): Constant defining how many trading periods can trader go without placing
65 an order until penalty is imposed. Penalty at range between start and stop constant
66 is calculated as percentile of positive reward, and subtracted from the actual reward.
67 penalty_stops (int): Constant defining at which trading period penalty will no longer be increased.
68 Reward for trading periods exceeding penalty stop constant will equal minus static reward adjustment.
69 static_reward_adjustment (float): Constant use to penalize trader for bad choices or
70 reward it for good one.
71 labeled_data_balancer (Optional[LabeledDataBalancer]): Balancer used to balance
72 labeled data. If None, no balancing will be performed.
73 meta_data (dict[str, Any]): Dictionary containing metadata about the dataset.
74 """
76 if test_ratio < 0.0 or test_ratio >= 1.0:
77 raise ValueError(f"Invalid test_ratio: {test_ratio}. It should be in range [0, 1).")
79 self.__data: dict[pd.DataFrame, pd.DataFrame] = self.__split_data(data, test_ratio)
80 self.__meta_data: Optional[dict[str, Any]] = meta_data
81 self.__mode = TradingEnvironment.TRAIN_MODE
82 self.__broker: Broker = Broker()
83 self.__validator: RewardValidatorBase = validator
84 self.__label_annotator: LabelAnnotatorBase = label_annotator
85 self.__labeled_data_balancer: Optional[LabeledDataBalancer] = labeled_data_balancer
87 self.__trading_data: SimpleNamespace = SimpleNamespace()
88 self.__trading_data.current_budget: float = initial_budget
89 self.__trading_data.currently_invested: float = 0
90 self.__trading_data.no_trades_placed_for: int = 0
91 self.__trading_data.currently_placed_trades: int = 0
93 self.__trading_consts = SimpleNamespace()
94 self.__trading_consts.INITIAL_BUDGET: float = initial_budget
95 self.__trading_consts.MAX_AMOUNT_OF_TRADES: int = max_amount_of_trades
96 self.__trading_consts.WINDOW_SIZE: int = window_size
97 self.__trading_consts.SELL_STOP_LOSS: float = sell_stop_loss
98 self.__trading_consts.SELL_TAKE_PROFIT: float = sell_take_profit
99 self.__trading_consts.BUY_STOP_LOSS: float = buy_stop_loss
100 self.__trading_consts.BUY_TAKE_PROFIT: float = buy_take_profit
101 self.__trading_consts.STATIC_REWARD_ADJUSTMENT: float = static_reward_adjustment
102 self.__trading_consts.PENALTY_STARTS: int = penalty_starts
103 self.__trading_consts.PENALTY_STOPS: int = penalty_stops
104 self.__trading_consts.PROFITABILITY_FUNCTION = lambda x: -1.0 * math.exp(-x + 1) + 1
105 self.__trading_consts.PENALTY_FUNCTION = lambda x: \
106 min(1, 1 - math.tanh(-3.0 * (x - penalty_stops) / (penalty_stops - penalty_starts)))
107 self.__trading_consts.OUTPUT_CLASSES: int = vars(self.__label_annotator.get_output_classes())
109 self.current_iteration: int = self.__trading_consts.WINDOW_SIZE
110 self.state: list[float] = self.__prepare_state_data()
111 self.action_space: Discrete = Discrete(3)
112 self.observation_space: Box = Box(low = np.ones(len(self.state)) * -3,
113 high = np.ones(len(self.state)) * 3,
114 dtype=np.float64)
116 def __split_data(self, data: pd.DataFrame, test_size: float) -> dict[pd.DataFrame, pd.DataFrame]:
117 """
118 Splits the given DataFrame into training and testing sets based on the specified test size ratio.
120 Parameters:
121 data (pd.DataFrame): DataFrame containing the stock market data.
122 test_size (float): Ratio of the data to be used for testing.
124 Returns:
125 (dict[pd.DataFrame, pd.DataFrame]): Dictionary containing training and testing data frames.
126 """
128 dividing_index = int(len(data) * (1 - test_size))
130 return {
131 TradingEnvironment.TRAIN_MODE: data.iloc[:dividing_index].reset_index(drop=True),
132 TradingEnvironment.TEST_MODE: data.iloc[dividing_index:].reset_index(drop=True)
133 }
135 def __prepare_labeled_data(self) -> pd.DataFrame:
136 """
137 Prepares labeled data for training the model with classification approach.
138 It extracts the relevant features and labels from the environment's data.
140 Returns:
141 (pd.DataFrame): A DataFrame containing the features and labels for training.
142 """
144 new_rows = []
145 for i in range(self.current_iteration, self.get_environment_length() - 1):
146 data_row = self.__prepare_state_data(slice(i - self.__trading_consts.WINDOW_SIZE, i), include_trading_data = False)
147 new_rows.append(data_row)
149 new_data = pd.DataFrame(new_rows, columns=[f"feature_{i}" for i in range(len(new_rows[0]))])
150 labels = self.__label_annotator.annotate(self.__data[self.__mode].copy()).shift(-self.current_iteration)
152 return new_data, labels.dropna()
154 def __prepare_state_data(self, index: Optional[slice] = None, include_trading_data: bool = True) -> list[float]:
155 """
156 Calculates state data as a list of floats representing current iteration's observation.
157 Observations contains all input data refined to window size and couple of coefficients
158 giving an insight into current budget and orders situation.
160 Returns:
161 (list[float]): List with current observations for environment.
162 """
164 if index is None:
165 index = slice(self.current_iteration - self.__trading_consts.WINDOW_SIZE, self.current_iteration)
167 current_market_data = self.__data[self.__mode].iloc[index]
168 current_market_data_no_index = current_market_data.select_dtypes(include = [np.number])
170 if self.__meta_data is not None and \
171 self.__meta_data.get('normalization_groups', None) is not None:
172 grouped_columns_names = self.__meta_data['normalization_groups']
173 preprocessed_data_pieces = []
174 left_over_columns_names = set(current_market_data_no_index.columns)
175 for columns_names_to_normalize in grouped_columns_names:
176 left_over_columns_names -= set(columns_names_to_normalize)
177 data_frame_piece_to_normalize = current_market_data_no_index[columns_names_to_normalize]
178 normalized_data_frame_piece = StandardScaler().fit_transform(data_frame_piece_to_normalize.values.reshape(-1, 1))
179 preprocessed_data_pieces.append(normalized_data_frame_piece.reshape(*data_frame_piece_to_normalize.shape))
180 for column in left_over_columns_names:
181 preprocessed_data_pieces.append(current_market_data_no_index[column].values.reshape(-1, 1))
182 normalized_current_market_data_values = np.hstack(preprocessed_data_pieces)
183 else:
184 normalized_current_market_data_values = StandardScaler().fit_transform(current_market_data_no_index)
185 current_marked_data_list = normalized_current_market_data_values.ravel().tolist()
187 if include_trading_data:
188 current_normalized_budget = 1.0 * self.__trading_data.current_budget / self.__trading_consts.INITIAL_BUDGET
189 current_profitability_coeff = self.__trading_consts.PROFITABILITY_FUNCTION(current_normalized_budget)
190 current_trades_occupancy_coeff = 1.0 * self.__trading_data.currently_placed_trades / self.__trading_consts.MAX_AMOUNT_OF_TRADES
191 current_no_trades_penalty_coeff = self.__trading_consts.PENALTY_FUNCTION(self.__trading_data.no_trades_placed_for)
192 current_inner_state_list = [current_profitability_coeff, current_trades_occupancy_coeff, current_no_trades_penalty_coeff]
193 current_marked_data_list += current_inner_state_list
195 return current_marked_data_list
197 def set_mode(self, mode: str) -> None:
198 """
199 Sets the mode of the environment to either TRAIN_MODE or TEST_MODE.
201 Parameters:
202 mode (str): Mode to set for the environment.
204 Raises:
205 ValueError: If the provided mode is not valid.
206 """
208 if mode not in [TradingEnvironment.TRAIN_MODE, TradingEnvironment.TEST_MODE]:
209 raise ValueError(f"Invalid mode: {mode}. Use TradingEnvironment.TRAIN_MODE or TradingEnvironment.TEST_MODE.")
210 self.__mode = mode
212 def get_mode(self) -> str:
213 """
214 Mode getter.
216 Returns:
217 (str): Current mode of the environment.
218 """
220 return copy.copy(self.__mode)
222 def get_trading_data(self) -> SimpleNamespace:
223 """
224 Trading data getter.
226 Returns:
227 (SimpleNamespace): Copy of the namespace with all trading data.
228 """
230 return copy.copy(self.__trading_data)
232 def get_trading_consts(self) -> SimpleNamespace:
233 """
234 Trading constants getter.
236 Returns:
237 (SimpleNamespace): Copy of the namespace with all trading constants.
238 """
240 return copy.copy(self.__trading_consts)
242 def get_broker(self) -> Broker:
243 """
244 Broker getter.
246 Returns:
247 (Broker): Copy of the broker used by environment.
248 """
250 return copy.copy(self.__broker)
252 def get_environment_length(self) -> int:
253 """
254 Environment length getter.
256 Returns:
257 (Int): Length of environment.
258 """
260 return len(self.__data[self.__mode])
262 def get_environment_spatial_data_dimension(self) -> tuple[int, int]:
263 """
264 Environment spatial data dimensionality getter.
266 Returns:
267 (Int): Dimension of spatial data in environment.
268 """
270 return (self.__trading_consts.WINDOW_SIZE, self.__data[self.__mode].shape[1] - 1)
272 def get_labeled_data(self, should_split: bool = True, should_balance: bool = True,
273 verbose: bool = True) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
274 """
275 Prepares labeled data for training or testing the model.
276 It extracts the relevant features and labels from the environment's data.
278 Parameters:
279 should_split (bool): Whether to split the data into training and testing sets.
280 Defaults to True. If set to False, testing data will be empty.
281 should_balance (bool): Whether to balance the labeled data. Defaults to True.
282 Will be ignored if labeled_data_balancer is None.
283 verbose (bool): Whether to log the class cardinality before and after balancing.
284 Defaults to True.
286 Returns:
287 (tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]): A tuple containing the
288 input data, output data, test input data, and test output data.
289 """
291 input_data, output_data = self.__prepare_labeled_data()
292 input_data_test, output_data_test = [], []
293 if verbose:
294 logging.info(f"Original class cardinality: {np.array(to_categorical(output_data)).sum(axis = 0)}")
296 if self.__mode == TradingEnvironment.TRAIN_MODE:
297 if should_split:
298 input_data, input_data_test, output_data, output_data_test = \
299 train_test_split(input_data, output_data, test_size = 0.1, random_state = 42,
300 stratify = output_data)
302 if self.__labeled_data_balancer is not None and should_balance:
303 input_data, output_data = self.__labeled_data_balancer.balance(input_data, output_data)
304 if verbose:
305 logging.info(f"Balanced class cardinality: {np.array(to_categorical(output_data)).sum(axis = 0)}")
307 return copy.copy((np.array(input_data), np.array(output_data),
308 np.array(input_data_test), np.array(output_data_test)))
310 def get_data_for_iteration(self, columns: list[str], start: int = 0, stop: Optional[int] = None,
311 step: int = 1) -> list[float]:
312 """
313 Data getter for certain iterations.
315 Parameters:
316 columns (list[str]): List of column names to extract from data.
317 start (int): Start iteration index. Defaults to 0.
318 stop (int): Stop iteration index. Defaults to environment length minus one.
319 step (int): Step between iterations. Defaults to 1.
321 Returns:
322 (list[float]): Copy of part of data with specified columns
323 over specified iterations.
324 """
326 if stop is None:
327 stop = self.get_environment_length() - 1
329 return copy.copy(self.__data[self.__mode].loc[start:stop:step, columns].values.ravel().tolist())
331 def step(self, action: int) -> tuple[list[float], float, bool, dict]:
332 """
333 Performs specified action on environment. It results in generation of the new
334 observations. This function causes trades to be handled, reward to be calculated and
335 environment to be updated.
337 Parameters:
338 action (int): Number specifing action. Possible values are 0 for buy action,
339 1 for wait action and 2 for sell action.
341 Returns:
342 (tuple[list[float], float, bool, dict]): Tuple containing next observation
343 state, reward, finish indication and additional info dictionary.
344 """
346 self.current_iteration += 1
347 self.state = self.__prepare_state_data()
349 close_changes = self.__data[self.__mode].iloc[self.current_iteration - 2 : self.current_iteration]['close'].values
350 stock_change_coeff = 1 + (close_changes[1] - close_changes[0]) / close_changes[0]
351 closed_orders= self.__broker.update_orders(stock_change_coeff)
353 reward = self.__validator.validate_orders(closed_orders)
354 self.__trading_data.currently_placed_trades -= len(closed_orders)
355 self.__trading_data.current_budget += np.sum([trade.current_value for trade in closed_orders])
356 self.__trading_data.currently_invested -= np.sum([trade.initial_value for trade in closed_orders])
358 number_of_possible_trades = self.__trading_consts.MAX_AMOUNT_OF_TRADES - self.__trading_data.currently_placed_trades
359 money_to_trade = 0
360 if number_of_possible_trades > 0:
361 money_to_trade = 1.0 / number_of_possible_trades * self.__trading_data.current_budget
363 if action == 0:
364 is_buy_order = True
365 stop_loss = self.__trading_consts.SELL_STOP_LOSS
366 take_profit = self.__trading_consts.SELL_TAKE_PROFIT
367 elif action == 2:
368 is_buy_order = False
369 stop_loss = self.__trading_consts.BUY_STOP_LOSS
370 take_profit = self.__trading_consts.BUY_TAKE_PROFIT
372 if action != 1:
373 if number_of_possible_trades > 0:
374 self.__trading_data.current_budget -= money_to_trade
375 self.__trading_data.currently_invested += money_to_trade
376 self.__broker.place_order(money_to_trade, is_buy_order, stop_loss, take_profit)
377 self.__trading_data.currently_placed_trades += 1
378 self.__trading_data.no_trades_placed_for = 0
379 reward += self.__trading_consts.STATIC_REWARD_ADJUSTMENT
380 else:
381 self.__trading_data.no_trades_placed_for += 1
382 reward -= self.__trading_consts.STATIC_REWARD_ADJUSTMENT
383 else:
384 self.__trading_data.no_trades_placed_for += 1
385 if number_of_possible_trades == 0:
386 reward += self.__trading_consts.STATIC_REWARD_ADJUSTMENT
388 if number_of_possible_trades > 0:
389 reward *= (1 - self.__trading_consts.PENALTY_FUNCTION(self.__trading_data.no_trades_placed_for)) \
390 if reward > 0 else 1
391 if self.__trading_consts.PENALTY_STOPS < self.__trading_data.no_trades_placed_for:
392 reward -= self.__trading_consts.STATIC_REWARD_ADJUSTMENT
394 if (self.current_iteration >= self.get_environment_length() - 1 or
395 self.__trading_data.current_budget > 10 * self.__trading_consts.INITIAL_BUDGET or
396 (self.__trading_data.current_budget + self.__trading_data.currently_invested) / self.__trading_consts.INITIAL_BUDGET < 0.8):
397 done = True
398 else:
399 done = False
401 info = {'coeff': stock_change_coeff,
402 'iteration': self.current_iteration,
403 'number_of_closed_orders': len(closed_orders),
404 'money_to_trade': money_to_trade,
405 'action': action,
406 'current_budget': self.__trading_data.current_budget,
407 'currently_invested': self.__trading_data.currently_invested,
408 'no_trades_placed_for': self.__trading_data.no_trades_placed_for,
409 'currently_placed_trades': self.__trading_data.currently_placed_trades}
411 return self.state, reward, done, info
413 def render(self) -> None:
414 """
415 Renders environment visualization. Will be implemented later.
416 """
418 #TODO: Visualization to be implemented
419 pass
421 def reset(self, randkey: Optional[int] = None) -> list[float]:
422 """
423 Resets environment. Used typically if environemnt is finished,
424 i.e. when ther is no more steps to be taken within environemnt
425 or finish conditions are fulfilled.
427 Parameters:
428 randkey (Optional[int]): Value indicating what iteration
429 should be trated as starting point after reset.
431 Returns:
432 (list[float]): Current iteration observation state.
433 """
435 if randkey is None:
436 randkey = random.randint(self.__trading_consts.WINDOW_SIZE, self.get_environment_length() - 1)
437 self.__trading_data.current_budget = self.__trading_consts.INITIAL_BUDGET
438 self.__trading_data.currently_invested = 0
439 self.__trading_data.no_trades_placed_for = 0
440 self.__trading_data.currently_placed_trades = 0
441 self.__broker.reset()
442 self.current_iteration = randkey
443 self.state = self.__prepare_state_data()
445 return self.state