Skip to content

Commit

Permalink
Add symlink support for cache backup and restore (#415)
Browse files Browse the repository at this point in the history
* add symlink support for cache backup and restore

Signed-off-by: Alex Goodman <[email protected]>

* workaround for go1.23+yardstick build issues

Signed-off-by: Alex Goodman <[email protected]>

---------

Signed-off-by: Alex Goodman <[email protected]>
  • Loading branch information
wagoodman authored Oct 23, 2024
1 parent 8e9852f commit 9c21aee
Show file tree
Hide file tree
Showing 13 changed files with 507 additions and 107 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
/go.work
/go.work.sum

# tools
.mise.toml

# default data directories
/vunnel
/bin
Expand Down
158 changes: 133 additions & 25 deletions cmd/grype-db/cli/commands/cache_restore.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"strings"

"github.com/scylladb/go-set/strset"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/spf13/viper"
Expand Down Expand Up @@ -200,6 +201,18 @@ func extractTarGz(reader io.Reader, selectedProviders *strset.Set) error {

tr := tar.NewReader(gr)

rootPath, err := os.Getwd()
if err != nil {
return fmt.Errorf("failed to get current working directory: %w", err)
}

rootPath, err = filepath.Abs(rootPath)
if err != nil {
return fmt.Errorf("failed to get absolute path: %w", err)
}

var restoredAny bool
fs := afero.NewOsFs()
for {
header, err := tr.Next()

Expand All @@ -217,36 +230,131 @@ func extractTarGz(reader io.Reader, selectedProviders *strset.Set) error {
continue
}

log.WithFields("path", header.Name).Trace("extracting file")
restoredAny = true

switch header.Typeflag {
case tar.TypeDir:
if err := os.Mkdir(header.Name, 0755); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
case tar.TypeReg:
parentPath := filepath.Dir(header.Name)
if parentPath != "" {
if err := os.MkdirAll(parentPath, 0755); err != nil {
return fmt.Errorf("failed to create parent directory %q for file %q: %w", parentPath, header.Name, err)
}
}
if err := processTarHeader(fs, rootPath, header, tr); err != nil {
return err
}
}

outFile, err := os.Create(header.Name)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
if err := safeCopy(outFile, tr); err != nil {
return fmt.Errorf("failed to copy file: %w", err)
}
if err := outFile.Close(); err != nil {
return fmt.Errorf("failed to close file: %w", err)
}
if !restoredAny {
return fmt.Errorf("no provider data was restored")
}
return nil
}

func processTarHeader(fs afero.Fs, rootPath string, header *tar.Header, reader io.Reader) error {
// clean the path to avoid traversal (removes "..", ".", etc.)
cleanedPath := cleanPathRelativeToRoot(rootPath, header.Name)

if err := detectPathTraversal(rootPath, cleanedPath); err != nil {
return err
}

log.WithFields("path", cleanedPath).Trace("extracting file")

switch header.Typeflag {
case tar.TypeDir:
if err := fs.Mkdir(cleanedPath, 0755); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
case tar.TypeSymlink:
if err := handleSymlink(fs, rootPath, cleanedPath, header.Linkname); err != nil {
return fmt.Errorf("failed to create symlink: %w", err)
}
case tar.TypeReg:
if err := handleFile(fs, cleanedPath, reader); err != nil {
return fmt.Errorf("failed to handle file: %w", err)
}
default:
log.WithFields("name", cleanedPath, "type", header.Typeflag).Warn("unknown file type in backup archive")
}
return nil
}

func handleFile(fs afero.Fs, cleanedPath string, reader io.Reader) error {
if cleanedPath == "" {
return fmt.Errorf("empty path")
}

parentPath := filepath.Dir(cleanedPath)
if parentPath != "" {
if err := fs.MkdirAll(parentPath, 0755); err != nil {
return fmt.Errorf("failed to create parent directory %q for file %q: %w", parentPath, cleanedPath, err)
}
}

default:
log.WithFields("name", header.Name, "type", header.Typeflag).Warn("unknown file type in backup archive")
outFile, err := fs.Create(cleanedPath)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
if err := safeCopy(outFile, reader); err != nil {
return fmt.Errorf("failed to copy file: %w", err)
}
if err := outFile.Close(); err != nil {
return fmt.Errorf("failed to close file: %w", err)
}
return nil
}

func handleSymlink(fs afero.Fs, rootPath, cleanedPath, linkName string) error {
if err := detectLinkTraversal(rootPath, cleanedPath, linkName); err != nil {
return err
}

linkReader, ok := fs.(afero.LinkReader)
if !ok {
return afero.ErrNoReadlink
}

// check if the symlink already exists and is pointing to the correct target
if linkTarget, err := linkReader.ReadlinkIfPossible(cleanedPath); err == nil {
if linkTarget == linkName {
return nil
}

if err := fs.Remove(cleanedPath); err != nil {
return fmt.Errorf("failed to remove existing symlink: %w", err)
}
}

linker, ok := fs.(afero.Linker)
if !ok {
return afero.ErrNoSymlink
}

if err := linker.SymlinkIfPossible(linkName, cleanedPath); err != nil {
return fmt.Errorf("failed to create symlink: %w", err)
}
return nil
}

func cleanPathRelativeToRoot(rootPath, path string) string {
return filepath.Join(rootPath, filepath.Clean(path))
}

func detectLinkTraversal(rootPath, cleanedPath, linkTarget string) error {
linkTarget = filepath.Clean(linkTarget)
if filepath.IsAbs(linkTarget) {
return detectPathTraversal(rootPath, linkTarget)
}

linkTarget = filepath.Join(filepath.Dir(cleanedPath), linkTarget)

if !strings.HasPrefix(linkTarget, rootPath) {
return fmt.Errorf("symlink points outside root: %s -> %s", cleanedPath, linkTarget)
}
return nil
}

func detectPathTraversal(rootPath, cleanedPath string) error {
if cleanedPath == "" {
return nil
}

if !strings.HasPrefix(cleanedPath, rootPath) {
return fmt.Errorf("path traversal detected: %s", cleanedPath)
}
return nil
}

Expand Down
Loading

0 comments on commit 9c21aee

Please sign in to comment.