From c3c97798b3fba0d2ac8d3f1094060c024be4a4cf Mon Sep 17 00:00:00 2001 From: Neville Li Date: Wed, 11 Dec 2024 20:20:55 -0500 Subject: [PATCH] Ensure free ports are unique --- internal/tests/cog_test.go | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/internal/tests/cog_test.go b/internal/tests/cog_test.go index fafd239..eb578fc 100644 --- a/internal/tests/cog_test.go +++ b/internal/tests/cog_test.go @@ -29,11 +29,32 @@ import ( "github.com/replicate/cog-runtime/internal/server" ) +type PortFinder struct { + ports map[int]bool + mu sync.Mutex +} + +func (f *PortFinder) Get() int { + f.mu.Lock() + defer f.mu.Unlock() + for { + a := must.Get(net.ResolveTCPAddr("tcp", "localhost:0")) + l := must.Get(net.ListenTCP("tcp", a)) + p := l.Addr().(*net.TCPAddr).Port + if _, ok := f.ports[p]; !ok { + f.ports[p] = true + l.Close() + return p + } + } +} + var ( _, b, _, _ = runtime.Caller(0) basePath = path.Dir(path.Dir(path.Dir(b))) logger = logging.New("cog-test") legacyCog = flag.Bool("legacy-cog", false, "Test with legacy Cog") + portFinder = PortFinder{ports: make(map[int]bool)} ) type WebhookRequest struct { @@ -128,7 +149,7 @@ func (ct *CogTest) Start() error { func (ct *CogTest) runtimeCmd() *exec.Cmd { pathEnv := path.Join(basePath, "python", ".venv", "bin") pythonPathEnv := path.Join(basePath, "python") - ct.serverPort = getFreePort() + ct.serverPort = portFinder.Get() args := []string{ "run", path.Join(basePath, "cmd", "cog-server", "main.go"), "--module-name", fmt.Sprintf("tests.runners.%s", ct.module), @@ -153,7 +174,7 @@ func (ct *CogTest) legacyCmd() *exec.Cmd { must.Do(os.Symlink(path.Join(runnersPath, "cog.yaml"), path.Join(tmpDir, "cog.yaml"))) must.Do(os.Symlink(path.Join(runnersPath, module), path.Join(tmpDir, "predict.py"))) pythonBin := path.Join(basePath, "python", ".venv-legacy", "bin", "python3") - ct.serverPort = getFreePort() + ct.serverPort = portFinder.Get() args := []string{ "-m", "cog.server.http", } @@ -175,7 +196,7 @@ func (ct *CogTest) Cleanup() error { func (ct *CogTest) StartWebhook() { log := logger.Sugar() - ct.webhookPort = getFreePort() + ct.webhookPort = portFinder.Get() ct.webhookServer = &http.Server{ Addr: fmt.Sprintf(":%d", ct.webhookPort), Handler: &WebhookHandler{}, @@ -324,10 +345,3 @@ func (ct *CogTest) AssertResponse( assert.Equal(ct.t, output, response.Output) assert.Equal(ct.t, logs, response.Logs) } - -func getFreePort() int { - a := must.Get(net.ResolveTCPAddr("tcp", "localhost:0")) - l := must.Get(net.ListenTCP("tcp", a)) - defer l.Close() - return l.Addr().(*net.TCPAddr).Port -}