diff --git a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs index 554d0c2a9..e6ffcd7c9 100644 --- a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs @@ -2,25 +2,25 @@ * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. */ +using System; +using System.Data; using System.Data.Common; +using System.Diagnostics; using System.Net; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; +using NUnit.Framework; +using Snowflake.Data.Client; +using Snowflake.Data.Core; using Snowflake.Data.Core.Session; +using Snowflake.Data.Core.Tools; +using Snowflake.Data.Log; +using Snowflake.Data.Tests.Mock; using Snowflake.Data.Tests.Util; namespace Snowflake.Data.Tests.IntegrationTests { - using NUnit.Framework; - using Snowflake.Data.Client; - using System.Data; - using System; - using Snowflake.Data.Core; - using System.Threading.Tasks; - using System.Threading; - using Snowflake.Data.Log; - using System.Diagnostics; - using Snowflake.Data.Tests.Mock; - using System.Runtime.InteropServices; - using System.Net.Http; [TestFixture] class SFConnectionIT : SFBaseTest @@ -2272,6 +2272,52 @@ public void TestUseMultiplePoolsConnectionPoolByDefault() Assert.AreEqual(ConnectionPoolType.MultipleConnectionPool, poolVersion); } + [Test] + // [Ignore("This test requires manual interaction and therefore cannot be run in CI")] + public void TestMFATokenCachingWithPasscodeFromConnectionString() + { + // Use a connection with MFA enabled and set passcode property for mfa authentication. e.g. ConnectionString + ";authenticator=username_password_mfa;passcode=(set proper passcode)" + // ACCOUNT PARAMETER ALLOW_CLIENT_MFA_CACHING should be set to true in the account. + // On Mac/Linux OS default credential manager is in memory so please uncomment following line to use file based credential manager + // SnowflakeCredentialManagerFactory.UseFileCredentialManager(); + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString + = ConnectionString + + ";authenticator=username_password_mfa;application=DuoTest;minPoolSize=0;passcode=(set proper passcode)"; + + + // Authenticate to retrieve and store the token if doesn't exist or invalid + Task connectTask = conn.OpenAsync(CancellationToken.None); + connectTask.Wait(); + Assert.AreEqual(ConnectionState.Open, conn.State); + } + } + + [Test] + [Ignore("Requires manual steps and environment with mfa authentication enrolled")] // to enroll to mfa authentication edit your user profile + public void TestMfaWithPasswordConnectionUsingPasscodeWithSecureString() + { + // Use a connection with MFA enabled and Passcode property on connection instance. + // ACCOUNT PARAMETER ALLOW_CLIENT_MFA_CACHING should be set to true in the account. + // On Mac/Linux OS default credential manager is in memory so please uncomment following line to use file based credential manager + // SnowflakeCredentialManagerFactory.UseFileCredentialManager(); + // arrange + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.Passcode = SecureStringHelper.Encode("$(set proper passcode)"); + // manual action: stop here in breakpoint to provide proper passcode by: conn.Passcode = SecureStringHelper.Encode("..."); + conn.ConnectionString = ConnectionString + "minPoolSize=2;application=DuoTest;"; + + // act + Task connectTask = conn.OpenAsync(CancellationToken.None); + connectTask.Wait(); + + // assert + Assert.AreEqual(ConnectionState.Open, conn.State); + } + } + [Test] [TestCase("connection_timeout=5;")] [TestCase("")] diff --git a/Snowflake.Data.Tests/Mock/MockLoginMFATokenCacheRestRequester.cs b/Snowflake.Data.Tests/Mock/MockLoginMFATokenCacheRestRequester.cs new file mode 100644 index 000000000..163124b7d --- /dev/null +++ b/Snowflake.Data.Tests/Mock/MockLoginMFATokenCacheRestRequester.cs @@ -0,0 +1,89 @@ +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Snowflake.Data.Core; + +namespace Snowflake.Data.Tests.Mock +{ + using Microsoft.IdentityModel.Tokens; + + class MockLoginMFATokenCacheRestRequester: IMockRestRequester + { + internal Queue LoginRequests { get; } = new(); + + internal Queue LoginResponses { get; } = new(); + + public T Get(IRestRequest request) + { + return Task.Run(async () => await (GetAsync(request, CancellationToken.None)).ConfigureAwait(false)).Result; + } + + public Task GetAsync(IRestRequest request, CancellationToken cancellationToken) + { + return Task.FromResult((T)(object)null); + } + + public Task GetAsync(IRestRequest request, CancellationToken cancellationToken) + { + return Task.FromResult(null); + } + + public HttpResponseMessage Get(IRestRequest request) + { + return null; + } + + public T Post(IRestRequest postRequest) + { + return Task.Run(async () => await (PostAsync(postRequest, CancellationToken.None)).ConfigureAwait(false)).Result; + } + + public Task PostAsync(IRestRequest postRequest, CancellationToken cancellationToken) + { + SFRestRequest sfRequest = (SFRestRequest)postRequest; + if (sfRequest.jsonBody is LoginRequest) + { + LoginRequests.Enqueue((LoginRequest) sfRequest.jsonBody); + var responseData = this.LoginResponses.IsNullOrEmpty() ? new LoginResponseData() + { + token = "session_token", + masterToken = "master_token", + authResponseSessionInfo = new SessionInfo(), + nameValueParameter = new List() + } : this.LoginResponses.Dequeue(); + var authnResponse = new LoginResponse + { + data = responseData, + success = true + }; + + // login request return success + return Task.FromResult((T)(object)authnResponse); + } + else if (sfRequest.jsonBody is CloseResponse) + { + var authnResponse = new CloseResponse() + { + success = true + }; + + // login request return success + return Task.FromResult((T)(object)authnResponse); + } + throw new NotImplementedException(); + } + + public void setHttpClient(HttpClient httpClient) + { + // Nothing to do + } + + public void Reset() + { + LoginRequests.Clear(); + LoginResponses.Clear(); + } + } +} diff --git a/Snowflake.Data.Tests/Mock/MockSnowflakeDbConnection.cs b/Snowflake.Data.Tests/Mock/MockSnowflakeDbConnection.cs index c6d8f0698..2f7d0efc0 100644 --- a/Snowflake.Data.Tests/Mock/MockSnowflakeDbConnection.cs +++ b/Snowflake.Data.Tests/Mock/MockSnowflakeDbConnection.cs @@ -78,10 +78,10 @@ public override Task OpenAsync(CancellationToken cancellationToken) cancellationToken); } - + private void SetMockSession() { - SfSession = new SFSession(ConnectionString, Password, _restRequester); + SfSession = new SFSession(ConnectionString, Password, Passcode, EasyLoggingStarter.Instance, _restRequester); _connectionTimeout = (int)SfSession.connectionTimeout.TotalSeconds; @@ -92,7 +92,7 @@ private void OnSessionEstablished() { _connectionState = ConnectionState.Open; } - + protected override bool CanReuseSession(TransactionRollbackStatus transactionRollbackStatus) { return false; diff --git a/Snowflake.Data.Tests/UnitTests/ArrowResultSetTest.cs b/Snowflake.Data.Tests/UnitTests/ArrowResultSetTest.cs index 0405c7009..8c385ad95 100755 --- a/Snowflake.Data.Tests/UnitTests/ArrowResultSetTest.cs +++ b/Snowflake.Data.Tests/UnitTests/ArrowResultSetTest.cs @@ -33,7 +33,7 @@ public void BeforeTest() // by default generate Int32 values from 1 to RowCount PrepareTestCase(SFDataType.FIXED, 0, Enumerable.Range(1, RowCount).ToArray()); } - + [Test] public void TestResultFormatIsArrow() { @@ -140,7 +140,7 @@ public void TestGetValueReturnsNull() var arrowResultSet = new ArrowResultSet(responseData, sfStatement, new CancellationToken()); arrowResultSet.Next(); - + Assert.AreEqual(true, arrowResultSet.IsDBNull(0)); Assert.AreEqual(DBNull.Value, arrowResultSet.GetValue(0)); } @@ -152,7 +152,7 @@ public void TestGetDecimal() TestGetNumber(testValues); } - + [Test] public void TestGetNumber64() { @@ -165,7 +165,7 @@ public void TestGetNumber64() public void TestGetNumber32() { var testValues = new int[] { 0, 100, -100, Int32.MaxValue, Int32.MinValue }; - + TestGetNumber(testValues); } @@ -176,7 +176,7 @@ public void TestGetNumber16() TestGetNumber(testValues); } - + [Test] public void TestGetNumber8() { @@ -200,7 +200,7 @@ private void TestGetNumber(IEnumerable testValues) Assert.AreEqual(expectedValue, _arrowResultSet.GetDecimal(ColumnIndex)); Assert.AreEqual(expectedValue, _arrowResultSet.GetDouble(ColumnIndex)); Assert.AreEqual(expectedValue, _arrowResultSet.GetFloat(ColumnIndex)); - + if (expectedValue >= Int64.MinValue && expectedValue <= Int64.MaxValue) { // get integer value @@ -230,7 +230,7 @@ public void TestGetBoolean() var testValues = new bool[] { true, false }; PrepareTestCase(SFDataType.BOOLEAN, 0, testValues); - + foreach (var testValue in testValues) { _arrowResultSet.Next(); @@ -245,7 +245,7 @@ public void TestGetReal() var testValues = new double[] { 0, Double.MinValue, Double.MaxValue }; PrepareTestCase(SFDataType.REAL, 0, testValues); - + foreach (var testValue in testValues) { _arrowResultSet.Next(); @@ -253,7 +253,7 @@ public void TestGetReal() Assert.AreEqual(testValue, _arrowResultSet.GetDouble(ColumnIndex)); } } - + [Test] public void TestGetText() { @@ -264,7 +264,7 @@ public void TestGetText() }; PrepareTestCase(SFDataType.TEXT, 0, testValues); - + foreach (var testValue in testValues) { _arrowResultSet.Next(); @@ -272,7 +272,7 @@ public void TestGetText() Assert.AreEqual(testValue, _arrowResultSet.GetString(ColumnIndex)); } } - + [Test] public void TestGetTextWithOneChar() { @@ -290,14 +290,14 @@ public void TestGetTextWithOneChar() #endif PrepareTestCase(SFDataType.TEXT, 0, testValues); - + foreach (var testValue in testValues) { _arrowResultSet.Next(); Assert.AreEqual(testValue, _arrowResultSet.GetChar(ColumnIndex)); } } - + [Test] public void TestGetArray() { @@ -308,7 +308,7 @@ public void TestGetArray() }; PrepareTestCase(SFDataType.ARRAY, 0, testValues); - + foreach (var testValue in testValues) { _arrowResultSet.Next(); @@ -320,7 +320,7 @@ public void TestGetArray() Assert.AreEqual(testValue.Length, str.Length); } } - + [Test] public void TestGetBinary() { @@ -342,7 +342,7 @@ public void TestGetBinary() Assert.AreEqual(testValue[j], buffer[j], "position " + j); } } - + [Test] public void TestGetDate() { @@ -354,7 +354,7 @@ public void TestGetDate() }; PrepareTestCase(SFDataType.DATE, 0, testValues); - + foreach (var testValue in testValues) { _arrowResultSet.Next(); @@ -362,7 +362,7 @@ public void TestGetDate() Assert.AreEqual(testValue, _arrowResultSet.GetDateTime(ColumnIndex)); } } - + [Test] public void TestGetTime() { @@ -384,7 +384,7 @@ public void TestGetTime() Assert.AreEqual(testValue, _arrowResultSet.GetValue(ColumnIndex)); Assert.AreEqual(testValue, _arrowResultSet.GetDateTime(ColumnIndex)); } - } + } } [Test] @@ -513,10 +513,10 @@ private QueryExecResponseData PrepareResponseData(RecordBatch recordBatch, SFDat return new QueryExecResponseData { rowType = recordBatch.Schema.FieldsList - .Select(col => + .Select(col => new ExecResponseRowType { - name = col.Name, + name = col.Name, type = sfType.ToString(), scale = scale }).ToList(), @@ -531,7 +531,7 @@ private string ConvertToBase64String(RecordBatch recordBatch) { if (recordBatch == null) return ""; - + using (var stream = new MemoryStream()) { using (var writer = new ArrowStreamWriter(stream, recordBatch.Schema)) @@ -542,12 +542,12 @@ private string ConvertToBase64String(RecordBatch recordBatch) return Convert.ToBase64String(stream.ToArray()); } } - + private SFStatement PrepareStatement() { SFSession session = new SFSession("user=user;password=password;account=account;", null); return new SFStatement(session); } - + } } diff --git a/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerMFATest.cs b/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerMFATest.cs new file mode 100644 index 000000000..a739e759e --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerMFATest.cs @@ -0,0 +1,136 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + + + +namespace Snowflake.Data.Tests.UnitTests +{ + using System; + using System.Linq; + using System.Security; + using Mock; + using NUnit.Framework; + using Snowflake.Data.Core; + using Snowflake.Data.Core.Session; + using Snowflake.Data.Client; + using Snowflake.Data.Core.Tools; + using Snowflake.Data.Tests.Util; + + [TestFixture, NonParallelizable] + class ConnectionPoolManagerMFATest + { + private readonly ConnectionPoolManager _connectionPoolManager = new ConnectionPoolManager(); + private const string ConnectionStringMFACache = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=2;passcode=12345;authenticator=username_password_mfa"; + private static PoolConfig s_poolConfig; + private static MockLoginMFATokenCacheRestRequester s_restRequester; + + [OneTimeSetUp] + public static void BeforeAllTests() + { + s_poolConfig = new PoolConfig(); + s_restRequester = new MockLoginMFATokenCacheRestRequester(); + SnowflakeDbConnectionPool.ForceConnectionPoolVersion(ConnectionPoolType.MultipleConnectionPool); + SessionPool.SessionFactory = new MockSessionFactoryMFA(s_restRequester); + } + + [OneTimeTearDown] + public static void AfterAllTests() + { + s_poolConfig.Reset(); + SessionPool.SessionFactory = new SessionFactory(); + } + + [SetUp] + public void BeforeEach() + { + _connectionPoolManager.ClearAllPools(); + s_restRequester.Reset(); + } + + [Test] + public void TestPoolManagerReturnsSessionPoolForGivenConnectionStringUsingMFA() + { + // Arrange + var testToken = "testToken1234"; + s_restRequester.LoginResponses.Enqueue(new LoginResponseData() + { + mfaToken = testToken, + authResponseSessionInfo = new SessionInfo() + }); + s_restRequester.LoginResponses.Enqueue(new LoginResponseData() + { + mfaToken = testToken, + authResponseSessionInfo = new SessionInfo() + }); + // Act + var session = _connectionPoolManager.GetSession(ConnectionStringMFACache, null, null); + + // Assert + Awaiter.WaitUntilConditionOrTimeout(() => s_restRequester.LoginRequests.Count == 2, TimeSpan.FromSeconds(15)); + Assert.AreEqual(2, s_restRequester.LoginRequests.Count); + var loginRequest1 = s_restRequester.LoginRequests.Dequeue(); + Assert.AreEqual(string.Empty, loginRequest1.data.Token); + Assert.AreEqual(testToken, SecureStringHelper.Decode(session._mfaToken)); + Assert.IsTrue(loginRequest1.data.SessionParameters.TryGetValue(SFSessionParameter.CLIENT_REQUEST_MFA_TOKEN, out var value) && (bool)value); + Assert.AreEqual("passcode", loginRequest1.data.extAuthnDuoMethod); + var loginRequest2 = s_restRequester.LoginRequests.Dequeue(); + Assert.AreEqual(testToken, loginRequest2.data.Token); + Assert.IsTrue(loginRequest2.data.SessionParameters.TryGetValue(SFSessionParameter.CLIENT_REQUEST_MFA_TOKEN, out var value1) && (bool)value1); + Assert.AreEqual("passcode", loginRequest2.data.extAuthnDuoMethod); + } + + [Test] + public void TestPoolManagerShouldThrowExceptionIfForcePoolingWithPasscodeNotUsingMFATokenCacheAuthenticator() + { + // Arrange + var connectionString = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=2;passcode=12345;POOLINGENABLED=true"; + // Act and assert + var thrown = Assert.Throws(() =>_connectionPoolManager.GetSession(connectionString, null,null)); + Assert.That(thrown.Message, Does.Contain("Passcode with MinPoolSize feature of connection pool allowed only for username_password_mfa authentication")); + } + + [Test] + public void TestPoolManagerShouldThrowExceptionIfForcePoolingWithPasscodeAsSecureStringNotUsingMFATokenCacheAuthenticator() + { + // Arrange + var connectionString = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=2;POOLINGENABLED=true"; + // Act and assert + var thrown = Assert.Throws(() =>_connectionPoolManager.GetSession(connectionString, null,SecureStringHelper.Encode("12345"))); + Assert.That(thrown.Message, Does.Contain("Passcode with MinPoolSize feature of connection pool allowed only for username_password_mfa authentication")); + } + + [Test] + public void TestPoolManagerShouldNotThrowExceptionIfForcePoolingWithPasscodeNotUsingMFATokenCacheAuthenticator() + { + // Arrange + var connectionString = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=2;passcode=12345;POOLINGENABLED=false"; + // Act and assert + Assert.DoesNotThrow(() =>_connectionPoolManager.GetSession(connectionString, null, null)); + } + + [Test] + public void TestPoolManagerShouldNotThrowExceptionIfMinPoolSizeZeroNotUsingMFATokenCacheAuthenticator() + { + // Arrange + var connectionString = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=0;passcode=12345;POOLINGENABLED=true"; + // Act and assert + Assert.DoesNotThrow(() =>_connectionPoolManager.GetSession(connectionString, null, null)); + } + } + + class MockSessionFactoryMFA : ISessionFactory + { + private readonly IMockRestRequester restRequester; + + public MockSessionFactoryMFA(IMockRestRequester restRequester) + { + this.restRequester = restRequester; + } + + public SFSession NewSession(string connectionString, SecureString password, SecureString passcode) + { + return new SFSession(connectionString, password, passcode, EasyLoggingStarter.Instance, restRequester); + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs b/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs index b53487d60..0293d6571 100644 --- a/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs +++ b/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs @@ -111,7 +111,7 @@ public void TestDifferentPoolsAreReturnedForDifferentConnectionStrings() public void TestGetSessionWorksForSpecifiedConnectionString() { // Act - var sfSession = _connectionPoolManager.GetSession(ConnectionString1, null); + var sfSession = _connectionPoolManager.GetSession(ConnectionString1, null, null); // Assert Assert.AreEqual(ConnectionString1, sfSession.ConnectionString); @@ -122,7 +122,7 @@ public void TestGetSessionWorksForSpecifiedConnectionString() public async Task TestGetSessionAsyncWorksForSpecifiedConnectionString() { // Act - var sfSession = await _connectionPoolManager.GetSessionAsync(ConnectionString1, null, CancellationToken.None); + var sfSession = await _connectionPoolManager.GetSessionAsync(ConnectionString1, null, null, CancellationToken.None); // Assert Assert.AreEqual(ConnectionString1, sfSession.ConnectionString); @@ -133,7 +133,7 @@ public async Task TestGetSessionAsyncWorksForSpecifiedConnectionString() public void TestCountingOfSessionProvidedByPool() { // Act - _connectionPoolManager.GetSession(ConnectionString1, null); + _connectionPoolManager.GetSession(ConnectionString1, null, null); // Assert var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, null); @@ -144,7 +144,7 @@ public void TestCountingOfSessionProvidedByPool() public void TestCountingOfSessionReturnedBackToPool() { // Arrange - var sfSession = _connectionPoolManager.GetSession(ConnectionString1, null); + var sfSession = _connectionPoolManager.GetSession(ConnectionString1, null, null); // Act _connectionPoolManager.AddSession(sfSession); @@ -285,8 +285,8 @@ public void TestGetMaxPoolSizeOnManagerLevelWhenAllPoolsEqual() public void TestGetCurrentPoolSizeReturnsSumOfPoolSizes() { // Arrange - EnsurePoolSize(ConnectionString1, null, 2); - EnsurePoolSize(ConnectionString2, null, 3); + EnsurePoolSize(ConnectionString1, null, null,2); + EnsurePoolSize(ConnectionString2, null, null, 3); // act var poolSize = _connectionPoolManager.GetCurrentPoolSize(); @@ -300,7 +300,7 @@ public void TestReturnPoolForSecurePassword() { // arrange const string AnotherPassword = "anotherPassword"; - EnsurePoolSize(ConnectionStringWithoutPassword, _password3, 1); + EnsurePoolSize(ConnectionStringWithoutPassword, _password3, null, 1); // act var pool = _connectionPoolManager.GetPool(ConnectionStringWithoutPassword, SecureStringHelper.Encode(AnotherPassword)); // a new pool has been created because the password is different @@ -315,9 +315,9 @@ public void TestReturnDifferentPoolWhenPasswordProvidedInDifferentWay() { // arrange var connectionStringWithPassword = $"{ConnectionStringWithoutPassword}password={SecureStringHelper.Decode(_password3)}"; - EnsurePoolSize(ConnectionStringWithoutPassword, _password3, 2); - EnsurePoolSize(connectionStringWithPassword, null, 5); - EnsurePoolSize(connectionStringWithPassword, _password3, 8); + EnsurePoolSize(ConnectionStringWithoutPassword, _password3, null, 2); + EnsurePoolSize(connectionStringWithPassword, null, null, 5); + EnsurePoolSize(connectionStringWithPassword, _password3, null, 8); // act var pool1 = _connectionPoolManager.GetPool(ConnectionStringWithoutPassword, _password3); @@ -360,13 +360,13 @@ public void TestPoolDoesNotSerializePassword() Assert.IsFalse(serializedPool.Contains(password)); } - private void EnsurePoolSize(string connectionString, SecureString password, int requiredCurrentSize) + private void EnsurePoolSize(string connectionString, SecureString password, SecureString passcode, int requiredCurrentSize) { var sessionPool = _connectionPoolManager.GetPool(connectionString, password); sessionPool.SetMaxPoolSize(requiredCurrentSize); for (var i = 0; i < requiredCurrentSize; i++) { - _connectionPoolManager.GetSession(connectionString, password); + _connectionPoolManager.GetSession(connectionString, password, passcode); } Assert.AreEqual(requiredCurrentSize, sessionPool.GetCurrentPoolSize()); } @@ -374,9 +374,9 @@ private void EnsurePoolSize(string connectionString, SecureString password, int class MockSessionFactory : ISessionFactory { - public SFSession NewSession(string connectionString, SecureString password) + public SFSession NewSession(string connectionString, SecureString password, SecureString passcode) { - var mockSfSession = new Mock(connectionString, password); + var mockSfSession = new Mock(connectionString, password, passcode, EasyLoggingStarter.Instance); mockSfSession.Setup(x => x.Open()).Verifiable(); mockSfSession.Setup(x => x.OpenAsync(default)).Returns(Task.FromResult(this)); mockSfSession.Setup(x => x.IsNotOpen()).Returns(false); diff --git a/Snowflake.Data.Tests/UnitTests/CredentialManager/SFCredentialManagerTest.cs b/Snowflake.Data.Tests/UnitTests/CredentialManager/SFCredentialManagerTest.cs new file mode 100644 index 000000000..2b5e7fa4f --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/CredentialManager/SFCredentialManagerTest.cs @@ -0,0 +1,297 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using System.IO; +using System.Runtime.InteropServices; +using Mono.Unix; +using Mono.Unix.Native; +using Moq; +using NUnit.Framework; +using Snowflake.Data.Client; +using Snowflake.Data.Core.CredentialManager.Infrastructure; +using Snowflake.Data.Core.Tools; + +namespace Snowflake.Data.Tests.UnitTests.CredentialManager +{ + public abstract class SFBaseCredentialManagerTest + { + protected ISnowflakeCredentialManager _credentialManager; + + [Test] + public void TestSavingAndRemovingCredentials() + { + // arrange + var key = "mockKey"; + var expectedToken = "token"; + + // act + _credentialManager.SaveCredentials(key, expectedToken); + + // assert + Assert.AreEqual(expectedToken, _credentialManager.GetCredentials(key)); + + // act + _credentialManager.RemoveCredentials(key); + + // assert + Assert.IsTrue(string.IsNullOrEmpty(_credentialManager.GetCredentials(key))); + } + + [Test] + public void TestSavingCredentialsForAnExistingKey() + { + // arrange + var key = "mockKey"; + var firstExpectedToken = "mockToken1"; + var secondExpectedToken = "mockToken2"; + + // act + _credentialManager.SaveCredentials(key, firstExpectedToken); + + // assert + Assert.AreEqual(firstExpectedToken, _credentialManager.GetCredentials(key)); + + // act + _credentialManager.SaveCredentials(key, secondExpectedToken); + + // assert + Assert.AreEqual(secondExpectedToken, _credentialManager.GetCredentials(key)); + + // act + _credentialManager.RemoveCredentials(key); + + // assert + Assert.IsTrue(string.IsNullOrEmpty(_credentialManager.GetCredentials(key))); + + } + + [Test] + public void TestRemovingCredentialsForKeyThatDoesNotExist() + { + // arrange + var key = "mockKey"; + + // act + _credentialManager.RemoveCredentials(key); + + // assert + Assert.IsTrue(string.IsNullOrEmpty(_credentialManager.GetCredentials(key))); + } + } + + [TestFixture] + [Platform("Win")] + public class SFNativeCredentialManagerTest : SFBaseCredentialManagerTest + { + [SetUp] + public void SetUp() + { + _credentialManager = SFCredentialManagerWindowsNativeImpl.Instance; + } + } + + [TestFixture] + public class SFInMemoryCredentialManagerTest : SFBaseCredentialManagerTest + { + [SetUp] + public void SetUp() + { + _credentialManager = SFCredentialManagerInMemoryImpl.Instance; + } + } + + [TestFixture] + public class SFFileCredentialManagerTest : SFBaseCredentialManagerTest + { + [SetUp] + public void SetUp() + { + _credentialManager = SFCredentialManagerFileImpl.Instance; + } + } + + [TestFixture, NonParallelizable] + class SFCredentialManagerTest + { + ISnowflakeCredentialManager _credentialManager; + + [ThreadStatic] + private static Mock t_fileOperations; + + [ThreadStatic] + private static Mock t_directoryOperations; + + [ThreadStatic] + private static Mock t_unixOperations; + + [ThreadStatic] + private static Mock t_environmentOperations; + + private const string CustomJsonDir = "testdirectory"; + + private static readonly string s_customJsonPath = Path.Combine(CustomJsonDir, SFCredentialManagerFileImpl.CredentialCacheFileName); + + [SetUp] public void SetUp() + { + t_fileOperations = new Mock(); + t_directoryOperations = new Mock(); + t_unixOperations = new Mock(); + t_environmentOperations = new Mock(); + SnowflakeCredentialManagerFactory.UseInMemoryCredentialManager(); + } + + [TearDown] public void TearDown() + { + SnowflakeCredentialManagerFactory.UseDefaultCredentialManager(); + } + + [Test] + public void TestUsingDefaultCredentialManager() + { + // arrange + SnowflakeCredentialManagerFactory.UseDefaultCredentialManager(); + + // act + _credentialManager = SnowflakeCredentialManagerFactory.GetCredentialManager(); + + // assert + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + Assert.IsInstanceOf(_credentialManager); + } + else + { + Assert.IsInstanceOf(_credentialManager); + } + } + + [Test] + public void TestSettingCustomCredentialManager() + { + // arrange + SnowflakeCredentialManagerFactory.SetCredentialManager(SFCredentialManagerFileImpl.Instance); + + // act + _credentialManager = SnowflakeCredentialManagerFactory.GetCredentialManager(); + + // assert + Assert.IsInstanceOf(_credentialManager); + } + + [Test] + public void TestUseFileImplCredentialManager() + { + // arrange + SnowflakeCredentialManagerFactory.UseFileCredentialManager(); + + // act + _credentialManager = SnowflakeCredentialManagerFactory.GetCredentialManager(); + + // assert + Assert.IsInstanceOf(_credentialManager); + } + + [Test] + public void TestThatThrowsErrorWhenTryingToSetCredentialManagerToNull() + { + // act and assert + var exception = Assert.Throws(() => SnowflakeCredentialManagerFactory.SetCredentialManager(null)); + Assert.IsTrue(exception.Message.Contains("Credential manager cannot be null. If you want to use the default credential manager, please call the UseDefaultCredentialManager method.")); + + } + + [Test] + public void TestThatThrowsErrorWhenCacheFileIsNotCreated() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + Assert.Ignore("skip test on Windows"); + } + + // arrange + t_directoryOperations + .Setup(d => d.Exists(s_customJsonPath)) + .Returns(false); + t_unixOperations + .Setup(u => u.CreateFileWithPermissions(s_customJsonPath, + FilePermissions.S_IRUSR | FilePermissions.S_IWUSR | FilePermissions.S_IXUSR)) + .Returns(-1); + t_environmentOperations + .Setup(e => e.GetEnvironmentVariable(SFCredentialManagerFileImpl.CredentialCacheDirectoryEnvironmentName)) + .Returns(CustomJsonDir); + SnowflakeCredentialManagerFactory.SetCredentialManager(new SFCredentialManagerFileImpl(t_fileOperations.Object, t_directoryOperations.Object, t_unixOperations.Object, t_environmentOperations.Object)); + _credentialManager = SnowflakeCredentialManagerFactory.GetCredentialManager(); + + // act + var thrown = Assert.Throws(() => _credentialManager.SaveCredentials("key", "token")); + + // assert + Assert.That(thrown.Message, Does.Contain("Failed to create the JSON token cache file")); + } + + [Test] + public void TestThatThrowsErrorWhenCacheFileCanBeAccessedByOthers() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + Assert.Ignore("skip test on Windows"); + } + + // arrange + t_unixOperations + .Setup(u => u.CreateFileWithPermissions(s_customJsonPath, + FilePermissions.S_IRUSR | FilePermissions.S_IWUSR | FilePermissions.S_IXUSR)) + .Returns(0); + t_unixOperations + .Setup(u => u.GetFilePermissions(s_customJsonPath)) + .Returns(FileAccessPermissions.AllPermissions); + t_environmentOperations + .Setup(e => e.GetEnvironmentVariable(SFCredentialManagerFileImpl.CredentialCacheDirectoryEnvironmentName)) + .Returns(CustomJsonDir); + SnowflakeCredentialManagerFactory.SetCredentialManager(new SFCredentialManagerFileImpl(t_fileOperations.Object, t_directoryOperations.Object, t_unixOperations.Object, t_environmentOperations.Object)); + _credentialManager = SnowflakeCredentialManagerFactory.GetCredentialManager(); + + // act + var thrown = Assert.Throws(() => _credentialManager.SaveCredentials("key", "token")); + + // assert + Assert.That(thrown.Message, Does.Contain("Permission for the JSON token cache file should contain only the owner access")); + } + + [Test] + public void TestThatJsonFileIsCheckedIfAlreadyExists() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + Assert.Ignore("skip test on Windows"); + } + + // arrange + t_unixOperations + .Setup(u => u.CreateFileWithPermissions(s_customJsonPath, + FilePermissions.S_IRUSR | FilePermissions.S_IWUSR | FilePermissions.S_IXUSR)) + .Returns(0); + t_unixOperations + .Setup(u => u.GetFilePermissions(s_customJsonPath)) + .Returns(FileAccessPermissions.UserReadWriteExecute); + t_environmentOperations + .Setup(e => e.GetEnvironmentVariable(SFCredentialManagerFileImpl.CredentialCacheDirectoryEnvironmentName)) + .Returns(CustomJsonDir); + t_fileOperations + .SetupSequence(f => f.Exists(s_customJsonPath)) + .Returns(false) + .Returns(true); + + SnowflakeCredentialManagerFactory.SetCredentialManager(new SFCredentialManagerFileImpl(t_fileOperations.Object, t_directoryOperations.Object, t_unixOperations.Object, t_environmentOperations.Object)); + _credentialManager = SnowflakeCredentialManagerFactory.GetCredentialManager(); + + // act + _credentialManager.SaveCredentials("key", "token"); + + // assert + t_fileOperations.Verify(f => f.Exists(s_customJsonPath), Times.Exactly(2)); + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs b/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs index a57a9fb74..044ac5ddc 100644 --- a/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs @@ -25,7 +25,7 @@ public void TestThatPropertiesAreParsed(TestCase testcase) testcase.SecurePassword); // assert - CollectionAssert.AreEquivalent(testcase.ExpectedProperties, properties); + CollectionAssert.IsSubsetOf(testcase.ExpectedProperties, properties); } [Test] @@ -104,6 +104,76 @@ public void TestFailWhenNoPasswordProvided(string connectionString, string passw Assert.That(exception.Message, Does.Contain("Required property PASSWORD is not provided")); } + [Test] + public void TestParsePasscode() + { + // arrange + var expectedPasscode = "abc"; + var connectionString = $"ACCOUNT=testaccount;USER=testuser;PASSWORD=testpassword;PASSCODE={expectedPasscode}"; + + // act + var properties = SFSessionProperties.ParseConnectionString(connectionString, null); + + // assert + Assert.AreEqual(expectedPasscode, properties[SFSessionProperty.PASSCODE]); + } + + [Test] + public void TestUsePasscodeFromSecureString() + { + // arrange + var expectedPasscode = "abc"; + var connectionString = $"ACCOUNT=testaccount;USER=testuser;PASSWORD=testpassword"; + var securePasscode = SecureStringHelper.Encode(expectedPasscode); + + // act + var properties = SFSessionProperties.ParseConnectionString(connectionString, null, securePasscode); + + // assert + Assert.AreEqual(expectedPasscode, properties[SFSessionProperty.PASSCODE]); + } + + [Test] + [TestCase("ACCOUNT=testaccount;USER=testuser;PASSWORD=testpassword;")] + [TestCase("ACCOUNT=testaccount;USER=testuser;PASSWORD=testpassword;PASSCODE=")] + public void TestDoNotParsePasscodeWhenNotProvided(string connectionString) + { + // act + var properties = SFSessionProperties.ParseConnectionString(connectionString, null); + + // assert + Assert.False(properties.TryGetValue(SFSessionProperty.PASSCODE, out _)); + } + + [Test] + [TestCase("ACCOUNT=testaccount;USER=testuser;PASSWORD=testpassword;", "false")] + [TestCase("ACCOUNT=testaccount;USER=testuser;PASSWORD=testpassword;passcodeInPassword=", "false")] + [TestCase("ACCOUNT=testaccount;USER=testuser;PASSWORD=testpassword;passcodeInPassword=true", "true")] + [TestCase("ACCOUNT=testaccount;USER=testuser;PASSWORD=testpassword;passcodeInPassword=TRUE", "TRUE")] + [TestCase("ACCOUNT=testaccount;USER=testuser;PASSWORD=testpassword;passcodeInPassword=false", "false")] + [TestCase("ACCOUNT=testaccount;USER=testuser;PASSWORD=testpassword;passcodeInPassword=FALSE", "FALSE")] + public void TestParsePasscodeInPassword(string connectionString, string expectedPasscodeInPassword) + { + // act + var properties = SFSessionProperties.ParseConnectionString(connectionString, null); + + // assert + Assert.IsTrue(properties.TryGetValue(SFSessionProperty.PASSCODEINPASSWORD, out var passcodeInPassword)); + Assert.AreEqual(expectedPasscodeInPassword, passcodeInPassword); + } + + [Test] + public void TestFailWhenInvalidPasscodeInPassword() + { + // arrange + var invalidConnectionString = "ACCOUNT=testaccount;USER=testuser;PASSWORD=testpassword;passcodeInPassword=abc"; + + // act + var thrown = Assert.Throws(() => SFSessionProperties.ParseConnectionString(invalidConnectionString, null)); + + Assert.That(thrown.Message, Does.Contain("Invalid parameter value for PASSCODEINPASSWORD")); + } + [Test] [TestCase("DB", SFSessionProperty.DB, "\"testdb\"")] [TestCase("SCHEMA", SFSessionProperty.SCHEMA, "\"quotedSchema\"")] @@ -222,7 +292,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.PASSCODEINPASSWORD, DefaultValue(SFSessionProperty.PASSCODEINPASSWORD) } } }; @@ -258,7 +329,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.PASSCODEINPASSWORD, DefaultValue(SFSessionProperty.PASSCODEINPASSWORD) } } }; var testCaseWithProxySettings = new TestCase() @@ -296,7 +368,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.PASSCODEINPASSWORD, DefaultValue(SFSessionProperty.PASSCODEINPASSWORD) } }, ConnectionString = $"ACCOUNT={defAccount};USER={defUser};PASSWORD={defPassword};useProxy=true;proxyHost=proxy.com;proxyPort=1234;nonProxyHosts=localhost" @@ -336,7 +409,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.PASSCODEINPASSWORD, DefaultValue(SFSessionProperty.PASSCODEINPASSWORD) } }, ConnectionString = $"ACCOUNT={defAccount};USER={defUser};PASSWORD={defPassword};proxyHost=proxy.com;proxyPort=1234;nonProxyHosts=localhost" @@ -375,7 +449,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.PASSCODEINPASSWORD, DefaultValue(SFSessionProperty.PASSCODEINPASSWORD) } } }; var testCaseWithIncludeRetryReason = new TestCase() @@ -411,7 +486,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.PASSCODEINPASSWORD, DefaultValue(SFSessionProperty.PASSCODEINPASSWORD) } } }; var testCaseWithDisableQueryContextCache = new TestCase() @@ -446,7 +522,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.PASSCODEINPASSWORD, DefaultValue(SFSessionProperty.PASSCODEINPASSWORD) } }, ConnectionString = $"ACCOUNT={defAccount};USER={defUser};PASSWORD={defPassword};DISABLEQUERYCONTEXTCACHE=true" @@ -483,7 +560,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.PASSCODEINPASSWORD, DefaultValue(SFSessionProperty.PASSCODEINPASSWORD) } }, ConnectionString = $"ACCOUNT={defAccount};USER={defUser};PASSWORD={defPassword};DISABLE_CONSOLE_LOGIN=false" @@ -522,7 +600,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.PASSCODEINPASSWORD, DefaultValue(SFSessionProperty.PASSCODEINPASSWORD) } } }; var testCaseUnderscoredAccountName = new TestCase() @@ -558,7 +637,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.PASSCODEINPASSWORD, DefaultValue(SFSessionProperty.PASSCODEINPASSWORD) } } }; var testCaseUnderscoredAccountNameWithEnabledAllowUnderscores = new TestCase() @@ -594,9 +674,11 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.PASSCODEINPASSWORD, DefaultValue(SFSessionProperty.PASSCODEINPASSWORD) } } }; + var testQueryTag = "Test QUERY_TAG 12345"; var testCaseQueryTag = new TestCase() { @@ -632,7 +714,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.PASSCODEINPASSWORD, DefaultValue(SFSessionProperty.PASSCODEINPASSWORD) } } }; diff --git a/Snowflake.Data.Tests/UnitTests/SFSessionTest.cs b/Snowflake.Data.Tests/UnitTests/SFSessionTest.cs index 262122b2d..969e5cadf 100644 --- a/Snowflake.Data.Tests/UnitTests/SFSessionTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFSessionTest.cs @@ -5,6 +5,7 @@ using Newtonsoft.Json; using Snowflake.Data.Core; using NUnit.Framework; +using Snowflake.Data.Core.Tools; using Snowflake.Data.Tests.Mock; namespace Snowflake.Data.Tests.UnitTests @@ -99,7 +100,7 @@ public void TestThatConfiguresEasyLogging(string configPath) : $"{simpleConnectionString}client_config_file={configPath};"; // act - new SFSession(connectionString, null, easyLoggingStarter.Object); + new SFSession(connectionString, null, null, easyLoggingStarter.Object); // assert easyLoggingStarter.Verify(starter => starter.Init(configPath)); @@ -157,5 +158,166 @@ public void TestHandlePasswordWithQuotations() // assert Assert.AreEqual(loginRequest.data.password, deserializedLoginRequest.data.password); } + + [Test] + public void TestHandlePasscodeParameter() + { + // arrange + var passcode = "123456"; + MockLoginStoringRestRequester restRequester = new MockLoginStoringRestRequester(); + SFSession sfSession = new SFSession($"account=test;user=test;password=test;passcode={passcode}", null, restRequester); + + // act + sfSession.Open(); + + // assert + Assert.AreEqual(1, restRequester.LoginRequests.Count); + var loginRequest = restRequester.LoginRequests[0]; + Assert.AreEqual(passcode, loginRequest.data.passcode); + Assert.AreEqual("passcode", loginRequest.data.extAuthnDuoMethod); + } + + [Test] + public void TestHandlePasscodeAsSecureString() + { + // arrange + var passcode = "123456"; + MockLoginStoringRestRequester restRequester = new MockLoginStoringRestRequester(); + SFSession sfSession = new SFSession($"account=test;user=test;password=test;", null, SecureStringHelper.Encode(passcode), EasyLoggingStarter.Instance, restRequester); + + // act + sfSession.Open(); + + // assert + Assert.AreEqual(1, restRequester.LoginRequests.Count); + var loginRequest = restRequester.LoginRequests[0]; + Assert.AreEqual(passcode, loginRequest.data.passcode); + Assert.AreEqual("passcode", loginRequest.data.extAuthnDuoMethod); + } + + [Test] + public void TestHandlePasscodeInPasswordParameter() + { + // arrange + var passcode = "123456"; + MockLoginStoringRestRequester restRequester = new MockLoginStoringRestRequester(); + SFSession sfSession = new SFSession($"account=test;user=test;password=test{passcode};passcodeInPassword=true;", null, restRequester); + + // act + sfSession.Open(); + + // assert + Assert.AreEqual(1, restRequester.LoginRequests.Count); + var loginRequest = restRequester.LoginRequests[0]; + Assert.IsNull(loginRequest.data.passcode); + Assert.AreEqual("passcode", loginRequest.data.extAuthnDuoMethod); + } + + [Test] + public void TestPushWhenNoPasscodeAndPasscodeInPasswordIsFalse() + { + // arrange + var passcode = "123456"; + MockLoginStoringRestRequester restRequester = new MockLoginStoringRestRequester(); + SFSession sfSession = new SFSession($"account=test;user=test;password=test;passcodeInPassword=false;", null, restRequester); + + // act + sfSession.Open(); + + // assert + Assert.AreEqual(1, restRequester.LoginRequests.Count); + var loginRequest = restRequester.LoginRequests[0]; + Assert.IsNull(loginRequest.data.passcode); + Assert.AreEqual("push", loginRequest.data.extAuthnDuoMethod); + } + + [Test] + public void TestPushAsDefaultSecondaryAuthentication() + { + // arrange + MockLoginStoringRestRequester restRequester = new MockLoginStoringRestRequester(); + SFSession sfSession = new SFSession($"account=test;user=test;password=test", null, restRequester); + + // act + sfSession.Open(); + + // assert + Assert.AreEqual(1, restRequester.LoginRequests.Count); + var loginRequest = restRequester.LoginRequests[0]; + Assert.IsNull(loginRequest.data.passcode); + Assert.AreEqual("push", loginRequest.data.extAuthnDuoMethod); + } + + [Test] + public void TestPushMFAWithAuthenticationCacheMFAToken() + { + // arrange + var restRequester = new MockLoginMFATokenCacheRestRequester(); + var sfSession = new SFSession($"account=test;user=test;password=test;authenticator=username_password_mfa", null, restRequester); + + // act + sfSession.Open(); + + // assert + Assert.AreEqual(1, restRequester.LoginRequests.Count); + var loginRequest = restRequester.LoginRequests.Dequeue(); + Assert.IsNull(loginRequest.data.passcode); + Assert.IsTrue(loginRequest.data.SessionParameters.TryGetValue(SFSessionParameter.CLIENT_REQUEST_MFA_TOKEN, out var value) && (bool)value); + Assert.AreEqual("push", loginRequest.data.extAuthnDuoMethod); + } + + [Test] + public void TestMFATokenCacheReturnedToSession() + { + // arrange + var testToken = "testToken1234"; + var restRequester = new MockLoginMFATokenCacheRestRequester(); + var sfSession = new SFSession($"account=test;user=test;password=test;authenticator=username_password_mfa", null, restRequester); + restRequester.LoginResponses.Enqueue(new LoginResponseData() + { + mfaToken = testToken, + authResponseSessionInfo = new SessionInfo() + }); + // act + sfSession.Open(); + + // assert + Assert.AreEqual(1, restRequester.LoginRequests.Count); + var loginRequest = restRequester.LoginRequests.Dequeue(); + Assert.AreEqual(SecureStringHelper.Decode(sfSession._mfaToken), testToken); + Assert.IsNull(loginRequest.data.passcode); + Assert.IsTrue(loginRequest.data.SessionParameters.TryGetValue(SFSessionParameter.CLIENT_REQUEST_MFA_TOKEN, out var value) && (bool)value); + Assert.AreEqual("push", loginRequest.data.extAuthnDuoMethod); + } + + [Test] + public void TestMFATokenCacheUsedInNewConnection() + { + // arrange + var testToken = "testToken1234"; + var restRequester = new MockLoginMFATokenCacheRestRequester(); + var connectionString = $"account=test;user=test;password=test;authenticator=username_password_mfa"; + var sfSession = new SFSession(connectionString, null, restRequester); + restRequester.LoginResponses.Enqueue(new LoginResponseData() + { + mfaToken = testToken, + authResponseSessionInfo = new SessionInfo() + }); + sfSession.Open(); + var sfSessionWithCachedToken = new SFSession(connectionString, null, restRequester); + // act + sfSessionWithCachedToken.Open(); + + // assert + Assert.AreEqual(2, restRequester.LoginRequests.Count); + var firstLoginRequest = restRequester.LoginRequests.Dequeue(); + Assert.AreEqual(SecureStringHelper.Decode(sfSession._mfaToken), testToken); + Assert.IsNull(firstLoginRequest.data.passcode); + Assert.IsTrue(firstLoginRequest.data.SessionParameters.TryGetValue(SFSessionParameter.CLIENT_REQUEST_MFA_TOKEN, out var value) && (bool)value); + Assert.AreEqual("push", firstLoginRequest.data.extAuthnDuoMethod); + + var secondLoginRequest = restRequester.LoginRequests.Dequeue(); + Assert.AreEqual(secondLoginRequest.data.Token, testToken); + } } } diff --git a/Snowflake.Data.Tests/UnitTests/SecretDetectorTest.cs b/Snowflake.Data.Tests/UnitTests/SecretDetectorTest.cs index 82c59a63c..a25b263f9 100644 --- a/Snowflake.Data.Tests/UnitTests/SecretDetectorTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SecretDetectorTest.cs @@ -273,6 +273,14 @@ public void TestPasswordProperty() BasicMasking(@"somethingBefore=cccc;private_key_pwd=", @"somethingBefore=cccc;private_key_pwd=****"); BasicMasking(@"somethingBefore=cccc;private_key_pwd =aa;somethingNext=bbbb", @"somethingBefore=cccc;private_key_pwd =****"); BasicMasking(@"somethingBefore=cccc;private_key_pwd="" 'aa", @"somethingBefore=cccc;private_key_pwd=****"); + + BasicMasking(@"somethingBefore=cccc;passcode=aa", @"somethingBefore=cccc;passcode=****"); + BasicMasking(@"somethingBefore=cccc;passcode=aa;somethingNext=bbbb", @"somethingBefore=cccc;passcode=****"); + BasicMasking(@"somethingBefore=cccc;passcode=""aa"";somethingNext=bbbb", @"somethingBefore=cccc;passcode=****"); + BasicMasking(@"somethingBefore=cccc;passcode=;somethingNext=bbbb", @"somethingBefore=cccc;passcode=****"); + BasicMasking(@"somethingBefore=cccc;passcode=", @"somethingBefore=cccc;passcode=****"); + BasicMasking(@"somethingBefore=cccc;passcode =aa;somethingNext=bbbb", @"somethingBefore=cccc;passcode =****"); + BasicMasking(@"somethingBefore=cccc;passcode="" 'aa", @"somethingBefore=cccc;passcode=****"); } [Test] diff --git a/Snowflake.Data.Tests/UnitTests/Session/SessionOrCreationTokensTest.cs b/Snowflake.Data.Tests/UnitTests/Session/SessionOrCreationTokensTest.cs index 7d2b1a603..da5863475 100644 --- a/Snowflake.Data.Tests/UnitTests/Session/SessionOrCreationTokensTest.cs +++ b/Snowflake.Data.Tests/UnitTests/Session/SessionOrCreationTokensTest.cs @@ -10,13 +10,13 @@ namespace Snowflake.Data.Tests.UnitTests.Session public class SessionOrCreationTokensTest { private SFSession _session = new SFSession("account=test;user=test;password=test", null); - + [Test] public void TestNoBackgroundSessionsToCreateWhenInitialisedWithSession() { // arrange var sessionOrTokens = new SessionOrCreationTokens(_session); - + // act var backgroundCreationTokens = sessionOrTokens.BackgroundSessionCreationTokens(); @@ -32,14 +32,14 @@ public void TestReturnFirstCreationToken() .Select(_ => sessionCreationTokenCounter.NewToken()) .ToList(); var sessionOrTokens = new SessionOrCreationTokens(tokens); - + // act var token = sessionOrTokens.SessionCreationToken(); - + // assert Assert.AreSame(tokens[0], token); } - + [Test] public void TestReturnCreationTokensFromTheSecondOneForBackgroundExecution() { @@ -49,10 +49,10 @@ public void TestReturnCreationTokensFromTheSecondOneForBackgroundExecution() .Select(_ => sessionCreationTokenCounter.NewToken()) .ToList(); var sessionOrTokens = new SessionOrCreationTokens(tokens); - + // act var backgroundTokens = sessionOrTokens.BackgroundSessionCreationTokens(); - + // assert Assert.AreEqual(2, backgroundTokens.Count); Assert.AreSame(tokens[1], backgroundTokens[0]); diff --git a/Snowflake.Data.Tests/UnitTests/Session/SessionPoolTest.cs b/Snowflake.Data.Tests/UnitTests/Session/SessionPoolTest.cs index fca8f7de1..14115824e 100644 --- a/Snowflake.Data.Tests/UnitTests/Session/SessionPoolTest.cs +++ b/Snowflake.Data.Tests/UnitTests/Session/SessionPoolTest.cs @@ -71,17 +71,17 @@ public void TestOverrideSetPooling() [Test] [TestCase("account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443", "somePassword", " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] - [TestCase("account=someAccount;db=someDb;host=someHost;password=somePassword;user=SomeUser;port=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] - [TestCase("account=someAccount;db=someDb;host=someHost;password=somePassword;user=SomeUser;private_key=SomePrivateKey;port=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] - [TestCase("account=someAccount;db=someDb;host=someHost;password=somePassword;user=SomeUser;token=someToken;port=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] - [TestCase("account=someAccount;db=someDb;host=someHost;password=somePassword;user=SomeUser;private_key_pwd=somePrivateKeyPwd;port=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] - [TestCase("account=someAccount;db=someDb;host=someHost;password=somePassword;user=SomeUser;proxyPassword=someProxyPassword;port=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] - [TestCase("ACCOUNT=someAccount;DB=someDb;HOST=someHost;PASSWORD=somePassword;USER=SomeUser;PORT=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] - [TestCase("ACCOUNT=\"someAccount\";DB=\"someDb\";HOST=\"someHost\";PASSWORD=\"somePassword\";USER=\"SomeUser\";PORT=\"443\"", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] + [TestCase("account=someAccount;db=someDb;host=someHost;password=somePassword;passcode=123;user=SomeUser;port=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] + [TestCase("account=someAccount;db=someDb;host=someHost;password=somePassword;passcode=123;user=SomeUser;private_key=SomePrivateKey;port=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] + [TestCase("account=someAccount;db=someDb;host=someHost;password=somePassword;passcode=123;user=SomeUser;token=someToken;port=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] + [TestCase("account=someAccount;db=someDb;host=someHost;password=somePassword;passcode=123;user=SomeUser;private_key_pwd=somePrivateKeyPwd;port=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] + [TestCase("account=someAccount;db=someDb;host=someHost;password=somePassword;passcode=123;user=SomeUser;proxyPassword=someProxyPassword;port=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] + [TestCase("ACCOUNT=someAccount;DB=someDb;HOST=someHost;PASSWORD=somePassword;passcode=123;USER=SomeUser;PORT=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] + [TestCase("ACCOUNT=\"someAccount\";DB=\"someDb\";HOST=\"someHost\";PASSWORD=\"somePassword\";PASSCODE=\"123\";USER=\"SomeUser\";PORT=\"443\"", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] public void TestPoolIdentificationBasedOnConnectionString(string connectionString, string password, string expectedPoolIdentification) { // arrange - var securePassword = password == null ? null : new NetworkCredential("", password).SecurePassword; + var securePassword = password == null ? null : SecureStringHelper.Encode(password); var pool = SessionPool.CreateSessionPool(connectionString, securePassword); // act diff --git a/Snowflake.Data/Client/ISnowflakeCredentialManager.cs b/Snowflake.Data/Client/ISnowflakeCredentialManager.cs new file mode 100644 index 000000000..802d8fe21 --- /dev/null +++ b/Snowflake.Data/Client/ISnowflakeCredentialManager.cs @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +namespace Snowflake.Data.Client +{ + public interface ISnowflakeCredentialManager + { + string GetCredentials(string key); + + void RemoveCredentials(string key); + + void SaveCredentials(string key, string token); + } +} diff --git a/Snowflake.Data/Client/SnowflakeCredentialManagerFactory.cs b/Snowflake.Data/Client/SnowflakeCredentialManagerFactory.cs new file mode 100644 index 000000000..f006ff607 --- /dev/null +++ b/Snowflake.Data/Client/SnowflakeCredentialManagerFactory.cs @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using System.Runtime.InteropServices; +using Snowflake.Data.Core; +using Snowflake.Data.Core.CredentialManager; +using Snowflake.Data.Core.CredentialManager.Infrastructure; +using Snowflake.Data.Log; + +namespace Snowflake.Data.Client +{ + public class SnowflakeCredentialManagerFactory + { + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + + private static readonly object s_credentialManagerLock = new object(); + private static readonly ISnowflakeCredentialManager s_defaultCredentialManager = GetDefaultCredentialManager(); + + private static ISnowflakeCredentialManager s_credentialManager; + + internal static string BuildCredentialKey(string host, string user, TokenType tokenType, string authenticator = null) + { + return $"{host.ToUpper()}:{user.ToUpper()}:{SFEnvironment.DriverName}:{tokenType.ToString().ToUpper()}:{authenticator?.ToUpper() ?? string.Empty}"; + } + + public static void UseDefaultCredentialManager() + { + SetCredentialManager(GetDefaultCredentialManager()); + } + + public static void UseInMemoryCredentialManager() + { + SetCredentialManager(SFCredentialManagerInMemoryImpl.Instance); + } + + public static void UseFileCredentialManager() + { + SetCredentialManager(SFCredentialManagerFileImpl.Instance); + } + + public static void UseWindowsCredentialManager() + { + SetCredentialManager(SFCredentialManagerWindowsNativeImpl.Instance); + } + + public static void SetCredentialManager(ISnowflakeCredentialManager customCredentialManager) + { + lock (s_credentialManagerLock) + { + if (customCredentialManager == null) + { + throw new SnowflakeDbException(SFError.INTERNAL_ERROR, + "Credential manager cannot be null. If you want to use the default credential manager, please call the UseDefaultCredentialManager method."); + } + + if (customCredentialManager == s_credentialManager) + { + s_logger.Info($"Credential manager is already set to: {customCredentialManager.GetType().Name}"); + return; + } + + s_logger.Info($"Setting the credential manager: {customCredentialManager.GetType().Name}"); + s_credentialManager = customCredentialManager; + } + } + + public static ISnowflakeCredentialManager GetCredentialManager() + { + if (s_credentialManager == null) + { + lock (s_credentialManagerLock) + { + if (s_credentialManager == null) + { + s_credentialManager = s_defaultCredentialManager; + } + } + } + + var credentialManager = s_credentialManager; + var typeCredentialText = credentialManager == s_defaultCredentialManager ? "default" : "custom"; + s_logger.Info($"Using {typeCredentialText} credential manager: {credentialManager?.GetType().Name}"); + return credentialManager; + } + + private static ISnowflakeCredentialManager GetDefaultCredentialManager() + { + return RuntimeInformation.IsOSPlatform(OSPlatform.Windows) + ? (ISnowflakeCredentialManager) + SFCredentialManagerWindowsNativeImpl.Instance + : SFCredentialManagerInMemoryImpl.Instance; + } + } +} diff --git a/Snowflake.Data/Client/SnowflakeDbConnection.cs b/Snowflake.Data/Client/SnowflakeDbConnection.cs index 9acb24f06..bd9cee33e 100755 --- a/Snowflake.Data/Client/SnowflakeDbConnection.cs +++ b/Snowflake.Data/Client/SnowflakeDbConnection.cs @@ -69,6 +69,8 @@ public SecureString Password get; set; } + public SecureString Passcode { get; set; } + public bool IsOpen() { return _connectionState == ConnectionState.Open && SfSession != null; @@ -269,7 +271,7 @@ public override void Open() try { OnSessionConnecting(); - SfSession = SnowflakeDbConnectionPool.GetSession(ConnectionString, Password); + SfSession = SnowflakeDbConnectionPool.GetSession(ConnectionString, Password, Passcode); if (SfSession == null) throw new SnowflakeDbException(SFError.INTERNAL_ERROR, "Could not open session"); logger.Debug($"Connection open with pooled session: {SfSession.sessionId}"); @@ -303,7 +305,7 @@ public override Task OpenAsync(CancellationToken cancellationToken) registerConnectionCancellationCallback(cancellationToken); OnSessionConnecting(); return SnowflakeDbConnectionPool - .GetSessionAsync(ConnectionString, Password, cancellationToken) + .GetSessionAsync(ConnectionString, Password, Passcode, cancellationToken) .ContinueWith(previousTask => { if (previousTask.IsFaulted) diff --git a/Snowflake.Data/Client/SnowflakeDbConnectionPool.cs b/Snowflake.Data/Client/SnowflakeDbConnectionPool.cs index fcee66e1a..fd10eadd8 100644 --- a/Snowflake.Data/Client/SnowflakeDbConnectionPool.cs +++ b/Snowflake.Data/Client/SnowflakeDbConnectionPool.cs @@ -31,16 +31,16 @@ private static IConnectionManager ConnectionManager } } - internal static SFSession GetSession(string connectionString, SecureString password) + internal static SFSession GetSession(string connectionString, SecureString password, SecureString passcode) { s_logger.Debug($"SnowflakeDbConnectionPool::GetSession"); - return ConnectionManager.GetSession(connectionString, password); + return ConnectionManager.GetSession(connectionString, password, passcode); } - internal static Task GetSessionAsync(string connectionString, SecureString password, CancellationToken cancellationToken) + internal static Task GetSessionAsync(string connectionString, SecureString password, SecureString passcode, CancellationToken cancellationToken) { s_logger.Debug($"SnowflakeDbConnectionPool::GetSessionAsync"); - return ConnectionManager.GetSessionAsync(connectionString, password, cancellationToken); + return ConnectionManager.GetSessionAsync(connectionString, password, passcode, cancellationToken); } public static SnowflakeDbSessionPool GetPool(string connectionString, SecureString password) diff --git a/Snowflake.Data/Core/Authenticator/BasicAuthenticator.cs b/Snowflake.Data/Core/Authenticator/BasicAuthenticator.cs index a26d542d3..2dba66594 100644 --- a/Snowflake.Data/Core/Authenticator/BasicAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/BasicAuthenticator.cs @@ -34,6 +34,7 @@ protected override void SetSpecializedAuthenticatorData(ref LoginRequestData dat { // Only need to add the password to Data for basic authentication data.password = session.properties[SFSessionProperty.PASSWORD]; + SetSecondaryAuthenticationData(ref data); } } diff --git a/Snowflake.Data/Core/Authenticator/ExternalBrowserAuthenticator.cs b/Snowflake.Data/Core/Authenticator/ExternalBrowserAuthenticator.cs index e39ec18f8..baba5f8a5 100644 --- a/Snowflake.Data/Core/Authenticator/ExternalBrowserAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/ExternalBrowserAuthenticator.cs @@ -260,6 +260,7 @@ protected override void SetSpecializedAuthenticatorData(ref LoginRequestData dat // Add the token and proof key to the Data data.Token = _samlResponseToken; data.ProofKey = _proofKey; + SetSpecializedAuthenticatorData(ref data); } private string GetLoginUrl(string proofKey, int localPort) diff --git a/Snowflake.Data/Core/Authenticator/IAuthenticator.cs b/Snowflake.Data/Core/Authenticator/IAuthenticator.cs index 7a41a8335..f5f02782c 100644 --- a/Snowflake.Data/Core/Authenticator/IAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/IAuthenticator.cs @@ -101,6 +101,24 @@ protected void Login() /// The login request data to update. protected abstract void SetSpecializedAuthenticatorData(ref LoginRequestData data); + protected void SetSecondaryAuthenticationData(ref LoginRequestData data) + { + if (session.properties.TryGetValue(SFSessionProperty.PASSCODEINPASSWORD, out var passcodeInPasswordString) + && bool.TryParse(passcodeInPasswordString, out var passcodeInPassword) + && passcodeInPassword) + { + data.extAuthnDuoMethod = "passcode"; + } else if (session.properties.TryGetValue(SFSessionProperty.PASSCODE, out var passcode) && !string.IsNullOrEmpty(passcode)) + { + data.extAuthnDuoMethod = "passcode"; + data.passcode = passcode; + } + else + { + data.extAuthnDuoMethod = "push"; + } + } + /// /// Builds a simple login request. Each authenticator will fill the Data part with their /// specialized information. The common Data attributes are already filled (clientAppId, @@ -116,16 +134,20 @@ private SFRestRequest BuildLoginRequest() { loginName = session.properties[SFSessionProperty.USER], accountName = session.properties[SFSessionProperty.ACCOUNT], + // TODO LOCAL TEST MFA temp change should be removed before merge + // clientAppId = "JDBC",//SFEnvironment.DriverName, + // clientAppVersion = "3.12.16", // SFEnvironment.DriverVersion, clientAppId = SFEnvironment.DriverName, clientAppVersion = SFEnvironment.DriverVersion, clientEnv = ClientEnv, SessionParameters = session.ParameterMap, Authenticator = authName, }; - SetSpecializedAuthenticatorData(ref data); - return session.BuildTimeoutRestRequest(loginUrl, new LoginRequest() { data = data }); + return data.HttpTimeout.HasValue ? + session.BuildTimeoutRestRequest(loginUrl, new LoginRequest() { data = data }, data.HttpTimeout.Value) : + session.BuildTimeoutRestRequest(loginUrl, new LoginRequest() { data = data }); } } @@ -187,6 +209,10 @@ internal static IAuthenticator GetAuthenticator(SFSession session) return new OAuthAuthenticator(session); } + else if (type.Equals(MFACacheAuthenticator.AUTH_NAME, StringComparison.InvariantCultureIgnoreCase)) + { + return new MFACacheAuthenticator(session); + } // Okta would provide a url of form: https://xxxxxx.okta.com or https://xxxxxx.oktapreview.com or https://vanity.url/snowflake/okta else if (type.Contains("okta") && type.StartsWith("https://")) { diff --git a/Snowflake.Data/Core/Authenticator/KeyPairAuthenticator.cs b/Snowflake.Data/Core/Authenticator/KeyPairAuthenticator.cs index 7d86d02c9..44b9b8bec 100644 --- a/Snowflake.Data/Core/Authenticator/KeyPairAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/KeyPairAuthenticator.cs @@ -75,6 +75,7 @@ protected override void SetSpecializedAuthenticatorData(ref LoginRequestData dat { // Add the token to the Data attribute data.Token = jwtToken; + SetSpecializedAuthenticatorData(ref data); } /// diff --git a/Snowflake.Data/Core/Authenticator/MFACacheAuthenticator.cs b/Snowflake.Data/Core/Authenticator/MFACacheAuthenticator.cs new file mode 100644 index 000000000..2d398352d --- /dev/null +++ b/Snowflake.Data/Core/Authenticator/MFACacheAuthenticator.cs @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using System.Threading; +using System.Threading.Tasks; +using Snowflake.Data.Core.Tools; + +namespace Snowflake.Data.Core.Authenticator +{ + class MFACacheAuthenticator : BaseAuthenticator, IAuthenticator + { + public const string AUTH_NAME = "username_password_mfa"; + private const int _MFA_LOGIN_HTTP_TIMEOUT = 60; + + internal MFACacheAuthenticator(SFSession session) : base(session, AUTH_NAME) + { + } + + /// + async Task IAuthenticator.AuthenticateAsync(CancellationToken cancellationToken) + { + await base.LoginAsync(cancellationToken); + } + + /// + void IAuthenticator.Authenticate() + { + base.Login(); + } + + /// + protected override void SetSpecializedAuthenticatorData(ref LoginRequestData data) + { + // Only need to add the password to Data for basic authentication + data.password = session.properties[SFSessionProperty.PASSWORD]; + data.SessionParameters[SFSessionParameter.CLIENT_REQUEST_MFA_TOKEN] = true; + data.HttpTimeout = TimeSpan.FromSeconds(_MFA_LOGIN_HTTP_TIMEOUT); + if (!string.IsNullOrEmpty(session._mfaToken?.ToString())) + { + data.Token = SecureStringHelper.Decode(session._mfaToken); + } + SetSecondaryAuthenticationData(ref data); + } + } + +} diff --git a/Snowflake.Data/Core/Authenticator/OAuthAuthenticator.cs b/Snowflake.Data/Core/Authenticator/OAuthAuthenticator.cs index f36d0353e..85599266e 100644 --- a/Snowflake.Data/Core/Authenticator/OAuthAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/OAuthAuthenticator.cs @@ -1,7 +1,4 @@ using Snowflake.Data.Log; -using System; -using System.Collections.Generic; -using System.Text; using System.Threading; using System.Threading.Tasks; @@ -48,6 +45,7 @@ protected override void SetSpecializedAuthenticatorData(ref LoginRequestData dat data.Token = session.properties[SFSessionProperty.TOKEN]; // Remove the login name for an OAuth session data.loginName = ""; + SetSecondaryAuthenticationData(ref data); } } } diff --git a/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs b/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs index 7c364d3c5..164949864 100644 --- a/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs @@ -248,6 +248,7 @@ private SamlRestRequest BuildSamlRestRequest(Uri ssoUrl, string onetimeToken) protected override void SetSpecializedAuthenticatorData(ref LoginRequestData data) { data.RawSamlResponse = _rawSamlTokenHtmlString; + SetSecondaryAuthenticationData(ref data); } private void VerifyUrls(Uri tokenOrSsoUrl, Uri sessionUrl) diff --git a/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerFileImpl.cs b/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerFileImpl.cs new file mode 100644 index 000000000..0f57aaebd --- /dev/null +++ b/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerFileImpl.cs @@ -0,0 +1,150 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using Mono.Unix; +using Mono.Unix.Native; +using Newtonsoft.Json; +using Snowflake.Data.Client; +using Snowflake.Data.Core.Tools; +using Snowflake.Data.Log; +using System; +using System.IO; +using System.Runtime.InteropServices; +using KeyTokenDict = System.Collections.Generic.Dictionary; + +namespace Snowflake.Data.Core.CredentialManager.Infrastructure +{ + internal class SFCredentialManagerFileImpl : ISnowflakeCredentialManager + { + internal const string CredentialCacheDirectoryEnvironmentName = "SF_TEMPORARY_CREDENTIAL_CACHE_DIR"; + + internal const string CredentialCacheDirName = ".snowflake"; + + internal const string CredentialCacheFileName = "temporary_credential.json"; + + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + + private readonly string _jsonCacheDirectory; + + private readonly string _jsonCacheFilePath; + + private readonly FileOperations _fileOperations; + + private readonly DirectoryOperations _directoryOperations; + + private readonly UnixOperations _unixOperations; + + private readonly EnvironmentOperations _environmentOperations; + + public static readonly SFCredentialManagerFileImpl Instance = new SFCredentialManagerFileImpl(FileOperations.Instance, DirectoryOperations.Instance, UnixOperations.Instance, EnvironmentOperations.Instance); + + internal SFCredentialManagerFileImpl(FileOperations fileOperations, DirectoryOperations directoryOperations, UnixOperations unixOperations, EnvironmentOperations environmentOperations) + { + _fileOperations = fileOperations; + _directoryOperations = directoryOperations; + _unixOperations = unixOperations; + _environmentOperations = environmentOperations; + SetCredentialCachePath(ref _jsonCacheDirectory, ref _jsonCacheFilePath); + } + + private void SetCredentialCachePath(ref string _jsonCacheDirectory, ref string _jsonCacheFilePath) + { + var customDirectory = _environmentOperations.GetEnvironmentVariable(CredentialCacheDirectoryEnvironmentName); + _jsonCacheDirectory = string.IsNullOrEmpty(customDirectory) ? Path.Combine(HomeDirectoryProvider.HomeDirectory(_environmentOperations), CredentialCacheDirName) : customDirectory; + if (!_directoryOperations.Exists(_jsonCacheDirectory)) + { + _directoryOperations.CreateDirectory(_jsonCacheDirectory); + } + _jsonCacheFilePath = Path.Combine(_jsonCacheDirectory, CredentialCacheFileName); + s_logger.Info($"Setting the json credential cache path to {_jsonCacheFilePath}"); + } + + internal void WriteToJsonFile(string content) + { + s_logger.Debug($"Writing credentials to json file in {_jsonCacheFilePath}"); + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + _fileOperations.Write(_jsonCacheFilePath, content); + } + else + { + if (!_directoryOperations.Exists(_jsonCacheDirectory)) + { + _directoryOperations.CreateDirectory(_jsonCacheDirectory); + } + s_logger.Info($"Creating the json file for credential cache in {_jsonCacheFilePath}"); + if (_fileOperations.Exists(_jsonCacheFilePath)) + { + s_logger.Info($"The existing json file for credential cache in {_jsonCacheFilePath} will be overwritten"); + } + var createFileResult = _unixOperations.CreateFileWithPermissions(_jsonCacheFilePath, + FilePermissions.S_IRUSR | FilePermissions.S_IWUSR | FilePermissions.S_IXUSR); + if (createFileResult == -1) + { + var errorMessage = "Failed to create the JSON token cache file"; + s_logger.Error(errorMessage); + throw new Exception(errorMessage); + } + else + { + _fileOperations.Write(_jsonCacheFilePath, content); + } + + var jsonPermissions = _unixOperations.GetFilePermissions(_jsonCacheFilePath); + if (jsonPermissions != FileAccessPermissions.UserReadWriteExecute) + { + var errorMessage = "Permission for the JSON token cache file should contain only the owner access"; + s_logger.Error(errorMessage); + throw new Exception(errorMessage); + } + } + } + + internal KeyTokenDict ReadJsonFile() + { + var contentFile = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? File.ReadAllText(_jsonCacheFilePath) : _unixOperations.ReadAllText(_jsonCacheFilePath); + return JsonConvert.DeserializeObject(contentFile); + } + + public string GetCredentials(string key) + { + s_logger.Debug($"Getting credentials from json file in {_jsonCacheFilePath} for key: {key}"); + if (_fileOperations.Exists(_jsonCacheFilePath)) + { + var keyTokenPairs = ReadJsonFile(); + var hashKey = key.ToSha256Hash(); + if (keyTokenPairs.TryGetValue(hashKey, out string token)) + { + return token; + } + } + + s_logger.Info("Unable to get credentials for the specified key"); + return ""; + } + + public void RemoveCredentials(string key) + { + s_logger.Debug($"Removing credentials from json file in {_jsonCacheFilePath} for key: {key}"); + if (_fileOperations.Exists(_jsonCacheFilePath)) + { + var keyTokenPairs = ReadJsonFile(); + var hashKey = key.ToSha256Hash(); + keyTokenPairs.Remove(hashKey); + WriteToJsonFile(JsonConvert.SerializeObject(keyTokenPairs)); + } + } + + public void SaveCredentials(string key, string token) + { + s_logger.Debug($"Saving credentials to json file in {_jsonCacheFilePath} for key: {key}"); + var hashKey = key.ToSha256Hash(); + KeyTokenDict keyTokenPairs = _fileOperations.Exists(_jsonCacheFilePath) ? ReadJsonFile() : new KeyTokenDict(); + keyTokenPairs[hashKey] = token; + + string jsonString = JsonConvert.SerializeObject(keyTokenPairs); + WriteToJsonFile(jsonString); + } + } +} diff --git a/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerInMemoryImpl.cs b/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerInMemoryImpl.cs new file mode 100644 index 000000000..8ea1e86cc --- /dev/null +++ b/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerInMemoryImpl.cs @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + + +using System.Collections.Generic; +using System.Security; +using Snowflake.Data.Client; +using Snowflake.Data.Core.Tools; +using Snowflake.Data.Log; + +namespace Snowflake.Data.Core.CredentialManager.Infrastructure +{ + internal class SFCredentialManagerInMemoryImpl : ISnowflakeCredentialManager + { + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + + private Dictionary s_credentials = new Dictionary(); + + public static readonly SFCredentialManagerInMemoryImpl Instance = new SFCredentialManagerInMemoryImpl(); + + public string GetCredentials(string key) + { + s_logger.Debug($"Getting credentials from memory for key: {key}"); + var hashKey = key.ToSha256Hash(); + if (s_credentials.TryGetValue(hashKey, out var secureToken)) + { + return SecureStringHelper.Decode(secureToken); + } + else + { + s_logger.Info("Unable to get credentials for the specified key"); + return ""; + } + } + + public void RemoveCredentials(string key) + { + var hashKey = key.ToSha256Hash(); + s_logger.Debug($"Removing credentials from memory for key: {key}"); + s_credentials.Remove(hashKey); + } + + public void SaveCredentials(string key, string token) + { + var hashKey = key.ToSha256Hash(); + s_logger.Debug($"Saving credentials into memory for key: {hashKey}"); + s_credentials[hashKey] = SecureStringHelper.Encode(token); + } + } +} diff --git a/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerWindowsNativeImpl.cs b/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerWindowsNativeImpl.cs new file mode 100644 index 000000000..264091ad9 --- /dev/null +++ b/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerWindowsNativeImpl.cs @@ -0,0 +1,129 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using Microsoft.Win32.SafeHandles; +using System; +using System.Runtime.InteropServices; +using System.Text; +using Snowflake.Data.Client; +using Snowflake.Data.Core.Tools; +using Snowflake.Data.Log; + +namespace Snowflake.Data.Core.CredentialManager.Infrastructure +{ + + internal class SFCredentialManagerWindowsNativeImpl : ISnowflakeCredentialManager + { + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + + public static readonly SFCredentialManagerWindowsNativeImpl Instance = new SFCredentialManagerWindowsNativeImpl(); + + public string GetCredentials(string key) + { + s_logger.Debug($"Getting the credentials for key: {key}"); + var hashKey = key.ToSha256Hash(); + IntPtr nCredPtr; + if (!CredRead(hashKey, 1 /* Generic */, 0, out nCredPtr)) + { + s_logger.Info($"Unable to get credentials for key: {key}"); + return ""; + } + + using (var critCred = new CriticalCredentialHandle(nCredPtr)) + { + var cred = critCred.GetCredential(); + return cred.CredentialBlob; + } + } + + public void RemoveCredentials(string key) + { + s_logger.Debug($"Removing the credentials for key: {key}"); + + var hashKey = key.ToSha256Hash(); + if (!CredDelete(hashKey, 1 /* Generic */, 0)) + { + s_logger.Info($"Unable to remove credentials because the specified key did not exist: {key}"); + } + } + + public void SaveCredentials(string key, string token) + { + s_logger.Debug($"Saving the credentials for key: {key}"); + var hashKey = key.ToSha256Hash(); + byte[] byteArray = Encoding.Unicode.GetBytes(token); + Credential credential = new Credential(); + credential.AttributeCount = 0; + credential.Attributes = IntPtr.Zero; + credential.Comment = IntPtr.Zero; + credential.TargetAlias = IntPtr.Zero; + credential.Type = 1; // Generic + credential.Persist = 2; // Local Machine + credential.CredentialBlobSize = (uint)(byteArray == null ? 0 : byteArray.Length); + credential.TargetName = hashKey; + credential.CredentialBlob = token; + credential.UserName = Environment.UserName; + + CredWrite(ref credential, 0); + } + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + private struct Credential + { + public uint Flags; + public uint Type; + [MarshalAs(UnmanagedType.LPWStr)] + public string TargetName; + public IntPtr Comment; + public System.Runtime.InteropServices.ComTypes.FILETIME LastWritten; + public uint CredentialBlobSize; + [MarshalAs(UnmanagedType.LPWStr)] + public string CredentialBlob; + public uint Persist; + public uint AttributeCount; + public IntPtr Attributes; + public IntPtr TargetAlias; + [MarshalAs(UnmanagedType.LPWStr)] + public string UserName; + } + + sealed class CriticalCredentialHandle : CriticalHandleZeroOrMinusOneIsInvalid + { + public CriticalCredentialHandle(IntPtr handle) + { + SetHandle(handle); + } + + public Credential GetCredential() + { + var credential = (Credential)Marshal.PtrToStructure(handle, typeof(Credential)); + return credential; + } + + protected override bool ReleaseHandle() + { + if (IsInvalid) + { + return false; + } + + CredFree(handle); + SetHandleAsInvalid(); + return true; + } + } + + [DllImport("Advapi32.dll", EntryPoint = "CredDeleteW", CharSet = CharSet.Unicode, SetLastError = true)] + internal static extern bool CredDelete(string target, uint type, int reservedFlag); + + [DllImport("Advapi32.dll", EntryPoint = "CredReadW", CharSet = CharSet.Unicode, SetLastError = true)] + static extern bool CredRead(string target, uint type, int reservedFlag, out IntPtr credentialPtr); + + [DllImport("Advapi32.dll", EntryPoint = "CredWriteW", CharSet = CharSet.Unicode, SetLastError = true)] + static extern bool CredWrite([In] ref Credential userCredential, [In] uint flags); + + [DllImport("Advapi32.dll", EntryPoint = "CredFree", SetLastError = true)] + static extern bool CredFree([In] IntPtr cred); + } +} diff --git a/Snowflake.Data/Core/CredentialManager/TokenType.cs b/Snowflake.Data/Core/CredentialManager/TokenType.cs new file mode 100644 index 000000000..cdeb063d2 --- /dev/null +++ b/Snowflake.Data/Core/CredentialManager/TokenType.cs @@ -0,0 +1,14 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +namespace Snowflake.Data.Core.CredentialManager +{ + internal enum TokenType + { + [StringAttr(value = "ID_TOKEN")] + IdToken, + [StringAttr(value = "MFA_TOKEN")] + MFAToken + } +} diff --git a/Snowflake.Data/Core/ErrorMessages.resx b/Snowflake.Data/Core/ErrorMessages.resx index 3532f3394..664122e11 100755 --- a/Snowflake.Data/Core/ErrorMessages.resx +++ b/Snowflake.Data/Core/ErrorMessages.resx @@ -180,6 +180,9 @@ Snowflake type {0} is not supported for parameters. + + Invalid browser url "{0}" cannot be used for authentication. + Browser response timed out after {0} seconds. diff --git a/Snowflake.Data/Core/RestRequest.cs b/Snowflake.Data/Core/RestRequest.cs index 112743f77..de988895b 100644 --- a/Snowflake.Data/Core/RestRequest.cs +++ b/Snowflake.Data/Core/RestRequest.cs @@ -27,7 +27,7 @@ internal abstract class BaseRestRequest : IRestRequest internal static string REST_REQUEST_TIMEOUT_KEY = "TIMEOUT_PER_REST_REQUEST"; - // The default Rest timeout. Set to 120 seconds. + // The default Rest timeout. Set to 120 seconds. public static int DEFAULT_REST_RETRY_SECONDS_TIMEOUT = 120; internal Uri Url { get; set; } @@ -133,7 +133,7 @@ internal SFRestRequest() : base() public override string ToString() { - return String.Format("SFRestRequest {{url: {0}, request body: {1} }}", Url.ToString(), + return String.Format("SFRestRequest {{url: {0}, request body: {1} }}", Url.ToString(), jsonBody.ToString()); } @@ -259,12 +259,21 @@ class LoginRequestData [JsonProperty(PropertyName = "PROOF_KEY", NullValueHandling = NullValueHandling.Ignore)] internal string ProofKey { get; set; } + [JsonProperty(PropertyName = "EXT_AUTHN_DUO_METHOD", NullValueHandling = NullValueHandling.Ignore)] + internal string extAuthnDuoMethod { get; set; } + + [JsonProperty(PropertyName = "PASSCODE", NullValueHandling = NullValueHandling.Ignore)] + internal string passcode; + [JsonProperty(PropertyName = "SESSION_PARAMETERS", NullValueHandling = NullValueHandling.Ignore)] internal Dictionary SessionParameters { get; set; } + [JsonIgnore] + internal TimeSpan? HttpTimeout { get; set; } + public override string ToString() { - return String.Format("LoginRequestData {{ClientAppVersion: {0},\n AccountName: {1},\n loginName: {2},\n ClientEnv: {3},\n authenticator: {4} }}", + return String.Format("LoginRequestData {{ClientAppVersion: {0},\n AccountName: {1},\n loginName: {2},\n ClientEnv: {3},\n authenticator: {4} }}", clientAppVersion, accountName, loginName, clientEnv.ToString(), Authenticator); } } @@ -291,7 +300,7 @@ class LoginRequestClientEnv public override string ToString() { - return String.Format("{{ APPLICATION: {0}, OS_VERSION: {1}, NET_RUNTIME: {2}, NET_VERSION: {3}, INSECURE_MODE: {4} }}", + return String.Format("{{ APPLICATION: {0}, OS_VERSION: {1}, NET_RUNTIME: {2}, NET_VERSION: {3}, INSECURE_MODE: {4} }}", application, osVersion, netRuntime, netVersion, insecureMode); } } diff --git a/Snowflake.Data/Core/RestResponse.cs b/Snowflake.Data/Core/RestResponse.cs index 64275fa42..fcdc68683 100755 --- a/Snowflake.Data/Core/RestResponse.cs +++ b/Snowflake.Data/Core/RestResponse.cs @@ -16,9 +16,11 @@ abstract class BaseRestResponse [JsonProperty(PropertyName = "message")] internal String message { get; set; } + [JsonProperty(PropertyName = "code", NullValueHandling = NullValueHandling.Ignore)] internal int code { get; set; } + [JsonProperty(PropertyName = "success")] internal bool success { get; set; } @@ -91,6 +93,9 @@ internal class LoginResponseData [JsonProperty(PropertyName = "masterValidityInSeconds", NullValueHandling = NullValueHandling.Ignore)] internal int masterValidityInSeconds { get; set; } + + [JsonProperty(PropertyName = "mfaToken", NullValueHandling = NullValueHandling.Ignore)] + internal string mfaToken { get; set; } } internal class AuthenticatorResponseData diff --git a/Snowflake.Data/Core/SFError.cs b/Snowflake.Data/Core/SFError.cs old mode 100755 new mode 100644 index 44de969a1..b87dcd97f --- a/Snowflake.Data/Core/SFError.cs +++ b/Snowflake.Data/Core/SFError.cs @@ -3,6 +3,8 @@ */ using System; +using System.Collections.Generic; +using System.Linq; namespace Snowflake.Data.Core { @@ -92,7 +94,39 @@ public enum SFError STRUCTURED_TYPE_READ_ERROR, [SFErrorAttr(errorCode = 270062)] - STRUCTURED_TYPE_READ_DETAILED_ERROR + STRUCTURED_TYPE_READ_DETAILED_ERROR, + + [SFErrorAttr(errorCode = 390120)] + EXT_AUTHN_DENIED, + + [SFErrorAttr(errorCode = 390123)] + EXT_AUTHN_LOCKED, + + [SFErrorAttr(errorCode = 390126)] + EXT_AUTHN_TIMEOUT, + + [SFErrorAttr(errorCode = 390127)] + EXT_AUTHN_INVALID, + + [SFErrorAttr(errorCode = 390129)] + EXT_AUTHN_EXCEPTION, + } + + class SFMFATokenErrors + { + private static List InvalidMFATokenErrors = new List + { + SFError.EXT_AUTHN_DENIED, + SFError.EXT_AUTHN_LOCKED, + SFError.EXT_AUTHN_TIMEOUT, + SFError.EXT_AUTHN_INVALID, + SFError.EXT_AUTHN_EXCEPTION + }; + + public static bool IsInvalidMFATokenContinueError(int error) + { + return InvalidMFATokenErrors.Any(e => e.GetAttribute().errorCode == error); + } } class SFErrorAttr : Attribute diff --git a/Snowflake.Data/Core/Session/ConnectionCacheManager.cs b/Snowflake.Data/Core/Session/ConnectionCacheManager.cs index febecbbce..538221b09 100644 --- a/Snowflake.Data/Core/Session/ConnectionCacheManager.cs +++ b/Snowflake.Data/Core/Session/ConnectionCacheManager.cs @@ -11,9 +11,9 @@ namespace Snowflake.Data.Core.Session internal sealed class ConnectionCacheManager : IConnectionManager { private readonly SessionPool _sessionPool = SessionPool.CreateSessionCache(); - public SFSession GetSession(string connectionString, SecureString password) => _sessionPool.GetSession(connectionString, password); - public Task GetSessionAsync(string connectionString, SecureString password, CancellationToken cancellationToken) - => _sessionPool.GetSessionAsync(connectionString, password, cancellationToken); + public SFSession GetSession(string connectionString, SecureString password, SecureString passcode) => _sessionPool.GetSession(connectionString, password, passcode); + public Task GetSessionAsync(string connectionString, SecureString password, SecureString passcode, CancellationToken cancellationToken) + => _sessionPool.GetSessionAsync(connectionString, password, passcode, cancellationToken); public bool AddSession(SFSession session) => _sessionPool.AddSession(session, false); public void ReleaseBusySession(SFSession session) => _sessionPool.ReleaseBusySession(session); public void ClearAllPools() => _sessionPool.ClearSessions(); diff --git a/Snowflake.Data/Core/Session/ConnectionPoolManager.cs b/Snowflake.Data/Core/Session/ConnectionPoolManager.cs index 09bfa5821..6a0013bb0 100644 --- a/Snowflake.Data/Core/Session/ConnectionPoolManager.cs +++ b/Snowflake.Data/Core/Session/ConnectionPoolManager.cs @@ -29,16 +29,16 @@ internal ConnectionPoolManager() } } - public SFSession GetSession(string connectionString, SecureString password) + public SFSession GetSession(string connectionString, SecureString password, SecureString passcode) { s_logger.Debug($"ConnectionPoolManager::GetSession"); - return GetPool(connectionString, password).GetSession(); + return GetPool(connectionString, password).GetSession(passcode); } - public Task GetSessionAsync(string connectionString, SecureString password, CancellationToken cancellationToken) + public Task GetSessionAsync(string connectionString, SecureString password, SecureString passcode, CancellationToken cancellationToken) { s_logger.Debug($"ConnectionPoolManager::GetSessionAsync"); - return GetPool(connectionString, password).GetSessionAsync(cancellationToken); + return GetPool(connectionString, password).GetSessionAsync(passcode, cancellationToken); } public bool AddSession(SFSession session) diff --git a/Snowflake.Data/Core/Session/IConnectionManager.cs b/Snowflake.Data/Core/Session/IConnectionManager.cs index 01cfa3e8c..5d3885de4 100644 --- a/Snowflake.Data/Core/Session/IConnectionManager.cs +++ b/Snowflake.Data/Core/Session/IConnectionManager.cs @@ -10,8 +10,8 @@ namespace Snowflake.Data.Core.Session { internal interface IConnectionManager { - SFSession GetSession(string connectionString, SecureString password); - Task GetSessionAsync(string connectionString, SecureString password, CancellationToken cancellationToken); + SFSession GetSession(string connectionString, SecureString password, SecureString passcode = null); + Task GetSessionAsync(string connectionString, SecureString password, SecureString passcode, CancellationToken cancellationToken); bool AddSession(SFSession session); void ReleaseBusySession(SFSession session); void ClearAllPools(); diff --git a/Snowflake.Data/Core/Session/ISessionFactory.cs b/Snowflake.Data/Core/Session/ISessionFactory.cs index f9416de8d..fbc896fda 100644 --- a/Snowflake.Data/Core/Session/ISessionFactory.cs +++ b/Snowflake.Data/Core/Session/ISessionFactory.cs @@ -4,6 +4,6 @@ namespace Snowflake.Data.Core.Session { internal interface ISessionFactory { - SFSession NewSession(string connectionString, SecureString password); + SFSession NewSession(string connectionString, SecureString password, SecureString passcode); } } diff --git a/Snowflake.Data/Core/Session/SFSession.cs b/Snowflake.Data/Core/Session/SFSession.cs old mode 100755 new mode 100644 index b6a0ebf79..f09e6cd2f --- a/Snowflake.Data/Core/Session/SFSession.cs +++ b/Snowflake.Data/Core/Session/SFSession.cs @@ -1,4 +1,4 @@ -/* +/* * Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. */ @@ -14,6 +14,7 @@ using System.Threading.Tasks; using System.Net.Http; using System.Text.RegularExpressions; +using Snowflake.Data.Core.CredentialManager; using Snowflake.Data.Core.Session; using Snowflake.Data.Core.Tools; @@ -73,6 +74,8 @@ public class SFSession internal string ConnectionString { get; } internal SecureString Password { get; } + internal SecureString Passcode { get; } + private QueryContextCache _queryContextCache = new QueryContextCache(_defaultQueryContextCacheSize); private int _queryContextCacheSize = _defaultQueryContextCacheSize; @@ -98,6 +101,8 @@ public void SetPooling(bool isEnabled) internal String _queryTag; + internal SecureString _mfaToken; + internal void ProcessLoginResponse(LoginResponse authnResponse) { if (authnResponse.success) @@ -116,6 +121,12 @@ internal void ProcessLoginResponse(LoginResponse authnResponse) { logger.Debug("Query context cache disabled."); } + if (!string.IsNullOrEmpty(authnResponse.data.mfaToken)) + { + _mfaToken = SecureStringHelper.Encode(authnResponse.data.mfaToken); + var key = SnowflakeCredentialManagerFactory.BuildCredentialKey(properties[SFSessionProperty.HOST], properties[SFSessionProperty.USER], TokenType.MFAToken, properties[SFSessionProperty.AUTHENTICATOR]); + SnowflakeCredentialManagerFactory.GetCredentialManager().SaveCredentials(key, authnResponse.data.mfaToken); + } logger.Debug($"Session opened: {sessionId}"); _startTime = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); } @@ -128,6 +139,14 @@ internal void ProcessLoginResponse(LoginResponse authnResponse) ""); logger.Error("Authentication failed", e); + if (SFMFATokenErrors.IsInvalidMFATokenContinueError(e.ErrorCode)) + { + logger.Info($"Unable to use cached MFA token is expired or invalid. Fails with the {e.Message}. ", e); + _mfaToken = null; + var mfaKey = SnowflakeCredentialManagerFactory.BuildCredentialKey(properties[SFSessionProperty.HOST], properties[SFSessionProperty.USER], TokenType.MFAToken, properties[SFSessionProperty.AUTHENTICATOR]); + SnowflakeCredentialManagerFactory.GetCredentialManager().RemoveCredentials(mfaKey); + } + throw e; } } @@ -158,19 +177,22 @@ internal Uri BuildLoginUrl() /// A string in the form of "key1=value1;key2=value2" internal SFSession( String connectionString, - SecureString password) : this(connectionString, password, EasyLoggingStarter.Instance) + SecureString password, + SecureString passcode = null) : this(connectionString, password, passcode, EasyLoggingStarter.Instance) { } internal SFSession( String connectionString, SecureString password, + SecureString passcode, EasyLoggingStarter easyLoggingStarter) { _easyLoggingStarter = easyLoggingStarter; ConnectionString = connectionString; Password = password; - properties = SFSessionProperties.ParseConnectionString(ConnectionString, Password); + Passcode = passcode; + properties = SFSessionProperties.ParseConnectionString(ConnectionString, Password, Passcode); _disableQueryContextCache = bool.Parse(properties[SFSessionProperty.DISABLEQUERYCONTEXTCACHE]); _disableConsoleLogin = bool.Parse(properties[SFSessionProperty.DISABLE_CONSOLE_LOGIN]); properties.TryGetValue(SFSessionProperty.USER, out _user); @@ -190,6 +212,12 @@ internal SFSession( _maxRetryCount = extractedProperties.maxHttpRetries; _maxRetryTimeout = extractedProperties.retryTimeout; _disableSamlUrlCheck = extractedProperties._disableSamlUrlCheck; + + if (properties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var _authenticatorType) && _authenticatorType == "username_password_mfa") + { + var mfaKey = SnowflakeCredentialManagerFactory.BuildCredentialKey(properties[SFSessionProperty.HOST], properties[SFSessionProperty.USER], TokenType.MFAToken, _authenticatorType); + _mfaToken = SecureStringHelper.Encode(SnowflakeCredentialManagerFactory.GetCredentialManager().GetCredentials(mfaKey)); + } } catch (SnowflakeDbException e) { @@ -221,7 +249,11 @@ private void ValidateApplicationName(SFSessionProperties properties) } } - internal SFSession(String connectionString, SecureString password, IMockRestRequester restRequester) : this(connectionString, password) + internal SFSession(String connectionString, SecureString password, IMockRestRequester restRequester) : this(connectionString, password, null, EasyLoggingStarter.Instance, restRequester) + { + } + + internal SFSession(String connectionString, SecureString password, SecureString passcode, EasyLoggingStarter easyLoggingStarter, IMockRestRequester restRequester) : this(connectionString, password, passcode, easyLoggingStarter) { // Inject the HttpClient to use with the Mock requester restRequester.setHttpClient(_HttpClient); @@ -428,6 +460,19 @@ internal SFRestRequest BuildTimeoutRestRequest(Uri uri, Object body) }; } + internal SFRestRequest BuildTimeoutRestRequest(Uri uri, Object body, TimeSpan httpTimeout) + { + return new SFRestRequest() + { + jsonBody = body, + Url = uri, + authorizationToken = SF_AUTHORIZATION_BASIC, + RestTimeout = connectionTimeout, + HttpTimeout = httpTimeout, + _isLogin = true + }; + } + internal void UpdateSessionParameterMap(List parameterList) { logger.Debug("Update parameter map"); diff --git a/Snowflake.Data/Core/Session/SFSessionParameter.cs b/Snowflake.Data/Core/Session/SFSessionParameter.cs index 97fdcec23..7d25c6e01 100755 --- a/Snowflake.Data/Core/Session/SFSessionParameter.cs +++ b/Snowflake.Data/Core/Session/SFSessionParameter.cs @@ -14,5 +14,6 @@ internal enum SFSessionParameter QUERY_CONTEXT_CACHE_SIZE, DATE_OUTPUT_FORMAT, TIME_OUTPUT_FORMAT, + CLIENT_REQUEST_MFA_TOKEN, } } diff --git a/Snowflake.Data/Core/Session/SFSessionProperty.cs b/Snowflake.Data/Core/Session/SFSessionProperty.cs index 07896ae14..a9663961d 100644 --- a/Snowflake.Data/Core/Session/SFSessionProperty.cs +++ b/Snowflake.Data/Core/Session/SFSessionProperty.cs @@ -1,4 +1,4 @@ -/* +/* * Copyright (c) 2012-2021 Snowflake Computing Inc. All rights reserved. */ @@ -112,7 +112,11 @@ internal enum SFSessionProperty [SFSessionPropertyAttr(required = false, defaultValue = "true")] POOLINGENABLED, [SFSessionPropertyAttr(required = false, defaultValue = "false")] - DISABLE_SAML_URL_CHECK + DISABLE_SAML_URL_CHECK, + [SFSessionPropertyAttr(required = false, IsSecret = true)] + PASSCODE, + [SFSessionPropertyAttr(required = false, defaultValue = "false")] + PASSCODEINPASSWORD } class SFSessionPropertyAttr : Attribute @@ -181,7 +185,7 @@ public override int GetHashCode() return base.GetHashCode(); } - internal static SFSessionProperties ParseConnectionString(string connectionString, SecureString password) + internal static SFSessionProperties ParseConnectionString(string connectionString, SecureString password, SecureString passcode = null) { logger.Info("Start parsing connection string."); var builder = new DbConnectionStringBuilder(); @@ -257,7 +261,13 @@ internal static SFSessionProperties ParseConnectionString(string connectionStrin properties[SFSessionProperty.PASSWORD] = SecureStringHelper.Decode(password); } + if (passcode != null && passcode.Length > 0) + { + properties[SFSessionProperty.PASSCODE] = SecureStringHelper.Decode(passcode); + } + ValidateAuthenticator(properties); + ValidatePasscodeInPassword(properties); properties.IsPoolingEnabledValueProvided = properties.IsNonEmptyValueProvided(SFSessionProperty.POOLINGENABLED); CheckSessionProperties(properties); ValidateFileTransferMaxBytesInMemoryProperty(properties); @@ -303,7 +313,8 @@ private static void ValidateAuthenticator(SFSessionProperties properties) OktaAuthenticator.AUTH_NAME, OAuthAuthenticator.AUTH_NAME, KeyPairAuthenticator.AUTH_NAME, - ExternalBrowserAuthenticator.AUTH_NAME + ExternalBrowserAuthenticator.AUTH_NAME, + MFACacheAuthenticator.AUTH_NAME }; if (properties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator)) @@ -318,6 +329,23 @@ private static void ValidateAuthenticator(SFSessionProperties properties) } } + private static void ValidatePasscodeInPassword(SFSessionProperties properties) + { + if (properties.TryGetValue(SFSessionProperty.PASSCODEINPASSWORD, out var passCodeInPassword)) + { + if (!bool.TryParse(passCodeInPassword, out _)) + { + var errorMessage = $"Invalid value of {SFSessionProperty.PASSCODEINPASSWORD.ToString()} parameter"; + logger.Error(errorMessage); + throw new SnowflakeDbException( + new Exception(errorMessage), + SFError.INVALID_CONNECTION_PARAMETER_VALUE, + "", + SFSessionProperty.PASSCODEINPASSWORD.ToString()); + } + } + } + internal bool IsNonEmptyValueProvided(SFSessionProperty property) => TryGetValue(property, out var propertyValueStr) && !string.IsNullOrEmpty(propertyValueStr); diff --git a/Snowflake.Data/Core/Session/SessionFactory.cs b/Snowflake.Data/Core/Session/SessionFactory.cs index 2eb0ba6df..a1795ba10 100644 --- a/Snowflake.Data/Core/Session/SessionFactory.cs +++ b/Snowflake.Data/Core/Session/SessionFactory.cs @@ -4,9 +4,9 @@ namespace Snowflake.Data.Core.Session { internal class SessionFactory : ISessionFactory { - public SFSession NewSession(string connectionString, SecureString password) + public SFSession NewSession(string connectionString, SecureString password, SecureString passcode) { - return new SFSession(connectionString, password); + return new SFSession(connectionString, password, passcode, EasyLoggingStarter.Instance); } } } diff --git a/Snowflake.Data/Core/Session/SessionPool.cs b/Snowflake.Data/Core/Session/SessionPool.cs index de66c2240..d58c06223 100644 --- a/Snowflake.Data/Core/Session/SessionPool.cs +++ b/Snowflake.Data/Core/Session/SessionPool.cs @@ -9,11 +9,13 @@ using System.Threading; using System.Threading.Tasks; using Snowflake.Data.Client; +using Snowflake.Data.Core.Authenticator; using Snowflake.Data.Core.Tools; using Snowflake.Data.Log; namespace Snowflake.Data.Core.Session { + sealed class SessionPool : IDisposable { private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); @@ -133,34 +135,71 @@ internal void ValidateSecurePassword(SecureString password) private string ExtractPassword(SecureString password) => password == null ? string.Empty : SecureStringHelper.Decode(password); - internal SFSession GetSession(string connStr, SecureString password) + internal SFSession GetSession(string connStr, SecureString password, SecureString passcode) { s_logger.Debug("SessionPool::GetSession" + PoolIdentification()); + var sessionProperties = SFSessionProperties.ParseConnectionString(connStr, password); + ValidateMinPoolSizeWithPasscode(sessionProperties, passcode); if (!GetPooling()) - return NewNonPoolingSession(connStr, password); - var sessionOrCreateTokens = GetIdleSession(connStr); + return NewNonPoolingSession(connStr, password, passcode); + var isMfaAuthentication = sessionProperties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator) && authenticator == MFACacheAuthenticator.AUTH_NAME; + var sessionOrCreateTokens = GetIdleSession(connStr, isMfaAuthentication ? 1 : int.MaxValue); if (sessionOrCreateTokens.Session != null) { _sessionPoolEventHandler.OnSessionProvided(this); } ScheduleNewIdleSessions(connStr, password, sessionOrCreateTokens.BackgroundSessionCreationTokens()); WarnAboutOverridenConfig(); - return sessionOrCreateTokens.Session ?? NewSession(connStr, password, sessionOrCreateTokens.SessionCreationToken()); + var session = sessionOrCreateTokens.Session ?? NewSession(connStr, password, passcode, sessionOrCreateTokens.SessionCreationToken()); + if (isMfaAuthentication) + { + ScheduleNewIdleSessions(connStr, password, RegisterSessionCreationsToEnsureMinPoolSize()); + } + return session; } - internal async Task GetSessionAsync(string connStr, SecureString password, CancellationToken cancellationToken) + private void ValidateMinPoolSizeWithPasscode(SFSessionProperties sessionProperties, SecureString passcode) + { + if (!GetPooling() || !IsMultiplePoolsVersion() || _poolConfig.MinPoolSize == 0) return; + var isUsingPasscode = (passcode != null && passcode.Length > 0) || (sessionProperties.IsNonEmptyValueProvided(SFSessionProperty.PASSCODE) || + (sessionProperties.TryGetValue(SFSessionProperty.PASSCODEINPASSWORD, out var passcodeInPasswordValue) && + bool.TryParse(passcodeInPasswordValue, out var isPasscodeinPassword) && isPasscodeinPassword)); + var isMfaAuthenticator = sessionProperties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator) && + authenticator == MFACacheAuthenticator.AUTH_NAME; + if(isUsingPasscode && !isMfaAuthenticator) + { + const string ErrorMessage = "Passcode with MinPoolSize feature of connection pool allowed only for username_password_mfa authentication"; + s_logger.Error(ErrorMessage + PoolIdentification()); + throw new SnowflakeDbException(SFError.INVALID_CONNECTION_STRING, ErrorMessage); + } + } + + internal async Task GetSessionAsync(string connStr, SecureString password, SecureString passcode, CancellationToken cancellationToken) { s_logger.Debug("SessionPool::GetSessionAsync" + PoolIdentification()); + var sessionProperties = SFSessionProperties.ParseConnectionString(connStr, password); + ValidateMinPoolSizeWithPasscode(sessionProperties, passcode); if (!GetPooling()) - return await NewNonPoolingSessionAsync(connStr, password, cancellationToken).ConfigureAwait(false); - var sessionOrCreateTokens = GetIdleSession(connStr); + return await NewNonPoolingSessionAsync(connStr, password, passcode, cancellationToken).ConfigureAwait(false); + var isMfaAuthentication = sessionProperties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator) && authenticator == MFACacheAuthenticator.AUTH_NAME; + var sessionOrCreateTokens = GetIdleSession(connStr, isMfaAuthentication ? 1 : int.MaxValue); + WarnAboutOverridenConfig(); + if (sessionOrCreateTokens.Session != null) { _sessionPoolEventHandler.OnSessionProvided(this); } ScheduleNewIdleSessions(connStr, password, sessionOrCreateTokens.BackgroundSessionCreationTokens()); WarnAboutOverridenConfig(); - return sessionOrCreateTokens.Session ?? await NewSessionAsync(connStr, password, sessionOrCreateTokens.SessionCreationToken(), cancellationToken).ConfigureAwait(false); + var session = sessionOrCreateTokens.Session ?? + await NewSessionAsync(connStr, password, passcode, sessionOrCreateTokens.SessionCreationToken(), cancellationToken) + .ConfigureAwait(false); + if (isMfaAuthentication) + { + ScheduleNewIdleSessions(connStr, password, RegisterSessionCreationsToEnsureMinPoolSize()); + } + return session; + } private void ScheduleNewIdleSessions(string connStr, SecureString password, List tokens) @@ -172,7 +211,7 @@ private void ScheduleNewIdleSession(string connStr, SecureString password, Sessi { Task.Run(() => { - var session = NewSession(connStr, password, token); + var session = NewSession(connStr, password, null, token); AddSession(session, false); // we don't want to ensure min pool size here because we could get into infinite recursion if expirationTimeout would be very low }); } @@ -187,17 +226,17 @@ private void WarnAboutOverridenConfig() internal bool IsConfigOverridden() => _configOverriden; - internal SFSession GetSession() => GetSession(ConnectionString, Password); + internal SFSession GetSession(SecureString passcode) => GetSession(ConnectionString, Password, passcode); - internal Task GetSessionAsync(CancellationToken cancellationToken) => - GetSessionAsync(ConnectionString, Password, cancellationToken); + internal Task GetSessionAsync(SecureString passcode, CancellationToken cancellationToken) => + GetSessionAsync(ConnectionString, Password, passcode, cancellationToken); internal void SetSessionPoolEventHandler(ISessionPoolEventHandler sessionPoolEventHandler) { _sessionPoolEventHandler = sessionPoolEventHandler; } - private SessionOrCreationTokens GetIdleSession(string connStr) + private SessionOrCreationTokens GetIdleSession(string connStr, int maxSessions) { s_logger.Debug("SessionPool::GetIdleSession" + PoolIdentification()); lock (_sessionPoolLock) @@ -215,7 +254,7 @@ private SessionOrCreationTokens GetIdleSession(string connStr) return new SessionOrCreationTokens(session); } s_logger.Debug("SessionPool::GetIdleSession - no thread was waiting for a session, but could not find any idle session available in the pool" + PoolIdentification()); - var sessionsCount = AllowedNumberOfNewSessionCreations(1); + var sessionsCount = AllowedNumberOfNewSessionCreations(1, maxSessions); if (sessionsCount > 0) { // there is no need to wait for a session since we can create new ones @@ -226,7 +265,7 @@ private SessionOrCreationTokens GetIdleSession(string connStr) return new SessionOrCreationTokens(WaitForSession(connStr)); } - private List RegisterSessionCreationsWhenReturningSessionToPool() + private List RegisterSessionCreationsToEnsureMinPoolSize() { var count = AllowedNumberOfNewSessionCreations(0); return RegisterSessionCreations(count); @@ -237,7 +276,7 @@ private List RegisterSessionCreations(int sessionsCount) = .Select(_ => _sessionCreationTokenCounter.NewToken()) .ToList(); - private int AllowedNumberOfNewSessionCreations(int atLeastCount) + private int AllowedNumberOfNewSessionCreations(int atLeastCount, int maxSessionsLimit = int.MaxValue) { // we are expecting to create atLeast 1 session in case of opening a connection (atLeastCount = 1) // but we have no expectations when closing a connection (atLeastCount = 0) @@ -252,7 +291,7 @@ private int AllowedNumberOfNewSessionCreations(int atLeastCount) { var maxSessionsToCreate = _poolConfig.MaxPoolSize - currentSize; var sessionsNeeded = Math.Max(_poolConfig.MinPoolSize - currentSize, atLeastCount); - var sessionsToCreate = Math.Min(sessionsNeeded, maxSessionsToCreate); + var sessionsToCreate = Math.Min(maxSessionsLimit, Math.Min(sessionsNeeded, maxSessionsToCreate)); s_logger.Debug($"SessionPool - allowed to create {sessionsToCreate} sessions, current pool size is {currentSize} out of {_poolConfig.MaxPoolSize}" + PoolIdentification()); return sessionsToCreate; } @@ -326,15 +365,15 @@ private SFSession ExtractIdleSession(string connStr) return null; } - private SFSession NewNonPoolingSession(String connectionString, SecureString password) => - NewSession(connectionString, password, _noPoolingSessionCreationTokenCounter.NewToken()); + private SFSession NewNonPoolingSession(String connectionString, SecureString password, SecureString passcode) => + NewSession(connectionString, password, passcode, _noPoolingSessionCreationTokenCounter.NewToken()); - private SFSession NewSession(String connectionString, SecureString password, SessionCreationToken sessionCreationToken) + private SFSession NewSession(String connectionString, SecureString password, SecureString passcode, SessionCreationToken sessionCreationToken) { s_logger.Debug("SessionPool::NewSession" + PoolIdentification()); try { - var session = s_sessionFactory.NewSession(connectionString, password); + var session = s_sessionFactory.NewSession(connectionString, password, passcode); session.Open(); s_logger.Debug("SessionPool::NewSession - opened" + PoolIdentification()); if (GetPooling() && !_underDestruction) @@ -374,13 +413,14 @@ private SFSession NewSession(String connectionString, SecureString password, Ses private Task NewNonPoolingSessionAsync( String connectionString, SecureString password, + SecureString passcode, CancellationToken cancellationToken) => - NewSessionAsync(connectionString, password, _noPoolingSessionCreationTokenCounter.NewToken(), cancellationToken); + NewSessionAsync(connectionString, password, passcode, _noPoolingSessionCreationTokenCounter.NewToken(), cancellationToken); - private Task NewSessionAsync(String connectionString, SecureString password, SessionCreationToken sessionCreationToken, CancellationToken cancellationToken) + private Task NewSessionAsync(String connectionString, SecureString password, SecureString passcode, SessionCreationToken sessionCreationToken, CancellationToken cancellationToken) { s_logger.Debug("SessionPool::NewSessionAsync" + PoolIdentification()); - var session = s_sessionFactory.NewSession(connectionString, password); + var session = s_sessionFactory.NewSession(connectionString, password, passcode); return session .OpenAsync(cancellationToken) .ContinueWith(previousTask => @@ -457,7 +497,7 @@ internal bool AddSession(SFSession session, bool ensureMinPoolSize) ReleaseBusySession(session); if (ensureMinPoolSize) { - ScheduleNewIdleSessions(ConnectionString, Password, RegisterSessionCreationsWhenReturningSessionToPool()); + ScheduleNewIdleSessions(ConnectionString, Password, RegisterSessionCreationsToEnsureMinPoolSize()); } return false; } @@ -478,7 +518,7 @@ private Tuple> ReturnSessionToPool(SFSession se { _busySessionsCounter.Decrease(); var sessionCreationTokens = ensureMinPoolSize - ? RegisterSessionCreationsWhenReturningSessionToPool() + ? RegisterSessionCreationsToEnsureMinPoolSize() : SessionOrCreationTokens.s_emptySessionCreationTokenList; var poolState = GetCurrentState(); s_logger.Debug($"Could not return session to pool {poolState}" + PoolIdentification()); @@ -493,7 +533,7 @@ private Tuple> ReturnSessionToPool(SFSession se if (session.IsExpired(_poolConfig.ExpirationTimeout, DateTimeOffset.UtcNow.ToUnixTimeMilliseconds())) // checking again because we could have spent some time waiting for a lock { var sessionCreationTokens = ensureMinPoolSize - ? RegisterSessionCreationsWhenReturningSessionToPool() + ? RegisterSessionCreationsToEnsureMinPoolSize() : SessionOrCreationTokens.s_emptySessionCreationTokenList; var poolState = GetCurrentState(); s_logger.Debug($"Could not return session to pool {poolState}" + PoolIdentification()); @@ -508,7 +548,7 @@ private Tuple> ReturnSessionToPool(SFSession se _idleSessions.Add(session); _waitingForIdleSessionQueue.OnResourceIncrease(); var sessionCreationTokensAfterReturningToPool = ensureMinPoolSize - ? RegisterSessionCreationsWhenReturningSessionToPool() + ? RegisterSessionCreationsToEnsureMinPoolSize() : SessionOrCreationTokens.s_emptySessionCreationTokenList; var poolStateAfterReturningToPool = GetCurrentState(); s_logger.Debug($"returned session with sid {session.sessionId} to pool {poolStateAfterReturningToPool}" + PoolIdentification()); diff --git a/Snowflake.Data/Core/Tools/FileOperations.cs b/Snowflake.Data/Core/Tools/FileOperations.cs index 9efe481bd..656c51257 100644 --- a/Snowflake.Data/Core/Tools/FileOperations.cs +++ b/Snowflake.Data/Core/Tools/FileOperations.cs @@ -14,5 +14,7 @@ public virtual bool Exists(string path) { return File.Exists(path); } + + public virtual void Write(string path, string content) => File.WriteAllText(path, content); } } diff --git a/Snowflake.Data/Core/Tools/StringUtils.cs b/Snowflake.Data/Core/Tools/StringUtils.cs new file mode 100644 index 000000000..3e5c45767 --- /dev/null +++ b/Snowflake.Data/Core/Tools/StringUtils.cs @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using System.Security.Cryptography; + +namespace Snowflake.Data.Core.Tools +{ + public static class StringUtils + { + internal static string ToSha256Hash(this string text) + { + if (string.IsNullOrEmpty(text)) + return string.Empty; + + using (var sha256Encoder = SHA256.Create()) + { + var sha256Hash = sha256Encoder.ComputeHash(System.Text.Encoding.UTF8.GetBytes(text)); + return BitConverter.ToString(sha256Hash).Replace("-", string.Empty); + } + } + } +} diff --git a/Snowflake.Data/Core/Tools/UnixOperations.cs b/Snowflake.Data/Core/Tools/UnixOperations.cs index cb44099b7..1b369da86 100644 --- a/Snowflake.Data/Core/Tools/UnixOperations.cs +++ b/Snowflake.Data/Core/Tools/UnixOperations.cs @@ -2,20 +2,36 @@ * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. */ + +using System.IO; +using System.Security; +using System.Text; using Mono.Unix; using Mono.Unix.Native; namespace Snowflake.Data.Core.Tools { + internal class UnixOperations { public static readonly UnixOperations Instance = new UnixOperations(); + public virtual int CreateFileWithPermissions(string path, FilePermissions permissions) + { + return Syscall.creat(path, permissions); + } + public virtual int CreateDirectoryWithPermissions(string path, FilePermissions permissions) { return Syscall.mkdir(path, permissions); } + public virtual FileAccessPermissions GetFilePermissions(string path) + { + var fileInfo = new UnixFileInfo(path); + return fileInfo.FileAccessPermissions; + } + public virtual FileAccessPermissions GetDirPermissions(string path) { var dirInfo = new UnixDirectoryInfo(path); @@ -27,5 +43,24 @@ public virtual bool CheckFileHasAnyOfPermissions(string path, FileAccessPermissi var fileInfo = new UnixFileInfo(path); return (permissions & fileInfo.FileAccessPermissions) != 0; } + + public string ReadAllText(string path, FileAccessPermissions forbiddenPermissions = FileAccessPermissions.OtherReadWriteExecute) + { + var fileInfo = new UnixFileInfo(path: path); + + using (var handle = fileInfo.OpenRead()) + { + if (handle.OwnerUser.UserId != Syscall.geteuid()) + throw new SecurityException("Attempting to read a file not owned by the effective user of the current process"); + if (handle.OwnerGroup.GroupId != Syscall.getegid()) + throw new SecurityException("Attempting to read a file not owned by the effective group of the current process"); + if ((handle.FileAccessPermissions & forbiddenPermissions) != 0) + throw new SecurityException("Attempting to read a file with too broad permissions assigned"); + using (var streamReader = new StreamReader(handle, Encoding.Default)) + { + return streamReader.ReadToEnd(); + } + } + } } } diff --git a/Snowflake.Data/Logger/SecretDetector.cs b/Snowflake.Data/Logger/SecretDetector.cs index 59cd810d6..09c5981cf 100644 --- a/Snowflake.Data/Logger/SecretDetector.cs +++ b/Snowflake.Data/Logger/SecretDetector.cs @@ -92,7 +92,7 @@ private static string MaskCustomPatterns(string text) private const string ConnectionTokenPattern = @"(token|assertion content)(['""\s:=]+)([a-z0-9=/_\-+:]{8,})"; private const string TokenPropertyPattern = @"(token)(\s*=)(.*)"; private const string PasswordPattern = @"(password|passcode|pwd|proxypassword|private_key_pwd)(['""\s:=]+)([a-z0-9!""#$%&'\()*+,-./:;<=>?@\[\]\^_`{|}~]{6,})"; - private const string PasswordPropertyPattern = @"(password|proxypassword|private_key_pwd)(\s*=)(.*)"; + private const string PasswordPropertyPattern = @"(password|passcode|proxypassword|private_key_pwd)(\s*=)(.*)"; private static readonly Func[] s_maskFunctions = { MaskAWSServerSide, diff --git a/Snowflake.Data/Snowflake.Data.csproj b/Snowflake.Data/Snowflake.Data.csproj index a0b09fade..286d75771 100644 --- a/Snowflake.Data/Snowflake.Data.csproj +++ b/Snowflake.Data/Snowflake.Data.csproj @@ -1,4 +1,4 @@ - + netstandard2.0 Snowflake.Data diff --git a/doc/Connecting.md b/doc/Connecting.md index 576120f79..b794fdbaf 100644 --- a/doc/Connecting.md +++ b/doc/Connecting.md @@ -50,6 +50,8 @@ The following table lists all valid connection properties: | EXPIRATIONTIMEOUT | No | Timeout for using each connection. Connections which last more than specified timeout are considered to be expired and are being removed from the pool. The default is 1 hour. Usage of units possible and allowed are: e. g. `360000ms` (milliseconds), `3600s` (seconds), `60m` (minutes) where seconds are default for a skipped postfix. Special values: `0` - immediate expiration of the connection just after its creation. Expiration timeout cannot be set to infinity. | | POOLINGENABLED | No | Boolean flag indicating if the connection should be a part of a pool. The default value is `true`. | | DISABLE_SAML_URL_CHECK | No | Specifies whether to check if the saml postback url matches the host url from the connection string. The default value is `false`. | +| PASSCODE | No | Passcode from your Duo application to be used in Multi Factor Authentication. | +| PASSCODEINPASSWORD | No | Boolean flag indicating if MFA passcode is added to the password. |