Skip to content

Commit 3a6a7fb

Browse files
client: use set instead of append for grpc headers
Signed-off-by: huanghaoyuanhhy <[email protected]>
1 parent 1020567 commit 3a6a7fb

File tree

2 files changed

+58
-17
lines changed

2 files changed

+58
-17
lines changed

internal/client/milvus/grpc.go

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -317,19 +317,31 @@ func NewGrpc(cfg *paramtable.MilvusConfig) (*GrpcClient, error) {
317317
return cli, nil
318318
}
319319

320-
func (g *GrpcClient) newCtx(ctx context.Context) context.Context {
320+
func (g *GrpcClient) newAuthMD(ctx context.Context) metadata.MD {
321+
md := metadata.MD{}
322+
if outgoingMD, ok := metadata.FromOutgoingContext(ctx); ok {
323+
md = outgoingMD.Copy()
324+
}
325+
321326
if g.auth != "" {
322-
ctx = metadata.AppendToOutgoingContext(ctx, authorizationHeader, g.auth)
327+
md.Set(authorizationHeader, g.auth)
323328
}
324329
if g.identifier != "" {
325-
ctx = metadata.AppendToOutgoingContext(ctx, identifierHeader, g.identifier)
330+
md.Set(identifierHeader, g.identifier)
326331
}
327-
return ctx
332+
333+
return md
334+
}
335+
336+
func (g *GrpcClient) newCtx(ctx context.Context) context.Context {
337+
return metadata.NewOutgoingContext(ctx, g.newAuthMD(ctx))
328338
}
329339

330340
func (g *GrpcClient) newCtxWithDB(ctx context.Context, db string) context.Context {
331-
ctx = g.newCtx(ctx)
332-
return metadata.AppendToOutgoingContext(ctx, databaseHeader, db)
341+
md := g.newAuthMD(ctx)
342+
md.Set(databaseHeader, db)
343+
344+
return metadata.NewOutgoingContext(ctx, md)
333345
}
334346

335347
func (g *GrpcClient) connect(ctx context.Context) error {

internal/client/milvus/grpc_test.go

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,20 +76,49 @@ func TestIsRateLimitError(t *testing.T) {
7676
}
7777

7878
func TestGrpcClient_newCtx(t *testing.T) {
79-
cli := &GrpcClient{auth: "auth", identifier: "identifier"}
80-
ctx := cli.newCtx(context.Background())
81-
md, ok := metadata.FromOutgoingContext(ctx)
82-
assert.True(t, ok)
83-
assert.Equal(t, "auth", md.Get(authorizationHeader)[0])
84-
assert.Equal(t, "identifier", md.Get(identifierHeader)[0])
79+
t.Run("Normal", func(t *testing.T) {
80+
cli := &GrpcClient{auth: "auth", identifier: "identifier"}
81+
ctx := cli.newCtx(context.Background())
82+
md, ok := metadata.FromOutgoingContext(ctx)
83+
assert.True(t, ok)
84+
assert.Equal(t, "auth", md.Get(authorizationHeader)[0])
85+
assert.Len(t, md.Get(authorizationHeader), 1)
86+
assert.Equal(t, "identifier", md.Get(identifierHeader)[0])
87+
assert.Len(t, md.Get(identifierHeader), 1)
88+
})
89+
90+
t.Run("SetMultipleTimes", func(t *testing.T) {
91+
cli := &GrpcClient{auth: "auth", identifier: "identifier"}
92+
ctx := cli.newCtx(context.Background())
93+
ctx = cli.newCtx(ctx)
94+
md, ok := metadata.FromOutgoingContext(ctx)
95+
assert.True(t, ok)
96+
assert.Equal(t, "auth", md.Get(authorizationHeader)[0])
97+
assert.Len(t, md.Get(authorizationHeader), 1)
98+
assert.Equal(t, "identifier", md.Get(identifierHeader)[0])
99+
assert.Len(t, md.Get(identifierHeader), 1)
100+
})
101+
85102
}
86103

87104
func TestGrpcClient_newCtxWithDB(t *testing.T) {
88-
cli := &GrpcClient{}
89-
ctx := cli.newCtxWithDB(context.Background(), "db")
90-
md, ok := metadata.FromOutgoingContext(ctx)
91-
assert.True(t, ok)
92-
assert.Equal(t, "db", md.Get(databaseHeader)[0])
105+
t.Run("Normal", func(t *testing.T) {
106+
cli := &GrpcClient{}
107+
ctx := cli.newCtxWithDB(context.Background(), "db")
108+
md, ok := metadata.FromOutgoingContext(ctx)
109+
assert.True(t, ok)
110+
assert.Equal(t, "db", md.Get(databaseHeader)[0])
111+
})
112+
113+
t.Run("SetMultipleTimes", func(t *testing.T) {
114+
cli := &GrpcClient{}
115+
ctx := cli.newCtxWithDB(context.Background(), "db")
116+
ctx = cli.newCtxWithDB(ctx, "db2")
117+
md, ok := metadata.FromOutgoingContext(ctx)
118+
assert.True(t, ok)
119+
assert.Equal(t, "db2", md.Get(databaseHeader)[0])
120+
assert.Len(t, md.Get(databaseHeader), 1)
121+
})
93122
}
94123

95124
func TestGrpcClient_HasFeature(t *testing.T) {

0 commit comments

Comments
 (0)