-
Notifications
You must be signed in to change notification settings - Fork 4
/
db_connection.py
117 lines (97 loc) · 3.83 KB
/
db_connection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import logging
import mysql.connector
import yaml
import traceback
import time
from pymysql.converters import escape_string
class DBConnector(object):
def __init__(self):
self.dbconn = None
self.db_config = self.read_db_config()
def read_db_config(self):
with open('./vm/vm_passwords.yml', 'r') as file:
config = yaml.safe_load(file)
return config
# def create_connection(self):
# return sqlite3.connect('doi_database.db', timeout=30.0)
def create_connection(self):
return mysql.connector.connect(
host=self.db_config['database_url'],
user=self.db_config['database_user'], # Replace with your username
password=self.db_config['database_password'],
database=self.db_config['database_name'], # Replace with your database name
port=self.db_config['database_port']
)
# For explicitly opening database connection
def __enter__(self):
self.dbconn = self.create_connection()
return self.dbconn
def __exit__(self, exc_type, exc_val, exc_tb):
self.dbconn.close()
class DBConnection(object):
connection = None
@classmethod
def get_connection(cls, new=False):
"""Creates return new Singleton database connection"""
if new or not cls.connection:
cls.connection = DBConnector().create_connection()
return cls.connection
@classmethod
def log_sql(cls, query, stack_trace):
# for debugging, enable if there are issues.
# works with ./vm/sql_log_scan.sh
return
formatted_stack_trace = ''.join(stack_trace)
log_string = (
f"SQL Query: {query}\n"
"Stack Trace:\n"
f"{formatted_stack_trace}\n"
f"{'-' * 40}\n"
)
with open('./sql.log', 'a') as log_file:
log_file.write(log_string)
@classmethod
def execute_query(cls, query, args=None):
"""
Execute a SQL query with retry mechanism on deadlock.
:param query: The SQL query
:param args: Arguments for the query
:return: Query result for SELECT, or None for other types
"""
max_retries = 3
retry_delay = 60 # seconds
for attempt in range(max_retries):
connection = cls.get_connection()
try:
cursor = connection.cursor()
if args is None:
cursor.execute(query)
formatted_query = query
else:
cursor.execute(query, args)
formatted_query = query
for arg in args:
formatted_query = formatted_query.replace('%s', f"'{escape_string(str(arg))}'", 1)
cls.log_sql(formatted_query, traceback.format_stack())
if query.strip().upper().startswith("SELECT"):
result = cursor.fetchall()
cursor.close()
return result
else:
connection.commit()
cursor.close()
return None
except mysql.connector.errors.InternalError as e:
if e.errno == 1213: # Deadlock error code
logging.warning(
f"Deadlock detected, attempt {attempt + 1} of {max_retries}. Retrying in {retry_delay} seconds.")
time.sleep(retry_delay)
else:
raise e
except Exception as e:
logging.critical(f"Bad SQL: {e}:\n{query}")
print(traceback.format_exc())
cursor.close()
raise e
# If all retries fail, rethrow the last exception
raise Exception(f"Failed to execute query after {max_retries} attempts due to deadlock.")