Coverage for source/environment/trading_environment.py: 80%

223 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-09-21 11:29 +0000

1# environment/trading environment.py 

2 

3# global imports 

4import copy 

5import logging 

6import math 

7import numpy as np 

8import pandas as pd 

9import random 

10from enum import Enum 

11from gym import Env 

12from gym.spaces import Box, Discrete 

13from sklearn.model_selection import train_test_split 

14from tensorflow.keras.utils import to_categorical 

15from types import SimpleNamespace 

16from typing import Any, Optional 

17 

18# local imports 

19from source.environment import Broker, LabelAnnotatorBase, LabeledDataBalancer, RewardValidatorBase 

20 

21class TradingEnvironment(Env): 

22 """ 

23 Implements stock market environment that actor can perform actions (place orders) in. 

24 It is used to train various models using various approaches. Can be 

25 configured to award points and impose a penalty in several ways. 

26 """ 

27 

28 class TradingMode(Enum): 

29 """ 

30 Enumeration for the different trading modes. 

31 """ 

32 

33 IMPLICIT_ORDER_CLOSING = 0 

34 EXPLICIT_ORDER_CLOSING = 1 

35 

36 # global class constants 

37 TRAIN_MODE = 'train' 

38 TEST_MODE = 'test' 

39 

40 def __init__(self, data: pd.DataFrame, initial_budget: float, max_amount_of_trades: int, 

41 window_size: int, validator: RewardValidatorBase, label_annotator: LabelAnnotatorBase, 

42 sell_stop_loss: float, sell_take_profit: float, buy_stop_loss: float, buy_take_profit: float, 

43 test_ratio: float = 0.2, penalty_starts: int = 0, penalty_stops: int = 10, 

44 static_reward_adjustment: float = 1, labeled_data_balancer: Optional[LabeledDataBalancer] = None, 

45 meta_data: Optional[dict[str, Any]] = None, trading_mode: Optional[TradingMode] = None, 

46 should_prefetch: bool = True) -> None: 

47 """ 

48 Class constructor. Allows to define all crucial constans, reward validation methods, 

49 environmental penalty policies, etc. 

50 

51 Parameters: 

52 data (pd.DataFrame): DataFrame containing historical market data. 

53 initial_budget (float): Initial budget constant for trader to start from. 

54 max_amount_of_trades (int): Max amount of trades that can be ongoing at the same time. 

55 Seting this constant prevents traders from placing orders randomly and defines 

56 amount of money that can be assigned to a single order at certain iteration. 

57 window_size (int): Constant defining how far in the past trader will be able to look 

58 into at certain iteration. 

59 validator (RewardValidatorBase): Validator implementing policy used to award points 

60 for closed trades. 

61 label_annotator (LabelAnnotatorBase): Annotator implementing policy used to label 

62 data with target values. It is used to provide supervised agents with information 

63 about what is the target class value for certain iteration. 

64 sell_stop_loss (float): Constant used to define losing boundary at which sell order 

65 (short) is closed. 

66 sell_take_profit (float): Constant used to define winning boundary at which sell order 

67 (short) is closed. 

68 buy_stop_loss (float): Constant used to define losing boundary at which buy order 

69 (long) is closed. 

70 buy_take_profit (float): Constant used to define winning boundary at which buy order 

71 (long) is closed. 

72 test_ratio (float): Ratio of data that should be used for testing purposes. 

73 penalty_starts (int): Constant defining how many trading periods can trader go without placing 

74 an order until penalty is imposed. Penalty at range between start and stop constant 

75 is calculated as percentile of positive reward, and subtracted from the actual reward. 

76 penalty_stops (int): Constant defining at which trading period penalty will no longer be increased. 

77 Reward for trading periods exceeding penalty stop constant will equal minus static reward adjustment. 

78 static_reward_adjustment (float): Constant use to penalize trader for bad choices or 

79 reward it for good one. 

80 labeled_data_balancer (Optional[LabeledDataBalancer]): Balancer used to balance 

81 labeled data. If None, no balancing will be performed. 

82 meta_data (dict[str, Any]): Dictionary containing metadata about the dataset. 

83 mode (TradingMode): Mode of the environment, either IMPLICIT_ORDER_CLOSING or EXPLICIT_ORDER_CLOSING. 

84 should_prefetch (bool): If True, data will be pre-fetched to speed up training. 

85 """ 

86 

87 if test_ratio < 0.0 or test_ratio >= 1.0: 

88 raise ValueError(f"Invalid test_ratio: {test_ratio}. It should be in range [0, 1).") 

89 

90 if trading_mode is None: 

91 trading_mode = TradingEnvironment.TradingMode.IMPLICIT_ORDER_CLOSING 

92 elif isinstance(trading_mode, int): 

93 trading_mode = TradingEnvironment.TradingMode(trading_mode) 

94 

95 if not isinstance(trading_mode, TradingEnvironment.TradingMode): 

96 raise ValueError(f"Invalid trading_mode: {trading_mode}. It should be of type TradingEnvironment.TradingMode or int.") 

97 

98 # Initializing the environment 

99 self.__data: dict[pd.DataFrame, pd.DataFrame] = self.__split_data(data, test_ratio) 

100 self.__meta_data: Optional[dict[str, Any]] = meta_data 

101 self.__mode = TradingEnvironment.TRAIN_MODE 

102 self.__trading_mode: TradingEnvironment.TradingMode = trading_mode 

103 self.__should_prefetch: bool = should_prefetch 

104 self.__broker: Broker = Broker() 

105 self.__validator: RewardValidatorBase = validator 

106 self.__label_annotator: LabelAnnotatorBase = label_annotator 

107 self.__labeled_data_balancer: Optional[LabeledDataBalancer] = labeled_data_balancer 

108 

109 # Setting up trading data 

110 self.__trading_data: SimpleNamespace = SimpleNamespace() 

111 self.__trading_data.current_budget: float = initial_budget 

112 self.__trading_data.currently_invested: float = 0 

113 self.__trading_data.no_trades_placed_for: int = 0 

114 self.__trading_data.currently_placed_trades: int = 0 

115 

116 # Setting up trading constants 

117 self.__trading_consts = SimpleNamespace() 

118 self.__trading_consts.INITIAL_BUDGET: float = initial_budget 

119 self.__trading_consts.MAX_AMOUNT_OF_TRADES: int = max_amount_of_trades 

120 self.__trading_consts.WINDOW_SIZE: int = window_size 

121 self.__trading_consts.SELL_STOP_LOSS: float = sell_stop_loss 

122 self.__trading_consts.SELL_TAKE_PROFIT: float = sell_take_profit 

123 self.__trading_consts.BUY_STOP_LOSS: float = buy_stop_loss 

124 self.__trading_consts.BUY_TAKE_PROFIT: float = buy_take_profit 

125 self.__trading_consts.STATIC_REWARD_ADJUSTMENT: float = static_reward_adjustment 

126 self.__trading_consts.PENALTY_STARTS: int = penalty_starts 

127 self.__trading_consts.PENALTY_STOPS: int = penalty_stops 

128 self.__trading_consts.PROFITABILITY_FUNCTION = lambda x: -1.0 * math.exp(-x + 1) + 1 

129 self.__trading_consts.PENALTY_FUNCTION = lambda x: \ 

130 min(1, 1 - math.tanh(-3.0 * (x - penalty_stops) / (penalty_stops - penalty_starts))) 

131 self.__trading_consts.OUTPUT_CLASSES: int = vars(self.__label_annotator.get_output_classes()) 

132 

133 # Prefetching data if needed 

134 if self.__should_prefetch: 

135 self.__prefetched_data = { TradingEnvironment.TRAIN_MODE: None, 

136 TradingEnvironment.TEST_MODE: None } 

137 self.__mode = TradingEnvironment.TEST_MODE 

138 self.__prefetched_data[self.__mode] = self.__prefetch_state_data(env_length_range = (self.__trading_consts.WINDOW_SIZE, 

139 self.get_environment_length())) 

140 self.__mode = TradingEnvironment.TRAIN_MODE 

141 self.__prefetched_data[self.__mode] = self.__prefetch_state_data(env_length_range = (self.__trading_consts.WINDOW_SIZE, 

142 self.get_environment_length())) 

143 else: 

144 self.__prefetched_data = None 

145 

146 # Initializing the environment state 

147 self.current_iteration: int = self.__trading_consts.WINDOW_SIZE 

148 self.state: list[float] = self.__get_current_state_data() 

149 self.action_space: Discrete = Discrete(3) 

150 self.observation_space: Box = Box(low = np.ones(len(self.state)) * -3, 

151 high = np.ones(len(self.state)) * 3, 

152 dtype = np.float64) 

153 

154 def __split_data(self, data: pd.DataFrame, test_size: float) -> dict[pd.DataFrame, pd.DataFrame]: 

155 """ 

156 Splits the given DataFrame into training and testing sets based on the specified test size ratio. 

157 

158 Parameters: 

159 data (pd.DataFrame): DataFrame containing the stock market data. 

160 test_size (float): Ratio of the data to be used for testing. 

161 

162 Returns: 

163 (dict[pd.DataFrame, pd.DataFrame]): Dictionary containing training and testing data frames. 

164 """ 

165 

166 dividing_index = int(len(data) * (1 - test_size)) 

167 

168 return { 

169 TradingEnvironment.TRAIN_MODE: data.iloc[:dividing_index].reset_index(drop=True), 

170 TradingEnvironment.TEST_MODE: data.iloc[dividing_index:].reset_index(drop=True) 

171 } 

172 

173 def __standard_scale(self, data: np.ndarray) -> np.ndarray: 

174 """ 

175 Standardizes the given data by removing the mean and scaling to unit variance. 

176 

177 Parameters: 

178 data (np.ndarray): The data to be standardized. 

179 

180 Returns: 

181 (np.ndarray): The standardized data. 

182 """ 

183 

184 mean = np.mean(data, axis = 0, keepdims = True) 

185 std = np.std(data, axis = 0, keepdims = True) 

186 std[std == 0] = 1 

187 

188 return (data - mean) / std 

189 

190 def __prepare_state_data(self, slice_to_get: slice, include_trading_data: bool = True) -> list[float]: 

191 """ 

192 Calculates state data as a list of floats representing current iteration's observation. 

193 Observations contains all input data refined to window size and couple of coefficients 

194 giving an insight into current budget and orders situation. 

195 

196 Parameters: 

197 slice (slice): Slice to get the data from. 

198 include_trading_data (bool): If True, includes trading data in the observation. 

199 

200 Returns: 

201 (list[float]): List with current observations for environment. 

202 """ 

203 

204 current_market_data = self.__data[self.__mode].iloc[slice_to_get].copy() 

205 current_market_data_no_index = current_market_data.select_dtypes(include = [np.number]) 

206 

207 if self.__meta_data is not None and \ 

208 self.__meta_data.get('normalization_groups', None) is not None: 

209 grouped_columns_names = self.__meta_data['normalization_groups'] 

210 preprocessed_data_pieces = [] 

211 left_over_columns_names = set(current_market_data_no_index.columns) 

212 for columns_names_to_normalize in grouped_columns_names: 

213 left_over_columns_names -= set(columns_names_to_normalize) 

214 data_frame_piece_to_normalize = current_market_data_no_index[columns_names_to_normalize] 

215 normalized_data_frame_piece = self.__standard_scale(data_frame_piece_to_normalize.values.reshape(-1, 1)) 

216 preprocessed_data_pieces.append(normalized_data_frame_piece.reshape(*data_frame_piece_to_normalize.shape)) 

217 for column in left_over_columns_names: 

218 preprocessed_data_pieces.append(current_market_data_no_index[column].values.reshape(-1, 1)) 

219 normalized_current_market_data_values = np.hstack(preprocessed_data_pieces) 

220 else: 

221 normalized_current_market_data_values = self.__standard_scale(current_market_data_no_index.values) 

222 current_marked_data_list = normalized_current_market_data_values.ravel().tolist() 

223 

224 if include_trading_data: 

225 current_normalized_budget = 1.0 * self.__trading_data.current_budget / self.__trading_consts.INITIAL_BUDGET 

226 current_profitability_coeff = self.__trading_consts.PROFITABILITY_FUNCTION(current_normalized_budget) 

227 current_trades_occupancy_coeff = 1.0 * self.__trading_data.currently_placed_trades / self.__trading_consts.MAX_AMOUNT_OF_TRADES 

228 current_no_trades_penalty_coeff = self.__trading_consts.PENALTY_FUNCTION(self.__trading_data.no_trades_placed_for) 

229 current_inner_state_list = [current_profitability_coeff, current_trades_occupancy_coeff, current_no_trades_penalty_coeff] 

230 current_marked_data_list += current_inner_state_list 

231 

232 return current_marked_data_list 

233 

234 def __prefetch_state_data(self, env_length_range: tuple[int, int], include_trading_data: bool = True) -> pd.DataFrame: 

235 """ 

236 Prefetches state data for the given environment length range. 

237 

238 Parameters: 

239 env_length_range (tuple[int, int]): Range to limit the length of the environment. 

240 include_trading_data (bool): If True, includes trading data in the observation. 

241 

242 Returns: 

243 (pd.DataFrame): DataFrame containing the pre-fetched state data. 

244 """ 

245 

246 new_rows = [] 

247 for i in range(env_length_range[0], env_length_range[1]): 

248 data_row = self.__prepare_state_data(slice(i - self.__trading_consts.WINDOW_SIZE, i), include_trading_data = include_trading_data) 

249 new_rows.append(data_row) 

250 

251 return pd.DataFrame(new_rows, columns = [f"feature_{i}" for i in range(len(new_rows[0]))]) 

252 

253 def __prepare_labeled_data(self, env_length_range: tuple[int, int]) -> tuple[pd.DataFrame, pd.Series]: 

254 """ 

255 Prepares labeled data for training the model with classification approach. 

256 It extracts the relevant features and labels from the environment's data. 

257 

258 Parameters: 

259 env_length_range (tuple[int, int]): Range to limit the length 

260 

261 Returns: 

262 (tuple[pd.DataFrame, pd.Series]): A tuple containing the input data and output labels. 

263 """ 

264 

265 prefetched_data = self.__prefetch_state_data(env_length_range, include_trading_data = False) 

266 labels = self.__label_annotator.annotate(self.__data[self.__mode]. \ 

267 iloc[:env_length_range[1]].copy()).shift(-env_length_range[0]).dropna() 

268 

269 return prefetched_data, labels 

270 

271 def __get_current_state_data(self) -> list[float]: 

272 """ 

273 Retrieves the current state data from the environment. 

274 

275 Returns: 

276 (list[float]): List with current observations for environment. 

277 """ 

278 

279 if self.__should_prefetch: 

280 return self.__prefetched_data[self.__mode].iloc[self.current_iteration - self.__trading_consts.WINDOW_SIZE].values.ravel().tolist() 

281 

282 return self.__prepare_state_data(slice_to_get = slice(self.current_iteration - self.__trading_consts.WINDOW_SIZE, self.current_iteration)) 

283 

284 

285 def set_mode(self, mode: str) -> None: 

286 """ 

287 Sets the mode of the environment to either TRAIN_MODE or TEST_MODE. 

288 

289 Parameters: 

290 mode (str): Mode to set for the environment. 

291 

292 Raises: 

293 ValueError: If the provided mode is not valid. 

294 """ 

295 

296 if mode not in [TradingEnvironment.TRAIN_MODE, TradingEnvironment.TEST_MODE]: 

297 raise ValueError(f"Invalid mode: {mode}. Use TradingEnvironment.TRAIN_MODE or TradingEnvironment.TEST_MODE.") 

298 self.__mode = mode 

299 

300 def get_mode(self) -> str: 

301 """ 

302 Mode getter. 

303 

304 Returns: 

305 (str): Current mode of the environment. 

306 """ 

307 

308 return copy.copy(self.__mode) 

309 

310 def get_trading_data(self) -> SimpleNamespace: 

311 """ 

312 Trading data getter. 

313 

314 Returns: 

315 (SimpleNamespace): Copy of the namespace with all trading data. 

316 """ 

317 

318 return copy.copy(self.__trading_data) 

319 

320 def get_number_of_trading_points_per_year(self) -> int: 

321 """ 

322 Returns the number of trading points per year. 

323 

324 Returns: 

325 (int): Number of trading points per year. 

326 """ 

327 

328 temp_data = {"time": pd.to_datetime(self.__data[self.TRAIN_MODE]['time'])} 

329 temp_df = pd.DataFrame(temp_data) 

330 temp_df['year'] = temp_df['time'].dt.year 

331 

332 trading_points_per_year = temp_df.groupby('year').size() 

333 if len(trading_points_per_year) > 3: 

334 # If there are more than three years, return the mode 

335 # of the central years 

336 return trading_points_per_year.iloc[1:-1].mode()[0] 

337 elif len(trading_points_per_year) > 2: 

338 # If there are only three years, return the middle year 

339 return trading_points_per_year.values[-2] 

340 else: 

341 # If there are only two years, return the maximum 

342 return max(trading_points_per_year.values) 

343 

344 def get_trading_consts(self) -> SimpleNamespace: 

345 """ 

346 Trading constants getter. 

347 

348 Returns: 

349 (SimpleNamespace): Copy of the namespace with all trading constants. 

350 """ 

351 

352 return copy.copy(self.__trading_consts) 

353 

354 def get_broker(self) -> Broker: 

355 """ 

356 Broker getter. 

357 

358 Returns: 

359 (Broker): Copy of the broker used by environment. 

360 """ 

361 

362 return copy.copy(self.__broker) 

363 

364 def get_environment_length(self) -> int: 

365 """ 

366 Environment length getter. 

367 

368 Returns: 

369 (Int): Length of environment. 

370 """ 

371 

372 return len(self.__data[self.__mode]) 

373 

374 def get_environment_spatial_data_dimension(self) -> tuple[int, int]: 

375 """ 

376 Environment spatial data dimensionality getter. 

377 

378 Returns: 

379 (Int): Dimension of spatial data in environment. 

380 """ 

381 

382 return (self.__trading_consts.WINDOW_SIZE, self.__data[self.__mode].shape[1] - 1) 

383 

384 def get_labeled_data(self, should_split: bool = True, should_balance: bool = True, 

385 verbose: bool = True, env_length_range: Optional[tuple[int, int]] = None) \ 

386 -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: 

387 """ 

388 Prepares labeled data for training or testing the model. 

389 It extracts the relevant features and labels from the environment's data. 

390 

391 Parameters: 

392 should_split (bool): Whether to split the data into training and testing sets. 

393 Defaults to True. If set to False, testing data will be empty. 

394 should_balance (bool): Whether to balance the labeled data. Defaults to True. 

395 Will be ignored if labeled_data_balancer is None. 

396 verbose (bool): Whether to log the class cardinality before and after balancing. 

397 Defaults to True. 

398 env_length_range (tuple[int, int]): Optional range to limit the range of the environment. 

399 

400 Returns: 

401 (tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]): A tuple containing the 

402 input data, output data, test input data, and test output data. 

403 """ 

404 

405 if env_length_range is None: 

406 env_length_range = (self.__trading_consts.WINDOW_SIZE, self.get_environment_length() - 1) 

407 

408 input_data, output_data = self.__prepare_labeled_data(env_length_range) 

409 input_data_test, output_data_test = [], [] 

410 if verbose: 

411 logging.info(f"Original class cardinality: {np.array(to_categorical(output_data)).sum(axis = 0)}") 

412 

413 if self.__mode == TradingEnvironment.TRAIN_MODE: 

414 if should_split: 

415 input_data, input_data_test, output_data, output_data_test = \ 

416 train_test_split(input_data, output_data, test_size = 0.1, random_state = 42, 

417 stratify = output_data) 

418 

419 if self.__labeled_data_balancer is not None and should_balance: 

420 input_data, output_data = self.__labeled_data_balancer.balance(input_data, output_data) 

421 if verbose: 

422 logging.info(f"Balanced class cardinality: {np.array(to_categorical(output_data)).sum(axis = 0)}") 

423 

424 return copy.copy((np.array(input_data), np.array(output_data), 

425 np.array(input_data_test), np.array(output_data_test))) 

426 

427 def get_data_for_iteration(self, columns: list[str], start: int = 0, stop: Optional[int] = None, 

428 step: int = 1) -> list[float]: 

429 """ 

430 Data getter for certain iterations. 

431 

432 Parameters: 

433 columns (list[str]): List of column names to extract from data. 

434 start (int): Start iteration index. Defaults to 0. 

435 stop (int): Stop iteration index. Defaults to environment length minus one. 

436 step (int): Step between iterations. Defaults to 1. 

437 

438 Returns: 

439 (list[float]): Copy of part of data with specified columns 

440 over specified iterations. 

441 """ 

442 

443 if stop is None: 

444 stop = self.get_environment_length() - 1 

445 

446 return copy.copy(self.__data[self.__mode].loc[start:stop:step, columns].values.ravel().tolist()) 

447 

448 def step(self, action: int) -> tuple[list[float], float, bool, dict]: 

449 """ 

450 Performs specified action on environment. It results in generation of the new 

451 observations. This function causes trades to be handled, reward to be calculated and 

452 environment to be updated. 

453 

454 Parameters: 

455 action (int): Number specifing action. Possible values are 0 for buy action, 

456 1 for wait action and 2 for sell action. 

457 

458 Returns: 

459 (tuple[list[float], float, bool, dict]): Tuple containing next observation 

460 state, reward, finish indication and additional info dictionary. 

461 """ 

462 

463 self.current_iteration += 1 

464 self.state = self.__get_current_state_data() 

465 

466 close_changes = self.__data[self.__mode].iloc[self.current_iteration - 2 : self.current_iteration]['close'].values 

467 stock_change_coeff = 1 + (close_changes[1] - close_changes[0]) / close_changes[0] 

468 closed_orders = self.__broker.update_orders(stock_change_coeff) 

469 

470 if self.__trading_mode == TradingEnvironment.TradingMode.EXPLICIT_ORDER_CLOSING: 

471 current_orders = self.__broker.get_current_orders() 

472 if len(current_orders) > 0: 

473 was_last_order_placed_as_buy = current_orders[-1].is_buy_order 

474 if (action == 0 and not was_last_order_placed_as_buy) or \ 

475 (action == 2 and was_last_order_placed_as_buy): 

476 closed_orders += self.__broker.force_close_orders() 

477 

478 reward = self.__validator.validate_orders(closed_orders) 

479 self.__trading_data.currently_placed_trades -= len(closed_orders) 

480 self.__trading_data.current_budget += np.sum([trade.current_value for trade in closed_orders]) 

481 self.__trading_data.currently_invested -= np.sum([trade.initial_value for trade in closed_orders]) 

482 

483 number_of_possible_trades = self.__trading_consts.MAX_AMOUNT_OF_TRADES - self.__trading_data.currently_placed_trades 

484 money_to_trade = 0 

485 if number_of_possible_trades > 0: 

486 money_to_trade = 1.0 / number_of_possible_trades * self.__trading_data.current_budget 

487 

488 if action == 0: 

489 is_buy_order = True 

490 stop_loss = self.__trading_consts.SELL_STOP_LOSS 

491 take_profit = self.__trading_consts.SELL_TAKE_PROFIT 

492 elif action == 2: 

493 is_buy_order = False 

494 stop_loss = self.__trading_consts.BUY_STOP_LOSS 

495 take_profit = self.__trading_consts.BUY_TAKE_PROFIT 

496 

497 if action != 1: 

498 if number_of_possible_trades > 0: 

499 self.__trading_data.current_budget -= money_to_trade 

500 self.__trading_data.currently_invested += money_to_trade 

501 self.__broker.place_order(money_to_trade, is_buy_order, stop_loss, take_profit) 

502 self.__trading_data.currently_placed_trades += 1 

503 self.__trading_data.no_trades_placed_for = 0 

504 reward += self.__trading_consts.STATIC_REWARD_ADJUSTMENT 

505 else: 

506 self.__trading_data.no_trades_placed_for += 1 

507 reward -= self.__trading_consts.STATIC_REWARD_ADJUSTMENT 

508 else: 

509 self.__trading_data.no_trades_placed_for += 1 

510 if number_of_possible_trades == 0: 

511 reward += self.__trading_consts.STATIC_REWARD_ADJUSTMENT 

512 

513 if number_of_possible_trades > 0: 

514 reward *= (1 - self.__trading_consts.PENALTY_FUNCTION(self.__trading_data.no_trades_placed_for)) \ 

515 if reward > 0 else 1 

516 if self.__trading_consts.PENALTY_STOPS < self.__trading_data.no_trades_placed_for: 

517 reward -= self.__trading_consts.STATIC_REWARD_ADJUSTMENT 

518 

519 if (self.current_iteration >= self.get_environment_length() - 1 or 

520 self.__trading_data.current_budget > 10 * self.__trading_consts.INITIAL_BUDGET or 

521 (self.__trading_data.current_budget + self.__trading_data.currently_invested) / self.__trading_consts.INITIAL_BUDGET < 0.6): 

522 done = True 

523 else: 

524 done = False 

525 

526 info = {'coeff': stock_change_coeff, 

527 'iteration': self.current_iteration, 

528 'number_of_closed_orders': len(closed_orders), 

529 'money_to_trade': money_to_trade, 

530 'action': action, 

531 'current_budget': self.__trading_data.current_budget, 

532 'currently_invested': self.__trading_data.currently_invested, 

533 'no_trades_placed_for': self.__trading_data.no_trades_placed_for, 

534 'currently_placed_trades': self.__trading_data.currently_placed_trades} 

535 

536 return self.state, reward, done, info 

537 

538 def render(self) -> None: 

539 """ 

540 Renders environment visualization. Will be implemented later. 

541 """ 

542 

543 #TODO: Visualization to be implemented 

544 pass 

545 

546 def reset(self, randkey: Optional[int] = None) -> list[float]: 

547 """ 

548 Resets environment. Used typically if environemnt is finished, 

549 i.e. when ther is no more steps to be taken within environemnt 

550 or finish conditions are fulfilled. 

551 

552 Parameters: 

553 randkey (Optional[int]): Value indicating what iteration 

554 should be trated as starting point after reset. 

555 

556 Returns: 

557 (list[float]): Current iteration observation state. 

558 """ 

559 

560 if randkey is None: 

561 randkey = random.randint(self.__trading_consts.WINDOW_SIZE, self.get_environment_length() - 1) 

562 self.__trading_data.current_budget = self.__trading_consts.INITIAL_BUDGET 

563 self.__trading_data.currently_invested = 0 

564 self.__trading_data.no_trades_placed_for = 0 

565 self.__trading_data.currently_placed_trades = 0 

566 self.__broker.reset() 

567 self.current_iteration = randkey 

568 self.state = self.__get_current_state_data() 

569 

570 return self.state