Skip to content

Commit

Permalink
fix: avoid duplication of variants in registry when function name is …
Browse files Browse the repository at this point in the history
…overloaded (#87)

this happens with count_star, where same substrait function name is used
for two different functions in the extension
  • Loading branch information
scgkiran authored Dec 16, 2024
1 parent 317c209 commit 5924d58
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 6 deletions.
23 changes: 18 additions & 5 deletions functions/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ type mapAndSlice[V extensions.FunctionVariant] struct {
// It returns
// 1. a mapAndSlice of LocalFunctionVariants
// 2. an error if a function variant is not found for a dialect function
func makeLocalFunctionVariantMapAndSlice[T withID, V extensions.FunctionVariant](
func makeLocalFunctionVariantMapAndSlice[T withID, V localFunctionVariant](
dialectFunctionInfos map[extensions.ID]*dialectFunctionInfo, getFunctionVariants func(string) []T,
createLocalVariant func(T, *dialectFunctionInfo) V) (*mapAndSlice[V], error) {

Expand All @@ -110,17 +110,19 @@ func makeLocalFunctionVariantMapAndSlice[T withID, V extensions.FunctionVariant]

localVariantArray := make([]V, 0)
for _, f := range getFunctionVariants(dfi.Name) {
if dfi, ok := dialectFunctionInfos[f.ID()]; ok {
localVariantArray = append(localVariantArray, createLocalVariant(f, dfi))
processedFunctions[f.ID()] = true
if _, alreadyProcessed := processedFunctions[f.ID()]; !alreadyProcessed {
if dfi1, ok := dialectFunctionInfos[f.ID()]; ok {
localVariantArray = append(localVariantArray, createLocalVariant(f, dfi1))
processedFunctions[f.ID()] = true
}
}
}
if _, ok := processedFunctions[dfi.ID]; !ok {
return nil, fmt.Errorf("%w: no function variant found for '%s'", substraitgo.ErrInvalidDialect, dfi.ID)
}
if len(localVariantArray) > 0 {
addToSliceMap(variantsMap, SubstraitFunctionName(dfi.Name), localVariantArray)
addToSliceMap(variantsMap, LocalFunctionName(dfi.LocalName), localVariantArray)
addToSliceMapWithLocalKey(variantsMap, localVariantArray)
variantsSlice = append(variantsSlice, localVariantArray...)
}
}
Expand All @@ -137,6 +139,17 @@ func addToSliceMap[K FunctionName, V extensions.FunctionVariant](m map[FunctionN
m[key] = append(m[key], value...)
}

func addToSliceMapWithLocalKey[V localFunctionVariant](m map[FunctionName][]V, value []V) {
for _, v := range value {
// localFunctionName for some variants can be different even though substraitFunctionName is same
key := LocalFunctionName(v.LocalName())
if _, ok := m[key]; !ok {
m[key] = make([]V, 0)
}
m[key] = append(m[key], v)
}
}

func (d *dialectImpl) LocalizeTypeRegistry(TypeRegistry) (LocalTypeRegistry, error) {
typeInfos := make([]typeInfo, 0, len(d.toLocalTypeMap))
for name, info := range d.toLocalTypeMap {
Expand Down
144 changes: 143 additions & 1 deletion functions/dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1643,7 +1643,7 @@ scalar_functions:
}
checkCompoundNames(t, getScalarCompoundNames(fv), expectedNames)
assert.Len(t, urisFound, len(expectedUris))
for k, _ := range urisFound {
for k := range urisFound {
assert.Contains(t, expectedUris, k)
}

Expand All @@ -1670,3 +1670,145 @@ scalar_functions:
assert.Contains(t, allFunctions, f)
}
}

func TestAggregateFunctionsWithSameName(t *testing.T) {
const arithmeticUri = "http://localhost/functions_aggregate_generic.yaml"
const decimalUri = "http://localhost/functions_aggregate_decimal_output.yaml"
const decimalYaml = `---
aggregate_functions:
- name: "count"
description: Count a set of values. Result is returned as a decimal instead of i64.
impls:
- args:
- name: x
value: any
options:
overflow:
values: [SILENT, SATURATE, ERROR]
nullability: DECLARED_OUTPUT
decomposable: MANY
intermediate: decimal<38,0>
return: decimal<38,0>
- name: "count"
description: "Count a set of records (not field referenced). Result is returned as a decimal instead of i64."
impls:
- options:
overflow:
values: [SILENT, SATURATE, ERROR]
nullability: DECLARED_OUTPUT
decomposable: MANY
intermediate: decimal<38,0>
return: decimal<38,0>
`
const arithmeticYaml = `---
aggregate_functions:
- name: "count"
description: Count a set of values
impls:
- args:
- name: x
value: any
options:
overflow:
values: [SILENT, SATURATE, ERROR]
nullability: DECLARED_OUTPUT
decomposable: MANY
intermediate: i64
return: i64
- name: "count"
description: "Count a set of records (not field referenced)"
impls:
- options:
overflow:
values: [SILENT, SATURATE, ERROR]
nullability: DECLARED_OUTPUT
decomposable: MANY
intermediate: i64
return: i64
`

dialectYaml := `
name: test
type: sql
dependencies:
aggregate:
http://localhost/functions_aggregate_generic.yaml
aggdec:
http://localhost/functions_aggregate_decimal_output.yaml
supported_types:
dec:
sql_type_name: numeric
supported_as_column: true
i64:
sql_type_name: BIGINT
supported_as_column: true
aggregate_functions:
- name: aggdec.count
local_name: count
aggregate: true
supported_kernels:
- any
- name: aggdec.count
local_name: count_rows
aggregate: true
supported_kernels:
- ""
`
// get substrait function registry
var c extensions.Collection
require.NoError(t, c.Load(arithmeticUri, strings.NewReader(arithmeticYaml)))
require.NoError(t, c.Load(decimalUri, strings.NewReader(decimalYaml)))
funcRegistry := NewFunctionRegistry(&c)
localRegistry := getLocalFunctionRegistry(t, dialectYaml, funcRegistry)
allFunctions := funcRegistry.GetAllFunctions()

testcases := []struct {
numArgs int
localName string
substraitName string
signature string
numSubstraitFunctions int
}{
{1, "count", "count", "count:any", 2},
{0, "count_rows", "count", "count:", 2},
}
for _, tt := range testcases {
t.Run(tt.localName, func(t *testing.T) {
var fv []*LocalAggregateFunctionVariant
fv = localRegistry.GetAggregateFunctions(LocalFunctionName(tt.localName), tt.numArgs)

require.Greater(t, len(fv), 0)
assert.Equal(t, decimalUri, fv[0].URI())
assert.Equal(t, tt.localName, fv[0].LocalName())
assert.Equal(t, tt.substraitName, fv[0].Name())
checkCompoundNames(t, getAggregateCompoundNames(fv), []string{tt.signature})

fv = localRegistry.GetAggregateFunctions(SubstraitFunctionName(tt.substraitName), tt.numArgs)
require.Greater(t, len(fv), 0)
assert.Equal(t, decimalUri, fv[0].URI())
assert.Equal(t, tt.localName, fv[0].LocalName())
assert.Equal(t, tt.substraitName, fv[0].Name())
checkCompoundNames(t, getAggregateCompoundNames(fv), []string{tt.signature})

aggregateFunctions := funcRegistry.GetAggregateFunctions(tt.substraitName, tt.numArgs)
assert.Equal(t, tt.numSubstraitFunctions, len(aggregateFunctions))
for _, f := range aggregateFunctions {
assert.Contains(t, allFunctions, f)
}
aggregateFunctions = funcRegistry.GetAggregateFunctionsByName(tt.substraitName)
assert.Equal(t, 4, len(aggregateFunctions))
uriMap := map[string]int{
arithmeticUri: 0,
decimalUri: 0,
}
for _, f := range aggregateFunctions {
assert.Contains(t, allFunctions, f)
uriMap[f.URI()]++
}
assert.Equal(t, 2, uriMap[arithmeticUri])
assert.Equal(t, 2, uriMap[decimalUri])
})
}
}
7 changes: 7 additions & 0 deletions functions/registries.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ const (
POSTFIX
)

type localFunctionVariant interface {
extensions.FunctionVariant
LocalName() string
Notation() FunctionNotation
IsOptionSupported(name string, value string) bool
}

type LocalFunctionVariant struct {
localName string
supportedOptions map[string]extensions.Option
Expand Down

0 comments on commit 5924d58

Please sign in to comment.