Skip to content

Commit cde8801

Browse files
committed
feat(iterator): unoptimized Tee
1 parent bc97567 commit cde8801

File tree

2 files changed

+105
-0
lines changed

2 files changed

+105
-0
lines changed

iterator/iter.go

+63
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package iterator
22

33
import (
44
"iter"
5+
"slices"
6+
"sync"
57
)
68

79
func Map[T, V any](it func(func(T) bool), fn func(T) V) func(func(V) bool) {
@@ -220,3 +222,64 @@ func Flatten[T any, S ~[]T](it func(func(S) bool)) func(func(T) bool) {
220222
}
221223
}
222224
}
225+
226+
type teeState[T any] struct {
227+
next func() (T, bool)
228+
mu sync.Mutex
229+
buf []T
230+
positions []int
231+
}
232+
233+
func (s *teeState[T]) advanceOne(i int) (T, bool) {
234+
s.mu.Lock()
235+
defer s.mu.Unlock()
236+
237+
if s.positions[i] == len(s.buf) {
238+
t, ok := s.next()
239+
if !ok {
240+
return t, ok
241+
}
242+
s.buf = append(s.buf, t)
243+
minPos := slices.Min(s.positions)
244+
if minPos > 0 {
245+
s.buf = s.buf[minPos:]
246+
for j := range s.positions {
247+
s.positions[j] -= minPos
248+
}
249+
}
250+
}
251+
pos := s.positions[i]
252+
s.positions[i]++
253+
return s.buf[pos], true
254+
}
255+
256+
func Tee[T any](it func(func(T) bool), n int) []func(func(T) bool) {
257+
next, stop := iter.Pull(it)
258+
259+
state := &teeState[T]{
260+
next: next,
261+
positions: make([]int, n),
262+
}
263+
stopped := 0
264+
265+
outs := make([]func(func(T) bool), n)
266+
for i := range n {
267+
i := i
268+
outs[i] = func(yield func(T) bool) {
269+
for {
270+
t, ok := state.advanceOne(i)
271+
if !ok {
272+
return
273+
}
274+
if !yield(t) {
275+
stopped++
276+
break
277+
}
278+
}
279+
if stopped == n {
280+
stop()
281+
}
282+
}
283+
}
284+
return outs
285+
}

iterator/iter_test.go

+42
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
package iterator
22

33
import (
4+
"go-exp/functions/partials"
45
"gotest.tools/v3/assert"
56
"maps"
67
"slices"
78
"strconv"
89
"strings"
10+
"sync"
911
"testing"
1012
)
1113

@@ -198,3 +200,43 @@ func TestFlatten(t *testing.T) {
198200

199201
assert.DeepEqual(t, result, expected)
200202
}
203+
204+
func TestTee(t *testing.T) {
205+
it := slices.Values([]int{4, 3, 2, 1})
206+
its := Tee(it, 2)
207+
208+
result1 := slices.Collect(its[0])
209+
result2 := slices.Collect(its[1])
210+
211+
expected := []int{4, 3, 2, 1}
212+
assert.DeepEqual(t, result1, expected)
213+
assert.DeepEqual(t, result2, expected)
214+
}
215+
216+
func BenchmarkTee(b *testing.B) {
217+
if b.N > 10000000 {
218+
b.Skipf("N too large: %d", b.N)
219+
}
220+
221+
it := TakeWhile(Generate(b.N, func(x int) int { return x - 1 }), partials.Gt(0))
222+
its := Tee(it, 10)
223+
224+
results := make([][]int, 10)
225+
226+
var wg sync.WaitGroup
227+
228+
for i := range 10 {
229+
wg.Add(1)
230+
go func() {
231+
results[i] = slices.Collect(its[i])
232+
wg.Done()
233+
}()
234+
}
235+
236+
expected := slices.Collect(TakeWhile(Generate(b.N, func(x int) int { return x - 1 }), partials.Gt(0)))
237+
238+
wg.Wait()
239+
for _, result := range results {
240+
assert.DeepEqual(b, result, expected)
241+
}
242+
}

0 commit comments

Comments
 (0)