Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 54 additions & 28 deletions mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2924,7 +2924,9 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p
row.append(
FetchLobColumnData(hStmt, i, SQL_C_CHAR, false, false, charEncoding));
} else {
uint64_t fetchBufferSize = columnSize + 1 /* null-termination */;
// Multiply by 4 because utf8 conversion by the driver might
// turn varchar(x) into up to 3*x (maybe 4*x?) bytes.
uint64_t fetchBufferSize = 4 * columnSize + 1 /* null-termination */;
std::vector<SQLCHAR> dataBuffer(fetchBufferSize);
SQLLEN dataLen;
ret = SQLGetData_ptr(hStmt, i, SQL_C_CHAR, dataBuffer.data(), dataBuffer.size(),
Expand Down Expand Up @@ -2953,12 +2955,15 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p
row.append(raw_bytes);
}
} else {
// Buffer too small, fallback to streaming
LOG("SQLGetData: CHAR column %d data truncated "
"(buffer_size=%zu), using streaming LOB",
i, dataBuffer.size());
row.append(FetchLobColumnData(hStmt, i, SQL_C_CHAR, false, false,
charEncoding));
// Reaching this case indicates an error in mssql_python.
// Theoretically, we could still compensate by calling SQLGetData or
// FetchLobColumnData more often, but then we would still have to process
// the data we already got from the above call to SQLGetData.
// Better to throw an exception and fix the code than to risk returning corrupted data.
ThrowStdException(
"Internal error: SQLGetData returned data "
"larger than expected for CHAR column"
);
}
} else if (dataLen == SQL_NULL_DATA) {
LOG("SQLGetData: Column %d is NULL (CHAR)", i);
Expand Down Expand Up @@ -2995,7 +3000,7 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p
case SQL_WCHAR:
case SQL_WVARCHAR:
case SQL_WLONGVARCHAR: {
if (columnSize == SQL_NO_TOTAL || columnSize > 4000) {
if (columnSize == SQL_NO_TOTAL || columnSize == 0 || columnSize > 4000) {
LOG("SQLGetData: Streaming LOB for column %d (SQL_C_WCHAR) "
"- columnSize=%lu",
i, (unsigned long)columnSize);
Expand Down Expand Up @@ -3024,12 +3029,15 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p
"length=%lu for column %d",
(unsigned long)numCharsInData, i);
} else {
// Buffer too small, fallback to streaming
LOG("SQLGetData: NVARCHAR column %d data "
"truncated, using streaming LOB",
i);
row.append(FetchLobColumnData(hStmt, i, SQL_C_WCHAR, true, false,
"utf-16le"));
// Reaching this case indicates an error in mssql_python.
// Theoretically, we could still compensate by calling SQLGetData or
// FetchLobColumnData more often, but then we would still have to process
// the data we already got from the above call to SQLGetData.
// Better to throw an exception and fix the code than to risk returning corrupted data.
ThrowStdException(
"Internal error: SQLGetData returned data "
"larger than expected for WCHAR column"
);
}
} else if (dataLen == SQL_NULL_DATA) {
LOG("SQLGetData: Column %d is NULL (NVARCHAR)", i);
Expand Down Expand Up @@ -3291,8 +3299,15 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p
row.append(py::bytes(
reinterpret_cast<const char*>(dataBuffer.data()), dataLen));
} else {
row.append(
FetchLobColumnData(hStmt, i, SQL_C_BINARY, false, true, ""));
// Reaching this case indicates an error in mssql_python.
// Theoretically, we could still compensate by calling SQLGetData or
// FetchLobColumnData more often, but then we would still have to process
// the data we already got from the above call to SQLGetData.
// Better to throw an exception and fix the code than to risk returning corrupted data.
ThrowStdException(
"Internal error: SQLGetData returned data "
"larger than expected for BINARY column"
);
}
} else if (dataLen == SQL_NULL_DATA) {
row.append(py::none());
Expand Down Expand Up @@ -3434,7 +3449,9 @@ SQLRETURN SQLBindColums(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& column
// TODO: handle variable length data correctly. This logic wont
// suffice
HandleZeroColumnSizeAtFetch(columnSize);
uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/;
// Multiply by 4 because utf8 conversion by the driver might
// turn varchar(x) into up to 3*x (maybe 4*x?) bytes.
uint64_t fetchBufferSize = 4 * columnSize + 1 /*null-terminator*/;
// TODO: For LONGVARCHAR/BINARY types, columnSize is returned as
// 2GB-1 by SQLDescribeCol. So fetchBufferSize = 2GB.
// fetchSize=1 if columnSize>1GB. So we'll allocate a vector of
Expand Down Expand Up @@ -3580,8 +3597,7 @@ SQLRETURN SQLBindColums(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& column
// Fetch rows in batches
// TODO: Move to anonymous namespace, since it is not used outside this file
SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& columnNames,
py::list& rows, SQLUSMALLINT numCols, SQLULEN& numRowsFetched,
const std::vector<SQLUSMALLINT>& lobColumns) {
py::list& rows, SQLUSMALLINT numCols, SQLULEN& numRowsFetched) {
LOG("FetchBatchData: Fetching data in batches");
SQLRETURN ret = SQLFetchScroll_ptr(hStmt, SQL_FETCH_NEXT, 0);
if (ret == SQL_NO_DATA) {
Expand All @@ -3600,19 +3616,28 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum
SQLULEN columnSize;
SQLULEN processedColumnSize;
uint64_t fetchBufferSize;
bool isLob;
};
std::vector<ColumnInfo> columnInfos(numCols);
for (SQLUSMALLINT col = 0; col < numCols; col++) {
const auto& columnMeta = columnNames[col].cast<py::dict>();
columnInfos[col].dataType = columnMeta["DataType"].cast<SQLSMALLINT>();
columnInfos[col].columnSize = columnMeta["ColumnSize"].cast<SQLULEN>();
columnInfos[col].isLob =
std::find(lobColumns.begin(), lobColumns.end(), col + 1) != lobColumns.end();
columnInfos[col].processedColumnSize = columnInfos[col].columnSize;
HandleZeroColumnSizeAtFetch(columnInfos[col].processedColumnSize);
columnInfos[col].fetchBufferSize =
columnInfos[col].processedColumnSize + 1; // +1 for null terminator
switch (columnInfos[col].dataType) {
case SQL_CHAR:
case SQL_VARCHAR:
case SQL_LONGVARCHAR:
// Multiply by 4 because utf8 conversion by the driver might
// turn varchar(x) into up to 3*x (maybe 4*x?) bytes.
columnInfos[col].fetchBufferSize =
4 * columnInfos[col].processedColumnSize + 1; // +1 for null terminator
break;
default:
columnInfos[col].fetchBufferSize =
columnInfos[col].processedColumnSize + 1; // +1 for null terminator
break;
}
}

std::string decimalSeparator = GetDecimalSeparator(); // Cache decimal separator
Expand All @@ -3630,7 +3655,6 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum
columnInfosExt[col].columnSize = columnInfos[col].columnSize;
columnInfosExt[col].processedColumnSize = columnInfos[col].processedColumnSize;
columnInfosExt[col].fetchBufferSize = columnInfos[col].fetchBufferSize;
columnInfosExt[col].isLob = columnInfos[col].isLob;

// Map data type to processor function (switch executed once per column,
// not per cell)
Expand Down Expand Up @@ -3916,7 +3940,9 @@ size_t calculateRowSize(py::list& columnNames, SQLUSMALLINT numCols) {
case SQL_CHAR:
case SQL_VARCHAR:
case SQL_LONGVARCHAR:
rowSize += columnSize;
// Multiply by 4 because utf8 conversion by the driver might
// turn varchar(x) into up to 3*x (maybe 4*x?) bytes.
rowSize += 4 * columnSize;
break;
case SQL_SS_XML:
case SQL_WCHAR:
Expand Down Expand Up @@ -4068,7 +4094,7 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch
SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)(intptr_t)fetchSize, 0);
SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched, 0);

ret = FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched, lobColumns);
ret = FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched);
if (!SQL_SUCCEEDED(ret) && ret != SQL_NO_DATA) {
LOG("FetchMany_wrap: Error when fetching data - SQLRETURN=%d", ret);
return ret;
Expand Down Expand Up @@ -4201,7 +4227,7 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows,

while (ret != SQL_NO_DATA) {
ret =
FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched, lobColumns);
FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched);
if (!SQL_SUCCEEDED(ret) && ret != SQL_NO_DATA) {
LOG("FetchAll_wrap: Error when fetching data - SQLRETURN=%d", ret);
return ret;
Expand Down
40 changes: 21 additions & 19 deletions mssql_python/pybind/ddbc_bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -651,14 +651,8 @@ struct ColumnInfoExt {
SQLULEN columnSize;
SQLULEN processedColumnSize;
uint64_t fetchBufferSize;
bool isLob;
};

// Forward declare FetchLobColumnData (defined in ddbc_bindings.cpp) - MUST be
// outside namespace
py::object FetchLobColumnData(SQLHSTMT hStmt, SQLUSMALLINT col, SQLSMALLINT cType, bool isWideChar,
bool isBinary, const std::string& charEncoding = "utf-8");

// Specialized column processors for each data type (eliminates switch in hot
// loop)
namespace ColumnProcessors {
Expand Down Expand Up @@ -795,7 +789,7 @@ inline void ProcessChar(PyObject* row, ColumnBuffers& buffers, const void* colIn
// Fast path: Data fits in buffer (not LOB or truncated)
// fetchBufferSize includes null-terminator, numCharsInData doesn't. Hence
// '<'
if (!colInfo->isLob && numCharsInData < colInfo->fetchBufferSize) {
if (numCharsInData < colInfo->fetchBufferSize) {
// Performance: Direct Python C API call - create string from buffer
PyObject* pyStr = PyUnicode_FromStringAndSize(
reinterpret_cast<char*>(
Expand All @@ -808,9 +802,12 @@ inline void ProcessChar(PyObject* row, ColumnBuffers& buffers, const void* colIn
PyList_SET_ITEM(row, col - 1, pyStr);
}
} else {
// Slow path: LOB data requires separate fetch call
PyList_SET_ITEM(row, col - 1,
FetchLobColumnData(hStmt, col, SQL_C_CHAR, false, false).release().ptr());
// Reaching this case indicates an error in mssql_python.
// This function is only called on columns bound by SQLBindCol.
// For such columns, the ODBC Driver does not allow us to compensate by
// fetching the remaining data using SQLGetData / FetchLobColumnData.
ThrowStdException(
"Internal error: CHAR/VARCHAR column data exceeds buffer size.");
}
}

Expand Down Expand Up @@ -838,7 +835,7 @@ inline void ProcessWChar(PyObject* row, ColumnBuffers& buffers, const void* colI
// Fast path: Data fits in buffer (not LOB or truncated)
// fetchBufferSize includes null-terminator, numCharsInData doesn't. Hence
// '<'
if (!colInfo->isLob && numCharsInData < colInfo->fetchBufferSize) {
if (numCharsInData < colInfo->fetchBufferSize) {
#if defined(__APPLE__) || defined(__linux__)
// Performance: Direct UTF-16 decode (SQLWCHAR is 2 bytes on
// Linux/macOS)
Expand Down Expand Up @@ -875,9 +872,12 @@ inline void ProcessWChar(PyObject* row, ColumnBuffers& buffers, const void* colI
}
#endif
} else {
// Slow path: LOB data requires separate fetch call
PyList_SET_ITEM(row, col - 1,
FetchLobColumnData(hStmt, col, SQL_C_WCHAR, true, false).release().ptr());
// Reaching this case indicates an error in mssql_python.
// This function is only called on columns bound by SQLBindCol.
// For such columns, the ODBC Driver does not allow us to compensate by
// fetching the remaining data using SQLGetData / FetchLobColumnData.
ThrowStdException(
"Internal error: NCHAR/NVARCHAR column data exceeds buffer size.");
}
}

Expand All @@ -902,7 +902,7 @@ inline void ProcessBinary(PyObject* row, ColumnBuffers& buffers, const void* col
}

// Fast path: Data fits in buffer (not LOB or truncated)
if (!colInfo->isLob && static_cast<size_t>(dataLen) <= colInfo->processedColumnSize) {
if (static_cast<size_t>(dataLen) <= colInfo->processedColumnSize) {
// Performance: Direct Python C API call - create bytes from buffer
PyObject* pyBytes = PyBytes_FromStringAndSize(
reinterpret_cast<const char*>(
Expand All @@ -915,10 +915,12 @@ inline void ProcessBinary(PyObject* row, ColumnBuffers& buffers, const void* col
PyList_SET_ITEM(row, col - 1, pyBytes);
}
} else {
// Slow path: LOB data requires separate fetch call
PyList_SET_ITEM(
row, col - 1,
FetchLobColumnData(hStmt, col, SQL_C_BINARY, false, true, "").release().ptr());
// Reaching this case indicates an error in mssql_python.
// This function is only called on columns bound by SQLBindCol.
// For such columns, the ODBC Driver does not allow us to compensate by
// fetching the remaining data using SQLGetData / FetchLobColumnData.
ThrowStdException(
"Internal error: BINARY/VARBINARY column data exceeds buffer size.");
}
}

Expand Down
53 changes: 53 additions & 0 deletions tests/test_004_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15018,3 +15018,56 @@ def test_close(db_connection):
pytest.fail(f"Cursor close test failed: {e}")
finally:
cursor = db_connection.cursor()


def test_varchar_buffersize_special_character(cursor):
cursor.execute(
"drop table if exists #t1;\n"
+ "create table #t1 (a varchar(2) collate SQL_Latin1_General_CP1_CI_AS)\n"
+ "insert into #t1 values (N'ßl')\n"
)
assert cursor.execute("select * from #t1").fetchall()[0][0] == "ßl"
assert cursor.execute("select * from #t1").fetchmany(1)[0][0] == "ßl"
assert cursor.execute("select * from #t1").fetchone()[0] == "ßl"
assert cursor.execute("select LEFT(a, 1) from #t1").fetchone()[0] == "ß"
assert cursor.execute("select cast(a as varchar(3)) from #t1").fetchone()[0] == "ßl"


def test_varchar_latin1_fetch(cursor):
def query():
cursor.execute("""
declare @t1 as table(
row_nr int,
latin1 varchar(1) collate SQL_Latin1_General_CP1_CI_AS,
utf8 varchar(3) collate Latin1_General_100_CI_AI_SC_UTF8
)

insert into @t1 (row_nr, latin1)
select top 256
row_number() over(order by (select 1)) - 1,
cast(row_number() over(order by (select 1)) - 1 as binary(1))
from sys.objects

update @t1 set utf8 = latin1

select * from @t1
""")
cursor.nextset()
cursor.nextset()

def validate(result):
assert len(result) == 256
for (row_nr, latin1, utf8) in result:
assert utf8 == latin1 or (
# small difference in how sql server and msodbcsql18 handle unmapped characters
row_nr in [129, 141, 143, 144, 157]
and utf8 == chr(row_nr)
and latin1 == '?'
), (row_nr, utf8, latin1, chr(row_nr))

query()
validate(cursor.fetchall())
query()
validate(cursor.fetchmany(500))
query()
validate([cursor.fetchone() for _ in range(256)])
Loading