Coverage for source/training/training_handler.py: 95%

128 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-30 15:13 +0000

1# training/training_handler.py 

2 

3import logging 

4import io 

5import matplotlib.pyplot as plt 

6from tensorflow.keras.callbacks import Callback 

7from typing import Optional 

8from reportlab.pdfgen import canvas 

9from reportlab.lib.pagesizes import letter 

10from reportlab.lib.units import inch 

11from reportlab.lib.utils import ImageReader 

12 

13from .training_config import TrainingConfig 

14from ..environment.trading_environment import TradingEnvironment 

15from ..agent.agent_handler import AgentHandler 

16from ..plotting.plot_responsibility_chain_base import PlotResponsibilityChainBase 

17from ..plotting.plot_testing_history_responsibility_chain import PlotTestingHistoryResponsibilityChain 

18from ..plotting.plot_training_history_responsibility_chain import PlotTrainingHistoryResponsibilityChain 

19 

20class TrainingHandler(): 

21 """ 

22 Responsible for orchestrating the training process and report generation. 

23 

24 This class manages the complete training workflow, from initializing the 

25 environment and agent, running training and testing sessions, to generating 

26 PDF reports with performance visualizations and logs. It serves as the main 

27 entry point for executing and documenting trading agent training. 

28 """ 

29 

30 """ 

31 PDF report related constants 

32 """ 

33 HEADING_SPACING = 20 

34 CAPTION_FONT_SIZE = 14 

35 TEXT_FONT_SIZE = 8 

36 FONT_NAME = 'Courier' 

37 MARGINS = { 

38 'left': 30, 

39 'right': 30, 

40 'top': 30, 

41 'bottom': 30 

42 } 

43 EXCLUDE_FROM_LOGS = ['ETA'] 

44 

45 def __init__(self, config: TrainingConfig, page_width: int = letter[0], page_height: int = letter[1], 

46 heading_spacing: int = HEADING_SPACING, caption_font_size: int = CAPTION_FONT_SIZE, 

47 text_font_size: int = TEXT_FONT_SIZE, font_name: str = FONT_NAME, 

48 margins: dict[str, int] = MARGINS, exclude_from_logs: list[str] = EXCLUDE_FROM_LOGS) -> None: 

49 """ 

50 Initializes the training handler with configuration parameters. 

51 

52 Parameters: 

53 config (TrainingConfig): Configuration containing environment and agent settings. 

54 page_width (int): Width of PDF report pages in points. 

55 page_height (int): Height of PDF report pages in points. 

56 heading_spacing (int): Spacing between headings and content in points. 

57 caption_font_size (int): Font size for section captions. 

58 text_font_size (int): Font size for main text content. 

59 font_name (str): Font family to use for text in reports. 

60 margins (dict[str, int]): Dictionary with 'left', 'right', 'top', and 'bottom' margins. 

61 exclude_from_logs (list[str]): Terms that should be excluded from logs in reports. 

62 

63 Raises: 

64 ValueError: If margins dictionary doesn't contain required keys. 

65 """ 

66 

67 # Training related configuration 

68 self.__environment: TradingEnvironment = config.instantiate_environment() 

69 self.__agent: AgentHandler = config.instantiate_agent() 

70 self.__nr_of_steps: int = config.nr_of_steps 

71 self.__repeat_test: int = config.repeat_test 

72 self.__steps_per_episode: int = int(config.nr_of_steps / config.nr_of_episodes) 

73 

74 # Report related configuration 

75 self.__config_summary = str(config) 

76 self.__plotting_chain: PlotResponsibilityChainBase = PlotTestingHistoryResponsibilityChain() 

77 self.__plotting_chain.add_next_chain_link(PlotTrainingHistoryResponsibilityChain()) 

78 self.__generated_data: dict = {} 

79 self.__logs: io.StringIO = io.StringIO() 

80 self.__page_width = page_width 

81 self.__page_height = page_height 

82 self.__heading_spacing = heading_spacing 

83 self.__caption_font_size = caption_font_size 

84 self.__text_font_size = text_font_size 

85 self.__font_name = font_name 

86 self.__margins = margins 

87 self.__exclude_from_logs = exclude_from_logs 

88 

89 if not any(term in margins for term in ['left', 'right', 'top', 'bottom']): 

90 raise ValueError("Margins should contain 'left', 'right', 'top' and 'bottom' keys!") 

91 

92 def run_training(self, callbacks: list[Callback] = [], weights_load_path: Optional[str] = None, 

93 weights_save_path: Optional[str] = None) -> None: 

94 """ 

95 Executes the training and testing process for the trading agent. 

96 

97 This method orchestrates the complete training workflow, capturing logs, 

98 training the agent, and testing its performance. It populates internal 

99 data structures with results that can later be used for report generation. 

100 

101 Parameters: 

102 callbacks (list[Callback]): Keras callbacks to use during training. 

103 weights_load_path (str, optional): Path to load pre-trained weights from. 

104 weights_save_path (str, optional): Path to save trained weights to. 

105 """ 

106 

107 log_streamer = logging.StreamHandler(self.__logs) 

108 log_streamer.setFormatter(logging.Formatter('%(message)s')) 

109 log_streamer.setLevel(logging.INFO) 

110 root_logger = logging.getLogger() 

111 root_logger.addHandler(log_streamer) 

112 

113 try: 

114 logging.info(f"Training started!") 

115 logging.info(self.__config_summary.replace('\t', ' ')) 

116 logging.info(f"Printing models architecture...") 

117 self.__agent.print_model_summary(print_function = lambda x: logging.info(x)) 

118 

119 self.__generated_data['train'] = self.__agent.train_agent(self.__environment, 

120 self.__nr_of_steps, 

121 self.__steps_per_episode, 

122 callbacks, 

123 weights_load_path, 

124 weights_save_path) 

125 self.__generated_data['test'] = self.__agent.test_agent(self.__environment, 

126 self.__repeat_test) 

127 

128 logging.info(f"Training finished!") 

129 except Exception as e: 

130 logging.error(f"Training failed! Original error: {e}") 

131 finally: 

132 root_logger.removeHandler(log_streamer) 

133 log_streamer.close() 

134 

135 def __handle_plot_generation(self, data: dict) -> Optional[ImageReader]: 

136 """ 

137 Generates a plot based on provided data using the responsibility chain. 

138 

139 Parameters: 

140 data (dict): Dictionary containing 'key' identifying the plot type 

141 and 'plot_data' containing the actual data to be plotted. 

142 

143 Returns: 

144 Optional[ImageReader]: ReportLab ImageReader object if plot was generated, 

145 None if no handler could process the request. 

146 """ 

147 

148 image_reader = None 

149 axes = self.__plotting_chain.plot(data) 

150 if axes is not None: 

151 buffer = io.BytesIO() 

152 axes.figure.savefig(buffer, format = 'png') 

153 buffer.seek(0) 

154 image_reader = ImageReader(buffer) 

155 plt.close(axes.figure) 

156 else: 

157 logging.warning(f'Did not managed to generate plot for {data["key"]}!') 

158 

159 return image_reader 

160 

161 def __calculate_max_dimensions(self, pdf: canvas.Canvas) -> tuple[int, int]: 

162 """ 

163 Calculates maximum text dimensions that can fit on a PDF page. 

164 

165 Parameters: 

166 pdf (canvas.Canvas): The PDF canvas to calculate dimensions for. 

167 

168 Returns: 

169 tuple[int, int]: Maximum number of characters per line and number 

170 of lines per page that can fit within margins. 

171 """ 

172 

173 pdf.setFont(self.__font_name, self.__text_font_size) 

174 text_width = pdf.stringWidth(' ', self.__font_name, self.__text_font_size) 

175 available_width = self.__page_width - self.__margins['left'] - self.__margins['right'] 

176 max_width = int(available_width / text_width) 

177 

178 text_height = self.__text_font_size * 1.2 

179 available_height = self.__page_height - self.__margins['top'] - self.__margins['bottom'] \ 

180 - self.__heading_spacing 

181 max_height = int(available_height / text_height) 

182 

183 return max_width, max_height 

184 

185 def __handle_logs_preprocessing(self, raw_logs_bufffer: list[str], max_log_length: int, 

186 max_lines_per_page: int) -> list[list[str]]: 

187 """ 

188 Processes raw logs to fit them within PDF page constraints. 

189 

190 This method filters out excluded terms, handles line wrapping for long lines, 

191 and chunks the logs into page-sized portions. 

192 

193 Parameters: 

194 raw_logs_bufffer (list[str]): Raw log lines to process. 

195 max_log_length (int): Maximum characters per line. 

196 max_lines_per_page (int): Maximum lines per page. 

197 

198 Returns: 

199 list[list[str]]: List of pages, where each page is a list of log lines. 

200 """ 

201 

202 preprocessed_logs = [] 

203 for log in raw_logs_bufffer: 

204 if not any(exclude_term in log for exclude_term in self.__exclude_from_logs): 

205 if len(log) > max_log_length: 

206 for i in range(0, len(log), max_log_length): 

207 preprocessed_logs.append(log[i:i + max_log_length]) 

208 else: 

209 preprocessed_logs.append(log) 

210 

211 return [preprocessed_logs[i:i + max_lines_per_page] 

212 for i in range(0, len(preprocessed_logs), max_lines_per_page)] 

213 

214 def __draw_caption(self, pdf: canvas.Canvas, text: str) -> None: 

215 """ 

216 Draws a section caption with separating line on the PDF. 

217 

218 Parameters: 

219 pdf (canvas.Canvas): PDF canvas to draw on. 

220 text (str): Caption text to draw. 

221 """ 

222 

223 pdf.setFont(self.__font_name, self.__caption_font_size) 

224 pdf.drawString(self.__margins['left'], self.__page_height - self.__margins['top'], text) 

225 pdf.setLineWidth(2) 

226 pdf.setStrokeColorRGB(0, 0, 0) 

227 pdf.line(self.__margins['left'], self.__page_height - self.__margins['top'] \ 

228 - self.__heading_spacing / 2, self.__page_width - self.__margins['right'], 

229 self.__page_height - self.__margins['top'] - self.__heading_spacing / 2) 

230 

231 def __draw_text_body(self, pdf: canvas.Canvas, text_body: list[str]) -> None: 

232 """ 

233 Draws a block of text lines on the PDF. 

234 

235 Parameters: 

236 pdf (canvas.Canvas): PDF canvas to draw on. 

237 text_body (list[str]): List of text lines to draw. 

238 """ 

239 

240 text_block = pdf.beginText(self.__margins['left'], self.__page_height - \ 

241 self.__margins['top'] - self.__heading_spacing) 

242 text_block.setFont(self.__font_name, self.__text_font_size) 

243 for line in text_body: 

244 text_block.textLine(line) 

245 pdf.drawText(text_block) 

246 

247 def generate_report(self, path_to_pdf: str) -> None: 

248 """ 

249 Generates a comprehensive PDF report of training and testing results. 

250 

251 Creates a multi-page report with logs, training history plots, and test 

252 results visualizations based on the data collected during training. 

253 

254 Parameters: 

255 path_to_pdf (str): File path where the PDF report should be saved. 

256 """ 

257 

258 logging.info(f"Generating report...") 

259 pdf = canvas.Canvas(path_to_pdf, pagesize = letter) 

260 pdf.setTitle("Report") 

261 

262 # Report logs 

263 raw_logs = self.__logs.getvalue().split('\n') 

264 max_log_length, max_lines_per_page = self.__calculate_max_dimensions(pdf) 

265 preprocessed_logs = self.__handle_logs_preprocessing(raw_logs, max_log_length, max_lines_per_page) 

266 for preprocessed_logs_chunk in preprocessed_logs: 

267 self.__draw_caption(pdf, "Log output") 

268 self.__draw_text_body(pdf, preprocessed_logs_chunk) 

269 pdf.showPage() 

270 

271 # Draw training plot 

272 data = { 

273 'key': 'training_history', 

274 'plot_data': self.__generated_data['train'] 

275 } 

276 plot_buffer = self.__handle_plot_generation(data) 

277 if plot_buffer is not None: 

278 self.__draw_caption(pdf, "Training performance") 

279 pdf.drawImage(plot_buffer, inch, self.__page_height - 7.5 * inch, width = 6 * inch, 

280 preserveAspectRatio = True) 

281 pdf.showPage() 

282 

283 # Draw testing plots 

284 for index, testing_data in self.__generated_data['test'].items(): 

285 data = { 

286 'key': 'testing_history', 

287 'plot_data': testing_data 

288 } 

289 plot_buffer = self.__handle_plot_generation(data) 

290 if plot_buffer is not None: 

291 self.__draw_caption(pdf, f"Testing outcome, trial: {index + 1}") 

292 pdf.drawImage(plot_buffer, inch, self.__page_height - 7.5 * inch, width = 6 * inch, 

293 preserveAspectRatio = True) 

294 pdf.showPage() 

295 

296 pdf.save() 

297 logging.info(f"Report generated!")