diff --git a/s3transfer/__init__.py b/s3transfer/__init__.py index e8ff66f0..c3e0eac2 100644 --- a/s3transfer/__init__.py +++ b/s3transfer/__init__.py @@ -543,6 +543,7 @@ def download_file( key, filename, object_size, + extra_args, callback, ) parts_future = controller.submit(download_parts_handler) @@ -563,7 +564,7 @@ def _process_future_results(self, futures): future.result() def _download_file_as_future( - self, bucket, key, filename, object_size, callback + self, bucket, key, filename, object_size, extra_args, callback ): part_size = self._config.multipart_chunksize num_parts = int(math.ceil(object_size / float(part_size))) @@ -575,6 +576,7 @@ def _download_file_as_future( filename, part_size, num_parts, + extra_args, callback, ) try: @@ -593,7 +595,15 @@ def _calculate_range_param(self, part_size, part_index, num_parts): return range_param def _download_range( - self, bucket, key, filename, part_size, num_parts, callback, part_index + self, + bucket, + key, + filename, + part_size, + num_parts, + extra_args, + callback, + part_index, ): try: range_param = self._calculate_range_param( @@ -606,7 +616,7 @@ def _download_range( try: logger.debug("Making get_object call.") response = self._client.get_object( - Bucket=bucket, Key=key, Range=range_param + Bucket=bucket, Key=key, Range=range_param, **extra_args ) streaming_body = StreamReaderProgress( response['Body'], callback diff --git a/tests/unit/test_s3transfer.py b/tests/unit/test_s3transfer.py index a2f46a13..84fe917f 100644 --- a/tests/unit/test_s3transfer.py +++ b/tests/unit/test_s3transfer.py @@ -670,6 +670,30 @@ def test_download_file_fowards_extra_args(self): SSECustomerKey='foo', ) + def test_mutlipart_download_file_fowards_extra_args(self): + extra_args = { + 'SSECustomerKey': 'foo', + 'SSECustomerAlgorithm': 'AES256', + } + osutil = InMemoryOSLayer({}) + over_multipart_threshold = 100 * 1024 * 1024 + transfer = S3Transfer(self.client, osutil=osutil) + self.client.head_object.return_value = { + 'ContentLength': over_multipart_threshold + } + self.client.get_object.return_value = {'Body': BytesIO(b'foobar')} + + transfer.download_file( + 'bucket', 'key', 'filename', extra_args=extra_args + ) + + self.client.get_object.assert_called_with( + Bucket='bucket', + Key='key', + Range=self.client.get_object.call_args.kwargs['Range'], + **extra_args + ) + def test_get_object_stream_is_retried_and_succeeds(self): below_threshold = 20 osutil = InMemoryOSLayer({'smallfile': b'hello world'})