diff --git a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs index efbe85568..76490e099 100644 --- a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs @@ -2276,36 +2276,26 @@ public void TestUseMultiplePoolsConnectionPoolByDefault() } [Test] - //[Ignore("This test requires manual interaction and therefore cannot be run in CI")] - public void TestMFATokenCaching() + [Ignore("This test requires manual interaction and therefore cannot be run in CI")] + public void TestMFATokenCachingWithPasscodeFromConnectionString() { using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) { - //conn.Passcode = SecureStringHelper.Encode("123456"); conn.ConnectionString = ConnectionString - + ";authenticator=username_password_mfa;minPoolSize=2;application=DuoTest;authenticator=username_password_mfa;"; + + ";authenticator=username_password_mfa;application=DuoTest;Passcode=123456;"; // Authenticate to retrieve and store the token if doesn't exist or invalid Task connectTask = conn.OpenAsync(CancellationToken.None); connectTask.Wait(); Assert.AreEqual(ConnectionState.Open, conn.State); - - // Authenticate using the MFA token cache - connectTask = conn.OpenAsync(CancellationToken.None); - connectTask.Wait(); - Assert.AreEqual(ConnectionState.Open, conn.State); - - connectTask = conn.CloseAsync(CancellationToken.None); - connectTask.Wait(); - Assert.AreEqual(ConnectionState.Closed, conn.State); } } [Test] - //[Ignore("Requires manual steps and environment with mfa authentication enrolled")] // to enroll to mfa authentication edit your user profile - public void TestMfaWithPasswordConnection() + [Ignore("Requires manual steps and environment with mfa authentication enrolled")] // to enroll to mfa authentication edit your user profile + public void TestMfaWithPasswordConnectionUsingPasscodeWithSecureString() { // arrange using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) @@ -2315,15 +2305,11 @@ public void TestMfaWithPasswordConnection() conn.ConnectionString = ConnectionString + "minPoolSize=2;application=DuoTest;"; // act - conn.Open(); - Thread.Sleep(3000); - conn.Close(); - - conn.Open(); + Task connectTask = conn.OpenAsync(CancellationToken.None); + connectTask.Wait(); // assert Assert.AreEqual(ConnectionState.Open, conn.State); - // manual action: verify that you have received no push request for given connection } } diff --git a/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerMFATest.cs b/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerMFATest.cs index 6e5578614..c7970fe6e 100644 --- a/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerMFATest.cs +++ b/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerMFATest.cs @@ -9,7 +9,6 @@ namespace Snowflake.Data.Tests.UnitTests using System; using System.Linq; using System.Security; - using System.Threading; using Mock; using NUnit.Framework; using Snowflake.Data.Core; @@ -66,21 +65,20 @@ public void TestPoolManagerReturnsSessionPoolForGivenConnectionStringUsingMFA() authResponseSessionInfo = new SessionInfo() }); // Act - var session = _connectionPoolManager.GetSession(ConnectionStringMFACache, null); - Thread.Sleep(3000); + var session = _connectionPoolManager.GetSession(ConnectionStringMFACache, null, null); // Assert - + Awaiter.WaitUntilConditionOrTimeout(() => s_restRequester.LoginRequests.Count == 2, TimeSpan.FromSeconds(15)); Assert.AreEqual(2, s_restRequester.LoginRequests.Count); var loginRequest1 = s_restRequester.LoginRequests.Dequeue(); - Assert.AreEqual(loginRequest1.data.Token, string.Empty); - Assert.AreEqual(SecureStringHelper.Decode(session._mfaToken), testToken); + Assert.AreEqual(string.Empty, loginRequest1.data.Token); + Assert.AreEqual(testToken, SecureStringHelper.Decode(session._mfaToken)); Assert.IsTrue(loginRequest1.data.SessionParameters.TryGetValue(SFSessionParameter.CLIENT_REQUEST_MFA_TOKEN, out var value) && (bool)value); Assert.AreEqual("passcode", loginRequest1.data.extAuthnDuoMethod); var loginRequest2 = s_restRequester.LoginRequests.Dequeue(); - Assert.AreEqual(loginRequest2.data.Token, testToken); + Assert.AreEqual(testToken, loginRequest2.data.Token); Assert.IsTrue(loginRequest2.data.SessionParameters.TryGetValue(SFSessionParameter.CLIENT_REQUEST_MFA_TOKEN, out var value1) && (bool)value1); - Assert.AreEqual("passcode", loginRequest1.data.extAuthnDuoMethod); + Assert.AreEqual("passcode", loginRequest2.data.extAuthnDuoMethod); } [Test] @@ -103,9 +101,9 @@ public void TestPoolManagerShouldOnlyUsePasscodeAsArgumentForFirstSessionWhenNot // Act _connectionPoolManager.GetSession(ConnectionStringMFABasicWithoutPasscode, null, SecureStringHelper.Encode(TestPasscode)); - Thread.Sleep(10000); // Assert + Awaiter.WaitUntilConditionOrTimeout(() => s_restRequester.LoginRequests.Count == 3, TimeSpan.FromSeconds(15)); Assert.AreEqual(3, s_restRequester.LoginRequests.Count); var request = s_restRequester.LoginRequests.ToList(); Assert.AreEqual(1, request.Count(r => r.data.extAuthnDuoMethod == "passcode" && r.data.passcode == TestPasscode)); @@ -118,7 +116,7 @@ public void TestPoolManagerShouldThrowExceptionIfForcePoolingWithPasscodeNotUsin // Arrange var connectionString = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=2;passcode=12345;POOLINGENABLED=true"; // Act and assert - var thrown = Assert.Throws(() =>_connectionPoolManager.GetSession(connectionString, null)); + var thrown = Assert.Throws(() =>_connectionPoolManager.GetSession(connectionString, null,null)); Assert.That(thrown.Message, Does.Contain("Passcode with MinPoolSize feature of connection pool allowed only for username_password_mfa authentication")); } @@ -128,7 +126,7 @@ public void TestPoolManagerShouldNotThrowExceptionIfForcePoolingWithPasscodeNotU // Arrange var connectionString = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=2;passcode=12345;POOLINGENABLED=false"; // Act and assert - Assert.DoesNotThrow(() =>_connectionPoolManager.GetSession(connectionString, null)); + Assert.DoesNotThrow(() =>_connectionPoolManager.GetSession(connectionString, null, null)); } [Test] @@ -137,7 +135,7 @@ public void TestPoolManagerShouldNotThrowExceptionIfMinPoolSizeZeroNotUsingMFATo // Arrange var connectionString = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=0;passcode=12345;POOLINGENABLED=true"; // Act and assert - Assert.DoesNotThrow(() =>_connectionPoolManager.GetSession(connectionString, null)); + Assert.DoesNotThrow(() =>_connectionPoolManager.GetSession(connectionString, null, null)); } } diff --git a/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs b/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs index 3b280e6da..0293d6571 100644 --- a/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs +++ b/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs @@ -111,7 +111,7 @@ public void TestDifferentPoolsAreReturnedForDifferentConnectionStrings() public void TestGetSessionWorksForSpecifiedConnectionString() { // Act - var sfSession = _connectionPoolManager.GetSession(ConnectionString1, null); + var sfSession = _connectionPoolManager.GetSession(ConnectionString1, null, null); // Assert Assert.AreEqual(ConnectionString1, sfSession.ConnectionString); @@ -133,7 +133,7 @@ public async Task TestGetSessionAsyncWorksForSpecifiedConnectionString() public void TestCountingOfSessionProvidedByPool() { // Act - _connectionPoolManager.GetSession(ConnectionString1, null); + _connectionPoolManager.GetSession(ConnectionString1, null, null); // Assert var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, null); @@ -144,7 +144,7 @@ public void TestCountingOfSessionProvidedByPool() public void TestCountingOfSessionReturnedBackToPool() { // Arrange - var sfSession = _connectionPoolManager.GetSession(ConnectionString1, null); + var sfSession = _connectionPoolManager.GetSession(ConnectionString1, null, null); // Act _connectionPoolManager.AddSession(sfSession); diff --git a/Snowflake.Data.Tests/UnitTests/CredentialManager/SFCredentialManagerTest.cs b/Snowflake.Data.Tests/UnitTests/CredentialManager/SFCredentialManagerTest.cs index 8dbeec6c0..8526ce978 100644 --- a/Snowflake.Data.Tests/UnitTests/CredentialManager/SFCredentialManagerTest.cs +++ b/Snowflake.Data.Tests/UnitTests/CredentialManager/SFCredentialManagerTest.cs @@ -2,6 +2,8 @@ * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. */ +using Snowflake.Data.Core; + namespace Snowflake.Data.Tests.UnitTests.CredentialManager { using Mono.Unix; @@ -9,7 +11,6 @@ namespace Snowflake.Data.Tests.UnitTests.CredentialManager using Moq; using NUnit.Framework; using Snowflake.Data.Client; - using Snowflake.Data.Core.CredentialManager; using Snowflake.Data.Core.CredentialManager.Infrastructure; using Snowflake.Data.Core.Tools; using System; @@ -48,31 +49,24 @@ public void TestSavingCredentialsForAnExistingKey() var firstExpectedToken = "mockToken1"; var secondExpectedToken = "mockToken2"; - try - { - // act - _credentialManager.SaveCredentials(key, firstExpectedToken); + // act + _credentialManager.SaveCredentials(key, firstExpectedToken); - // assert - Assert.AreEqual(firstExpectedToken, _credentialManager.GetCredentials(key)); + // assert + Assert.AreEqual(firstExpectedToken, _credentialManager.GetCredentials(key)); - // act - _credentialManager.SaveCredentials(key, secondExpectedToken); + // act + _credentialManager.SaveCredentials(key, secondExpectedToken); - // assert - Assert.AreEqual(secondExpectedToken, _credentialManager.GetCredentials(key)); + // assert + Assert.AreEqual(secondExpectedToken, _credentialManager.GetCredentials(key)); - // act - _credentialManager.RemoveCredentials(key); + // act + _credentialManager.RemoveCredentials(key); + + // assert + Assert.IsTrue(string.IsNullOrEmpty(_credentialManager.GetCredentials(key))); - // assert - Assert.IsTrue(string.IsNullOrEmpty(_credentialManager.GetCredentials(key))); - } - catch (Exception ex) - { - // assert - Assert.Fail("Should not throw an exception: " + ex.Message); - } } [Test] @@ -81,19 +75,11 @@ public void TestRemovingCredentialsForKeyThatDoesNotExist() // arrange var key = "mockKey"; - try - { - // act - _credentialManager.RemoveCredentials(key); + // act + _credentialManager.RemoveCredentials(key); - // assert - Assert.IsTrue(string.IsNullOrEmpty(_credentialManager.GetCredentials(key))); - } - catch (Exception ex) - { - // assert - Assert.Fail("Should not throw an exception: " + ex.Message); - } + // assert + Assert.IsTrue(string.IsNullOrEmpty(_credentialManager.GetCredentials(key))); } } @@ -124,11 +110,11 @@ public class SFFileCredentialManagerTest : SFBaseCredentialManagerTest [SetUp] public void SetUp() { - _credentialManager = SFCredentialManagerFileImpl.Instance; + _credentialManager = SnowflakeCredentialManagerFileImpl.Instance; } } - [TestFixture] + [TestFixture, NonParallelizable] class SFCredentialManagerTest { ISnowflakeCredentialManager _credentialManager; @@ -147,7 +133,7 @@ class SFCredentialManagerTest private const string CustomJsonDir = "testdirectory"; - private static readonly string s_customJsonPath = Path.Combine(CustomJsonDir, SFCredentialManagerFileImpl.CredentialCacheFileName); + private static readonly string s_customJsonPath = Path.Combine(CustomJsonDir, SnowflakeCredentialManagerFileImpl.CredentialCacheFileName); [SetUp] public void SetUp() { @@ -155,22 +141,22 @@ [SetUp] public void SetUp() t_directoryOperations = new Mock(); t_unixOperations = new Mock(); t_environmentOperations = new Mock(); - SFCredentialManagerFactory.SetCredentialManager(SFCredentialManagerInMemoryImpl.Instance); + SnowflakeCredentialManagerFactory.SetCredentialManager(SFCredentialManagerInMemoryImpl.Instance); } [TearDown] public void TearDown() { - SFCredentialManagerFactory.UseDefaultCredentialManager(); + SnowflakeCredentialManagerFactory.UseDefaultCredentialManager(); } [Test] public void TestUsingDefaultCredentialManager() { // arrange - SFCredentialManagerFactory.UseDefaultCredentialManager(); + SnowflakeCredentialManagerFactory.UseDefaultCredentialManager(); // act - _credentialManager = SFCredentialManagerFactory.GetCredentialManager(); + _credentialManager = SnowflakeCredentialManagerFactory.GetCredentialManager(); // assert if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) @@ -187,13 +173,13 @@ public void TestUsingDefaultCredentialManager() public void TestSettingCustomCredentialManager() { // arrange - SFCredentialManagerFactory.SetCredentialManager(SFCredentialManagerFileImpl.Instance); + SnowflakeCredentialManagerFactory.SetCredentialManager(SnowflakeCredentialManagerFileImpl.Instance); // act - _credentialManager = SFCredentialManagerFactory.GetCredentialManager(); + _credentialManager = SnowflakeCredentialManagerFactory.GetCredentialManager(); // assert - Assert.IsInstanceOf(_credentialManager); + Assert.IsInstanceOf(_credentialManager); } [Test] @@ -213,10 +199,10 @@ public void TestThatThrowsErrorWhenCacheFileIsNotCreated() FilePermissions.S_IRUSR | FilePermissions.S_IWUSR | FilePermissions.S_IXUSR)) .Returns(-1); t_environmentOperations - .Setup(e => e.GetEnvironmentVariable(SFCredentialManagerFileImpl.CredentialCacheDirectoryEnvironmentName)) + .Setup(e => e.GetEnvironmentVariable(SnowflakeCredentialManagerFileImpl.CredentialCacheDirectoryEnvironmentName)) .Returns(CustomJsonDir); - SFCredentialManagerFactory.SetCredentialManager(new SFCredentialManagerFileImpl(t_fileOperations.Object, t_directoryOperations.Object, t_unixOperations.Object, t_environmentOperations.Object)); - _credentialManager = SFCredentialManagerFactory.GetCredentialManager(); + SnowflakeCredentialManagerFactory.SetCredentialManager(new SnowflakeCredentialManagerFileImpl(t_fileOperations.Object, t_directoryOperations.Object, t_unixOperations.Object, t_environmentOperations.Object)); + _credentialManager = SnowflakeCredentialManagerFactory.GetCredentialManager(); // act var thrown = Assert.Throws(() => _credentialManager.SaveCredentials("key", "token")); @@ -242,10 +228,10 @@ public void TestThatThrowsErrorWhenCacheFileCanBeAccessedByOthers() .Setup(u => u.GetFilePermissions(s_customJsonPath)) .Returns(FileAccessPermissions.AllPermissions); t_environmentOperations - .Setup(e => e.GetEnvironmentVariable(SFCredentialManagerFileImpl.CredentialCacheDirectoryEnvironmentName)) + .Setup(e => e.GetEnvironmentVariable(SnowflakeCredentialManagerFileImpl.CredentialCacheDirectoryEnvironmentName)) .Returns(CustomJsonDir); - SFCredentialManagerFactory.SetCredentialManager(new SFCredentialManagerFileImpl(t_fileOperations.Object, t_directoryOperations.Object, t_unixOperations.Object, t_environmentOperations.Object)); - _credentialManager = SFCredentialManagerFactory.GetCredentialManager(); + SnowflakeCredentialManagerFactory.SetCredentialManager(new SnowflakeCredentialManagerFileImpl(t_fileOperations.Object, t_directoryOperations.Object, t_unixOperations.Object, t_environmentOperations.Object)); + _credentialManager = SnowflakeCredentialManagerFactory.GetCredentialManager(); // act var thrown = Assert.Throws(() => _credentialManager.SaveCredentials("key", "token")); @@ -271,15 +257,15 @@ public void TestThatJsonFileIsCheckedIfAlreadyExists() .Setup(u => u.GetFilePermissions(s_customJsonPath)) .Returns(FileAccessPermissions.UserReadWriteExecute); t_environmentOperations - .Setup(e => e.GetEnvironmentVariable(SFCredentialManagerFileImpl.CredentialCacheDirectoryEnvironmentName)) + .Setup(e => e.GetEnvironmentVariable(SnowflakeCredentialManagerFileImpl.CredentialCacheDirectoryEnvironmentName)) .Returns(CustomJsonDir); t_fileOperations .SetupSequence(f => f.Exists(s_customJsonPath)) .Returns(false) .Returns(true); - SFCredentialManagerFactory.SetCredentialManager(new SFCredentialManagerFileImpl(t_fileOperations.Object, t_directoryOperations.Object, t_unixOperations.Object, t_environmentOperations.Object)); - _credentialManager = SFCredentialManagerFactory.GetCredentialManager(); + SnowflakeCredentialManagerFactory.SetCredentialManager(new SnowflakeCredentialManagerFileImpl(t_fileOperations.Object, t_directoryOperations.Object, t_unixOperations.Object, t_environmentOperations.Object)); + _credentialManager = SnowflakeCredentialManagerFactory.GetCredentialManager(); // act _credentialManager.SaveCredentials("key", "token"); diff --git a/Snowflake.Data.Tests/UnitTests/SFSessionTest.cs b/Snowflake.Data.Tests/UnitTests/SFSessionTest.cs index 65a6cacfe..969e5cadf 100644 --- a/Snowflake.Data.Tests/UnitTests/SFSessionTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFSessionTest.cs @@ -7,7 +7,6 @@ using NUnit.Framework; using Snowflake.Data.Core.Tools; using Snowflake.Data.Tests.Mock; -using System; namespace Snowflake.Data.Tests.UnitTests { diff --git a/Snowflake.Data/Client/SFCredentialManagerFactory.cs b/Snowflake.Data/Client/SFCredentialManagerFactory.cs deleted file mode 100644 index fd98bb5a8..000000000 --- a/Snowflake.Data/Client/SFCredentialManagerFactory.cs +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. - */ - -namespace Snowflake.Data.Client -{ - using System; - using Snowflake.Data.Core; - using Snowflake.Data.Core.CredentialManager; - using Snowflake.Data.Core.CredentialManager.Infrastructure; - using Snowflake.Data.Log; - using System.Runtime.InteropServices; - - public class SFCredentialManagerFactory - { - private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); - - private static ISnowflakeCredentialManager s_customCredentialManager = null; - - internal static string BuildCredentialKey(string host, string user, TokenType tokenType, string authenticator = null) - { - return $"{host.ToUpper()}:{user.ToUpper()}:{SFEnvironment.DriverName}:{tokenType.ToString().ToUpper()}:{authenticator?.ToUpper() ?? string.Empty}"; - } - - public static void UseDefaultCredentialManager() - { - s_logger.Info("Clearing the custom credential manager"); - s_customCredentialManager = null; - } - - public static void SetCredentialManager(ISnowflakeCredentialManager customCredentialManager) - { - s_logger.Info($"Setting the custom credential manager: {customCredentialManager.GetType().Name}"); - s_customCredentialManager = customCredentialManager; - } - - public static ISnowflakeCredentialManager GetCredentialManager() - { - if (s_customCredentialManager == null) - { - var defaultCredentialManager = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? (ISnowflakeCredentialManager) - SFCredentialManagerWindowsNativeImpl.Instance : SFCredentialManagerInMemoryImpl.Instance; - s_logger.Info($"Using the default credential manager: {defaultCredentialManager.GetType().Name}"); - return defaultCredentialManager; - } - s_logger.Info($"Using a custom credential manager: {s_customCredentialManager.GetType().Name}"); - return s_customCredentialManager; - } - } -} diff --git a/Snowflake.Data/Client/SnowflakeCredentialManagerFactory.cs b/Snowflake.Data/Client/SnowflakeCredentialManagerFactory.cs new file mode 100644 index 000000000..072ab3e05 --- /dev/null +++ b/Snowflake.Data/Client/SnowflakeCredentialManagerFactory.cs @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using System.Runtime.InteropServices; +using Snowflake.Data.Core; +using Snowflake.Data.Core.CredentialManager; +using Snowflake.Data.Core.CredentialManager.Infrastructure; +using Snowflake.Data.Log; + +namespace Snowflake.Data.Client +{ + public class SnowflakeCredentialManagerFactory + { + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + + private static readonly object credentialManagerLock = new object(); + + private static ISnowflakeCredentialManager s_customCredentialManager = null; + + internal static string BuildCredentialKey(string host, string user, TokenType tokenType, string authenticator = null) + { + return $"{host.ToUpper()}:{user.ToUpper()}:{SFEnvironment.DriverName}:{tokenType.ToString().ToUpper()}:{authenticator?.ToUpper() ?? string.Empty}"; + } + + public static void UseDefaultCredentialManager() + { + lock (credentialManagerLock) + { + s_logger.Info("Clearing the custom credential manager"); + s_customCredentialManager = null; + } + } + + public static void SetCredentialManager(ISnowflakeCredentialManager customCredentialManager) + { + lock (credentialManagerLock) + { + s_logger.Info($"Setting the custom credential manager: {customCredentialManager.GetType().Name}"); + s_customCredentialManager = customCredentialManager; + } + } + + public static ISnowflakeCredentialManager GetCredentialManager() + { + + if (s_customCredentialManager == null) + { + lock (credentialManagerLock) + { + if (s_customCredentialManager == null) + { + var defaultCredentialManager = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? (ISnowflakeCredentialManager) + SFCredentialManagerWindowsNativeImpl.Instance : SFCredentialManagerInMemoryImpl.Instance; + s_logger.Info($"Using the default credential manager: {defaultCredentialManager.GetType().Name}"); + return defaultCredentialManager; + } + } + } + s_logger.Info($"Using a custom credential manager: {s_customCredentialManager.GetType().Name}"); + return s_customCredentialManager; + } + } +} diff --git a/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerFileImpl.cs b/Snowflake.Data/Client/SnowflakeCredentialManagerFileImpl.cs similarity index 86% rename from Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerFileImpl.cs rename to Snowflake.Data/Client/SnowflakeCredentialManagerFileImpl.cs index 5fa072d4f..5a30a4559 100644 --- a/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerFileImpl.cs +++ b/Snowflake.Data/Client/SnowflakeCredentialManagerFileImpl.cs @@ -11,11 +11,11 @@ using System; using System.IO; using System.Runtime.InteropServices; -using KeyToken = System.Collections.Generic.Dictionary; +using KeyTokenDict = System.Collections.Generic.Dictionary; -namespace Snowflake.Data.Core.CredentialManager.Infrastructure +namespace Snowflake.Data.Core { - internal class SFCredentialManagerFileImpl : ISnowflakeCredentialManager + public class SnowflakeCredentialManagerFileImpl : ISnowflakeCredentialManager { internal const string CredentialCacheDirectoryEnvironmentName = "SF_TEMPORARY_CREDENTIAL_CACHE_DIR"; @@ -23,7 +23,7 @@ internal class SFCredentialManagerFileImpl : ISnowflakeCredentialManager internal const string CredentialCacheFileName = "temporary_credential.json"; - private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); private readonly string _jsonCacheDirectory; @@ -37,9 +37,9 @@ internal class SFCredentialManagerFileImpl : ISnowflakeCredentialManager private readonly EnvironmentOperations _environmentOperations; - public static readonly SFCredentialManagerFileImpl Instance = new SFCredentialManagerFileImpl(FileOperations.Instance, DirectoryOperations.Instance, UnixOperations.Instance, EnvironmentOperations.Instance); + public static readonly SnowflakeCredentialManagerFileImpl Instance = new SnowflakeCredentialManagerFileImpl(FileOperations.Instance, DirectoryOperations.Instance, UnixOperations.Instance, EnvironmentOperations.Instance); - internal SFCredentialManagerFileImpl(FileOperations fileOperations, DirectoryOperations directoryOperations, UnixOperations unixOperations, EnvironmentOperations environmentOperations) + internal SnowflakeCredentialManagerFileImpl(FileOperations fileOperations, DirectoryOperations directoryOperations, UnixOperations unixOperations, EnvironmentOperations environmentOperations) { _fileOperations = fileOperations; _directoryOperations = directoryOperations; @@ -101,10 +101,10 @@ internal void WriteToJsonFile(string content) } } - internal KeyToken ReadJsonFile() + internal KeyTokenDict ReadJsonFile() { var contentFile = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? File.ReadAllText(_jsonCacheFilePath) : _unixOperations.ReadAllText(_jsonCacheFilePath); - return JsonConvert.DeserializeObject(contentFile); + return JsonConvert.DeserializeObject(contentFile); } public string GetCredentials(string key) @@ -140,7 +140,7 @@ public void SaveCredentials(string key, string token) { s_logger.Debug($"Saving credentials to json file in {_jsonCacheFilePath} for key: {key}"); var hashKey = key.ToSha256Hash(); - KeyToken keyTokenPairs = _fileOperations.Exists(_jsonCacheFilePath) ? ReadJsonFile() : new KeyToken(); + KeyTokenDict keyTokenPairs = _fileOperations.Exists(_jsonCacheFilePath) ? ReadJsonFile() : new KeyTokenDict(); keyTokenPairs[hashKey] = token; string jsonString = JsonConvert.SerializeObject(keyTokenPairs); diff --git a/Snowflake.Data/Core/Authenticator/MFACacheAuthenticator.cs b/Snowflake.Data/Core/Authenticator/MFACacheAuthenticator.cs index cae9eb55d..d4b679632 100644 --- a/Snowflake.Data/Core/Authenticator/MFACacheAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/MFACacheAuthenticator.cs @@ -1,19 +1,16 @@ /* - * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. */ -using Snowflake.Data.Log; using System.Threading; using System.Threading.Tasks; +using Snowflake.Data.Core.Tools; namespace Snowflake.Data.Core.Authenticator { - using Tools; - class MFACacheAuthenticator : BaseAuthenticator, IAuthenticator { public const string AUTH_NAME = "username_password_mfa"; - private static readonly SFLogger logger = SFLoggerFactory.GetLogger(); internal MFACacheAuthenticator(SFSession session) : base(session, AUTH_NAME) { diff --git a/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerWindowsNativeImpl.cs b/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerWindowsNativeImpl.cs index aa52bda1f..264091ad9 100644 --- a/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerWindowsNativeImpl.cs +++ b/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerWindowsNativeImpl.cs @@ -30,9 +30,11 @@ public string GetCredentials(string key) return ""; } - var critCred = new CriticalCredentialHandle(nCredPtr); - Credential cred = critCred.GetCredential(); - return cred.CredentialBlob; + using (var critCred = new CriticalCredentialHandle(nCredPtr)) + { + var cred = critCred.GetCredential(); + return cred.CredentialBlob; + } } public void RemoveCredentials(string key) diff --git a/Snowflake.Data/Core/Session/ConnectionCacheManager.cs b/Snowflake.Data/Core/Session/ConnectionCacheManager.cs index 6f6ed2862..538221b09 100644 --- a/Snowflake.Data/Core/Session/ConnectionCacheManager.cs +++ b/Snowflake.Data/Core/Session/ConnectionCacheManager.cs @@ -11,7 +11,7 @@ namespace Snowflake.Data.Core.Session internal sealed class ConnectionCacheManager : IConnectionManager { private readonly SessionPool _sessionPool = SessionPool.CreateSessionCache(); - public SFSession GetSession(string connectionString, SecureString password, SecureString passcode = null) => _sessionPool.GetSession(connectionString, password, passcode); + public SFSession GetSession(string connectionString, SecureString password, SecureString passcode) => _sessionPool.GetSession(connectionString, password, passcode); public Task GetSessionAsync(string connectionString, SecureString password, SecureString passcode, CancellationToken cancellationToken) => _sessionPool.GetSessionAsync(connectionString, password, passcode, cancellationToken); public bool AddSession(SFSession session) => _sessionPool.AddSession(session, false); diff --git a/Snowflake.Data/Core/Session/ConnectionPoolManager.cs b/Snowflake.Data/Core/Session/ConnectionPoolManager.cs index 8c147b97d..6a0013bb0 100644 --- a/Snowflake.Data/Core/Session/ConnectionPoolManager.cs +++ b/Snowflake.Data/Core/Session/ConnectionPoolManager.cs @@ -29,7 +29,7 @@ internal ConnectionPoolManager() } } - public SFSession GetSession(string connectionString, SecureString password, SecureString passcode = null) + public SFSession GetSession(string connectionString, SecureString password, SecureString passcode) { s_logger.Debug($"ConnectionPoolManager::GetSession"); return GetPool(connectionString, password).GetSession(passcode); diff --git a/Snowflake.Data/Core/Session/SFSession.cs b/Snowflake.Data/Core/Session/SFSession.cs index c80257127..eaacc19ee 100644 --- a/Snowflake.Data/Core/Session/SFSession.cs +++ b/Snowflake.Data/Core/Session/SFSession.cs @@ -101,8 +101,6 @@ public void SetPooling(bool isEnabled) internal String _queryTag; - private readonly ISnowflakeCredentialManager _credManager = SFCredentialManagerFactory.GetCredentialManager(); - internal SecureString _mfaToken; internal void ProcessLoginResponse(LoginResponse authnResponse) @@ -126,8 +124,8 @@ internal void ProcessLoginResponse(LoginResponse authnResponse) if (!string.IsNullOrEmpty(authnResponse.data.mfaToken)) { _mfaToken = SecureStringHelper.Encode(authnResponse.data.mfaToken); - var key = SFCredentialManagerFactory.BuildCredentialKey(properties[SFSessionProperty.HOST], properties[SFSessionProperty.USER], TokenType.MFAToken, properties[SFSessionProperty.AUTHENTICATOR]); - _credManager.SaveCredentials(key, authnResponse.data.mfaToken); + var key = SnowflakeCredentialManagerFactory.BuildCredentialKey(properties[SFSessionProperty.HOST], properties[SFSessionProperty.USER], TokenType.MFAToken, properties[SFSessionProperty.AUTHENTICATOR]); + SnowflakeCredentialManagerFactory.GetCredentialManager().SaveCredentials(key, authnResponse.data.mfaToken); } logger.Debug($"Session opened: {sessionId}"); _startTime = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); @@ -145,8 +143,8 @@ internal void ProcessLoginResponse(LoginResponse authnResponse) { logger.Info("MFA Token has expired or not valid.", e); _mfaToken = null; - var mfaKey = SFCredentialManagerFactory.BuildCredentialKey(properties[SFSessionProperty.HOST], properties[SFSessionProperty.USER], TokenType.MFAToken, properties[SFSessionProperty.AUTHENTICATOR]); - _credManager.RemoveCredentials(mfaKey); + var mfaKey = SnowflakeCredentialManagerFactory.BuildCredentialKey(properties[SFSessionProperty.HOST], properties[SFSessionProperty.USER], TokenType.MFAToken, properties[SFSessionProperty.AUTHENTICATOR]); + SnowflakeCredentialManagerFactory.GetCredentialManager().RemoveCredentials(mfaKey); } throw e; @@ -217,8 +215,8 @@ internal SFSession( if (properties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var _authenticatorType) && _authenticatorType == "username_password_mfa") { - var mfaKey = SFCredentialManagerFactory.BuildCredentialKey(properties[SFSessionProperty.HOST], properties[SFSessionProperty.USER], TokenType.MFAToken, _authenticatorType); - _mfaToken = SecureStringHelper.Encode(_credManager.GetCredentials(mfaKey)); + var mfaKey = SnowflakeCredentialManagerFactory.BuildCredentialKey(properties[SFSessionProperty.HOST], properties[SFSessionProperty.USER], TokenType.MFAToken, _authenticatorType); + _mfaToken = SecureStringHelper.Encode(SnowflakeCredentialManagerFactory.GetCredentialManager().GetCredentials(mfaKey)); } } catch (SnowflakeDbException e) diff --git a/Snowflake.Data/Core/Session/SFSessionProperty.cs b/Snowflake.Data/Core/Session/SFSessionProperty.cs index c472a6e55..a75492d47 100644 --- a/Snowflake.Data/Core/Session/SFSessionProperty.cs +++ b/Snowflake.Data/Core/Session/SFSessionProperty.cs @@ -113,7 +113,7 @@ internal enum SFSessionProperty POOLINGENABLED, [SFSessionPropertyAttr(required = false, defaultValue = "false")] DISABLE_SAML_URL_CHECK, - [SFSessionPropertyAttr(required = false, defaultValue = "false", IsSecret = true)] + [SFSessionPropertyAttr(required = false, defaultValue = "false")] CLIENT_REQUEST_MFA_TOKEN, [SFSessionPropertyAttr(required = false, IsSecret = true)] PASSCODE, diff --git a/Snowflake.Data/Core/Session/SessionPool.cs b/Snowflake.Data/Core/Session/SessionPool.cs index b94376213..8164c4999 100644 --- a/Snowflake.Data/Core/Session/SessionPool.cs +++ b/Snowflake.Data/Core/Session/SessionPool.cs @@ -257,6 +257,7 @@ private SessionOrCreationTokens GetIdleSession(string connStr, int maxSessions) var sessionsCount = Math.Min(maxSessions, AllowedNumberOfNewSessionCreations(1)); if (sessionsCount > 0) { + s_logger.Debug($"SessionPool::GetIdleSession - register creation of {sessionsCount} sessions" + PoolIdentification()); // there is no need to wait for a session since we can create new ones return new SessionOrCreationTokens(RegisterSessionCreations(sessionsCount)); } @@ -497,7 +498,7 @@ internal bool AddSession(SFSession session, bool ensureMinPoolSize) ReleaseBusySession(session); if (ensureMinPoolSize) { - ScheduleNewIdleSessions(ConnectionString, Password, RegisterSessionCreationsToEnsureMinPoolSize()); // passcode is probably not fresh - it could be improved + ScheduleNewIdleSessions(ConnectionString, Password, RegisterSessionCreationsToEnsureMinPoolSize()); } return false; } diff --git a/Snowflake.Data/Core/Tools/StringUtils.cs b/Snowflake.Data/Core/Tools/StringUtils.cs index 70bebe872..329a3bf27 100644 --- a/Snowflake.Data/Core/Tools/StringUtils.cs +++ b/Snowflake.Data/Core/Tools/StringUtils.cs @@ -1,6 +1,6 @@ -// -// Copyright (c) 2019-2024 Snowflake Inc. All rights reserved. -// +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ namespace Snowflake.Data.Core.Tools { diff --git a/Snowflake.Data/Core/Tools/UnixOperations.cs b/Snowflake.Data/Core/Tools/UnixOperations.cs index 2b61e225e..7757a5681 100644 --- a/Snowflake.Data/Core/Tools/UnixOperations.cs +++ b/Snowflake.Data/Core/Tools/UnixOperations.cs @@ -44,16 +44,6 @@ public virtual bool CheckFileHasAnyOfPermissions(string path, FileAccessPermissi return (permissions & fileInfo.FileAccessPermissions) != 0; } - - /// - /// Reads all text from a file at the specified path, ensuring the file is owned by the effective user and group of the current process, - /// and does not have broader permissions than specified. - /// - /// The path to the file. - /// Permissions that are not allowed for the file. Defaults to OtherReadWriteExecute. - /// The content of the file as a string. - /// Thrown if the file is not owned by the effective user or group, or if it has forbidden permissions. - public string ReadAllText(string path, FileAccessPermissions forbiddenPermissions = FileAccessPermissions.OtherReadWriteExecute) { var fileInfo = new UnixFileInfo(path: path);