@@ -10,6 +10,7 @@ package testing
1010
1111import  (
1212	"bytes" 
13+ 	"context" 
1314	"errors" 
1415	"flag" 
1516	"fmt" 
@@ -78,6 +79,9 @@ type common struct {
7879	tempDir     string 
7980	tempDirErr  error 
8081	tempDirSeq  int32 
82+ 
83+ 	ctx        context.Context 
84+ 	cancelCtx  context.CancelFunc 
8185}
8286
8387type  logger  struct  {
@@ -152,6 +156,7 @@ func fmtDuration(d time.Duration) string {
152156// TB is the interface common to T and B. 
153157type  TB  interface  {
154158	Cleanup (func ())
159+ 	Context () context.Context 
155160	Error (args  ... interface {})
156161	Errorf (format  string , args  ... interface {})
157162	Fail ()
@@ -307,6 +312,15 @@ func (c *common) Cleanup(f func()) {
307312	c .cleanups  =  append (c .cleanups , f )
308313}
309314
315+ // Context returns a context that is canceled just before 
316+ // Cleanup-registered functions are called. 
317+ // 
318+ // Cleanup functions can wait for any resources 
319+ // that shut down on [context.Context.Done] before the test or benchmark completes. 
320+ func  (c  * common ) Context () context.Context  {
321+ 	return  c .ctx 
322+ }
323+ 
310324// TempDir returns a temporary directory for the test to use. 
311325// The directory is automatically removed by Cleanup when the test and 
312326// all its subtests complete. 
@@ -447,6 +461,9 @@ func (c *common) runCleanup() {
447461		if  cleanup  ==  nil  {
448462			return 
449463		}
464+ 		if  c .cancelCtx  !=  nil  {
465+ 			c .cancelCtx ()
466+ 		}
450467		cleanup ()
451468	}
452469}
@@ -488,12 +505,15 @@ func (t *T) Run(name string, f func(t *T)) bool {
488505	}
489506
490507	// Create a subtest. 
508+ 	ctx , cancelCtx  :=  context .WithCancel (context .Background ())
491509	sub  :=  T {
492510		common : common {
493- 			output : & logger {logToStdout : flagVerbose },
494- 			name :   testName ,
495- 			parent : & t .common ,
496- 			level :  t .level  +  1 ,
511+ 			output :    & logger {logToStdout : flagVerbose },
512+ 			name :      testName ,
513+ 			parent :    & t .common ,
514+ 			level :     t .level  +  1 ,
515+ 			ctx :       ctx ,
516+ 			cancelCtx : cancelCtx ,
497517		},
498518		context : t .context ,
499519	}
@@ -606,9 +626,12 @@ func runTests(matchString func(pat, str string) (bool, error), tests []InternalT
606626	ok  =  true 
607627
608628	ctx  :=  newTestContext (newMatcher (matchString , flagRunRegexp , "-test.run" , flagSkipRegexp ))
629+ 	runCtx , cancelCtx  :=  context .WithCancel (context .Background ())
609630	t  :=  & T {
610631		common : common {
611- 			output : & logger {logToStdout : flagVerbose },
632+ 			output :    & logger {logToStdout : flagVerbose },
633+ 			ctx :       runCtx ,
634+ 			cancelCtx : cancelCtx ,
612635		},
613636		context : ctx ,
614637	}
0 commit comments