Skip to content

Commit

Permalink
target/smtp: Check-in accidentally reverted attempt_starttls changes
Browse files Browse the repository at this point in the history
  • Loading branch information
foxcpp committed Jan 25, 2025
1 parent cff6cfa commit be0ec6b
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 116 deletions.
21 changes: 15 additions & 6 deletions framework/config/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package config

import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"
Expand Down Expand Up @@ -305,6 +306,16 @@ func (m *Map) DataSize(name string, inheritGlobal, required bool, defaultVal int
}, store)
}

func ParseBool(s string) (bool, error) {
switch strings.ToLower(s) {
case "1", "true", "on", "yes":
return true, nil
case "0", "false", "off", "no":
return false, nil
}
return false, fmt.Errorf("bool argument should be 'yes' or 'no'")
}

// Bool maps presence of some configuration directive to a boolean variable.
// Additionally, 'name yes' and 'name no' are mapped to true and false
// correspondingly.
Expand All @@ -327,13 +338,11 @@ func (m *Map) Bool(name string, inheritGlobal, defaultVal bool, store *bool) {
return nil, NodeErr(node, "expected exactly 1 argument")
}

switch strings.ToLower(node.Args[0]) {
case "1", "true", "on", "yes":
return true, nil
case "0", "false", "off", "no":
return false, nil
b, err := ParseBool(node.Args[0])
if err != nil {
return nil, NodeErr(node, "bool argument should be 'yes' or 'no'")
}
return nil, NodeErr(node, "bool argument should be 'yes' or 'no'")
return b, nil
}, store)
}

Expand Down
9 changes: 6 additions & 3 deletions internal/smtpconn/smtpconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,12 +259,15 @@ func (c *C) attemptConnect(ctx context.Context, lmtp bool, endp config.Endpoint,
return false, nil, nil, err
}

if endp.IsTLS() || !starttls {
return endp.IsTLS(), cl, conn, nil
if !starttls {
return false, cl, conn, nil
}

if ok, _ := cl.Extension("STARTTLS"); !ok {
return false, cl, conn, nil
if err := cl.Quit(); err != nil {
cl.Close()
}
return false, nil, nil, fmt.Errorf("TLS required but unsupported by downstream")
}

cfg := tlsConfig.Clone()
Expand Down
59 changes: 38 additions & 21 deletions internal/target/smtp/smtp_downstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ package smtp_downstream
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"runtime/trace"
Expand All @@ -54,12 +53,11 @@ type Downstream struct {
lmtp bool
targetsArg []string

requireTLS bool
attemptStartTLS bool
hostname string
endpoints []config.Endpoint
saslFactory saslClientFactory
tlsConfig tls.Config
starttls bool
hostname string
endpoints []config.Endpoint
saslFactory saslClientFactory
tlsConfig tls.Config

connectTimeout time.Duration
commandTimeout time.Duration
Expand Down Expand Up @@ -89,10 +87,34 @@ func NewDownstream(modName, instName string, _, inlineArgs []string) (module.Mod
}

func (u *Downstream) Init(cfg *config.Map) error {
var attemptTLS *bool

var targetsArg []string
cfg.Bool("debug", true, false, &u.log.Debug)
cfg.Bool("require_tls", false, false, &u.requireTLS)
cfg.Bool("attempt_starttls", false, !u.lmtp, &u.attemptStartTLS)
cfg.Callback("require_tls", func(m *config.Map, node config.Node) error {
u.log.Msg("require_tls directive is deprecated and ignored")
return nil
})
cfg.Callback("attempt_starttls", func(m *config.Map, node config.Node) error {
u.log.Msg("attempt_starttls directive is deprecated and equivalent to starttls")

if len(node.Args) == 0 {
trueVal := true
attemptTLS = &trueVal
return nil
}
if len(node.Args) != 1 {
return config.NodeErr(node, "expected exactly 1 argument")
}

b, err := config.ParseBool(node.Args[0])
if err != nil {
return err
}
attemptTLS = &b
return nil
})
cfg.Bool("starttls", false, !u.lmtp, &u.starttls)
cfg.String("hostname", true, true, "", &u.hostname)
cfg.StringList("targets", false, false, nil, &targetsArg)
cfg.Custom("auth", false, false, func() (interface{}, error) {
Expand All @@ -109,6 +131,10 @@ func (u *Downstream) Init(cfg *config.Map) error {
return err
}

if attemptTLS != nil {
u.starttls = *attemptTLS
}

// INTERNATIONALIZATION: See RFC 6531 Section 3.7.1.
var err error
u.hostname, err = idna.ToASCII(u.hostname)
Expand Down Expand Up @@ -201,14 +227,11 @@ func (d *delivery) connect(ctx context.Context) error {
}

for _, endp := range d.u.endpoints {
var (
didTLS bool
err error
)
var err error
if d.u.lmtp {
didTLS, err = conn.ConnectLMTP(ctx, endp, d.u.attemptStartTLS, &d.u.tlsConfig)
_, err = conn.ConnectLMTP(ctx, endp, d.u.starttls, &d.u.tlsConfig)
} else {
didTLS, err = conn.Connect(ctx, endp, d.u.attemptStartTLS, &d.u.tlsConfig)
_, err = conn.Connect(ctx, endp, d.u.starttls, &d.u.tlsConfig)
}
if err != nil {
if len(d.u.endpoints) != 1 {
Expand All @@ -220,12 +243,6 @@ func (d *delivery) connect(ctx context.Context) error {

d.log.DebugMsg("connected", "downstream_server", conn.ServerName())

if !didTLS && d.u.requireTLS {
conn.Close()
lastErr = errors.New("TLS is required, but unsupported by downstream")
continue
}

lastErr = nil
break
}
Expand Down
93 changes: 7 additions & 86 deletions internal/target/smtp/smtp_downstream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func TestDownstreamDelivery_MAILErr(t *testing.T) {
testutils.CheckSMTPErr(t, err, 550, exterrors.EnhancedCode{5, 1, 2}, "Hey")
}

func TestDownstreamDelivery_AttemptTLS(t *testing.T) {
func TestDownstreamDelivery_StartTLS(t *testing.T) {
clientCfg, be, srv := testutils.SMTPServerSTARTTLS(t, "127.0.0.1:"+testPort)
defer srv.Close()
defer testutils.CheckSMTPConnLeak(t, srv)
Expand All @@ -221,9 +221,9 @@ func TestDownstreamDelivery_AttemptTLS(t *testing.T) {
Port: testPort,
},
},
tlsConfig: *clientCfg.Clone(),
attemptStartTLS: true,
log: testutils.Logger(t, "target.smtp"),
tlsConfig: *clientCfg.Clone(),
starttls: true,
log: testutils.Logger(t, "target.smtp"),
}

testutils.DoTestDelivery(t, mod, "[email protected]", []string{"[email protected]"})
Expand All @@ -235,85 +235,7 @@ func TestDownstreamDelivery_AttemptTLS(t *testing.T) {
}
}

func TestDownstreamDelivery_AttemptTLS_Fallback(t *testing.T) {
be, srv := testutils.SMTPServer(t, "127.0.0.1:"+testPort)
defer srv.Close()
defer testutils.CheckSMTPConnLeak(t, srv)

mod := &Downstream{
hostname: "mx.example.invalid",
endpoints: []config.Endpoint{
{
Scheme: "tcp",
Host: "127.0.0.1",
Port: testPort,
},
},
attemptStartTLS: true,
log: testutils.Logger(t, "target.smtp"),
}

testutils.DoTestDelivery(t, mod, "[email protected]", []string{"[email protected]"})
be.CheckMsg(t, 0, "[email protected]", []string{"[email protected]"})
}

func TestDownstreamDelivery_RequireTLS(t *testing.T) {
clientCfg, be, srv := testutils.SMTPServerSTARTTLS(t, "127.0.0.1:"+testPort)
defer srv.Close()
defer testutils.CheckSMTPConnLeak(t, srv)

mod := &Downstream{
hostname: "mx.example.invalid",
endpoints: []config.Endpoint{
{
Scheme: "tcp",
Host: "127.0.0.1",
Port: testPort,
},
},
tlsConfig: *clientCfg.Clone(),
attemptStartTLS: true,
requireTLS: true,
log: testutils.Logger(t, "target.smtp"),
}

testutils.DoTestDelivery(t, mod, "[email protected]", []string{"[email protected]"})
be.CheckMsg(t, 0, "[email protected]", []string{"[email protected]"})
tlsState, ok := be.Messages[0].Conn.TLSConnectionState()
if !ok || !tlsState.HandshakeComplete {
t.Fatal("Message was not delivered over TLS")
}
}

func TestDownstreamDelivery_RequireTLS_Implicit(t *testing.T) {
clientCfg, be, srv := testutils.SMTPServerTLS(t, "127.0.0.1:"+testPort)
defer srv.Close()
defer testutils.CheckSMTPConnLeak(t, srv)

mod := &Downstream{
hostname: "mx.example.invalid",
endpoints: []config.Endpoint{
{
Scheme: "tls",
Host: "127.0.0.1",
Port: testPort,
},
},
tlsConfig: *clientCfg.Clone(),
attemptStartTLS: true,
requireTLS: true,
log: testutils.Logger(t, "target.smtp"),
}

testutils.DoTestDelivery(t, mod, "[email protected]", []string{"[email protected]"})
be.CheckMsg(t, 0, "[email protected]", []string{"[email protected]"})
tlsState, ok := be.Messages[0].Conn.TLSConnectionState()
if !ok || !tlsState.HandshakeComplete {
t.Fatal("Message was not delivered over TLS")
}
}

func TestDownstreamDelivery_RequireTLS_Fail(t *testing.T) {
func TestDownstreamDelivery_StartTLS_NoFallback(t *testing.T) {
_, srv := testutils.SMTPServer(t, "127.0.0.1:"+testPort)
defer srv.Close()
defer testutils.CheckSMTPConnLeak(t, srv)
Expand All @@ -327,9 +249,8 @@ func TestDownstreamDelivery_RequireTLS_Fail(t *testing.T) {
Port: testPort,
},
},
attemptStartTLS: true,
requireTLS: true,
log: testutils.Logger(t, "target.smtp"),
starttls: true,
log: testutils.Logger(t, "target.smtp"),
}

_, err := testutils.DoTestDeliveryErr(t, mod, "[email protected]", []string{"[email protected]"})
Expand Down
4 changes: 4 additions & 0 deletions internal/target/smtp/smtputf8_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ func TestDownstreamDelivery_EHLO_ALabel(t *testing.T) {
Name: "hostname",
Args: []string{"тест.invalid"},
},
{
Name: "starttls",
Args: []string{"no"},
},
},
})); err != nil {
t.Fatal(err)
Expand Down

0 comments on commit be0ec6b

Please sign in to comment.