Skip to content

Commit 8b1a82f

Browse files
committed
Merge pull request #6 from pin/misc-stuff
Misc stuff from George
2 parents ab5b5e0 + 4056ba4 commit 8b1a82f

File tree

6 files changed

+78
-24
lines changed

6 files changed

+78
-24
lines changed

client.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ func (c Client) Put(filename string, mode string, handler func(w *io.PipeWriter)
8282
handler(writer)
8383
wg.Done()
8484
}()
85-
s.Run(false)
85+
s.run(false)
8686
wg.Wait()
8787
return nil
8888
}
@@ -105,7 +105,7 @@ func (c Client) Get(filename string, mode string, handler func(r *io.PipeReader)
105105
handler(reader)
106106
wg.Done()
107107
}()
108-
r.Run(false)
108+
r.run(false)
109109
wg.Wait()
110110
return fmt.Errorf("Send timeout")
111111
}

packet.go

+3-4
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ func (p *ERROR) Pack() []byte {
153153
return buffer.Bytes()
154154
}
155155

156-
func ParsePacket(data []byte) (*Packet, error) {
156+
func ParsePacket(data []byte) (Packet, error) {
157157
var p Packet
158158
opcode := binary.BigEndian.Uint16(data)
159159
switch opcode {
@@ -168,8 +168,7 @@ func ParsePacket(data []byte) (*Packet, error) {
168168
case OP_ERROR:
169169
p = &ERROR{}
170170
default:
171-
return nil, fmt.Errorf("Unknown packet type: %d", opcode)
171+
return nil, fmt.Errorf("unknown opcode: %d", opcode)
172172
}
173-
pp := Packet(p)
174-
return &pp, pp.Unpack(data)
173+
return p, p.Unpack(data)
175174
}

receiver.go

+8-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package tftp
22

33
import (
4+
"errors"
45
"fmt"
56
"io"
67
"log"
@@ -17,14 +18,16 @@ type receiver struct {
1718
log *log.Logger
1819
}
1920

20-
func (r *receiver) Run(isServerMode bool) error {
21+
var ErrReceiveTimeout = errors.New("receive timeout")
22+
23+
func (r *receiver) run(serverMode bool) error {
2124
var blockNumber uint16
2225
blockNumber = 1
2326
var buffer []byte
2427
buffer = make([]byte, MAX_DATAGRAM_SIZE)
2528
firstBlock := true
2629
for {
27-
last, e := r.receiveBlock(buffer, blockNumber, firstBlock && !isServerMode)
30+
last, e := r.receiveBlock(buffer, blockNumber, firstBlock && !serverMode)
2831
if e != nil {
2932
if r.log != nil {
3033
r.log.Printf("Error receiving block %d: %v", blockNumber, e)
@@ -69,7 +72,7 @@ func (r *receiver) receiveBlock(b []byte, n uint16, firstBlockOnClient bool) (la
6972
if e != nil {
7073
continue
7174
}
72-
switch p := Packet(*packet).(type) {
75+
switch p := packet.(type) {
7376
case *DATA:
7477
r.log.Printf("got DATA #%d (%d bytes)", p.BlockNumber, len(p.Data))
7578
if n == p.BlockNumber {
@@ -90,7 +93,7 @@ func (r *receiver) receiveBlock(b []byte, n uint16, firstBlockOnClient bool) (la
9093
}
9194
}
9295
}
93-
return false, fmt.Errorf("Receive timeout")
96+
return false, ErrReceiveTimeout
9497
}
9598

9699
func (r *receiver) terminate(b []byte, n uint16, dallying bool) (e error) {
@@ -117,7 +120,7 @@ func (r *receiver) terminate(b []byte, n uint16, dallying bool) (e error) {
117120
if e != nil {
118121
continue
119122
}
120-
switch p := Packet(*packet).(type) {
123+
switch p := packet.(type) {
121124
case *DATA:
122125
r.log.Printf("got DATA #%d (%d bytes)", p.BlockNumber, len(p.Data))
123126
if n == p.BlockNumber {

sender.go

+13-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package tftp
22

33
import (
4+
"errors"
45
"fmt"
56
"io"
67
"log"
@@ -17,15 +18,17 @@ type sender struct {
1718
log *log.Logger
1819
}
1920

20-
func (s *sender) Run(isServerMode bool) {
21+
var ErrSendTimeout = errors.New("send timeout")
22+
23+
func (s *sender) run(serverMode bool) {
2124
var buffer, tmp []byte
2225
buffer = make([]byte, BLOCK_SIZE)
2326
tmp = make([]byte, MAX_DATAGRAM_SIZE)
24-
if !isServerMode {
25-
e := s.sendRequest(tmp)
26-
if e != nil {
27-
s.log.Printf("Error starting transmission: %v", e)
28-
s.reader.CloseWithError(e)
27+
if !serverMode {
28+
err := s.sendRequest(tmp)
29+
if err != nil {
30+
s.log.Printf("Error starting transmission: %v", err)
31+
s.reader.CloseWithError(err)
2932
return
3033
}
3134
}
@@ -93,7 +96,7 @@ func (s *sender) sendRequest(tmp []byte) (e error) {
9396
if e != nil {
9497
continue
9598
}
96-
switch p := Packet(*packet).(type) {
99+
switch p := packet.(type) {
97100
case *ACK:
98101
if p.BlockNumber == 0 {
99102
s.log.Printf("got ACK #0")
@@ -105,7 +108,7 @@ func (s *sender) sendRequest(tmp []byte) (e error) {
105108
}
106109
}
107110
}
108-
return fmt.Errorf("Send timeout")
111+
return ErrSendTimeout
109112
}
110113

111114
func (s *sender) sendBlock(b []byte, c int, n uint16, tmp []byte) (e error) {
@@ -128,7 +131,7 @@ func (s *sender) sendBlock(b []byte, c int, n uint16, tmp []byte) (e error) {
128131
if e != nil {
129132
continue
130133
}
131-
switch p := Packet(*packet).(type) {
134+
switch p := packet.(type) {
132135
case *ACK:
133136
s.log.Printf("got ACK #%d", p.BlockNumber)
134137
if n == p.BlockNumber {
@@ -139,5 +142,5 @@ func (s *sender) sendBlock(b []byte, c int, n uint16, tmp []byte) (e error) {
139142
}
140143
}
141144
}
142-
return fmt.Errorf("Send timeout")
145+
return ErrSendTimeout
143146
}

server.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ func (s Server) processRequest(conn *net.UDPConn) error {
8282
if e != nil {
8383
return nil
8484
}
85-
switch p := Packet(*p).(type) {
85+
switch p := p.(type) {
8686
case *WRQ:
8787
s.Log.Printf("got WRQ (filename=%s, mode=%s)", p.Filename, p.Mode)
8888
trasnmissionConn, e := s.transmissionConn()
@@ -102,7 +102,7 @@ func (s Server) processRequest(conn *net.UDPConn) error {
102102
s.Log.Printf("sent ERROR (code=%d): %s", 1, e.Error())
103103
return e
104104
}
105-
go r.Run(true)
105+
go r.run(true)
106106
case *RRQ:
107107
s.Log.Printf("got RRQ (filename=%s, mode=%s)", p.Filename, p.Mode)
108108
trasnmissionConn, e := s.transmissionConn()
@@ -112,7 +112,7 @@ func (s Server) processRequest(conn *net.UDPConn) error {
112112
reader, writer := io.Pipe()
113113
r := &sender{remoteAddr, trasnmissionConn, reader, p.Filename, p.Mode, s.Log}
114114
go s.WriteHandler(p.Filename, writer)
115-
go r.Run(true)
115+
go r.run(true)
116116
}
117117
return nil
118118
}

tftp_test.go

+49
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,55 @@ func TestPutGet(t *testing.T) {
5757
}
5858
}
5959

60+
func TestTimeout(t *testing.T) {
61+
addr, _ := net.ResolveUDPAddr("udp", "localhost:12322")
62+
63+
log := log.New(os.Stderr, "", log.Ldate|log.Ltime)
64+
65+
writeHandler := func(filename string, r *io.PipeReader) {
66+
buf := make([]byte, 64)
67+
for i := 0; i < 5; i++ {
68+
_, err := r.Read(buf)
69+
if err != nil {
70+
panic(err)
71+
}
72+
}
73+
// server "fail" during receive
74+
}
75+
76+
readHandler := func(filename string, w *io.PipeWriter) {
77+
for i := 0; i < 5; i++ {
78+
_, err := w.Write(randomByteArray(64))
79+
if err != nil {
80+
panic(err)
81+
}
82+
}
83+
// server "fail" during send
84+
}
85+
86+
s = &Server{addr, writeHandler, readHandler, log}
87+
go s.Serve()
88+
89+
c = &Client{addr, log}
90+
91+
var err error
92+
c.Put("test", "octet", func(writer *io.PipeWriter) {
93+
_, err = writer.Write(randomByteArray(5000))
94+
writer.Close()
95+
})
96+
if err != ErrSendTimeout {
97+
t.Fatalf("Send timeout expected, got %v", err)
98+
}
99+
100+
buf := new(bytes.Buffer)
101+
c.Get("test", "octet", func(reader *io.PipeReader) {
102+
_, err = buf.ReadFrom(reader)
103+
})
104+
if err != ErrReceiveTimeout {
105+
t.Fatalf("Receive timeout expected, got %v", err)
106+
}
107+
}
108+
60109
func randomByteArray(n int) []byte {
61110
bs := make([]byte, n)
62111
for i := 0; i < n; i++ {

0 commit comments

Comments
 (0)