Skip to content

Commit

Permalink
tunnel: remove default port from wss server name (#274)
Browse files Browse the repository at this point in the history
  • Loading branch information
at-wat authored May 28, 2021
1 parent 4a1d830 commit a1fe6fe
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 1 deletion.
33 changes: 33 additions & 0 deletions tunnel/endpoint.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright 2020 SEQSENSE, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package tunnel

import (
"fmt"
"strings"
)

func serverNameFromEndpoint(scheme, endpoint string) string {
var defaultPort int
switch scheme {
case "wss":
defaultPort = 443
case "ws":
defaultPort = 80
default:
return endpoint
}
return strings.TrimSuffix(endpoint, fmt.Sprintf(":%d", defaultPort))
}
74 changes: 74 additions & 0 deletions tunnel/endpoint_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright 2020 SEQSENSE, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package tunnel

import (
"fmt"
"testing"
)

func TestServerNameFromEndpoint(t *testing.T) {
testCases := []struct {
scheme string
input string
output string
}{
{
scheme: "wss",
input: "hostname:123",
output: "hostname:123",
},
{
scheme: "wss",
input: "hostname:443",
output: "hostname",
},
{
scheme: "wss",
input: "hostname",
output: "hostname",
},
{
scheme: "ws",
input: "hostname:443",
output: "hostname:443",
},
{
scheme: "ws",
input: "hostname:80",
output: "hostname",
},
{
scheme: "ws",
input: "hostname",
output: "hostname",
},
}
for _, testCase := range testCases {
testCase := testCase
t.Run(fmt.Sprintf("%s://%s", testCase.scheme, testCase.input), func(t *testing.T) {
out := serverNameFromEndpoint(testCase.scheme, testCase.input)
if out != testCase.output {
t.Errorf(
"Expected ServerName %s for endpoint %s scheme %s, got %s",
testCase.output,
testCase.input,
testCase.scheme,
out,
)
}
})
}
}
20 changes: 19 additions & 1 deletion tunnel/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package tunnel

import (
"crypto/tls"
"errors"
"fmt"
"io"
"net"
Expand All @@ -34,6 +35,9 @@ const (
userAgent = "aws-iot-device-sdk-go/tunnel"
)

// ErrUnsupportedScheme indicate that the requested protocol scheme is not supported.
var ErrUnsupportedScheme = errors.New("unsupported scheme")

func endpointHost(region string) string {
return fmt.Sprintf(defaultEndpointHostFormat, region)
}
Expand Down Expand Up @@ -82,6 +86,10 @@ func openProxyConn(endpoint, mode, token string, opts ...ProxyOption) (*websocke
}
}

if err := opt.validate(); err != nil {
return nil, nil, err
}

wsc, err := websocket.NewConfig(
fmt.Sprintf("%s://%s/tunnel?local-proxy-mode=%s", opt.Scheme, endpoint, mode),
fmt.Sprintf("https://%s", endpoint),
Expand All @@ -91,7 +99,8 @@ func openProxyConn(endpoint, mode, token string, opts ...ProxyOption) (*websocke
}
if opt.Scheme == "wss" {
wsc.TlsConfig = &tls.Config{
ServerName: endpoint,
// Remove protocol default port number from the URI to avoid TLS certificate validation error.
ServerName: serverNameFromEndpoint(opt.Scheme, endpoint),
InsecureSkipVerify: opt.InsecureSkipVerify,
}
}
Expand Down Expand Up @@ -133,6 +142,15 @@ type ProxyOptions struct {
PingPeriod time.Duration
}

func (o *ProxyOptions) validate() error {
switch o.Scheme {
case "wss", "ws":
default:
return ioterr.New(ErrUnsupportedScheme, o.Scheme)
}
return nil
}

// WithErrorHandler sets a ErrorHandler.
func WithErrorHandler(h ErrorHandler) ProxyOption {
return func(opt *ProxyOptions) error {
Expand Down
31 changes: 31 additions & 0 deletions tunnel/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -596,3 +596,34 @@ func (acceptFunc) Close() error {
func (acceptFunc) Addr() net.Addr {
panic("not implemented")
}

func TestProxyOption_validate(t *testing.T) {
t.Run("Valid", func(t *testing.T) {
opts := []*ProxyOptions{
{Scheme: "ws"},
{Scheme: "wss"},
}
for _, o := range opts {
o := o
t.Run(o.Scheme, func(t *testing.T) {
if err := o.validate(); err != nil {
t.Errorf("Validation failed: %v", err)
}
})
}
})
t.Run("Invalid", func(t *testing.T) {
opts := []*ProxyOptions{
{Scheme: "http"},
{Scheme: ""},
}
for _, o := range opts {
o := o
t.Run(o.Scheme, func(t *testing.T) {
if err := o.validate(); err == nil {
t.Error("Validation must fail")
}
})
}
})
}

0 comments on commit a1fe6fe

Please sign in to comment.