Coverage for source/environment/trading_environment.py: 81%
232 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-09-29 18:17 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-09-29 18:17 +0000
1# environment/trading environment.py
3# global imports
4import copy
5import logging
6import math
7import numpy as np
8import pandas as pd
9import random
10from enum import Enum
11from gym import Env
12from gym.spaces import Box, Discrete
13from sklearn.model_selection import train_test_split
14from tensorflow.keras.utils import to_categorical
15from types import SimpleNamespace
16from typing import Any, Optional
18# local imports
19from source.environment import Broker, LabelAnnotatorBase, LabeledDataBalancer, RewardValidatorBase
21class TradingEnvironment(Env):
22 """
23 Implements stock market environment that actor can perform actions (place orders) in.
24 It is used to train various models using various approaches. Can be
25 configured to award points and impose a penalty in several ways.
26 """
28 class TradingMode(Enum):
29 """
30 Enumeration for the different trading modes.
31 """
33 IMPLICIT_ORDER_CLOSING = 0
34 EXPLICIT_ORDER_CLOSING = 1
36 # global class constants
37 TRAIN_MODE = 'train'
38 TEST_MODE = 'test'
40 def __init__(self, data: pd.DataFrame, initial_budget: float, max_amount_of_trades: int,
41 window_size: int, validator: RewardValidatorBase, label_annotator: LabelAnnotatorBase,
42 sell_stop_loss: float, sell_take_profit: float, buy_stop_loss: float, buy_take_profit: float,
43 test_ratio: float = 0.2, penalty_starts: int = 0, penalty_stops: int = 10,
44 static_reward_adjustment: float = 1, labeled_data_balancer: Optional[LabeledDataBalancer] = None,
45 meta_data: Optional[dict[str, Any]] = None, trading_mode: Optional[TradingMode] = None,
46 should_prefetch: bool = True) -> None:
47 """
48 Class constructor. Allows to define all crucial constans, reward validation methods,
49 environmental penalty policies, etc.
51 Parameters:
52 data (pd.DataFrame): DataFrame containing historical market data.
53 initial_budget (float): Initial budget constant for trader to start from.
54 max_amount_of_trades (int): Max amount of trades that can be ongoing at the same time.
55 Seting this constant prevents traders from placing orders randomly and defines
56 amount of money that can be assigned to a single order at certain iteration.
57 window_size (int): Constant defining how far in the past trader will be able to look
58 into at certain iteration.
59 validator (RewardValidatorBase): Validator implementing policy used to award points
60 for closed trades.
61 label_annotator (LabelAnnotatorBase): Annotator implementing policy used to label
62 data with target values. It is used to provide supervised agents with information
63 about what is the target class value for certain iteration.
64 sell_stop_loss (float): Constant used to define losing boundary at which sell order
65 (short) is closed.
66 sell_take_profit (float): Constant used to define winning boundary at which sell order
67 (short) is closed.
68 buy_stop_loss (float): Constant used to define losing boundary at which buy order
69 (long) is closed.
70 buy_take_profit (float): Constant used to define winning boundary at which buy order
71 (long) is closed.
72 test_ratio (float): Ratio of data that should be used for testing purposes.
73 penalty_starts (int): Constant defining how many trading periods can trader go without placing
74 an order until penalty is imposed. Penalty at range between start and stop constant
75 is calculated as percentile of positive reward, and subtracted from the actual reward.
76 penalty_stops (int): Constant defining at which trading period penalty will no longer be increased.
77 Reward for trading periods exceeding penalty stop constant will equal minus static reward adjustment.
78 static_reward_adjustment (float): Constant use to penalize trader for bad choices or
79 reward it for good one.
80 labeled_data_balancer (Optional[LabeledDataBalancer]): Balancer used to balance
81 labeled data. If None, no balancing will be performed.
82 meta_data (dict[str, Any]): Dictionary containing metadata about the dataset.
83 mode (TradingMode): Mode of the environment, either IMPLICIT_ORDER_CLOSING or EXPLICIT_ORDER_CLOSING.
84 should_prefetch (bool): If True, data will be pre-fetched to speed up training.
85 """
87 if test_ratio < 0.0 or test_ratio >= 1.0:
88 raise ValueError(f"Invalid test_ratio: {test_ratio}. It should be in range [0, 1).")
90 if trading_mode is None:
91 trading_mode = TradingEnvironment.TradingMode.IMPLICIT_ORDER_CLOSING
92 elif isinstance(trading_mode, int):
93 trading_mode = TradingEnvironment.TradingMode(trading_mode)
95 if not isinstance(trading_mode, TradingEnvironment.TradingMode):
96 raise ValueError(f"Invalid trading_mode: {trading_mode}. It should be of type TradingEnvironment.TradingMode or int.")
98 # Initializing the environment
99 self.__data: dict[pd.DataFrame, pd.DataFrame] = self.__split_data(data, test_ratio)
100 self.__meta_data: Optional[dict[str, Any]] = meta_data
101 self.__mode = TradingEnvironment.TRAIN_MODE
102 self.__trading_mode: TradingEnvironment.TradingMode = trading_mode
103 self.__should_prefetch: bool = should_prefetch
104 self.__broker: Broker = Broker()
105 self.__validator: RewardValidatorBase = validator
106 self.__label_annotator: LabelAnnotatorBase = label_annotator
107 self.__labeled_data_balancer: Optional[LabeledDataBalancer] = labeled_data_balancer
109 # Setting up trading data
110 self.__trading_data: SimpleNamespace = SimpleNamespace()
111 self.__trading_data.current_budget: float = initial_budget
112 self.__trading_data.currently_invested: float = 0
113 self.__trading_data.no_trades_placed_for: int = 0
114 self.__trading_data.currently_placed_long_trades: int = 0
115 self.__trading_data.currently_placed_short_trades: int = 0
117 # Setting up trading constants
118 self.__trading_consts = SimpleNamespace()
119 self.__trading_consts.INITIAL_BUDGET: float = initial_budget
120 self.__trading_consts.MAX_AMOUNT_OF_TRADES: int = max_amount_of_trades
121 self.__trading_consts.WINDOW_SIZE: int = window_size
122 self.__trading_consts.SELL_STOP_LOSS: float = sell_stop_loss
123 self.__trading_consts.SELL_TAKE_PROFIT: float = sell_take_profit
124 self.__trading_consts.BUY_STOP_LOSS: float = buy_stop_loss
125 self.__trading_consts.BUY_TAKE_PROFIT: float = buy_take_profit
126 self.__trading_consts.STATIC_REWARD_ADJUSTMENT: float = static_reward_adjustment
127 self.__trading_consts.PENALTY_STARTS: int = penalty_starts
128 self.__trading_consts.PENALTY_STOPS: int = penalty_stops
129 self.__trading_consts.PROFITABILITY_FUNCTION = lambda x: -1.0 * math.exp(-x + 1) + 1
130 self.__trading_consts.PENALTY_FUNCTION = lambda x: \
131 min(1, 1 - math.tanh(-3.0 * (x - penalty_stops) / (penalty_stops - penalty_starts)))
132 self.__trading_consts.OUTPUT_CLASSES: int = vars(self.__label_annotator.get_output_classes())
134 # Prefetching data if needed
135 if self.__should_prefetch:
136 self.__prefetched_data = { TradingEnvironment.TRAIN_MODE: None,
137 TradingEnvironment.TEST_MODE: None }
138 self.__mode = TradingEnvironment.TEST_MODE
139 self.__prefetched_data[self.__mode] = self.__prefetch_state_data(env_length_range = (self.__trading_consts.WINDOW_SIZE,
140 self.get_environment_length()))
141 self.__mode = TradingEnvironment.TRAIN_MODE
142 self.__prefetched_data[self.__mode] = self.__prefetch_state_data(env_length_range = (self.__trading_consts.WINDOW_SIZE,
143 self.get_environment_length()))
144 else:
145 self.__prefetched_data = None
147 # Initializing the environment state
148 self.current_iteration: int = self.__trading_consts.WINDOW_SIZE
149 self.state: list[float] = self.__get_current_state_data()
150 self.action_space: Discrete = Discrete(3)
151 self.observation_space: Box = Box(low = np.ones(len(self.state)) * -3,
152 high = np.ones(len(self.state)) * 3,
153 dtype = np.float64)
155 def __split_data(self, data: pd.DataFrame, test_size: float) -> dict[pd.DataFrame, pd.DataFrame]:
156 """
157 Splits the given DataFrame into training and testing sets based on the specified test size ratio.
159 Parameters:
160 data (pd.DataFrame): DataFrame containing the stock market data.
161 test_size (float): Ratio of the data to be used for testing.
163 Returns:
164 (dict[pd.DataFrame, pd.DataFrame]): Dictionary containing training and testing data frames.
165 """
167 dividing_index = int(len(data) * (1 - test_size))
169 return {
170 TradingEnvironment.TRAIN_MODE: data.iloc[:dividing_index].reset_index(drop=True),
171 TradingEnvironment.TEST_MODE: data.iloc[dividing_index:].reset_index(drop=True)
172 }
174 def __standard_scale(self, data: np.ndarray) -> np.ndarray:
175 """
176 Standardizes the given data by removing the mean and scaling to unit variance.
178 Parameters:
179 data (np.ndarray): The data to be standardized.
181 Returns:
182 (np.ndarray): The standardized data.
183 """
185 mean = np.mean(data, axis = 0, keepdims = True)
186 std = np.std(data, axis = 0, keepdims = True)
187 std[std == 0] = 1
189 return (data - mean) / std
191 def __prepare_state_data(self, slice_to_get: slice, include_trading_data: bool = True) -> list[float]:
192 """
193 Calculates state data as a list of floats representing current iteration's observation.
194 Observations contains all input data refined to window size and couple of coefficients
195 giving an insight into current budget and orders situation.
197 Parameters:
198 slice (slice): Slice to get the data from.
199 include_trading_data (bool): If True, includes trading data in the observation.
201 Returns:
202 (list[float]): List with current observations for environment.
203 """
205 current_market_data = self.__data[self.__mode].iloc[slice_to_get].copy()
206 current_market_data_no_index = current_market_data.select_dtypes(include = [np.number])
208 if self.__meta_data is not None and \
209 self.__meta_data.get('normalization_groups', None) is not None:
210 grouped_columns_names = self.__meta_data['normalization_groups']
211 preprocessed_data_pieces = []
212 left_over_columns_names = set(current_market_data_no_index.columns)
213 for columns_names_to_normalize in grouped_columns_names:
214 left_over_columns_names -= set(columns_names_to_normalize)
215 data_frame_piece_to_normalize = current_market_data_no_index[columns_names_to_normalize]
216 normalized_data_frame_piece = self.__standard_scale(data_frame_piece_to_normalize.values.reshape(-1, 1))
217 preprocessed_data_pieces.append(normalized_data_frame_piece.reshape(*data_frame_piece_to_normalize.shape))
218 for column in left_over_columns_names:
219 preprocessed_data_pieces.append(current_market_data_no_index[column].values.reshape(-1, 1))
220 normalized_current_market_data_values = np.hstack(preprocessed_data_pieces)
221 else:
222 normalized_current_market_data_values = self.__standard_scale(current_market_data_no_index.values)
223 current_marked_data_list = normalized_current_market_data_values.ravel().tolist()
225 if include_trading_data:
226 current_normalized_budget = 1.0 * self.__trading_data.current_budget / self.__trading_consts.INITIAL_BUDGET
227 current_profitability_coeff = self.__trading_consts.PROFITABILITY_FUNCTION(current_normalized_budget)
228 current_no_trades_penalty_coeff = self.__trading_consts.PENALTY_FUNCTION(self.__trading_data.no_trades_placed_for)
229 current_long_trades_occupancy_coeff = 1.0 * self.__trading_data.currently_placed_long_trades / self.__trading_consts.MAX_AMOUNT_OF_TRADES
230 current_short_trades_occupancy_coeff = 1.0 * self.__trading_data.currently_placed_short_trades / self.__trading_consts.MAX_AMOUNT_OF_TRADES
231 current_inner_state_list = [current_profitability_coeff, current_no_trades_penalty_coeff, \
232 current_long_trades_occupancy_coeff, current_short_trades_occupancy_coeff]
233 current_marked_data_list += current_inner_state_list
235 return current_marked_data_list
237 def __prefetch_state_data(self, env_length_range: tuple[int, int], include_trading_data: bool = True) -> pd.DataFrame:
238 """
239 Prefetches state data for the given environment length range.
241 Parameters:
242 env_length_range (tuple[int, int]): Range to limit the length of the environment.
243 include_trading_data (bool): If True, includes trading data in the observation.
245 Returns:
246 (pd.DataFrame): DataFrame containing the pre-fetched state data.
247 """
249 new_rows = []
250 for i in range(env_length_range[0], env_length_range[1]):
251 data_row = self.__prepare_state_data(slice(i - self.__trading_consts.WINDOW_SIZE, i), include_trading_data = include_trading_data)
252 new_rows.append(data_row)
254 return pd.DataFrame(new_rows, columns = [f"feature_{i}" for i in range(len(new_rows[0]))])
256 def __prepare_labeled_data(self, env_length_range: tuple[int, int]) -> tuple[pd.DataFrame, pd.Series]:
257 """
258 Prepares labeled data for training the model with classification approach.
259 It extracts the relevant features and labels from the environment's data.
261 Parameters:
262 env_length_range (tuple[int, int]): Range to limit the length
264 Returns:
265 (tuple[pd.DataFrame, pd.Series]): A tuple containing the input data and output labels.
266 """
268 prefetched_data = self.__prefetch_state_data(env_length_range, include_trading_data = False)
269 labels = self.__label_annotator.annotate(self.__data[self.__mode]. \
270 iloc[:env_length_range[1]].copy()).shift(-env_length_range[0]).dropna()
272 return prefetched_data, labels
274 def __get_current_state_data(self) -> list[float]:
275 """
276 Retrieves the current state data from the environment.
278 Returns:
279 (list[float]): List with current observations for environment.
280 """
282 if self.__should_prefetch:
283 return self.__prefetched_data[self.__mode].iloc[self.current_iteration - self.__trading_consts.WINDOW_SIZE].values.ravel().tolist()
285 return self.__prepare_state_data(slice_to_get = slice(self.current_iteration - self.__trading_consts.WINDOW_SIZE, self.current_iteration))
288 def set_mode(self, mode: str) -> None:
289 """
290 Sets the mode of the environment to either TRAIN_MODE or TEST_MODE.
292 Parameters:
293 mode (str): Mode to set for the environment.
295 Raises:
296 ValueError: If the provided mode is not valid.
297 """
299 if mode not in [TradingEnvironment.TRAIN_MODE, TradingEnvironment.TEST_MODE]:
300 raise ValueError(f"Invalid mode: {mode}. Use TradingEnvironment.TRAIN_MODE or TradingEnvironment.TEST_MODE.")
301 self.__mode = mode
303 def get_mode(self) -> str:
304 """
305 Mode getter.
307 Returns:
308 (str): Current mode of the environment.
309 """
311 return copy.copy(self.__mode)
313 def get_trading_data(self) -> SimpleNamespace:
314 """
315 Trading data getter.
317 Returns:
318 (SimpleNamespace): Copy of the namespace with all trading data.
319 """
321 return copy.copy(self.__trading_data)
323 def get_number_of_trading_points_per_year(self) -> int:
324 """
325 Returns the number of trading points per year.
327 Returns:
328 (int): Number of trading points per year.
329 """
331 temp_data = {"time": pd.to_datetime(self.__data[self.TRAIN_MODE]['time'])}
332 temp_df = pd.DataFrame(temp_data)
333 temp_df['year'] = temp_df['time'].dt.year
335 trading_points_per_year = temp_df.groupby('year').size()
336 if len(trading_points_per_year) > 3:
337 # If there are more than three years, return the mode
338 # of the central years
339 return trading_points_per_year.iloc[1:-1].mode()[0]
340 elif len(trading_points_per_year) > 2:
341 # If there are only three years, return the middle year
342 return trading_points_per_year.values[-2]
343 else:
344 # If there are only two years, return the maximum
345 return max(trading_points_per_year.values)
347 def get_trading_consts(self) -> SimpleNamespace:
348 """
349 Trading constants getter.
351 Returns:
352 (SimpleNamespace): Copy of the namespace with all trading constants.
353 """
355 return copy.copy(self.__trading_consts)
357 def get_broker(self) -> Broker:
358 """
359 Broker getter.
361 Returns:
362 (Broker): Copy of the broker used by environment.
363 """
365 return copy.copy(self.__broker)
367 def get_environment_length(self) -> int:
368 """
369 Environment length getter.
371 Returns:
372 (Int): Length of environment.
373 """
375 return len(self.__data[self.__mode])
377 def get_environment_spatial_data_dimension(self) -> tuple[int, int]:
378 """
379 Environment spatial data dimensionality getter.
381 Returns:
382 (Int): Dimension of spatial data in environment.
383 """
385 return (self.__trading_consts.WINDOW_SIZE, self.__data[self.__mode].shape[1] - 1)
387 def get_labeled_data(self, should_split: bool = True, should_balance: bool = True,
388 verbose: bool = True, env_length_range: Optional[tuple[int, int]] = None) \
389 -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
390 """
391 Prepares labeled data for training or testing the model.
392 It extracts the relevant features and labels from the environment's data.
394 Parameters:
395 should_split (bool): Whether to split the data into training and testing sets.
396 Defaults to True. If set to False, testing data will be empty.
397 should_balance (bool): Whether to balance the labeled data. Defaults to True.
398 Will be ignored if labeled_data_balancer is None.
399 verbose (bool): Whether to log the class cardinality before and after balancing.
400 Defaults to True.
401 env_length_range (tuple[int, int]): Optional range to limit the range of the environment.
403 Returns:
404 (tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]): A tuple containing the
405 input data, output data, test input data, and test output data.
406 """
408 if env_length_range is None:
409 env_length_range = (self.__trading_consts.WINDOW_SIZE, self.get_environment_length() - 1)
411 input_data, output_data = self.__prepare_labeled_data(env_length_range)
412 input_data_test, output_data_test = [], []
413 if verbose:
414 logging.info(f"Original class cardinality: {np.array(to_categorical(output_data)).sum(axis = 0)}")
416 if self.__mode == TradingEnvironment.TRAIN_MODE:
417 if should_split:
418 input_data, input_data_test, output_data, output_data_test = \
419 train_test_split(input_data, output_data, test_size = 0.1, random_state = 42,
420 stratify = output_data)
422 if self.__labeled_data_balancer is not None and should_balance:
423 input_data, output_data = self.__labeled_data_balancer.balance(input_data, output_data)
424 if verbose:
425 logging.info(f"Balanced class cardinality: {np.array(to_categorical(output_data)).sum(axis = 0)}")
427 return copy.copy((np.array(input_data), np.array(output_data),
428 np.array(input_data_test), np.array(output_data_test)))
430 def get_data_for_iteration(self, columns: list[str], start: int = 0, stop: Optional[int] = None,
431 step: int = 1) -> list[float]:
432 """
433 Data getter for certain iterations.
435 Parameters:
436 columns (list[str]): List of column names to extract from data.
437 start (int): Start iteration index. Defaults to 0.
438 stop (int): Stop iteration index. Defaults to environment length minus one.
439 step (int): Step between iterations. Defaults to 1.
441 Returns:
442 (list[float]): Copy of part of data with specified columns
443 over specified iterations.
444 """
446 if stop is None:
447 stop = self.get_environment_length() - 1
449 return copy.copy(self.__data[self.__mode].loc[start:stop:step, columns].values.ravel().tolist())
451 def step(self, action: int) -> tuple[list[float], float, bool, dict]:
452 """
453 Performs specified action on environment. It results in generation of the new
454 observations. This function causes trades to be handled, reward to be calculated and
455 environment to be updated.
457 Parameters:
458 action (int): Number specifing action. Possible values are 0 for buy action,
459 1 for wait action and 2 for sell action.
461 Returns:
462 (tuple[list[float], float, bool, dict]): Tuple containing next observation
463 state, reward, finish indication and additional info dictionary.
464 """
466 self.current_iteration += 1
467 self.state = self.__get_current_state_data()
469 close_changes = self.__data[self.__mode].iloc[self.current_iteration - 2 : self.current_iteration]['close'].values
470 stock_change_coeff = 1 + (close_changes[1] - close_changes[0]) / close_changes[0]
471 closed_orders = self.__broker.update_orders(stock_change_coeff)
473 if self.__trading_mode == TradingEnvironment.TradingMode.EXPLICIT_ORDER_CLOSING:
474 current_orders = self.__broker.get_current_orders()
475 if len(current_orders) > 0:
476 was_last_order_placed_as_buy = current_orders[-1].is_buy_order
477 if (action == 0 and not was_last_order_placed_as_buy) or \
478 (action == 2 and was_last_order_placed_as_buy):
479 closed_orders += self.__broker.force_close_orders()
481 reward = self.__validator.validate_orders(closed_orders)
482 number_of_closed_long_trades = len([order for order in closed_orders if order.is_buy_order])
483 number_of_closed_short_trades = len(closed_orders) - number_of_closed_long_trades
484 self.__trading_data.currently_placed_long_trades -= number_of_closed_long_trades
485 self.__trading_data.currently_placed_short_trades -= number_of_closed_short_trades
486 self.__trading_data.current_budget += np.sum([order.current_value for order in closed_orders])
487 self.__trading_data.currently_invested -= np.sum([order.initial_value for order in closed_orders])
489 number_of_possible_trades = self.__trading_consts.MAX_AMOUNT_OF_TRADES \
490 - self.__trading_data.currently_placed_long_trades - self.__trading_data.currently_placed_short_trades
491 money_to_trade = 0
492 if number_of_possible_trades > 0:
493 money_to_trade = 1.0 / number_of_possible_trades * self.__trading_data.current_budget
495 if action == 0:
496 is_buy_order = True
497 stop_loss = self.__trading_consts.SELL_STOP_LOSS
498 take_profit = self.__trading_consts.SELL_TAKE_PROFIT
499 elif action == 2:
500 is_buy_order = False
501 stop_loss = self.__trading_consts.BUY_STOP_LOSS
502 take_profit = self.__trading_consts.BUY_TAKE_PROFIT
504 if action != 1:
505 if number_of_possible_trades > 0:
506 self.__trading_data.current_budget -= money_to_trade
507 self.__trading_data.currently_invested += money_to_trade
508 self.__broker.place_order(money_to_trade, is_buy_order, stop_loss, take_profit)
509 self.__trading_data.no_trades_placed_for = 0
510 reward += self.__trading_consts.STATIC_REWARD_ADJUSTMENT
511 if action == 0:
512 self.__trading_data.currently_placed_long_trades += 1
513 elif action == 2:
514 self.__trading_data.currently_placed_short_trades += 1
515 else:
516 self.__trading_data.no_trades_placed_for += 1
517 reward -= self.__trading_consts.STATIC_REWARD_ADJUSTMENT
518 else:
519 self.__trading_data.no_trades_placed_for += 1
520 if number_of_possible_trades == 0:
521 reward += self.__trading_consts.STATIC_REWARD_ADJUSTMENT
523 if number_of_possible_trades > 0:
524 reward *= (1 - self.__trading_consts.PENALTY_FUNCTION(self.__trading_data.no_trades_placed_for)) \
525 if reward > 0 else 1
526 if self.__trading_consts.PENALTY_STOPS < self.__trading_data.no_trades_placed_for:
527 reward -= self.__trading_consts.STATIC_REWARD_ADJUSTMENT
529 if (self.current_iteration >= self.get_environment_length() - 1 or
530 self.__trading_data.current_budget > 10 * self.__trading_consts.INITIAL_BUDGET or
531 (self.__trading_data.current_budget + self.__trading_data.currently_invested) / self.__trading_consts.INITIAL_BUDGET < 0.6):
532 done = True
533 else:
534 done = False
536 info = {'coeff': stock_change_coeff,
537 'iteration': self.current_iteration,
538 'number_of_closed_orders': len(closed_orders),
539 'money_to_trade': money_to_trade,
540 'action': action,
541 'current_budget': self.__trading_data.current_budget,
542 'currently_invested': self.__trading_data.currently_invested,
543 'no_trades_placed_for': self.__trading_data.no_trades_placed_for,
544 'currently_placed_long_trades': self.__trading_data.currently_placed_long_trades,
545 'currently_placed_short_trades': self.__trading_data.currently_placed_short_trades}
547 return self.state, reward, done, info
549 def render(self) -> None:
550 """
551 Renders environment visualization. Will be implemented later.
552 """
554 #TODO: Visualization to be implemented
555 pass
557 def reset(self, randkey: Optional[int] = None) -> list[float]:
558 """
559 Resets environment. Used typically if environemnt is finished,
560 i.e. when ther is no more steps to be taken within environemnt
561 or finish conditions are fulfilled.
563 Parameters:
564 randkey (Optional[int]): Value indicating what iteration
565 should be trated as starting point after reset.
567 Returns:
568 (list[float]): Current iteration observation state.
569 """
571 if randkey is None:
572 randkey = random.randint(self.__trading_consts.WINDOW_SIZE, self.get_environment_length() - 1)
573 self.__trading_data.current_budget = self.__trading_consts.INITIAL_BUDGET
574 self.__trading_data.currently_invested = 0
575 self.__trading_data.no_trades_placed_for = 0
576 self.__trading_data.currently_placed_long_trades = 0
577 self.__trading_data.currently_placed_short_trades = 0
578 self.__broker.reset()
579 self.current_iteration = randkey
580 self.state = self.__get_current_state_data()
582 return self.state