diff --git a/etl/mongo/source.go b/etl/mongo/source.go index 55af45e..256df82 100644 --- a/etl/mongo/source.go +++ b/etl/mongo/source.go @@ -3,6 +3,7 @@ package mongo import ( "context" + "github.com/meshtrade/mesh-etl/etl/pipeline" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" @@ -21,7 +22,7 @@ func NewMongoCollector[T any](collection mongo.Collection, query bson.D, opts .. } } -func (m *MongoCollector[T]) Collect(ctx context.Context) ([]T, error) { +func (m *MongoCollector[T]) Collect(ctx context.Context, pipelineState *pipeline.PipelineState) (chan T, error) { cursor, err := m.collection.Find( ctx, m.query, @@ -36,5 +37,11 @@ func (m *MongoCollector[T]) Collect(ctx context.Context) ([]T, error) { return nil, err } - return records, nil + outputChannel := make(chan T, len(records)) + for _, record := range records { + outputChannel <- record + } + close(outputChannel) + + return outputChannel, nil } diff --git a/etl/parquet/serialiser.go b/etl/parquet/stage.go similarity index 92% rename from etl/parquet/serialiser.go rename to etl/parquet/stage.go index 843e4fd..2c7692d 100644 --- a/etl/parquet/serialiser.go +++ b/etl/parquet/stage.go @@ -12,6 +12,7 @@ import ( "github.com/apache/arrow/go/v17/arrow/memory" "github.com/apache/arrow/go/v17/parquet" "github.com/apache/arrow/go/v17/parquet/pqarrow" + "github.com/meshtrade/mesh-etl/etl/pipeline" "github.com/rs/zerolog/log" ) @@ -97,7 +98,13 @@ func buildArrowFieldsAndBuilders(pool memory.Allocator, elemType reflect.Type) ( return arrowFields, fieldBuilders, nil } -func (s *ParquetSerialiser[T]) Marshal(ctx context.Context, inputStruct []T) ([]byte, error) { +func (s *ParquetSerialiser[T]) Serialise(ctx context.Context, p *pipeline.PipelineState, inChannel chan T) (chan []byte, error) { + // collect values from channel + inputStruct := []T{} + for inValue := range inChannel { + inputStruct = append(inputStruct, inValue) + } + // get the reflection value of the input slice timeType := reflect.TypeOf(time.Time{}) @@ -166,7 +173,12 @@ func (s *ParquetSerialiser[T]) Marshal(ctx context.Context, inputStruct []T) ([] // NOTE: NEVER call close in defer function! pw.Close() - return dataBuffer.Bytes(), nil + // load value into output channel + outputChannel := make(chan []byte, 1) + outputChannel <- dataBuffer.Bytes() + close(outputChannel) + + return outputChannel, nil } func (s *ParquetSerialiser[T]) appendStructValues(builder *array.StructBuilder, structVal reflect.Value) error { diff --git a/etl/pipeline/pipeline.go b/etl/pipeline/pipeline.go index bdcc9d6..1690afb 100644 --- a/etl/pipeline/pipeline.go +++ b/etl/pipeline/pipeline.go @@ -6,6 +6,7 @@ type Pipeline[T, V any] struct { source source[T] stage stage[T, V] sink sink[V] + state *PipelineState } func NewPipeline[T, V any]( @@ -17,19 +18,24 @@ func NewPipeline[T, V any]( source: source, stage: stage, sink: sink, + state: NewPipelineState(), } } func (p *Pipeline[T, V]) Execute(ctx context.Context) error { - sourceChannel, err := p.source(ctx) + sourceChannel, err := p.source(ctx, p.state) if err != nil { return err } - stageChannel, err := p.stage(ctx, sourceChannel) + stageChannel, err := p.stage(ctx, p.state, sourceChannel) if err != nil { return err } - return p.sink(ctx, stageChannel) + if err := p.sink(ctx, p.state, stageChannel); err != nil { + return err + } + + return p.state.RunAfterEffects(ctx) } diff --git a/etl/pipeline/pipelineState.go b/etl/pipeline/pipelineState.go new file mode 100644 index 0000000..9586282 --- /dev/null +++ b/etl/pipeline/pipelineState.go @@ -0,0 +1,41 @@ +package pipeline + +import ( + "context" + + "golang.org/x/sync/errgroup" +) + +type AfterEffect func(ctx context.Context) error + +type PipelineState struct { + afterEffects []AfterEffect +} + +func NewPipelineState() *PipelineState { + return &PipelineState{ + afterEffects: []AfterEffect{}, + } +} + +func (p *PipelineState) RegisterAfterEffect(afterEffect AfterEffect) { + p.afterEffects = append(p.afterEffects, afterEffect) +} + +func (p *PipelineState) RunAfterEffects(ctx context.Context) error { + errGroup := new(errgroup.Group) + + for _, afterEffect := range p.afterEffects { + errGroup.Go(func() error { + return afterEffect(ctx) + }) + } + + if err := errGroup.Wait(); err != nil { + return err + } + + p.afterEffects = []AfterEffect{} + + return nil +} diff --git a/etl/pipeline/sink.go b/etl/pipeline/sink.go index 0496f20..34a30a0 100644 --- a/etl/pipeline/sink.go +++ b/etl/pipeline/sink.go @@ -1,14 +1,18 @@ package pipeline -import "context" +import ( + "context" -type sink[T any] func(context.Context, chan T) error + "golang.org/x/sync/errgroup" +) -func Emit[T any](emitFunc func(context.Context, T) error) sink[T] { - return func(ctx context.Context, inChannel chan T) error { +type sink[T any] func(context.Context, *PipelineState, chan T) error + +func Emit[T any](emitFunc func(context.Context, *PipelineState, T) error) sink[T] { + return func(ctx context.Context, p *PipelineState, inChannel chan T) error { // collect values from channel for inValue := range inChannel { - if err := emitFunc(ctx, inValue); err != nil { + if err := emitFunc(ctx, p, inValue); err != nil { return err } } @@ -18,10 +22,13 @@ func Emit[T any](emitFunc func(context.Context, T) error) sink[T] { } func Spread[T any](sinks ...sink[T]) sink[T] { - // create list of channels for sinks + // allocate list of channels for sinks sinkChannels := make([]chan T, len(sinks)) - return func(ctx context.Context, inChannel chan T) error { + // prepare err group + var errGroup = new(errgroup.Group) + + return func(ctx context.Context, p *PipelineState, inChannel chan T) error { // collect input values from channel inValues := make([]T, len(inChannel)) idx := 0 @@ -32,19 +39,34 @@ func Spread[T any](sinks ...sink[T]) sink[T] { // construct sink channel for each sink and load input values for range sinks { + // allocate sink channel to hold values sinkChannel := make(chan T, len(inChannel)) + + // load input values into new sink channel for _, inValue := range inValues { sinkChannel <- inValue } + + // close channel to indicate no more data will be sent close(sinkChannel) + + // add channel to list of channels sinkChannels = append(sinkChannels, sinkChannel) } - // execute sinks + // execute sinks concurrently for idx, sink := range sinks { - if err := sink(ctx, sinkChannels[idx]); err != nil { - return err - } + errGroup.Go(func() error { + if err := sink(ctx, p, sinkChannels[idx]); err != nil { + return err + } + return nil + }) + } + + // wait for sinks + if err := errGroup.Wait(); err != nil { + return err } return nil @@ -54,9 +76,9 @@ func Spread[T any](sinks ...sink[T]) sink[T] { func SequenceSink[T any]( sinks ...sink[T], ) sink[T] { - return func(ctx context.Context, c chan T) error { + return func(ctx context.Context, p *PipelineState, c chan T) error { for _, sink := range sinks { - if err := sink(ctx, c); err != nil { + if err := sink(ctx, p, c); err != nil { return err } } diff --git a/etl/pipeline/source.go b/etl/pipeline/source.go index 8fb9d9c..b107fcb 100644 --- a/etl/pipeline/source.go +++ b/etl/pipeline/source.go @@ -4,14 +4,14 @@ import ( "context" ) -type source[T any] func(ctx context.Context) (chan T, error) -type chainedSource[T, V any] func(context.Context, chan T) (chan V, error) +type source[T any] func(context.Context, *PipelineState) (chan T, error) +type chainedSource[T, V any] func(context.Context, *PipelineState, chan T) (chan V, error) -func SourceBatch[T any](sourceFunc func(context.Context) ([]T, error)) source[T] { +func SourceBatch[T any](sourceFunc func(context.Context, *PipelineState) ([]T, error)) source[T] { // return function to be executed during pipeline execution - return func(ctx context.Context) (chan T, error) { + return func(ctx context.Context, p *PipelineState) (chan T, error) { // execute source - batch, err := sourceFunc(ctx) + batch, err := sourceFunc(ctx, p) if err != nil { return nil, err } @@ -29,8 +29,8 @@ func SourceBatch[T any](sourceFunc func(context.Context) ([]T, error)) source[T] } } -func ChainedSourceBatch[T, V any](sourceFunc func(context.Context, []T) ([]V, error)) chainedSource[T, V] { - return func(ctx context.Context, inChannel chan T) (chan V, error) { +func ChainedSourceBatch[T, V any](sourceFunc func(context.Context, *PipelineState, []T) ([]V, error)) chainedSource[T, V] { + return func(ctx context.Context, p *PipelineState, inChannel chan T) (chan V, error) { inValues := make([]T, len(inChannel)) idx := 0 for inValue := range inChannel { @@ -38,7 +38,7 @@ func ChainedSourceBatch[T, V any](sourceFunc func(context.Context, []T) ([]V, er idx++ } - outValues, err := sourceFunc(ctx, inValues) + outValues, err := sourceFunc(ctx, p, inValues) if err != nil { return nil, err } @@ -53,14 +53,14 @@ func ChainedSourceBatch[T, V any](sourceFunc func(context.Context, []T) ([]V, er } } -func SourceScalar[T any](source func(context.Context) (T, error)) source[T] { +func SourceScalar[T any](source func(context.Context, *PipelineState) (T, error)) source[T] { // return function to be executed during pipeline execution - return func(ctx context.Context) (chan T, error) { + return func(ctx context.Context, p *PipelineState) (chan T, error) { // create synchronous channel scalarChan := make(chan T, 1) // execute source - scalar, err := source(ctx) + scalar, err := source(ctx, p) if err != nil { return nil, err } @@ -73,8 +73,8 @@ func SourceScalar[T any](source func(context.Context) (T, error)) source[T] { } } -func ChainedSourceScalar[T, V any](sourceFunc func(context.Context, T) ([]V, error)) chainedSource[T, V] { - return func(ctx context.Context, inChannel chan T) (chan V, error) { +func ChainedSourceScalar[T, V any](sourceFunc func(context.Context, *PipelineState, T) ([]V, error)) chainedSource[T, V] { + return func(ctx context.Context, p *PipelineState, inChannel chan T) (chan V, error) { // reading all values to prevent leak inValues := make([]T, len(inChannel)) idx := 0 @@ -84,7 +84,7 @@ func ChainedSourceScalar[T, V any](sourceFunc func(context.Context, T) ([]V, err } // call source with first inValue - outValues, err := sourceFunc(ctx, inValues[0]) + outValues, err := sourceFunc(ctx, p, inValues[0]) if err != nil { return nil, err } @@ -101,15 +101,15 @@ func ChainedSourceScalar[T, V any](sourceFunc func(context.Context, T) ([]V, err } func SequenceSource[T, V any](source1 source[T], source2 chainedSource[T, V]) source[V] { - return func(ctx context.Context) (chan V, error) { + return func(ctx context.Context, p *PipelineState) (chan V, error) { // execute source1 to obtain handle to channel - source1Chan, err := source1(ctx) + source1Chan, err := source1(ctx, p) if err != nil { return nil, err } // execute source2 given source1 - source2Chan, err := source2(ctx, source1Chan) + source2Chan, err := source2(ctx, p, source1Chan) // load source2 data into channel chainChannel := make(chan V, len(source2Chan)) diff --git a/etl/pipeline/stage.go b/etl/pipeline/stage.go index b4e09f3..3f241f9 100644 --- a/etl/pipeline/stage.go +++ b/etl/pipeline/stage.go @@ -2,18 +2,19 @@ package pipeline import ( "context" + "math/rand" ) -type stage[T any, V any] func(context.Context, chan T) (chan V, error) +type stage[T any, V any] func(context.Context, *PipelineState, chan T) (chan V, error) -func Map[T, V any](mapFunc func(context.Context, T) V) stage[T, V] { - return func(ctx context.Context, inChannel chan T) (chan V, error) { +func Map[T, V any](mapFunc func(context.Context, *PipelineState, T) V) stage[T, V] { + return func(ctx context.Context, p *PipelineState, inChannel chan T) (chan V, error) { // create channel with buffer for each element in inChannel mapChan := make(chan V, len(inChannel)) // map values from inChannel for inValue := range inChannel { - mapChan <- mapFunc(ctx, inValue) + mapChan <- mapFunc(ctx, p, inValue) } close(mapChan) @@ -21,14 +22,14 @@ func Map[T, V any](mapFunc func(context.Context, T) V) stage[T, V] { } } -func Filter[T any](filterFunc func(context.Context, T) bool) stage[T, T] { - return func(ctx context.Context, inChannel chan T) (chan T, error) { +func Filter[T any](filterFunc func(context.Context, *PipelineState, T) bool) stage[T, T] { + return func(ctx context.Context, p *PipelineState, inChannel chan T) (chan T, error) { // optimistically allocate buffer for all elements in input channel filterChan := make(chan T, len(inChannel)) // filter values from inChannel for inValue := range inChannel { - if filterFunc(ctx, inValue) { + if filterFunc(ctx, p, inValue) { filterChan <- inValue } } @@ -37,24 +38,121 @@ func Filter[T any](filterFunc func(context.Context, T) bool) stage[T, T] { } } +func Shuffle[T any]() stage[T, T] { + return func(ctx context.Context, p *PipelineState, input chan T) (chan T, error) { + // load values from input into array + inputValues := make([]T, len(input)) + idx := 0 + for inputValue := range input { + inputValues[idx] = inputValue + idx++ + } + + // shuffle the data (NOTE: Assuming Go 1.20 which automatically sets seed) + rand.Shuffle(len(inputValues), func(i, j int) { + inputValues[i], inputValues[j] = inputValues[j], inputValues[i] + }) + + // load data into return channel + outChannel := make(chan T, len(inputValues)) + for _, inputValue := range inputValues { + outChannel <- inputValue + } + close(outChannel) + + return outChannel, nil + } +} + +func Count[T comparable]() stage[T, map[T]int] { + return func(ctx context.Context, p *PipelineState, inChannel chan T) (chan map[T]int, error) { + countMap := make(map[T]int) + + for inValue := range inChannel { + _, found := countMap[inValue] + if found { + countMap[inValue] += 1 + } else { + countMap[inValue] = 1 + } + } + + outputChan := make(chan map[T]int, 1) + outputChan <- countMap + close(outputChan) + + return outputChan, nil + } +} + func SequenceStage[T, V, K any]( stage1 stage[T, V], stage2 stage[V, K], ) stage[T, K] { - return func(ctx context.Context, inChannel chan T) (chan K, error) { - + return func(ctx context.Context, p *PipelineState, inChannel chan T) (chan K, error) { // execute stage 1 - stage1Channel, err := stage1(ctx, inChannel) + stage1Channel, err := stage1(ctx, p, inChannel) if err != nil { return nil, err } // execute stage 2 - stage2Channel, err := stage2(ctx, stage1Channel) + stage2Channel, err := stage2(ctx, p, stage1Channel) + if err != nil { + return nil, err + } + + return stage2Channel, nil + } +} + +func SequenceStage3[T, V, K, M any]( + stage1 stage[T, V], + stage2 stage[V, K], + stage3 stage[K, M], +) stage[T, M] { + return func(ctx context.Context, p *PipelineState, inChannel chan T) (chan M, error) { + // construct sequenced for stages 1 + 2 + sequencedStage1 := SequenceStage(stage1, stage2) + + // execute stages 1 + 2 + sequencedChan, err := sequencedStage1(ctx, p, inChannel) + if err != nil { + return nil, err + } + + // execute sequenced stage 3 + stage3Chan, err := stage3(ctx, p, sequencedChan) + if err != nil { + return nil, err + } + + return stage3Chan, nil + } +} + +func SequenceStage4[T, V, K, M, L any]( + stage1 stage[T, V], + stage2 stage[V, K], + stage3 stage[K, M], + stage4 stage[M, L], +) stage[T, L] { + return func(ctx context.Context, p *PipelineState, inChannel chan T) (chan L, error) { + // construct sequenced for stages 1 + 2 + sequencedStage1 := SequenceStage3(stage1, stage2, stage3) + + // execute stages 1 + 2 + 3 + sequencedChan, err := sequencedStage1(ctx, p, inChannel) + if err != nil { + return nil, err + } + + // execute sequenced stage 4 + stage4Chan, err := stage4(ctx, p, sequencedChan) if err != nil { return nil, err } - return stage2Channel, err + return stage4Chan, nil } } diff --git a/examples/collector.go b/examples/collector.go deleted file mode 100644 index 0457967..0000000 --- a/examples/collector.go +++ /dev/null @@ -1,7 +0,0 @@ -package main - -import "context" - -type BatchedDataSource[T any] interface { - Collect(context.Context, string) ([]T, string, error) -} diff --git a/examples/collectorImpl.go b/examples/collectorImpl.go deleted file mode 100644 index 2057715..0000000 --- a/examples/collectorImpl.go +++ /dev/null @@ -1,26 +0,0 @@ -package main - -import ( - "context" -) - -// ensure slice data collector implements interface -var _ BatchedDataSource[string] = &SliceDataCollector{} - -// SliceDataCollector is used to collect data from an internal slice of string -type SliceDataCollector struct { - data []string - batchSize int -} - -func NewSliceDataCollector(_data []string, _batchSize int) *SliceDataCollector { - return &SliceDataCollector{ - data: _data, - batchSize: _batchSize, - } -} - -// Collect implements etl.DataCollector. -func (d *SliceDataCollector) Collect(ctx context.Context, pagingToken string) ([]string, string, error) { - return d.data, "", nil -} diff --git a/examples/emitter.go b/examples/emitter.go deleted file mode 100644 index 6d34d4a..0000000 --- a/examples/emitter.go +++ /dev/null @@ -1,7 +0,0 @@ -package main - -import "context" - -type DataEmitter[T any] interface { - Emit(ctx context.Context, value T) error -} diff --git a/examples/main.go b/examples/main.go index f97577a..0ff94e3 100644 --- a/examples/main.go +++ b/examples/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "fmt" "log" "github.com/meshtrade/mesh-etl/etl/pipeline" @@ -14,27 +15,35 @@ type Model struct { func main() { pipeline := pipeline.NewPipeline( pipeline.SequenceSource( - pipeline.SourceScalar(func(ctx context.Context) (int, error) { - return 0, nil + pipeline.SourceScalar(func(ctx context.Context, ps *pipeline.PipelineState) (int, error) { + ps.RegisterAfterEffect(func(ctx context.Context) error { + fmt.Println("Got return value: ", 1) + return nil + }) + return 1, nil }), - pipeline.ChainedSourceScalar(func(ctx context.Context, in int) ([]int, error) { - return []int{1, 2, 3, 4}, nil + pipeline.ChainedSourceScalar(func(ctx context.Context, p *pipeline.PipelineState, input int) ([]int, error) { + return []int{}, nil }), ), - pipeline.SequenceStage( - pipeline.Map(func(ctx context.Context, inValue int) int { - return inValue * inValue - }), - pipeline.Map(func(ctx context.Context, inValue int) Model { - return Model{ - Value: inValue, - } - }), - ), - pipeline.Emit(NewSTDOutEmitter[Model]().Emit), + pipeline.Map(func(ctx context.Context, p *pipeline.PipelineState, input int) int { + return 0 + }), + pipeline.Emit(func(ctx context.Context, p *pipeline.PipelineState, input int) error { + return nil + }), ) if err := pipeline.Execute(context.Background()); err != nil { log.Fatal(err) } + + if err := pipeline.Execute(context.Background()); err != nil { + log.Fatal(err) + } + + if err := pipeline.Execute(context.Background()); err != nil { + log.Fatal(err) + } + } diff --git a/examples/stdoutEmitter.go b/examples/stdoutEmitter.go index 249dc5f..02e7dd0 100644 --- a/examples/stdoutEmitter.go +++ b/examples/stdoutEmitter.go @@ -3,6 +3,8 @@ package main import ( "context" "fmt" + + "github.com/meshtrade/mesh-etl/etl/pipeline" ) type StdOutEmitter[T any] struct { @@ -12,7 +14,7 @@ func NewSTDOutEmitter[T any]() *StdOutEmitter[T] { return &StdOutEmitter[T]{} } -func (e *StdOutEmitter[T]) Emit(ctx context.Context, value T) error { +func (e *StdOutEmitter[T]) Emit(ctx context.Context, pipelineState *pipeline.PipelineState, value T) error { fmt.Printf("Data: %v\n", value) return nil }