From 4be6c907f3c9b5803cc165f494e5dbbbf67a347a Mon Sep 17 00:00:00 2001 From: Jan Safranek Date: Mon, 13 May 2024 16:01:08 +0200 Subject: [PATCH] Use correct context - Set timeout for each CSI call separately. - Pass ctx to csi-lib-utils function as needed. --- cmd/csi-snapshotter/main.go | 16 +++++++++++----- cmd/csi-snapshotter/main_test.go | 2 +- pkg/snapshotter/snapshotter_test.go | 2 +- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/cmd/csi-snapshotter/main.go b/cmd/csi-snapshotter/main.go index a501b7015..0d09d15bc 100644 --- a/cmd/csi-snapshotter/main.go +++ b/cmd/csi-snapshotter/main.go @@ -165,7 +165,9 @@ func main() { // Connect to CSI. metricsManager := metrics.NewCSIMetricsManager("" /* driverName */) + ctx := context.Background() csiConn, err := connection.Connect( + ctx, *csiAddress, metricsManager, connection.OnConnectionLoss(connection.ExitOnConnectionLoss())) @@ -175,11 +177,11 @@ func main() { } // Pass a context with a timeout - ctx, cancel := context.WithTimeout(context.Background(), *csiTimeout) + tctx, cancel := context.WithTimeout(ctx, *csiTimeout) defer cancel() // Find driver name - driverName, err := csirpc.GetDriverName(ctx, csiConn) + driverName, err := csirpc.GetDriverName(tctx, csiConn) if err != nil { klog.Errorf("error getting CSI driver name: %v", err) os.Exit(1) @@ -202,13 +204,15 @@ func main() { } // Check it's ready - if err = csirpc.ProbeForever(csiConn, *csiTimeout); err != nil { + if err = csirpc.ProbeForever(ctx, csiConn, *csiTimeout); err != nil { klog.Errorf("error waiting for CSI driver to be ready: %v", err) os.Exit(1) } // Find out if the driver supports create/delete snapshot. - supportsCreateSnapshot, err := supportsControllerCreateSnapshot(ctx, csiConn) + tctx, cancel = context.WithTimeout(ctx, *csiTimeout) + defer cancel() + supportsCreateSnapshot, err := supportsControllerCreateSnapshot(tctx, csiConn) if err != nil { klog.Errorf("error determining if driver supports create/delete snapshot operations: %v", err) os.Exit(1) @@ -228,7 +232,9 @@ func main() { snapShotter := snapshotter.NewSnapshotter(csiConn) var groupSnapshotter group_snapshotter.GroupSnapshotter if *enableVolumeGroupSnapshots { - supportsCreateVolumeGroupSnapshot, err := supportsGroupControllerCreateVolumeGroupSnapshot(ctx, csiConn) + tctx, cancel = context.WithTimeout(ctx, *csiTimeout) + defer cancel() + supportsCreateVolumeGroupSnapshot, err := supportsGroupControllerCreateVolumeGroupSnapshot(tctx, csiConn) if err != nil { klog.Errorf("error determining if driver supports create/delete group snapshot operations: %v", err) } else if !supportsCreateVolumeGroupSnapshot { diff --git a/cmd/csi-snapshotter/main_test.go b/cmd/csi-snapshotter/main_test.go index 052d8873a..c3ecf9d58 100644 --- a/cmd/csi-snapshotter/main_test.go +++ b/cmd/csi-snapshotter/main_test.go @@ -154,7 +154,7 @@ func createMockServer(t *testing.T) (*gomock.Controller, *driver.MockCSIDriver, // Create a client connection to it addr := drv.Address() - csiConn, err := connection.Connect(addr, metricsManager) + csiConn, err := connection.Connect(context.Background(), addr, metricsManager) if err != nil { return nil, nil, nil, nil, nil, err } diff --git a/pkg/snapshotter/snapshotter_test.go b/pkg/snapshotter/snapshotter_test.go index 6f225a287..148848fff 100644 --- a/pkg/snapshotter/snapshotter_test.go +++ b/pkg/snapshotter/snapshotter_test.go @@ -57,7 +57,7 @@ func createMockServer(t *testing.T) (*gomock.Controller, *driver.MockCSIDriver, // Create a client connection to it addr := drv.Address() - csiConn, err := connection.Connect(addr, metricsManager) + csiConn, err := connection.Connect(context.Background(), addr, metricsManager) if err != nil { return nil, nil, nil, nil, nil, err }