Coverage for source/training/training_handler.py: 95%
128 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-30 15:13 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-30 15:13 +0000
1# training/training_handler.py
3import logging
4import io
5import matplotlib.pyplot as plt
6from tensorflow.keras.callbacks import Callback
7from typing import Optional
8from reportlab.pdfgen import canvas
9from reportlab.lib.pagesizes import letter
10from reportlab.lib.units import inch
11from reportlab.lib.utils import ImageReader
13from .training_config import TrainingConfig
14from ..environment.trading_environment import TradingEnvironment
15from ..agent.agent_handler import AgentHandler
16from ..plotting.plot_responsibility_chain_base import PlotResponsibilityChainBase
17from ..plotting.plot_testing_history_responsibility_chain import PlotTestingHistoryResponsibilityChain
18from ..plotting.plot_training_history_responsibility_chain import PlotTrainingHistoryResponsibilityChain
20class TrainingHandler():
21 """
22 Responsible for orchestrating the training process and report generation.
24 This class manages the complete training workflow, from initializing the
25 environment and agent, running training and testing sessions, to generating
26 PDF reports with performance visualizations and logs. It serves as the main
27 entry point for executing and documenting trading agent training.
28 """
30 """
31 PDF report related constants
32 """
33 HEADING_SPACING = 20
34 CAPTION_FONT_SIZE = 14
35 TEXT_FONT_SIZE = 8
36 FONT_NAME = 'Courier'
37 MARGINS = {
38 'left': 30,
39 'right': 30,
40 'top': 30,
41 'bottom': 30
42 }
43 EXCLUDE_FROM_LOGS = ['ETA']
45 def __init__(self, config: TrainingConfig, page_width: int = letter[0], page_height: int = letter[1],
46 heading_spacing: int = HEADING_SPACING, caption_font_size: int = CAPTION_FONT_SIZE,
47 text_font_size: int = TEXT_FONT_SIZE, font_name: str = FONT_NAME,
48 margins: dict[str, int] = MARGINS, exclude_from_logs: list[str] = EXCLUDE_FROM_LOGS) -> None:
49 """
50 Initializes the training handler with configuration parameters.
52 Parameters:
53 config (TrainingConfig): Configuration containing environment and agent settings.
54 page_width (int): Width of PDF report pages in points.
55 page_height (int): Height of PDF report pages in points.
56 heading_spacing (int): Spacing between headings and content in points.
57 caption_font_size (int): Font size for section captions.
58 text_font_size (int): Font size for main text content.
59 font_name (str): Font family to use for text in reports.
60 margins (dict[str, int]): Dictionary with 'left', 'right', 'top', and 'bottom' margins.
61 exclude_from_logs (list[str]): Terms that should be excluded from logs in reports.
63 Raises:
64 ValueError: If margins dictionary doesn't contain required keys.
65 """
67 # Training related configuration
68 self.__environment: TradingEnvironment = config.instantiate_environment()
69 self.__agent: AgentHandler = config.instantiate_agent()
70 self.__nr_of_steps: int = config.nr_of_steps
71 self.__repeat_test: int = config.repeat_test
72 self.__steps_per_episode: int = int(config.nr_of_steps / config.nr_of_episodes)
74 # Report related configuration
75 self.__config_summary = str(config)
76 self.__plotting_chain: PlotResponsibilityChainBase = PlotTestingHistoryResponsibilityChain()
77 self.__plotting_chain.add_next_chain_link(PlotTrainingHistoryResponsibilityChain())
78 self.__generated_data: dict = {}
79 self.__logs: io.StringIO = io.StringIO()
80 self.__page_width = page_width
81 self.__page_height = page_height
82 self.__heading_spacing = heading_spacing
83 self.__caption_font_size = caption_font_size
84 self.__text_font_size = text_font_size
85 self.__font_name = font_name
86 self.__margins = margins
87 self.__exclude_from_logs = exclude_from_logs
89 if not any(term in margins for term in ['left', 'right', 'top', 'bottom']):
90 raise ValueError("Margins should contain 'left', 'right', 'top' and 'bottom' keys!")
92 def run_training(self, callbacks: list[Callback] = [], weights_load_path: Optional[str] = None,
93 weights_save_path: Optional[str] = None) -> None:
94 """
95 Executes the training and testing process for the trading agent.
97 This method orchestrates the complete training workflow, capturing logs,
98 training the agent, and testing its performance. It populates internal
99 data structures with results that can later be used for report generation.
101 Parameters:
102 callbacks (list[Callback]): Keras callbacks to use during training.
103 weights_load_path (str, optional): Path to load pre-trained weights from.
104 weights_save_path (str, optional): Path to save trained weights to.
105 """
107 log_streamer = logging.StreamHandler(self.__logs)
108 log_streamer.setFormatter(logging.Formatter('%(message)s'))
109 log_streamer.setLevel(logging.INFO)
110 root_logger = logging.getLogger()
111 root_logger.addHandler(log_streamer)
113 try:
114 logging.info(f"Training started!")
115 logging.info(self.__config_summary.replace('\t', ' '))
116 logging.info(f"Printing models architecture...")
117 self.__agent.print_model_summary(print_function = lambda x: logging.info(x))
119 self.__generated_data['train'] = self.__agent.train_agent(self.__environment,
120 self.__nr_of_steps,
121 self.__steps_per_episode,
122 callbacks,
123 weights_load_path,
124 weights_save_path)
125 self.__generated_data['test'] = self.__agent.test_agent(self.__environment,
126 self.__repeat_test)
128 logging.info(f"Training finished!")
129 except Exception as e:
130 logging.error(f"Training failed! Original error: {e}")
131 finally:
132 root_logger.removeHandler(log_streamer)
133 log_streamer.close()
135 def __handle_plot_generation(self, data: dict) -> Optional[ImageReader]:
136 """
137 Generates a plot based on provided data using the responsibility chain.
139 Parameters:
140 data (dict): Dictionary containing 'key' identifying the plot type
141 and 'plot_data' containing the actual data to be plotted.
143 Returns:
144 Optional[ImageReader]: ReportLab ImageReader object if plot was generated,
145 None if no handler could process the request.
146 """
148 image_reader = None
149 axes = self.__plotting_chain.plot(data)
150 if axes is not None:
151 buffer = io.BytesIO()
152 axes.figure.savefig(buffer, format = 'png')
153 buffer.seek(0)
154 image_reader = ImageReader(buffer)
155 plt.close(axes.figure)
156 else:
157 logging.warning(f'Did not managed to generate plot for {data["key"]}!')
159 return image_reader
161 def __calculate_max_dimensions(self, pdf: canvas.Canvas) -> tuple[int, int]:
162 """
163 Calculates maximum text dimensions that can fit on a PDF page.
165 Parameters:
166 pdf (canvas.Canvas): The PDF canvas to calculate dimensions for.
168 Returns:
169 tuple[int, int]: Maximum number of characters per line and number
170 of lines per page that can fit within margins.
171 """
173 pdf.setFont(self.__font_name, self.__text_font_size)
174 text_width = pdf.stringWidth(' ', self.__font_name, self.__text_font_size)
175 available_width = self.__page_width - self.__margins['left'] - self.__margins['right']
176 max_width = int(available_width / text_width)
178 text_height = self.__text_font_size * 1.2
179 available_height = self.__page_height - self.__margins['top'] - self.__margins['bottom'] \
180 - self.__heading_spacing
181 max_height = int(available_height / text_height)
183 return max_width, max_height
185 def __handle_logs_preprocessing(self, raw_logs_bufffer: list[str], max_log_length: int,
186 max_lines_per_page: int) -> list[list[str]]:
187 """
188 Processes raw logs to fit them within PDF page constraints.
190 This method filters out excluded terms, handles line wrapping for long lines,
191 and chunks the logs into page-sized portions.
193 Parameters:
194 raw_logs_bufffer (list[str]): Raw log lines to process.
195 max_log_length (int): Maximum characters per line.
196 max_lines_per_page (int): Maximum lines per page.
198 Returns:
199 list[list[str]]: List of pages, where each page is a list of log lines.
200 """
202 preprocessed_logs = []
203 for log in raw_logs_bufffer:
204 if not any(exclude_term in log for exclude_term in self.__exclude_from_logs):
205 if len(log) > max_log_length:
206 for i in range(0, len(log), max_log_length):
207 preprocessed_logs.append(log[i:i + max_log_length])
208 else:
209 preprocessed_logs.append(log)
211 return [preprocessed_logs[i:i + max_lines_per_page]
212 for i in range(0, len(preprocessed_logs), max_lines_per_page)]
214 def __draw_caption(self, pdf: canvas.Canvas, text: str) -> None:
215 """
216 Draws a section caption with separating line on the PDF.
218 Parameters:
219 pdf (canvas.Canvas): PDF canvas to draw on.
220 text (str): Caption text to draw.
221 """
223 pdf.setFont(self.__font_name, self.__caption_font_size)
224 pdf.drawString(self.__margins['left'], self.__page_height - self.__margins['top'], text)
225 pdf.setLineWidth(2)
226 pdf.setStrokeColorRGB(0, 0, 0)
227 pdf.line(self.__margins['left'], self.__page_height - self.__margins['top'] \
228 - self.__heading_spacing / 2, self.__page_width - self.__margins['right'],
229 self.__page_height - self.__margins['top'] - self.__heading_spacing / 2)
231 def __draw_text_body(self, pdf: canvas.Canvas, text_body: list[str]) -> None:
232 """
233 Draws a block of text lines on the PDF.
235 Parameters:
236 pdf (canvas.Canvas): PDF canvas to draw on.
237 text_body (list[str]): List of text lines to draw.
238 """
240 text_block = pdf.beginText(self.__margins['left'], self.__page_height - \
241 self.__margins['top'] - self.__heading_spacing)
242 text_block.setFont(self.__font_name, self.__text_font_size)
243 for line in text_body:
244 text_block.textLine(line)
245 pdf.drawText(text_block)
247 def generate_report(self, path_to_pdf: str) -> None:
248 """
249 Generates a comprehensive PDF report of training and testing results.
251 Creates a multi-page report with logs, training history plots, and test
252 results visualizations based on the data collected during training.
254 Parameters:
255 path_to_pdf (str): File path where the PDF report should be saved.
256 """
258 logging.info(f"Generating report...")
259 pdf = canvas.Canvas(path_to_pdf, pagesize = letter)
260 pdf.setTitle("Report")
262 # Report logs
263 raw_logs = self.__logs.getvalue().split('\n')
264 max_log_length, max_lines_per_page = self.__calculate_max_dimensions(pdf)
265 preprocessed_logs = self.__handle_logs_preprocessing(raw_logs, max_log_length, max_lines_per_page)
266 for preprocessed_logs_chunk in preprocessed_logs:
267 self.__draw_caption(pdf, "Log output")
268 self.__draw_text_body(pdf, preprocessed_logs_chunk)
269 pdf.showPage()
271 # Draw training plot
272 data = {
273 'key': 'training_history',
274 'plot_data': self.__generated_data['train']
275 }
276 plot_buffer = self.__handle_plot_generation(data)
277 if plot_buffer is not None:
278 self.__draw_caption(pdf, "Training performance")
279 pdf.drawImage(plot_buffer, inch, self.__page_height - 7.5 * inch, width = 6 * inch,
280 preserveAspectRatio = True)
281 pdf.showPage()
283 # Draw testing plots
284 for index, testing_data in self.__generated_data['test'].items():
285 data = {
286 'key': 'testing_history',
287 'plot_data': testing_data
288 }
289 plot_buffer = self.__handle_plot_generation(data)
290 if plot_buffer is not None:
291 self.__draw_caption(pdf, f"Testing outcome, trial: {index + 1}")
292 pdf.drawImage(plot_buffer, inch, self.__page_height - 7.5 * inch, width = 6 * inch,
293 preserveAspectRatio = True)
294 pdf.showPage()
296 pdf.save()
297 logging.info(f"Report generated!")