Coverage for source/training/training_handler.py: 20%
133 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-06 12:00 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-06 12:00 +0000
1# training/training_handler.py
3# global imports
4import logging
5import io
6import matplotlib.pyplot as plt
7from tensorflow.keras.callbacks import Callback
8from typing import Optional
9from reportlab.pdfgen import canvas
10from reportlab.lib.pagesizes import letter
11from reportlab.lib.units import inch
12from reportlab.lib.utils import ImageReader
14# local imports
15from source.agent import AgentHandler
16from source.plotting import ClassificationTestingPlotResponsibilityChain, \
17 ClassificationTrainingPlotResponsibilityChain, PlotResponsibilityChainBase, \
18 PlotTestingHistoryResponsibilityChain, PlotTrainingHistoryResponsibilityChain
19from source.training import TrainingConfig
21class TrainingHandler():
22 """
23 Responsible for orchestrating the training process and report generation.
25 This class manages the complete training workflow, from initializing the
26 environment and agent, running training and testing sessions, to generating
27 PDF reports with performance visualizations and logs. It serves as the main
28 entry point for executing and documenting trading agent training.
29 """
31 # Constants used locally
32 HEADING_SPACING = 20
33 CAPTION_FONT_SIZE = 14
34 TEXT_FONT_SIZE = 8
35 FONT_NAME = 'Courier'
36 MARGINS = {
37 'left': 30,
38 'right': 30,
39 'top': 30,
40 'bottom': 30
41 }
42 EXCLUDE_FROM_LOGS = ['ETA']
44 def __init__(self, config: TrainingConfig, page_width: int = letter[0], page_height: int = letter[1],
45 heading_spacing: int = HEADING_SPACING, caption_font_size: int = CAPTION_FONT_SIZE,
46 text_font_size: int = TEXT_FONT_SIZE, font_name: str = FONT_NAME,
47 margins: dict[str, int] = MARGINS, exclude_from_logs: list[str] = EXCLUDE_FROM_LOGS) -> None:
48 """
49 Initializes the training handler with configuration parameters.
51 Parameters:
52 config (TrainingConfig): Configuration containing environment and agent settings.
53 page_width (int): Width of PDF report pages in points.
54 page_height (int): Height of PDF report pages in points.
55 heading_spacing (int): Spacing between headings and content in points.
56 caption_font_size (int): Font size for section captions.
57 text_font_size (int): Font size for main text content.
58 font_name (str): Font family to use for text in reports.
59 margins (dict[str, int]): Dictionary with 'left', 'right', 'top', and 'bottom' margins.
60 exclude_from_logs (list[str]): Terms that should be excluded from logs in reports.
62 Raises:
63 ValueError: If margins dictionary doesn't contain required keys.
64 """
66 # Training related configuration
67 self.__agent: AgentHandler = config.instantiate_agent_handler()
68 self.__nr_of_steps: int = config.nr_of_steps
69 self.__repeat_test: int = config.repeat_test
70 self.__nr_of_episodes: int = config.nr_of_episodes
72 # Report related configuration
73 self.__config_summary = str(config)
74 # self.__plotting_chain: PlotResponsibilityChainBase = PlotTestingHistoryResponsibilityChain()
75 # self.__plotting_chain.add_next_chain_link(PlotTrainingHistoryResponsibilityChain())
76 # self.__plotting_chain.add_next_chain_link(ClassificationTrainingPlotResponsibilityChain())
77 # self.__plotting_chain.add_next_chain_link(ClassificationTestingPlotResponsibilityChain())
78 classification_training_plot_chain = ClassificationTrainingPlotResponsibilityChain()
79 classification_training_plot_chain.add_next_chain_link(ClassificationTestingPlotResponsibilityChain())
80 plot_training_history_chain = PlotTrainingHistoryResponsibilityChain()
81 plot_training_history_chain.add_next_chain_link(classification_training_plot_chain)
82 self.__plotting_chain: PlotResponsibilityChainBase = PlotTestingHistoryResponsibilityChain()
83 self.__plotting_chain.add_next_chain_link(plot_training_history_chain)
84 self.__generated_data: dict = {}
85 self.__generated_data['train'] = {}
86 self.__generated_data['test'] = {}
87 self.__logs: io.StringIO = io.StringIO()
88 self.__page_width = page_width
89 self.__page_height = page_height
90 self.__heading_spacing = heading_spacing
91 self.__caption_font_size = caption_font_size
92 self.__text_font_size = text_font_size
93 self.__font_name = font_name
94 self.__margins = margins
95 self.__exclude_from_logs = exclude_from_logs
97 if not any(term in margins for term in ['left', 'right', 'top', 'bottom']):
98 raise ValueError("Margins should contain 'left', 'right', 'top' and 'bottom' keys!")
100 def run_training(self, callbacks: list[Callback] = [], weights_load_path: Optional[str] = None,
101 weights_save_path: Optional[str] = None) -> None:
102 """
103 Executes the training and testing process for the trading agent.
105 This method orchestrates the complete training workflow, capturing logs,
106 training the agent, and testing its performance. It populates internal
107 data structures with results that can later be used for report generation.
109 Parameters:
110 callbacks (list[Callback]): Keras callbacks to use during training.
111 weights_load_path (str, optional): Path to load pre-trained weights from.
112 weights_save_path (str, optional): Path to save trained weights to.
113 """
115 log_streamer = logging.StreamHandler(self.__logs)
116 log_streamer.setFormatter(logging.Formatter('%(message)s'))
117 log_streamer.setLevel(logging.INFO)
118 root_logger = logging.getLogger()
119 root_logger.addHandler(log_streamer)
121 # try:
122 logging.info(f"Training started!")
123 logging.info(self.__config_summary.replace('\t', ' '))
124 logging.info(f"Printing models architecture...")
125 self.__agent.print_model_summary(print_function = lambda x: logging.info(x))
127 train_keys, train_data = self.__agent.train_agent(self.__nr_of_steps, self.__nr_of_episodes,
128 callbacks, weights_load_path, weights_save_path)
129 for key, data in zip(train_keys, train_data):
130 self.__generated_data['train'][key] = data
132 test_keys, test_data = self.__agent.test_agent(self.__repeat_test)
133 for (iteration, key_list), (_, data_list) in zip(test_keys.items(), test_data.items()):
134 self.__generated_data['test'][iteration] = {}
135 for key, data in zip(key_list, data_list):
136 self.__generated_data['test'][iteration][key] = data
138 logging.info(f"Training finished!")
139 # except Exception as e:
140 # logging.error(f"Training failed! Original error: {e}")
141 # finally:
142 root_logger.removeHandler(log_streamer)
143 log_streamer.close()
145 def __handle_plot_generation(self, data: dict) -> Optional[ImageReader]:
146 """
147 Generates a plot based on provided data using the responsibility chain.
149 Parameters:
150 data (dict): Dictionary containing 'key' identifying the plot type
151 and 'plot_data' containing the actual data to be plotted.
153 Returns:
154 Optional[ImageReader]: ReportLab ImageReader object if plot was generated,
155 None if no handler could process the request.
156 """
158 image_reader = None
159 axes = self.__plotting_chain.plot(data)
160 if axes is not None:
161 buffer = io.BytesIO()
162 axes.figure.savefig(buffer, format = 'png')
163 buffer.seek(0)
164 image_reader = ImageReader(buffer)
165 plt.close(axes.figure)
166 else:
167 logging.warning(f'Did not managed to generate plot for {data["key"]}!')
169 return image_reader
171 def __calculate_max_dimensions(self, pdf: canvas.Canvas) -> tuple[int, int]:
172 """
173 Calculates maximum text dimensions that can fit on a PDF page.
175 Parameters:
176 pdf (canvas.Canvas): The PDF canvas to calculate dimensions for.
178 Returns:
179 tuple[int, int]: Maximum number of characters per line and number
180 of lines per page that can fit within margins.
181 """
183 pdf.setFont(self.__font_name, self.__text_font_size)
184 text_width = pdf.stringWidth(' ', self.__font_name, self.__text_font_size)
185 available_width = self.__page_width - self.__margins['left'] - self.__margins['right']
186 max_width = int(available_width / text_width)
188 text_height = self.__text_font_size * 1.2
189 available_height = self.__page_height - self.__margins['top'] - self.__margins['bottom'] \
190 - self.__heading_spacing
191 max_height = int(available_height / text_height)
193 return max_width, max_height
195 def __handle_logs_preprocessing(self, raw_logs_bufffer: list[str], max_log_length: int,
196 max_lines_per_page: int) -> list[list[str]]:
197 """
198 Processes raw logs to fit them within PDF page constraints.
200 This method filters out excluded terms, handles line wrapping for long lines,
201 and chunks the logs into page-sized portions.
203 Parameters:
204 raw_logs_bufffer (list[str]): Raw log lines to process.
205 max_log_length (int): Maximum characters per line.
206 max_lines_per_page (int): Maximum lines per page.
208 Returns:
209 list[list[str]]: List of pages, where each page is a list of log lines.
210 """
212 preprocessed_logs = []
213 for log in raw_logs_bufffer:
214 if not any(exclude_term in log for exclude_term in self.__exclude_from_logs):
215 if len(log) > max_log_length:
216 for i in range(0, len(log), max_log_length):
217 preprocessed_logs.append(log[i:i + max_log_length])
218 else:
219 preprocessed_logs.append(log)
221 return [preprocessed_logs[i:i + max_lines_per_page]
222 for i in range(0, len(preprocessed_logs), max_lines_per_page)]
224 def __draw_caption(self, pdf: canvas.Canvas, text: str) -> None:
225 """
226 Draws a section caption with separating line on the PDF.
228 Parameters:
229 pdf (canvas.Canvas): PDF canvas to draw on.
230 text (str): Caption text to draw.
231 """
233 pdf.setFont(self.__font_name, self.__caption_font_size)
234 pdf.drawString(self.__margins['left'], self.__page_height - self.__margins['top'], text)
235 pdf.setLineWidth(2)
236 pdf.setStrokeColorRGB(0, 0, 0)
237 pdf.line(self.__margins['left'], self.__page_height - self.__margins['top'] \
238 - self.__heading_spacing / 2, self.__page_width - self.__margins['right'],
239 self.__page_height - self.__margins['top'] - self.__heading_spacing / 2)
241 def __draw_text_body(self, pdf: canvas.Canvas, text_body: list[str]) -> None:
242 """
243 Draws a block of text lines on the PDF.
245 Parameters:
246 pdf (canvas.Canvas): PDF canvas to draw on.
247 text_body (list[str]): List of text lines to draw.
248 """
250 text_block = pdf.beginText(self.__margins['left'], self.__page_height - \
251 self.__margins['top'] - self.__heading_spacing)
252 text_block.setFont(self.__font_name, self.__text_font_size)
253 for line in text_body:
254 text_block.textLine(line)
255 pdf.drawText(text_block)
257 def generate_report(self, path_to_pdf: str) -> None:
258 """
259 Generates a comprehensive PDF report of training and testing results.
261 Creates a multi-page report with logs, training history plots, and test
262 results visualizations based on the data collected during training.
264 Parameters:
265 path_to_pdf (str): File path where the PDF report should be saved.
266 """
268 logging.info(f"Generating report...")
269 pdf = canvas.Canvas(path_to_pdf, pagesize = letter)
270 pdf.setTitle("Report")
272 # Report logs
273 raw_logs = self.__logs.getvalue().split('\n')
274 max_log_length, max_lines_per_page = self.__calculate_max_dimensions(pdf)
275 preprocessed_logs = self.__handle_logs_preprocessing(raw_logs, max_log_length, max_lines_per_page)
276 for preprocessed_logs_chunk in preprocessed_logs:
277 self.__draw_caption(pdf, "Log output")
278 self.__draw_text_body(pdf, preprocessed_logs_chunk)
279 pdf.showPage()
281 # Draw training plot
282 for key, data in self.__generated_data['train'].items():
283 plot_buffer = self.__handle_plot_generation({
284 'key': key,
285 'plot_data': data
286 })
287 if plot_buffer is not None:
288 self.__draw_caption(pdf, "Training performance")
289 pdf.drawImage(plot_buffer, inch, self.__page_height - 7.5 * inch, width = 6 * inch,
290 preserveAspectRatio = True)
291 pdf.showPage()
293 # Draw testing plots
294 for iteration, key_data_pair in self.__generated_data['test'].items():
295 for key, data in key_data_pair.items():
296 plot_buffer = self.__handle_plot_generation({
297 'key': key,
298 'plot_data': data
299 })
300 if plot_buffer is not None:
301 self.__draw_caption(pdf, f"Testing outcome, trial: {iteration}")
302 pdf.drawImage(plot_buffer, 0.5 * inch, 1 * inch, width = letter[0] - 1 * inch,
303 height = letter[1] - 2 * inch)
304 pdf.showPage()
306 pdf.save()
307 logging.info(f"Report generated!")