diff --git a/Snowflake.Data.Tests/IntegrationTests/SFMultiStatementsIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFMultiStatementsIT.cs index f34b6d915..a9c48cb82 100644 --- a/Snowflake.Data.Tests/IntegrationTests/SFMultiStatementsIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFMultiStatementsIT.cs @@ -281,6 +281,116 @@ public void TestMixedQueryTypeWithBinding() } } + [Test] + public void TestMixedQueryBindingWithMultiStatementCountZero() + { + using (DbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + conn.Open(); + + using (DbCommand cmd = conn.CreateCommand()) + { + cmd.CommandText = $"use schema {testConfig.schema};"+ + $"use schema {testConfig.schema};"+ + $"create or replace table {TableName}(cola integer, colb string);" + + $"insert into {TableName} values (?, ?);" + + $"insert into {TableName} values (?, ?), (?, ?);" + + $"select * from {TableName};" + + $"drop table if exists {TableName}"; + + // Set statement count + var stmtCountParam = cmd.CreateParameter(); + stmtCountParam.ParameterName = "MULTI_STATEMENT_COUNT"; + stmtCountParam.DbType = DbType.Int16; + stmtCountParam.Value = 0; + cmd.Parameters.Add(stmtCountParam); + + // set parameter bindings + var p1 = cmd.CreateParameter(); + p1.ParameterName = "1"; + p1.DbType = DbType.Int16; + p1.Value = 1; + cmd.Parameters.Add(p1); + + var p2 = cmd.CreateParameter(); + p2.ParameterName = "2"; + p2.DbType = DbType.String; + p2.Value ="str1"; + cmd.Parameters.Add(p2); + + var p3 = cmd.CreateParameter(); + p3.ParameterName = "3"; + p3.DbType = DbType.Int16; + p3.Value = 2; + cmd.Parameters.Add(p3); + + var p4 = cmd.CreateParameter(); + p4.ParameterName = "4"; + p4.DbType = DbType.String; + p4.Value = "str2"; + cmd.Parameters.Add(p4); + + var p5 = cmd.CreateParameter(); + p5.ParameterName = "5"; + p5.DbType = DbType.Int16; + p5.Value = 3; + cmd.Parameters.Add(p5); + + var p6 = cmd.CreateParameter(); + p6.ParameterName = "6"; + p6.DbType = DbType.String; + p6.Value = "str3"; + cmd.Parameters.Add(p6); + + DbDataReader reader = cmd.ExecuteReader(); + + //skip use statement + Assert.IsTrue(reader.NextResult()); + Assert.IsTrue(reader.NextResult()); + + // result of create + Assert.IsFalse(reader.HasRows); + Assert.AreEqual(0, reader.RecordsAffected); + + // result of insert #1 + Assert.IsTrue(reader.NextResult()); + Assert.IsFalse(reader.HasRows); + Assert.AreEqual(1, reader.RecordsAffected); + + // result of insert #2 + Assert.IsTrue(reader.NextResult()); + Assert.IsFalse(reader.HasRows); + Assert.AreEqual(2, reader.RecordsAffected); + + // result of select + Assert.IsTrue(reader.NextResult()); + Assert.IsTrue(reader.HasRows); + Assert.AreEqual(-1, reader.RecordsAffected); + Assert.IsTrue(reader.Read()); + Assert.AreEqual(1, reader.GetInt32(0)); + Assert.AreEqual("str1", reader.GetString(1)); + Assert.IsTrue(reader.Read()); + Assert.AreEqual(2, reader.GetInt32(0)); + Assert.AreEqual("str2", reader.GetString(1)); + Assert.IsTrue(reader.Read()); + Assert.AreEqual(3, reader.GetInt32(0)); + Assert.AreEqual("str3", reader.GetString(1)); + Assert.IsFalse(reader.Read()); + + // result of drop + Assert.IsTrue(reader.NextResult()); + Assert.IsFalse(reader.HasRows); + Assert.AreEqual(0, reader.RecordsAffected); + + Assert.IsFalse(reader.NextResult()); + reader.Close(); + } + + conn.Close(); + } + } + [Test] public void TestWithExecuteNonQuery() {