Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions timeout/timeout.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Package timeout provides a utility function to execute a function with a
// timeout. It helps in scenarios where an operation needs to be bound by a time
// limit, preventing indefinite blocking.
package timeout

import (
"context"
"time"
)

// DoWithTimeout executes the given function f within the specified timeout
// duration. It takes a parent context, a timeout duration, and the function to
// execute. The function f is of type func() error.
//
// DoWithTimeout returns nil if f completes successfully within the timeout. If
// f returns an error, DoWithTimeout returns that error. If the timeout duration
// is reached before f completes, DoWithTimeout returns
// [context.DeadlineExceeded]. If the parent context is canceled before f
// completes, DoWithTimeout returns [context.Canceled].
func DoWithTimeout(ctx context.Context, timeout time.Duration, f func() error) error {
ctxWithTimeout, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

// Buffer of 1 to prevent sender from blocking if receiver is not ready
done := make(chan error, 1)

go func() {
// Close the channel when the goroutine exits
defer close(done)
done <- f()
}()

select {
case <-ctxWithTimeout.Done():
return ctxWithTimeout.Err()
case err := <-done:
return err
}
}
120 changes: 120 additions & 0 deletions timeout/timeout_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package timeout

import (
"context"
"errors"
"fmt"
"testing"
"time"
)

func TestDoWithTimeout_Success(t *testing.T) {
ctx := context.Background()
timeout := 100 * time.Millisecond
f := func() error {
time.Sleep(10 * time.Millisecond)
return nil
}

err := DoWithTimeout(ctx, timeout, f)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
}

func TestDoWithTimeout_Timeout(t *testing.T) {
ctx := context.Background()
timeout := 10 * time.Millisecond
f := func() error {
time.Sleep(100 * time.Millisecond)
return nil
}

err := DoWithTimeout(ctx, timeout, f)
if !errors.Is(err, context.DeadlineExceeded) {
t.Errorf("Expected context.DeadlineExceeded, got %v", err)
}
}

func TestDoWithTimeout_ContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
timeout := 100 * time.Millisecond
f := func() error {
time.Sleep(10 * time.Millisecond)
return nil
}

// Cancel context before calling DoWithTimeout
cancel()
err := DoWithTimeout(ctx, timeout, f)
if !errors.Is(err, context.Canceled) {
t.Errorf("Expected context.Canceled, got %v", err)
}
}

func ExampleDoWithTimeout() {
// Scenario 1: Function completes successfully within timeout
ctx1 := context.Background()
timeout1 := 100 * time.Millisecond
f1 := func() error {
time.Sleep(10 * time.Millisecond)
fmt.Println("Function 1 completed")
return nil
}
err1 := DoWithTimeout(ctx1, timeout1, f1)
if err1 != nil {
fmt.Printf("Function 1 error: %v\n", err1)
}

// Scenario 2: Function times out
ctx2 := context.Background()
timeout2 := 10 * time.Millisecond
f2 := func() error {
time.Sleep(100 * time.Millisecond)
fmt.Println("Function 2 completed (this should not print if timeout works)")
return nil
}
err2 := DoWithTimeout(ctx2, timeout2, f2)
if err2 != nil {
fmt.Printf("Function 2 error: %v\n", err2)
}

// Output:
// Function 1 completed
// Function 2 error: context deadline exceeded
}

func TestDoWithTimeout_FunctionError(t *testing.T) {
ctx := context.Background()
timeout := 100 * time.Millisecond
expectedErr := errors.New("function error")
f := func() error {
return expectedErr
}

err := DoWithTimeout(ctx, timeout, f)
if !errors.Is(err, expectedErr) {
t.Errorf("Expected error %v, got %v", expectedErr, err)
}
}

func TestDoWithTimeout_ContextCanceledDuringExecution(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
timeout := 100 * time.Millisecond
f := func() error {
// Sleep long enough for cancellation to occur
time.Sleep(50 * time.Millisecond)
return nil
}

go func() {
// Wait a bit then cancel
time.Sleep(10 * time.Millisecond)
cancel()
}()

err := DoWithTimeout(ctx, timeout, f)
if !errors.Is(err, context.Canceled) {
t.Errorf("Expected context.Canceled, got %v", err)
}
}
Loading