diff --git a/client.go b/client.go index 21d265cf..02bfc3df 100644 --- a/client.go +++ b/client.go @@ -5,6 +5,7 @@ package smtp import ( + "bytes" "crypto/tls" "encoding/base64" "errors" @@ -150,30 +151,22 @@ func NewClientLMTP(conn net.Conn) *Client { // setConn sets the underlying network connection for the client. func (c *Client) setConn(conn net.Conn) { - c.conn = conn - - var r io.Reader = conn - var w io.Writer = conn - - r = &lineLimitReader{ + r := &lineLimitReader{ R: conn, // Doubled maximum line length per RFC 5321 (Section 4.5.3.1.6) LineLimit: 2000, } - r = io.TeeReader(r, clientDebugWriter{c}) - w = io.MultiWriter(w, clientDebugWriter{c}) - - rwc := struct { + c.conn = conn + c.text = textproto.NewConn(struct { io.Reader io.Writer io.Closer }{ - Reader: r, - Writer: w, + Reader: io.TeeReader(r, clientDebugWriter{[]byte("SERVER "), c}), + Writer: io.MultiWriter(conn, clientDebugWriter{[]byte("CLIENT "), c}), Closer: conn, - } - c.text = textproto.NewConn(rwc) + }) } // Close closes the connection. @@ -942,14 +935,37 @@ func toSMTPErr(protoErr *textproto.Error) *SMTPError { } type clientDebugWriter struct { - c *Client + prefix []byte + c *Client } func (cdw clientDebugWriter) Write(b []byte) (int, error) { if cdw.c.DebugWriter == nil { return len(b), nil } - return cdw.c.DebugWriter.Write(b) + + // Prefix every line with the prefix. + var n int + for { + i := bytes.Index(b, []byte("\r\n")) + if i == -1 { + i = len(b) - 1 + } + _, err := cdw.c.DebugWriter.Write(cdw.prefix) + if err != nil { + return n, err + } + nn, err := cdw.c.DebugWriter.Write(b[:i+2]) + if err != nil { + return n + nn, err + } + n += nn + b = b[i+2:] + if len(b) == 0 { + break + } + } + return n, nil } // validateLine checks to see if a line has CR or LF. diff --git a/client_test.go b/client_test.go index 6f35a7d6..5f586127 100644 --- a/client_test.go +++ b/client_test.go @@ -1148,3 +1148,44 @@ func TestClientMTPRIORITY(t *testing.T) { t.Errorf("wrote %q; want %q", actualcmds, client) } } + +func TestClientDebugWriter(t *testing.T) { + server := strings.Join(strings.Split(mtPriorityServer, "\n"), "\r\n") + client := strings.Join(strings.Split(mtPriorityClient, "\n"), "\r\n") + + var ( + wrote bytes.Buffer + dbg = new(bytes.Buffer) + fake faker + ) + fake.ReadWriter = struct { + io.Reader + io.Writer + }{ + strings.NewReader(server), + &wrote, + } + + c := NewClient(fake) + c.DebugWriter = dbg + c.didHello = true + c.ext = map[string]string{"MT-PRIORITY": ""} + priority := 6 + c.Rcpt("root@nsa.gov", &RcptOptions{ + MTPriority: &priority, + }) + c.Close() + if actualcmds := wrote.String(); client != actualcmds { + t.Errorf("wrote %q; want %q", actualcmds, client) + } + + tr := strings.NewReplacer("\t", "", "\n", "\r\n") + want := tr.Replace(` + CLIENT RCPT TO: MT-PRIORITY=6 + SERVER 220 hello world + SERVER 250 ok + `[1:]) + if dbg.String() != want { + t.Errorf("debug wrong\nhave: %q\nwant: %q", dbg, want) + } +}