Skip to content

Commit

Permalink
refactor: imagerunner global context (#992)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexplischke authored Dec 12, 2024
1 parent 65e3c90 commit 67274d0
Showing 1 changed file with 17 additions and 50 deletions.
67 changes: 17 additions & 50 deletions internal/saucecloud/imagerunner.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"fmt"
"io"
"os"
"os/signal"
"path/filepath"
"reflect"
"strings"
Expand Down Expand Up @@ -54,9 +53,6 @@ type ImgRunner struct {

Async bool
AsyncEventManager imagerunner.AsyncEventManager

ctx context.Context
cancel context.CancelFunc
}

func NewImgRunner(project imagerunner.Project, runnerService ImageRunner, tunnelService tunnel.Service,
Expand Down Expand Up @@ -101,14 +97,7 @@ func (r *ImgRunner) RunProject(ctx context.Context) (int, error) {
return 0, nil
}

ctx, cancel := context.WithCancel(context.Background())
r.ctx = ctx
r.cancel = cancel

sigChan := r.registerInterruptOnSignal()
defer unregisterSignalCapture(sigChan)

suites, results := r.createWorkerPool(r.Project.Sauce.Concurrency, 0)
suites, results := r.createWorkerPool(ctx, r.Project.Sauce.Concurrency, 0)

// Submit suites to work on.
go func() {
Expand All @@ -117,26 +106,26 @@ func (r *ImgRunner) RunProject(ctx context.Context) (int, error) {
}
}()

if passed := r.collectResults(results, len(r.Project.Suites)); !passed {
if passed := r.collectResults(ctx, results, len(r.Project.Suites)); !passed {
return 1, nil
}

return 0, nil
}

func (r *ImgRunner) createWorkerPool(ccy int, maxRetries int) (chan imagerunner.Suite, chan execResult) {
func (r *ImgRunner) createWorkerPool(ctx context.Context, ccy int, maxRetries int) (chan imagerunner.Suite, chan execResult) {
suites := make(chan imagerunner.Suite, maxRetries+1)
results := make(chan execResult, ccy)

log.Info().Int("concurrency", ccy).Msg("Launching workers.")
for i := 0; i < ccy; i++ {
go r.runSuites(suites, results)
go r.runSuites(ctx, suites, results)
}

return suites, results
}

func (r *ImgRunner) runSuites(suites chan imagerunner.Suite, results chan<- execResult) {
func (r *ImgRunner) runSuites(ctx context.Context, suites chan imagerunner.Suite, results chan<- execResult) {
for suite := range suites {
// Apply defaults.
defaults := r.Project.Defaults
Expand All @@ -160,7 +149,7 @@ func (r *ImgRunner) runSuites(suites chan imagerunner.Suite, results chan<- exec

startTime := time.Now()

if r.ctx.Err() != nil {
if ctx.Err() != nil {
results <- execResult{
name: suite.Name,
startTime: startTime,
Expand All @@ -172,7 +161,7 @@ func (r *ImgRunner) runSuites(suites chan imagerunner.Suite, results chan<- exec
continue
}

run, err := r.runSuite(suite)
run, err := r.runSuite(ctx, suite)

endTime := time.Now()
duration := time.Since(startTime)
Expand Down Expand Up @@ -226,7 +215,7 @@ func (r *ImgRunner) buildService(serviceIn imagerunner.SuiteService, suiteName s
return serviceOut, nil
}

func (r *ImgRunner) runSuite(suite imagerunner.Suite) (imagerunner.Runner, error) {
func (r *ImgRunner) runSuite(ctx context.Context, suite imagerunner.Suite) (imagerunner.Runner, error) {
files, err := mapFiles(suite.Files)
if err != nil {
log.Err(err).Str("suite", suite.Name).Msg("Unable to read source files")
Expand All @@ -242,7 +231,7 @@ func (r *ImgRunner) runSuite(suite imagerunner.Suite) (imagerunner.Runner, error
suite.Timeout = 24 * time.Hour
}

ctx, cancel := context.WithTimeout(r.ctx, suite.Timeout)
ctx, cancel := context.WithTimeout(ctx, suite.Timeout)
defer cancel()

var auth *imagerunner.Auth
Expand Down Expand Up @@ -359,11 +348,11 @@ func (r *ImgRunner) getTunnel() *imagerunner.Tunnel {
}
}

func (r *ImgRunner) collectResults(results chan execResult, expected int) bool {
func (r *ImgRunner) collectResults(ctx context.Context, results chan execResult, expected int) bool {
inProgress := expected
passed := true

stopProgress := r.startProgressTicker(r.ctx, &inProgress)
stopProgress := r.startProgressTicker(ctx, &inProgress)
for i := 0; i < expected; i++ {
res := <-results
inProgress--
Expand All @@ -380,9 +369,9 @@ func (r *ImgRunner) collectResults(results chan execResult, expected int) bool {
} else {
if !r.Project.LiveLogs {
// only print logs if live logs are disabled
r.PrintLogs(res.runID, res.name)
r.PrintLogs(ctx, res.runID, res.name)
}
files := r.DownloadArtifacts(res.runID, res.name, res.status, res.err != nil)
files := r.DownloadArtifacts(ctx, res.runID, res.name, res.status, res.err != nil)
for _, f := range files {
artifacts = append(artifacts, report.Artifact{FilePath: f})
}
Expand Down Expand Up @@ -417,28 +406,6 @@ func (r *ImgRunner) collectResults(results chan execResult, expected int) bool {
return passed
}

func (r *ImgRunner) registerInterruptOnSignal() chan os.Signal {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt)

go func() {
for {
sig := <-sigChan
if sig == nil {
return
}
if r.ctx.Err() == nil {
r.cancel()
println("\nStopping run. Cancelling all suites in progress... (press Ctrl-c again to exit without waiting)\n")
} else {
os.Exit(1)
}
}
}()

return sigChan
}

func (r *ImgRunner) PollRun(ctx context.Context, id string, lastStatus string) (imagerunner.Runner, error) {
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
Expand All @@ -465,7 +432,7 @@ func (r *ImgRunner) PollRun(ctx context.Context, id string, lastStatus string) (

// DownloadArtifacts downloads a zipped archive of artifacts
// and extracts the required files.
func (r *ImgRunner) DownloadArtifacts(runnerID, suiteName, status string, passed bool) []string {
func (r *ImgRunner) DownloadArtifacts(ctx context.Context, runnerID, suiteName, status string, passed bool) []string {
if r.Async ||
runnerID == "" ||
status == imagerunner.StateCancelled ||
Expand All @@ -480,7 +447,7 @@ func (r *ImgRunner) DownloadArtifacts(runnerID, suiteName, status string, passed
}

log.Info().Msg("Downloading artifacts archive")
reader, err := r.RunnerService.DownloadArtifacts(r.ctx, runnerID)
reader, err := r.RunnerService.DownloadArtifacts(ctx, runnerID)
if err != nil {
log.Err(err).Str("suite", suiteName).Msg("Failed to fetch artifacts.")
return nil
Expand Down Expand Up @@ -534,13 +501,13 @@ func (r *ImgRunner) PrintResult(res execResult) {
logEvent.Msg("Suite finished.")
}

func (r *ImgRunner) PrintLogs(runID, suiteName string) {
func (r *ImgRunner) PrintLogs(ctx context.Context, runID, suiteName string) {
if r.Async || runID == "" {
return
}

// Need a poll timeout, because artifacts may never exist.
ctx, cancel := context.WithTimeout(r.ctx, 3*time.Minute)
ctx, cancel := context.WithTimeout(ctx, 3*time.Minute)
defer cancel()

logs, err := r.PollLogs(ctx, runID)
Expand Down

0 comments on commit 67274d0

Please sign in to comment.