diff --git a/velox/connectors/hive/storage_adapters/s3fs/RegisterS3FileSystem.cpp b/velox/connectors/hive/storage_adapters/s3fs/RegisterS3FileSystem.cpp index 630bb411327..cc569992265 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/RegisterS3FileSystem.cpp +++ b/velox/connectors/hive/storage_adapters/s3fs/RegisterS3FileSystem.cpp @@ -16,6 +16,8 @@ #include "velox/connectors/hive/storage_adapters/s3fs/RegisterS3FileSystem.h" // @manual +#include + #ifdef VELOX_ENABLE_S3 #include "velox/common/base/StatsReporter.h" #include "velox/connectors/hive/storage_adapters/s3fs/S3Config.h" // @manual @@ -47,6 +49,7 @@ FileSystemMap& fileSystems() { } CacheKeyFn cacheKeyFunc; +S3FileSystemFactory fileSystemFactory; std::shared_ptr fileSystemGenerator( std::shared_ptr properties, @@ -86,7 +89,10 @@ std::shared_ptr fileSystemGenerator( static_cast>( properties->get(S3Config::kS3LogLocation)); initializeS3(logLevel, logLocation); - auto fs = std::make_shared(bucketName, properties); + auto fs = fileSystemFactory + ? fileSystemFactory(std::move(bucketName), properties) + : std::make_shared(bucketName, properties); + VELOX_CHECK_NOT_NULL(fs, "S3 file system factory returned nullptr"); instanceMap.insert({cacheKey, fs}); return fs; }); @@ -109,11 +115,14 @@ std::unique_ptr s3WriteFileSinkGenerator( } #endif -void registerS3FileSystem(CacheKeyFn identityFunction) { +void registerS3FileSystem( + CacheKeyFn identityFunction, + S3FileSystemFactory fileSystemFactoryInput) { #ifdef VELOX_ENABLE_S3 fileSystems().withWLock([&](auto& instanceMap) { if (instanceMap.empty()) { cacheKeyFunc = identityFunction; + fileSystemFactory = fileSystemFactoryInput; registerFileSystem(isS3File, std::function(fileSystemGenerator)); dwio::common::FileSink::registerFactory( std::function(s3WriteFileSinkGenerator)); diff --git a/velox/connectors/hive/storage_adapters/s3fs/RegisterS3FileSystem.h b/velox/connectors/hive/storage_adapters/s3fs/RegisterS3FileSystem.h index 346379b72b7..ba30bbb4eda 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/RegisterS3FileSystem.h +++ b/velox/connectors/hive/storage_adapters/s3fs/RegisterS3FileSystem.h @@ -19,6 +19,7 @@ #include #include #include +#include namespace Aws::Auth { // Forward-declare the AWSCredentialsProvider class from the AWS SDK. @@ -31,11 +32,21 @@ class ConfigBase; namespace facebook::velox::filesystems { +class FileSystem; + using CacheKeyFn = std::function< std::string(std::shared_ptr, std::string_view)>; +// Factory for substituting the FileSystem instance created for an S3 bucket. +// This customizes the filesystem object, not S3FileSystem::Impl. +using S3FileSystemFactory = std::function( + std::string bucketName, + std::shared_ptr config)>; + // Register the S3 filesystem. -void registerS3FileSystem(CacheKeyFn cacheKeyFunc = nullptr); +void registerS3FileSystem( + CacheKeyFn cacheKeyFunc = nullptr, + S3FileSystemFactory fileSystemFactory = nullptr); void registerS3Metrics(); diff --git a/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemRegistrationTest.cpp b/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemRegistrationTest.cpp index 256fe6e481e..6b98a1d290a 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemRegistrationTest.cpp +++ b/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemRegistrationTest.cpp @@ -26,11 +26,25 @@ std::string cacheKeyFunc( return config->get("hive.s3.endpoint").value(); } +class CustomS3FileSystem : public S3FileSystem { + public: + CustomS3FileSystem( + std::string_view bucketName, + std::shared_ptr config) + : S3FileSystem(bucketName, config) {} +}; + +std::shared_ptr s3FileSystemFactory( + std::string bucketName, + std::shared_ptr config) { + return std::make_shared(bucketName, config); +} + class S3FileSystemRegistrationTest : public S3Test { protected: static void SetUpTestCase() { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); - filesystems::registerS3FileSystem(cacheKeyFunc); + filesystems::registerS3FileSystem(cacheKeyFunc, s3FileSystemFactory); } static void TearDownTestCase() { @@ -85,6 +99,13 @@ TEST_F(S3FileSystemRegistrationTest, cacheKey) { ASSERT_EQ(s3fs, s3fs_new); } +TEST_F(S3FileSystemRegistrationTest, customFileSystemFactory) { + auto hiveConfig = minioServer_->hiveConfig(); + auto s3fs = filesystems::getFileSystem(kDummyPath, hiveConfig); + auto customS3fs = std::dynamic_pointer_cast(s3fs); + VELOX_CHECK_NOT_NULL(customS3fs); +} + TEST_F(S3FileSystemRegistrationTest, finalize) { auto hiveConfig = minioServer_->hiveConfig(); auto s3fs = filesystems::getFileSystem(kDummyPath, hiveConfig);