Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1490901 Passcode support for mfa authentication #979

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Data.Common;
using System.Net;
using Snowflake.Data.Core.Session;
using Snowflake.Data.Core.Tools;
using Snowflake.Data.Tests.Util;

namespace Snowflake.Data.Tests.IntegrationTests
Expand Down Expand Up @@ -2271,6 +2272,26 @@ public void TestUseMultiplePoolsConnectionPoolByDefault()
// assert
Assert.AreEqual(ConnectionPoolType.MultipleConnectionPool, poolVersion);
}

[Test]
[Ignore("Requires manual steps and environment with mfa authentication enrolled")] // to enroll to mfa authentication edit your user profile
public void TestMfaWithPasswordConnection()
{
// arrange
using (SnowflakeDbConnection conn = new SnowflakeDbConnection())
{
conn.Passcode = SecureStringHelper.Encode("123456");
// manual action: stop here in breakpoint to provide proper passcode by: conn.Passcode = SecureStringHelper.Encode("...");
conn.ConnectionString = ConnectionString + "minPoolSize=0;application=DuoTest";

// act
conn.Open();

// assert
Assert.AreEqual(ConnectionState.Open, conn.State);
// manual action: verify that you have received no push request for given connection
}
}
}
}

Expand Down
6 changes: 3 additions & 3 deletions Snowflake.Data.Tests/Mock/MockSnowflakeDbConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ public override Task OpenAsync(CancellationToken cancellationToken)
cancellationToken);

}

private void SetMockSession()
{
SfSession = new SFSession(ConnectionString, Password, _restRequester);
SfSession = new SFSession(ConnectionString, Password, Passcode, _restRequester);

_connectionTimeout = (int)SfSession.connectionTimeout.TotalSeconds;

Expand All @@ -92,7 +92,7 @@ private void OnSessionEstablished()
{
_connectionState = ConnectionState.Open;
}

protected override bool CanReuseSession(TransactionRollbackStatus transactionRollbackStatus)
{
return false;
Expand Down
50 changes: 25 additions & 25 deletions Snowflake.Data.Tests/UnitTests/ArrowResultSetTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public void BeforeTest()
// by default generate Int32 values from 1 to RowCount
PrepareTestCase(SFDataType.FIXED, 0, Enumerable.Range(1, RowCount).ToArray());
}

[Test]
public void TestResultFormatIsArrow()
{
Expand Down Expand Up @@ -139,7 +139,7 @@ public void TestGetValueReturnsNull()
var arrowResultSet = new ArrowResultSet(responseData, sfStatement, new CancellationToken());

arrowResultSet.Next();

Assert.AreEqual(true, arrowResultSet.IsDBNull(0));
Assert.AreEqual(DBNull.Value, arrowResultSet.GetValue(0));
}
Expand All @@ -151,7 +151,7 @@ public void TestGetDecimal()

TestGetNumber(testValues);
}

[Test]
public void TestGetNumber64()
{
Expand All @@ -164,7 +164,7 @@ public void TestGetNumber64()
public void TestGetNumber32()
{
var testValues = new int[] { 0, 100, -100, Int32.MaxValue, Int32.MinValue };

TestGetNumber(testValues);
}

Expand All @@ -175,7 +175,7 @@ public void TestGetNumber16()

TestGetNumber(testValues);
}

[Test]
public void TestGetNumber8()
{
Expand All @@ -199,7 +199,7 @@ private void TestGetNumber(IEnumerable testValues)
Assert.AreEqual(expectedValue, _arrowResultSet.GetDecimal(ColumnIndex));
Assert.AreEqual(expectedValue, _arrowResultSet.GetDouble(ColumnIndex));
Assert.AreEqual(expectedValue, _arrowResultSet.GetFloat(ColumnIndex));

if (expectedValue >= Int64.MinValue && expectedValue <= Int64.MaxValue)
{
// get integer value
Expand Down Expand Up @@ -229,7 +229,7 @@ public void TestGetBoolean()
var testValues = new bool[] { true, false };

PrepareTestCase(SFDataType.BOOLEAN, 0, testValues);

foreach (var testValue in testValues)
{
_arrowResultSet.Next();
Expand All @@ -244,15 +244,15 @@ public void TestGetReal()
var testValues = new double[] { 0, Double.MinValue, Double.MaxValue };

PrepareTestCase(SFDataType.REAL, 0, testValues);

foreach (var testValue in testValues)
{
_arrowResultSet.Next();
Assert.AreEqual(testValue, _arrowResultSet.GetValue(ColumnIndex));
Assert.AreEqual(testValue, _arrowResultSet.GetDouble(ColumnIndex));
}
}

[Test]
public void TestGetText()
{
Expand All @@ -263,15 +263,15 @@ public void TestGetText()
};

PrepareTestCase(SFDataType.TEXT, 0, testValues);

foreach (var testValue in testValues)
{
_arrowResultSet.Next();
Assert.AreEqual(testValue, _arrowResultSet.GetValue(ColumnIndex));
Assert.AreEqual(testValue, _arrowResultSet.GetString(ColumnIndex));
}
}

[Test]
public void TestGetTextWithOneChar()
{
Expand All @@ -289,14 +289,14 @@ public void TestGetTextWithOneChar()
#endif

PrepareTestCase(SFDataType.TEXT, 0, testValues);

foreach (var testValue in testValues)
{
_arrowResultSet.Next();
Assert.AreEqual(testValue, _arrowResultSet.GetChar(ColumnIndex));
}
}

[Test]
public void TestGetArray()
{
Expand All @@ -307,7 +307,7 @@ public void TestGetArray()
};

PrepareTestCase(SFDataType.ARRAY, 0, testValues);

foreach (var testValue in testValues)
{
_arrowResultSet.Next();
Expand All @@ -319,7 +319,7 @@ public void TestGetArray()
Assert.AreEqual(testValue.Length, str.Length);
}
}

[Test]
public void TestGetBinary()
{
Expand All @@ -341,7 +341,7 @@ public void TestGetBinary()
Assert.AreEqual(testValue[j], buffer[j], "position " + j);
}
}

[Test]
public void TestGetDate()
{
Expand All @@ -353,15 +353,15 @@ public void TestGetDate()
};

PrepareTestCase(SFDataType.DATE, 0, testValues);

foreach (var testValue in testValues)
{
_arrowResultSet.Next();
Assert.AreEqual(testValue, _arrowResultSet.GetValue(ColumnIndex));
Assert.AreEqual(testValue, _arrowResultSet.GetDateTime(ColumnIndex));
}
}

[Test]
public void TestGetTime()
{
Expand All @@ -383,7 +383,7 @@ public void TestGetTime()
Assert.AreEqual(testValue, _arrowResultSet.GetValue(ColumnIndex));
Assert.AreEqual(testValue, _arrowResultSet.GetDateTime(ColumnIndex));
}
}
}
}

[Test]
Expand Down Expand Up @@ -473,10 +473,10 @@ private QueryExecResponseData PrepareResponseData(RecordBatch recordBatch, SFDat
return new QueryExecResponseData
{
rowType = recordBatch.Schema.FieldsList
.Select(col =>
.Select(col =>
new ExecResponseRowType
{
name = col.Name,
name = col.Name,
type = sfType.ToString(),
scale = scale
}).ToList(),
Expand All @@ -491,7 +491,7 @@ private string ConvertToBase64String(RecordBatch recordBatch)
{
if (recordBatch == null)
return "";

using (var stream = new MemoryStream())
{
using (var writer = new ArrowStreamWriter(stream, recordBatch.Schema))
Expand All @@ -502,12 +502,12 @@ private string ConvertToBase64String(RecordBatch recordBatch)
return Convert.ToBase64String(stream.ToArray());
}
}

private SFStatement PrepareStatement()
{
SFSession session = new SFSession("user=user;password=password;account=account;", null);
SFSession session = new SFSession("user=user;password=password;account=account;", null, null);
return new SFStatement(session);
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public void TestAuthPropertiesValid(string connectionString, string password)
var securePassword = string.IsNullOrEmpty(password) ? null : new NetworkCredential(string.Empty, password).SecurePassword;

// Act/Assert
Assert.DoesNotThrow(() => SFSessionProperties.ParseConnectionString(_necessaryNonAuthProperties + connectionString, securePassword));
Assert.DoesNotThrow(() => SFSessionProperties.ParseConnectionString(_necessaryNonAuthProperties + connectionString, securePassword, null));
}

[TestCase("authenticator=snowflake;", null, SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property PASSWORD is not provided.")]
Expand All @@ -54,7 +54,7 @@ public void TestAuthPropertiesInvalid(string connectionString, string password,
var securePassword = string.IsNullOrEmpty(password) ? null : new NetworkCredential(string.Empty, password).SecurePassword;

// Act
var exception = Assert.Throws<SnowflakeDbException>(() => SFSessionProperties.ParseConnectionString(_necessaryNonAuthProperties + connectionString, securePassword));
var exception = Assert.Throws<SnowflakeDbException>(() => SFSessionProperties.ParseConnectionString(_necessaryNonAuthProperties + connectionString, securePassword, null));

// Assert
SnowflakeDbExceptionAssert.HasErrorCode(exception, expectedError);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ private QueryExecResponseData mockQueryRequestData()
private SFResultSet mockSFResultSet(QueryExecResponseData responseData, CancellationToken token)
{
string connectionString = "user=user;password=password;account=account;";
SFSession session = new SFSession(connectionString, null);
SFSession session = new SFSession(connectionString, null , null);
List<NameValueParameter> list = new List<NameValueParameter>
{
new NameValueParameter { name = "CLIENT_PREFETCH_THREADS", value = "3" }
Expand Down
28 changes: 14 additions & 14 deletions Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ public void TestDifferentPoolsAreReturnedForDifferentConnectionStrings()
public void TestGetSessionWorksForSpecifiedConnectionString()
{
// Act
var sfSession = _connectionPoolManager.GetSession(ConnectionString1, null);
var sfSession = _connectionPoolManager.GetSession(ConnectionString1, null, null);

// Assert
Assert.AreEqual(ConnectionString1, sfSession.ConnectionString);
Expand All @@ -122,7 +122,7 @@ public void TestGetSessionWorksForSpecifiedConnectionString()
public async Task TestGetSessionAsyncWorksForSpecifiedConnectionString()
{
// Act
var sfSession = await _connectionPoolManager.GetSessionAsync(ConnectionString1, null, CancellationToken.None);
var sfSession = await _connectionPoolManager.GetSessionAsync(ConnectionString1, null, null, CancellationToken.None);

// Assert
Assert.AreEqual(ConnectionString1, sfSession.ConnectionString);
Expand All @@ -133,7 +133,7 @@ public async Task TestGetSessionAsyncWorksForSpecifiedConnectionString()
public void TestCountingOfSessionProvidedByPool()
{
// Act
_connectionPoolManager.GetSession(ConnectionString1, null);
_connectionPoolManager.GetSession(ConnectionString1, null, null);

// Assert
var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, null);
Expand All @@ -144,7 +144,7 @@ public void TestCountingOfSessionProvidedByPool()
public void TestCountingOfSessionReturnedBackToPool()
{
// Arrange
var sfSession = _connectionPoolManager.GetSession(ConnectionString1, null);
var sfSession = _connectionPoolManager.GetSession(ConnectionString1, null, null);

// Act
_connectionPoolManager.AddSession(sfSession);
Expand Down Expand Up @@ -285,8 +285,8 @@ public void TestGetMaxPoolSizeOnManagerLevelWhenAllPoolsEqual()
public void TestGetCurrentPoolSizeReturnsSumOfPoolSizes()
{
// Arrange
EnsurePoolSize(ConnectionString1, null, 2);
EnsurePoolSize(ConnectionString2, null, 3);
EnsurePoolSize(ConnectionString1, null, null,2);
EnsurePoolSize(ConnectionString2, null, null, 3);

// act
var poolSize = _connectionPoolManager.GetCurrentPoolSize();
Expand All @@ -300,7 +300,7 @@ public void TestReturnPoolForSecurePassword()
{
// arrange
const string AnotherPassword = "anotherPassword";
EnsurePoolSize(ConnectionStringWithoutPassword, _password3, 1);
EnsurePoolSize(ConnectionStringWithoutPassword, _password3, null, 1);

// act
var pool = _connectionPoolManager.GetPool(ConnectionStringWithoutPassword, SecureStringHelper.Encode(AnotherPassword)); // a new pool has been created because the password is different
Expand All @@ -315,9 +315,9 @@ public void TestReturnDifferentPoolWhenPasswordProvidedInDifferentWay()
{
// arrange
var connectionStringWithPassword = $"{ConnectionStringWithoutPassword}password={SecureStringHelper.Decode(_password3)}";
EnsurePoolSize(ConnectionStringWithoutPassword, _password3, 2);
EnsurePoolSize(connectionStringWithPassword, null, 5);
EnsurePoolSize(connectionStringWithPassword, _password3, 8);
EnsurePoolSize(ConnectionStringWithoutPassword, _password3, null, 2);
EnsurePoolSize(connectionStringWithPassword, null, null, 5);
EnsurePoolSize(connectionStringWithPassword, _password3, null, 8);

// act
var pool1 = _connectionPoolManager.GetPool(ConnectionStringWithoutPassword, _password3);
Expand Down Expand Up @@ -360,23 +360,23 @@ public void TestPoolDoesNotSerializePassword()
Assert.IsFalse(serializedPool.Contains(password));
}

private void EnsurePoolSize(string connectionString, SecureString password, int requiredCurrentSize)
private void EnsurePoolSize(string connectionString, SecureString password, SecureString passcode, int requiredCurrentSize)
{
var sessionPool = _connectionPoolManager.GetPool(connectionString, password);
sessionPool.SetMaxPoolSize(requiredCurrentSize);
for (var i = 0; i < requiredCurrentSize; i++)
{
_connectionPoolManager.GetSession(connectionString, password);
_connectionPoolManager.GetSession(connectionString, password, passcode);
}
Assert.AreEqual(requiredCurrentSize, sessionPool.GetCurrentPoolSize());
}
}

class MockSessionFactory : ISessionFactory
{
public SFSession NewSession(string connectionString, SecureString password)
public SFSession NewSession(string connectionString, SecureString password, SecureString passcode)
{
var mockSfSession = new Mock<SFSession>(connectionString, password);
var mockSfSession = new Mock<SFSession>(connectionString, password, passcode);
mockSfSession.Setup(x => x.Open()).Verifiable();
mockSfSession.Setup(x => x.OpenAsync(default)).Returns(Task.FromResult(this));
mockSfSession.Setup(x => x.IsNotOpen()).Returns(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class SFAuthenticatorFactoryTest
private IAuthenticator GetAuthenticator(string authenticatorName, string extraParams = "")
{
string connectionString = $"account=test;user=test;password=test;authenticator={authenticatorName};{extraParams}";
SFSession session = new SFSession(connectionString, null);
SFSession session = new SFSession(connectionString, null, null);

return AuthenticatorFactory.GetAuthenticator(session);
}
Expand Down
Loading
Loading