Skip to content

Commit

Permalink
fixed heap updates
Browse files Browse the repository at this point in the history
  • Loading branch information
travierm committed Apr 22, 2024
1 parent 1b33f2c commit 0381a7d
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 15 deletions.
31 changes: 25 additions & 6 deletions heap.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func NewHeap[T any](tableName string, dataDir string) *Heap[T] {
pageSize: pageSize,
pages: make([]*Page, 0),
dataDir: dataDir,
wal: NewWAL(fmt.Sprintf("storage/%s.wal", tableName)),
wal: NewWAL(fmt.Sprintf("%s/%s.wal", dataDir, tableName)),
}
}

Expand Down Expand Up @@ -102,11 +102,23 @@ func (h *Heap[T]) Update(record *Record[T]) error {
Data: dataBytes,
})

// Update the record data size
binary.BigEndian.PutUint32(page.data[offset+8:offset+12], uint32(len(dataBytes)))

// Update the record data
copy(page.data[offset+12:offset+12+int(dataSize)], dataBytes)
// Check if the updated data size is different from the original size
if len(dataBytes) != int(dataSize) {
// Remove the original record
copy(page.data[offset:], page.data[offset+12+int(dataSize):])
page.data = page.data[:len(page.data)-12-int(dataSize)]

// Insert the updated record as a new record
updatedRecordBytes := make([]byte, 12+len(dataBytes))
binary.BigEndian.PutUint64(updatedRecordBytes[0:8], record.ID)
binary.BigEndian.PutUint32(updatedRecordBytes[8:12], uint32(len(dataBytes)))
copy(updatedRecordBytes[12:], dataBytes)

page.data = append(page.data, updatedRecordBytes...)
} else {
// Update the record data in-place
copy(page.data[offset+12:offset+12+int(dataSize)], dataBytes)
}

return nil
}
Expand Down Expand Up @@ -263,6 +275,8 @@ func (h *Heap[T]) Flush() error {
}
}

h.wal.Flush()

return nil
}

Expand Down Expand Up @@ -378,9 +392,14 @@ func (h *Heap[T]) Recover(walFilePath string) error {
} else {
err = h.Update(record)
}

if err != nil {
return err
}
case "DELETE":
err = h.Delete(entry.RecordID)
}

if err != nil {
return err
}
Expand Down
49 changes: 42 additions & 7 deletions heap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,26 @@ type TestRecord struct {
Amount float64
}

func TestHeap(t *testing.T) {
func TestHeapCrud(t *testing.T) {
heap := NewHeap[TestRecord]("products", "storage/test")

err := heap.Insert(&Record[TestRecord]{ID: 1, Data: TestRecord{Name: "Product 1", Amount: 100.0}})
err = heap.Insert(&Record[TestRecord]{ID: 2, Data: TestRecord{Name: "Product 2", Amount: 150.0}})
err = heap.Insert(&Record[TestRecord]{ID: 3, Data: TestRecord{Name: "Product 3", Amount: 150.0}})
err = heap.Insert(&Record[TestRecord]{ID: 3, Data: TestRecord{Name: "Product 3", Amount: 200.0}})
err = heap.Update(&Record[TestRecord]{ID: 2, Data: TestRecord{Name: "Updated Product", Amount: 200.0}})
heap.Flush()

newHeap := NewHeap[TestRecord]("products", "storage/test")
newHeap.Fill()
record, err := newHeap.FindByID(2)
insertedRecord, err := heap.FindByID(1)
if err != nil {
t.Error(err)
}
assert.Equal(t, "Product 1", insertedRecord.Data.Name)

assert.Equal(t, "Product 2", record.Data.Name)
assert.Equal(t, 3, len(heap.wal.entries))
updatedRecord, err := heap.FindByID(2)
if err != nil {
t.Error(err)
}
assert.Equal(t, "Updated Product", updatedRecord.Data.Name)

ClearTestFolder()
}
Expand All @@ -56,6 +59,38 @@ func TestFindByIdLargeDataset(t *testing.T) {
ClearTestFolder()
}

func TestCanRecoverFromWAL(t *testing.T) {
heap := NewHeap[TestRecord]("products", "storage/test")
err := heap.Insert(&Record[TestRecord]{ID: 1, Data: TestRecord{Name: "Product 1", Amount: 100.0}})
err = heap.Insert(&Record[TestRecord]{ID: 2, Data: TestRecord{Name: "Product 2", Amount: 200.0}})
err = heap.Insert(&Record[TestRecord]{ID: 3, Data: TestRecord{Name: "Product 3", Amount: 220.0}})
err = heap.Update(&Record[TestRecord]{ID: 2, Data: TestRecord{Name: "Updated Product", Amount: 120.0}})
err = heap.Delete(3)

if err != nil {
t.Error(err)
}

heap.wal.Flush()

recoveredHeap := NewHeap[TestRecord]("products2", "storage/test")
recoveredHeap.Recover(heap.wal.path)

firstRecord, _ := recoveredHeap.FindByID(1)
updatedRecord, _ := recoveredHeap.FindByID(2)
deletedRecord, _ := recoveredHeap.FindByID(3)

assert.Equal(t, "Product 1", firstRecord.Data.Name)
assert.Equal(t, 100.0, firstRecord.Data.Amount)

assert.Equal(t, "Updated Product", updatedRecord.Data.Name)
assert.Equal(t, 120.0, updatedRecord.Data.Amount)

assert.Nil(t, deletedRecord)

ClearTestFolder()
}

func BenchmarkFindById(b *testing.B) {
// create 100 records
heap := NewHeap[TestRecord]("products", "storage/test")
Expand Down
27 changes: 25 additions & 2 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,35 @@ package main
import (
"fmt"
"os"
"path/filepath"
)

func ClearTestFolder() {
folder := "storage/test"
err := os.RemoveAll(folder)

// Delete .bin files
binFiles, err := filepath.Glob(filepath.Join(folder, "*.bin"))
if err != nil {
fmt.Printf("Error finding .bin files in '%s': %v\n", folder, err)
return
}
for _, file := range binFiles {
err = os.Remove(file)
if err != nil {
fmt.Printf("Error deleting file '%s': %v\n", file, err)
}
}

// Delete .wal files
walFiles, err := filepath.Glob(filepath.Join(folder, "*.wal"))
if err != nil {
fmt.Printf("Error deleting folder '%s': %v\n", folder, err)
fmt.Printf("Error finding .wal files in '%s': %v\n", folder, err)
return
}
for _, file := range walFiles {
err = os.Remove(file)
if err != nil {
fmt.Printf("Error deleting file '%s': %v\n", file, err)
}
}
}

0 comments on commit 0381a7d

Please sign in to comment.