diff --git a/timeout/timeout.go b/timeout/timeout.go new file mode 100644 index 0000000..f2c13dd --- /dev/null +++ b/timeout/timeout.go @@ -0,0 +1,39 @@ +// Package timeout provides a utility function to execute a function with a +// timeout. It helps in scenarios where an operation needs to be bound by a time +// limit, preventing indefinite blocking. +package timeout + +import ( + "context" + "time" +) + +// DoWithTimeout executes the given function f within the specified timeout +// duration. It takes a parent context, a timeout duration, and the function to +// execute. The function f is of type func() error. +// +// DoWithTimeout returns nil if f completes successfully within the timeout. If +// f returns an error, DoWithTimeout returns that error. If the timeout duration +// is reached before f completes, DoWithTimeout returns +// [context.DeadlineExceeded]. If the parent context is canceled before f +// completes, DoWithTimeout returns [context.Canceled]. +func DoWithTimeout(ctx context.Context, timeout time.Duration, f func() error) error { + ctxWithTimeout, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + // Buffer of 1 to prevent sender from blocking if receiver is not ready + done := make(chan error, 1) + + go func() { + // Close the channel when the goroutine exits + defer close(done) + done <- f() + }() + + select { + case <-ctxWithTimeout.Done(): + return ctxWithTimeout.Err() + case err := <-done: + return err + } +} diff --git a/timeout/timeout_test.go b/timeout/timeout_test.go new file mode 100644 index 0000000..f682d28 --- /dev/null +++ b/timeout/timeout_test.go @@ -0,0 +1,120 @@ +package timeout + +import ( + "context" + "errors" + "fmt" + "testing" + "time" +) + +func TestDoWithTimeout_Success(t *testing.T) { + ctx := context.Background() + timeout := 100 * time.Millisecond + f := func() error { + time.Sleep(10 * time.Millisecond) + return nil + } + + err := DoWithTimeout(ctx, timeout, f) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } +} + +func TestDoWithTimeout_Timeout(t *testing.T) { + ctx := context.Background() + timeout := 10 * time.Millisecond + f := func() error { + time.Sleep(100 * time.Millisecond) + return nil + } + + err := DoWithTimeout(ctx, timeout, f) + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("Expected context.DeadlineExceeded, got %v", err) + } +} + +func TestDoWithTimeout_ContextCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + timeout := 100 * time.Millisecond + f := func() error { + time.Sleep(10 * time.Millisecond) + return nil + } + + // Cancel context before calling DoWithTimeout + cancel() + err := DoWithTimeout(ctx, timeout, f) + if !errors.Is(err, context.Canceled) { + t.Errorf("Expected context.Canceled, got %v", err) + } +} + +func ExampleDoWithTimeout() { + // Scenario 1: Function completes successfully within timeout + ctx1 := context.Background() + timeout1 := 100 * time.Millisecond + f1 := func() error { + time.Sleep(10 * time.Millisecond) + fmt.Println("Function 1 completed") + return nil + } + err1 := DoWithTimeout(ctx1, timeout1, f1) + if err1 != nil { + fmt.Printf("Function 1 error: %v\n", err1) + } + + // Scenario 2: Function times out + ctx2 := context.Background() + timeout2 := 10 * time.Millisecond + f2 := func() error { + time.Sleep(100 * time.Millisecond) + fmt.Println("Function 2 completed (this should not print if timeout works)") + return nil + } + err2 := DoWithTimeout(ctx2, timeout2, f2) + if err2 != nil { + fmt.Printf("Function 2 error: %v\n", err2) + } + + // Output: + // Function 1 completed + // Function 2 error: context deadline exceeded +} + +func TestDoWithTimeout_FunctionError(t *testing.T) { + ctx := context.Background() + timeout := 100 * time.Millisecond + expectedErr := errors.New("function error") + f := func() error { + return expectedErr + } + + err := DoWithTimeout(ctx, timeout, f) + if !errors.Is(err, expectedErr) { + t.Errorf("Expected error %v, got %v", expectedErr, err) + } +} + +func TestDoWithTimeout_ContextCanceledDuringExecution(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + timeout := 100 * time.Millisecond + f := func() error { + // Sleep long enough for cancellation to occur + time.Sleep(50 * time.Millisecond) + return nil + } + + go func() { + // Wait a bit then cancel + time.Sleep(10 * time.Millisecond) + cancel() + }() + + err := DoWithTimeout(ctx, timeout, f) + if !errors.Is(err, context.Canceled) { + t.Errorf("Expected context.Canceled, got %v", err) + } +}