From 7b4f3952bf2d604f661d1dd5c62ad795eb25cf97 Mon Sep 17 00:00:00 2001 From: Ahmet Alp Balkan Date: Wed, 2 Oct 2024 23:45:33 -0700 Subject: [PATCH] Proposal: expose MultiError (#8) One of the great things about semgroup over x/sync/errgroup is that it actually waits all tasks to complete and returns an error. However, the returned error type doesn't let the caller to iterate over the errors (e.g. my use case requires me to truncate the list of errors). Therefore, introducing a backwards-compatible change that exports `MultiError` type, and guarantees that `Wait()` method returns an error of this type. --- semgroup.go | 17 ++++++++++------- semgroup_test.go | 12 ++++++++++++ 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/semgroup.go b/semgroup.go index 7582d0a..c05e5ec 100644 --- a/semgroup.go +++ b/semgroup.go @@ -25,7 +25,7 @@ type Group struct { wg sync.WaitGroup ctx context.Context - errs multiError + errs MultiError mu sync.Mutex // protects errs } @@ -70,15 +70,17 @@ func (g *Group) Go(f func() error) { } // Wait blocks until all function calls from the Go method have returned, then -// returns all accumulated non-nil error (if any) from them. +// returns all accumulated non-nil errors (if any) from them. +// +// If a non-nil error is returned, it will be of type [MultiError]. func (g *Group) Wait() error { g.wg.Wait() return g.errs.ErrorOrNil() } -type multiError []error +type MultiError []error -func (e multiError) Error() string { +func (e MultiError) Error() string { var b strings.Builder fmt.Fprintf(&b, "%d error(s) occurred:\n", len(e)) @@ -92,7 +94,8 @@ func (e multiError) Error() string { return b.String() } -func (e multiError) ErrorOrNil() error { +// ErrorOrNil returns nil if there are no errors, otherwise returns itself. +func (e MultiError) ErrorOrNil() error { if len(e) == 0 { return nil } @@ -100,7 +103,7 @@ func (e multiError) ErrorOrNil() error { return e } -func (e multiError) Is(target error) bool { +func (e MultiError) Is(target error) bool { for _, err := range e { if errors.Is(err, target) { return true @@ -109,7 +112,7 @@ func (e multiError) Is(target error) bool { return false } -func (e multiError) As(target interface{}) bool { +func (e MultiError) As(target interface{}) bool { for _, err := range e { if errors.As(err, target) { return true diff --git a/semgroup_test.go b/semgroup_test.go index 255206d..45e56a2 100644 --- a/semgroup_test.go +++ b/semgroup_test.go @@ -62,6 +62,9 @@ func TestGroup_multiple_tasks_errors(t *testing.T) { if err == nil { t.Fatalf("g.Wait() should return an error") } + if !errors.As(err, &MultiError{}) { + t.Fatalf("the error should be of type MultiError") + } wantErr := `2 error(s) occurred: * foo @@ -124,6 +127,15 @@ func TestGroup_multiple_tasks_errors_Is(t *testing.T) { if errors.Is(err, bazErr) { t.Errorf("error should not be contained %v\n", bazErr) } + + var gotMultiErr MultiError + if !errors.As(err, &gotMultiErr) { + t.Fatalf("error should be matched MultiError") + } + expectedErr := (MultiError{fooErr, barErr}).Error() + if gotMultiErr.Error() != expectedErr { + t.Errorf("error should be %q, got %q", expectedErr, gotMultiErr.Error()) + } } type foobarErr struct{ str string }