Skip to content

Commit 73367e2

Browse files
committed
add test
1 parent a7c9bce commit 73367e2

File tree

6 files changed

+658
-38
lines changed

6 files changed

+658
-38
lines changed

pkg/engine/internal/worker/thread.go

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,6 @@ func (t *thread) runJob(ctx context.Context, job *threadJob) {
162162

163163
ctx, capture := xcap.NewCapture(ctx, nil)
164164
pipeline := executor.Run(ctx, cfg, job.Task.Fragment, logger)
165-
defer pipeline.Close()
166165

167166
err := job.Scheduler.SendMessageAsync(ctx, wire.TaskStatusMessage{
168167
ID: job.Task.ULID,
@@ -175,21 +174,57 @@ func (t *thread) runJob(ctx context.Context, job *threadJob) {
175174
}
176175

177176
var totalRows int
177+
totalRows, err = t.drainPipeline(ctx, pipeline, job, logger)
178+
if err != nil {
179+
level.Warn(logger).Log("msg", "task failed", "err", err)
180+
_ = job.Scheduler.SendMessageAsync(ctx, wire.TaskStatusMessage{
181+
ID: job.Task.ULID,
182+
Status: workflow.TaskStatus{
183+
State: workflow.TaskStateFailed,
184+
Error: err,
185+
},
186+
})
187+
188+
pipeline.Close()
189+
return
190+
}
191+
192+
// Close before ending capture to ensure all observations are recorded.
193+
pipeline.Close()
194+
capture.End()
195+
196+
// Finally, close all sinks.
197+
for _, sink := range job.Sinks {
198+
err := sink.Close(ctx)
199+
if err != nil {
200+
level.Warn(logger).Log("msg", "failed to close sink", "err", err)
201+
}
202+
}
203+
204+
// TODO(rfratto): We should find a way to expose queue time here.
205+
result := statsCtx.Result(time.Since(startTime), 0, totalRows)
206+
level.Info(logger).Log("msg", "task completed", "duration", time.Since(startTime))
207+
208+
// Wait for the scheduler to confirm the task has completed before
209+
// requesting a new one. This allows the scheduler to update its bookkeeping
210+
// for how many threads have capacity for requesting tasks.
211+
err = job.Scheduler.SendMessage(ctx, wire.TaskStatusMessage{
212+
ID: job.Task.ULID,
213+
Status: workflow.TaskStatus{State: workflow.TaskStateCompleted, Statistics: &result, Capture: capture},
214+
})
215+
if err != nil {
216+
level.Warn(logger).Log("msg", "failed to inform scheduler of task status", "err", err)
217+
}
218+
}
178219

220+
func (t *thread) drainPipeline(ctx context.Context, pipeline executor.Pipeline, job *threadJob, logger log.Logger) (int, error) {
221+
var totalRows int
179222
for {
180223
rec, err := pipeline.Read(ctx)
181224
if err != nil && errors.Is(err, executor.EOF) {
182225
break
183226
} else if err != nil {
184-
level.Warn(logger).Log("msg", "task failed", "err", err)
185-
_ = job.Scheduler.SendMessageAsync(ctx, wire.TaskStatusMessage{
186-
ID: job.Task.ULID,
187-
Status: workflow.TaskStatus{
188-
State: workflow.TaskStateFailed,
189-
Error: err,
190-
},
191-
})
192-
return
227+
return totalRows, err
193228
}
194229

195230
totalRows += int(rec.NumRows())
@@ -211,28 +246,6 @@ func (t *thread) runJob(ctx context.Context, job *threadJob) {
211246
}
212247
}
213248
}
214-
capture.End()
215-
216-
// Finally, close all sinks.
217-
for _, sink := range job.Sinks {
218-
err := sink.Close(ctx)
219-
if err != nil {
220-
level.Warn(logger).Log("msg", "failed to close sink", "err", err)
221-
}
222-
}
223-
224-
// TODO(rfratto): We should find a way to expose queue time here.
225-
result := statsCtx.Result(time.Since(startTime), 0, totalRows)
226-
level.Info(logger).Log("msg", "task completed", "duration", time.Since(startTime))
227249

228-
// Wait for the scheduler to confirm the task has completed before
229-
// requesting a new one. This allows the scheduler to update its bookkeeping
230-
// for how many threads have capacity for requesting tasks.
231-
err = job.Scheduler.SendMessage(ctx, wire.TaskStatusMessage{
232-
ID: job.Task.ULID,
233-
Status: workflow.TaskStatus{State: workflow.TaskStateCompleted, Statistics: &result, Capture: capture},
234-
})
235-
if err != nil {
236-
level.Warn(logger).Log("msg", "failed to inform scheduler of task status", "err", err)
237-
}
250+
return totalRows, nil
238251
}

pkg/xcap/aggregation_test.go

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
package xcap
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
)
8+
9+
func TestAggregatedObservation_Record(t *testing.T) {
10+
t.Run("sum aggregation int64", func(t *testing.T) {
11+
stat := NewStatisticInt64("value", AggregationTypeSum)
12+
agg := &AggregatedObservation{
13+
Statistic: stat,
14+
Value: int64(0),
15+
Count: 0,
16+
}
17+
18+
agg.Record(stat.Observe(10))
19+
agg.Record(stat.Observe(20))
20+
agg.Record(stat.Observe(30))
21+
22+
require.Equal(t, 3, agg.Count)
23+
require.Equal(t, int64(60), agg.Value.(int64))
24+
})
25+
26+
t.Run("sum aggregation float64", func(t *testing.T) {
27+
stat := NewStatisticFloat64("value", AggregationTypeSum)
28+
agg := &AggregatedObservation{
29+
Statistic: stat,
30+
Value: float64(0),
31+
Count: 0,
32+
}
33+
34+
agg.Record(stat.Observe(10.5))
35+
agg.Record(stat.Observe(20.3))
36+
agg.Record(stat.Observe(30.2))
37+
38+
require.Equal(t, 3, agg.Count)
39+
require.InDelta(t, 61.0, agg.Value.(float64), 0.001)
40+
})
41+
42+
t.Run("min aggregation int64", func(t *testing.T) {
43+
stat := NewStatisticInt64("value", AggregationTypeMin)
44+
agg := &AggregatedObservation{
45+
Statistic: stat,
46+
Value: int64(30),
47+
Count: 0,
48+
}
49+
50+
agg.Record(stat.Observe(30))
51+
agg.Record(stat.Observe(10))
52+
agg.Record(stat.Observe(20))
53+
54+
require.Equal(t, 3, agg.Count)
55+
require.Equal(t, int64(10), agg.Value.(int64))
56+
})
57+
58+
t.Run("min aggregation float64", func(t *testing.T) {
59+
stat := NewStatisticFloat64("value", AggregationTypeMin)
60+
agg := &AggregatedObservation{
61+
Statistic: stat,
62+
Value: float64(30.5),
63+
Count: 0,
64+
}
65+
66+
agg.Record(stat.Observe(30.5))
67+
agg.Record(stat.Observe(10.2))
68+
agg.Record(stat.Observe(20.8))
69+
70+
require.Equal(t, 3, agg.Count)
71+
require.Equal(t, float64(10.2), agg.Value.(float64))
72+
})
73+
74+
t.Run("max aggregation int64", func(t *testing.T) {
75+
stat := NewStatisticInt64("value", AggregationTypeMax)
76+
agg := &AggregatedObservation{
77+
Statistic: stat,
78+
Value: int64(10),
79+
Count: 0,
80+
}
81+
82+
agg.Record(stat.Observe(10))
83+
agg.Record(stat.Observe(30))
84+
agg.Record(stat.Observe(20))
85+
86+
require.Equal(t, 3, agg.Count)
87+
require.Equal(t, int64(30), agg.Value.(int64))
88+
})
89+
90+
t.Run("max aggregation float64", func(t *testing.T) {
91+
stat := NewStatisticFloat64("value", AggregationTypeMax)
92+
agg := &AggregatedObservation{
93+
Statistic: stat,
94+
Value: float64(10.0),
95+
Count: 0,
96+
}
97+
98+
agg.Record(stat.Observe(10.0))
99+
agg.Record(stat.Observe(30.5))
100+
agg.Record(stat.Observe(20.8))
101+
102+
require.Equal(t, 3, agg.Count)
103+
require.Equal(t, float64(30.5), agg.Value.(float64))
104+
})
105+
106+
t.Run("max aggregation bool flag", func(t *testing.T) {
107+
stat := NewStatisticFlag("success")
108+
agg := &AggregatedObservation{
109+
Statistic: stat,
110+
Value: false,
111+
Count: 0,
112+
}
113+
114+
agg.Record(stat.Observe(false))
115+
agg.Record(stat.Observe(false))
116+
agg.Record(stat.Observe(true))
117+
118+
require.Equal(t, 3, agg.Count)
119+
require.Equal(t, true, agg.Value.(bool))
120+
})
121+
122+
t.Run("last aggregation int64", func(t *testing.T) {
123+
stat := NewStatisticInt64("value", AggregationTypeLast)
124+
agg := &AggregatedObservation{
125+
Statistic: stat,
126+
Value: int64(0),
127+
Count: 0,
128+
}
129+
130+
agg.Record(stat.Observe(10))
131+
agg.Record(stat.Observe(20))
132+
agg.Record(stat.Observe(30))
133+
134+
require.Equal(t, 3, agg.Count)
135+
require.Equal(t, int64(30), agg.Value.(int64))
136+
})
137+
138+
t.Run("first aggregation int64", func(t *testing.T) {
139+
stat := NewStatisticInt64("value", AggregationTypeFirst)
140+
agg := &AggregatedObservation{
141+
Statistic: stat,
142+
Value: nil,
143+
Count: 0,
144+
}
145+
146+
agg.Record(stat.Observe(10))
147+
agg.Record(stat.Observe(20))
148+
agg.Record(stat.Observe(30))
149+
150+
require.Equal(t, 3, agg.Count)
151+
require.Equal(t, int64(10), agg.Value.(int64))
152+
})
153+
}

pkg/xcap/capture.go

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,9 @@ func (c *Capture) getAllStatistics() []Statistic {
103103
return result
104104
}
105105

106-
// MergeCaptures returns a new [Capture] by merging all regions from the input captures.
107-
//
108-
// If linkByAttribute is non-empty, MergeCaptures will establish parent-child relationships
109-
// between regions that don't already have a parent. For each region without a parent:
106+
// LinkRegions links root regions based on the provided link attribute and resolveParent function.
110107
// - It extracts the value of the linkByAttribute (must be a string attribute)
111-
// - It calls resolveParent() with that value to determine the parent's attribute value
108+
// - Calls resolveParent() with that value to determine the parent's attribute value
112109
// - It finds the region with the matching attribute value and sets it as the parent
113110
//
114111
// Use a linkByAttribute that is unique for each region.

0 commit comments

Comments
 (0)