Coverage for source/training/training_handler.py: 96%
136 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-27 17:11 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-27 17:11 +0000
1# training/training_handler.py
3# global imports
4import io
5import logging
6import matplotlib.pyplot as plt
7from reportlab.lib.pagesizes import letter
8from reportlab.lib.units import inch
9from reportlab.lib.utils import ImageReader
10from reportlab.pdfgen import canvas
11from tensorflow.keras.callbacks import Callback
12from typing import Optional
14# local imports
15from source.agent import AgentHandler
16from source.plotting import AssetPriceMovementSummaryPlotResponsibilityChain, \
17 ClassificationTestingPlotResponsibilityChain, ClassificationTrainingPlotResponsibilityChain, \
18 PerformanceTestingPlotResponsibilityChain, PlotResponsibilityChainBase, \
19 PriceMovementTrendClassSummaryPlotResponsibilityChain, ReinforcementTrainingPlotResponsibilityChain
20from source.training import TrainingConfig
22class TrainingHandler():
23 """
24 Responsible for orchestrating the training process and report generation.
26 This class manages the complete training workflow, from initializing the
27 environment and agent, running training and testing sessions, to generating
28 PDF reports with performance visualizations and logs. It serves as the main
29 entry point for executing and documenting trading agent training.
30 """
32 # local constants
33 __HEADING_SPACING = 20
34 __CAPTION_FONT_SIZE = 14
35 __TEXT_FONT_SIZE = 8
36 __FONT_NAME = 'Courier'
37 __MARGINS = {
38 'left': 30,
39 'right': 30,
40 'top': 30,
41 'bottom': 30
42 }
43 __EXCLUDE_FROM_LOGS = ['ETA']
45 def __init__(self, config: TrainingConfig, page_width: int = letter[0], page_height: int = letter[1],
46 heading_spacing: int = __HEADING_SPACING, caption_font_size: int = __CAPTION_FONT_SIZE,
47 text_font_size: int = __TEXT_FONT_SIZE, font_name: str = __FONT_NAME,
48 margins: dict[str, int] = __MARGINS, exclude_from_logs: list[str] = __EXCLUDE_FROM_LOGS) -> None:
49 """
50 Initializes the training handler with configuration parameters.
52 Parameters:
53 config (TrainingConfig): Configuration containing environment and agent settings.
54 page_width (int): Width of PDF report pages in points.
55 page_height (int): Height of PDF report pages in points.
56 heading_spacing (int): Spacing between headings and content in points.
57 caption_font_size (int): Font size for section captions.
58 text_font_size (int): Font size for main text content.
59 font_name (str): Font family to use for text in reports.
60 margins (dict[str, int]): Dictionary with 'left', 'right', 'top', and 'bottom' margins.
61 exclude_from_logs (list[str]): Terms that should be excluded from logs in reports.
63 Raises:
64 ValueError: If margins dictionary doesn't contain required keys.
65 """
67 # Training related configuration
68 self.__agent: AgentHandler = config.instantiate_agent_handler()
69 self.__nr_of_steps: int = config.nr_of_steps
70 self.__repeat_test: int = config.repeat_test
71 self.__nr_of_episodes: int = config.nr_of_episodes
73 # Report related configuration
74 self.__config_summary = str(config)
75 self.__plotting_chain: PlotResponsibilityChainBase = AssetPriceMovementSummaryPlotResponsibilityChain()
76 self.__plotting_chain.add_next_chain_link(ClassificationTestingPlotResponsibilityChain())
77 self.__plotting_chain.add_next_chain_link(ClassificationTrainingPlotResponsibilityChain())
78 self.__plotting_chain.add_next_chain_link(PerformanceTestingPlotResponsibilityChain())
79 self.__plotting_chain.add_next_chain_link(PriceMovementTrendClassSummaryPlotResponsibilityChain())
80 self.__plotting_chain.add_next_chain_link(ReinforcementTrainingPlotResponsibilityChain())
81 self.__generated_data: dict = {}
82 self.__generated_data['train'] = {}
83 self.__generated_data['test'] = {}
84 self.__logs: io.StringIO = io.StringIO()
85 self.__page_width = page_width
86 self.__page_height = page_height
87 self.__heading_spacing = heading_spacing
88 self.__caption_font_size = caption_font_size
89 self.__text_font_size = text_font_size
90 self.__font_name = font_name
91 self.__margins = margins
92 self.__exclude_from_logs = exclude_from_logs
94 if not any(term in margins for term in ['left', 'right', 'top', 'bottom']):
95 raise ValueError("Margins should contain 'left', 'right', 'top' and 'bottom' keys!")
97 def run_training(self, callbacks: list[Callback] = [], weights_load_path: Optional[str] = None,
98 weights_save_path: Optional[str] = None) -> None:
99 """
100 Executes the training and testing process for the trading agent.
102 This method orchestrates the complete training workflow, capturing logs,
103 training the agent, and testing its performance. It populates internal
104 data structures with results that can later be used for report generation.
106 Parameters:
107 callbacks (list[Callback]): Keras callbacks to use during training.
108 weights_load_path (str, optional): Path to load pre-trained weights from.
109 weights_save_path (str, optional): Path to save trained weights to.
110 """
112 log_streamer = logging.StreamHandler(self.__logs)
113 log_streamer.setFormatter(logging.Formatter('%(message)s'))
114 log_streamer.setLevel(logging.INFO)
115 root_logger = logging.getLogger()
116 root_logger.addHandler(log_streamer)
118 try:
119 logging.info(f"Training started!")
120 logging.info(self.__config_summary.replace('\t', ' '))
121 logging.info(f"Printing models architecture...")
122 self.__agent.print_model_summary(print_function = lambda x: logging.info(x))
124 train_keys, train_data = self.__agent.train_agent(self.__nr_of_steps, self.__nr_of_episodes,
125 callbacks, weights_load_path, weights_save_path)
126 for key, data in zip(train_keys, train_data):
127 self.__generated_data['train'][key] = data
129 test_keys, test_data = self.__agent.test_agent(self.__repeat_test)
130 for (iteration, key_list), (_, data_list) in zip(test_keys.items(), test_data.items()):
131 self.__generated_data['test'][iteration] = {}
132 for key, data in zip(key_list, data_list):
133 self.__generated_data['test'][iteration][key] = data
135 logging.info(f"Training finished!")
136 except Exception as e:
137 logging.error(f"Training failed! Original error: {e}")
138 finally:
139 root_logger.removeHandler(log_streamer)
140 log_streamer.close()
142 def __handle_plot_generation(self, data: dict) -> Optional[ImageReader]:
143 """
144 Generates a plot based on provided data using the responsibility chain.
146 Parameters:
147 data (dict): Dictionary containing 'key' identifying the plot type
148 and 'plot_data' containing the actual data to be plotted.
150 Returns:
151 Optional[ImageReader]: ReportLab ImageReader object if plot was generated,
152 None if no handler could process the request.
153 """
155 image_reader = None
156 axes = self.__plotting_chain.plot(data)
157 if axes is not None:
158 buffer = io.BytesIO()
159 axes.figure.savefig(buffer, format = 'png')
160 buffer.seek(0)
161 image_reader = ImageReader(buffer)
162 plt.close(axes.figure)
163 else:
164 logging.warning(f'Did not managed to generate plot for {data["key"]}!')
166 return image_reader
168 def __calculate_max_dimensions(self, pdf: canvas.Canvas) -> tuple[int, int]:
169 """
170 Calculates maximum text dimensions that can fit on a PDF page.
172 Parameters:
173 pdf (canvas.Canvas): The PDF canvas to calculate dimensions for.
175 Returns:
176 tuple[int, int]: Maximum number of characters per line and number
177 of lines per page that can fit within margins.
178 """
180 pdf.setFont(self.__font_name, self.__text_font_size)
181 text_width = pdf.stringWidth(' ', self.__font_name, self.__text_font_size)
182 available_width = self.__page_width - self.__margins['left'] - self.__margins['right']
183 max_width = int(available_width / text_width)
185 text_height = self.__text_font_size * 1.2
186 available_height = self.__page_height - self.__margins['top'] - self.__margins['bottom'] \
187 - self.__heading_spacing
188 max_height = int(available_height / text_height)
190 return max_width, max_height
192 def __handle_logs_preprocessing(self, raw_logs_bufffer: list[str], max_log_length: int,
193 max_lines_per_page: int) -> list[list[str]]:
194 """
195 Processes raw logs to fit them within PDF page constraints.
197 This method filters out excluded terms, handles line wrapping for long lines,
198 and chunks the logs into page-sized portions.
200 Parameters:
201 raw_logs_bufffer (list[str]): Raw log lines to process.
202 max_log_length (int): Maximum characters per line.
203 max_lines_per_page (int): Maximum lines per page.
205 Returns:
206 list[list[str]]: List of pages, where each page is a list of log lines.
207 """
209 preprocessed_logs = []
210 for log in raw_logs_bufffer:
211 if not any(exclude_term in log for exclude_term in self.__exclude_from_logs):
212 if len(log) > max_log_length:
213 for i in range(0, len(log), max_log_length):
214 preprocessed_logs.append(log[i:i + max_log_length])
215 else:
216 preprocessed_logs.append(log)
218 return [preprocessed_logs[i:i + max_lines_per_page]
219 for i in range(0, len(preprocessed_logs), max_lines_per_page)]
221 def __draw_caption(self, pdf: canvas.Canvas, text: str) -> None:
222 """
223 Draws a section caption with separating line on the PDF.
225 Parameters:
226 pdf (canvas.Canvas): PDF canvas to draw on.
227 text (str): Caption text to draw.
228 """
230 pdf.setFont(self.__font_name, self.__caption_font_size)
231 pdf.drawString(self.__margins['left'], self.__page_height - self.__margins['top'], text)
232 pdf.setLineWidth(2)
233 pdf.setStrokeColorRGB(0, 0, 0)
234 pdf.line(self.__margins['left'], self.__page_height - self.__margins['top'] \
235 - self.__heading_spacing / 2, self.__page_width - self.__margins['right'],
236 self.__page_height - self.__margins['top'] - self.__heading_spacing / 2)
238 def __draw_text_body(self, pdf: canvas.Canvas, text_body: list[str]) -> None:
239 """
240 Draws a block of text lines on the PDF.
242 Parameters:
243 pdf (canvas.Canvas): PDF canvas to draw on.
244 text_body (list[str]): List of text lines to draw.
245 """
247 text_block = pdf.beginText(self.__margins['left'], self.__page_height - \
248 self.__margins['top'] - self.__heading_spacing)
249 text_block.setFont(self.__font_name, self.__text_font_size)
250 for line in text_body:
251 text_block.textLine(line)
252 pdf.drawText(text_block)
254 def generate_report(self, path_to_pdf: str) -> None:
255 """
256 Generates a comprehensive PDF report of training and testing results.
258 Creates a multi-page report with logs, training history plots, and test
259 results visualizations based on the data collected during training.
261 Parameters:
262 path_to_pdf (str): File path where the PDF report should be saved.
263 """
265 logging.info(f"Generating report...")
266 pdf = canvas.Canvas(path_to_pdf, pagesize = letter)
267 pdf.setTitle("Report")
269 # Report logs
270 raw_logs = self.__logs.getvalue().split('\n')
271 max_log_length, max_lines_per_page = self.__calculate_max_dimensions(pdf)
272 preprocessed_logs = self.__handle_logs_preprocessing(raw_logs, max_log_length, max_lines_per_page)
273 for preprocessed_logs_chunk in preprocessed_logs:
274 self.__draw_caption(pdf, "Log output")
275 self.__draw_text_body(pdf, preprocessed_logs_chunk)
276 pdf.showPage()
278 # Draw training plot
279 for key, data in self.__generated_data['train'].items():
280 plot_buffer = self.__handle_plot_generation({
281 'key': key,
282 'plot_data': data
283 })
284 if plot_buffer is not None:
285 self.__draw_caption(pdf, "Training performance")
286 pdf.drawImage(plot_buffer, 0.5 * inch, 1 * inch, width = letter[0] - 1 * inch,
287 height = letter[1] - 2 * inch)
288 pdf.showPage()
290 # Draw testing plots
291 for iteration, key_data_pair in self.__generated_data['test'].items():
292 for key, data in key_data_pair.items():
293 plot_buffer = self.__handle_plot_generation({
294 'key': key,
295 'plot_data': data
296 })
297 if plot_buffer is not None:
298 self.__draw_caption(pdf, f"Testing outcome, trial: {iteration + 1}")
299 pdf.drawImage(plot_buffer, 0.5 * inch, 1 * inch, width = letter[0] - 1 * inch,
300 height = letter[1] - 2 * inch)
301 pdf.showPage()
303 pdf.save()
304 logging.info(f"Report generated!")