Skip to content

Commit

Permalink
Ensure free ports are unique
Browse files Browse the repository at this point in the history
  • Loading branch information
nevillelyh committed Dec 12, 2024
1 parent 48272da commit c3c9779
Showing 1 changed file with 24 additions and 10 deletions.
34 changes: 24 additions & 10 deletions internal/tests/cog_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,32 @@ import (
"github.com/replicate/cog-runtime/internal/server"
)

type PortFinder struct {
ports map[int]bool
mu sync.Mutex
}

func (f *PortFinder) Get() int {
f.mu.Lock()
defer f.mu.Unlock()
for {
a := must.Get(net.ResolveTCPAddr("tcp", "localhost:0"))
l := must.Get(net.ListenTCP("tcp", a))
p := l.Addr().(*net.TCPAddr).Port
if _, ok := f.ports[p]; !ok {
f.ports[p] = true
l.Close()
return p
}
}
}

var (
_, b, _, _ = runtime.Caller(0)
basePath = path.Dir(path.Dir(path.Dir(b)))
logger = logging.New("cog-test")
legacyCog = flag.Bool("legacy-cog", false, "Test with legacy Cog")
portFinder = PortFinder{ports: make(map[int]bool)}
)

type WebhookRequest struct {
Expand Down Expand Up @@ -128,7 +149,7 @@ func (ct *CogTest) Start() error {
func (ct *CogTest) runtimeCmd() *exec.Cmd {
pathEnv := path.Join(basePath, "python", ".venv", "bin")
pythonPathEnv := path.Join(basePath, "python")
ct.serverPort = getFreePort()
ct.serverPort = portFinder.Get()
args := []string{
"run", path.Join(basePath, "cmd", "cog-server", "main.go"),
"--module-name", fmt.Sprintf("tests.runners.%s", ct.module),
Expand All @@ -153,7 +174,7 @@ func (ct *CogTest) legacyCmd() *exec.Cmd {
must.Do(os.Symlink(path.Join(runnersPath, "cog.yaml"), path.Join(tmpDir, "cog.yaml")))
must.Do(os.Symlink(path.Join(runnersPath, module), path.Join(tmpDir, "predict.py")))
pythonBin := path.Join(basePath, "python", ".venv-legacy", "bin", "python3")
ct.serverPort = getFreePort()
ct.serverPort = portFinder.Get()
args := []string{
"-m", "cog.server.http",
}
Expand All @@ -175,7 +196,7 @@ func (ct *CogTest) Cleanup() error {

func (ct *CogTest) StartWebhook() {
log := logger.Sugar()
ct.webhookPort = getFreePort()
ct.webhookPort = portFinder.Get()
ct.webhookServer = &http.Server{
Addr: fmt.Sprintf(":%d", ct.webhookPort),
Handler: &WebhookHandler{},
Expand Down Expand Up @@ -324,10 +345,3 @@ func (ct *CogTest) AssertResponse(
assert.Equal(ct.t, output, response.Output)
assert.Equal(ct.t, logs, response.Logs)
}

func getFreePort() int {
a := must.Get(net.ResolveTCPAddr("tcp", "localhost:0"))
l := must.Get(net.ListenTCP("tcp", a))
defer l.Close()
return l.Addr().(*net.TCPAddr).Port
}

0 comments on commit c3c9779

Please sign in to comment.