Coverage for source/training/training_handler.py: 95%
130 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-31 06:53 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-31 06:53 +0000
1# training/training_handler.py
3import logging
4import io
5import matplotlib.pyplot as plt
6from tensorflow.keras.callbacks import Callback
7from typing import Optional
8from reportlab.pdfgen import canvas
9from reportlab.lib.pagesizes import letter
10from reportlab.lib.units import inch
11from reportlab.lib.utils import ImageReader
13from .training_config import TrainingConfig
14from ..environment.trading_environment import TradingEnvironment
15from ..agent.agent_handler import AgentHandler
16from ..plotting.plot_responsibility_chain_base import PlotResponsibilityChainBase
17from ..plotting.plot_testing_history_responsibility_chain import PlotTestingHistoryResponsibilityChain
18from ..plotting.plot_training_history_responsibility_chain import PlotTrainingHistoryResponsibilityChain
20class TrainingHandler():
21 """
22 Responsible for orchestrating the training process and report generation.
24 This class manages the complete training workflow, from initializing the
25 environment and agent, running training and testing sessions, to generating
26 PDF reports with performance visualizations and logs. It serves as the main
27 entry point for executing and documenting trading agent training.
28 """
30 """
31 PDF report related constants
32 """
33 HEADING_SPACING = 20
34 CAPTION_FONT_SIZE = 14
35 TEXT_FONT_SIZE = 8
36 FONT_NAME = 'Courier'
37 MARGINS = {
38 'left': 30,
39 'right': 30,
40 'top': 30,
41 'bottom': 30
42 }
43 EXCLUDE_FROM_LOGS = ['ETA']
45 def __init__(self, config: TrainingConfig, page_width: int = letter[0], page_height: int = letter[1],
46 heading_spacing: int = HEADING_SPACING, caption_font_size: int = CAPTION_FONT_SIZE,
47 text_font_size: int = TEXT_FONT_SIZE, font_name: str = FONT_NAME,
48 margins: dict[str, int] = MARGINS, exclude_from_logs: list[str] = EXCLUDE_FROM_LOGS) -> None:
49 """
50 Initializes the training handler with configuration parameters.
52 Parameters:
53 config (TrainingConfig): Configuration containing environment and agent settings.
54 page_width (int): Width of PDF report pages in points.
55 page_height (int): Height of PDF report pages in points.
56 heading_spacing (int): Spacing between headings and content in points.
57 caption_font_size (int): Font size for section captions.
58 text_font_size (int): Font size for main text content.
59 font_name (str): Font family to use for text in reports.
60 margins (dict[str, int]): Dictionary with 'left', 'right', 'top', and 'bottom' margins.
61 exclude_from_logs (list[str]): Terms that should be excluded from logs in reports.
63 Raises:
64 ValueError: If margins dictionary doesn't contain required keys.
65 """
67 # Training related configuration
68 self.__environment: TradingEnvironment = config.instantiate_environment()
69 self.__agent: AgentHandler = config.instantiate_agent()
70 self.__nr_of_steps: int = config.nr_of_steps
71 self.__repeat_test: int = config.repeat_test
72 self.__steps_per_episode: int = int(config.nr_of_steps / config.nr_of_episodes)
74 # Report related configuration
75 self.__config_summary = str(config)
76 self.__plotting_chain: PlotResponsibilityChainBase = PlotTestingHistoryResponsibilityChain()
77 self.__plotting_chain.add_next_chain_link(PlotTrainingHistoryResponsibilityChain())
78 self.__generated_data: dict = {}
79 self.__logs: io.StringIO = io.StringIO()
80 self.__page_width = page_width
81 self.__page_height = page_height
82 self.__heading_spacing = heading_spacing
83 self.__caption_font_size = caption_font_size
84 self.__text_font_size = text_font_size
85 self.__font_name = font_name
86 self.__margins = margins
87 self.__exclude_from_logs = exclude_from_logs
89 if not any(term in margins for term in ['left', 'right', 'top', 'bottom']):
90 raise ValueError("Margins should contain 'left', 'right', 'top' and 'bottom' keys!")
92 def run_training(self, callbacks: list[Callback] = [], weights_load_path: Optional[str] = None,
93 weights_save_path: Optional[str] = None) -> None:
94 """
95 Executes the training and testing process for the trading agent.
97 This method orchestrates the complete training workflow, capturing logs,
98 training the agent, and testing its performance. It populates internal
99 data structures with results that can later be used for report generation.
101 Parameters:
102 callbacks (list[Callback]): Keras callbacks to use during training.
103 weights_load_path (str, optional): Path to load pre-trained weights from.
104 weights_save_path (str, optional): Path to save trained weights to.
105 """
107 log_streamer = logging.StreamHandler(self.__logs)
108 log_streamer.setFormatter(logging.Formatter('%(message)s'))
109 log_streamer.setLevel(logging.INFO)
110 root_logger = logging.getLogger()
111 root_logger.addHandler(log_streamer)
113 try:
114 logging.info(f"Training started!")
115 logging.info(self.__config_summary.replace('\t', ' '))
116 logging.info(f"Printing models architecture...")
117 self.__agent.print_model_summary(print_function = lambda x: logging.info(x))
119 self.__environment.set_mode(TradingEnvironment.TRAIN_MODE)
120 self.__generated_data['train'] = self.__agent.train_agent(self.__environment,
121 self.__nr_of_steps,
122 self.__steps_per_episode,
123 callbacks,
124 weights_load_path,
125 weights_save_path)
127 self.__environment.set_mode(TradingEnvironment.TEST_MODE)
128 self.__generated_data['test'] = self.__agent.test_agent(self.__environment,
129 self.__repeat_test)
131 logging.info(f"Training finished!")
132 except Exception as e:
133 logging.error(f"Training failed! Original error: {e}")
134 finally:
135 root_logger.removeHandler(log_streamer)
136 log_streamer.close()
138 def __handle_plot_generation(self, data: dict) -> Optional[ImageReader]:
139 """
140 Generates a plot based on provided data using the responsibility chain.
142 Parameters:
143 data (dict): Dictionary containing 'key' identifying the plot type
144 and 'plot_data' containing the actual data to be plotted.
146 Returns:
147 Optional[ImageReader]: ReportLab ImageReader object if plot was generated,
148 None if no handler could process the request.
149 """
151 image_reader = None
152 axes = self.__plotting_chain.plot(data)
153 if axes is not None:
154 buffer = io.BytesIO()
155 axes.figure.savefig(buffer, format = 'png')
156 buffer.seek(0)
157 image_reader = ImageReader(buffer)
158 plt.close(axes.figure)
159 else:
160 logging.warning(f'Did not managed to generate plot for {data["key"]}!')
162 return image_reader
164 def __calculate_max_dimensions(self, pdf: canvas.Canvas) -> tuple[int, int]:
165 """
166 Calculates maximum text dimensions that can fit on a PDF page.
168 Parameters:
169 pdf (canvas.Canvas): The PDF canvas to calculate dimensions for.
171 Returns:
172 tuple[int, int]: Maximum number of characters per line and number
173 of lines per page that can fit within margins.
174 """
176 pdf.setFont(self.__font_name, self.__text_font_size)
177 text_width = pdf.stringWidth(' ', self.__font_name, self.__text_font_size)
178 available_width = self.__page_width - self.__margins['left'] - self.__margins['right']
179 max_width = int(available_width / text_width)
181 text_height = self.__text_font_size * 1.2
182 available_height = self.__page_height - self.__margins['top'] - self.__margins['bottom'] \
183 - self.__heading_spacing
184 max_height = int(available_height / text_height)
186 return max_width, max_height
188 def __handle_logs_preprocessing(self, raw_logs_bufffer: list[str], max_log_length: int,
189 max_lines_per_page: int) -> list[list[str]]:
190 """
191 Processes raw logs to fit them within PDF page constraints.
193 This method filters out excluded terms, handles line wrapping for long lines,
194 and chunks the logs into page-sized portions.
196 Parameters:
197 raw_logs_bufffer (list[str]): Raw log lines to process.
198 max_log_length (int): Maximum characters per line.
199 max_lines_per_page (int): Maximum lines per page.
201 Returns:
202 list[list[str]]: List of pages, where each page is a list of log lines.
203 """
205 preprocessed_logs = []
206 for log in raw_logs_bufffer:
207 if not any(exclude_term in log for exclude_term in self.__exclude_from_logs):
208 if len(log) > max_log_length:
209 for i in range(0, len(log), max_log_length):
210 preprocessed_logs.append(log[i:i + max_log_length])
211 else:
212 preprocessed_logs.append(log)
214 return [preprocessed_logs[i:i + max_lines_per_page]
215 for i in range(0, len(preprocessed_logs), max_lines_per_page)]
217 def __draw_caption(self, pdf: canvas.Canvas, text: str) -> None:
218 """
219 Draws a section caption with separating line on the PDF.
221 Parameters:
222 pdf (canvas.Canvas): PDF canvas to draw on.
223 text (str): Caption text to draw.
224 """
226 pdf.setFont(self.__font_name, self.__caption_font_size)
227 pdf.drawString(self.__margins['left'], self.__page_height - self.__margins['top'], text)
228 pdf.setLineWidth(2)
229 pdf.setStrokeColorRGB(0, 0, 0)
230 pdf.line(self.__margins['left'], self.__page_height - self.__margins['top'] \
231 - self.__heading_spacing / 2, self.__page_width - self.__margins['right'],
232 self.__page_height - self.__margins['top'] - self.__heading_spacing / 2)
234 def __draw_text_body(self, pdf: canvas.Canvas, text_body: list[str]) -> None:
235 """
236 Draws a block of text lines on the PDF.
238 Parameters:
239 pdf (canvas.Canvas): PDF canvas to draw on.
240 text_body (list[str]): List of text lines to draw.
241 """
243 text_block = pdf.beginText(self.__margins['left'], self.__page_height - \
244 self.__margins['top'] - self.__heading_spacing)
245 text_block.setFont(self.__font_name, self.__text_font_size)
246 for line in text_body:
247 text_block.textLine(line)
248 pdf.drawText(text_block)
250 def generate_report(self, path_to_pdf: str) -> None:
251 """
252 Generates a comprehensive PDF report of training and testing results.
254 Creates a multi-page report with logs, training history plots, and test
255 results visualizations based on the data collected during training.
257 Parameters:
258 path_to_pdf (str): File path where the PDF report should be saved.
259 """
261 logging.info(f"Generating report...")
262 pdf = canvas.Canvas(path_to_pdf, pagesize = letter)
263 pdf.setTitle("Report")
265 # Report logs
266 raw_logs = self.__logs.getvalue().split('\n')
267 max_log_length, max_lines_per_page = self.__calculate_max_dimensions(pdf)
268 preprocessed_logs = self.__handle_logs_preprocessing(raw_logs, max_log_length, max_lines_per_page)
269 for preprocessed_logs_chunk in preprocessed_logs:
270 self.__draw_caption(pdf, "Log output")
271 self.__draw_text_body(pdf, preprocessed_logs_chunk)
272 pdf.showPage()
274 # Draw training plot
275 data = {
276 'key': 'training_history',
277 'plot_data': self.__generated_data['train']
278 }
279 plot_buffer = self.__handle_plot_generation(data)
280 if plot_buffer is not None:
281 self.__draw_caption(pdf, "Training performance")
282 pdf.drawImage(plot_buffer, inch, self.__page_height - 7.5 * inch, width = 6 * inch,
283 preserveAspectRatio = True)
284 pdf.showPage()
286 # Draw testing plots
287 for index, testing_data in self.__generated_data['test'].items():
288 data = {
289 'key': 'testing_history',
290 'plot_data': testing_data
291 }
292 plot_buffer = self.__handle_plot_generation(data)
293 if plot_buffer is not None:
294 self.__draw_caption(pdf, f"Testing outcome, trial: {index + 1}")
295 pdf.drawImage(plot_buffer, inch, self.__page_height - 7.5 * inch, width = 6 * inch,
296 preserveAspectRatio = True)
297 pdf.showPage()
299 pdf.save()
300 logging.info(f"Report generated!")