Coverage for source/environment/trading_environment.py: 90%
174 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-27 17:11 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-27 17:11 +0000
1# environment/trading environment.py
3# global imports
4import copy
5import logging
6import math
7import numpy as np
8import pandas as pd
9import random
10from gym import Env
11from gym.spaces import Box, Discrete
12from sklearn.preprocessing import StandardScaler
13from sklearn.model_selection import train_test_split
14from tensorflow.keras.utils import to_categorical
15from types import SimpleNamespace
16from typing import Any, Optional
18# local imports
19from source.environment import Broker, LabelAnnotatorBase, LabeledDataBalancer, RewardValidatorBase
21class TradingEnvironment(Env):
22 """
23 Implements stock market environment that actor can perform actions (place orders) in.
24 It is used to train various models using various approaches. Can be
25 configured to award points and impose a penalty in several ways.
26 """
28 # global class constants
29 TRAIN_MODE = 'train'
30 TEST_MODE = 'test'
32 def __init__(self, data: pd.DataFrame, initial_budget: float, max_amount_of_trades: int,
33 window_size: int, validator: RewardValidatorBase, label_annotator: LabelAnnotatorBase,
34 sell_stop_loss: float, sell_take_profit: float, buy_stop_loss: float, buy_take_profit: float,
35 test_ratio: float = 0.2, penalty_starts: int = 0, penalty_stops: int = 10,
36 static_reward_adjustment: float = 1, labeled_data_balancer: Optional[LabeledDataBalancer] = None,
37 meta_data: Optional[dict[str, Any]] = None) -> None:
38 """
39 Class constructor. Allows to define all crucial constans, reward validation methods,
40 environmental penalty policies, etc.
42 Parameters:
43 data (pd.DataFrame): DataFrame containing historical market data.
44 initial_budget (float): Initial budget constant for trader to start from.
45 max_amount_of_trades (int): Max amount of trades that can be ongoing at the same time.
46 Seting this constant prevents traders from placing orders randomly and defines
47 amount of money that can be assigned to a single order at certain iteration.
48 window_size (int): Constant defining how far in the past trader will be able to look
49 into at certain iteration.
50 validator (RewardValidatorBase): Validator implementing policy used to award points
51 for closed trades.
52 label_annotator (LabelAnnotatorBase): Annotator implementing policy used to label
53 data with target values. It is used to provide supervised agents with information
54 about what is the target class value for certain iteration.
55 sell_stop_loss (float): Constant used to define losing boundary at which sell order
56 (short) is closed.
57 sell_take_profit (float): Constant used to define winning boundary at which sell order
58 (short) is closed.
59 buy_stop_loss (float): Constant used to define losing boundary at which buy order
60 (long) is closed.
61 buy_take_profit (float): Constant used to define winning boundary at which buy order
62 (long) is closed.
63 test_ratio (float): Ratio of data that should be used for testing purposes.
64 penalty_starts (int): Constant defining how many trading periods can trader go without placing
65 an order until penalty is imposed. Penalty at range between start and stop constant
66 is calculated as percentile of positive reward, and subtracted from the actual reward.
67 penalty_stops (int): Constant defining at which trading period penalty will no longer be increased.
68 Reward for trading periods exceeding penalty stop constant will equal minus static reward adjustment.
69 static_reward_adjustment (float): Constant use to penalize trader for bad choices or
70 reward it for good one.
71 labeled_data_balancer (Optional[LabeledDataBalancer]): Balancer used to balance
72 labeled data. If None, no balancing will be performed.
73 meta_data (dict[str, Any]): Dictionary containing metadata about the dataset.
74 """
76 if test_ratio < 0.0 or test_ratio >= 1.0:
77 raise ValueError(f"Invalid test_ratio: {test_ratio}. It should be in range [0, 1).")
79 self.__data: dict[pd.DataFrame, pd.DataFrame] = self.__split_data(data, test_ratio)
80 self.__meta_data: Optional[dict[str, Any]] = meta_data
81 self.__mode = TradingEnvironment.TRAIN_MODE
82 self.__broker: Broker = Broker()
83 self.__validator: RewardValidatorBase = validator
84 self.__label_annotator: LabelAnnotatorBase = label_annotator
85 self.__labeled_data_balancer: Optional[LabeledDataBalancer] = labeled_data_balancer
87 self.__trading_data: SimpleNamespace = SimpleNamespace()
88 self.__trading_data.current_budget: float = initial_budget
89 self.__trading_data.currently_invested: float = 0
90 self.__trading_data.no_trades_placed_for: int = 0
91 self.__trading_data.currently_placed_trades: int = 0
93 self.__trading_consts = SimpleNamespace()
94 self.__trading_consts.INITIAL_BUDGET: float = initial_budget
95 self.__trading_consts.MAX_AMOUNT_OF_TRADES: int = max_amount_of_trades
96 self.__trading_consts.WINDOW_SIZE: int = window_size
97 self.__trading_consts.SELL_STOP_LOSS: float = sell_stop_loss
98 self.__trading_consts.SELL_TAKE_PROFIT: float = sell_take_profit
99 self.__trading_consts.BUY_STOP_LOSS: float = buy_stop_loss
100 self.__trading_consts.BUY_TAKE_PROFIT: float = buy_take_profit
101 self.__trading_consts.STATIC_REWARD_ADJUSTMENT: float = static_reward_adjustment
102 self.__trading_consts.PENALTY_STARTS: int = penalty_starts
103 self.__trading_consts.PENALTY_STOPS: int = penalty_stops
104 self.__trading_consts.PROFITABILITY_FUNCTION = lambda x: -1.0 * math.exp(-x + 1) + 1
105 self.__trading_consts.PENALTY_FUNCTION = lambda x: \
106 min(1, 1 - math.tanh(-3.0 * (x - penalty_stops) / (penalty_stops - penalty_starts)))
107 self.__trading_consts.OUTPUT_CLASSES: int = vars(self.__label_annotator.get_output_classes())
109 self.current_iteration: int = self.__trading_consts.WINDOW_SIZE
110 self.state: list[float] = self.__prepare_state_data()
111 self.action_space: Discrete = Discrete(3)
112 self.observation_space: Box = Box(low = np.ones(len(self.state)) * -3,
113 high = np.ones(len(self.state)) * 3,
114 dtype=np.float64)
116 def __split_data(self, data: pd.DataFrame, test_size: float) -> dict[pd.DataFrame, pd.DataFrame]:
117 """
118 Splits the given DataFrame into training and testing sets based on the specified test size ratio.
120 Parameters:
121 data (pd.DataFrame): DataFrame containing the stock market data.
122 test_size (float): Ratio of the data to be used for testing.
124 Returns:
125 (dict[pd.DataFrame, pd.DataFrame]): Dictionary containing training and testing data frames.
126 """
128 dividing_index = int(len(data) * (1 - test_size))
130 return {
131 TradingEnvironment.TRAIN_MODE: data.iloc[:dividing_index].reset_index(drop=True),
132 TradingEnvironment.TEST_MODE: data.iloc[dividing_index:].reset_index(drop=True)
133 }
135 def __prepare_labeled_data(self) -> pd.DataFrame:
136 """
137 Prepares labeled data for training the model with classification approach.
138 It extracts the relevant features and labels from the environment's data.
140 Returns:
141 (pd.DataFrame): A DataFrame containing the features and labels for training.
142 """
144 new_rows = []
145 for i in range(self.current_iteration, self.get_environment_length() - 1):
146 data_row = self.__prepare_state_data(slice(i - self.__trading_consts.WINDOW_SIZE, i), include_trading_data = False)
147 new_rows.append(data_row)
149 new_data = pd.DataFrame(new_rows, columns=[f"feature_{i}" for i in range(len(new_rows[0]))])
150 labels = self.__label_annotator.annotate(self.__data[self.__mode]).shift(-self.current_iteration)
152 return new_data, labels.dropna()
154 def __prepare_state_data(self, index: Optional[slice] = None, include_trading_data: bool = True) -> list[float]:
155 """
156 Calculates state data as a list of floats representing current iteration's observation.
157 Observations contains all input data refined to window size and couple of coefficients
158 giving an insight into current budget and orders situation.
160 Returns:
161 (list[float]): List with current observations for environment.
162 """
164 if index is None:
165 index = slice(self.current_iteration - self.__trading_consts.WINDOW_SIZE, self.current_iteration)
167 current_market_data = self.__data[self.__mode].iloc[index]
168 current_market_data_no_index = current_market_data.select_dtypes(include = [np.number])
170 if self.__meta_data is not None and \
171 self.__meta_data.get('normalization_groups', None) is not None:
172 normalization_groups = self.__meta_data['normalization_groups']
173 normalized_data_pieces = []
174 for normalization_group in normalization_groups:
175 columns_to_normalize = current_market_data_no_index[normalization_group]
176 normalized_columns = StandardScaler().fit_transform(columns_to_normalize.values.reshape(-1, 1))
177 normalized_data_pieces.append(normalized_columns.reshape(*columns_to_normalize.shape))
178 normalized_current_market_data_values = np.hstack(normalized_data_pieces)
179 else:
180 normalized_current_market_data_values = StandardScaler().fit_transform(current_market_data_no_index)
181 current_marked_data_list = normalized_current_market_data_values.ravel().tolist()
183 if include_trading_data:
184 current_normalized_budget = 1.0 * self.__trading_data.current_budget / self.__trading_consts.INITIAL_BUDGET
185 current_profitability_coeff = self.__trading_consts.PROFITABILITY_FUNCTION(current_normalized_budget)
186 current_trades_occupancy_coeff = 1.0 * self.__trading_data.currently_placed_trades / self.__trading_consts.MAX_AMOUNT_OF_TRADES
187 current_no_trades_penalty_coeff = self.__trading_consts.PENALTY_FUNCTION(self.__trading_data.no_trades_placed_for)
188 current_inner_state_list = [current_profitability_coeff, current_trades_occupancy_coeff, current_no_trades_penalty_coeff]
189 current_marked_data_list += current_inner_state_list
191 return current_marked_data_list
193 def set_mode(self, mode: str) -> None:
194 """
195 Sets the mode of the environment to either TRAIN_MODE or TEST_MODE.
197 Parameters:
198 mode (str): Mode to set for the environment.
200 Raises:
201 ValueError: If the provided mode is not valid.
202 """
204 if mode not in [TradingEnvironment.TRAIN_MODE, TradingEnvironment.TEST_MODE]:
205 raise ValueError(f"Invalid mode: {mode}. Use TradingEnvironment.TRAIN_MODE or TradingEnvironment.TEST_MODE.")
206 self.__mode = mode
208 def get_mode(self) -> str:
209 """
210 Mode getter.
212 Returns:
213 (str): Current mode of the environment.
214 """
216 return copy.copy(self.__mode)
218 def get_trading_data(self) -> SimpleNamespace:
219 """
220 Trading data getter.
222 Returns:
223 (SimpleNamespace): Copy of the namespace with all trading data.
224 """
226 return copy.copy(self.__trading_data)
228 def get_trading_consts(self) -> SimpleNamespace:
229 """
230 Trading constants getter.
232 Returns:
233 (SimpleNamespace): Copy of the namespace with all trading constants.
234 """
236 return copy.copy(self.__trading_consts)
238 def get_broker(self) -> Broker:
239 """
240 Broker getter.
242 Returns:
243 (Broker): Copy of the broker used by environment.
244 """
246 return copy.copy(self.__broker)
248 def get_environment_length(self) -> int:
249 """
250 Environment length getter.
252 Returns:
253 (Int): Length of environment.
254 """
256 return len(self.__data[self.__mode])
258 def get_environment_spatial_data_dimension(self) -> tuple[int, int]:
259 """
260 Environment spatial data dimensionality getter.
262 Returns:
263 (Int): Dimension of spatial data in environment.
264 """
266 return (self.__trading_consts.WINDOW_SIZE, self.__data[self.__mode].shape[1] - 1)
268 def get_labeled_data(self, should_split: bool = True, should_balance: bool = True,
269 verbose: bool = True) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
270 """
271 Prepares labeled data for training or testing the model.
272 It extracts the relevant features and labels from the environment's data.
274 Parameters:
275 should_split (bool): Whether to split the data into training and testing sets.
276 Defaults to True. If set to False, testing data will be empty.
277 should_balance (bool): Whether to balance the labeled data. Defaults to True.
278 Will be ignored if labeled_data_balancer is None.
279 verbose (bool): Whether to log the class cardinality before and after balancing.
280 Defaults to True.
282 Returns:
283 (tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]): A tuple containing the
284 input data, output data, test input data, and test output data.
285 """
287 input_data, output_data = self.__prepare_labeled_data()
288 input_data_test, output_data_test = [], []
289 if verbose:
290 logging.info(f"Original class cardinality: {np.array(to_categorical(output_data)).sum(axis = 0)}")
292 if self.__mode == TradingEnvironment.TRAIN_MODE:
293 if should_split:
294 input_data, input_data_test, output_data, output_data_test = \
295 train_test_split(input_data, output_data, test_size = 0.1, random_state = 42,
296 stratify = output_data)
298 if self.__labeled_data_balancer is not None and should_balance:
299 input_data, output_data = self.__labeled_data_balancer.balance(input_data, output_data)
300 if verbose:
301 logging.info(f"Balanced class cardinality: {np.array(to_categorical(output_data)).sum(axis = 0)}")
303 return copy.copy((np.array(input_data), np.array(output_data),
304 np.array(input_data_test), np.array(output_data_test)))
306 def get_data_for_iteration(self, columns: list[str], start: int = 0, stop: Optional[int] = None,
307 step: int = 1) -> list[float]:
308 """
309 Data getter for certain iterations.
311 Parameters:
312 columns (list[str]): List of column names to extract from data.
313 start (int): Start iteration index. Defaults to 0.
314 stop (int): Stop iteration index. Defaults to environment length minus one.
315 step (int): Step between iterations. Defaults to 1.
317 Returns:
318 (list[float]): Copy of part of data with specified columns
319 over specified iterations.
320 """
322 if stop is None:
323 stop = self.get_environment_length() - 1
325 return copy.copy(self.__data[self.__mode].loc[start:stop:step, columns].values.ravel().tolist())
327 def step(self, action: int) -> tuple[list[float], float, bool, dict]:
328 """
329 Performs specified action on environment. It results in generation of the new
330 observations. This function causes trades to be handled, reward to be calculated and
331 environment to be updated.
333 Parameters:
334 action (int): Number specifing action. Possible values are 0 for buy action,
335 1 for wait action and 2 for sell action.
337 Returns:
338 (tuple[list[float], float, bool, dict]): Tuple containing next observation
339 state, reward, finish indication and additional info dictionary.
340 """
342 self.current_iteration += 1
343 self.state = self.__prepare_state_data()
345 close_changes = self.__data[self.__mode].iloc[self.current_iteration - 2 : self.current_iteration]['close'].values
346 stock_change_coeff = 1 + (close_changes[1] - close_changes[0]) / close_changes[0]
347 closed_orders= self.__broker.update_orders(stock_change_coeff)
349 reward = self.__validator.validate_orders(closed_orders)
350 self.__trading_data.currently_placed_trades -= len(closed_orders)
351 self.__trading_data.current_budget += np.sum([trade.current_value for trade in closed_orders])
352 self.__trading_data.currently_invested -= np.sum([trade.initial_value for trade in closed_orders])
354 number_of_possible_trades = self.__trading_consts.MAX_AMOUNT_OF_TRADES - self.__trading_data.currently_placed_trades
355 money_to_trade = 0
356 if number_of_possible_trades > 0:
357 money_to_trade = 1.0 / number_of_possible_trades * self.__trading_data.current_budget
359 if action == 0:
360 is_buy_order = True
361 stop_loss = self.__trading_consts.SELL_STOP_LOSS
362 take_profit = self.__trading_consts.SELL_TAKE_PROFIT
363 elif action == 2:
364 is_buy_order = False
365 stop_loss = self.__trading_consts.BUY_STOP_LOSS
366 take_profit = self.__trading_consts.BUY_TAKE_PROFIT
368 if action != 1:
369 if number_of_possible_trades > 0:
370 self.__trading_data.current_budget -= money_to_trade
371 self.__trading_data.currently_invested += money_to_trade
372 self.__broker.place_order(money_to_trade, is_buy_order, stop_loss, take_profit)
373 self.__trading_data.currently_placed_trades += 1
374 self.__trading_data.no_trades_placed_for = 0
375 reward += self.__trading_consts.STATIC_REWARD_ADJUSTMENT
376 else:
377 self.__trading_data.no_trades_placed_for += 1
378 reward -= self.__trading_consts.STATIC_REWARD_ADJUSTMENT
379 else:
380 self.__trading_data.no_trades_placed_for += 1
381 if number_of_possible_trades == 0:
382 reward += self.__trading_consts.STATIC_REWARD_ADJUSTMENT
384 if number_of_possible_trades > 0:
385 reward *= (1 - self.__trading_consts.PENALTY_FUNCTION(self.__trading_data.no_trades_placed_for)) \
386 if reward > 0 else 1
387 if self.__trading_consts.PENALTY_STOPS < self.__trading_data.no_trades_placed_for:
388 reward -= self.__trading_consts.STATIC_REWARD_ADJUSTMENT
390 if (self.current_iteration >= self.get_environment_length() - 1 or
391 self.__trading_data.current_budget > 10 * self.__trading_consts.INITIAL_BUDGET or
392 (self.__trading_data.current_budget + self.__trading_data.currently_invested) / self.__trading_consts.INITIAL_BUDGET < 0.8):
393 done = True
394 else:
395 done = False
397 info = {'coeff': stock_change_coeff,
398 'iteration': self.current_iteration,
399 'number_of_closed_orders': len(closed_orders),
400 'money_to_trade': money_to_trade,
401 'action': action,
402 'current_budget': self.__trading_data.current_budget,
403 'currently_invested': self.__trading_data.currently_invested,
404 'no_trades_placed_for': self.__trading_data.no_trades_placed_for,
405 'currently_placed_trades': self.__trading_data.currently_placed_trades}
407 return self.state, reward, done, info
409 def render(self) -> None:
410 """
411 Renders environment visualization. Will be implemented later.
412 """
414 #TODO: Visualization to be implemented
415 pass
417 def reset(self, randkey: Optional[int] = None) -> list[float]:
418 """
419 Resets environment. Used typically if environemnt is finished,
420 i.e. when ther is no more steps to be taken within environemnt
421 or finish conditions are fulfilled.
423 Parameters:
424 randkey (Optional[int]): Value indicating what iteration
425 should be trated as starting point after reset.
427 Returns:
428 (list[float]): Current iteration observation state.
429 """
431 if randkey is None:
432 randkey = random.randint(self.__trading_consts.WINDOW_SIZE, self.get_environment_length() - 1)
433 self.__trading_data.current_budget = self.__trading_consts.INITIAL_BUDGET
434 self.__trading_data.currently_invested = 0
435 self.__trading_data.no_trades_placed_for = 0
436 self.__trading_data.currently_placed_trades = 0
437 self.__broker.reset()
438 self.current_iteration = randkey
439 self.state = self.__prepare_state_data()
441 return self.state