Skip to content

Commit

Permalink
Merge pull request #302 from wader/decode-cleanup-trymust
Browse files Browse the repository at this point in the history
decode: Cleanup Try<f>/<f> pairs
  • Loading branch information
wader authored Jun 30, 2022
2 parents 53ca5f2 + a6a9713 commit f96637f
Show file tree
Hide file tree
Showing 23 changed files with 56 additions and 56 deletions.
6 changes: 3 additions & 3 deletions format/avro/avro_ocf.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func decodeBlockCodec(d *decode.D, dataSize int64, codec string) *bytes.Buffer {
bb := &bytes.Buffer{}
if codec == "deflate" {
br := d.FieldRawLen("compressed", dataSize*8)
d.MustCopy(bb, flate.NewReader(bitio.NewIOReader(br)))
d.Copy(bb, flate.NewReader(bitio.NewIOReader(br)))
} else if codec == "snappy" {
// Everything but last 4 bytes which are the checksum
n := dataSize - 4
Expand All @@ -110,11 +110,11 @@ func decodeBlockCodec(d *decode.D, dataSize int64, codec string) *bytes.Buffer {
if err != nil {
d.Fatalf("failed decompressing data: %v", err)
}
d.MustCopy(bb, bytes.NewReader(decompressed))
d.Copy(bb, bytes.NewReader(decompressed))

// Check the checksum
crc32W := crc32.NewIEEE()
d.MustCopy(crc32W, bytes.NewReader(bb.Bytes()))
d.Copy(crc32W, bytes.NewReader(bb.Bytes()))
d.FieldU32("crc", d.ValidateUBytes(crc32W.Sum(nil)), scalar.ActualHex)
} else {
// Unknown codec, just dump the compressed data.
Expand Down
2 changes: 1 addition & 1 deletion format/bzip2/bzip2.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func bzip2Decode(d *decode.D, in any) any {
}

blockCRC32W := crc32.NewIEEE()
d.MustCopy(blockCRC32W, bitFlipReader{bitio.NewIOReader(uncompressedBR)})
d.Copy(blockCRC32W, bitFlipReader{bitio.NewIOReader(uncompressedBR)})
blockCRC32N := bits.Reverse32(binary.BigEndian.Uint32(blockCRC32W.Sum(nil)))
_ = blockCRCValue.TryScalarFn(d.ValidateU(uint64(blockCRC32N)))
streamCRCN = blockCRC32N ^ ((streamCRCN << 1) | (streamCRCN >> 31))
Expand Down
2 changes: 1 addition & 1 deletion format/cbor/cbor.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func decodeCBORValue(d *decode.D) any {
return nil
}

buf := d.MustReadAllBits(d.FieldRawLen("value", int64(count)*8))
buf := d.ReadAllBits(d.FieldRawLen("value", int64(count)*8))

return buf
}},
Expand Down
2 changes: 1 addition & 1 deletion format/flac/flac.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func flacDecode(d *decode.D, in any) any {
frameStreamSamplesBuf := ffo.SamplesBuf[0 : samplesInFrame*uint64(ffo.Channels*ffo.BitsPerSample/8)]
framesNDecodedSamples += ffo.Samples

d.MustCopy(md5Samples, bytes.NewReader(frameStreamSamplesBuf))
d.Copy(md5Samples, bytes.NewReader(frameStreamSamplesBuf))
streamDecodedSamples += ffo.Samples

// reuse buffer if possible
Expand Down
4 changes: 2 additions & 2 deletions format/flac/flac_frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ func frameDecode(d *decode.D, in any) any {
})

headerCRC := &checksum.CRC{Bits: 8, Table: checksum.ATM8Table}
d.MustCopyBits(headerCRC, d.BitBufRange(frameStart, d.Pos()-frameStart))
d.CopyBits(headerCRC, d.BitBufRange(frameStart, d.Pos()-frameStart))
d.FieldU8("crc", d.ValidateUBytes(headerCRC.Sum(nil)), scalar.ActualHex)
})

Expand Down Expand Up @@ -565,7 +565,7 @@ func frameDecode(d *decode.D, in any) any {
d.FieldU("byte_align", d.ByteAlignBits(), d.AssertU(0))
// <16> CRC-16 (polynomial = x^16 + x^15 + x^2 + x^0, initialized with 0) of everything before the crc, back to and including the frame header sync code
footerCRC := &checksum.CRC{Bits: 16, Table: checksum.ANSI16Table}
d.MustCopyBits(footerCRC, d.BitBufRange(frameStart, d.Pos()-frameStart))
d.CopyBits(footerCRC, d.BitBufRange(frameStart, d.Pos()-frameStart))
d.FieldRawLen("footer_crc", 16, d.ValidateBitBuf(footerCRC.Sum(nil)), scalar.RawHex)

streamSamples := len(channelSamples[0])
Expand Down
2 changes: 1 addition & 1 deletion format/flac/flac_streaminfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func streaminfoDecode(d *decode.D, in any) any {
bitsPerSample := d.FieldU5("bits_per_sample", scalar.ActualUAdd(1))
totalSamplesInStream := d.FieldU("total_samples_in_stream", 36)
md5BR := d.FieldRawLen("md5", 16*8, scalar.RawHex)
md5b := d.MustReadAllBits(md5BR)
md5b := d.ReadAllBits(md5BR)

return format.FlacStreaminfoOut{
StreamInfo: format.FlacStreamInfo{
Expand Down
2 changes: 1 addition & 1 deletion format/gif/gif.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func gifDecode(d *decode.D, in any) any {
d.FieldU8("terminator")
seenTerminator = true
}
d.MustCopyBits(dataBytes, d.MustCloneReadSeeker(b))
d.CopyBits(dataBytes, d.CloneReadSeeker(b))
})
}
})
Expand Down
2 changes: 1 addition & 1 deletion format/gzip/gzip.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func gzDecode(d *decode.D, in any) any {
d.FieldRawLen("compressed", readCompressedSize)
crc32W := crc32.NewIEEE()
// TODO: cleanup clone
d.MustCopyBits(crc32W, d.MustCloneReadSeeker(uncompressedBR))
d.CopyBits(crc32W, d.CloneReadSeeker(uncompressedBR))
d.FieldU32("crc32", d.ValidateUBytes(crc32W.Sum(nil)), scalar.ActualHex)
d.FieldU32("isize")
}
Expand Down
2 changes: 1 addition & 1 deletion format/id3/id3v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ func decodeFrame(d *decode.D, version int) uint64 {
if unsyncFlag {
// TODO: DecodeFn
// TODO: unknown after frame decode
unsyncedBR := d.MustNewBitBufFromReader(unsyncReader{Reader: bitio.NewIOReader(d.BitBufRange(d.Pos(), int64(dataSize)*8))})
unsyncedBR := d.NewBitBufFromReader(unsyncReader{Reader: bitio.NewIOReader(d.BitBufRange(d.Pos(), int64(dataSize)*8))})
d.FieldFormatBitBuf("unsync", unsyncedBR, decode.FormatFn(func(d *decode.D, in any) any {
if fn, ok := frames[idNormalized]; ok {
fn(d)
Expand Down
4 changes: 2 additions & 2 deletions format/inet/ipv4_packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ func decodeIPv4(d *decode.D, in any) any {
headerEnd := d.Pos()

ipv4Checksum := &checksum.IPv4{}
d.MustCopy(ipv4Checksum, bitio.NewIOReader(d.BitBufRange(0, checksumStart)))
d.MustCopy(ipv4Checksum, bitio.NewIOReader(d.BitBufRange(checksumEnd, headerEnd-checksumEnd)))
d.Copy(ipv4Checksum, bitio.NewIOReader(d.BitBufRange(0, checksumStart)))
d.Copy(ipv4Checksum, bitio.NewIOReader(d.BitBufRange(checksumEnd, headerEnd-checksumEnd)))
_ = d.FieldMustGet("header_checksum").TryScalarFn(d.ValidateUBytes(ipv4Checksum.Sum(nil)), scalar.ActualHex)

dataLen := int64(totalLength-(ihl*4)) * 8
Expand Down
2 changes: 1 addition & 1 deletion format/jpeg/jpeg.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ func jpegDecode(d *decode.D, in any) any {
// TODO: FieldBitsLen? concat bitbuf?
chunk := d.FieldRawLen("data", d.BitsLeft())
// TODO: redo this? multi reader?
chunkBytes := d.MustReadAllBits(chunk)
chunkBytes := d.ReadAllBits(chunk)

if extendedXMP == nil {
extendedXMP = make([]byte, fullLength)
Expand Down
2 changes: 1 addition & 1 deletion format/mp4/boxes.go
Original file line number Diff line number Diff line change
Expand Up @@ -1086,7 +1086,7 @@ func init() {
d.FieldU24("flags")
systemIDBR := d.FieldRawLen("system_id", 16*8, systemIDNames)
// TODO: make nicer
systemID := d.MustReadAllBits(systemIDBR)
systemID := d.ReadAllBits(systemIDBR)
switch version {
case 0:
case 1:
Expand Down
2 changes: 1 addition & 1 deletion format/mpeg/avc_nalu.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func avcNALUDecode(d *decode.D, in any) any {
d.FieldBool("forbidden_zero_bit")
d.FieldU2("nal_ref_idc")
nalType := d.FieldU5("nal_unit_type", avcNALNames)
unescapedBR := d.MustNewBitBufFromReader(decode.NALUnescapeReader{Reader: bitio.NewIOReader(d.BitBufRange(d.Pos(), d.BitsLeft()))})
unescapedBR := d.NewBitBufFromReader(decode.NALUnescapeReader{Reader: bitio.NewIOReader(d.BitBufRange(d.Pos(), d.BitsLeft()))})

switch nalType {
case avcNALCodedSliceNonIDR,
Expand Down
2 changes: 1 addition & 1 deletion format/mpeg/hevc_nalu.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func hevcNALUDecode(d *decode.D, in any) any {
nalType := d.FieldU6("nal_unit_type", hevcNALNames)
d.FieldU6("nuh_layer_id")
d.FieldU3("nuh_temporal_id_plus1")
unescapedBR := d.MustNewBitBufFromReader(decode.NALUnescapeReader{Reader: bitio.NewIOReader(d.BitBufRange(d.Pos(), d.BitsLeft()))})
unescapedBR := d.NewBitBufFromReader(decode.NALUnescapeReader{Reader: bitio.NewIOReader(d.BitBufRange(d.Pos(), d.BitsLeft()))})

switch nalType {
case hevcNALNUTVPS:
Expand Down
4 changes: 2 additions & 2 deletions format/mpeg/mp3_frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,8 @@ func frameDecode(d *decode.D, in any) any {

crcHash := &checksum.CRC{Bits: 16, Current: 0xffff, Table: checksum.ANSI16Table}
// 2 bytes after sync and some other fields + all of side info
d.MustCopyBits(crcHash, d.BitBufRange(2*8, 2*8))
d.MustCopyBits(crcHash, d.BitBufRange(6*8, sideInfoBytes*8))
d.CopyBits(crcHash, d.BitBufRange(2*8, 2*8))
d.CopyBits(crcHash, d.BitBufRange(6*8, sideInfoBytes*8))

if crcValue != nil {
_ = crcValue.TryScalarFn(d.ValidateUBytes(crcHash.Sum(nil)))
Expand Down
2 changes: 1 addition & 1 deletion format/mpeg/mpeg_pes_packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ func pesPacketDecode(d *decode.D, in any) any {

v = subStreamPacket{
number: int(substreamNumber),
buf: d.MustReadAllBits(substreamBR),
buf: d.ReadAllBits(substreamBR),
}
})
})
Expand Down
8 changes: 4 additions & 4 deletions format/ogg/ogg_page.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,17 @@ func pageDecode(d *decode.D, in any) any {
})
d.FieldArray("segments", func(d *decode.D) {
for _, ss := range segmentTable {
bs := d.MustReadAllBits(d.FieldRawLen("segment", int64(ss)*8))
bs := d.ReadAllBits(d.FieldRawLen("segment", int64(ss)*8))
p.Segments = append(p.Segments, bs)
}
})
endPos := d.Pos()

pageChecksumValue := d.FieldGet("crc")
pageCRC := &checksum.CRC{Bits: 32, Table: checksum.Poly04c11db7Table}
d.MustCopy(pageCRC, bitio.NewIOReader(d.BitBufRange(startPos, pageChecksumValue.Range.Start-startPos))) // header before checksum
d.MustCopy(pageCRC, bytes.NewReader([]byte{0, 0, 0, 0})) // zero checksum bits
d.MustCopy(pageCRC, bitio.NewIOReader(d.BitBufRange(pageChecksumValue.Range.Stop(), endPos-pageChecksumValue.Range.Stop()))) // rest of page
d.Copy(pageCRC, bitio.NewIOReader(d.BitBufRange(startPos, pageChecksumValue.Range.Start-startPos))) // header before checksum
d.Copy(pageCRC, bytes.NewReader([]byte{0, 0, 0, 0})) // zero checksum bits
d.Copy(pageCRC, bitio.NewIOReader(d.BitBufRange(pageChecksumValue.Range.Stop(), endPos-pageChecksumValue.Range.Stop()))) // rest of page
_ = pageChecksumValue.TryScalarFn(d.ValidateUBytes(pageCRC.Sum(nil)))

return p
Expand Down
2 changes: 1 addition & 1 deletion format/pcap/pcap.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func decodePcap(d *decode.D, in any) any {
d.Errorf("incl_len %d > orig_len %d", inclLen, origLen)
}

bs := d.MustReadAllBits(d.BitBufRange(d.Pos(), int64(inclLen)*8))
bs := d.ReadAllBits(d.BitBufRange(d.Pos(), int64(inclLen)*8))

if fn, ok := linkToDecodeFn[linkType]; ok {
// TODO: report decode errors
Expand Down
2 changes: 1 addition & 1 deletion format/pcap/pcapng.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ var blockFns = map[uint64]func(d *decode.D, dc *decodeContext){
capturedLength := d.FieldU32("capture_packet_length")
d.FieldU32("original_packet_length")

bs := d.MustReadAllBits(d.BitBufRange(d.Pos(), int64(capturedLength)*8))
bs := d.ReadAllBits(d.BitBufRange(d.Pos(), int64(capturedLength)*8))

linkType := dc.interfaceTypes[int(interfaceID)]

Expand Down
2 changes: 1 addition & 1 deletion format/png/png.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ func pngDecode(d *decode.D, in any) any {
})

chunkCRC := crc32.NewIEEE()
d.MustCopy(chunkCRC, bitio.NewIOReader(d.BitBufRange(crcStartPos, d.Pos()-crcStartPos)))
d.Copy(chunkCRC, bitio.NewIOReader(d.BitBufRange(crcStartPos, d.Pos()-crcStartPos)))
d.FieldU32("crc", d.ValidateUBytes(chunkCRC.Sum(nil)), scalar.ActualHex)
})

Expand Down
2 changes: 1 addition & 1 deletion format/rtmp/rtmp.go
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ func rtmpDecode(d *decode.D, in any) any {
}

if payloadLength > 0 {
d.MustCopyBits(&m.b, d.FieldRawLen("data", payloadLength))
d.CopyBits(&m.b, d.FieldRawLen("data", payloadLength))
}

if m.l == uint64(m.b.Len()) {
Expand Down
42 changes: 21 additions & 21 deletions pkg/decode/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,49 +210,49 @@ func (d *D) FieldDecoder(name string, bitBuf bitio.ReaderAtSeeker, v any) *D {
}
}

func (d *D) CopyBits(w io.Writer, r bitio.Reader) (int64, error) {
func (d *D) TryCopyBits(w io.Writer, r bitio.Reader) (int64, error) {
// TODO: what size? now same as io.Copy
buf := d.SharedReadBuf(32 * 1024)
return bitioextra.CopyBitsBuffer(w, r, buf)
}

func (d *D) MustCopyBits(w io.Writer, r bitio.Reader) int64 {
n, err := d.CopyBits(w, r)
func (d *D) CopyBits(w io.Writer, r bitio.Reader) int64 {
n, err := d.TryCopyBits(w, r)
if err != nil {
d.IOPanic(err, "MustCopy: Copy")
d.IOPanic(err, "CopyBits: Copy")
}
return n
}

func (d *D) Copy(w io.Writer, r io.Reader) (int64, error) {
func (d *D) TryCopy(w io.Writer, r io.Reader) (int64, error) {
// TODO: what size? now same as io.Copy
buf := d.SharedReadBuf(32 * 1024)
return io.CopyBuffer(w, r, buf)
}

func (d *D) MustCopy(w io.Writer, r io.Reader) int64 {
n, err := d.Copy(w, r)
func (d *D) Copy(w io.Writer, r io.Reader) int64 {
n, err := d.TryCopy(w, r)
if err != nil {
d.IOPanic(err, "MustCopy: Copy")
d.IOPanic(err, "Copy")
}
return n
}

func (d *D) MustCloneReadSeeker(br bitio.ReadSeeker) bitio.ReadSeeker {
func (d *D) CloneReadSeeker(br bitio.ReadSeeker) bitio.ReadSeeker {
br, err := bitio.CloneReadSeeker(br)
if err != nil {
d.IOPanic(err, "MustClone")
d.IOPanic(err, "CloneReadSeeker")
}
return br
}

func (d *D) MustNewBitBufFromReader(r io.Reader) bitio.ReaderAtSeeker {
func (d *D) NewBitBufFromReader(r io.Reader) bitio.ReaderAtSeeker {
b := &bytes.Buffer{}
d.MustCopy(b, r)
d.Copy(b, r)
return bitio.NewBitReader(b.Bytes(), -1)
}

func (d *D) ReadAllBits(r bitio.Reader) ([]byte, error) {
func (d *D) TryReadAllBits(r bitio.Reader) ([]byte, error) {
bb := &bytes.Buffer{}
buf := d.SharedReadBuf(32 * 1024)
if _, err := bitioextra.CopyBitsBuffer(bb, r, buf); err != nil {
Expand All @@ -261,8 +261,8 @@ func (d *D) ReadAllBits(r bitio.Reader) ([]byte, error) {
return bb.Bytes(), nil
}

func (d *D) MustReadAllBits(r bitio.Reader) []byte {
buf, err := d.ReadAllBits(r)
func (d *D) ReadAllBits(r bitio.Reader) []byte {
buf, err := d.TryReadAllBits(r)
if err != nil {
d.IOPanic(err, "Bytes ReadAllBytes")
}
Expand Down Expand Up @@ -340,7 +340,7 @@ func (d *D) IOPanic(err error, op string) {
}

// Bits reads nBits bits from buffer
func (d *D) bits(nBits int) (uint64, error) {
func (d *D) TryBits(nBits int) (uint64, error) {
if nBits < 0 || nBits > 64 {
return 0, fmt.Errorf("nBits must be 0-64 (%d)", nBits)
}
Expand All @@ -355,12 +355,12 @@ func (d *D) bits(nBits int) (uint64, error) {
}

// Bits reads nBits bits from buffer
func (d *D) Bits(nBits int) (uint64, error) {
n, err := d.bits(nBits)
func (d *D) Bits(nBits int) uint64 {
n, err := d.TryBits(nBits)
if err != nil {
return 0, err
panic(IOError{Err: err, Op: "Bits", ReadSize: int64(nBits), Pos: d.Pos()})
}
return n, nil
return n
}

func (d *D) PeekBits(nBits int) uint64 {
Expand Down Expand Up @@ -419,7 +419,7 @@ func (d *D) TryPeekBits(nBits int) (uint64, error) {
if err != nil {
return 0, err
}
n, err := d.bits(nBits)
n, err := d.TryBits(nBits)
if _, err := d.bitBuf.SeekBits(start, io.SeekStart); err != nil {
return 0, err
}
Expand Down
12 changes: 6 additions & 6 deletions pkg/decode/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func (d *D) tryUEndian(nBits int, endian Endian) (uint64, error) {
if nBits < 0 {
return 0, fmt.Errorf("tryUEndian nBits must be >= 0 (%d)", nBits)
}
n, err := d.bits(nBits)
n, err := d.TryBits(nBits)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -88,7 +88,7 @@ func (d *D) tryFEndian(nBits int, endian Endian) (float64, error) {
if nBits < 0 {
return 0, fmt.Errorf("tryFEndian nBits must be >= 0 (%d)", nBits)
}
n, err := d.bits(nBits)
n, err := d.TryBits(nBits)
if err != nil {
return 0, err
}
Expand All @@ -111,7 +111,7 @@ func (d *D) tryFPEndian(nBits int, fBits int, endian Endian) (float64, error) {
if nBits < 0 {
return 0, fmt.Errorf("tryFPEndian nBits must be >= 0 (%d)", nBits)
}
n, err := d.bits(nBits)
n, err := d.TryBits(nBits)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -158,7 +158,7 @@ func (d *D) tryTextLenPrefixed(lenBits int, fixedBytes int, e encoding.Encoding)
}

p := d.Pos()
l, err := d.bits(lenBits)
l, err := d.TryBits(lenBits)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -226,7 +226,7 @@ func (d *D) tryUnary(ov uint64) (uint64, error) {
p := d.Pos()
var n uint64
for {
b, err := d.bits(1)
b, err := d.TryBits(1)
if err != nil {
d.SeekAbs(p)
return 0, err
Expand All @@ -240,7 +240,7 @@ func (d *D) tryUnary(ov uint64) (uint64, error) {
}

func (d *D) tryBool() (bool, error) {
n, err := d.bits(1)
n, err := d.TryBits(1)
if err != nil {
return false, err
}
Expand Down

0 comments on commit f96637f

Please sign in to comment.