From ff17d6ccc0bb52d806a0ce1a501fa2c63171d390 Mon Sep 17 00:00:00 2001 From: KyleSmith19091 Date: Wed, 13 Nov 2024 08:56:09 +0200 Subject: [PATCH] refactor arrow stage to always flatten given input struct --- etl/parquet/stage.go | 98 +++++++++++++++----------------------------- examples/main.go | 9 ---- 2 files changed, 34 insertions(+), 73 deletions(-) diff --git a/etl/parquet/stage.go b/etl/parquet/stage.go index be63b74..5151d97 100644 --- a/etl/parquet/stage.go +++ b/etl/parquet/stage.go @@ -23,23 +23,25 @@ type ParquetSerialiser[T any] struct { } func NewParquetSerialiser[T any]() *ParquetSerialiser[T] { - var t T - + // create new go allocator pool := memory.NewGoAllocator() // check element type + var t T elemType := reflect.TypeOf(t) if elemType.Kind() != reflect.Struct { log.Fatal().Msg("expected type for serialiser to be struct") } // dynamically build the Arrow schema based on the struct fields - arrowFields, fieldBuilders, err := buildArrowFieldsAndBuilders(pool, elemType) + arrowFields := []arrow.Field{} + fieldBuilders := []array.Builder{} + arrowFields, fieldBuilders, err := buildArrowFieldsAndBuilders(arrowFields, fieldBuilders, pool, elemType) if err != nil { log.Fatal().Err(err).Msg("error building arrow fields and builders") } - // build schema + // build schema from fields of type T schema := arrow.NewSchema(arrowFields, nil) return &ParquetSerialiser[T]{ @@ -49,10 +51,8 @@ func NewParquetSerialiser[T any]() *ParquetSerialiser[T] { } } -func buildArrowFieldsAndBuilders(pool memory.Allocator, elemType reflect.Type) ([]arrow.Field, []array.Builder, error) { - var arrowFields []arrow.Field - var fieldBuilders []array.Builder - +// buildArrowFieldsAndBuilders will build arrow fields and arrow field builders for a given reflect.Type +func buildArrowFieldsAndBuilders(arrowFields []arrow.Field, fieldBuilders []array.Builder, pool memory.Allocator, elemType reflect.Type) ([]arrow.Field, []array.Builder, error) { for i := 0; i < elemType.NumField(); i++ { field := elemType.Field(i) @@ -77,19 +77,12 @@ func buildArrowFieldsAndBuilders(pool memory.Allocator, elemType reflect.Type) ( } // recursively build schema for the inner struct - innerFields, _, err := buildArrowFieldsAndBuilders(pool, field.Type) + var err error + arrowFields, fieldBuilders, err = buildArrowFieldsAndBuilders(arrowFields, fieldBuilders, pool, field.Type) if err != nil { return nil, nil, err } - // add the inner struct schema - innerStructSchema := arrow.StructOf(innerFields...) - arrowFields = append(arrowFields, arrow.Field{Name: field.Name, Type: innerStructSchema, Nullable: false}) - - // add the inner struct builder and its fields - innerStructBuilder := array.NewStructBuilder(pool, innerStructSchema) - fieldBuilders = append(fieldBuilders, innerStructBuilder) - default: return nil, nil, fmt.Errorf("unsupported field type for field %s: %s", field.Name, field.Type.Kind()) } @@ -113,41 +106,13 @@ func (s *ParquetSerialiser[T]) Serialise(ctx context.Context, p *pipeline.Pipeli return outChannel, nil } - // get the reflection value of the input slice - timeType := reflect.TypeOf(time.Time{}) - // iterate through the slice and append values to builders for i := 0; i < len(inputStruct); i++ { structVal := reflect.ValueOf(inputStruct[i]) - for j := 0; j < structVal.NumField(); j++ { - fieldVal := structVal.Field(j) - - switch fieldVal.Kind() { - case reflect.String: - s.fieldBuilders[j].(*array.StringBuilder).Append(fieldVal.String()) - - case reflect.Int32: - s.fieldBuilders[j].(*array.Int32Builder).Append(int32(fieldVal.Int())) - - case reflect.Float64: - s.fieldBuilders[j].(*array.Float64Builder).Append(fieldVal.Float()) - - case reflect.Struct: - if fieldVal.Type() == timeType { - timeVal := fieldVal.Interface().(time.Time) - s.fieldBuilders[j].(*array.Date64Builder).Append(arrow.Date64FromTime(timeVal)) - continue - } - structBuilder := s.fieldBuilders[j].(*array.StructBuilder) - if err := s.appendStructValues(structBuilder, fieldVal); err != nil { - return nil, err - } - structBuilder.Append(true) - default: - return nil, fmt.Errorf("unsupported field type: %s", fieldVal.Kind()) - } - } + // add struct values for given struct + builderIdx := 0 + s.appendStructValues(builderIdx, structVal) } // create arrow arrays from builders @@ -189,40 +154,45 @@ func (s *ParquetSerialiser[T]) Serialise(ctx context.Context, p *pipeline.Pipeli return outputChannel, nil } -func (s *ParquetSerialiser[T]) appendStructValues(builder *array.StructBuilder, structVal reflect.Value) error { +func (s *ParquetSerialiser[T]) appendStructValues(builderIdx int, structVal reflect.Value) (int, error) { timeType := reflect.TypeOf(time.Time{}) - for i := 0; i < structVal.NumField(); i++ { - fieldVal := structVal.Field(i) - fieldBuilder := builder.FieldBuilder(i) + for j := 0; j < structVal.NumField(); j++ { + // get a handle to the value of the struct field + fieldVal := structVal.Field(j) switch fieldVal.Kind() { case reflect.String: - fieldBuilder.(*array.StringBuilder).Append(fieldVal.String()) + // add string value + s.fieldBuilders[builderIdx].(*array.StringBuilder).Append(fieldVal.String()) case reflect.Int32: - fieldBuilder.(*array.Int32Builder).Append(int32(fieldVal.Int())) + // add int value + s.fieldBuilders[builderIdx].(*array.Int32Builder).Append(int32(fieldVal.Int())) case reflect.Float64: - fieldBuilder.(*array.Float64Builder).Append(fieldVal.Float()) + // add float value + s.fieldBuilders[builderIdx].(*array.Float64Builder).Append(fieldVal.Float()) case reflect.Struct: if fieldVal.Type() == timeType { + // add date time value timeVal := fieldVal.Interface().(time.Time) - fieldBuilder.(*array.Date64Builder).Append(arrow.Date64FromTime(timeVal)) + s.fieldBuilders[builderIdx].(*array.Date64Builder).Append(arrow.Date64FromTime(timeVal)) + continue } else { - // Recursively handle nested structs - nestedStructBuilder := fieldBuilder.(*array.StructBuilder) - if err := s.appendStructValues(nestedStructBuilder, fieldVal); err != nil { - return err + // recursively add struct values + var err error + builderIdx, err = s.appendStructValues(builderIdx, fieldVal) + if err != nil { + return -1, err } - nestedStructBuilder.Append(true) } - default: - return fmt.Errorf("unsupported field type: %s", fieldVal.Kind()) + return -1, fmt.Errorf("unsupported field type: %s", fieldVal.Kind()) } + builderIdx++ } - return nil + return builderIdx, nil } diff --git a/examples/main.go b/examples/main.go index 0ff94e3..35a87a9 100644 --- a/examples/main.go +++ b/examples/main.go @@ -37,13 +37,4 @@ func main() { if err := pipeline.Execute(context.Background()); err != nil { log.Fatal(err) } - - if err := pipeline.Execute(context.Background()); err != nil { - log.Fatal(err) - } - - if err := pipeline.Execute(context.Background()); err != nil { - log.Fatal(err) - } - }