Coverage for source/utils/aws_handler.py: 78%
50 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-19 11:16 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-08-19 11:16 +0000
1# utils/aws_handler.py
3# global imports
4import botocore.exceptions
5import boto3
6import functools
7import io
8import os
9from typing import Any
11# local imports
12from source.utils import SingletonMeta
14class AWSHandler(metaclass = SingletonMeta):
15 """
16 Responsible for handling communication with Amazon AWS services.
17 """
19 # local constants
20 __DEFAULT_REGION = "eu-central-1"
22 def __renew_s3_client_session(self) -> Any:
23 """
24 Assumes a role in AWS and renews S3 client session. Functionality is
25 put into a separate method to allow for easier session renewal. The
26 credentials are expired after 1 hour by default.
27 """
29 credentials = self.__credential_session.client('sts'). \
30 assume_role(RoleArn = self.__role_arn,
31 RoleSessionName = 'S3_bucket_user_session')['Credentials']
33 return boto3.client('s3', aws_access_key_id=credentials['AccessKeyId'],
34 aws_secret_access_key=credentials['SecretAccessKey'],
35 aws_session_token=credentials['SessionToken'],
36 region_name=self.__region_name)
38 def __init__(self, region_name: str = __DEFAULT_REGION) -> None:
39 """
40 Class constructor. Before calling it AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY
41 and ACCOUNT_ID should be available as environmental variables.
43 Parameters:
44 region_name (str): Region name to connect to.
46 Raises:
47 RuntimeError: If AWS credentials or account ID are not defined.
48 """
50 AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID')
51 AWS_SECRET_ACCESS_KEY = os.getenv('AWS_SECRET_ACCESS_KEY')
52 ACCOUNT_ID = os.getenv('ACCOUNT_ID')
53 ROLE_NAME = os.getenv('ROLE_NAME')
54 if not AWS_ACCESS_KEY_ID or not AWS_SECRET_ACCESS_KEY or not ACCOUNT_ID \
55 or not ROLE_NAME:
56 raise RuntimeError("AWS credentials or account ID not found in environment variables!")
58 self.__credential_session = boto3.Session(aws_access_key_id = AWS_ACCESS_KEY_ID,
59 aws_secret_access_key = AWS_SECRET_ACCESS_KEY)
60 self.__role_arn = f'arn:aws:iam::{ACCOUNT_ID}:role/{ROLE_NAME}'
61 self.__region_name = region_name
62 self.__s3_session = self.__renew_s3_client_session()
64 def with_session_renewal(method):
65 """
66 Decorator to handle session renewal on ExpiredToken error.
67 """
68 @functools.wraps(method)
69 def wrapper(self, *args, **kwargs):
70 try:
71 return method(self, *args, **kwargs)
72 except botocore.exceptions.ClientError as e:
73 self.__s3_session = self.__renew_s3_client_session()
74 try:
75 return method(self, *args, **kwargs)
76 except Exception as e2:
77 raise RuntimeError(f"Operation failed after session renewal! Original error: {e2}")
78 except Exception as e:
79 raise RuntimeError(f"Did not manage to perform S3 operation! Original error: {e}")
80 return wrapper
82 @with_session_renewal
83 def upload_file_to_s3(self, bucket_name: str, file_path: str, desired_name: str = "") -> None:
84 """
85 Attempts to upload local file specified by path to S3 Amazon bucket.
87 Parameters:
88 bucket_name (str): String denoting bucket name.
89 file_path (str): String representing file to the path that should be uploaded.
90 desired_name (str): Desired name to be given to the file after being uploaded.
91 If left unspecified, name does not change.
93 Raises:
94 RuntimeError: If approached problem during file uploading.
95 """
97 if desired_name == "":
98 desired_name = file_path.split('/')[-1]
100 self.__s3_session.upload_file(file_path, bucket_name, desired_name)
102 @with_session_renewal
103 def upload_buffer_to_s3(self, bucket_name: str, buffer: io.StringIO, desired_name: str) -> None:
104 """
105 Attempts to upload buffer as file body directly to S3 Amazon bucket.
107 Parameters:
108 bucket_name (str): String denoting bucket name.
109 buffer (io.StringIO): Buffer containing data that should be directly
110 written to bucket.
111 desired_name (str): Desired name to be given to the file after being uploaded.
113 Raises:
114 RuntimeError: If approached problem during file uploading.
115 """
117 self.__s3_session.put_object(Bucket = bucket_name, Key = desired_name, Body = buffer.getvalue())
119 @with_session_renewal
120 def download_file_from_s3(self, bucket_name: str, file_name: str, desired_path: str = "") -> None:
121 """
122 Downloads a file from an S3 bucket to a local path.
124 Parameters:
125 bucket_name (str): The name of the S3 bucket.
126 file_name (str): The key/path of the file in the S3 bucket.
127 desired_path (str, optional): The local path where the file will be saved.
128 If not provided, the file will be downloaded to the current working directory
129 with the original filename.
131 Raises:
132 RuntimeError: If the download operation fails.
133 """
135 if desired_path == "":
136 desired_path = os.getcwd() + '/' + file_name.split('/')[-1]
138 self.__s3_session.download_file(bucket_name, file_name, desired_path)