Skip to content

Commit

Permalink
Implement vector merge operator (#5574)
Browse files Browse the repository at this point in the history
  • Loading branch information
nwt authored Jan 15, 2025
1 parent 876a5a3 commit 7b5a68b
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 9 deletions.
14 changes: 7 additions & 7 deletions compiler/kernel/vop.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ func (b *Builder) compileVam(o dag.Op, parents []vector.Puller) ([]vector.Puller
case *dag.Join:
// see sam version for ref
case *dag.Merge:
//e, err := b.compileVamExpr(o.Expr)
//if err != nil {
// return nil, err
//}
//XXX this needs to be native
//cmp := vamexpr.NewComparator(true, o.Order == order.Desc, e).WithMissingAsNull()
//return []vector.Puller{vamop.NewMerge(b.rctx, parents, cmp.Compare)}, nil
b.resetResetters()
e, err := b.compileExpr(o.Expr)
if err != nil {
return nil, err
}
cmp := expr.NewComparator(true, expr.NewSortEvaluator(e, o.Order)).WithMissingAsNull()
return []vector.Puller{vamop.NewMerge(b.rctx, parents, cmp.Compare)}, nil
case *dag.Scatter:
return b.compileVamScatter(o, parents)
case *dag.Scope:
Expand Down
3 changes: 1 addition & 2 deletions compiler/optimizer/parallelize.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ func (o *Optimizer) parallelizeFileScan(seq dag.Seq, replicas int) (dag.Seq, err
}
if n < len(seq) {
switch seq[n].(type) {
// TODO: Add dag.Sort when the vector runtime implements dag.Merge.
case *dag.Summarize:
case *dag.Sort, *dag.Summarize:
return parallelizeHead(seq, n, outputKeys, replicas), nil
}
}
Expand Down
211 changes: 211 additions & 0 deletions runtime/vam/op/merge.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
package op

import (
"container/heap"
"context"
"sync"

"github.com/brimdata/super"
samexpr "github.com/brimdata/super/runtime/sam/expr"
"github.com/brimdata/super/vector"
"github.com/brimdata/super/zcode"
)

type Merge struct {
ctx context.Context
cmp samexpr.CompareFn

heap []*mergeParent
index []uint32
once sync.Once
parents []*mergeParent
}

func NewMerge(ctx context.Context, parents []vector.Puller, cmp samexpr.CompareFn) *Merge {
var mergeParents []*mergeParent
for i, p := range parents {
mergeParents = append(mergeParents, &mergeParent{
ctx: ctx,
parent: p,
resultCh: make(chan result),
doneCh: make(chan struct{}),
tag: uint32(i),
})
}
return &Merge{
ctx: ctx,
cmp: cmp,
parents: mergeParents,
}
}

func (m *Merge) Pull(done bool) (vector.Any, error) {
var err error
m.once.Do(func() {
for _, p := range m.parents {
go p.run()
}
err = m.start()
})
if err != nil {
return nil, err
}
if done || m.Len() == 0 {
for _, parent := range m.heap {
select {
case parent.doneCh <- struct{}{}:
case <-m.ctx.Done():
return nil, m.ctx.Err()
}
}
// Restart parents and return EOS.
return nil, m.start()
}
tags := make([]uint32, 0, 2048)
for {
tag, endOfVector := m.nextTag()
tags = append(tags, tag)
var views []vector.Any
if endOfVector || len(tags) == cap(tags) {
views = m.createViews()
}
if err := m.updateHeap(); err != nil {
return nil, err
}
if len(views) > 0 {
return vector.NewDynamic(tags, views), nil
}
}
}

func (m *Merge) start() error {
m.heap = m.heap[:0]
for _, parent := range m.parents {
ok, err := parent.replenish()
if err != nil {
return err
}
if ok {
heap.Push(m, parent)
}
}
return nil
}

func (m *Merge) nextTag() (tag uint32, endOfVector bool) {
min := m.heap[0]
min.off++
return min.tag, min.off >= min.vec.Len()
}

func (m *Merge) updateHeap() error {
min := m.heap[0]
if min.off < min.vec.Len() {
min.updateVal()
heap.Fix(m, 0)
return nil
}
ok, err := min.replenish()
if err != nil {
return err
}
if !ok {
heap.Pop(m)
}
heap.Fix(m, 0)
return nil
}

func (m *Merge) createViews() []vector.Any {
views := make([]vector.Any, len(m.parents))
for i, p := range m.parents {
if p.vec == nil || p.off == p.lastOff {
continue
}
if int(p.off) >= len(m.index) {
m.index = make([]uint32, p.off)
for i := range m.index {
m.index[i] = uint32(i)
}
}
index := m.index[p.lastOff:p.off]
views[i] = vector.NewView(p.vec, index)
p.lastOff = p.off
}
return views
}

func (m *Merge) Len() int { return len(m.heap) }
func (m *Merge) Less(i, j int) bool { return m.cmp(m.heap[i].val, m.heap[j].val) < 0 }
func (m *Merge) Swap(i, j int) { m.heap[i], m.heap[j] = m.heap[j], m.heap[i] }
func (m *Merge) Push(x any) { m.heap = append(m.heap, x.(*mergeParent)) }

func (m *Merge) Pop() any {
x := m.heap[m.Len()-1]
m.heap = m.heap[:m.Len()-1]
return x
}

type mergeParent struct {
ctx context.Context
parent vector.Puller
resultCh chan result
doneCh chan struct{}
tag uint32

vec vector.Any
off uint32
lastOff uint32
builder zcode.Builder
val super.Value
}

func (m *mergeParent) run() {
for {
vec, err := m.parent.Pull(false)
Select:
select {
case m.resultCh <- result{vec, err}:
case <-m.doneCh:
vec, err = m.parent.Pull(true)
if err != nil {
// Send err downstream.
goto Select
}
case <-m.ctx.Done():
return
}
}
}

// replenish tries to receive the next vector. It returns false when EOS
// is encountered and its goroutine will then block until resumed or
// canceled.
func (m *mergeParent) replenish() (bool, error) {
select {
case r := <-m.resultCh:
if r.vector == nil || r.err != nil {
m.vec = nil
return false, r.err
}
m.vec = r.vector
m.off = 0
m.lastOff = 0
m.updateVal()
return true, nil
case <-m.ctx.Done():
return false, m.ctx.Err()
}
}

func (m *mergeParent) updateVal() {
var typ super.Type
if dynVec, ok := m.vec.(*vector.Dynamic); ok {
typ = dynVec.TypeOf(m.off)
} else {
typ = m.vec.Type()
}
m.builder.Truncate()
m.vec.Serialize(&m.builder, m.off)
m.val = super.NewValue(typ, m.builder.Bytes().Body())
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ zed: |
=> x > 2
) |> merge x
vector: true

input: |
{x:1}
{x:3}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
zed: fork (=>pass =>pass) |> merge this

vector: true

input: 1 2

output: |
Expand Down

0 comments on commit 7b5a68b

Please sign in to comment.