Skip to content

Commit

Permalink
close connections in case of write errors (#613)
Browse files Browse the repository at this point in the history
  • Loading branch information
aler9 committed Dec 14, 2024
1 parent a2df9d8 commit 8172f2c
Show file tree
Hide file tree
Showing 11 changed files with 241 additions and 343 deletions.
37 changes: 24 additions & 13 deletions async_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,46 +4,57 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/ringbuffer"
)

// this struct contains a queue that allows to detach the routine that is reading a stream
// this is an asynchronous queue processor
// that allows to detach the routine that is reading a stream
// from the routine that is writing a stream.
type asyncProcessor struct {
bufferSize int

running bool
buffer *ringbuffer.RingBuffer

done chan struct{}
chError chan error
}

func (w *asyncProcessor) allocateBuffer(size int) {
w.buffer, _ = ringbuffer.New(uint64(size))
func (w *asyncProcessor) initialize() {
w.buffer, _ = ringbuffer.New(uint64(w.bufferSize))
}

func (w *asyncProcessor) start() {
w.running = true
w.done = make(chan struct{})
w.chError = make(chan error)
go w.run()
}

func (w *asyncProcessor) stop() {
if w.running {
w.buffer.Close()
<-w.done
w.running = false
if !w.running {
panic("should not happen")
}
w.buffer.Close()
<-w.chError
w.running = false
}

func (w *asyncProcessor) run() {
defer close(w.done)
err := w.runInner()
w.chError <- err
close(w.chError)
}

func (w *asyncProcessor) runInner() error {
for {
tmp, ok := w.buffer.Pull()
if !ok {
return
return nil
}

tmp.(func())()
err := tmp.(func() error)()
if err != nil {
return err
}
}
}

func (w *asyncProcessor) push(cb func()) bool {
func (w *asyncProcessor) push(cb func() error) bool {
return w.buffer.Push(cb)
}
127 changes: 70 additions & 57 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,22 +335,19 @@ type Client struct {
keepalivePeriod time.Duration
keepaliveTimer *time.Timer
closeError error
writer asyncProcessor
writer *asyncProcessor
reader *clientReader
timeDecoder *rtptime.GlobalDecoder2
mustClose bool

// in
chOptions chan optionsReq
chDescribe chan describeReq
chAnnounce chan announceReq
chSetup chan setupReq
chPlay chan playReq
chRecord chan recordReq
chPause chan pauseReq
chReadError chan error
chReadResponse chan *base.Response
chReadRequest chan *base.Request
chOptions chan optionsReq
chDescribe chan describeReq
chAnnounce chan announceReq
chSetup chan setupReq
chPlay chan playReq
chRecord chan recordReq
chPause chan pauseReq

// out
done chan struct{}
Expand Down Expand Up @@ -462,9 +459,6 @@ func (c *Client) Start(scheme string, host string) error {
c.chPlay = make(chan playReq)
c.chRecord = make(chan recordReq)
c.chPause = make(chan pauseReq)
c.chReadError = make(chan error)
c.chReadResponse = make(chan *base.Response)
c.chReadRequest = make(chan *base.Request)
c.done = make(chan struct{})

go c.run()
Expand Down Expand Up @@ -530,6 +524,34 @@ func (c *Client) run() {

func (c *Client) runInner() error {
for {
chReaderResponse := func() chan *base.Response {
if c.reader != nil {
return c.reader.chResponse
}
return nil
}()

chReaderRequest := func() chan *base.Request {
if c.reader != nil {
return c.reader.chRequest
}
return nil
}()

chReaderError := func() chan error {
if c.reader != nil {
return c.reader.chError
}
return nil
}()

chWriterError := func() chan error {
if c.writer != nil {
return c.writer.chError
}
return nil
}()

select {
case req := <-c.chOptions:
res, err := c.doOptions(req.url)
Expand Down Expand Up @@ -601,15 +623,18 @@ func (c *Client) runInner() error {
}
c.keepaliveTimer = time.NewTimer(c.keepalivePeriod)

case err := <-c.chReadError:
case err := <-chWriterError:
return err

case err := <-chReaderError:
c.reader = nil
return err

case res := <-c.chReadResponse:
case res := <-chReaderResponse:
c.OnResponse(res)
// these are responses to keepalives, ignore them.

case req := <-c.chReadRequest:
case req := <-chReaderRequest:
err := c.handleServerRequest(req)
if err != nil {
return err
Expand All @@ -630,19 +655,19 @@ func (c *Client) waitResponse(requestCseqStr string) (*base.Response, error) {
case <-t.C:
return nil, liberrors.ErrClientRequestTimedOut{}

case err := <-c.chReadError:
case err := <-c.reader.chError:
c.reader = nil
return nil, err

case res := <-c.chReadResponse:
case res := <-c.reader.chResponse:
c.OnResponse(res)

// accept response if CSeq equals request CSeq, or if CSeq is not present
if cseq, ok := res.Header["CSeq"]; !ok || len(cseq) != 1 || strings.TrimSpace(cseq[0]) == requestCseqStr {
return res, nil
}

case req := <-c.chReadRequest:
case req := <-c.reader.chRequest:
err := c.handleServerRequest(req)
if err != nil {
return nil, err
Expand Down Expand Up @@ -682,8 +707,8 @@ func (c *Client) handleServerRequest(req *base.Request) error {

func (c *Client) doClose() {
if c.state == clientStatePlay || c.state == clientStateRecord {
c.stopWriter()
c.stopReadRoutines()
c.writer.stop()
c.stopTransportRoutines()
}

if c.nconn != nil && c.baseURL != nil {
Expand Down Expand Up @@ -808,15 +833,21 @@ func (c *Client) trySwitchingProtocol2(medi *description.Media, baseURL *base.UR
return c.doSetup(baseURL, medi, 0, 0)
}

func (c *Client) startReadRoutines() {
func (c *Client) startTransportRoutines() {
// allocate writer here because it's needed by RTCP receiver / sender
if c.state == clientStateRecord || c.backChannelSetupped {
c.writer.allocateBuffer(c.WriteQueueSize)
c.writer = &asyncProcessor{
bufferSize: c.WriteQueueSize,
}
c.writer.initialize()
} else {
// when reading, buffer is only used to send RTCP receiver reports,
// that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers.
c.writer.allocateBuffer(8)
c.writer = &asyncProcessor{
bufferSize: 8,
}
c.writer.initialize()
}

c.timeDecoder = rtptime.NewGlobalDecoder2()
Expand Down Expand Up @@ -848,7 +879,7 @@ func (c *Client) startReadRoutines() {
}
}

func (c *Client) stopReadRoutines() {
func (c *Client) stopTransportRoutines() {
if c.reader != nil {
c.reader.setAllowInterleavedFrames(false)
}
Expand All @@ -861,14 +892,8 @@ func (c *Client) stopReadRoutines() {
}

c.timeDecoder = nil
}

func (c *Client) startWriter() {
c.writer.start()
}

func (c *Client) stopWriter() {
c.writer.stop()
c.writer = nil
}

func (c *Client) connOpen() error {
Expand Down Expand Up @@ -1637,7 +1662,7 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
}

c.state = clientStatePlay
c.startReadRoutines()
c.startTransportRoutines()

// Range is mandatory in Parrot Streaming Server
if ra == nil {
Expand All @@ -1662,13 +1687,13 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
Header: header,
}, false)
if err != nil {
c.stopReadRoutines()
c.stopTransportRoutines()
c.state = clientStatePrePlay
return nil, err
}

if res.StatusCode != base.StatusOK {
c.stopReadRoutines()
c.stopTransportRoutines()
c.state = clientStatePrePlay
return nil, liberrors.ErrClientBadStatusCode{
Code: res.StatusCode, Message: res.StatusMessage,
Expand All @@ -1689,7 +1714,7 @@ func (c *Client) doPlay(ra *headers.Range) (*base.Response, error) {
}
}

c.startWriter()
c.writer.start()
c.lastRange = ra

return res, nil
Expand Down Expand Up @@ -1718,27 +1743,27 @@ func (c *Client) doRecord() (*base.Response, error) {
}

c.state = clientStateRecord
c.startReadRoutines()
c.startTransportRoutines()

res, err := c.do(&base.Request{
Method: base.Record,
URL: c.baseURL,
}, false)
if err != nil {
c.stopReadRoutines()
c.stopTransportRoutines()
c.state = clientStatePreRecord
return nil, err
}

if res.StatusCode != base.StatusOK {
c.stopReadRoutines()
c.stopTransportRoutines()
c.state = clientStatePreRecord
return nil, liberrors.ErrClientBadStatusCode{
Code: res.StatusCode, Message: res.StatusMessage,
}
}

c.startWriter()
c.writer.start()

return nil, nil
}
Expand Down Expand Up @@ -1766,25 +1791,25 @@ func (c *Client) doPause() (*base.Response, error) {
return nil, err
}

c.stopWriter()
c.writer.stop()

res, err := c.do(&base.Request{
Method: base.Pause,
URL: c.baseURL,
}, false)
if err != nil {
c.startWriter()
c.writer.start()
return nil, err
}

if res.StatusCode != base.StatusOK {
c.startWriter()
c.writer.start()
return nil, liberrors.ErrClientBadStatusCode{
Code: res.StatusCode, Message: res.StatusMessage,
}
}

c.stopReadRoutines()
c.stopTransportRoutines()

switch c.state {
case clientStatePlay:
Expand Down Expand Up @@ -1929,15 +1954,3 @@ func (c *Client) PacketNTP(medi *description.Media, pkt *rtp.Packet) (time.Time,
ct := cm.formats[pkt.PayloadType]
return ct.rtcpReceiver.PacketNTP(pkt.Timestamp)
}

func (c *Client) readResponse(res *base.Response) {
c.chReadResponse <- res
}

func (c *Client) readRequest(req *base.Request) {
c.chReadRequest <- req
}

func (c *Client) readError(err error) {
c.chReadError <- err
}
4 changes: 2 additions & 2 deletions client_format.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ func (cf *clientFormat) stop() {
func (cf *clientFormat) writePacketRTP(byts []byte, pkt *rtp.Packet, ntp time.Time) error {
cf.rtcpSender.ProcessPacket(pkt, ntp, cf.format.PTSEqualsDTS(pkt))

ok := cf.cm.c.writer.push(func() {
cf.cm.writePacketRTPInQueue(byts)
ok := cf.cm.c.writer.push(func() error {
return cf.cm.writePacketRTPInQueue(byts)
})
if !ok {
return liberrors.ErrClientWriteQueueFull{}
Expand Down
Loading

0 comments on commit 8172f2c

Please sign in to comment.