diff --git a/go.mod b/go.mod index ee68b54cf..2c0f3430e 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/mozilla-services/autograph go 1.22.1 require ( - github.com/DataDog/datadog-go v3.7.2+incompatible + github.com/DataDog/datadog-go v4.8.3+incompatible github.com/aws/aws-lambda-go v1.47.0 github.com/aws/aws-sdk-go-v2 v1.32.6 github.com/aws/aws-sdk-go-v2/config v1.28.6 @@ -39,6 +39,7 @@ require ( github.com/Azure/go-autorest/autorest/validation v0.2.0 // indirect github.com/Azure/go-autorest/logger v0.1.0 // indirect github.com/Azure/go-autorest/tracing v0.5.0 // indirect + github.com/Microsoft/go-winio v0.6.2 // indirect github.com/aws/aws-sdk-go v1.42.15 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.17.47 // indirect diff --git a/go.sum b/go.sum index 6592d3ee2..b7bb11e83 100644 --- a/go.sum +++ b/go.sum @@ -45,8 +45,10 @@ github.com/Azure/go-autorest/tracing v0.5.0 h1:TRn4WjSnkcSy5AEG3pnbtFSwNtwzjr4VY github.com/Azure/go-autorest/tracing v0.5.0/go.mod h1:r/s2XiOKccPW3HrqB+W0TQzfbtp2fGCgRFtBroKn4Dk= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= -github.com/DataDog/datadog-go v3.7.2+incompatible h1:o4QtYjBU/rG58VPh8Ne6F65YiMY5/v5q4WdY/HvRYMQ= -github.com/DataDog/datadog-go v3.7.2+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/DataDog/datadog-go v4.8.3+incompatible h1:fNGaYSuObuQb5nzeTQqowRAd9bpDIRRV4/gUtIBjh8Q= +github.com/DataDog/datadog-go v4.8.3+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo= github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= diff --git a/vendor/github.com/DataDog/datadog-go/statsd/aggregator.go b/vendor/github.com/DataDog/datadog-go/statsd/aggregator.go index f63e3239c..c4446a23b 100644 --- a/vendor/github.com/DataDog/datadog-go/statsd/aggregator.go +++ b/vendor/github.com/DataDog/datadog-go/statsd/aggregator.go @@ -3,39 +3,67 @@ package statsd import ( "strings" "sync" + "sync/atomic" "time" ) type ( - countsMap map[string]*countMetric - gaugesMap map[string]*gaugeMetric - setsMap map[string]*setMetric + countsMap map[string]*countMetric + gaugesMap map[string]*gaugeMetric + setsMap map[string]*setMetric + bufferedMetricMap map[string]*bufferedMetric ) type aggregator struct { - client *Client + nbContextGauge int32 + nbContextCount int32 + nbContextSet int32 - counts countsMap countsM sync.RWMutex - - gauges gaugesMap gaugesM sync.RWMutex + setsM sync.RWMutex - sets setsMap - setsM sync.RWMutex + gauges gaugesMap + counts countsMap + sets setsMap + histograms bufferedMetricContexts + distributions bufferedMetricContexts + timings bufferedMetricContexts closed chan struct{} - exited chan struct{} + + client *Client + + // aggregator implements ChannelMode mechanism to receive histograms, + // distributions and timings. Since they need sampling they need to + // lock for random. When using both ChannelMode and ExtendedAggregation + // we don't want goroutine to fight over the lock. + inputMetrics chan metric + stopChannelMode chan struct{} + wg sync.WaitGroup +} + +type aggregatorMetrics struct { + nbContext int32 + nbContextGauge int32 + nbContextCount int32 + nbContextSet int32 + nbContextHistogram int32 + nbContextDistribution int32 + nbContextTiming int32 } func newAggregator(c *Client) *aggregator { return &aggregator{ - client: c, - counts: countsMap{}, - gauges: gaugesMap{}, - sets: setsMap{}, - closed: make(chan struct{}), - exited: make(chan struct{}), + client: c, + counts: countsMap{}, + gauges: gaugesMap{}, + sets: setsMap{}, + histograms: newBufferedContexts(newHistogramMetric), + distributions: newBufferedContexts(newDistributionMetric), + timings: newBufferedContexts(newTimingMetric), + closed: make(chan struct{}), + stopChannelMode: make(chan struct{}), } } @@ -46,25 +74,72 @@ func (a *aggregator) start(flushInterval time.Duration) { for { select { case <-ticker.C: - a.sendMetrics() + a.flush() case <-a.closed: - close(a.exited) return } } }() } -func (a *aggregator) sendMetrics() { - for _, m := range a.flushMetrics() { - a.client.send(m) +func (a *aggregator) startReceivingMetric(bufferSize int, nbWorkers int) { + a.inputMetrics = make(chan metric, bufferSize) + for i := 0; i < nbWorkers; i++ { + a.wg.Add(1) + go a.pullMetric() } } +func (a *aggregator) stopReceivingMetric() { + close(a.stopChannelMode) + a.wg.Wait() +} + func (a *aggregator) stop() { - close(a.closed) - <-a.exited - a.sendMetrics() + a.closed <- struct{}{} +} + +func (a *aggregator) pullMetric() { + for { + select { + case m := <-a.inputMetrics: + switch m.metricType { + case histogram: + a.histogram(m.name, m.fvalue, m.tags, m.rate) + case distribution: + a.distribution(m.name, m.fvalue, m.tags, m.rate) + case timing: + a.timing(m.name, m.fvalue, m.tags, m.rate) + } + case <-a.stopChannelMode: + a.wg.Done() + return + } + } +} + +func (a *aggregator) flush() { + for _, m := range a.flushMetrics() { + a.client.sendBlocking(m) + } +} + +func (a *aggregator) flushTelemetryMetrics() *aggregatorMetrics { + if a == nil { + return nil + } + + am := &aggregatorMetrics{ + nbContextGauge: atomic.SwapInt32(&a.nbContextGauge, 0), + nbContextCount: atomic.SwapInt32(&a.nbContextCount, 0), + nbContextSet: atomic.SwapInt32(&a.nbContextSet, 0), + nbContextHistogram: a.histograms.resetAndGetNbContext(), + nbContextDistribution: a.distributions.resetAndGetNbContext(), + nbContextTiming: a.timings.resetAndGetNbContext(), + } + + am.nbContext = am.nbContextGauge + am.nbContextCount + am.nbContextSet + am.nbContextHistogram + am.nbContextDistribution + am.nbContextTiming + return am } func (a *aggregator) flushMetrics() []metric { @@ -91,23 +166,35 @@ func (a *aggregator) flushMetrics() []metric { metrics = append(metrics, g.flushUnsafe()) } - a.countsM.RLock() + a.countsM.Lock() counts := a.counts a.counts = countsMap{} - a.countsM.RUnlock() + a.countsM.Unlock() for _, c := range counts { metrics = append(metrics, c.flushUnsafe()) } + metrics = a.histograms.flush(metrics) + metrics = a.distributions.flush(metrics) + metrics = a.timings.flush(metrics) + + atomic.AddInt32(&a.nbContextCount, int32(len(counts))) + atomic.AddInt32(&a.nbContextGauge, int32(len(gauges))) + atomic.AddInt32(&a.nbContextSet, int32(len(sets))) return metrics } func getContext(name string, tags []string) string { - return name + ":" + strings.Join(tags, ",") + return name + ":" + strings.Join(tags, tagSeparatorSymbol) } -func (a *aggregator) count(name string, value int64, tags []string, rate float64) error { +func getContextAndTags(name string, tags []string) (string, string) { + stringTags := strings.Join(tags, tagSeparatorSymbol) + return name + ":" + stringTags, stringTags +} + +func (a *aggregator) count(name string, value int64, tags []string) error { context := getContext(name, tags) a.countsM.RLock() if count, found := a.counts[context]; found { @@ -118,12 +205,19 @@ func (a *aggregator) count(name string, value int64, tags []string, rate float64 a.countsM.RUnlock() a.countsM.Lock() - a.counts[context] = newCountMetric(name, value, tags, rate) + // Check if another goroutines hasn't created the value betwen the RUnlock and 'Lock' + if count, found := a.counts[context]; found { + count.sample(value) + a.countsM.Unlock() + return nil + } + + a.counts[context] = newCountMetric(name, value, tags) a.countsM.Unlock() return nil } -func (a *aggregator) gauge(name string, value float64, tags []string, rate float64) error { +func (a *aggregator) gauge(name string, value float64, tags []string) error { context := getContext(name, tags) a.gaugesM.RLock() if gauge, found := a.gauges[context]; found { @@ -133,15 +227,21 @@ func (a *aggregator) gauge(name string, value float64, tags []string, rate float } a.gaugesM.RUnlock() - gauge := newGaugeMetric(name, value, tags, rate) + gauge := newGaugeMetric(name, value, tags) a.gaugesM.Lock() + // Check if another goroutines hasn't created the value betwen the 'RUnlock' and 'Lock' + if gauge, found := a.gauges[context]; found { + gauge.sample(value) + a.gaugesM.Unlock() + return nil + } a.gauges[context] = gauge a.gaugesM.Unlock() return nil } -func (a *aggregator) set(name string, value string, tags []string, rate float64) error { +func (a *aggregator) set(name string, value string, tags []string) error { context := getContext(name, tags) a.setsM.RLock() if set, found := a.sets[context]; found { @@ -152,7 +252,32 @@ func (a *aggregator) set(name string, value string, tags []string, rate float64) a.setsM.RUnlock() a.setsM.Lock() - a.sets[context] = newSetMetric(name, value, tags, rate) + // Check if another goroutines hasn't created the value betwen the 'RUnlock' and 'Lock' + if set, found := a.sets[context]; found { + set.sample(value) + a.setsM.Unlock() + return nil + } + a.sets[context] = newSetMetric(name, value, tags) a.setsM.Unlock() return nil } + +// Only histograms, distributions and timings are sampled with a rate since we +// only pack them in on message instead of aggregating them. Discarding the +// sample rate will have impacts on the CPU and memory usage of the Agent. + +// type alias for Client.sendToAggregator +type bufferedMetricSampleFunc func(name string, value float64, tags []string, rate float64) error + +func (a *aggregator) histogram(name string, value float64, tags []string, rate float64) error { + return a.histograms.sample(name, value, tags, rate) +} + +func (a *aggregator) distribution(name string, value float64, tags []string, rate float64) error { + return a.distributions.sample(name, value, tags, rate) +} + +func (a *aggregator) timing(name string, value float64, tags []string, rate float64) error { + return a.timings.sample(name, value, tags, rate) +} diff --git a/vendor/github.com/DataDog/datadog-go/statsd/buffer.go b/vendor/github.com/DataDog/datadog-go/statsd/buffer.go index c38e229ab..34dc0d3fb 100644 --- a/vendor/github.com/DataDog/datadog-go/statsd/buffer.go +++ b/vendor/github.com/DataDog/datadog-go/statsd/buffer.go @@ -1,11 +1,21 @@ package statsd +import ( + "strconv" +) + type bufferFullError string func (e bufferFullError) Error() string { return string(e) } const errBufferFull = bufferFullError("statsd buffer is full") +type partialWriteError string + +func (e partialWriteError) Error() string { return string(e) } + +const errPartialWrite = partialWriteError("value partially written") + const metricOverhead = 512 // statsdBuffer is a buffer containing statsd messages @@ -30,8 +40,8 @@ func (b *statsdBuffer) writeGauge(namespace string, globalTags []string, name st return errBufferFull } originalBuffer := b.buffer - b.writeSeparator() b.buffer = appendGauge(b.buffer, namespace, globalTags, name, value, tags, rate) + b.writeSeparator() return b.validateNewElement(originalBuffer) } @@ -40,8 +50,8 @@ func (b *statsdBuffer) writeCount(namespace string, globalTags []string, name st return errBufferFull } originalBuffer := b.buffer - b.writeSeparator() b.buffer = appendCount(b.buffer, namespace, globalTags, name, value, tags, rate) + b.writeSeparator() return b.validateNewElement(originalBuffer) } @@ -50,18 +60,70 @@ func (b *statsdBuffer) writeHistogram(namespace string, globalTags []string, nam return errBufferFull } originalBuffer := b.buffer - b.writeSeparator() b.buffer = appendHistogram(b.buffer, namespace, globalTags, name, value, tags, rate) + b.writeSeparator() return b.validateNewElement(originalBuffer) } +// writeAggregated serialized as many values as possible in the current buffer and return the position in values where it stopped. +func (b *statsdBuffer) writeAggregated(metricSymbol []byte, namespace string, globalTags []string, name string, values []float64, tags string, tagSize int, precision int) (int, error) { + if b.elementCount >= b.maxElements { + return 0, errBufferFull + } + + originalBuffer := b.buffer + b.buffer = appendHeader(b.buffer, namespace, name) + + // buffer already full + if len(b.buffer)+tagSize > b.maxSize { + b.buffer = originalBuffer + return 0, errBufferFull + } + + // We add as many value as possible + var position int + for idx, v := range values { + previousBuffer := b.buffer + if idx != 0 { + b.buffer = append(b.buffer, ':') + } + + b.buffer = strconv.AppendFloat(b.buffer, v, 'f', precision, 64) + + // Should we stop serializing and switch to another buffer + if len(b.buffer)+tagSize > b.maxSize { + b.buffer = previousBuffer + break + } + position = idx + 1 + } + + // we could not add a single value + if position == 0 { + b.buffer = originalBuffer + return 0, errBufferFull + } + + b.buffer = append(b.buffer, '|') + b.buffer = append(b.buffer, metricSymbol...) + b.buffer = appendTagsAggregated(b.buffer, globalTags, tags) + b.writeSeparator() + b.elementCount++ + + if position != len(values) { + return position, errPartialWrite + } + return position, nil + +} + func (b *statsdBuffer) writeDistribution(namespace string, globalTags []string, name string, value float64, tags []string, rate float64) error { if b.elementCount >= b.maxElements { return errBufferFull } originalBuffer := b.buffer - b.writeSeparator() b.buffer = appendDistribution(b.buffer, namespace, globalTags, name, value, tags, rate) + b.writeSeparator() return b.validateNewElement(originalBuffer) } @@ -70,8 +132,8 @@ func (b *statsdBuffer) writeSet(namespace string, globalTags []string, name stri return errBufferFull } originalBuffer := b.buffer - b.writeSeparator() b.buffer = appendSet(b.buffer, namespace, globalTags, name, value, tags, rate) + b.writeSeparator() return b.validateNewElement(originalBuffer) } @@ -80,8 +142,8 @@ func (b *statsdBuffer) writeTiming(namespace string, globalTags []string, name s return errBufferFull } originalBuffer := b.buffer - b.writeSeparator() b.buffer = appendTiming(b.buffer, namespace, globalTags, name, value, tags, rate) + b.writeSeparator() return b.validateNewElement(originalBuffer) } @@ -90,8 +152,8 @@ func (b *statsdBuffer) writeEvent(event Event, globalTags []string) error { return errBufferFull } originalBuffer := b.buffer - b.writeSeparator() b.buffer = appendEvent(b.buffer, event, globalTags) + b.writeSeparator() return b.validateNewElement(originalBuffer) } @@ -100,8 +162,8 @@ func (b *statsdBuffer) writeServiceCheck(serviceCheck ServiceCheck, globalTags [ return errBufferFull } originalBuffer := b.buffer - b.writeSeparator() b.buffer = appendServiceCheck(b.buffer, serviceCheck, globalTags) + b.writeSeparator() return b.validateNewElement(originalBuffer) } @@ -115,9 +177,7 @@ func (b *statsdBuffer) validateNewElement(originalBuffer []byte) error { } func (b *statsdBuffer) writeSeparator() { - if b.elementCount != 0 { - b.buffer = appendSeparator(b.buffer) - } + b.buffer = append(b.buffer, '\n') } func (b *statsdBuffer) reset() { diff --git a/vendor/github.com/DataDog/datadog-go/statsd/buffered_metric_context.go b/vendor/github.com/DataDog/datadog-go/statsd/buffered_metric_context.go new file mode 100644 index 000000000..15bba2861 --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/statsd/buffered_metric_context.go @@ -0,0 +1,82 @@ +package statsd + +import ( + "math/rand" + "sync" + "sync/atomic" + "time" +) + +// bufferedMetricContexts represent the contexts for Histograms, Distributions +// and Timing. Since those 3 metric types behave the same way and are sampled +// with the same type they're represented by the same class. +type bufferedMetricContexts struct { + nbContext int32 + mutex sync.RWMutex + values bufferedMetricMap + newMetric func(string, float64, string) *bufferedMetric + + // Each bufferedMetricContexts uses its own random source and random + // lock to prevent goroutines from contending for the lock on the + // "math/rand" package-global random source (e.g. calls like + // "rand.Float64()" must acquire a shared lock to get the next + // pseudorandom number). + random *rand.Rand + randomLock sync.Mutex +} + +func newBufferedContexts(newMetric func(string, float64, string) *bufferedMetric) bufferedMetricContexts { + return bufferedMetricContexts{ + values: bufferedMetricMap{}, + newMetric: newMetric, + // Note that calling "time.Now().UnixNano()" repeatedly quickly may return + // very similar values. That's fine for seeding the worker-specific random + // source because we just need an evenly distributed stream of float values. + // Do not use this random source for cryptographic randomness. + random: rand.New(rand.NewSource(time.Now().UnixNano())), + } +} + +func (bc *bufferedMetricContexts) flush(metrics []metric) []metric { + bc.mutex.Lock() + values := bc.values + bc.values = bufferedMetricMap{} + bc.mutex.Unlock() + + for _, d := range values { + metrics = append(metrics, d.flushUnsafe()) + } + atomic.AddInt32(&bc.nbContext, int32(len(values))) + return metrics +} + +func (bc *bufferedMetricContexts) sample(name string, value float64, tags []string, rate float64) error { + if !shouldSample(rate, bc.random, &bc.randomLock) { + return nil + } + + context, stringTags := getContextAndTags(name, tags) + + bc.mutex.RLock() + if v, found := bc.values[context]; found { + v.sample(value) + bc.mutex.RUnlock() + return nil + } + bc.mutex.RUnlock() + + bc.mutex.Lock() + // Check if another goroutines hasn't created the value betwen the 'RUnlock' and 'Lock' + if v, found := bc.values[context]; found { + v.sample(value) + bc.mutex.Unlock() + return nil + } + bc.values[context] = bc.newMetric(name, value, stringTags) + bc.mutex.Unlock() + return nil +} + +func (bc *bufferedMetricContexts) resetAndGetNbContext() int32 { + return atomic.SwapInt32(&bc.nbContext, 0) +} diff --git a/vendor/github.com/DataDog/datadog-go/statsd/event.go b/vendor/github.com/DataDog/datadog-go/statsd/event.go index cf966d1a8..35c535362 100644 --- a/vendor/github.com/DataDog/datadog-go/statsd/event.go +++ b/vendor/github.com/DataDog/datadog-go/statsd/event.go @@ -71,9 +71,6 @@ func (e Event) Check() error { if len(e.Title) == 0 { return fmt.Errorf("statsd.Event title is required") } - if len(e.Text) == 0 { - return fmt.Errorf("statsd.Event text is required") - } return nil } diff --git a/vendor/github.com/DataDog/datadog-go/statsd/format.go b/vendor/github.com/DataDog/datadog-go/statsd/format.go index bd856a048..8d62aa7ba 100644 --- a/vendor/github.com/DataDog/datadog-go/statsd/format.go +++ b/vendor/github.com/DataDog/datadog-go/statsd/format.go @@ -12,6 +12,7 @@ var ( distributionSymbol = []byte("d") setSymbol = []byte("s") timingSymbol = []byte("ms") + tagSeparatorSymbol = "," ) func appendHeader(buffer []byte, namespace string, name string) []byte { @@ -54,14 +55,14 @@ func appendTags(buffer []byte, globalTags []string, tags []string) []byte { for _, tag := range globalTags { if !firstTag { - buffer = append(buffer, ',') + buffer = append(buffer, tagSeparatorSymbol...) } buffer = appendWithoutNewlines(buffer, tag) firstTag = false } for _, tag := range tags { if !firstTag { - buffer = append(buffer, ',') + buffer = append(buffer, tagSeparatorSymbol...) } buffer = appendWithoutNewlines(buffer, tag) firstTag = false @@ -69,6 +70,30 @@ func appendTags(buffer []byte, globalTags []string, tags []string) []byte { return buffer } +func appendTagsAggregated(buffer []byte, globalTags []string, tags string) []byte { + if len(globalTags) == 0 && tags == "" { + return buffer + } + + buffer = append(buffer, "|#"...) + firstTag := true + + for _, tag := range globalTags { + if !firstTag { + buffer = append(buffer, tagSeparatorSymbol...) + } + buffer = appendWithoutNewlines(buffer, tag) + firstTag = false + } + if tags != "" { + if !firstTag { + buffer = append(buffer, tagSeparatorSymbol...) + } + buffer = appendWithoutNewlines(buffer, tags) + } + return buffer +} + func appendFloatMetric(buffer []byte, typeSymbol []byte, namespace string, globalTags []string, name string, value float64, tags []string, rate float64, precision int) []byte { buffer = appendHeader(buffer, namespace, name) buffer = strconv.AppendFloat(buffer, value, 'f', precision, 64) @@ -143,7 +168,7 @@ func appendEvent(buffer []byte, event Event, globalTags []string) []byte { buffer = append(buffer, "_e{"...) buffer = strconv.AppendInt(buffer, int64(len(event.Title)), 10) - buffer = append(buffer, ',') + buffer = append(buffer, tagSeparatorSymbol...) buffer = strconv.AppendInt(buffer, int64(escapedTextLen), 10) buffer = append(buffer, "}:"...) buffer = append(buffer, event.Title...) diff --git a/vendor/github.com/DataDog/datadog-go/statsd/metrics.go b/vendor/github.com/DataDog/datadog-go/statsd/metrics.go index 4f582e978..99ed4da53 100644 --- a/vendor/github.com/DataDog/datadog-go/statsd/metrics.go +++ b/vendor/github.com/DataDog/datadog-go/statsd/metrics.go @@ -17,15 +17,13 @@ type countMetric struct { value int64 name string tags []string - rate float64 } -func newCountMetric(name string, value int64, tags []string, rate float64) *countMetric { +func newCountMetric(name string, value int64, tags []string) *countMetric { return &countMetric{ value: value, name: name, tags: tags, - rate: rate, } } @@ -38,7 +36,7 @@ func (c *countMetric) flushUnsafe() metric { metricType: count, name: c.name, tags: c.tags, - rate: c.rate, + rate: 1, ivalue: c.value, } } @@ -49,15 +47,13 @@ type gaugeMetric struct { value uint64 name string tags []string - rate float64 } -func newGaugeMetric(name string, value float64, tags []string, rate float64) *gaugeMetric { +func newGaugeMetric(name string, value float64, tags []string) *gaugeMetric { return &gaugeMetric{ value: math.Float64bits(value), name: name, tags: tags, - rate: rate, } } @@ -70,7 +66,7 @@ func (g *gaugeMetric) flushUnsafe() metric { metricType: gauge, name: g.name, tags: g.tags, - rate: g.rate, + rate: 1, fvalue: math.Float64frombits(g.value), } } @@ -81,16 +77,14 @@ type setMetric struct { data map[string]struct{} name string tags []string - rate float64 sync.Mutex } -func newSetMetric(name string, value string, tags []string, rate float64) *setMetric { +func newSetMetric(name string, value string, tags []string) *setMetric { set := &setMetric{ data: map[string]struct{}{}, name: name, tags: tags, - rate: rate, } set.data[value] = struct{}{} return set @@ -116,10 +110,72 @@ func (s *setMetric) flushUnsafe() []metric { metricType: set, name: s.name, tags: s.tags, - rate: s.rate, + rate: 1, svalue: value, } i++ } return metrics } + +// Histograms, Distributions and Timings + +type bufferedMetric struct { + sync.Mutex + + data []float64 + name string + // Histograms and Distributions store tags as one string since we need + // to compute its size multiple time when serializing. + tags string + mtype metricType +} + +func (s *bufferedMetric) sample(v float64) { + s.Lock() + defer s.Unlock() + s.data = append(s.data, v) +} + +func (s *bufferedMetric) flushUnsafe() metric { + return metric{ + metricType: s.mtype, + name: s.name, + stags: s.tags, + rate: 1, + fvalues: s.data, + } +} + +type histogramMetric = bufferedMetric + +func newHistogramMetric(name string, value float64, stringTags string) *histogramMetric { + return &histogramMetric{ + data: []float64{value}, + name: name, + tags: stringTags, + mtype: histogramAggregated, + } +} + +type distributionMetric = bufferedMetric + +func newDistributionMetric(name string, value float64, stringTags string) *distributionMetric { + return &distributionMetric{ + data: []float64{value}, + name: name, + tags: stringTags, + mtype: distributionAggregated, + } +} + +type timingMetric = bufferedMetric + +func newTimingMetric(name string, value float64, stringTags string) *timingMetric { + return &timingMetric{ + data: []float64{value}, + name: name, + tags: stringTags, + mtype: timingAggregated, + } +} diff --git a/vendor/github.com/DataDog/datadog-go/statsd/options.go b/vendor/github.com/DataDog/datadog-go/statsd/options.go index e3c4eaedf..9db27039c 100644 --- a/vendor/github.com/DataDog/datadog-go/statsd/options.go +++ b/vendor/github.com/DataDog/datadog-go/statsd/options.go @@ -1,6 +1,7 @@ package statsd import ( + "fmt" "math" "strings" "time" @@ -24,17 +25,23 @@ var ( // DefaultSenderQueueSize is the default value for the DefaultSenderQueueSize option DefaultSenderQueueSize = 0 // DefaultWriteTimeoutUDS is the default value for the WriteTimeoutUDS option - DefaultWriteTimeoutUDS = 1 * time.Millisecond + DefaultWriteTimeoutUDS = 100 * time.Millisecond // DefaultTelemetry is the default value for the Telemetry option DefaultTelemetry = true - // DefaultReceivingingMode is the default behavior when sending metrics + // DefaultReceivingMode is the default behavior when sending metrics DefaultReceivingMode = MutexMode // DefaultChannelModeBufferSize is the default size of the channel holding incoming metrics DefaultChannelModeBufferSize = 4096 // DefaultAggregationFlushInterval is the default interval for the aggregator to flush metrics. - DefaultAggregationFlushInterval = 3 * time.Second + // This should divide the Agent reporting period (default=10s) evenly to reduce "aliasing" that + // can cause values to appear irregular. + DefaultAggregationFlushInterval = 2 * time.Second // DefaultAggregation DefaultAggregation = false + // DefaultExtendedAggregation + DefaultExtendedAggregation = false + // DefaultDevMode + DefaultDevMode = false ) // Options contains the configuration options for a client. @@ -93,8 +100,18 @@ type Options struct { ChannelModeBufferSize int // AggregationFlushInterval is the interval for the aggregator to flush metrics AggregationFlushInterval time.Duration - // [beta] Aggregation enables/disables client side aggregation + // [beta] Aggregation enables/disables client side aggregation for + // Gauges, Counts and Sets (compatible with every Agent's version). Aggregation bool + // [beta] Extended aggregation enables/disables client side aggregation + // for all types. This feature is only compatible with Agent's versions + // >=7.25.0 or Agent's version >=6.25.0 && < 7.0.0. + ExtendedAggregation bool + // TelemetryAddr specify a different endpoint for telemetry metrics. + TelemetryAddr string + // DevMode enables the "dev" mode where the client sends much more + // telemetry metrics to help troubleshooting the client behavior. + DevMode bool } func resolveOptions(options []Option) (*Options, error) { @@ -113,6 +130,8 @@ func resolveOptions(options []Option) (*Options, error) { ChannelModeBufferSize: DefaultChannelModeBufferSize, AggregationFlushInterval: DefaultAggregationFlushInterval, Aggregation: DefaultAggregation, + ExtendedAggregation: DefaultExtendedAggregation, + DevMode: DefaultDevMode, } for _, option := range options { @@ -183,6 +202,9 @@ func WithBufferFlushInterval(bufferFlushInterval time.Duration) Option { // WithBufferShardCount sets the BufferShardCount option. func WithBufferShardCount(bufferShardCount int) Option { return func(o *Options) error { + if bufferShardCount < 1 { + return fmt.Errorf("BufferShardCount must be a positive integer") + } o.BufferShardCount = bufferShardCount return nil } @@ -220,7 +242,7 @@ func WithChannelMode() Option { } } -// WithMutexModeMode will use mutex to receive metrics +// WithMutexMode will use mutex to receive metrics func WithMutexMode() Option { return func(o *Options) error { o.ReceiveMode = MutexMode @@ -244,7 +266,8 @@ func WithAggregationInterval(interval time.Duration) Option { } } -// WithClientSideAggregation enables client side aggregation. Client side aggregation is a beta feature. +// WithClientSideAggregation enables client side aggregation for Gauges, Counts +// and Sets. Client side aggregation is a beta feature. func WithClientSideAggregation() Option { return func(o *Options) error { o.Aggregation = true @@ -256,6 +279,45 @@ func WithClientSideAggregation() Option { func WithoutClientSideAggregation() Option { return func(o *Options) error { o.Aggregation = false + o.ExtendedAggregation = false + return nil + } +} + +// WithExtendedClientSideAggregation enables client side aggregation for all +// types. This feature is only compatible with Agent's version >=6.25.0 && +// <7.0.0 or Agent's versions >=7.25.0. Client side aggregation is a beta +// feature. +func WithExtendedClientSideAggregation() Option { + return func(o *Options) error { + o.Aggregation = true + o.ExtendedAggregation = true + return nil + } +} + +// WithTelemetryAddr specify a different address for telemetry metrics. +func WithTelemetryAddr(addr string) Option { + return func(o *Options) error { + o.TelemetryAddr = addr + return nil + } +} + +// WithDevMode enables client "dev" mode, sending more Telemetry metrics to +// help troubleshoot client behavior. +func WithDevMode() Option { + return func(o *Options) error { + o.DevMode = true + return nil + } +} + +// WithoutDevMode disables client "dev" mode, sending more Telemetry metrics to +// help troubleshoot client behavior. +func WithoutDevMode() Option { + return func(o *Options) error { + o.DevMode = false return nil } } diff --git a/vendor/github.com/DataDog/datadog-go/statsd/pipe.go b/vendor/github.com/DataDog/datadog-go/statsd/pipe.go new file mode 100644 index 000000000..0d098a182 --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/statsd/pipe.go @@ -0,0 +1,9 @@ +// +build !windows + +package statsd + +import "errors" + +func newWindowsPipeWriter(pipepath string) (statsdWriter, error) { + return nil, errors.New("Windows Named Pipes are only supported on Windows") +} diff --git a/vendor/github.com/DataDog/datadog-go/statsd/pipe_windows.go b/vendor/github.com/DataDog/datadog-go/statsd/pipe_windows.go new file mode 100644 index 000000000..f533b0248 --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/statsd/pipe_windows.go @@ -0,0 +1,84 @@ +// +build windows + +package statsd + +import ( + "net" + "sync" + "time" + + "github.com/Microsoft/go-winio" +) + +const defaultPipeTimeout = 1 * time.Millisecond + +type pipeWriter struct { + mu sync.RWMutex + conn net.Conn + timeout time.Duration + pipepath string +} + +func (p *pipeWriter) SetWriteTimeout(d time.Duration) error { + p.mu.Lock() + p.timeout = d + p.mu.Unlock() + return nil +} + +func (p *pipeWriter) Write(data []byte) (n int, err error) { + conn, err := p.ensureConnection() + if err != nil { + return 0, err + } + + p.mu.RLock() + conn.SetWriteDeadline(time.Now().Add(p.timeout)) + p.mu.RUnlock() + + n, err = conn.Write(data) + if err != nil { + if e, ok := err.(net.Error); !ok || !e.Temporary() { + // disconnected; retry again on next attempt + p.mu.Lock() + p.conn = nil + p.mu.Unlock() + } + } + return n, err +} + +func (p *pipeWriter) ensureConnection() (net.Conn, error) { + p.mu.RLock() + conn := p.conn + p.mu.RUnlock() + if conn != nil { + return conn, nil + } + + // looks like we might need to connect - try again with write locking. + p.mu.Lock() + defer p.mu.Unlock() + if p.conn != nil { + return p.conn, nil + } + newconn, err := winio.DialPipe(p.pipepath, nil) + if err != nil { + return nil, err + } + p.conn = newconn + return newconn, nil +} + +func (p *pipeWriter) Close() error { + return p.conn.Close() +} + +func newWindowsPipeWriter(pipepath string) (*pipeWriter, error) { + // Defer connection establishment to first write + return &pipeWriter{ + conn: nil, + timeout: defaultPipeTimeout, + pipepath: pipepath, + }, nil +} diff --git a/vendor/github.com/DataDog/datadog-go/statsd/sender.go b/vendor/github.com/DataDog/datadog-go/statsd/sender.go index 4b3ac7d2d..4c8eb2696 100644 --- a/vendor/github.com/DataDog/datadog-go/statsd/sender.go +++ b/vendor/github.com/DataDog/datadog-go/statsd/sender.go @@ -28,20 +28,22 @@ type SenderMetrics struct { } type sender struct { - transport statsdWriter - pool *bufferPool - queue chan *statsdBuffer - metrics *SenderMetrics - stop chan struct{} + transport statsdWriter + pool *bufferPool + queue chan *statsdBuffer + metrics *SenderMetrics + stop chan struct{} + flushSignal chan struct{} } func newSender(transport statsdWriter, queueSize int, pool *bufferPool) *sender { sender := &sender{ - transport: transport, - pool: pool, - queue: make(chan *statsdBuffer, queueSize), - metrics: &SenderMetrics{}, - stop: make(chan struct{}), + transport: transport, + pool: pool, + queue: make(chan *statsdBuffer, queueSize), + metrics: &SenderMetrics{}, + stop: make(chan struct{}), + flushSignal: make(chan struct{}), } go sender.sendLoop() @@ -95,11 +97,17 @@ func (s *sender) sendLoop() { s.write(buffer) case <-s.stop: return + case <-s.flushSignal: + // At that point we know that the workers are paused (the statsd client + // will pause them before calling sender.flush()). + // So we can fully flush the input queue + s.flushInputQueue() + s.flushSignal <- struct{}{} } } } -func (s *sender) flush() { +func (s *sender) flushInputQueue() { for { select { case buffer := <-s.queue: @@ -109,10 +117,14 @@ func (s *sender) flush() { } } } +func (s *sender) flush() { + s.flushSignal <- struct{}{} + <-s.flushSignal +} func (s *sender) close() error { s.stop <- struct{}{} <-s.stop - s.flush() + s.flushInputQueue() return s.transport.Close() } diff --git a/vendor/github.com/DataDog/datadog-go/statsd/statsd.go b/vendor/github.com/DataDog/datadog-go/statsd/statsd.go index 31a4ee439..19b5515a5 100644 --- a/vendor/github.com/DataDog/datadog-go/statsd/statsd.go +++ b/vendor/github.com/DataDog/datadog-go/statsd/statsd.go @@ -6,24 +6,12 @@ adding tags and histograms and pushing upstream to Datadog. Refer to http://docs.datadoghq.com/guides/dogstatsd/ for information about DogStatsD. -Example Usage: - - // Create the client - c, err := statsd.New("127.0.0.1:8125") - if err != nil { - log.Fatal(err) - } - // Prefix every metric with the app name - c.Namespace = "flubber." - // Send the EC2 availability zone as a tag with every metric - c.Tags = append(c.Tags, "us-east-1a") - err = c.Gauge("request.duration", 1.2, nil, 1) - statsd is based on go-statsd-client. */ package statsd import ( + "errors" "fmt" "os" "strings" @@ -57,44 +45,38 @@ const DefaultUDSBufferPoolSize = 512 /* DefaultMaxAgentPayloadSize is the default maximum payload size the agent can receive. This can be adjusted by changing dogstatsd_buffer_size in the -agent configuration file datadog.yaml. +agent configuration file datadog.yaml. This is also used as the optimal payload size +for UDS datagrams. */ const DefaultMaxAgentPayloadSize = 8192 /* -TelemetryInterval is the interval at which telemetry will be sent by the client. -*/ -const TelemetryInterval = 10 * time.Second - -/* -clientTelemetryTag is a tag identifying this specific client. -*/ -var clientTelemetryTag = "client:go" - -/* -clientVersionTelemetryTag is a tag identifying this specific client version. +UnixAddressPrefix holds the prefix to use to enable Unix Domain Socket +traffic instead of UDP. */ -var clientVersionTelemetryTag = "client_version:3.7.2" +const UnixAddressPrefix = "unix://" /* -UnixAddressPrefix holds the prefix to use to enable Unix Domain Socket +WindowsPipeAddressPrefix holds the prefix to use to enable Windows Named Pipes traffic instead of UDP. */ -const UnixAddressPrefix = "unix://" +const WindowsPipeAddressPrefix = `\\.\pipe\` + +const ( + agentHostEnvVarName = "DD_AGENT_HOST" + agentPortEnvVarName = "DD_DOGSTATSD_PORT" + defaultUDPPort = "8125" +) /* ddEnvTagsMapping is a mapping of each "DD_" prefixed environment variable -to a specific tag name. +to a specific tag name. We use a slice to keep the order and simplify tests. */ -var ddEnvTagsMapping = map[string]string{ - // Client-side entity ID injection for container tagging. - "DD_ENTITY_ID": "dd.internal.entity_id", - // The name of the env in which the service runs. - "DD_ENV": "env", - // The name of the running service. - "DD_SERVICE": "service", - // The current version of the running service. - "DD_VERSION": "version", +var ddEnvTagsMapping = []struct{ envName, tagName string }{ + {"DD_ENTITY_ID", "dd.internal.entity_id"}, // Client-side entity ID injection for container tagging. + {"DD_ENV", "env"}, // The name of the env in which the service runs. + {"DD_SERVICE", "service"}, // The name of the running service. + {"DD_VERSION", "version"}, // The current version of the running service. } type metricType int @@ -103,9 +85,12 @@ const ( gauge metricType = iota count histogram + histogramAggregated distribution + distributionAggregated set timing + timingAggregated event serviceCheck ) @@ -117,17 +102,25 @@ const ( ChannelMode ) +const ( + WriterNameUDP string = "udp" + WriterNameUDS string = "uds" + WriterWindowsPipe string = "pipe" +) + type metric struct { metricType metricType namespace string globalTags []string name string fvalue float64 + fvalues []float64 ivalue int64 svalue string evalue *Event scvalue *ServiceCheck tags []string + stags string rate float64 } @@ -205,70 +198,95 @@ type Client struct { // Tags are global tags to be added to every statsd call Tags []string // skipErrors turns off error passing and allows UDS to emulate UDP behaviour - SkipErrors bool - flushTime time.Duration - bufferPool *bufferPool - buffer *statsdBuffer - metrics *ClientMetrics - telemetryTags []string - stop chan struct{} - wg sync.WaitGroup - bufferShards []*worker - closerLock sync.Mutex - receiveMode ReceivingMode - agg *aggregator - options []Option - addrOption string + SkipErrors bool + flushTime time.Duration + metrics *ClientMetrics + telemetry *telemetryClient + stop chan struct{} + wg sync.WaitGroup + workers []*worker + closerLock sync.Mutex + workersMode ReceivingMode + aggregatorMode ReceivingMode + agg *aggregator + aggExtended *aggregator + options []Option + addrOption string } // ClientMetrics contains metrics about the client type ClientMetrics struct { - TotalMetrics uint64 - TotalEvents uint64 - TotalServiceChecks uint64 - TotalDroppedOnReceive uint64 + TotalMetrics uint64 + TotalMetricsGauge uint64 + TotalMetricsCount uint64 + TotalMetricsHistogram uint64 + TotalMetricsDistribution uint64 + TotalMetricsSet uint64 + TotalMetricsTiming uint64 + TotalEvents uint64 + TotalServiceChecks uint64 + TotalDroppedOnReceive uint64 } // Verify that Client implements the ClientInterface. // https://golang.org/doc/faq#guarantee_satisfies_interface var _ ClientInterface = &Client{} -// New returns a pointer to a new Client given an addr in the format "hostname:port" or -// "unix:///path/to/socket". +func resolveAddr(addr string) string { + envPort := "" + if addr == "" { + addr = os.Getenv(agentHostEnvVarName) + envPort = os.Getenv(agentPortEnvVarName) + } + + if addr == "" { + return "" + } + + if !strings.HasPrefix(addr, WindowsPipeAddressPrefix) && !strings.HasPrefix(addr, UnixAddressPrefix) { + if !strings.Contains(addr, ":") { + if envPort != "" { + addr = fmt.Sprintf("%s:%s", addr, envPort) + } else { + addr = fmt.Sprintf("%s:%s", addr, defaultUDPPort) + } + } + } + return addr +} + +func createWriter(addr string) (statsdWriter, string, error) { + addr = resolveAddr(addr) + if addr == "" { + return nil, "", errors.New("No address passed and autodetection from environment failed") + } + + switch { + case strings.HasPrefix(addr, WindowsPipeAddressPrefix): + w, err := newWindowsPipeWriter(addr) + return w, WriterWindowsPipe, err + case strings.HasPrefix(addr, UnixAddressPrefix): + w, err := newUDSWriter(addr[len(UnixAddressPrefix):]) + return w, WriterNameUDS, err + default: + w, err := newUDPWriter(addr) + return w, WriterNameUDP, err + } +} + +// New returns a pointer to a new Client given an addr in the format "hostname:port" for UDP, +// "unix:///path/to/socket" for UDS or "\\.\pipe\path\to\pipe" for Windows Named Pipes. func New(addr string, options ...Option) (*Client, error) { - var w statsdWriter o, err := resolveOptions(options) if err != nil { return nil, err } - var writerType string - optimalPayloadSize := OptimalUDPPayloadSize - defaultBufferPoolSize := DefaultUDPBufferPoolSize - if !strings.HasPrefix(addr, UnixAddressPrefix) { - w, err = newUDPWriter(addr) - writerType = "udp" - } else { - // FIXME: The agent has a performance pitfall preventing us from using better defaults here. - // Once it's fixed, use `DefaultMaxAgentPayloadSize` and `DefaultUDSBufferPoolSize` instead. - optimalPayloadSize = OptimalUDPPayloadSize - defaultBufferPoolSize = DefaultUDPBufferPoolSize - w, err = newUDSWriter(addr[len(UnixAddressPrefix):]) - writerType = "uds" - } + w, writerType, err := createWriter(addr) if err != nil { return nil, err } - if o.MaxBytesPerPayload == 0 { - o.MaxBytesPerPayload = optimalPayloadSize - } - if o.BufferPoolSize == 0 { - o.BufferPoolSize = defaultBufferPoolSize - } - if o.SenderQueueSize == 0 { - o.SenderQueueSize = defaultBufferPoolSize - } client, err := newWithWriter(w, o, writerType) if err == nil { client.options = append(client.options, options...) @@ -309,55 +327,92 @@ func newWithWriter(w statsdWriter, o *Options, writerName string) (*Client, erro Tags: o.Tags, metrics: &ClientMetrics{}, } - if o.Aggregation { - c.agg = newAggregator(&c) - c.agg.start(o.AggregationFlushInterval) - } - // Inject values of DD_* environment variables as global tags. - for envName, tagName := range ddEnvTagsMapping { - if value := os.Getenv(envName); value != "" { - c.Tags = append(c.Tags, fmt.Sprintf("%s:%s", tagName, value)) + for _, mapping := range ddEnvTagsMapping { + if value := os.Getenv(mapping.envName); value != "" { + c.Tags = append(c.Tags, fmt.Sprintf("%s:%s", mapping.tagName, value)) } } - c.telemetryTags = append(c.Tags, clientTelemetryTag, clientVersionTelemetryTag, "client_transport:"+writerName) - if o.MaxBytesPerPayload == 0 { - o.MaxBytesPerPayload = OptimalUDPPayloadSize + if writerName == WriterNameUDS { + o.MaxBytesPerPayload = DefaultMaxAgentPayloadSize + } else { + o.MaxBytesPerPayload = OptimalUDPPayloadSize + } } if o.BufferPoolSize == 0 { - o.BufferPoolSize = DefaultUDPBufferPoolSize + if writerName == WriterNameUDS { + o.BufferPoolSize = DefaultUDSBufferPoolSize + } else { + o.BufferPoolSize = DefaultUDPBufferPoolSize + } } if o.SenderQueueSize == 0 { - o.SenderQueueSize = DefaultUDPBufferPoolSize + if writerName == WriterNameUDS { + o.SenderQueueSize = DefaultUDSBufferPoolSize + } else { + o.SenderQueueSize = DefaultUDPBufferPoolSize + } + } + + bufferPool := newBufferPool(o.BufferPoolSize, o.MaxBytesPerPayload, o.MaxMessagesPerPayload) + c.sender = newSender(w, o.SenderQueueSize, bufferPool) + c.aggregatorMode = o.ReceiveMode + + c.workersMode = o.ReceiveMode + // ChannelMode mode at the worker level is not enabled when + // ExtendedAggregation is since the user app will not directly + // use the worker (the aggregator sit between the app and the + // workers). + if o.ExtendedAggregation { + c.workersMode = MutexMode + } + + if o.Aggregation || o.ExtendedAggregation { + c.agg = newAggregator(&c) + c.agg.start(o.AggregationFlushInterval) + + if o.ExtendedAggregation { + c.aggExtended = c.agg + + if c.aggregatorMode == ChannelMode { + c.agg.startReceivingMetric(o.ChannelModeBufferSize, o.BufferShardCount) + } + } } - c.receiveMode = o.ReceiveMode - c.bufferPool = newBufferPool(o.BufferPoolSize, o.MaxBytesPerPayload, o.MaxMessagesPerPayload) - c.buffer = c.bufferPool.borrowBuffer() - c.sender = newSender(w, o.SenderQueueSize, c.bufferPool) for i := 0; i < o.BufferShardCount; i++ { - w := newWorker(c.bufferPool, c.sender) - c.bufferShards = append(c.bufferShards, w) - if c.receiveMode == ChannelMode { - w.startReceivingMetric(o.ChannelModeBufferSize) // TODO make it configurable + w := newWorker(bufferPool, c.sender) + c.workers = append(c.workers, w) + + if c.workersMode == ChannelMode { + w.startReceivingMetric(o.ChannelModeBufferSize) } } + c.flushTime = o.BufferFlushInterval c.stop = make(chan struct{}, 1) + c.wg.Add(1) go func() { defer c.wg.Done() c.watch() }() - c.wg.Add(1) - go func() { - defer c.wg.Done() - if o.Telemetry { - c.telemetry() + + if o.Telemetry { + if o.TelemetryAddr == "" { + c.telemetry = newTelemetryClient(&c, writerName, o.DevMode) + } else { + var err error + c.telemetry, err = newTelemetryClientWithCustomAddr(&c, writerName, o.DevMode, o.TelemetryAddr, bufferPool) + if err != nil { + return nil, err + } } - }() + c.telemetry.run(&c.wg, c.stop) + } + return &c, nil } @@ -370,7 +425,8 @@ func NewBuffered(addr string, buflen int) (*Client, error) { return New(addr, WithMaxMessagesPerPayload(buflen)) } -// SetWriteTimeout allows the user to set a custom UDS write timeout. Not supported for UDP. +// SetWriteTimeout allows the user to set a custom UDS write timeout. Not supported for UDP +// or Windows Pipes. func (c *Client) SetWriteTimeout(d time.Duration) error { if c == nil { return ErrNoClient @@ -384,7 +440,7 @@ func (c *Client) watch() { for { select { case <-ticker.C: - for _, w := range c.bufferShards { + for _, w := range c.workers { w.flush() } case <-c.stop: @@ -394,91 +450,83 @@ func (c *Client) watch() { } } -func (c *Client) telemetry() { - ticker := time.NewTicker(TelemetryInterval) - for { - select { - case <-ticker.C: - for _, m := range c.flushTelemetry() { - c.send(m) - } - case <-c.stop: - ticker.Stop() - return - } - } -} - -// flushTelemetry returns Telemetry metrics to be flushed. It's its own function to ease testing. -func (c *Client) flushTelemetry() []metric { - m := []metric{} - - // same as Count but without global namespace - telemetryCount := func(name string, value int64) { - m = append(m, metric{metricType: count, name: name, ivalue: value, tags: c.telemetryTags, rate: 1}) - } - - clientMetrics := c.FlushTelemetryMetrics() - telemetryCount("datadog.dogstatsd.client.metrics", int64(clientMetrics.TotalMetrics)) - telemetryCount("datadog.dogstatsd.client.events", int64(clientMetrics.TotalEvents)) - telemetryCount("datadog.dogstatsd.client.service_checks", int64(clientMetrics.TotalServiceChecks)) - telemetryCount("datadog.dogstatsd.client.metric_dropped_on_receive", int64(clientMetrics.TotalDroppedOnReceive)) - - senderMetrics := c.sender.flushTelemetryMetrics() - telemetryCount("datadog.dogstatsd.client.packets_sent", int64(senderMetrics.TotalSentPayloads)) - telemetryCount("datadog.dogstatsd.client.bytes_sent", int64(senderMetrics.TotalSentBytes)) - telemetryCount("datadog.dogstatsd.client.packets_dropped", int64(senderMetrics.TotalDroppedPayloads)) - telemetryCount("datadog.dogstatsd.client.bytes_dropped", int64(senderMetrics.TotalDroppedBytes)) - telemetryCount("datadog.dogstatsd.client.packets_dropped_queue", int64(senderMetrics.TotalDroppedPayloadsQueueFull)) - telemetryCount("datadog.dogstatsd.client.bytes_dropped_queue", int64(senderMetrics.TotalDroppedBytesQueueFull)) - telemetryCount("datadog.dogstatsd.client.packets_dropped_writer", int64(senderMetrics.TotalDroppedPayloadsWriter)) - telemetryCount("datadog.dogstatsd.client.bytes_dropped_writer", int64(senderMetrics.TotalDroppedBytesWriter)) - return m -} - -// Flush forces a flush of all the queued dogstatsd payloads -// This method is blocking and will not return until everything is sent -// through the network +// Flush forces a flush of all the queued dogstatsd payloads This method is +// blocking and will not return until everything is sent through the network. +// In MutexMode, this will also block sampling new data to the client while the +// workers and sender are flushed. func (c *Client) Flush() error { if c == nil { return ErrNoClient } - for _, w := range c.bufferShards { - w.flush() + if c.agg != nil { + c.agg.flush() + } + for _, w := range c.workers { + w.pause() + defer w.unpause() + w.flushUnsafe() } + // Now that the worker are pause the sender can flush the queue between + // worker and senders c.sender.flush() return nil } func (c *Client) FlushTelemetryMetrics() ClientMetrics { - return ClientMetrics{ - TotalMetrics: atomic.SwapUint64(&c.metrics.TotalMetrics, 0), - TotalEvents: atomic.SwapUint64(&c.metrics.TotalEvents, 0), - TotalServiceChecks: atomic.SwapUint64(&c.metrics.TotalServiceChecks, 0), - TotalDroppedOnReceive: atomic.SwapUint64(&c.metrics.TotalDroppedOnReceive, 0), + cm := ClientMetrics{ + TotalMetricsGauge: atomic.SwapUint64(&c.metrics.TotalMetricsGauge, 0), + TotalMetricsCount: atomic.SwapUint64(&c.metrics.TotalMetricsCount, 0), + TotalMetricsSet: atomic.SwapUint64(&c.metrics.TotalMetricsSet, 0), + TotalMetricsHistogram: atomic.SwapUint64(&c.metrics.TotalMetricsHistogram, 0), + TotalMetricsDistribution: atomic.SwapUint64(&c.metrics.TotalMetricsDistribution, 0), + TotalMetricsTiming: atomic.SwapUint64(&c.metrics.TotalMetricsTiming, 0), + TotalEvents: atomic.SwapUint64(&c.metrics.TotalEvents, 0), + TotalServiceChecks: atomic.SwapUint64(&c.metrics.TotalServiceChecks, 0), + TotalDroppedOnReceive: atomic.SwapUint64(&c.metrics.TotalDroppedOnReceive, 0), } + + cm.TotalMetrics = cm.TotalMetricsGauge + cm.TotalMetricsCount + + cm.TotalMetricsSet + cm.TotalMetricsHistogram + + cm.TotalMetricsDistribution + cm.TotalMetricsTiming + + return cm } func (c *Client) send(m metric) error { - if c == nil { - return ErrNoClient + h := hashString32(m.name) + worker := c.workers[h%uint32(len(c.workers))] + + if c.workersMode == ChannelMode { + select { + case worker.inputMetrics <- m: + default: + atomic.AddUint64(&c.metrics.TotalDroppedOnReceive, 1) + } + return nil } + return worker.processMetric(m) +} +// sendBlocking is used by the aggregator to inject aggregated metrics. +func (c *Client) sendBlocking(m metric) error { m.globalTags = c.Tags m.namespace = c.Namespace h := hashString32(m.name) - worker := c.bufferShards[h%uint32(len(c.bufferShards))] + worker := c.workers[h%uint32(len(c.workers))] + return worker.processMetric(m) +} - if c.receiveMode == ChannelMode { +func (c *Client) sendToAggregator(mType metricType, name string, value float64, tags []string, rate float64, f bufferedMetricSampleFunc) error { + if c.aggregatorMode == ChannelMode { select { - case worker.inputMetrics <- m: + case c.aggExtended.inputMetrics <- metric{metricType: mType, name: name, fvalue: value, tags: tags, rate: rate}: default: atomic.AddUint64(&c.metrics.TotalDroppedOnReceive, 1) } return nil } - return worker.processMetric(m) + return f(name, value, tags, rate) } // Gauge measures the value of a metric at a particular time. @@ -486,11 +534,11 @@ func (c *Client) Gauge(name string, value float64, tags []string, rate float64) if c == nil { return ErrNoClient } - atomic.AddUint64(&c.metrics.TotalMetrics, 1) + atomic.AddUint64(&c.metrics.TotalMetricsGauge, 1) if c.agg != nil { - return c.agg.gauge(name, value, tags, rate) + return c.agg.gauge(name, value, tags) } - return c.send(metric{metricType: gauge, name: name, fvalue: value, tags: tags, rate: rate}) + return c.send(metric{metricType: gauge, name: name, fvalue: value, tags: tags, rate: rate, globalTags: c.Tags, namespace: c.Namespace}) } // Count tracks how many times something happened per second. @@ -498,11 +546,11 @@ func (c *Client) Count(name string, value int64, tags []string, rate float64) er if c == nil { return ErrNoClient } - atomic.AddUint64(&c.metrics.TotalMetrics, 1) + atomic.AddUint64(&c.metrics.TotalMetricsCount, 1) if c.agg != nil { - return c.agg.count(name, value, tags, rate) + return c.agg.count(name, value, tags) } - return c.send(metric{metricType: count, name: name, ivalue: value, tags: tags, rate: rate}) + return c.send(metric{metricType: count, name: name, ivalue: value, tags: tags, rate: rate, globalTags: c.Tags, namespace: c.Namespace}) } // Histogram tracks the statistical distribution of a set of values on each host. @@ -510,8 +558,11 @@ func (c *Client) Histogram(name string, value float64, tags []string, rate float if c == nil { return ErrNoClient } - atomic.AddUint64(&c.metrics.TotalMetrics, 1) - return c.send(metric{metricType: histogram, name: name, fvalue: value, tags: tags, rate: rate}) + atomic.AddUint64(&c.metrics.TotalMetricsHistogram, 1) + if c.aggExtended != nil { + return c.sendToAggregator(histogram, name, value, tags, rate, c.aggExtended.histogram) + } + return c.send(metric{metricType: histogram, name: name, fvalue: value, tags: tags, rate: rate, globalTags: c.Tags, namespace: c.Namespace}) } // Distribution tracks the statistical distribution of a set of values across your infrastructure. @@ -519,8 +570,11 @@ func (c *Client) Distribution(name string, value float64, tags []string, rate fl if c == nil { return ErrNoClient } - atomic.AddUint64(&c.metrics.TotalMetrics, 1) - return c.send(metric{metricType: distribution, name: name, fvalue: value, tags: tags, rate: rate}) + atomic.AddUint64(&c.metrics.TotalMetricsDistribution, 1) + if c.aggExtended != nil { + return c.sendToAggregator(distribution, name, value, tags, rate, c.aggExtended.distribution) + } + return c.send(metric{metricType: distribution, name: name, fvalue: value, tags: tags, rate: rate, globalTags: c.Tags, namespace: c.Namespace}) } // Decr is just Count of -1 @@ -538,11 +592,11 @@ func (c *Client) Set(name string, value string, tags []string, rate float64) err if c == nil { return ErrNoClient } - atomic.AddUint64(&c.metrics.TotalMetrics, 1) + atomic.AddUint64(&c.metrics.TotalMetricsSet, 1) if c.agg != nil { - return c.agg.set(name, value, tags, rate) + return c.agg.set(name, value, tags) } - return c.send(metric{metricType: set, name: name, svalue: value, tags: tags, rate: rate}) + return c.send(metric{metricType: set, name: name, svalue: value, tags: tags, rate: rate, globalTags: c.Tags, namespace: c.Namespace}) } // Timing sends timing information, it is an alias for TimeInMilliseconds @@ -556,8 +610,11 @@ func (c *Client) TimeInMilliseconds(name string, value float64, tags []string, r if c == nil { return ErrNoClient } - atomic.AddUint64(&c.metrics.TotalMetrics, 1) - return c.send(metric{metricType: timing, name: name, fvalue: value, tags: tags, rate: rate}) + atomic.AddUint64(&c.metrics.TotalMetricsTiming, 1) + if c.aggExtended != nil { + return c.sendToAggregator(timing, name, value, tags, rate, c.aggExtended.timing) + } + return c.send(metric{metricType: timing, name: name, fvalue: value, tags: tags, rate: rate, globalTags: c.Tags, namespace: c.Namespace}) } // Event sends the provided Event. @@ -566,7 +623,7 @@ func (c *Client) Event(e *Event) error { return ErrNoClient } atomic.AddUint64(&c.metrics.TotalEvents, 1) - return c.send(metric{metricType: event, evalue: e, rate: 1}) + return c.send(metric{metricType: event, evalue: e, rate: 1, globalTags: c.Tags, namespace: c.Namespace}) } // SimpleEvent sends an event with the provided title and text. @@ -581,7 +638,7 @@ func (c *Client) ServiceCheck(sc *ServiceCheck) error { return ErrNoClient } atomic.AddUint64(&c.metrics.TotalServiceChecks, 1) - return c.send(metric{metricType: serviceCheck, scvalue: sc, rate: 1}) + return c.send(metric{metricType: serviceCheck, scvalue: sc, rate: 1, globalTags: c.Tags, namespace: c.Namespace}) } // SimpleServiceCheck sends an serviceCheck with the provided name and status. @@ -608,20 +665,23 @@ func (c *Client) Close() error { } close(c.stop) - if c.receiveMode == ChannelMode { - for _, w := range c.bufferShards { + if c.workersMode == ChannelMode { + for _, w := range c.workers { w.stopReceivingMetric() } } - // Wait for the threads to stop - c.wg.Wait() - - // Finally flush any remaining metrics that may have come in at the last moment + // flush the aggregator first if c.agg != nil { + if c.aggExtended != nil && c.aggregatorMode == ChannelMode { + c.agg.stopReceivingMetric() + } c.agg.stop() } - c.Flush() + // Wait for the threads to stop + c.wg.Wait() + + c.Flush() return c.sender.close() } diff --git a/vendor/github.com/DataDog/datadog-go/statsd/telemetry.go b/vendor/github.com/DataDog/datadog-go/statsd/telemetry.go new file mode 100644 index 000000000..eef5c2731 --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/statsd/telemetry.go @@ -0,0 +1,151 @@ +package statsd + +import ( + "fmt" + "sync" + "time" +) + +/* +TelemetryInterval is the interval at which telemetry will be sent by the client. +*/ +const TelemetryInterval = 10 * time.Second + +/* +clientTelemetryTag is a tag identifying this specific client. +*/ +var clientTelemetryTag = "client:go" + +/* +clientVersionTelemetryTag is a tag identifying this specific client version. +*/ +var clientVersionTelemetryTag = "client_version:4.8.3" + +type telemetryClient struct { + c *Client + tags []string + tagsByType map[metricType][]string + sender *sender + worker *worker + devMode bool +} + +func newTelemetryClient(c *Client, transport string, devMode bool) *telemetryClient { + t := &telemetryClient{ + c: c, + tags: append(c.Tags, clientTelemetryTag, clientVersionTelemetryTag, "client_transport:"+transport), + tagsByType: map[metricType][]string{}, + devMode: devMode, + } + + if devMode { + t.tagsByType[gauge] = append(append([]string{}, t.tags...), "metrics_type:gauge") + t.tagsByType[count] = append(append([]string{}, t.tags...), "metrics_type:count") + t.tagsByType[set] = append(append([]string{}, t.tags...), "metrics_type:set") + t.tagsByType[timing] = append(append([]string{}, t.tags...), "metrics_type:timing") + t.tagsByType[histogram] = append(append([]string{}, t.tags...), "metrics_type:histogram") + t.tagsByType[distribution] = append(append([]string{}, t.tags...), "metrics_type:distribution") + t.tagsByType[timing] = append(append([]string{}, t.tags...), "metrics_type:timing") + } + return t +} + +func newTelemetryClientWithCustomAddr(c *Client, transport string, devMode bool, telemetryAddr string, pool *bufferPool) (*telemetryClient, error) { + telemetryWriter, _, err := createWriter(telemetryAddr) + if err != nil { + return nil, fmt.Errorf("Could not resolve telemetry address: %v", err) + } + + t := newTelemetryClient(c, transport, devMode) + + // Creating a custom sender/worker with 1 worker in mutex mode for the + // telemetry that share the same bufferPool. + // FIXME due to performance pitfall, we're always using UDP defaults + // even for UDS. + t.sender = newSender(telemetryWriter, DefaultUDPBufferPoolSize, pool) + t.worker = newWorker(pool, t.sender) + return t, nil +} + +func (t *telemetryClient) run(wg *sync.WaitGroup, stop chan struct{}) { + wg.Add(1) + go func() { + defer wg.Done() + ticker := time.NewTicker(TelemetryInterval) + for { + select { + case <-ticker.C: + t.sendTelemetry() + case <-stop: + ticker.Stop() + if t.sender != nil { + t.sender.close() + } + return + } + } + }() +} + +func (t *telemetryClient) sendTelemetry() { + for _, m := range t.flush() { + if t.worker != nil { + t.worker.processMetric(m) + } else { + t.c.send(m) + } + } + + if t.worker != nil { + t.worker.flush() + } +} + +// flushTelemetry returns Telemetry metrics to be flushed. It's its own function to ease testing. +func (t *telemetryClient) flush() []metric { + m := []metric{} + + // same as Count but without global namespace + telemetryCount := func(name string, value int64, tags []string) { + m = append(m, metric{metricType: count, name: name, ivalue: value, tags: tags, rate: 1}) + } + + clientMetrics := t.c.FlushTelemetryMetrics() + telemetryCount("datadog.dogstatsd.client.metrics", int64(clientMetrics.TotalMetrics), t.tags) + if t.devMode { + telemetryCount("datadog.dogstatsd.client.metrics_by_type", int64(clientMetrics.TotalMetricsGauge), t.tagsByType[gauge]) + telemetryCount("datadog.dogstatsd.client.metrics_by_type", int64(clientMetrics.TotalMetricsCount), t.tagsByType[count]) + telemetryCount("datadog.dogstatsd.client.metrics_by_type", int64(clientMetrics.TotalMetricsHistogram), t.tagsByType[histogram]) + telemetryCount("datadog.dogstatsd.client.metrics_by_type", int64(clientMetrics.TotalMetricsDistribution), t.tagsByType[distribution]) + telemetryCount("datadog.dogstatsd.client.metrics_by_type", int64(clientMetrics.TotalMetricsSet), t.tagsByType[set]) + telemetryCount("datadog.dogstatsd.client.metrics_by_type", int64(clientMetrics.TotalMetricsTiming), t.tagsByType[timing]) + } + + telemetryCount("datadog.dogstatsd.client.events", int64(clientMetrics.TotalEvents), t.tags) + telemetryCount("datadog.dogstatsd.client.service_checks", int64(clientMetrics.TotalServiceChecks), t.tags) + telemetryCount("datadog.dogstatsd.client.metric_dropped_on_receive", int64(clientMetrics.TotalDroppedOnReceive), t.tags) + + senderMetrics := t.c.sender.flushTelemetryMetrics() + telemetryCount("datadog.dogstatsd.client.packets_sent", int64(senderMetrics.TotalSentPayloads), t.tags) + telemetryCount("datadog.dogstatsd.client.bytes_sent", int64(senderMetrics.TotalSentBytes), t.tags) + telemetryCount("datadog.dogstatsd.client.packets_dropped", int64(senderMetrics.TotalDroppedPayloads), t.tags) + telemetryCount("datadog.dogstatsd.client.bytes_dropped", int64(senderMetrics.TotalDroppedBytes), t.tags) + telemetryCount("datadog.dogstatsd.client.packets_dropped_queue", int64(senderMetrics.TotalDroppedPayloadsQueueFull), t.tags) + telemetryCount("datadog.dogstatsd.client.bytes_dropped_queue", int64(senderMetrics.TotalDroppedBytesQueueFull), t.tags) + telemetryCount("datadog.dogstatsd.client.packets_dropped_writer", int64(senderMetrics.TotalDroppedPayloadsWriter), t.tags) + telemetryCount("datadog.dogstatsd.client.bytes_dropped_writer", int64(senderMetrics.TotalDroppedBytesWriter), t.tags) + + if aggMetrics := t.c.agg.flushTelemetryMetrics(); aggMetrics != nil { + telemetryCount("datadog.dogstatsd.client.aggregated_context", int64(aggMetrics.nbContext), t.tags) + if t.devMode { + telemetryCount("datadog.dogstatsd.client.aggregated_context_by_type", int64(aggMetrics.nbContextGauge), t.tagsByType[gauge]) + telemetryCount("datadog.dogstatsd.client.aggregated_context_by_type", int64(aggMetrics.nbContextSet), t.tagsByType[set]) + telemetryCount("datadog.dogstatsd.client.aggregated_context_by_type", int64(aggMetrics.nbContextCount), t.tagsByType[count]) + telemetryCount("datadog.dogstatsd.client.aggregated_context_by_type", int64(aggMetrics.nbContextHistogram), t.tagsByType[histogram]) + telemetryCount("datadog.dogstatsd.client.aggregated_context_by_type", int64(aggMetrics.nbContextDistribution), t.tagsByType[distribution]) + telemetryCount("datadog.dogstatsd.client.aggregated_context_by_type", int64(aggMetrics.nbContextTiming), t.tagsByType[timing]) + } + } + + return m +} diff --git a/vendor/github.com/DataDog/datadog-go/statsd/udp.go b/vendor/github.com/DataDog/datadog-go/statsd/udp.go index 9ddff421c..8af522c5b 100644 --- a/vendor/github.com/DataDog/datadog-go/statsd/udp.go +++ b/vendor/github.com/DataDog/datadog-go/statsd/udp.go @@ -2,18 +2,10 @@ package statsd import ( "errors" - "fmt" "net" - "os" "time" ) -const ( - autoHostEnvName = "DD_AGENT_HOST" - autoPortEnvName = "DD_DOGSTATSD_PORT" - defaultUDPPort = "8125" -) - // udpWriter is an internal class wrapping around management of UDP connection type udpWriter struct { conn net.Conn @@ -21,13 +13,6 @@ type udpWriter struct { // New returns a pointer to a new udpWriter given an addr in the format "hostname:port". func newUDPWriter(addr string) (*udpWriter, error) { - if addr == "" { - addr = addressFromEnvironment() - } - if addr == "" { - return nil, errors.New("No address passed and autodetection from environment failed") - } - udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err @@ -53,21 +38,3 @@ func (w *udpWriter) Write(data []byte) (int, error) { func (w *udpWriter) Close() error { return w.conn.Close() } - -func (w *udpWriter) remoteAddr() net.Addr { - return w.conn.RemoteAddr() -} - -func addressFromEnvironment() string { - autoHost := os.Getenv(autoHostEnvName) - if autoHost == "" { - return "" - } - - autoPort := os.Getenv(autoPortEnvName) - if autoPort == "" { - autoPort = defaultUDPPort - } - - return fmt.Sprintf("%s:%s", autoHost, autoPort) -} diff --git a/vendor/github.com/DataDog/datadog-go/statsd/uds.go b/vendor/github.com/DataDog/datadog-go/statsd/uds.go index 4d65d2317..6c52261bd 100644 --- a/vendor/github.com/DataDog/datadog-go/statsd/uds.go +++ b/vendor/github.com/DataDog/datadog-go/statsd/uds.go @@ -1,3 +1,5 @@ +// +build !windows + package statsd import ( @@ -10,7 +12,7 @@ import ( UDSTimeout holds the default timeout for UDS socket writes, as they can get blocking when the receiving buffer is full. */ -const defaultUDSTimeout = 1 * time.Millisecond +const defaultUDSTimeout = 100 * time.Millisecond // udsWriter is an internal class wrapping around management of UDS connection type udsWriter struct { diff --git a/vendor/github.com/DataDog/datadog-go/statsd/uds_windows.go b/vendor/github.com/DataDog/datadog-go/statsd/uds_windows.go new file mode 100644 index 000000000..9c97dfd4e --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/statsd/uds_windows.go @@ -0,0 +1,10 @@ +// +build windows + +package statsd + +import "fmt" + +// newUDSWriter is disable on windows as unix sockets are not available +func newUDSWriter(addr string) (statsdWriter, error) { + return nil, fmt.Errorf("unix socket is not available on windows") +} diff --git a/vendor/github.com/DataDog/datadog-go/statsd/utils.go b/vendor/github.com/DataDog/datadog-go/statsd/utils.go new file mode 100644 index 000000000..a2829d94f --- /dev/null +++ b/vendor/github.com/DataDog/datadog-go/statsd/utils.go @@ -0,0 +1,23 @@ +package statsd + +import ( + "math/rand" + "sync" +) + +func shouldSample(rate float64, r *rand.Rand, lock *sync.Mutex) bool { + if rate >= 1 { + return true + } + // sources created by rand.NewSource() (ie. w.random) are not thread safe. + // TODO: use defer once the lowest Go version we support is 1.14 (defer + // has an overhead before that). + lock.Lock() + if r.Float64() > rate { + lock.Unlock() + return false + } + lock.Unlock() + return true + +} diff --git a/vendor/github.com/DataDog/datadog-go/statsd/worker.go b/vendor/github.com/DataDog/datadog-go/statsd/worker.go index 74895dbfb..e5a3bac56 100644 --- a/vendor/github.com/DataDog/datadog-go/statsd/worker.go +++ b/vendor/github.com/DataDog/datadog-go/statsd/worker.go @@ -3,12 +3,15 @@ package statsd import ( "math/rand" "sync" + "time" ) type worker struct { - pool *bufferPool - buffer *statsdBuffer - sender *sender + pool *bufferPool + buffer *statsdBuffer + sender *sender + random *rand.Rand + randomLock sync.Mutex sync.Mutex inputMetrics chan metric @@ -16,10 +19,21 @@ type worker struct { } func newWorker(pool *bufferPool, sender *sender) *worker { + // Each worker uses its own random source and random lock to prevent + // workers in separate goroutines from contending for the lock on the + // "math/rand" package-global random source (e.g. calls like + // "rand.Float64()" must acquire a shared lock to get the next + // pseudorandom number). + // Note that calling "time.Now().UnixNano()" repeatedly quickly may return + // very similar values. That's fine for seeding the worker-specific random + // source because we just need an evenly distributed stream of float values. + // Do not use this random source for cryptographic randomness. + random := rand.New(rand.NewSource(time.Now().UnixNano())) return &worker{ pool: pool, sender: sender, buffer: pool.borrowBuffer(), + random: random, stop: make(chan struct{}), } } @@ -45,7 +59,7 @@ func (w *worker) pullMetric() { } func (w *worker) processMetric(m metric) error { - if !w.shouldSample(m.rate) { + if !shouldSample(m.rate, w.random, &w.randomLock) { return nil } w.Lock() @@ -58,11 +72,29 @@ func (w *worker) processMetric(m metric) error { return err } -func (w *worker) shouldSample(rate float64) bool { - if rate < 1 && rand.Float64() > rate { - return false +func (w *worker) writeAggregatedMetricUnsafe(m metric, metricSymbol []byte, precision int) error { + globalPos := 0 + + // first check how much data we can write to the buffer: + // +3 + len(metricSymbol) because the message will include '||#' before the tags + // +1 for the potential line break at the start of the metric + tagsSize := len(m.stags) + 4 + len(metricSymbol) + for _, t := range m.globalTags { + tagsSize += len(t) + 1 + } + + for { + pos, err := w.buffer.writeAggregated(metricSymbol, m.namespace, m.globalTags, m.name, m.fvalues[globalPos:], m.stags, tagsSize, precision) + if err == errPartialWrite { + // We successfully wrote part of the histogram metrics. + // We flush the current buffer and finish the histogram + // in a new one. + w.flushUnsafe() + globalPos += pos + } else { + return err + } } - return true } func (w *worker) writeMetricUnsafe(m metric) error { @@ -83,6 +115,12 @@ func (w *worker) writeMetricUnsafe(m metric) error { return w.buffer.writeEvent(*m.evalue, m.globalTags) case serviceCheck: return w.buffer.writeServiceCheck(*m.scvalue, m.globalTags) + case histogramAggregated: + return w.writeAggregatedMetricUnsafe(m, histogramSymbol, -1) + case distributionAggregated: + return w.writeAggregatedMetricUnsafe(m, distributionSymbol, -1) + case timingAggregated: + return w.writeAggregatedMetricUnsafe(m, timingSymbol, 6) default: return nil } @@ -94,6 +132,14 @@ func (w *worker) flush() { w.Unlock() } +func (w *worker) pause() { + w.Lock() +} + +func (w *worker) unpause() { + w.Unlock() +} + // flush the current buffer. Lock must be held by caller. // flushed buffer written to the network asynchronously. func (w *worker) flushUnsafe() { diff --git a/vendor/github.com/Microsoft/go-winio/.gitattributes b/vendor/github.com/Microsoft/go-winio/.gitattributes new file mode 100644 index 000000000..94f480de9 --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/.gitattributes @@ -0,0 +1 @@ +* text=auto eol=lf \ No newline at end of file diff --git a/vendor/github.com/Microsoft/go-winio/.gitignore b/vendor/github.com/Microsoft/go-winio/.gitignore new file mode 100644 index 000000000..815e20660 --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/.gitignore @@ -0,0 +1,10 @@ +.vscode/ + +*.exe + +# testing +testdata + +# go workspaces +go.work +go.work.sum diff --git a/vendor/github.com/Microsoft/go-winio/.golangci.yml b/vendor/github.com/Microsoft/go-winio/.golangci.yml new file mode 100644 index 000000000..faedfe937 --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/.golangci.yml @@ -0,0 +1,147 @@ +linters: + enable: + # style + - containedctx # struct contains a context + - dupl # duplicate code + - errname # erorrs are named correctly + - nolintlint # "//nolint" directives are properly explained + - revive # golint replacement + - unconvert # unnecessary conversions + - wastedassign + + # bugs, performance, unused, etc ... + - contextcheck # function uses a non-inherited context + - errorlint # errors not wrapped for 1.13 + - exhaustive # check exhaustiveness of enum switch statements + - gofmt # files are gofmt'ed + - gosec # security + - nilerr # returns nil even with non-nil error + - thelper # test helpers without t.Helper() + - unparam # unused function params + +issues: + exclude-dirs: + - pkg/etw/sample + + exclude-rules: + # err is very often shadowed in nested scopes + - linters: + - govet + text: '^shadow: declaration of "err" shadows declaration' + + # ignore long lines for skip autogen directives + - linters: + - revive + text: "^line-length-limit: " + source: "^//(go:generate|sys) " + + #TODO: remove after upgrading to go1.18 + # ignore comment spacing for nolint and sys directives + - linters: + - revive + text: "^comment-spacings: no space between comment delimiter and comment text" + source: "//(cspell:|nolint:|sys |todo)" + + # not on go 1.18 yet, so no any + - linters: + - revive + text: "^use-any: since GO 1.18 'interface{}' can be replaced by 'any'" + + # allow unjustified ignores of error checks in defer statements + - linters: + - nolintlint + text: "^directive `//nolint:errcheck` should provide explanation" + source: '^\s*defer ' + + # allow unjustified ignores of error lints for io.EOF + - linters: + - nolintlint + text: "^directive `//nolint:errorlint` should provide explanation" + source: '[=|!]= io.EOF' + + +linters-settings: + exhaustive: + default-signifies-exhaustive: true + govet: + enable-all: true + disable: + # struct order is often for Win32 compat + # also, ignore pointer bytes/GC issues for now until performance becomes an issue + - fieldalignment + nolintlint: + require-explanation: true + require-specific: true + revive: + # revive is more configurable than static check, so likely the preferred alternative to static-check + # (once the perf issue is solved: https://github.com/golangci/golangci-lint/issues/2997) + enable-all-rules: + true + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md + rules: + # rules with required arguments + - name: argument-limit + disabled: true + - name: banned-characters + disabled: true + - name: cognitive-complexity + disabled: true + - name: cyclomatic + disabled: true + - name: file-header + disabled: true + - name: function-length + disabled: true + - name: function-result-limit + disabled: true + - name: max-public-structs + disabled: true + # geneally annoying rules + - name: add-constant # complains about any and all strings and integers + disabled: true + - name: confusing-naming # we frequently use "Foo()" and "foo()" together + disabled: true + - name: flag-parameter # excessive, and a common idiom we use + disabled: true + - name: unhandled-error # warns over common fmt.Print* and io.Close; rely on errcheck instead + disabled: true + # general config + - name: line-length-limit + arguments: + - 140 + - name: var-naming + arguments: + - [] + - - CID + - CRI + - CTRD + - DACL + - DLL + - DOS + - ETW + - FSCTL + - GCS + - GMSA + - HCS + - HV + - IO + - LCOW + - LDAP + - LPAC + - LTSC + - MMIO + - NT + - OCI + - PMEM + - PWSH + - RX + - SACl + - SID + - SMB + - TX + - VHD + - VHDX + - VMID + - VPCI + - WCOW + - WIM diff --git a/vendor/github.com/Microsoft/go-winio/CODEOWNERS b/vendor/github.com/Microsoft/go-winio/CODEOWNERS new file mode 100644 index 000000000..ae1b4942b --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/CODEOWNERS @@ -0,0 +1 @@ + * @microsoft/containerplat diff --git a/vendor/github.com/Microsoft/go-winio/LICENSE b/vendor/github.com/Microsoft/go-winio/LICENSE new file mode 100644 index 000000000..b8b569d77 --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/LICENSE @@ -0,0 +1,22 @@ +The MIT License (MIT) + +Copyright (c) 2015 Microsoft + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + diff --git a/vendor/github.com/Microsoft/go-winio/README.md b/vendor/github.com/Microsoft/go-winio/README.md new file mode 100644 index 000000000..7474b4f0b --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/README.md @@ -0,0 +1,89 @@ +# go-winio [![Build Status](https://github.com/microsoft/go-winio/actions/workflows/ci.yml/badge.svg)](https://github.com/microsoft/go-winio/actions/workflows/ci.yml) + +This repository contains utilities for efficiently performing Win32 IO operations in +Go. Currently, this is focused on accessing named pipes and other file handles, and +for using named pipes as a net transport. + +This code relies on IO completion ports to avoid blocking IO on system threads, allowing Go +to reuse the thread to schedule another goroutine. This limits support to Windows Vista and +newer operating systems. This is similar to the implementation of network sockets in Go's net +package. + +Please see the LICENSE file for licensing information. + +## Contributing + +This project welcomes contributions and suggestions. +Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that +you have the right to, and actually do, grant us the rights to use your contribution. +For details, visit [Microsoft CLA](https://cla.microsoft.com). + +When you submit a pull request, a CLA-bot will automatically determine whether you need to +provide a CLA and decorate the PR appropriately (e.g., label, comment). +Simply follow the instructions provided by the bot. +You will only need to do this once across all repos using our CLA. + +Additionally, the pull request pipeline requires the following steps to be performed before +mergining. + +### Code Sign-Off + +We require that contributors sign their commits using [`git commit --signoff`][git-commit-s] +to certify they either authored the work themselves or otherwise have permission to use it in this project. + +A range of commits can be signed off using [`git rebase --signoff`][git-rebase-s]. + +Please see [the developer certificate](https://developercertificate.org) for more info, +as well as to make sure that you can attest to the rules listed. +Our CI uses the DCO Github app to ensure that all commits in a given PR are signed-off. + +### Linting + +Code must pass a linting stage, which uses [`golangci-lint`][lint]. +The linting settings are stored in [`.golangci.yaml`](./.golangci.yaml), and can be run +automatically with VSCode by adding the following to your workspace or folder settings: + +```json + "go.lintTool": "golangci-lint", + "go.lintOnSave": "package", +``` + +Additional editor [integrations options are also available][lint-ide]. + +Alternatively, `golangci-lint` can be [installed locally][lint-install] and run from the repo root: + +```shell +# use . or specify a path to only lint a package +# to show all lint errors, use flags "--max-issues-per-linter=0 --max-same-issues=0" +> golangci-lint run ./... +``` + +### Go Generate + +The pipeline checks that auto-generated code, via `go generate`, are up to date. + +This can be done for the entire repo: + +```shell +> go generate ./... +``` + +## Code of Conduct + +This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). +For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or +contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. + +## Special Thanks + +Thanks to [natefinch][natefinch] for the inspiration for this library. +See [npipe](https://github.com/natefinch/npipe) for another named pipe implementation. + +[lint]: https://golangci-lint.run/ +[lint-ide]: https://golangci-lint.run/usage/integrations/#editor-integration +[lint-install]: https://golangci-lint.run/usage/install/#local-installation + +[git-commit-s]: https://git-scm.com/docs/git-commit#Documentation/git-commit.txt--s +[git-rebase-s]: https://git-scm.com/docs/git-rebase#Documentation/git-rebase.txt---signoff + +[natefinch]: https://github.com/natefinch diff --git a/vendor/github.com/Microsoft/go-winio/SECURITY.md b/vendor/github.com/Microsoft/go-winio/SECURITY.md new file mode 100644 index 000000000..869fdfe2b --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/SECURITY.md @@ -0,0 +1,41 @@ + + +## Security + +Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). + +If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. + +## Reporting Security Issues + +**Please do not report security vulnerabilities through public GitHub issues.** + +Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). + +If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). + +You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). + +Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: + + * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) + * Full paths of source file(s) related to the manifestation of the issue + * The location of the affected source code (tag/branch/commit or direct URL) + * Any special configuration required to reproduce the issue + * Step-by-step instructions to reproduce the issue + * Proof-of-concept or exploit code (if possible) + * Impact of the issue, including how an attacker might exploit the issue + +This information will help us triage your report more quickly. + +If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. + +## Preferred Languages + +We prefer all communications to be in English. + +## Policy + +Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). + + diff --git a/vendor/github.com/Microsoft/go-winio/backup.go b/vendor/github.com/Microsoft/go-winio/backup.go new file mode 100644 index 000000000..b54341daa --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/backup.go @@ -0,0 +1,287 @@ +//go:build windows +// +build windows + +package winio + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "os" + "runtime" + "unicode/utf16" + + "github.com/Microsoft/go-winio/internal/fs" + "golang.org/x/sys/windows" +) + +//sys backupRead(h windows.Handle, b []byte, bytesRead *uint32, abort bool, processSecurity bool, context *uintptr) (err error) = BackupRead +//sys backupWrite(h windows.Handle, b []byte, bytesWritten *uint32, abort bool, processSecurity bool, context *uintptr) (err error) = BackupWrite + +const ( + BackupData = uint32(iota + 1) + BackupEaData + BackupSecurity + BackupAlternateData + BackupLink + BackupPropertyData + BackupObjectId //revive:disable-line:var-naming ID, not Id + BackupReparseData + BackupSparseBlock + BackupTxfsData +) + +const ( + StreamSparseAttributes = uint32(8) +) + +//nolint:revive // var-naming: ALL_CAPS +const ( + WRITE_DAC = windows.WRITE_DAC + WRITE_OWNER = windows.WRITE_OWNER + ACCESS_SYSTEM_SECURITY = windows.ACCESS_SYSTEM_SECURITY +) + +// BackupHeader represents a backup stream of a file. +type BackupHeader struct { + //revive:disable-next-line:var-naming ID, not Id + Id uint32 // The backup stream ID + Attributes uint32 // Stream attributes + Size int64 // The size of the stream in bytes + Name string // The name of the stream (for BackupAlternateData only). + Offset int64 // The offset of the stream in the file (for BackupSparseBlock only). +} + +type win32StreamID struct { + StreamID uint32 + Attributes uint32 + Size uint64 + NameSize uint32 +} + +// BackupStreamReader reads from a stream produced by the BackupRead Win32 API and produces a series +// of BackupHeader values. +type BackupStreamReader struct { + r io.Reader + bytesLeft int64 +} + +// NewBackupStreamReader produces a BackupStreamReader from any io.Reader. +func NewBackupStreamReader(r io.Reader) *BackupStreamReader { + return &BackupStreamReader{r, 0} +} + +// Next returns the next backup stream and prepares for calls to Read(). It skips the remainder of the current stream if +// it was not completely read. +func (r *BackupStreamReader) Next() (*BackupHeader, error) { + if r.bytesLeft > 0 { //nolint:nestif // todo: flatten this + if s, ok := r.r.(io.Seeker); ok { + // Make sure Seek on io.SeekCurrent sometimes succeeds + // before trying the actual seek. + if _, err := s.Seek(0, io.SeekCurrent); err == nil { + if _, err = s.Seek(r.bytesLeft, io.SeekCurrent); err != nil { + return nil, err + } + r.bytesLeft = 0 + } + } + if _, err := io.Copy(io.Discard, r); err != nil { + return nil, err + } + } + var wsi win32StreamID + if err := binary.Read(r.r, binary.LittleEndian, &wsi); err != nil { + return nil, err + } + hdr := &BackupHeader{ + Id: wsi.StreamID, + Attributes: wsi.Attributes, + Size: int64(wsi.Size), + } + if wsi.NameSize != 0 { + name := make([]uint16, int(wsi.NameSize/2)) + if err := binary.Read(r.r, binary.LittleEndian, name); err != nil { + return nil, err + } + hdr.Name = windows.UTF16ToString(name) + } + if wsi.StreamID == BackupSparseBlock { + if err := binary.Read(r.r, binary.LittleEndian, &hdr.Offset); err != nil { + return nil, err + } + hdr.Size -= 8 + } + r.bytesLeft = hdr.Size + return hdr, nil +} + +// Read reads from the current backup stream. +func (r *BackupStreamReader) Read(b []byte) (int, error) { + if r.bytesLeft == 0 { + return 0, io.EOF + } + if int64(len(b)) > r.bytesLeft { + b = b[:r.bytesLeft] + } + n, err := r.r.Read(b) + r.bytesLeft -= int64(n) + if err == io.EOF { + err = io.ErrUnexpectedEOF + } else if r.bytesLeft == 0 && err == nil { + err = io.EOF + } + return n, err +} + +// BackupStreamWriter writes a stream compatible with the BackupWrite Win32 API. +type BackupStreamWriter struct { + w io.Writer + bytesLeft int64 +} + +// NewBackupStreamWriter produces a BackupStreamWriter on top of an io.Writer. +func NewBackupStreamWriter(w io.Writer) *BackupStreamWriter { + return &BackupStreamWriter{w, 0} +} + +// WriteHeader writes the next backup stream header and prepares for calls to Write(). +func (w *BackupStreamWriter) WriteHeader(hdr *BackupHeader) error { + if w.bytesLeft != 0 { + return fmt.Errorf("missing %d bytes", w.bytesLeft) + } + name := utf16.Encode([]rune(hdr.Name)) + wsi := win32StreamID{ + StreamID: hdr.Id, + Attributes: hdr.Attributes, + Size: uint64(hdr.Size), + NameSize: uint32(len(name) * 2), + } + if hdr.Id == BackupSparseBlock { + // Include space for the int64 block offset + wsi.Size += 8 + } + if err := binary.Write(w.w, binary.LittleEndian, &wsi); err != nil { + return err + } + if len(name) != 0 { + if err := binary.Write(w.w, binary.LittleEndian, name); err != nil { + return err + } + } + if hdr.Id == BackupSparseBlock { + if err := binary.Write(w.w, binary.LittleEndian, hdr.Offset); err != nil { + return err + } + } + w.bytesLeft = hdr.Size + return nil +} + +// Write writes to the current backup stream. +func (w *BackupStreamWriter) Write(b []byte) (int, error) { + if w.bytesLeft < int64(len(b)) { + return 0, fmt.Errorf("too many bytes by %d", int64(len(b))-w.bytesLeft) + } + n, err := w.w.Write(b) + w.bytesLeft -= int64(n) + return n, err +} + +// BackupFileReader provides an io.ReadCloser interface on top of the BackupRead Win32 API. +type BackupFileReader struct { + f *os.File + includeSecurity bool + ctx uintptr +} + +// NewBackupFileReader returns a new BackupFileReader from a file handle. If includeSecurity is true, +// Read will attempt to read the security descriptor of the file. +func NewBackupFileReader(f *os.File, includeSecurity bool) *BackupFileReader { + r := &BackupFileReader{f, includeSecurity, 0} + return r +} + +// Read reads a backup stream from the file by calling the Win32 API BackupRead(). +func (r *BackupFileReader) Read(b []byte) (int, error) { + var bytesRead uint32 + err := backupRead(windows.Handle(r.f.Fd()), b, &bytesRead, false, r.includeSecurity, &r.ctx) + if err != nil { + return 0, &os.PathError{Op: "BackupRead", Path: r.f.Name(), Err: err} + } + runtime.KeepAlive(r.f) + if bytesRead == 0 { + return 0, io.EOF + } + return int(bytesRead), nil +} + +// Close frees Win32 resources associated with the BackupFileReader. It does not close +// the underlying file. +func (r *BackupFileReader) Close() error { + if r.ctx != 0 { + _ = backupRead(windows.Handle(r.f.Fd()), nil, nil, true, false, &r.ctx) + runtime.KeepAlive(r.f) + r.ctx = 0 + } + return nil +} + +// BackupFileWriter provides an io.WriteCloser interface on top of the BackupWrite Win32 API. +type BackupFileWriter struct { + f *os.File + includeSecurity bool + ctx uintptr +} + +// NewBackupFileWriter returns a new BackupFileWriter from a file handle. If includeSecurity is true, +// Write() will attempt to restore the security descriptor from the stream. +func NewBackupFileWriter(f *os.File, includeSecurity bool) *BackupFileWriter { + w := &BackupFileWriter{f, includeSecurity, 0} + return w +} + +// Write restores a portion of the file using the provided backup stream. +func (w *BackupFileWriter) Write(b []byte) (int, error) { + var bytesWritten uint32 + err := backupWrite(windows.Handle(w.f.Fd()), b, &bytesWritten, false, w.includeSecurity, &w.ctx) + if err != nil { + return 0, &os.PathError{Op: "BackupWrite", Path: w.f.Name(), Err: err} + } + runtime.KeepAlive(w.f) + if int(bytesWritten) != len(b) { + return int(bytesWritten), errors.New("not all bytes could be written") + } + return len(b), nil +} + +// Close frees Win32 resources associated with the BackupFileWriter. It does not +// close the underlying file. +func (w *BackupFileWriter) Close() error { + if w.ctx != 0 { + _ = backupWrite(windows.Handle(w.f.Fd()), nil, nil, true, false, &w.ctx) + runtime.KeepAlive(w.f) + w.ctx = 0 + } + return nil +} + +// OpenForBackup opens a file or directory, potentially skipping access checks if the backup +// or restore privileges have been acquired. +// +// If the file opened was a directory, it cannot be used with Readdir(). +func OpenForBackup(path string, access uint32, share uint32, createmode uint32) (*os.File, error) { + h, err := fs.CreateFile(path, + fs.AccessMask(access), + fs.FileShareMode(share), + nil, + fs.FileCreationDisposition(createmode), + fs.FILE_FLAG_BACKUP_SEMANTICS|fs.FILE_FLAG_OPEN_REPARSE_POINT, + 0, + ) + if err != nil { + err = &os.PathError{Op: "open", Path: path, Err: err} + return nil, err + } + return os.NewFile(uintptr(h), path), nil +} diff --git a/vendor/github.com/Microsoft/go-winio/doc.go b/vendor/github.com/Microsoft/go-winio/doc.go new file mode 100644 index 000000000..1f5bfe2d5 --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/doc.go @@ -0,0 +1,22 @@ +// This package provides utilities for efficiently performing Win32 IO operations in Go. +// Currently, this package is provides support for genreal IO and management of +// - named pipes +// - files +// - [Hyper-V sockets] +// +// This code is similar to Go's [net] package, and uses IO completion ports to avoid +// blocking IO on system threads, allowing Go to reuse the thread to schedule other goroutines. +// +// This limits support to Windows Vista and newer operating systems. +// +// Additionally, this package provides support for: +// - creating and managing GUIDs +// - writing to [ETW] +// - opening and manageing VHDs +// - parsing [Windows Image files] +// - auto-generating Win32 API code +// +// [Hyper-V sockets]: https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-guide/make-integration-service +// [ETW]: https://docs.microsoft.com/en-us/windows-hardware/drivers/devtest/event-tracing-for-windows--etw- +// [Windows Image files]: https://docs.microsoft.com/en-us/windows-hardware/manufacture/desktop/work-with-windows-images +package winio diff --git a/vendor/github.com/Microsoft/go-winio/ea.go b/vendor/github.com/Microsoft/go-winio/ea.go new file mode 100644 index 000000000..e104dbdfd --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/ea.go @@ -0,0 +1,137 @@ +package winio + +import ( + "bytes" + "encoding/binary" + "errors" +) + +type fileFullEaInformation struct { + NextEntryOffset uint32 + Flags uint8 + NameLength uint8 + ValueLength uint16 +} + +var ( + fileFullEaInformationSize = binary.Size(&fileFullEaInformation{}) + + errInvalidEaBuffer = errors.New("invalid extended attribute buffer") + errEaNameTooLarge = errors.New("extended attribute name too large") + errEaValueTooLarge = errors.New("extended attribute value too large") +) + +// ExtendedAttribute represents a single Windows EA. +type ExtendedAttribute struct { + Name string + Value []byte + Flags uint8 +} + +func parseEa(b []byte) (ea ExtendedAttribute, nb []byte, err error) { + var info fileFullEaInformation + err = binary.Read(bytes.NewReader(b), binary.LittleEndian, &info) + if err != nil { + err = errInvalidEaBuffer + return ea, nb, err + } + + nameOffset := fileFullEaInformationSize + nameLen := int(info.NameLength) + valueOffset := nameOffset + int(info.NameLength) + 1 + valueLen := int(info.ValueLength) + nextOffset := int(info.NextEntryOffset) + if valueLen+valueOffset > len(b) || nextOffset < 0 || nextOffset > len(b) { + err = errInvalidEaBuffer + return ea, nb, err + } + + ea.Name = string(b[nameOffset : nameOffset+nameLen]) + ea.Value = b[valueOffset : valueOffset+valueLen] + ea.Flags = info.Flags + if info.NextEntryOffset != 0 { + nb = b[info.NextEntryOffset:] + } + return ea, nb, err +} + +// DecodeExtendedAttributes decodes a list of EAs from a FILE_FULL_EA_INFORMATION +// buffer retrieved from BackupRead, ZwQueryEaFile, etc. +func DecodeExtendedAttributes(b []byte) (eas []ExtendedAttribute, err error) { + for len(b) != 0 { + ea, nb, err := parseEa(b) + if err != nil { + return nil, err + } + + eas = append(eas, ea) + b = nb + } + return eas, err +} + +func writeEa(buf *bytes.Buffer, ea *ExtendedAttribute, last bool) error { + if int(uint8(len(ea.Name))) != len(ea.Name) { + return errEaNameTooLarge + } + if int(uint16(len(ea.Value))) != len(ea.Value) { + return errEaValueTooLarge + } + entrySize := uint32(fileFullEaInformationSize + len(ea.Name) + 1 + len(ea.Value)) + withPadding := (entrySize + 3) &^ 3 + nextOffset := uint32(0) + if !last { + nextOffset = withPadding + } + info := fileFullEaInformation{ + NextEntryOffset: nextOffset, + Flags: ea.Flags, + NameLength: uint8(len(ea.Name)), + ValueLength: uint16(len(ea.Value)), + } + + err := binary.Write(buf, binary.LittleEndian, &info) + if err != nil { + return err + } + + _, err = buf.Write([]byte(ea.Name)) + if err != nil { + return err + } + + err = buf.WriteByte(0) + if err != nil { + return err + } + + _, err = buf.Write(ea.Value) + if err != nil { + return err + } + + _, err = buf.Write([]byte{0, 0, 0}[0 : withPadding-entrySize]) + if err != nil { + return err + } + + return nil +} + +// EncodeExtendedAttributes encodes a list of EAs into a FILE_FULL_EA_INFORMATION +// buffer for use with BackupWrite, ZwSetEaFile, etc. +func EncodeExtendedAttributes(eas []ExtendedAttribute) ([]byte, error) { + var buf bytes.Buffer + for i := range eas { + last := false + if i == len(eas)-1 { + last = true + } + + err := writeEa(&buf, &eas[i], last) + if err != nil { + return nil, err + } + } + return buf.Bytes(), nil +} diff --git a/vendor/github.com/Microsoft/go-winio/file.go b/vendor/github.com/Microsoft/go-winio/file.go new file mode 100644 index 000000000..fe82a180d --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/file.go @@ -0,0 +1,320 @@ +//go:build windows +// +build windows + +package winio + +import ( + "errors" + "io" + "runtime" + "sync" + "sync/atomic" + "syscall" + "time" + + "golang.org/x/sys/windows" +) + +//sys cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) = CancelIoEx +//sys createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) = CreateIoCompletionPort +//sys getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus +//sys setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes +//sys wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult + +var ( + ErrFileClosed = errors.New("file has already been closed") + ErrTimeout = &timeoutError{} +) + +type timeoutError struct{} + +func (*timeoutError) Error() string { return "i/o timeout" } +func (*timeoutError) Timeout() bool { return true } +func (*timeoutError) Temporary() bool { return true } + +type timeoutChan chan struct{} + +var ioInitOnce sync.Once +var ioCompletionPort windows.Handle + +// ioResult contains the result of an asynchronous IO operation. +type ioResult struct { + bytes uint32 + err error +} + +// ioOperation represents an outstanding asynchronous Win32 IO. +type ioOperation struct { + o windows.Overlapped + ch chan ioResult +} + +func initIO() { + h, err := createIoCompletionPort(windows.InvalidHandle, 0, 0, 0xffffffff) + if err != nil { + panic(err) + } + ioCompletionPort = h + go ioCompletionProcessor(h) +} + +// win32File implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall. +// It takes ownership of this handle and will close it if it is garbage collected. +type win32File struct { + handle windows.Handle + wg sync.WaitGroup + wgLock sync.RWMutex + closing atomic.Bool + socket bool + readDeadline deadlineHandler + writeDeadline deadlineHandler +} + +type deadlineHandler struct { + setLock sync.Mutex + channel timeoutChan + channelLock sync.RWMutex + timer *time.Timer + timedout atomic.Bool +} + +// makeWin32File makes a new win32File from an existing file handle. +func makeWin32File(h windows.Handle) (*win32File, error) { + f := &win32File{handle: h} + ioInitOnce.Do(initIO) + _, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff) + if err != nil { + return nil, err + } + err = setFileCompletionNotificationModes(h, windows.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS|windows.FILE_SKIP_SET_EVENT_ON_HANDLE) + if err != nil { + return nil, err + } + f.readDeadline.channel = make(timeoutChan) + f.writeDeadline.channel = make(timeoutChan) + return f, nil +} + +// Deprecated: use NewOpenFile instead. +func MakeOpenFile(h syscall.Handle) (io.ReadWriteCloser, error) { + return NewOpenFile(windows.Handle(h)) +} + +func NewOpenFile(h windows.Handle) (io.ReadWriteCloser, error) { + // If we return the result of makeWin32File directly, it can result in an + // interface-wrapped nil, rather than a nil interface value. + f, err := makeWin32File(h) + if err != nil { + return nil, err + } + return f, nil +} + +// closeHandle closes the resources associated with a Win32 handle. +func (f *win32File) closeHandle() { + f.wgLock.Lock() + // Atomically set that we are closing, releasing the resources only once. + if !f.closing.Swap(true) { + f.wgLock.Unlock() + // cancel all IO and wait for it to complete + _ = cancelIoEx(f.handle, nil) + f.wg.Wait() + // at this point, no new IO can start + windows.Close(f.handle) + f.handle = 0 + } else { + f.wgLock.Unlock() + } +} + +// Close closes a win32File. +func (f *win32File) Close() error { + f.closeHandle() + return nil +} + +// IsClosed checks if the file has been closed. +func (f *win32File) IsClosed() bool { + return f.closing.Load() +} + +// prepareIO prepares for a new IO operation. +// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning. +func (f *win32File) prepareIO() (*ioOperation, error) { + f.wgLock.RLock() + if f.closing.Load() { + f.wgLock.RUnlock() + return nil, ErrFileClosed + } + f.wg.Add(1) + f.wgLock.RUnlock() + c := &ioOperation{} + c.ch = make(chan ioResult) + return c, nil +} + +// ioCompletionProcessor processes completed async IOs forever. +func ioCompletionProcessor(h windows.Handle) { + for { + var bytes uint32 + var key uintptr + var op *ioOperation + err := getQueuedCompletionStatus(h, &bytes, &key, &op, windows.INFINITE) + if op == nil { + panic(err) + } + op.ch <- ioResult{bytes, err} + } +} + +// todo: helsaawy - create an asyncIO version that takes a context + +// asyncIO processes the return value from ReadFile or WriteFile, blocking until +// the operation has actually completed. +func (f *win32File) asyncIO(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) { + if err != windows.ERROR_IO_PENDING { //nolint:errorlint // err is Errno + return int(bytes), err + } + + if f.closing.Load() { + _ = cancelIoEx(f.handle, &c.o) + } + + var timeout timeoutChan + if d != nil { + d.channelLock.Lock() + timeout = d.channel + d.channelLock.Unlock() + } + + var r ioResult + select { + case r = <-c.ch: + err = r.err + if err == windows.ERROR_OPERATION_ABORTED { //nolint:errorlint // err is Errno + if f.closing.Load() { + err = ErrFileClosed + } + } else if err != nil && f.socket { + // err is from Win32. Query the overlapped structure to get the winsock error. + var bytes, flags uint32 + err = wsaGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags) + } + case <-timeout: + _ = cancelIoEx(f.handle, &c.o) + r = <-c.ch + err = r.err + if err == windows.ERROR_OPERATION_ABORTED { //nolint:errorlint // err is Errno + err = ErrTimeout + } + } + + // runtime.KeepAlive is needed, as c is passed via native + // code to ioCompletionProcessor, c must remain alive + // until the channel read is complete. + // todo: (de)allocate *ioOperation via win32 heap functions, instead of needing to KeepAlive? + runtime.KeepAlive(c) + return int(r.bytes), err +} + +// Read reads from a file handle. +func (f *win32File) Read(b []byte) (int, error) { + c, err := f.prepareIO() + if err != nil { + return 0, err + } + defer f.wg.Done() + + if f.readDeadline.timedout.Load() { + return 0, ErrTimeout + } + + var bytes uint32 + err = windows.ReadFile(f.handle, b, &bytes, &c.o) + n, err := f.asyncIO(c, &f.readDeadline, bytes, err) + runtime.KeepAlive(b) + + // Handle EOF conditions. + if err == nil && n == 0 && len(b) != 0 { + return 0, io.EOF + } else if err == windows.ERROR_BROKEN_PIPE { //nolint:errorlint // err is Errno + return 0, io.EOF + } + return n, err +} + +// Write writes to a file handle. +func (f *win32File) Write(b []byte) (int, error) { + c, err := f.prepareIO() + if err != nil { + return 0, err + } + defer f.wg.Done() + + if f.writeDeadline.timedout.Load() { + return 0, ErrTimeout + } + + var bytes uint32 + err = windows.WriteFile(f.handle, b, &bytes, &c.o) + n, err := f.asyncIO(c, &f.writeDeadline, bytes, err) + runtime.KeepAlive(b) + return n, err +} + +func (f *win32File) SetReadDeadline(deadline time.Time) error { + return f.readDeadline.set(deadline) +} + +func (f *win32File) SetWriteDeadline(deadline time.Time) error { + return f.writeDeadline.set(deadline) +} + +func (f *win32File) Flush() error { + return windows.FlushFileBuffers(f.handle) +} + +func (f *win32File) Fd() uintptr { + return uintptr(f.handle) +} + +func (d *deadlineHandler) set(deadline time.Time) error { + d.setLock.Lock() + defer d.setLock.Unlock() + + if d.timer != nil { + if !d.timer.Stop() { + <-d.channel + } + d.timer = nil + } + d.timedout.Store(false) + + select { + case <-d.channel: + d.channelLock.Lock() + d.channel = make(chan struct{}) + d.channelLock.Unlock() + default: + } + + if deadline.IsZero() { + return nil + } + + timeoutIO := func() { + d.timedout.Store(true) + close(d.channel) + } + + now := time.Now() + duration := deadline.Sub(now) + if deadline.After(now) { + // Deadline is in the future, set a timer to wait + d.timer = time.AfterFunc(duration, timeoutIO) + } else { + // Deadline is in the past. Cancel all pending IO now. + timeoutIO() + } + return nil +} diff --git a/vendor/github.com/Microsoft/go-winio/fileinfo.go b/vendor/github.com/Microsoft/go-winio/fileinfo.go new file mode 100644 index 000000000..c860eb991 --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/fileinfo.go @@ -0,0 +1,106 @@ +//go:build windows +// +build windows + +package winio + +import ( + "os" + "runtime" + "unsafe" + + "golang.org/x/sys/windows" +) + +// FileBasicInfo contains file access time and file attributes information. +type FileBasicInfo struct { + CreationTime, LastAccessTime, LastWriteTime, ChangeTime windows.Filetime + FileAttributes uint32 + _ uint32 // padding +} + +// alignedFileBasicInfo is a FileBasicInfo, but aligned to uint64 by containing +// uint64 rather than windows.Filetime. Filetime contains two uint32s. uint64 +// alignment is necessary to pass this as FILE_BASIC_INFO. +type alignedFileBasicInfo struct { + CreationTime, LastAccessTime, LastWriteTime, ChangeTime uint64 + FileAttributes uint32 + _ uint32 // padding +} + +// GetFileBasicInfo retrieves times and attributes for a file. +func GetFileBasicInfo(f *os.File) (*FileBasicInfo, error) { + bi := &alignedFileBasicInfo{} + if err := windows.GetFileInformationByHandleEx( + windows.Handle(f.Fd()), + windows.FileBasicInfo, + (*byte)(unsafe.Pointer(bi)), + uint32(unsafe.Sizeof(*bi)), + ); err != nil { + return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err} + } + runtime.KeepAlive(f) + // Reinterpret the alignedFileBasicInfo as a FileBasicInfo so it matches the + // public API of this module. The data may be unnecessarily aligned. + return (*FileBasicInfo)(unsafe.Pointer(bi)), nil +} + +// SetFileBasicInfo sets times and attributes for a file. +func SetFileBasicInfo(f *os.File, bi *FileBasicInfo) error { + // Create an alignedFileBasicInfo based on a FileBasicInfo. The copy is + // suitable to pass to GetFileInformationByHandleEx. + biAligned := *(*alignedFileBasicInfo)(unsafe.Pointer(bi)) + if err := windows.SetFileInformationByHandle( + windows.Handle(f.Fd()), + windows.FileBasicInfo, + (*byte)(unsafe.Pointer(&biAligned)), + uint32(unsafe.Sizeof(biAligned)), + ); err != nil { + return &os.PathError{Op: "SetFileInformationByHandle", Path: f.Name(), Err: err} + } + runtime.KeepAlive(f) + return nil +} + +// FileStandardInfo contains extended information for the file. +// FILE_STANDARD_INFO in WinBase.h +// https://docs.microsoft.com/en-us/windows/win32/api/winbase/ns-winbase-file_standard_info +type FileStandardInfo struct { + AllocationSize, EndOfFile int64 + NumberOfLinks uint32 + DeletePending, Directory bool +} + +// GetFileStandardInfo retrieves ended information for the file. +func GetFileStandardInfo(f *os.File) (*FileStandardInfo, error) { + si := &FileStandardInfo{} + if err := windows.GetFileInformationByHandleEx(windows.Handle(f.Fd()), + windows.FileStandardInfo, + (*byte)(unsafe.Pointer(si)), + uint32(unsafe.Sizeof(*si))); err != nil { + return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err} + } + runtime.KeepAlive(f) + return si, nil +} + +// FileIDInfo contains the volume serial number and file ID for a file. This pair should be +// unique on a system. +type FileIDInfo struct { + VolumeSerialNumber uint64 + FileID [16]byte +} + +// GetFileID retrieves the unique (volume, file ID) pair for a file. +func GetFileID(f *os.File) (*FileIDInfo, error) { + fileID := &FileIDInfo{} + if err := windows.GetFileInformationByHandleEx( + windows.Handle(f.Fd()), + windows.FileIdInfo, + (*byte)(unsafe.Pointer(fileID)), + uint32(unsafe.Sizeof(*fileID)), + ); err != nil { + return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err} + } + runtime.KeepAlive(f) + return fileID, nil +} diff --git a/vendor/github.com/Microsoft/go-winio/hvsock.go b/vendor/github.com/Microsoft/go-winio/hvsock.go new file mode 100644 index 000000000..c4fdd9d4a --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/hvsock.go @@ -0,0 +1,582 @@ +//go:build windows +// +build windows + +package winio + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "os" + "time" + "unsafe" + + "golang.org/x/sys/windows" + + "github.com/Microsoft/go-winio/internal/socket" + "github.com/Microsoft/go-winio/pkg/guid" +) + +const afHVSock = 34 // AF_HYPERV + +// Well known Service and VM IDs +// https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-guide/make-integration-service#vmid-wildcards + +// HvsockGUIDWildcard is the wildcard VmId for accepting connections from all partitions. +func HvsockGUIDWildcard() guid.GUID { // 00000000-0000-0000-0000-000000000000 + return guid.GUID{} +} + +// HvsockGUIDBroadcast is the wildcard VmId for broadcasting sends to all partitions. +func HvsockGUIDBroadcast() guid.GUID { // ffffffff-ffff-ffff-ffff-ffffffffffff + return guid.GUID{ + Data1: 0xffffffff, + Data2: 0xffff, + Data3: 0xffff, + Data4: [8]uint8{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + } +} + +// HvsockGUIDLoopback is the Loopback VmId for accepting connections to the same partition as the connector. +func HvsockGUIDLoopback() guid.GUID { // e0e16197-dd56-4a10-9195-5ee7a155a838 + return guid.GUID{ + Data1: 0xe0e16197, + Data2: 0xdd56, + Data3: 0x4a10, + Data4: [8]uint8{0x91, 0x95, 0x5e, 0xe7, 0xa1, 0x55, 0xa8, 0x38}, + } +} + +// HvsockGUIDSiloHost is the address of a silo's host partition: +// - The silo host of a hosted silo is the utility VM. +// - The silo host of a silo on a physical host is the physical host. +func HvsockGUIDSiloHost() guid.GUID { // 36bd0c5c-7276-4223-88ba-7d03b654c568 + return guid.GUID{ + Data1: 0x36bd0c5c, + Data2: 0x7276, + Data3: 0x4223, + Data4: [8]byte{0x88, 0xba, 0x7d, 0x03, 0xb6, 0x54, 0xc5, 0x68}, + } +} + +// HvsockGUIDChildren is the wildcard VmId for accepting connections from the connector's child partitions. +func HvsockGUIDChildren() guid.GUID { // 90db8b89-0d35-4f79-8ce9-49ea0ac8b7cd + return guid.GUID{ + Data1: 0x90db8b89, + Data2: 0xd35, + Data3: 0x4f79, + Data4: [8]uint8{0x8c, 0xe9, 0x49, 0xea, 0xa, 0xc8, 0xb7, 0xcd}, + } +} + +// HvsockGUIDParent is the wildcard VmId for accepting connections from the connector's parent partition. +// Listening on this VmId accepts connection from: +// - Inside silos: silo host partition. +// - Inside hosted silo: host of the VM. +// - Inside VM: VM host. +// - Physical host: Not supported. +func HvsockGUIDParent() guid.GUID { // a42e7cda-d03f-480c-9cc2-a4de20abb878 + return guid.GUID{ + Data1: 0xa42e7cda, + Data2: 0xd03f, + Data3: 0x480c, + Data4: [8]uint8{0x9c, 0xc2, 0xa4, 0xde, 0x20, 0xab, 0xb8, 0x78}, + } +} + +// hvsockVsockServiceTemplate is the Service GUID used for the VSOCK protocol. +func hvsockVsockServiceTemplate() guid.GUID { // 00000000-facb-11e6-bd58-64006a7986d3 + return guid.GUID{ + Data2: 0xfacb, + Data3: 0x11e6, + Data4: [8]uint8{0xbd, 0x58, 0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3}, + } +} + +// An HvsockAddr is an address for a AF_HYPERV socket. +type HvsockAddr struct { + VMID guid.GUID + ServiceID guid.GUID +} + +type rawHvsockAddr struct { + Family uint16 + _ uint16 + VMID guid.GUID + ServiceID guid.GUID +} + +var _ socket.RawSockaddr = &rawHvsockAddr{} + +// Network returns the address's network name, "hvsock". +func (*HvsockAddr) Network() string { + return "hvsock" +} + +func (addr *HvsockAddr) String() string { + return fmt.Sprintf("%s:%s", &addr.VMID, &addr.ServiceID) +} + +// VsockServiceID returns an hvsock service ID corresponding to the specified AF_VSOCK port. +func VsockServiceID(port uint32) guid.GUID { + g := hvsockVsockServiceTemplate() // make a copy + g.Data1 = port + return g +} + +func (addr *HvsockAddr) raw() rawHvsockAddr { + return rawHvsockAddr{ + Family: afHVSock, + VMID: addr.VMID, + ServiceID: addr.ServiceID, + } +} + +func (addr *HvsockAddr) fromRaw(raw *rawHvsockAddr) { + addr.VMID = raw.VMID + addr.ServiceID = raw.ServiceID +} + +// Sockaddr returns a pointer to and the size of this struct. +// +// Implements the [socket.RawSockaddr] interface, and allows use in +// [socket.Bind] and [socket.ConnectEx]. +func (r *rawHvsockAddr) Sockaddr() (unsafe.Pointer, int32, error) { + return unsafe.Pointer(r), int32(unsafe.Sizeof(rawHvsockAddr{})), nil +} + +// Sockaddr interface allows use with `sockets.Bind()` and `.ConnectEx()`. +func (r *rawHvsockAddr) FromBytes(b []byte) error { + n := int(unsafe.Sizeof(rawHvsockAddr{})) + + if len(b) < n { + return fmt.Errorf("got %d, want %d: %w", len(b), n, socket.ErrBufferSize) + } + + copy(unsafe.Slice((*byte)(unsafe.Pointer(r)), n), b[:n]) + if r.Family != afHVSock { + return fmt.Errorf("got %d, want %d: %w", r.Family, afHVSock, socket.ErrAddrFamily) + } + + return nil +} + +// HvsockListener is a socket listener for the AF_HYPERV address family. +type HvsockListener struct { + sock *win32File + addr HvsockAddr +} + +var _ net.Listener = &HvsockListener{} + +// HvsockConn is a connected socket of the AF_HYPERV address family. +type HvsockConn struct { + sock *win32File + local, remote HvsockAddr +} + +var _ net.Conn = &HvsockConn{} + +func newHVSocket() (*win32File, error) { + fd, err := windows.Socket(afHVSock, windows.SOCK_STREAM, 1) + if err != nil { + return nil, os.NewSyscallError("socket", err) + } + f, err := makeWin32File(fd) + if err != nil { + windows.Close(fd) + return nil, err + } + f.socket = true + return f, nil +} + +// ListenHvsock listens for connections on the specified hvsock address. +func ListenHvsock(addr *HvsockAddr) (_ *HvsockListener, err error) { + l := &HvsockListener{addr: *addr} + + var sock *win32File + sock, err = newHVSocket() + if err != nil { + return nil, l.opErr("listen", err) + } + defer func() { + if err != nil { + _ = sock.Close() + } + }() + + sa := addr.raw() + err = socket.Bind(sock.handle, &sa) + if err != nil { + return nil, l.opErr("listen", os.NewSyscallError("socket", err)) + } + err = windows.Listen(sock.handle, 16) + if err != nil { + return nil, l.opErr("listen", os.NewSyscallError("listen", err)) + } + return &HvsockListener{sock: sock, addr: *addr}, nil +} + +func (l *HvsockListener) opErr(op string, err error) error { + return &net.OpError{Op: op, Net: "hvsock", Addr: &l.addr, Err: err} +} + +// Addr returns the listener's network address. +func (l *HvsockListener) Addr() net.Addr { + return &l.addr +} + +// Accept waits for the next connection and returns it. +func (l *HvsockListener) Accept() (_ net.Conn, err error) { + sock, err := newHVSocket() + if err != nil { + return nil, l.opErr("accept", err) + } + defer func() { + if sock != nil { + sock.Close() + } + }() + c, err := l.sock.prepareIO() + if err != nil { + return nil, l.opErr("accept", err) + } + defer l.sock.wg.Done() + + // AcceptEx, per documentation, requires an extra 16 bytes per address. + // + // https://docs.microsoft.com/en-us/windows/win32/api/mswsock/nf-mswsock-acceptex + const addrlen = uint32(16 + unsafe.Sizeof(rawHvsockAddr{})) + var addrbuf [addrlen * 2]byte + + var bytes uint32 + err = windows.AcceptEx(l.sock.handle, sock.handle, &addrbuf[0], 0 /* rxdatalen */, addrlen, addrlen, &bytes, &c.o) + if _, err = l.sock.asyncIO(c, nil, bytes, err); err != nil { + return nil, l.opErr("accept", os.NewSyscallError("acceptex", err)) + } + + conn := &HvsockConn{ + sock: sock, + } + // The local address returned in the AcceptEx buffer is the same as the Listener socket's + // address. However, the service GUID reported by GetSockName is different from the Listeners + // socket, and is sometimes the same as the local address of the socket that dialed the + // address, with the service GUID.Data1 incremented, but othertimes is different. + // todo: does the local address matter? is the listener's address or the actual address appropriate? + conn.local.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[0]))) + conn.remote.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[addrlen]))) + + // initialize the accepted socket and update its properties with those of the listening socket + if err = windows.Setsockopt(sock.handle, + windows.SOL_SOCKET, windows.SO_UPDATE_ACCEPT_CONTEXT, + (*byte)(unsafe.Pointer(&l.sock.handle)), int32(unsafe.Sizeof(l.sock.handle))); err != nil { + return nil, conn.opErr("accept", os.NewSyscallError("setsockopt", err)) + } + + sock = nil + return conn, nil +} + +// Close closes the listener, causing any pending Accept calls to fail. +func (l *HvsockListener) Close() error { + return l.sock.Close() +} + +// HvsockDialer configures and dials a Hyper-V Socket (ie, [HvsockConn]). +type HvsockDialer struct { + // Deadline is the time the Dial operation must connect before erroring. + Deadline time.Time + + // Retries is the number of additional connects to try if the connection times out, is refused, + // or the host is unreachable + Retries uint + + // RetryWait is the time to wait after a connection error to retry + RetryWait time.Duration + + rt *time.Timer // redial wait timer +} + +// Dial the Hyper-V socket at addr. +// +// See [HvsockDialer.Dial] for more information. +func Dial(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) { + return (&HvsockDialer{}).Dial(ctx, addr) +} + +// Dial attempts to connect to the Hyper-V socket at addr, and returns a connection if successful. +// Will attempt (HvsockDialer).Retries if dialing fails, waiting (HvsockDialer).RetryWait between +// retries. +// +// Dialing can be cancelled either by providing (HvsockDialer).Deadline, or cancelling ctx. +func (d *HvsockDialer) Dial(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) { + op := "dial" + // create the conn early to use opErr() + conn = &HvsockConn{ + remote: *addr, + } + + if !d.Deadline.IsZero() { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(ctx, d.Deadline) + defer cancel() + } + + // preemptive timeout/cancellation check + if err = ctx.Err(); err != nil { + return nil, conn.opErr(op, err) + } + + sock, err := newHVSocket() + if err != nil { + return nil, conn.opErr(op, err) + } + defer func() { + if sock != nil { + sock.Close() + } + }() + + sa := addr.raw() + err = socket.Bind(sock.handle, &sa) + if err != nil { + return nil, conn.opErr(op, os.NewSyscallError("bind", err)) + } + + c, err := sock.prepareIO() + if err != nil { + return nil, conn.opErr(op, err) + } + defer sock.wg.Done() + var bytes uint32 + for i := uint(0); i <= d.Retries; i++ { + err = socket.ConnectEx( + sock.handle, + &sa, + nil, // sendBuf + 0, // sendDataLen + &bytes, + (*windows.Overlapped)(unsafe.Pointer(&c.o))) + _, err = sock.asyncIO(c, nil, bytes, err) + if i < d.Retries && canRedial(err) { + if err = d.redialWait(ctx); err == nil { + continue + } + } + break + } + if err != nil { + return nil, conn.opErr(op, os.NewSyscallError("connectex", err)) + } + + // update the connection properties, so shutdown can be used + if err = windows.Setsockopt( + sock.handle, + windows.SOL_SOCKET, + windows.SO_UPDATE_CONNECT_CONTEXT, + nil, // optvalue + 0, // optlen + ); err != nil { + return nil, conn.opErr(op, os.NewSyscallError("setsockopt", err)) + } + + // get the local name + var sal rawHvsockAddr + err = socket.GetSockName(sock.handle, &sal) + if err != nil { + return nil, conn.opErr(op, os.NewSyscallError("getsockname", err)) + } + conn.local.fromRaw(&sal) + + // one last check for timeout, since asyncIO doesn't check the context + if err = ctx.Err(); err != nil { + return nil, conn.opErr(op, err) + } + + conn.sock = sock + sock = nil + + return conn, nil +} + +// redialWait waits before attempting to redial, resetting the timer as appropriate. +func (d *HvsockDialer) redialWait(ctx context.Context) (err error) { + if d.RetryWait == 0 { + return nil + } + + if d.rt == nil { + d.rt = time.NewTimer(d.RetryWait) + } else { + // should already be stopped and drained + d.rt.Reset(d.RetryWait) + } + + select { + case <-ctx.Done(): + case <-d.rt.C: + return nil + } + + // stop and drain the timer + if !d.rt.Stop() { + <-d.rt.C + } + return ctx.Err() +} + +// assumes error is a plain, unwrapped windows.Errno provided by direct syscall. +func canRedial(err error) bool { + //nolint:errorlint // guaranteed to be an Errno + switch err { + case windows.WSAECONNREFUSED, windows.WSAENETUNREACH, windows.WSAETIMEDOUT, + windows.ERROR_CONNECTION_REFUSED, windows.ERROR_CONNECTION_UNAVAIL: + return true + default: + return false + } +} + +func (conn *HvsockConn) opErr(op string, err error) error { + // translate from "file closed" to "socket closed" + if errors.Is(err, ErrFileClosed) { + err = socket.ErrSocketClosed + } + return &net.OpError{Op: op, Net: "hvsock", Source: &conn.local, Addr: &conn.remote, Err: err} +} + +func (conn *HvsockConn) Read(b []byte) (int, error) { + c, err := conn.sock.prepareIO() + if err != nil { + return 0, conn.opErr("read", err) + } + defer conn.sock.wg.Done() + buf := windows.WSABuf{Buf: &b[0], Len: uint32(len(b))} + var flags, bytes uint32 + err = windows.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil) + n, err := conn.sock.asyncIO(c, &conn.sock.readDeadline, bytes, err) + if err != nil { + var eno windows.Errno + if errors.As(err, &eno) { + err = os.NewSyscallError("wsarecv", eno) + } + return 0, conn.opErr("read", err) + } else if n == 0 { + err = io.EOF + } + return n, err +} + +func (conn *HvsockConn) Write(b []byte) (int, error) { + t := 0 + for len(b) != 0 { + n, err := conn.write(b) + if err != nil { + return t + n, err + } + t += n + b = b[n:] + } + return t, nil +} + +func (conn *HvsockConn) write(b []byte) (int, error) { + c, err := conn.sock.prepareIO() + if err != nil { + return 0, conn.opErr("write", err) + } + defer conn.sock.wg.Done() + buf := windows.WSABuf{Buf: &b[0], Len: uint32(len(b))} + var bytes uint32 + err = windows.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil) + n, err := conn.sock.asyncIO(c, &conn.sock.writeDeadline, bytes, err) + if err != nil { + var eno windows.Errno + if errors.As(err, &eno) { + err = os.NewSyscallError("wsasend", eno) + } + return 0, conn.opErr("write", err) + } + return n, err +} + +// Close closes the socket connection, failing any pending read or write calls. +func (conn *HvsockConn) Close() error { + return conn.sock.Close() +} + +func (conn *HvsockConn) IsClosed() bool { + return conn.sock.IsClosed() +} + +// shutdown disables sending or receiving on a socket. +func (conn *HvsockConn) shutdown(how int) error { + if conn.IsClosed() { + return socket.ErrSocketClosed + } + + err := windows.Shutdown(conn.sock.handle, how) + if err != nil { + // If the connection was closed, shutdowns fail with "not connected" + if errors.Is(err, windows.WSAENOTCONN) || + errors.Is(err, windows.WSAESHUTDOWN) { + err = socket.ErrSocketClosed + } + return os.NewSyscallError("shutdown", err) + } + return nil +} + +// CloseRead shuts down the read end of the socket, preventing future read operations. +func (conn *HvsockConn) CloseRead() error { + err := conn.shutdown(windows.SHUT_RD) + if err != nil { + return conn.opErr("closeread", err) + } + return nil +} + +// CloseWrite shuts down the write end of the socket, preventing future write operations and +// notifying the other endpoint that no more data will be written. +func (conn *HvsockConn) CloseWrite() error { + err := conn.shutdown(windows.SHUT_WR) + if err != nil { + return conn.opErr("closewrite", err) + } + return nil +} + +// LocalAddr returns the local address of the connection. +func (conn *HvsockConn) LocalAddr() net.Addr { + return &conn.local +} + +// RemoteAddr returns the remote address of the connection. +func (conn *HvsockConn) RemoteAddr() net.Addr { + return &conn.remote +} + +// SetDeadline implements the net.Conn SetDeadline method. +func (conn *HvsockConn) SetDeadline(t time.Time) error { + // todo: implement `SetDeadline` for `win32File` + if err := conn.SetReadDeadline(t); err != nil { + return fmt.Errorf("set read deadline: %w", err) + } + if err := conn.SetWriteDeadline(t); err != nil { + return fmt.Errorf("set write deadline: %w", err) + } + return nil +} + +// SetReadDeadline implements the net.Conn SetReadDeadline method. +func (conn *HvsockConn) SetReadDeadline(t time.Time) error { + return conn.sock.SetReadDeadline(t) +} + +// SetWriteDeadline implements the net.Conn SetWriteDeadline method. +func (conn *HvsockConn) SetWriteDeadline(t time.Time) error { + return conn.sock.SetWriteDeadline(t) +} diff --git a/vendor/github.com/Microsoft/go-winio/internal/fs/doc.go b/vendor/github.com/Microsoft/go-winio/internal/fs/doc.go new file mode 100644 index 000000000..1f6538817 --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/internal/fs/doc.go @@ -0,0 +1,2 @@ +// This package contains Win32 filesystem functionality. +package fs diff --git a/vendor/github.com/Microsoft/go-winio/internal/fs/fs.go b/vendor/github.com/Microsoft/go-winio/internal/fs/fs.go new file mode 100644 index 000000000..0cd9621df --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/internal/fs/fs.go @@ -0,0 +1,262 @@ +//go:build windows + +package fs + +import ( + "golang.org/x/sys/windows" + + "github.com/Microsoft/go-winio/internal/stringbuffer" +) + +//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go fs.go + +// https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilew +//sys CreateFile(name string, access AccessMask, mode FileShareMode, sa *windows.SecurityAttributes, createmode FileCreationDisposition, attrs FileFlagOrAttribute, templatefile windows.Handle) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateFileW + +const NullHandle windows.Handle = 0 + +// AccessMask defines standard, specific, and generic rights. +// +// Used with CreateFile and NtCreateFile (and co.). +// +// Bitmask: +// 3 3 2 2 2 2 2 2 2 2 2 2 1 1 1 1 1 1 1 1 1 1 +// 1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0 +// +---------------+---------------+-------------------------------+ +// |G|G|G|G|Resvd|A| StandardRights| SpecificRights | +// |R|W|E|A| |S| | | +// +-+-------------+---------------+-------------------------------+ +// +// GR Generic Read +// GW Generic Write +// GE Generic Exectue +// GA Generic All +// Resvd Reserved +// AS Access Security System +// +// https://learn.microsoft.com/en-us/windows/win32/secauthz/access-mask +// +// https://learn.microsoft.com/en-us/windows/win32/secauthz/generic-access-rights +// +// https://learn.microsoft.com/en-us/windows/win32/fileio/file-access-rights-constants +type AccessMask = windows.ACCESS_MASK + +//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API. +const ( + // Not actually any. + // + // For CreateFile: "query certain metadata such as file, directory, or device attributes without accessing that file or device" + // https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilew#parameters + FILE_ANY_ACCESS AccessMask = 0 + + GENERIC_READ AccessMask = 0x8000_0000 + GENERIC_WRITE AccessMask = 0x4000_0000 + GENERIC_EXECUTE AccessMask = 0x2000_0000 + GENERIC_ALL AccessMask = 0x1000_0000 + ACCESS_SYSTEM_SECURITY AccessMask = 0x0100_0000 + + // Specific Object Access + // from ntioapi.h + + FILE_READ_DATA AccessMask = (0x0001) // file & pipe + FILE_LIST_DIRECTORY AccessMask = (0x0001) // directory + + FILE_WRITE_DATA AccessMask = (0x0002) // file & pipe + FILE_ADD_FILE AccessMask = (0x0002) // directory + + FILE_APPEND_DATA AccessMask = (0x0004) // file + FILE_ADD_SUBDIRECTORY AccessMask = (0x0004) // directory + FILE_CREATE_PIPE_INSTANCE AccessMask = (0x0004) // named pipe + + FILE_READ_EA AccessMask = (0x0008) // file & directory + FILE_READ_PROPERTIES AccessMask = FILE_READ_EA + + FILE_WRITE_EA AccessMask = (0x0010) // file & directory + FILE_WRITE_PROPERTIES AccessMask = FILE_WRITE_EA + + FILE_EXECUTE AccessMask = (0x0020) // file + FILE_TRAVERSE AccessMask = (0x0020) // directory + + FILE_DELETE_CHILD AccessMask = (0x0040) // directory + + FILE_READ_ATTRIBUTES AccessMask = (0x0080) // all + + FILE_WRITE_ATTRIBUTES AccessMask = (0x0100) // all + + FILE_ALL_ACCESS AccessMask = (STANDARD_RIGHTS_REQUIRED | SYNCHRONIZE | 0x1FF) + FILE_GENERIC_READ AccessMask = (STANDARD_RIGHTS_READ | FILE_READ_DATA | FILE_READ_ATTRIBUTES | FILE_READ_EA | SYNCHRONIZE) + FILE_GENERIC_WRITE AccessMask = (STANDARD_RIGHTS_WRITE | FILE_WRITE_DATA | FILE_WRITE_ATTRIBUTES | FILE_WRITE_EA | FILE_APPEND_DATA | SYNCHRONIZE) + FILE_GENERIC_EXECUTE AccessMask = (STANDARD_RIGHTS_EXECUTE | FILE_READ_ATTRIBUTES | FILE_EXECUTE | SYNCHRONIZE) + + SPECIFIC_RIGHTS_ALL AccessMask = 0x0000FFFF + + // Standard Access + // from ntseapi.h + + DELETE AccessMask = 0x0001_0000 + READ_CONTROL AccessMask = 0x0002_0000 + WRITE_DAC AccessMask = 0x0004_0000 + WRITE_OWNER AccessMask = 0x0008_0000 + SYNCHRONIZE AccessMask = 0x0010_0000 + + STANDARD_RIGHTS_REQUIRED AccessMask = 0x000F_0000 + + STANDARD_RIGHTS_READ AccessMask = READ_CONTROL + STANDARD_RIGHTS_WRITE AccessMask = READ_CONTROL + STANDARD_RIGHTS_EXECUTE AccessMask = READ_CONTROL + + STANDARD_RIGHTS_ALL AccessMask = 0x001F_0000 +) + +type FileShareMode uint32 + +//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API. +const ( + FILE_SHARE_NONE FileShareMode = 0x00 + FILE_SHARE_READ FileShareMode = 0x01 + FILE_SHARE_WRITE FileShareMode = 0x02 + FILE_SHARE_DELETE FileShareMode = 0x04 + FILE_SHARE_VALID_FLAGS FileShareMode = 0x07 +) + +type FileCreationDisposition uint32 + +//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API. +const ( + // from winbase.h + + CREATE_NEW FileCreationDisposition = 0x01 + CREATE_ALWAYS FileCreationDisposition = 0x02 + OPEN_EXISTING FileCreationDisposition = 0x03 + OPEN_ALWAYS FileCreationDisposition = 0x04 + TRUNCATE_EXISTING FileCreationDisposition = 0x05 +) + +// Create disposition values for NtCreate* +type NTFileCreationDisposition uint32 + +//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API. +const ( + // From ntioapi.h + + FILE_SUPERSEDE NTFileCreationDisposition = 0x00 + FILE_OPEN NTFileCreationDisposition = 0x01 + FILE_CREATE NTFileCreationDisposition = 0x02 + FILE_OPEN_IF NTFileCreationDisposition = 0x03 + FILE_OVERWRITE NTFileCreationDisposition = 0x04 + FILE_OVERWRITE_IF NTFileCreationDisposition = 0x05 + FILE_MAXIMUM_DISPOSITION NTFileCreationDisposition = 0x05 +) + +// CreateFile and co. take flags or attributes together as one parameter. +// Define alias until we can use generics to allow both +// +// https://learn.microsoft.com/en-us/windows/win32/fileio/file-attribute-constants +type FileFlagOrAttribute uint32 + +//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API. +const ( + // from winnt.h + + FILE_FLAG_WRITE_THROUGH FileFlagOrAttribute = 0x8000_0000 + FILE_FLAG_OVERLAPPED FileFlagOrAttribute = 0x4000_0000 + FILE_FLAG_NO_BUFFERING FileFlagOrAttribute = 0x2000_0000 + FILE_FLAG_RANDOM_ACCESS FileFlagOrAttribute = 0x1000_0000 + FILE_FLAG_SEQUENTIAL_SCAN FileFlagOrAttribute = 0x0800_0000 + FILE_FLAG_DELETE_ON_CLOSE FileFlagOrAttribute = 0x0400_0000 + FILE_FLAG_BACKUP_SEMANTICS FileFlagOrAttribute = 0x0200_0000 + FILE_FLAG_POSIX_SEMANTICS FileFlagOrAttribute = 0x0100_0000 + FILE_FLAG_OPEN_REPARSE_POINT FileFlagOrAttribute = 0x0020_0000 + FILE_FLAG_OPEN_NO_RECALL FileFlagOrAttribute = 0x0010_0000 + FILE_FLAG_FIRST_PIPE_INSTANCE FileFlagOrAttribute = 0x0008_0000 +) + +// NtCreate* functions take a dedicated CreateOptions parameter. +// +// https://learn.microsoft.com/en-us/windows/win32/api/Winternl/nf-winternl-ntcreatefile +// +// https://learn.microsoft.com/en-us/windows/win32/devnotes/nt-create-named-pipe-file +type NTCreateOptions uint32 + +//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API. +const ( + // From ntioapi.h + + FILE_DIRECTORY_FILE NTCreateOptions = 0x0000_0001 + FILE_WRITE_THROUGH NTCreateOptions = 0x0000_0002 + FILE_SEQUENTIAL_ONLY NTCreateOptions = 0x0000_0004 + FILE_NO_INTERMEDIATE_BUFFERING NTCreateOptions = 0x0000_0008 + + FILE_SYNCHRONOUS_IO_ALERT NTCreateOptions = 0x0000_0010 + FILE_SYNCHRONOUS_IO_NONALERT NTCreateOptions = 0x0000_0020 + FILE_NON_DIRECTORY_FILE NTCreateOptions = 0x0000_0040 + FILE_CREATE_TREE_CONNECTION NTCreateOptions = 0x0000_0080 + + FILE_COMPLETE_IF_OPLOCKED NTCreateOptions = 0x0000_0100 + FILE_NO_EA_KNOWLEDGE NTCreateOptions = 0x0000_0200 + FILE_DISABLE_TUNNELING NTCreateOptions = 0x0000_0400 + FILE_RANDOM_ACCESS NTCreateOptions = 0x0000_0800 + + FILE_DELETE_ON_CLOSE NTCreateOptions = 0x0000_1000 + FILE_OPEN_BY_FILE_ID NTCreateOptions = 0x0000_2000 + FILE_OPEN_FOR_BACKUP_INTENT NTCreateOptions = 0x0000_4000 + FILE_NO_COMPRESSION NTCreateOptions = 0x0000_8000 +) + +type FileSQSFlag = FileFlagOrAttribute + +//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API. +const ( + // from winbase.h + + SECURITY_ANONYMOUS FileSQSFlag = FileSQSFlag(SecurityAnonymous << 16) + SECURITY_IDENTIFICATION FileSQSFlag = FileSQSFlag(SecurityIdentification << 16) + SECURITY_IMPERSONATION FileSQSFlag = FileSQSFlag(SecurityImpersonation << 16) + SECURITY_DELEGATION FileSQSFlag = FileSQSFlag(SecurityDelegation << 16) + + SECURITY_SQOS_PRESENT FileSQSFlag = 0x0010_0000 + SECURITY_VALID_SQOS_FLAGS FileSQSFlag = 0x001F_0000 +) + +// GetFinalPathNameByHandle flags +// +// https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-getfinalpathnamebyhandlew#parameters +type GetFinalPathFlag uint32 + +//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API. +const ( + GetFinalPathDefaultFlag GetFinalPathFlag = 0x0 + + FILE_NAME_NORMALIZED GetFinalPathFlag = 0x0 + FILE_NAME_OPENED GetFinalPathFlag = 0x8 + + VOLUME_NAME_DOS GetFinalPathFlag = 0x0 + VOLUME_NAME_GUID GetFinalPathFlag = 0x1 + VOLUME_NAME_NT GetFinalPathFlag = 0x2 + VOLUME_NAME_NONE GetFinalPathFlag = 0x4 +) + +// getFinalPathNameByHandle facilitates calling the Windows API GetFinalPathNameByHandle +// with the given handle and flags. It transparently takes care of creating a buffer of the +// correct size for the call. +// +// https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-getfinalpathnamebyhandlew +func GetFinalPathNameByHandle(h windows.Handle, flags GetFinalPathFlag) (string, error) { + b := stringbuffer.NewWString() + //TODO: can loop infinitely if Win32 keeps returning the same (or a larger) n? + for { + n, err := windows.GetFinalPathNameByHandle(h, b.Pointer(), b.Cap(), uint32(flags)) + if err != nil { + return "", err + } + // If the buffer wasn't large enough, n will be the total size needed (including null terminator). + // Resize and try again. + if n > b.Cap() { + b.ResizeTo(n) + continue + } + // If the buffer is large enough, n will be the size not including the null terminator. + // Convert to a Go string and return. + return b.String(), nil + } +} diff --git a/vendor/github.com/Microsoft/go-winio/internal/fs/security.go b/vendor/github.com/Microsoft/go-winio/internal/fs/security.go new file mode 100644 index 000000000..81760ac67 --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/internal/fs/security.go @@ -0,0 +1,12 @@ +package fs + +// https://learn.microsoft.com/en-us/windows/win32/api/winnt/ne-winnt-security_impersonation_level +type SecurityImpersonationLevel int32 // C default enums underlying type is `int`, which is Go `int32` + +// Impersonation levels +const ( + SecurityAnonymous SecurityImpersonationLevel = 0 + SecurityIdentification SecurityImpersonationLevel = 1 + SecurityImpersonation SecurityImpersonationLevel = 2 + SecurityDelegation SecurityImpersonationLevel = 3 +) diff --git a/vendor/github.com/Microsoft/go-winio/internal/fs/zsyscall_windows.go b/vendor/github.com/Microsoft/go-winio/internal/fs/zsyscall_windows.go new file mode 100644 index 000000000..a94e234c7 --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/internal/fs/zsyscall_windows.go @@ -0,0 +1,61 @@ +//go:build windows + +// Code generated by 'go generate' using "github.com/Microsoft/go-winio/tools/mkwinsyscall"; DO NOT EDIT. + +package fs + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + return e +} + +var ( + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + + procCreateFileW = modkernel32.NewProc("CreateFileW") +) + +func CreateFile(name string, access AccessMask, mode FileShareMode, sa *windows.SecurityAttributes, createmode FileCreationDisposition, attrs FileFlagOrAttribute, templatefile windows.Handle) (handle windows.Handle, err error) { + var _p0 *uint16 + _p0, err = syscall.UTF16PtrFromString(name) + if err != nil { + return + } + return _CreateFile(_p0, access, mode, sa, createmode, attrs, templatefile) +} + +func _CreateFile(name *uint16, access AccessMask, mode FileShareMode, sa *windows.SecurityAttributes, createmode FileCreationDisposition, attrs FileFlagOrAttribute, templatefile windows.Handle) (handle windows.Handle, err error) { + r0, _, e1 := syscall.SyscallN(procCreateFileW.Addr(), uintptr(unsafe.Pointer(name)), uintptr(access), uintptr(mode), uintptr(unsafe.Pointer(sa)), uintptr(createmode), uintptr(attrs), uintptr(templatefile)) + handle = windows.Handle(r0) + if handle == windows.InvalidHandle { + err = errnoErr(e1) + } + return +} diff --git a/vendor/github.com/Microsoft/go-winio/internal/socket/rawaddr.go b/vendor/github.com/Microsoft/go-winio/internal/socket/rawaddr.go new file mode 100644 index 000000000..7e82f9afa --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/internal/socket/rawaddr.go @@ -0,0 +1,20 @@ +package socket + +import ( + "unsafe" +) + +// RawSockaddr allows structs to be used with [Bind] and [ConnectEx]. The +// struct must meet the Win32 sockaddr requirements specified here: +// https://docs.microsoft.com/en-us/windows/win32/winsock/sockaddr-2 +// +// Specifically, the struct size must be least larger than an int16 (unsigned short) +// for the address family. +type RawSockaddr interface { + // Sockaddr returns a pointer to the RawSockaddr and its struct size, allowing + // for the RawSockaddr's data to be overwritten by syscalls (if necessary). + // + // It is the callers responsibility to validate that the values are valid; invalid + // pointers or size can cause a panic. + Sockaddr() (unsafe.Pointer, int32, error) +} diff --git a/vendor/github.com/Microsoft/go-winio/internal/socket/socket.go b/vendor/github.com/Microsoft/go-winio/internal/socket/socket.go new file mode 100644 index 000000000..88580d974 --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/internal/socket/socket.go @@ -0,0 +1,177 @@ +//go:build windows + +package socket + +import ( + "errors" + "fmt" + "net" + "sync" + "syscall" + "unsafe" + + "github.com/Microsoft/go-winio/pkg/guid" + "golang.org/x/sys/windows" +) + +//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go socket.go + +//sys getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getsockname +//sys getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getpeername +//sys bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind + +const socketError = uintptr(^uint32(0)) + +var ( + // todo(helsaawy): create custom error types to store the desired vs actual size and addr family? + + ErrBufferSize = errors.New("buffer size") + ErrAddrFamily = errors.New("address family") + ErrInvalidPointer = errors.New("invalid pointer") + ErrSocketClosed = fmt.Errorf("socket closed: %w", net.ErrClosed) +) + +// todo(helsaawy): replace these with generics, ie: GetSockName[S RawSockaddr](s windows.Handle) (S, error) + +// GetSockName writes the local address of socket s to the [RawSockaddr] rsa. +// If rsa is not large enough, the [windows.WSAEFAULT] is returned. +func GetSockName(s windows.Handle, rsa RawSockaddr) error { + ptr, l, err := rsa.Sockaddr() + if err != nil { + return fmt.Errorf("could not retrieve socket pointer and size: %w", err) + } + + // although getsockname returns WSAEFAULT if the buffer is too small, it does not set + // &l to the correct size, so--apart from doubling the buffer repeatedly--there is no remedy + return getsockname(s, ptr, &l) +} + +// GetPeerName returns the remote address the socket is connected to. +// +// See [GetSockName] for more information. +func GetPeerName(s windows.Handle, rsa RawSockaddr) error { + ptr, l, err := rsa.Sockaddr() + if err != nil { + return fmt.Errorf("could not retrieve socket pointer and size: %w", err) + } + + return getpeername(s, ptr, &l) +} + +func Bind(s windows.Handle, rsa RawSockaddr) (err error) { + ptr, l, err := rsa.Sockaddr() + if err != nil { + return fmt.Errorf("could not retrieve socket pointer and size: %w", err) + } + + return bind(s, ptr, l) +} + +// "golang.org/x/sys/windows".ConnectEx and .Bind only accept internal implementations of the +// their sockaddr interface, so they cannot be used with HvsockAddr +// Replicate functionality here from +// https://cs.opensource.google/go/x/sys/+/master:windows/syscall_windows.go + +// The function pointers to `AcceptEx`, `ConnectEx` and `GetAcceptExSockaddrs` must be loaded at +// runtime via a WSAIoctl call: +// https://docs.microsoft.com/en-us/windows/win32/api/Mswsock/nc-mswsock-lpfn_connectex#remarks + +type runtimeFunc struct { + id guid.GUID + once sync.Once + addr uintptr + err error +} + +func (f *runtimeFunc) Load() error { + f.once.Do(func() { + var s windows.Handle + s, f.err = windows.Socket(windows.AF_INET, windows.SOCK_STREAM, windows.IPPROTO_TCP) + if f.err != nil { + return + } + defer windows.CloseHandle(s) //nolint:errcheck + + var n uint32 + f.err = windows.WSAIoctl(s, + windows.SIO_GET_EXTENSION_FUNCTION_POINTER, + (*byte)(unsafe.Pointer(&f.id)), + uint32(unsafe.Sizeof(f.id)), + (*byte)(unsafe.Pointer(&f.addr)), + uint32(unsafe.Sizeof(f.addr)), + &n, + nil, // overlapped + 0, // completionRoutine + ) + }) + return f.err +} + +var ( + // todo: add `AcceptEx` and `GetAcceptExSockaddrs` + WSAID_CONNECTEX = guid.GUID{ //revive:disable-line:var-naming ALL_CAPS + Data1: 0x25a207b9, + Data2: 0xddf3, + Data3: 0x4660, + Data4: [8]byte{0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e}, + } + + connectExFunc = runtimeFunc{id: WSAID_CONNECTEX} +) + +func ConnectEx( + fd windows.Handle, + rsa RawSockaddr, + sendBuf *byte, + sendDataLen uint32, + bytesSent *uint32, + overlapped *windows.Overlapped, +) error { + if err := connectExFunc.Load(); err != nil { + return fmt.Errorf("failed to load ConnectEx function pointer: %w", err) + } + ptr, n, err := rsa.Sockaddr() + if err != nil { + return err + } + return connectEx(fd, ptr, n, sendBuf, sendDataLen, bytesSent, overlapped) +} + +// BOOL LpfnConnectex( +// [in] SOCKET s, +// [in] const sockaddr *name, +// [in] int namelen, +// [in, optional] PVOID lpSendBuffer, +// [in] DWORD dwSendDataLength, +// [out] LPDWORD lpdwBytesSent, +// [in] LPOVERLAPPED lpOverlapped +// ) + +func connectEx( + s windows.Handle, + name unsafe.Pointer, + namelen int32, + sendBuf *byte, + sendDataLen uint32, + bytesSent *uint32, + overlapped *windows.Overlapped, +) (err error) { + r1, _, e1 := syscall.SyscallN(connectExFunc.addr, + uintptr(s), + uintptr(name), + uintptr(namelen), + uintptr(unsafe.Pointer(sendBuf)), + uintptr(sendDataLen), + uintptr(unsafe.Pointer(bytesSent)), + uintptr(unsafe.Pointer(overlapped)), + ) + + if r1 == 0 { + if e1 != 0 { + err = error(e1) + } else { + err = syscall.EINVAL + } + } + return err +} diff --git a/vendor/github.com/Microsoft/go-winio/internal/socket/zsyscall_windows.go b/vendor/github.com/Microsoft/go-winio/internal/socket/zsyscall_windows.go new file mode 100644 index 000000000..e1504126a --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/internal/socket/zsyscall_windows.go @@ -0,0 +1,69 @@ +//go:build windows + +// Code generated by 'go generate' using "github.com/Microsoft/go-winio/tools/mkwinsyscall"; DO NOT EDIT. + +package socket + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + return e +} + +var ( + modws2_32 = windows.NewLazySystemDLL("ws2_32.dll") + + procbind = modws2_32.NewProc("bind") + procgetpeername = modws2_32.NewProc("getpeername") + procgetsockname = modws2_32.NewProc("getsockname") +) + +func bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) { + r1, _, e1 := syscall.SyscallN(procbind.Addr(), uintptr(s), uintptr(name), uintptr(namelen)) + if r1 == socketError { + err = errnoErr(e1) + } + return +} + +func getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) { + r1, _, e1 := syscall.SyscallN(procgetpeername.Addr(), uintptr(s), uintptr(name), uintptr(unsafe.Pointer(namelen))) + if r1 == socketError { + err = errnoErr(e1) + } + return +} + +func getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) { + r1, _, e1 := syscall.SyscallN(procgetsockname.Addr(), uintptr(s), uintptr(name), uintptr(unsafe.Pointer(namelen))) + if r1 == socketError { + err = errnoErr(e1) + } + return +} diff --git a/vendor/github.com/Microsoft/go-winio/internal/stringbuffer/wstring.go b/vendor/github.com/Microsoft/go-winio/internal/stringbuffer/wstring.go new file mode 100644 index 000000000..42ebc019f --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/internal/stringbuffer/wstring.go @@ -0,0 +1,132 @@ +package stringbuffer + +import ( + "sync" + "unicode/utf16" +) + +// TODO: worth exporting and using in mkwinsyscall? + +// Uint16BufferSize is the buffer size in the pool, chosen somewhat arbitrarily to accommodate +// large path strings: +// MAX_PATH (260) + size of volume GUID prefix (49) + null terminator = 310. +const MinWStringCap = 310 + +// use *[]uint16 since []uint16 creates an extra allocation where the slice header +// is copied to heap and then referenced via pointer in the interface header that sync.Pool +// stores. +var pathPool = sync.Pool{ // if go1.18+ adds Pool[T], use that to store []uint16 directly + New: func() interface{} { + b := make([]uint16, MinWStringCap) + return &b + }, +} + +func newBuffer() []uint16 { return *(pathPool.Get().(*[]uint16)) } + +// freeBuffer copies the slice header data, and puts a pointer to that in the pool. +// This avoids taking a pointer to the slice header in WString, which can be set to nil. +func freeBuffer(b []uint16) { pathPool.Put(&b) } + +// WString is a wide string buffer ([]uint16) meant for storing UTF-16 encoded strings +// for interacting with Win32 APIs. +// Sizes are specified as uint32 and not int. +// +// It is not thread safe. +type WString struct { + // type-def allows casting to []uint16 directly, use struct to prevent that and allow adding fields in the future. + + // raw buffer + b []uint16 +} + +// NewWString returns a [WString] allocated from a shared pool with an +// initial capacity of at least [MinWStringCap]. +// Since the buffer may have been previously used, its contents are not guaranteed to be empty. +// +// The buffer should be freed via [WString.Free] +func NewWString() *WString { + return &WString{ + b: newBuffer(), + } +} + +func (b *WString) Free() { + if b.empty() { + return + } + freeBuffer(b.b) + b.b = nil +} + +// ResizeTo grows the buffer to at least c and returns the new capacity, freeing the +// previous buffer back into pool. +func (b *WString) ResizeTo(c uint32) uint32 { + // already sufficient (or n is 0) + if c <= b.Cap() { + return b.Cap() + } + + if c <= MinWStringCap { + c = MinWStringCap + } + // allocate at-least double buffer size, as is done in [bytes.Buffer] and other places + if c <= 2*b.Cap() { + c = 2 * b.Cap() + } + + b2 := make([]uint16, c) + if !b.empty() { + copy(b2, b.b) + freeBuffer(b.b) + } + b.b = b2 + return c +} + +// Buffer returns the underlying []uint16 buffer. +func (b *WString) Buffer() []uint16 { + if b.empty() { + return nil + } + return b.b +} + +// Pointer returns a pointer to the first uint16 in the buffer. +// If the [WString.Free] has already been called, the pointer will be nil. +func (b *WString) Pointer() *uint16 { + if b.empty() { + return nil + } + return &b.b[0] +} + +// String returns the returns the UTF-8 encoding of the UTF-16 string in the buffer. +// +// It assumes that the data is null-terminated. +func (b *WString) String() string { + // Using [windows.UTF16ToString] would require importing "golang.org/x/sys/windows" + // and would make this code Windows-only, which makes no sense. + // So copy UTF16ToString code into here. + // If other windows-specific code is added, switch to [windows.UTF16ToString] + + s := b.b + for i, v := range s { + if v == 0 { + s = s[:i] + break + } + } + return string(utf16.Decode(s)) +} + +// Cap returns the underlying buffer capacity. +func (b *WString) Cap() uint32 { + if b.empty() { + return 0 + } + return b.cap() +} + +func (b *WString) cap() uint32 { return uint32(cap(b.b)) } +func (b *WString) empty() bool { return b == nil || b.cap() == 0 } diff --git a/vendor/github.com/Microsoft/go-winio/pipe.go b/vendor/github.com/Microsoft/go-winio/pipe.go new file mode 100644 index 000000000..a2da6639d --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/pipe.go @@ -0,0 +1,586 @@ +//go:build windows +// +build windows + +package winio + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "os" + "runtime" + "time" + "unsafe" + + "golang.org/x/sys/windows" + + "github.com/Microsoft/go-winio/internal/fs" +) + +//sys connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) = ConnectNamedPipe +//sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateNamedPipeW +//sys disconnectNamedPipe(pipe windows.Handle) (err error) = DisconnectNamedPipe +//sys getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo +//sys getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW +//sys ntCreateNamedPipeFile(pipe *windows.Handle, access ntAccessMask, oa *objectAttributes, iosb *ioStatusBlock, share ntFileShareMode, disposition ntFileCreationDisposition, options ntFileOptions, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntStatus) = ntdll.NtCreateNamedPipeFile +//sys rtlNtStatusToDosError(status ntStatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb +//sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntStatus) = ntdll.RtlDosPathNameToNtPathName_U +//sys rtlDefaultNpAcl(dacl *uintptr) (status ntStatus) = ntdll.RtlDefaultNpAcl + +type PipeConn interface { + net.Conn + Disconnect() error + Flush() error +} + +// type aliases for mkwinsyscall code +type ( + ntAccessMask = fs.AccessMask + ntFileShareMode = fs.FileShareMode + ntFileCreationDisposition = fs.NTFileCreationDisposition + ntFileOptions = fs.NTCreateOptions +) + +type ioStatusBlock struct { + Status, Information uintptr +} + +// typedef struct _OBJECT_ATTRIBUTES { +// ULONG Length; +// HANDLE RootDirectory; +// PUNICODE_STRING ObjectName; +// ULONG Attributes; +// PVOID SecurityDescriptor; +// PVOID SecurityQualityOfService; +// } OBJECT_ATTRIBUTES; +// +// https://learn.microsoft.com/en-us/windows/win32/api/ntdef/ns-ntdef-_object_attributes +type objectAttributes struct { + Length uintptr + RootDirectory uintptr + ObjectName *unicodeString + Attributes uintptr + SecurityDescriptor *securityDescriptor + SecurityQoS uintptr +} + +type unicodeString struct { + Length uint16 + MaximumLength uint16 + Buffer uintptr +} + +// typedef struct _SECURITY_DESCRIPTOR { +// BYTE Revision; +// BYTE Sbz1; +// SECURITY_DESCRIPTOR_CONTROL Control; +// PSID Owner; +// PSID Group; +// PACL Sacl; +// PACL Dacl; +// } SECURITY_DESCRIPTOR, *PISECURITY_DESCRIPTOR; +// +// https://learn.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-security_descriptor +type securityDescriptor struct { + Revision byte + Sbz1 byte + Control uint16 + Owner uintptr + Group uintptr + Sacl uintptr //revive:disable-line:var-naming SACL, not Sacl + Dacl uintptr //revive:disable-line:var-naming DACL, not Dacl +} + +type ntStatus int32 + +func (status ntStatus) Err() error { + if status >= 0 { + return nil + } + return rtlNtStatusToDosError(status) +} + +var ( + // ErrPipeListenerClosed is returned for pipe operations on listeners that have been closed. + ErrPipeListenerClosed = net.ErrClosed + + errPipeWriteClosed = errors.New("pipe has been closed for write") +) + +type win32Pipe struct { + *win32File + path string +} + +var _ PipeConn = (*win32Pipe)(nil) + +type win32MessageBytePipe struct { + win32Pipe + writeClosed bool + readEOF bool +} + +type pipeAddress string + +func (f *win32Pipe) LocalAddr() net.Addr { + return pipeAddress(f.path) +} + +func (f *win32Pipe) RemoteAddr() net.Addr { + return pipeAddress(f.path) +} + +func (f *win32Pipe) SetDeadline(t time.Time) error { + if err := f.SetReadDeadline(t); err != nil { + return err + } + return f.SetWriteDeadline(t) +} + +func (f *win32Pipe) Disconnect() error { + return disconnectNamedPipe(f.win32File.handle) +} + +// CloseWrite closes the write side of a message pipe in byte mode. +func (f *win32MessageBytePipe) CloseWrite() error { + if f.writeClosed { + return errPipeWriteClosed + } + err := f.win32File.Flush() + if err != nil { + return err + } + _, err = f.win32File.Write(nil) + if err != nil { + return err + } + f.writeClosed = true + return nil +} + +// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since +// they are used to implement CloseWrite(). +func (f *win32MessageBytePipe) Write(b []byte) (int, error) { + if f.writeClosed { + return 0, errPipeWriteClosed + } + if len(b) == 0 { + return 0, nil + } + return f.win32File.Write(b) +} + +// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message +// mode pipe will return io.EOF, as will all subsequent reads. +func (f *win32MessageBytePipe) Read(b []byte) (int, error) { + if f.readEOF { + return 0, io.EOF + } + n, err := f.win32File.Read(b) + if err == io.EOF { //nolint:errorlint + // If this was the result of a zero-byte read, then + // it is possible that the read was due to a zero-size + // message. Since we are simulating CloseWrite with a + // zero-byte message, ensure that all future Read() calls + // also return EOF. + f.readEOF = true + } else if err == windows.ERROR_MORE_DATA { //nolint:errorlint // err is Errno + // ERROR_MORE_DATA indicates that the pipe's read mode is message mode + // and the message still has more bytes. Treat this as a success, since + // this package presents all named pipes as byte streams. + err = nil + } + return n, err +} + +func (pipeAddress) Network() string { + return "pipe" +} + +func (s pipeAddress) String() string { + return string(s) +} + +// tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout. +func tryDialPipe(ctx context.Context, path *string, access fs.AccessMask, impLevel PipeImpLevel) (windows.Handle, error) { + for { + select { + case <-ctx.Done(): + return windows.Handle(0), ctx.Err() + default: + h, err := fs.CreateFile(*path, + access, + 0, // mode + nil, // security attributes + fs.OPEN_EXISTING, + fs.FILE_FLAG_OVERLAPPED|fs.SECURITY_SQOS_PRESENT|fs.FileSQSFlag(impLevel), + 0, // template file handle + ) + if err == nil { + return h, nil + } + if err != windows.ERROR_PIPE_BUSY { //nolint:errorlint // err is Errno + return h, &os.PathError{Err: err, Op: "open", Path: *path} + } + // Wait 10 msec and try again. This is a rather simplistic + // view, as we always try each 10 milliseconds. + time.Sleep(10 * time.Millisecond) + } + } +} + +// DialPipe connects to a named pipe by path, timing out if the connection +// takes longer than the specified duration. If timeout is nil, then we use +// a default timeout of 2 seconds. (We do not use WaitNamedPipe.) +func DialPipe(path string, timeout *time.Duration) (net.Conn, error) { + var absTimeout time.Time + if timeout != nil { + absTimeout = time.Now().Add(*timeout) + } else { + absTimeout = time.Now().Add(2 * time.Second) + } + ctx, cancel := context.WithDeadline(context.Background(), absTimeout) + defer cancel() + conn, err := DialPipeContext(ctx, path) + if errors.Is(err, context.DeadlineExceeded) { + return nil, ErrTimeout + } + return conn, err +} + +// DialPipeContext attempts to connect to a named pipe by `path` until `ctx` +// cancellation or timeout. +func DialPipeContext(ctx context.Context, path string) (net.Conn, error) { + return DialPipeAccess(ctx, path, uint32(fs.GENERIC_READ|fs.GENERIC_WRITE)) +} + +// PipeImpLevel is an enumeration of impersonation levels that may be set +// when calling DialPipeAccessImpersonation. +type PipeImpLevel uint32 + +const ( + PipeImpLevelAnonymous = PipeImpLevel(fs.SECURITY_ANONYMOUS) + PipeImpLevelIdentification = PipeImpLevel(fs.SECURITY_IDENTIFICATION) + PipeImpLevelImpersonation = PipeImpLevel(fs.SECURITY_IMPERSONATION) + PipeImpLevelDelegation = PipeImpLevel(fs.SECURITY_DELEGATION) +) + +// DialPipeAccess attempts to connect to a named pipe by `path` with `access` until `ctx` +// cancellation or timeout. +func DialPipeAccess(ctx context.Context, path string, access uint32) (net.Conn, error) { + return DialPipeAccessImpLevel(ctx, path, access, PipeImpLevelAnonymous) +} + +// DialPipeAccessImpLevel attempts to connect to a named pipe by `path` with +// `access` at `impLevel` until `ctx` cancellation or timeout. The other +// DialPipe* implementations use PipeImpLevelAnonymous. +func DialPipeAccessImpLevel(ctx context.Context, path string, access uint32, impLevel PipeImpLevel) (net.Conn, error) { + var err error + var h windows.Handle + h, err = tryDialPipe(ctx, &path, fs.AccessMask(access), impLevel) + if err != nil { + return nil, err + } + + var flags uint32 + err = getNamedPipeInfo(h, &flags, nil, nil, nil) + if err != nil { + return nil, err + } + + f, err := makeWin32File(h) + if err != nil { + windows.Close(h) + return nil, err + } + + // If the pipe is in message mode, return a message byte pipe, which + // supports CloseWrite(). + if flags&windows.PIPE_TYPE_MESSAGE != 0 { + return &win32MessageBytePipe{ + win32Pipe: win32Pipe{win32File: f, path: path}, + }, nil + } + return &win32Pipe{win32File: f, path: path}, nil +} + +type acceptResponse struct { + f *win32File + err error +} + +type win32PipeListener struct { + firstHandle windows.Handle + path string + config PipeConfig + acceptCh chan (chan acceptResponse) + closeCh chan int + doneCh chan int +} + +func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (windows.Handle, error) { + path16, err := windows.UTF16FromString(path) + if err != nil { + return 0, &os.PathError{Op: "open", Path: path, Err: err} + } + + var oa objectAttributes + oa.Length = unsafe.Sizeof(oa) + + var ntPath unicodeString + if err := rtlDosPathNameToNtPathName(&path16[0], + &ntPath, + 0, + 0, + ).Err(); err != nil { + return 0, &os.PathError{Op: "open", Path: path, Err: err} + } + defer windows.LocalFree(windows.Handle(ntPath.Buffer)) //nolint:errcheck + oa.ObjectName = &ntPath + oa.Attributes = windows.OBJ_CASE_INSENSITIVE + + // The security descriptor is only needed for the first pipe. + if first { + if sd != nil { + //todo: does `sdb` need to be allocated on the heap, or can go allocate it? + l := uint32(len(sd)) + sdb, err := windows.LocalAlloc(0, l) + if err != nil { + return 0, fmt.Errorf("LocalAlloc for security descriptor with of length %d: %w", l, err) + } + defer windows.LocalFree(windows.Handle(sdb)) //nolint:errcheck + copy((*[0xffff]byte)(unsafe.Pointer(sdb))[:], sd) + oa.SecurityDescriptor = (*securityDescriptor)(unsafe.Pointer(sdb)) + } else { + // Construct the default named pipe security descriptor. + var dacl uintptr + if err := rtlDefaultNpAcl(&dacl).Err(); err != nil { + return 0, fmt.Errorf("getting default named pipe ACL: %w", err) + } + defer windows.LocalFree(windows.Handle(dacl)) //nolint:errcheck + + sdb := &securityDescriptor{ + Revision: 1, + Control: windows.SE_DACL_PRESENT, + Dacl: dacl, + } + oa.SecurityDescriptor = sdb + } + } + + typ := uint32(windows.FILE_PIPE_REJECT_REMOTE_CLIENTS) + if c.MessageMode { + typ |= windows.FILE_PIPE_MESSAGE_TYPE + } + + disposition := fs.FILE_OPEN + access := fs.GENERIC_READ | fs.GENERIC_WRITE | fs.SYNCHRONIZE + if first { + disposition = fs.FILE_CREATE + // By not asking for read or write access, the named pipe file system + // will put this pipe into an initially disconnected state, blocking + // client connections until the next call with first == false. + access = fs.SYNCHRONIZE + } + + timeout := int64(-50 * 10000) // 50ms + + var ( + h windows.Handle + iosb ioStatusBlock + ) + err = ntCreateNamedPipeFile(&h, + access, + &oa, + &iosb, + fs.FILE_SHARE_READ|fs.FILE_SHARE_WRITE, + disposition, + 0, + typ, + 0, + 0, + 0xffffffff, + uint32(c.InputBufferSize), + uint32(c.OutputBufferSize), + &timeout).Err() + if err != nil { + return 0, &os.PathError{Op: "open", Path: path, Err: err} + } + + runtime.KeepAlive(ntPath) + return h, nil +} + +func (l *win32PipeListener) makeServerPipe() (*win32File, error) { + h, err := makeServerPipeHandle(l.path, nil, &l.config, false) + if err != nil { + return nil, err + } + f, err := makeWin32File(h) + if err != nil { + windows.Close(h) + return nil, err + } + return f, nil +} + +func (l *win32PipeListener) makeConnectedServerPipe() (*win32File, error) { + p, err := l.makeServerPipe() + if err != nil { + return nil, err + } + + // Wait for the client to connect. + ch := make(chan error) + go func(p *win32File) { + ch <- connectPipe(p) + }(p) + + select { + case err = <-ch: + if err != nil { + p.Close() + p = nil + } + case <-l.closeCh: + // Abort the connect request by closing the handle. + p.Close() + p = nil + err = <-ch + if err == nil || err == ErrFileClosed { //nolint:errorlint // err is Errno + err = ErrPipeListenerClosed + } + } + return p, err +} + +func (l *win32PipeListener) listenerRoutine() { + closed := false + for !closed { + select { + case <-l.closeCh: + closed = true + case responseCh := <-l.acceptCh: + var ( + p *win32File + err error + ) + for { + p, err = l.makeConnectedServerPipe() + // If the connection was immediately closed by the client, try + // again. + if err != windows.ERROR_NO_DATA { //nolint:errorlint // err is Errno + break + } + } + responseCh <- acceptResponse{p, err} + closed = err == ErrPipeListenerClosed //nolint:errorlint // err is Errno + } + } + windows.Close(l.firstHandle) + l.firstHandle = 0 + // Notify Close() and Accept() callers that the handle has been closed. + close(l.doneCh) +} + +// PipeConfig contain configuration for the pipe listener. +type PipeConfig struct { + // SecurityDescriptor contains a Windows security descriptor in SDDL format. + SecurityDescriptor string + + // MessageMode determines whether the pipe is in byte or message mode. In either + // case the pipe is read in byte mode by default. The only practical difference in + // this implementation is that CloseWrite() is only supported for message mode pipes; + // CloseWrite() is implemented as a zero-byte write, but zero-byte writes are only + // transferred to the reader (and returned as io.EOF in this implementation) + // when the pipe is in message mode. + MessageMode bool + + // InputBufferSize specifies the size of the input buffer, in bytes. + InputBufferSize int32 + + // OutputBufferSize specifies the size of the output buffer, in bytes. + OutputBufferSize int32 +} + +// ListenPipe creates a listener on a Windows named pipe path, e.g. \\.\pipe\mypipe. +// The pipe must not already exist. +func ListenPipe(path string, c *PipeConfig) (net.Listener, error) { + var ( + sd []byte + err error + ) + if c == nil { + c = &PipeConfig{} + } + if c.SecurityDescriptor != "" { + sd, err = SddlToSecurityDescriptor(c.SecurityDescriptor) + if err != nil { + return nil, err + } + } + h, err := makeServerPipeHandle(path, sd, c, true) + if err != nil { + return nil, err + } + l := &win32PipeListener{ + firstHandle: h, + path: path, + config: *c, + acceptCh: make(chan (chan acceptResponse)), + closeCh: make(chan int), + doneCh: make(chan int), + } + go l.listenerRoutine() + return l, nil +} + +func connectPipe(p *win32File) error { + c, err := p.prepareIO() + if err != nil { + return err + } + defer p.wg.Done() + + err = connectNamedPipe(p.handle, &c.o) + _, err = p.asyncIO(c, nil, 0, err) + if err != nil && err != windows.ERROR_PIPE_CONNECTED { //nolint:errorlint // err is Errno + return err + } + return nil +} + +func (l *win32PipeListener) Accept() (net.Conn, error) { + ch := make(chan acceptResponse) + select { + case l.acceptCh <- ch: + response := <-ch + err := response.err + if err != nil { + return nil, err + } + if l.config.MessageMode { + return &win32MessageBytePipe{ + win32Pipe: win32Pipe{win32File: response.f, path: l.path}, + }, nil + } + return &win32Pipe{win32File: response.f, path: l.path}, nil + case <-l.doneCh: + return nil, ErrPipeListenerClosed + } +} + +func (l *win32PipeListener) Close() error { + select { + case l.closeCh <- 1: + <-l.doneCh + case <-l.doneCh: + } + return nil +} + +func (l *win32PipeListener) Addr() net.Addr { + return pipeAddress(l.path) +} diff --git a/vendor/github.com/Microsoft/go-winio/pkg/guid/guid.go b/vendor/github.com/Microsoft/go-winio/pkg/guid/guid.go new file mode 100644 index 000000000..48ce4e924 --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/pkg/guid/guid.go @@ -0,0 +1,232 @@ +// Package guid provides a GUID type. The backing structure for a GUID is +// identical to that used by the golang.org/x/sys/windows GUID type. +// There are two main binary encodings used for a GUID, the big-endian encoding, +// and the Windows (mixed-endian) encoding. See here for details: +// https://en.wikipedia.org/wiki/Universally_unique_identifier#Encoding +package guid + +import ( + "crypto/rand" + "crypto/sha1" //nolint:gosec // not used for secure application + "encoding" + "encoding/binary" + "fmt" + "strconv" +) + +//go:generate go run golang.org/x/tools/cmd/stringer -type=Variant -trimprefix=Variant -linecomment + +// Variant specifies which GUID variant (or "type") of the GUID. It determines +// how the entirety of the rest of the GUID is interpreted. +type Variant uint8 + +// The variants specified by RFC 4122 section 4.1.1. +const ( + // VariantUnknown specifies a GUID variant which does not conform to one of + // the variant encodings specified in RFC 4122. + VariantUnknown Variant = iota + VariantNCS + VariantRFC4122 // RFC 4122 + VariantMicrosoft + VariantFuture +) + +// Version specifies how the bits in the GUID were generated. For instance, a +// version 4 GUID is randomly generated, and a version 5 is generated from the +// hash of an input string. +type Version uint8 + +func (v Version) String() string { + return strconv.FormatUint(uint64(v), 10) +} + +var _ = (encoding.TextMarshaler)(GUID{}) +var _ = (encoding.TextUnmarshaler)(&GUID{}) + +// NewV4 returns a new version 4 (pseudorandom) GUID, as defined by RFC 4122. +func NewV4() (GUID, error) { + var b [16]byte + if _, err := rand.Read(b[:]); err != nil { + return GUID{}, err + } + + g := FromArray(b) + g.setVersion(4) // Version 4 means randomly generated. + g.setVariant(VariantRFC4122) + + return g, nil +} + +// NewV5 returns a new version 5 (generated from a string via SHA-1 hashing) +// GUID, as defined by RFC 4122. The RFC is unclear on the encoding of the name, +// and the sample code treats it as a series of bytes, so we do the same here. +// +// Some implementations, such as those found on Windows, treat the name as a +// big-endian UTF16 stream of bytes. If that is desired, the string can be +// encoded as such before being passed to this function. +func NewV5(namespace GUID, name []byte) (GUID, error) { + b := sha1.New() //nolint:gosec // not used for secure application + namespaceBytes := namespace.ToArray() + b.Write(namespaceBytes[:]) + b.Write(name) + + a := [16]byte{} + copy(a[:], b.Sum(nil)) + + g := FromArray(a) + g.setVersion(5) // Version 5 means generated from a string. + g.setVariant(VariantRFC4122) + + return g, nil +} + +func fromArray(b [16]byte, order binary.ByteOrder) GUID { + var g GUID + g.Data1 = order.Uint32(b[0:4]) + g.Data2 = order.Uint16(b[4:6]) + g.Data3 = order.Uint16(b[6:8]) + copy(g.Data4[:], b[8:16]) + return g +} + +func (g GUID) toArray(order binary.ByteOrder) [16]byte { + b := [16]byte{} + order.PutUint32(b[0:4], g.Data1) + order.PutUint16(b[4:6], g.Data2) + order.PutUint16(b[6:8], g.Data3) + copy(b[8:16], g.Data4[:]) + return b +} + +// FromArray constructs a GUID from a big-endian encoding array of 16 bytes. +func FromArray(b [16]byte) GUID { + return fromArray(b, binary.BigEndian) +} + +// ToArray returns an array of 16 bytes representing the GUID in big-endian +// encoding. +func (g GUID) ToArray() [16]byte { + return g.toArray(binary.BigEndian) +} + +// FromWindowsArray constructs a GUID from a Windows encoding array of bytes. +func FromWindowsArray(b [16]byte) GUID { + return fromArray(b, binary.LittleEndian) +} + +// ToWindowsArray returns an array of 16 bytes representing the GUID in Windows +// encoding. +func (g GUID) ToWindowsArray() [16]byte { + return g.toArray(binary.LittleEndian) +} + +func (g GUID) String() string { + return fmt.Sprintf( + "%08x-%04x-%04x-%04x-%012x", + g.Data1, + g.Data2, + g.Data3, + g.Data4[:2], + g.Data4[2:]) +} + +// FromString parses a string containing a GUID and returns the GUID. The only +// format currently supported is the `xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx` +// format. +func FromString(s string) (GUID, error) { + if len(s) != 36 { + return GUID{}, fmt.Errorf("invalid GUID %q", s) + } + if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' { + return GUID{}, fmt.Errorf("invalid GUID %q", s) + } + + var g GUID + + data1, err := strconv.ParseUint(s[0:8], 16, 32) + if err != nil { + return GUID{}, fmt.Errorf("invalid GUID %q", s) + } + g.Data1 = uint32(data1) + + data2, err := strconv.ParseUint(s[9:13], 16, 16) + if err != nil { + return GUID{}, fmt.Errorf("invalid GUID %q", s) + } + g.Data2 = uint16(data2) + + data3, err := strconv.ParseUint(s[14:18], 16, 16) + if err != nil { + return GUID{}, fmt.Errorf("invalid GUID %q", s) + } + g.Data3 = uint16(data3) + + for i, x := range []int{19, 21, 24, 26, 28, 30, 32, 34} { + v, err := strconv.ParseUint(s[x:x+2], 16, 8) + if err != nil { + return GUID{}, fmt.Errorf("invalid GUID %q", s) + } + g.Data4[i] = uint8(v) + } + + return g, nil +} + +func (g *GUID) setVariant(v Variant) { + d := g.Data4[0] + switch v { + case VariantNCS: + d = (d & 0x7f) + case VariantRFC4122: + d = (d & 0x3f) | 0x80 + case VariantMicrosoft: + d = (d & 0x1f) | 0xc0 + case VariantFuture: + d = (d & 0x0f) | 0xe0 + case VariantUnknown: + fallthrough + default: + panic(fmt.Sprintf("invalid variant: %d", v)) + } + g.Data4[0] = d +} + +// Variant returns the GUID variant, as defined in RFC 4122. +func (g GUID) Variant() Variant { + b := g.Data4[0] + if b&0x80 == 0 { + return VariantNCS + } else if b&0xc0 == 0x80 { + return VariantRFC4122 + } else if b&0xe0 == 0xc0 { + return VariantMicrosoft + } else if b&0xe0 == 0xe0 { + return VariantFuture + } + return VariantUnknown +} + +func (g *GUID) setVersion(v Version) { + g.Data3 = (g.Data3 & 0x0fff) | (uint16(v) << 12) +} + +// Version returns the GUID version, as defined in RFC 4122. +func (g GUID) Version() Version { + return Version((g.Data3 & 0xF000) >> 12) +} + +// MarshalText returns the textual representation of the GUID. +func (g GUID) MarshalText() ([]byte, error) { + return []byte(g.String()), nil +} + +// UnmarshalText takes the textual representation of a GUID, and unmarhals it +// into this GUID. +func (g *GUID) UnmarshalText(text []byte) error { + g2, err := FromString(string(text)) + if err != nil { + return err + } + *g = g2 + return nil +} diff --git a/vendor/github.com/Microsoft/go-winio/pkg/guid/guid_nonwindows.go b/vendor/github.com/Microsoft/go-winio/pkg/guid/guid_nonwindows.go new file mode 100644 index 000000000..805bd3548 --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/pkg/guid/guid_nonwindows.go @@ -0,0 +1,16 @@ +//go:build !windows +// +build !windows + +package guid + +// GUID represents a GUID/UUID. It has the same structure as +// golang.org/x/sys/windows.GUID so that it can be used with functions expecting +// that type. It is defined as its own type as that is only available to builds +// targeted at `windows`. The representation matches that used by native Windows +// code. +type GUID struct { + Data1 uint32 + Data2 uint16 + Data3 uint16 + Data4 [8]byte +} diff --git a/vendor/github.com/Microsoft/go-winio/pkg/guid/guid_windows.go b/vendor/github.com/Microsoft/go-winio/pkg/guid/guid_windows.go new file mode 100644 index 000000000..27e45ee5c --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/pkg/guid/guid_windows.go @@ -0,0 +1,13 @@ +//go:build windows +// +build windows + +package guid + +import "golang.org/x/sys/windows" + +// GUID represents a GUID/UUID. It has the same structure as +// golang.org/x/sys/windows.GUID so that it can be used with functions expecting +// that type. It is defined as its own type so that stringification and +// marshaling can be supported. The representation matches that used by native +// Windows code. +type GUID windows.GUID diff --git a/vendor/github.com/Microsoft/go-winio/pkg/guid/variant_string.go b/vendor/github.com/Microsoft/go-winio/pkg/guid/variant_string.go new file mode 100644 index 000000000..4076d3132 --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/pkg/guid/variant_string.go @@ -0,0 +1,27 @@ +// Code generated by "stringer -type=Variant -trimprefix=Variant -linecomment"; DO NOT EDIT. + +package guid + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[VariantUnknown-0] + _ = x[VariantNCS-1] + _ = x[VariantRFC4122-2] + _ = x[VariantMicrosoft-3] + _ = x[VariantFuture-4] +} + +const _Variant_name = "UnknownNCSRFC 4122MicrosoftFuture" + +var _Variant_index = [...]uint8{0, 7, 10, 18, 27, 33} + +func (i Variant) String() string { + if i >= Variant(len(_Variant_index)-1) { + return "Variant(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _Variant_name[_Variant_index[i]:_Variant_index[i+1]] +} diff --git a/vendor/github.com/Microsoft/go-winio/privilege.go b/vendor/github.com/Microsoft/go-winio/privilege.go new file mode 100644 index 000000000..d9b90b6e8 --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/privilege.go @@ -0,0 +1,196 @@ +//go:build windows +// +build windows + +package winio + +import ( + "bytes" + "encoding/binary" + "fmt" + "runtime" + "sync" + "unicode/utf16" + + "golang.org/x/sys/windows" +) + +//sys adjustTokenPrivileges(token windows.Token, releaseAll bool, input *byte, outputSize uint32, output *byte, requiredSize *uint32) (success bool, err error) [true] = advapi32.AdjustTokenPrivileges +//sys impersonateSelf(level uint32) (err error) = advapi32.ImpersonateSelf +//sys revertToSelf() (err error) = advapi32.RevertToSelf +//sys openThreadToken(thread windows.Handle, accessMask uint32, openAsSelf bool, token *windows.Token) (err error) = advapi32.OpenThreadToken +//sys getCurrentThread() (h windows.Handle) = GetCurrentThread +//sys lookupPrivilegeValue(systemName string, name string, luid *uint64) (err error) = advapi32.LookupPrivilegeValueW +//sys lookupPrivilegeName(systemName string, luid *uint64, buffer *uint16, size *uint32) (err error) = advapi32.LookupPrivilegeNameW +//sys lookupPrivilegeDisplayName(systemName string, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) = advapi32.LookupPrivilegeDisplayNameW + +const ( + //revive:disable-next-line:var-naming ALL_CAPS + SE_PRIVILEGE_ENABLED = windows.SE_PRIVILEGE_ENABLED + + //revive:disable-next-line:var-naming ALL_CAPS + ERROR_NOT_ALL_ASSIGNED windows.Errno = windows.ERROR_NOT_ALL_ASSIGNED + + SeBackupPrivilege = "SeBackupPrivilege" + SeRestorePrivilege = "SeRestorePrivilege" + SeSecurityPrivilege = "SeSecurityPrivilege" +) + +var ( + privNames = make(map[string]uint64) + privNameMutex sync.Mutex +) + +// PrivilegeError represents an error enabling privileges. +type PrivilegeError struct { + privileges []uint64 +} + +func (e *PrivilegeError) Error() string { + s := "Could not enable privilege " + if len(e.privileges) > 1 { + s = "Could not enable privileges " + } + for i, p := range e.privileges { + if i != 0 { + s += ", " + } + s += `"` + s += getPrivilegeName(p) + s += `"` + } + return s +} + +// RunWithPrivilege enables a single privilege for a function call. +func RunWithPrivilege(name string, fn func() error) error { + return RunWithPrivileges([]string{name}, fn) +} + +// RunWithPrivileges enables privileges for a function call. +func RunWithPrivileges(names []string, fn func() error) error { + privileges, err := mapPrivileges(names) + if err != nil { + return err + } + runtime.LockOSThread() + defer runtime.UnlockOSThread() + token, err := newThreadToken() + if err != nil { + return err + } + defer releaseThreadToken(token) + err = adjustPrivileges(token, privileges, SE_PRIVILEGE_ENABLED) + if err != nil { + return err + } + return fn() +} + +func mapPrivileges(names []string) ([]uint64, error) { + privileges := make([]uint64, 0, len(names)) + privNameMutex.Lock() + defer privNameMutex.Unlock() + for _, name := range names { + p, ok := privNames[name] + if !ok { + err := lookupPrivilegeValue("", name, &p) + if err != nil { + return nil, err + } + privNames[name] = p + } + privileges = append(privileges, p) + } + return privileges, nil +} + +// EnableProcessPrivileges enables privileges globally for the process. +func EnableProcessPrivileges(names []string) error { + return enableDisableProcessPrivilege(names, SE_PRIVILEGE_ENABLED) +} + +// DisableProcessPrivileges disables privileges globally for the process. +func DisableProcessPrivileges(names []string) error { + return enableDisableProcessPrivilege(names, 0) +} + +func enableDisableProcessPrivilege(names []string, action uint32) error { + privileges, err := mapPrivileges(names) + if err != nil { + return err + } + + p := windows.CurrentProcess() + var token windows.Token + err = windows.OpenProcessToken(p, windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY, &token) + if err != nil { + return err + } + + defer token.Close() + return adjustPrivileges(token, privileges, action) +} + +func adjustPrivileges(token windows.Token, privileges []uint64, action uint32) error { + var b bytes.Buffer + _ = binary.Write(&b, binary.LittleEndian, uint32(len(privileges))) + for _, p := range privileges { + _ = binary.Write(&b, binary.LittleEndian, p) + _ = binary.Write(&b, binary.LittleEndian, action) + } + prevState := make([]byte, b.Len()) + reqSize := uint32(0) + success, err := adjustTokenPrivileges(token, false, &b.Bytes()[0], uint32(len(prevState)), &prevState[0], &reqSize) + if !success { + return err + } + if err == ERROR_NOT_ALL_ASSIGNED { //nolint:errorlint // err is Errno + return &PrivilegeError{privileges} + } + return nil +} + +func getPrivilegeName(luid uint64) string { + var nameBuffer [256]uint16 + bufSize := uint32(len(nameBuffer)) + err := lookupPrivilegeName("", &luid, &nameBuffer[0], &bufSize) + if err != nil { + return fmt.Sprintf("", luid) + } + + var displayNameBuffer [256]uint16 + displayBufSize := uint32(len(displayNameBuffer)) + var langID uint32 + err = lookupPrivilegeDisplayName("", &nameBuffer[0], &displayNameBuffer[0], &displayBufSize, &langID) + if err != nil { + return fmt.Sprintf("", string(utf16.Decode(nameBuffer[:bufSize]))) + } + + return string(utf16.Decode(displayNameBuffer[:displayBufSize])) +} + +func newThreadToken() (windows.Token, error) { + err := impersonateSelf(windows.SecurityImpersonation) + if err != nil { + return 0, err + } + + var token windows.Token + err = openThreadToken(getCurrentThread(), windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY, false, &token) + if err != nil { + rerr := revertToSelf() + if rerr != nil { + panic(rerr) + } + return 0, err + } + return token, nil +} + +func releaseThreadToken(h windows.Token) { + err := revertToSelf() + if err != nil { + panic(err) + } + h.Close() +} diff --git a/vendor/github.com/Microsoft/go-winio/reparse.go b/vendor/github.com/Microsoft/go-winio/reparse.go new file mode 100644 index 000000000..67d1a104a --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/reparse.go @@ -0,0 +1,131 @@ +//go:build windows +// +build windows + +package winio + +import ( + "bytes" + "encoding/binary" + "fmt" + "strings" + "unicode/utf16" + "unsafe" +) + +const ( + reparseTagMountPoint = 0xA0000003 + reparseTagSymlink = 0xA000000C +) + +type reparseDataBuffer struct { + ReparseTag uint32 + ReparseDataLength uint16 + Reserved uint16 + SubstituteNameOffset uint16 + SubstituteNameLength uint16 + PrintNameOffset uint16 + PrintNameLength uint16 +} + +// ReparsePoint describes a Win32 symlink or mount point. +type ReparsePoint struct { + Target string + IsMountPoint bool +} + +// UnsupportedReparsePointError is returned when trying to decode a non-symlink or +// mount point reparse point. +type UnsupportedReparsePointError struct { + Tag uint32 +} + +func (e *UnsupportedReparsePointError) Error() string { + return fmt.Sprintf("unsupported reparse point %x", e.Tag) +} + +// DecodeReparsePoint decodes a Win32 REPARSE_DATA_BUFFER structure containing either a symlink +// or a mount point. +func DecodeReparsePoint(b []byte) (*ReparsePoint, error) { + tag := binary.LittleEndian.Uint32(b[0:4]) + return DecodeReparsePointData(tag, b[8:]) +} + +func DecodeReparsePointData(tag uint32, b []byte) (*ReparsePoint, error) { + isMountPoint := false + switch tag { + case reparseTagMountPoint: + isMountPoint = true + case reparseTagSymlink: + default: + return nil, &UnsupportedReparsePointError{tag} + } + nameOffset := 8 + binary.LittleEndian.Uint16(b[4:6]) + if !isMountPoint { + nameOffset += 4 + } + nameLength := binary.LittleEndian.Uint16(b[6:8]) + name := make([]uint16, nameLength/2) + err := binary.Read(bytes.NewReader(b[nameOffset:nameOffset+nameLength]), binary.LittleEndian, &name) + if err != nil { + return nil, err + } + return &ReparsePoint{string(utf16.Decode(name)), isMountPoint}, nil +} + +func isDriveLetter(c byte) bool { + return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') +} + +// EncodeReparsePoint encodes a Win32 REPARSE_DATA_BUFFER structure describing a symlink or +// mount point. +func EncodeReparsePoint(rp *ReparsePoint) []byte { + // Generate an NT path and determine if this is a relative path. + var ntTarget string + relative := false + if strings.HasPrefix(rp.Target, `\\?\`) { + ntTarget = `\??\` + rp.Target[4:] + } else if strings.HasPrefix(rp.Target, `\\`) { + ntTarget = `\??\UNC\` + rp.Target[2:] + } else if len(rp.Target) >= 2 && isDriveLetter(rp.Target[0]) && rp.Target[1] == ':' { + ntTarget = `\??\` + rp.Target + } else { + ntTarget = rp.Target + relative = true + } + + // The paths must be NUL-terminated even though they are counted strings. + target16 := utf16.Encode([]rune(rp.Target + "\x00")) + ntTarget16 := utf16.Encode([]rune(ntTarget + "\x00")) + + size := int(unsafe.Sizeof(reparseDataBuffer{})) - 8 + size += len(ntTarget16)*2 + len(target16)*2 + + tag := uint32(reparseTagMountPoint) + if !rp.IsMountPoint { + tag = reparseTagSymlink + size += 4 // Add room for symlink flags + } + + data := reparseDataBuffer{ + ReparseTag: tag, + ReparseDataLength: uint16(size), + SubstituteNameOffset: 0, + SubstituteNameLength: uint16((len(ntTarget16) - 1) * 2), + PrintNameOffset: uint16(len(ntTarget16) * 2), + PrintNameLength: uint16((len(target16) - 1) * 2), + } + + var b bytes.Buffer + _ = binary.Write(&b, binary.LittleEndian, &data) + if !rp.IsMountPoint { + flags := uint32(0) + if relative { + flags |= 1 + } + _ = binary.Write(&b, binary.LittleEndian, flags) + } + + _ = binary.Write(&b, binary.LittleEndian, ntTarget16) + _ = binary.Write(&b, binary.LittleEndian, target16) + return b.Bytes() +} diff --git a/vendor/github.com/Microsoft/go-winio/sd.go b/vendor/github.com/Microsoft/go-winio/sd.go new file mode 100644 index 000000000..c3685e98e --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/sd.go @@ -0,0 +1,133 @@ +//go:build windows +// +build windows + +package winio + +import ( + "errors" + "fmt" + "unsafe" + + "golang.org/x/sys/windows" +) + +//sys lookupAccountName(systemName *uint16, accountName string, sid *byte, sidSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) = advapi32.LookupAccountNameW +//sys lookupAccountSid(systemName *uint16, sid *byte, name *uint16, nameSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) = advapi32.LookupAccountSidW +//sys convertSidToStringSid(sid *byte, str **uint16) (err error) = advapi32.ConvertSidToStringSidW +//sys convertStringSidToSid(str *uint16, sid **byte) (err error) = advapi32.ConvertStringSidToSidW + +type AccountLookupError struct { + Name string + Err error +} + +func (e *AccountLookupError) Error() string { + if e.Name == "" { + return "lookup account: empty account name specified" + } + var s string + switch { + case errors.Is(e.Err, windows.ERROR_INVALID_SID): + s = "the security ID structure is invalid" + case errors.Is(e.Err, windows.ERROR_NONE_MAPPED): + s = "not found" + default: + s = e.Err.Error() + } + return "lookup account " + e.Name + ": " + s +} + +func (e *AccountLookupError) Unwrap() error { return e.Err } + +type SddlConversionError struct { + Sddl string + Err error +} + +func (e *SddlConversionError) Error() string { + return "convert " + e.Sddl + ": " + e.Err.Error() +} + +func (e *SddlConversionError) Unwrap() error { return e.Err } + +// LookupSidByName looks up the SID of an account by name +// +//revive:disable-next-line:var-naming SID, not Sid +func LookupSidByName(name string) (sid string, err error) { + if name == "" { + return "", &AccountLookupError{name, windows.ERROR_NONE_MAPPED} + } + + var sidSize, sidNameUse, refDomainSize uint32 + err = lookupAccountName(nil, name, nil, &sidSize, nil, &refDomainSize, &sidNameUse) + if err != nil && err != windows.ERROR_INSUFFICIENT_BUFFER { //nolint:errorlint // err is Errno + return "", &AccountLookupError{name, err} + } + sidBuffer := make([]byte, sidSize) + refDomainBuffer := make([]uint16, refDomainSize) + err = lookupAccountName(nil, name, &sidBuffer[0], &sidSize, &refDomainBuffer[0], &refDomainSize, &sidNameUse) + if err != nil { + return "", &AccountLookupError{name, err} + } + var strBuffer *uint16 + err = convertSidToStringSid(&sidBuffer[0], &strBuffer) + if err != nil { + return "", &AccountLookupError{name, err} + } + sid = windows.UTF16ToString((*[0xffff]uint16)(unsafe.Pointer(strBuffer))[:]) + _, _ = windows.LocalFree(windows.Handle(unsafe.Pointer(strBuffer))) + return sid, nil +} + +// LookupNameBySid looks up the name of an account by SID +// +//revive:disable-next-line:var-naming SID, not Sid +func LookupNameBySid(sid string) (name string, err error) { + if sid == "" { + return "", &AccountLookupError{sid, windows.ERROR_NONE_MAPPED} + } + + sidBuffer, err := windows.UTF16PtrFromString(sid) + if err != nil { + return "", &AccountLookupError{sid, err} + } + + var sidPtr *byte + if err = convertStringSidToSid(sidBuffer, &sidPtr); err != nil { + return "", &AccountLookupError{sid, err} + } + defer windows.LocalFree(windows.Handle(unsafe.Pointer(sidPtr))) //nolint:errcheck + + var nameSize, refDomainSize, sidNameUse uint32 + err = lookupAccountSid(nil, sidPtr, nil, &nameSize, nil, &refDomainSize, &sidNameUse) + if err != nil && err != windows.ERROR_INSUFFICIENT_BUFFER { //nolint:errorlint // err is Errno + return "", &AccountLookupError{sid, err} + } + + nameBuffer := make([]uint16, nameSize) + refDomainBuffer := make([]uint16, refDomainSize) + err = lookupAccountSid(nil, sidPtr, &nameBuffer[0], &nameSize, &refDomainBuffer[0], &refDomainSize, &sidNameUse) + if err != nil { + return "", &AccountLookupError{sid, err} + } + + name = windows.UTF16ToString(nameBuffer) + return name, nil +} + +func SddlToSecurityDescriptor(sddl string) ([]byte, error) { + sd, err := windows.SecurityDescriptorFromString(sddl) + if err != nil { + return nil, &SddlConversionError{Sddl: sddl, Err: err} + } + b := unsafe.Slice((*byte)(unsafe.Pointer(sd)), sd.Length()) + return b, nil +} + +func SecurityDescriptorToSddl(sd []byte) (string, error) { + if l := int(unsafe.Sizeof(windows.SECURITY_DESCRIPTOR{})); len(sd) < l { + return "", fmt.Errorf("SecurityDescriptor (%d) smaller than expected (%d): %w", len(sd), l, windows.ERROR_INCORRECT_SIZE) + } + s := (*windows.SECURITY_DESCRIPTOR)(unsafe.Pointer(&sd[0])) + return s.String(), nil +} diff --git a/vendor/github.com/Microsoft/go-winio/syscall.go b/vendor/github.com/Microsoft/go-winio/syscall.go new file mode 100644 index 000000000..a6ca111b3 --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/syscall.go @@ -0,0 +1,5 @@ +//go:build windows + +package winio + +//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go ./*.go diff --git a/vendor/github.com/Microsoft/go-winio/zsyscall_windows.go b/vendor/github.com/Microsoft/go-winio/zsyscall_windows.go new file mode 100644 index 000000000..89b66eda8 --- /dev/null +++ b/vendor/github.com/Microsoft/go-winio/zsyscall_windows.go @@ -0,0 +1,378 @@ +//go:build windows + +// Code generated by 'go generate' using "github.com/Microsoft/go-winio/tools/mkwinsyscall"; DO NOT EDIT. + +package winio + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + return e +} + +var ( + modadvapi32 = windows.NewLazySystemDLL("advapi32.dll") + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + modntdll = windows.NewLazySystemDLL("ntdll.dll") + modws2_32 = windows.NewLazySystemDLL("ws2_32.dll") + + procAdjustTokenPrivileges = modadvapi32.NewProc("AdjustTokenPrivileges") + procConvertSidToStringSidW = modadvapi32.NewProc("ConvertSidToStringSidW") + procConvertStringSidToSidW = modadvapi32.NewProc("ConvertStringSidToSidW") + procImpersonateSelf = modadvapi32.NewProc("ImpersonateSelf") + procLookupAccountNameW = modadvapi32.NewProc("LookupAccountNameW") + procLookupAccountSidW = modadvapi32.NewProc("LookupAccountSidW") + procLookupPrivilegeDisplayNameW = modadvapi32.NewProc("LookupPrivilegeDisplayNameW") + procLookupPrivilegeNameW = modadvapi32.NewProc("LookupPrivilegeNameW") + procLookupPrivilegeValueW = modadvapi32.NewProc("LookupPrivilegeValueW") + procOpenThreadToken = modadvapi32.NewProc("OpenThreadToken") + procRevertToSelf = modadvapi32.NewProc("RevertToSelf") + procBackupRead = modkernel32.NewProc("BackupRead") + procBackupWrite = modkernel32.NewProc("BackupWrite") + procCancelIoEx = modkernel32.NewProc("CancelIoEx") + procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe") + procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort") + procCreateNamedPipeW = modkernel32.NewProc("CreateNamedPipeW") + procDisconnectNamedPipe = modkernel32.NewProc("DisconnectNamedPipe") + procGetCurrentThread = modkernel32.NewProc("GetCurrentThread") + procGetNamedPipeHandleStateW = modkernel32.NewProc("GetNamedPipeHandleStateW") + procGetNamedPipeInfo = modkernel32.NewProc("GetNamedPipeInfo") + procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus") + procSetFileCompletionNotificationModes = modkernel32.NewProc("SetFileCompletionNotificationModes") + procNtCreateNamedPipeFile = modntdll.NewProc("NtCreateNamedPipeFile") + procRtlDefaultNpAcl = modntdll.NewProc("RtlDefaultNpAcl") + procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U") + procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb") + procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult") +) + +func adjustTokenPrivileges(token windows.Token, releaseAll bool, input *byte, outputSize uint32, output *byte, requiredSize *uint32) (success bool, err error) { + var _p0 uint32 + if releaseAll { + _p0 = 1 + } + r0, _, e1 := syscall.SyscallN(procAdjustTokenPrivileges.Addr(), uintptr(token), uintptr(_p0), uintptr(unsafe.Pointer(input)), uintptr(outputSize), uintptr(unsafe.Pointer(output)), uintptr(unsafe.Pointer(requiredSize))) + success = r0 != 0 + if true { + err = errnoErr(e1) + } + return +} + +func convertSidToStringSid(sid *byte, str **uint16) (err error) { + r1, _, e1 := syscall.SyscallN(procConvertSidToStringSidW.Addr(), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(str))) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + +func convertStringSidToSid(str *uint16, sid **byte) (err error) { + r1, _, e1 := syscall.SyscallN(procConvertStringSidToSidW.Addr(), uintptr(unsafe.Pointer(str)), uintptr(unsafe.Pointer(sid))) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + +func impersonateSelf(level uint32) (err error) { + r1, _, e1 := syscall.SyscallN(procImpersonateSelf.Addr(), uintptr(level)) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + +func lookupAccountName(systemName *uint16, accountName string, sid *byte, sidSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) { + var _p0 *uint16 + _p0, err = syscall.UTF16PtrFromString(accountName) + if err != nil { + return + } + return _lookupAccountName(systemName, _p0, sid, sidSize, refDomain, refDomainSize, sidNameUse) +} + +func _lookupAccountName(systemName *uint16, accountName *uint16, sid *byte, sidSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) { + r1, _, e1 := syscall.SyscallN(procLookupAccountNameW.Addr(), uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(accountName)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(sidSize)), uintptr(unsafe.Pointer(refDomain)), uintptr(unsafe.Pointer(refDomainSize)), uintptr(unsafe.Pointer(sidNameUse))) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + +func lookupAccountSid(systemName *uint16, sid *byte, name *uint16, nameSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) { + r1, _, e1 := syscall.SyscallN(procLookupAccountSidW.Addr(), uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(nameSize)), uintptr(unsafe.Pointer(refDomain)), uintptr(unsafe.Pointer(refDomainSize)), uintptr(unsafe.Pointer(sidNameUse))) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + +func lookupPrivilegeDisplayName(systemName string, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) { + var _p0 *uint16 + _p0, err = syscall.UTF16PtrFromString(systemName) + if err != nil { + return + } + return _lookupPrivilegeDisplayName(_p0, name, buffer, size, languageId) +} + +func _lookupPrivilegeDisplayName(systemName *uint16, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) { + r1, _, e1 := syscall.SyscallN(procLookupPrivilegeDisplayNameW.Addr(), uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(buffer)), uintptr(unsafe.Pointer(size)), uintptr(unsafe.Pointer(languageId))) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + +func lookupPrivilegeName(systemName string, luid *uint64, buffer *uint16, size *uint32) (err error) { + var _p0 *uint16 + _p0, err = syscall.UTF16PtrFromString(systemName) + if err != nil { + return + } + return _lookupPrivilegeName(_p0, luid, buffer, size) +} + +func _lookupPrivilegeName(systemName *uint16, luid *uint64, buffer *uint16, size *uint32) (err error) { + r1, _, e1 := syscall.SyscallN(procLookupPrivilegeNameW.Addr(), uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(luid)), uintptr(unsafe.Pointer(buffer)), uintptr(unsafe.Pointer(size))) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + +func lookupPrivilegeValue(systemName string, name string, luid *uint64) (err error) { + var _p0 *uint16 + _p0, err = syscall.UTF16PtrFromString(systemName) + if err != nil { + return + } + var _p1 *uint16 + _p1, err = syscall.UTF16PtrFromString(name) + if err != nil { + return + } + return _lookupPrivilegeValue(_p0, _p1, luid) +} + +func _lookupPrivilegeValue(systemName *uint16, name *uint16, luid *uint64) (err error) { + r1, _, e1 := syscall.SyscallN(procLookupPrivilegeValueW.Addr(), uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(luid))) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + +func openThreadToken(thread windows.Handle, accessMask uint32, openAsSelf bool, token *windows.Token) (err error) { + var _p0 uint32 + if openAsSelf { + _p0 = 1 + } + r1, _, e1 := syscall.SyscallN(procOpenThreadToken.Addr(), uintptr(thread), uintptr(accessMask), uintptr(_p0), uintptr(unsafe.Pointer(token))) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + +func revertToSelf() (err error) { + r1, _, e1 := syscall.SyscallN(procRevertToSelf.Addr()) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + +func backupRead(h windows.Handle, b []byte, bytesRead *uint32, abort bool, processSecurity bool, context *uintptr) (err error) { + var _p0 *byte + if len(b) > 0 { + _p0 = &b[0] + } + var _p1 uint32 + if abort { + _p1 = 1 + } + var _p2 uint32 + if processSecurity { + _p2 = 1 + } + r1, _, e1 := syscall.SyscallN(procBackupRead.Addr(), uintptr(h), uintptr(unsafe.Pointer(_p0)), uintptr(len(b)), uintptr(unsafe.Pointer(bytesRead)), uintptr(_p1), uintptr(_p2), uintptr(unsafe.Pointer(context))) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + +func backupWrite(h windows.Handle, b []byte, bytesWritten *uint32, abort bool, processSecurity bool, context *uintptr) (err error) { + var _p0 *byte + if len(b) > 0 { + _p0 = &b[0] + } + var _p1 uint32 + if abort { + _p1 = 1 + } + var _p2 uint32 + if processSecurity { + _p2 = 1 + } + r1, _, e1 := syscall.SyscallN(procBackupWrite.Addr(), uintptr(h), uintptr(unsafe.Pointer(_p0)), uintptr(len(b)), uintptr(unsafe.Pointer(bytesWritten)), uintptr(_p1), uintptr(_p2), uintptr(unsafe.Pointer(context))) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + +func cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) { + r1, _, e1 := syscall.SyscallN(procCancelIoEx.Addr(), uintptr(file), uintptr(unsafe.Pointer(o))) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + +func connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) { + r1, _, e1 := syscall.SyscallN(procConnectNamedPipe.Addr(), uintptr(pipe), uintptr(unsafe.Pointer(o))) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + +func createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) { + r0, _, e1 := syscall.SyscallN(procCreateIoCompletionPort.Addr(), uintptr(file), uintptr(port), uintptr(key), uintptr(threadCount)) + newport = windows.Handle(r0) + if newport == 0 { + err = errnoErr(e1) + } + return +} + +func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) { + var _p0 *uint16 + _p0, err = syscall.UTF16PtrFromString(name) + if err != nil { + return + } + return _createNamedPipe(_p0, flags, pipeMode, maxInstances, outSize, inSize, defaultTimeout, sa) +} + +func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) { + r0, _, e1 := syscall.SyscallN(procCreateNamedPipeW.Addr(), uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa))) + handle = windows.Handle(r0) + if handle == windows.InvalidHandle { + err = errnoErr(e1) + } + return +} + +func disconnectNamedPipe(pipe windows.Handle) (err error) { + r1, _, e1 := syscall.SyscallN(procDisconnectNamedPipe.Addr(), uintptr(pipe)) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + +func getCurrentThread() (h windows.Handle) { + r0, _, _ := syscall.SyscallN(procGetCurrentThread.Addr()) + h = windows.Handle(r0) + return +} + +func getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) { + r1, _, e1 := syscall.SyscallN(procGetNamedPipeHandleStateW.Addr(), uintptr(pipe), uintptr(unsafe.Pointer(state)), uintptr(unsafe.Pointer(curInstances)), uintptr(unsafe.Pointer(maxCollectionCount)), uintptr(unsafe.Pointer(collectDataTimeout)), uintptr(unsafe.Pointer(userName)), uintptr(maxUserNameSize)) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + +func getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) { + r1, _, e1 := syscall.SyscallN(procGetNamedPipeInfo.Addr(), uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances))) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + +func getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) { + r1, _, e1 := syscall.SyscallN(procGetQueuedCompletionStatus.Addr(), uintptr(port), uintptr(unsafe.Pointer(bytes)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(o)), uintptr(timeout)) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + +func setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) { + r1, _, e1 := syscall.SyscallN(procSetFileCompletionNotificationModes.Addr(), uintptr(h), uintptr(flags)) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + +func ntCreateNamedPipeFile(pipe *windows.Handle, access ntAccessMask, oa *objectAttributes, iosb *ioStatusBlock, share ntFileShareMode, disposition ntFileCreationDisposition, options ntFileOptions, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntStatus) { + r0, _, _ := syscall.SyscallN(procNtCreateNamedPipeFile.Addr(), uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout))) + status = ntStatus(r0) + return +} + +func rtlDefaultNpAcl(dacl *uintptr) (status ntStatus) { + r0, _, _ := syscall.SyscallN(procRtlDefaultNpAcl.Addr(), uintptr(unsafe.Pointer(dacl))) + status = ntStatus(r0) + return +} + +func rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntStatus) { + r0, _, _ := syscall.SyscallN(procRtlDosPathNameToNtPathName_U.Addr(), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(ntName)), uintptr(filePart), uintptr(reserved)) + status = ntStatus(r0) + return +} + +func rtlNtStatusToDosError(status ntStatus) (winerr error) { + r0, _, _ := syscall.SyscallN(procRtlNtStatusToDosErrorNoTeb.Addr(), uintptr(status)) + if r0 != 0 { + winerr = syscall.Errno(r0) + } + return +} + +func wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) { + var _p0 uint32 + if wait { + _p0 = 1 + } + r1, _, e1 := syscall.SyscallN(procWSAGetOverlappedResult.Addr(), uintptr(h), uintptr(unsafe.Pointer(o)), uintptr(unsafe.Pointer(bytes)), uintptr(_p0), uintptr(unsafe.Pointer(flags))) + if r1 == 0 { + err = errnoErr(e1) + } + return +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 5ebb4bde1..890e0b984 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -33,9 +33,16 @@ github.com/Azure/go-autorest/logger # github.com/Azure/go-autorest/tracing v0.5.0 ## explicit; go 1.12 github.com/Azure/go-autorest/tracing -# github.com/DataDog/datadog-go v3.7.2+incompatible +# github.com/DataDog/datadog-go v4.8.3+incompatible ## explicit github.com/DataDog/datadog-go/statsd +# github.com/Microsoft/go-winio v0.6.2 +## explicit; go 1.21 +github.com/Microsoft/go-winio +github.com/Microsoft/go-winio/internal/fs +github.com/Microsoft/go-winio/internal/socket +github.com/Microsoft/go-winio/internal/stringbuffer +github.com/Microsoft/go-winio/pkg/guid # github.com/aws/aws-lambda-go v1.47.0 ## explicit; go 1.18 github.com/aws/aws-lambda-go/lambda