Skip to content

Commit

Permalink
add socket's tls methods to rpc
Browse files Browse the repository at this point in the history
  • Loading branch information
hslam committed Jan 10, 2022
1 parent 512adcc commit 7486fa4
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 16 deletions.
3 changes: 1 addition & 2 deletions benchmarks/tls/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"
"github.com/hslam/rpc"
"github.com/hslam/rpc/benchmarks/tls/service"
"github.com/hslam/socket"
"github.com/hslam/stats"
"log"
"math/rand"
Expand Down Expand Up @@ -38,7 +37,7 @@ func main() {
}
var wrkClients []stats.Client
for i := 0; i < clients; i++ {
if conn, err := rpc.DialTLS(network, addr, codec, socket.SkipVerifyTLSConfig()); err != nil {
if conn, err := rpc.DialTLS(network, addr, codec, rpc.SkipVerifyTLSConfig()); err != nil {
log.Fatalln("dailing error: ", err)
} else {
wrkClients = append(wrkClients, &WrkClient{conn})
Expand Down
3 changes: 1 addition & 2 deletions benchmarks/tls/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"flag"
"github.com/hslam/rpc"
"github.com/hslam/rpc/benchmarks/tls/service"
"github.com/hslam/socket"
)

var network string
Expand All @@ -20,5 +19,5 @@ func init() {

func main() {
rpc.Register(new(service.Arith))
rpc.ListenTLS(network, addr, codec, socket.DefalutTLSConfig())
rpc.ListenTLS(network, addr, codec, rpc.DefalutTLSConfig())
}
5 changes: 2 additions & 3 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package rpc

import (
"github.com/hslam/rpc/examples/codec/json/service"
"github.com/hslam/socket"
"sync"
"testing"
"time"
Expand All @@ -28,11 +27,11 @@ func TestDialTLS(t *testing.T) {
addr := ":8880"
network := ""
codec := ""
if _, err := DialTLS(network, addr, codec, socket.SkipVerifyTLSConfig()); err == nil {
if _, err := DialTLS(network, addr, codec, SkipVerifyTLSConfig()); err == nil {
t.Error("The err should not be nil")
}
network = "tcp"
if _, err := DialTLS(network, addr, codec, socket.SkipVerifyTLSConfig()); err == nil {
if _, err := DialTLS(network, addr, codec, SkipVerifyTLSConfig()); err == nil {
t.Error("The err should not be nil")
}
}
Expand Down
3 changes: 1 addition & 2 deletions examples/tls/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ import (
"fmt"
"github.com/hslam/rpc"
"github.com/hslam/rpc/examples/tls/service"
"github.com/hslam/socket"
)

func main() {
conn, err := rpc.DialTLS("tcp", ":9999", "pb", socket.SkipVerifyTLSConfig())
conn, err := rpc.DialTLS("tcp", ":9999", "pb", rpc.SkipVerifyTLSConfig())
if err != nil {
panic(err)
}
Expand Down
3 changes: 1 addition & 2 deletions examples/tls/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@ package main
import (
"github.com/hslam/rpc"
"github.com/hslam/rpc/examples/tls/service"
"github.com/hslam/socket"
)

func main() {
rpc.Register(new(service.Arith))
rpc.ListenTLS("tcp", ":9999", "pb", socket.DefalutTLSConfig())
rpc.ListenTLS("tcp", ":9999", "pb", rpc.DefalutTLSConfig())
}
9 changes: 4 additions & 5 deletions listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package rpc

import (
"github.com/hslam/rpc/examples/codec/json/service"
"github.com/hslam/socket"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -73,10 +72,10 @@ func TestListenTLS(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
ListenTLS(network, addr, codec, socket.DefalutTLSConfig())
ListenTLS(network, addr, codec, DefalutTLSConfig())
}()
time.Sleep(time.Millisecond * 10)
conn, err := DialTLS(network, addr, codec, socket.SkipVerifyTLSConfig())
conn, err := DialTLS(network, addr, codec, SkipVerifyTLSConfig())
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -155,11 +154,11 @@ func TestServerListenTLS(t *testing.T) {
addr := ":8880"
network := ""
codec := ""
if err := DefaultServer.ListenTLS(network, addr, codec, socket.SkipVerifyTLSConfig()); err == nil {
if err := DefaultServer.ListenTLS(network, addr, codec, SkipVerifyTLSConfig()); err == nil {
t.Error("The err should not be nil")
}
network = "tcp"
if err := DefaultServer.ListenTLS(network, addr, codec, socket.SkipVerifyTLSConfig()); err == nil {
if err := DefaultServer.ListenTLS(network, addr, codec, SkipVerifyTLSConfig()); err == nil {
t.Error("The err should not be nil")
}
}
Expand Down
35 changes: 35 additions & 0 deletions tls.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) 2022 Meng Huang ([email protected])
// This package is licensed under a MIT license that can be found in the LICENSE file.

package rpc

import (
"crypto/tls"
"github.com/hslam/socket"
)

// LoadTLSConfig returns a TLS config by loading the certificate file and the key file.
func LoadTLSConfig(certFile, keyFile string) (*tls.Config, error) {
return socket.LoadTLSConfig(certFile, keyFile)
}

// TLSConfig returns a TLS config by the certificate data and the key data.
func TLSConfig(certPEM []byte, keyPEM []byte) *tls.Config {
return socket.TLSConfig(certPEM, keyPEM)
}

// DefalutTLSConfig returns a default TLS config.
func DefalutTLSConfig() *tls.Config {
return socket.DefalutTLSConfig()
}

// SkipVerifyTLSConfig returns a insecure skip verify TLS config.
func SkipVerifyTLSConfig() *tls.Config {
return socket.SkipVerifyTLSConfig()
}

// DefaultKeyPEM represents the default private key data.
var DefaultKeyPEM = socket.DefaultKeyPEM

// DefaultCertPEM represents the default certificate data.
var DefaultCertPEM = socket.DefaultCertPEM
44 changes: 44 additions & 0 deletions tls_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) 2022 Meng Huang ([email protected])
// This package is licensed under a MIT license that can be found in the LICENSE file.

package rpc

import (
"os"
"testing"
)

func TestLoadTLSConfig(t *testing.T) {
var certFileName = "tmpTestCertFile"
var keyFileName = "tmpTestKeyFile"
var err error
_, err = LoadTLSConfig("", "")
if err == nil {
t.Error("should be no such file or directory")
}
certFile, _ := os.Create(certFileName)
certFile.Write(DefaultCertPEM)
certFile.Close()
defer os.Remove(certFileName)
_, err = LoadTLSConfig(certFileName, "")
if err == nil {
t.Error("should be no such file or directory")
}
keyFile, _ := os.Create(keyFileName)
keyFile.Write(DefaultKeyPEM)
keyFile.Close()
defer os.Remove(keyFileName)
_, err = LoadTLSConfig(certFileName, keyFileName)
if err != nil {
t.Error(err)
}
}

func TestTLSConfig(t *testing.T) {
defer func() {
if err := recover(); err == nil {
t.Error("should panic")
}
}()
TLSConfig(DefaultCertPEM, []byte{})
}

0 comments on commit 7486fa4

Please sign in to comment.