Skip to content

Commit f0e34dd

Browse files
committed
Remove cors from target URL to avoid double headers
1 parent ae050cc commit f0e34dd

File tree

2 files changed

+72
-64
lines changed

2 files changed

+72
-64
lines changed

api/proxy/proxy.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@ package proxy
22

33
import (
44
"context"
5+
"fmt"
56
"net"
67
"net/http"
78
"net/http/httputil"
89
"net/url"
10+
"strings"
911
"time"
1012

1113
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
@@ -66,19 +68,30 @@ func NewServer(
6668
MaxAge: 300,
6769
})
6870
handler = c.Handler(mux)
71+
72+
proxy.ModifyResponse = func(resp *http.Response) error {
73+
// Remove CORS headers from the target response
74+
for k := range resp.Header {
75+
if strings.HasPrefix(k, "Access-Control-") {
76+
delete(resp.Header, k)
77+
}
78+
}
79+
return nil
80+
}
6981
}
7082

7183
// Register GRPC services handled locally
7284
grpcMux := runtime.NewServeMux()
7385
for _, svc := range local {
74-
svc.RegisterHandlerService(grpcMux)
86+
if err := svc.RegisterHandlerService(grpcMux); err != nil {
87+
return nil, fmt.Errorf("registering local service %s: %w", svc.Path(), err)
88+
}
7589
mux.Handle(svc.Path(), grpcMux)
7690
}
7791

7892
// The rest is proxied.
7993
// HTTP handler to forward requests
8094
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
81-
r.Host = targetURL.Host
8295
proxy.ServeHTTP(w, r)
8396
})
8497

node/node_test.go

Lines changed: 57 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ import (
3535
"google.golang.org/grpc/credentials/insecure"
3636
"google.golang.org/grpc/status"
3737
"google.golang.org/protobuf/encoding/protojson"
38-
"google.golang.org/protobuf/proto"
3938

4039
"github.com/spacemeshos/go-spacemesh/activation"
4140
"github.com/spacemeshos/go-spacemesh/api/grpcserver"
@@ -192,34 +191,6 @@ func TestSpacemeshApp_Cmd(t *testing.T) {
192191
r.Equal(expected2, str)
193192
}
194193

195-
func marshalProto(tb testing.TB, msg proto.Message) []byte {
196-
buf, err := protojson.Marshal(msg)
197-
require.NoError(tb, err)
198-
return buf
199-
}
200-
201-
func callEndpointPost(tb testing.TB, url string, payload []byte) ([]byte, int) {
202-
resp, err := http.Post(url, "application/json", bytes.NewReader(payload))
203-
require.NoError(tb, err)
204-
require.Equal(tb, "application/json", resp.Header.Get("Content-Type"))
205-
buf, err := io.ReadAll(resp.Body)
206-
require.NoError(tb, err)
207-
require.NoError(tb, resp.Body.Close())
208-
209-
return buf, resp.StatusCode
210-
}
211-
212-
func callEndpointGet(tb testing.TB, url string) ([]byte, int) {
213-
resp, err := http.Get(url)
214-
require.NoError(tb, err)
215-
require.Equal(tb, "application/json", resp.Header.Get("Content-Type"))
216-
buf, err := io.ReadAll(resp.Body)
217-
require.NoError(tb, err)
218-
require.NoError(tb, resp.Body.Close())
219-
220-
return buf, resp.StatusCode
221-
}
222-
223194
func TestSpacemeshApp_GrpcService(t *testing.T) {
224195
// Use a unique port
225196
listener := "127.0.0.1:1242"
@@ -300,15 +271,15 @@ func TestSpacemeshApp_JsonService(t *testing.T) {
300271
r := require.New(t)
301272

302273
const message = "你好世界"
303-
payload := marshalProto(t, &pb.EchoRequest{Msg: &pb.SimpleString{Value: message}})
274+
payload, err := protojson.Marshal(&pb.EchoRequest{Msg: &pb.SimpleString{Value: message}})
275+
require.NoError(t, err)
304276
listener := "127.0.0.1:0"
305277

306278
cfg := getTestDefaultConfig(t)
307279
cfg.API.JSONListener = listener
308280
cfg.API.PrivateServices = nil
309281
app := New(WithConfig(cfg), WithLog(logtest.New(t)))
310282

311-
var err error
312283
app.clock, err = timesync.NewClock(
313284
timesync.WithLayerDuration(cfg.LayerDuration),
314285
timesync.WithTickInterval(1*time.Second),
@@ -329,63 +300,80 @@ func TestSpacemeshApp_JsonService(t *testing.T) {
329300
r.NoError(err)
330301
defer app.stopServices(context.Background())
331302

332-
var (
333-
respBody []byte
334-
respStatus int
335-
)
336303
endpoint := fmt.Sprintf("http://%s/v1/node/echo", app.jsonAPIServer.BoundAddress)
304+
var resp *http.Response
337305
require.Eventually(t, func() bool {
338-
respBody, respStatus = callEndpointPost(t, endpoint, payload)
339-
return respStatus == http.StatusOK
306+
resp, err = http.Post(endpoint, "application/json", bytes.NewReader(payload))
307+
return err == nil && resp.StatusCode == http.StatusOK
340308
}, 2*time.Second, 100*time.Millisecond)
309+
310+
require.Equal(t, "application/json", resp.Header.Get("Content-Type"))
311+
respBody, err := io.ReadAll(resp.Body)
312+
require.NoError(t, err)
313+
require.NoError(t, resp.Body.Close())
314+
341315
var msg pb.EchoResponse
342316
require.NoError(t, protojson.Unmarshal(respBody, &msg))
343317
require.Equal(t, message, msg.Msg.Value)
344-
require.Equal(t, http.StatusOK, respStatus)
345318
require.NoError(t, protojson.Unmarshal(respBody, &msg))
346319
require.Equal(t, message, msg.Msg.Value)
347320
}
348321

349322
func TestProxyingJsonService(t *testing.T) {
350-
cfg := config.Config{
323+
serverCfg := config.Config{
351324
API: grpcserver.Config{
352-
JSONListener: "127.0.0.1:0",
353-
PublicListener: "127.0.0.1:0",
354-
PublicServices: []grpcserver.Service{grpcserver.Node},
355-
PrivateServices: nil,
356-
PostServices: nil,
357-
TLSServices: nil,
325+
JSONListener: "127.0.0.1:0",
326+
PublicListener: "127.0.0.1:0",
327+
JSONCorsEverywhere: true,
328+
PublicServices: []grpcserver.Service{grpcserver.Node},
358329
},
359330
}
360331

361332
// Start server
362333
logger := logtest.New(t)
363334
db := localsql.InMemoryTest(t)
364-
serverApp := New(WithConfig(&cfg), WithLog(logger.Named("server")))
335+
serverApp := New(WithConfig(&serverCfg), WithLog(logger.Named("server")))
365336
err := serverApp.startAPIServices(context.Background())
366337
require.NoError(t, err)
367338
defer serverApp.stopServices(context.Background())
368339

369340
// Start client proxying to the server
370-
cfg.API.ProxyApiV2Address = fmt.Sprintf("http://%s", serverApp.jsonAPIServer.BoundAddress)
371-
cfg.API.NonProxiedServices = []grpcserver.Service{grpcserver.SmeshingIdentitiesV2Beta1}
372-
clientApp := New(WithConfig(&cfg), WithLog(logger.Named("client")))
341+
clientCfg := config.Config{
342+
API: grpcserver.Config{
343+
JSONListener: "127.0.0.1:0",
344+
PublicListener: "127.0.0.1:0",
345+
JSONCorsEverywhere: true,
346+
ProxyApiV2Address: fmt.Sprintf("http://%s", serverApp.jsonAPIServer.BoundAddress),
347+
NonProxiedServices: []grpcserver.Service{grpcserver.SmeshingIdentitiesV2Beta1},
348+
PublicServices: []grpcserver.Service{grpcserver.Node},
349+
},
350+
}
351+
clientApp := New(WithConfig(&clientCfg), WithLog(logger.Named("client")))
373352
clientApp.idStates = identity.NewIdentityStateStorage(db, logger.Named("idStates").Zap())
374353

375354
require.NoError(t, clientApp.startAPIServices(context.Background()))
376355
defer clientApp.stopServices(context.Background())
377356

378-
var (
379-
respBody []byte
380-
respStatus int
381-
)
382357
const message = "hello world"
383358
endpoint := fmt.Sprintf("http://%s/v1/node/echo", clientApp.apiProxy.BoundAddress)
384-
payload := marshalProto(t, &pb.EchoRequest{Msg: &pb.SimpleString{Value: message}})
359+
payload, err := protojson.Marshal(&pb.EchoRequest{Msg: &pb.SimpleString{Value: message}})
360+
require.NoError(t, err)
361+
req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewBuffer(payload))
362+
require.NoError(t, err)
363+
req.Header.Set("Content-Type", "application/json")
364+
req.Header.Set("Origin", "http://localhost")
365+
var resp *http.Response
385366
require.Eventually(t, func() bool {
386-
respBody, respStatus = callEndpointPost(t, endpoint, payload)
387-
return respStatus == http.StatusOK
367+
resp, err = http.DefaultClient.Do(req)
368+
return err == nil && resp.StatusCode == http.StatusOK
388369
}, 2*time.Second, 100*time.Millisecond)
370+
371+
require.Equal(t, "*", resp.Header.Get("Access-Control-Allow-Origin"))
372+
require.Equal(t, "application/json", resp.Header.Get("Content-Type"))
373+
respBody, err := io.ReadAll(resp.Body)
374+
require.NoError(t, err)
375+
require.NoError(t, resp.Body.Close())
376+
389377
var msg pb.EchoResponse
390378
require.NoError(t, protojson.Unmarshal(respBody, &msg))
391379
require.Equal(t, message, msg.Msg.Value)
@@ -398,12 +386,19 @@ func TestProxyingJsonService(t *testing.T) {
398386

399387
nodeID := types.RandomNodeID()
400388
clientApp.idStates.Set(nodeID, &identity.ATXBroadcasted{AtxId: types.RandomATXID()})
401-
respBody, status := callEndpointGet(t, endpoint)
402-
require.Equal(t, http.StatusOK, status)
403-
var resp pbV2.IdentityStatesResponse
404-
require.NoError(t, protojson.Unmarshal(respBody, &resp))
405-
require.Len(t, resp.States, 1)
406-
require.Equal(t, nodeID.Bytes(), resp.States[0].Smesher)
389+
resp, err = http.Get(endpoint)
390+
require.NoError(t, err)
391+
require.Equal(t, http.StatusOK, resp.StatusCode)
392+
require.Equal(t, "application/json", resp.Header.Get("Content-Type"))
393+
394+
respBody, err = io.ReadAll(resp.Body)
395+
require.NoError(t, err)
396+
require.NoError(t, resp.Body.Close())
397+
398+
var msg2 pbV2.IdentityStatesResponse
399+
require.NoError(t, protojson.Unmarshal(respBody, &msg2))
400+
require.Len(t, msg2.States, 1)
401+
require.Equal(t, nodeID.Bytes(), msg2.States[0].Smesher)
407402
}
408403

409404
type noopHook struct{}

0 commit comments

Comments
 (0)