Coverage for source/training/training_handler.py: 96%

136 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-07-27 17:11 +0000

1# training/training_handler.py 

2 

3# global imports 

4import io 

5import logging 

6import matplotlib.pyplot as plt 

7from reportlab.lib.pagesizes import letter 

8from reportlab.lib.units import inch 

9from reportlab.lib.utils import ImageReader 

10from reportlab.pdfgen import canvas 

11from tensorflow.keras.callbacks import Callback 

12from typing import Optional 

13 

14# local imports 

15from source.agent import AgentHandler 

16from source.plotting import AssetPriceMovementSummaryPlotResponsibilityChain, \ 

17 ClassificationTestingPlotResponsibilityChain, ClassificationTrainingPlotResponsibilityChain, \ 

18 PerformanceTestingPlotResponsibilityChain, PlotResponsibilityChainBase, \ 

19 PriceMovementTrendClassSummaryPlotResponsibilityChain, ReinforcementTrainingPlotResponsibilityChain 

20from source.training import TrainingConfig 

21 

22class TrainingHandler(): 

23 """ 

24 Responsible for orchestrating the training process and report generation. 

25 

26 This class manages the complete training workflow, from initializing the 

27 environment and agent, running training and testing sessions, to generating 

28 PDF reports with performance visualizations and logs. It serves as the main 

29 entry point for executing and documenting trading agent training. 

30 """ 

31 

32 # local constants 

33 __HEADING_SPACING = 20 

34 __CAPTION_FONT_SIZE = 14 

35 __TEXT_FONT_SIZE = 8 

36 __FONT_NAME = 'Courier' 

37 __MARGINS = { 

38 'left': 30, 

39 'right': 30, 

40 'top': 30, 

41 'bottom': 30 

42 } 

43 __EXCLUDE_FROM_LOGS = ['ETA'] 

44 

45 def __init__(self, config: TrainingConfig, page_width: int = letter[0], page_height: int = letter[1], 

46 heading_spacing: int = __HEADING_SPACING, caption_font_size: int = __CAPTION_FONT_SIZE, 

47 text_font_size: int = __TEXT_FONT_SIZE, font_name: str = __FONT_NAME, 

48 margins: dict[str, int] = __MARGINS, exclude_from_logs: list[str] = __EXCLUDE_FROM_LOGS) -> None: 

49 """ 

50 Initializes the training handler with configuration parameters. 

51 

52 Parameters: 

53 config (TrainingConfig): Configuration containing environment and agent settings. 

54 page_width (int): Width of PDF report pages in points. 

55 page_height (int): Height of PDF report pages in points. 

56 heading_spacing (int): Spacing between headings and content in points. 

57 caption_font_size (int): Font size for section captions. 

58 text_font_size (int): Font size for main text content. 

59 font_name (str): Font family to use for text in reports. 

60 margins (dict[str, int]): Dictionary with 'left', 'right', 'top', and 'bottom' margins. 

61 exclude_from_logs (list[str]): Terms that should be excluded from logs in reports. 

62 

63 Raises: 

64 ValueError: If margins dictionary doesn't contain required keys. 

65 """ 

66 

67 # Training related configuration 

68 self.__agent: AgentHandler = config.instantiate_agent_handler() 

69 self.__nr_of_steps: int = config.nr_of_steps 

70 self.__repeat_test: int = config.repeat_test 

71 self.__nr_of_episodes: int = config.nr_of_episodes 

72 

73 # Report related configuration 

74 self.__config_summary = str(config) 

75 self.__plotting_chain: PlotResponsibilityChainBase = AssetPriceMovementSummaryPlotResponsibilityChain() 

76 self.__plotting_chain.add_next_chain_link(ClassificationTestingPlotResponsibilityChain()) 

77 self.__plotting_chain.add_next_chain_link(ClassificationTrainingPlotResponsibilityChain()) 

78 self.__plotting_chain.add_next_chain_link(PerformanceTestingPlotResponsibilityChain()) 

79 self.__plotting_chain.add_next_chain_link(PriceMovementTrendClassSummaryPlotResponsibilityChain()) 

80 self.__plotting_chain.add_next_chain_link(ReinforcementTrainingPlotResponsibilityChain()) 

81 self.__generated_data: dict = {} 

82 self.__generated_data['train'] = {} 

83 self.__generated_data['test'] = {} 

84 self.__logs: io.StringIO = io.StringIO() 

85 self.__page_width = page_width 

86 self.__page_height = page_height 

87 self.__heading_spacing = heading_spacing 

88 self.__caption_font_size = caption_font_size 

89 self.__text_font_size = text_font_size 

90 self.__font_name = font_name 

91 self.__margins = margins 

92 self.__exclude_from_logs = exclude_from_logs 

93 

94 if not any(term in margins for term in ['left', 'right', 'top', 'bottom']): 

95 raise ValueError("Margins should contain 'left', 'right', 'top' and 'bottom' keys!") 

96 

97 def run_training(self, callbacks: list[Callback] = [], weights_load_path: Optional[str] = None, 

98 weights_save_path: Optional[str] = None) -> None: 

99 """ 

100 Executes the training and testing process for the trading agent. 

101 

102 This method orchestrates the complete training workflow, capturing logs, 

103 training the agent, and testing its performance. It populates internal 

104 data structures with results that can later be used for report generation. 

105 

106 Parameters: 

107 callbacks (list[Callback]): Keras callbacks to use during training. 

108 weights_load_path (str, optional): Path to load pre-trained weights from. 

109 weights_save_path (str, optional): Path to save trained weights to. 

110 """ 

111 

112 log_streamer = logging.StreamHandler(self.__logs) 

113 log_streamer.setFormatter(logging.Formatter('%(message)s')) 

114 log_streamer.setLevel(logging.INFO) 

115 root_logger = logging.getLogger() 

116 root_logger.addHandler(log_streamer) 

117 

118 try: 

119 logging.info(f"Training started!") 

120 logging.info(self.__config_summary.replace('\t', ' ')) 

121 logging.info(f"Printing models architecture...") 

122 self.__agent.print_model_summary(print_function = lambda x: logging.info(x)) 

123 

124 train_keys, train_data = self.__agent.train_agent(self.__nr_of_steps, self.__nr_of_episodes, 

125 callbacks, weights_load_path, weights_save_path) 

126 for key, data in zip(train_keys, train_data): 

127 self.__generated_data['train'][key] = data 

128 

129 test_keys, test_data = self.__agent.test_agent(self.__repeat_test) 

130 for (iteration, key_list), (_, data_list) in zip(test_keys.items(), test_data.items()): 

131 self.__generated_data['test'][iteration] = {} 

132 for key, data in zip(key_list, data_list): 

133 self.__generated_data['test'][iteration][key] = data 

134 

135 logging.info(f"Training finished!") 

136 except Exception as e: 

137 logging.error(f"Training failed! Original error: {e}") 

138 finally: 

139 root_logger.removeHandler(log_streamer) 

140 log_streamer.close() 

141 

142 def __handle_plot_generation(self, data: dict) -> Optional[ImageReader]: 

143 """ 

144 Generates a plot based on provided data using the responsibility chain. 

145 

146 Parameters: 

147 data (dict): Dictionary containing 'key' identifying the plot type 

148 and 'plot_data' containing the actual data to be plotted. 

149 

150 Returns: 

151 Optional[ImageReader]: ReportLab ImageReader object if plot was generated, 

152 None if no handler could process the request. 

153 """ 

154 

155 image_reader = None 

156 axes = self.__plotting_chain.plot(data) 

157 if axes is not None: 

158 buffer = io.BytesIO() 

159 axes.figure.savefig(buffer, format = 'png') 

160 buffer.seek(0) 

161 image_reader = ImageReader(buffer) 

162 plt.close(axes.figure) 

163 else: 

164 logging.warning(f'Did not managed to generate plot for {data["key"]}!') 

165 

166 return image_reader 

167 

168 def __calculate_max_dimensions(self, pdf: canvas.Canvas) -> tuple[int, int]: 

169 """ 

170 Calculates maximum text dimensions that can fit on a PDF page. 

171 

172 Parameters: 

173 pdf (canvas.Canvas): The PDF canvas to calculate dimensions for. 

174 

175 Returns: 

176 tuple[int, int]: Maximum number of characters per line and number 

177 of lines per page that can fit within margins. 

178 """ 

179 

180 pdf.setFont(self.__font_name, self.__text_font_size) 

181 text_width = pdf.stringWidth(' ', self.__font_name, self.__text_font_size) 

182 available_width = self.__page_width - self.__margins['left'] - self.__margins['right'] 

183 max_width = int(available_width / text_width) 

184 

185 text_height = self.__text_font_size * 1.2 

186 available_height = self.__page_height - self.__margins['top'] - self.__margins['bottom'] \ 

187 - self.__heading_spacing 

188 max_height = int(available_height / text_height) 

189 

190 return max_width, max_height 

191 

192 def __handle_logs_preprocessing(self, raw_logs_bufffer: list[str], max_log_length: int, 

193 max_lines_per_page: int) -> list[list[str]]: 

194 """ 

195 Processes raw logs to fit them within PDF page constraints. 

196 

197 This method filters out excluded terms, handles line wrapping for long lines, 

198 and chunks the logs into page-sized portions. 

199 

200 Parameters: 

201 raw_logs_bufffer (list[str]): Raw log lines to process. 

202 max_log_length (int): Maximum characters per line. 

203 max_lines_per_page (int): Maximum lines per page. 

204 

205 Returns: 

206 list[list[str]]: List of pages, where each page is a list of log lines. 

207 """ 

208 

209 preprocessed_logs = [] 

210 for log in raw_logs_bufffer: 

211 if not any(exclude_term in log for exclude_term in self.__exclude_from_logs): 

212 if len(log) > max_log_length: 

213 for i in range(0, len(log), max_log_length): 

214 preprocessed_logs.append(log[i:i + max_log_length]) 

215 else: 

216 preprocessed_logs.append(log) 

217 

218 return [preprocessed_logs[i:i + max_lines_per_page] 

219 for i in range(0, len(preprocessed_logs), max_lines_per_page)] 

220 

221 def __draw_caption(self, pdf: canvas.Canvas, text: str) -> None: 

222 """ 

223 Draws a section caption with separating line on the PDF. 

224 

225 Parameters: 

226 pdf (canvas.Canvas): PDF canvas to draw on. 

227 text (str): Caption text to draw. 

228 """ 

229 

230 pdf.setFont(self.__font_name, self.__caption_font_size) 

231 pdf.drawString(self.__margins['left'], self.__page_height - self.__margins['top'], text) 

232 pdf.setLineWidth(2) 

233 pdf.setStrokeColorRGB(0, 0, 0) 

234 pdf.line(self.__margins['left'], self.__page_height - self.__margins['top'] \ 

235 - self.__heading_spacing / 2, self.__page_width - self.__margins['right'], 

236 self.__page_height - self.__margins['top'] - self.__heading_spacing / 2) 

237 

238 def __draw_text_body(self, pdf: canvas.Canvas, text_body: list[str]) -> None: 

239 """ 

240 Draws a block of text lines on the PDF. 

241 

242 Parameters: 

243 pdf (canvas.Canvas): PDF canvas to draw on. 

244 text_body (list[str]): List of text lines to draw. 

245 """ 

246 

247 text_block = pdf.beginText(self.__margins['left'], self.__page_height - \ 

248 self.__margins['top'] - self.__heading_spacing) 

249 text_block.setFont(self.__font_name, self.__text_font_size) 

250 for line in text_body: 

251 text_block.textLine(line) 

252 pdf.drawText(text_block) 

253 

254 def generate_report(self, path_to_pdf: str) -> None: 

255 """ 

256 Generates a comprehensive PDF report of training and testing results. 

257 

258 Creates a multi-page report with logs, training history plots, and test 

259 results visualizations based on the data collected during training. 

260 

261 Parameters: 

262 path_to_pdf (str): File path where the PDF report should be saved. 

263 """ 

264 

265 logging.info(f"Generating report...") 

266 pdf = canvas.Canvas(path_to_pdf, pagesize = letter) 

267 pdf.setTitle("Report") 

268 

269 # Report logs 

270 raw_logs = self.__logs.getvalue().split('\n') 

271 max_log_length, max_lines_per_page = self.__calculate_max_dimensions(pdf) 

272 preprocessed_logs = self.__handle_logs_preprocessing(raw_logs, max_log_length, max_lines_per_page) 

273 for preprocessed_logs_chunk in preprocessed_logs: 

274 self.__draw_caption(pdf, "Log output") 

275 self.__draw_text_body(pdf, preprocessed_logs_chunk) 

276 pdf.showPage() 

277 

278 # Draw training plot 

279 for key, data in self.__generated_data['train'].items(): 

280 plot_buffer = self.__handle_plot_generation({ 

281 'key': key, 

282 'plot_data': data 

283 }) 

284 if plot_buffer is not None: 

285 self.__draw_caption(pdf, "Training performance") 

286 pdf.drawImage(plot_buffer, 0.5 * inch, 1 * inch, width = letter[0] - 1 * inch, 

287 height = letter[1] - 2 * inch) 

288 pdf.showPage() 

289 

290 # Draw testing plots 

291 for iteration, key_data_pair in self.__generated_data['test'].items(): 

292 for key, data in key_data_pair.items(): 

293 plot_buffer = self.__handle_plot_generation({ 

294 'key': key, 

295 'plot_data': data 

296 }) 

297 if plot_buffer is not None: 

298 self.__draw_caption(pdf, f"Testing outcome, trial: {iteration + 1}") 

299 pdf.drawImage(plot_buffer, 0.5 * inch, 1 * inch, width = letter[0] - 1 * inch, 

300 height = letter[1] - 2 * inch) 

301 pdf.showPage() 

302 

303 pdf.save() 

304 logging.info(f"Report generated!")