Skip to content

Commit f65865f

Browse files
committed
Advanced rows count returning #221
For inlined INSERTs and from INSERT INTO ... SELECT
1 parent ae4b0ac commit f65865f

File tree

8 files changed

+69
-8
lines changed

8 files changed

+69
-8
lines changed

clickhouse_driver/client.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,6 @@ def execute(self, query, params=None, with_column_types=False,
255255
Defaults to ``False`` (row-like form).
256256
257257
:return: * number of inserted rows for INSERT queries with data.
258-
Returning rows count from INSERT FROM SELECT is not
259-
supported.
260258
* if `with_column_types=False`: `list` of `tuples` with
261259
rows/columns.
262260
* if `with_column_types=True`: `tuple` of 2 elements:
@@ -284,11 +282,23 @@ def execute(self, query, params=None, with_column_types=False,
284282
)
285283
else:
286284
rv = self.process_ordinary_query(
287-
query, params=params, with_column_types=with_column_types,
285+
query, params=params, with_column_types=True,
288286
external_tables=external_tables,
289287
query_id=query_id, types_check=types_check,
290288
columnar=columnar
291289
)
290+
rows, columns_with_types = rv
291+
# No columns in case of DDL or INSERT ... SELECT
292+
if not columns_with_types:
293+
# Dump check for INSERT.
294+
# Backwards compatibility for .execute(DDL) return [].
295+
# TODO: remove in 0.3.0. Return integer (zero) for DDL too.
296+
if query.lower().strip().startswith('insert'):
297+
rows = self.last_query.progress.rows
298+
rv = (rows, columns_with_types)
299+
if not with_column_types:
300+
rv = rv[0]
301+
292302
self.last_query.store_elapsed(time() - start_time)
293303
return rv
294304

clickhouse_driver/dbapi/cursor.py

+4
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,10 @@ def _process_response(self, response, executemany=False):
332332
self._rowcount = len(rows)
333333
else:
334334
self._columns = self._types = []
335+
# TODO: return 0 for DDL in 0.3.0
336+
if not isinstance(rows, list): # number of inserted rows
337+
self._rowcount = rows
338+
rows = []
335339

336340
self._rows = rows
337341

docs/features.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ Parameters are expected in Python extended format codes, e.g.
324324
... {'limit': 3}
325325
... )
326326
>>> cursor.rowcount
327-
0
327+
5
328328
>>> cursor.execute('SELECT sum(x) FROM test')
329329
>>> cursor.fetchall()
330330
[(303,)]

docs/quickstart.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ Of course for ``INSERT ... SELECT`` queries data is not needed:
193193
... 'SELECT * FROM system.numbers LIMIT %(limit)s',
194194
... {'limit': 5}
195195
... )
196-
[]
196+
5
197197
198198
ClickHouse will execute this query like a usual ``SELECT`` query.
199199

tests/test_blocks.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from clickhouse_driver.errors import ServerException
55
from tests.testcase import BaseTestCase, file_config
6-
from tests.util import capture_logging
6+
from tests.util import capture_logging, will_fail_in
77

88

99
class BlocksTestCase(BaseTestCase):
@@ -80,6 +80,11 @@ def test_close_connection_on_keyboard_interrupt(self):
8080

8181
self.assertFalse(self.client.connection.connected)
8282

83+
@will_fail_in(0, 3)
84+
def test_ddl_return_value(self):
85+
rv = self.client.execute('DROP TABLE test IF EXISTS')
86+
self.assertEqual(rv, [])
87+
8388

8489
class ProgressTestCase(BaseTestCase):
8590
def test_select_with_progress(self):

tests/test_dbapi.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
ProgrammingError, InterfaceError, OperationalError
99
)
1010
from tests.testcase import BaseTestCase
11+
from tests.util import will_fail_in
1112

1213

1314
class DBAPITestCaseBase(BaseTestCase):
@@ -149,7 +150,15 @@ def test_rowcount_insert_from_select(self):
149150
'INSERT INTO test '
150151
'SELECT number FROM system.numbers LIMIT 4'
151152
)
152-
self.assertEqual(cursor.rowcount, -1)
153+
self.assertEqual(cursor.rowcount, 4)
154+
self.assertEqual(cursor.fetchall(), [])
155+
156+
cursor.execute(
157+
'INSERT INTO test '
158+
'SELECT number FROM system.numbers LIMIT 0'
159+
)
160+
self.assertEqual(cursor.rowcount, 0)
161+
self.assertEqual(cursor.fetchall(), [])
153162

154163
def test_description(self):
155164
with self.created_cursor() as cursor:
@@ -165,6 +174,7 @@ def test_pep249_sizes(self):
165174
cursor.setinputsizes(0)
166175
cursor.setoutputsize(0)
167176

177+
@will_fail_in(0, 3)
168178
def test_ddl(self):
169179
with self.created_cursor() as cursor:
170180
cursor.execute('DROP TABLE IF EXISTS test')

tests/test_insert.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from tests.testcase import BaseTestCase
44
from clickhouse_driver import errors
55
from clickhouse_driver.errors import ServerException
6+
from tests.util import require_server_version
67

78

89
class InsertTestCase(BaseTestCase):
@@ -134,7 +135,21 @@ def test_insert_from_select(self):
134135
'INSERT INTO test (a) '
135136
'SELECT number FROM system.numbers LIMIT 5'
136137
)
137-
self.assertEqual(inserted, [])
138+
self.assertEqual(inserted, 5)
139+
140+
inserted = self.client.execute(
141+
'INSERT INTO test (a) '
142+
'SELECT number FROM system.numbers LIMIT 0'
143+
)
144+
self.assertEqual(inserted, 0)
145+
146+
@require_server_version(19, 3, 3)
147+
def test_insert_inline(self):
148+
with self.create_table('a UInt64'):
149+
inserted = self.client.execute(
150+
'INSERT INTO test (a) VALUES (1), (2), (3)'
151+
)
152+
self.assertEqual(inserted, 3)
138153

139154
def test_insert_return(self):
140155
with self.create_table('a Int8'):

tests/util.py

+17
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import logging
33
from io import StringIO
44

5+
from clickhouse_driver import VERSION, __version__
6+
57

68
def skip_by_server_version(testcase, version_required):
79
testcase.skipTest(
@@ -29,6 +31,21 @@ def wrapper(*args, **kwargs):
2931
return check
3032

3133

34+
def will_fail_in(*version):
35+
def decorator(f):
36+
@wraps(f)
37+
def wrapper(*args, **kwargs):
38+
self = args[0]
39+
40+
if VERSION >= version:
41+
self.fail(
42+
'This test should not work in {}'.format(__version__)
43+
)
44+
45+
return wrapper
46+
return decorator
47+
48+
3249
class LoggingCapturer(object):
3350
def __init__(self, logger_name, level):
3451
self.old_stdout_handlers = []

0 commit comments

Comments
 (0)