Coverage for source/environment/trading_environment.py: 81%

220 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-09-14 17:40 +0000

1# environment/trading environment.py 

2 

3# global imports 

4import copy 

5import logging 

6import math 

7import numpy as np 

8import pandas as pd 

9import random 

10from enum import Enum 

11from gym import Env 

12from gym.spaces import Box, Discrete 

13from sklearn.preprocessing import StandardScaler 

14from sklearn.model_selection import train_test_split 

15from tensorflow.keras.utils import to_categorical 

16from types import SimpleNamespace 

17from typing import Any, Optional 

18 

19# local imports 

20from source.environment import Broker, LabelAnnotatorBase, LabeledDataBalancer, RewardValidatorBase 

21 

22class TradingEnvironment(Env): 

23 """ 

24 Implements stock market environment that actor can perform actions (place orders) in. 

25 It is used to train various models using various approaches. Can be 

26 configured to award points and impose a penalty in several ways. 

27 """ 

28 

29 class TradingMode(Enum): 

30 """ 

31 Enumeration for the different trading modes. 

32 """ 

33 

34 IMPLICIT_ORDER_CLOSING = 0 

35 EXPLICIT_ORDER_CLOSING = 1 

36 

37 # global class constants 

38 TRAIN_MODE = 'train' 

39 TEST_MODE = 'test' 

40 

41 def __init__(self, data: pd.DataFrame, initial_budget: float, max_amount_of_trades: int, 

42 window_size: int, validator: RewardValidatorBase, label_annotator: LabelAnnotatorBase, 

43 sell_stop_loss: float, sell_take_profit: float, buy_stop_loss: float, buy_take_profit: float, 

44 test_ratio: float = 0.2, penalty_starts: int = 0, penalty_stops: int = 10, 

45 static_reward_adjustment: float = 1, labeled_data_balancer: Optional[LabeledDataBalancer] = None, 

46 meta_data: Optional[dict[str, Any]] = None, trading_mode: Optional[TradingMode] = None, 

47 should_prefetch: bool = True) -> None: 

48 """ 

49 Class constructor. Allows to define all crucial constans, reward validation methods, 

50 environmental penalty policies, etc. 

51 

52 Parameters: 

53 data (pd.DataFrame): DataFrame containing historical market data. 

54 initial_budget (float): Initial budget constant for trader to start from. 

55 max_amount_of_trades (int): Max amount of trades that can be ongoing at the same time. 

56 Seting this constant prevents traders from placing orders randomly and defines 

57 amount of money that can be assigned to a single order at certain iteration. 

58 window_size (int): Constant defining how far in the past trader will be able to look 

59 into at certain iteration. 

60 validator (RewardValidatorBase): Validator implementing policy used to award points 

61 for closed trades. 

62 label_annotator (LabelAnnotatorBase): Annotator implementing policy used to label 

63 data with target values. It is used to provide supervised agents with information 

64 about what is the target class value for certain iteration. 

65 sell_stop_loss (float): Constant used to define losing boundary at which sell order 

66 (short) is closed. 

67 sell_take_profit (float): Constant used to define winning boundary at which sell order 

68 (short) is closed. 

69 buy_stop_loss (float): Constant used to define losing boundary at which buy order 

70 (long) is closed. 

71 buy_take_profit (float): Constant used to define winning boundary at which buy order 

72 (long) is closed. 

73 test_ratio (float): Ratio of data that should be used for testing purposes. 

74 penalty_starts (int): Constant defining how many trading periods can trader go without placing 

75 an order until penalty is imposed. Penalty at range between start and stop constant 

76 is calculated as percentile of positive reward, and subtracted from the actual reward. 

77 penalty_stops (int): Constant defining at which trading period penalty will no longer be increased. 

78 Reward for trading periods exceeding penalty stop constant will equal minus static reward adjustment. 

79 static_reward_adjustment (float): Constant use to penalize trader for bad choices or 

80 reward it for good one. 

81 labeled_data_balancer (Optional[LabeledDataBalancer]): Balancer used to balance 

82 labeled data. If None, no balancing will be performed. 

83 meta_data (dict[str, Any]): Dictionary containing metadata about the dataset. 

84 mode (TradingMode): Mode of the environment, either IMPLICIT_ORDER_CLOSING or EXPLICIT_ORDER_CLOSING. 

85 should_prefetch (bool): If True, data will be pre-fetched to speed up training. 

86 """ 

87 

88 if test_ratio < 0.0 or test_ratio >= 1.0: 

89 raise ValueError(f"Invalid test_ratio: {test_ratio}. It should be in range [0, 1).") 

90 

91 if trading_mode is None: 

92 trading_mode = TradingEnvironment.TradingMode.IMPLICIT_ORDER_CLOSING 

93 

94 # Initializing the environment 

95 self.__data: dict[pd.DataFrame, pd.DataFrame] = self.__split_data(data, test_ratio) 

96 self.__meta_data: Optional[dict[str, Any]] = meta_data 

97 self.__mode = TradingEnvironment.TRAIN_MODE 

98 self.__trading_mode: TradingEnvironment.TradingMode = trading_mode 

99 self.__should_prefetch: bool = should_prefetch 

100 self.__broker: Broker = Broker() 

101 self.__validator: RewardValidatorBase = validator 

102 self.__label_annotator: LabelAnnotatorBase = label_annotator 

103 self.__labeled_data_balancer: Optional[LabeledDataBalancer] = labeled_data_balancer 

104 

105 # Setting up trading data 

106 self.__trading_data: SimpleNamespace = SimpleNamespace() 

107 self.__trading_data.current_budget: float = initial_budget 

108 self.__trading_data.currently_invested: float = 0 

109 self.__trading_data.no_trades_placed_for: int = 0 

110 self.__trading_data.currently_placed_trades: int = 0 

111 

112 # Setting up trading constants 

113 self.__trading_consts = SimpleNamespace() 

114 self.__trading_consts.INITIAL_BUDGET: float = initial_budget 

115 self.__trading_consts.MAX_AMOUNT_OF_TRADES: int = max_amount_of_trades 

116 self.__trading_consts.WINDOW_SIZE: int = window_size 

117 self.__trading_consts.SELL_STOP_LOSS: float = sell_stop_loss 

118 self.__trading_consts.SELL_TAKE_PROFIT: float = sell_take_profit 

119 self.__trading_consts.BUY_STOP_LOSS: float = buy_stop_loss 

120 self.__trading_consts.BUY_TAKE_PROFIT: float = buy_take_profit 

121 self.__trading_consts.STATIC_REWARD_ADJUSTMENT: float = static_reward_adjustment 

122 self.__trading_consts.PENALTY_STARTS: int = penalty_starts 

123 self.__trading_consts.PENALTY_STOPS: int = penalty_stops 

124 self.__trading_consts.PROFITABILITY_FUNCTION = lambda x: -1.0 * math.exp(-x + 1) + 1 

125 self.__trading_consts.PENALTY_FUNCTION = lambda x: \ 

126 min(1, 1 - math.tanh(-3.0 * (x - penalty_stops) / (penalty_stops - penalty_starts))) 

127 self.__trading_consts.OUTPUT_CLASSES: int = vars(self.__label_annotator.get_output_classes()) 

128 

129 # Prefetching data if needed 

130 if self.__should_prefetch: 

131 self.__prefetched_data = { TradingEnvironment.TRAIN_MODE: None, 

132 TradingEnvironment.TEST_MODE: None } 

133 self.__mode = TradingEnvironment.TEST_MODE 

134 self.__prefetched_data[self.__mode] = self.__prefetch_state_data(env_length_range = (self.__trading_consts.WINDOW_SIZE, 

135 self.get_environment_length())) 

136 self.__mode = TradingEnvironment.TRAIN_MODE 

137 self.__prefetched_data[self.__mode] = self.__prefetch_state_data(env_length_range = (self.__trading_consts.WINDOW_SIZE, 

138 self.get_environment_length())) 

139 else: 

140 self.__prefetched_data = None 

141 

142 # Initializing the environment state 

143 self.current_iteration: int = self.__trading_consts.WINDOW_SIZE 

144 self.state: list[float] = self.__get_current_state_data() 

145 self.action_space: Discrete = Discrete(3) 

146 self.observation_space: Box = Box(low = np.ones(len(self.state)) * -3, 

147 high = np.ones(len(self.state)) * 3, 

148 dtype = np.float64) 

149 

150 def __split_data(self, data: pd.DataFrame, test_size: float) -> dict[pd.DataFrame, pd.DataFrame]: 

151 """ 

152 Splits the given DataFrame into training and testing sets based on the specified test size ratio. 

153 

154 Parameters: 

155 data (pd.DataFrame): DataFrame containing the stock market data. 

156 test_size (float): Ratio of the data to be used for testing. 

157 

158 Returns: 

159 (dict[pd.DataFrame, pd.DataFrame]): Dictionary containing training and testing data frames. 

160 """ 

161 

162 dividing_index = int(len(data) * (1 - test_size)) 

163 

164 return { 

165 TradingEnvironment.TRAIN_MODE: data.iloc[:dividing_index].reset_index(drop=True), 

166 TradingEnvironment.TEST_MODE: data.iloc[dividing_index:].reset_index(drop=True) 

167 } 

168 

169 def __standard_scale(self, data: np.ndarray) -> np.ndarray: 

170 """ 

171 Standardizes the given data by removing the mean and scaling to unit variance. 

172 

173 Parameters: 

174 data (np.ndarray): The data to be standardized. 

175 

176 Returns: 

177 (np.ndarray): The standardized data. 

178 """ 

179 

180 mean = np.mean(data, axis = 0, keepdims = True) 

181 std = np.std(data, axis = 0, keepdims = True) 

182 std[std == 0] = 1 

183 

184 return (data - mean) / std 

185 

186 def __prepare_state_data(self, slice_to_get: slice, include_trading_data: bool = True) -> list[float]: 

187 """ 

188 Calculates state data as a list of floats representing current iteration's observation. 

189 Observations contains all input data refined to window size and couple of coefficients 

190 giving an insight into current budget and orders situation. 

191 

192 Parameters: 

193 slice (slice): Slice to get the data from. 

194 include_trading_data (bool): If True, includes trading data in the observation. 

195 

196 Returns: 

197 (list[float]): List with current observations for environment. 

198 """ 

199 

200 current_market_data = self.__data[self.__mode].iloc[slice_to_get].copy() 

201 current_market_data_no_index = current_market_data.select_dtypes(include = [np.number]) 

202 

203 if self.__meta_data is not None and \ 

204 self.__meta_data.get('normalization_groups', None) is not None: 

205 grouped_columns_names = self.__meta_data['normalization_groups'] 

206 preprocessed_data_pieces = [] 

207 left_over_columns_names = set(current_market_data_no_index.columns) 

208 for columns_names_to_normalize in grouped_columns_names: 

209 left_over_columns_names -= set(columns_names_to_normalize) 

210 data_frame_piece_to_normalize = current_market_data_no_index[columns_names_to_normalize] 

211 normalized_data_frame_piece = self.__standard_scale(data_frame_piece_to_normalize.values.reshape(-1, 1)) 

212 preprocessed_data_pieces.append(normalized_data_frame_piece.reshape(*data_frame_piece_to_normalize.shape)) 

213 for column in left_over_columns_names: 

214 preprocessed_data_pieces.append(current_market_data_no_index[column].values.reshape(-1, 1)) 

215 normalized_current_market_data_values = np.hstack(preprocessed_data_pieces) 

216 else: 

217 normalized_current_market_data_values = self.__standard_scale(current_market_data_no_index.values) 

218 current_marked_data_list = normalized_current_market_data_values.ravel().tolist() 

219 

220 if include_trading_data: 

221 current_normalized_budget = 1.0 * self.__trading_data.current_budget / self.__trading_consts.INITIAL_BUDGET 

222 current_profitability_coeff = self.__trading_consts.PROFITABILITY_FUNCTION(current_normalized_budget) 

223 current_trades_occupancy_coeff = 1.0 * self.__trading_data.currently_placed_trades / self.__trading_consts.MAX_AMOUNT_OF_TRADES 

224 current_no_trades_penalty_coeff = self.__trading_consts.PENALTY_FUNCTION(self.__trading_data.no_trades_placed_for) 

225 current_inner_state_list = [current_profitability_coeff, current_trades_occupancy_coeff, current_no_trades_penalty_coeff] 

226 current_marked_data_list += current_inner_state_list 

227 

228 return current_marked_data_list 

229 

230 def __prefetch_state_data(self, env_length_range: tuple[int, int], include_trading_data: bool = True) -> pd.DataFrame: 

231 """ 

232 Prefetches state data for the given environment length range. 

233 

234 Parameters: 

235 env_length_range (tuple[int, int]): Range to limit the length of the environment. 

236 include_trading_data (bool): If True, includes trading data in the observation. 

237 

238 Returns: 

239 (pd.DataFrame): DataFrame containing the pre-fetched state data. 

240 """ 

241 

242 new_rows = [] 

243 for i in range(env_length_range[0], env_length_range[1]): 

244 data_row = self.__prepare_state_data(slice(i - self.__trading_consts.WINDOW_SIZE, i), include_trading_data = include_trading_data) 

245 new_rows.append(data_row) 

246 

247 return pd.DataFrame(new_rows, columns = [f"feature_{i}" for i in range(len(new_rows[0]))]) 

248 

249 def __prepare_labeled_data(self, env_length_range: tuple[int, int]) -> tuple[pd.DataFrame, pd.Series]: 

250 """ 

251 Prepares labeled data for training the model with classification approach. 

252 It extracts the relevant features and labels from the environment's data. 

253 

254 Parameters: 

255 env_length_range (tuple[int, int]): Range to limit the length 

256 

257 Returns: 

258 (tuple[pd.DataFrame, pd.Series]): A tuple containing the input data and output labels. 

259 """ 

260 

261 prefetched_data = self.__prefetch_state_data(env_length_range, include_trading_data = False) 

262 labels = self.__label_annotator.annotate(self.__data[self.__mode]. \ 

263 iloc[:env_length_range[1]].copy()).shift(-env_length_range[0]).dropna() 

264 

265 return prefetched_data, labels 

266 

267 def __get_current_state_data(self) -> list[float]: 

268 """ 

269 Retrieves the current state data from the environment. 

270 

271 Returns: 

272 (list[float]): List with current observations for environment. 

273 """ 

274 

275 if self.__should_prefetch: 

276 return self.__prefetched_data[self.__mode].iloc[self.current_iteration - self.__trading_consts.WINDOW_SIZE].values.ravel().tolist() 

277 

278 return self.__prepare_state_data(slice_to_get = slice(self.current_iteration - self.__trading_consts.WINDOW_SIZE, self.current_iteration)) 

279 

280 

281 def set_mode(self, mode: str) -> None: 

282 """ 

283 Sets the mode of the environment to either TRAIN_MODE or TEST_MODE. 

284 

285 Parameters: 

286 mode (str): Mode to set for the environment. 

287 

288 Raises: 

289 ValueError: If the provided mode is not valid. 

290 """ 

291 

292 if mode not in [TradingEnvironment.TRAIN_MODE, TradingEnvironment.TEST_MODE]: 

293 raise ValueError(f"Invalid mode: {mode}. Use TradingEnvironment.TRAIN_MODE or TradingEnvironment.TEST_MODE.") 

294 self.__mode = mode 

295 

296 def get_mode(self) -> str: 

297 """ 

298 Mode getter. 

299 

300 Returns: 

301 (str): Current mode of the environment. 

302 """ 

303 

304 return copy.copy(self.__mode) 

305 

306 def get_trading_data(self) -> SimpleNamespace: 

307 """ 

308 Trading data getter. 

309 

310 Returns: 

311 (SimpleNamespace): Copy of the namespace with all trading data. 

312 """ 

313 

314 return copy.copy(self.__trading_data) 

315 

316 def get_number_of_trading_points_per_year(self) -> int: 

317 """ 

318 Returns the number of trading points per year. 

319 

320 Returns: 

321 (int): Number of trading points per year. 

322 """ 

323 

324 temp_data = {"time": pd.to_datetime(self.__data[self.TRAIN_MODE]['time'])} 

325 temp_df = pd.DataFrame(temp_data) 

326 temp_df['year'] = temp_df['time'].dt.year 

327 

328 trading_points_per_year = temp_df.groupby('year').size() 

329 if len(trading_points_per_year) > 3: 

330 # If there are more than three years, return the mode 

331 # of the central years 

332 return trading_points_per_year.iloc[1:-1].mode()[0] 

333 elif len(trading_points_per_year) > 2: 

334 # If there are only three years, return the middle year 

335 return trading_points_per_year.values[-2] 

336 else: 

337 # If there are only two years, return the maximum 

338 return max(trading_points_per_year.values) 

339 

340 def get_trading_consts(self) -> SimpleNamespace: 

341 """ 

342 Trading constants getter. 

343 

344 Returns: 

345 (SimpleNamespace): Copy of the namespace with all trading constants. 

346 """ 

347 

348 return copy.copy(self.__trading_consts) 

349 

350 def get_broker(self) -> Broker: 

351 """ 

352 Broker getter. 

353 

354 Returns: 

355 (Broker): Copy of the broker used by environment. 

356 """ 

357 

358 return copy.copy(self.__broker) 

359 

360 def get_environment_length(self) -> int: 

361 """ 

362 Environment length getter. 

363 

364 Returns: 

365 (Int): Length of environment. 

366 """ 

367 

368 return len(self.__data[self.__mode]) 

369 

370 def get_environment_spatial_data_dimension(self) -> tuple[int, int]: 

371 """ 

372 Environment spatial data dimensionality getter. 

373 

374 Returns: 

375 (Int): Dimension of spatial data in environment. 

376 """ 

377 

378 return (self.__trading_consts.WINDOW_SIZE, self.__data[self.__mode].shape[1] - 1) 

379 

380 def get_labeled_data(self, should_split: bool = True, should_balance: bool = True, 

381 verbose: bool = True, env_length_range: Optional[tuple[int, int]] = None) \ 

382 -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: 

383 """ 

384 Prepares labeled data for training or testing the model. 

385 It extracts the relevant features and labels from the environment's data. 

386 

387 Parameters: 

388 should_split (bool): Whether to split the data into training and testing sets. 

389 Defaults to True. If set to False, testing data will be empty. 

390 should_balance (bool): Whether to balance the labeled data. Defaults to True. 

391 Will be ignored if labeled_data_balancer is None. 

392 verbose (bool): Whether to log the class cardinality before and after balancing. 

393 Defaults to True. 

394 env_length_range (tuple[int, int]): Optional range to limit the range of the environment. 

395 

396 Returns: 

397 (tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]): A tuple containing the 

398 input data, output data, test input data, and test output data. 

399 """ 

400 

401 if env_length_range is None: 

402 env_length_range = (self.__trading_consts.WINDOW_SIZE, self.get_environment_length() - 1) 

403 

404 input_data, output_data = self.__prepare_labeled_data(env_length_range) 

405 input_data_test, output_data_test = [], [] 

406 if verbose: 

407 logging.info(f"Original class cardinality: {np.array(to_categorical(output_data)).sum(axis = 0)}") 

408 

409 if self.__mode == TradingEnvironment.TRAIN_MODE: 

410 if should_split: 

411 input_data, input_data_test, output_data, output_data_test = \ 

412 train_test_split(input_data, output_data, test_size = 0.1, random_state = 42, 

413 stratify = output_data) 

414 

415 if self.__labeled_data_balancer is not None and should_balance: 

416 input_data, output_data = self.__labeled_data_balancer.balance(input_data, output_data) 

417 if verbose: 

418 logging.info(f"Balanced class cardinality: {np.array(to_categorical(output_data)).sum(axis = 0)}") 

419 

420 return copy.copy((np.array(input_data), np.array(output_data), 

421 np.array(input_data_test), np.array(output_data_test))) 

422 

423 def get_data_for_iteration(self, columns: list[str], start: int = 0, stop: Optional[int] = None, 

424 step: int = 1) -> list[float]: 

425 """ 

426 Data getter for certain iterations. 

427 

428 Parameters: 

429 columns (list[str]): List of column names to extract from data. 

430 start (int): Start iteration index. Defaults to 0. 

431 stop (int): Stop iteration index. Defaults to environment length minus one. 

432 step (int): Step between iterations. Defaults to 1. 

433 

434 Returns: 

435 (list[float]): Copy of part of data with specified columns 

436 over specified iterations. 

437 """ 

438 

439 if stop is None: 

440 stop = self.get_environment_length() - 1 

441 

442 return copy.copy(self.__data[self.__mode].loc[start:stop:step, columns].values.ravel().tolist()) 

443 

444 def step(self, action: int) -> tuple[list[float], float, bool, dict]: 

445 """ 

446 Performs specified action on environment. It results in generation of the new 

447 observations. This function causes trades to be handled, reward to be calculated and 

448 environment to be updated. 

449 

450 Parameters: 

451 action (int): Number specifing action. Possible values are 0 for buy action, 

452 1 for wait action and 2 for sell action. 

453 

454 Returns: 

455 (tuple[list[float], float, bool, dict]): Tuple containing next observation 

456 state, reward, finish indication and additional info dictionary. 

457 """ 

458 

459 self.current_iteration += 1 

460 self.state = self.__get_current_state_data() 

461 

462 close_changes = self.__data[self.__mode].iloc[self.current_iteration - 2 : self.current_iteration]['close'].values 

463 stock_change_coeff = 1 + (close_changes[1] - close_changes[0]) / close_changes[0] 

464 closed_orders = self.__broker.update_orders(stock_change_coeff) 

465 

466 if self.__trading_mode == TradingEnvironment.TradingMode.EXPLICIT_ORDER_CLOSING: 

467 current_orders = self.__broker.get_current_orders() 

468 if len(current_orders) > 0: 

469 was_last_order_placed_as_buy = current_orders[-1].is_buy_order 

470 if (action == 0 and not was_last_order_placed_as_buy) or \ 

471 (action == 2 and was_last_order_placed_as_buy): 

472 closed_orders += self.__broker.force_close_orders() 

473 

474 reward = self.__validator.validate_orders(closed_orders) 

475 self.__trading_data.currently_placed_trades -= len(closed_orders) 

476 self.__trading_data.current_budget += np.sum([trade.current_value for trade in closed_orders]) 

477 self.__trading_data.currently_invested -= np.sum([trade.initial_value for trade in closed_orders]) 

478 

479 number_of_possible_trades = self.__trading_consts.MAX_AMOUNT_OF_TRADES - self.__trading_data.currently_placed_trades 

480 money_to_trade = 0 

481 if number_of_possible_trades > 0: 

482 money_to_trade = 1.0 / number_of_possible_trades * self.__trading_data.current_budget 

483 

484 if action == 0: 

485 is_buy_order = True 

486 stop_loss = self.__trading_consts.SELL_STOP_LOSS 

487 take_profit = self.__trading_consts.SELL_TAKE_PROFIT 

488 elif action == 2: 

489 is_buy_order = False 

490 stop_loss = self.__trading_consts.BUY_STOP_LOSS 

491 take_profit = self.__trading_consts.BUY_TAKE_PROFIT 

492 

493 if action != 1: 

494 if number_of_possible_trades > 0: 

495 self.__trading_data.current_budget -= money_to_trade 

496 self.__trading_data.currently_invested += money_to_trade 

497 self.__broker.place_order(money_to_trade, is_buy_order, stop_loss, take_profit) 

498 self.__trading_data.currently_placed_trades += 1 

499 self.__trading_data.no_trades_placed_for = 0 

500 reward += self.__trading_consts.STATIC_REWARD_ADJUSTMENT 

501 else: 

502 self.__trading_data.no_trades_placed_for += 1 

503 reward -= self.__trading_consts.STATIC_REWARD_ADJUSTMENT 

504 else: 

505 self.__trading_data.no_trades_placed_for += 1 

506 if number_of_possible_trades == 0: 

507 reward += self.__trading_consts.STATIC_REWARD_ADJUSTMENT 

508 

509 if number_of_possible_trades > 0: 

510 reward *= (1 - self.__trading_consts.PENALTY_FUNCTION(self.__trading_data.no_trades_placed_for)) \ 

511 if reward > 0 else 1 

512 if self.__trading_consts.PENALTY_STOPS < self.__trading_data.no_trades_placed_for: 

513 reward -= self.__trading_consts.STATIC_REWARD_ADJUSTMENT 

514 

515 if (self.current_iteration >= self.get_environment_length() - 1 or 

516 self.__trading_data.current_budget > 10 * self.__trading_consts.INITIAL_BUDGET or 

517 (self.__trading_data.current_budget + self.__trading_data.currently_invested) / self.__trading_consts.INITIAL_BUDGET < 0.8): 

518 done = True 

519 else: 

520 done = False 

521 

522 info = {'coeff': stock_change_coeff, 

523 'iteration': self.current_iteration, 

524 'number_of_closed_orders': len(closed_orders), 

525 'money_to_trade': money_to_trade, 

526 'action': action, 

527 'current_budget': self.__trading_data.current_budget, 

528 'currently_invested': self.__trading_data.currently_invested, 

529 'no_trades_placed_for': self.__trading_data.no_trades_placed_for, 

530 'currently_placed_trades': self.__trading_data.currently_placed_trades} 

531 

532 return self.state, reward, done, info 

533 

534 def render(self) -> None: 

535 """ 

536 Renders environment visualization. Will be implemented later. 

537 """ 

538 

539 #TODO: Visualization to be implemented 

540 pass 

541 

542 def reset(self, randkey: Optional[int] = None) -> list[float]: 

543 """ 

544 Resets environment. Used typically if environemnt is finished, 

545 i.e. when ther is no more steps to be taken within environemnt 

546 or finish conditions are fulfilled. 

547 

548 Parameters: 

549 randkey (Optional[int]): Value indicating what iteration 

550 should be trated as starting point after reset. 

551 

552 Returns: 

553 (list[float]): Current iteration observation state. 

554 """ 

555 

556 if randkey is None: 

557 randkey = random.randint(self.__trading_consts.WINDOW_SIZE, self.get_environment_length() - 1) 

558 self.__trading_data.current_budget = self.__trading_consts.INITIAL_BUDGET 

559 self.__trading_data.currently_invested = 0 

560 self.__trading_data.no_trades_placed_for = 0 

561 self.__trading_data.currently_placed_trades = 0 

562 self.__broker.reset() 

563 self.current_iteration = randkey 

564 self.state = self.__get_current_state_data() 

565 

566 return self.state