Skip to content

Commit

Permalink
SNOW-1817982 iobound tpe limiting (#2115)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-mkeller authored Dec 5, 2024
1 parent 703d7f4 commit 9439854
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 3 deletions.
3 changes: 3 additions & 0 deletions DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne

# Release Notes

- v3.12.5(TBD)
- Added a feature to limit the sizes of IO-bound ThreadPoolExecutors during PUT and GET commands.

- v3.12.4(December 3,2024)
- Fixed a bug where multipart uploads to Azure would be missing their MD5 hashes.
- Fixed a bug where OpenTelemetry header injection would sometimes cause Exceptions to be thrown.
Expand Down
8 changes: 8 additions & 0 deletions src/snowflake/connector/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,10 @@ def _get_private_bytes_from_file(
False,
bool,
), # disable saml url check in okta authentication
"iobound_tpe_limit": (
None,
(type(None), int),
), # SNOW-1817982: limit iobound TPE sizes when executing PUT/GET
}

APPLICATION_RE = re.compile(r"[\w\d_]+")
Expand Down Expand Up @@ -726,6 +730,10 @@ def auth_class(self, value: AuthByPlugin) -> None:
def is_query_context_cache_disabled(self) -> bool:
return self._disable_query_context_cache

@property
def iobound_tpe_limit(self) -> int | None:
return self._iobound_tpe_limit

def connect(self, **kwargs) -> None:
"""Establishes connection to Snowflake."""
logger.debug("connect")
Expand Down
1 change: 1 addition & 0 deletions src/snowflake/connector/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,7 @@ def execute(
source_from_stream=file_stream,
multipart_threshold=data.get("threshold"),
use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1,
iobound_tpe_limit=self._connection.iobound_tpe_limit,
)
sf_file_transfer_agent.execute()
data = sf_file_transfer_agent.result()
Expand Down
14 changes: 12 additions & 2 deletions src/snowflake/connector/file_transfer_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def __init__(
multipart_threshold: int | None = None,
source_from_stream: IO[bytes] | None = None,
use_s3_regional_url: bool = False,
iobound_tpe_limit: int | None = None,
) -> None:
self._cursor = cursor
self._command = command
Expand Down Expand Up @@ -384,6 +385,7 @@ def __init__(
self._multipart_threshold = multipart_threshold or 67108864 # Historical value
self._use_s3_regional_url = use_s3_regional_url
self._credentials: StorageCredential | None = None
self._iobound_tpe_limit = iobound_tpe_limit

def execute(self) -> None:
self._parse_command()
Expand Down Expand Up @@ -440,10 +442,15 @@ def execute(self) -> None:
result.result_status = result.result_status.value

def transfer(self, metas: list[SnowflakeFileMeta]) -> None:
iobound_tpe_limit = min(len(metas), os.cpu_count())
logger.debug("Decided IO-bound TPE size: %d", iobound_tpe_limit)
if self._iobound_tpe_limit is not None:
logger.debug("IO-bound TPE size is limited to: %d", self._iobound_tpe_limit)
iobound_tpe_limit = min(iobound_tpe_limit, self._iobound_tpe_limit)
max_concurrency = self._parallel
network_tpe = ThreadPoolExecutor(max_concurrency)
preprocess_tpe = ThreadPoolExecutor(min(len(metas), os.cpu_count()))
postprocess_tpe = ThreadPoolExecutor(min(len(metas), os.cpu_count()))
preprocess_tpe = ThreadPoolExecutor(iobound_tpe_limit)
postprocess_tpe = ThreadPoolExecutor(iobound_tpe_limit)
logger.debug(f"Chunk ThreadPoolExecutor size: {max_concurrency}")
cv_main_thread = threading.Condition() # to signal the main thread
cv_chunk_process = (
Expand All @@ -454,6 +461,9 @@ def transfer(self, metas: list[SnowflakeFileMeta]) -> None:
transfer_metadata = TransferMetadata() # this is protected by cv_chunk_process
is_upload = self._command_type == CMD_TYPE_UPLOAD
exception_caught_in_callback: Exception | None = None
logger.debug(
"Going to %sload %d files", "up" if is_upload else "down", len(metas)
)

def notify_file_completed() -> None:
# Increment the number of completed files, then notify the main thread.
Expand Down
18 changes: 18 additions & 0 deletions test/integ/test_put_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,3 +827,21 @@ def test_put_md5(tmp_path, conn_cnx):
cur.execute(f"LS @{stage_name}").fetchall(),
)
)


@pytest.mark.skipolddriver
def test_iobound_limit(tmp_path, conn_cnx, caplog):
tmp_stage_name = random_string(5, "test_iobound_limit")
file0 = tmp_path / "file0"
file1 = tmp_path / "file1"
file0.touch()
file1.touch()
with conn_cnx(iobound_tpe_limit=1) as conn:
with conn.cursor() as cur:
cur.execute(f"create temp stage {tmp_stage_name}")
with caplog.at_level(
logging.DEBUG, "snowflake.connector.file_transfer_agent"
):
cur.execute(f"put file://{tmp_path}/* @{tmp_stage_name}")
assert "Decided IO-bound TPE size: 2" in caplog.text
assert "IO-bound TPE size is limited to: 1" in caplog.text
92 changes: 91 additions & 1 deletion test/unit/test_put_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def test_percentage(tmp_path):
func_callback(1)


@pytest.mark.skipolddriver
def test_upload_file_with_azure_upload_failed_error(tmp_path):
"""Tests Upload file with expired Azure storage token."""
file1 = tmp_path / "file1"
Expand Down Expand Up @@ -166,3 +165,94 @@ def test_upload_file_with_azure_upload_failed_error(tmp_path):
rest_client.execute()
assert mock_update.called
assert rest_client._results[0].error_details is exc


def test_iobound_limit(tmp_path):
file1 = tmp_path / "file1"
file2 = tmp_path / "file2"
file3 = tmp_path / "file3"
file1.touch()
file2.touch()
file3.touch()
# Positive case
rest_client = SnowflakeFileTransferAgent(
mock.MagicMock(autospec=SnowflakeCursor),
"PUT some_file.txt",
{
"data": {
"command": "UPLOAD",
"src_locations": [file1, file2, file3],
"sourceCompression": "none",
"stageInfo": {
"creds": {
"AZURE_SAS_TOKEN": "sas_token",
},
"location": "some_bucket",
"region": "no_region",
"locationType": "AZURE",
"path": "remote_loc",
"endPoint": "",
"storageAccount": "storage_account",
},
},
"success": True,
},
)
with mock.patch(
"snowflake.connector.file_transfer_agent.ThreadPoolExecutor"
) as tpe:
with mock.patch("snowflake.connector.file_transfer_agent.threading.Condition"):
with mock.patch(
"snowflake.connector.file_transfer_agent.TransferMetadata",
return_value=mock.Mock(
num_files_started=0,
num_files_completed=3,
),
):
try:
rest_client.execute()
except AttributeError:
pass
# 2 IObound TPEs should be created for 3 files unlimited
rest_client = SnowflakeFileTransferAgent(
mock.MagicMock(autospec=SnowflakeCursor),
"PUT some_file.txt",
{
"data": {
"command": "UPLOAD",
"src_locations": [file1, file2, file3],
"sourceCompression": "none",
"stageInfo": {
"creds": {
"AZURE_SAS_TOKEN": "sas_token",
},
"location": "some_bucket",
"region": "no_region",
"locationType": "AZURE",
"path": "remote_loc",
"endPoint": "",
"storageAccount": "storage_account",
},
},
"success": True,
},
iobound_tpe_limit=2,
)
assert len(list(filter(lambda e: e.args == (3,), tpe.call_args_list))) == 2
with mock.patch(
"snowflake.connector.file_transfer_agent.ThreadPoolExecutor"
) as tpe:
with mock.patch("snowflake.connector.file_transfer_agent.threading.Condition"):
with mock.patch(
"snowflake.connector.file_transfer_agent.TransferMetadata",
return_value=mock.Mock(
num_files_started=0,
num_files_completed=3,
),
):
try:
rest_client.execute()
except AttributeError:
pass
# 2 IObound TPEs should be created for 3 files limited to 2
assert len(list(filter(lambda e: e.args == (2,), tpe.call_args_list))) == 2

0 comments on commit 9439854

Please sign in to comment.