diff --git a/esti/multipart_test.go b/esti/multipart_test.go index 62a375f5975..e96d5af06d2 100644 --- a/esti/multipart_test.go +++ b/esti/multipart_test.go @@ -52,7 +52,7 @@ func TestMultipartUpload(t *testing.T) { partsConcat = append(partsConcat, parts[i]...) } - completedParts := uploadMultipartParts(t, ctx, logger, resp, parts, 0) + completedParts := uploadMultipartParts(t, ctx, svc, logger, resp, parts, 0) if isBlockstoreType(block.BlockstoreTypeS3) == nil { // Object should have Last-Modified time at around time of MPU creation. Ensure @@ -166,7 +166,7 @@ func reverse(s string) string { return string(runes) } -func uploadMultipartParts(t *testing.T, ctx context.Context, logger logging.Logger, resp *s3.CreateMultipartUploadOutput, parts [][]byte, firstIndex int) []types.CompletedPart { +func uploadMultipartParts(t *testing.T, ctx context.Context, client *s3.Client, logger logging.Logger, resp *s3.CreateMultipartUploadOutput, parts [][]byte, firstIndex int) []types.CompletedPart { count := len(parts) completedParts := make([]types.CompletedPart, count) errs := make([]error, count) @@ -176,7 +176,7 @@ func uploadMultipartParts(t *testing.T, ctx context.Context, logger logging.Logg go func(i int) { defer wg.Done() partNumber := firstIndex + i + 1 - completedParts[i], errs[i] = uploadMultipartPart(ctx, logger, svc, resp, parts[i], partNumber) + completedParts[i], errs[i] = uploadMultipartPart(ctx, logger, client, resp, parts[i], partNumber) }(i) } wg.Wait() diff --git a/esti/s3_gateway_test.go b/esti/s3_gateway_test.go index 3397e089a30..c1037d16d52 100644 --- a/esti/s3_gateway_test.go +++ b/esti/s3_gateway_test.go @@ -4,9 +4,6 @@ import ( "bytes" "context" "fmt" - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/s3" - "github.com/go-openapi/swag" "io" "math/rand" "net/http" @@ -16,6 +13,14 @@ import ( "testing" "time" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/types" + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" + "github.com/go-openapi/swag" + "github.com/thanhpk/randstr" + "github.com/minio/minio-go/v7" "github.com/minio/minio-go/v7/pkg/credentials" "github.com/minio/minio-go/v7/pkg/tags" @@ -31,10 +36,12 @@ import ( type GetCredentials = func(id, secret, token string) *credentials.Credentials const ( - numUploads = 100 - randomDataPathLength = 1020 - branch = "main" - gatewayTestPrefix = branch + "/data/" + numUploads = 100 + randomDataPathLength = 1020 + branch = "main" + gatewayTestPrefix = branch + "/data/" + errorPreconditionFailed = "At least one of the pre-conditions you specified did not hold" + errorNotImplemented = "A header you provided implies functionality that is not implemented" ) func newMinioClient(t *testing.T, getCredentials GetCredentials) *minio.Client { @@ -181,6 +188,106 @@ func TestS3UploadAndDownload(t *testing.T) { }) } } +func TestMultipartUploadIfNoneMatch(t *testing.T) { + ctx, logger, repo := setupTest(t) + defer tearDownTest(repo) + s3Endpoint := viper.GetString("s3_endpoint") + s3Client := createS3Client(s3Endpoint, t) + multipartNumberOfParts := 7 + multipartPartSize := 5 * 1024 * 1024 + type TestCase struct { + Path string + IfNoneMatch string + ExpectedError string + } + + testCases := []TestCase{ + {Path: "main/object1"}, + {Path: "main/object1", IfNoneMatch: "*", ExpectedError: errorPreconditionFailed}, + {Path: "main/object2", IfNoneMatch: "*"}, + } + for _, tc := range testCases { + input := &s3.CreateMultipartUploadInput{ + Bucket: aws.String(repo), + Key: aws.String(tc.Path), + } + + resp, err := s3Client.CreateMultipartUpload(ctx, input) + require.NoError(t, err, "failed to create multipart upload") + + parts := make([][]byte, multipartNumberOfParts) + for i := 0; i < multipartNumberOfParts; i++ { + parts[i] = randstr.Bytes(multipartPartSize + i) + } + + completedParts := uploadMultipartParts(t, ctx, s3Client, logger, resp, parts, 0) + + completeInput := &s3.CompleteMultipartUploadInput{ + Bucket: resp.Bucket, + Key: resp.Key, + UploadId: resp.UploadId, + MultipartUpload: &types.CompletedMultipartUpload{ + Parts: completedParts, + }, + } + _, err = s3Client.CompleteMultipartUpload(ctx, completeInput, s3.WithAPIOptions(setHTTPHeaders(tc.IfNoneMatch))) + if tc.ExpectedError != "" { + require.ErrorContains(t, err, tc.ExpectedError) + } else { + require.NoError(t, err, "expected no error but got %w") + } + } +} + +func setHTTPHeaders(ifNoneMatch string) func(*middleware.Stack) error { + return func(stack *middleware.Stack) error { + return stack.Build.Add(middleware.BuildMiddlewareFunc("AddIfNoneMatchHeader", func( + ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler, + ) ( + middleware.BuildOutput, middleware.Metadata, error, + ) { + if req, ok := in.Request.(*smithyhttp.Request); ok { + // Add the If-None-Match header + req.Header.Set("If-None-Match", ifNoneMatch) + } + return next.HandleBuild(ctx, in) + }), middleware.Before) + } +} +func TestS3IfNoneMatch(t *testing.T) { + + ctx, _, repo := setupTest(t) + defer tearDownTest(repo) + + s3Endpoint := viper.GetString("s3_endpoint") + s3Client := createS3Client(s3Endpoint, t) + + type TestCase struct { + Path string + IfNoneMatch string + ExpectedError string + } + + testCases := []TestCase{ + {Path: "main/object1"}, + {Path: "main/object1", IfNoneMatch: "*", ExpectedError: errorPreconditionFailed}, + {Path: "main/object2", IfNoneMatch: "*"}, + {Path: "main/object2"}, + {Path: "main/object3", IfNoneMatch: "unsupported string", ExpectedError: errorNotImplemented}, + } + for _, tc := range testCases { + input := &s3.PutObjectInput{ + Bucket: aws.String(repo), + Key: aws.String(tc.Path), + } + _, err := s3Client.PutObject(ctx, input, s3.WithAPIOptions(setHTTPHeaders(tc.IfNoneMatch))) + if tc.ExpectedError != "" { + require.ErrorContains(t, err, tc.ExpectedError) + } else { + require.NoError(t, err, "expected no error but got %w") + } + } +} func verifyObjectInfo(t *testing.T, got minio.ObjectInfo, expectedSize int) { if got.Err != nil { diff --git a/pkg/gateway/operations/operation_utils.go b/pkg/gateway/operations/operation_utils.go index c9ae38c3fc9..93d8e700db0 100644 --- a/pkg/gateway/operations/operation_utils.go +++ b/pkg/gateway/operations/operation_utils.go @@ -6,6 +6,7 @@ import ( "time" "github.com/treeverse/lakefs/pkg/catalog" + "github.com/treeverse/lakefs/pkg/graveler" "github.com/treeverse/lakefs/pkg/logging" ) @@ -40,7 +41,7 @@ func shouldReplaceMetadata(req *http.Request) bool { return req.Header.Get(amzMetadataDirectiveHeaderPrefix) == "REPLACE" } -func (o *PathOperation) finishUpload(req *http.Request, mTime *time.Time, checksum, physicalAddress string, size int64, relative bool, metadata map[string]string, contentType string) error { +func (o *PathOperation) finishUpload(req *http.Request, mTime *time.Time, checksum, physicalAddress string, size int64, relative bool, metadata map[string]string, contentType string, allowOverwrite bool) error { var writeTime time.Time if mTime == nil { writeTime = time.Now() @@ -59,7 +60,7 @@ func (o *PathOperation) finishUpload(req *http.Request, mTime *time.Time, checks ContentType(contentType). Build() - err := o.Catalog.CreateEntry(req.Context(), o.Repository.Name, o.Reference, entry) + err := o.Catalog.CreateEntry(req.Context(), o.Repository.Name, o.Reference, entry, graveler.WithIfAbsent(!allowOverwrite)) if err != nil { o.Log(req).WithError(err).Error("could not update metadata") return err diff --git a/pkg/gateway/operations/postobject.go b/pkg/gateway/operations/postobject.go index 984512ca551..d3cd11ab0f3 100644 --- a/pkg/gateway/operations/postobject.go +++ b/pkg/gateway/operations/postobject.go @@ -10,6 +10,7 @@ import ( "time" "github.com/treeverse/lakefs/pkg/block" + "github.com/treeverse/lakefs/pkg/catalog" gatewayErrors "github.com/treeverse/lakefs/pkg/gateway/errors" "github.com/treeverse/lakefs/pkg/gateway/multipart" "github.com/treeverse/lakefs/pkg/gateway/path" @@ -94,6 +95,23 @@ func (controller *PostObject) HandleCompleteMultipartUpload(w http.ResponseWrite _ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrInternalError)) return } + // check and validate whether if-none-match header provided + allowOverwrite, err := o.checkIfAbsent(req) + if err != nil { + _ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrNotImplemented)) + return + } + // before writing body, ensure preconditions - this means we essentially check for object existence twice: + // once here, before uploading the body to save resources and time, + // and then graveler will check again when passed a SetOptions. + if !allowOverwrite { + _, err := o.Catalog.GetEntry(req.Context(), o.Repository.Name, o.Reference, o.Path, catalog.GetEntryParams{}) + if err == nil { + // In case object exists in catalog, no error returns + _ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrPreconditionFailed)) + return + } + } objName := multiPart.PhysicalAddress req = req.WithContext(logging.AddFields(req.Context(), logging.Fields{logging.PhysicalAddressFieldKey: objName})) xmlMultipartComplete, err := io.ReadAll(req.Body) @@ -124,7 +142,11 @@ func (controller *PostObject) HandleCompleteMultipartUpload(w http.ResponseWrite return } checksum := strings.Split(resp.ETag, "-")[0] - err = o.finishUpload(req, resp.MTime, checksum, objName, resp.ContentLength, true, multiPart.Metadata, multiPart.ContentType) + err = o.finishUpload(req, resp.MTime, checksum, objName, resp.ContentLength, true, multiPart.Metadata, multiPart.ContentType, allowOverwrite) + if errors.Is(err, graveler.ErrPreconditionFailed) { + _ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrPreconditionFailed)) + return + } if errors.Is(err, graveler.ErrWriteToProtectedBranch) { _ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrWriteToProtectedBranch)) return diff --git a/pkg/gateway/operations/putobject.go b/pkg/gateway/operations/putobject.go index 28c31ce7652..aeacec0442b 100644 --- a/pkg/gateway/operations/putobject.go +++ b/pkg/gateway/operations/putobject.go @@ -20,6 +20,7 @@ import ( ) const ( + IfNoneMatchHeader = "If-None-Match" CopySourceHeader = "x-amz-copy-source" CopySourceRangeHeader = "x-amz-copy-source-range" QueryParamUploadID = "uploadId" @@ -30,7 +31,6 @@ type PutObject struct{} func (controller *PutObject) RequiredPermissions(req *http.Request, repoID, _, destPath string) (permissions.Node, error) { copySource := req.Header.Get(CopySourceHeader) - if len(copySource) == 0 { return permissions.Node{ Permission: permissions.Permission{ @@ -298,6 +298,23 @@ func handlePut(w http.ResponseWriter, req *http.Request, o *PathOperation) { o.Incr("put_object", o.Principal, o.Repository.Name, o.Reference) storageClass := StorageClassFromHeader(req.Header) opts := block.PutOpts{StorageClass: storageClass} + // check and validate whether if-none-match header provided + allowOverwrite, err := o.checkIfAbsent(req) + if err != nil { + _ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrNotImplemented)) + return + } + // before writing body, ensure preconditions - this means we essentially check for object existence twice: + // once here, before uploading the body to save resources and time, + // and then graveler will check again when passed a SetOptions. + if !allowOverwrite { + _, err := o.Catalog.GetEntry(req.Context(), o.Repository.Name, o.Reference, o.Path, catalog.GetEntryParams{}) + if err == nil { + // In case object exists in catalog, no error returns + _ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrPreconditionFailed)) + return + } + } address := o.PathProvider.NewPath() blob, err := upload.WriteBlob(req.Context(), o.BlockStore, o.Repository.StorageNamespace, address, req.Body, req.ContentLength, opts) if err != nil { @@ -309,7 +326,11 @@ func handlePut(w http.ResponseWriter, req *http.Request, o *PathOperation) { // write metadata metadata := amzMetaAsMetadata(req) contentType := req.Header.Get("Content-Type") - err = o.finishUpload(req, &blob.CreationDate, blob.Checksum, blob.PhysicalAddress, blob.Size, true, metadata, contentType) + err = o.finishUpload(req, &blob.CreationDate, blob.Checksum, blob.PhysicalAddress, blob.Size, true, metadata, contentType, allowOverwrite) + if errors.Is(err, graveler.ErrPreconditionFailed) { + _ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrPreconditionFailed)) + return + } if errors.Is(err, graveler.ErrWriteToProtectedBranch) { _ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrWriteToProtectedBranch)) return @@ -325,3 +346,14 @@ func handlePut(w http.ResponseWriter, req *http.Request, o *PathOperation) { o.SetHeader(w, "ETag", httputil.ETag(blob.Checksum)) w.WriteHeader(http.StatusOK) } + +func (o *PathOperation) checkIfAbsent(req *http.Request) (bool, error) { + headerValue := req.Header.Get(IfNoneMatchHeader) + if headerValue == "" { + return true, nil + } + if headerValue == "*" { + return false, nil + } + return false, gatewayErrors.ErrNotImplemented +}