Coverage for source/training/training_handler.py: 20%

133 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-06-06 12:00 +0000

1# training/training_handler.py 

2 

3# global imports 

4import logging 

5import io 

6import matplotlib.pyplot as plt 

7from tensorflow.keras.callbacks import Callback 

8from typing import Optional 

9from reportlab.pdfgen import canvas 

10from reportlab.lib.pagesizes import letter 

11from reportlab.lib.units import inch 

12from reportlab.lib.utils import ImageReader 

13 

14# local imports 

15from source.agent import AgentHandler 

16from source.plotting import ClassificationTestingPlotResponsibilityChain, \ 

17 ClassificationTrainingPlotResponsibilityChain, PlotResponsibilityChainBase, \ 

18 PlotTestingHistoryResponsibilityChain, PlotTrainingHistoryResponsibilityChain 

19from source.training import TrainingConfig 

20 

21class TrainingHandler(): 

22 """ 

23 Responsible for orchestrating the training process and report generation. 

24 

25 This class manages the complete training workflow, from initializing the 

26 environment and agent, running training and testing sessions, to generating 

27 PDF reports with performance visualizations and logs. It serves as the main 

28 entry point for executing and documenting trading agent training. 

29 """ 

30 

31 # Constants used locally 

32 HEADING_SPACING = 20 

33 CAPTION_FONT_SIZE = 14 

34 TEXT_FONT_SIZE = 8 

35 FONT_NAME = 'Courier' 

36 MARGINS = { 

37 'left': 30, 

38 'right': 30, 

39 'top': 30, 

40 'bottom': 30 

41 } 

42 EXCLUDE_FROM_LOGS = ['ETA'] 

43 

44 def __init__(self, config: TrainingConfig, page_width: int = letter[0], page_height: int = letter[1], 

45 heading_spacing: int = HEADING_SPACING, caption_font_size: int = CAPTION_FONT_SIZE, 

46 text_font_size: int = TEXT_FONT_SIZE, font_name: str = FONT_NAME, 

47 margins: dict[str, int] = MARGINS, exclude_from_logs: list[str] = EXCLUDE_FROM_LOGS) -> None: 

48 """ 

49 Initializes the training handler with configuration parameters. 

50 

51 Parameters: 

52 config (TrainingConfig): Configuration containing environment and agent settings. 

53 page_width (int): Width of PDF report pages in points. 

54 page_height (int): Height of PDF report pages in points. 

55 heading_spacing (int): Spacing between headings and content in points. 

56 caption_font_size (int): Font size for section captions. 

57 text_font_size (int): Font size for main text content. 

58 font_name (str): Font family to use for text in reports. 

59 margins (dict[str, int]): Dictionary with 'left', 'right', 'top', and 'bottom' margins. 

60 exclude_from_logs (list[str]): Terms that should be excluded from logs in reports. 

61 

62 Raises: 

63 ValueError: If margins dictionary doesn't contain required keys. 

64 """ 

65 

66 # Training related configuration 

67 self.__agent: AgentHandler = config.instantiate_agent_handler() 

68 self.__nr_of_steps: int = config.nr_of_steps 

69 self.__repeat_test: int = config.repeat_test 

70 self.__nr_of_episodes: int = config.nr_of_episodes 

71 

72 # Report related configuration 

73 self.__config_summary = str(config) 

74 # self.__plotting_chain: PlotResponsibilityChainBase = PlotTestingHistoryResponsibilityChain() 

75 # self.__plotting_chain.add_next_chain_link(PlotTrainingHistoryResponsibilityChain()) 

76 # self.__plotting_chain.add_next_chain_link(ClassificationTrainingPlotResponsibilityChain()) 

77 # self.__plotting_chain.add_next_chain_link(ClassificationTestingPlotResponsibilityChain()) 

78 classification_training_plot_chain = ClassificationTrainingPlotResponsibilityChain() 

79 classification_training_plot_chain.add_next_chain_link(ClassificationTestingPlotResponsibilityChain()) 

80 plot_training_history_chain = PlotTrainingHistoryResponsibilityChain() 

81 plot_training_history_chain.add_next_chain_link(classification_training_plot_chain) 

82 self.__plotting_chain: PlotResponsibilityChainBase = PlotTestingHistoryResponsibilityChain() 

83 self.__plotting_chain.add_next_chain_link(plot_training_history_chain) 

84 self.__generated_data: dict = {} 

85 self.__generated_data['train'] = {} 

86 self.__generated_data['test'] = {} 

87 self.__logs: io.StringIO = io.StringIO() 

88 self.__page_width = page_width 

89 self.__page_height = page_height 

90 self.__heading_spacing = heading_spacing 

91 self.__caption_font_size = caption_font_size 

92 self.__text_font_size = text_font_size 

93 self.__font_name = font_name 

94 self.__margins = margins 

95 self.__exclude_from_logs = exclude_from_logs 

96 

97 if not any(term in margins for term in ['left', 'right', 'top', 'bottom']): 

98 raise ValueError("Margins should contain 'left', 'right', 'top' and 'bottom' keys!") 

99 

100 def run_training(self, callbacks: list[Callback] = [], weights_load_path: Optional[str] = None, 

101 weights_save_path: Optional[str] = None) -> None: 

102 """ 

103 Executes the training and testing process for the trading agent. 

104 

105 This method orchestrates the complete training workflow, capturing logs, 

106 training the agent, and testing its performance. It populates internal 

107 data structures with results that can later be used for report generation. 

108 

109 Parameters: 

110 callbacks (list[Callback]): Keras callbacks to use during training. 

111 weights_load_path (str, optional): Path to load pre-trained weights from. 

112 weights_save_path (str, optional): Path to save trained weights to. 

113 """ 

114 

115 log_streamer = logging.StreamHandler(self.__logs) 

116 log_streamer.setFormatter(logging.Formatter('%(message)s')) 

117 log_streamer.setLevel(logging.INFO) 

118 root_logger = logging.getLogger() 

119 root_logger.addHandler(log_streamer) 

120 

121 # try: 

122 logging.info(f"Training started!") 

123 logging.info(self.__config_summary.replace('\t', ' ')) 

124 logging.info(f"Printing models architecture...") 

125 self.__agent.print_model_summary(print_function = lambda x: logging.info(x)) 

126 

127 train_keys, train_data = self.__agent.train_agent(self.__nr_of_steps, self.__nr_of_episodes, 

128 callbacks, weights_load_path, weights_save_path) 

129 for key, data in zip(train_keys, train_data): 

130 self.__generated_data['train'][key] = data 

131 

132 test_keys, test_data = self.__agent.test_agent(self.__repeat_test) 

133 for (iteration, key_list), (_, data_list) in zip(test_keys.items(), test_data.items()): 

134 self.__generated_data['test'][iteration] = {} 

135 for key, data in zip(key_list, data_list): 

136 self.__generated_data['test'][iteration][key] = data 

137 

138 logging.info(f"Training finished!") 

139 # except Exception as e: 

140 # logging.error(f"Training failed! Original error: {e}") 

141 # finally: 

142 root_logger.removeHandler(log_streamer) 

143 log_streamer.close() 

144 

145 def __handle_plot_generation(self, data: dict) -> Optional[ImageReader]: 

146 """ 

147 Generates a plot based on provided data using the responsibility chain. 

148 

149 Parameters: 

150 data (dict): Dictionary containing 'key' identifying the plot type 

151 and 'plot_data' containing the actual data to be plotted. 

152 

153 Returns: 

154 Optional[ImageReader]: ReportLab ImageReader object if plot was generated, 

155 None if no handler could process the request. 

156 """ 

157 

158 image_reader = None 

159 axes = self.__plotting_chain.plot(data) 

160 if axes is not None: 

161 buffer = io.BytesIO() 

162 axes.figure.savefig(buffer, format = 'png') 

163 buffer.seek(0) 

164 image_reader = ImageReader(buffer) 

165 plt.close(axes.figure) 

166 else: 

167 logging.warning(f'Did not managed to generate plot for {data["key"]}!') 

168 

169 return image_reader 

170 

171 def __calculate_max_dimensions(self, pdf: canvas.Canvas) -> tuple[int, int]: 

172 """ 

173 Calculates maximum text dimensions that can fit on a PDF page. 

174 

175 Parameters: 

176 pdf (canvas.Canvas): The PDF canvas to calculate dimensions for. 

177 

178 Returns: 

179 tuple[int, int]: Maximum number of characters per line and number 

180 of lines per page that can fit within margins. 

181 """ 

182 

183 pdf.setFont(self.__font_name, self.__text_font_size) 

184 text_width = pdf.stringWidth(' ', self.__font_name, self.__text_font_size) 

185 available_width = self.__page_width - self.__margins['left'] - self.__margins['right'] 

186 max_width = int(available_width / text_width) 

187 

188 text_height = self.__text_font_size * 1.2 

189 available_height = self.__page_height - self.__margins['top'] - self.__margins['bottom'] \ 

190 - self.__heading_spacing 

191 max_height = int(available_height / text_height) 

192 

193 return max_width, max_height 

194 

195 def __handle_logs_preprocessing(self, raw_logs_bufffer: list[str], max_log_length: int, 

196 max_lines_per_page: int) -> list[list[str]]: 

197 """ 

198 Processes raw logs to fit them within PDF page constraints. 

199 

200 This method filters out excluded terms, handles line wrapping for long lines, 

201 and chunks the logs into page-sized portions. 

202 

203 Parameters: 

204 raw_logs_bufffer (list[str]): Raw log lines to process. 

205 max_log_length (int): Maximum characters per line. 

206 max_lines_per_page (int): Maximum lines per page. 

207 

208 Returns: 

209 list[list[str]]: List of pages, where each page is a list of log lines. 

210 """ 

211 

212 preprocessed_logs = [] 

213 for log in raw_logs_bufffer: 

214 if not any(exclude_term in log for exclude_term in self.__exclude_from_logs): 

215 if len(log) > max_log_length: 

216 for i in range(0, len(log), max_log_length): 

217 preprocessed_logs.append(log[i:i + max_log_length]) 

218 else: 

219 preprocessed_logs.append(log) 

220 

221 return [preprocessed_logs[i:i + max_lines_per_page] 

222 for i in range(0, len(preprocessed_logs), max_lines_per_page)] 

223 

224 def __draw_caption(self, pdf: canvas.Canvas, text: str) -> None: 

225 """ 

226 Draws a section caption with separating line on the PDF. 

227 

228 Parameters: 

229 pdf (canvas.Canvas): PDF canvas to draw on. 

230 text (str): Caption text to draw. 

231 """ 

232 

233 pdf.setFont(self.__font_name, self.__caption_font_size) 

234 pdf.drawString(self.__margins['left'], self.__page_height - self.__margins['top'], text) 

235 pdf.setLineWidth(2) 

236 pdf.setStrokeColorRGB(0, 0, 0) 

237 pdf.line(self.__margins['left'], self.__page_height - self.__margins['top'] \ 

238 - self.__heading_spacing / 2, self.__page_width - self.__margins['right'], 

239 self.__page_height - self.__margins['top'] - self.__heading_spacing / 2) 

240 

241 def __draw_text_body(self, pdf: canvas.Canvas, text_body: list[str]) -> None: 

242 """ 

243 Draws a block of text lines on the PDF. 

244 

245 Parameters: 

246 pdf (canvas.Canvas): PDF canvas to draw on. 

247 text_body (list[str]): List of text lines to draw. 

248 """ 

249 

250 text_block = pdf.beginText(self.__margins['left'], self.__page_height - \ 

251 self.__margins['top'] - self.__heading_spacing) 

252 text_block.setFont(self.__font_name, self.__text_font_size) 

253 for line in text_body: 

254 text_block.textLine(line) 

255 pdf.drawText(text_block) 

256 

257 def generate_report(self, path_to_pdf: str) -> None: 

258 """ 

259 Generates a comprehensive PDF report of training and testing results. 

260 

261 Creates a multi-page report with logs, training history plots, and test 

262 results visualizations based on the data collected during training. 

263 

264 Parameters: 

265 path_to_pdf (str): File path where the PDF report should be saved. 

266 """ 

267 

268 logging.info(f"Generating report...") 

269 pdf = canvas.Canvas(path_to_pdf, pagesize = letter) 

270 pdf.setTitle("Report") 

271 

272 # Report logs 

273 raw_logs = self.__logs.getvalue().split('\n') 

274 max_log_length, max_lines_per_page = self.__calculate_max_dimensions(pdf) 

275 preprocessed_logs = self.__handle_logs_preprocessing(raw_logs, max_log_length, max_lines_per_page) 

276 for preprocessed_logs_chunk in preprocessed_logs: 

277 self.__draw_caption(pdf, "Log output") 

278 self.__draw_text_body(pdf, preprocessed_logs_chunk) 

279 pdf.showPage() 

280 

281 # Draw training plot 

282 for key, data in self.__generated_data['train'].items(): 

283 plot_buffer = self.__handle_plot_generation({ 

284 'key': key, 

285 'plot_data': data 

286 }) 

287 if plot_buffer is not None: 

288 self.__draw_caption(pdf, "Training performance") 

289 pdf.drawImage(plot_buffer, inch, self.__page_height - 7.5 * inch, width = 6 * inch, 

290 preserveAspectRatio = True) 

291 pdf.showPage() 

292 

293 # Draw testing plots 

294 for iteration, key_data_pair in self.__generated_data['test'].items(): 

295 for key, data in key_data_pair.items(): 

296 plot_buffer = self.__handle_plot_generation({ 

297 'key': key, 

298 'plot_data': data 

299 }) 

300 if plot_buffer is not None: 

301 self.__draw_caption(pdf, f"Testing outcome, trial: {iteration}") 

302 pdf.drawImage(plot_buffer, 0.5 * inch, 1 * inch, width = letter[0] - 1 * inch, 

303 height = letter[1] - 2 * inch) 

304 pdf.showPage() 

305 

306 pdf.save() 

307 logging.info(f"Report generated!")