diff --git a/gslice/gslice.go b/gslice/gslice.go index 5a2a3a8..3eca1ca 100644 --- a/gslice/gslice.go +++ b/gslice/gslice.go @@ -875,8 +875,9 @@ func Union[S ~[]T, T comparable](ss ...S) S { if len(ss) == 1 { return Uniq(ss[0]) } - members := set.New[T]() - ret := S{} // TODO: Guess a cap. + size := Sum(Map(ss, func(s S) int { return len(s) })) + members := set.NewWithCap[T](size) + ret := make(S, 0, size/2) for _, s := range ss { for _, v := range s { if members.Add(v) { diff --git a/gslice/gslice_bench_test.go b/gslice/gslice_bench_test.go index b6c3165..f7ab6e0 100644 --- a/gslice/gslice_bench_test.go +++ b/gslice/gslice_bench_test.go @@ -19,6 +19,7 @@ import ( "strconv" "testing" + "github.com/bytedance/gg/collection/set" "github.com/bytedance/gg/internal/iter" ) @@ -107,3 +108,106 @@ func BenchmarkShuffle_Parallel(b *testing.B) { }) }) } + +func oldUnion[S ~[]T, T comparable](ss ...S) S { + if len(ss) == 0 { + return S{} + } + if len(ss) == 1 { + return Uniq(ss[0]) + } + members := set.New[T]() + ret := S{} // TODO: Guess a cap. + for _, s := range ss { + for _, v := range s { + if members.Add(v) { + ret = append(ret, v) + } + } + } + return ret +} + +func BenchmarkUnion(b *testing.B) { + // 1. all different + ss1 := [][]int{ + Range(0, 10), + Range(10, 20), + } + b.Run("new-union-diff-2-10", func(b *testing.B) { + for i := 0; i < b.N; i++ { + Union(ss1...) + } + }) + b.Run("old-union-diff-2-10", func(b *testing.B) { + for i := 0; i < b.N; i++ { + oldUnion(ss1...) + } + }) + ss2 := [][]int{ + Range(0, 100), + Range(100, 200), + Range(200, 300), + Range(300, 400), + Range(400, 500), + } + b.Run("new-union-diff-5-100", func(b *testing.B) { + for i := 0; i < b.N; i++ { + Union(ss2...) + } + }) + b.Run("old-union-diff-5-100", func(b *testing.B) { + for i := 0; i < b.N; i++ { + oldUnion(ss2...) + } + }) + + // 2. all same + ss3 := [][]int{ + Repeat(0, 10), + Repeat(0, 10), + } + b.Run("new-union-same-2-10", func(b *testing.B) { + for i := 0; i < b.N; i++ { + Union(ss3...) + } + }) + b.Run("old-union-same-2-10", func(b *testing.B) { + for i := 0; i < b.N; i++ { + oldUnion(ss3...) + } + }) + ss4 := [][]int{ + Repeat(0, 100), + Repeat(0, 100), + Repeat(0, 100), + Repeat(0, 100), + Repeat(0, 100), + } + b.Run("new-union-same-5-100", func(b *testing.B) { + for i := 0; i < b.N; i++ { + Union(ss4...) + } + }) + b.Run("old-union-same-5-100", func(b *testing.B) { + for i := 0; i < b.N; i++ { + oldUnion(ss4...) + } + }) + + // 3. half different + ss5 := [][]int{ + Range(0, 100), + Range(50, 150), + } + b.Run("new-union-half-2-100", func(b *testing.B) { + for i := 0; i < b.N; i++ { + Union(ss5...) + } + }) + b.Run("old-union-half-2-100", func(b *testing.B) { + for i := 0; i < b.N; i++ { + oldUnion(ss5...) + } + }) +}