Coverage for source/training/training_handler.py: 95%

130 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-31 12:26 +0000

1# training/training_handler.py 

2 

3import logging 

4import io 

5import matplotlib.pyplot as plt 

6from tensorflow.keras.callbacks import Callback 

7from typing import Optional 

8from reportlab.pdfgen import canvas 

9from reportlab.lib.pagesizes import letter 

10from reportlab.lib.units import inch 

11from reportlab.lib.utils import ImageReader 

12 

13from .training_config import TrainingConfig 

14from ..environment.trading_environment import TradingEnvironment 

15from ..agent.agent_handler import AgentHandler 

16from ..plotting.plot_responsibility_chain_base import PlotResponsibilityChainBase 

17from ..plotting.plot_testing_history_responsibility_chain import PlotTestingHistoryResponsibilityChain 

18from ..plotting.plot_training_history_responsibility_chain import PlotTrainingHistoryResponsibilityChain 

19 

20class TrainingHandler(): 

21 """ 

22 Responsible for orchestrating the training process and report generation. 

23 

24 This class manages the complete training workflow, from initializing the 

25 environment and agent, running training and testing sessions, to generating 

26 PDF reports with performance visualizations and logs. It serves as the main 

27 entry point for executing and documenting trading agent training. 

28 """ 

29 

30 """ 

31 PDF report related constants 

32 """ 

33 HEADING_SPACING = 20 

34 CAPTION_FONT_SIZE = 14 

35 TEXT_FONT_SIZE = 8 

36 FONT_NAME = 'Courier' 

37 MARGINS = { 

38 'left': 30, 

39 'right': 30, 

40 'top': 30, 

41 'bottom': 30 

42 } 

43 EXCLUDE_FROM_LOGS = ['ETA'] 

44 

45 def __init__(self, config: TrainingConfig, page_width: int = letter[0], page_height: int = letter[1], 

46 heading_spacing: int = HEADING_SPACING, caption_font_size: int = CAPTION_FONT_SIZE, 

47 text_font_size: int = TEXT_FONT_SIZE, font_name: str = FONT_NAME, 

48 margins: dict[str, int] = MARGINS, exclude_from_logs: list[str] = EXCLUDE_FROM_LOGS) -> None: 

49 """ 

50 Initializes the training handler with configuration parameters. 

51 

52 Parameters: 

53 config (TrainingConfig): Configuration containing environment and agent settings. 

54 page_width (int): Width of PDF report pages in points. 

55 page_height (int): Height of PDF report pages in points. 

56 heading_spacing (int): Spacing between headings and content in points. 

57 caption_font_size (int): Font size for section captions. 

58 text_font_size (int): Font size for main text content. 

59 font_name (str): Font family to use for text in reports. 

60 margins (dict[str, int]): Dictionary with 'left', 'right', 'top', and 'bottom' margins. 

61 exclude_from_logs (list[str]): Terms that should be excluded from logs in reports. 

62 

63 Raises: 

64 ValueError: If margins dictionary doesn't contain required keys. 

65 """ 

66 

67 # Training related configuration 

68 self.__environment: TradingEnvironment = config.instantiate_environment() 

69 self.__agent: AgentHandler = config.instantiate_agent() 

70 self.__nr_of_steps: int = config.nr_of_steps 

71 self.__repeat_test: int = config.repeat_test 

72 self.__steps_per_episode: int = int(config.nr_of_steps / config.nr_of_episodes) 

73 

74 # Report related configuration 

75 self.__config_summary = str(config) 

76 self.__plotting_chain: PlotResponsibilityChainBase = PlotTestingHistoryResponsibilityChain() 

77 self.__plotting_chain.add_next_chain_link(PlotTrainingHistoryResponsibilityChain()) 

78 self.__generated_data: dict = {} 

79 self.__logs: io.StringIO = io.StringIO() 

80 self.__page_width = page_width 

81 self.__page_height = page_height 

82 self.__heading_spacing = heading_spacing 

83 self.__caption_font_size = caption_font_size 

84 self.__text_font_size = text_font_size 

85 self.__font_name = font_name 

86 self.__margins = margins 

87 self.__exclude_from_logs = exclude_from_logs 

88 

89 if not any(term in margins for term in ['left', 'right', 'top', 'bottom']): 

90 raise ValueError("Margins should contain 'left', 'right', 'top' and 'bottom' keys!") 

91 

92 def run_training(self, callbacks: list[Callback] = [], weights_load_path: Optional[str] = None, 

93 weights_save_path: Optional[str] = None) -> None: 

94 """ 

95 Executes the training and testing process for the trading agent. 

96 

97 This method orchestrates the complete training workflow, capturing logs, 

98 training the agent, and testing its performance. It populates internal 

99 data structures with results that can later be used for report generation. 

100 

101 Parameters: 

102 callbacks (list[Callback]): Keras callbacks to use during training. 

103 weights_load_path (str, optional): Path to load pre-trained weights from. 

104 weights_save_path (str, optional): Path to save trained weights to. 

105 """ 

106 

107 log_streamer = logging.StreamHandler(self.__logs) 

108 log_streamer.setFormatter(logging.Formatter('%(message)s')) 

109 log_streamer.setLevel(logging.INFO) 

110 root_logger = logging.getLogger() 

111 root_logger.addHandler(log_streamer) 

112 

113 try: 

114 logging.info(f"Training started!") 

115 logging.info(self.__config_summary.replace('\t', ' ')) 

116 logging.info(f"Printing models architecture...") 

117 self.__agent.print_model_summary(print_function = lambda x: logging.info(x)) 

118 

119 self.__environment.set_mode(TradingEnvironment.TRAIN_MODE) 

120 self.__generated_data['train'] = self.__agent.train_agent(self.__environment, 

121 self.__nr_of_steps, 

122 self.__steps_per_episode, 

123 callbacks, 

124 weights_load_path, 

125 weights_save_path) 

126 

127 self.__environment.set_mode(TradingEnvironment.TEST_MODE) 

128 self.__generated_data['test'] = self.__agent.test_agent(self.__environment, 

129 self.__repeat_test) 

130 

131 logging.info(f"Training finished!") 

132 except Exception as e: 

133 logging.error(f"Training failed! Original error: {e}") 

134 finally: 

135 root_logger.removeHandler(log_streamer) 

136 log_streamer.close() 

137 

138 def __handle_plot_generation(self, data: dict) -> Optional[ImageReader]: 

139 """ 

140 Generates a plot based on provided data using the responsibility chain. 

141 

142 Parameters: 

143 data (dict): Dictionary containing 'key' identifying the plot type 

144 and 'plot_data' containing the actual data to be plotted. 

145 

146 Returns: 

147 Optional[ImageReader]: ReportLab ImageReader object if plot was generated, 

148 None if no handler could process the request. 

149 """ 

150 

151 image_reader = None 

152 axes = self.__plotting_chain.plot(data) 

153 if axes is not None: 

154 buffer = io.BytesIO() 

155 axes.figure.savefig(buffer, format = 'png') 

156 buffer.seek(0) 

157 image_reader = ImageReader(buffer) 

158 plt.close(axes.figure) 

159 else: 

160 logging.warning(f'Did not managed to generate plot for {data["key"]}!') 

161 

162 return image_reader 

163 

164 def __calculate_max_dimensions(self, pdf: canvas.Canvas) -> tuple[int, int]: 

165 """ 

166 Calculates maximum text dimensions that can fit on a PDF page. 

167 

168 Parameters: 

169 pdf (canvas.Canvas): The PDF canvas to calculate dimensions for. 

170 

171 Returns: 

172 tuple[int, int]: Maximum number of characters per line and number 

173 of lines per page that can fit within margins. 

174 """ 

175 

176 pdf.setFont(self.__font_name, self.__text_font_size) 

177 text_width = pdf.stringWidth(' ', self.__font_name, self.__text_font_size) 

178 available_width = self.__page_width - self.__margins['left'] - self.__margins['right'] 

179 max_width = int(available_width / text_width) 

180 

181 text_height = self.__text_font_size * 1.2 

182 available_height = self.__page_height - self.__margins['top'] - self.__margins['bottom'] \ 

183 - self.__heading_spacing 

184 max_height = int(available_height / text_height) 

185 

186 return max_width, max_height 

187 

188 def __handle_logs_preprocessing(self, raw_logs_bufffer: list[str], max_log_length: int, 

189 max_lines_per_page: int) -> list[list[str]]: 

190 """ 

191 Processes raw logs to fit them within PDF page constraints. 

192 

193 This method filters out excluded terms, handles line wrapping for long lines, 

194 and chunks the logs into page-sized portions. 

195 

196 Parameters: 

197 raw_logs_bufffer (list[str]): Raw log lines to process. 

198 max_log_length (int): Maximum characters per line. 

199 max_lines_per_page (int): Maximum lines per page. 

200 

201 Returns: 

202 list[list[str]]: List of pages, where each page is a list of log lines. 

203 """ 

204 

205 preprocessed_logs = [] 

206 for log in raw_logs_bufffer: 

207 if not any(exclude_term in log for exclude_term in self.__exclude_from_logs): 

208 if len(log) > max_log_length: 

209 for i in range(0, len(log), max_log_length): 

210 preprocessed_logs.append(log[i:i + max_log_length]) 

211 else: 

212 preprocessed_logs.append(log) 

213 

214 return [preprocessed_logs[i:i + max_lines_per_page] 

215 for i in range(0, len(preprocessed_logs), max_lines_per_page)] 

216 

217 def __draw_caption(self, pdf: canvas.Canvas, text: str) -> None: 

218 """ 

219 Draws a section caption with separating line on the PDF. 

220 

221 Parameters: 

222 pdf (canvas.Canvas): PDF canvas to draw on. 

223 text (str): Caption text to draw. 

224 """ 

225 

226 pdf.setFont(self.__font_name, self.__caption_font_size) 

227 pdf.drawString(self.__margins['left'], self.__page_height - self.__margins['top'], text) 

228 pdf.setLineWidth(2) 

229 pdf.setStrokeColorRGB(0, 0, 0) 

230 pdf.line(self.__margins['left'], self.__page_height - self.__margins['top'] \ 

231 - self.__heading_spacing / 2, self.__page_width - self.__margins['right'], 

232 self.__page_height - self.__margins['top'] - self.__heading_spacing / 2) 

233 

234 def __draw_text_body(self, pdf: canvas.Canvas, text_body: list[str]) -> None: 

235 """ 

236 Draws a block of text lines on the PDF. 

237 

238 Parameters: 

239 pdf (canvas.Canvas): PDF canvas to draw on. 

240 text_body (list[str]): List of text lines to draw. 

241 """ 

242 

243 text_block = pdf.beginText(self.__margins['left'], self.__page_height - \ 

244 self.__margins['top'] - self.__heading_spacing) 

245 text_block.setFont(self.__font_name, self.__text_font_size) 

246 for line in text_body: 

247 text_block.textLine(line) 

248 pdf.drawText(text_block) 

249 

250 def generate_report(self, path_to_pdf: str) -> None: 

251 """ 

252 Generates a comprehensive PDF report of training and testing results. 

253 

254 Creates a multi-page report with logs, training history plots, and test 

255 results visualizations based on the data collected during training. 

256 

257 Parameters: 

258 path_to_pdf (str): File path where the PDF report should be saved. 

259 """ 

260 

261 logging.info(f"Generating report...") 

262 pdf = canvas.Canvas(path_to_pdf, pagesize = letter) 

263 pdf.setTitle("Report") 

264 

265 # Report logs 

266 raw_logs = self.__logs.getvalue().split('\n') 

267 max_log_length, max_lines_per_page = self.__calculate_max_dimensions(pdf) 

268 preprocessed_logs = self.__handle_logs_preprocessing(raw_logs, max_log_length, max_lines_per_page) 

269 for preprocessed_logs_chunk in preprocessed_logs: 

270 self.__draw_caption(pdf, "Log output") 

271 self.__draw_text_body(pdf, preprocessed_logs_chunk) 

272 pdf.showPage() 

273 

274 # Draw training plot 

275 data = { 

276 'key': 'training_history', 

277 'plot_data': self.__generated_data['train'] 

278 } 

279 plot_buffer = self.__handle_plot_generation(data) 

280 if plot_buffer is not None: 

281 self.__draw_caption(pdf, "Training performance") 

282 pdf.drawImage(plot_buffer, inch, self.__page_height - 7.5 * inch, width = 6 * inch, 

283 preserveAspectRatio = True) 

284 pdf.showPage() 

285 

286 # Draw testing plots 

287 for index, testing_data in self.__generated_data['test'].items(): 

288 data = { 

289 'key': 'testing_history', 

290 'plot_data': testing_data 

291 } 

292 plot_buffer = self.__handle_plot_generation(data) 

293 if plot_buffer is not None: 

294 self.__draw_caption(pdf, f"Testing outcome, trial: {index + 1}") 

295 pdf.drawImage(plot_buffer, inch, self.__page_height - 7.5 * inch, width = 6 * inch, 

296 preserveAspectRatio = True) 

297 pdf.showPage() 

298 

299 pdf.save() 

300 logging.info(f"Report generated!")