Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 29 additions & 81 deletions snapshots/devbox/devbox.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,13 +366,17 @@ func (o *Snapshotter) Commit(ctx context.Context, name, key string, opts ...snap
}

func (o *Snapshotter) RemoveDir(ctx context.Context, dir string) {
isMounted, err := isMountPoint(dir)
// Hold lock to ensure atomicity between check and use (prevent TOCTOU race condition)
lvm.LockLV()
defer lvm.UnlockLV()

isMounted, err := lvm.IsMountPointInternal(dir)
if err != nil {
log.G(ctx).WithError(err).WithField("path", dir).Warn("failed to check if path is a mount point")
return
}
if isMounted {
if err1 := o.unmountLvm(ctx, dir); err1 != nil {
if err1 := lvm.UnmountVolumeInternal(dir); err1 != nil {
log.G(ctx).WithError(err1).WithField("path", dir).Warn("failed to unmount directory")
return
}
Expand All @@ -395,13 +399,19 @@ func (o *Snapshotter) Remove(ctx context.Context, key string) (err error) {
var (
removals []string
removedLvNames []string
mountPath string
)

log.G(ctx).Infof("Remove called with key: %s", key)
// Remove directories after the transaction is closed, failures must not
// return error since the transaction is committed with the removal
// key no longer available.
defer func() {
if mountPath != "" {
if err = o.unmountLvm(ctx, mountPath); err != nil {
log.G(ctx).WithError(err).WithField("path", mountPath).Warn("failed to unmount directory")
}
}
if err == nil {
for _, dir := range removals {
o.RemoveDir(ctx, dir)
Expand All @@ -419,17 +429,12 @@ func (o *Snapshotter) Remove(ctx context.Context, key string) (err error) {

return o.ms.WithTransaction(ctx, true, func(ctx context.Context) error {
// modified by sealos
var mountPath string
mountPath, err = storage.RemoveDevbox(ctx, key)
log.G(ctx).Infof("Removed devbox content for key: %s, mount path: %s", key, mountPath)
if err != nil && err != errdefs.ErrNotFound {
return fmt.Errorf("failed to remove devbox content for snapshot %s: %w", key, err)
}
if mountPath != "" {
if err = o.unmountLvm(ctx, mountPath); err != nil {
log.G(ctx).WithError(err).WithField("path", mountPath).Warn("failed to unmount directory")
}
}

_, _, err = storage.Remove(ctx, key)
if err != nil {
return fmt.Errorf("failed to remove snapshot %s: %w", key, err)
Expand Down Expand Up @@ -511,7 +516,7 @@ func (o *Snapshotter) cleanupDirectories(ctx context.Context) (_ []string, _ []s
// Unmount any mounted LVs
for _, lvName := range removedLvNames {
devicePath := fmt.Sprintf("/dev/%s/%s", o.lvmVgName, lvName)
mountPoints, err := findMountPointByDevice(devicePath)
mountPoints, err := lvm.FindMountPointByDevice(devicePath)
if err != nil {
log.G(ctx).WithError(err).WithField("lvName", lvName).WithField("devicePath", devicePath).
Warn("Cleanup: failed to find mount point for LV, continuing")
Expand Down Expand Up @@ -594,7 +599,6 @@ func (o *Snapshotter) getCleanupLvNames(ctx context.Context) ([]string, error) {
}

func (o *Snapshotter) resizeLVMVolume(ctx context.Context, lvName, useLimit string) error {

capacity, err := parseUseLimit(useLimit)
if err != nil {
return fmt.Errorf("failed to parse use limit %s: %w", useLimit, err)
Expand Down Expand Up @@ -640,56 +644,11 @@ func readProcMounts() ([][]string, error) {
return mounts, nil
}

// findMountPointByDevice finds the mount point for a given device path by reading /proc/mounts
// Returns the mount point path if found, empty string if not mounted, and error on failure
func findMountPointByDevice(devicePath string) ([]string, error) {
mounts, err := readProcMounts()
if err != nil {
return nil, err
}

var mountPoints []string
for _, fields := range mounts {
if len(fields) < 2 {
continue
}

mountDevice := fields[0]
mountPoint := fields[1]

// Check if the device matches (handle both direct path and symlink resolution)
if mountDevice == devicePath {
mountPoints = append(mountPoints, mountPoint)
continue
}

// Resolve both paths and compare
resolvedDevicePath, err1 := filepath.EvalSymlinks(devicePath)
resolvedMountDevice, err2 := filepath.EvalSymlinks(mountDevice)

// If both resolve successfully, compare resolved paths
if err1 == nil && err2 == nil {
if resolvedDevicePath == resolvedMountDevice {
mountPoints = append(mountPoints, mountPoint)
continue
}
}

// Also check if one resolves to the other
if err1 == nil && resolvedDevicePath == mountDevice {
mountPoints = append(mountPoints, mountPoint)
continue
}
if err2 == nil && resolvedMountDevice == devicePath {
mountPoints = append(mountPoints, mountPoint)
continue
}
}

return mountPoints, nil
}

func isMountPoint(dir string) (bool, error) {
// Acquire global LVM lock to protect reading /proc/mounts
lvm.RLockLV()
defer lvm.RUnlockLV()

mounts, err := readProcMounts()
if err != nil {
return false, err
Expand All @@ -707,6 +666,10 @@ func isMountPoint(dir string) (bool, error) {
}

func (o *Snapshotter) mkfs(lvName string) error {
// Acquire global LVM lock to protect filesystem operations
lvm.LockLV()
defer lvm.UnlockLV()

devicePath := fmt.Sprintf("/dev/%s/%s", o.lvmVgName, lvName)
// Check if the device exists
if _, err := os.Stat(devicePath); os.IsNotExist(err) {
Expand All @@ -722,34 +685,19 @@ func (o *Snapshotter) mkfs(lvName string) error {
}

func (o *Snapshotter) mountLvm(ctx context.Context, lvName string, path string) error {
_, err := os.Stat(path)
if os.IsNotExist(err) {
if err := os.MkdirAll(path, 0755); err != nil {
return fmt.Errorf("failed to create directory %s: %w", path, err)
}
} else if err != nil {
return fmt.Errorf("failed to stat path %s: %w", path, err)
}
devicePath := fmt.Sprintf("/dev/%s/%s", o.lvmVgName, lvName)
err = syscall.Mount(devicePath, path, "ext4", 0, "")
if err != nil {
if err := lvm.MountVolume(devicePath, path, "ext4", 0, ""); err != nil {
log.G(ctx).WithError(err).WithField("devicePath", devicePath).WithField("path", path).Warn("failed to mount LVM logical volume")
return fmt.Errorf("failed to mount LVM logical volume %s to %s: %w", devicePath, path, err)
}
return nil
}

// unmountLvm unmounts the LVM logical volume
func (o *Snapshotter) unmountLvm(ctx context.Context, path string) error {
isMounted, err := isMountPoint(path)
if err != nil {
return fmt.Errorf("failed to check if path %s is a mount point: %w", path, err)
}
if !isMounted {
log.G(ctx).Infof("Path %s is not mounted, skipping unmount", path)
return nil
}
err = syscall.Unmount(path, 0)
if err != nil {
return fmt.Errorf("failed to unmount path %s: %w", path, err)
if err := lvm.UnmountVolume(path); err != nil {
log.G(ctx).WithError(err).WithField("path", path).Warn("failed to unmount LVM logical volume")
return fmt.Errorf("failed to unmount LVM logical volume %s: %w", path, err)
}
return nil
}
Expand Down Expand Up @@ -816,7 +764,7 @@ func (o *Snapshotter) createSnapshot(ctx context.Context, kind snapshots.Kind, k
// mount point for the snapshot
log.G(ctx).Debug("LVM logical volume name found for content ID:", contentID, "is", lvName)
var isMounted bool
if isMounted, err = isMountPoint(npath); err != nil {
if isMounted, err = lvm.IsMountPoint(npath); err != nil {
return fmt.Errorf("failed to check if path is a mount point: %w", err)
} else if isMounted {
log.G(ctx).Infof("Path %s is already mounted, skipping mount", npath)
Expand Down
28 changes: 14 additions & 14 deletions snapshots/devbox/devbox_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func TestFindMountPointAndUnmount(t *testing.T) {

// Step 5: Test findMountPointByDevice
t.Logf("Step 5: Testing findMountPointByDevice for %s", devicePath)
foundMountPoints, err := findMountPointByDevice(devicePath)
foundMountPoints, err := lvm.FindMountPointByDevice(devicePath)
if err != nil {
t.Fatalf("findMountPointByDevice failed: %v", err)
}
Expand All @@ -150,7 +150,7 @@ func TestFindMountPointAndUnmount(t *testing.T) {

// Step 7: Verify the mount point is no longer mounted
t.Logf("Step 7: Verifying mount point is no longer mounted")
foundMountPoints, err = findMountPointByDevice(devicePath)
foundMountPoints, err = lvm.FindMountPointByDevice(devicePath)
if err != nil {
t.Fatalf("findMountPointByDevice failed after unmount: %v", err)
}
Expand Down Expand Up @@ -194,7 +194,7 @@ func TestFindMountPointByDevice_UnmountedDevice(t *testing.T) {
devicePath := fmt.Sprintf("/dev/%s/%s", testVGName, lvName)

// Test findMountPointByDevice on an unmounted device
mountPoints, err := findMountPointByDevice(devicePath)
mountPoints, err := lvm.FindMountPointByDevice(devicePath)
if err != nil {
t.Fatalf("findMountPointByDevice failed: %v", err)
}
Expand Down Expand Up @@ -307,7 +307,7 @@ func TestFindMountPointAndUnmount_Concurrent(t *testing.T) {
}()

// Step 5: Test findMountPointByDevice (concurrent access)
foundMountPoints, err := findMountPointByDevice(devicePath)
foundMountPoints, err := lvm.FindMountPointByDevice(devicePath)
if err != nil {
errorChan <- fmt.Errorf("goroutine %d: findMountPointByDevice failed: %w", index, err)
return
Expand All @@ -333,7 +333,7 @@ func TestFindMountPointAndUnmount_Concurrent(t *testing.T) {
mounted = false // Mark as unmounted so defer doesn't try again

// Step 7: Verify the mount point is no longer mounted
foundMountPoints, err = findMountPointByDevice(devicePath)
foundMountPoints, err = lvm.FindMountPointByDevice(devicePath)
if err != nil {
errorChan <- fmt.Errorf("goroutine %d: findMountPointByDevice failed after unmount: %w", index, err)
return
Expand Down Expand Up @@ -661,7 +661,7 @@ func TestFindMountPointByDevice(t *testing.T) {

// Step 2: Test findMountPointByDevice on unmounted device (should return empty)
t.Logf("Step 2: Testing findMountPointByDevice on unmounted device")
mountPoints, err := findMountPointByDevice(devicePath)
mountPoints, err := lvm.FindMountPointByDevice(devicePath)
if err != nil {
t.Fatalf("findMountPointByDevice failed: %v", err)
}
Expand Down Expand Up @@ -704,7 +704,7 @@ func TestFindMountPointByDevice(t *testing.T) {

// Step 6: Test findMountPointByDevice on mounted device
t.Logf("Step 6: Testing findMountPointByDevice on mounted device")
mountPoints, err = findMountPointByDevice(devicePath)
mountPoints, err = lvm.FindMountPointByDevice(devicePath)
if err != nil {
t.Fatalf("findMountPointByDevice failed: %v", err)
}
Expand All @@ -724,7 +724,7 @@ func TestFindMountPointByDevice(t *testing.T) {
t.Fatalf("Failed to unmount %s: %v", expectedMountPoint, err)
}

mountPoints, err = findMountPointByDevice(devicePath)
mountPoints, err = lvm.FindMountPointByDevice(devicePath)
if err != nil {
t.Fatalf("findMountPointByDevice failed after unmount: %v", err)
}
Expand Down Expand Up @@ -813,7 +813,7 @@ func TestFindMountPointByDevice_MultipleMountPoints(t *testing.T) {
mountedPoints = append(mountedPoints, mountPoint)

// Verify it's mounted after each mount
mountPoints, err := findMountPointByDevice(devicePath)
mountPoints, err := lvm.FindMountPointByDevice(devicePath)
if err != nil {
t.Fatalf("findMountPointByDevice failed after mounting to %s: %v", mountPoint, err)
}
Expand All @@ -833,7 +833,7 @@ func TestFindMountPointByDevice_MultipleMountPoints(t *testing.T) {

// Step 6: Test findMountPointByDevice finds all mount points
t.Logf("Step 6: Testing findMountPointByDevice finds all %d mount points", numMountPoints)
foundMountPoints, err := findMountPointByDevice(devicePath)
foundMountPoints, err := lvm.FindMountPointByDevice(devicePath)
if err != nil {
t.Fatalf("findMountPointByDevice failed: %v", err)
}
Expand Down Expand Up @@ -951,7 +951,7 @@ func TestCleanupUnmountAllMountPoints(t *testing.T) {

// Step 5: Verify all mount points are mounted
t.Logf("Step 5: Verifying all %d mount points are mounted", numMountPoints)
foundMountPoints, err := findMountPointByDevice(devicePath)
foundMountPoints, err := lvm.FindMountPointByDevice(devicePath)
if err != nil {
t.Fatalf("findMountPointByDevice failed: %v", err)
}
Expand All @@ -967,7 +967,7 @@ func TestCleanupUnmountAllMountPoints(t *testing.T) {
// Simulate the cleanup logic from cleanupDirectories
for _, lvNameToCleanup := range removedLvNames {
devicePathToCheck := fmt.Sprintf("/dev/%s/%s", snapshotter.lvmVgName, lvNameToCleanup)
mountPointsToUnmount, err := findMountPointByDevice(devicePathToCheck)
mountPointsToUnmount, err := lvm.FindMountPointByDevice(devicePathToCheck)
if err != nil {
t.Fatalf("Failed to find mount points for LV %s: %v", lvNameToCleanup, err)
}
Expand All @@ -992,7 +992,7 @@ func TestCleanupUnmountAllMountPoints(t *testing.T) {

// Step 7: Verify all mount points are unmounted
t.Logf("Step 7: Verifying all mount points are unmounted")
foundMountPoints, err = findMountPointByDevice(devicePath)
foundMountPoints, err = lvm.FindMountPointByDevice(devicePath)
if err != nil {
t.Fatalf("findMountPointByDevice failed after unmount: %v", err)
}
Expand All @@ -1004,7 +1004,7 @@ func TestCleanupUnmountAllMountPoints(t *testing.T) {

// Step 8: Verify device can be removed (no mount points should allow removal)
t.Logf("Step 8: Verifying device has no mount points (can be removed)")
remainingMountPoints, err := findMountPointByDevice(devicePath)
remainingMountPoints, err := lvm.FindMountPointByDevice(devicePath)
if err != nil {
t.Fatalf("findMountPointByDevice failed: %v", err)
}
Expand Down
Loading