Skip to content

Commit

Permalink
Merge pull request #18 from linyows/refactoring
Browse files Browse the repository at this point in the history
Refactoring
  • Loading branch information
linyows authored Aug 30, 2023
2 parents 7469711 + bd7ebda commit 9ac13d1
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 97 deletions.
145 changes: 91 additions & 54 deletions pipe.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type Pipe struct {
afterConnHook func()
}

type Mediator func([]byte, int) ([]byte, int)
type Mediator func([]byte, int) ([]byte, int, bool)
type Flow int
type Data []byte
type Direction string
Expand All @@ -52,7 +52,7 @@ const (
srcToPxy Direction = ">|"
pxyToDst Direction = "|>"
dstToPxy Direction = "|<"
//pxyToSrc Direction = "<|"
pxyToSrc Direction = "<|"
srcToDst Direction = "->"
dstToSrc Direction = "<-"
onPxy Direction = "--"
Expand All @@ -63,7 +63,7 @@ const (
// SMTP response codes
codeServiceReady int = 220
codeStartingMailInput int = 354
//codeActionCompleted int = 250
codeActionCompleted int = 250
)

var (
Expand All @@ -78,6 +78,76 @@ func (e Elapse) String() string {
return fmt.Sprintf("%d msec", e)
}

func (p *Pipe) mediateOnUpstream(b []byte, i int) ([]byte, int, bool) {
data := b[0:i]

if !p.tls || p.rMailAddr == nil {
p.setSenderMailAddress(data)
p.setSenderServerName(data)
p.setReceiverMailAddressAndServerName(data)
}

if !p.tls && p.readytls {
p.locked = true
er := p.starttls()
if er != nil {
go p.afterCommHook([]byte(fmt.Sprintf("starttls error: %s", er.Error())), pxyToDst)
}
p.readytls = false
go p.afterCommHook(data, srcToPxy)
}

if p.locked {
p.waitForTLSConn(b, i)
go p.afterCommHook(data, pxyToDst)
} else {
go p.afterCommHook(p.removeMailBody(data), srcToDst)
}

return b, i, false
}

func (p *Pipe) mediateOnDownstream(b []byte, i int) ([]byte, int, bool) {
data := b[0:i]

if p.isResponseOfEHLOWithStartTLS(b) {
go p.afterCommHook(data, dstToPxy)
b, i = p.removeStartTLSCommand(b, i)
} else if p.isResponseOfReadyToStartTLS(b) {
go p.afterCommHook(data, dstToPxy)
er := p.connectTLS()
if er != nil {
go p.afterCommHook([]byte(fmt.Sprintf("TLS connection error: %s", er.Error())), dstToPxy)
}
}

// time before email input
p.setTimeAtDataStarting(b)

// remove buffering ready response
if p.tls && !p.readytls && p.locked {
// continue
return b, i, true
}

if p.isResponseOfEHLOWithoutStartTLS(b) {
go p.afterCommHook(data, pxyToSrc)
} else {
go p.afterCommHook(data, dstToSrc)
}

return b, i, false
}

func (p *Pipe) setTimeAtDataStarting(b []byte) {
list := bytes.Split(b, []byte(crlf))
for _, v := range list {
if len(v) >= 3 && string(v[:3]) == fmt.Sprint(codeStartingMailInput) {
p.timeAtDataStarting = time.Now()
}
}
}

func (p *Pipe) Do() {
p.timeAtConnected = time.Now()
go p.afterCommHook([]byte(fmt.Sprintf("connected to %s", p.rAddr)), onPxy)
Expand All @@ -87,21 +157,7 @@ func (p *Pipe) Do() {

// Sender --- packet --> Proxy
go func() {
_, err := p.copy(upstream, func(b []byte, i int) ([]byte, int) {
if !p.tls || p.rMailAddr == nil {
p.pairing(b[0:i])
}
if !p.tls && p.readytls {
p.locked = true
er := p.starttls()
if er != nil {
go p.afterCommHook([]byte(fmt.Sprintf("starttls error: %s", er.Error())), pxyToDst)
}
p.readytls = false
go p.afterCommHook(b[0:i], srcToPxy)
}
return b, i
})
_, err := p.copy(upstream, p.mediateOnUpstream)
if err != nil {
go p.afterCommHook([]byte(fmt.Sprintf("io copy error: %s", err.Error())), pxyToDst)
}
Expand All @@ -110,36 +166,30 @@ func (p *Pipe) Do() {

// Proxy <--- packet -- Receiver
go func() {
_, err := p.copy(downstream, func(b []byte, i int) ([]byte, int) {
if p.isResponseOfEHLOWithStartTLS(b) {
go p.afterCommHook(b[0:i], dstToPxy)
b, i = p.removeStartTLSCommand(b, i)
} else if p.isResponseOfReadyToStartTLS(b) {
go p.afterCommHook(b[0:i], dstToPxy)
er := p.connectTLS()
if er != nil {
go p.afterCommHook([]byte(fmt.Sprintf("TLS connection error: %s", er.Error())), dstToPxy)
}
}
return b, i
})
_, err := p.copy(downstream, p.mediateOnDownstream)
if err != nil {
go p.afterCommHook([]byte(fmt.Sprintf("io copy error: %s", err.Error())), dstToPxy)
}
once.Do(p.close())
}()
}

func (p *Pipe) pairing(b []byte) {
func (p *Pipe) setSenderServerName(b []byte) {
if bytes.Contains(b, []byte("HELO")) {
p.sServerName = bytes.TrimSpace(bytes.Replace(b, []byte("HELO"), []byte(""), 1))
}
if bytes.Contains(b, []byte("EHLO")) {
p.sServerName = bytes.TrimSpace(bytes.Replace(b, []byte("EHLO"), []byte(""), 1))
}
}

func (p *Pipe) setSenderMailAddress(b []byte) {
if bytes.Contains(b, []byte(mailFromPrefix)) {
p.sMailAddr = bytes.Replace(mailFromRegex.Find(b), []byte(mailFromPrefix), []byte(""), 1)
}
}

func (p *Pipe) setReceiverMailAddressAndServerName(b []byte) {
if bytes.Contains(b, []byte(rcptToPrefix)) {
p.rMailAddr = bytes.Replace(mailToRegex.Find(b), []byte(rcptToPrefix), []byte(""), 1)
p.rServerName = bytes.Split(p.rMailAddr, []byte("@"))[1]
Expand Down Expand Up @@ -177,35 +227,18 @@ func (p *Pipe) copy(dr Flow, fn Mediator) (written int64, err error) {
buf := make([]byte, bufferSize)

for {
var isContinue bool
if p.locked {
continue
}

nr, er := p.src(dr).Read(buf)
if nr > 0 {
buf, nr = fn(buf, nr)
if dr == upstream && p.locked {
p.waitForTLSConn(buf, nr)
}
if nr == 0 {
// Run the Mediator!
buf, nr, isContinue = fn(buf, nr)
if nr == 0 || isContinue {
continue
}
if dr == upstream {
go p.afterCommHook(p.removeMailBody(buf[0:nr]), srcToDst)
} else {
// time before email input
list := bytes.Split(buf, []byte(crlf))
for _, v := range list {
if len(v) >= 3 && string(v[:3]) == fmt.Sprint(codeStartingMailInput) {
p.timeAtDataStarting = time.Now()
}
}
// remove buffering ready response
if bytes.Contains(buf, []byte("Ready to start TLS")) || bytes.Contains(buf, []byte("SMTP server ready")) || bytes.Contains(buf, []byte("Start TLS")) {
continue
}
go p.afterCommHook(buf[0:nr], dstToSrc)
}
nw, ew := p.dst(dr).Write(buf[0:nr])
if nw > 0 {
written += int64(nw)
Expand Down Expand Up @@ -301,7 +334,11 @@ func (p *Pipe) close() func() {
}

func (p *Pipe) isResponseOfEHLOWithStartTLS(b []byte) bool {
return !p.tls && !p.locked && bytes.Contains(b, []byte("STARTTLS"))
return !p.tls && !p.locked && bytes.Contains(b, []byte(fmt.Sprint(codeActionCompleted))) && bytes.Contains(b, []byte("STARTTLS"))
}

func (p *Pipe) isResponseOfEHLOWithoutStartTLS(b []byte) bool {
return !p.tls && !p.locked && bytes.Contains(b, []byte(fmt.Sprint(codeActionCompleted))) && !bytes.Contains(b, []byte("STARTTLS"))
}

func (p *Pipe) isResponseOfReadyToStartTLS(b []byte) bool {
Expand Down
106 changes: 63 additions & 43 deletions pipe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,73 +5,83 @@ import (
"time"
)

func TestPairing(t *testing.T) {
func TestSetSenderServerName(t *testing.T) {
var tests = []struct {
arg []byte
expectSenderServer []byte
expectSenderAddr []byte
expectReceiverServer []byte
expectReceiverAddr []byte
arg []byte
expectSenderServer []byte
}{
{
arg: []byte("EHLO mx.example.local\r\n"),
expectSenderServer: []byte("mx.example.local"),
},
{
arg: []byte("HELO mx.example.local\r\n"),
expectSenderServer: []byte("mx.example.local"),
},
}
for _, v := range tests {
pipe := &Pipe{afterCommHook: func(b Data, to Direction) {}}
pipe.setSenderServerName(v.arg)
if string(v.expectSenderServer) != string(pipe.sServerName) {
t.Errorf("sender server name expected %s, but got %s", v.expectSenderServer, pipe.sServerName)
}
}
}

func TestSetSenderMailAddress(t *testing.T) {
var tests = []struct {
arg []byte
expectSenderAddr []byte
}{
{
arg: []byte("EHLO mx.example.local\r\n"),
expectSenderServer: []byte("mx.example.local"),
expectSenderAddr: nil,
expectReceiverServer: nil,
expectReceiverAddr: nil,
arg: []byte("MAIL FROM:<[email protected]> SIZE=4095\r\n"),
expectSenderAddr: []byte("[email protected]"),
},
{
arg: []byte("HELO mx.example.local\r\n"),
expectSenderServer: []byte("mx.example.local"),
expectSenderAddr: nil,
expectReceiverServer: nil,
expectReceiverAddr: nil,
// Sender Rewriting Scheme
arg: []byte("MAIL FROM:<SRS0=x/[email protected]> SIZE=4095\r\n"),
expectSenderAddr: []byte("SRS0=x/[email protected]"),
},
{
arg: []byte("MAIL FROM:<[email protected]> SIZE=4095\r\n"),
expectSenderServer: nil,
expectSenderAddr: []byte("[email protected]"),
expectReceiverServer: nil,
expectReceiverAddr: nil,
// Pipelining
arg: []byte("MAIL FROM:<[email protected]> SIZE=4095\r\nRCPT TO:<[email protected]> ORCPT=rfc822;[email protected]\r\nDATA\r\n"),
expectSenderAddr: []byte("[email protected]"),
},
}
for _, v := range tests {
pipe := &Pipe{afterCommHook: func(b Data, to Direction) {}}
pipe.setSenderMailAddress(v.arg)
if string(v.expectSenderAddr) != string(pipe.sMailAddr) {
t.Errorf("sender email address expected %s, but got %s", v.expectSenderAddr, pipe.sMailAddr)
}
}
}

func TestSetReceiverMailAddressAndServerName(t *testing.T) {
var tests = []struct {
arg []byte
expectReceiverServer []byte
expectReceiverAddr []byte
}{
{
arg: []byte("RCPT TO:<[email protected]>\r\n"),
expectSenderServer: nil,
expectSenderAddr: nil,
expectReceiverServer: []byte("example.com"),
expectReceiverAddr: []byte("[email protected]"),
},
{
// Sender Rewriting Scheme
arg: []byte("MAIL FROM:<SRS0=x/[email protected]> SIZE=4095\r\n"),
expectSenderServer: nil,
expectSenderAddr: []byte("SRS0=x/[email protected]"),
expectReceiverServer: nil,
expectReceiverAddr: nil,
},
{
// Pipelining
arg: []byte("MAIL FROM:<[email protected]> SIZE=4095\r\nRCPT TO:<[email protected]> ORCPT=rfc822;[email protected]\r\nDATA\r\n"),
expectSenderServer: nil,
expectSenderAddr: []byte("[email protected]"),
expectReceiverServer: []byte("example.com"),
expectReceiverAddr: []byte("[email protected]"),
},
}
for _, v := range tests {
pipe := &Pipe{afterCommHook: func(b Data, to Direction) {}}
pipe.pairing(v.arg)

if v.expectSenderServer != nil && string(v.expectSenderServer) != string(pipe.sServerName) {
t.Errorf("sender server name expected %s, but got %s", v.expectSenderServer, pipe.sServerName)
}
if v.expectSenderAddr != nil && string(v.expectSenderAddr) != string(pipe.sMailAddr) {
t.Errorf("sender email address expected %s, but got %s", v.expectSenderAddr, pipe.sMailAddr)
}
if v.expectReceiverServer != nil && string(v.expectReceiverServer) != string(pipe.rServerName) {
pipe.setReceiverMailAddressAndServerName(v.arg)
if string(v.expectReceiverServer) != string(pipe.rServerName) {
t.Errorf("receiver server name expected %s, but got %s", v.expectReceiverServer, pipe.rServerName)
}
if v.expectReceiverAddr != nil && string(v.expectReceiverAddr) != string(pipe.rMailAddr) {
if string(v.expectReceiverAddr) != string(pipe.rMailAddr) {
t.Errorf("receiver email address expected %s, but got %s", v.expectReceiverAddr, pipe.rMailAddr)
}
}
Expand All @@ -87,6 +97,16 @@ func TestIsResponseOfEHLOWithStartTLS(t *testing.T) {
}
}

func TestIsResponseOfEHLOWithoutStartTLS(t *testing.T) {
pipe := &Pipe{
tls: false,
locked: false,
}
if !pipe.isResponseOfEHLOWithoutStartTLS([]byte("250-example.test\r\n250-PIPELINING\r\n250-8BITMIME\r\n250 SIZE 41943040\r\n")) {
t.Errorf("expected true, but got false")
}
}

func TestIsResponseOfReadyToStartTLS(t *testing.T) {
pipe := &Pipe{
tls: false,
Expand Down

0 comments on commit 9ac13d1

Please sign in to comment.