Skip to content

Commit

Permalink
Fix max query parameters to 100
Browse files Browse the repository at this point in the history
  • Loading branch information
G4brym committed Oct 12, 2024
1 parent 4a71b0f commit c7ba34c
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 24 deletions.
1 change: 1 addition & 0 deletions django_d1/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class DatabaseFeatures(SQLiteDatabaseFeatures):
can_defer_constraint_checks = False
supports_pragma_foreign_key_check = False
can_alter_table_rename_column = False
max_query_params = 100
can_clone_databases = False
can_rollback_ddl = False
supports_atomic_references_rename = False
Expand Down
93 changes: 70 additions & 23 deletions django_dbms/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import datetime
import hashlib
import json

import sqlparse
import websocket # Using websocket-client library for synchronous operations
from django.conf import settings
from django.db import IntegrityError, DatabaseError
from django.db.backends.sqlite3.base import DatabaseWrapper as SQLiteDatabaseWrapper
from django.db.backends.sqlite3.client import DatabaseClient as SQLiteDatabaseClient
Expand All @@ -14,11 +16,13 @@
from django.utils import timezone
from sqlparse.sql import IdentifierList, Identifier
from sqlparse.tokens import DML
from django.core.cache import cache


class DatabaseFeatures(SQLiteDatabaseFeatures):
supports_transactions = True
supports_savepoints = False
max_query_params = 100


class DatabaseOperations(SQLiteDatabaseOperations):
Expand Down Expand Up @@ -140,6 +144,13 @@ def _convert_results(self, results):
converted_results.append(tuple(converted_row))
return converted_results

def generate_cache_key(self, sql_query: str) -> str:
# Create a SHA256 hash of the combined string
hash_object = hashlib.sha256(sql_query.encode())

# Return the hexadecimal representation of the hash
return f"workers_dbms_{hash_object.hexdigest()}"

def raw_query(self, websocket, query, params=None):
if params == None:
if query.strip() == 'PRAGMA foreign_keys = OFF':
Expand Down Expand Up @@ -167,7 +178,7 @@ def raw_query(self, websocket, query, params=None):
if params:
sql, params = self._format_params(sql, params)

websocket.send(json.dumps({
socket_input = json.dumps({
"type": "request",
"request": {
"type": "execute",
Expand All @@ -176,28 +187,55 @@ def raw_query(self, websocket, query, params=None):
"query": sql
}
}
}))
})

if self.connection.debug is True:
print(sql)
print(params)
should_cache = False
response = None
if self.connection.cache is True:
upper_query = query.upper()
user_model = settings.AUTH_USER_MODEL.replace('.', '_').upper()
try:
if 'FROM "DJANGO_SESSION"' in upper_query or f'FROM "{user_model}"' in upper_query:
if "UPDATE" not in upper_query and "DELETE" not in upper_query and "INSERT" not in upper_query:
cache_key = self.generate_cache_key(socket_input)
response = cache.get(cache_key)

response = websocket.recv()
parsed_response = json.loads(response)
if not response:
should_cache = True
else:
pass # TODO: clear cache
# cache.clear(prefix="workers_dbms_")
except TypeError:
pass

if self.connection.debug is True:
print(parsed_response)
print('---')
if not response:
websocket.send(socket_input)

if self.connection.debug is True:
print(sql)
print(params)

response = websocket.recv()

parsed_response = json.loads(response)

if parsed_response["type"] == "response_error":
if "unique constraint failed" in parsed_response["error"].lower():
raise IntegrityError(parsed_response["error"])

raise DatabaseError(parsed_response["error"] + "\n" + sql)

if self.connection.cache is True and should_cache is True:
cache_key = self.generate_cache_key(socket_input)
cache.set(cache_key, response, 300)

results = self._convert_results(list(tuple(row) for row in parsed_response["result"]["results"]))
meta = parsed_response["result"].get("meta")

if self.connection.debug is True:
print(results)
print('---')

return results, meta

def quote_name(self, name):
Expand Down Expand Up @@ -259,6 +297,7 @@ def bulk_insert_sql(self, fields, placeholder_rows):
class DatabaseWrapper(SQLiteDatabaseWrapper):
vendor = 'websocket'
debug = False
cache = False

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand All @@ -277,11 +316,14 @@ def get_connection_params(self):
'endpoint_url': settings_dict['WORKERS_DBMS_ENDPOINT'],
'access_id': settings_dict.get('WORKERS_DBMS_ACCESS_ID'),
'access_secret': settings_dict.get('WORKERS_DBMS_ACCESS_SECRET'),
'cache': settings_dict.get('WORKERS_DBMS_CACHE', True),
'debug': settings_dict.get('WORKERS_DBMS_DEBUG'),
}

def get_new_connection(self, conn_params):
headers = []
if conn_params['cache']:
self.cache = conn_params['cache']
if conn_params['debug']:
self.debug = conn_params['debug'] is True

Expand Down Expand Up @@ -347,19 +389,23 @@ def close(self):
# self.websocket.close()

def execute(self, sql, params=None):
result, meta = self.ops.raw_query(self.websocket, sql, params)
try:
result, meta = self.ops.raw_query(self.websocket, sql, params)

self.results = result
self.results = result

# Update rowcount based on the operation type
if meta:
if "INSERT" in sql.upper():
self.rowcount = meta.get("rows_written", 0)
# self.connection.ops.last_insert_id = meta.get("last_insert_id") # TODO: implement last insert id
elif "UPDATE" in sql.upper() or "DELETE" in sql.upper():
self.rowcount = meta.get("rows_written", 0)
else:
self.rowcount = meta.get("rows_read", 0)
# Update rowcount based on the operation type
if meta:
if "INSERT" in sql.upper():
self.rowcount = meta.get("rows_written", 0)
# self.connection.ops.last_insert_id = meta.get("last_insert_id") # TODO: implement last insert id
elif "UPDATE" in sql.upper() or "DELETE" in sql.upper():
self.rowcount = meta.get("rows_written", 0)
else:
self.rowcount = meta.get("rows_read", 0)
except Exception as e:
self.results = []
raise DatabaseError(str(e))

return self

Expand All @@ -380,12 +426,13 @@ def fetchmany(self, size=None):
def fetchall(self):
if self.results:
results = self.results
self.results = None
self.results = [] # Clear the results after fetching
return results
return []

def __iter__(self):
return iter(self.fetchall())
while self.results:
yield self.fetchone()

@property
def rowcount(self):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "django-cf"
version = "0.0.14"
version = "0.0.16"
authors = [
{ name="Gabriel Massadas" },
]
Expand Down

0 comments on commit c7ba34c

Please sign in to comment.