Coverage for source/environment/trading_environment.py: 81%

232 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-09-29 20:04 +0000

1# environment/trading environment.py 

2 

3# global imports 

4import copy 

5import logging 

6import math 

7import numpy as np 

8import pandas as pd 

9import random 

10from enum import Enum 

11from gym import Env 

12from gym.spaces import Box, Discrete 

13from sklearn.model_selection import train_test_split 

14from tensorflow.keras.utils import to_categorical 

15from types import SimpleNamespace 

16from typing import Any, Optional 

17 

18# local imports 

19from source.environment import Broker, LabelAnnotatorBase, LabeledDataBalancer, RewardValidatorBase 

20 

21class TradingEnvironment(Env): 

22 """ 

23 Implements stock market environment that actor can perform actions (place orders) in. 

24 It is used to train various models using various approaches. Can be 

25 configured to award points and impose a penalty in several ways. 

26 """ 

27 

28 class TradingMode(Enum): 

29 """ 

30 Enumeration for the different trading modes. 

31 """ 

32 

33 IMPLICIT_ORDER_CLOSING = 0 

34 EXPLICIT_ORDER_CLOSING = 1 

35 

36 # global class constants 

37 TRAIN_MODE = 'train' 

38 TEST_MODE = 'test' 

39 

40 def __init__(self, data: pd.DataFrame, initial_budget: float, max_amount_of_trades: int, 

41 window_size: int, validator: RewardValidatorBase, label_annotator: LabelAnnotatorBase, 

42 sell_stop_loss: float, sell_take_profit: float, buy_stop_loss: float, buy_take_profit: float, 

43 test_ratio: float = 0.2, penalty_starts: int = 0, penalty_stops: int = 10, 

44 static_reward_adjustment: float = 1, labeled_data_balancer: Optional[LabeledDataBalancer] = None, 

45 meta_data: Optional[dict[str, Any]] = None, trading_mode: Optional[TradingMode] = None, 

46 should_prefetch: bool = True) -> None: 

47 """ 

48 Class constructor. Allows to define all crucial constans, reward validation methods, 

49 environmental penalty policies, etc. 

50 

51 Parameters: 

52 data (pd.DataFrame): DataFrame containing historical market data. 

53 initial_budget (float): Initial budget constant for trader to start from. 

54 max_amount_of_trades (int): Max amount of trades that can be ongoing at the same time. 

55 Seting this constant prevents traders from placing orders randomly and defines 

56 amount of money that can be assigned to a single order at certain iteration. 

57 window_size (int): Constant defining how far in the past trader will be able to look 

58 into at certain iteration. 

59 validator (RewardValidatorBase): Validator implementing policy used to award points 

60 for closed trades. 

61 label_annotator (LabelAnnotatorBase): Annotator implementing policy used to label 

62 data with target values. It is used to provide supervised agents with information 

63 about what is the target class value for certain iteration. 

64 sell_stop_loss (float): Constant used to define losing boundary at which sell order 

65 (short) is closed. 

66 sell_take_profit (float): Constant used to define winning boundary at which sell order 

67 (short) is closed. 

68 buy_stop_loss (float): Constant used to define losing boundary at which buy order 

69 (long) is closed. 

70 buy_take_profit (float): Constant used to define winning boundary at which buy order 

71 (long) is closed. 

72 test_ratio (float): Ratio of data that should be used for testing purposes. 

73 penalty_starts (int): Constant defining how many trading periods can trader go without placing 

74 an order until penalty is imposed. Penalty at range between start and stop constant 

75 is calculated as percentile of positive reward, and subtracted from the actual reward. 

76 penalty_stops (int): Constant defining at which trading period penalty will no longer be increased. 

77 Reward for trading periods exceeding penalty stop constant will equal minus static reward adjustment. 

78 static_reward_adjustment (float): Constant use to penalize trader for bad choices or 

79 reward it for good one. 

80 labeled_data_balancer (Optional[LabeledDataBalancer]): Balancer used to balance 

81 labeled data. If None, no balancing will be performed. 

82 meta_data (dict[str, Any]): Dictionary containing metadata about the dataset. 

83 mode (TradingMode): Mode of the environment, either IMPLICIT_ORDER_CLOSING or EXPLICIT_ORDER_CLOSING. 

84 should_prefetch (bool): If True, data will be pre-fetched to speed up training. 

85 """ 

86 

87 if test_ratio < 0.0 or test_ratio >= 1.0: 

88 raise ValueError(f"Invalid test_ratio: {test_ratio}. It should be in range [0, 1).") 

89 

90 if trading_mode is None: 

91 trading_mode = TradingEnvironment.TradingMode.IMPLICIT_ORDER_CLOSING 

92 elif isinstance(trading_mode, int): 

93 trading_mode = TradingEnvironment.TradingMode(trading_mode) 

94 

95 if not isinstance(trading_mode, TradingEnvironment.TradingMode): 

96 raise ValueError(f"Invalid trading_mode: {trading_mode}. It should be of type TradingEnvironment.TradingMode or int.") 

97 

98 # Initializing the environment 

99 self.__data: dict[pd.DataFrame, pd.DataFrame] = self.__split_data(data, test_ratio) 

100 self.__meta_data: Optional[dict[str, Any]] = meta_data 

101 self.__mode = TradingEnvironment.TRAIN_MODE 

102 self.__trading_mode: TradingEnvironment.TradingMode = trading_mode 

103 self.__should_prefetch: bool = should_prefetch 

104 self.__broker: Broker = Broker() 

105 self.__validator: RewardValidatorBase = validator 

106 self.__label_annotator: LabelAnnotatorBase = label_annotator 

107 self.__labeled_data_balancer: Optional[LabeledDataBalancer] = labeled_data_balancer 

108 

109 # Setting up trading data 

110 self.__trading_data: SimpleNamespace = SimpleNamespace() 

111 self.__trading_data.current_budget: float = initial_budget 

112 self.__trading_data.currently_invested: float = 0 

113 self.__trading_data.no_trades_placed_for: int = 0 

114 self.__trading_data.currently_placed_long_trades: int = 0 

115 self.__trading_data.currently_placed_short_trades: int = 0 

116 

117 # Setting up trading constants 

118 self.__trading_consts = SimpleNamespace() 

119 self.__trading_consts.INITIAL_BUDGET: float = initial_budget 

120 self.__trading_consts.MAX_AMOUNT_OF_TRADES: int = max_amount_of_trades 

121 self.__trading_consts.WINDOW_SIZE: int = window_size 

122 self.__trading_consts.SELL_STOP_LOSS: float = sell_stop_loss 

123 self.__trading_consts.SELL_TAKE_PROFIT: float = sell_take_profit 

124 self.__trading_consts.BUY_STOP_LOSS: float = buy_stop_loss 

125 self.__trading_consts.BUY_TAKE_PROFIT: float = buy_take_profit 

126 self.__trading_consts.STATIC_REWARD_ADJUSTMENT: float = static_reward_adjustment 

127 self.__trading_consts.PENALTY_STARTS: int = penalty_starts 

128 self.__trading_consts.PENALTY_STOPS: int = penalty_stops 

129 self.__trading_consts.PROFITABILITY_FUNCTION = lambda x: -1.0 * math.exp(-x + 1) + 1 

130 self.__trading_consts.PENALTY_FUNCTION = lambda x: \ 

131 min(1, 1 - math.tanh(-3.0 * (x - penalty_stops) / (penalty_stops - penalty_starts))) 

132 self.__trading_consts.OUTPUT_CLASSES: int = vars(self.__label_annotator.get_output_classes()) 

133 

134 # Prefetching data if needed 

135 if self.__should_prefetch: 

136 self.__prefetched_data = { TradingEnvironment.TRAIN_MODE: None, 

137 TradingEnvironment.TEST_MODE: None } 

138 self.__mode = TradingEnvironment.TEST_MODE 

139 self.__prefetched_data[self.__mode] = self.__prefetch_state_data(env_length_range = (self.__trading_consts.WINDOW_SIZE, 

140 self.get_environment_length())) 

141 self.__mode = TradingEnvironment.TRAIN_MODE 

142 self.__prefetched_data[self.__mode] = self.__prefetch_state_data(env_length_range = (self.__trading_consts.WINDOW_SIZE, 

143 self.get_environment_length())) 

144 else: 

145 self.__prefetched_data = None 

146 

147 # Initializing the environment state 

148 self.current_iteration: int = self.__trading_consts.WINDOW_SIZE 

149 self.state: list[float] = self.__get_current_state_data() 

150 self.action_space: Discrete = Discrete(3) 

151 self.observation_space: Box = Box(low = np.ones(len(self.state)) * -3, 

152 high = np.ones(len(self.state)) * 3, 

153 dtype = np.float64) 

154 

155 def __split_data(self, data: pd.DataFrame, test_size: float) -> dict[pd.DataFrame, pd.DataFrame]: 

156 """ 

157 Splits the given DataFrame into training and testing sets based on the specified test size ratio. 

158 

159 Parameters: 

160 data (pd.DataFrame): DataFrame containing the stock market data. 

161 test_size (float): Ratio of the data to be used for testing. 

162 

163 Returns: 

164 (dict[pd.DataFrame, pd.DataFrame]): Dictionary containing training and testing data frames. 

165 """ 

166 

167 dividing_index = int(len(data) * (1 - test_size)) 

168 

169 return { 

170 TradingEnvironment.TRAIN_MODE: data.iloc[:dividing_index].reset_index(drop=True), 

171 TradingEnvironment.TEST_MODE: data.iloc[dividing_index:].reset_index(drop=True) 

172 } 

173 

174 def __standard_scale(self, data: np.ndarray) -> np.ndarray: 

175 """ 

176 Standardizes the given data by removing the mean and scaling to unit variance. 

177 

178 Parameters: 

179 data (np.ndarray): The data to be standardized. 

180 

181 Returns: 

182 (np.ndarray): The standardized data. 

183 """ 

184 

185 mean = np.mean(data, axis = 0, keepdims = True) 

186 std = np.std(data, axis = 0, keepdims = True) 

187 std[std == 0] = 1 

188 

189 return (data - mean) / std 

190 

191 def __prepare_state_data(self, slice_to_get: slice, include_trading_data: bool = True) -> list[float]: 

192 """ 

193 Calculates state data as a list of floats representing current iteration's observation. 

194 Observations contains all input data refined to window size and couple of coefficients 

195 giving an insight into current budget and orders situation. 

196 

197 Parameters: 

198 slice (slice): Slice to get the data from. 

199 include_trading_data (bool): If True, includes trading data in the observation. 

200 

201 Returns: 

202 (list[float]): List with current observations for environment. 

203 """ 

204 

205 current_market_data = self.__data[self.__mode].iloc[slice_to_get].copy() 

206 current_market_data_no_index = current_market_data.select_dtypes(include = [np.number]) 

207 

208 if self.__meta_data is not None and \ 

209 self.__meta_data.get('normalization_groups', None) is not None: 

210 grouped_columns_names = self.__meta_data['normalization_groups'] 

211 preprocessed_data_pieces = [] 

212 left_over_columns_names = set(current_market_data_no_index.columns) 

213 for columns_names_to_normalize in grouped_columns_names: 

214 left_over_columns_names -= set(columns_names_to_normalize) 

215 data_frame_piece_to_normalize = current_market_data_no_index[columns_names_to_normalize] 

216 normalized_data_frame_piece = self.__standard_scale(data_frame_piece_to_normalize.values.reshape(-1, 1)) 

217 preprocessed_data_pieces.append(normalized_data_frame_piece.reshape(*data_frame_piece_to_normalize.shape)) 

218 for column in left_over_columns_names: 

219 preprocessed_data_pieces.append(current_market_data_no_index[column].values.reshape(-1, 1)) 

220 normalized_current_market_data_values = np.hstack(preprocessed_data_pieces) 

221 else: 

222 normalized_current_market_data_values = self.__standard_scale(current_market_data_no_index.values) 

223 current_marked_data_list = normalized_current_market_data_values.ravel().tolist() 

224 

225 if include_trading_data: 

226 current_normalized_budget = 1.0 * self.__trading_data.current_budget / self.__trading_consts.INITIAL_BUDGET 

227 current_profitability_coeff = self.__trading_consts.PROFITABILITY_FUNCTION(current_normalized_budget) 

228 current_no_trades_penalty_coeff = self.__trading_consts.PENALTY_FUNCTION(self.__trading_data.no_trades_placed_for) 

229 current_long_trades_occupancy_coeff = 1.0 * self.__trading_data.currently_placed_long_trades / self.__trading_consts.MAX_AMOUNT_OF_TRADES 

230 current_short_trades_occupancy_coeff = 1.0 * self.__trading_data.currently_placed_short_trades / self.__trading_consts.MAX_AMOUNT_OF_TRADES 

231 current_inner_state_list = [current_profitability_coeff, current_no_trades_penalty_coeff, \ 

232 current_long_trades_occupancy_coeff, current_short_trades_occupancy_coeff] 

233 current_marked_data_list += current_inner_state_list 

234 

235 return current_marked_data_list 

236 

237 def __prefetch_state_data(self, env_length_range: tuple[int, int], include_trading_data: bool = True) -> pd.DataFrame: 

238 """ 

239 Prefetches state data for the given environment length range. 

240 

241 Parameters: 

242 env_length_range (tuple[int, int]): Range to limit the length of the environment. 

243 include_trading_data (bool): If True, includes trading data in the observation. 

244 

245 Returns: 

246 (pd.DataFrame): DataFrame containing the pre-fetched state data. 

247 """ 

248 

249 new_rows = [] 

250 for i in range(env_length_range[0], env_length_range[1]): 

251 data_row = self.__prepare_state_data(slice(i - self.__trading_consts.WINDOW_SIZE, i), include_trading_data = include_trading_data) 

252 new_rows.append(data_row) 

253 

254 return pd.DataFrame(new_rows, columns = [f"feature_{i}" for i in range(len(new_rows[0]))]) 

255 

256 def __prepare_labeled_data(self, env_length_range: tuple[int, int]) -> tuple[pd.DataFrame, pd.Series]: 

257 """ 

258 Prepares labeled data for training the model with classification approach. 

259 It extracts the relevant features and labels from the environment's data. 

260 

261 Parameters: 

262 env_length_range (tuple[int, int]): Range to limit the length 

263 

264 Returns: 

265 (tuple[pd.DataFrame, pd.Series]): A tuple containing the input data and output labels. 

266 """ 

267 

268 prefetched_data = self.__prefetch_state_data(env_length_range, include_trading_data = False) 

269 labels = self.__label_annotator.annotate(self.__data[self.__mode]. \ 

270 iloc[:env_length_range[1]].copy()).shift(-env_length_range[0]).dropna() 

271 

272 return prefetched_data, labels 

273 

274 def __get_current_state_data(self) -> list[float]: 

275 """ 

276 Retrieves the current state data from the environment. 

277 

278 Returns: 

279 (list[float]): List with current observations for environment. 

280 """ 

281 

282 if self.__should_prefetch: 

283 return self.__prefetched_data[self.__mode].iloc[self.current_iteration - self.__trading_consts.WINDOW_SIZE].values.ravel().tolist() 

284 

285 return self.__prepare_state_data(slice_to_get = slice(self.current_iteration - self.__trading_consts.WINDOW_SIZE, self.current_iteration)) 

286 

287 

288 def set_mode(self, mode: str) -> None: 

289 """ 

290 Sets the mode of the environment to either TRAIN_MODE or TEST_MODE. 

291 

292 Parameters: 

293 mode (str): Mode to set for the environment. 

294 

295 Raises: 

296 ValueError: If the provided mode is not valid. 

297 """ 

298 

299 if mode not in [TradingEnvironment.TRAIN_MODE, TradingEnvironment.TEST_MODE]: 

300 raise ValueError(f"Invalid mode: {mode}. Use TradingEnvironment.TRAIN_MODE or TradingEnvironment.TEST_MODE.") 

301 self.__mode = mode 

302 

303 def get_mode(self) -> str: 

304 """ 

305 Mode getter. 

306 

307 Returns: 

308 (str): Current mode of the environment. 

309 """ 

310 

311 return copy.copy(self.__mode) 

312 

313 def get_trading_data(self) -> SimpleNamespace: 

314 """ 

315 Trading data getter. 

316 

317 Returns: 

318 (SimpleNamespace): Copy of the namespace with all trading data. 

319 """ 

320 

321 return copy.copy(self.__trading_data) 

322 

323 def get_number_of_trading_points_per_year(self) -> int: 

324 """ 

325 Returns the number of trading points per year. 

326 

327 Returns: 

328 (int): Number of trading points per year. 

329 """ 

330 

331 temp_data = {"time": pd.to_datetime(self.__data[self.TRAIN_MODE]['time'])} 

332 temp_df = pd.DataFrame(temp_data) 

333 temp_df['year'] = temp_df['time'].dt.year 

334 

335 trading_points_per_year = temp_df.groupby('year').size() 

336 if len(trading_points_per_year) > 3: 

337 # If there are more than three years, return the mode 

338 # of the central years 

339 return trading_points_per_year.iloc[1:-1].mode()[0] 

340 elif len(trading_points_per_year) > 2: 

341 # If there are only three years, return the middle year 

342 return trading_points_per_year.values[-2] 

343 else: 

344 # If there are only two years, return the maximum 

345 return max(trading_points_per_year.values) 

346 

347 def get_trading_consts(self) -> SimpleNamespace: 

348 """ 

349 Trading constants getter. 

350 

351 Returns: 

352 (SimpleNamespace): Copy of the namespace with all trading constants. 

353 """ 

354 

355 return copy.copy(self.__trading_consts) 

356 

357 def get_broker(self) -> Broker: 

358 """ 

359 Broker getter. 

360 

361 Returns: 

362 (Broker): Copy of the broker used by environment. 

363 """ 

364 

365 return copy.copy(self.__broker) 

366 

367 def get_environment_length(self) -> int: 

368 """ 

369 Environment length getter. 

370 

371 Returns: 

372 (Int): Length of environment. 

373 """ 

374 

375 return len(self.__data[self.__mode]) 

376 

377 def get_environment_spatial_data_dimension(self) -> tuple[int, int]: 

378 """ 

379 Environment spatial data dimensionality getter. 

380 

381 Returns: 

382 (Int): Dimension of spatial data in environment. 

383 """ 

384 

385 return (self.__trading_consts.WINDOW_SIZE, self.__data[self.__mode].shape[1] - 1) 

386 

387 def get_labeled_data(self, should_split: bool = True, should_balance: bool = True, 

388 verbose: bool = True, env_length_range: Optional[tuple[int, int]] = None) \ 

389 -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: 

390 """ 

391 Prepares labeled data for training or testing the model. 

392 It extracts the relevant features and labels from the environment's data. 

393 

394 Parameters: 

395 should_split (bool): Whether to split the data into training and testing sets. 

396 Defaults to True. If set to False, testing data will be empty. 

397 should_balance (bool): Whether to balance the labeled data. Defaults to True. 

398 Will be ignored if labeled_data_balancer is None. 

399 verbose (bool): Whether to log the class cardinality before and after balancing. 

400 Defaults to True. 

401 env_length_range (tuple[int, int]): Optional range to limit the range of the environment. 

402 

403 Returns: 

404 (tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]): A tuple containing the 

405 input data, output data, test input data, and test output data. 

406 """ 

407 

408 if env_length_range is None: 

409 env_length_range = (self.__trading_consts.WINDOW_SIZE, self.get_environment_length() - 1) 

410 

411 input_data, output_data = self.__prepare_labeled_data(env_length_range) 

412 input_data_test, output_data_test = [], [] 

413 if verbose: 

414 logging.info(f"Original class cardinality: {np.array(to_categorical(output_data)).sum(axis = 0)}") 

415 

416 if self.__mode == TradingEnvironment.TRAIN_MODE: 

417 if should_split: 

418 input_data, input_data_test, output_data, output_data_test = \ 

419 train_test_split(input_data, output_data, test_size = 0.1, random_state = 42, 

420 stratify = output_data) 

421 

422 if self.__labeled_data_balancer is not None and should_balance: 

423 input_data, output_data = self.__labeled_data_balancer.balance(input_data, output_data) 

424 if verbose: 

425 logging.info(f"Balanced class cardinality: {np.array(to_categorical(output_data)).sum(axis = 0)}") 

426 

427 return copy.copy((np.array(input_data), np.array(output_data), 

428 np.array(input_data_test), np.array(output_data_test))) 

429 

430 def get_data_for_iteration(self, columns: list[str], start: int = 0, stop: Optional[int] = None, 

431 step: int = 1) -> list[float]: 

432 """ 

433 Data getter for certain iterations. 

434 

435 Parameters: 

436 columns (list[str]): List of column names to extract from data. 

437 start (int): Start iteration index. Defaults to 0. 

438 stop (int): Stop iteration index. Defaults to environment length minus one. 

439 step (int): Step between iterations. Defaults to 1. 

440 

441 Returns: 

442 (list[float]): Copy of part of data with specified columns 

443 over specified iterations. 

444 """ 

445 

446 if stop is None: 

447 stop = self.get_environment_length() - 1 

448 

449 return copy.copy(self.__data[self.__mode].loc[start:stop:step, columns].values.ravel().tolist()) 

450 

451 def step(self, action: int) -> tuple[list[float], float, bool, dict]: 

452 """ 

453 Performs specified action on environment. It results in generation of the new 

454 observations. This function causes trades to be handled, reward to be calculated and 

455 environment to be updated. 

456 

457 Parameters: 

458 action (int): Number specifing action. Possible values are 0 for buy action, 

459 1 for wait action and 2 for sell action. 

460 

461 Returns: 

462 (tuple[list[float], float, bool, dict]): Tuple containing next observation 

463 state, reward, finish indication and additional info dictionary. 

464 """ 

465 

466 self.current_iteration += 1 

467 self.state = self.__get_current_state_data() 

468 

469 close_changes = self.__data[self.__mode].iloc[self.current_iteration - 2 : self.current_iteration]['close'].values 

470 stock_change_coeff = 1 + (close_changes[1] - close_changes[0]) / close_changes[0] 

471 closed_orders = self.__broker.update_orders(stock_change_coeff) 

472 

473 if self.__trading_mode == TradingEnvironment.TradingMode.EXPLICIT_ORDER_CLOSING: 

474 current_orders = self.__broker.get_current_orders() 

475 if len(current_orders) > 0: 

476 was_last_order_placed_as_buy = current_orders[-1].is_buy_order 

477 if (action == 0 and not was_last_order_placed_as_buy) or \ 

478 (action == 2 and was_last_order_placed_as_buy): 

479 closed_orders += self.__broker.force_close_orders() 

480 

481 reward = self.__validator.validate_orders(closed_orders) 

482 number_of_closed_long_trades = len([order for order in closed_orders if order.is_buy_order]) 

483 number_of_closed_short_trades = len(closed_orders) - number_of_closed_long_trades 

484 self.__trading_data.currently_placed_long_trades -= number_of_closed_long_trades 

485 self.__trading_data.currently_placed_short_trades -= number_of_closed_short_trades 

486 self.__trading_data.current_budget += np.sum([order.current_value for order in closed_orders]) 

487 self.__trading_data.currently_invested -= np.sum([order.initial_value for order in closed_orders]) 

488 

489 number_of_possible_trades = self.__trading_consts.MAX_AMOUNT_OF_TRADES \ 

490 - self.__trading_data.currently_placed_long_trades - self.__trading_data.currently_placed_short_trades 

491 money_to_trade = 0 

492 if number_of_possible_trades > 0: 

493 money_to_trade = 1.0 / number_of_possible_trades * self.__trading_data.current_budget 

494 

495 if action == 0: 

496 is_buy_order = True 

497 stop_loss = self.__trading_consts.SELL_STOP_LOSS 

498 take_profit = self.__trading_consts.SELL_TAKE_PROFIT 

499 elif action == 2: 

500 is_buy_order = False 

501 stop_loss = self.__trading_consts.BUY_STOP_LOSS 

502 take_profit = self.__trading_consts.BUY_TAKE_PROFIT 

503 

504 if action != 1: 

505 if number_of_possible_trades > 0: 

506 self.__trading_data.current_budget -= money_to_trade 

507 self.__trading_data.currently_invested += money_to_trade 

508 self.__broker.place_order(money_to_trade, is_buy_order, stop_loss, take_profit) 

509 self.__trading_data.no_trades_placed_for = 0 

510 reward += self.__trading_consts.STATIC_REWARD_ADJUSTMENT 

511 if action == 0: 

512 self.__trading_data.currently_placed_long_trades += 1 

513 elif action == 2: 

514 self.__trading_data.currently_placed_short_trades += 1 

515 else: 

516 self.__trading_data.no_trades_placed_for += 1 

517 reward -= self.__trading_consts.STATIC_REWARD_ADJUSTMENT 

518 else: 

519 self.__trading_data.no_trades_placed_for += 1 

520 if number_of_possible_trades == 0: 

521 reward += self.__trading_consts.STATIC_REWARD_ADJUSTMENT 

522 

523 if number_of_possible_trades > 0: 

524 reward *= (1 - self.__trading_consts.PENALTY_FUNCTION(self.__trading_data.no_trades_placed_for)) \ 

525 if reward > 0 else 1 

526 if self.__trading_consts.PENALTY_STOPS < self.__trading_data.no_trades_placed_for: 

527 reward -= self.__trading_consts.STATIC_REWARD_ADJUSTMENT 

528 

529 if (self.current_iteration >= self.get_environment_length() - 1 or 

530 self.__trading_data.current_budget > 10 * self.__trading_consts.INITIAL_BUDGET or 

531 (self.__trading_data.current_budget + self.__trading_data.currently_invested) / self.__trading_consts.INITIAL_BUDGET < 0.6): 

532 done = True 

533 else: 

534 done = False 

535 

536 info = {'coeff': stock_change_coeff, 

537 'iteration': self.current_iteration, 

538 'number_of_closed_orders': len(closed_orders), 

539 'money_to_trade': money_to_trade, 

540 'action': action, 

541 'current_budget': self.__trading_data.current_budget, 

542 'currently_invested': self.__trading_data.currently_invested, 

543 'no_trades_placed_for': self.__trading_data.no_trades_placed_for, 

544 'currently_placed_long_trades': self.__trading_data.currently_placed_long_trades, 

545 'currently_placed_short_trades': self.__trading_data.currently_placed_short_trades} 

546 

547 return self.state, reward, done, info 

548 

549 def render(self) -> None: 

550 """ 

551 Renders environment visualization. Will be implemented later. 

552 """ 

553 

554 #TODO: Visualization to be implemented 

555 pass 

556 

557 def reset(self, randkey: Optional[int] = None) -> list[float]: 

558 """ 

559 Resets environment. Used typically if environemnt is finished, 

560 i.e. when ther is no more steps to be taken within environemnt 

561 or finish conditions are fulfilled. 

562 

563 Parameters: 

564 randkey (Optional[int]): Value indicating what iteration 

565 should be trated as starting point after reset. 

566 

567 Returns: 

568 (list[float]): Current iteration observation state. 

569 """ 

570 

571 if randkey is None: 

572 randkey = random.randint(self.__trading_consts.WINDOW_SIZE, self.get_environment_length() - 1) 

573 self.__trading_data.current_budget = self.__trading_consts.INITIAL_BUDGET 

574 self.__trading_data.currently_invested = 0 

575 self.__trading_data.no_trades_placed_for = 0 

576 self.__trading_data.currently_placed_long_trades = 0 

577 self.__trading_data.currently_placed_short_trades = 0 

578 self.__broker.reset() 

579 self.current_iteration = randkey 

580 self.state = self.__get_current_state_data() 

581 

582 return self.state