Skip to content

Commit

Permalink
Adding a timeout to fix infinite looping on corrupt bdb database files.
Browse files Browse the repository at this point in the history
  • Loading branch information
erikvarga committed Feb 8, 2024
1 parent a8af76a commit 9d196a3
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 7 deletions.
4 changes: 3 additions & 1 deletion pkg/bdb/bdb.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package bdb

import (
"context"
"io"
"os"

Expand Down Expand Up @@ -61,7 +62,7 @@ func (db *BerkeleyDB) Close() error {
return db.file.Close()
}

func (db *BerkeleyDB) Read() <-chan dbi.Entry {
func (db *BerkeleyDB) Read(ctx context.Context) <-chan dbi.Entry {
entries := make(chan dbi.Entry)

go func() {
Expand Down Expand Up @@ -118,6 +119,7 @@ func (db *BerkeleyDB) Read() <-chan dbi.Entry {

// Traverse the page to concatenate the data that may span multiple pages.
valueContent, err := HashPageValueContent(
ctx,
db.file,
pageData,
hashPageIndex,
Expand Down
9 changes: 8 additions & 1 deletion pkg/bdb/hash_page.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package bdb

import (
"bytes"
"context"
"encoding/binary"
"io"
"os"
Expand Down Expand Up @@ -31,7 +32,7 @@ func ParseHashPage(data []byte, swapped bool) (*HashPage, error) {
return &hashPage, nil
}

func HashPageValueContent(db *os.File, pageData []byte, hashPageIndex uint16, pageSize uint32, swapped bool) ([]byte, error) {
func HashPageValueContent(ctx context.Context, db *os.File, pageData []byte, hashPageIndex uint16, pageSize uint32, swapped bool) ([]byte, error) {
// the first byte is the page type, so we can peek at it first before parsing further...
valuePageType := pageData[hashPageIndex]

Expand All @@ -50,6 +51,12 @@ func HashPageValueContent(db *os.File, pageData []byte, hashPageIndex uint16, pa
var hashValue []byte

for currentPageNo := entry.PageNo; currentPageNo != 0; {
select {
case <-ctx.Done():
return nil, xerrors.Errorf("timed out parsing hash page")
default:
}

pageStart := pageSize * currentPageNo

_, err := db.Seek(int64(pageStart), io.SeekStart)
Expand Down
4 changes: 3 additions & 1 deletion pkg/db/rpmdbinterface.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package dbi

import "context"

type Entry struct {
Value []byte
Err error
}

type RpmDBInterface interface {
Read() <-chan Entry
Read(ctx context.Context) <-chan Entry
Close() error
}
3 changes: 2 additions & 1 deletion pkg/ndb/ndb.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ SOFTWARE.
package ndb

import (
"context"
"encoding/binary"
"fmt"
"io"
Expand Down Expand Up @@ -136,7 +137,7 @@ func (db *RpmNDB) Close() error {
return db.file.Close()
}

func (db *RpmNDB) Read() <-chan dbi.Entry {
func (db *RpmNDB) Read(ctx context.Context) <-chan dbi.Entry {
entries := make(chan dbi.Entry)

go func() {
Expand Down
14 changes: 12 additions & 2 deletions pkg/rpmdb.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package rpmdb

import (
"context"

"github.com/knqyf263/go-rpmdb/pkg/bdb"
dbi "github.com/knqyf263/go-rpmdb/pkg/db"
"github.com/knqyf263/go-rpmdb/pkg/ndb"
Expand Down Expand Up @@ -47,7 +49,11 @@ func (d *RpmDB) Close() error {
}

func (d *RpmDB) Package(name string) (*PackageInfo, error) {
pkgs, err := d.ListPackages()
return d.PackageWithContext(context.TODO(), name)
}

func (d *RpmDB) PackageWithContext(ctx context.Context, name string) (*PackageInfo, error) {
pkgs, err := d.ListPackagesWithContext(ctx)
if err != nil {
return nil, xerrors.Errorf("unable to list packages: %w", err)
}
Expand All @@ -61,9 +67,13 @@ func (d *RpmDB) Package(name string) (*PackageInfo, error) {
}

func (d *RpmDB) ListPackages() ([]*PackageInfo, error) {
return d.ListPackagesWithContext(context.TODO())
}

func (d *RpmDB) ListPackagesWithContext(ctx context.Context) ([]*PackageInfo, error) {
var pkgList []*PackageInfo

for entry := range d.db.Read() {
for entry := range d.db.Read(ctx) {
if entry.Err != nil {
return nil, entry.Err
}
Expand Down
15 changes: 15 additions & 0 deletions pkg/rpmdb_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package rpmdb

import (
"context"
"os"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -824,3 +826,16 @@ func TestNevra(t *testing.T) {
_, err = pkg.InstalledFiles()
require.Error(t, err)
}

func TestTimeoutPackages(t *testing.T) {
db, err := Open("testdata/centos7-many/Packages")
require.NoError(t, err)
ctxTimesOut, cancelFunc := context.WithTimeout(context.Background(), 1*time.Microsecond)
defer cancelFunc()
_, err = db.ListPackagesWithContext(ctxTimesOut)
if err == nil {
t.Errorf("Expected timeout parsing hash page")
} else {
assert.Equal(t, "timed out parsing hash page", err.Error())
}
}
3 changes: 2 additions & 1 deletion pkg/sqlite3/sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sqlite3

import (
"bytes"
"context"
"database/sql"
"encoding/binary"
"os"
Expand Down Expand Up @@ -44,7 +45,7 @@ func Open(path string) (*SQLite3, error) {
return &SQLite3{db}, nil
}

func (db *SQLite3) Read() <-chan dbi.Entry {
func (db *SQLite3) Read(ctx context.Context) <-chan dbi.Entry {
entries := make(chan dbi.Entry)

go func() {
Expand Down

0 comments on commit 9d196a3

Please sign in to comment.