Skip to content

Commit

Permalink
Merge pull request #9 from meshtrade/refactor-arrow-flat-serialise
Browse files Browse the repository at this point in the history
Refactor: arrow stage to always flatten given input struct
  • Loading branch information
KyleSmith19091 authored Nov 13, 2024
2 parents 9a30ff7 + ff17d6c commit 4f33afa
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 73 deletions.
98 changes: 34 additions & 64 deletions etl/parquet/stage.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,25 @@ type ParquetSerialiser[T any] struct {
}

func NewParquetSerialiser[T any]() *ParquetSerialiser[T] {
var t T

// create new go allocator
pool := memory.NewGoAllocator()

// check element type
var t T
elemType := reflect.TypeOf(t)
if elemType.Kind() != reflect.Struct {
log.Fatal().Msg("expected type for serialiser to be struct")
}

// dynamically build the Arrow schema based on the struct fields
arrowFields, fieldBuilders, err := buildArrowFieldsAndBuilders(pool, elemType)
arrowFields := []arrow.Field{}
fieldBuilders := []array.Builder{}
arrowFields, fieldBuilders, err := buildArrowFieldsAndBuilders(arrowFields, fieldBuilders, pool, elemType)
if err != nil {
log.Fatal().Err(err).Msg("error building arrow fields and builders")
}

// build schema
// build schema from fields of type T
schema := arrow.NewSchema(arrowFields, nil)

return &ParquetSerialiser[T]{
Expand All @@ -49,10 +51,8 @@ func NewParquetSerialiser[T any]() *ParquetSerialiser[T] {
}
}

func buildArrowFieldsAndBuilders(pool memory.Allocator, elemType reflect.Type) ([]arrow.Field, []array.Builder, error) {
var arrowFields []arrow.Field
var fieldBuilders []array.Builder

// buildArrowFieldsAndBuilders will build arrow fields and arrow field builders for a given reflect.Type
func buildArrowFieldsAndBuilders(arrowFields []arrow.Field, fieldBuilders []array.Builder, pool memory.Allocator, elemType reflect.Type) ([]arrow.Field, []array.Builder, error) {
for i := 0; i < elemType.NumField(); i++ {
field := elemType.Field(i)

Expand All @@ -77,19 +77,12 @@ func buildArrowFieldsAndBuilders(pool memory.Allocator, elemType reflect.Type) (
}

// recursively build schema for the inner struct
innerFields, _, err := buildArrowFieldsAndBuilders(pool, field.Type)
var err error
arrowFields, fieldBuilders, err = buildArrowFieldsAndBuilders(arrowFields, fieldBuilders, pool, field.Type)
if err != nil {
return nil, nil, err
}

// add the inner struct schema
innerStructSchema := arrow.StructOf(innerFields...)
arrowFields = append(arrowFields, arrow.Field{Name: field.Name, Type: innerStructSchema, Nullable: false})

// add the inner struct builder and its fields
innerStructBuilder := array.NewStructBuilder(pool, innerStructSchema)
fieldBuilders = append(fieldBuilders, innerStructBuilder)

default:
return nil, nil, fmt.Errorf("unsupported field type for field %s: %s", field.Name, field.Type.Kind())
}
Expand All @@ -113,41 +106,13 @@ func (s *ParquetSerialiser[T]) Serialise(ctx context.Context, p *pipeline.Pipeli
return outChannel, nil
}

// get the reflection value of the input slice
timeType := reflect.TypeOf(time.Time{})

// iterate through the slice and append values to builders
for i := 0; i < len(inputStruct); i++ {
structVal := reflect.ValueOf(inputStruct[i])

for j := 0; j < structVal.NumField(); j++ {
fieldVal := structVal.Field(j)

switch fieldVal.Kind() {
case reflect.String:
s.fieldBuilders[j].(*array.StringBuilder).Append(fieldVal.String())

case reflect.Int32:
s.fieldBuilders[j].(*array.Int32Builder).Append(int32(fieldVal.Int()))

case reflect.Float64:
s.fieldBuilders[j].(*array.Float64Builder).Append(fieldVal.Float())

case reflect.Struct:
if fieldVal.Type() == timeType {
timeVal := fieldVal.Interface().(time.Time)
s.fieldBuilders[j].(*array.Date64Builder).Append(arrow.Date64FromTime(timeVal))
continue
}
structBuilder := s.fieldBuilders[j].(*array.StructBuilder)
if err := s.appendStructValues(structBuilder, fieldVal); err != nil {
return nil, err
}
structBuilder.Append(true)
default:
return nil, fmt.Errorf("unsupported field type: %s", fieldVal.Kind())
}
}
// add struct values for given struct
builderIdx := 0
s.appendStructValues(builderIdx, structVal)
}

// create arrow arrays from builders
Expand Down Expand Up @@ -189,40 +154,45 @@ func (s *ParquetSerialiser[T]) Serialise(ctx context.Context, p *pipeline.Pipeli
return outputChannel, nil
}

func (s *ParquetSerialiser[T]) appendStructValues(builder *array.StructBuilder, structVal reflect.Value) error {
func (s *ParquetSerialiser[T]) appendStructValues(builderIdx int, structVal reflect.Value) (int, error) {
timeType := reflect.TypeOf(time.Time{})

for i := 0; i < structVal.NumField(); i++ {
fieldVal := structVal.Field(i)
fieldBuilder := builder.FieldBuilder(i)
for j := 0; j < structVal.NumField(); j++ {
// get a handle to the value of the struct field
fieldVal := structVal.Field(j)

switch fieldVal.Kind() {
case reflect.String:
fieldBuilder.(*array.StringBuilder).Append(fieldVal.String())
// add string value
s.fieldBuilders[builderIdx].(*array.StringBuilder).Append(fieldVal.String())

case reflect.Int32:
fieldBuilder.(*array.Int32Builder).Append(int32(fieldVal.Int()))
// add int value
s.fieldBuilders[builderIdx].(*array.Int32Builder).Append(int32(fieldVal.Int()))

case reflect.Float64:
fieldBuilder.(*array.Float64Builder).Append(fieldVal.Float())
// add float value
s.fieldBuilders[builderIdx].(*array.Float64Builder).Append(fieldVal.Float())

case reflect.Struct:
if fieldVal.Type() == timeType {
// add date time value
timeVal := fieldVal.Interface().(time.Time)
fieldBuilder.(*array.Date64Builder).Append(arrow.Date64FromTime(timeVal))
s.fieldBuilders[builderIdx].(*array.Date64Builder).Append(arrow.Date64FromTime(timeVal))
continue
} else {
// Recursively handle nested structs
nestedStructBuilder := fieldBuilder.(*array.StructBuilder)
if err := s.appendStructValues(nestedStructBuilder, fieldVal); err != nil {
return err
// recursively add struct values
var err error
builderIdx, err = s.appendStructValues(builderIdx, fieldVal)
if err != nil {
return -1, err
}
nestedStructBuilder.Append(true)
}

default:
return fmt.Errorf("unsupported field type: %s", fieldVal.Kind())
return -1, fmt.Errorf("unsupported field type: %s", fieldVal.Kind())
}
builderIdx++
}

return nil
return builderIdx, nil
}
9 changes: 0 additions & 9 deletions examples/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,4 @@ func main() {
if err := pipeline.Execute(context.Background()); err != nil {
log.Fatal(err)
}

if err := pipeline.Execute(context.Background()); err != nil {
log.Fatal(err)
}

if err := pipeline.Execute(context.Background()); err != nil {
log.Fatal(err)
}

}

0 comments on commit 4f33afa

Please sign in to comment.