diff --git a/ee/stargate/internal/server/grpc/api_test.go b/ee/stargate/internal/server/grpc/api_test.go index bf6606000e..7cd2635e3e 100644 --- a/ee/stargate/internal/server/grpc/api_test.go +++ b/ee/stargate/internal/server/grpc/api_test.go @@ -104,24 +104,34 @@ func TestStream(t *testing.T) { case <-goCtx.Done(): return case ev := <-incomingMessageChan: - testCaseMsg := testCase.ev.Event.(*api.StargateServerMessage_ApiCall) - serverMsg := ev.Event.(*api.StargateServerMessage_ApiCall) - assert.Equal(t, testCaseMsg.ApiCall.Method, serverMsg.ApiCall.Method) - assert.Equal(t, testCaseMsg.ApiCall.Path, serverMsg.ApiCall.Path) - for k, v := range testCaseMsg.ApiCall.Query { - assert.Equal(t, v.Values, serverMsg.ApiCall.Query[k].Values) - } - for k, v := range testCaseMsg.ApiCall.Headers { - assert.Equal(t, v.Values, serverMsg.ApiCall.Headers[k].Values) - } - assert.Equal(t, testCaseMsg.ApiCall.Body, serverMsg.ApiCall.Body) + switch serverMsg := ev.Event.(type) { + case *api.StargateServerMessage_ApiCall: + testCaseMsg := testCase.ev.Event.(*api.StargateServerMessage_ApiCall) + assert.Equal(t, testCaseMsg.ApiCall.Method, serverMsg.ApiCall.Method) + assert.Equal(t, testCaseMsg.ApiCall.Path, serverMsg.ApiCall.Path) + for k, v := range testCaseMsg.ApiCall.Query { + assert.Equal(t, v.Values, serverMsg.ApiCall.Query[k].Values) + } + for k, v := range testCaseMsg.ApiCall.Headers { + assert.Equal(t, v.Values, serverMsg.ApiCall.Headers[k].Values) + } + assert.Equal(t, testCaseMsg.ApiCall.Body, serverMsg.ApiCall.Body) - testCase.response.CorrelationId = ev.CorrelationId - select { - case <-goCtx.Done(): - return - case responseChan <- testCase.response: + testCase.response.CorrelationId = ev.CorrelationId + select { + case <-goCtx.Done(): + return + case responseChan <- testCase.response: + } + case *api.StargateServerMessage_Ping_: + responseChan <- &api.StargateClientMessage{ + CorrelationId: ev.CorrelationId, + Event: &api.StargateClientMessage_Pong_{ + Pong: &api.StargateClientMessage_Pong{}, + }, + } } + } }()