Coverage for source/environment/trading_environment.py: 88%

178 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-08-01 20:51 +0000

1# environment/trading environment.py 

2 

3# global imports 

4import copy 

5import logging 

6import math 

7import numpy as np 

8import pandas as pd 

9import random 

10from gym import Env 

11from gym.spaces import Box, Discrete 

12from sklearn.preprocessing import StandardScaler 

13from sklearn.model_selection import train_test_split 

14from tensorflow.keras.utils import to_categorical 

15from types import SimpleNamespace 

16from typing import Any, Optional 

17 

18# local imports 

19from source.environment import Broker, LabelAnnotatorBase, LabeledDataBalancer, RewardValidatorBase 

20 

21class TradingEnvironment(Env): 

22 """ 

23 Implements stock market environment that actor can perform actions (place orders) in. 

24 It is used to train various models using various approaches. Can be 

25 configured to award points and impose a penalty in several ways. 

26 """ 

27 

28 # global class constants 

29 TRAIN_MODE = 'train' 

30 TEST_MODE = 'test' 

31 

32 def __init__(self, data: pd.DataFrame, initial_budget: float, max_amount_of_trades: int, 

33 window_size: int, validator: RewardValidatorBase, label_annotator: LabelAnnotatorBase, 

34 sell_stop_loss: float, sell_take_profit: float, buy_stop_loss: float, buy_take_profit: float, 

35 test_ratio: float = 0.2, penalty_starts: int = 0, penalty_stops: int = 10, 

36 static_reward_adjustment: float = 1, labeled_data_balancer: Optional[LabeledDataBalancer] = None, 

37 meta_data: Optional[dict[str, Any]] = None) -> None: 

38 """ 

39 Class constructor. Allows to define all crucial constans, reward validation methods, 

40 environmental penalty policies, etc. 

41 

42 Parameters: 

43 data (pd.DataFrame): DataFrame containing historical market data. 

44 initial_budget (float): Initial budget constant for trader to start from. 

45 max_amount_of_trades (int): Max amount of trades that can be ongoing at the same time. 

46 Seting this constant prevents traders from placing orders randomly and defines 

47 amount of money that can be assigned to a single order at certain iteration. 

48 window_size (int): Constant defining how far in the past trader will be able to look 

49 into at certain iteration. 

50 validator (RewardValidatorBase): Validator implementing policy used to award points 

51 for closed trades. 

52 label_annotator (LabelAnnotatorBase): Annotator implementing policy used to label 

53 data with target values. It is used to provide supervised agents with information 

54 about what is the target class value for certain iteration. 

55 sell_stop_loss (float): Constant used to define losing boundary at which sell order 

56 (short) is closed. 

57 sell_take_profit (float): Constant used to define winning boundary at which sell order 

58 (short) is closed. 

59 buy_stop_loss (float): Constant used to define losing boundary at which buy order 

60 (long) is closed. 

61 buy_take_profit (float): Constant used to define winning boundary at which buy order 

62 (long) is closed. 

63 test_ratio (float): Ratio of data that should be used for testing purposes. 

64 penalty_starts (int): Constant defining how many trading periods can trader go without placing 

65 an order until penalty is imposed. Penalty at range between start and stop constant 

66 is calculated as percentile of positive reward, and subtracted from the actual reward. 

67 penalty_stops (int): Constant defining at which trading period penalty will no longer be increased. 

68 Reward for trading periods exceeding penalty stop constant will equal minus static reward adjustment. 

69 static_reward_adjustment (float): Constant use to penalize trader for bad choices or 

70 reward it for good one. 

71 labeled_data_balancer (Optional[LabeledDataBalancer]): Balancer used to balance 

72 labeled data. If None, no balancing will be performed. 

73 meta_data (dict[str, Any]): Dictionary containing metadata about the dataset. 

74 """ 

75 

76 if test_ratio < 0.0 or test_ratio >= 1.0: 

77 raise ValueError(f"Invalid test_ratio: {test_ratio}. It should be in range [0, 1).") 

78 

79 self.__data: dict[pd.DataFrame, pd.DataFrame] = self.__split_data(data, test_ratio) 

80 self.__meta_data: Optional[dict[str, Any]] = meta_data 

81 self.__mode = TradingEnvironment.TRAIN_MODE 

82 self.__broker: Broker = Broker() 

83 self.__validator: RewardValidatorBase = validator 

84 self.__label_annotator: LabelAnnotatorBase = label_annotator 

85 self.__labeled_data_balancer: Optional[LabeledDataBalancer] = labeled_data_balancer 

86 

87 self.__trading_data: SimpleNamespace = SimpleNamespace() 

88 self.__trading_data.current_budget: float = initial_budget 

89 self.__trading_data.currently_invested: float = 0 

90 self.__trading_data.no_trades_placed_for: int = 0 

91 self.__trading_data.currently_placed_trades: int = 0 

92 

93 self.__trading_consts = SimpleNamespace() 

94 self.__trading_consts.INITIAL_BUDGET: float = initial_budget 

95 self.__trading_consts.MAX_AMOUNT_OF_TRADES: int = max_amount_of_trades 

96 self.__trading_consts.WINDOW_SIZE: int = window_size 

97 self.__trading_consts.SELL_STOP_LOSS: float = sell_stop_loss 

98 self.__trading_consts.SELL_TAKE_PROFIT: float = sell_take_profit 

99 self.__trading_consts.BUY_STOP_LOSS: float = buy_stop_loss 

100 self.__trading_consts.BUY_TAKE_PROFIT: float = buy_take_profit 

101 self.__trading_consts.STATIC_REWARD_ADJUSTMENT: float = static_reward_adjustment 

102 self.__trading_consts.PENALTY_STARTS: int = penalty_starts 

103 self.__trading_consts.PENALTY_STOPS: int = penalty_stops 

104 self.__trading_consts.PROFITABILITY_FUNCTION = lambda x: -1.0 * math.exp(-x + 1) + 1 

105 self.__trading_consts.PENALTY_FUNCTION = lambda x: \ 

106 min(1, 1 - math.tanh(-3.0 * (x - penalty_stops) / (penalty_stops - penalty_starts))) 

107 self.__trading_consts.OUTPUT_CLASSES: int = vars(self.__label_annotator.get_output_classes()) 

108 

109 self.current_iteration: int = self.__trading_consts.WINDOW_SIZE 

110 self.state: list[float] = self.__prepare_state_data() 

111 self.action_space: Discrete = Discrete(3) 

112 self.observation_space: Box = Box(low = np.ones(len(self.state)) * -3, 

113 high = np.ones(len(self.state)) * 3, 

114 dtype=np.float64) 

115 

116 def __split_data(self, data: pd.DataFrame, test_size: float) -> dict[pd.DataFrame, pd.DataFrame]: 

117 """ 

118 Splits the given DataFrame into training and testing sets based on the specified test size ratio. 

119 

120 Parameters: 

121 data (pd.DataFrame): DataFrame containing the stock market data. 

122 test_size (float): Ratio of the data to be used for testing. 

123 

124 Returns: 

125 (dict[pd.DataFrame, pd.DataFrame]): Dictionary containing training and testing data frames. 

126 """ 

127 

128 dividing_index = int(len(data) * (1 - test_size)) 

129 

130 return { 

131 TradingEnvironment.TRAIN_MODE: data.iloc[:dividing_index].reset_index(drop=True), 

132 TradingEnvironment.TEST_MODE: data.iloc[dividing_index:].reset_index(drop=True) 

133 } 

134 

135 def __prepare_labeled_data(self) -> pd.DataFrame: 

136 """ 

137 Prepares labeled data for training the model with classification approach. 

138 It extracts the relevant features and labels from the environment's data. 

139 

140 Returns: 

141 (pd.DataFrame): A DataFrame containing the features and labels for training. 

142 """ 

143 

144 new_rows = [] 

145 for i in range(self.current_iteration, self.get_environment_length() - 1): 

146 data_row = self.__prepare_state_data(slice(i - self.__trading_consts.WINDOW_SIZE, i), include_trading_data = False) 

147 new_rows.append(data_row) 

148 

149 new_data = pd.DataFrame(new_rows, columns=[f"feature_{i}" for i in range(len(new_rows[0]))]) 

150 labels = self.__label_annotator.annotate(self.__data[self.__mode].copy()).shift(-self.current_iteration) 

151 

152 return new_data, labels.dropna() 

153 

154 def __prepare_state_data(self, index: Optional[slice] = None, include_trading_data: bool = True) -> list[float]: 

155 """ 

156 Calculates state data as a list of floats representing current iteration's observation. 

157 Observations contains all input data refined to window size and couple of coefficients 

158 giving an insight into current budget and orders situation. 

159 

160 Returns: 

161 (list[float]): List with current observations for environment. 

162 """ 

163 

164 if index is None: 

165 index = slice(self.current_iteration - self.__trading_consts.WINDOW_SIZE, self.current_iteration) 

166 

167 current_market_data = self.__data[self.__mode].iloc[index] 

168 current_market_data_no_index = current_market_data.select_dtypes(include = [np.number]) 

169 

170 if self.__meta_data is not None and \ 

171 self.__meta_data.get('normalization_groups', None) is not None: 

172 grouped_columns_names = self.__meta_data['normalization_groups'] 

173 preprocessed_data_pieces = [] 

174 left_over_columns_names = set(current_market_data_no_index.columns) 

175 for columns_names_to_normalize in grouped_columns_names: 

176 left_over_columns_names -= set(columns_names_to_normalize) 

177 data_frame_piece_to_normalize = current_market_data_no_index[columns_names_to_normalize] 

178 normalized_data_frame_piece = StandardScaler().fit_transform(data_frame_piece_to_normalize.values.reshape(-1, 1)) 

179 preprocessed_data_pieces.append(normalized_data_frame_piece.reshape(*data_frame_piece_to_normalize.shape)) 

180 for column in left_over_columns_names: 

181 preprocessed_data_pieces.append(current_market_data_no_index[column].values.reshape(-1, 1)) 

182 normalized_current_market_data_values = np.hstack(preprocessed_data_pieces) 

183 else: 

184 normalized_current_market_data_values = StandardScaler().fit_transform(current_market_data_no_index) 

185 current_marked_data_list = normalized_current_market_data_values.ravel().tolist() 

186 

187 if include_trading_data: 

188 current_normalized_budget = 1.0 * self.__trading_data.current_budget / self.__trading_consts.INITIAL_BUDGET 

189 current_profitability_coeff = self.__trading_consts.PROFITABILITY_FUNCTION(current_normalized_budget) 

190 current_trades_occupancy_coeff = 1.0 * self.__trading_data.currently_placed_trades / self.__trading_consts.MAX_AMOUNT_OF_TRADES 

191 current_no_trades_penalty_coeff = self.__trading_consts.PENALTY_FUNCTION(self.__trading_data.no_trades_placed_for) 

192 current_inner_state_list = [current_profitability_coeff, current_trades_occupancy_coeff, current_no_trades_penalty_coeff] 

193 current_marked_data_list += current_inner_state_list 

194 

195 return current_marked_data_list 

196 

197 def set_mode(self, mode: str) -> None: 

198 """ 

199 Sets the mode of the environment to either TRAIN_MODE or TEST_MODE. 

200 

201 Parameters: 

202 mode (str): Mode to set for the environment. 

203 

204 Raises: 

205 ValueError: If the provided mode is not valid. 

206 """ 

207 

208 if mode not in [TradingEnvironment.TRAIN_MODE, TradingEnvironment.TEST_MODE]: 

209 raise ValueError(f"Invalid mode: {mode}. Use TradingEnvironment.TRAIN_MODE or TradingEnvironment.TEST_MODE.") 

210 self.__mode = mode 

211 

212 def get_mode(self) -> str: 

213 """ 

214 Mode getter. 

215 

216 Returns: 

217 (str): Current mode of the environment. 

218 """ 

219 

220 return copy.copy(self.__mode) 

221 

222 def get_trading_data(self) -> SimpleNamespace: 

223 """ 

224 Trading data getter. 

225 

226 Returns: 

227 (SimpleNamespace): Copy of the namespace with all trading data. 

228 """ 

229 

230 return copy.copy(self.__trading_data) 

231 

232 def get_trading_consts(self) -> SimpleNamespace: 

233 """ 

234 Trading constants getter. 

235 

236 Returns: 

237 (SimpleNamespace): Copy of the namespace with all trading constants. 

238 """ 

239 

240 return copy.copy(self.__trading_consts) 

241 

242 def get_broker(self) -> Broker: 

243 """ 

244 Broker getter. 

245 

246 Returns: 

247 (Broker): Copy of the broker used by environment. 

248 """ 

249 

250 return copy.copy(self.__broker) 

251 

252 def get_environment_length(self) -> int: 

253 """ 

254 Environment length getter. 

255 

256 Returns: 

257 (Int): Length of environment. 

258 """ 

259 

260 return len(self.__data[self.__mode]) 

261 

262 def get_environment_spatial_data_dimension(self) -> tuple[int, int]: 

263 """ 

264 Environment spatial data dimensionality getter. 

265 

266 Returns: 

267 (Int): Dimension of spatial data in environment. 

268 """ 

269 

270 return (self.__trading_consts.WINDOW_SIZE, self.__data[self.__mode].shape[1] - 1) 

271 

272 def get_labeled_data(self, should_split: bool = True, should_balance: bool = True, 

273 verbose: bool = True) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: 

274 """ 

275 Prepares labeled data for training or testing the model. 

276 It extracts the relevant features and labels from the environment's data. 

277 

278 Parameters: 

279 should_split (bool): Whether to split the data into training and testing sets. 

280 Defaults to True. If set to False, testing data will be empty. 

281 should_balance (bool): Whether to balance the labeled data. Defaults to True. 

282 Will be ignored if labeled_data_balancer is None. 

283 verbose (bool): Whether to log the class cardinality before and after balancing. 

284 Defaults to True. 

285 

286 Returns: 

287 (tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]): A tuple containing the 

288 input data, output data, test input data, and test output data. 

289 """ 

290 

291 input_data, output_data = self.__prepare_labeled_data() 

292 input_data_test, output_data_test = [], [] 

293 if verbose: 

294 logging.info(f"Original class cardinality: {np.array(to_categorical(output_data)).sum(axis = 0)}") 

295 

296 if self.__mode == TradingEnvironment.TRAIN_MODE: 

297 if should_split: 

298 input_data, input_data_test, output_data, output_data_test = \ 

299 train_test_split(input_data, output_data, test_size = 0.1, random_state = 42, 

300 stratify = output_data) 

301 

302 if self.__labeled_data_balancer is not None and should_balance: 

303 input_data, output_data = self.__labeled_data_balancer.balance(input_data, output_data) 

304 if verbose: 

305 logging.info(f"Balanced class cardinality: {np.array(to_categorical(output_data)).sum(axis = 0)}") 

306 

307 return copy.copy((np.array(input_data), np.array(output_data), 

308 np.array(input_data_test), np.array(output_data_test))) 

309 

310 def get_data_for_iteration(self, columns: list[str], start: int = 0, stop: Optional[int] = None, 

311 step: int = 1) -> list[float]: 

312 """ 

313 Data getter for certain iterations. 

314 

315 Parameters: 

316 columns (list[str]): List of column names to extract from data. 

317 start (int): Start iteration index. Defaults to 0. 

318 stop (int): Stop iteration index. Defaults to environment length minus one. 

319 step (int): Step between iterations. Defaults to 1. 

320 

321 Returns: 

322 (list[float]): Copy of part of data with specified columns 

323 over specified iterations. 

324 """ 

325 

326 if stop is None: 

327 stop = self.get_environment_length() - 1 

328 

329 return copy.copy(self.__data[self.__mode].loc[start:stop:step, columns].values.ravel().tolist()) 

330 

331 def step(self, action: int) -> tuple[list[float], float, bool, dict]: 

332 """ 

333 Performs specified action on environment. It results in generation of the new 

334 observations. This function causes trades to be handled, reward to be calculated and 

335 environment to be updated. 

336 

337 Parameters: 

338 action (int): Number specifing action. Possible values are 0 for buy action, 

339 1 for wait action and 2 for sell action. 

340 

341 Returns: 

342 (tuple[list[float], float, bool, dict]): Tuple containing next observation 

343 state, reward, finish indication and additional info dictionary. 

344 """ 

345 

346 self.current_iteration += 1 

347 self.state = self.__prepare_state_data() 

348 

349 close_changes = self.__data[self.__mode].iloc[self.current_iteration - 2 : self.current_iteration]['close'].values 

350 stock_change_coeff = 1 + (close_changes[1] - close_changes[0]) / close_changes[0] 

351 closed_orders= self.__broker.update_orders(stock_change_coeff) 

352 

353 reward = self.__validator.validate_orders(closed_orders) 

354 self.__trading_data.currently_placed_trades -= len(closed_orders) 

355 self.__trading_data.current_budget += np.sum([trade.current_value for trade in closed_orders]) 

356 self.__trading_data.currently_invested -= np.sum([trade.initial_value for trade in closed_orders]) 

357 

358 number_of_possible_trades = self.__trading_consts.MAX_AMOUNT_OF_TRADES - self.__trading_data.currently_placed_trades 

359 money_to_trade = 0 

360 if number_of_possible_trades > 0: 

361 money_to_trade = 1.0 / number_of_possible_trades * self.__trading_data.current_budget 

362 

363 if action == 0: 

364 is_buy_order = True 

365 stop_loss = self.__trading_consts.SELL_STOP_LOSS 

366 take_profit = self.__trading_consts.SELL_TAKE_PROFIT 

367 elif action == 2: 

368 is_buy_order = False 

369 stop_loss = self.__trading_consts.BUY_STOP_LOSS 

370 take_profit = self.__trading_consts.BUY_TAKE_PROFIT 

371 

372 if action != 1: 

373 if number_of_possible_trades > 0: 

374 self.__trading_data.current_budget -= money_to_trade 

375 self.__trading_data.currently_invested += money_to_trade 

376 self.__broker.place_order(money_to_trade, is_buy_order, stop_loss, take_profit) 

377 self.__trading_data.currently_placed_trades += 1 

378 self.__trading_data.no_trades_placed_for = 0 

379 reward += self.__trading_consts.STATIC_REWARD_ADJUSTMENT 

380 else: 

381 self.__trading_data.no_trades_placed_for += 1 

382 reward -= self.__trading_consts.STATIC_REWARD_ADJUSTMENT 

383 else: 

384 self.__trading_data.no_trades_placed_for += 1 

385 if number_of_possible_trades == 0: 

386 reward += self.__trading_consts.STATIC_REWARD_ADJUSTMENT 

387 

388 if number_of_possible_trades > 0: 

389 reward *= (1 - self.__trading_consts.PENALTY_FUNCTION(self.__trading_data.no_trades_placed_for)) \ 

390 if reward > 0 else 1 

391 if self.__trading_consts.PENALTY_STOPS < self.__trading_data.no_trades_placed_for: 

392 reward -= self.__trading_consts.STATIC_REWARD_ADJUSTMENT 

393 

394 if (self.current_iteration >= self.get_environment_length() - 1 or 

395 self.__trading_data.current_budget > 10 * self.__trading_consts.INITIAL_BUDGET or 

396 (self.__trading_data.current_budget + self.__trading_data.currently_invested) / self.__trading_consts.INITIAL_BUDGET < 0.8): 

397 done = True 

398 else: 

399 done = False 

400 

401 info = {'coeff': stock_change_coeff, 

402 'iteration': self.current_iteration, 

403 'number_of_closed_orders': len(closed_orders), 

404 'money_to_trade': money_to_trade, 

405 'action': action, 

406 'current_budget': self.__trading_data.current_budget, 

407 'currently_invested': self.__trading_data.currently_invested, 

408 'no_trades_placed_for': self.__trading_data.no_trades_placed_for, 

409 'currently_placed_trades': self.__trading_data.currently_placed_trades} 

410 

411 return self.state, reward, done, info 

412 

413 def render(self) -> None: 

414 """ 

415 Renders environment visualization. Will be implemented later. 

416 """ 

417 

418 #TODO: Visualization to be implemented 

419 pass 

420 

421 def reset(self, randkey: Optional[int] = None) -> list[float]: 

422 """ 

423 Resets environment. Used typically if environemnt is finished, 

424 i.e. when ther is no more steps to be taken within environemnt 

425 or finish conditions are fulfilled. 

426 

427 Parameters: 

428 randkey (Optional[int]): Value indicating what iteration 

429 should be trated as starting point after reset. 

430 

431 Returns: 

432 (list[float]): Current iteration observation state. 

433 """ 

434 

435 if randkey is None: 

436 randkey = random.randint(self.__trading_consts.WINDOW_SIZE, self.get_environment_length() - 1) 

437 self.__trading_data.current_budget = self.__trading_consts.INITIAL_BUDGET 

438 self.__trading_data.currently_invested = 0 

439 self.__trading_data.no_trades_placed_for = 0 

440 self.__trading_data.currently_placed_trades = 0 

441 self.__broker.reset() 

442 self.current_iteration = randkey 

443 self.state = self.__prepare_state_data() 

444 

445 return self.state