Skip to content

Commit ff7dd92

Browse files
committed
Add curve preferences, pinned public key SHA256 and mTLS for TLS options
1 parent 44ff114 commit ff7dd92

File tree

7 files changed

+613
-126
lines changed

7 files changed

+613
-126
lines changed

common/tls/reality_server.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,10 @@ func NewRealityServer(ctx context.Context, logger log.ContextLogger, options opt
6868
return nil, E.New("unknown cipher_suite: ", cipherSuite)
6969
}
7070
}
71-
if len(options.Certificate) > 0 || options.CertificatePath != "" {
71+
if len(options.CurvePreferences) > 0 {
72+
return nil, E.New("curve preferences is unavailable in reality")
73+
}
74+
if len(options.Certificate) > 0 || options.CertificatePath != "" || len(options.ClientCertificatePublicKeySHA256) > 0 {
7275
return nil, E.New("certificate is unavailable in reality")
7376
}
7477
if len(options.Key) > 0 || options.KeyPath != "" {

common/tls/std_client.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
package tls
22

33
import (
4+
"bytes"
45
"context"
6+
"crypto/sha256"
57
"crypto/tls"
68
"crypto/x509"
9+
"encoding/base64"
710
"net"
811
"os"
912
"strings"
@@ -13,6 +16,7 @@ import (
1316
"github.com/sagernet/sing-box/common/tlsfragment"
1417
C "github.com/sagernet/sing-box/constant"
1518
"github.com/sagernet/sing-box/option"
19+
"github.com/sagernet/sing/common"
1620
E "github.com/sagernet/sing/common/exceptions"
1721
"github.com/sagernet/sing/common/logger"
1822
"github.com/sagernet/sing/common/ntp"
@@ -108,6 +112,15 @@ func NewSTDClient(ctx context.Context, logger logger.ContextLogger, serverAddres
108112
return err
109113
}
110114
}
115+
if len(options.CertificatePublicKeySHA256) > 0 {
116+
if len(options.Certificate) > 0 || options.CertificatePath != "" {
117+
return nil, E.New("certificate_public_key_sha256 is conflict with certificate or certificate_path")
118+
}
119+
tlsConfig.InsecureSkipVerify = true
120+
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
121+
return verifyPublicKeySHA256(options.CertificatePublicKeySHA256, rawCerts, tlsConfig.Time)
122+
}
123+
}
111124
if len(options.ALPN) > 0 {
112125
tlsConfig.NextProtos = options.ALPN
113126
}
@@ -137,6 +150,14 @@ func NewSTDClient(ctx context.Context, logger logger.ContextLogger, serverAddres
137150
return nil, E.New("unknown cipher_suite: ", cipherSuite)
138151
}
139152
}
153+
if len(options.CurvePreferences) > 0 {
154+
for _, curve := range options.CurvePreferences {
155+
tlsConfig.CurvePreferences = append(tlsConfig.CurvePreferences, tls.CurveID(curve))
156+
}
157+
} else {
158+
// DisableX25519MLKEM768 by default
159+
tlsConfig.CurvePreferences = []tls.CurveID{tls.X25519, tls.CurveP256, tls.CurveP384, tls.CurveP521}
160+
}
140161
var certificate []byte
141162
if len(options.Certificate) > 0 {
142163
certificate = []byte(strings.Join(options.Certificate, "\n"))
@@ -175,3 +196,64 @@ func NewSTDClient(ctx context.Context, logger logger.ContextLogger, serverAddres
175196
}
176197
return config, nil
177198
}
199+
200+
func verifyPublicKeySHA256(knownHashValues [][]byte, rawCerts [][]byte, timeFunc func() time.Time) error {
201+
for i, rawCert := range rawCerts {
202+
certificate, err := x509.ParseCertificate(rawCert)
203+
if err != nil {
204+
continue
205+
}
206+
207+
// Extract public key and hash it
208+
pubKeyBytes, err := x509.MarshalPKIXPublicKey(certificate.PublicKey)
209+
if err != nil {
210+
continue
211+
}
212+
hash := sha256.Sum256(pubKeyBytes)
213+
214+
var matched bool
215+
for _, value := range knownHashValues {
216+
if bytes.Equal(value, hash[:]) {
217+
matched = true
218+
break
219+
}
220+
}
221+
if !matched {
222+
continue
223+
}
224+
if i == 0 {
225+
return nil
226+
}
227+
certificates := make([]*x509.Certificate, i+1)
228+
for j := range certificates {
229+
certificate, err := x509.ParseCertificate(rawCerts[j])
230+
if err != nil {
231+
return err
232+
}
233+
certificates[j] = certificate
234+
}
235+
verifyOptions := x509.VerifyOptions{
236+
Roots: x509.NewCertPool(),
237+
Intermediates: x509.NewCertPool(),
238+
}
239+
if timeFunc != nil {
240+
verifyOptions.CurrentTime = timeFunc()
241+
}
242+
verifyOptions.Roots.AddCert(certificates[i])
243+
for _, certificate := range certificates[1:] {
244+
verifyOptions.Intermediates.AddCert(certificate)
245+
}
246+
return common.Error(certificates[0].Verify(verifyOptions))
247+
}
248+
249+
// Generate error message with first certificate's public key hash
250+
if len(rawCerts) > 0 {
251+
if certificate, err := x509.ParseCertificate(rawCerts[0]); err == nil {
252+
if pubKeyBytes, err := x509.MarshalPKIXPublicKey(certificate.PublicKey); err == nil {
253+
hash := sha256.Sum256(pubKeyBytes)
254+
return E.New("unrecognized public key: ", base64.StdEncoding.EncodeToString(hash[:]))
255+
}
256+
}
257+
}
258+
return E.New("unrecognized certificate")
259+
}

common/tls/std_server.go

Lines changed: 93 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package tls
33
import (
44
"context"
55
"crypto/tls"
6+
"crypto/x509"
67
"net"
78
"os"
89
"strings"
@@ -22,16 +23,17 @@ import (
2223
var errInsecureUnused = E.New("tls: insecure unused")
2324

2425
type STDServerConfig struct {
25-
access sync.RWMutex
26-
config *tls.Config
27-
logger log.Logger
28-
acmeService adapter.SimpleLifecycle
29-
certificate []byte
30-
key []byte
31-
certificatePath string
32-
keyPath string
33-
echKeyPath string
34-
watcher *fswatch.Watcher
26+
access sync.RWMutex
27+
config *tls.Config
28+
logger log.Logger
29+
acmeService adapter.SimpleLifecycle
30+
certificate []byte
31+
key []byte
32+
certificatePath string
33+
keyPath string
34+
clientCertificatePath []string
35+
echKeyPath string
36+
watcher *fswatch.Watcher
3537
}
3638

3739
func (c *STDServerConfig) ServerName() string {
@@ -111,6 +113,9 @@ func (c *STDServerConfig) startWatcher() error {
111113
if c.echKeyPath != "" {
112114
watchPath = append(watchPath, c.echKeyPath)
113115
}
116+
if len(c.clientCertificatePath) > 0 {
117+
watchPath = append(watchPath, c.clientCertificatePath...)
118+
}
114119
if len(watchPath) == 0 {
115120
return nil
116121
}
@@ -159,6 +164,30 @@ func (c *STDServerConfig) certificateUpdated(path string) error {
159164
c.config = config
160165
c.access.Unlock()
161166
c.logger.Info("reloaded TLS certificate")
167+
} else if common.Contains(c.clientCertificatePath, path) {
168+
clientCertificateCA := x509.NewCertPool()
169+
var reloaded bool
170+
for _, certPath := range c.clientCertificatePath {
171+
content, err := os.ReadFile(certPath)
172+
if err != nil {
173+
c.logger.Error(E.Cause(err, "reload certificate from ", c.clientCertificatePath))
174+
continue
175+
}
176+
if !clientCertificateCA.AppendCertsFromPEM(content) {
177+
c.logger.Error(E.New("invalid client certificate file: ", certPath))
178+
continue
179+
}
180+
reloaded = true
181+
}
182+
if !reloaded {
183+
return E.New("client certificates is empty")
184+
}
185+
c.access.Lock()
186+
config := c.config.Clone()
187+
config.ClientCAs = clientCertificateCA
188+
c.config = config
189+
c.access.Unlock()
190+
c.logger.Info("reloaded client certificates")
162191
} else if path == c.echKeyPath {
163192
echKey, err := os.ReadFile(c.echKeyPath)
164193
if err != nil {
@@ -235,8 +264,14 @@ func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option.
235264
return nil, E.New("unknown cipher_suite: ", cipherSuite)
236265
}
237266
}
238-
var certificate []byte
239-
var key []byte
267+
for _, curveID := range options.CurvePreferences {
268+
tlsConfig.CurvePreferences = append(tlsConfig.CurvePreferences, tls.CurveID(curveID))
269+
}
270+
tlsConfig.ClientAuth = tls.ClientAuthType(options.ClientAuthentication)
271+
var (
272+
certificate []byte
273+
key []byte
274+
)
240275
if acmeService == nil {
241276
if len(options.Certificate) > 0 {
242277
certificate = []byte(strings.Join(options.Certificate, "\n"))
@@ -278,6 +313,43 @@ func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option.
278313
tlsConfig.Certificates = []tls.Certificate{keyPair}
279314
}
280315
}
316+
if len(options.ClientCertificate) > 0 || len(options.ClientCertificatePath) > 0 {
317+
if tlsConfig.ClientAuth == tls.NoClientCert {
318+
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
319+
}
320+
}
321+
if tlsConfig.ClientAuth == tls.VerifyClientCertIfGiven || tlsConfig.ClientAuth == tls.RequireAndVerifyClientCert {
322+
if len(options.ClientCertificate) > 0 {
323+
clientCertificateCA := x509.NewCertPool()
324+
if !clientCertificateCA.AppendCertsFromPEM([]byte(strings.Join(options.ClientCertificate, "\n"))) {
325+
return nil, E.New("invalid client certificate strings")
326+
}
327+
tlsConfig.ClientCAs = clientCertificateCA
328+
} else if len(options.ClientCertificatePath) > 0 {
329+
clientCertificateCA := x509.NewCertPool()
330+
for _, path := range options.ClientCertificatePath {
331+
content, err := os.ReadFile(path)
332+
if err != nil {
333+
return nil, E.Cause(err, "read client certificate from ", path)
334+
}
335+
if !clientCertificateCA.AppendCertsFromPEM(content) {
336+
return nil, E.New("invalid client certificate file: ", path)
337+
}
338+
}
339+
tlsConfig.ClientCAs = clientCertificateCA
340+
} else if len(options.ClientCertificatePublicKeySHA256) > 0 {
341+
if tlsConfig.ClientAuth == tls.RequireAndVerifyClientCert {
342+
tlsConfig.ClientAuth = tls.RequireAnyClientCert
343+
} else if tlsConfig.ClientAuth == tls.VerifyClientCertIfGiven {
344+
tlsConfig.ClientAuth = tls.RequestClientCert
345+
}
346+
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
347+
return verifyPublicKeySHA256(options.ClientCertificatePublicKeySHA256, rawCerts, tlsConfig.Time)
348+
}
349+
} else {
350+
return nil, E.New("missing client_certificate, client_certificate_path or client_certificate_public_key_sha256 for client authentication")
351+
}
352+
}
281353
var echKeyPath string
282354
if options.ECH != nil && options.ECH.Enabled {
283355
err = parseECHServerConfig(ctx, options, tlsConfig, &echKeyPath)
@@ -286,14 +358,15 @@ func NewSTDServer(ctx context.Context, logger log.ContextLogger, options option.
286358
}
287359
}
288360
serverConfig := &STDServerConfig{
289-
config: tlsConfig,
290-
logger: logger,
291-
acmeService: acmeService,
292-
certificate: certificate,
293-
key: key,
294-
certificatePath: options.CertificatePath,
295-
keyPath: options.KeyPath,
296-
echKeyPath: echKeyPath,
361+
config: tlsConfig,
362+
logger: logger,
363+
acmeService: acmeService,
364+
certificate: certificate,
365+
key: key,
366+
certificatePath: options.CertificatePath,
367+
clientCertificatePath: options.ClientCertificatePath,
368+
keyPath: options.KeyPath,
369+
echKeyPath: echKeyPath,
297370
}
298371
serverConfig.config.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
299372
serverConfig.access.Lock()

common/tls/utls_client.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,15 @@ func NewUTLSClient(ctx context.Context, logger logger.ContextLogger, serverAddre
167167
}
168168
tlsConfig.InsecureServerNameToVerify = serverName
169169
}
170+
if len(options.CertificatePublicKeySHA256) > 0 {
171+
if len(options.Certificate) > 0 || options.CertificatePath != "" {
172+
return nil, E.New("certificate_public_key_sha256 is conflict with certificate or certificate_path")
173+
}
174+
tlsConfig.InsecureSkipVerify = true
175+
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
176+
return verifyPublicKeySHA256(options.CertificatePublicKeySHA256, rawCerts, tlsConfig.Time)
177+
}
178+
}
170179
if len(options.ALPN) > 0 {
171180
tlsConfig.NextProtos = options.ALPN
172181
}

0 commit comments

Comments
 (0)