From d61d880f95c2b8020f5bfcd42aa9c59cadab5a69 Mon Sep 17 00:00:00 2001 From: Brian Kanya Date: Tue, 25 Nov 2025 13:46:19 +0100 Subject: [PATCH 01/10] Re-mount volumes on a restart or update #72 --- pkg/rclone/driver.go | 26 ++++++- pkg/rclone/nodeserver.go | 158 +++++++++++++++++++++++++++++++++++++++ pkg/rclone/rclone.go | 9 ++- 3 files changed, 187 insertions(+), 6 deletions(-) diff --git a/pkg/rclone/driver.go b/pkg/rclone/driver.go index 60fd38ba..e4f771ee 100644 --- a/pkg/rclone/driver.go +++ b/pkg/rclone/driver.go @@ -78,14 +78,26 @@ func NewNodeServer(csiDriver *csicommon.CSIDriver, cacheDir string, cacheSize st } rcloneOps := NewRclone(kubeClient, rclonePort, cacheDir, cacheSize) - return &nodeServer{ + // Use kubelet plugin directory for state persistence + stateFile := "/var/lib/kubelet/plugins/csi-rclone/mounted_volumes.json" + + ns := &nodeServer{ DefaultNodeServer: csicommon.NewDefaultNodeServer(csiDriver), mounter: &mount.SafeFormatAndMount{ Interface: mount.New(""), Exec: utilexec.New(), }, - RcloneOps: rcloneOps, - }, nil + RcloneOps: rcloneOps, + mountedVolumes: make(map[string]*MountedVolume), + stateFile: stateFile, + } + + // Load persisted state on startup + if err := ns.loadState(); err != nil { + klog.Warningf("Failed to load persisted volume state: %v", err) + } + + return ns, nil } func NewControllerServer(csiDriver *csicommon.CSIDriver) *controllerServer { @@ -139,7 +151,13 @@ func (d *Driver) Run() error { ) d.server = s if d.ns != nil && d.ns.RcloneOps != nil { - return d.ns.RcloneOps.Run() + onDaemonReady := func() error { + if d.ns != nil { + return d.ns.remountTrackedVolumes(context.Background()) + } + return nil + } + return d.ns.RcloneOps.Run(onDaemonReady) } s.Wait() return nil diff --git a/pkg/rclone/nodeserver.go b/pkg/rclone/nodeserver.go index ef035016..b31def14 100644 --- a/pkg/rclone/nodeserver.go +++ b/pkg/rclone/nodeserver.go @@ -7,10 +7,13 @@ package rclone import ( "bytes" + "encoding/json" "errors" "fmt" "os" + "path/filepath" "strings" + "sync" "time" "gopkg.in/ini.v1" @@ -37,6 +40,23 @@ type nodeServer struct { *csicommon.DefaultNodeServer mounter *mount.SafeFormatAndMount RcloneOps Operations + + // Track mounted volumes for automatic remounting + mountedVolumes map[string]*MountedVolume + mutex sync.RWMutex + stateFile string +} + +type MountedVolume struct { + VolumeId string + TargetPath string + Remote string + RemotePath string + ConfigData string + ReadOnly bool + Parameters map[string]string + SecretName string + SecretNamespace string } // Mounting Volume (Preparation) @@ -141,6 +161,10 @@ func (ns *nodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePublis } return nil, status.Error(codes.Internal, err.Error()) } + + // Track the mounted volume for automatic remounting + ns.trackMountedVolume(volumeId, targetPath, remote, remotePath, configData, readOnly, flags, secretName, secretNamespace) + // err = ns.WaitForMountAvailable(targetPath) // if err != nil { // return nil, status.Error(codes.Internal, err.Error()) @@ -323,6 +347,10 @@ func (ns *nodeServer) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpu if err := ns.RcloneOps.Unmount(ctx, req.GetVolumeId(), targetPath); err != nil { klog.Warningf("Unmounting volume failed: %s", err) } + + // Remove the volume from tracking + ns.removeTrackedVolume(req.GetVolumeId()) + mount.CleanupMountPoint(req.GetTargetPath(), ns.mounter, false) return &csi.NodeUnpublishVolumeResponse{}, nil } @@ -344,6 +372,82 @@ func (*nodeServer) NodeExpandVolume(ctx context.Context, req *csi.NodeExpandVolu return nil, status.Errorf(codes.Unimplemented, "method NodeExpandVolume not implemented") } +// Track mounted volume for automatic remounting +func (ns *nodeServer) trackMountedVolume(volumeId, targetPath, remote, remotePath, configData string, readOnly bool, parameters map[string]string, secretName, secretNamespace string) { + ns.mutex.Lock() + defer ns.mutex.Unlock() + + ns.mountedVolumes[volumeId] = &MountedVolume{ + VolumeId: volumeId, + TargetPath: targetPath, + Remote: remote, + RemotePath: remotePath, + ConfigData: configData, + ReadOnly: readOnly, + Parameters: parameters, + SecretName: secretName, + SecretNamespace: secretNamespace, + } + klog.Infof("Tracked mounted volume %s at path %s", volumeId, targetPath) + + if err := ns.persistState(); err != nil { + klog.Errorf("Failed to persist volume state: %v", err) + } +} + +// Remove tracked volume when unmounted +func (ns *nodeServer) removeTrackedVolume(volumeId string) { + ns.mutex.Lock() + defer ns.mutex.Unlock() + + delete(ns.mountedVolumes, volumeId) + klog.Infof("Removed tracked volume %s", volumeId) + + if err := ns.persistState(); err != nil { + klog.Errorf("Failed to persist volume state: %v", err) + } +} + +// Automatically remount all tracked volumes after daemon restart +func (ns *nodeServer) remountTrackedVolumes(ctx context.Context) error { + ns.mutex.RLock() + defer ns.mutex.RUnlock() + + if len(ns.mountedVolumes) == 0 { + klog.Info("No tracked volumes to remount") + return nil + } + + klog.Infof("Remounting %d tracked volumes", len(ns.mountedVolumes)) + + for volumeId, mv := range ns.mountedVolumes { + klog.Infof("Remounting volume %s to %s", volumeId, mv.TargetPath) + + // Create the mount directory if it doesn't exist + if err := os.MkdirAll(mv.TargetPath, 0750); err != nil { + klog.Errorf("Failed to create mount directory %s: %v", mv.TargetPath, err) + continue + } + + // Remount the volume + rcloneVol := &RcloneVolume{ + ID: mv.VolumeId, + Remote: mv.Remote, + RemotePath: mv.RemotePath, + } + + err := ns.RcloneOps.Mount(ctx, rcloneVol, mv.TargetPath, mv.ConfigData, mv.ReadOnly, mv.Parameters) + if err != nil { + klog.Errorf("Failed to remount volume %s: %v", volumeId, err) + // Don't return error here - continue with other volumes + } else { + klog.Infof("Successfully remounted volume %s", volumeId) + } + } + + return nil +} + func (ns *nodeServer) WaitForMountAvailable(mountpoint string) error { for { select { @@ -357,3 +461,57 @@ func (ns *nodeServer) WaitForMountAvailable(mountpoint string) error { } } } + +// Persist volume state to disk +func (ns *nodeServer) persistState() error { + ns.mutex.RLock() + defer ns.mutex.RUnlock() + + if ns.stateFile == "" { + return nil + } + + data, err := json.Marshal(ns.mountedVolumes) + if err != nil { + return fmt.Errorf("failed to marshal volume state: %v", err) + } + + if err := os.MkdirAll(filepath.Dir(ns.stateFile), 0755); err != nil { + return fmt.Errorf("failed to create state directory: %v", err) + } + + if err := os.WriteFile(ns.stateFile, data, 0600); err != nil { + return fmt.Errorf("failed to write state file: %v", err) + } + + klog.Infof("Persisted volume state to %s", ns.stateFile) + return nil +} + +// Load volume state from disk +func (ns *nodeServer) loadState() error { + ns.mutex.Lock() + defer ns.mutex.Unlock() + + if ns.stateFile == "" { + return nil + } + + data, err := os.ReadFile(ns.stateFile) + if err != nil { + if os.IsNotExist(err) { + klog.Info("No persisted volume state found, starting fresh") + return nil + } + return fmt.Errorf("failed to read state file: %v", err) + } + + var volumes map[string]*MountedVolume + if err := json.Unmarshal(data, &volumes); err != nil { + return fmt.Errorf("failed to unmarshal volume state: %v", err) + } + + ns.mountedVolumes = volumes + klog.Infof("Loaded %d tracked volumes from %s", len(ns.mountedVolumes), ns.stateFile) + return nil +} diff --git a/pkg/rclone/rclone.go b/pkg/rclone/rclone.go index 9b325084..264b2719 100644 --- a/pkg/rclone/rclone.go +++ b/pkg/rclone/rclone.go @@ -34,7 +34,7 @@ type Operations interface { Unmount(ctx context.Context, volumeId string, targetPath string) error GetVolumeById(ctx context.Context, volumeId string) (*RcloneVolume, error) Cleanup() error - Run() error + Run(onDaemonReady func() error) error } type Rclone struct { @@ -472,11 +472,16 @@ func (r *Rclone) start_daemon() error { return nil } -func (r *Rclone) Run() error { +func (r *Rclone) Run(onDaemonReady func() error) error { err := r.start_daemon() if err != nil { return err } + if onDaemonReady != nil { + if err := onDaemonReady(); err != nil { + klog.Warningf("Error in onDaemonReady callback: %v", err) + } + } // blocks until the rclone daemon is stopped return r.daemonCmd.Wait() } From 2a46ac7fa7a29bb65c94ac97020345e1d550480b Mon Sep 17 00:00:00 2001 From: Lionel Sambuc Date: Mon, 1 Dec 2025 05:47:40 +0000 Subject: [PATCH 02/10] fix: address comments and apply fixes (#77) * Ensure the mutex is not copied, even when the nodeServer is copied by storing a pointer to the mutex, instead of the mutex itself. * Use Mutex instead of RWMutex, as having two readers of the variable at the same time means we are going to write the state at the same time, corrupting the state file on storage. * Mutex / RWMutex are not recursive / re-entrant in Go, so in two cases we do not call `Unlock()` through `defer` as `persistState()` also takes the lock. * As a rule of thumb, Locking a Mutex should be as close as possible to the resource requiring it, to minimize the size of the critical section / the time spent holding the lock. * Remount each volume in a goroutine, with a rate limits of the number of active routine to prevent contention, and keep under control startup times. --- .devcontainer/rclone/install.sh | 3 + pkg/rclone/driver.go | 17 +++- pkg/rclone/nodeserver.go | 168 +++++++++++++++++++------------- 3 files changed, 116 insertions(+), 72 deletions(-) diff --git a/.devcontainer/rclone/install.sh b/.devcontainer/rclone/install.sh index c9a86242..a972f212 100644 --- a/.devcontainer/rclone/install.sh +++ b/.devcontainer/rclone/install.sh @@ -19,3 +19,6 @@ rm -rf /tmp/rclone # Fix the $GOPATH folder chown -R "${USERNAME}:golang" /go chmod -R g+r+w /go + +# Make sure the default folders exists +mkdir -p /var/lib/kubelet/plugins/csi-rclone/ \ No newline at end of file diff --git a/pkg/rclone/driver.go b/pkg/rclone/driver.go index e4f771ee..49423a78 100644 --- a/pkg/rclone/driver.go +++ b/pkg/rclone/driver.go @@ -1,9 +1,11 @@ package rclone import ( + "context" "fmt" "net" "os" + "path/filepath" "sync" "github.com/SwissDataScienceCenter/csi-rclone/pkg/kube" @@ -87,13 +89,22 @@ func NewNodeServer(csiDriver *csicommon.CSIDriver, cacheDir string, cacheSize st Interface: mount.New(""), Exec: utilexec.New(), }, - RcloneOps: rcloneOps, - mountedVolumes: make(map[string]*MountedVolume), + RcloneOps: rcloneOps, + mountedVolumes: make(map[string]MountedVolume), + mutex: &sync.Mutex{}, stateFile: stateFile, } + // Ensure the folder exists + if err = os.MkdirAll(filepath.Dir(ns.stateFile), 0755); err != nil { + return nil, fmt.Errorf("failed to create state directory: %v", err) + } + // Load persisted state on startup - if err := ns.loadState(); err != nil { + ns.mutex.Lock() + defer ns.mutex.Unlock() + + if ns.mountedVolumes, err = readVolumeMap(ns.stateFile); err != nil { klog.Warningf("Failed to load persisted volume state: %v", err) } diff --git a/pkg/rclone/nodeserver.go b/pkg/rclone/nodeserver.go index b31def14..0d26791a 100644 --- a/pkg/rclone/nodeserver.go +++ b/pkg/rclone/nodeserver.go @@ -11,7 +11,7 @@ import ( "errors" "fmt" "os" - "path/filepath" + "runtime" "strings" "sync" "time" @@ -42,20 +42,20 @@ type nodeServer struct { RcloneOps Operations // Track mounted volumes for automatic remounting - mountedVolumes map[string]*MountedVolume - mutex sync.RWMutex + mountedVolumes map[string]MountedVolume + mutex *sync.Mutex stateFile string } type MountedVolume struct { - VolumeId string - TargetPath string - Remote string - RemotePath string - ConfigData string - ReadOnly bool - Parameters map[string]string - SecretName string + VolumeId string + TargetPath string + Remote string + RemotePath string + ConfigData string + ReadOnly bool + Parameters map[string]string + SecretName string SecretNamespace string } @@ -377,20 +377,20 @@ func (ns *nodeServer) trackMountedVolume(volumeId, targetPath, remote, remotePat ns.mutex.Lock() defer ns.mutex.Unlock() - ns.mountedVolumes[volumeId] = &MountedVolume{ - VolumeId: volumeId, - TargetPath: targetPath, - Remote: remote, - RemotePath: remotePath, - ConfigData: configData, - ReadOnly: readOnly, - Parameters: parameters, - SecretName: secretName, + ns.mountedVolumes[volumeId] = MountedVolume{ + VolumeId: volumeId, + TargetPath: targetPath, + Remote: remote, + RemotePath: remotePath, + ConfigData: configData, + ReadOnly: readOnly, + Parameters: parameters, + SecretName: secretName, SecretNamespace: secretNamespace, } klog.Infof("Tracked mounted volume %s at path %s", volumeId, targetPath) - if err := ns.persistState(); err != nil { + if err := writeVolumeMap(ns.stateFile, ns.mountedVolumes); err != nil { klog.Errorf("Failed to persist volume state: %v", err) } } @@ -403,15 +403,20 @@ func (ns *nodeServer) removeTrackedVolume(volumeId string) { delete(ns.mountedVolumes, volumeId) klog.Infof("Removed tracked volume %s", volumeId) - if err := ns.persistState(); err != nil { + if err := writeVolumeMap(ns.stateFile, ns.mountedVolumes); err != nil { klog.Errorf("Failed to persist volume state: %v", err) } } // Automatically remount all tracked volumes after daemon restart func (ns *nodeServer) remountTrackedVolumes(ctx context.Context) error { - ns.mutex.RLock() - defer ns.mutex.RUnlock() + type mountResult struct { + volumeID string + err error + } + + ns.mutex.Lock() + defer ns.mutex.Unlock() if len(ns.mountedVolumes) == 0 { klog.Info("No tracked volumes to remount") @@ -420,32 +425,67 @@ func (ns *nodeServer) remountTrackedVolumes(ctx context.Context) error { klog.Infof("Remounting %d tracked volumes", len(ns.mountedVolumes)) + // Limit the number of active workers to the number of CPU threads (arbitrarily chosen) + limits := make(chan bool, runtime.GOMAXPROCS(0)) + defer close(limits) + + volumesCount := len(ns.mountedVolumes) + results := make(chan mountResult, volumesCount) + defer close(results) + + ctxWithTimeout, cancel := context.WithTimeout(ctx, 60*time.Second) + defer cancel() + for volumeId, mv := range ns.mountedVolumes { - klog.Infof("Remounting volume %s to %s", volumeId, mv.TargetPath) + go func() { + limits <- true // block until there is a free slot in the queue + defer func() { + <-limits // free a slot in the queue when we exit + }() - // Create the mount directory if it doesn't exist - if err := os.MkdirAll(mv.TargetPath, 0750); err != nil { - klog.Errorf("Failed to create mount directory %s: %v", mv.TargetPath, err) - continue - } + ctxWithMountTimeout, cancel := context.WithTimeout(ctxWithTimeout, 30*time.Second) + defer cancel() - // Remount the volume - rcloneVol := &RcloneVolume{ - ID: mv.VolumeId, - Remote: mv.Remote, - RemotePath: mv.RemotePath, - } + klog.Infof("Remounting volume %s to %s", volumeId, mv.TargetPath) - err := ns.RcloneOps.Mount(ctx, rcloneVol, mv.TargetPath, mv.ConfigData, mv.ReadOnly, mv.Parameters) - if err != nil { - klog.Errorf("Failed to remount volume %s: %v", volumeId, err) - // Don't return error here - continue with other volumes - } else { - klog.Infof("Successfully remounted volume %s", volumeId) - } + // Create the mount directory if it doesn't exist + var err error + if err = os.MkdirAll(mv.TargetPath, 0750); err != nil { + klog.Errorf("Failed to create mount directory %s: %v", mv.TargetPath, err) + } else { + // Remount the volume + rcloneVol := &RcloneVolume{ + ID: mv.VolumeId, + Remote: mv.Remote, + RemotePath: mv.RemotePath, + } + + err = ns.RcloneOps.Mount(ctxWithMountTimeout, rcloneVol, mv.TargetPath, mv.ConfigData, mv.ReadOnly, mv.Parameters) + } + + results <- mountResult{volumeId, err} + }() } - return nil + for { + select { + case result := <-results: + volumesCount-- + if result.err != nil { + klog.Errorf("Failed to remount volume %s: %v", result.volumeID, result.err) + // Don't return error here, continue with other volumes not to block all users because of a failed mount. + delete(ns.mountedVolumes, result.volumeID) + // Should we keep it on disk? This will be lost on the first new mount which will override the file. + } else { + klog.Infof("Successfully remounted volume %s", result.volumeID) + } + if volumesCount == 0 { + return nil + } + case <-ctxWithTimeout.Done(): + return ctxWithTimeout.Err() + } + } } func (ns *nodeServer) WaitForMountAvailable(mountpoint string) error { @@ -463,55 +503,45 @@ func (ns *nodeServer) WaitForMountAvailable(mountpoint string) error { } // Persist volume state to disk -func (ns *nodeServer) persistState() error { - ns.mutex.RLock() - defer ns.mutex.RUnlock() - - if ns.stateFile == "" { +func writeVolumeMap(filename string, volumes map[string]MountedVolume) error { + if filename == "" { return nil } - data, err := json.Marshal(ns.mountedVolumes) + data, err := json.Marshal(volumes) if err != nil { return fmt.Errorf("failed to marshal volume state: %v", err) } - if err := os.MkdirAll(filepath.Dir(ns.stateFile), 0755); err != nil { - return fmt.Errorf("failed to create state directory: %v", err) - } - - if err := os.WriteFile(ns.stateFile, data, 0600); err != nil { + if err := os.WriteFile(filename, data, 0600); err != nil { return fmt.Errorf("failed to write state file: %v", err) } - klog.Infof("Persisted volume state to %s", ns.stateFile) + klog.Infof("Persisted volume state to %s", filename) return nil } // Load volume state from disk -func (ns *nodeServer) loadState() error { - ns.mutex.Lock() - defer ns.mutex.Unlock() +func readVolumeMap(filename string) (map[string]MountedVolume, error) { + volumes := make(map[string]MountedVolume) - if ns.stateFile == "" { - return nil + if filename == "" { + return volumes, nil } - data, err := os.ReadFile(ns.stateFile) + data, err := os.ReadFile(filename) if err != nil { if os.IsNotExist(err) { klog.Info("No persisted volume state found, starting fresh") - return nil + return volumes, nil } - return fmt.Errorf("failed to read state file: %v", err) + return volumes, fmt.Errorf("failed to read state file: %v", err) } - var volumes map[string]*MountedVolume if err := json.Unmarshal(data, &volumes); err != nil { - return fmt.Errorf("failed to unmarshal volume state: %v", err) + return nil, fmt.Errorf("failed to unmarshal volume state: %v", err) } - ns.mountedVolumes = volumes - klog.Infof("Loaded %d tracked volumes from %s", len(ns.mountedVolumes), ns.stateFile) - return nil + klog.Infof("Loaded %d tracked volumes from %s", len(volumes), filename) + return volumes, nil } From 9023632f1df55cb9b84b5d43eb8adc7bb1c8daae Mon Sep 17 00:00:00 2001 From: Lionel Sambuc Date: Tue, 9 Dec 2025 11:05:41 +0100 Subject: [PATCH 03/10] sambuc/feat merge restart pr 2 (#78) * Refactor * CLI code * servers configurations * ControllerServer * NodeServer * Cleanup warnings in: * controllerserver.go * main.go * Use tags to specify json field names on permanent storage * Make DriverVersion a constant * Protect activeVolume while reading for the metrics * Compute len(ns.mountedVolumes) once --- cmd/csi-rclone-plugin/main.go | 109 +++-------------- pkg/rclone/controllerserver.go | 105 ++++++++++++---- pkg/rclone/driver.go | 183 +++++----------------------- pkg/rclone/nodeserver.go | 216 ++++++++++++++++++++++++++++----- test/sanity_test.go | 27 +++-- 5 files changed, 335 insertions(+), 305 deletions(-) diff --git a/cmd/csi-rclone-plugin/main.go b/cmd/csi-rclone-plugin/main.go index 2a238a7d..73160e11 100644 --- a/cmd/csi-rclone-plugin/main.go +++ b/cmd/csi-rclone-plugin/main.go @@ -13,22 +13,21 @@ import ( "github.com/SwissDataScienceCenter/csi-rclone/pkg/rclone" "github.com/spf13/cobra" "k8s.io/klog" - mountUtils "k8s.io/mount-utils" ) -var ( - endpoint string - nodeID string - cacheDir string - cacheSize string - meters []metrics.Observable -) +func exitOnError(err error) { + if err != nil { + klog.Error(err.Error()) + os.Exit(1) + } +} func init() { - flag.Set("logtostderr", "true") + exitOnError(flag.Set("logtostderr", "true")) } func main() { + var meters []metrics.Observable metricsServerConfig := metrics.ServerConfig{ Host: "localhost", Port: 9090, @@ -37,6 +36,8 @@ func main() { ShutdownTimeout: 5 * time.Second, Enabled: false, } + nodeServerConfig := rclone.NodeServerConfig{} + controllerServerConfig := rclone.ControllerServerConfig{} root := &cobra.Command{ Use: "rclone", @@ -48,34 +49,10 @@ func main() { Use: "run", Short: "Start the CSI driver.", } - root.AddCommand(runCmd) + exitOnError(nodeServerConfig.CommandLineParameters(runCmd, &meters)) + exitOnError(controllerServerConfig.CommandLineParameters(runCmd, &meters)) - runNode := &cobra.Command{ - Use: "node", - Short: "Start the CSI driver node service - expected to run in a daemonset on every node.", - Run: func(cmd *cobra.Command, args []string) { - handleNode() - }, - } - runNode.PersistentFlags().StringVar(&nodeID, "nodeid", "", "node id") - runNode.MarkPersistentFlagRequired("nodeid") - runNode.PersistentFlags().StringVar(&endpoint, "endpoint", "", "CSI endpoint") - runNode.MarkPersistentFlagRequired("endpoint") - runNode.PersistentFlags().StringVar(&cacheDir, "cachedir", "", "cache dir") - runNode.PersistentFlags().StringVar(&cacheSize, "cachesize", "", "cache size") - runCmd.AddCommand(runNode) - runController := &cobra.Command{ - Use: "controller", - Short: "Start the CSI driver controller.", - Run: func(cmd *cobra.Command, args []string) { - handleController() - }, - } - runController.PersistentFlags().StringVar(&nodeID, "nodeid", "", "node id") - runController.MarkPersistentFlagRequired("nodeid") - runController.PersistentFlags().StringVar(&endpoint, "endpoint", "", "CSI endpoint") - runController.MarkPersistentFlagRequired("endpoint") - runCmd.AddCommand(runController) + root.AddCommand(runCmd) versionCmd := &cobra.Command{ Use: "version", @@ -86,7 +63,7 @@ func main() { } root.AddCommand(versionCmd) - root.ParseFlags(os.Args[1:]) + exitOnError(root.ParseFlags(os.Args[1:])) if metricsServerConfig.Enabled { // Gracefully exit the metrics background servers @@ -97,63 +74,7 @@ func main() { go metricsServer.ListenAndServe() } - if err := root.Execute(); err != nil { - fmt.Fprintf(os.Stderr, "%s", err.Error()) - os.Exit(1) - } + exitOnError(root.Execute()) os.Exit(0) } - -func handleNode() { - err := unmountOldVols() - if err != nil { - klog.Warningf("There was an error when trying to unmount old volumes: %v", err) - } - d := rclone.NewDriver(nodeID, endpoint) - ns, err := rclone.NewNodeServer(d.CSIDriver, cacheDir, cacheSize) - if err != nil { - panic(err) - } - meters = append(meters, ns.Metrics()...) - d.WithNodeServer(ns) - err = d.Run() - if err != nil { - panic(err) - } -} - -func handleController() { - d := rclone.NewDriver(nodeID, endpoint) - cs := rclone.NewControllerServer(d.CSIDriver) - meters = append(meters, cs.Metrics()...) - d.WithControllerServer(cs) - err := d.Run() - if err != nil { - panic(err) - } -} - -// unmountOldVols is used to unmount volumes after a restart on a node -func unmountOldVols() error { - const mountType = "fuse.rclone" - const unmountTimeout = time.Second * 5 - klog.Info("Checking for existing mounts") - mounter := mountUtils.Mounter{} - mounts, err := mounter.List() - if err != nil { - return err - } - for _, mount := range mounts { - if mount.Type != mountType { - continue - } - err := mounter.UnmountWithForce(mount.Path, unmountTimeout) - if err != nil { - klog.Warningf("Failed to unmount %s because of %v.", mount.Path, err) - continue - } - klog.Infof("Sucessfully unmounted %s", mount.Path) - } - return nil -} diff --git a/pkg/rclone/controllerserver.go b/pkg/rclone/controllerserver.go index 4e00dd79..1ae3f06a 100644 --- a/pkg/rclone/controllerserver.go +++ b/pkg/rclone/controllerserver.go @@ -3,10 +3,13 @@ package rclone import ( + "context" "sync" + "github.com/SwissDataScienceCenter/csi-rclone/pkg/metrics" "github.com/container-storage-interface/spec/lib/go/csi" - "golang.org/x/net/context" + "github.com/prometheus/client_golang/prometheus" + "github.com/spf13/cobra" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "k8s.io/klog" @@ -16,13 +19,66 @@ import ( const secretAnnotationName = "csi-rclone.dev/secretName" -type controllerServer struct { +type ControllerServerConfig struct{ DriverConfig } + +type ControllerServer struct { *csicommon.DefaultControllerServer - active_volumes map[string]int64 - mutex sync.RWMutex + activeVolumes map[string]int64 + mutex *sync.RWMutex +} + +func NewControllerServer(csiDriver *csicommon.CSIDriver) *ControllerServer { + return &ControllerServer{ + DefaultControllerServer: csicommon.NewDefaultControllerServer(csiDriver), + activeVolumes: map[string]int64{}, + mutex: &sync.RWMutex{}, + } +} + +func (cs *ControllerServer) metrics() []metrics.Observable { + var meters []metrics.Observable + + meter := prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "csi_rclone_active_volume_count", + Help: "Number of active (Mounted) volumes.", + }) + meters = append(meters, + func() { + cs.mutex.RLock() + defer cs.mutex.RUnlock() + meter.Set(float64(len(cs.activeVolumes))) + }, + ) + prometheus.MustRegister(meter) + + return meters } -func (cs *controllerServer) ValidateVolumeCapabilities(ctx context.Context, req *csi.ValidateVolumeCapabilitiesRequest) (*csi.ValidateVolumeCapabilitiesResponse, error) { +func (config *ControllerServerConfig) CommandLineParameters(runCmd *cobra.Command, meters *[]metrics.Observable) error { + runController := &cobra.Command{ + Use: "controller", + Short: "Start the CSI driver controller.", + RunE: func(cmd *cobra.Command, args []string) error { + return Run(context.Background(), + &config.DriverConfig, + func(csiDriver *csicommon.CSIDriver) (csi.ControllerServer, csi.NodeServer, error) { + cs := NewControllerServer(csiDriver) + *meters = append(*meters, cs.metrics()...) + return cs, nil, nil + }, + func(_ context.Context) error { return nil }, + ) + }, + } + if err := config.DriverConfig.CommandLineParameters(runController); err != nil { + return err + } + + runCmd.AddCommand(runController) + return nil +} + +func (cs *ControllerServer) ValidateVolumeCapabilities(_ context.Context, req *csi.ValidateVolumeCapabilitiesRequest) (*csi.ValidateVolumeCapabilitiesResponse, error) { volId := req.GetVolumeId() if len(volId) == 0 { return nil, status.Error(codes.InvalidArgument, "ValidateVolumeCapabilities must be provided volume id") @@ -31,9 +87,9 @@ func (cs *controllerServer) ValidateVolumeCapabilities(ctx context.Context, req return nil, status.Error(codes.InvalidArgument, "ValidateVolumeCapabilities without capabilities") } - cs.mutex.Lock() - defer cs.mutex.Unlock() - if _, ok := cs.active_volumes[volId]; !ok { + cs.mutex.RLock() + defer cs.mutex.RUnlock() + if _, ok := cs.activeVolumes[volId]; !ok { return nil, status.Errorf(codes.NotFound, "Volume %s not found", volId) } return &csi.ValidateVolumeCapabilitiesResponse{ @@ -45,18 +101,18 @@ func (cs *controllerServer) ValidateVolumeCapabilities(ctx context.Context, req }, nil } -// Attaching Volume -func (cs *controllerServer) ControllerPublishVolume(ctx context.Context, req *csi.ControllerPublishVolumeRequest) (*csi.ControllerPublishVolumeResponse, error) { +// ControllerPublishVolume Attaching Volume +func (cs *ControllerServer) ControllerPublishVolume(_ context.Context, _ *csi.ControllerPublishVolumeRequest) (*csi.ControllerPublishVolumeResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method ControllerPublishVolume not implemented") } -// Detaching Volume -func (cs *controllerServer) ControllerUnpublishVolume(ctx context.Context, req *csi.ControllerUnpublishVolumeRequest) (*csi.ControllerUnpublishVolumeResponse, error) { +// ControllerUnpublishVolume Detaching Volume +func (cs *ControllerServer) ControllerUnpublishVolume(_ context.Context, _ *csi.ControllerUnpublishVolumeRequest) (*csi.ControllerUnpublishVolumeResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method ControllerUnpublishVolume not implemented") } -// Provisioning Volumes -func (cs *controllerServer) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) { +// CreateVolume Provisioning Volumes +func (cs *ControllerServer) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) { klog.Infof("ControllerCreateVolume: called with args %+v", *req) volumeName := req.GetName() if len(volumeName) == 0 { @@ -70,18 +126,18 @@ func (cs *controllerServer) CreateVolume(ctx context.Context, req *csi.CreateVol // we don't use the size as it makes no sense for rclone. but csi drivers should succeed if // called twice with the same capacity for the same volume and fail if called twice with // differing capacity, so we need to remember it - volSizeBytes := int64(req.GetCapacityRange().GetRequiredBytes()) + volSizeBytes := req.GetCapacityRange().GetRequiredBytes() cs.mutex.Lock() defer cs.mutex.Unlock() - if val, ok := cs.active_volumes[volumeName]; ok && val != volSizeBytes { + if val, ok := cs.activeVolumes[volumeName]; ok && val != volSizeBytes { return nil, status.Errorf(codes.AlreadyExists, "Volume operation already exists for volume %s", volumeName) } - cs.active_volumes[volumeName] = volSizeBytes + cs.activeVolumes[volumeName] = volSizeBytes // See https://github.com/kubernetes-csi/external-provisioner/blob/v5.1.0/pkg/controller/controller.go#L75 // on how parameters from the persistent volume are parsed // We have to pass the secret name and namespace into the context so that the node server can use them - // The external provisioner uses the secret name and namespace but it does not pass them into the request, + // The external provisioner uses the secret name and namespace, but it does not pass them into the request, // so we read the PVC here to extract them ourselves because we may need them in the node server for decoding secrets. pvcName, pvcNameFound := req.Parameters["csi.storage.k8s.io/pvc/name"] pvcNamespace, pvcNamespaceFound := req.Parameters["csi.storage.k8s.io/pvc/namespace"] @@ -114,29 +170,28 @@ func (cs *controllerServer) CreateVolume(ctx context.Context, req *csi.CreateVol } -// Delete Volume -func (cs *controllerServer) DeleteVolume(ctx context.Context, req *csi.DeleteVolumeRequest) (*csi.DeleteVolumeResponse, error) { +func (cs *ControllerServer) DeleteVolume(_ context.Context, req *csi.DeleteVolumeRequest) (*csi.DeleteVolumeResponse, error) { volId := req.GetVolumeId() if len(volId) == 0 { - return nil, status.Error(codes.InvalidArgument, "DeteleVolume must be provided volume id") + return nil, status.Error(codes.InvalidArgument, "DeleteVolume must be provided volume id") } cs.mutex.Lock() defer cs.mutex.Unlock() - delete(cs.active_volumes, volId) + delete(cs.activeVolumes, volId) return &csi.DeleteVolumeResponse{}, nil } -func (*controllerServer) ControllerExpandVolume(ctx context.Context, req *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) { +func (*ControllerServer) ControllerExpandVolume(_ context.Context, _ *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method ControllerExpandVolume not implemented") } -func (cs *controllerServer) ControllerGetVolume(ctx context.Context, req *csi.ControllerGetVolumeRequest) (*csi.ControllerGetVolumeResponse, error) { +func (cs *ControllerServer) ControllerGetVolume(_ context.Context, req *csi.ControllerGetVolumeRequest) (*csi.ControllerGetVolumeResponse, error) { return &csi.ControllerGetVolumeResponse{Volume: &csi.Volume{ VolumeId: req.VolumeId, }}, nil } -func (cs *controllerServer) ControllerModifyVolume(ctx context.Context, req *csi.ControllerModifyVolumeRequest) (*csi.ControllerModifyVolumeResponse, error) { +func (cs *ControllerServer) ControllerModifyVolume(_ context.Context, _ *csi.ControllerModifyVolumeRequest) (*csi.ControllerModifyVolumeResponse, error) { return &csi.ControllerModifyVolumeResponse{}, nil } diff --git a/pkg/rclone/driver.go b/pkg/rclone/driver.go index 49423a78..1fad5a5b 100644 --- a/pkg/rclone/driver.go +++ b/pkg/rclone/driver.go @@ -2,185 +2,68 @@ package rclone import ( "context" - "fmt" - "net" + "errors" "os" - "path/filepath" - "sync" - "github.com/SwissDataScienceCenter/csi-rclone/pkg/kube" - "github.com/SwissDataScienceCenter/csi-rclone/pkg/metrics" "github.com/container-storage-interface/spec/lib/go/csi" csicommon "github.com/kubernetes-csi/drivers/pkg/csi-common" - "github.com/prometheus/client_golang/prometheus" + "github.com/spf13/cobra" "k8s.io/klog" - "k8s.io/utils/mount" - - utilexec "k8s.io/utils/exec" ) -type Driver struct { - CSIDriver *csicommon.CSIDriver - endpoint string +const DriverVersion = "SwissDataScienceCenter" - ns *nodeServer - cs *controllerServer - cap []*csi.VolumeCapability_AccessMode - cscap []*csi.ControllerServiceCapability - server csicommon.NonBlockingGRPCServer -} +type DriverSetup func(csiDriver *csicommon.CSIDriver) (csi.ControllerServer, csi.NodeServer, error) -var ( - DriverVersion = "SwissDataScienceCenter" -) +type DriverServe func(ctx context.Context) error -func getFreePort() (port int, err error) { - var a *net.TCPAddr - if a, err = net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { - var l *net.TCPListener - if l, err = net.ListenTCP("tcp", a); err == nil { - defer l.Close() - return l.Addr().(*net.TCPAddr).Port, nil - } +type DriverConfig struct { + Endpoint string + NodeID string +} + +func (config *DriverConfig) CommandLineParameters(runCmd *cobra.Command) error { + runCmd.PersistentFlags().StringVar(&config.NodeID, "nodeid", config.NodeID, "node id") + if err := runCmd.MarkPersistentFlagRequired("nodeid"); err != nil { + return err + } + runCmd.PersistentFlags().StringVar(&config.Endpoint, "endpoint", config.Endpoint, "CSI endpoint") + if err := runCmd.MarkPersistentFlagRequired("endpoint"); err != nil { + return err } - return + return nil } -func NewDriver(nodeID, endpoint string) *Driver { +func Run(ctx context.Context, config *DriverConfig, setup DriverSetup, serve DriverServe) error { driverName := os.Getenv("DRIVER_NAME") if driverName == "" { - panic("DriverName env var not set!") + return errors.New("DRIVER_NAME env variable not set or empty") } klog.Infof("Starting new %s RcloneDriver in version %s", driverName, DriverVersion) - d := &Driver{} - d.endpoint = endpoint - - d.CSIDriver = csicommon.NewCSIDriver(driverName, DriverVersion, nodeID) - d.CSIDriver.AddVolumeCapabilityAccessModes([]csi.VolumeCapability_AccessMode_Mode{ + driver := csicommon.NewCSIDriver(driverName, DriverVersion, config.NodeID) + driver.AddVolumeCapabilityAccessModes([]csi.VolumeCapability_AccessMode_Mode{ csi.VolumeCapability_AccessMode_SINGLE_NODE_SINGLE_WRITER, }) - d.CSIDriver.AddControllerServiceCapabilities( + driver.AddControllerServiceCapabilities( []csi.ControllerServiceCapability_RPC_Type{ csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME, }) - return d -} - -func NewNodeServer(csiDriver *csicommon.CSIDriver, cacheDir string, cacheSize string) (*nodeServer, error) { - kubeClient, err := kube.GetK8sClient() - if err != nil { - return nil, err + is := csicommon.NewDefaultIdentityServer(driver) + cs, ns, setupErr := setup(driver) + if setupErr != nil { + return setupErr } - rclonePort, err := getFreePort() - if err != nil { - return nil, fmt.Errorf("Cannot get a free TCP port to run rclone") - } - rcloneOps := NewRclone(kubeClient, rclonePort, cacheDir, cacheSize) - - // Use kubelet plugin directory for state persistence - stateFile := "/var/lib/kubelet/plugins/csi-rclone/mounted_volumes.json" - - ns := &nodeServer{ - DefaultNodeServer: csicommon.NewDefaultNodeServer(csiDriver), - mounter: &mount.SafeFormatAndMount{ - Interface: mount.New(""), - Exec: utilexec.New(), - }, - RcloneOps: rcloneOps, - mountedVolumes: make(map[string]MountedVolume), - mutex: &sync.Mutex{}, - stateFile: stateFile, - } - - // Ensure the folder exists - if err = os.MkdirAll(filepath.Dir(ns.stateFile), 0755); err != nil { - return nil, fmt.Errorf("failed to create state directory: %v", err) - } - - // Load persisted state on startup - ns.mutex.Lock() - defer ns.mutex.Unlock() - - if ns.mountedVolumes, err = readVolumeMap(ns.stateFile); err != nil { - klog.Warningf("Failed to load persisted volume state: %v", err) - } - - return ns, nil -} + s := csicommon.NewNonBlockingGRPCServer() + defer s.Stop() + s.Start(config.Endpoint, is, cs, ns) -func NewControllerServer(csiDriver *csicommon.CSIDriver) *controllerServer { - return &controllerServer{ - DefaultControllerServer: csicommon.NewDefaultControllerServer(csiDriver), - active_volumes: map[string]int64{}, - mutex: sync.RWMutex{}, + if err := serve(ctx); err != nil { + return err } -} - -func (ns *nodeServer) Metrics() []metrics.Observable { - var meters []metrics.Observable - - // What should we meter? - - return meters -} - -func (cs *controllerServer) Metrics() []metrics.Observable { - var meters []metrics.Observable - - meter := prometheus.NewGauge(prometheus.GaugeOpts{ - Name: "csi_rclone_active_volume_count", - Help: "Number of active (Mounted) volumes.", - }) - meters = append(meters, - func() { meter.Set(float64(len(cs.active_volumes))) }, - ) - prometheus.MustRegister(meter) - - return meters -} -func (d *Driver) WithNodeServer(ns *nodeServer) *Driver { - d.ns = ns - return d -} - -func (d *Driver) WithControllerServer(cs *controllerServer) *Driver { - d.cs = cs - return d -} - -func (d *Driver) Run() error { - s := csicommon.NewNonBlockingGRPCServer() - s.Start( - d.endpoint, - csicommon.NewDefaultIdentityServer(d.CSIDriver), - d.cs, - d.ns, - ) - d.server = s - if d.ns != nil && d.ns.RcloneOps != nil { - onDaemonReady := func() error { - if d.ns != nil { - return d.ns.remountTrackedVolumes(context.Background()) - } - return nil - } - return d.ns.RcloneOps.Run(onDaemonReady) - } s.Wait() return nil } - -func (d *Driver) Stop() error { - var err error - if d.ns != nil && d.ns.RcloneOps != nil { - err = d.ns.RcloneOps.Cleanup() - } - if d.server != nil { - d.server.Stop() - } - return err -} diff --git a/pkg/rclone/nodeserver.go b/pkg/rclone/nodeserver.go index 0d26791a..46a6b6ad 100644 --- a/pkg/rclone/nodeserver.go +++ b/pkg/rclone/nodeserver.go @@ -10,24 +10,30 @@ import ( "encoding/json" "errors" "fmt" + "net" "os" + "path/filepath" "runtime" "strings" "sync" "time" - "gopkg.in/ini.v1" - v1 "k8s.io/api/core/v1" - apierrors "k8s.io/apimachinery/pkg/api/errors" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/klog" - "github.com/SwissDataScienceCenter/csi-rclone/pkg/kube" + "github.com/SwissDataScienceCenter/csi-rclone/pkg/metrics" "github.com/container-storage-interface/spec/lib/go/csi" "github.com/fernet/fernet-go" + "github.com/prometheus/client_golang/prometheus" + "github.com/spf13/cobra" "golang.org/x/net/context" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "gopkg.in/ini.v1" + v1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/klog" + mountutils "k8s.io/mount-utils" + "k8s.io/utils/exec" "k8s.io/utils/mount" csicommon "github.com/kubernetes-csi/drivers/pkg/csi-common" @@ -36,7 +42,7 @@ import ( const CSI_ANNOTATION_PREFIX = "csi-rclone.dev" const pvcSecretNameAnnotation = CSI_ANNOTATION_PREFIX + "/secretName" -type nodeServer struct { +type NodeServer struct { *csicommon.DefaultNodeServer mounter *mount.SafeFormatAndMount RcloneOps Operations @@ -47,29 +53,182 @@ type nodeServer struct { stateFile string } +// unmountOldVols is used to unmount volumes after a restart on a node +func unmountOldVols() error { + const mountType = "fuse.rclone" + const unmountTimeout = time.Second * 5 + klog.Info("Checking for existing mounts") + mounter := mountutils.Mounter{} + mounts, err := mounter.List() + if err != nil { + return err + } + for _, mount := range mounts { + if mount.Type != mountType { + continue + } + err := mounter.UnmountWithForce(mount.Path, unmountTimeout) + if err != nil { + klog.Warningf("Failed to unmount %s because of %v.", mount.Path, err) + continue + } + klog.Infof("Sucessfully unmounted %s", mount.Path) + } + return nil +} + +func getFreePort() (port int, err error) { + var a *net.TCPAddr + if a, err = net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { + var l *net.TCPListener + if l, err = net.ListenTCP("tcp", a); err == nil { + defer l.Close() + return l.Addr().(*net.TCPAddr).Port, nil + } + } + return +} + +func NewNodeServer(csiDriver *csicommon.CSIDriver, cacheDir string, cacheSize string) (*NodeServer, error) { + err := unmountOldVols() + if err != nil { + klog.Warningf("There was an error when trying to unmount old volumes: %v", err) + return nil, err + } + + kubeClient, err := kube.GetK8sClient() + if err != nil { + return nil, err + } + + rclonePort, err := getFreePort() + if err != nil { + return nil, fmt.Errorf("Cannot get a free TCP port to run rclone") + } + + ns := &NodeServer{ + DefaultNodeServer: csicommon.NewDefaultNodeServer(csiDriver), + mounter: &mount.SafeFormatAndMount{ + Interface: mount.New(""), + Exec: exec.New(), + }, + RcloneOps: NewRclone(kubeClient, rclonePort, cacheDir, cacheSize), + mountedVolumes: make(map[string]MountedVolume), + mutex: &sync.Mutex{}, + // Use kubelet plugin directory for state persistence + stateFile: "/var/lib/kubelet/plugins/csi-rclone/mounted_volumes.json", + } + + // Ensure the folder exists + if err = os.MkdirAll(filepath.Dir(ns.stateFile), 0755); err != nil { + return nil, fmt.Errorf("failed to create state directory: %v", err) + } + + // Load persisted state on startup + ns.mutex.Lock() + defer ns.mutex.Unlock() + + if ns.mountedVolumes, err = readVolumeMap(ns.stateFile); err != nil { + klog.Warningf("Failed to load persisted volume state: %v", err) + } + + return ns, nil +} + +func (ns *NodeServer) Run(ctx context.Context) error { + defer ns.Stop() + return ns.RcloneOps.Run(func() error { + return ns.remountTrackedVolumes(ctx) + }) +} + +func (ns *NodeServer) metrics() []metrics.Observable { + var meters []metrics.Observable + + meter := prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "csi_rclone_active_volume_count", + Help: "Number of active (Mounted) volumes.", + }) + meters = append(meters, + func() { meter.Set(float64(len(ns.mountedVolumes))) }, + ) + prometheus.MustRegister(meter) + + return meters +} + +func (ns *NodeServer) Stop() { + if err := ns.RcloneOps.Cleanup(); err != nil { + klog.Errorf("Failed to cleanup rclone ops: %v", err) + } +} + +type NodeServerConfig struct { + DriverConfig + CacheDir string + CacheSize string + ns *NodeServer +} + +func (config *NodeServerConfig) CommandLineParameters(runCmd *cobra.Command, meters *[]metrics.Observable) error { + runNode := &cobra.Command{ + Use: "node", + Short: "Start the CSI driver node service - expected to run in a daemonset on every node.", + RunE: func(cmd *cobra.Command, args []string) error { + return Run(context.Background(), &config.DriverConfig, + func(csiDriver *csicommon.CSIDriver) (csi.ControllerServer, csi.NodeServer, error) { + ns, err := NewNodeServer(csiDriver, config.CacheDir, config.CacheSize) + if err != nil { + return nil, nil, err + } + // We go through a temporary variable to ensure that config.ns is only set with a correct NodeServer. + config.ns = ns + *meters = append(*meters, config.ns.metrics()...) + return nil, config.ns, err + }, + func(ctx context.Context) error { + if config.ns == nil { + return errors.New("node server uninitialized") + } + return config.ns.Run(ctx) + }, + ) + }, + } + if err := config.DriverConfig.CommandLineParameters(runNode); err != nil { + return err + } + + runNode.PersistentFlags().StringVar(&config.CacheDir, "cachedir", config.CacheDir, "cache dir") + runNode.PersistentFlags().StringVar(&config.CacheSize, "cachesize", config.CacheSize, "cache size") + + runCmd.AddCommand(runNode) + return nil +} + type MountedVolume struct { - VolumeId string - TargetPath string - Remote string - RemotePath string - ConfigData string - ReadOnly bool - Parameters map[string]string - SecretName string - SecretNamespace string + VolumeId string `json:"volume_id"` + TargetPath string `json:"target_path"` + Remote string `json:"remote"` + RemotePath string `json:"remote_path"` + ConfigData string `json:"config_data"` + ReadOnly bool `json:"read_only"` + Parameters map[string]string `json:"parameters"` + SecretName string `json:"secret_name"` + SecretNamespace string `json:"secret_namespace"` } // Mounting Volume (Preparation) -func (ns *nodeServer) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRequest) (*csi.NodeStageVolumeResponse, error) { +func (ns *NodeServer) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRequest) (*csi.NodeStageVolumeResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method NodeStageVolume not implemented") } -func (ns *nodeServer) NodeUnstageVolume(ctx context.Context, req *csi.NodeUnstageVolumeRequest) (*csi.NodeUnstageVolumeResponse, error) { +func (ns *NodeServer) NodeUnstageVolume(ctx context.Context, req *csi.NodeUnstageVolumeRequest) (*csi.NodeUnstageVolumeResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method NodeUnstageVolume not implemented") } // Mounting Volume (Actual Mounting) -func (ns *nodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolumeRequest) (*csi.NodePublishVolumeResponse, error) { +func (ns *NodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolumeRequest) (*csi.NodePublishVolumeResponse, error) { if err := validatePublishVolumeRequest(req); err != nil { return nil, err } @@ -327,7 +486,7 @@ func extractConfigData(parameters map[string]string) (string, map[string]string) } // Unmounting Volumes -func (ns *nodeServer) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpublishVolumeRequest) (*csi.NodeUnpublishVolumeResponse, error) { +func (ns *NodeServer) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpublishVolumeRequest) (*csi.NodeUnpublishVolumeResponse, error) { klog.Infof("NodeUnpublishVolume called with: %s", req) if err := validateUnPublishVolumeRequest(req); err != nil { return nil, err @@ -368,12 +527,12 @@ func validateUnPublishVolumeRequest(req *csi.NodeUnpublishVolumeRequest) error { } // Resizing Volume -func (*nodeServer) NodeExpandVolume(ctx context.Context, req *csi.NodeExpandVolumeRequest) (*csi.NodeExpandVolumeResponse, error) { +func (*NodeServer) NodeExpandVolume(ctx context.Context, req *csi.NodeExpandVolumeRequest) (*csi.NodeExpandVolumeResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method NodeExpandVolume not implemented") } // Track mounted volume for automatic remounting -func (ns *nodeServer) trackMountedVolume(volumeId, targetPath, remote, remotePath, configData string, readOnly bool, parameters map[string]string, secretName, secretNamespace string) { +func (ns *NodeServer) trackMountedVolume(volumeId, targetPath, remote, remotePath, configData string, readOnly bool, parameters map[string]string, secretName, secretNamespace string) { ns.mutex.Lock() defer ns.mutex.Unlock() @@ -396,7 +555,7 @@ func (ns *nodeServer) trackMountedVolume(volumeId, targetPath, remote, remotePat } // Remove tracked volume when unmounted -func (ns *nodeServer) removeTrackedVolume(volumeId string) { +func (ns *NodeServer) removeTrackedVolume(volumeId string) { ns.mutex.Lock() defer ns.mutex.Unlock() @@ -409,7 +568,7 @@ func (ns *nodeServer) removeTrackedVolume(volumeId string) { } // Automatically remount all tracked volumes after daemon restart -func (ns *nodeServer) remountTrackedVolumes(ctx context.Context) error { +func (ns *NodeServer) remountTrackedVolumes(ctx context.Context) error { type mountResult struct { volumeID string err error @@ -418,18 +577,19 @@ func (ns *nodeServer) remountTrackedVolumes(ctx context.Context) error { ns.mutex.Lock() defer ns.mutex.Unlock() - if len(ns.mountedVolumes) == 0 { + volumesCount := len(ns.mountedVolumes) + + if volumesCount == 0 { klog.Info("No tracked volumes to remount") return nil } - klog.Infof("Remounting %d tracked volumes", len(ns.mountedVolumes)) + klog.Infof("Remounting %d tracked volumes", volumesCount) // Limit the number of active workers to the number of CPU threads (arbitrarily chosen) limits := make(chan bool, runtime.GOMAXPROCS(0)) defer close(limits) - volumesCount := len(ns.mountedVolumes) results := make(chan mountResult, volumesCount) defer close(results) @@ -488,7 +648,7 @@ func (ns *nodeServer) remountTrackedVolumes(ctx context.Context) error { } } -func (ns *nodeServer) WaitForMountAvailable(mountpoint string) error { +func (ns *NodeServer) WaitForMountAvailable(mountpoint string) error { for { select { case <-time.After(100 * time.Millisecond): diff --git a/test/sanity_test.go b/test/sanity_test.go index 0527cd41..6e400d97 100644 --- a/test/sanity_test.go +++ b/test/sanity_test.go @@ -9,9 +9,11 @@ import ( "github.com/SwissDataScienceCenter/csi-rclone/pkg/kube" "github.com/SwissDataScienceCenter/csi-rclone/pkg/rclone" + "github.com/container-storage-interface/spec/lib/go/csi" "github.com/google/uuid" "github.com/kubernetes-csi/csi-test/v5/pkg/sanity" "github.com/kubernetes-csi/csi-test/v5/utils" + csicommon "github.com/kubernetes-csi/drivers/pkg/csi-common" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "google.golang.org/grpc" @@ -45,25 +47,35 @@ func createSocketDir() (string, error) { var _ = Describe("Sanity CSI checks", Ordered, func() { var err error var kubeClient *kubernetes.Clientset = &kubernetes.Clientset{} + var nodeID string var endpoint string - var driver *rclone.Driver = &rclone.Driver{} var socketDir string BeforeAll(func() { socketDir, err = createSocketDir() Expect(err).ShouldNot(HaveOccurred()) + nodeID = "hostname" endpoint = fmt.Sprintf("unix://%s/csi.sock", socketDir) kubeClient, err = kube.GetK8sClient() Expect(err).ShouldNot(HaveOccurred()) os.Setenv("DRIVER_NAME", "csi-rclone") - driver = rclone.NewDriver("hostname", endpoint) - cs := rclone.NewControllerServer(driver.CSIDriver) - ns, err := rclone.NewNodeServer(driver.CSIDriver, "", "") - Expect(err).ShouldNot(HaveOccurred()) - driver.WithControllerServer(cs).WithNodeServer(ns) go func() { defer GinkgoRecover() - err := driver.Run() + var nsDoublePointer **rclone.NodeServer + err := rclone.Run(context.Background(), &nodeID, &endpoint, + func(csiDriver *csicommon.CSIDriver) (csi.ControllerServer, csi.NodeServer, error) { + cs := rclone.NewControllerServer(csiDriver) + ns, err := rclone.NewNodeServer(csiDriver, "", "") + if err != nil { + return nil, nil, err + } + nsDoublePointer = &ns + return cs, ns, nil + }, + func(ctx context.Context) error { + return (*nsDoublePointer).Run(ctx) + }, + ) Expect(err).ShouldNot(HaveOccurred()) }() _, err = utils.Connect(endpoint, grpc.WithTransportCredentials(insecure.NewCredentials())) @@ -71,7 +83,6 @@ var _ = Describe("Sanity CSI checks", Ordered, func() { }) AfterAll(func() { - driver.Stop() os.RemoveAll(socketDir) os.Unsetenv("DRIVER_NAME") }) From 29814fb52bbf39d72479fc174a69c11b7b92b9c9 Mon Sep 17 00:00:00 2001 From: Lionel Sambuc Date: Tue, 9 Dec 2025 13:36:30 +0100 Subject: [PATCH 04/10] fix tests after introduction of DriverConfig (#80) Adapt function prototypes --- pkg/rclone/controllerserver.go | 4 ++-- pkg/rclone/driver.go | 6 +++--- pkg/rclone/nodeserver.go | 15 ++++++--------- test/sanity_test.go | 23 +++++++++-------------- 4 files changed, 20 insertions(+), 28 deletions(-) diff --git a/pkg/rclone/controllerserver.go b/pkg/rclone/controllerserver.go index 1ae3f06a..4b39a848 100644 --- a/pkg/rclone/controllerserver.go +++ b/pkg/rclone/controllerserver.go @@ -61,12 +61,12 @@ func (config *ControllerServerConfig) CommandLineParameters(runCmd *cobra.Comman RunE: func(cmd *cobra.Command, args []string) error { return Run(context.Background(), &config.DriverConfig, - func(csiDriver *csicommon.CSIDriver) (csi.ControllerServer, csi.NodeServer, error) { + func(csiDriver *csicommon.CSIDriver) (*ControllerServer, *NodeServer, error) { cs := NewControllerServer(csiDriver) *meters = append(*meters, cs.metrics()...) return cs, nil, nil }, - func(_ context.Context) error { return nil }, + func(_ context.Context, cs *ControllerServer, ns *NodeServer) error { return nil }, ) }, } diff --git a/pkg/rclone/driver.go b/pkg/rclone/driver.go index 1fad5a5b..e0b7e53d 100644 --- a/pkg/rclone/driver.go +++ b/pkg/rclone/driver.go @@ -13,9 +13,9 @@ import ( const DriverVersion = "SwissDataScienceCenter" -type DriverSetup func(csiDriver *csicommon.CSIDriver) (csi.ControllerServer, csi.NodeServer, error) +type DriverSetup func(csiDriver *csicommon.CSIDriver) (*ControllerServer, *NodeServer, error) -type DriverServe func(ctx context.Context) error +type DriverServe func(ctx context.Context, cs *ControllerServer, ns *NodeServer) error type DriverConfig struct { Endpoint string @@ -60,7 +60,7 @@ func Run(ctx context.Context, config *DriverConfig, setup DriverSetup, serve Dri defer s.Stop() s.Start(config.Endpoint, is, cs, ns) - if err := serve(ctx); err != nil { + if err := serve(ctx, cs, ns); err != nil { return err } diff --git a/pkg/rclone/nodeserver.go b/pkg/rclone/nodeserver.go index 46a6b6ad..620b1292 100644 --- a/pkg/rclone/nodeserver.go +++ b/pkg/rclone/nodeserver.go @@ -167,7 +167,6 @@ type NodeServerConfig struct { DriverConfig CacheDir string CacheSize string - ns *NodeServer } func (config *NodeServerConfig) CommandLineParameters(runCmd *cobra.Command, meters *[]metrics.Observable) error { @@ -176,21 +175,19 @@ func (config *NodeServerConfig) CommandLineParameters(runCmd *cobra.Command, met Short: "Start the CSI driver node service - expected to run in a daemonset on every node.", RunE: func(cmd *cobra.Command, args []string) error { return Run(context.Background(), &config.DriverConfig, - func(csiDriver *csicommon.CSIDriver) (csi.ControllerServer, csi.NodeServer, error) { + func(csiDriver *csicommon.CSIDriver) (*ControllerServer, *NodeServer, error) { ns, err := NewNodeServer(csiDriver, config.CacheDir, config.CacheSize) if err != nil { return nil, nil, err } - // We go through a temporary variable to ensure that config.ns is only set with a correct NodeServer. - config.ns = ns - *meters = append(*meters, config.ns.metrics()...) - return nil, config.ns, err + *meters = append(*meters, ns.metrics()...) + return nil, ns, err }, - func(ctx context.Context) error { - if config.ns == nil { + func(ctx context.Context, cs *ControllerServer, ns *NodeServer) error { + if ns == nil { return errors.New("node server uninitialized") } - return config.ns.Run(ctx) + return ns.Run(ctx) }, ) }, diff --git a/test/sanity_test.go b/test/sanity_test.go index 6e400d97..3a1044de 100644 --- a/test/sanity_test.go +++ b/test/sanity_test.go @@ -9,7 +9,6 @@ import ( "github.com/SwissDataScienceCenter/csi-rclone/pkg/kube" "github.com/SwissDataScienceCenter/csi-rclone/pkg/rclone" - "github.com/container-storage-interface/spec/lib/go/csi" "github.com/google/uuid" "github.com/kubernetes-csi/csi-test/v5/pkg/sanity" "github.com/kubernetes-csi/csi-test/v5/utils" @@ -47,33 +46,29 @@ func createSocketDir() (string, error) { var _ = Describe("Sanity CSI checks", Ordered, func() { var err error var kubeClient *kubernetes.Clientset = &kubernetes.Clientset{} - var nodeID string var endpoint string var socketDir string BeforeAll(func() { socketDir, err = createSocketDir() Expect(err).ShouldNot(HaveOccurred()) - nodeID = "hostname" endpoint = fmt.Sprintf("unix://%s/csi.sock", socketDir) + config := rclone.NodeServerConfig{DriverConfig: rclone.DriverConfig{Endpoint: endpoint, NodeID: "hostname"}} kubeClient, err = kube.GetK8sClient() Expect(err).ShouldNot(HaveOccurred()) os.Setenv("DRIVER_NAME", "csi-rclone") go func() { defer GinkgoRecover() - var nsDoublePointer **rclone.NodeServer - err := rclone.Run(context.Background(), &nodeID, &endpoint, - func(csiDriver *csicommon.CSIDriver) (csi.ControllerServer, csi.NodeServer, error) { + err := rclone.Run(context.Background(), &config.DriverConfig, + func(csiDriver *csicommon.CSIDriver) (*rclone.ControllerServer, *rclone.NodeServer, error) { cs := rclone.NewControllerServer(csiDriver) - ns, err := rclone.NewNodeServer(csiDriver, "", "") - if err != nil { - return nil, nil, err - } - nsDoublePointer = &ns - return cs, ns, nil + ns, err := rclone.NewNodeServer(csiDriver, config.CacheDir, config.CacheSize) + Expect(err).ShouldNot(HaveOccurred()) + return cs, ns, err }, - func(ctx context.Context) error { - return (*nsDoublePointer).Run(ctx) + func(ctx context.Context, cs *rclone.ControllerServer, ns *rclone.NodeServer) error { + Expect(ns).ShouldNot(BeNil()) + return ns.Run(ctx) }, ) Expect(err).ShouldNot(HaveOccurred()) From 9292bc42a9f952aa284eecea0e4889b607c13562 Mon Sep 17 00:00:00 2001 From: Tasko Olevski Date: Tue, 9 Dec 2025 14:41:35 +0100 Subject: [PATCH 05/10] fix: propagate context more thoroughly (#79) --- cmd/csi-rclone-plugin/main.go | 4 +-- pkg/common/constants.go | 11 +++++++ pkg/rclone/controllerserver.go | 3 +- pkg/rclone/nodeserver.go | 5 ++-- pkg/rclone/rclone.go | 52 +++++++++++++++++++++++++--------- test/sanity_test.go | 2 +- 6 files changed, 58 insertions(+), 19 deletions(-) create mode 100644 pkg/common/constants.go diff --git a/cmd/csi-rclone-plugin/main.go b/cmd/csi-rclone-plugin/main.go index 73160e11..854bb964 100644 --- a/cmd/csi-rclone-plugin/main.go +++ b/cmd/csi-rclone-plugin/main.go @@ -6,9 +6,9 @@ import ( "fmt" "os" "os/signal" - "syscall" "time" + "github.com/SwissDataScienceCenter/csi-rclone/pkg/common" "github.com/SwissDataScienceCenter/csi-rclone/pkg/metrics" "github.com/SwissDataScienceCenter/csi-rclone/pkg/rclone" "github.com/spf13/cobra" @@ -67,7 +67,7 @@ func main() { if metricsServerConfig.Enabled { // Gracefully exit the metrics background servers - ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) + ctx, stop := signal.NotifyContext(context.Background(), common.InterruptSignals...) defer stop() metricsServer := metricsServerConfig.NewServer(ctx, &meters) diff --git a/pkg/common/constants.go b/pkg/common/constants.go new file mode 100644 index 00000000..a591954e --- /dev/null +++ b/pkg/common/constants.go @@ -0,0 +1,11 @@ +package common + +import ( + "os" + "syscall" +) + +// Signals to listen to: +// 1. os.Interrup -> allows devs to easily run a server locally +// 2. syscall.SIGTERM -> sent by kubernetes when stopping a server gracefully +var InterruptSignals = []os.Signal{os.Interrupt, syscall.SIGTERM} diff --git a/pkg/rclone/controllerserver.go b/pkg/rclone/controllerserver.go index 4b39a848..152011f8 100644 --- a/pkg/rclone/controllerserver.go +++ b/pkg/rclone/controllerserver.go @@ -59,7 +59,8 @@ func (config *ControllerServerConfig) CommandLineParameters(runCmd *cobra.Comman Use: "controller", Short: "Start the CSI driver controller.", RunE: func(cmd *cobra.Command, args []string) error { - return Run(context.Background(), + ctx := cmd.Context() + return Run(ctx, &config.DriverConfig, func(csiDriver *csicommon.CSIDriver) (*ControllerServer, *NodeServer, error) { cs := NewControllerServer(csiDriver) diff --git a/pkg/rclone/nodeserver.go b/pkg/rclone/nodeserver.go index 620b1292..615e22f0 100644 --- a/pkg/rclone/nodeserver.go +++ b/pkg/rclone/nodeserver.go @@ -137,7 +137,7 @@ func NewNodeServer(csiDriver *csicommon.CSIDriver, cacheDir string, cacheSize st func (ns *NodeServer) Run(ctx context.Context) error { defer ns.Stop() - return ns.RcloneOps.Run(func() error { + return ns.RcloneOps.Run(ctx, func() error { return ns.remountTrackedVolumes(ctx) }) } @@ -174,7 +174,8 @@ func (config *NodeServerConfig) CommandLineParameters(runCmd *cobra.Command, met Use: "node", Short: "Start the CSI driver node service - expected to run in a daemonset on every node.", RunE: func(cmd *cobra.Command, args []string) error { - return Run(context.Background(), &config.DriverConfig, + ctx := cmd.Context() + return Run(ctx, &config.DriverConfig, func(csiDriver *csicommon.CSIDriver) (*ControllerServer, *NodeServer, error) { ns, err := NewNodeServer(csiDriver, config.CacheDir, config.CacheSize) if err != nil { diff --git a/pkg/rclone/rclone.go b/pkg/rclone/rclone.go index 264b2719..12617e4c 100644 --- a/pkg/rclone/rclone.go +++ b/pkg/rclone/rclone.go @@ -34,7 +34,7 @@ type Operations interface { Unmount(ctx context.Context, volumeId string, targetPath string) error GetVolumeById(ctx context.Context, volumeId string) (*RcloneVolume, error) Cleanup() error - Run(onDaemonReady func() error) error + Run(ctx context.Context, onDaemonReady func() error) error } type Rclone struct { @@ -160,7 +160,11 @@ func (r *Rclone) Mount(ctx context.Context, rcloneVolume *RcloneVolume, targetPa return fmt.Errorf("mounting failed: couldn't create request body: %s", err) } requestBody := bytes.NewBuffer(postBody) - resp, err := http.Post(fmt.Sprintf("http://localhost:%d/config/create", r.port), "application/json", requestBody) + req, err := createRcloneRequest(ctx, http.MethodPost, requestBody, "/config/create", r.port) + if err != nil { + return fmt.Errorf("mounting failed: cannot create a request for rclone config creation: %w", err) + } + resp, err := http.DefaultClient.Do(req) if err != nil { return fmt.Errorf("mounting failed: couldn't send HTTP request to create config: %w", err) } @@ -218,7 +222,11 @@ func (r *Rclone) Mount(ctx context.Context, rcloneVolume *RcloneVolume, targetPa } klog.Infof("executing mount command args=%s", string(postBody)) requestBody = bytes.NewBuffer(postBody) - resp, err = http.Post(fmt.Sprintf("http://localhost:%d/mount/mount", r.port), "application/json", requestBody) + req, err = createRcloneRequest(ctx, http.MethodPost, requestBody, "/mount/mount", r.port) + if err != nil { + return fmt.Errorf("mounting failed: cannot create a request for rclone mounting: %w", err) + } + resp, err = http.DefaultClient.Do(req) if err != nil { return fmt.Errorf("mounting failed: couldn't send HTTP request to create mount: %w", err) } @@ -249,7 +257,7 @@ func (r *Rclone) CreateVol(ctx context.Context, volumeName, remote, remotePath, } flags["config"] = rcloneConfigPath - return r.command("mkdir", remote, path, flags) + return r.command(ctx, "mkdir", remote, path, flags) } func (r Rclone) DeleteVol(ctx context.Context, rcloneVolume *RcloneVolume, rcloneConfigPath string, parameters map[string]string) error { @@ -258,7 +266,7 @@ func (r Rclone) DeleteVol(ctx context.Context, rcloneVolume *RcloneVolume, rclon flags[key] = value } flags["config"] = rcloneConfigPath - return r.command("purge", rcloneVolume.Remote, rcloneVolume.RemotePath, flags) + return r.command(ctx, "purge", rcloneVolume.Remote, rcloneVolume.RemotePath, flags) } func (r Rclone) Unmount(ctx context.Context, volumeId string, targetPath string) error { @@ -273,7 +281,11 @@ func (r Rclone) Unmount(ctx context.Context, volumeId string, targetPath string) return fmt.Errorf("unmounting failed: couldn't create request body: %s", err) } requestBody := bytes.NewBuffer(postBody) - resp, err := http.Post(fmt.Sprintf("http://localhost:%d/mount/unmount", r.port), "application/json", requestBody) + req, err := createRcloneRequest(ctx, http.MethodPost, requestBody, "/mount/unmount", r.port) + if err != nil { + return fmt.Errorf("unmounting failed: couldn't create a request for rclone: %w", err) + } + resp, err := http.DefaultClient.Do(req) if err != nil { return fmt.Errorf("unmounting failed: couldn't send HTTP request: %w", err) } @@ -291,7 +303,11 @@ func (r Rclone) Unmount(ctx context.Context, volumeId string, targetPath string) return fmt.Errorf("deleting config failed: couldn't create request body: %s", err) } requestBody = bytes.NewBuffer(postBody) - resp, err = http.Post(fmt.Sprintf("http://localhost:%d/config/delete", r.port), "application/json", requestBody) + req, err = createRcloneRequest(ctx, http.MethodPost, requestBody, "/config/delete", r.port) + if err != nil { + return fmt.Errorf("unmounting failed: couldn't create a request for rclone configuration deletion: %w", err) + } + resp, err = http.DefaultClient.Do(req) if err != nil { klog.Errorf("deleting config failed: couldn't send HTTP request: %v", err) return nil @@ -409,7 +425,7 @@ func checkResponse(resp *http.Response) error { return fmt.Errorf("received error from the rclone server: %s", result.String()) } -func (r *Rclone) start_daemon() error { +func (r *Rclone) start_daemon(ctx context.Context) error { f, err := os.CreateTemp("", "rclone.conf") if err != nil { return err @@ -449,7 +465,7 @@ func (r *Rclone) start_daemon() error { klog.Infof("running rclone remote control daemon cmd=%s, args=%s", rclone_cmd, rclone_args) env := os.Environ() - cmd := os_exec.Command(rclone_cmd, rclone_args...) + cmd := os_exec.CommandContext(ctx, rclone_cmd, rclone_args...) cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} stdout, err := cmd.StdoutPipe() cmd.Stderr = cmd.Stdout @@ -472,8 +488,8 @@ func (r *Rclone) start_daemon() error { return nil } -func (r *Rclone) Run(onDaemonReady func() error) error { - err := r.start_daemon() +func (r *Rclone) Run(ctx context.Context, onDaemonReady func() error) error { + err := r.start_daemon(ctx) if err != nil { return err } @@ -494,7 +510,7 @@ func (r *Rclone) Cleanup() error { return r.daemonCmd.Process.Kill() } -func (r *Rclone) command(cmd, remote, remotePath string, flags map[string]string) error { +func (r *Rclone) command(ctx context.Context, cmd, remote, remotePath string, flags map[string]string) error { // rclone remote:path [flag] args := append( []string{}, @@ -508,7 +524,7 @@ func (r *Rclone) command(cmd, remote, remotePath string, flags map[string]string } klog.Infof("executing %s command cmd=rclone, remote=%s:%s", cmd, remote, remotePath) - out, err := r.execute.Command("rclone", args...).CombinedOutput() + out, err := r.execute.CommandContext(ctx, "rclone", args...).CombinedOutput() if err != nil { return fmt.Errorf("%s failed: %v cmd: 'rclone' remote: '%s' remotePath:'%s' args:'%s' output: %q", cmd, err, remote, remotePath, args, string(out)) @@ -516,3 +532,13 @@ func (r *Rclone) command(cmd, remote, remotePath string, flags map[string]string return nil } + +func createRcloneRequest(ctx context.Context, method string, body io.Reader, path string, rcloneServerPort int) (*http.Request, error) { + rcloneServerURL := fmt.Sprintf("http://localhost:%d/%s", rcloneServerPort, strings.TrimLeft(path, "/")) + req, err := http.NewRequestWithContext(ctx, method, rcloneServerURL, body) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + return req, nil +} diff --git a/test/sanity_test.go b/test/sanity_test.go index 3a1044de..a74beb6a 100644 --- a/test/sanity_test.go +++ b/test/sanity_test.go @@ -49,7 +49,7 @@ var _ = Describe("Sanity CSI checks", Ordered, func() { var endpoint string var socketDir string - BeforeAll(func() { + BeforeAll(func(ctx SpecContext) { socketDir, err = createSocketDir() Expect(err).ShouldNot(HaveOccurred()) endpoint = fmt.Sprintf("unix://%s/csi.sock", socketDir) From 2675531ed98efaa2f1429af94898509b9f719bbc Mon Sep 17 00:00:00 2001 From: Lionel Sambuc Date: Fri, 12 Dec 2025 14:03:13 +0100 Subject: [PATCH 06/10] fix: the error handling was creating issues ignored previously (#81) --- cmd/csi-rclone-plugin/main.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/cmd/csi-rclone-plugin/main.go b/cmd/csi-rclone-plugin/main.go index 854bb964..7f2b37fb 100644 --- a/cmd/csi-rclone-plugin/main.go +++ b/cmd/csi-rclone-plugin/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "errors" "flag" "fmt" "os" @@ -12,11 +13,13 @@ import ( "github.com/SwissDataScienceCenter/csi-rclone/pkg/metrics" "github.com/SwissDataScienceCenter/csi-rclone/pkg/rclone" "github.com/spf13/cobra" + "github.com/spf13/pflag" "k8s.io/klog" ) func exitOnError(err error) { - if err != nil { + // ParseFlags uses errors to return some status information, ignore it here. + if err != nil && !errors.Is(err, pflag.ErrHelp) { klog.Error(err.Error()) os.Exit(1) } @@ -43,6 +46,10 @@ func main() { Use: "rclone", Short: "CSI based rclone driver", } + // Allow flags to be defined in subcommands, they will be reported at the Execute() step, with the help printed + // before exiting. + root.FParseErrWhitelist.UnknownFlags = true + metricsServerConfig.CommandLineParameters(root) runCmd := &cobra.Command{ @@ -58,7 +65,7 @@ func main() { Use: "version", Short: "Prints information about this version of csi rclone plugin", Run: func(cmd *cobra.Command, args []string) { - fmt.Printf("csi-rclone plugin Version: %s", rclone.DriverVersion) + fmt.Printf("csi-rclone plugin Version: %s\n", rclone.DriverVersion) }, } root.AddCommand(versionCmd) From 6372a5df32b95c8f164c366fcca49e6310aaab4e Mon Sep 17 00:00:00 2001 From: Lionel Sambuc Date: Mon, 15 Dec 2025 15:42:44 +0100 Subject: [PATCH 07/10] fix: Use node tmp folder for the mounts recovery state (#82) - Use a folder name on the host which contains the deployment, to prevent conflicts in case of multiple deployment on the same host. - Cleaned up a bit the templates to make it easier to compare. - Use a folder under /tmp so that state is cleaned on node reboot, but kept between pod/container restarts --------- Co-authored-by: Tasko Olevski --- .devcontainer/rclone/install.sh | 2 +- .../templates/csi-controller-rclone.yaml | 8 +++---- .../templates/csi-nodeplugin-rclone.yaml | 21 ++++++++++++------- pkg/rclone/nodeserver.go | 3 +-- 4 files changed, 20 insertions(+), 14 deletions(-) diff --git a/.devcontainer/rclone/install.sh b/.devcontainer/rclone/install.sh index a972f212..8628ce9c 100644 --- a/.devcontainer/rclone/install.sh +++ b/.devcontainer/rclone/install.sh @@ -21,4 +21,4 @@ chown -R "${USERNAME}:golang" /go chmod -R g+r+w /go # Make sure the default folders exists -mkdir -p /var/lib/kubelet/plugins/csi-rclone/ \ No newline at end of file +mkdir -p /run/csi-rclone \ No newline at end of file diff --git a/deploy/csi-rclone/templates/csi-controller-rclone.yaml b/deploy/csi-rclone/templates/csi-controller-rclone.yaml index dd13c437..4f65cd82 100644 --- a/deploy/csi-rclone/templates/csi-controller-rclone.yaml +++ b/deploy/csi-rclone/templates/csi-controller-rclone.yaml @@ -54,8 +54,8 @@ spec: image: {{ .Values.csiControllerRclone.csiProvisioner.image.repository }}:{{ .Values.csiControllerRclone.csiProvisioner.image.tag | default .Chart.AppVersion }} imagePullPolicy: {{ .Values.csiControllerRclone.csiProvisioner.imagePullPolicy }} volumeMounts: - - name: socket-dir - mountPath: /csi + - mountPath: /csi + name: socket-dir - name: rclone args: - run @@ -85,7 +85,7 @@ spec: fieldRef: fieldPath: spec.nodeName - name: CSI_ENDPOINT - value: "unix://plugin/csi.sock" + value: "unix://csi/csi.sock" - name: KUBERNETES_CLUSTER_DOMAIN value: {{ quote .Values.kubernetesClusterDomain }} {{- if .Values.csiControllerRclone.rclone.goMemLimit }} @@ -114,7 +114,7 @@ spec: timeoutSeconds: 3 periodSeconds: 2 volumeMounts: - - mountPath: /plugin + - mountPath: /csi name: socket-dir - name: liveness-probe imagePullPolicy: Always diff --git a/deploy/csi-rclone/templates/csi-nodeplugin-rclone.yaml b/deploy/csi-rclone/templates/csi-nodeplugin-rclone.yaml index 670badc5..be27eb09 100644 --- a/deploy/csi-rclone/templates/csi-nodeplugin-rclone.yaml +++ b/deploy/csi-rclone/templates/csi-nodeplugin-rclone.yaml @@ -22,7 +22,7 @@ spec: - name: node-driver-registrar args: - --v=5 - - --csi-address=/plugin/csi.sock + - --csi-address=/csi/csi.sock - --kubelet-registration-path=/var/lib/kubelet/plugins/{{ .Values.storageClassName }}/csi.sock env: - name: KUBE_NODE_NAME @@ -45,7 +45,7 @@ spec: resources: {{- toYaml .Values.csiNodepluginRclone.rclone.resources | nindent 12 }} volumeMounts: - - mountPath: /plugin + - mountPath: /csi name: plugin-dir - mountPath: /registration name: registration-dir @@ -53,9 +53,9 @@ spec: imagePullPolicy: Always image: registry.k8s.io/sig-storage/livenessprobe:v2.15.0 args: - - --csi-address=/plugin/csi.sock + - --csi-address=/csi/csi.sock volumeMounts: - - mountPath: /plugin + - mountPath: /csi name: plugin-dir - name: rclone args: @@ -86,7 +86,7 @@ spec: fieldRef: fieldPath: spec.nodeName - name: CSI_ENDPOINT - value: "unix://plugin/csi.sock" + value: "unix://csi/csi.sock" - name: KUBERNETES_CLUSTER_DOMAIN value: {{ quote .Values.kubernetesClusterDomain }} - name: DRIVER_NAME @@ -134,8 +134,10 @@ spec: timeoutSeconds: 10 periodSeconds: 30 volumeMounts: - - mountPath: /plugin + - mountPath: /csi name: plugin-dir + - mountPath: /run/csi-rclone + name: node-temp-dir - mountPath: /var/lib/kubelet/pods mountPropagation: Bidirectional name: pods-mount-dir @@ -154,6 +156,11 @@ spec: {{ toYaml . | nindent 8 }} {{- end }} volumes: + - hostPath: + # NOTE: We mount on /tmp because we want the saved configuration to not survive a whole node restart. + path: /tmp/{{.Release.Namespace}}-{{.Release.Name}}-{{.Release.Revision}} + type: DirectoryOrCreate + name: node-temp-dir - hostPath: path: {{ .Values.kubeletDir }}/plugins/{{ .Values.storageClassName }} type: DirectoryOrCreate @@ -167,4 +174,4 @@ spec: type: DirectoryOrCreate name: registration-dir - name: cache-dir - emptyDir: + emptyDir: {} diff --git a/pkg/rclone/nodeserver.go b/pkg/rclone/nodeserver.go index 615e22f0..830f232e 100644 --- a/pkg/rclone/nodeserver.go +++ b/pkg/rclone/nodeserver.go @@ -115,8 +115,7 @@ func NewNodeServer(csiDriver *csicommon.CSIDriver, cacheDir string, cacheSize st RcloneOps: NewRclone(kubeClient, rclonePort, cacheDir, cacheSize), mountedVolumes: make(map[string]MountedVolume), mutex: &sync.Mutex{}, - // Use kubelet plugin directory for state persistence - stateFile: "/var/lib/kubelet/plugins/csi-rclone/mounted_volumes.json", + stateFile: "/run/csi-rclone/mounted_volumes.json", } // Ensure the folder exists From 06815222b677ceddbfaa1cbd1303d213133206e8 Mon Sep 17 00:00:00 2001 From: Lionel Sambuc Date: Mon, 15 Dec 2025 11:37:51 +0100 Subject: [PATCH 08/10] fix: Wait for the deamon to be ready --- pkg/rclone/rclone.go | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/pkg/rclone/rclone.go b/pkg/rclone/rclone.go index 12617e4c..31a35a58 100644 --- a/pkg/rclone/rclone.go +++ b/pkg/rclone/rclone.go @@ -11,6 +11,7 @@ import ( "os" os_exec "os/exec" "syscall" + "time" "strings" @@ -425,6 +426,25 @@ func checkResponse(resp *http.Response) error { return fmt.Errorf("received error from the rclone server: %s", result.String()) } +func waitForDaemon(ctx context.Context, port int) error { + // Wait for the daemon to have started + ctxWaitRcloneStart, cancel := context.WithTimeout(ctx, 1*time.Second) + defer cancel() + + req, err := createRcloneRequest(ctxWaitRcloneStart, http.MethodPost, nil, "/core/version", port) + if err != nil { + return err + } + + for { + _, err = http.DefaultClient.Do(req) + // Keep trying until we retrieve a response, or we hit the deadline + if err == nil || errors.Is(err, context.DeadlineExceeded) { + return err + } + } +} + func (r *Rclone) start_daemon(ctx context.Context) error { f, err := os.CreateTemp("", "rclone.conf") if err != nil { @@ -477,6 +497,11 @@ func (r *Rclone) start_daemon(ctx context.Context) error { if err := cmd.Start(); err != nil { return err } + + if err := waitForDaemon(ctx, r.port); err != nil { + return err + } + go func() { output := "" for scanner.Scan() { From 23cd7c85b84b4f7b987e677cab0fe8e02668c90f Mon Sep 17 00:00:00 2001 From: Lionel Sambuc Date: Mon, 15 Dec 2025 11:22:17 +0100 Subject: [PATCH 09/10] fix: Add a copy of the GRPC Server which re-uses the socket. --- pkg/rclone/MyGRPCServer.go | 101 +++++++++++++++++++++++++++++++++++++ pkg/rclone/driver.go | 2 +- 2 files changed, 102 insertions(+), 1 deletion(-) create mode 100644 pkg/rclone/MyGRPCServer.go diff --git a/pkg/rclone/MyGRPCServer.go b/pkg/rclone/MyGRPCServer.go new file mode 100644 index 00000000..f9278f96 --- /dev/null +++ b/pkg/rclone/MyGRPCServer.go @@ -0,0 +1,101 @@ +package rclone + +import ( + "context" + "net" + "sync" + + "github.com/container-storage-interface/spec/lib/go/csi" + "github.com/golang/glog" + "github.com/kubernetes-csi/csi-lib-utils/protosanitizer" + csicommon "github.com/kubernetes-csi/drivers/pkg/csi-common" + "google.golang.org/grpc" +) + +// Override the serve function to keep the csi socket instead of removing it before use. +// this is basically a copy-paste of csi-common/server.go with two lines modified. +// done as a quick hack and as the actual implementation struct is private, so I can't ovveride the `server`function` +// only. + +func NewMyGRPCServer() csicommon.NonBlockingGRPCServer { + return &myGRPCServer{} +} + +type myGRPCServer struct { + wg sync.WaitGroup + server *grpc.Server +} + +func (s *myGRPCServer) Start(endpoint string, ids csi.IdentityServer, cs csi.ControllerServer, ns csi.NodeServer) { + + s.wg.Add(1) + + go s.serve(endpoint, ids, cs, ns) + + return +} + +func (s *myGRPCServer) Wait() { + s.wg.Wait() +} + +func (s *myGRPCServer) Stop() { + s.server.GracefulStop() +} + +func (s *myGRPCServer) ForceStop() { + s.server.Stop() +} + +func logGRPC(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + glog.V(3).Infof("GRPC call: %s", info.FullMethod) + glog.V(5).Infof("GRPC request: %s", protosanitizer.StripSecrets(req)) + resp, err := handler(ctx, req) + if err != nil { + glog.Errorf("GRPC error: %v", err) + } else { + glog.V(5).Infof("GRPC response: %s", protosanitizer.StripSecrets(resp)) + } + return resp, err +} + +func (s *myGRPCServer) serve(endpoint string, ids csi.IdentityServer, cs csi.ControllerServer, ns csi.NodeServer) { + + proto, addr, err := csicommon.ParseEndpoint(endpoint) + if err != nil { + glog.Fatal(err.Error()) + } + + if proto == "unix" { + addr = "/" + addr + //if err := os.Remove(addr); err != nil && !os.IsNotExist(err) { + // glog.Fatalf("Failed to remove %s, error: %s", addr, err.Error()) + //} + } + + listener, err := net.Listen(proto, addr) + if err != nil { + glog.Fatalf("Failed to listen: %v", err) + } + + opts := []grpc.ServerOption{ + grpc.UnaryInterceptor(logGRPC), + } + server := grpc.NewServer(opts...) + s.server = server + + if ids != nil { + csi.RegisterIdentityServer(server, ids) + } + if cs != nil { + csi.RegisterControllerServer(server, cs) + } + if ns != nil { + csi.RegisterNodeServer(server, ns) + } + + glog.Infof("Listening for connections on address: %#v", listener.Addr()) + + server.Serve(listener) + +} diff --git a/pkg/rclone/driver.go b/pkg/rclone/driver.go index e0b7e53d..08d9adf4 100644 --- a/pkg/rclone/driver.go +++ b/pkg/rclone/driver.go @@ -56,7 +56,7 @@ func Run(ctx context.Context, config *DriverConfig, setup DriverSetup, serve Dri return setupErr } - s := csicommon.NewNonBlockingGRPCServer() + s := NewMyGRPCServer() defer s.Stop() s.Start(config.Endpoint, is, cs, ns) From 28f60df1e3acb6556e2cb345a12b1d970753c22e Mon Sep 17 00:00:00 2001 From: Flora Thiebaut Date: Thu, 12 Jun 2025 09:05:27 +0200 Subject: [PATCH 10/10] build: add action to build the container image (dev) --- .github/workflows/build.yaml | 71 ++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 .github/workflows/build.yaml diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml new file mode 100644 index 00000000..3fe5df72 --- /dev/null +++ b/.github/workflows/build.yaml @@ -0,0 +1,71 @@ +name: Build dev version + +on: + push: + workflow_dispatch: + +env: + REGISTRY: ghcr.io + IMAGE_NAME: ${{ github.repository }} + CHART_NAME: ${{ github.repository }}/helm-chart + +defaults: + run: + shell: bash + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +permissions: + contents: read + +jobs: + build-image: + runs-on: ubuntu-24.04 + outputs: + image: ${{ steps.docker_image.outputs.image }} + image_repository: ${{ steps.docker_image.outputs.image_repository }} + image_tag: ${{ steps.docker_image.outputs.image_tag }} + permissions: + contents: read + packages: write + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Docker image metadata + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + tags: type=sha + - name: Extract Docker image name + id: docker_image + env: + IMAGE_TAGS: ${{ steps.meta.outputs.tags }} + run: | + IMAGE=$(echo "$IMAGE_TAGS" | cut -d" " -f1) + IMAGE_REPOSITORY=$(echo "$IMAGE" | cut -d":" -f1) + IMAGE_TAG=$(echo "$IMAGE" | cut -d":" -f2) + echo "image=$IMAGE" >> "$GITHUB_OUTPUT" + echo "image_repository=$IMAGE_REPOSITORY" >> "$GITHUB_OUTPUT" + echo "image_tag=$IMAGE_TAG" >> "$GITHUB_OUTPUT" + - name: Set up Docker buildx + uses: docker/setup-buildx-action@v3 + - name: Set up Docker + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Build and push Docker image + uses: docker/build-push-action@v6 + with: + context: . + push: true + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + cache-from: type=registry,ref=${{ steps.docker_image.outputs.image_repository }}:buildcache + cache-to: type=registry,ref=${{ steps.docker_image.outputs.image_repository }}:buildcache,mode=max + +# TODO: add job to build and push the helm chart if needed (manual trigger)