diff --git a/client.go b/client.go index e8d4a91..3d0acf3 100644 --- a/client.go +++ b/client.go @@ -82,7 +82,7 @@ func (c Client) Put(filename string, mode string, handler func(w *io.PipeWriter) handler(writer) wg.Done() }() - s.Run(false) + s.run(false) wg.Wait() return nil } @@ -105,7 +105,7 @@ func (c Client) Get(filename string, mode string, handler func(r *io.PipeReader) handler(reader) wg.Done() }() - r.Run(false) + r.run(false) wg.Wait() return fmt.Errorf("Send timeout") } diff --git a/packet.go b/packet.go index 478019b..29d1fa2 100644 --- a/packet.go +++ b/packet.go @@ -153,7 +153,7 @@ func (p *ERROR) Pack() []byte { return buffer.Bytes() } -func ParsePacket(data []byte) (*Packet, error) { +func ParsePacket(data []byte) (Packet, error) { var p Packet opcode := binary.BigEndian.Uint16(data) switch opcode { @@ -168,8 +168,7 @@ func ParsePacket(data []byte) (*Packet, error) { case OP_ERROR: p = &ERROR{} default: - return nil, fmt.Errorf("Unknown packet type: %d", opcode) + return nil, fmt.Errorf("unknown opcode: %d", opcode) } - pp := Packet(p) - return &pp, pp.Unpack(data) + return p, p.Unpack(data) } diff --git a/receiver.go b/receiver.go index adb83b3..821b8a8 100644 --- a/receiver.go +++ b/receiver.go @@ -1,6 +1,7 @@ package tftp import ( + "errors" "fmt" "io" "log" @@ -17,14 +18,16 @@ type receiver struct { log *log.Logger } -func (r *receiver) Run(isServerMode bool) error { +var ErrReceiveTimeout = errors.New("receive timeout") + +func (r *receiver) run(serverMode bool) error { var blockNumber uint16 blockNumber = 1 var buffer []byte buffer = make([]byte, MAX_DATAGRAM_SIZE) firstBlock := true for { - last, e := r.receiveBlock(buffer, blockNumber, firstBlock && !isServerMode) + last, e := r.receiveBlock(buffer, blockNumber, firstBlock && !serverMode) if e != nil { if r.log != nil { r.log.Printf("Error receiving block %d: %v", blockNumber, e) @@ -69,7 +72,7 @@ func (r *receiver) receiveBlock(b []byte, n uint16, firstBlockOnClient bool) (la if e != nil { continue } - switch p := Packet(*packet).(type) { + switch p := packet.(type) { case *DATA: r.log.Printf("got DATA #%d (%d bytes)", p.BlockNumber, len(p.Data)) if n == p.BlockNumber { @@ -90,7 +93,7 @@ func (r *receiver) receiveBlock(b []byte, n uint16, firstBlockOnClient bool) (la } } } - return false, fmt.Errorf("Receive timeout") + return false, ErrReceiveTimeout } func (r *receiver) terminate(b []byte, n uint16, dallying bool) (e error) { @@ -117,7 +120,7 @@ func (r *receiver) terminate(b []byte, n uint16, dallying bool) (e error) { if e != nil { continue } - switch p := Packet(*packet).(type) { + switch p := packet.(type) { case *DATA: r.log.Printf("got DATA #%d (%d bytes)", p.BlockNumber, len(p.Data)) if n == p.BlockNumber { diff --git a/sender.go b/sender.go index 92bb056..850daa7 100644 --- a/sender.go +++ b/sender.go @@ -1,6 +1,7 @@ package tftp import ( + "errors" "fmt" "io" "log" @@ -17,15 +18,17 @@ type sender struct { log *log.Logger } -func (s *sender) Run(isServerMode bool) { +var ErrSendTimeout = errors.New("send timeout") + +func (s *sender) run(serverMode bool) { var buffer, tmp []byte buffer = make([]byte, BLOCK_SIZE) tmp = make([]byte, MAX_DATAGRAM_SIZE) - if !isServerMode { - e := s.sendRequest(tmp) - if e != nil { - s.log.Printf("Error starting transmission: %v", e) - s.reader.CloseWithError(e) + if !serverMode { + err := s.sendRequest(tmp) + if err != nil { + s.log.Printf("Error starting transmission: %v", err) + s.reader.CloseWithError(err) return } } @@ -93,7 +96,7 @@ func (s *sender) sendRequest(tmp []byte) (e error) { if e != nil { continue } - switch p := Packet(*packet).(type) { + switch p := packet.(type) { case *ACK: if p.BlockNumber == 0 { s.log.Printf("got ACK #0") @@ -105,7 +108,7 @@ func (s *sender) sendRequest(tmp []byte) (e error) { } } } - return fmt.Errorf("Send timeout") + return ErrSendTimeout } func (s *sender) sendBlock(b []byte, c int, n uint16, tmp []byte) (e error) { @@ -128,7 +131,7 @@ func (s *sender) sendBlock(b []byte, c int, n uint16, tmp []byte) (e error) { if e != nil { continue } - switch p := Packet(*packet).(type) { + switch p := packet.(type) { case *ACK: s.log.Printf("got ACK #%d", p.BlockNumber) if n == p.BlockNumber { @@ -139,5 +142,5 @@ func (s *sender) sendBlock(b []byte, c int, n uint16, tmp []byte) (e error) { } } } - return fmt.Errorf("Send timeout") + return ErrSendTimeout } diff --git a/server.go b/server.go index d78e102..bc9805c 100644 --- a/server.go +++ b/server.go @@ -82,7 +82,7 @@ func (s Server) processRequest(conn *net.UDPConn) error { if e != nil { return nil } - switch p := Packet(*p).(type) { + switch p := p.(type) { case *WRQ: s.Log.Printf("got WRQ (filename=%s, mode=%s)", p.Filename, p.Mode) trasnmissionConn, e := s.transmissionConn() @@ -102,7 +102,7 @@ func (s Server) processRequest(conn *net.UDPConn) error { s.Log.Printf("sent ERROR (code=%d): %s", 1, e.Error()) return e } - go r.Run(true) + go r.run(true) case *RRQ: s.Log.Printf("got RRQ (filename=%s, mode=%s)", p.Filename, p.Mode) trasnmissionConn, e := s.transmissionConn() @@ -112,7 +112,7 @@ func (s Server) processRequest(conn *net.UDPConn) error { reader, writer := io.Pipe() r := &sender{remoteAddr, trasnmissionConn, reader, p.Filename, p.Mode, s.Log} go s.WriteHandler(p.Filename, writer) - go r.Run(true) + go r.run(true) } return nil } diff --git a/tftp_test.go b/tftp_test.go index 839c11f..a266c99 100644 --- a/tftp_test.go +++ b/tftp_test.go @@ -57,6 +57,55 @@ func TestPutGet(t *testing.T) { } } +func TestTimeout(t *testing.T) { + addr, _ := net.ResolveUDPAddr("udp", "localhost:12322") + + log := log.New(os.Stderr, "", log.Ldate|log.Ltime) + + writeHandler := func(filename string, r *io.PipeReader) { + buf := make([]byte, 64) + for i := 0; i < 5; i++ { + _, err := r.Read(buf) + if err != nil { + panic(err) + } + } + // server "fail" during receive + } + + readHandler := func(filename string, w *io.PipeWriter) { + for i := 0; i < 5; i++ { + _, err := w.Write(randomByteArray(64)) + if err != nil { + panic(err) + } + } + // server "fail" during send + } + + s = &Server{addr, writeHandler, readHandler, log} + go s.Serve() + + c = &Client{addr, log} + + var err error + c.Put("test", "octet", func(writer *io.PipeWriter) { + _, err = writer.Write(randomByteArray(5000)) + writer.Close() + }) + if err != ErrSendTimeout { + t.Fatalf("Send timeout expected, got %v", err) + } + + buf := new(bytes.Buffer) + c.Get("test", "octet", func(reader *io.PipeReader) { + _, err = buf.ReadFrom(reader) + }) + if err != ErrReceiveTimeout { + t.Fatalf("Receive timeout expected, got %v", err) + } +} + func randomByteArray(n int) []byte { bs := make([]byte, n) for i := 0; i < n; i++ {