Skip to content

[S3] Validate parts in multi-part upload #2981

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright 2024 LinkedIn Corp. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*
*/
package com.github.ambry.frontend.s3;

public class S3Constants {
public static final int MIN_PART_NUM = 1;
public static final int MAX_PART_NUM = 10000;
public static final int MAX_LIST_SIZE = 10000;

// Error Messages
public static final String ERR_INVALID_PART_NUMBER =
"Invalid part number: %d. Part number must be an integer between %d and %d.";
public static final String ERR_DUPLICATE_PART_NUMBER = "Duplicate part number found: %d.";
public static final String ERR_DUPLICATE_ETAG = "Duplicate eTag found: %s.";
public static final String ERR_EMPTY_REQUEST_BODY = "Xml request body cannot be empty.";
public static final String ERR_PART_LIST_TOO_LONG = String.format("Parts list size cannot exceed %d.", MAX_LIST_SIZE);
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import java.util.Comparator;
import java.util.EnumSet;
import java.util.GregorianCalendar;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Executors;
Expand All @@ -82,6 +83,9 @@
public class S3MultipartCompleteUploadHandler<R> {
private static final Logger LOGGER = LoggerFactory.getLogger(S3MultipartCompleteUploadHandler.class);
private static final ObjectMapper objectMapper = new XmlMapper();
private static final int MIN_PART_NUM = 1;
private static final int MAX_PART_NUM = 10000;
private static final int MAX_LIST_SIZE = 10000;
private final SecurityService securityService;
private final FrontendMetrics frontendMetrics;
private final AccountAndContainerInjector accountAndContainerInjector;
Expand Down Expand Up @@ -384,10 +388,11 @@ List<ChunkInfo> getChunksToStitch(CompleteMultipartUpload completeMultipartUploa
List<ChunkInfo> chunkInfos = new ArrayList<>();
try {
// sort the list in order
List<Part> sortedParts = Arrays.asList(completeMultipartUpload.getPart());
Collections.sort(sortedParts, Comparator.comparingInt(Part::getPartNumber));
List<Part> parts = Arrays.asList(completeMultipartUpload.getPart());
validatePartsOrThrow(parts);
Collections.sort(parts, Comparator.comparingInt(Part::getPartNumber));
String reservedMetadataId = null;
for (Part part : sortedParts) {
for (Part part : parts) {
S3MultipartETag eTag = S3MultipartETag.deserialize(part.geteTag());
// TODO [S3]: decide the life cycle of S3.
long expirationTimeInMs = -1;
Expand Down Expand Up @@ -415,4 +420,48 @@ List<ChunkInfo> getChunksToStitch(CompleteMultipartUpload completeMultipartUploa
return chunkInfos;
}
}

/**
* Check the list size and part number before processing request
* 1. Disallow duplicate part numbers
* 2. Disallow duplicate etags
* 3. Check for list size 10000
* 4. Check for part numbers integer 1-10000
* @param parts sorted parts list
* @return the bad request error
*/
private static void validatePartsOrThrow(List<Part> parts) throws RestServiceException {
if (parts == null || parts.isEmpty()) {
throw new RestServiceException(S3Constants.ERR_EMPTY_REQUEST_BODY, RestServiceErrorCode.BadRequest);
}

if (parts.size() > S3Constants.MAX_LIST_SIZE) {
String error = S3Constants.ERR_PART_LIST_TOO_LONG;
throw new RestServiceException(error, RestServiceErrorCode.BadRequest);
}

Set<Integer> partNumbers = new HashSet<>();
Set<String> etags = new HashSet<>();

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove Whitespace

for (Part part : parts) {
int partNumber = part.getPartNumber();
String etag = part.geteTag();

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rm white space

if (partNumber < S3Constants.MIN_PART_NUM || partNumber > S3Constants.MAX_PART_NUM) {
String error = String.format(S3Constants.ERR_INVALID_PART_NUMBER, partNumber, S3Constants.MIN_PART_NUM,
S3Constants.MAX_PART_NUM);
throw new RestServiceException(error, RestServiceErrorCode.BadRequest);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rm white space

if (!partNumbers.add(partNumber)) {
String error = String.format(S3Constants.ERR_DUPLICATE_PART_NUMBER, partNumber);
throw new RestServiceException(error, RestServiceErrorCode.BadRequest);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rm white space

if (!etags.add(etag)) {
String error = String.format(S3Constants.ERR_DUPLICATE_ETAG, etag);
throw new RestServiceException(error, RestServiceErrorCode.BadRequest);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import com.github.ambry.commons.CommonTestUtils;
import com.github.ambry.config.FrontendConfig;
import com.github.ambry.config.VerifiableProperties;
import com.github.ambry.frontend.s3.S3Constants;
import com.github.ambry.frontend.s3.S3DeleteHandler;
import com.github.ambry.frontend.s3.S3MultipartAbortUploadHandler;
import com.github.ambry.frontend.s3.S3MultipartUploadHandler;
Expand All @@ -45,6 +46,7 @@
import com.github.ambry.rest.RestMethod;
import com.github.ambry.rest.RestRequest;
import com.github.ambry.rest.RestResponseChannel;
import com.github.ambry.rest.RestServiceException;
import com.github.ambry.router.ByteBufferRSC;
import com.github.ambry.router.FutureResult;
import com.github.ambry.router.InMemoryRouter;
Expand All @@ -53,11 +55,15 @@
import com.github.ambry.utils.TestUtils;
import com.github.ambry.utils.Utils;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Properties;
import org.json.JSONObject;
import org.junit.Test;
Expand Down Expand Up @@ -245,6 +251,93 @@ public void multiPartUploadTest() throws Exception {
assertEquals("Mismatch on status", ResponseStatus.NoContent, restResponseChannel.getStatus());
}

@Test
public void testDuplicatePartNumbers() throws Exception {
Part part1 = new Part("1", "etag1");
Part part2 = new Part("1", "etag2");
Part[] parts = {part2, part1};
String expectedMessage = String.format(S3Constants.ERR_DUPLICATE_PART_NUMBER, 1);
testMultipartUploadWithInvalidParts(parts, expectedMessage);
}

@Test
public void testDuplicateEtags() throws Exception {
Part part1 = new Part("1", "etag1");
Part part2 = new Part("2", "etag1");
Part[] parts = {part2, part1};
String expectedMessage = String.format(S3Constants.ERR_DUPLICATE_ETAG, "etag1");
testMultipartUploadWithInvalidParts(parts, expectedMessage);
}

@Test
public void testInvalidPartNumLessThanMin() throws Exception {
Part part1 = new Part("0", "etag1");
Part part2 = new Part("1", "etag2");
Part[] parts = {part2, part1};
String expectedMessage = String.format(S3Constants.ERR_INVALID_PART_NUMBER, 0, S3Constants.MIN_PART_NUM, S3Constants.MAX_PART_NUM);
testMultipartUploadWithInvalidParts(parts, expectedMessage);
}

@Test
public void testPartNumberInvalidExceedsMax() throws Exception {
int invalidPartNumber = S3Constants.MAX_PART_NUM + 1;
Part part1 = new Part("2", "etag1");
Part part2 = new Part(String.valueOf(invalidPartNumber), "etag2");
Part[] parts = {part2, part1};
String expectedMessage = String.format(S3Constants.ERR_INVALID_PART_NUMBER, invalidPartNumber, S3Constants.MIN_PART_NUM, S3Constants.MAX_PART_NUM);
testMultipartUploadWithInvalidParts(parts, expectedMessage);
}

@Test
public void testExceedMaxParts() throws Exception {
Part[] parts = new Part[S3Constants.MAX_LIST_SIZE + 1];
for (int i = 1; i <= S3Constants.MAX_LIST_SIZE + 1; i++) {
parts[i - 1] = new Part(String.valueOf(i), "eTag" + i);
}
String expectedMessage = S3Constants.ERR_PART_LIST_TOO_LONG;
testMultipartUploadWithInvalidParts(parts, expectedMessage);
}

@Test
public void testEmptyPartList() throws Exception {
Part[] parts = {};
String expectedMessage = S3Constants.ERR_EMPTY_REQUEST_BODY;
testMultipartUploadWithInvalidParts(parts, expectedMessage);
}

private void testMultipartUploadWithInvalidParts(Part[] parts, String expectedErrorMessage) throws Exception {
String accountName = account.getName();
String containerName = container.getName();
String blobName = "MyDirectory/MyKey";
String uploadId = "uploadId";
String uri = S3_PREFIX + SLASH + accountName + SLASH + containerName + SLASH + blobName + "?uploadId=" + uploadId;
JSONObject headers = new JSONObject();

CompleteMultipartUpload completeMultipartUpload = new CompleteMultipartUpload(parts);
XmlMapper xmlMapper = new XmlMapper();
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
xmlMapper.writeValue(byteArrayOutputStream, completeMultipartUpload);
String completeMultipartStr = byteArrayOutputStream.toString();
byte[] content = completeMultipartStr.getBytes(StandardCharsets.UTF_8);
int size = content.length;

headers.put(Headers.CONTENT_TYPE, OCTET_STREAM_CONTENT_TYPE);
headers.put(Headers.CONTENT_LENGTH, size);

RestRequest request = FrontendRestRequestServiceTest.createRestRequest(RestMethod.POST, uri, headers,
new LinkedList<>(Arrays.asList(ByteBuffer.wrap(content), null)));
request.setArg(InternalKeys.REQUEST_PATH,
RequestPath.parse(request, frontendConfig.pathPrefixesToRemove, CLUSTER_NAME));

RestResponseChannel restResponseChannel = new MockRestResponseChannel();
s3PostHandler.handle(request, restResponseChannel, (r, e) -> {
assertNotNull("Expected an exception, but none was thrown.", e);
assertTrue("Unexpected error message: " + e.getMessage(), e.getMessage().contains(expectedErrorMessage));
});
}



/**
* Initiates a {@link S3PutHandler}
*/
Expand Down
Loading