Skip to content

Commit

Permalink
[Storage] Modified StorageRetryPolicy to skip AzureSigningError f…
Browse files Browse the repository at this point in the history
…rom bad storage account key (#36431)
  • Loading branch information
weirongw23-msft authored Jul 25, 2024
1 parent 6f0da99 commit 14fe734
Show file tree
Hide file tree
Showing 14 changed files with 100 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
SansIOHTTPPolicy
)

from .authentication import StorageHttpChallenge
from .authentication import AzureSigningError, StorageHttpChallenge
from .constants import DEFAULT_OAUTH_SCOPE
from .models import LocationMode

Expand Down Expand Up @@ -542,6 +542,8 @@ def send(self, request):
continue
break
except AzureError as err:
if isinstance(err, AzureSigningError):
raise
retries_remaining = self.increment(
retry_settings, request=request.http_request, error=err)
if retries_remaining:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from azure.core.exceptions import AzureError
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, AsyncHTTPPolicy

from .authentication import StorageHttpChallenge
from .authentication import AzureSigningError, StorageHttpChallenge
from .constants import DEFAULT_OAUTH_SCOPE
from .policies import is_retry, StorageRetryPolicy

Expand Down Expand Up @@ -127,6 +127,8 @@ async def send(self, request):
continue
break
except AzureError as err:
if isinstance(err, AzureSigningError):
raise
retries_remaining = self.increment(
retry_settings, request=request.http_request, error=err)
if retries_remaining:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from azure.core.pipeline.policies import ContentDecodePolicy

from .authentication import AzureSigningError
from .models import get_enum_value, StorageErrorCode, UserDelegationKey
from .parser import _to_utc_datetime

Expand Down Expand Up @@ -81,9 +82,12 @@ def return_raw_deserialized(response, *_):
return response.http_response.location_mode, response.context[ContentDecodePolicy.CONTEXT_NAME]


def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # pylint:disable=too-many-statements
def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # pylint:disable=too-many-statements, too-many-branches
raise_error = HttpResponseError
serialized = False
if isinstance(storage_error, AzureSigningError):
storage_error.message = storage_error.message + \
'. This is likely due to an invalid shared key. Please check your shared key and try again.'
if not storage_error.response or storage_error.response.status_code in [200, 204]:
raise storage_error
# If it is one of those three then it has been serialized prior by the generated layer.
Expand Down
28 changes: 28 additions & 0 deletions sdk/storage/azure-storage-blob/tests/test_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
ServiceResponseError
)
from azure.core.pipeline.transport import RequestsTransport
from azure.storage.blob._shared.authentication import AzureSigningError
from azure.storage.blob import (
BlobClient,
BlobServiceClient,
ExponentialRetry,
LinearRetry,
Expand Down Expand Up @@ -550,4 +552,30 @@ def test_streaming_retry(self, **kwargs):
blob.download_blob()
assert iterator_mock.__next__.call_count == count[0] == 3

@BlobPreparer()
def test_invalid_storage_account_key(self, **kwargs):
storage_account_name = kwargs.pop("storage_account_name")
storage_account_key = "a"

# Arrange
blob_client = self._create_storage_service(
BlobClient,
storage_account_name,
storage_account_key,
container_name="foo",
blob_name="bar"
)

retry_counter = RetryCounter()
retry_callback = retry_counter.simple_count

# Act
with pytest.raises(AzureSigningError) as e:
blob_client.get_blob_properties(retry_hook=retry_callback)

# Assert
assert ("This is likely due to an invalid shared key. Please check your shared key and try again." in
e.value.message)
assert retry_counter.count == 0

# ------------------------------------------------------------------------------
29 changes: 28 additions & 1 deletion sdk/storage/azure-storage-blob/tests/test_retry_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
)
from azure.core.pipeline.transport import AioHttpTransport
from azure.storage.blob import LocationMode
from azure.storage.blob._shared.authentication import AzureSigningError
from azure.storage.blob._shared.policies_async import ExponentialRetry, LinearRetry
from azure.storage.blob.aio import BlobServiceClient
from azure.storage.blob.aio import BlobClient, BlobServiceClient

from devtools_testutils import ResponseCallback, RetryCounter
from devtools_testutils.aio import recorded_by_proxy_async
Expand Down Expand Up @@ -529,4 +530,30 @@ async def test_streaming_retry(self, **kwargs):
await blob.download_blob()
assert stream_reader_read_mock.call_count == count[0] == 4

@BlobPreparer()
async def test_invalid_storage_account_key(self, **kwargs):
storage_account_name = kwargs.pop("storage_account_name")
storage_account_key = "a"

# Arrange
blob_client = self._create_storage_service(
BlobClient,
storage_account_name,
storage_account_key,
container_name="foo",
blob_name="bar"
)

retry_counter = RetryCounter()
retry_callback = retry_counter.simple_count

# Act
with pytest.raises(AzureSigningError) as e:
await blob_client.get_blob_properties(retry_hook=retry_callback)

# Assert
assert ("This is likely due to an invalid shared key. Please check your shared key and try again." in
e.value.message)
assert retry_counter.count == 0

# ------------------------------------------------------------------------------
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
SansIOHTTPPolicy
)

from .authentication import StorageHttpChallenge
from .authentication import AzureSigningError, StorageHttpChallenge
from .constants import DEFAULT_OAUTH_SCOPE
from .models import LocationMode

Expand Down Expand Up @@ -544,6 +544,8 @@ def send(self, request):
continue
break
except AzureError as err:
if isinstance(err, AzureSigningError):
raise
retries_remaining = self.increment(
retry_settings, request=request.http_request, error=err)
if retries_remaining:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from azure.core.exceptions import AzureError
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, AsyncHTTPPolicy

from .authentication import StorageHttpChallenge
from .authentication import AzureSigningError, StorageHttpChallenge
from .constants import DEFAULT_OAUTH_SCOPE
from .policies import is_retry, StorageRetryPolicy

Expand Down Expand Up @@ -127,6 +127,8 @@ async def send(self, request):
continue
break
except AzureError as err:
if isinstance(err, AzureSigningError):
raise
retries_remaining = self.increment(
retry_settings, request=request.http_request, error=err)
if retries_remaining:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from azure.core.pipeline.policies import ContentDecodePolicy

from .authentication import AzureSigningError
from .models import get_enum_value, StorageErrorCode, UserDelegationKey
from .parser import _to_utc_datetime

Expand Down Expand Up @@ -81,9 +82,12 @@ def return_raw_deserialized(response, *_):
return response.http_response.location_mode, response.context[ContentDecodePolicy.CONTEXT_NAME]


def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # pylint:disable=too-many-statements
def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # pylint:disable=too-many-statements, too-many-branches
raise_error = HttpResponseError
serialized = False
if isinstance(storage_error, AzureSigningError):
storage_error.message = storage_error.message + \
'. This is likely due to an invalid shared key. Please check your shared key and try again.'
if not storage_error.response or storage_error.response.status_code in [200, 204]:
raise storage_error
# If it is one of those three then it has been serialized prior by the generated layer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
SansIOHTTPPolicy
)

from .authentication import StorageHttpChallenge
from .authentication import AzureSigningError, StorageHttpChallenge
from .constants import DEFAULT_OAUTH_SCOPE
from .models import LocationMode

Expand Down Expand Up @@ -541,6 +541,8 @@ def send(self, request):
continue
break
except AzureError as err:
if isinstance(err, AzureSigningError):
raise
retries_remaining = self.increment(
retry_settings, request=request.http_request, error=err)
if retries_remaining:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from azure.core.exceptions import AzureError
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, AsyncHTTPPolicy

from .authentication import StorageHttpChallenge
from .authentication import AzureSigningError, StorageHttpChallenge
from .constants import DEFAULT_OAUTH_SCOPE
from .policies import is_retry, StorageRetryPolicy

Expand Down Expand Up @@ -127,6 +127,8 @@ async def send(self, request):
continue
break
except AzureError as err:
if isinstance(err, AzureSigningError):
raise
retries_remaining = self.increment(
retry_settings, request=request.http_request, error=err)
if retries_remaining:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from azure.core.pipeline.policies import ContentDecodePolicy

from .authentication import AzureSigningError
from .models import get_enum_value, StorageErrorCode, UserDelegationKey
from .parser import _to_utc_datetime

Expand Down Expand Up @@ -81,9 +82,12 @@ def return_raw_deserialized(response, *_):
return response.http_response.location_mode, response.context[ContentDecodePolicy.CONTEXT_NAME]


def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # pylint:disable=too-many-statements
def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # pylint:disable=too-many-statements, too-many-branches
raise_error = HttpResponseError
serialized = False
if isinstance(storage_error, AzureSigningError):
storage_error.message = storage_error.message + \
'. This is likely due to an invalid shared key. Please check your shared key and try again.'
if not storage_error.response or storage_error.response.status_code in [200, 204]:
raise storage_error
# If it is one of those three then it has been serialized prior by the generated layer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
SansIOHTTPPolicy
)

from .authentication import StorageHttpChallenge
from .authentication import AzureSigningError, StorageHttpChallenge
from .constants import DEFAULT_OAUTH_SCOPE
from .models import LocationMode

Expand Down Expand Up @@ -547,6 +547,8 @@ def send(self, request):
continue
break
except AzureError as err:
if isinstance(err, AzureSigningError):
raise
retries_remaining = self.increment(
retry_settings, request=request.http_request, error=err)
if retries_remaining:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from azure.core.exceptions import AzureError
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, AsyncHTTPPolicy

from .authentication import StorageHttpChallenge
from .authentication import AzureSigningError, StorageHttpChallenge
from .constants import DEFAULT_OAUTH_SCOPE
from .policies import is_retry, StorageRetryPolicy

Expand Down Expand Up @@ -127,6 +127,8 @@ async def send(self, request):
continue
break
except AzureError as err:
if isinstance(err, AzureSigningError):
raise
retries_remaining = self.increment(
retry_settings, request=request.http_request, error=err)
if retries_remaining:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from azure.core.pipeline.policies import ContentDecodePolicy

from .authentication import AzureSigningError
from .models import get_enum_value, StorageErrorCode, UserDelegationKey
from .parser import _to_utc_datetime

Expand Down Expand Up @@ -81,9 +82,12 @@ def return_raw_deserialized(response, *_):
return response.http_response.location_mode, response.context[ContentDecodePolicy.CONTEXT_NAME]


def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # pylint:disable=too-many-statements
def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # pylint:disable=too-many-statements, too-many-branches
raise_error = HttpResponseError
serialized = False
if isinstance(storage_error, AzureSigningError):
storage_error.message = storage_error.message + \
'. This is likely due to an invalid shared key. Please check your shared key and try again.'
if not storage_error.response or storage_error.response.status_code in [200, 204]:
raise storage_error
# If it is one of those three then it has been serialized prior by the generated layer.
Expand Down

0 comments on commit 14fe734

Please sign in to comment.