Skip to content

Commit

Permalink
Change validation process for session pool, if using passcode in conn…
Browse files Browse the repository at this point in the history
…ection string without username_password_authentication an exception will be thrown to indicate the user

that the passcode should not be used if pooling is enabled or with a minimum pool size greater than 0.
Additionally, if the passcode is provided by an argument and not part of the connection string, it will not be used for the session created by the session pool, and the push MFA mechanism will be triggered.
  • Loading branch information
sfc-gh-jmartinezramirez committed Jul 11, 2024
1 parent 37a391d commit a8dff4b
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,11 @@ public void setHttpClient(HttpClient httpClient)
{
// Nothing to do
}

public void Reset()
{
LoginRequests.Clear();
LoginResponses.Clear();
}
}
}
78 changes: 56 additions & 22 deletions Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerMFATest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,28 @@
* Copyright (c) 2024 Snowflake Computing Inc. All rights reserved.
*/

using System.Security;
using System.Threading;
using NUnit.Framework;
using Snowflake.Data.Core;
using Snowflake.Data.Core.Session;
using Snowflake.Data.Client;
using Snowflake.Data.Core.Tools;
using Snowflake.Data.Tests.Util;


namespace Snowflake.Data.Tests.UnitTests
{
using System;
using System.Linq;
using System.Security;
using System.Threading;
using Mock;

[TestFixture, NonParallelizable]
using NUnit.Framework;
using Snowflake.Data.Core;
using Snowflake.Data.Core.Session;
using Snowflake.Data.Client;
using Snowflake.Data.Core.Tools;
using Snowflake.Data.Tests.Util;

[TestFixture]
class ConnectionPoolManagerMFATest
{
private readonly ConnectionPoolManager _connectionPoolManager = new ConnectionPoolManager();
private const string ConnectionStringMFACache = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=2;passcode=12345;authenticator=username_password_mfa";
private const string ConnectionStringMFABasicWithoutPasscode = "db=D2;warehouse=W2;account=A2;user=U2;password=P2;role=R2;minPoolSize=3;";
private static PoolConfig s_poolConfig;
private static MockLoginMFATokenCacheRestRequester s_restRequester;

Expand All @@ -44,6 +47,7 @@ public static void AfterAllTests()
public void BeforeEach()
{
_connectionPoolManager.ClearAllPools();
s_restRequester.Reset();
}

[Test]
Expand Down Expand Up @@ -79,31 +83,61 @@ public void TestPoolManagerReturnsSessionPoolForGivenConnectionStringUsingMFA()
Assert.AreEqual("passcode", loginRequest1.data.extAuthnDuoMethod);
}

[Test]
public void TestPoolManagerShouldOnlyUsePasscodeAsArgumentForFirstSessionWhenNotUsingMFAAuthenticator()
{
// Arrange
const string TestPasscode = "123456";
s_restRequester.LoginResponses.Enqueue(new LoginResponseData()
{
authResponseSessionInfo = new SessionInfo()
});
s_restRequester.LoginResponses.Enqueue(new LoginResponseData()
{
authResponseSessionInfo = new SessionInfo()
});
s_restRequester.LoginResponses.Enqueue(new LoginResponseData()
{
authResponseSessionInfo = new SessionInfo()
});
// Act
var session = _connectionPoolManager.GetSession(ConnectionStringMFABasicWithoutPasscode, null, SecureStringHelper.Encode(TestPasscode));
Thread.Sleep(3000);

// Assert

Assert.AreEqual(3, s_restRequester.LoginRequests.Count);
var request = s_restRequester.LoginRequests.ToList();
Assert.AreEqual(1, request.Count(r => r.data.extAuthnDuoMethod == "passcode" && r.data.passcode == TestPasscode));
Assert.AreEqual(2, request.Count(r => r.data.extAuthnDuoMethod == "push" && r.data.passcode == null));
}

[Test]
public void TestPoolManagerShouldThrowExceptionIfForcePoolingWithPasscodeNotUsingMFATokenCacheAuthenticator()
{
// Arrange
var connectionString = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=2;passcode=12345;POOLINGENABLED=true";
// Act and assert
var thrown = Assert.Throws<Exception>(() =>_connectionPoolManager.GetSession(connectionString, null));
Assert.That(thrown.Message, Does.Contain("Could not get a pool because passcode was provided using a different authenticator than username_password_mfa"));
Assert.That(thrown.Message, Does.Contain("Could not use connection pool because passcode was provided using a different authenticator than username_password_mfa"));
}

[Test]
public void TestPoolManagerShouldDisablePoolingWhenPassingPasscodeNotUsingMFATokenCacheAuthenticator()
public void TestPoolManagerShouldNotThrowExceptionIfForcePoolingWithPasscodeNotUsingMFATokenCacheAuthenticator()
{
// Arrange
var connectionString = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=2;passcode=12345;";
var pool = _connectionPoolManager.GetPool(connectionString);
// Act
var session = _connectionPoolManager.GetSession(connectionString, null);

// Asssert
// TODO: Review pool config is not the same for session and session pool
// Assert.IsFalse(session.GetPooling());
Assert.AreEqual(0, pool.GetCurrentPoolSize());
Assert.IsFalse(pool.GetPooling());
var connectionString = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=2;passcode=12345;POOLINGENABLED=false";
// Act and assert
Assert.DoesNotThrow(() =>_connectionPoolManager.GetSession(connectionString, null));
}

[Test]
public void TestPoolManagerShouldNotThrowExceptionIfMinPoolSizeZeroNotUsingMFATokenCacheAuthenticator()
{
// Arrange
var connectionString = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=0;passcode=12345;POOLINGENABLED=true";
// Act and assert
Assert.DoesNotThrow(() =>_connectionPoolManager.GetSession(connectionString, null));
}
}

Expand Down
4 changes: 0 additions & 4 deletions Snowflake.Data/Core/Session/SFSessionHttpClientProperties.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,6 @@ public void DisablePoolingDefaultIfSecretsProvidedExternally(SFSessionProperties
&& !properties.IsNonEmptyValueProvided(SFSessionProperty.PRIVATE_KEY_PWD))
{
DisablePoolingIfNotExplicitlyEnabled(properties, "key pair with private key in a file");
} else if (!MFACacheAuthenticator.AUTH_NAME.Equals(authenticator)
&& properties.IsNonEmptyValueProvided(SFSessionProperty.PASSCODE))
{
DisablePoolingIfNotExplicitlyEnabled(properties, "mfa authentication without token cache");
}
}

Expand Down
52 changes: 23 additions & 29 deletions Snowflake.Data/Core/Session/SessionPool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ internal SFSession GetSession(string connStr, SecureString password, SecureStrin
{
s_logger.Debug("SessionPool::GetSession" + PoolIdentification());
SFSession session = null;
var sessionProperties = SFSessionProperties.ParseConnectionString(connStr, password, passcode);
ValidatePoolingIfPasscodeProvided(passcode, sessionProperties);
var sessionProperties = SFSessionProperties.ParseConnectionString(connStr, password);
ValidatePoolingIfPasscodeProvided(sessionProperties);
if (!GetPooling())
return NewNonPoolingSession(connStr, password, passcode);
var sessionOrCreateTokens = GetIdleSession(connStr);
Expand All @@ -156,42 +156,37 @@ internal SFSession GetSession(string connStr, SecureString password, SecureStrin
{
_sessionPoolEventHandler.OnSessionProvided(this);
}
ScheduleNewIdleSessions(connStr, password, passcode, sessionOrCreateTokens.BackgroundSessionCreationTokens());
ScheduleNewIdleSessions(connStr, password, RegisterSessionCreationsToEnsureMinPoolSize());
WarnAboutOverridenConfig();
return session ?? sessionOrCreateTokens.Session ?? NewSession(connStr, password, passcode, sessionOrCreateTokens.SessionCreationToken());
}

private void ValidatePoolingIfPasscodeProvided(SecureString passcode, SFSessionProperties sessionProperties)
private void ValidatePoolingIfPasscodeProvided(SFSessionProperties sessionProperties)
{
if (!GetPooling()) return;
var isUsingPasscode = ((passcode != null && !SecureStringHelper.Decode(passcode).IsNullOrEmpty()) ||
sessionProperties.IsNonEmptyValueProvided(SFSessionProperty.PASSCODE) ||
if (!GetPooling() || _poolConfig.MinPoolSize == 0) return;
var isUsingPasscode = (sessionProperties.IsNonEmptyValueProvided(SFSessionProperty.PASSCODE) ||
(sessionProperties.TryGetValue(SFSessionProperty.PASSCODEINPASSWORD, out var passcodeInPasswordValue) &&
bool.TryParse(passcodeInPasswordValue, out var isPasscodeinPassword) && isPasscodeinPassword));
if(!isUsingPasscode) return;
var isMfaAuthenticator = sessionProperties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator) &&
authenticator == MFACacheAuthenticator.AUTH_NAME;

if (isMfaAuthenticator) return;
if (sessionProperties.IsPoolingEnabledValueProvided)
if(isUsingPasscode && !isMfaAuthenticator)
{
const string ErrorMessage = "Could not get a pool because passcode was provided using a different authenticator than username_password_mfa";
const string ErrorMessage = "Could not use connection pool because passcode was provided using a different authenticator than username_password_mfa";
s_logger.Error(ErrorMessage + PoolIdentification());
throw new Exception(ErrorMessage);
}
s_logger.Warn("Pooling is disabled because passcode was provided using a different authenticator than username_password_mfa" + PoolIdentification());
_poolConfig.PoolingEnabled = false;
}

internal async Task<SFSession> GetSessionAsync(string connStr, SecureString password, SecureString passcode, CancellationToken cancellationToken)
{
s_logger.Debug("SessionPool::GetSessionAsync" + PoolIdentification());
SFSession session = null;
var sessionProperties = SFSessionProperties.ParseConnectionString(connStr, password, passcode);
ValidatePoolingIfPasscodeProvided(passcode, sessionProperties);
var sessionProperties = SFSessionProperties.ParseConnectionString(connStr, password);
ValidatePoolingIfPasscodeProvided(sessionProperties);
if (!GetPooling())
return await NewNonPoolingSessionAsync(connStr, password, passcode, cancellationToken).ConfigureAwait(false);
var sessionOrCreateTokens = GetIdleSession(connStr);
WarnAboutOverridenConfig();
if (sessionProperties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator) &&
authenticator == MFACacheAuthenticator.AUTH_NAME)
session = sessionOrCreateTokens.Session ??
Expand All @@ -201,21 +196,20 @@ await NewSessionAsync(connStr, password, passcode, sessionOrCreateTokens.Session
{
_sessionPoolEventHandler.OnSessionProvided(this);
}
ScheduleNewIdleSessions(connStr, password, passcode, sessionOrCreateTokens.BackgroundSessionCreationTokens());
WarnAboutOverridenConfig();
ScheduleNewIdleSessions(connStr, password, RegisterSessionCreationsToEnsureMinPoolSize());
return session ?? sessionOrCreateTokens.Session ?? await NewSessionAsync(connStr, password, passcode, sessionOrCreateTokens.SessionCreationToken(), cancellationToken).ConfigureAwait(false);
}

private void ScheduleNewIdleSessions(string connStr, SecureString password, SecureString passcode, List<SessionCreationToken> tokens)
private void ScheduleNewIdleSessions(string connStr, SecureString password, List<SessionCreationToken> tokens)
{
tokens.ForEach(token => ScheduleNewIdleSession(connStr, password, passcode, token));
tokens.ForEach(token => ScheduleNewIdleSession(connStr, password, token));
}

private void ScheduleNewIdleSession(string connStr, SecureString password, SecureString passcode, SessionCreationToken token)
private void ScheduleNewIdleSession(string connStr, SecureString password, SessionCreationToken token)
{
Task.Run(() =>
{
var session = NewSession(connStr, password, passcode, token);
var session = NewSession(connStr, password, null, token);
AddSession(session, false); // we don't want to ensure min pool size here because we could get into infinite recursion if expirationTimeout would be very low
});
}
Expand Down Expand Up @@ -258,7 +252,7 @@ private SessionOrCreationTokens GetIdleSession(string connStr)
return new SessionOrCreationTokens(session);
}
s_logger.Debug("SessionPool::GetIdleSession - no thread was waiting for a session, but could not find any idle session available in the pool" + PoolIdentification());
var sessionsCount = AllowedNumberOfNewSessionCreations(1);
var sessionsCount = Math.Min(1, AllowedNumberOfNewSessionCreations(1));
if (sessionsCount > 0)
{
// there is no need to wait for a session since we can create new ones
Expand All @@ -269,7 +263,7 @@ private SessionOrCreationTokens GetIdleSession(string connStr)
return new SessionOrCreationTokens(WaitForSession(connStr));
}

private List<SessionCreationToken> RegisterSessionCreationsWhenReturningSessionToPool()
private List<SessionCreationToken> RegisterSessionCreationsToEnsureMinPoolSize()
{
var count = AllowedNumberOfNewSessionCreations(0);
return RegisterSessionCreations(count);
Expand Down Expand Up @@ -501,15 +495,15 @@ internal bool AddSession(SFSession session, bool ensureMinPoolSize)
ReleaseBusySession(session);
if (ensureMinPoolSize)
{
ScheduleNewIdleSessions(ConnectionString, Password, session.Passcode, RegisterSessionCreationsWhenReturningSessionToPool()); // passcode is probably not fresh - it could be improved
ScheduleNewIdleSessions(ConnectionString, Password, RegisterSessionCreationsToEnsureMinPoolSize()); // passcode is probably not fresh - it could be improved
}
return false;
}

var result = ReturnSessionToPool(session, ensureMinPoolSize);
var wasSessionReturnedToPool = result.Item1;
var sessionCreationTokens = result.Item2;
ScheduleNewIdleSessions(ConnectionString, Password, session.Passcode, sessionCreationTokens); // passcode is probably not fresh - it could be improved
ScheduleNewIdleSessions(ConnectionString, Password, sessionCreationTokens);
return wasSessionReturnedToPool;
}

Expand All @@ -522,7 +516,7 @@ private Tuple<bool, List<SessionCreationToken>> ReturnSessionToPool(SFSession se
{
_busySessionsCounter.Decrease();
var sessionCreationTokens = ensureMinPoolSize
? RegisterSessionCreationsWhenReturningSessionToPool()
? RegisterSessionCreationsToEnsureMinPoolSize()
: SessionOrCreationTokens.s_emptySessionCreationTokenList;
var poolState = GetCurrentState();
s_logger.Debug($"Could not return session to pool {poolState}" + PoolIdentification());
Expand All @@ -537,7 +531,7 @@ private Tuple<bool, List<SessionCreationToken>> ReturnSessionToPool(SFSession se
if (session.IsExpired(_poolConfig.ExpirationTimeout, DateTimeOffset.UtcNow.ToUnixTimeMilliseconds())) // checking again because we could have spent some time waiting for a lock
{
var sessionCreationTokens = ensureMinPoolSize
? RegisterSessionCreationsWhenReturningSessionToPool()
? RegisterSessionCreationsToEnsureMinPoolSize()
: SessionOrCreationTokens.s_emptySessionCreationTokenList;
var poolState = GetCurrentState();
s_logger.Debug($"Could not return session to pool {poolState}" + PoolIdentification());
Expand All @@ -552,7 +546,7 @@ private Tuple<bool, List<SessionCreationToken>> ReturnSessionToPool(SFSession se
_idleSessions.Add(session);
_waitingForIdleSessionQueue.OnResourceIncrease();
var sessionCreationTokensAfterReturningToPool = ensureMinPoolSize
? RegisterSessionCreationsWhenReturningSessionToPool()
? RegisterSessionCreationsToEnsureMinPoolSize()
: SessionOrCreationTokens.s_emptySessionCreationTokenList;
var poolStateAfterReturningToPool = GetCurrentState();
s_logger.Debug($"returned session with sid {session.sessionId} to pool {poolStateAfterReturningToPool}" + PoolIdentification());
Expand Down

0 comments on commit a8dff4b

Please sign in to comment.