Skip to content

Commit

Permalink
Refactor Packet struct to match RFC format
Browse files Browse the repository at this point in the history
Remove PayloadOffset from Header.
Don't keep unnecessary full data bytes inside Packet.
Packet Marshal and Unmarshal method API stay the same.

Fixes pion#90
  • Loading branch information
ffmiruz authored and aler9 committed May 16, 2021
1 parent 4e87540 commit b7e6b3a
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 112 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ Check out the **[contributing wiki](https://github.com/pion/webrtc/wiki/Contribu
* [Robin Raymond](https://github.com/robin-raymond)
* [debiandebiandebian](https://github.com/debiandebiandebian)
* [Juliusz Chroboczek](https://github.com/jech)
* [ffmiyo](https://github.com/ffmiyo)

### License
MIT License - see [LICENSE](LICENSE) for full text
130 changes: 61 additions & 69 deletions packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@ type Extension struct {
}

// Header represents an RTP packet header
// NOTE: PayloadOffset is populated by Marshal/Unmarshal and should not be modified
type Header struct {
Version uint8
Padding bool
Extension bool
Marker bool
PayloadOffset int
PayloadType uint8
SequenceNumber uint16
Timestamp uint32
Expand All @@ -30,10 +28,8 @@ type Header struct {
}

// Packet represents an RTP Packet
// NOTE: Raw is populated by Marshal/Unmarshal and should not be modified
type Packet struct {
Header
Raw []byte
Payload []byte
}

Expand Down Expand Up @@ -77,10 +73,11 @@ func (p Packet) String() string {
return out
}

// Unmarshal parses the passed byte slice and stores the result in the Header this method is called upon
func (h *Header) Unmarshal(rawPacket []byte) error { //nolint:gocognit
if len(rawPacket) < headerLength {
return fmt.Errorf("%w: %d < %d", errHeaderSizeInsufficient, len(rawPacket), headerLength)
// Unmarshal parses the passed byte slice and stores the result in the Header.
// It returns the number of bytes read n and any error.
func (h *Header) Unmarshal(buf []byte) (n int, err error) { //nolint:gocognit
if len(buf) < headerLength {
return 0, fmt.Errorf("%w: %d < %d", errHeaderSizeInsufficient, len(buf), headerLength)
}

/*
Expand All @@ -98,124 +95,122 @@ func (h *Header) Unmarshal(rawPacket []byte) error { //nolint:gocognit
* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
*/

h.Version = rawPacket[0] >> versionShift & versionMask
h.Padding = (rawPacket[0] >> paddingShift & paddingMask) > 0
h.Extension = (rawPacket[0] >> extensionShift & extensionMask) > 0
nCSRC := int(rawPacket[0] & ccMask)
h.Version = buf[0] >> versionShift & versionMask
h.Padding = (buf[0] >> paddingShift & paddingMask) > 0
h.Extension = (buf[0] >> extensionShift & extensionMask) > 0
nCSRC := int(buf[0] & ccMask)
if cap(h.CSRC) < nCSRC || h.CSRC == nil {
h.CSRC = make([]uint32, nCSRC)
} else {
h.CSRC = h.CSRC[:nCSRC]
}

currOffset := csrcOffset + (nCSRC * csrcLength)
if len(rawPacket) < currOffset {
return fmt.Errorf("size %d < %d: %w", len(rawPacket), currOffset, errHeaderSizeInsufficient)
n = csrcOffset + (nCSRC * csrcLength)
if len(buf) < n {
return n, fmt.Errorf("size %d < %d: %w", len(buf), n,
errHeaderSizeInsufficient)
}

h.Marker = (rawPacket[1] >> markerShift & markerMask) > 0
h.PayloadType = rawPacket[1] & ptMask
h.Marker = (buf[1] >> markerShift & markerMask) > 0
h.PayloadType = buf[1] & ptMask

h.SequenceNumber = binary.BigEndian.Uint16(rawPacket[seqNumOffset : seqNumOffset+seqNumLength])
h.Timestamp = binary.BigEndian.Uint32(rawPacket[timestampOffset : timestampOffset+timestampLength])
h.SSRC = binary.BigEndian.Uint32(rawPacket[ssrcOffset : ssrcOffset+ssrcLength])
h.SequenceNumber = binary.BigEndian.Uint16(buf[seqNumOffset : seqNumOffset+seqNumLength])
h.Timestamp = binary.BigEndian.Uint32(buf[timestampOffset : timestampOffset+timestampLength])
h.SSRC = binary.BigEndian.Uint32(buf[ssrcOffset : ssrcOffset+ssrcLength])

for i := range h.CSRC {
offset := csrcOffset + (i * csrcLength)
h.CSRC[i] = binary.BigEndian.Uint32(rawPacket[offset:])
h.CSRC[i] = binary.BigEndian.Uint32(buf[offset:])
}

if h.Extensions != nil {
h.Extensions = h.Extensions[:0]
}

if h.Extension {
if expected := currOffset + 4; len(rawPacket) < expected {
return fmt.Errorf("size %d < %d: %w",
len(rawPacket), expected,
if expected := n + 4; len(buf) < expected {
return n, fmt.Errorf("size %d < %d: %w",
len(buf), expected,
errHeaderSizeInsufficientForExtension,
)
}

h.ExtensionProfile = binary.BigEndian.Uint16(rawPacket[currOffset:])
currOffset += 2
extensionLength := int(binary.BigEndian.Uint16(rawPacket[currOffset:])) * 4
currOffset += 2
h.ExtensionProfile = binary.BigEndian.Uint16(buf[n:])
n += 2
extensionLength := int(binary.BigEndian.Uint16(buf[n:])) * 4
n += 2

if expected := currOffset + extensionLength; len(rawPacket) < expected {
return fmt.Errorf("size %d < %d: %w",
len(rawPacket), expected,
if expected := n + extensionLength; len(buf) < expected {
return n, fmt.Errorf("size %d < %d: %w",
len(buf), expected,
errHeaderSizeInsufficientForExtension,
)
}

switch h.ExtensionProfile {
// RFC 8285 RTP One Byte Header Extension
case extensionProfileOneByte:
end := currOffset + extensionLength
for currOffset < end {
if rawPacket[currOffset] == 0x00 { // padding
currOffset++
end := n + extensionLength
for n < end {
if buf[n] == 0x00 { // padding
n++
continue
}

extid := rawPacket[currOffset] >> 4
len := int(rawPacket[currOffset]&^0xF0 + 1)
currOffset++
extid := buf[n] >> 4
len := int(buf[n]&^0xF0 + 1)
n++

if extid == extensionIDReserved {
break
}

extension := Extension{id: extid, payload: rawPacket[currOffset : currOffset+len]}
extension := Extension{id: extid, payload: buf[n : n+len]}
h.Extensions = append(h.Extensions, extension)
currOffset += len
n += len
}

// RFC 8285 RTP Two Byte Header Extension
case extensionProfileTwoByte:
end := currOffset + extensionLength
for currOffset < end {
if rawPacket[currOffset] == 0x00 { // padding
currOffset++
end := n + extensionLength
for n < end {
if buf[n] == 0x00 { // padding
n++
continue
}

extid := rawPacket[currOffset]
currOffset++
extid := buf[n]
n++

len := int(rawPacket[currOffset])
currOffset++
len := int(buf[n])
n++

extension := Extension{id: extid, payload: rawPacket[currOffset : currOffset+len]}
extension := Extension{id: extid, payload: buf[n : n+len]}
h.Extensions = append(h.Extensions, extension)
currOffset += len
n += len
}

default: // RFC3550 Extension
if len(rawPacket) < currOffset+extensionLength {
return fmt.Errorf("%w: %d < %d", errHeaderSizeInsufficientForExtension, len(rawPacket), currOffset+extensionLength)
if len(buf) < n+extensionLength {
return n, fmt.Errorf("%w: %d < %d",
errHeaderSizeInsufficientForExtension, len(buf), n+extensionLength)
}

extension := Extension{id: 0, payload: rawPacket[currOffset : currOffset+extensionLength]}
extension := Extension{id: 0, payload: buf[n : n+extensionLength]}
h.Extensions = append(h.Extensions, extension)
currOffset += len(h.Extensions[0].payload)
n += len(h.Extensions[0].payload)
}
}

h.PayloadOffset = currOffset

return nil
return n, nil
}

// Unmarshal parses the passed byte slice and stores the result in the Packet this method is called upon
func (p *Packet) Unmarshal(rawPacket []byte) error {
if err := p.Header.Unmarshal(rawPacket); err != nil {
// Unmarshal parses the passed byte slice and stores the result in the Packet.
func (p *Packet) Unmarshal(buf []byte) error {
n, err := p.Header.Unmarshal(buf)
if err != nil {
return err
}

p.Payload = rawPacket[p.PayloadOffset:]
p.Raw = rawPacket
p.Payload = buf[n:]
return nil
}

Expand All @@ -227,7 +222,6 @@ func (h *Header) Marshal() (buf []byte, err error) {
if err != nil {
return nil, err
}

return buf[:n], nil
}

Expand All @@ -253,7 +247,8 @@ func (h *Header) MarshalTo(buf []byte) (n int, err error) {
return 0, io.ErrShortBuffer
}

// The first byte contains the version, padding bit, extension bit, and csrc size
// The first byte contains the version, padding bit, extension bit,
// and csrc size.
buf[0] = (h.Version << versionShift) | uint8(len(h.CSRC))
if h.Padding {
buf[0] |= 1 << paddingShift
Expand Down Expand Up @@ -324,8 +319,6 @@ func (h *Header) MarshalTo(buf []byte) (n int, err error) {
}
}

h.PayloadOffset = n

return n, nil
}

Expand Down Expand Up @@ -479,7 +472,6 @@ func (p *Packet) MarshalTo(buf []byte) (n int, err error) {
}

m := copy(buf[n:], p.Payload)
p.Raw = buf[:n+m]

return n + m, nil
}
Expand Down
Loading

0 comments on commit b7e6b3a

Please sign in to comment.