Coverage for source/utils/aws_handler.py: 78%

50 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-08-21 22:14 +0000

1# utils/aws_handler.py 

2 

3# global imports 

4import botocore.exceptions 

5import boto3 

6import functools 

7import io 

8import os 

9from typing import Any 

10 

11# local imports 

12from source.utils import SingletonMeta 

13 

14class AWSHandler(metaclass = SingletonMeta): 

15 """ 

16 Responsible for handling communication with Amazon AWS services. 

17 """ 

18 

19 # local constants 

20 __DEFAULT_REGION = "eu-central-1" 

21 

22 def __renew_s3_client_session(self) -> Any: 

23 """ 

24 Assumes a role in AWS and renews S3 client session. Functionality is 

25 put into a separate method to allow for easier session renewal. The 

26 credentials are expired after 1 hour by default. 

27 """ 

28 

29 credentials = self.__credential_session.client('sts'). \ 

30 assume_role(RoleArn = self.__role_arn, 

31 RoleSessionName = 'S3_bucket_user_session')['Credentials'] 

32 

33 return boto3.client('s3', aws_access_key_id=credentials['AccessKeyId'], 

34 aws_secret_access_key=credentials['SecretAccessKey'], 

35 aws_session_token=credentials['SessionToken'], 

36 region_name=self.__region_name) 

37 

38 def __init__(self, region_name: str = __DEFAULT_REGION) -> None: 

39 """ 

40 Class constructor. Before calling it AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY 

41 and ACCOUNT_ID should be available as environmental variables. 

42 

43 Parameters: 

44 region_name (str): Region name to connect to. 

45 

46 Raises: 

47 RuntimeError: If AWS credentials or account ID are not defined. 

48 """ 

49 

50 AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID') 

51 AWS_SECRET_ACCESS_KEY = os.getenv('AWS_SECRET_ACCESS_KEY') 

52 ACCOUNT_ID = os.getenv('ACCOUNT_ID') 

53 ROLE_NAME = os.getenv('ROLE_NAME') 

54 if not AWS_ACCESS_KEY_ID or not AWS_SECRET_ACCESS_KEY or not ACCOUNT_ID \ 

55 or not ROLE_NAME: 

56 raise RuntimeError("AWS credentials or account ID not found in environment variables!") 

57 

58 self.__credential_session = boto3.Session(aws_access_key_id = AWS_ACCESS_KEY_ID, 

59 aws_secret_access_key = AWS_SECRET_ACCESS_KEY) 

60 self.__role_arn = f'arn:aws:iam::{ACCOUNT_ID}:role/{ROLE_NAME}' 

61 self.__region_name = region_name 

62 self.__s3_session = self.__renew_s3_client_session() 

63 

64 def with_session_renewal(method): 

65 """ 

66 Decorator to handle session renewal on ExpiredToken error. 

67 """ 

68 @functools.wraps(method) 

69 def wrapper(self, *args, **kwargs): 

70 try: 

71 return method(self, *args, **kwargs) 

72 except botocore.exceptions.ClientError as e: 

73 self.__s3_session = self.__renew_s3_client_session() 

74 try: 

75 return method(self, *args, **kwargs) 

76 except Exception as e2: 

77 raise RuntimeError(f"Operation failed after session renewal! Original error: {e2}") 

78 except Exception as e: 

79 raise RuntimeError(f"Did not manage to perform S3 operation! Original error: {e}") 

80 return wrapper 

81 

82 @with_session_renewal 

83 def upload_file_to_s3(self, bucket_name: str, file_path: str, desired_name: str = "") -> None: 

84 """ 

85 Attempts to upload local file specified by path to S3 Amazon bucket. 

86 

87 Parameters: 

88 bucket_name (str): String denoting bucket name. 

89 file_path (str): String representing file to the path that should be uploaded. 

90 desired_name (str): Desired name to be given to the file after being uploaded. 

91 If left unspecified, name does not change. 

92 

93 Raises: 

94 RuntimeError: If approached problem during file uploading. 

95 """ 

96 

97 if desired_name == "": 

98 desired_name = file_path.split('/')[-1] 

99 

100 self.__s3_session.upload_file(file_path, bucket_name, desired_name) 

101 

102 @with_session_renewal 

103 def upload_buffer_to_s3(self, bucket_name: str, buffer: io.StringIO, desired_name: str) -> None: 

104 """ 

105 Attempts to upload buffer as file body directly to S3 Amazon bucket. 

106 

107 Parameters: 

108 bucket_name (str): String denoting bucket name. 

109 buffer (io.StringIO): Buffer containing data that should be directly 

110 written to bucket. 

111 desired_name (str): Desired name to be given to the file after being uploaded. 

112 

113 Raises: 

114 RuntimeError: If approached problem during file uploading. 

115 """ 

116 

117 self.__s3_session.put_object(Bucket = bucket_name, Key = desired_name, Body = buffer.getvalue()) 

118 

119 @with_session_renewal 

120 def download_file_from_s3(self, bucket_name: str, file_name: str, desired_path: str = "") -> None: 

121 """ 

122 Downloads a file from an S3 bucket to a local path. 

123 

124 Parameters: 

125 bucket_name (str): The name of the S3 bucket. 

126 file_name (str): The key/path of the file in the S3 bucket. 

127 desired_path (str, optional): The local path where the file will be saved. 

128 If not provided, the file will be downloaded to the current working directory 

129 with the original filename. 

130 

131 Raises: 

132 RuntimeError: If the download operation fails. 

133 """ 

134 

135 if desired_path == "": 

136 desired_path = os.getcwd() + '/' + file_name.split('/')[-1] 

137 

138 self.__s3_session.download_file(bucket_name, file_name, desired_path)