Coverage for source/utils/gradient_handler.py: 81%
32 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-30 20:59 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-07-30 20:59 +0000
1# utils/gradient_handler.py
3# global imports
4import os
5from datetime import datetime
6from gradient import NotebooksClient
8# local imports
9from source.utils import SingletonMeta
11class GradientHandler(metaclass = SingletonMeta):
12 """
13 Responsible for communication and management of Paperspace Gradient services.
14 """
16 # local constants
17 __DEFAULT_START_COMMAND = " & PIP_DISABLE_PIP_VERSION_CHECK=1 jupyter lab --allow-root --ip=0.0.0.0 --no-browser \
18 --ServerApp.trust_xheaders=True --ServerApp.disable_check_xsrf=False --ServerApp.allow_remote_access=True \
19 --ServerApp.allow_origin='*' --ServerApp.allow_credentials=True"
20 __DEFAULT_CONTAINER_TYPE = 'paperspace/gradient-base:pt112-tf29-jax0317-py39-20230125'
22 def __init__(self) -> None:
23 """
24 Class constructor. Before calling it GRADIENT_API_KEY and GRADIENT_PROJECT_ID
25 should be available in as environmental variables.
27 Raises:
28 RuntimeError: If Paperspace Gradient api key or project ID are not defined.
29 """
31 GRADIENT_API_KEY = os.getenv('GRADIENT_API_KEY')
32 GRADIENT_PROJECT_ID = os.getenv('GRADIENT_PROJECT_ID')
34 if not GRADIENT_API_KEY or not GRADIENT_PROJECT_ID:
35 raise RuntimeError('Paperspace Gradient api key or project ID not found in environment variables!')
37 self.notebooks = NotebooksClient(GRADIENT_API_KEY)
38 self.project_id = GRADIENT_PROJECT_ID
40 def create_notebook(self, command_to_invoke: str, github_repository_url: str = None,
41 notebook_name: str = datetime.now(), machine_types: list = ['Free-P5000'],
42 timeout: int = 6, environment_dict: dict = dict()) -> str:
43 """
44 Attempts to create notebook basing on certain github repository, starting command and
45 sets needed environmental parameters (e.g. variables, machine type, etc.).
47 Parameters:
48 command_to_invoke (str): Command to be invoked after notebook is started.
49 github_repository_url (str): URL to repository that should be downloaded to notebook.
50 notebook_name (str): Name that should be given to notebook.
51 machine_types (list): List of demanded machine's types that notebook should be attempted
52 to be created on. The first successful creation stops attempts to create notebook for
53 machine types that are further in the list. Therefore, machine's types should be ordered
54 from the most wanted to the least one.
55 timeout (int): Number of hours that notebook should be active for. For free instances maximal
56 value for this parameter is 6.
57 environment_dict (dict): Dictionary containing defined environmental variables.
59 Raises:
60 RuntimeError: If approached problem during notebook creation.
61 """
63 notebook_id = None
64 error_to_be_raised = None
65 for machine_type in machine_types:
66 try:
67 notebook_id = self.notebooks.create(machine_type = machine_type,
68 container = self.__DEFAULT_CONTAINER_TYPE,
69 project_id = self.project_id,
70 shutdown_timeout = timeout,
71 workspace = github_repository_url,
72 command = command_to_invoke + self.__DEFAULT_START_COMMAND,
73 environment = environment_dict,
74 name = notebook_name)
75 except Exception as error:
76 error_to_be_raised = error
78 if notebook_id is not None:
79 break
81 if notebook_id is None:
82 raise RuntimeError(f"Did not managed to create notebook! Original error: {error_to_be_raised}")
84 return notebook_id
86 def delete_notebook(self, notebook_id: str) -> None:
87 """
88 Deletes notebook with the given ID from Paperspace Gradient.
90 This method attempts to remove a previously created notebook from
91 the Paperspace Gradient platform.
93 Parameters:
94 notebook_id (str): ID of the notebook to be deleted.
96 Raises:
97 RuntimeError: If approached problem during notebook deletion.
98 """
100 try:
101 self.notebooks.delete(notebook_id)
102 except Exception as error:
103 raise RuntimeError(f"Did not managed to delete notebook! Original error: {error}")