diff --git a/async_processor.go b/async_processor.go index d63230f8..41da1553 100644 --- a/async_processor.go +++ b/async_processor.go @@ -4,46 +4,57 @@ import ( "github.com/bluenviron/gortsplib/v4/pkg/ringbuffer" ) -// this struct contains a queue that allows to detach the routine that is reading a stream +// this is an asynchronous queue processor +// that allows to detach the routine that is reading a stream // from the routine that is writing a stream. type asyncProcessor struct { + bufferSize int + running bool buffer *ringbuffer.RingBuffer - done chan struct{} + chError chan error } -func (w *asyncProcessor) allocateBuffer(size int) { - w.buffer, _ = ringbuffer.New(uint64(size)) +func (w *asyncProcessor) initialize() { + w.buffer, _ = ringbuffer.New(uint64(w.bufferSize)) } func (w *asyncProcessor) start() { w.running = true - w.done = make(chan struct{}) + w.chError = make(chan error) go w.run() } func (w *asyncProcessor) stop() { - if w.running { - w.buffer.Close() - <-w.done - w.running = false + if !w.running { + panic("should not happen") } + w.buffer.Close() + <-w.chError + w.running = false } func (w *asyncProcessor) run() { - defer close(w.done) + err := w.runInner() + w.chError <- err + close(w.chError) +} +func (w *asyncProcessor) runInner() error { for { tmp, ok := w.buffer.Pull() if !ok { - return + return nil } - tmp.(func())() + err := tmp.(func() error)() + if err != nil { + return err + } } } -func (w *asyncProcessor) push(cb func()) bool { +func (w *asyncProcessor) push(cb func() error) bool { return w.buffer.Push(cb) } diff --git a/client.go b/client.go index 85cb0e8c..424da01e 100644 --- a/client.go +++ b/client.go @@ -335,22 +335,19 @@ type Client struct { keepalivePeriod time.Duration keepaliveTimer *time.Timer closeError error - writer asyncProcessor + writer *asyncProcessor reader *clientReader timeDecoder *rtptime.GlobalDecoder2 mustClose bool // in - chOptions chan optionsReq - chDescribe chan describeReq - chAnnounce chan announceReq - chSetup chan setupReq - chPlay chan playReq - chRecord chan recordReq - chPause chan pauseReq - chReadError chan error - chReadResponse chan *base.Response - chReadRequest chan *base.Request + chOptions chan optionsReq + chDescribe chan describeReq + chAnnounce chan announceReq + chSetup chan setupReq + chPlay chan playReq + chRecord chan recordReq + chPause chan pauseReq // out done chan struct{} @@ -462,9 +459,6 @@ func (c *Client) Start(scheme string, host string) error { c.chPlay = make(chan playReq) c.chRecord = make(chan recordReq) c.chPause = make(chan pauseReq) - c.chReadError = make(chan error) - c.chReadResponse = make(chan *base.Response) - c.chReadRequest = make(chan *base.Request) c.done = make(chan struct{}) go c.run() @@ -530,6 +524,34 @@ func (c *Client) run() { func (c *Client) runInner() error { for { + chReaderResponse := func() chan *base.Response { + if c.reader != nil { + return c.reader.chResponse + } + return nil + }() + + chReaderRequest := func() chan *base.Request { + if c.reader != nil { + return c.reader.chRequest + } + return nil + }() + + chReaderError := func() chan error { + if c.reader != nil { + return c.reader.chError + } + return nil + }() + + chWriterError := func() chan error { + if c.writer != nil { + return c.writer.chError + } + return nil + }() + select { case req := <-c.chOptions: res, err := c.doOptions(req.url) @@ -601,15 +623,18 @@ func (c *Client) runInner() error { } c.keepaliveTimer = time.NewTimer(c.keepalivePeriod) - case err := <-c.chReadError: + case err := <-chWriterError: + return err + + case err := <-chReaderError: c.reader = nil return err - case res := <-c.chReadResponse: + case res := <-chReaderResponse: c.OnResponse(res) // these are responses to keepalives, ignore them. - case req := <-c.chReadRequest: + case req := <-chReaderRequest: err := c.handleServerRequest(req) if err != nil { return err @@ -630,11 +655,11 @@ func (c *Client) waitResponse(requestCseqStr string) (*base.Response, error) { case <-t.C: return nil, liberrors.ErrClientRequestTimedOut{} - case err := <-c.chReadError: + case err := <-c.reader.chError: c.reader = nil return nil, err - case res := <-c.chReadResponse: + case res := <-c.reader.chResponse: c.OnResponse(res) // accept response if CSeq equals request CSeq, or if CSeq is not present @@ -642,7 +667,7 @@ func (c *Client) waitResponse(requestCseqStr string) (*base.Response, error) { return res, nil } - case req := <-c.chReadRequest: + case req := <-c.reader.chRequest: err := c.handleServerRequest(req) if err != nil { return nil, err @@ -682,8 +707,8 @@ func (c *Client) handleServerRequest(req *base.Request) error { func (c *Client) doClose() { if c.state == clientStatePlay || c.state == clientStateRecord { - c.stopWriter() - c.stopReadRoutines() + c.writer.stop() + c.stopTransportRoutines() } if c.nconn != nil && c.baseURL != nil { @@ -808,15 +833,21 @@ func (c *Client) trySwitchingProtocol2(medi *description.Media, baseURL *base.UR return c.doSetup(baseURL, medi, 0, 0) } -func (c *Client) startReadRoutines() { +func (c *Client) startTransportRoutines() { // allocate writer here because it's needed by RTCP receiver / sender if c.state == clientStateRecord || c.backChannelSetupped { - c.writer.allocateBuffer(c.WriteQueueSize) + c.writer = &asyncProcessor{ + bufferSize: c.WriteQueueSize, + } + c.writer.initialize() } else { // when reading, buffer is only used to send RTCP receiver reports, // that are much smaller than RTP packets and are sent at a fixed interval. // decrease RAM consumption by allocating less buffers. - c.writer.allocateBuffer(8) + c.writer = &asyncProcessor{ + bufferSize: 8, + } + c.writer.initialize() } c.timeDecoder = rtptime.NewGlobalDecoder2() @@ -848,7 +879,7 @@ func (c *Client) startReadRoutines() { } } -func (c *Client) stopReadRoutines() { +func (c *Client) stopTransportRoutines() { if c.reader != nil { c.reader.setAllowInterleavedFrames(false) } @@ -861,14 +892,8 @@ func (c *Client) stopReadRoutines() { } c.timeDecoder = nil -} - -func (c *Client) startWriter() { - c.writer.start() -} -func (c *Client) stopWriter() { - c.writer.stop() + c.writer = nil } func (c *Client) connOpen() error { @@ -1637,7 +1662,7 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) { } c.state = clientStatePlay - c.startReadRoutines() + c.startTransportRoutines() // Range is mandatory in Parrot Streaming Server if ra == nil { @@ -1662,13 +1687,13 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) { Header: header, }, false) if err != nil { - c.stopReadRoutines() + c.stopTransportRoutines() c.state = clientStatePrePlay return nil, err } if res.StatusCode != base.StatusOK { - c.stopReadRoutines() + c.stopTransportRoutines() c.state = clientStatePrePlay return nil, liberrors.ErrClientBadStatusCode{ Code: res.StatusCode, Message: res.StatusMessage, @@ -1689,7 +1714,7 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) { } } - c.startWriter() + c.writer.start() c.lastRange = ra return res, nil @@ -1718,27 +1743,27 @@ func (c *Client) doRecord() (*base.Response, error) { } c.state = clientStateRecord - c.startReadRoutines() + c.startTransportRoutines() res, err := c.do(&base.Request{ Method: base.Record, URL: c.baseURL, }, false) if err != nil { - c.stopReadRoutines() + c.stopTransportRoutines() c.state = clientStatePreRecord return nil, err } if res.StatusCode != base.StatusOK { - c.stopReadRoutines() + c.stopTransportRoutines() c.state = clientStatePreRecord return nil, liberrors.ErrClientBadStatusCode{ Code: res.StatusCode, Message: res.StatusMessage, } } - c.startWriter() + c.writer.start() return nil, nil } @@ -1766,25 +1791,25 @@ func (c *Client) doPause() (*base.Response, error) { return nil, err } - c.stopWriter() + c.writer.stop() res, err := c.do(&base.Request{ Method: base.Pause, URL: c.baseURL, }, false) if err != nil { - c.startWriter() + c.writer.start() return nil, err } if res.StatusCode != base.StatusOK { - c.startWriter() + c.writer.start() return nil, liberrors.ErrClientBadStatusCode{ Code: res.StatusCode, Message: res.StatusMessage, } } - c.stopReadRoutines() + c.stopTransportRoutines() switch c.state { case clientStatePlay: @@ -1929,15 +1954,3 @@ func (c *Client) PacketNTP(medi *description.Media, pkt *rtp.Packet) (time.Time, ct := cm.formats[pkt.PayloadType] return ct.rtcpReceiver.PacketNTP(pkt.Timestamp) } - -func (c *Client) readResponse(res *base.Response) { - c.chReadResponse <- res -} - -func (c *Client) readRequest(req *base.Request) { - c.chReadRequest <- req -} - -func (c *Client) readError(err error) { - c.chReadError <- err -} diff --git a/client_format.go b/client_format.go index b5eeb9ac..8695b2c2 100644 --- a/client_format.go +++ b/client_format.go @@ -74,8 +74,8 @@ func (cf *clientFormat) stop() { func (cf *clientFormat) writePacketRTP(byts []byte, pkt *rtp.Packet, ntp time.Time) error { cf.rtcpSender.ProcessPacket(pkt, ntp, cf.format.PTSEqualsDTS(pkt)) - ok := cf.cm.c.writer.push(func() { - cf.cm.writePacketRTPInQueue(byts) + ok := cf.cm.c.writer.push(func() error { + return cf.cm.writePacketRTPInQueue(byts) }) if !ok { return liberrors.ErrClientWriteQueueFull{} diff --git a/client_media.go b/client_media.go index dca30f95..47f279b5 100644 --- a/client_media.go +++ b/client_media.go @@ -25,8 +25,8 @@ type clientMedia struct { tcpRTPFrame *base.InterleavedFrame tcpRTCPFrame *base.InterleavedFrame tcpBuffer []byte - writePacketRTPInQueue func([]byte) - writePacketRTCPInQueue func([]byte) + writePacketRTPInQueue func([]byte) error + writePacketRTCPInQueue func([]byte) error } func (cm *clientMedia) close() { @@ -152,29 +152,29 @@ func (cm *clientMedia) findFormatWithSSRC(ssrc uint32) *clientFormat { return nil } -func (cm *clientMedia) writePacketRTPInQueueUDP(payload []byte) { - cm.udpRTPListener.write(payload) //nolint:errcheck +func (cm *clientMedia) writePacketRTPInQueueUDP(payload []byte) error { + return cm.udpRTPListener.write(payload) } -func (cm *clientMedia) writePacketRTCPInQueueUDP(payload []byte) { - cm.udpRTCPListener.write(payload) //nolint:errcheck +func (cm *clientMedia) writePacketRTCPInQueueUDP(payload []byte) error { + return cm.udpRTCPListener.write(payload) } -func (cm *clientMedia) writePacketRTPInQueueTCP(payload []byte) { +func (cm *clientMedia) writePacketRTPInQueueTCP(payload []byte) error { cm.tcpRTPFrame.Payload = payload cm.c.nconn.SetWriteDeadline(time.Now().Add(cm.c.WriteTimeout)) - cm.c.conn.WriteInterleavedFrame(cm.tcpRTPFrame, cm.tcpBuffer) //nolint:errcheck + return cm.c.conn.WriteInterleavedFrame(cm.tcpRTPFrame, cm.tcpBuffer) } -func (cm *clientMedia) writePacketRTCPInQueueTCP(payload []byte) { +func (cm *clientMedia) writePacketRTCPInQueueTCP(payload []byte) error { cm.tcpRTCPFrame.Payload = payload cm.c.nconn.SetWriteDeadline(time.Now().Add(cm.c.WriteTimeout)) - cm.c.conn.WriteInterleavedFrame(cm.tcpRTCPFrame, cm.tcpBuffer) //nolint:errcheck + return cm.c.conn.WriteInterleavedFrame(cm.tcpRTCPFrame, cm.tcpBuffer) } func (cm *clientMedia) writePacketRTCP(byts []byte) error { - ok := cm.c.writer.push(func() { - cm.writePacketRTCPInQueue(byts) + ok := cm.c.writer.push(func() error { + return cm.writePacketRTCPInQueue(byts) }) if !ok { return liberrors.ErrClientWriteQueueFull{} diff --git a/client_reader.go b/client_reader.go index ca65d519..f64d115c 100644 --- a/client_reader.go +++ b/client_reader.go @@ -12,9 +12,17 @@ type clientReader struct { mutex sync.Mutex allowInterleavedFrames bool + + chResponse chan *base.Response + chRequest chan *base.Request + chError chan error } func (r *clientReader) start() { + r.chResponse = make(chan *base.Response) + r.chRequest = make(chan *base.Request) + r.chError = make(chan error) + go r.run() } @@ -27,18 +35,17 @@ func (r *clientReader) setAllowInterleavedFrames(v bool) { func (r *clientReader) wait() { for { select { - case <-r.c.chReadError: + case <-r.chError: return - case <-r.c.chReadResponse: - case <-r.c.chReadRequest: + case <-r.chResponse: + case <-r.chRequest: } } } func (r *clientReader) run() { - err := r.runInner() - r.c.readError(err) + r.chError <- r.runInner() } func (r *clientReader) runInner() error { @@ -50,10 +57,10 @@ func (r *clientReader) runInner() error { switch what := what.(type) { case *base.Response: - r.c.readResponse(what) + r.chResponse <- what case *base.Request: - r.c.readRequest(what) + r.chRequest <- what case *base.InterleavedFrame: r.mutex.Lock() diff --git a/client_record_test.go b/client_record_test.go index df3e5276..90979712 100644 --- a/client_record_test.go +++ b/client_record_test.go @@ -126,7 +126,7 @@ func readRequestIgnoreFrames(c *conn.Conn) (*base.Request, error) { } } -func TestClientRecordSerial(t *testing.T) { +func TestClientRecord(t *testing.T) { for _, transport := range []string{ "udp", "tcp", @@ -350,7 +350,7 @@ func TestClientRecordSerial(t *testing.T) { } } -func TestClientRecordParallel(t *testing.T) { +func TestClientRecordSocketError(t *testing.T) { for _, transport := range []string{ "udp", "tcp", @@ -446,15 +446,6 @@ func TestClientRecordParallel(t *testing.T) { StatusCode: base.StatusOK, }) require.NoError(t, err2) - - req, err2 = readRequestIgnoreFrames(conn) - require.NoError(t, err2) - require.Equal(t, base.Teardown, req.Method) - - err2 = conn.WriteResponse(&base.Response{ - StatusCode: base.StatusOK, - }) - require.NoError(t, err2) }() c := Client{ @@ -471,9 +462,6 @@ func TestClientRecordParallel(t *testing.T) { }(), } - writerDone := make(chan struct{}) - defer func() { <-writerDone }() - medi := testH264Media medias := []*description.Media{medi} @@ -481,21 +469,15 @@ func TestClientRecordParallel(t *testing.T) { require.NoError(t, err) defer c.Close() - go func() { - defer close(writerDone) - - t := time.NewTicker(50 * time.Millisecond) - defer t.Stop() + ti := time.NewTicker(50 * time.Millisecond) + defer ti.Stop() - for range t.C { - err := c.WritePacketRTP(medi, &testRTPPacket) - if err != nil { - return - } + for range ti.C { + err := c.WritePacketRTP(medi, &testRTPPacket) + if err != nil { + break } - }() - - time.Sleep(1 * time.Second) + } }) } } @@ -645,143 +627,6 @@ func TestClientRecordPauseSerial(t *testing.T) { } } -func TestClientRecordPauseParallel(t *testing.T) { - for _, transport := range []string{ - "udp", - "tcp", - } { - t.Run(transport, func(t *testing.T) { - l, err := net.Listen("tcp", "localhost:8554") - require.NoError(t, err) - defer l.Close() - - serverDone := make(chan struct{}) - defer func() { <-serverDone }() - go func() { - defer close(serverDone) - - nconn, err2 := l.Accept() - require.NoError(t, err2) - defer nconn.Close() - conn := conn.NewConn(nconn) - - req, err2 := conn.ReadRequest() - require.NoError(t, err2) - require.Equal(t, base.Options, req.Method) - - err2 = conn.WriteResponse(&base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Public": base.HeaderValue{strings.Join([]string{ - string(base.Announce), - string(base.Setup), - string(base.Record), - string(base.Pause), - }, ", ")}, - }, - }) - require.NoError(t, err2) - - req, err2 = conn.ReadRequest() - require.NoError(t, err2) - require.Equal(t, base.Announce, req.Method) - - err2 = conn.WriteResponse(&base.Response{ - StatusCode: base.StatusOK, - }) - require.NoError(t, err2) - - req, err2 = conn.ReadRequest() - require.NoError(t, err2) - require.Equal(t, base.Setup, req.Method) - - var inTH headers.Transport - err2 = inTH.Unmarshal(req.Header["Transport"]) - require.NoError(t, err2) - - th := headers.Transport{ - Delivery: deliveryPtr(headers.TransportDeliveryUnicast), - } - - if transport == "udp" { - th.Protocol = headers.TransportProtocolUDP - th.ServerPorts = &[2]int{34556, 34557} - th.ClientPorts = inTH.ClientPorts - } else { - th.Protocol = headers.TransportProtocolTCP - th.InterleavedIDs = inTH.InterleavedIDs - } - - err2 = conn.WriteResponse(&base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Transport": th.Marshal(), - }, - }) - require.NoError(t, err2) - - req, err2 = conn.ReadRequest() - require.NoError(t, err2) - require.Equal(t, base.Record, req.Method) - - err2 = conn.WriteResponse(&base.Response{ - StatusCode: base.StatusOK, - }) - require.NoError(t, err2) - - req, err2 = readRequestIgnoreFrames(conn) - require.NoError(t, err2) - require.Equal(t, base.Pause, req.Method) - - err2 = conn.WriteResponse(&base.Response{ - StatusCode: base.StatusOK, - }) - require.NoError(t, err2) - }() - - c := Client{ - Transport: func() *Transport { - if transport == "udp" { - v := TransportUDP - return &v - } - v := TransportTCP - return &v - }(), - } - - medi := testH264Media - medias := []*description.Media{medi} - - err = record(&c, "rtsp://localhost:8554/teststream", medias, nil) - require.NoError(t, err) - - writerDone := make(chan struct{}) - go func() { - defer close(writerDone) - - t := time.NewTicker(50 * time.Millisecond) - defer t.Stop() - - for range t.C { - err2 := c.WritePacketRTP(medi, &testRTPPacket) - if err2 != nil { - return - } - } - }() - - time.Sleep(1 * time.Second) - - _, err = c.Pause() - require.NoError(t, err) - - c.Close() - <-writerDone - }) - } -} - func TestClientRecordAutomaticProtocol(t *testing.T) { l, err := net.Listen("tcp", "localhost:8554") require.NoError(t, err) diff --git a/server_conn.go b/server_conn.go index 5a3ed7ad..1a1cbd66 100644 --- a/server_conn.go +++ b/server_conn.go @@ -63,10 +63,9 @@ type ServerConn struct { bc *bytecounter.ByteCounter conn *conn.Conn session *ServerSession + reader *serverConnReader // in - chReadRequest chan readReq - chReadError chan error chRemoveSession chan *ServerSession // out @@ -84,8 +83,6 @@ func (sc *ServerConn) initialize() { sc.ctx = ctx sc.ctxCancel = ctxCancel sc.remoteAddr = sc.nconn.RemoteAddr().(*net.TCPAddr) - sc.chReadRequest = make(chan readReq) - sc.chReadError = make(chan error) sc.chRemoveSession = make(chan *ServerSession) sc.done = make(chan struct{}) @@ -142,10 +139,10 @@ func (sc *ServerConn) run() { } sc.conn = conn.NewConn(sc.bc) - cr := &serverConnReader{ + sc.reader = &serverConnReader{ sc: sc, } - cr.initialize() + sc.reader.initialize() err := sc.runInner() @@ -153,7 +150,9 @@ func (sc *ServerConn) run() { sc.nconn.Close() - cr.wait() + if sc.reader != nil { + sc.reader.wait() + } if sc.session != nil { sc.session.removeConn(sc) @@ -172,10 +171,11 @@ func (sc *ServerConn) run() { func (sc *ServerConn) runInner() error { for { select { - case req := <-sc.chReadRequest: + case req := <-sc.reader.chRequest: req.res <- sc.handleRequestOuter(req.req) - case err := <-sc.chReadError: + case err := <-sc.reader.chError: + sc.reader = nil return err case ss := <-sc.chRemoveSession: @@ -446,20 +446,3 @@ func (sc *ServerConn) removeSession(ss *ServerSession) { case <-sc.ctx.Done(): } } - -func (sc *ServerConn) readRequest(req readReq) error { - select { - case sc.chReadRequest <- req: - return <-req.res - - case <-sc.ctx.Done(): - return liberrors.ErrServerTerminated{} - } -} - -func (sc *ServerConn) readError(err error) { - select { - case sc.chReadError <- err: - case <-sc.ctx.Done(): - } -} diff --git a/server_conn_reader.go b/server_conn_reader.go index b5e89160..d8deb691 100644 --- a/server_conn_reader.go +++ b/server_conn_reader.go @@ -2,6 +2,7 @@ package gortsplib import ( "errors" + "fmt" "sync/atomic" "time" @@ -25,26 +26,35 @@ func isSwitchReadFuncError(err error) bool { type serverConnReader struct { sc *ServerConn - chReadDone chan struct{} + chRequest chan readReq + chError chan error } func (cr *serverConnReader) initialize() { - cr.chReadDone = make(chan struct{}) + cr.chRequest = make(chan readReq) + cr.chError = make(chan error) go cr.run() } func (cr *serverConnReader) wait() { - <-cr.chReadDone + for { + select { + case <-cr.chError: + return + + case req := <-cr.chRequest: + req.res <- fmt.Errorf("terminated") + } + } } func (cr *serverConnReader) run() { - defer close(cr.chReadDone) - readFunc := cr.readFuncStandard for { err := readFunc() + var eerr switchReadFuncError if errors.As(err, &eerr) { if eerr.tcp { @@ -55,7 +65,7 @@ func (cr *serverConnReader) run() { continue } - cr.sc.readError(err) + cr.chError <- err break } } @@ -74,7 +84,9 @@ func (cr *serverConnReader) readFuncStandard() error { case *base.Request: cres := make(chan error) req := readReq{req: what, res: cres} - err := cr.sc.readRequest(req) + cr.chRequest <- req + + err := <-cres if err != nil { return err } @@ -108,7 +120,9 @@ func (cr *serverConnReader) readFuncTCP() error { case *base.Request: cres := make(chan error) req := readReq{req: what, res: cres} - err := cr.sc.readRequest(req) + cr.chRequest <- req + + err := <-cres if err != nil { return err } diff --git a/server_multicast_writer.go b/server_multicast_writer.go index 74f19f37..26806ee0 100644 --- a/server_multicast_writer.go +++ b/server_multicast_writer.go @@ -11,7 +11,7 @@ type serverMulticastWriter struct { rtpl *serverUDPListener rtcpl *serverUDPListener - writer asyncProcessor + writer *asyncProcessor rtpAddr *net.UDPAddr rtcpAddr *net.UDPAddr } @@ -48,7 +48,10 @@ func (h *serverMulticastWriter) initialize() error { h.rtpAddr = rtpAddr h.rtcpAddr = rtcpAddr - h.writer.allocateBuffer(h.s.WriteQueueSize) + h.writer = &asyncProcessor{ + bufferSize: h.s.WriteQueueSize, + } + h.writer.initialize() h.writer.start() return nil @@ -65,8 +68,8 @@ func (h *serverMulticastWriter) ip() net.IP { } func (h *serverMulticastWriter) writePacketRTP(payload []byte) error { - ok := h.writer.push(func() { - h.rtpl.write(payload, h.rtpAddr) //nolint:errcheck + ok := h.writer.push(func() error { + return h.rtpl.write(payload, h.rtpAddr) }) if !ok { return liberrors.ErrServerWriteQueueFull{} @@ -76,8 +79,8 @@ func (h *serverMulticastWriter) writePacketRTP(payload []byte) error { } func (h *serverMulticastWriter) writePacketRTCP(payload []byte) error { - ok := h.writer.push(func() { - h.rtcpl.write(payload, h.rtcpAddr) //nolint:errcheck + ok := h.writer.push(func() error { + return h.rtcpl.write(payload, h.rtcpAddr) }) if !ok { return liberrors.ErrServerWriteQueueFull{} diff --git a/server_play_test.go b/server_play_test.go index 2bd5072a..8dd1a984 100644 --- a/server_play_test.go +++ b/server_play_test.go @@ -765,7 +765,7 @@ func TestServerPlay(t *testing.T) { var l1 net.PacketConn var l2 net.PacketConn - switch transport { + switch transport { //nolint:dupl case "udp": require.Equal(t, headers.TransportProtocolUDP, th.Protocol) require.Equal(t, headers.TransportDeliveryUnicast, *th.Delivery) @@ -942,6 +942,186 @@ func TestServerPlay(t *testing.T) { } } +func TestServerPlaySocketError(t *testing.T) { + for _, transport := range []string{ + "udp", + "multicast", + "tcp", + "tls", + } { + t.Run(transport, func(t *testing.T) { + var stream *ServerStream + connClosed := make(chan struct{}) + writeDone := make(chan struct{}) + listenIP := multicastCapableIP(t) + + s := &Server{ + Handler: &testServerHandler{ + onConnClose: func(_ *ServerHandlerOnConnCloseCtx) { + close(connClosed) + }, + onDescribe: func(_ *ServerHandlerOnDescribeCtx) (*base.Response, *ServerStream, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, stream, nil + }, + onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, stream, nil + }, + onPlay: func(_ *ServerHandlerOnPlayCtx) (*base.Response, error) { + go func() { + defer close(writeDone) + + t := time.NewTicker(50 * time.Millisecond) + defer t.Stop() + + for range t.C { + err := stream.WritePacketRTP(stream.Description().Medias[0], &testRTPPacket) + if err != nil { + return + } + } + }() + + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + }, + RTSPAddress: listenIP + ":8554", + } + + switch transport { + case "udp": + s.UDPRTPAddress = "127.0.0.1:8000" + s.UDPRTCPAddress = "127.0.0.1:8001" + + case "multicast": + s.MulticastIPRange = "224.1.0.0/16" + s.MulticastRTPPort = 8000 + s.MulticastRTCPPort = 8001 + + case "tls": + cert, err := tls.X509KeyPair(serverCert, serverKey) + require.NoError(t, err) + s.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}} + } + + err := s.Start() + require.NoError(t, err) + defer s.Close() + + stream = NewServerStream(s, &description.Session{Medias: []*description.Media{testH264Media}}) + + func() { + nconn, err := net.Dial("tcp", listenIP+":8554") + require.NoError(t, err) + defer nconn.Close() + + nconn = func() net.Conn { + if transport == "tls" { + return tls.Client(nconn, &tls.Config{InsecureSkipVerify: true}) + } + return nconn + }() + conn := conn.NewConn(nconn) + + desc := doDescribe(t, conn) + + inTH := &headers.Transport{ + Mode: transportModePtr(headers.TransportModePlay), + } + + switch transport { + case "udp": + v := headers.TransportDeliveryUnicast + inTH.Delivery = &v + inTH.Protocol = headers.TransportProtocolUDP + inTH.ClientPorts = &[2]int{35466, 35467} + + case "multicast": + v := headers.TransportDeliveryMulticast + inTH.Delivery = &v + inTH.Protocol = headers.TransportProtocolUDP + + default: + v := headers.TransportDeliveryUnicast + inTH.Delivery = &v + inTH.Protocol = headers.TransportProtocolTCP + inTH.InterleavedIDs = &[2]int{5, 6} // odd value + } + + res, th := doSetup(t, conn, mediaURL(t, desc.BaseURL, desc.Medias[0]).String(), inTH, "") + + var l1 net.PacketConn + var l2 net.PacketConn + + switch transport { //nolint:dupl + case "udp": + require.Equal(t, headers.TransportProtocolUDP, th.Protocol) + require.Equal(t, headers.TransportDeliveryUnicast, *th.Delivery) + + l1, err = net.ListenPacket("udp", listenIP+":35466") + require.NoError(t, err) + defer l1.Close() + + l2, err = net.ListenPacket("udp", listenIP+":35467") + require.NoError(t, err) + defer l2.Close() + + case "multicast": + require.Equal(t, headers.TransportProtocolUDP, th.Protocol) + require.Equal(t, headers.TransportDeliveryMulticast, *th.Delivery) + + l1, err = net.ListenPacket("udp", "224.0.0.0:"+strconv.FormatInt(int64(th.Ports[0]), 10)) + require.NoError(t, err) + defer l1.Close() + + p := ipv4.NewPacketConn(l1) + + var intfs []net.Interface + intfs, err = net.Interfaces() + require.NoError(t, err) + + for _, intf := range intfs { + err = p.JoinGroup(&intf, &net.UDPAddr{IP: *th.Destination}) + require.NoError(t, err) + } + + l2, err = net.ListenPacket("udp", "224.0.0.0:"+strconv.FormatInt(int64(th.Ports[1]), 10)) + require.NoError(t, err) + defer l2.Close() + + p = ipv4.NewPacketConn(l2) + + intfs, err = net.Interfaces() + require.NoError(t, err) + + for _, intf := range intfs { + err = p.JoinGroup(&intf, &net.UDPAddr{IP: *th.Destination}) + require.NoError(t, err) + } + + default: + require.Equal(t, headers.TransportProtocolTCP, th.Protocol) + require.Equal(t, headers.TransportDeliveryUnicast, *th.Delivery) + } + + session := readSession(t, res) + + doPlay(t, conn, "rtsp://"+listenIP+":8554/teststream", session) + }() + + <-connClosed + + stream.Close() + <-writeDone + }) + } +} + func TestServerPlayDecodeErrors(t *testing.T) { for _, ca := range []struct { proto string diff --git a/server_session.go b/server_session.go index 1ca8807d..5fb49503 100644 --- a/server_session.go +++ b/server_session.go @@ -252,7 +252,7 @@ type ServerSession struct { announcedDesc *description.Session // publish udpLastPacketTime *int64 // publish udpCheckStreamTimer *time.Timer - writer asyncProcessor + writer *asyncProcessor timeDecoder *rtptime.GlobalDecoder2 // in @@ -425,12 +425,14 @@ func (ss *ServerSession) run() { ss.setuppedStream.readerRemove(ss) } - ss.writer.stop() - for _, sm := range ss.setuppedMedias { sm.stop() } + if ss.writer != nil { + ss.writer.stop() + } + ss.s.closeSession(ss) if h, ok := ss.s.Handler.(ServerHandlerOnSessionClose); ok { @@ -443,6 +445,13 @@ func (ss *ServerSession) run() { func (ss *ServerSession) runInner() error { for { + chWriterError := func() chan error { + if ss.writer != nil { + return ss.writer.chError + } + return nil + }() + select { case req := <-ss.chHandleRequest: ss.lastRequestTime = ss.s.timeNow() @@ -539,6 +548,9 @@ func (ss *ServerSession) runInner() error { ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod) + case err := <-chWriterError: + return err + case <-ss.ctx.Done(): return liberrors.ErrServerTerminated{} } @@ -930,7 +942,10 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( // inside the callback. if ss.state != ServerSessionStatePlay && *ss.setuppedTransport != TransportUDPMulticast { - ss.writer.allocateBuffer(ss.s.WriteQueueSize) + ss.writer = &asyncProcessor{ + bufferSize: ss.s.WriteQueueSize, + } + ss.writer.initialize() } res, err := sc.s.Handler.(ServerHandlerOnPlay).OnPlay(&ServerHandlerOnPlayCtx{ @@ -1023,7 +1038,10 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( // when recording, writeBuffer is only used to send RTCP receiver reports, // that are much smaller than RTP packets and are sent at a fixed interval. // decrease RAM consumption by allocating less buffers. - ss.writer.allocateBuffer(8) + ss.writer = &asyncProcessor{ + bufferSize: 8, + } + ss.writer.initialize() res, err := ss.s.Handler.(ServerHandlerOnRecord).OnRecord(&ServerHandlerOnRecordCtx{ Session: ss, @@ -1087,45 +1105,48 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( return res, err } - if ss.setuppedStream != nil { - ss.setuppedStream.readerSetInactive(ss) - } + if ss.state == ServerSessionStatePlay || ss.state == ServerSessionStateRecord { + if ss.setuppedStream != nil { + ss.setuppedStream.readerSetInactive(ss) + } - ss.writer.stop() + for _, sm := range ss.setuppedMedias { + sm.stop() + } - for _, sm := range ss.setuppedMedias { - sm.stop() - } + ss.writer.stop() + ss.writer = nil - ss.timeDecoder = nil + ss.timeDecoder = nil - switch ss.state { - case ServerSessionStatePlay: - ss.state = ServerSessionStatePrePlay + switch ss.state { + case ServerSessionStatePlay: + ss.state = ServerSessionStatePrePlay - switch *ss.setuppedTransport { - case TransportUDP: - ss.udpCheckStreamTimer = emptyTimer() + switch *ss.setuppedTransport { + case TransportUDP: + ss.udpCheckStreamTimer = emptyTimer() - case TransportUDPMulticast: - ss.udpCheckStreamTimer = emptyTimer() + case TransportUDPMulticast: + ss.udpCheckStreamTimer = emptyTimer() - default: // TCP - err = switchReadFuncError{false} - ss.tcpConn = nil - } + default: // TCP + err = switchReadFuncError{false} + ss.tcpConn = nil + } - case ServerSessionStateRecord: - switch *ss.setuppedTransport { - case TransportUDP: - ss.udpCheckStreamTimer = emptyTimer() + case ServerSessionStateRecord: + switch *ss.setuppedTransport { + case TransportUDP: + ss.udpCheckStreamTimer = emptyTimer() - default: // TCP - err = switchReadFuncError{false} - ss.tcpConn = nil - } + default: // TCP + err = switchReadFuncError{false} + ss.tcpConn = nil + } - ss.state = ServerSessionStatePreRecord + ss.state = ServerSessionStatePreRecord + } } return res, err diff --git a/server_session_media.go b/server_session_media.go index 2ee6369f..5e74ee9d 100644 --- a/server_session_media.go +++ b/server_session_media.go @@ -27,8 +27,8 @@ type serverSessionMedia struct { tcpRTCPFrame *base.InterleavedFrame tcpBuffer []byte formats map[uint8]*serverSessionFormat // record only - writePacketRTPInQueue func([]byte) - writePacketRTCPInQueue func([]byte) + writePacketRTPInQueue func([]byte) error + writePacketRTCPInQueue func([]byte) error } func (sm *serverSessionMedia) initialize() { @@ -115,33 +115,33 @@ func (sm *serverSessionMedia) findFormatWithSSRC(ssrc uint32) *serverSessionForm return nil } -func (sm *serverSessionMedia) writePacketRTPInQueueUDP(payload []byte) { +func (sm *serverSessionMedia) writePacketRTPInQueueUDP(payload []byte) error { atomic.AddUint64(sm.ss.bytesSent, uint64(len(payload))) - sm.ss.s.udpRTPListener.write(payload, sm.udpRTPWriteAddr) //nolint:errcheck + return sm.ss.s.udpRTPListener.write(payload, sm.udpRTPWriteAddr) } -func (sm *serverSessionMedia) writePacketRTCPInQueueUDP(payload []byte) { +func (sm *serverSessionMedia) writePacketRTCPInQueueUDP(payload []byte) error { atomic.AddUint64(sm.ss.bytesSent, uint64(len(payload))) - sm.ss.s.udpRTCPListener.write(payload, sm.udpRTCPWriteAddr) //nolint:errcheck + return sm.ss.s.udpRTCPListener.write(payload, sm.udpRTCPWriteAddr) } -func (sm *serverSessionMedia) writePacketRTPInQueueTCP(payload []byte) { +func (sm *serverSessionMedia) writePacketRTPInQueueTCP(payload []byte) error { atomic.AddUint64(sm.ss.bytesSent, uint64(len(payload))) sm.tcpRTPFrame.Payload = payload sm.ss.tcpConn.nconn.SetWriteDeadline(time.Now().Add(sm.ss.s.WriteTimeout)) - sm.ss.tcpConn.conn.WriteInterleavedFrame(sm.tcpRTPFrame, sm.tcpBuffer) //nolint:errcheck + return sm.ss.tcpConn.conn.WriteInterleavedFrame(sm.tcpRTPFrame, sm.tcpBuffer) } -func (sm *serverSessionMedia) writePacketRTCPInQueueTCP(payload []byte) { +func (sm *serverSessionMedia) writePacketRTCPInQueueTCP(payload []byte) error { atomic.AddUint64(sm.ss.bytesSent, uint64(len(payload))) sm.tcpRTCPFrame.Payload = payload sm.ss.tcpConn.nconn.SetWriteDeadline(time.Now().Add(sm.ss.s.WriteTimeout)) - sm.ss.tcpConn.conn.WriteInterleavedFrame(sm.tcpRTCPFrame, sm.tcpBuffer) //nolint:errcheck + return sm.ss.tcpConn.conn.WriteInterleavedFrame(sm.tcpRTCPFrame, sm.tcpBuffer) } func (sm *serverSessionMedia) writePacketRTP(payload []byte) error { - ok := sm.ss.writer.push(func() { - sm.writePacketRTPInQueue(payload) + ok := sm.ss.writer.push(func() error { + return sm.writePacketRTPInQueue(payload) }) if !ok { return liberrors.ErrServerWriteQueueFull{} @@ -151,8 +151,8 @@ func (sm *serverSessionMedia) writePacketRTP(payload []byte) error { } func (sm *serverSessionMedia) writePacketRTCP(payload []byte) error { - ok := sm.ss.writer.push(func() { - sm.writePacketRTCPInQueue(payload) + ok := sm.ss.writer.push(func() error { + return sm.writePacketRTCPInQueue(payload) }) if !ok { return liberrors.ErrServerWriteQueueFull{}