2121
2222import logging
2323from datetime import datetime , date
24+ from types import ModuleType
2425
2526from sqlalchemy import types as sqltypes
2627from sqlalchemy .engine import default , reflection
@@ -205,6 +206,12 @@ def initialize(self, connection):
205206 self .default_schema_name = \
206207 self ._get_default_schema_name (connection )
207208
209+ def set_isolation_level (self , dbapi_connection , level ):
210+ """
211+ For CrateDB, this is implemented as a noop.
212+ """
213+ pass
214+
208215 def do_rollback (self , connection ):
209216 # if any exception is raised by the dbapi, sqlalchemy by default
210217 # attempts to do a rollback crate doesn't support rollbacks.
@@ -223,7 +230,21 @@ def connect(self, host=None, port=None, *args, **kwargs):
223230 use_ssl = asbool (kwargs .pop ("ssl" , False ))
224231 if use_ssl :
225232 servers = ["https://" + server for server in servers ]
226- return self .dbapi .connect (servers = servers , ** kwargs )
233+
234+ is_module = isinstance (self .dbapi , ModuleType )
235+ if is_module :
236+ driver_name = self .dbapi .__name__
237+ else :
238+ driver_name = self .dbapi .__class__ .__name__
239+ if driver_name == "crate.client" :
240+ if "database" in kwargs :
241+ del kwargs ["database" ]
242+ return self .dbapi .connect (servers = servers , ** kwargs )
243+ elif driver_name in ["psycopg" , "PsycopgAdaptDBAPI" , "AsyncAdapt_asyncpg_dbapi" ]:
244+ return self .dbapi .connect (host = host , port = port , ** kwargs )
245+ else :
246+ raise ValueError (f"Unknown driver variant: { driver_name } " )
247+
227248 return self .dbapi .connect (** kwargs )
228249
229250 def _get_default_schema_name (self , connection ):
@@ -269,11 +290,11 @@ def get_schema_names(self, connection, **kw):
269290 def get_table_names (self , connection , schema = None , ** kw ):
270291 if schema is None :
271292 schema = self ._get_effective_schema_name (connection )
272- cursor = connection .exec_driver_sql (
293+ cursor = connection .exec_driver_sql (self . _format_query (
273294 "SELECT table_name FROM information_schema.tables "
274295 "WHERE {0} = ? "
275296 "AND table_type = 'BASE TABLE' "
276- "ORDER BY table_name ASC, {0} ASC" .format (self .schema_column ),
297+ "ORDER BY table_name ASC, {0} ASC" ) .format (self .schema_column ),
277298 (schema or self .default_schema_name , )
278299 )
279300 return [row [0 ] for row in cursor .fetchall ()]
@@ -295,7 +316,7 @@ def get_columns(self, connection, table_name, schema=None, **kw):
295316 "AND column_name !~ ?" \
296317 .format (self .schema_column )
297318 cursor = connection .exec_driver_sql (
298- query ,
319+ self . _format_query ( query ) ,
299320 (table_name ,
300321 schema or self .default_schema_name ,
301322 r"(.*)\[\'(.*)\'\]" ) # regex to filter subscript
@@ -334,7 +355,7 @@ def result_fun(result):
334355 return set (rows [0 ] if rows else [])
335356
336357 pk_result = engine .exec_driver_sql (
337- query ,
358+ self . _format_query ( query ) ,
338359 (table_name , schema or self .default_schema_name )
339360 )
340361 pks = result_fun (pk_result )
@@ -375,6 +396,17 @@ def has_ilike_operator(self):
375396 server_version_info = self .server_version_info
376397 return server_version_info is not None and server_version_info >= (4 , 1 , 0 )
377398
399+ def _format_query (self , query ):
400+ """
401+ When using the PostgreSQL protocol with drivers `psycopg` or `asyncpg`,
402+ the paramstyle is not `qmark`, but `pyformat`.
403+
404+ TODO: Review: Is it legit and sane? Are there alternatives?
405+ """
406+ if self .paramstyle == "pyformat" :
407+ query = query .replace ("= ?" , "= %s" ).replace ("!~ ?" , "!~ %s" )
408+ return query
409+
378410
379411class DateTrunc (functions .GenericFunction ):
380412 name = "date_trunc"
0 commit comments