Skip to content

Commit

Permalink
added after effects to pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleSmith19091 committed Oct 26, 2024
1 parent 4450132 commit 11ec6f5
Show file tree
Hide file tree
Showing 12 changed files with 262 additions and 105 deletions.
11 changes: 9 additions & 2 deletions etl/mongo/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package mongo
import (
"context"

"github.com/meshtrade/mesh-etl/etl/pipeline"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
Expand All @@ -21,7 +22,7 @@ func NewMongoCollector[T any](collection mongo.Collection, query bson.D, opts ..
}
}

func (m *MongoCollector[T]) Collect(ctx context.Context) ([]T, error) {
func (m *MongoCollector[T]) Collect(ctx context.Context, pipelineState *pipeline.PipelineState) (chan T, error) {
cursor, err := m.collection.Find(
ctx,
m.query,
Expand All @@ -36,5 +37,11 @@ func (m *MongoCollector[T]) Collect(ctx context.Context) ([]T, error) {
return nil, err
}

return records, nil
outputChannel := make(chan T, len(records))
for _, record := range records {
outputChannel <- record
}
close(outputChannel)

return outputChannel, nil
}
16 changes: 14 additions & 2 deletions etl/parquet/serialiser.go → etl/parquet/stage.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/apache/arrow/go/v17/arrow/memory"
"github.com/apache/arrow/go/v17/parquet"
"github.com/apache/arrow/go/v17/parquet/pqarrow"
"github.com/meshtrade/mesh-etl/etl/pipeline"
"github.com/rs/zerolog/log"
)

Expand Down Expand Up @@ -97,7 +98,13 @@ func buildArrowFieldsAndBuilders(pool memory.Allocator, elemType reflect.Type) (
return arrowFields, fieldBuilders, nil
}

func (s *ParquetSerialiser[T]) Marshal(ctx context.Context, inputStruct []T) ([]byte, error) {
func (s *ParquetSerialiser[T]) Serialise(ctx context.Context, p *pipeline.PipelineState, inChannel chan T) (chan []byte, error) {
// collect values from channel
inputStruct := []T{}
for inValue := range inChannel {
inputStruct = append(inputStruct, inValue)
}

// get the reflection value of the input slice
timeType := reflect.TypeOf(time.Time{})

Expand Down Expand Up @@ -166,7 +173,12 @@ func (s *ParquetSerialiser[T]) Marshal(ctx context.Context, inputStruct []T) ([]
// NOTE: NEVER call close in defer function!
pw.Close()

return dataBuffer.Bytes(), nil
// load value into output channel
outputChannel := make(chan []byte, 1)
outputChannel <- dataBuffer.Bytes()
close(outputChannel)

return outputChannel, nil
}

func (s *ParquetSerialiser[T]) appendStructValues(builder *array.StructBuilder, structVal reflect.Value) error {
Expand Down
12 changes: 9 additions & 3 deletions etl/pipeline/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ type Pipeline[T, V any] struct {
source source[T]
stage stage[T, V]
sink sink[V]
state *PipelineState
}

func NewPipeline[T, V any](
Expand All @@ -17,19 +18,24 @@ func NewPipeline[T, V any](
source: source,
stage: stage,
sink: sink,
state: NewPipelineState(),
}
}

func (p *Pipeline[T, V]) Execute(ctx context.Context) error {
sourceChannel, err := p.source(ctx)
sourceChannel, err := p.source(ctx, p.state)
if err != nil {
return err
}

stageChannel, err := p.stage(ctx, sourceChannel)
stageChannel, err := p.stage(ctx, p.state, sourceChannel)
if err != nil {
return err
}

return p.sink(ctx, stageChannel)
if err := p.sink(ctx, p.state, stageChannel); err != nil {
return err
}

return p.state.RunAfterEffects(ctx)
}
41 changes: 41 additions & 0 deletions etl/pipeline/pipelineState.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package pipeline

import (
"context"

"golang.org/x/sync/errgroup"
)

type AfterEffect func(ctx context.Context) error

type PipelineState struct {
afterEffects []AfterEffect
}

func NewPipelineState() *PipelineState {
return &PipelineState{
afterEffects: []AfterEffect{},
}
}

func (p *PipelineState) RegisterAfterEffect(afterEffect AfterEffect) {
p.afterEffects = append(p.afterEffects, afterEffect)
}

func (p *PipelineState) RunAfterEffects(ctx context.Context) error {
errGroup := new(errgroup.Group)

for _, afterEffect := range p.afterEffects {
errGroup.Go(func() error {
return afterEffect(ctx)
})
}

if err := errGroup.Wait(); err != nil {
return err
}

p.afterEffects = []AfterEffect{}

return nil
}
48 changes: 35 additions & 13 deletions etl/pipeline/sink.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
package pipeline

import "context"
import (
"context"

type sink[T any] func(context.Context, chan T) error
"golang.org/x/sync/errgroup"
)

func Emit[T any](emitFunc func(context.Context, T) error) sink[T] {
return func(ctx context.Context, inChannel chan T) error {
type sink[T any] func(context.Context, *PipelineState, chan T) error

func Emit[T any](emitFunc func(context.Context, *PipelineState, T) error) sink[T] {
return func(ctx context.Context, p *PipelineState, inChannel chan T) error {
// collect values from channel
for inValue := range inChannel {
if err := emitFunc(ctx, inValue); err != nil {
if err := emitFunc(ctx, p, inValue); err != nil {
return err
}
}
Expand All @@ -18,10 +22,13 @@ func Emit[T any](emitFunc func(context.Context, T) error) sink[T] {
}

func Spread[T any](sinks ...sink[T]) sink[T] {
// create list of channels for sinks
// allocate list of channels for sinks
sinkChannels := make([]chan T, len(sinks))

return func(ctx context.Context, inChannel chan T) error {
// prepare err group
var errGroup = new(errgroup.Group)

return func(ctx context.Context, p *PipelineState, inChannel chan T) error {
// collect input values from channel
inValues := make([]T, len(inChannel))
idx := 0
Expand All @@ -32,19 +39,34 @@ func Spread[T any](sinks ...sink[T]) sink[T] {

// construct sink channel for each sink and load input values
for range sinks {
// allocate sink channel to hold values
sinkChannel := make(chan T, len(inChannel))

// load input values into new sink channel
for _, inValue := range inValues {
sinkChannel <- inValue
}

// close channel to indicate no more data will be sent
close(sinkChannel)

// add channel to list of channels
sinkChannels = append(sinkChannels, sinkChannel)
}

// execute sinks
// execute sinks concurrently
for idx, sink := range sinks {
if err := sink(ctx, sinkChannels[idx]); err != nil {
return err
}
errGroup.Go(func() error {
if err := sink(ctx, p, sinkChannels[idx]); err != nil {
return err
}
return nil
})
}

// wait for sinks
if err := errGroup.Wait(); err != nil {
return err
}

return nil
Expand All @@ -54,9 +76,9 @@ func Spread[T any](sinks ...sink[T]) sink[T] {
func SequenceSink[T any](
sinks ...sink[T],
) sink[T] {
return func(ctx context.Context, c chan T) error {
return func(ctx context.Context, p *PipelineState, c chan T) error {
for _, sink := range sinks {
if err := sink(ctx, c); err != nil {
if err := sink(ctx, p, c); err != nil {
return err
}
}
Expand Down
34 changes: 17 additions & 17 deletions etl/pipeline/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ import (
"context"
)

type source[T any] func(ctx context.Context) (chan T, error)
type chainedSource[T, V any] func(context.Context, chan T) (chan V, error)
type source[T any] func(context.Context, *PipelineState) (chan T, error)
type chainedSource[T, V any] func(context.Context, *PipelineState, chan T) (chan V, error)

func SourceBatch[T any](sourceFunc func(context.Context) ([]T, error)) source[T] {
func SourceBatch[T any](sourceFunc func(context.Context, *PipelineState) ([]T, error)) source[T] {
// return function to be executed during pipeline execution
return func(ctx context.Context) (chan T, error) {
return func(ctx context.Context, p *PipelineState) (chan T, error) {
// execute source
batch, err := sourceFunc(ctx)
batch, err := sourceFunc(ctx, p)
if err != nil {
return nil, err
}
Expand All @@ -29,16 +29,16 @@ func SourceBatch[T any](sourceFunc func(context.Context) ([]T, error)) source[T]
}
}

func ChainedSourceBatch[T, V any](sourceFunc func(context.Context, []T) ([]V, error)) chainedSource[T, V] {
return func(ctx context.Context, inChannel chan T) (chan V, error) {
func ChainedSourceBatch[T, V any](sourceFunc func(context.Context, *PipelineState, []T) ([]V, error)) chainedSource[T, V] {
return func(ctx context.Context, p *PipelineState, inChannel chan T) (chan V, error) {
inValues := make([]T, len(inChannel))
idx := 0
for inValue := range inChannel {
inValues[idx] = inValue
idx++
}

outValues, err := sourceFunc(ctx, inValues)
outValues, err := sourceFunc(ctx, p, inValues)
if err != nil {
return nil, err
}
Expand All @@ -53,14 +53,14 @@ func ChainedSourceBatch[T, V any](sourceFunc func(context.Context, []T) ([]V, er
}
}

func SourceScalar[T any](source func(context.Context) (T, error)) source[T] {
func SourceScalar[T any](source func(context.Context, *PipelineState) (T, error)) source[T] {
// return function to be executed during pipeline execution
return func(ctx context.Context) (chan T, error) {
return func(ctx context.Context, p *PipelineState) (chan T, error) {
// create synchronous channel
scalarChan := make(chan T, 1)

// execute source
scalar, err := source(ctx)
scalar, err := source(ctx, p)
if err != nil {
return nil, err
}
Expand All @@ -73,8 +73,8 @@ func SourceScalar[T any](source func(context.Context) (T, error)) source[T] {
}
}

func ChainedSourceScalar[T, V any](sourceFunc func(context.Context, T) ([]V, error)) chainedSource[T, V] {
return func(ctx context.Context, inChannel chan T) (chan V, error) {
func ChainedSourceScalar[T, V any](sourceFunc func(context.Context, *PipelineState, T) ([]V, error)) chainedSource[T, V] {
return func(ctx context.Context, p *PipelineState, inChannel chan T) (chan V, error) {
// reading all values to prevent leak
inValues := make([]T, len(inChannel))
idx := 0
Expand All @@ -84,7 +84,7 @@ func ChainedSourceScalar[T, V any](sourceFunc func(context.Context, T) ([]V, err
}

// call source with first inValue
outValues, err := sourceFunc(ctx, inValues[0])
outValues, err := sourceFunc(ctx, p, inValues[0])
if err != nil {
return nil, err
}
Expand All @@ -101,15 +101,15 @@ func ChainedSourceScalar[T, V any](sourceFunc func(context.Context, T) ([]V, err
}

func SequenceSource[T, V any](source1 source[T], source2 chainedSource[T, V]) source[V] {
return func(ctx context.Context) (chan V, error) {
return func(ctx context.Context, p *PipelineState) (chan V, error) {
// execute source1 to obtain handle to channel
source1Chan, err := source1(ctx)
source1Chan, err := source1(ctx, p)
if err != nil {
return nil, err
}

// execute source2 given source1
source2Chan, err := source2(ctx, source1Chan)
source2Chan, err := source2(ctx, p, source1Chan)

// load source2 data into channel
chainChannel := make(chan V, len(source2Chan))
Expand Down
Loading

0 comments on commit 11ec6f5

Please sign in to comment.