-
Notifications
You must be signed in to change notification settings - Fork 0
/
db_connection.py
122 lines (101 loc) · 4.09 KB
/
db_connection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import logging
import re
import mysql.connector
import yaml
import traceback
import time
from pymysql.converters import escape_string
class DBConnector(object):
def __init__(self):
self.dbconn = None
self.db_config = self.read_db_config()
def read_db_config(self):
with open('./db_connection.yml', 'r') as file:
config = yaml.safe_load(file)
return config
def create_connection(self):
return mysql.connector.connect(
host=self.db_config['database_url'],
user=self.db_config['database_user'], # Replace with your username
password=self.db_config['database_password'],
database=self.db_config['database_name'], # Replace with your database name
port=self.db_config['database_port']
)
# For explicitly opening database connection
def __enter__(self):
self.dbconn = self.create_connection()
return self.dbconn
def __exit__(self, exc_type, exc_val, exc_tb):
self.dbconn.close()
class DBConnection(object):
connection = None
@classmethod
def get_connection(cls, new=False):
"""Creates return new Singleton database connection"""
if new or not cls.connection:
cls.connection = DBConnector().create_connection()
return cls.connection
@classmethod
def log_sql(cls, query, stack_trace):
# for debugging, enable if there are issues.
# works with ./vm/sql_log_scan.sh
return
formatted_stack_trace = ''.join(stack_trace)
log_string = (
f"SQL Query: {query}\n"
"Stack Trace:\n"
f"{formatted_stack_trace}\n"
f"{'-' * 40}\n"
)
with open('./sql.log', 'a') as log_file:
log_file.write(log_string)
@classmethod
def execute_query(cls, query, args=None):
"""
Execute a SQL query with retry mechanism on deadlock.
:param query: The SQL query
:param args: Arguments for the query
:return: Query result for SELECT, or None for other types
"""
max_retries = 3
retry_delay = 60 # seconds
# Regex to match queries that are expected to return a result set.
returns_result_pattern = re.compile(r"^\s*(WITH|SELECT)", re.IGNORECASE)
for attempt in range(max_retries):
connection = cls.get_connection()
try:
cursor = connection.cursor()
if args is None:
cursor.execute(query)
formatted_query = query
else:
cursor.execute(query, args)
formatted_query = query
for arg in args:
formatted_query = formatted_query.replace('%s', f"'{escape_string(str(arg))}'", 1)
cls.log_sql(formatted_query, traceback.format_stack())
# Check if query is expected to return a result.
# Had to make this more flexible than just matching SELECT at the beginning of query
# to also handle CTEs.
if returns_result_pattern.match(query):
result = cursor.fetchall()
cursor.close()
return result
else:
connection.commit()
cursor.close()
return None
except mysql.connector.errors.InternalError as e:
if e.errno == 1213: # Deadlock error code
logging.warning(
f"Deadlock detected, attempt {attempt + 1} of {max_retries}. Retrying in {retry_delay} seconds.")
time.sleep(retry_delay)
else:
raise e
except Exception as e:
logging.critical(f"Bad SQL: {e}:\n{query}")
print(traceback.format_exc())
cursor.close()
raise e
# If all retries fail, rethrow the last exception
raise Exception(f"Failed to execute query after {max_retries} attempts due to deadlock.")