diff --git a/.devcontainer/rclone/install.sh b/.devcontainer/rclone/install.sh index c9a8624..8628ce9 100644 --- a/.devcontainer/rclone/install.sh +++ b/.devcontainer/rclone/install.sh @@ -19,3 +19,6 @@ rm -rf /tmp/rclone # Fix the $GOPATH folder chown -R "${USERNAME}:golang" /go chmod -R g+r+w /go + +# Make sure the default folders exists +mkdir -p /run/csi-rclone \ No newline at end of file diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml new file mode 100644 index 0000000..3fe5df7 --- /dev/null +++ b/.github/workflows/build.yaml @@ -0,0 +1,71 @@ +name: Build dev version + +on: + push: + workflow_dispatch: + +env: + REGISTRY: ghcr.io + IMAGE_NAME: ${{ github.repository }} + CHART_NAME: ${{ github.repository }}/helm-chart + +defaults: + run: + shell: bash + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +permissions: + contents: read + +jobs: + build-image: + runs-on: ubuntu-24.04 + outputs: + image: ${{ steps.docker_image.outputs.image }} + image_repository: ${{ steps.docker_image.outputs.image_repository }} + image_tag: ${{ steps.docker_image.outputs.image_tag }} + permissions: + contents: read + packages: write + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Docker image metadata + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + tags: type=sha + - name: Extract Docker image name + id: docker_image + env: + IMAGE_TAGS: ${{ steps.meta.outputs.tags }} + run: | + IMAGE=$(echo "$IMAGE_TAGS" | cut -d" " -f1) + IMAGE_REPOSITORY=$(echo "$IMAGE" | cut -d":" -f1) + IMAGE_TAG=$(echo "$IMAGE" | cut -d":" -f2) + echo "image=$IMAGE" >> "$GITHUB_OUTPUT" + echo "image_repository=$IMAGE_REPOSITORY" >> "$GITHUB_OUTPUT" + echo "image_tag=$IMAGE_TAG" >> "$GITHUB_OUTPUT" + - name: Set up Docker buildx + uses: docker/setup-buildx-action@v3 + - name: Set up Docker + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Build and push Docker image + uses: docker/build-push-action@v6 + with: + context: . + push: true + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + cache-from: type=registry,ref=${{ steps.docker_image.outputs.image_repository }}:buildcache + cache-to: type=registry,ref=${{ steps.docker_image.outputs.image_repository }}:buildcache,mode=max + +# TODO: add job to build and push the helm chart if needed (manual trigger) diff --git a/cmd/csi-rclone-plugin/main.go b/cmd/csi-rclone-plugin/main.go index 2a238a7..7f2b37f 100644 --- a/cmd/csi-rclone-plugin/main.go +++ b/cmd/csi-rclone-plugin/main.go @@ -2,33 +2,35 @@ package main import ( "context" + "errors" "flag" "fmt" "os" "os/signal" - "syscall" "time" + "github.com/SwissDataScienceCenter/csi-rclone/pkg/common" "github.com/SwissDataScienceCenter/csi-rclone/pkg/metrics" "github.com/SwissDataScienceCenter/csi-rclone/pkg/rclone" "github.com/spf13/cobra" + "github.com/spf13/pflag" "k8s.io/klog" - mountUtils "k8s.io/mount-utils" ) -var ( - endpoint string - nodeID string - cacheDir string - cacheSize string - meters []metrics.Observable -) +func exitOnError(err error) { + // ParseFlags uses errors to return some status information, ignore it here. + if err != nil && !errors.Is(err, pflag.ErrHelp) { + klog.Error(err.Error()) + os.Exit(1) + } +} func init() { - flag.Set("logtostderr", "true") + exitOnError(flag.Set("logtostderr", "true")) } func main() { + var meters []metrics.Observable metricsServerConfig := metrics.ServerConfig{ Host: "localhost", Port: 9090, @@ -37,123 +39,49 @@ func main() { ShutdownTimeout: 5 * time.Second, Enabled: false, } + nodeServerConfig := rclone.NodeServerConfig{} + controllerServerConfig := rclone.ControllerServerConfig{} root := &cobra.Command{ Use: "rclone", Short: "CSI based rclone driver", } + // Allow flags to be defined in subcommands, they will be reported at the Execute() step, with the help printed + // before exiting. + root.FParseErrWhitelist.UnknownFlags = true + metricsServerConfig.CommandLineParameters(root) runCmd := &cobra.Command{ Use: "run", Short: "Start the CSI driver.", } - root.AddCommand(runCmd) + exitOnError(nodeServerConfig.CommandLineParameters(runCmd, &meters)) + exitOnError(controllerServerConfig.CommandLineParameters(runCmd, &meters)) - runNode := &cobra.Command{ - Use: "node", - Short: "Start the CSI driver node service - expected to run in a daemonset on every node.", - Run: func(cmd *cobra.Command, args []string) { - handleNode() - }, - } - runNode.PersistentFlags().StringVar(&nodeID, "nodeid", "", "node id") - runNode.MarkPersistentFlagRequired("nodeid") - runNode.PersistentFlags().StringVar(&endpoint, "endpoint", "", "CSI endpoint") - runNode.MarkPersistentFlagRequired("endpoint") - runNode.PersistentFlags().StringVar(&cacheDir, "cachedir", "", "cache dir") - runNode.PersistentFlags().StringVar(&cacheSize, "cachesize", "", "cache size") - runCmd.AddCommand(runNode) - runController := &cobra.Command{ - Use: "controller", - Short: "Start the CSI driver controller.", - Run: func(cmd *cobra.Command, args []string) { - handleController() - }, - } - runController.PersistentFlags().StringVar(&nodeID, "nodeid", "", "node id") - runController.MarkPersistentFlagRequired("nodeid") - runController.PersistentFlags().StringVar(&endpoint, "endpoint", "", "CSI endpoint") - runController.MarkPersistentFlagRequired("endpoint") - runCmd.AddCommand(runController) + root.AddCommand(runCmd) versionCmd := &cobra.Command{ Use: "version", Short: "Prints information about this version of csi rclone plugin", Run: func(cmd *cobra.Command, args []string) { - fmt.Printf("csi-rclone plugin Version: %s", rclone.DriverVersion) + fmt.Printf("csi-rclone plugin Version: %s\n", rclone.DriverVersion) }, } root.AddCommand(versionCmd) - root.ParseFlags(os.Args[1:]) + exitOnError(root.ParseFlags(os.Args[1:])) if metricsServerConfig.Enabled { // Gracefully exit the metrics background servers - ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) + ctx, stop := signal.NotifyContext(context.Background(), common.InterruptSignals...) defer stop() metricsServer := metricsServerConfig.NewServer(ctx, &meters) go metricsServer.ListenAndServe() } - if err := root.Execute(); err != nil { - fmt.Fprintf(os.Stderr, "%s", err.Error()) - os.Exit(1) - } + exitOnError(root.Execute()) os.Exit(0) } - -func handleNode() { - err := unmountOldVols() - if err != nil { - klog.Warningf("There was an error when trying to unmount old volumes: %v", err) - } - d := rclone.NewDriver(nodeID, endpoint) - ns, err := rclone.NewNodeServer(d.CSIDriver, cacheDir, cacheSize) - if err != nil { - panic(err) - } - meters = append(meters, ns.Metrics()...) - d.WithNodeServer(ns) - err = d.Run() - if err != nil { - panic(err) - } -} - -func handleController() { - d := rclone.NewDriver(nodeID, endpoint) - cs := rclone.NewControllerServer(d.CSIDriver) - meters = append(meters, cs.Metrics()...) - d.WithControllerServer(cs) - err := d.Run() - if err != nil { - panic(err) - } -} - -// unmountOldVols is used to unmount volumes after a restart on a node -func unmountOldVols() error { - const mountType = "fuse.rclone" - const unmountTimeout = time.Second * 5 - klog.Info("Checking for existing mounts") - mounter := mountUtils.Mounter{} - mounts, err := mounter.List() - if err != nil { - return err - } - for _, mount := range mounts { - if mount.Type != mountType { - continue - } - err := mounter.UnmountWithForce(mount.Path, unmountTimeout) - if err != nil { - klog.Warningf("Failed to unmount %s because of %v.", mount.Path, err) - continue - } - klog.Infof("Sucessfully unmounted %s", mount.Path) - } - return nil -} diff --git a/deploy/csi-rclone/templates/csi-controller-rclone.yaml b/deploy/csi-rclone/templates/csi-controller-rclone.yaml index dd13c43..4f65cd8 100644 --- a/deploy/csi-rclone/templates/csi-controller-rclone.yaml +++ b/deploy/csi-rclone/templates/csi-controller-rclone.yaml @@ -54,8 +54,8 @@ spec: image: {{ .Values.csiControllerRclone.csiProvisioner.image.repository }}:{{ .Values.csiControllerRclone.csiProvisioner.image.tag | default .Chart.AppVersion }} imagePullPolicy: {{ .Values.csiControllerRclone.csiProvisioner.imagePullPolicy }} volumeMounts: - - name: socket-dir - mountPath: /csi + - mountPath: /csi + name: socket-dir - name: rclone args: - run @@ -85,7 +85,7 @@ spec: fieldRef: fieldPath: spec.nodeName - name: CSI_ENDPOINT - value: "unix://plugin/csi.sock" + value: "unix://csi/csi.sock" - name: KUBERNETES_CLUSTER_DOMAIN value: {{ quote .Values.kubernetesClusterDomain }} {{- if .Values.csiControllerRclone.rclone.goMemLimit }} @@ -114,7 +114,7 @@ spec: timeoutSeconds: 3 periodSeconds: 2 volumeMounts: - - mountPath: /plugin + - mountPath: /csi name: socket-dir - name: liveness-probe imagePullPolicy: Always diff --git a/deploy/csi-rclone/templates/csi-nodeplugin-rclone.yaml b/deploy/csi-rclone/templates/csi-nodeplugin-rclone.yaml index 670badc..be27eb0 100644 --- a/deploy/csi-rclone/templates/csi-nodeplugin-rclone.yaml +++ b/deploy/csi-rclone/templates/csi-nodeplugin-rclone.yaml @@ -22,7 +22,7 @@ spec: - name: node-driver-registrar args: - --v=5 - - --csi-address=/plugin/csi.sock + - --csi-address=/csi/csi.sock - --kubelet-registration-path=/var/lib/kubelet/plugins/{{ .Values.storageClassName }}/csi.sock env: - name: KUBE_NODE_NAME @@ -45,7 +45,7 @@ spec: resources: {{- toYaml .Values.csiNodepluginRclone.rclone.resources | nindent 12 }} volumeMounts: - - mountPath: /plugin + - mountPath: /csi name: plugin-dir - mountPath: /registration name: registration-dir @@ -53,9 +53,9 @@ spec: imagePullPolicy: Always image: registry.k8s.io/sig-storage/livenessprobe:v2.15.0 args: - - --csi-address=/plugin/csi.sock + - --csi-address=/csi/csi.sock volumeMounts: - - mountPath: /plugin + - mountPath: /csi name: plugin-dir - name: rclone args: @@ -86,7 +86,7 @@ spec: fieldRef: fieldPath: spec.nodeName - name: CSI_ENDPOINT - value: "unix://plugin/csi.sock" + value: "unix://csi/csi.sock" - name: KUBERNETES_CLUSTER_DOMAIN value: {{ quote .Values.kubernetesClusterDomain }} - name: DRIVER_NAME @@ -134,8 +134,10 @@ spec: timeoutSeconds: 10 periodSeconds: 30 volumeMounts: - - mountPath: /plugin + - mountPath: /csi name: plugin-dir + - mountPath: /run/csi-rclone + name: node-temp-dir - mountPath: /var/lib/kubelet/pods mountPropagation: Bidirectional name: pods-mount-dir @@ -154,6 +156,11 @@ spec: {{ toYaml . | nindent 8 }} {{- end }} volumes: + - hostPath: + # NOTE: We mount on /tmp because we want the saved configuration to not survive a whole node restart. + path: /tmp/{{.Release.Namespace}}-{{.Release.Name}}-{{.Release.Revision}} + type: DirectoryOrCreate + name: node-temp-dir - hostPath: path: {{ .Values.kubeletDir }}/plugins/{{ .Values.storageClassName }} type: DirectoryOrCreate @@ -167,4 +174,4 @@ spec: type: DirectoryOrCreate name: registration-dir - name: cache-dir - emptyDir: + emptyDir: {} diff --git a/pkg/common/constants.go b/pkg/common/constants.go new file mode 100644 index 0000000..a591954 --- /dev/null +++ b/pkg/common/constants.go @@ -0,0 +1,11 @@ +package common + +import ( + "os" + "syscall" +) + +// Signals to listen to: +// 1. os.Interrup -> allows devs to easily run a server locally +// 2. syscall.SIGTERM -> sent by kubernetes when stopping a server gracefully +var InterruptSignals = []os.Signal{os.Interrupt, syscall.SIGTERM} diff --git a/pkg/rclone/MyGRPCServer.go b/pkg/rclone/MyGRPCServer.go new file mode 100644 index 0000000..f9278f9 --- /dev/null +++ b/pkg/rclone/MyGRPCServer.go @@ -0,0 +1,101 @@ +package rclone + +import ( + "context" + "net" + "sync" + + "github.com/container-storage-interface/spec/lib/go/csi" + "github.com/golang/glog" + "github.com/kubernetes-csi/csi-lib-utils/protosanitizer" + csicommon "github.com/kubernetes-csi/drivers/pkg/csi-common" + "google.golang.org/grpc" +) + +// Override the serve function to keep the csi socket instead of removing it before use. +// this is basically a copy-paste of csi-common/server.go with two lines modified. +// done as a quick hack and as the actual implementation struct is private, so I can't ovveride the `server`function` +// only. + +func NewMyGRPCServer() csicommon.NonBlockingGRPCServer { + return &myGRPCServer{} +} + +type myGRPCServer struct { + wg sync.WaitGroup + server *grpc.Server +} + +func (s *myGRPCServer) Start(endpoint string, ids csi.IdentityServer, cs csi.ControllerServer, ns csi.NodeServer) { + + s.wg.Add(1) + + go s.serve(endpoint, ids, cs, ns) + + return +} + +func (s *myGRPCServer) Wait() { + s.wg.Wait() +} + +func (s *myGRPCServer) Stop() { + s.server.GracefulStop() +} + +func (s *myGRPCServer) ForceStop() { + s.server.Stop() +} + +func logGRPC(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + glog.V(3).Infof("GRPC call: %s", info.FullMethod) + glog.V(5).Infof("GRPC request: %s", protosanitizer.StripSecrets(req)) + resp, err := handler(ctx, req) + if err != nil { + glog.Errorf("GRPC error: %v", err) + } else { + glog.V(5).Infof("GRPC response: %s", protosanitizer.StripSecrets(resp)) + } + return resp, err +} + +func (s *myGRPCServer) serve(endpoint string, ids csi.IdentityServer, cs csi.ControllerServer, ns csi.NodeServer) { + + proto, addr, err := csicommon.ParseEndpoint(endpoint) + if err != nil { + glog.Fatal(err.Error()) + } + + if proto == "unix" { + addr = "/" + addr + //if err := os.Remove(addr); err != nil && !os.IsNotExist(err) { + // glog.Fatalf("Failed to remove %s, error: %s", addr, err.Error()) + //} + } + + listener, err := net.Listen(proto, addr) + if err != nil { + glog.Fatalf("Failed to listen: %v", err) + } + + opts := []grpc.ServerOption{ + grpc.UnaryInterceptor(logGRPC), + } + server := grpc.NewServer(opts...) + s.server = server + + if ids != nil { + csi.RegisterIdentityServer(server, ids) + } + if cs != nil { + csi.RegisterControllerServer(server, cs) + } + if ns != nil { + csi.RegisterNodeServer(server, ns) + } + + glog.Infof("Listening for connections on address: %#v", listener.Addr()) + + server.Serve(listener) + +} diff --git a/pkg/rclone/controllerserver.go b/pkg/rclone/controllerserver.go index 4e00dd7..152011f 100644 --- a/pkg/rclone/controllerserver.go +++ b/pkg/rclone/controllerserver.go @@ -3,10 +3,13 @@ package rclone import ( + "context" "sync" + "github.com/SwissDataScienceCenter/csi-rclone/pkg/metrics" "github.com/container-storage-interface/spec/lib/go/csi" - "golang.org/x/net/context" + "github.com/prometheus/client_golang/prometheus" + "github.com/spf13/cobra" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "k8s.io/klog" @@ -16,13 +19,67 @@ import ( const secretAnnotationName = "csi-rclone.dev/secretName" -type controllerServer struct { +type ControllerServerConfig struct{ DriverConfig } + +type ControllerServer struct { *csicommon.DefaultControllerServer - active_volumes map[string]int64 - mutex sync.RWMutex + activeVolumes map[string]int64 + mutex *sync.RWMutex +} + +func NewControllerServer(csiDriver *csicommon.CSIDriver) *ControllerServer { + return &ControllerServer{ + DefaultControllerServer: csicommon.NewDefaultControllerServer(csiDriver), + activeVolumes: map[string]int64{}, + mutex: &sync.RWMutex{}, + } +} + +func (cs *ControllerServer) metrics() []metrics.Observable { + var meters []metrics.Observable + + meter := prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "csi_rclone_active_volume_count", + Help: "Number of active (Mounted) volumes.", + }) + meters = append(meters, + func() { + cs.mutex.RLock() + defer cs.mutex.RUnlock() + meter.Set(float64(len(cs.activeVolumes))) + }, + ) + prometheus.MustRegister(meter) + + return meters } -func (cs *controllerServer) ValidateVolumeCapabilities(ctx context.Context, req *csi.ValidateVolumeCapabilitiesRequest) (*csi.ValidateVolumeCapabilitiesResponse, error) { +func (config *ControllerServerConfig) CommandLineParameters(runCmd *cobra.Command, meters *[]metrics.Observable) error { + runController := &cobra.Command{ + Use: "controller", + Short: "Start the CSI driver controller.", + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + return Run(ctx, + &config.DriverConfig, + func(csiDriver *csicommon.CSIDriver) (*ControllerServer, *NodeServer, error) { + cs := NewControllerServer(csiDriver) + *meters = append(*meters, cs.metrics()...) + return cs, nil, nil + }, + func(_ context.Context, cs *ControllerServer, ns *NodeServer) error { return nil }, + ) + }, + } + if err := config.DriverConfig.CommandLineParameters(runController); err != nil { + return err + } + + runCmd.AddCommand(runController) + return nil +} + +func (cs *ControllerServer) ValidateVolumeCapabilities(_ context.Context, req *csi.ValidateVolumeCapabilitiesRequest) (*csi.ValidateVolumeCapabilitiesResponse, error) { volId := req.GetVolumeId() if len(volId) == 0 { return nil, status.Error(codes.InvalidArgument, "ValidateVolumeCapabilities must be provided volume id") @@ -31,9 +88,9 @@ func (cs *controllerServer) ValidateVolumeCapabilities(ctx context.Context, req return nil, status.Error(codes.InvalidArgument, "ValidateVolumeCapabilities without capabilities") } - cs.mutex.Lock() - defer cs.mutex.Unlock() - if _, ok := cs.active_volumes[volId]; !ok { + cs.mutex.RLock() + defer cs.mutex.RUnlock() + if _, ok := cs.activeVolumes[volId]; !ok { return nil, status.Errorf(codes.NotFound, "Volume %s not found", volId) } return &csi.ValidateVolumeCapabilitiesResponse{ @@ -45,18 +102,18 @@ func (cs *controllerServer) ValidateVolumeCapabilities(ctx context.Context, req }, nil } -// Attaching Volume -func (cs *controllerServer) ControllerPublishVolume(ctx context.Context, req *csi.ControllerPublishVolumeRequest) (*csi.ControllerPublishVolumeResponse, error) { +// ControllerPublishVolume Attaching Volume +func (cs *ControllerServer) ControllerPublishVolume(_ context.Context, _ *csi.ControllerPublishVolumeRequest) (*csi.ControllerPublishVolumeResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method ControllerPublishVolume not implemented") } -// Detaching Volume -func (cs *controllerServer) ControllerUnpublishVolume(ctx context.Context, req *csi.ControllerUnpublishVolumeRequest) (*csi.ControllerUnpublishVolumeResponse, error) { +// ControllerUnpublishVolume Detaching Volume +func (cs *ControllerServer) ControllerUnpublishVolume(_ context.Context, _ *csi.ControllerUnpublishVolumeRequest) (*csi.ControllerUnpublishVolumeResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method ControllerUnpublishVolume not implemented") } -// Provisioning Volumes -func (cs *controllerServer) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) { +// CreateVolume Provisioning Volumes +func (cs *ControllerServer) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) { klog.Infof("ControllerCreateVolume: called with args %+v", *req) volumeName := req.GetName() if len(volumeName) == 0 { @@ -70,18 +127,18 @@ func (cs *controllerServer) CreateVolume(ctx context.Context, req *csi.CreateVol // we don't use the size as it makes no sense for rclone. but csi drivers should succeed if // called twice with the same capacity for the same volume and fail if called twice with // differing capacity, so we need to remember it - volSizeBytes := int64(req.GetCapacityRange().GetRequiredBytes()) + volSizeBytes := req.GetCapacityRange().GetRequiredBytes() cs.mutex.Lock() defer cs.mutex.Unlock() - if val, ok := cs.active_volumes[volumeName]; ok && val != volSizeBytes { + if val, ok := cs.activeVolumes[volumeName]; ok && val != volSizeBytes { return nil, status.Errorf(codes.AlreadyExists, "Volume operation already exists for volume %s", volumeName) } - cs.active_volumes[volumeName] = volSizeBytes + cs.activeVolumes[volumeName] = volSizeBytes // See https://github.com/kubernetes-csi/external-provisioner/blob/v5.1.0/pkg/controller/controller.go#L75 // on how parameters from the persistent volume are parsed // We have to pass the secret name and namespace into the context so that the node server can use them - // The external provisioner uses the secret name and namespace but it does not pass them into the request, + // The external provisioner uses the secret name and namespace, but it does not pass them into the request, // so we read the PVC here to extract them ourselves because we may need them in the node server for decoding secrets. pvcName, pvcNameFound := req.Parameters["csi.storage.k8s.io/pvc/name"] pvcNamespace, pvcNamespaceFound := req.Parameters["csi.storage.k8s.io/pvc/namespace"] @@ -114,29 +171,28 @@ func (cs *controllerServer) CreateVolume(ctx context.Context, req *csi.CreateVol } -// Delete Volume -func (cs *controllerServer) DeleteVolume(ctx context.Context, req *csi.DeleteVolumeRequest) (*csi.DeleteVolumeResponse, error) { +func (cs *ControllerServer) DeleteVolume(_ context.Context, req *csi.DeleteVolumeRequest) (*csi.DeleteVolumeResponse, error) { volId := req.GetVolumeId() if len(volId) == 0 { - return nil, status.Error(codes.InvalidArgument, "DeteleVolume must be provided volume id") + return nil, status.Error(codes.InvalidArgument, "DeleteVolume must be provided volume id") } cs.mutex.Lock() defer cs.mutex.Unlock() - delete(cs.active_volumes, volId) + delete(cs.activeVolumes, volId) return &csi.DeleteVolumeResponse{}, nil } -func (*controllerServer) ControllerExpandVolume(ctx context.Context, req *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) { +func (*ControllerServer) ControllerExpandVolume(_ context.Context, _ *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method ControllerExpandVolume not implemented") } -func (cs *controllerServer) ControllerGetVolume(ctx context.Context, req *csi.ControllerGetVolumeRequest) (*csi.ControllerGetVolumeResponse, error) { +func (cs *ControllerServer) ControllerGetVolume(_ context.Context, req *csi.ControllerGetVolumeRequest) (*csi.ControllerGetVolumeResponse, error) { return &csi.ControllerGetVolumeResponse{Volume: &csi.Volume{ VolumeId: req.VolumeId, }}, nil } -func (cs *controllerServer) ControllerModifyVolume(ctx context.Context, req *csi.ControllerModifyVolumeRequest) (*csi.ControllerModifyVolumeResponse, error) { +func (cs *ControllerServer) ControllerModifyVolume(_ context.Context, _ *csi.ControllerModifyVolumeRequest) (*csi.ControllerModifyVolumeResponse, error) { return &csi.ControllerModifyVolumeResponse{}, nil } diff --git a/pkg/rclone/driver.go b/pkg/rclone/driver.go index 60fd38b..08d9adf 100644 --- a/pkg/rclone/driver.go +++ b/pkg/rclone/driver.go @@ -1,157 +1,69 @@ package rclone import ( - "fmt" - "net" + "context" + "errors" "os" - "sync" - "github.com/SwissDataScienceCenter/csi-rclone/pkg/kube" - "github.com/SwissDataScienceCenter/csi-rclone/pkg/metrics" "github.com/container-storage-interface/spec/lib/go/csi" csicommon "github.com/kubernetes-csi/drivers/pkg/csi-common" - "github.com/prometheus/client_golang/prometheus" + "github.com/spf13/cobra" "k8s.io/klog" - "k8s.io/utils/mount" - - utilexec "k8s.io/utils/exec" ) -type Driver struct { - CSIDriver *csicommon.CSIDriver - endpoint string +const DriverVersion = "SwissDataScienceCenter" - ns *nodeServer - cs *controllerServer - cap []*csi.VolumeCapability_AccessMode - cscap []*csi.ControllerServiceCapability - server csicommon.NonBlockingGRPCServer -} +type DriverSetup func(csiDriver *csicommon.CSIDriver) (*ControllerServer, *NodeServer, error) -var ( - DriverVersion = "SwissDataScienceCenter" -) +type DriverServe func(ctx context.Context, cs *ControllerServer, ns *NodeServer) error + +type DriverConfig struct { + Endpoint string + NodeID string +} -func getFreePort() (port int, err error) { - var a *net.TCPAddr - if a, err = net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { - var l *net.TCPListener - if l, err = net.ListenTCP("tcp", a); err == nil { - defer l.Close() - return l.Addr().(*net.TCPAddr).Port, nil - } +func (config *DriverConfig) CommandLineParameters(runCmd *cobra.Command) error { + runCmd.PersistentFlags().StringVar(&config.NodeID, "nodeid", config.NodeID, "node id") + if err := runCmd.MarkPersistentFlagRequired("nodeid"); err != nil { + return err + } + runCmd.PersistentFlags().StringVar(&config.Endpoint, "endpoint", config.Endpoint, "CSI endpoint") + if err := runCmd.MarkPersistentFlagRequired("endpoint"); err != nil { + return err } - return + return nil } -func NewDriver(nodeID, endpoint string) *Driver { +func Run(ctx context.Context, config *DriverConfig, setup DriverSetup, serve DriverServe) error { driverName := os.Getenv("DRIVER_NAME") if driverName == "" { - panic("DriverName env var not set!") + return errors.New("DRIVER_NAME env variable not set or empty") } klog.Infof("Starting new %s RcloneDriver in version %s", driverName, DriverVersion) - d := &Driver{} - d.endpoint = endpoint - - d.CSIDriver = csicommon.NewCSIDriver(driverName, DriverVersion, nodeID) - d.CSIDriver.AddVolumeCapabilityAccessModes([]csi.VolumeCapability_AccessMode_Mode{ + driver := csicommon.NewCSIDriver(driverName, DriverVersion, config.NodeID) + driver.AddVolumeCapabilityAccessModes([]csi.VolumeCapability_AccessMode_Mode{ csi.VolumeCapability_AccessMode_SINGLE_NODE_SINGLE_WRITER, }) - d.CSIDriver.AddControllerServiceCapabilities( + driver.AddControllerServiceCapabilities( []csi.ControllerServiceCapability_RPC_Type{ csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME, }) - return d -} - -func NewNodeServer(csiDriver *csicommon.CSIDriver, cacheDir string, cacheSize string) (*nodeServer, error) { - kubeClient, err := kube.GetK8sClient() - if err != nil { - return nil, err - } - - rclonePort, err := getFreePort() - if err != nil { - return nil, fmt.Errorf("Cannot get a free TCP port to run rclone") + is := csicommon.NewDefaultIdentityServer(driver) + cs, ns, setupErr := setup(driver) + if setupErr != nil { + return setupErr } - rcloneOps := NewRclone(kubeClient, rclonePort, cacheDir, cacheSize) - return &nodeServer{ - DefaultNodeServer: csicommon.NewDefaultNodeServer(csiDriver), - mounter: &mount.SafeFormatAndMount{ - Interface: mount.New(""), - Exec: utilexec.New(), - }, - RcloneOps: rcloneOps, - }, nil -} + s := NewMyGRPCServer() + defer s.Stop() + s.Start(config.Endpoint, is, cs, ns) -func NewControllerServer(csiDriver *csicommon.CSIDriver) *controllerServer { - return &controllerServer{ - DefaultControllerServer: csicommon.NewDefaultControllerServer(csiDriver), - active_volumes: map[string]int64{}, - mutex: sync.RWMutex{}, + if err := serve(ctx, cs, ns); err != nil { + return err } -} - -func (ns *nodeServer) Metrics() []metrics.Observable { - var meters []metrics.Observable - - // What should we meter? - return meters -} - -func (cs *controllerServer) Metrics() []metrics.Observable { - var meters []metrics.Observable - - meter := prometheus.NewGauge(prometheus.GaugeOpts{ - Name: "csi_rclone_active_volume_count", - Help: "Number of active (Mounted) volumes.", - }) - meters = append(meters, - func() { meter.Set(float64(len(cs.active_volumes))) }, - ) - prometheus.MustRegister(meter) - - return meters -} - -func (d *Driver) WithNodeServer(ns *nodeServer) *Driver { - d.ns = ns - return d -} - -func (d *Driver) WithControllerServer(cs *controllerServer) *Driver { - d.cs = cs - return d -} - -func (d *Driver) Run() error { - s := csicommon.NewNonBlockingGRPCServer() - s.Start( - d.endpoint, - csicommon.NewDefaultIdentityServer(d.CSIDriver), - d.cs, - d.ns, - ) - d.server = s - if d.ns != nil && d.ns.RcloneOps != nil { - return d.ns.RcloneOps.Run() - } s.Wait() return nil } - -func (d *Driver) Stop() error { - var err error - if d.ns != nil && d.ns.RcloneOps != nil { - err = d.ns.RcloneOps.Cleanup() - } - if d.server != nil { - d.server.Stop() - } - return err -} diff --git a/pkg/rclone/nodeserver.go b/pkg/rclone/nodeserver.go index ef03501..830f232 100644 --- a/pkg/rclone/nodeserver.go +++ b/pkg/rclone/nodeserver.go @@ -7,24 +7,33 @@ package rclone import ( "bytes" + "encoding/json" "errors" "fmt" + "net" "os" + "path/filepath" + "runtime" "strings" + "sync" "time" - "gopkg.in/ini.v1" - v1 "k8s.io/api/core/v1" - apierrors "k8s.io/apimachinery/pkg/api/errors" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/klog" - "github.com/SwissDataScienceCenter/csi-rclone/pkg/kube" + "github.com/SwissDataScienceCenter/csi-rclone/pkg/metrics" "github.com/container-storage-interface/spec/lib/go/csi" "github.com/fernet/fernet-go" + "github.com/prometheus/client_golang/prometheus" + "github.com/spf13/cobra" "golang.org/x/net/context" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "gopkg.in/ini.v1" + v1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/klog" + mountutils "k8s.io/mount-utils" + "k8s.io/utils/exec" "k8s.io/utils/mount" csicommon "github.com/kubernetes-csi/drivers/pkg/csi-common" @@ -33,23 +42,190 @@ import ( const CSI_ANNOTATION_PREFIX = "csi-rclone.dev" const pvcSecretNameAnnotation = CSI_ANNOTATION_PREFIX + "/secretName" -type nodeServer struct { +type NodeServer struct { *csicommon.DefaultNodeServer mounter *mount.SafeFormatAndMount RcloneOps Operations + + // Track mounted volumes for automatic remounting + mountedVolumes map[string]MountedVolume + mutex *sync.Mutex + stateFile string +} + +// unmountOldVols is used to unmount volumes after a restart on a node +func unmountOldVols() error { + const mountType = "fuse.rclone" + const unmountTimeout = time.Second * 5 + klog.Info("Checking for existing mounts") + mounter := mountutils.Mounter{} + mounts, err := mounter.List() + if err != nil { + return err + } + for _, mount := range mounts { + if mount.Type != mountType { + continue + } + err := mounter.UnmountWithForce(mount.Path, unmountTimeout) + if err != nil { + klog.Warningf("Failed to unmount %s because of %v.", mount.Path, err) + continue + } + klog.Infof("Sucessfully unmounted %s", mount.Path) + } + return nil +} + +func getFreePort() (port int, err error) { + var a *net.TCPAddr + if a, err = net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { + var l *net.TCPListener + if l, err = net.ListenTCP("tcp", a); err == nil { + defer l.Close() + return l.Addr().(*net.TCPAddr).Port, nil + } + } + return +} + +func NewNodeServer(csiDriver *csicommon.CSIDriver, cacheDir string, cacheSize string) (*NodeServer, error) { + err := unmountOldVols() + if err != nil { + klog.Warningf("There was an error when trying to unmount old volumes: %v", err) + return nil, err + } + + kubeClient, err := kube.GetK8sClient() + if err != nil { + return nil, err + } + + rclonePort, err := getFreePort() + if err != nil { + return nil, fmt.Errorf("Cannot get a free TCP port to run rclone") + } + + ns := &NodeServer{ + DefaultNodeServer: csicommon.NewDefaultNodeServer(csiDriver), + mounter: &mount.SafeFormatAndMount{ + Interface: mount.New(""), + Exec: exec.New(), + }, + RcloneOps: NewRclone(kubeClient, rclonePort, cacheDir, cacheSize), + mountedVolumes: make(map[string]MountedVolume), + mutex: &sync.Mutex{}, + stateFile: "/run/csi-rclone/mounted_volumes.json", + } + + // Ensure the folder exists + if err = os.MkdirAll(filepath.Dir(ns.stateFile), 0755); err != nil { + return nil, fmt.Errorf("failed to create state directory: %v", err) + } + + // Load persisted state on startup + ns.mutex.Lock() + defer ns.mutex.Unlock() + + if ns.mountedVolumes, err = readVolumeMap(ns.stateFile); err != nil { + klog.Warningf("Failed to load persisted volume state: %v", err) + } + + return ns, nil +} + +func (ns *NodeServer) Run(ctx context.Context) error { + defer ns.Stop() + return ns.RcloneOps.Run(ctx, func() error { + return ns.remountTrackedVolumes(ctx) + }) +} + +func (ns *NodeServer) metrics() []metrics.Observable { + var meters []metrics.Observable + + meter := prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "csi_rclone_active_volume_count", + Help: "Number of active (Mounted) volumes.", + }) + meters = append(meters, + func() { meter.Set(float64(len(ns.mountedVolumes))) }, + ) + prometheus.MustRegister(meter) + + return meters +} + +func (ns *NodeServer) Stop() { + if err := ns.RcloneOps.Cleanup(); err != nil { + klog.Errorf("Failed to cleanup rclone ops: %v", err) + } +} + +type NodeServerConfig struct { + DriverConfig + CacheDir string + CacheSize string +} + +func (config *NodeServerConfig) CommandLineParameters(runCmd *cobra.Command, meters *[]metrics.Observable) error { + runNode := &cobra.Command{ + Use: "node", + Short: "Start the CSI driver node service - expected to run in a daemonset on every node.", + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + return Run(ctx, &config.DriverConfig, + func(csiDriver *csicommon.CSIDriver) (*ControllerServer, *NodeServer, error) { + ns, err := NewNodeServer(csiDriver, config.CacheDir, config.CacheSize) + if err != nil { + return nil, nil, err + } + *meters = append(*meters, ns.metrics()...) + return nil, ns, err + }, + func(ctx context.Context, cs *ControllerServer, ns *NodeServer) error { + if ns == nil { + return errors.New("node server uninitialized") + } + return ns.Run(ctx) + }, + ) + }, + } + if err := config.DriverConfig.CommandLineParameters(runNode); err != nil { + return err + } + + runNode.PersistentFlags().StringVar(&config.CacheDir, "cachedir", config.CacheDir, "cache dir") + runNode.PersistentFlags().StringVar(&config.CacheSize, "cachesize", config.CacheSize, "cache size") + + runCmd.AddCommand(runNode) + return nil +} + +type MountedVolume struct { + VolumeId string `json:"volume_id"` + TargetPath string `json:"target_path"` + Remote string `json:"remote"` + RemotePath string `json:"remote_path"` + ConfigData string `json:"config_data"` + ReadOnly bool `json:"read_only"` + Parameters map[string]string `json:"parameters"` + SecretName string `json:"secret_name"` + SecretNamespace string `json:"secret_namespace"` } // Mounting Volume (Preparation) -func (ns *nodeServer) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRequest) (*csi.NodeStageVolumeResponse, error) { +func (ns *NodeServer) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRequest) (*csi.NodeStageVolumeResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method NodeStageVolume not implemented") } -func (ns *nodeServer) NodeUnstageVolume(ctx context.Context, req *csi.NodeUnstageVolumeRequest) (*csi.NodeUnstageVolumeResponse, error) { +func (ns *NodeServer) NodeUnstageVolume(ctx context.Context, req *csi.NodeUnstageVolumeRequest) (*csi.NodeUnstageVolumeResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method NodeUnstageVolume not implemented") } // Mounting Volume (Actual Mounting) -func (ns *nodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolumeRequest) (*csi.NodePublishVolumeResponse, error) { +func (ns *NodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolumeRequest) (*csi.NodePublishVolumeResponse, error) { if err := validatePublishVolumeRequest(req); err != nil { return nil, err } @@ -141,6 +317,10 @@ func (ns *nodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePublis } return nil, status.Error(codes.Internal, err.Error()) } + + // Track the mounted volume for automatic remounting + ns.trackMountedVolume(volumeId, targetPath, remote, remotePath, configData, readOnly, flags, secretName, secretNamespace) + // err = ns.WaitForMountAvailable(targetPath) // if err != nil { // return nil, status.Error(codes.Internal, err.Error()) @@ -303,7 +483,7 @@ func extractConfigData(parameters map[string]string) (string, map[string]string) } // Unmounting Volumes -func (ns *nodeServer) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpublishVolumeRequest) (*csi.NodeUnpublishVolumeResponse, error) { +func (ns *NodeServer) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpublishVolumeRequest) (*csi.NodeUnpublishVolumeResponse, error) { klog.Infof("NodeUnpublishVolume called with: %s", req) if err := validateUnPublishVolumeRequest(req); err != nil { return nil, err @@ -323,6 +503,10 @@ func (ns *nodeServer) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpu if err := ns.RcloneOps.Unmount(ctx, req.GetVolumeId(), targetPath); err != nil { klog.Warningf("Unmounting volume failed: %s", err) } + + // Remove the volume from tracking + ns.removeTrackedVolume(req.GetVolumeId()) + mount.CleanupMountPoint(req.GetTargetPath(), ns.mounter, false) return &csi.NodeUnpublishVolumeResponse{}, nil } @@ -340,11 +524,128 @@ func validateUnPublishVolumeRequest(req *csi.NodeUnpublishVolumeRequest) error { } // Resizing Volume -func (*nodeServer) NodeExpandVolume(ctx context.Context, req *csi.NodeExpandVolumeRequest) (*csi.NodeExpandVolumeResponse, error) { +func (*NodeServer) NodeExpandVolume(ctx context.Context, req *csi.NodeExpandVolumeRequest) (*csi.NodeExpandVolumeResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method NodeExpandVolume not implemented") } -func (ns *nodeServer) WaitForMountAvailable(mountpoint string) error { +// Track mounted volume for automatic remounting +func (ns *NodeServer) trackMountedVolume(volumeId, targetPath, remote, remotePath, configData string, readOnly bool, parameters map[string]string, secretName, secretNamespace string) { + ns.mutex.Lock() + defer ns.mutex.Unlock() + + ns.mountedVolumes[volumeId] = MountedVolume{ + VolumeId: volumeId, + TargetPath: targetPath, + Remote: remote, + RemotePath: remotePath, + ConfigData: configData, + ReadOnly: readOnly, + Parameters: parameters, + SecretName: secretName, + SecretNamespace: secretNamespace, + } + klog.Infof("Tracked mounted volume %s at path %s", volumeId, targetPath) + + if err := writeVolumeMap(ns.stateFile, ns.mountedVolumes); err != nil { + klog.Errorf("Failed to persist volume state: %v", err) + } +} + +// Remove tracked volume when unmounted +func (ns *NodeServer) removeTrackedVolume(volumeId string) { + ns.mutex.Lock() + defer ns.mutex.Unlock() + + delete(ns.mountedVolumes, volumeId) + klog.Infof("Removed tracked volume %s", volumeId) + + if err := writeVolumeMap(ns.stateFile, ns.mountedVolumes); err != nil { + klog.Errorf("Failed to persist volume state: %v", err) + } +} + +// Automatically remount all tracked volumes after daemon restart +func (ns *NodeServer) remountTrackedVolumes(ctx context.Context) error { + type mountResult struct { + volumeID string + err error + } + + ns.mutex.Lock() + defer ns.mutex.Unlock() + + volumesCount := len(ns.mountedVolumes) + + if volumesCount == 0 { + klog.Info("No tracked volumes to remount") + return nil + } + + klog.Infof("Remounting %d tracked volumes", volumesCount) + + // Limit the number of active workers to the number of CPU threads (arbitrarily chosen) + limits := make(chan bool, runtime.GOMAXPROCS(0)) + defer close(limits) + + results := make(chan mountResult, volumesCount) + defer close(results) + + ctxWithTimeout, cancel := context.WithTimeout(ctx, 60*time.Second) + defer cancel() + + for volumeId, mv := range ns.mountedVolumes { + go func() { + limits <- true // block until there is a free slot in the queue + defer func() { + <-limits // free a slot in the queue when we exit + }() + + ctxWithMountTimeout, cancel := context.WithTimeout(ctxWithTimeout, 30*time.Second) + defer cancel() + + klog.Infof("Remounting volume %s to %s", volumeId, mv.TargetPath) + + // Create the mount directory if it doesn't exist + var err error + if err = os.MkdirAll(mv.TargetPath, 0750); err != nil { + klog.Errorf("Failed to create mount directory %s: %v", mv.TargetPath, err) + } else { + // Remount the volume + rcloneVol := &RcloneVolume{ + ID: mv.VolumeId, + Remote: mv.Remote, + RemotePath: mv.RemotePath, + } + + err = ns.RcloneOps.Mount(ctxWithMountTimeout, rcloneVol, mv.TargetPath, mv.ConfigData, mv.ReadOnly, mv.Parameters) + } + + results <- mountResult{volumeId, err} + }() + } + + for { + select { + case result := <-results: + volumesCount-- + if result.err != nil { + klog.Errorf("Failed to remount volume %s: %v", result.volumeID, result.err) + // Don't return error here, continue with other volumes not to block all users because of a failed mount. + delete(ns.mountedVolumes, result.volumeID) + // Should we keep it on disk? This will be lost on the first new mount which will override the file. + } else { + klog.Infof("Successfully remounted volume %s", result.volumeID) + } + if volumesCount == 0 { + return nil + } + case <-ctxWithTimeout.Done(): + return ctxWithTimeout.Err() + } + } +} + +func (ns *NodeServer) WaitForMountAvailable(mountpoint string) error { for { select { case <-time.After(100 * time.Millisecond): @@ -357,3 +658,47 @@ func (ns *nodeServer) WaitForMountAvailable(mountpoint string) error { } } } + +// Persist volume state to disk +func writeVolumeMap(filename string, volumes map[string]MountedVolume) error { + if filename == "" { + return nil + } + + data, err := json.Marshal(volumes) + if err != nil { + return fmt.Errorf("failed to marshal volume state: %v", err) + } + + if err := os.WriteFile(filename, data, 0600); err != nil { + return fmt.Errorf("failed to write state file: %v", err) + } + + klog.Infof("Persisted volume state to %s", filename) + return nil +} + +// Load volume state from disk +func readVolumeMap(filename string) (map[string]MountedVolume, error) { + volumes := make(map[string]MountedVolume) + + if filename == "" { + return volumes, nil + } + + data, err := os.ReadFile(filename) + if err != nil { + if os.IsNotExist(err) { + klog.Info("No persisted volume state found, starting fresh") + return volumes, nil + } + return volumes, fmt.Errorf("failed to read state file: %v", err) + } + + if err := json.Unmarshal(data, &volumes); err != nil { + return nil, fmt.Errorf("failed to unmarshal volume state: %v", err) + } + + klog.Infof("Loaded %d tracked volumes from %s", len(volumes), filename) + return volumes, nil +} diff --git a/pkg/rclone/rclone.go b/pkg/rclone/rclone.go index 9b32508..31a35a5 100644 --- a/pkg/rclone/rclone.go +++ b/pkg/rclone/rclone.go @@ -11,6 +11,7 @@ import ( "os" os_exec "os/exec" "syscall" + "time" "strings" @@ -34,7 +35,7 @@ type Operations interface { Unmount(ctx context.Context, volumeId string, targetPath string) error GetVolumeById(ctx context.Context, volumeId string) (*RcloneVolume, error) Cleanup() error - Run() error + Run(ctx context.Context, onDaemonReady func() error) error } type Rclone struct { @@ -160,7 +161,11 @@ func (r *Rclone) Mount(ctx context.Context, rcloneVolume *RcloneVolume, targetPa return fmt.Errorf("mounting failed: couldn't create request body: %s", err) } requestBody := bytes.NewBuffer(postBody) - resp, err := http.Post(fmt.Sprintf("http://localhost:%d/config/create", r.port), "application/json", requestBody) + req, err := createRcloneRequest(ctx, http.MethodPost, requestBody, "/config/create", r.port) + if err != nil { + return fmt.Errorf("mounting failed: cannot create a request for rclone config creation: %w", err) + } + resp, err := http.DefaultClient.Do(req) if err != nil { return fmt.Errorf("mounting failed: couldn't send HTTP request to create config: %w", err) } @@ -218,7 +223,11 @@ func (r *Rclone) Mount(ctx context.Context, rcloneVolume *RcloneVolume, targetPa } klog.Infof("executing mount command args=%s", string(postBody)) requestBody = bytes.NewBuffer(postBody) - resp, err = http.Post(fmt.Sprintf("http://localhost:%d/mount/mount", r.port), "application/json", requestBody) + req, err = createRcloneRequest(ctx, http.MethodPost, requestBody, "/mount/mount", r.port) + if err != nil { + return fmt.Errorf("mounting failed: cannot create a request for rclone mounting: %w", err) + } + resp, err = http.DefaultClient.Do(req) if err != nil { return fmt.Errorf("mounting failed: couldn't send HTTP request to create mount: %w", err) } @@ -249,7 +258,7 @@ func (r *Rclone) CreateVol(ctx context.Context, volumeName, remote, remotePath, } flags["config"] = rcloneConfigPath - return r.command("mkdir", remote, path, flags) + return r.command(ctx, "mkdir", remote, path, flags) } func (r Rclone) DeleteVol(ctx context.Context, rcloneVolume *RcloneVolume, rcloneConfigPath string, parameters map[string]string) error { @@ -258,7 +267,7 @@ func (r Rclone) DeleteVol(ctx context.Context, rcloneVolume *RcloneVolume, rclon flags[key] = value } flags["config"] = rcloneConfigPath - return r.command("purge", rcloneVolume.Remote, rcloneVolume.RemotePath, flags) + return r.command(ctx, "purge", rcloneVolume.Remote, rcloneVolume.RemotePath, flags) } func (r Rclone) Unmount(ctx context.Context, volumeId string, targetPath string) error { @@ -273,7 +282,11 @@ func (r Rclone) Unmount(ctx context.Context, volumeId string, targetPath string) return fmt.Errorf("unmounting failed: couldn't create request body: %s", err) } requestBody := bytes.NewBuffer(postBody) - resp, err := http.Post(fmt.Sprintf("http://localhost:%d/mount/unmount", r.port), "application/json", requestBody) + req, err := createRcloneRequest(ctx, http.MethodPost, requestBody, "/mount/unmount", r.port) + if err != nil { + return fmt.Errorf("unmounting failed: couldn't create a request for rclone: %w", err) + } + resp, err := http.DefaultClient.Do(req) if err != nil { return fmt.Errorf("unmounting failed: couldn't send HTTP request: %w", err) } @@ -291,7 +304,11 @@ func (r Rclone) Unmount(ctx context.Context, volumeId string, targetPath string) return fmt.Errorf("deleting config failed: couldn't create request body: %s", err) } requestBody = bytes.NewBuffer(postBody) - resp, err = http.Post(fmt.Sprintf("http://localhost:%d/config/delete", r.port), "application/json", requestBody) + req, err = createRcloneRequest(ctx, http.MethodPost, requestBody, "/config/delete", r.port) + if err != nil { + return fmt.Errorf("unmounting failed: couldn't create a request for rclone configuration deletion: %w", err) + } + resp, err = http.DefaultClient.Do(req) if err != nil { klog.Errorf("deleting config failed: couldn't send HTTP request: %v", err) return nil @@ -409,7 +426,26 @@ func checkResponse(resp *http.Response) error { return fmt.Errorf("received error from the rclone server: %s", result.String()) } -func (r *Rclone) start_daemon() error { +func waitForDaemon(ctx context.Context, port int) error { + // Wait for the daemon to have started + ctxWaitRcloneStart, cancel := context.WithTimeout(ctx, 1*time.Second) + defer cancel() + + req, err := createRcloneRequest(ctxWaitRcloneStart, http.MethodPost, nil, "/core/version", port) + if err != nil { + return err + } + + for { + _, err = http.DefaultClient.Do(req) + // Keep trying until we retrieve a response, or we hit the deadline + if err == nil || errors.Is(err, context.DeadlineExceeded) { + return err + } + } +} + +func (r *Rclone) start_daemon(ctx context.Context) error { f, err := os.CreateTemp("", "rclone.conf") if err != nil { return err @@ -449,7 +485,7 @@ func (r *Rclone) start_daemon() error { klog.Infof("running rclone remote control daemon cmd=%s, args=%s", rclone_cmd, rclone_args) env := os.Environ() - cmd := os_exec.Command(rclone_cmd, rclone_args...) + cmd := os_exec.CommandContext(ctx, rclone_cmd, rclone_args...) cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} stdout, err := cmd.StdoutPipe() cmd.Stderr = cmd.Stdout @@ -461,6 +497,11 @@ func (r *Rclone) start_daemon() error { if err := cmd.Start(); err != nil { return err } + + if err := waitForDaemon(ctx, r.port); err != nil { + return err + } + go func() { output := "" for scanner.Scan() { @@ -472,11 +513,16 @@ func (r *Rclone) start_daemon() error { return nil } -func (r *Rclone) Run() error { - err := r.start_daemon() +func (r *Rclone) Run(ctx context.Context, onDaemonReady func() error) error { + err := r.start_daemon(ctx) if err != nil { return err } + if onDaemonReady != nil { + if err := onDaemonReady(); err != nil { + klog.Warningf("Error in onDaemonReady callback: %v", err) + } + } // blocks until the rclone daemon is stopped return r.daemonCmd.Wait() } @@ -489,7 +535,7 @@ func (r *Rclone) Cleanup() error { return r.daemonCmd.Process.Kill() } -func (r *Rclone) command(cmd, remote, remotePath string, flags map[string]string) error { +func (r *Rclone) command(ctx context.Context, cmd, remote, remotePath string, flags map[string]string) error { // rclone remote:path [flag] args := append( []string{}, @@ -503,7 +549,7 @@ func (r *Rclone) command(cmd, remote, remotePath string, flags map[string]string } klog.Infof("executing %s command cmd=rclone, remote=%s:%s", cmd, remote, remotePath) - out, err := r.execute.Command("rclone", args...).CombinedOutput() + out, err := r.execute.CommandContext(ctx, "rclone", args...).CombinedOutput() if err != nil { return fmt.Errorf("%s failed: %v cmd: 'rclone' remote: '%s' remotePath:'%s' args:'%s' output: %q", cmd, err, remote, remotePath, args, string(out)) @@ -511,3 +557,13 @@ func (r *Rclone) command(cmd, remote, remotePath string, flags map[string]string return nil } + +func createRcloneRequest(ctx context.Context, method string, body io.Reader, path string, rcloneServerPort int) (*http.Request, error) { + rcloneServerURL := fmt.Sprintf("http://localhost:%d/%s", rcloneServerPort, strings.TrimLeft(path, "/")) + req, err := http.NewRequestWithContext(ctx, method, rcloneServerURL, body) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + return req, nil +} diff --git a/test/sanity_test.go b/test/sanity_test.go index 0527cd4..a74beb6 100644 --- a/test/sanity_test.go +++ b/test/sanity_test.go @@ -12,6 +12,7 @@ import ( "github.com/google/uuid" "github.com/kubernetes-csi/csi-test/v5/pkg/sanity" "github.com/kubernetes-csi/csi-test/v5/utils" + csicommon "github.com/kubernetes-csi/drivers/pkg/csi-common" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "google.golang.org/grpc" @@ -46,24 +47,30 @@ var _ = Describe("Sanity CSI checks", Ordered, func() { var err error var kubeClient *kubernetes.Clientset = &kubernetes.Clientset{} var endpoint string - var driver *rclone.Driver = &rclone.Driver{} var socketDir string - BeforeAll(func() { + BeforeAll(func(ctx SpecContext) { socketDir, err = createSocketDir() Expect(err).ShouldNot(HaveOccurred()) endpoint = fmt.Sprintf("unix://%s/csi.sock", socketDir) + config := rclone.NodeServerConfig{DriverConfig: rclone.DriverConfig{Endpoint: endpoint, NodeID: "hostname"}} kubeClient, err = kube.GetK8sClient() Expect(err).ShouldNot(HaveOccurred()) os.Setenv("DRIVER_NAME", "csi-rclone") - driver = rclone.NewDriver("hostname", endpoint) - cs := rclone.NewControllerServer(driver.CSIDriver) - ns, err := rclone.NewNodeServer(driver.CSIDriver, "", "") - Expect(err).ShouldNot(HaveOccurred()) - driver.WithControllerServer(cs).WithNodeServer(ns) go func() { defer GinkgoRecover() - err := driver.Run() + err := rclone.Run(context.Background(), &config.DriverConfig, + func(csiDriver *csicommon.CSIDriver) (*rclone.ControllerServer, *rclone.NodeServer, error) { + cs := rclone.NewControllerServer(csiDriver) + ns, err := rclone.NewNodeServer(csiDriver, config.CacheDir, config.CacheSize) + Expect(err).ShouldNot(HaveOccurred()) + return cs, ns, err + }, + func(ctx context.Context, cs *rclone.ControllerServer, ns *rclone.NodeServer) error { + Expect(ns).ShouldNot(BeNil()) + return ns.Run(ctx) + }, + ) Expect(err).ShouldNot(HaveOccurred()) }() _, err = utils.Connect(endpoint, grpc.WithTransportCredentials(insecure.NewCredentials())) @@ -71,7 +78,6 @@ var _ = Describe("Sanity CSI checks", Ordered, func() { }) AfterAll(func() { - driver.Stop() os.RemoveAll(socketDir) os.Unsetenv("DRIVER_NAME") })