21
21
22
22
import logging
23
23
from datetime import datetime , date
24
+ from types import ModuleType
24
25
25
26
from sqlalchemy import types as sqltypes
26
27
from sqlalchemy .engine import default , reflection
@@ -202,6 +203,12 @@ def initialize(self, connection):
202
203
self .default_schema_name = \
203
204
self ._get_default_schema_name (connection )
204
205
206
+ def set_isolation_level (self , dbapi_connection , level ):
207
+ """
208
+ For CrateDB, this is implemented as a noop.
209
+ """
210
+ pass
211
+
205
212
def do_rollback (self , connection ):
206
213
# if any exception is raised by the dbapi, sqlalchemy by default
207
214
# attempts to do a rollback crate doesn't support rollbacks.
@@ -220,7 +227,21 @@ def connect(self, host=None, port=None, *args, **kwargs):
220
227
use_ssl = asbool (kwargs .pop ("ssl" , False ))
221
228
if use_ssl :
222
229
servers = ["https://" + server for server in servers ]
223
- return self .dbapi .connect (servers = servers , ** kwargs )
230
+
231
+ is_module = isinstance (self .dbapi , ModuleType )
232
+ if is_module :
233
+ driver_name = self .dbapi .__name__
234
+ else :
235
+ driver_name = self .dbapi .__class__ .__name__
236
+ if driver_name == "crate.client" :
237
+ if "database" in kwargs :
238
+ del kwargs ["database" ]
239
+ return self .dbapi .connect (servers = servers , ** kwargs )
240
+ elif driver_name in ["psycopg" , "PsycopgAdaptDBAPI" , "AsyncAdapt_asyncpg_dbapi" ]:
241
+ return self .dbapi .connect (host = host , port = port , ** kwargs )
242
+ else :
243
+ raise ValueError (f"Unknown driver variant: { driver_name } " )
244
+
224
245
return self .dbapi .connect (** kwargs )
225
246
226
247
def _get_default_schema_name (self , connection ):
@@ -266,11 +287,11 @@ def get_schema_names(self, connection, **kw):
266
287
def get_table_names (self , connection , schema = None , ** kw ):
267
288
if schema is None :
268
289
schema = self ._get_effective_schema_name (connection )
269
- cursor = connection .exec_driver_sql (
290
+ cursor = connection .exec_driver_sql (self . _format_query (
270
291
"SELECT table_name FROM information_schema.tables "
271
292
"WHERE {0} = ? "
272
293
"AND table_type = 'BASE TABLE' "
273
- "ORDER BY table_name ASC, {0} ASC" .format (self .schema_column ),
294
+ "ORDER BY table_name ASC, {0} ASC" ) .format (self .schema_column ),
274
295
(schema or self .default_schema_name , )
275
296
)
276
297
return [row [0 ] for row in cursor .fetchall ()]
@@ -292,7 +313,7 @@ def get_columns(self, connection, table_name, schema=None, **kw):
292
313
"AND column_name !~ ?" \
293
314
.format (self .schema_column )
294
315
cursor = connection .exec_driver_sql (
295
- query ,
316
+ self . _format_query ( query ) ,
296
317
(table_name ,
297
318
schema or self .default_schema_name ,
298
319
r"(.*)\[\'(.*)\'\]" ) # regex to filter subscript
@@ -331,7 +352,7 @@ def result_fun(result):
331
352
return set (rows [0 ] if rows else [])
332
353
333
354
pk_result = engine .exec_driver_sql (
334
- query ,
355
+ self . _format_query ( query ) ,
335
356
(table_name , schema or self .default_schema_name )
336
357
)
337
358
pks = result_fun (pk_result )
@@ -372,6 +393,17 @@ def has_ilike_operator(self):
372
393
server_version_info = self .server_version_info
373
394
return server_version_info is not None and server_version_info >= (4 , 1 , 0 )
374
395
396
+ def _format_query (self , query ):
397
+ """
398
+ When using the PostgreSQL protocol with drivers `psycopg` or `asyncpg`,
399
+ the paramstyle is not `qmark`, but `pyformat`.
400
+
401
+ TODO: Review: Is it legit and sane? Are there alternatives?
402
+ """
403
+ if self .paramstyle == "pyformat" :
404
+ query = query .replace ("= ?" , "= %s" ).replace ("!~ ?" , "!~ %s" )
405
+ return query
406
+
375
407
376
408
class DateTrunc (functions .GenericFunction ):
377
409
name = "date_trunc"
0 commit comments