-
Notifications
You must be signed in to change notification settings - Fork 1
/
ctxtrace.go
188 lines (163 loc) · 5.56 KB
/
ctxtrace.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
package ctxtrace
import (
"context"
"net/http"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags"
"github.com/openzipkin/zipkin-go/model"
"github.com/openzipkin/zipkin-go/propagation/b3"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
const (
headerRequestID = "x-request-id"
)
// TraceData is a simple struct to hold both the RequestID and the B3 TraceSpan
type TraceData struct {
RequestID string
TraceSpan *model.SpanContext
}
type traceCtxMarker struct{}
// UnaryServerInterceptor for propagating client information
func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
ctx = extractMetadataToContext(ctx)
return handler(ctx, req)
}
}
// StreamServerInterceptor for propagating client information
// only on the first request on the stream
func StreamServerInterceptor() grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
ctx := stream.Context()
wrapped := grpc_middleware.WrapServerStream(stream)
wrapped.WrappedContext = extractMetadataToContext(ctx)
return handler(srv, wrapped)
}
}
// UnaryClientInterceptor propagates any user information from the context
func UnaryClientInterceptor() grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{},
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
newCtx := NewOutgoingContextWithData(ctx)
return invoker(newCtx, method, req, reply, cc, opts...)
}
}
// StreamClientInterceptor propagates any user information from the context
func StreamClientInterceptor() grpc.StreamClientInterceptor {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn,
method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
newCtx := NewOutgoingContextWithData(ctx)
return streamer(newCtx, desc, cc, method, opts...)
}
}
// Extract extracts metadata from the context.
func Extract(ctx context.Context) TraceData {
data, ok := ctx.Value(traceCtxMarker{}).(TraceData)
if !ok {
return TraceData{}
}
return data
}
// ExtractHTTP extracts metadata from a normal http request
func ExtractHTTP(r *http.Request) (TraceData, error) {
data := TraceData{}
if reqID := r.Header.Get(headerRequestID); reqID != "" {
data.RequestID = reqID
}
span, err := b3.ExtractHTTP(r)()
if err != nil {
return data, err
}
data.TraceSpan = span
return data, nil
}
// ExtractHTTPToContext extracts metadata from a normal http request and adds it to the context
func ExtractHTTPToContext(ctx context.Context, r *http.Request) context.Context {
data, _ := ExtractHTTP(r)
return context.WithValue(ctx, traceCtxMarker{}, data)
}
func addOtelSpanContextToContext(ctx context.Context, traceData TraceData) context.Context {
traceIDString := traceData.TraceSpan.TraceID.String()
traceID, err := trace.TraceIDFromHex(traceIDString)
if err != nil {
return ctx
}
spanIDString := traceData.TraceSpan.ID.String()
spanID, err := trace.SpanIDFromHex(spanIDString)
if err != nil {
return ctx
}
traceFlags := trace.TraceFlags(0)
if *traceData.TraceSpan.Sampled {
traceFlags = trace.FlagsSampled
}
spanContext := trace.NewSpanContext(
//TODO: add tracestate, remote
trace.SpanContextConfig{
TraceID: traceID,
SpanID: spanID,
TraceFlags: traceFlags,
},
)
if !spanContext.IsValid() {
return ctx
}
return trace.ContextWithRemoteSpanContext(ctx, spanContext)
}
// finds caller information in the gRPC metadata and adds it to the context
func extractMetadataToContext(ctx context.Context) context.Context {
md, mdOK := metadata.FromIncomingContext(ctx)
if !mdOK {
return ctx
}
data := TraceData{}
span, err := b3.ExtractGRPC(&md)()
if err != nil {
zap.L().Warn("b3 extract failed", zap.Error(err))
} else {
data.TraceSpan = span
ctx = addOtelSpanContextToContext(ctx, data)
}
if mdValue, ok := md[headerRequestID]; ok && len(mdValue) != 0 {
data.RequestID = mdValue[0]
grpc_ctxtags.Extract(ctx).Set("request_id", mdValue[0])
}
return context.WithValue(ctx, traceCtxMarker{}, data)
}
// NewOutgoingContextWithData creates a new context with the metadata added
func NewOutgoingContextWithData(ctx context.Context) context.Context {
md := InjectDataIntoOutMetadata(ctx, Extract(ctx))
return metadata.NewOutgoingContext(ctx, md)
}
// InjectDataIntoOutMetadata injects the given trace data into metadata fit for an outgoing context
func InjectDataIntoOutMetadata(ctx context.Context, data TraceData) metadata.MD {
md, mdOK := metadata.FromOutgoingContext(ctx)
if !mdOK {
md = metadata.New(nil)
}
packCallerMetadata(&md, Extract(ctx))
return md
}
// packCallerMetadata extracts caller specific values from the context,
// into a MD metadata struct that can be propagated with outgoing gRPC requests
func packCallerMetadata(m *metadata.MD, data TraceData) {
if m == nil {
zap.L().Fatal("metadata is nil", zap.Stack("stack"))
}
if data.TraceSpan != nil {
err := b3.InjectGRPC(m)(*data.TraceSpan)
if err != nil {
zap.L().Warn("b3 injection failed", zap.Error(err))
}
}
if data.RequestID != "" {
m.Set(headerRequestID, data.RequestID)
}
}
// WithValue Creates context with TraceData values
func WithValue(ctx context.Context, traceData TraceData) context.Context {
return context.WithValue(ctx, traceCtxMarker{}, traceData)
}