-
Notifications
You must be signed in to change notification settings - Fork 1
/
hedged.go
180 lines (150 loc) · 4.06 KB
/
hedged.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
package synx
import (
"context"
"fmt"
"strings"
"time"
)
const infiniteTimeout = 30 * 24 * time.Hour // domain specific infinite
// HedgedWorker ...
type HedgedWorker interface {
Execute(ctx context.Context, input any) (result any, err error)
}
// NewHedger returns a new http.RoundTripper which implements hedged requests pattern.
// Given RoundTripper starts a new request after a timeout from previous request.
// Starts no more than upto requests.
func NewHedger(timeout time.Duration, upto int, worker HedgedWorker) HedgedWorker {
switch {
case timeout < 0:
panic("synx: timeout cannot be negative")
case upto < 1:
panic("synx: upto must be greater than 0")
case worker == nil:
panic("synx: worker cannot be nil")
}
if timeout == 0 {
timeout = time.Nanosecond // smallest possible timeout if not set
}
hedged := &hedgedWorker{
worker: worker,
timeout: timeout,
upto: upto,
wp: NewWorkerPool(10, time.Minute),
}
return hedged
}
type hedgedWorker struct {
worker HedgedWorker
timeout time.Duration
upto int
wp *WorkerPool
}
func (ht *hedgedWorker) Execute(ctx context.Context, input any) (any, error) {
mainCtx := ctx
var timeout time.Duration
errOverall := &MultiError{}
resultCh := make(chan indexedResult, ht.upto)
errorCh := make(chan error, ht.upto)
resultIdx := -1
cancels := make([]func(), ht.upto)
defer ht.wp.Do(func() {
for i, cancel := range cancels {
if i != resultIdx && cancel != nil {
cancel()
}
}
})
for sent := 0; len(errOverall.Errors) < ht.upto; sent++ {
if sent < ht.upto {
idx := sent
subCtx, cancel := context.WithCancel(ctx)
cancels[idx] = cancel
ht.wp.Do(func() {
result, err := ht.worker.Execute(subCtx, input)
if err != nil {
errorCh <- err
} else {
resultCh <- indexedResult{idx, result}
}
})
}
// all request sent - effectively disabling timeout between requests
if sent == ht.upto {
timeout = infiniteTimeout
}
result, err := waitResult(mainCtx, resultCh, errorCh, timeout)
switch {
case result.Result != nil:
resultIdx = result.Index
return result.Result, nil
case mainCtx.Err() != nil:
return nil, mainCtx.Err()
case err != nil:
errOverall.Errors = append(errOverall.Errors, err)
}
}
// all request have returned errors
return nil, errOverall
}
func waitResult(ctx context.Context, resultCh <-chan indexedResult, errorCh <-chan error, timeout time.Duration) (indexedResult, error) {
// try to read result first before blocking on all other channels
select {
case res := <-resultCh:
return res, nil
default:
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case res := <-resultCh:
return res, nil
case err := <-errorCh:
return indexedResult{}, err
case <-ctx.Done():
return indexedResult{}, ctx.Err()
case <-timer.C:
return indexedResult{}, nil // it's not a request timeout, it's timeout BETWEEN consecutive requests
}
}
}
type indexedResult struct {
Index int
Result any
}
// MultiError is an error type to track multiple errors. This is used to
// accumulate errors in cases and return them as a single "error".
// Insiper by https://github.com/hashicorp/go-multierror
type MultiError struct {
Errors []error
ErrorFormatFn ErrorFormatFunc
}
func (e *MultiError) Error() string {
fn := e.ErrorFormatFn
if fn == nil {
fn = listFormatFunc
}
return fn(e.Errors)
}
func (e *MultiError) String() string {
return fmt.Sprintf("*%#v", e.Errors)
}
// ErrorOrNil returns an error if there are some.
func (e *MultiError) ErrorOrNil() error {
switch {
case e == nil || len(e.Errors) == 0:
return nil
default:
return e
}
}
// ErrorFormatFunc is called by MultiError to return the list of errors as a string.
type ErrorFormatFunc func([]error) string
func listFormatFunc(es []error) string {
if len(es) == 1 {
return fmt.Sprintf("1 error occurred:\n\t* %s\n\n", es[0])
}
points := make([]string, len(es))
for i, err := range es {
points[i] = fmt.Sprintf("* %s", err)
}
return fmt.Sprintf("%d errors occurred:\n\t%s\n\n", len(es), strings.Join(points, "\n\t"))
}