diff --git a/go/vt/vterrors/code.go b/go/vt/vterrors/code.go index 26abd85e49e..2aa9827c084 100644 --- a/go/vt/vterrors/code.go +++ b/go/vt/vterrors/code.go @@ -73,6 +73,7 @@ var ( VT09012 = errorWithoutState("VT09012", vtrpcpb.Code_FAILED_PRECONDITION, "%s statement with %s tablet not allowed", "This type of statement is not allowed on the given tablet.") VT09013 = errorWithoutState("VT09013", vtrpcpb.Code_FAILED_PRECONDITION, "semi-sync plugins are not loaded", "Durability policy wants Vitess to use semi-sync, but the MySQL instances don't have the semi-sync plugin loaded.") VT09014 = errorWithoutState("VT09014", vtrpcpb.Code_FAILED_PRECONDITION, "vindex cannot be modified", "The vindex cannot be used as table in DML statement") + VT09015 = errorWithoutState("VT09015", vtrpcpb.Code_FAILED_PRECONDITION, "schema tracking required", "This query cannot be planned without more information on the SQL schema. Please turn on schema tracking or add authoritative columns information to your VSchema.") VT10001 = errorWithoutState("VT10001", vtrpcpb.Code_ABORTED, "foreign key constraints are not allowed", "Foreign key constraints are not allowed, see https://vitess.io/blog/2021-06-15-online-ddl-why-no-fk/.") @@ -136,6 +137,7 @@ var ( VT09012, VT09013, VT09014, + VT09015, VT10001, VT12001, VT13001, diff --git a/go/vt/vtgate/engine/aggregations.go b/go/vt/vtgate/engine/aggregations.go index 9bdee013cf8..abe542c1e4c 100644 --- a/go/vt/vtgate/engine/aggregations.go +++ b/go/vt/vtgate/engine/aggregations.go @@ -20,6 +20,8 @@ import ( "fmt" "strconv" + "vitess.io/vitess/go/vt/vterrors" + "google.golang.org/protobuf/proto" "vitess.io/vitess/go/mysql/collations" @@ -41,7 +43,6 @@ type AggregateParams struct { // These are used only for distinct opcodes. KeyCol int WCol int - WAssigned bool CollationID collations.ID Alias string `json:",omitempty"` @@ -53,22 +54,26 @@ type AggregateParams struct { OrigOpcode AggregateOpcode } -func (ap *AggregateParams) isDistinct() bool { - return ap.Opcode == AggregateCountDistinct || ap.Opcode == AggregateSumDistinct +func NewAggregateParam(opcode AggregateOpcode, col int, alias string) *AggregateParams { + out := &AggregateParams{ + Opcode: opcode, + Col: col, + Alias: alias, + WCol: -1, + } + if opcode.NeedsComparableValues() { + out.KeyCol = col + } + return out } -func (ap *AggregateParams) preProcess() bool { - switch ap.Opcode { - case AggregateCountDistinct, AggregateSumDistinct, AggregateGtid, AggregateCount, AggregateGroupConcat: - return true - default: - return false - } +func (ap *AggregateParams) WAssigned() bool { + return ap.WCol >= 0 } func (ap *AggregateParams) String() string { keyCol := strconv.Itoa(ap.Col) - if ap.WAssigned { + if ap.WAssigned() { keyCol = fmt.Sprintf("%s|%d", keyCol, ap.WCol) } if ap.CollationID != collations.Unknown { @@ -161,7 +166,7 @@ func merge( ) ([]sqltypes.Value, []sqltypes.Value, error) { result := sqltypes.CopyRow(row1) for index, aggr := range aggregates { - if aggr.isDistinct() { + if aggr.Opcode.IsDistinct() { if row2[aggr.KeyCol].IsNull() { continue } @@ -194,8 +199,14 @@ func merge( } result[aggr.Col], err = evalengine.NullSafeAdd(value, v2, fields[aggr.Col].Type) case AggregateMin: + if aggr.WAssigned() && !row2[aggr.Col].IsComparable() { + return minMaxWeightStringError() + } result[aggr.Col], err = evalengine.Min(row1[aggr.Col], row2[aggr.Col], aggr.CollationID) case AggregateMax: + if aggr.WAssigned() && !row2[aggr.Col].IsComparable() { + return minMaxWeightStringError() + } result[aggr.Col], err = evalengine.Max(row1[aggr.Col], row2[aggr.Col], aggr.CollationID) case AggregateCountDistinct: result[aggr.Col], err = evalengine.NullSafeAdd(row1[aggr.Col], countOne, fields[aggr.Col].Type) @@ -241,6 +252,10 @@ func merge( return result, curDistincts, nil } +func minMaxWeightStringError() ([]sqltypes.Value, []sqltypes.Value, error) { + return nil, nil, vterrors.VT12001("min/max on types that are not comparable is not supported") +} + func convertFinal(current []sqltypes.Value, aggregates []*AggregateParams) ([]sqltypes.Value, error) { result := sqltypes.CopyRow(current) for _, aggr := range aggregates { @@ -270,17 +285,13 @@ func convertFields(fields []*querypb.Field, aggrs []*AggregateParams) []*querypb if aggr.Alias != "" { fields[aggr.Col].Name = aggr.Alias } - if aggr.isDistinct() { - // TODO: this should move to plan time - aggr.KeyCol = aggr.Col - } } return fields } func findComparableCurrentDistinct(row []sqltypes.Value, aggr *AggregateParams) sqltypes.Value { curDistinct := row[aggr.KeyCol] - if aggr.WAssigned && !curDistinct.IsComparable() { + if aggr.WAssigned() && !curDistinct.IsComparable() { aggr.KeyCol = aggr.WCol curDistinct = row[aggr.KeyCol] } diff --git a/go/vt/vtgate/engine/opcode/constants.go b/go/vt/vtgate/engine/opcode/constants.go index fa54b898ede..818a9e67db6 100644 --- a/go/vt/vtgate/engine/opcode/constants.go +++ b/go/vt/vtgate/engine/opcode/constants.go @@ -161,3 +161,21 @@ func (code AggregateOpcode) Type(typ *querypb.Type) (querypb.Type, bool) { panic(code.String()) // we have a unit test checking we never reach here } } + +func (code AggregateOpcode) NeedsComparableValues() bool { + switch code { + case AggregateCountDistinct, AggregateSumDistinct, AggregateMin, AggregateMax: + return true + default: + return false + } +} + +func (code AggregateOpcode) IsDistinct() bool { + switch code { + case AggregateCountDistinct, AggregateSumDistinct: + return true + default: + return false + } +} diff --git a/go/vt/vtgate/engine/ordered_aggregate_test.go b/go/vt/vtgate/engine/ordered_aggregate_test.go index 83f2cb4ecf4..539e81b4c59 100644 --- a/go/vt/vtgate/engine/ordered_aggregate_test.go +++ b/go/vt/vtgate/engine/ordered_aggregate_test.go @@ -45,7 +45,6 @@ func init() { } func TestOrderedAggregateExecute(t *testing.T) { - assert := assert.New(t) fields := sqltypes.MakeTestFields( "col|count(*)", "varbinary|decimal", @@ -62,16 +61,13 @@ func TestOrderedAggregateExecute(t *testing.T) { } oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{{ - Opcode: AggregateSum, - Col: 1, - }}, + Aggregates: []*AggregateParams{NewAggregateParam(AggregateSum, 1, "")}, GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false) - assert.NoError(err) + assert.NoError(t, err) wantResult := sqltypes.MakeTestResult( fields, @@ -83,7 +79,6 @@ func TestOrderedAggregateExecute(t *testing.T) { } func TestOrderedAggregateExecuteTruncate(t *testing.T) { - assert := assert.New(t) fp := &fakePrimitive{ results: []*sqltypes.Result{sqltypes.MakeTestResult( sqltypes.MakeTestFields( @@ -98,19 +93,18 @@ func TestOrderedAggregateExecuteTruncate(t *testing.T) { )}, } + aggr := NewAggregateParam(AggregateSum, 1, "") + aggr.OrigOpcode = AggregateCountStar + oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{{ - OrigOpcode: AggregateCountStar, - Opcode: AggregateSum, - Col: 1, - }}, + Aggregates: []*AggregateParams{aggr}, GroupByKeys: []*GroupByParams{{KeyCol: 2}}, TruncateColumnCount: 2, Input: fp, } result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false) - assert.NoError(err) + assert.NoError(t, err) wantResult := sqltypes.MakeTestResult( sqltypes.MakeTestFields( @@ -124,8 +118,34 @@ func TestOrderedAggregateExecuteTruncate(t *testing.T) { utils.MustMatch(t, wantResult, result) } +func TestMinMaxFailsCorrectly(t *testing.T) { + fp := &fakePrimitive{ + results: []*sqltypes.Result{sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col|weight_string(col)", + "varchar|varbinary", + ), + "a|A", + "A|A", + "b|B", + "C|C", + "c|C", + )}, + } + + aggr := NewAggregateParam(AggregateMax, 0, "") + aggr.WCol = 1 + oa := &ScalarAggregate{ + Aggregates: []*AggregateParams{aggr}, + TruncateColumnCount: 1, + Input: fp, + } + + _, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false) + assert.ErrorContains(t, err, "min/max on types that are not comparable is not supported") +} + func TestOrderedAggregateStreamExecute(t *testing.T) { - assert := assert.New(t) fields := sqltypes.MakeTestFields( "col|count(*)", "varbinary|decimal", @@ -142,10 +162,7 @@ func TestOrderedAggregateStreamExecute(t *testing.T) { } oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{{ - Opcode: AggregateSum, - Col: 1, - }}, + Aggregates: []*AggregateParams{NewAggregateParam(AggregateSum, 1, "")}, GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } @@ -155,7 +172,7 @@ func TestOrderedAggregateStreamExecute(t *testing.T) { results = append(results, qr) return nil }) - assert.NoError(err) + assert.NoError(t, err) wantResults := sqltypes.MakeTestStreamingResults( fields, @@ -169,7 +186,6 @@ func TestOrderedAggregateStreamExecute(t *testing.T) { } func TestOrderedAggregateStreamExecuteTruncate(t *testing.T) { - assert := assert.New(t) fp := &fakePrimitive{ results: []*sqltypes.Result{sqltypes.MakeTestResult( sqltypes.MakeTestFields( @@ -185,10 +201,7 @@ func TestOrderedAggregateStreamExecuteTruncate(t *testing.T) { } oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{{ - Opcode: AggregateSum, - Col: 1, - }}, + Aggregates: []*AggregateParams{NewAggregateParam(AggregateSum, 1, "")}, GroupByKeys: []*GroupByParams{{KeyCol: 2}}, TruncateColumnCount: 2, Input: fp, @@ -199,7 +212,7 @@ func TestOrderedAggregateStreamExecuteTruncate(t *testing.T) { results = append(results, qr) return nil }) - assert.NoError(err) + assert.NoError(t, err) wantResults := sqltypes.MakeTestStreamingResults( sqltypes.MakeTestFields( @@ -216,7 +229,6 @@ func TestOrderedAggregateStreamExecuteTruncate(t *testing.T) { } func TestOrderedAggregateGetFields(t *testing.T) { - assert := assert.New(t) input := sqltypes.MakeTestResult( sqltypes.MakeTestFields( "col|count(*)", @@ -228,34 +240,8 @@ func TestOrderedAggregateGetFields(t *testing.T) { oa := &OrderedAggregate{Input: fp} got, err := oa.GetFields(context.Background(), nil, nil) - assert.NoError(err) - assert.Equal(got, input) -} - -func TestOrderedAggregateGetFieldsTruncate(t *testing.T) { - assert := assert.New(t) - result := sqltypes.MakeTestResult( - sqltypes.MakeTestFields( - "col|count(*)|weight_string(col)", - "varchar|decimal|varbinary", - ), - ) - fp := &fakePrimitive{results: []*sqltypes.Result{result}} - - oa := &OrderedAggregate{ - TruncateColumnCount: 2, - Input: fp, - } - - got, err := oa.GetFields(context.Background(), nil, nil) - assert.NoError(err) - wantResult := sqltypes.MakeTestResult( - sqltypes.MakeTestFields( - "col|count(*)", - "varchar|decimal", - ), - ) - utils.MustMatch(t, wantResult, got) + assert.NoError(t, err) + assert.Equal(t, got, input) } func TestOrderedAggregateInputFail(t *testing.T) { @@ -280,7 +266,6 @@ func TestOrderedAggregateInputFail(t *testing.T) { } func TestOrderedAggregateExecuteCountDistinct(t *testing.T) { - assert := assert.New(t) fp := &fakePrimitive{ results: []*sqltypes.Result{sqltypes.MakeTestResult( sqltypes.MakeTestFields( @@ -319,23 +304,17 @@ func TestOrderedAggregateExecuteCountDistinct(t *testing.T) { )}, } + aggr1 := NewAggregateParam(AggregateCountDistinct, 1, "count(distinct col2)") + aggr2 := NewAggregateParam(AggregateSum, 2, "") + aggr2.OrigOpcode = AggregateCountStar oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{{ - Opcode: AggregateCountDistinct, - Col: 1, - Alias: "count(distinct col2)", - }, { - // Also add a count(*) - OrigOpcode: AggregateCountStar, - Opcode: AggregateSum, - Col: 2, - }}, + Aggregates: []*AggregateParams{aggr1, aggr2}, GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false) - assert.NoError(err) + assert.NoError(t, err) wantResult := sqltypes.MakeTestResult( sqltypes.MakeTestFields( @@ -356,7 +335,6 @@ func TestOrderedAggregateExecuteCountDistinct(t *testing.T) { } func TestOrderedAggregateStreamCountDistinct(t *testing.T) { - assert := assert.New(t) fp := &fakePrimitive{ results: []*sqltypes.Result{sqltypes.MakeTestResult( sqltypes.MakeTestFields( @@ -395,16 +373,13 @@ func TestOrderedAggregateStreamCountDistinct(t *testing.T) { )}, } + aggr2 := NewAggregateParam(AggregateSum, 2, "") + aggr2.OrigOpcode = AggregateCountDistinct + oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{{ - Opcode: AggregateCountDistinct, - Col: 1, - Alias: "count(distinct col2)", - }, { - Opcode: AggregateSum, - OrigOpcode: AggregateCountStar, - Col: 2, - }}, + Aggregates: []*AggregateParams{ + NewAggregateParam(AggregateCountDistinct, 1, "count(distinct col2)"), + aggr2}, GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } @@ -414,7 +389,7 @@ func TestOrderedAggregateStreamCountDistinct(t *testing.T) { results = append(results, qr) return nil }) - assert.NoError(err) + assert.NoError(t, err) wantResults := sqltypes.MakeTestStreamingResults( sqltypes.MakeTestFields( @@ -443,7 +418,6 @@ func TestOrderedAggregateStreamCountDistinct(t *testing.T) { } func TestOrderedAggregateSumDistinctGood(t *testing.T) { - assert := assert.New(t) fp := &fakePrimitive{ results: []*sqltypes.Result{sqltypes.MakeTestResult( sqltypes.MakeTestFields( @@ -484,20 +458,16 @@ func TestOrderedAggregateSumDistinctGood(t *testing.T) { } oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{{ - Opcode: AggregateSumDistinct, - Col: 1, - Alias: "sum(distinct col2)", - }, { - Opcode: AggregateSum, - Col: 2, - }}, + Aggregates: []*AggregateParams{ + NewAggregateParam(AggregateSumDistinct, 1, "sum(distinct col2)"), + NewAggregateParam(AggregateSum, 2, ""), + }, GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false) - assert.NoError(err) + assert.NoError(t, err) wantResult := sqltypes.MakeTestResult( sqltypes.MakeTestFields( @@ -516,7 +486,7 @@ func TestOrderedAggregateSumDistinctGood(t *testing.T) { ) want := fmt.Sprintf("%v", wantResult.Rows) got := fmt.Sprintf("%v", result.Rows) - assert.Equal(want, got) + assert.Equal(t, want, got) } func TestOrderedAggregateSumDistinctTolerateError(t *testing.T) { @@ -533,11 +503,7 @@ func TestOrderedAggregateSumDistinctTolerateError(t *testing.T) { } oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{{ - Opcode: AggregateSumDistinct, - Col: 1, - Alias: "sum(distinct col2)", - }}, + Aggregates: []*AggregateParams{NewAggregateParam(AggregateSumDistinct, 1, "sum(distinct col2)")}, GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } @@ -569,10 +535,7 @@ func TestOrderedAggregateKeysFail(t *testing.T) { } oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{{ - Opcode: AggregateSum, - Col: 1, - }}, + Aggregates: []*AggregateParams{NewAggregateParam(AggregateSum, 1, "")}, GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } @@ -602,10 +565,7 @@ func TestOrderedAggregateMergeFail(t *testing.T) { } oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{{ - Opcode: AggregateSum, - Col: 1, - }}, + Aggregates: []*AggregateParams{NewAggregateParam(AggregateSum, 1, "")}, GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } @@ -640,22 +600,13 @@ func TestOrderedAggregateMergeFail(t *testing.T) { } func TestMerge(t *testing.T) { - assert := assert.New(t) oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{{ - Opcode: AggregateSum, - Col: 1, - }, { - Opcode: AggregateSum, - Col: 2, - }, { - Opcode: AggregateMin, - Col: 3, - }, { - Opcode: AggregateMax, - Col: 4, - }}, - } + Aggregates: []*AggregateParams{ + NewAggregateParam(AggregateSum, 1, ""), + NewAggregateParam(AggregateSum, 2, ""), + NewAggregateParam(AggregateMin, 3, ""), + NewAggregateParam(AggregateMax, 4, ""), + }} fields := sqltypes.MakeTestFields( "a|b|c|d|e", "int64|int64|decimal|in32|varbinary", @@ -666,14 +617,14 @@ func TestMerge(t *testing.T) { ) merged, _, err := merge(fields, r.Rows[0], r.Rows[1], nil, oa.Aggregates) - assert.NoError(err) + assert.NoError(t, err) want := sqltypes.MakeTestResult(fields, "1|5|6.0|2|bc").Rows[0] - assert.Equal(want, merged) + assert.Equal(t, want, merged) // swap and retry merged, _, err = merge(fields, r.Rows[1], r.Rows[0], nil, oa.Aggregates) - assert.NoError(err) - assert.Equal(want, merged) + assert.NoError(t, err) + assert.Equal(t, want, merged) } func TestOrderedAggregateExecuteGtid(t *testing.T) { @@ -703,11 +654,7 @@ func TestOrderedAggregateExecuteGtid(t *testing.T) { } oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{{ - Opcode: AggregateGtid, - Col: 1, - Alias: "vgtid", - }}, + Aggregates: []*AggregateParams{NewAggregateParam(AggregateGtid, 1, "vgtid")}, TruncateColumnCount: 2, Input: fp, } @@ -740,14 +687,10 @@ func TestCountDistinctOnVarchar(t *testing.T) { )}, } + aggr := NewAggregateParam(AggregateCountDistinct, 1, "count(distinct c2)") + aggr.WCol = 2 oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{{ - Opcode: AggregateCountDistinct, - Col: 1, - WCol: 2, - WAssigned: true, - Alias: "count(distinct c2)", - }}, + Aggregates: []*AggregateParams{aggr}, GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, TruncateColumnCount: 2, @@ -804,14 +747,10 @@ func TestCountDistinctOnVarcharWithNulls(t *testing.T) { )}, } + aggr := NewAggregateParam(AggregateCountDistinct, 1, "count(distinct c2)") + aggr.WCol = 2 oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{{ - Opcode: AggregateCountDistinct, - Col: 1, - WCol: 2, - WAssigned: true, - Alias: "count(distinct c2)", - }}, + Aggregates: []*AggregateParams{aggr}, GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, TruncateColumnCount: 2, @@ -870,14 +809,10 @@ func TestSumDistinctOnVarcharWithNulls(t *testing.T) { )}, } + aggr := NewAggregateParam(AggregateSumDistinct, 1, "sum(distinct c2)") + aggr.WCol = 2 oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{{ - Opcode: AggregateSumDistinct, - Col: 1, - WCol: 2, - WAssigned: true, - Alias: "sum(distinct c2)", - }}, + Aggregates: []*AggregateParams{aggr}, GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, TruncateColumnCount: 2, @@ -939,15 +874,10 @@ func TestMultiDistinct(t *testing.T) { } oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{{ - Opcode: AggregateCountDistinct, - Col: 1, - Alias: "count(distinct c2)", - }, { - Opcode: AggregateSumDistinct, - Col: 2, - Alias: "sum(distinct c3)", - }}, + Aggregates: []*AggregateParams{ + NewAggregateParam(AggregateCountDistinct, 1, "count(distinct c2)"), + NewAggregateParam(AggregateSumDistinct, 2, "sum(distinct c3)"), + }, GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } @@ -982,7 +912,6 @@ func TestMultiDistinct(t *testing.T) { } func TestOrderedAggregateCollate(t *testing.T) { - assert := assert.New(t) fields := sqltypes.MakeTestFields( "col|count(*)", "varchar|decimal", @@ -1004,16 +933,13 @@ func TestOrderedAggregateCollate(t *testing.T) { collationID, _ := collationEnv.LookupID("utf8mb4_0900_ai_ci") oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{{ - Opcode: AggregateSum, - Col: 1, - }}, + Aggregates: []*AggregateParams{NewAggregateParam(AggregateSum, 1, "")}, GroupByKeys: []*GroupByParams{{KeyCol: 0, CollationID: collationID}}, Input: fp, } result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false) - assert.NoError(err) + assert.NoError(t, err) wantResult := sqltypes.MakeTestResult( fields, @@ -1026,7 +952,6 @@ func TestOrderedAggregateCollate(t *testing.T) { } func TestOrderedAggregateCollateAS(t *testing.T) { - assert := assert.New(t) fields := sqltypes.MakeTestFields( "col|count(*)", "varchar|decimal", @@ -1046,16 +971,13 @@ func TestOrderedAggregateCollateAS(t *testing.T) { collationID, _ := collationEnv.LookupID("utf8mb4_0900_as_ci") oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{{ - Opcode: AggregateSum, - Col: 1, - }}, + Aggregates: []*AggregateParams{NewAggregateParam(AggregateSum, 1, "")}, GroupByKeys: []*GroupByParams{{KeyCol: 0, CollationID: collationID}}, Input: fp, } result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false) - assert.NoError(err) + assert.NoError(t, err) wantResult := sqltypes.MakeTestResult( fields, @@ -1069,7 +991,6 @@ func TestOrderedAggregateCollateAS(t *testing.T) { } func TestOrderedAggregateCollateKS(t *testing.T) { - assert := assert.New(t) fields := sqltypes.MakeTestFields( "col|count(*)", "varchar|decimal", @@ -1090,16 +1011,13 @@ func TestOrderedAggregateCollateKS(t *testing.T) { collationID, _ := collationEnv.LookupID("utf8mb4_ja_0900_as_cs_ks") oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{{ - Opcode: AggregateSum, - Col: 1, - }}, + Aggregates: []*AggregateParams{NewAggregateParam(AggregateSum, 1, "")}, GroupByKeys: []*GroupByParams{{KeyCol: 0, CollationID: collationID}}, Input: fp, } result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false) - assert.NoError(err) + assert.NoError(t, err) wantResult := sqltypes.MakeTestResult( fields, @@ -1175,11 +1093,7 @@ func TestGroupConcatWithAggrOnEngine(t *testing.T) { t.Run(tcase.name, func(t *testing.T) { fp := &fakePrimitive{results: []*sqltypes.Result{tcase.inputResult}} oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{{ - Opcode: AggregateGroupConcat, - Col: 1, - Alias: "group_concat(c2)", - }}, + Aggregates: []*AggregateParams{NewAggregateParam(AggregateGroupConcat, 1, "group_concat(c2)")}, GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } @@ -1258,10 +1172,7 @@ func TestGroupConcat(t *testing.T) { t.Run(tcase.name, func(t *testing.T) { fp := &fakePrimitive{results: []*sqltypes.Result{tcase.inputResult}} oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{{ - Opcode: AggregateGroupConcat, - Col: 1, - }}, + Aggregates: []*AggregateParams{NewAggregateParam(AggregateGroupConcat, 1, "")}, GroupByKeys: []*GroupByParams{{KeyCol: 0}}, Input: fp, } diff --git a/go/vt/vtgate/engine/scalar_aggregation_test.go b/go/vt/vtgate/engine/scalar_aggregation_test.go index 273a6228c8f..033d69c9271 100644 --- a/go/vt/vtgate/engine/scalar_aggregation_test.go +++ b/go/vt/vtgate/engine/scalar_aggregation_test.go @@ -69,7 +69,6 @@ func TestEmptyRows(outer *testing.T) { for _, test := range testCases { outer.Run(test.opcode.String(), func(t *testing.T) { - assert := assert.New(t) fp := &fakePrimitive{ results: []*sqltypes.Result{sqltypes.MakeTestResult( sqltypes.MakeTestFields( @@ -91,7 +90,7 @@ func TestEmptyRows(outer *testing.T) { } result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false) - assert.NoError(err) + assert.NoError(t, err) wantResult := sqltypes.MakeTestResult( sqltypes.MakeTestFields( @@ -106,7 +105,6 @@ func TestEmptyRows(outer *testing.T) { } func TestScalarAggregateStreamExecute(t *testing.T) { - assert := assert.New(t) fields := sqltypes.MakeTestFields( "col|weight_string(col)", "uint64|varbinary", @@ -135,17 +133,16 @@ func TestScalarAggregateStreamExecute(t *testing.T) { results = append(results, qr) return nil }) - assert.NoError(err) + assert.NoError(t, err) // one for the fields, and one for the actual aggregation result require.EqualValues(t, 2, len(results), "number of results") got := fmt.Sprintf("%v", results[1].Rows) - assert.Equal("[[DECIMAL(4)]]", got) + assert.Equal(t, "[[DECIMAL(4)]]", got) } // TestScalarAggregateExecuteTruncate checks if truncate works func TestScalarAggregateExecuteTruncate(t *testing.T) { - assert := assert.New(t) fields := sqltypes.MakeTestFields( "col|weight_string(col)", "uint64|varbinary", @@ -169,8 +166,8 @@ func TestScalarAggregateExecuteTruncate(t *testing.T) { } qr, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, true) - assert.NoError(err) - assert.Equal("[[DECIMAL(4)]]", fmt.Sprintf("%v", qr.Rows)) + assert.NoError(t, err) + assert.Equal(t, "[[DECIMAL(4)]]", fmt.Sprintf("%v", qr.Rows)) } // TestScalarGroupConcatWithAggrOnEngine tests group_concat with full aggregation on engine. diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index dcaf55c6aa0..ed5e72fd04c 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -3688,15 +3688,15 @@ func TestSelectAggregationNoData(t *testing.T) { }, { sql: `select count(*) from (select col1, col2 from user limit 2) x`, - sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2", "int64|int64")), - expSandboxQ: "select col1, col2 from `user` limit :__upper_limit", + sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1", "int64")), + expSandboxQ: "select 1 from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"count(*)" type:INT64]`, expRow: `[[INT64(0)]]`, }, { sql: `select col2, count(*) from (select col1, col2 from user limit 2) x group by col2`, - sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col2)", "int64|int64|varbinary")), - expSandboxQ: "select col1, col2, weight_string(col2) from `user` order by col2 asc limit :__upper_limit", + sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col2|1|weight_string(col2)", "int64|int64|varbinary")), + expSandboxQ: "select col2, 1, weight_string(col2) from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"col2" type:INT64 name:"count(*)" type:INT64]`, expRow: `[]`, }, @@ -3772,71 +3772,71 @@ func TestSelectAggregationData(t *testing.T) { }, { sql: `select count(*) from (select col1, col2 from user limit 2) x`, - sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2", "int64|int64"), "1|2", "2|1"), - expSandboxQ: "select col1, col2 from `user` limit :__upper_limit", + sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("1", "int64"), "1", "1"), + expSandboxQ: "select 1 from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"count(*)" type:INT64]`, expRow: `[[INT64(2)]]`, }, { sql: `select col2, count(*) from (select col1, col2 from user limit 9) x group by col2`, - sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col2)", "int64|int64|varbinary"), "3|1|NULL", "2|2|NULL"), - expSandboxQ: "select col1, col2, weight_string(col2) from `user` order by col2 asc limit :__upper_limit", + sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col2|1|weight_string(col2)", "int64|int64|varbinary"), "3|1|NULL", "2|1|NULL"), + expSandboxQ: "select col2, 1, weight_string(col2) from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"col2" type:INT64 name:"count(*)" type:INT64]`, - expRow: `[[INT64(1) INT64(8)] [INT64(2) INT64(1)]]`, + expRow: `[[INT64(2) INT64(4)] [INT64(3) INT64(5)]]`, }, { sql: `select count(col1) from (select id, col1 from user limit 2) x`, - sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1", "int64|varchar"), "3|a", "2|b"), - expSandboxQ: "select id, col1 from `user` limit :__upper_limit", + sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1", "varchar"), "a", "b"), + expSandboxQ: "select col1 from (select id, col1 from `user`) as x limit :__upper_limit", expField: `[name:"count(col1)" type:INT64]`, expRow: `[[INT64(2)]]`, }, { sql: `select count(col1), col2 from (select col2, col1 from user limit 9) x group by col2`, - sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col2|col1|weight_string(col2)", "int64|varchar|varbinary"), "3|a|NULL", "2|b|NULL"), - expSandboxQ: "select col2, col1, weight_string(col2) from `user` order by col2 asc limit :__upper_limit", + sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col2)", "varchar|int64|varbinary"), "a|3|NULL", "b|2|NULL"), + expSandboxQ: "select col1, col2, weight_string(col2) from (select col2, col1 from `user`) as x limit :__upper_limit", expField: `[name:"count(col1)" type:INT64 name:"col2" type:INT64]`, - expRow: `[[INT64(8) INT64(2)] [INT64(1) INT64(3)]]`, + expRow: `[[INT64(4) INT64(2)] [INT64(5) INT64(3)]]`, }, { sql: `select col1, count(col2) from (select col1, col2 from user limit 9) x group by col1`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|int64|varbinary"), "a|1|a", "b|null|b"), - expSandboxQ: "select col1, col2, weight_string(col1) from `user` order by col1 asc limit :__upper_limit", + expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"col1" type:VARCHAR name:"count(col2)" type:INT64]`, - expRow: `[[VARCHAR("a") INT64(8)] [VARCHAR("b") INT64(0)]]`, + expRow: `[[VARCHAR("a") INT64(5)] [VARCHAR("b") INT64(0)]]`, }, { sql: `select col1, count(col2) from (select col1, col2 from user limit 32) x group by col1`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|int64|varbinary"), "null|1|null", "null|null|null", "a|1|a", "b|null|b"), - expSandboxQ: "select col1, col2, weight_string(col1) from `user` order by col1 asc limit :__upper_limit", + expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"col1" type:VARCHAR name:"count(col2)" type:INT64]`, expRow: `[[NULL INT64(8)] [VARCHAR("a") INT64(8)] [VARCHAR("b") INT64(0)]]`, }, { sql: `select col1, sum(col2) from (select col1, col2 from user limit 4) x group by col1`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|int64|varbinary"), "a|3|a"), - expSandboxQ: "select col1, col2, weight_string(col1) from `user` order by col1 asc limit :__upper_limit", + expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"col1" type:VARCHAR name:"sum(col2)" type:DECIMAL]`, expRow: `[[VARCHAR("a") DECIMAL(12)]]`, }, { sql: `select col1, sum(col2) from (select col1, col2 from user limit 4) x group by col1`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|varchar|varbinary"), "a|2|a"), - expSandboxQ: "select col1, col2, weight_string(col1) from `user` order by col1 asc limit :__upper_limit", + expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"col1" type:VARCHAR name:"sum(col2)" type:FLOAT64]`, expRow: `[[VARCHAR("a") FLOAT64(8)]]`, }, { sql: `select col1, sum(col2) from (select col1, col2 from user limit 4) x group by col1`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|varchar|varbinary"), "a|x|a"), - expSandboxQ: "select col1, col2, weight_string(col1) from `user` order by col1 asc limit :__upper_limit", + expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"col1" type:VARCHAR name:"sum(col2)" type:FLOAT64]`, expRow: `[[VARCHAR("a") FLOAT64(0)]]`, }, { sql: `select col1, sum(col2) from (select col1, col2 from user limit 4) x group by col1`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|varchar|varbinary"), "a|null|a"), - expSandboxQ: "select col1, col2, weight_string(col1) from `user` order by col1 asc limit :__upper_limit", + expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"col1" type:VARCHAR name:"sum(col2)" type:FLOAT64]`, expRow: `[[VARCHAR("a") NULL]]`, }, diff --git a/go/vt/vtgate/planbuilder/gen4_planner.go b/go/vt/vtgate/planbuilder/gen4_planner.go index 87853dca885..de52b22b37a 100644 --- a/go/vt/vtgate/planbuilder/gen4_planner.go +++ b/go/vt/vtgate/planbuilder/gen4_planner.go @@ -592,7 +592,7 @@ func planHorizon(ctx *plancontext.PlanningContext, plan logicalPlan, in sqlparse } func planOrderByOnUnion(ctx *plancontext.PlanningContext, plan logicalPlan, union *sqlparser.Union) (logicalPlan, error) { - qp, err := operators.CreateQPFromUnion(ctx, union) + qp, err := operators.CreateQPFromSelectStatement(ctx, union) if err != nil { return nil, err } diff --git a/go/vt/vtgate/planbuilder/horizon_planning.go b/go/vt/vtgate/planbuilder/horizon_planning.go index 1f50b2e19bd..8f78959d72d 100644 --- a/go/vt/vtgate/planbuilder/horizon_planning.go +++ b/go/vt/vtgate/planbuilder/horizon_planning.go @@ -69,7 +69,7 @@ func (hp *horizonPlanning) planHorizon(ctx *plancontext.PlanningContext, plan lo } var err error - hp.qp, err = operators.CreateQPFromSelect(ctx, hp.sel) + hp.qp, err = operators.CreateQPFromSelectStatement(ctx, hp.sel) if err != nil { return nil, err } @@ -440,14 +440,11 @@ func generateAggregateParams(aggrs []operators.Aggr, aggrParamOffsets [][]offset } } - aggrParams[idx] = &engine.AggregateParams{ - Opcode: opcode, - Col: offset, - Alias: aggr.Alias, - Expr: aggr.Original.Expr, - Original: aggr.Original, - OrigOpcode: aggr.OpCode, - } + aggrParam := engine.NewAggregateParam(opcode, offset, aggr.Alias) + aggrParam.Expr = aggr.Original.Expr + aggrParam.Original = aggr.Original + aggrParam.OrigOpcode = aggr.OpCode + aggrParams[idx] = aggrParam } return aggrParams, nil } @@ -478,16 +475,12 @@ func addColumnsToOA( count++ a := aggregationExprs[offset] collID := ctx.SemTable.CollationForExpr(a.Func.GetArg()) - oa.aggregates = append(oa.aggregates, &engine.AggregateParams{ - Opcode: a.OpCode, - Col: o.col, - KeyCol: o.col, - WAssigned: o.wsCol >= 0, - WCol: o.wsCol, - Alias: a.Alias, - Original: a.Original, - CollationID: collID, - }) + aggr := engine.NewAggregateParam(a.OpCode, o.col, a.Alias) + aggr.KeyCol = o.col + aggr.WCol = o.wsCol + aggr.Original = a.Original + aggr.CollationID = collID + oa.aggregates = append(oa.aggregates, aggr) } lastOffset := distinctOffsets[len(distinctOffsets)-1] distinctIdx := 0 diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index cde1b756a9a..307936d3c08 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -51,8 +51,6 @@ func transformToLogicalPlan(ctx *plancontext.PlanningContext, op ops.Operator, i return transformSubQueryPlan(ctx, op) case *operators.CorrelatedSubQueryOp: return transformCorrelatedSubQueryPlan(ctx, op) - case *operators.Derived: - return transformDerivedPlan(ctx, op) case *operators.Filter: return transformFilter(ctx, op) case *operators.Horizon: @@ -89,16 +87,13 @@ func transformAggregator(ctx *plancontext.PlanningContext, op *operators.Aggrega if aggr.OpCode == opcode.AggregateUnassigned { return nil, vterrors.VT12001(fmt.Sprintf("in scatter query: aggregation function '%s'", sqlparser.String(aggr.Original))) } - oa.aggregates = append(oa.aggregates, &engine.AggregateParams{ - Opcode: aggr.OpCode, - Col: aggr.ColOffset, - Alias: aggr.Alias, - Expr: aggr.Func, - Original: aggr.Original, - OrigOpcode: aggr.OriginalOpCode, - WCol: aggr.WSOffset, - CollationID: aggr.GetCollation(ctx), - }) + aggrParam := engine.NewAggregateParam(aggr.OpCode, aggr.ColOffset, aggr.Alias) + aggrParam.Expr = aggr.Func + aggrParam.Original = aggr.Original + aggrParam.OrigOpcode = aggr.OriginalOpCode + aggrParam.WCol = aggr.WSOffset + aggrParam.CollationID = aggr.GetCollation(ctx) + oa.aggregates = append(oa.aggregates, aggrParam) } for _, groupBy := range op.Grouping { oa.groupByKeys = append(oa.groupByKeys, &engine.GroupByParams{ @@ -274,11 +269,14 @@ func transformFilter(ctx *plancontext.PlanningContext, op *operators.Filter) (lo } func transformHorizon(ctx *plancontext.PlanningContext, op *operators.Horizon, isRoot bool) (logicalPlan, error) { + if op.IsDerived() { + return transformDerivedPlan(ctx, op) + } source, err := transformToLogicalPlan(ctx, op.Source, isRoot) if err != nil { return nil, err } - switch node := op.Select.(type) { + switch node := op.Query.(type) { case *sqlparser.Select: hp := horizonPlanning{ sel: node, @@ -865,7 +863,7 @@ func getCollationsFor(ctx *plancontext.PlanningContext, n *operators.Union) []co return colls } -func transformDerivedPlan(ctx *plancontext.PlanningContext, op *operators.Derived) (logicalPlan, error) { +func transformDerivedPlan(ctx *plancontext.PlanningContext, op *operators.Horizon) (logicalPlan, error) { // transforming the inner part of the derived table into a logical plan // so that we can do horizon planning on the inner. If the logical plan // we've produced is a Route, we set its Select.From field to be an aliased diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index 07fa5fbbd9d..573f101471c 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -246,7 +246,7 @@ func (ts *tableSorter) Swap(i, j int) { } func (h *Horizon) toSQL(qb *queryBuilder) error { - err := stripDownQuery(h.Select, qb.sel) + err := stripDownQuery(h.Query, qb.sel) if err != nil { return err } @@ -318,9 +318,10 @@ func buildQuery(op ops.Operator, qb *queryBuilder) error { return buildApplyJoin(op, qb) case *Filter: return buildFilter(op, qb) - case *Derived: - return buildDerived(op, qb) case *Horizon: + if op.TableId != nil { + return buildDerived(op, qb) + } return buildHorizon(op, qb) case *Limit: return buildLimit(op, qb) @@ -457,7 +458,7 @@ func buildFilter(op *Filter, qb *queryBuilder) error { return nil } -func buildDerived(op *Derived, qb *queryBuilder) error { +func buildDerived(op *Horizon, qb *queryBuilder) error { err := buildQuery(op.Source, qb) if err != nil { return err @@ -486,7 +487,7 @@ func buildHorizon(op *Horizon, qb *queryBuilder) error { return err } - err = stripDownQuery(op.Select, qb.sel) + err = stripDownQuery(op.Query, qb.sel) if err != nil { return err } diff --git a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go index 9a596039906..cb506aac959 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go +++ b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go @@ -57,23 +57,24 @@ func tryPushingDownAggregator(ctx *plancontext.PlanningContext, aggregator *Aggr } aggregator.Pushed = true - if applyResult != rewrite.SameTree && aggregator.Original { - aggregator.aggregateTheAggregates() - } return } func (a *Aggregator) aggregateTheAggregates() { - for i, aggr := range a.Aggregations { - // Handle different aggregation operations when pushing down through a sharded route. - switch aggr.OpCode { - case opcode.AggregateCount, opcode.AggregateCountStar, opcode.AggregateCountDistinct: - // All count variations turn into SUM above the Route. - // Think of it as we are SUMming together a bunch of distributed COUNTs. - aggr.OriginalOpCode, aggr.OpCode = aggr.OpCode, opcode.AggregateSum - a.Aggregations[i] = aggr - } + for i := range a.Aggregations { + aggregateTheAggregate(a, i) + } +} + +func aggregateTheAggregate(a *Aggregator, i int) { + aggr := a.Aggregations[i] + switch aggr.OpCode { + case opcode.AggregateCount, opcode.AggregateCountStar, opcode.AggregateCountDistinct: + // All count variations turn into SUM above the Route. + // Think of it as we are SUMming together a bunch of distributed COUNTs. + aggr.OriginalOpCode, aggr.OpCode = aggr.OpCode, opcode.AggregateSum + a.Aggregations[i] = aggr } } @@ -92,9 +93,13 @@ func pushDownAggregationThroughRoute( } // Create a new aggregator to be placed below the route. - aggrBelowRoute := aggregator.Clone([]ops.Operator{route.Source}).(*Aggregator) - aggrBelowRoute.Pushed = false - aggrBelowRoute.Original = false + aggrBelowRoute := aggregator.SplitAggregatorBelowRoute(route.Inputs()) + aggrBelowRoute.Aggregations = nil + + err := pushDownAggregations(ctx, aggregator, aggrBelowRoute) + if err != nil { + return nil, nil, err + } // Set the source of the route to the new aggregator placed below the route. route.Source = aggrBelowRoute @@ -108,18 +113,45 @@ func pushDownAggregationThroughRoute( return aggregator, rewrite.NewTree("push aggregation under route - keep original", aggregator), nil } +// pushDownAggregations splits aggregations between the original aggregator and the one we are pushing down +func pushDownAggregations(ctx *plancontext.PlanningContext, aggregator *Aggregator, aggrBelowRoute *Aggregator) error { + for i, aggregation := range aggregator.Aggregations { + if !aggregation.Distinct || exprHasUniqueVindex(ctx, aggregation.Func.GetArg()) { + aggrBelowRoute.Aggregations = append(aggrBelowRoute.Aggregations, aggregation) + aggregateTheAggregate(aggregator, i) + continue + } + innerExpr := aggregation.Func.GetArg() + + if aggregator.DistinctExpr != nil { + if ctx.SemTable.EqualsExpr(aggregator.DistinctExpr, innerExpr) { + // we can handle multiple distinct aggregations, as long as they are aggregating on the same expression + aggrBelowRoute.Columns[aggregation.ColOffset] = aeWrap(innerExpr) + continue + } + return vterrors.VT12001(fmt.Sprintf("only one DISTINCT aggregation is allowed in a SELECT: %s", sqlparser.String(aggregation.Original))) + } + + // We handle a distinct aggregation by turning it into a group by and + // doing the aggregating on the vtgate level instead + aggregator.DistinctExpr = innerExpr + aeDistinctExpr := aeWrap(aggregator.DistinctExpr) + + aggrBelowRoute.Columns[aggregation.ColOffset] = aeDistinctExpr + + groupBy := NewGroupBy(aggregator.DistinctExpr, aggregator.DistinctExpr, aeDistinctExpr) + groupBy.ColOffset = aggregation.ColOffset + aggrBelowRoute.Grouping = append(aggrBelowRoute.Grouping, groupBy) + } + return nil +} + func pushDownAggregationThroughFilter( ctx *plancontext.PlanningContext, aggregator *Aggregator, filter *Filter, ) (ops.Operator, *rewrite.ApplyResult, error) { - for _, predicate := range filter.Predicates { - if sqlparser.ContainsAggregation(predicate) { - return nil, nil, errHorizonNotPlanned() - } - } - columnsNeeded := collectColNamesNeeded(ctx, filter) // Create a new aggregator to be placed below the route. @@ -145,7 +177,7 @@ withNextColumn: // by splitting one and pushing under a join, we can get rid of this one return aggregator.Source, rewrite.NewTree("push aggregation under filter - remove original", aggregator), nil } - + aggregator.aggregateTheAggregates() return aggregator, rewrite.NewTree("push aggregation under filter - keep original", aggregator), nil } @@ -293,6 +325,7 @@ func pushDownAggregationThroughJoin(ctx *plancontext.PlanningContext, rootAggr * return output, rewrite.NewTree("push Aggregation under join - keep original", rootAggr), nil } + rootAggr.aggregateTheAggregates() rootAggr.Source = output return rootAggr, rewrite.NewTree("push Aggregation under join", rootAggr), nil } @@ -472,6 +505,11 @@ func (ab *aggBuilder) handleAggr(ctx *plancontext.PlanningContext, aggr Aggr) er return errAbortAggrPushing case opcode.AggregateUnassigned: return vterrors.VT12001(fmt.Sprintf("in scatter query: aggregation function '%s'", sqlparser.String(aggr.Original))) + case opcode.AggregateGtid: + // this is only used for SHOW GTID queries that will never contain joins + return vterrors.VT13001("cannot do join with vgtid") + case opcode.AggregateSumDistinct, opcode.AggregateCountDistinct: + return errAbortAggrPushing default: return errHorizonNotPlanned() } diff --git a/go/vt/vtgate/planbuilder/operators/aggregator.go b/go/vt/vtgate/planbuilder/operators/aggregator.go index 9a5b28bccfa..36602e024d4 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregator.go +++ b/go/vt/vtgate/planbuilder/operators/aggregator.go @@ -42,6 +42,9 @@ type ( Grouping []GroupBy Aggregations []Aggr + // We support a single distinct aggregation per aggregator. It is stored here + DistinctExpr sqlparser.Expr + // Pushed will be set to true once this aggregation has been pushed deeper in the tree Pushed bool offsetPlanned bool @@ -58,17 +61,12 @@ type ( ) func (a *Aggregator) Clone(inputs []ops.Operator) ops.Operator { - return &Aggregator{ - Source: inputs[0], - Columns: slices.Clone(a.Columns), - Grouping: slices.Clone(a.Grouping), - Aggregations: slices.Clone(a.Aggregations), - Pushed: a.Pushed, - offsetPlanned: a.offsetPlanned, - Original: a.Original, - ResultColumns: a.ResultColumns, - QP: a.QP, - } + kopy := *a + kopy.Source = inputs[0] + kopy.Columns = slices.Clone(a.Columns) + kopy.Grouping = slices.Clone(a.Grouping) + kopy.Aggregations = slices.Clone(a.Aggregations) + return &kopy } func (a *Aggregator) Inputs() []ops.Operator { @@ -117,13 +115,39 @@ func (a *Aggregator) isDerived() bool { return a.TableID != nil } +func (a *Aggregator) findCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr) (int, error) { + if a.isDerived() { + derivedTBL, err := ctx.SemTable.TableInfoFor(*a.TableID) + if err != nil { + return 0, err + } + expr = semantics.RewriteDerivedTableExpression(expr, derivedTBL) + } + if offset, found := canReuseColumn(ctx, a.Columns, expr, extractExpr); found { + return offset, nil + } + return -1, nil +} + func (a *Aggregator) AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, _, addToGroupBy bool) (ops.Operator, int, error) { - if addToGroupBy { - return nil, 0, vterrors.VT13001("did not expect to add group by here") + offset, err := a.findCol(ctx, expr.Expr) + if err != nil { + return nil, 0, err + } + if offset >= 0 { + return a, offset, nil + } + if a.isDerived() { + derivedTBL, err := ctx.SemTable.TableInfoFor(*a.TableID) + if err != nil { + return nil, 0, err + } + expr.Expr = semantics.RewriteDerivedTableExpression(expr.Expr, derivedTBL) } + // Aggregator is little special and cannot work if the input offset are not matched with the aggregation columns. // So, before pushing anything from above the aggregator offset planning needs to be completed. - err := a.planOffsets(ctx) + err = a.planOffsets(ctx) if err != nil { return nil, 0, err } @@ -138,6 +162,10 @@ func (a *Aggregator) AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser } } + if addToGroupBy { + return nil, 0, vterrors.VT13001("did not expect to add group by here") + } + // If weight string function is received from above operator. Then check if we have a group on the expression used. // If it is found, then continue to push it down but with addToGroupBy true so that is the added to group by sql down in the AddColumn. // This also set the weight string column offset so that we would not need to add it later in aggregator operator planOffset. @@ -175,6 +203,10 @@ func (a *Aggregator) AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser } func (a *Aggregator) GetColumns() ([]*sqlparser.AliasedExpr, error) { + if _, isSourceDerived := a.Source.(*Horizon); isSourceDerived { + return a.Columns, nil + } + // we update the incoming columns, so we know about any new columns that have been added // in the optimization phase, other operators could be pushed down resulting in additional columns for aggregator. // Aggregator should be made aware of these to truncate them in final result. @@ -191,10 +223,17 @@ func (a *Aggregator) GetColumns() ([]*sqlparser.AliasedExpr, error) { return a.Columns, nil } +func (a *Aggregator) GetSelectExprs() (sqlparser.SelectExprs, error) { + return transformColumnsToSelectExprs(a) +} + func (a *Aggregator) ShortDescription() string { - columnns := slices2.Map(a.Columns, func(from *sqlparser.AliasedExpr) string { + columns := slices2.Map(a.Columns, func(from *sqlparser.AliasedExpr) string { return sqlparser.String(from) }) + if a.Alias != "" { + columns = append([]string{"derived[" + a.Alias + "]"}, columns...) + } org := "" if a.Original { @@ -202,7 +241,7 @@ func (a *Aggregator) ShortDescription() string { } if len(a.Grouping) == 0 { - return fmt.Sprintf("%s%s", org, strings.Join(columnns, ", ")) + return fmt.Sprintf("%s%s", org, strings.Join(columns, ", ")) } var grouping []string @@ -210,7 +249,7 @@ func (a *Aggregator) ShortDescription() string { grouping = append(grouping, sqlparser.String(gb.SimplifiedExpr)) } - return fmt.Sprintf("%s%s group by %s", org, strings.Join(columnns, ", "), strings.Join(grouping, ",")) + return fmt.Sprintf("%s%s group by %s", org, strings.Join(columns, ", "), strings.Join(grouping, ",")) } func (a *Aggregator) GetOrdering() ([]ops.OrderBy, error) { @@ -248,7 +287,7 @@ func (a *Aggregator) planOffsets(ctx *plancontext.PlanningContext) error { } for idx, aggr := range a.Aggregations { - if !aggr.NeedWeightString(ctx) { + if !aggr.NeedsWeightString(ctx) { continue } offset, err := a.internalAddColumn(ctx, aeWrap(weightStringFor(aggr.Func.GetArg())), true) @@ -312,6 +351,9 @@ func (a *Aggregator) addIfAggregationColumn(ctx *plancontext.PlanningContext, co return 0, err } if aggr.ColOffset != offset { + if _, srcIsAlsoAggr := a.Source.(*Aggregator); srcIsAlsoAggr { + return 0, vterrors.VT12001("aggregation on top of aggregation not supported") + } return -1, vterrors.VT13001(fmt.Sprintf("aggregation column on wrong index: want: %d, got: %d", colIdx, offset)) } @@ -327,7 +369,7 @@ func (a *Aggregator) addIfGroupingColumn(ctx *plancontext.PlanningContext, colId continue } - newSrc, offset, err := a.Source.AddColumn(ctx, a.Columns[colIdx], false, false) + newSrc, offset, err := a.Source.AddColumn(ctx, a.Columns[colIdx], false, true) if err != nil { return -1, err } @@ -365,7 +407,7 @@ func (a *Aggregator) pushRemainingGroupingColumnsAndWeightStrings(ctx *planconte a.Grouping[idx].WSOffset = offset } for idx, aggr := range a.Aggregations { - if aggr.WSOffset != -1 || !aggr.NeedWeightString(ctx) { + if aggr.WSOffset != -1 || !aggr.NeedsWeightString(ctx) { continue } @@ -395,4 +437,23 @@ func (a *Aggregator) internalAddColumn(ctx *plancontext.PlanningContext, aliased return offset, nil } +// SplitAggregatorBelowRoute returns the aggregator that will live under the Route. +// This is used when we are splitting the aggregation so one part is done +// at the mysql level and one part at the vtgate level +func (a *Aggregator) SplitAggregatorBelowRoute(input []ops.Operator) *Aggregator { + newOp := a.Clone(input).(*Aggregator) + newOp.Pushed = false + newOp.Original = false + newOp.Alias = "" + newOp.TableID = nil + return newOp +} + +func (a *Aggregator) introducesTableID() semantics.TableSet { + if a.TableID == nil { + return semantics.EmptyTableSet() + } + return *a.TableID +} + var _ ops.Operator = (*Aggregator)(nil) diff --git a/go/vt/vtgate/planbuilder/operators/apply_join.go b/go/vt/vtgate/planbuilder/operators/apply_join.go index 816ce47a813..41080e39902 100644 --- a/go/vt/vtgate/planbuilder/operators/apply_join.go +++ b/go/vt/vtgate/planbuilder/operators/apply_join.go @@ -183,6 +183,10 @@ func (a *ApplyJoin) GetColumns() ([]*sqlparser.AliasedExpr, error) { return slices2.Map(a.ColumnsAST, joinColumnToAliasedExpr), nil } +func (a *ApplyJoin) GetSelectExprs() (sqlparser.SelectExprs, error) { + return transformColumnsToSelectExprs(a) +} + func (a *ApplyJoin) GetOrdering() ([]ops.OrderBy, error) { return a.LHS.GetOrdering() } diff --git a/go/vt/vtgate/planbuilder/operators/logical.go b/go/vt/vtgate/planbuilder/operators/ast2op.go similarity index 93% rename from go/vt/vtgate/planbuilder/operators/logical.go rename to go/vt/vtgate/planbuilder/operators/ast2op.go index 41396070089..59223f0e631 100644 --- a/go/vt/vtgate/planbuilder/operators/logical.go +++ b/go/vt/vtgate/planbuilder/operators/ast2op.go @@ -30,8 +30,8 @@ import ( "vitess.io/vitess/go/vt/vtgate/vindexes" ) -// createLogicalOperatorFromAST creates an operator tree that represents the input SELECT or UNION query -func createLogicalOperatorFromAST(ctx *plancontext.PlanningContext, selStmt sqlparser.Statement) (op ops.Operator, err error) { +// translateQueryToOp creates an operator tree that represents the input SELECT or UNION query +func translateQueryToOp(ctx *plancontext.PlanningContext, selStmt sqlparser.Statement) (op ops.Operator, err error) { switch node := selStmt.(type) { case *sqlparser.Select: op, err = createOperatorFromSelect(ctx, node) @@ -53,7 +53,6 @@ func createLogicalOperatorFromAST(ctx *plancontext.PlanningContext, selStmt sqlp return op, nil } -// createOperatorFromSelect creates an operator tree that represents the input SELECT query func createOperatorFromSelect(ctx *plancontext.PlanningContext, sel *sqlparser.Select) (ops.Operator, error) { subq, err := createSubqueryFromStatement(ctx, sel) if err != nil { @@ -74,21 +73,20 @@ func createOperatorFromSelect(ctx *plancontext.PlanningContext, sel *sqlparser.S addColumnEquality(ctx, expr) } } - if subq == nil { - return &Horizon{ - Source: op, - Select: sel, - }, nil + + if subq != nil { + subq.Outer = op + op = subq } - subq.Outer = op + return &Horizon{ - Source: subq, - Select: sel, + Source: op, + Query: sel, }, nil } func createOperatorFromUnion(ctx *plancontext.PlanningContext, node *sqlparser.Union) (ops.Operator, error) { - opLHS, err := createLogicalOperatorFromAST(ctx, node.Left) + opLHS, err := translateQueryToOp(ctx, node.Left) if err != nil { return nil, err } @@ -97,7 +95,7 @@ func createOperatorFromUnion(ctx *plancontext.PlanningContext, node *sqlparser.U if isRHSUnion { return nil, vterrors.VT12001("nesting of UNIONs on the right-hand side") } - opRHS, err := createLogicalOperatorFromAST(ctx, node.Right) + opRHS, err := translateQueryToOp(ctx, node.Right) if err != nil { return nil, err } @@ -106,7 +104,7 @@ func createOperatorFromUnion(ctx *plancontext.PlanningContext, node *sqlparser.U Distinct: node.Distinct, Sources: []ops.Operator{opLHS, opRHS}, } - return &Horizon{Source: union, Select: node}, nil + return &Horizon{Source: union, Query: node}, nil } func createOperatorFromUpdate(ctx *plancontext.PlanningContext, updStmt *sqlparser.Update) (ops.Operator, error) { @@ -530,10 +528,10 @@ func modifyForAutoinc(ins *sqlparser.Insert, vTable *vindexes.Table) (*Generate, return gen, nil } -func getOperatorFromTableExpr(ctx *plancontext.PlanningContext, tableExpr sqlparser.TableExpr) (ops.Operator, error) { +func getOperatorFromTableExpr(ctx *plancontext.PlanningContext, tableExpr sqlparser.TableExpr, onlyTable bool) (ops.Operator, error) { switch tableExpr := tableExpr.(type) { case *sqlparser.AliasedTableExpr: - return getOperatorFromAliasedTableExpr(ctx, tableExpr) + return getOperatorFromAliasedTableExpr(ctx, tableExpr, onlyTable) case *sqlparser.JoinTableExpr: return getOperatorFromJoinTableExpr(ctx, tableExpr) case *sqlparser.ParenTableExpr: @@ -544,11 +542,11 @@ func getOperatorFromTableExpr(ctx *plancontext.PlanningContext, tableExpr sqlpar } func getOperatorFromJoinTableExpr(ctx *plancontext.PlanningContext, tableExpr *sqlparser.JoinTableExpr) (ops.Operator, error) { - lhs, err := getOperatorFromTableExpr(ctx, tableExpr.LeftExpr) + lhs, err := getOperatorFromTableExpr(ctx, tableExpr.LeftExpr, false) if err != nil { return nil, err } - rhs, err := getOperatorFromTableExpr(ctx, tableExpr.RightExpr) + rhs, err := getOperatorFromTableExpr(ctx, tableExpr.RightExpr, false) if err != nil { return nil, err } @@ -563,7 +561,7 @@ func getOperatorFromJoinTableExpr(ctx *plancontext.PlanningContext, tableExpr *s } } -func getOperatorFromAliasedTableExpr(ctx *plancontext.PlanningContext, tableExpr *sqlparser.AliasedTableExpr) (ops.Operator, error) { +func getOperatorFromAliasedTableExpr(ctx *plancontext.PlanningContext, tableExpr *sqlparser.AliasedTableExpr, onlyTable bool) (ops.Operator, error) { tableID := ctx.SemTable.TableSetFor(tableExpr) switch tbl := tableExpr.Expr.(type) { case sqlparser.TableName: @@ -591,7 +589,7 @@ func getOperatorFromAliasedTableExpr(ctx *plancontext.PlanningContext, tableExpr qg.Tables = append(qg.Tables, qt) return qg, nil case *sqlparser.DerivedTable: - inner, err := createLogicalOperatorFromAST(ctx, tbl.Select) + inner, err := translateQueryToOp(ctx, tbl.Select) if err != nil { return nil, err } @@ -599,12 +597,22 @@ func getOperatorFromAliasedTableExpr(ctx *plancontext.PlanningContext, tableExpr inner = horizon.Source } - return &Derived{ - TableId: tableID, + stmt := sqlparser.CloneSelectStatement(tbl.Select) + if onlyTable && stmt.GetLimit() == nil { + stmt.SetOrderBy(nil) + } + qp, err := CreateQPFromSelectStatement(ctx, stmt) + if err != nil { + return nil, err + } + + return &Horizon{ + TableId: &tableID, Alias: tableExpr.As.String(), Source: inner, - Query: tbl.Select, + Query: stmt, ColumnAliases: tableExpr.Columns, + QP: qp, }, nil default: return nil, vterrors.VT13001(fmt.Sprintf("unable to use: %T", tbl)) @@ -614,7 +622,7 @@ func getOperatorFromAliasedTableExpr(ctx *plancontext.PlanningContext, tableExpr func crossJoin(ctx *plancontext.PlanningContext, exprs sqlparser.TableExprs) (ops.Operator, error) { var output ops.Operator for _, tableExpr := range exprs { - op, err := getOperatorFromTableExpr(ctx, tableExpr) + op, err := getOperatorFromTableExpr(ctx, tableExpr, len(exprs) == 1) if err != nil { return nil, err } diff --git a/go/vt/vtgate/planbuilder/operators/delete.go b/go/vt/vtgate/planbuilder/operators/delete.go index af1db82ad80..c24ab9f5065 100644 --- a/go/vt/vtgate/planbuilder/operators/delete.go +++ b/go/vt/vtgate/planbuilder/operators/delete.go @@ -37,7 +37,7 @@ type Delete struct { } // Introduces implements the PhysicalOperator interface -func (d *Delete) Introduces() semantics.TableSet { +func (d *Delete) introducesTableID() semantics.TableSet { return d.QTable.ID } diff --git a/go/vt/vtgate/planbuilder/operators/derived.go b/go/vt/vtgate/planbuilder/operators/derived.go deleted file mode 100644 index dc6bbd79952..00000000000 --- a/go/vt/vtgate/planbuilder/operators/derived.go +++ /dev/null @@ -1,262 +0,0 @@ -/* -Copyright 2022 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package operators - -import ( - "io" - - "golang.org/x/exp/slices" - - "vitess.io/vitess/go/vt/sqlparser" - "vitess.io/vitess/go/vt/vterrors" - "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops" - "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" - "vitess.io/vitess/go/vt/vtgate/semantics" -) - -type Derived struct { - Source ops.Operator - TableId semantics.TableSet - - // QP contains the QueryProjection for this op - QP *QueryProjection - - Query sqlparser.SelectStatement - Alias string - ColumnAliases sqlparser.Columns - - // Columns needed to feed other plans - Columns []*sqlparser.ColName - ColumnsOffset []int -} - -// Clone implements the Operator interface -func (d *Derived) Clone(inputs []ops.Operator) ops.Operator { - return &Derived{ - Source: inputs[0], - Query: d.Query, - Alias: d.Alias, - ColumnAliases: sqlparser.CloneColumns(d.ColumnAliases), - Columns: slices.Clone(d.Columns), - ColumnsOffset: slices.Clone(d.ColumnsOffset), - TableId: d.TableId, - } -} - -// findOutputColumn returns the index on which the given name is found in the slice of -// *sqlparser.SelectExprs of the derivedTree. The *sqlparser.SelectExpr must be of type -// *sqlparser.AliasedExpr and match the given name. -// If name is not present but the query's select expressions contain a *sqlparser.StarExpr -// the function will return no error and an index equal to -1. -// If name is not present and the query does not have a *sqlparser.StarExpr, the function -// will return an unknown column error. -func (d *Derived) findOutputColumn(name *sqlparser.ColName) (int, error) { - hasStar := false - for j, exp := range sqlparser.GetFirstSelect(d.Query).SelectExprs { - switch exp := exp.(type) { - case *sqlparser.AliasedExpr: - if !exp.As.IsEmpty() && exp.As.Equal(name.Name) { - return j, nil - } - if exp.As.IsEmpty() { - col, ok := exp.Expr.(*sqlparser.ColName) - if !ok { - return 0, vterrors.VT12001("complex expression needs column alias: %s", sqlparser.String(exp)) - } - if name.Name.Equal(col.Name) { - return j, nil - } - } - case *sqlparser.StarExpr: - hasStar = true - } - } - - // we have found a star but no matching *sqlparser.AliasedExpr, thus we return -1 with no error. - if hasStar { - return -1, nil - } - return 0, vterrors.VT03014(name.Name.String(), "field list") -} - -// IsMergeable is not a great name for this function. Suggestions for a better one are welcome! -// This function will return false if the derived table inside it has to run on the vtgate side, and so can't be merged with subqueries -// This logic can also be used to check if this is a derived table that can be had on the left hand side of a vtgate join. -// Since vtgate joins are always nested loop joins, we can't execute them on the RHS -// if they do some things, like LIMIT or GROUP BY on wrong columns -func (d *Derived) IsMergeable(ctx *plancontext.PlanningContext) bool { - return isMergeable(ctx, d.Query, d) -} - -// Inputs implements the Operator interface -func (d *Derived) Inputs() []ops.Operator { - return []ops.Operator{d.Source} -} - -// SetInputs implements the Operator interface -func (d *Derived) SetInputs(ops []ops.Operator) { - d.Source = ops[0] -} - -func (d *Derived) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) (ops.Operator, error) { - if _, isUNion := d.Source.(*Union); isUNion { - // If we have a derived table on top of a UNION, we can let the UNION do the expression rewriting - var err error - d.Source, err = d.Source.AddPredicate(ctx, expr) - return d, err - } - tableInfo, err := ctx.SemTable.TableInfoForExpr(expr) - if err != nil { - if err == semantics.ErrNotSingleTable { - return &Filter{ - Source: d, - Predicates: []sqlparser.Expr{expr}, - }, nil - } - return nil, err - } - - newExpr := semantics.RewriteDerivedTableExpression(expr, tableInfo) - if !canBePushedDownIntoDerived(newExpr) { - // if we have an aggregation, we don't want to push it inside - return &Filter{Source: d, Predicates: []sqlparser.Expr{expr}}, nil - } - d.Source, err = d.Source.AddPredicate(ctx, newExpr) - if err != nil { - return nil, err - } - return d, nil -} - -func canBePushedDownIntoDerived(expr sqlparser.Expr) (canBePushed bool) { - canBePushed = true - _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { - switch node.(type) { - case *sqlparser.Max, *sqlparser.Min: - // empty by default - case sqlparser.AggrFunc: - canBePushed = false - return false, io.EOF - } - return true, nil - }, expr) - return -} - -func (d *Derived) AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, _, addToGroupBy bool) (ops.Operator, int, error) { - col, ok := expr.Expr.(*sqlparser.ColName) - if !ok { - return nil, 0, vterrors.VT13001("cannot push non-colname expression to a derived table") - } - - identity := func(c *sqlparser.ColName) sqlparser.Expr { return c } - if offset, found := canReuseColumn(ctx, d.Columns, col, identity); found { - return d, offset, nil - } - - i, err := d.findOutputColumn(col) - if err != nil { - return nil, 0, err - } - var pos int - d.ColumnsOffset, pos = addToIntSlice(d.ColumnsOffset, i) - - d.Columns = append(d.Columns, col) - // add it to the source if we were not already passing it through - if i <= -1 { - newSrc, _, err := d.Source.AddColumn(ctx, aeWrap(sqlparser.NewColName(col.Name.String())), true, addToGroupBy) - if err != nil { - return nil, 0, err - } - d.Source = newSrc - } - return d, pos, nil -} - -// canReuseColumn is generic, so it can be used with slices of different types. -// We don't care about the actual type, as long as we know it's a sqlparser.Expr -func canReuseColumn[T any]( - ctx *plancontext.PlanningContext, - columns []T, - col sqlparser.Expr, - f func(T) sqlparser.Expr, -) (offset int, found bool) { - for offset, column := range columns { - if ctx.SemTable.EqualsExprWithDeps(col, f(column)) { - return offset, true - } - } - - return -} - -func (d *Derived) GetColumns() (exprs []*sqlparser.AliasedExpr, err error) { - for _, expr := range sqlparser.GetFirstSelect(d.Query).SelectExprs { - ae, ok := expr.(*sqlparser.AliasedExpr) - if !ok { - return nil, errHorizonNotPlanned() - } - exprs = append(exprs, ae) - } - return -} - -func (d *Derived) GetOrdering() ([]ops.OrderBy, error) { - if d.QP == nil { - return nil, vterrors.VT13001("QP should already be here") - } - return d.QP.OrderExprs, nil -} - -func addToIntSlice(columnOffset []int, valToAdd int) ([]int, int) { - for idx, val := range columnOffset { - if val == valToAdd { - return columnOffset, idx - } - } - columnOffset = append(columnOffset, valToAdd) - return columnOffset, len(columnOffset) - 1 -} - -// TODO: REMOVE -func (d *Derived) selectStatement() sqlparser.SelectStatement { - return d.Query -} - -func (d *Derived) src() ops.Operator { - return d.Source -} - -func (d *Derived) getQP(ctx *plancontext.PlanningContext) (*QueryProjection, error) { - if d.QP != nil { - return d.QP, nil - } - qp, err := CreateQPFromSelect(ctx, d.Query.(*sqlparser.Select)) - if err != nil { - return nil, err - } - d.QP = qp - return d.QP, nil -} - -func (d *Derived) setQP(qp *QueryProjection) { - d.QP = qp -} - -func (d *Derived) ShortDescription() string { - return d.Alias -} diff --git a/go/vt/vtgate/planbuilder/operators/distinct.go b/go/vt/vtgate/planbuilder/operators/distinct.go index c120bef8230..d5905f87adb 100644 --- a/go/vt/vtgate/planbuilder/operators/distinct.go +++ b/go/vt/vtgate/planbuilder/operators/distinct.go @@ -112,6 +112,10 @@ func (d *Distinct) GetColumns() ([]*sqlparser.AliasedExpr, error) { return d.Source.GetColumns() } +func (d *Distinct) GetSelectExprs() (sqlparser.SelectExprs, error) { + return d.Source.GetSelectExprs() +} + func (d *Distinct) ShortDescription() string { return "" } diff --git a/go/vt/vtgate/planbuilder/operators/filter.go b/go/vt/vtgate/planbuilder/operators/filter.go index d8dbd8be9de..1c0751f47ad 100644 --- a/go/vt/vtgate/planbuilder/operators/filter.go +++ b/go/vt/vtgate/planbuilder/operators/filter.go @@ -100,6 +100,10 @@ func (f *Filter) GetColumns() ([]*sqlparser.AliasedExpr, error) { return f.Source.GetColumns() } +func (f *Filter) GetSelectExprs() (sqlparser.SelectExprs, error) { + return f.Source.GetSelectExprs() +} + func (f *Filter) GetOrdering() ([]ops.OrderBy, error) { return f.Source.GetOrdering() } diff --git a/go/vt/vtgate/planbuilder/operators/helpers.go b/go/vt/vtgate/planbuilder/operators/helpers.go index 1c472acf413..21be634d7d8 100644 --- a/go/vt/vtgate/planbuilder/operators/helpers.go +++ b/go/vt/vtgate/planbuilder/operators/helpers.go @@ -67,15 +67,15 @@ func Clone(op ops.Operator) ops.Operator { return op.Clone(clones) } -// TableIDIntroducer is used to signal that this operator introduces data from a new source -type TableIDIntroducer interface { - Introduces() semantics.TableSet +// tableIDIntroducer is used to signal that this operator introduces data from a new source +type tableIDIntroducer interface { + introducesTableID() semantics.TableSet } func TableID(op ops.Operator) (result semantics.TableSet) { _ = rewrite.Visit(op, func(this ops.Operator) error { - if tbl, ok := this.(TableIDIntroducer); ok { - result = result.Merge(tbl.Introduces()) + if tbl, ok := this.(tableIDIntroducer); ok { + result = result.Merge(tbl.introducesTableID()) } return nil }) diff --git a/go/vt/vtgate/planbuilder/operators/horizon.go b/go/vt/vtgate/planbuilder/operators/horizon.go index d4c30313114..d95e3244cff 100644 --- a/go/vt/vtgate/planbuilder/operators/horizon.go +++ b/go/vt/vtgate/planbuilder/operators/horizon.go @@ -17,10 +17,13 @@ limitations under the License. package operators import ( + "golang.org/x/exp/slices" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "vitess.io/vitess/go/vt/vtgate/semantics" ) // Horizon is an operator that allows us to postpone planning things like SELECT/GROUP BY/ORDER BY/LIMIT until later. @@ -32,58 +35,180 @@ import ( // and some that have to be evaluated at the vtgate level. type Horizon struct { Source ops.Operator - Select sqlparser.SelectStatement - QP *QueryProjection + + // If this is a derived table, the two following fields will contain the tableID and name of it + TableId *semantics.TableSet + Alias string + + // QP contains the QueryProjection for this op + QP *QueryProjection + + Query sqlparser.SelectStatement + ColumnAliases sqlparser.Columns + + // Columns needed to feed other plans + Columns []*sqlparser.ColName + ColumnsOffset []int } -func (h *Horizon) AddColumn(*plancontext.PlanningContext, *sqlparser.AliasedExpr, bool, bool) (ops.Operator, int, error) { - return nil, 0, vterrors.VT13001("the Horizon operator cannot accept new columns") +// Clone implements the Operator interface +func (h *Horizon) Clone(inputs []ops.Operator) ops.Operator { + return &Horizon{ + Source: inputs[0], + Query: h.Query, + Alias: h.Alias, + ColumnAliases: sqlparser.CloneColumns(h.ColumnAliases), + Columns: slices.Clone(h.Columns), + ColumnsOffset: slices.Clone(h.ColumnsOffset), + TableId: h.TableId, + QP: h.QP, + } } -func (h *Horizon) GetColumns() (exprs []*sqlparser.AliasedExpr, err error) { - for _, expr := range sqlparser.GetFirstSelect(h.Select).SelectExprs { - ae, ok := expr.(*sqlparser.AliasedExpr) - if !ok { - return nil, errHorizonNotPlanned() +// findOutputColumn returns the index on which the given name is found in the slice of +// *sqlparser.SelectExprs of the derivedTree. The *sqlparser.SelectExpr must be of type +// *sqlparser.AliasedExpr and match the given name. +// If name is not present but the query's select expressions contain a *sqlparser.StarExpr +// the function will return no error and an index equal to -1. +// If name is not present and the query does not have a *sqlparser.StarExpr, the function +// will return an unknown column error. +func (h *Horizon) findOutputColumn(name *sqlparser.ColName) (int, error) { + hasStar := false + for j, exp := range sqlparser.GetFirstSelect(h.Query).SelectExprs { + switch exp := exp.(type) { + case *sqlparser.AliasedExpr: + if !exp.As.IsEmpty() && exp.As.Equal(name.Name) { + return j, nil + } + if exp.As.IsEmpty() { + col, ok := exp.Expr.(*sqlparser.ColName) + if !ok { + return 0, vterrors.VT12001("complex expression needs column alias: %s", sqlparser.String(exp)) + } + if name.Name.Equal(col.Name) { + return j, nil + } + } + case *sqlparser.StarExpr: + hasStar = true } - exprs = append(exprs, ae) } - return + + // we have found a star but no matching *sqlparser.AliasedExpr, thus we return -1 with no error. + if hasStar { + return -1, nil + } + return 0, vterrors.VT03014(name.Name.String(), "field list") } -var _ ops.Operator = (*Horizon)(nil) +// IsMergeable is not a great name for this function. Suggestions for a better one are welcome! +// This function will return false if the derived table inside it has to run on the vtgate side, and so can't be merged with subqueries +// This logic can also be used to check if this is a derived table that can be had on the left hand side of a vtgate join. +// Since vtgate joins are always nested loop joins, we can't execute them on the RHS +// if they do some things, like LIMIT or GROUP BY on wrong columns +func (h *Horizon) IsMergeable(ctx *plancontext.PlanningContext) bool { + return isMergeable(ctx, h.Query, h) +} + +// Inputs implements the Operator interface +func (h *Horizon) Inputs() []ops.Operator { + return []ops.Operator{h.Source} +} + +// SetInputs implements the Operator interface +func (h *Horizon) SetInputs(ops []ops.Operator) { + h.Source = ops[0] +} func (h *Horizon) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) (ops.Operator, error) { - newSrc, err := h.Source.AddPredicate(ctx, expr) + if _, isUNion := h.Source.(*Union); isUNion { + // If we have a derived table on top of a UNION, we can let the UNION do the expression rewriting + var err error + h.Source, err = h.Source.AddPredicate(ctx, expr) + return h, err + } + tableInfo, err := ctx.SemTable.TableInfoForExpr(expr) + if err != nil { + if err == semantics.ErrNotSingleTable { + return &Filter{ + Source: h, + Predicates: []sqlparser.Expr{expr}, + }, nil + } + return nil, err + } + + newExpr := semantics.RewriteDerivedTableExpression(expr, tableInfo) + if sqlparser.ContainsAggregation(newExpr) { + return &Filter{Source: h, Predicates: []sqlparser.Expr{expr}}, nil + } + h.Source, err = h.Source.AddPredicate(ctx, newExpr) if err != nil { return nil, err } - h.Source = newSrc return h, nil } -func (h *Horizon) Clone(inputs []ops.Operator) ops.Operator { - return &Horizon{ - Source: inputs[0], - Select: h.Select, +func (h *Horizon) AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, _, addToGroupBy bool) (ops.Operator, int, error) { + col, ok := expr.Expr.(*sqlparser.ColName) + if !ok { + return nil, 0, vterrors.VT13001("cannot push non-colname expression to a derived table") } -} -func (h *Horizon) Inputs() []ops.Operator { - return []ops.Operator{h.Source} + identity := func(c *sqlparser.ColName) sqlparser.Expr { return c } + if offset, found := canReuseColumn(ctx, h.Columns, col, identity); found { + return h, offset, nil + } + + i, err := h.findOutputColumn(col) + if err != nil { + return nil, 0, err + } + var pos int + h.ColumnsOffset, pos = addToIntSlice(h.ColumnsOffset, i) + + h.Columns = append(h.Columns, col) + // add it to the source if we were not already passing it through + if i <= -1 { + newSrc, _, err := h.Source.AddColumn(ctx, aeWrap(sqlparser.NewColName(col.Name.String())), true, addToGroupBy) + if err != nil { + return nil, 0, err + } + h.Source = newSrc + } + return h, pos, nil } -// SetInputs implements the Operator interface -func (h *Horizon) SetInputs(ops []ops.Operator) { - h.Source = ops[0] +// canReuseColumn is generic, so it can be used with slices of different types. +// We don't care about the actual type, as long as we know it's a sqlparser.Expr +func canReuseColumn[T any]( + ctx *plancontext.PlanningContext, + columns []T, + col sqlparser.Expr, + f func(T) sqlparser.Expr, +) (offset int, found bool) { + for offset, column := range columns { + if ctx.SemTable.EqualsExprWithDeps(col, f(column)) { + return offset, true + } + } + + return } -func (h *Horizon) selectStatement() sqlparser.SelectStatement { - return h.Select +func (h *Horizon) GetColumns() (exprs []*sqlparser.AliasedExpr, err error) { + for _, expr := range sqlparser.GetFirstSelect(h.Query).SelectExprs { + ae, ok := expr.(*sqlparser.AliasedExpr) + if !ok { + return nil, vterrors.VT09015() + } + exprs = append(exprs, ae) + } + return } -func (h *Horizon) src() ops.Operator { - return h.Source +func (h *Horizon) GetSelectExprs() (sqlparser.SelectExprs, error) { + return sqlparser.GetFirstSelect(h.Query).SelectExprs, nil } func (h *Horizon) GetOrdering() ([]ops.OrderBy, error) { @@ -93,11 +218,30 @@ func (h *Horizon) GetOrdering() ([]ops.OrderBy, error) { return h.QP.OrderExprs, nil } +func addToIntSlice(columnOffset []int, valToAdd int) ([]int, int) { + for idx, val := range columnOffset { + if val == valToAdd { + return columnOffset, idx + } + } + columnOffset = append(columnOffset, valToAdd) + return columnOffset, len(columnOffset) - 1 +} + +// TODO: REMOVE +func (h *Horizon) selectStatement() sqlparser.SelectStatement { + return h.Query +} + +func (h *Horizon) src() ops.Operator { + return h.Source +} + func (h *Horizon) getQP(ctx *plancontext.PlanningContext) (*QueryProjection, error) { if h.QP != nil { return h.QP, nil } - qp, err := CreateQPFromSelect(ctx, h.Select.(*sqlparser.Select)) + qp, err := CreateQPFromSelectStatement(ctx, h.Query) if err != nil { return nil, err } @@ -110,5 +254,17 @@ func (h *Horizon) setQP(qp *QueryProjection) { } func (h *Horizon) ShortDescription() string { - return "" + return h.Alias +} + +func (h *Horizon) introducesTableID() semantics.TableSet { + if h.TableId == nil { + return semantics.EmptyTableSet() + } + + return *h.TableId +} + +func (h *Horizon) IsDerived() bool { + return h.TableId != nil } diff --git a/go/vt/vtgate/planbuilder/operators/horizon_expanding.go b/go/vt/vtgate/planbuilder/operators/horizon_expanding.go index 2342d6edb27..f12015e7d7c 100644 --- a/go/vt/vtgate/planbuilder/operators/horizon_expanding.go +++ b/go/vt/vtgate/planbuilder/operators/horizon_expanding.go @@ -26,7 +26,7 @@ import ( "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" ) -func expandHorizon(ctx *plancontext.PlanningContext, horizon horizonLike) (ops.Operator, *rewrite.ApplyResult, error) { +func expandHorizon(ctx *plancontext.PlanningContext, horizon *Horizon) (ops.Operator, *rewrite.ApplyResult, error) { sel, isSel := horizon.selectStatement().(*sqlparser.Select) if !isSel { return nil, nil, errHorizonNotPlanned() @@ -74,19 +74,7 @@ func expandHorizon(ctx *plancontext.PlanningContext, horizon horizonLike) (ops.O return op, rewrite.NewTree("expand horizon into smaller components", op), nil } -func checkInvalid(aggregations []Aggr, horizon horizonLike) error { - for _, aggregation := range aggregations { - if aggregation.Distinct { - return errHorizonNotPlanned() - } - } - if _, isDerived := horizon.(*Derived); isDerived { - return errHorizonNotPlanned() - } - return nil -} - -func createProjectionFromSelect(ctx *plancontext.PlanningContext, horizon horizonLike) (out ops.Operator, err error) { +func createProjectionFromSelect(ctx *plancontext.PlanningContext, horizon *Horizon) (out ops.Operator, err error) { qp, err := horizon.getQP(ctx) if err != nil { return nil, err @@ -97,42 +85,26 @@ func createProjectionFromSelect(ctx *plancontext.PlanningContext, horizon horizo if err != nil { return nil, err } - if derived, isDerived := horizon.(*Derived); isDerived { - id := derived.TableId - projX.TableID = &id - projX.Alias = derived.Alias - } + projX.TableID = horizon.TableId + projX.Alias = horizon.Alias out = projX return out, nil } - err = checkAggregationSupported(horizon) - if err != nil { - return nil, err - } - aggregations, complexAggr, err := qp.AggregationExpressions(ctx, true) if err != nil { return nil, err } - if err := checkInvalid(aggregations, horizon); err != nil { - return nil, err - } - a := &Aggregator{ Source: horizon.src(), Original: true, QP: qp, Grouping: qp.GetGrouping(), Aggregations: aggregations, - } - - if derived, isDerived := horizon.(*Derived); isDerived { - id := derived.TableId - a.TableID = &id - a.Alias = derived.Alias + TableID: horizon.TableId, + Alias: horizon.Alias, } if complexAggr { diff --git a/go/vt/vtgate/planbuilder/operators/horizon_planning.go b/go/vt/vtgate/planbuilder/operators/horizon_planning.go index 3fbff7a586e..dd6d5bc8898 100644 --- a/go/vt/vtgate/planbuilder/operators/horizon_planning.go +++ b/go/vt/vtgate/planbuilder/operators/horizon_planning.go @@ -33,14 +33,6 @@ type ( cols []ProjExpr names []*sqlparser.AliasedExpr } - - // horizonLike should be removed. we should use Horizon for both these cases - horizonLike interface { - ops.Operator - selectStatement() sqlparser.SelectStatement - src() ops.Operator - getQP(ctx *plancontext.PlanningContext) (*QueryProjection, error) - } ) func errHorizonNotPlanned() error { @@ -120,7 +112,7 @@ func planHorizons(ctx *plancontext.PlanningContext, root ops.Operator) (op ops.O func optimizeHorizonPlanning(ctx *plancontext.PlanningContext, root ops.Operator) (ops.Operator, error) { visitor := func(in ops.Operator, _ semantics.TableSet, isRoot bool) (ops.Operator, *rewrite.ApplyResult, error) { switch in := in.(type) { - case horizonLike: + case *Horizon: return pushOrExpandHorizon(ctx, in) case *Projection: return tryPushingDownProjection(ctx, in) @@ -151,12 +143,11 @@ func optimizeHorizonPlanning(ctx *plancontext.PlanningContext, root ops.Operator return newOp, nil } -func pushOrExpandHorizon(ctx *plancontext.PlanningContext, in horizonLike) (ops.Operator, *rewrite.ApplyResult, error) { - if derived, ok := in.(*Derived); ok { - if len(derived.ColumnAliases) > 0 { - return nil, nil, errHorizonNotPlanned() - } +func pushOrExpandHorizon(ctx *plancontext.PlanningContext, in *Horizon) (ops.Operator, *rewrite.ApplyResult, error) { + if len(in.ColumnAliases) > 0 { + return nil, nil, errHorizonNotPlanned() } + rb, isRoute := in.src().(*Route) if isRoute && rb.IsSingleShard() { return rewrite.Swap(in, rb, "push horizon into route") @@ -621,14 +612,14 @@ func tryPushingDownDistinct(in *Distinct) (ops.Operator, *rewrite.ApplyResult, e // makeSureOutputIsCorrect uses the original Horizon to make sure that the output columns line up with what the user asked for func makeSureOutputIsCorrect(ctx *plancontext.PlanningContext, oldHorizon ops.Operator, output ops.Operator) (ops.Operator, error) { - cols, err := output.GetColumns() + cols, err := output.GetSelectExprs() if err != nil { return nil, err } horizon := oldHorizon.(*Horizon) - sel := sqlparser.GetFirstSelect(horizon.Select) + sel := sqlparser.GetFirstSelect(horizon.Query) if len(sel.SelectExprs) == len(cols) { return output, nil diff --git a/go/vt/vtgate/planbuilder/operators/limit.go b/go/vt/vtgate/planbuilder/operators/limit.go index 35108965a52..24f6af9ec7c 100644 --- a/go/vt/vtgate/planbuilder/operators/limit.go +++ b/go/vt/vtgate/planbuilder/operators/limit.go @@ -69,6 +69,10 @@ func (l *Limit) GetColumns() ([]*sqlparser.AliasedExpr, error) { return l.Source.GetColumns() } +func (l *Limit) GetSelectExprs() (sqlparser.SelectExprs, error) { + return l.Source.GetSelectExprs() +} + func (l *Limit) GetOrdering() ([]ops.OrderBy, error) { return l.Source.GetOrdering() } diff --git a/go/vt/vtgate/planbuilder/operators/offset_planning.go b/go/vt/vtgate/planbuilder/operators/offset_planning.go index 31c36d609fb..b94252d1af6 100644 --- a/go/vt/vtgate/planbuilder/operators/offset_planning.go +++ b/go/vt/vtgate/planbuilder/operators/offset_planning.go @@ -38,7 +38,7 @@ func planOffsets(ctx *plancontext.PlanningContext, root ops.Operator) (ops.Opera visitor := func(in ops.Operator, _ semantics.TableSet, _ bool) (ops.Operator, *rewrite.ApplyResult, error) { var err error switch op := in.(type) { - case *Derived, *Horizon: + case *Horizon: return nil, nil, vterrors.VT13001(fmt.Sprintf("should not see %T here", in)) case offsettable: err = op.planOffsets(ctx) @@ -110,7 +110,12 @@ func useOffsets(ctx *plancontext.PlanningContext, expr sqlparser.Expr, op ops.Op } getColumns := func() []*sqlparser.AliasedExpr { return columns } - visitor := getVisitor(ctx, getColumns, found, notFound) + findCol := func(ctx *plancontext.PlanningContext, e sqlparser.Expr) (int, error) { + return slices.IndexFunc(getColumns(), func(expr *sqlparser.AliasedExpr) bool { + return ctx.SemTable.EqualsExprWithDeps(expr.Expr, e) + }), nil + } + visitor := getVisitor(ctx, findCol, found, notFound) // The cursor replace is not available while walking `down`, so `up` is used to do the replacement. up := func(cursor *sqlparser.CopyOnWriteCursor) { @@ -137,10 +142,6 @@ func addColumnsToInput(ctx *plancontext.PlanningContext, root ops.Operator) (ops return in, rewrite.SameTree, nil } - columns, err := filter.GetColumns() - if err != nil { - return nil, nil, err - } proj, areOnTopOfProj := filter.Source.(selectExpressions) if !areOnTopOfProj { // not much we can do here @@ -152,19 +153,12 @@ func addColumnsToInput(ctx *plancontext.PlanningContext, root ops.Operator) (ops _, addToGroupBy := e.(*sqlparser.ColName) proj.addColumnWithoutPushing(aeWrap(e), addToGroupBy) addedColumns = true - columns, err = proj.GetColumns() return nil } - getColumns := func() []*sqlparser.AliasedExpr { - return columns - } - visitor := getVisitor(ctx, getColumns, found, notFound) + visitor := getVisitor(ctx, proj.findCol, found, notFound) for _, expr := range filter.Predicates { - sqlparser.CopyOnRewrite(expr, visitor, nil, ctx.SemTable.CopyDependenciesOnSQLNodes) - if err != nil { - return nil, nil, err - } + _ = sqlparser.CopyOnRewrite(expr, visitor, nil, ctx.SemTable.CopyDependenciesOnSQLNodes) } if addedColumns { return in, rewrite.NewTree("added columns because filter needs it", in), nil @@ -178,7 +172,7 @@ func addColumnsToInput(ctx *plancontext.PlanningContext, root ops.Operator) (ops func getVisitor( ctx *plancontext.PlanningContext, - getColumns func() []*sqlparser.AliasedExpr, + findCol func(ctx *plancontext.PlanningContext, expr sqlparser.Expr) (int, error), found func(sqlparser.Expr, int), notFound func(sqlparser.Expr) error, ) func(node, parent sqlparser.SQLNode) bool { @@ -191,10 +185,11 @@ func getVisitor( if !ok { return true } - offset := slices.IndexFunc(getColumns(), func(expr *sqlparser.AliasedExpr) bool { - return ctx.SemTable.EqualsExprWithDeps(expr.Expr, e) - }) - + var offset int + offset, err = findCol(ctx, e) + if err != nil { + return false + } if offset >= 0 { found(e, offset) return false diff --git a/go/vt/vtgate/planbuilder/operators/operator.go b/go/vt/vtgate/planbuilder/operators/operator.go index d2ce5cb77d0..8fa73e882fe 100644 --- a/go/vt/vtgate/planbuilder/operators/operator.go +++ b/go/vt/vtgate/planbuilder/operators/operator.go @@ -34,6 +34,7 @@ The operators go through a few phases while planning: package operators import ( + "vitess.io/vitess/go/slices2" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops" @@ -53,11 +54,15 @@ type ( // PlanQuery creates a query plan for a given SQL statement func PlanQuery(ctx *plancontext.PlanningContext, stmt sqlparser.Statement) (ops.Operator, error) { - op, err := createLogicalOperatorFromAST(ctx, stmt) + op, err := translateQueryToOp(ctx, stmt) if err != nil { return nil, err } + if op, err = compact(ctx, op); err != nil { + return nil, err + } + if err = checkValid(op); err != nil { return nil, err } @@ -97,11 +102,15 @@ func (noInputs) SetInputs(ops []ops.Operator) { // AddColumn implements the Operator interface func (noColumns) AddColumn(*plancontext.PlanningContext, *sqlparser.AliasedExpr, bool, bool) (ops.Operator, int, error) { - return nil, 0, vterrors.VT13001("the noColumns operator cannot accept columns") + return nil, 0, vterrors.VT13001("noColumns operators have no column") } func (noColumns) GetColumns() ([]*sqlparser.AliasedExpr, error) { - return nil, vterrors.VT13001("the noColumns operator cannot accept columns") + return nil, vterrors.VT13001("noColumns operators have no column") +} + +func (noColumns) GetSelectExprs() (sqlparser.SelectExprs, error) { + return nil, vterrors.VT13001("noColumns operators have no column") } // AddPredicate implements the Operator interface @@ -135,3 +144,14 @@ func tryTruncateColumnsAt(op ops.Operator, truncateAt int) bool { return tryTruncateColumnsAt(inputs[0], truncateAt) } + +func transformColumnsToSelectExprs(op ops.Operator) (sqlparser.SelectExprs, error) { + columns, err := op.GetColumns() + if err != nil { + return nil, err + } + selExprs := slices2.Map(columns, func(from *sqlparser.AliasedExpr) sqlparser.SelectExpr { + return from + }) + return selExprs, nil +} diff --git a/go/vt/vtgate/planbuilder/operators/ops/op.go b/go/vt/vtgate/planbuilder/operators/ops/op.go index 57e27879861..a0a350633a2 100644 --- a/go/vt/vtgate/planbuilder/operators/ops/op.go +++ b/go/vt/vtgate/planbuilder/operators/ops/op.go @@ -49,6 +49,8 @@ type ( GetColumns() ([]*sqlparser.AliasedExpr, error) + GetSelectExprs() (sqlparser.SelectExprs, error) + ShortDescription() string GetOrdering() ([]OrderBy, error) diff --git a/go/vt/vtgate/planbuilder/operators/ordering.go b/go/vt/vtgate/planbuilder/operators/ordering.go index 360bf87cb23..2ff0901c97b 100644 --- a/go/vt/vtgate/planbuilder/operators/ordering.go +++ b/go/vt/vtgate/planbuilder/operators/ordering.go @@ -76,6 +76,10 @@ func (o *Ordering) GetColumns() ([]*sqlparser.AliasedExpr, error) { return o.Source.GetColumns() } +func (o *Ordering) GetSelectExprs() (sqlparser.SelectExprs, error) { + return o.Source.GetSelectExprs() +} + func (o *Ordering) GetOrdering() ([]ops.OrderBy, error) { return o.Order, nil } diff --git a/go/vt/vtgate/planbuilder/operators/phases.go b/go/vt/vtgate/planbuilder/operators/phases.go index 802d1876fdb..4a8c1ca0067 100644 --- a/go/vt/vtgate/planbuilder/operators/phases.go +++ b/go/vt/vtgate/planbuilder/operators/phases.go @@ -63,38 +63,53 @@ func addOrderBysForAggregations(ctx *plancontext.PlanningContext, root ops.Opera return in, rewrite.SameTree, nil } - requireOrdering, err := needsOrdering(aggrOp, ctx) + requireOrdering, err := needsOrdering(ctx, aggrOp) if err != nil { return nil, nil, err } if !requireOrdering { return in, rewrite.SameTree, nil } + orderBys := slices2.Map(aggrOp.Grouping, func(from GroupBy) ops.OrderBy { + return from.AsOrderBy() + }) + if aggrOp.DistinctExpr != nil { + orderBys = append(orderBys, ops.OrderBy{ + Inner: &sqlparser.Order{ + Expr: aggrOp.DistinctExpr, + }, + SimplifiedExpr: aggrOp.DistinctExpr, + }) + } aggrOp.Source = &Ordering{ Source: aggrOp.Source, - Order: slices2.Map(aggrOp.Grouping, func(from GroupBy) ops.OrderBy { - return from.AsOrderBy() - }), + Order: orderBys, } return in, rewrite.NewTree("added ordering before aggregation", in), nil } - return rewrite.TopDown(root, TableID, visitor, stopAtRoute) + return rewrite.BottomUp(root, TableID, visitor, stopAtRoute) } -func needsOrdering(in *Aggregator, ctx *plancontext.PlanningContext) (bool, error) { - if len(in.Grouping) == 0 { +func needsOrdering(ctx *plancontext.PlanningContext, in *Aggregator) (bool, error) { + requiredOrder := slices2.Map(in.Grouping, func(from GroupBy) sqlparser.Expr { + return from.SimplifiedExpr + }) + if in.DistinctExpr != nil { + requiredOrder = append(requiredOrder, in.DistinctExpr) + } + if len(requiredOrder) == 0 { return false, nil } srcOrdering, err := in.Source.GetOrdering() if err != nil { return false, err } - if len(srcOrdering) < len(in.Grouping) { + if len(srcOrdering) < len(requiredOrder) { return true, nil } - for idx, gb := range in.Grouping { - if !ctx.SemTable.EqualsExprWithDeps(srcOrdering[idx].SimplifiedExpr, gb.SimplifiedExpr) { + for idx, gb := range requiredOrder { + if !ctx.SemTable.EqualsExprWithDeps(srcOrdering[idx].SimplifiedExpr, gb) { return true, nil } } diff --git a/go/vt/vtgate/planbuilder/operators/projection.go b/go/vt/vtgate/planbuilder/operators/projection.go index a969fdc1129..ff5a43a8c35 100644 --- a/go/vt/vtgate/planbuilder/operators/projection.go +++ b/go/vt/vtgate/planbuilder/operators/projection.go @@ -119,10 +119,29 @@ func (p *Projection) isDerived() bool { return p.TableID != nil } +func (p *Projection) findCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr) (int, error) { + if p.isDerived() { + derivedTBL, err := ctx.SemTable.TableInfoFor(*p.TableID) + if err != nil { + return 0, err + } + expr = semantics.RewriteDerivedTableExpression(expr, derivedTBL) + } + if offset, found := canReuseColumn(ctx, p.Columns, expr, extractExpr); found { + return offset, nil + } + return -1, nil +} + func (p *Projection) AddColumn(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, _, addToGroupBy bool) (ops.Operator, int, error) { - if offset, found := canReuseColumn(ctx, p.Columns, expr.Expr, extractExpr); found { + offset, err := p.findCol(ctx, expr.Expr) + if err != nil { + return nil, 0, err + } + if offset >= 0 { return p, offset, nil } + sourceOp, offset, err := p.Source.AddColumn(ctx, expr, true, addToGroupBy) if err != nil { return nil, 0, err @@ -173,6 +192,10 @@ func (p *Projection) GetColumns() ([]*sqlparser.AliasedExpr, error) { return p.Columns, nil } +func (p *Projection) GetSelectExprs() (sqlparser.SelectExprs, error) { + return transformColumnsToSelectExprs(p) +} + func (p *Projection) GetOrdering() ([]ops.OrderBy, error) { return p.Source.GetOrdering() } @@ -319,3 +342,10 @@ func (p *Projection) planOffsets(ctx *plancontext.PlanningContext) error { return nil } + +func (p *Projection) introducesTableID() semantics.TableSet { + if p.TableID == nil { + return semantics.EmptyTableSet() + } + return *p.TableID +} diff --git a/go/vt/vtgate/planbuilder/operators/querygraph.go b/go/vt/vtgate/planbuilder/operators/querygraph.go index f9b05914e77..5d06da0e9b5 100644 --- a/go/vt/vtgate/planbuilder/operators/querygraph.go +++ b/go/vt/vtgate/planbuilder/operators/querygraph.go @@ -65,8 +65,8 @@ type ( var _ ops.Operator = (*QueryGraph)(nil) -// Introduces implements the TableIDIntroducer interface -func (qg *QueryGraph) Introduces() semantics.TableSet { +// Introduces implements the tableIDIntroducer interface +func (qg *QueryGraph) introducesTableID() semantics.TableSet { var ts semantics.TableSet for _, table := range qg.Tables { ts = ts.Merge(table.ID) diff --git a/go/vt/vtgate/planbuilder/operators/queryprojection.go b/go/vt/vtgate/planbuilder/operators/queryprojection.go index cde1c7a8d02..8f70df5c52e 100644 --- a/go/vt/vtgate/planbuilder/operators/queryprojection.go +++ b/go/vt/vtgate/planbuilder/operators/queryprojection.go @@ -29,7 +29,6 @@ import ( "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/engine/opcode" "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops" - "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/rewrite" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" "vitess.io/vitess/go/vt/vtgate/semantics" ) @@ -110,16 +109,8 @@ type ( } ) -func (aggr Aggr) NeedWeightString(ctx *plancontext.PlanningContext) bool { - switch aggr.OpCode { - case opcode.AggregateCountDistinct, opcode.AggregateSumDistinct: - return ctx.SemTable.NeedsWeightString(aggr.Func.GetArg()) - case opcode.AggregateMin, opcode.AggregateMax, opcode.AggregateGroupConcat: - // currently this returns false, as aggregation engine primitive does not support the usage of weight_string - // for comparison. If Min/Max column is non-comparable then it will fail at runtime. - return false - } - return false +func (aggr Aggr) NeedsWeightString(ctx *plancontext.PlanningContext) bool { + return aggr.OpCode.NeedsComparableValues() && ctx.SemTable.NeedsWeightString(aggr.Func.GetArg()) } func (aggr Aggr) GetCollation(ctx *plancontext.PlanningContext) collations.ID { @@ -208,8 +199,8 @@ func (s SelectExpr) GetAliasedExpr() (*sqlparser.AliasedExpr, error) { } } -// CreateQPFromSelect creates the QueryProjection for the input *sqlparser.Select -func CreateQPFromSelect(ctx *plancontext.PlanningContext, sel *sqlparser.Select) (*QueryProjection, error) { +// createQPFromSelect creates the QueryProjection for the input *sqlparser.Select +func createQPFromSelect(ctx *plancontext.PlanningContext, sel *sqlparser.Select) (*QueryProjection, error) { qp := &QueryProjection{ Distinct: sel.Distinct, } @@ -316,8 +307,8 @@ func (qp *QueryProjection) addSelectExpressions(sel *sqlparser.Select) error { return nil } -// CreateQPFromUnion creates the QueryProjection for the input *sqlparser.Union -func CreateQPFromUnion(ctx *plancontext.PlanningContext, union *sqlparser.Union) (*QueryProjection, error) { +// createQPFromUnion creates the QueryProjection for the input *sqlparser.Union +func createQPFromUnion(ctx *plancontext.PlanningContext, union *sqlparser.Union) (*QueryProjection, error) { qp := &QueryProjection{} sel := sqlparser.GetFirstSelect(union) @@ -832,19 +823,6 @@ func (qp *QueryProjection) GetColumnCount() int { return len(qp.SelectExprs) - qp.AddedColumn } -// checkAggregationSupported checks if the aggregation is supported on the given operator tree or not. -// We don't currently support planning for operators having derived tables. -func checkAggregationSupported(op ops.Operator) error { - return rewrite.Visit(op, func(operator ops.Operator) error { - _, isDerived := operator.(*Derived) - projection, isProjection := operator.(*Projection) - if isDerived || (isProjection && projection.TableID != nil) { - return errHorizonNotPlanned() - } - return nil - }) -} - func checkForInvalidGroupingExpressions(expr sqlparser.Expr) error { return sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) { if _, isAggregate := node.(sqlparser.AggrFunc); isAggregate { @@ -882,3 +860,13 @@ func CompareRefInt(a *int, b *int) bool { } return *a < *b } + +func CreateQPFromSelectStatement(ctx *plancontext.PlanningContext, stmt sqlparser.SelectStatement) (*QueryProjection, error) { + switch sel := stmt.(type) { + case *sqlparser.Select: + return createQPFromSelect(ctx, sel) + case *sqlparser.Union: + return createQPFromUnion(ctx, sel) + } + return nil, vterrors.VT13001("can only create query projection from Union and Select statements") +} diff --git a/go/vt/vtgate/planbuilder/operators/queryprojection_test.go b/go/vt/vtgate/planbuilder/operators/queryprojection_test.go index 2a89cd10716..7c92b716d7c 100644 --- a/go/vt/vtgate/planbuilder/operators/queryprojection_test.go +++ b/go/vt/vtgate/planbuilder/operators/queryprojection_test.go @@ -87,7 +87,7 @@ func TestQP(t *testing.T) { _, err = semantics.Analyze(sel, "", &semantics.FakeSI{}) require.NoError(t, err) - qp, err := CreateQPFromSelect(ctx, sel) + qp, err := createQPFromSelect(ctx, sel) if tcase.expErr != "" { require.Error(t, err) require.Contains(t, err.Error(), tcase.expErr) @@ -194,7 +194,7 @@ func TestQPSimplifiedExpr(t *testing.T) { _, err = semantics.Analyze(sel, "", &semantics.FakeSI{}) require.NoError(t, err) ctx := &plancontext.PlanningContext{SemTable: semantics.EmptySemTable()} - qp, err := CreateQPFromSelect(ctx, sel) + qp, err := createQPFromSelect(ctx, sel) require.NoError(t, err) require.Equal(t, tc.expected[1:], qp.toString()) }) diff --git a/go/vt/vtgate/planbuilder/operators/route.go b/go/vt/vtgate/planbuilder/operators/route.go index 31edd45fb78..ae2e2e432e1 100644 --- a/go/vt/vtgate/planbuilder/operators/route.go +++ b/go/vt/vtgate/planbuilder/operators/route.go @@ -582,6 +582,7 @@ type selectExpressions interface { ops.Operator addColumnWithoutPushing(expr *sqlparser.AliasedExpr, addToGroupBy bool) int isDerived() bool + findCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr) (int, error) } func addColumnToInput(operator ops.Operator, expr *sqlparser.AliasedExpr, addToGroupBy bool) (bool, int) { @@ -609,6 +610,10 @@ func (r *Route) GetColumns() ([]*sqlparser.AliasedExpr, error) { return r.Source.GetColumns() } +func (r *Route) GetSelectExprs() (sqlparser.SelectExprs, error) { + return r.Source.GetSelectExprs() +} + func (r *Route) GetOrdering() ([]ops.OrderBy, error) { return r.Source.GetOrdering() } diff --git a/go/vt/vtgate/planbuilder/operators/route_planning.go b/go/vt/vtgate/planbuilder/operators/route_planning.go index f38dcbab3a8..ac08b1fef62 100644 --- a/go/vt/vtgate/planbuilder/operators/route_planning.go +++ b/go/vt/vtgate/planbuilder/operators/route_planning.go @@ -52,15 +52,16 @@ func transformToPhysical(ctx *plancontext.PlanningContext, in ops.Operator) (ops return optimizeQueryGraph(ctx, op) case *Join: return optimizeJoin(ctx, op) - case *Derived: - return pushDownDerived(ctx, op) + case *Horizon: + if op.TableId != nil { + return pushDownDerived(ctx, op) + } case *SubQuery: return optimizeSubQuery(ctx, op, ts) case *Filter: return pushDownFilter(op) - default: - return operator, rewrite.SameTree, nil } + return operator, rewrite.SameTree, nil }) if err != nil { @@ -79,7 +80,7 @@ func pushDownFilter(op *Filter) (ops.Operator, *rewrite.ApplyResult, error) { return op, rewrite.SameTree, nil } -func pushDownDerived(ctx *plancontext.PlanningContext, op *Derived) (ops.Operator, *rewrite.ApplyResult, error) { +func pushDownDerived(ctx *plancontext.PlanningContext, op *Horizon) (ops.Operator, *rewrite.ApplyResult, error) { innerRoute, ok := op.Source.(*Route) if !ok { return op, rewrite.SameTree, nil @@ -389,9 +390,9 @@ func requiresSwitchingSides(ctx *plancontext.PlanningContext, op ops.Operator) b required := false _ = rewrite.Visit(op, func(current ops.Operator) error { - derived, isDerived := current.(*Derived) + horizon, isHorizon := current.(*Horizon) - if isDerived && !derived.IsMergeable(ctx) { + if isHorizon && horizon.IsDerived() && !horizon.IsMergeable(ctx) { required = true return io.EOF } @@ -500,11 +501,11 @@ func findColumnVindex(ctx *plancontext.PlanningContext, a ops.Operator, exp sqlp deps := ctx.SemTable.RecursiveDeps(expr) _ = rewrite.Visit(a, func(rel ops.Operator) error { - to, isTableOp := rel.(TableIDIntroducer) + to, isTableOp := rel.(tableIDIntroducer) if !isTableOp { return nil } - id := to.Introduces() + id := to.introducesTableID() if deps.IsSolvedBy(id) { tableInfo, err := ctx.SemTable.TableInfoFor(id) if err != nil { diff --git a/go/vt/vtgate/planbuilder/operators/subquery.go b/go/vt/vtgate/planbuilder/operators/subquery.go index 632042bf20e..8966c30e192 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery.go +++ b/go/vt/vtgate/planbuilder/operators/subquery.go @@ -111,7 +111,7 @@ func createSubqueryFromStatement(ctx *plancontext.PlanningContext, stmt sqlparse } subq := &SubQuery{} for _, sq := range ctx.SemTable.SubqueryMap[stmt] { - opInner, err := createLogicalOperatorFromAST(ctx, sq.Subquery.Select) + opInner, err := translateQueryToOp(ctx, sq.Subquery.Select) if err != nil { return nil, err } diff --git a/go/vt/vtgate/planbuilder/operators/table.go b/go/vt/vtgate/planbuilder/operators/table.go index a3eeea5e365..2409789d0bf 100644 --- a/go/vt/vtgate/planbuilder/operators/table.go +++ b/go/vt/vtgate/planbuilder/operators/table.go @@ -54,7 +54,7 @@ func (to *Table) Clone([]ops.Operator) ops.Operator { } // Introduces implements the PhysicalOperator interface -func (to *Table) Introduces() semantics.TableSet { +func (to *Table) introducesTableID() semantics.TableSet { return to.QTable.ID } @@ -79,6 +79,10 @@ func (to *Table) GetColumns() ([]*sqlparser.AliasedExpr, error) { return slices2.Map(to.Columns, colNameToExpr), nil } +func (to *Table) GetSelectExprs() (sqlparser.SelectExprs, error) { + return transformColumnsToSelectExprs(to) +} + func (to *Table) GetOrdering() ([]ops.OrderBy, error) { return nil, nil } diff --git a/go/vt/vtgate/planbuilder/operators/union.go b/go/vt/vtgate/planbuilder/operators/union.go index 7a68007ff63..6568521383d 100644 --- a/go/vt/vtgate/planbuilder/operators/union.go +++ b/go/vt/vtgate/planbuilder/operators/union.go @@ -144,7 +144,7 @@ func (u *Union) GetSelectFor(source int) (*sqlparser.Select, error) { for { switch op := src.(type) { case *Horizon: - return sqlparser.GetFirstSelect(op.Select), nil + return sqlparser.GetFirstSelect(op.Query), nil case *Route: src = op.Source default: diff --git a/go/vt/vtgate/planbuilder/operators/update.go b/go/vt/vtgate/planbuilder/operators/update.go index 91273062c17..0627f07734e 100644 --- a/go/vt/vtgate/planbuilder/operators/update.go +++ b/go/vt/vtgate/planbuilder/operators/update.go @@ -38,7 +38,7 @@ type Update struct { } // Introduces implements the PhysicalOperator interface -func (u *Update) Introduces() semantics.TableSet { +func (u *Update) introducesTableID() semantics.TableSet { return u.QTable.ID } diff --git a/go/vt/vtgate/planbuilder/operators/vindex.go b/go/vt/vtgate/planbuilder/operators/vindex.go index 252cb7ba2b3..e39d245bb50 100644 --- a/go/vt/vtgate/planbuilder/operators/vindex.go +++ b/go/vt/vtgate/planbuilder/operators/vindex.go @@ -52,7 +52,7 @@ type ( const VindexUnsupported = "WHERE clause for vindex function must be of the form id = or id in(,...)" // Introduces implements the Operator interface -func (v *Vindex) Introduces() semantics.TableSet { +func (v *Vindex) introducesTableID() semantics.TableSet { return v.Solved } @@ -86,6 +86,10 @@ func (v *Vindex) GetColumns() ([]*sqlparser.AliasedExpr, error) { return slices2.Map(v.Columns, colNameToExpr), nil } +func (v *Vindex) GetSelectExprs() (sqlparser.SelectExprs, error) { + return transformColumnsToSelectExprs(v) +} + func (v *Vindex) GetOrdering() ([]ops.OrderBy, error) { return nil, nil } diff --git a/go/vt/vtgate/planbuilder/ordered_aggregate.go b/go/vt/vtgate/planbuilder/ordered_aggregate.go index 7e2554d8c3f..490935889f5 100644 --- a/go/vt/vtgate/planbuilder/ordered_aggregate.go +++ b/go/vt/vtgate/planbuilder/ordered_aggregate.go @@ -272,23 +272,18 @@ func (oa *orderedAggregate) pushAggr(pb *primitiveBuilder, expr *sqlparser.Alias case popcode.AggregateSum: opcode = popcode.AggregateSumDistinct } - oa.aggregates = append(oa.aggregates, &engine.AggregateParams{ - Opcode: opcode, - Col: innerCol, - Alias: expr.ColumnName(), - OrigOpcode: origOpcode, - }) + aggr := engine.NewAggregateParam(opcode, innerCol, expr.ColumnName()) + aggr.OrigOpcode = origOpcode + oa.aggregates = append(oa.aggregates, aggr) } else { newBuilder, _, innerCol, err := planProjection(pb, oa.input, expr, origin) if err != nil { return nil, 0, err } pb.plan = newBuilder - oa.aggregates = append(oa.aggregates, &engine.AggregateParams{ - Opcode: opcode, - Col: innerCol, - OrigOpcode: origOpcode, - }) + aggr := engine.NewAggregateParam(opcode, innerCol, "") + aggr.OrigOpcode = origOpcode + oa.aggregates = append(oa.aggregates, aggr) } // Build a new rc with oa as origin because it's semantically different diff --git a/go/vt/vtgate/planbuilder/projection_pushing.go b/go/vt/vtgate/planbuilder/projection_pushing.go index acb01308b01..8772f882c78 100644 --- a/go/vt/vtgate/planbuilder/projection_pushing.go +++ b/go/vt/vtgate/planbuilder/projection_pushing.go @@ -138,13 +138,10 @@ func pushProjectionIntoOA(ctx *plancontext.PlanningContext, expr *sqlparser.Alia if err != nil { return 0, false, err } - node.aggregates = append(node.aggregates, &engine.AggregateParams{ - Opcode: popcode.AggregateAnyValue, - Col: offset, - Alias: expr.ColumnName(), - Expr: expr.Expr, - Original: expr, - }) + aggr := engine.NewAggregateParam(popcode.AggregateAnyValue, offset, expr.ColumnName()) + aggr.Expr = expr.Expr + aggr.Original = expr + node.aggregates = append(node.aggregates, aggr) return offset, true, nil } diff --git a/go/vt/vtgate/planbuilder/show.go b/go/vt/vtgate/planbuilder/show.go index 132ca55715b..7a4c5abc14c 100644 --- a/go/vt/vtgate/planbuilder/show.go +++ b/go/vt/vtgate/planbuilder/show.go @@ -564,11 +564,7 @@ func buildShowVGtidPlan(show *sqlparser.ShowBasic, vschema plancontext.VSchema) } return &engine.OrderedAggregate{ Aggregates: []*engine.AggregateParams{ - { - Opcode: popcode.AggregateGtid, - Col: 1, - Alias: "global vgtid_executed", - }, + engine.NewAggregateParam(popcode.AggregateGtid, 1, "global vgtid_executed"), }, TruncateColumnCount: 2, Input: send, diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json index a1c8d01ad24..aa54fa01776 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json @@ -362,28 +362,20 @@ "QueryType": "SELECT", "Original": "select a from (select count(*) as a from user) t", "Instructions": { - "OperatorType": "SimpleProjection", - "Columns": [ - 0 - ], + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "sum_count_star(0) AS a", "Inputs": [ { - "OperatorType": "Aggregate", - "Variant": "Scalar", - "Aggregates": "sum_count_star(0) AS a", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select count(*) as a from `user` where 1 != 1", - "Query": "select count(*) as a from `user`", - "Table": "`user`" - } - ] + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*) as a from `user` where 1 != 1", + "Query": "select count(*) as a from `user`", + "Table": "`user`" } ] }, @@ -686,9 +678,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select col1, col2, weight_string(col1), weight_string(col2) from `user` where 1 != 1 group by col1, weight_string(col1), col2, weight_string(col2)", + "FieldQuery": "select col1, col2, weight_string(col1), weight_string(col2) from `user` where 1 != 1 group by col1, col2, weight_string(col1), weight_string(col2)", "OrderBy": "(0|2) ASC, (1|3) ASC", - "Query": "select col1, col2, weight_string(col1), weight_string(col2) from `user` group by col1, weight_string(col1), col2, weight_string(col2) order by col1 asc, col2 asc", + "Query": "select col1, col2, weight_string(col1), weight_string(col2) from `user` group by col1, col2, weight_string(col1), weight_string(col2) order by col1 asc, col2 asc", "Table": "`user`" } ] @@ -798,9 +790,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select col1, col2, weight_string(col1), weight_string(col2) from `user` where 1 != 1 group by col1, weight_string(col1), col2, weight_string(col2)", + "FieldQuery": "select col1, col2, weight_string(col1), weight_string(col2) from `user` where 1 != 1 group by col1, col2, weight_string(col1), weight_string(col2)", "OrderBy": "(0|2) ASC, (1|3) ASC", - "Query": "select col1, col2, weight_string(col1), weight_string(col2) from `user` group by col1, weight_string(col1), col2, weight_string(col2) order by col1 asc, col2 asc", + "Query": "select col1, col2, weight_string(col1), weight_string(col2) from `user` group by col1, col2, weight_string(col1), weight_string(col2) order by col1 asc, col2 asc", "Table": "`user`" } ] @@ -855,9 +847,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select col1, col2, weight_string(col1), weight_string(col2) from `user` where 1 != 1 group by col1, weight_string(col1), col2, weight_string(col2)", + "FieldQuery": "select col1, col2, weight_string(col1), weight_string(col2) from `user` where 1 != 1 group by col1, col2, weight_string(col1), weight_string(col2)", "OrderBy": "(0|2) ASC, (1|3) ASC", - "Query": "select col1, col2, weight_string(col1), weight_string(col2) from `user` group by col1, weight_string(col1), col2, weight_string(col2) order by col1 asc, col2 asc", + "Query": "select col1, col2, weight_string(col1), weight_string(col2) from `user` group by col1, col2, weight_string(col1), weight_string(col2) order by col1 asc, col2 asc", "Table": "`user`" } ] @@ -912,9 +904,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select col1, col2, weight_string(col1), weight_string(col2) from `user` where 1 != 1 group by col1, weight_string(col1), col2, weight_string(col2)", + "FieldQuery": "select col1, col2, weight_string(col1), weight_string(col2) from `user` where 1 != 1 group by col1, col2, weight_string(col1), weight_string(col2)", "OrderBy": "(0|2) ASC, (1|3) ASC", - "Query": "select col1, col2, weight_string(col1), weight_string(col2) from `user` group by col1, weight_string(col1), col2, weight_string(col2) order by col1 asc, col2 asc", + "Query": "select col1, col2, weight_string(col1), weight_string(col2) from `user` group by col1, col2, weight_string(col1), weight_string(col2) order by col1 asc, col2 asc", "Table": "`user`" } ] @@ -981,9 +973,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select col1, col2, weight_string(col1), weight_string(col2) from `user` where 1 != 1 group by col1, weight_string(col1), col2, weight_string(col2)", + "FieldQuery": "select col1, col2, weight_string(col1), weight_string(col2) from `user` where 1 != 1 group by col1, col2, weight_string(col1), weight_string(col2)", "OrderBy": "(0|2) ASC, (1|3) ASC", - "Query": "select col1, col2, weight_string(col1), weight_string(col2) from `user` group by col1, weight_string(col1), col2, weight_string(col2) order by col1 asc, col2 asc", + "Query": "select col1, col2, weight_string(col1), weight_string(col2) from `user` group by col1, col2, weight_string(col1), weight_string(col2) order by col1 asc, col2 asc", "Table": "`user`" } ] @@ -1939,7 +1931,7 @@ "Sharded": true }, "FieldQuery": "select count(*) from (select `user`.col, user_extra.extra from `user`, user_extra where 1 != 1) as a where 1 != 1", - "Query": "select count(*) from (select `user`.col, user_extra.extra from `user`, user_extra where `user`.id = user_extra.user_id order by user_extra.extra asc) as a", + "Query": "select count(*) from (select `user`.col, user_extra.extra from `user`, user_extra where `user`.id = user_extra.user_id) as a", "Table": "`user`, user_extra" } ] @@ -1989,7 +1981,7 @@ "Sharded": true }, "FieldQuery": "select col from (select `user`.col, user_extra.extra from `user`, user_extra where 1 != 1) as a where 1 != 1", - "Query": "select col from (select `user`.col, user_extra.extra from `user`, user_extra where `user`.id = user_extra.user_id order by user_extra.extra asc) as a", + "Query": "select col from (select `user`.col, user_extra.extra from `user`, user_extra where `user`.id = user_extra.user_id) as a", "Table": "`user`, user_extra" }, "TablesUsed": [ @@ -2020,7 +2012,7 @@ }, "FieldQuery": "select col, count(*) from (select `user`.col, user_extra.extra from `user`, user_extra where 1 != 1) as a where 1 != 1 group by col", "OrderBy": "0 ASC", - "Query": "select col, count(*) from (select `user`.col, user_extra.extra from `user`, user_extra where `user`.id = user_extra.user_id order by user_extra.extra asc) as a group by col order by col asc", + "Query": "select col, count(*) from (select `user`.col, user_extra.extra from `user`, user_extra where `user`.id = user_extra.user_id) as a group by col order by col asc", "Table": "`user`, user_extra" } ] @@ -2221,9 +2213,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select col1, col2, col2, weight_string(col1), weight_string(col2) from `user` where 1 != 1 group by col1, weight_string(col1), col2, weight_string(col2)", + "FieldQuery": "select col1, col2, col2, weight_string(col1), weight_string(col2) from `user` where 1 != 1 group by col1, col2, weight_string(col1), weight_string(col2)", "OrderBy": "(0|3) ASC, (1|4) ASC", - "Query": "select col1, col2, col2, weight_string(col1), weight_string(col2) from `user` group by col1, weight_string(col1), col2, weight_string(col2) order by col1 asc, col2 asc", + "Query": "select col1, col2, col2, weight_string(col1), weight_string(col2) from `user` group by col1, col2, weight_string(col1), weight_string(col2) order by col1 asc, col2 asc", "Table": "`user`" } ] @@ -3560,7 +3552,7 @@ "Instructions": { "OperatorType": "Aggregate", "Variant": "Ordered", - "Aggregates": "min(1) AS min(distinct id), sum_distinct(2|4) AS sum(distinct col3)", + "Aggregates": "min(1|4) AS min(distinct id), sum_distinct(2|5) AS sum(distinct col3)", "GroupBy": "(0|3)", "ResultColumns": 3, "Inputs": [ @@ -3571,9 +3563,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select col1, min(distinct id), col3, weight_string(col1), weight_string(col3) from `user` where 1 != 1 group by col1, weight_string(col1), col3, weight_string(col3)", - "OrderBy": "(0|3) ASC, (2|4) ASC", - "Query": "select col1, min(distinct id), col3, weight_string(col1), weight_string(col3) from `user` group by col1, weight_string(col1), col3, weight_string(col3) order by col1 asc, col3 asc", + "FieldQuery": "select col1, min(distinct id), col3, weight_string(col1), weight_string(id), weight_string(col3) from `user` where 1 != 1 group by col1, col3, weight_string(col1), weight_string(id), weight_string(col3)", + "OrderBy": "(0|3) ASC, (2|5) ASC", + "Query": "select col1, min(distinct id), col3, weight_string(col1), weight_string(id), weight_string(col3) from `user` group by col1, col3, weight_string(col1), weight_string(id), weight_string(col3) order by col1 asc, col3 asc", "Table": "`user`" } ] @@ -3687,9 +3679,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select val2, val1, count(*), weight_string(val2), weight_string(val1) from `user` where 1 != 1 group by val2, weight_string(val2), val1, weight_string(val1)", + "FieldQuery": "select val2, val1, count(*), weight_string(val2), weight_string(val1) from `user` where 1 != 1 group by val2, val1, weight_string(val2), weight_string(val1)", "OrderBy": "(0|3) ASC, (1|4) ASC", - "Query": "select val2, val1, count(*), weight_string(val2), weight_string(val1) from `user` group by val2, weight_string(val2), val1, weight_string(val1) order by val2 asc, val1 asc", + "Query": "select val2, val1, count(*), weight_string(val2), weight_string(val1) from `user` group by val2, val1, weight_string(val2), weight_string(val1) order by val2 asc, val1 asc", "Table": "`user`" } ] @@ -3777,9 +3769,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select tcol1, tcol2, tcol2, weight_string(tcol1), weight_string(tcol2) from `user` where 1 != 1 group by tcol1, weight_string(tcol1), tcol2, weight_string(tcol2)", + "FieldQuery": "select tcol1, tcol2, tcol2, weight_string(tcol1), weight_string(tcol2) from `user` where 1 != 1 group by tcol1, tcol2, weight_string(tcol1), weight_string(tcol2)", "OrderBy": "(0|3) ASC, (1|4) ASC", - "Query": "select tcol1, tcol2, tcol2, weight_string(tcol1), weight_string(tcol2) from `user` group by tcol1, weight_string(tcol1), tcol2, weight_string(tcol2) order by tcol1 asc, tcol2 asc", + "Query": "select tcol1, tcol2, tcol2, weight_string(tcol1), weight_string(tcol2) from `user` group by tcol1, tcol2, weight_string(tcol1), weight_string(tcol2) order by tcol1 asc, tcol2 asc", "Table": "`user`" } ] @@ -3799,8 +3791,8 @@ "Instructions": { "OperatorType": "Aggregate", "Variant": "Ordered", - "Aggregates": "count_distinct(0|4) AS count(distinct tcol2), sum_count_star(2) AS count(*), sum_distinct(3|4) AS sum(distinct tcol2)", - "GroupBy": "(1|5)", + "Aggregates": "count_distinct(0|5) AS count(distinct tcol2), sum_count_star(2) AS count(*), sum_distinct(3|5) AS sum(distinct tcol2)", + "GroupBy": "(1|4)", "ResultColumns": 4, "Inputs": [ { @@ -3810,9 +3802,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select tcol2, tcol1, count(*), tcol2, weight_string(tcol2), weight_string(tcol1) from `user` where 1 != 1 group by tcol2, weight_string(tcol2), tcol1, weight_string(tcol1)", - "OrderBy": "(1|5) ASC, (0|4) ASC", - "Query": "select tcol2, tcol1, count(*), tcol2, weight_string(tcol2), weight_string(tcol1) from `user` group by tcol2, weight_string(tcol2), tcol1, weight_string(tcol1) order by tcol1 asc, tcol2 asc", + "FieldQuery": "select tcol2, tcol1, count(*), tcol2, weight_string(tcol1), weight_string(tcol2) from `user` where 1 != 1 group by tcol1, tcol2, weight_string(tcol1), weight_string(tcol2)", + "OrderBy": "(1|4) ASC, (0|5) ASC", + "Query": "select tcol2, tcol1, count(*), tcol2, weight_string(tcol1), weight_string(tcol2) from `user` group by tcol1, tcol2, weight_string(tcol1), weight_string(tcol2) order by tcol1 asc, tcol2 asc", "Table": "`user`" } ] @@ -3837,59 +3829,34 @@ "ResultColumns": 2, "Inputs": [ { - "OperatorType": "Projection", - "Expressions": [ - "[COLUMN 0] as textcol1", - "[COLUMN 1] as val2", - "[COLUMN 2]" - ], + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,L:1,L:2", + "JoinVars": { + "u2_val2": 3 + }, + "TableName": "`user`_`user`_music", "Inputs": [ { "OperatorType": "Join", "Variant": "Join", - "JoinColumnIndexes": "L:2,L:3,L:5", + "JoinColumnIndexes": "L:0,L:1,L:2,R:0", "JoinVars": { - "u2_val2": 0 + "u_val2": 1 }, - "TableName": "`user`_`user`_music", + "TableName": "`user`_`user`", "Inputs": [ { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "R:0,R:0,L:2,L:0,R:1,L:1", - "JoinVars": { - "u_val2": 0 + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true }, - "TableName": "`user`_`user`", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select u.val2, weight_string(u.val2), u.textcol1 from `user` as u where 1 != 1 group by u.val2, weight_string(u.val2), u.textcol1", - "OrderBy": "2 ASC COLLATE latin1_swedish_ci, (0|1) ASC", - "Query": "select u.val2, weight_string(u.val2), u.textcol1 from `user` as u group by u.val2, weight_string(u.val2), u.textcol1 order by u.textcol1 asc, u.val2 asc", - "Table": "`user`" - }, - { - "OperatorType": "Route", - "Variant": "EqualUnique", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select u2.val2, weight_string(u2.val2) from `user` as u2 where 1 != 1 group by u2.val2, weight_string(u2.val2)", - "Query": "select u2.val2, weight_string(u2.val2) from `user` as u2 where u2.id = :u_val2 group by u2.val2, weight_string(u2.val2)", - "Table": "`user`", - "Values": [ - ":u_val2" - ], - "Vindex": "user_index" - } - ] + "FieldQuery": "select u.textcol1, u.val2, weight_string(u.val2) from `user` as u where 1 != 1", + "OrderBy": "0 ASC COLLATE latin1_swedish_ci", + "Query": "select u.textcol1, u.val2, weight_string(u.val2) from `user` as u order by u.textcol1 asc", + "Table": "`user`" }, { "OperatorType": "Route", @@ -3898,15 +3865,30 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select 1 from music as m where 1 != 1", - "Query": "select 1 from music as m where m.id = :u2_val2", - "Table": "music", + "FieldQuery": "select u2.val2 from `user` as u2 where 1 != 1", + "Query": "select u2.val2 from `user` as u2 where u2.id = :u_val2", + "Table": "`user`", "Values": [ - ":u2_val2" + ":u_val2" ], - "Vindex": "music_user_map" + "Vindex": "user_index" } ] + }, + { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from music as m where 1 != 1", + "Query": "select 1 from music as m where m.id = :u2_val2", + "Table": "music", + "Values": [ + ":u2_val2" + ], + "Vindex": "music_user_map" } ] } @@ -4013,13 +3995,13 @@ { "OperatorType": "Projection", "Expressions": [ - "[COLUMN 2] * COALESCE([COLUMN 3], INT64(1)) as sum(col)" + "[COLUMN 0] * [COLUMN 1] as sum(col)" ], "Inputs": [ { "OperatorType": "Join", "Variant": "Join", - "JoinColumnIndexes": "L:0,L:1,L:2,R:1", + "JoinColumnIndexes": "L:0,R:0", "TableName": "`user`_user_extra", "Inputs": [ { @@ -4029,8 +4011,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select `user`.col as col, 32, sum(col) from `user` where 1 != 1", - "Query": "select `user`.col as col, 32, sum(col) from `user`", + "FieldQuery": "select sum(col) from (select `user`.col as col, 32 from `user` where 1 != 1) as t where 1 != 1", + "Query": "select sum(col) from (select `user`.col as col, 32 from `user`) as t", "Table": "`user`" }, { @@ -4040,8 +4022,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select 1, count(*) from user_extra where 1 != 1 group by 1", - "Query": "select 1, count(*) from user_extra group by 1", + "FieldQuery": "select count(*) from user_extra where 1 != 1 group by .0", + "Query": "select count(*) from user_extra group by .0", "Table": "user_extra" } ] @@ -4736,27 +4718,19 @@ "Aggregates": "count(0) AS count(city)", "Inputs": [ { - "OperatorType": "Projection", - "Expressions": [ - "[COLUMN 2] as count(city)" - ], + "OperatorType": "Limit", + "Count": "INT64(10)", "Inputs": [ { - "OperatorType": "Limit", - "Count": "INT64(10)", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select phone, id, city from `user` where 1 != 1", - "Query": "select phone, id, city from `user` where id > 12 limit :__upper_limit", - "Table": "`user`" - } - ] + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select city from (select phone, id, city from `user` where 1 != 1) as x where 1 != 1", + "Query": "select city from (select phone, id, city from `user` where id > 12) as x limit :__upper_limit", + "Table": "`user`" } ] } @@ -4780,27 +4754,19 @@ "Aggregates": "count_star(0) AS count(*)", "Inputs": [ { - "OperatorType": "Projection", - "Expressions": [ - "[COLUMN 0] as count(*)" - ], + "OperatorType": "Limit", + "Count": "INT64(10)", "Inputs": [ { - "OperatorType": "Limit", - "Count": "INT64(10)", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select phone, id, city from `user` where 1 != 1", - "Query": "select phone, id, city from `user` where id > 12 limit :__upper_limit", - "Table": "`user`" - } - ] + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from (select phone, id, city from `user` where 1 != 1) as x where 1 != 1", + "Query": "select 1 from (select phone, id, city from `user` where id > 12) as x limit :__upper_limit", + "Table": "`user`" } ] } @@ -4823,47 +4789,39 @@ "Aggregates": "count(0) AS count(col)", "Inputs": [ { - "OperatorType": "Projection", - "Expressions": [ - "[COLUMN 0] as count(col)" - ], + "OperatorType": "Limit", + "Count": "INT64(10)", "Inputs": [ { - "OperatorType": "Limit", - "Count": "INT64(10)", + "OperatorType": "Join", + "Variant": "LeftJoin", + "JoinColumnIndexes": "R:0", + "JoinVars": { + "user_id": 0 + }, + "TableName": "`user`_user_extra", "Inputs": [ { - "OperatorType": "Join", - "Variant": "LeftJoin", - "JoinColumnIndexes": "R:0", - "JoinVars": { - "user_id": 0 + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true }, - "TableName": "`user`_user_extra", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select `user`.id from `user` where 1 != 1", - "Query": "select `user`.id from `user`", - "Table": "`user`" - }, - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select user_extra.col as col from user_extra where 1 != 1", - "Query": "select user_extra.col as col from user_extra where user_extra.id = :user_id", - "Table": "user_extra" - } - ] + "FieldQuery": "select `user`.id from `user` where 1 != 1", + "Query": "select `user`.id from `user`", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select user_extra.col from user_extra where 1 != 1", + "Query": "select user_extra.col from user_extra where user_extra.id = :user_id", + "Table": "user_extra" } ] } @@ -4892,12 +4850,9 @@ "ResultColumns": 2, "Inputs": [ { - "OperatorType": "Projection", - "Expressions": [ - "[COLUMN 1] as val1", - "[COLUMN 0] as count(*)", - "[COLUMN 2]" - ], + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "(0|2) ASC", "Inputs": [ { "OperatorType": "Limit", @@ -4910,9 +4865,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select id, val1, weight_string(val1) from `user` where 1 != 1", - "OrderBy": "(1|2) ASC, (1|2) ASC", - "Query": "select id, val1, weight_string(val1) from `user` where val2 < 4 order by val1 asc, val1 asc limit :__upper_limit", + "FieldQuery": "select val1, 1, weight_string(val1), weight_string(val1) from (select id, val1 from `user` where 1 != 1) as x where 1 != 1", + "OrderBy": "(0|3) ASC", + "Query": "select val1, 1, weight_string(val1), weight_string(val1) from (select id, val1 from `user` where val2 < 4) as x order by val1 asc limit :__upper_limit", "Table": "`user`" } ] @@ -4955,7 +4910,7 @@ "Inputs": [ { "OperatorType": "Filter", - "Predicate": ":1 = 1", + "Predicate": "count(*) = 1", "Inputs": [ { "OperatorType": "Aggregate", @@ -5290,7 +5245,7 @@ "Sharded": true }, "FieldQuery": "select t1.portalId, t1.flowId from (select portalId, flowId, count(*) as `count` from user_extra where 1 != 1 group by user_id, flowId) as t1 where 1 != 1", - "Query": "select t1.portalId, t1.flowId from (select portalId, flowId, count(*) as `count` from user_extra where localDate > :v1 group by user_id, flowId order by null) as t1 where `count` >= :v2", + "Query": "select t1.portalId, t1.flowId from (select portalId, flowId, count(*) as `count` from user_extra where localDate > :v1 group by user_id, flowId) as t1 where `count` >= :v2", "Table": "user_extra" }, "TablesUsed": [ @@ -5312,22 +5267,28 @@ ], "Inputs": [ { - "OperatorType": "Aggregate", - "Variant": "Ordered", - "Aggregates": "max(1) AS bazo", - "GroupBy": "(0|2)", + "OperatorType": "Filter", + "Predicate": "bazo between 100 and 200", "Inputs": [ { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select foo, max(baz) as bazo, weight_string(foo) from (select foo, baz from `user` where 1 != 1) as f where 1 != 1 group by foo, weight_string(foo)", - "OrderBy": "(0|2) ASC", - "Query": "select foo, max(baz) as bazo, weight_string(foo) from (select foo, baz from `user` having max(baz) between 100 and 200) as f group by foo, weight_string(foo) order by foo asc", - "Table": "`user`" + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "max(1|3) AS bazo", + "GroupBy": "(0|2)", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select foo, max(baz) as bazo, weight_string(foo), weight_string(baz) from (select foo, baz from `user` where 1 != 1) as f where 1 != 1 group by foo, weight_string(foo), weight_string(baz)", + "OrderBy": "(0|2) ASC", + "Query": "select foo, max(baz) as bazo, weight_string(foo), weight_string(baz) from (select foo, baz from `user`) as f group by foo, weight_string(foo), weight_string(baz) order by foo asc", + "Table": "`user`" + } + ] } ] } @@ -5348,7 +5309,7 @@ "Instructions": { "OperatorType": "SimpleProjection", "Columns": [ - 1 + 0 ], "Inputs": [ { @@ -5356,31 +5317,22 @@ "Predicate": "bazo between 100 and 200", "Inputs": [ { - "OperatorType": "SimpleProjection", - "Columns": [ - 1, - 0 - ], + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "sum_count(1) AS bazo", + "GroupBy": "(0|2)", "Inputs": [ { - "OperatorType": "Aggregate", - "Variant": "Ordered", - "Aggregates": "sum_count(1) AS bazo", - "GroupBy": "(0|2)", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select foo, count(baz) as bazo, weight_string(foo) from (select foo, baz from `user` where 1 != 1) as f where 1 != 1 group by foo, weight_string(foo)", - "OrderBy": "(0|2) ASC", - "Query": "select foo, count(baz) as bazo, weight_string(foo) from (select foo, baz from `user`) as f group by foo, weight_string(foo) order by foo asc", - "Table": "`user`" - } - ] + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select foo, count(baz) as bazo, weight_string(foo) from (select foo, baz from `user` where 1 != 1) as f where 1 != 1 group by foo, weight_string(foo)", + "OrderBy": "(0|2) ASC", + "Query": "select foo, count(baz) as bazo, weight_string(foo) from (select foo, baz from `user`) as f group by foo, weight_string(foo) order by foo asc", + "Table": "`user`" } ] } @@ -6280,7 +6232,7 @@ "Instructions": { "OperatorType": "Aggregate", "Variant": "Scalar", - "Aggregates": "min(0) AS min(textcol1), max(1) AS max(textcol2), sum_distinct(2 COLLATE latin1_swedish_ci) AS sum(distinct textcol1), count_distinct(3 COLLATE latin1_swedish_ci) AS count(distinct textcol1)", + "Aggregates": "min(0 COLLATE latin1_swedish_ci) AS min(textcol1), max(1 COLLATE latin1_swedish_ci) AS max(textcol2), sum_distinct(2 COLLATE latin1_swedish_ci) AS sum(distinct textcol1), count_distinct(3 COLLATE latin1_swedish_ci) AS count(distinct textcol1)", "Inputs": [ { "OperatorType": "Route", @@ -6311,7 +6263,7 @@ "Instructions": { "OperatorType": "Aggregate", "Variant": "Ordered", - "Aggregates": "min(1) AS min(textcol1), max(2) AS max(textcol2), sum_distinct(3 COLLATE latin1_swedish_ci) AS sum(distinct textcol1), count_distinct(4 COLLATE latin1_swedish_ci) AS count(distinct textcol1)", + "Aggregates": "min(1 COLLATE latin1_swedish_ci) AS min(textcol1), max(2 COLLATE latin1_swedish_ci) AS max(textcol2), sum_distinct(3 COLLATE latin1_swedish_ci) AS sum(distinct textcol1), count_distinct(4 COLLATE latin1_swedish_ci) AS count(distinct textcol1)", "GroupBy": "0", "Inputs": [ { @@ -6463,14 +6415,14 @@ "Instructions": { "OperatorType": "Aggregate", "Variant": "Ordered", - "Aggregates": "min(1) AS min(user_extra.foo), max(3) AS max(user_extra.bar)", + "Aggregates": "min(1|5) AS min(user_extra.foo), max(3|6) AS max(user_extra.bar)", "GroupBy": "0, (2|4)", "ResultColumns": 4, "Inputs": [ { "OperatorType": "Join", "Variant": "Join", - "JoinColumnIndexes": "L:0,R:0,L:1,R:1,L:2", + "JoinColumnIndexes": "L:0,R:0,L:1,R:1,L:2,R:2,R:3", "JoinVars": { "user_col": 0 }, @@ -6495,8 +6447,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select min(user_extra.foo), max(user_extra.bar) from user_extra where 1 != 1 group by .0", - "Query": "select min(user_extra.foo), max(user_extra.bar) from user_extra where user_extra.bar = :user_col group by .0", + "FieldQuery": "select min(user_extra.foo), max(user_extra.bar), weight_string(user_extra.foo), weight_string(user_extra.bar) from user_extra where 1 != 1 group by .0, weight_string(user_extra.foo), weight_string(user_extra.bar)", + "Query": "select min(user_extra.foo), max(user_extra.bar), weight_string(user_extra.foo), weight_string(user_extra.bar) from user_extra where user_extra.bar = :user_col group by .0, weight_string(user_extra.foo), weight_string(user_extra.bar)", "Table": "user_extra" } ] @@ -6519,12 +6471,13 @@ "Instructions": { "OperatorType": "Aggregate", "Variant": "Scalar", - "Aggregates": "max(0) AS max(u.foo * ue.bar)", + "Aggregates": "max(0|1) AS max(u.foo * ue.bar)", + "ResultColumns": 1, "Inputs": [ { "OperatorType": "Join", "Variant": "Join", - "JoinColumnIndexes": "R:0", + "JoinColumnIndexes": "R:0,R:1", "JoinVars": { "u_foo": 0 }, @@ -6548,8 +6501,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select :u_foo * ue.bar from user_extra as ue where 1 != 1", - "Query": "select :u_foo * ue.bar from user_extra as ue", + "FieldQuery": "select :u_foo * ue.bar, weight_string(:u_foo * ue.bar) from user_extra as ue where 1 != 1", + "Query": "select :u_foo * ue.bar, weight_string(:u_foo * ue.bar) from user_extra as ue", "Table": "user_extra" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.json b/go/vt/vtgate/planbuilder/testdata/from_cases.json index 8ef50577bc0..ecec6a803ea 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.json @@ -2273,8 +2273,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select ref.col from ref, (select aa from `user` where 1 != 1) as `user` where 1 != 1", - "Query": "select ref.col from ref, (select aa from `user` where `user`.id = 1) as `user`", + "FieldQuery": "select ref.col from (select aa from `user` where 1 != 1) as `user`, ref where 1 != 1", + "Query": "select ref.col from (select aa from `user` where `user`.id = 1) as `user`, ref", "Table": "`user`, ref", "Values": [ "INT64(1)" @@ -2534,8 +2534,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select t.id from user_extra, (select id from `user` where 1 != 1) as t where 1 != 1", - "Query": "select t.id from user_extra, (select id from `user` where id = 5) as t where t.id = user_extra.user_id", + "FieldQuery": "select t.id from (select id from `user` where 1 != 1) as t, user_extra where 1 != 1", + "Query": "select t.id from (select id from `user` where id = 5) as t, user_extra where t.id = user_extra.user_id", "Table": "`user`, user_extra", "Values": [ "INT64(5)" @@ -4896,28 +4896,20 @@ "QueryType": "SELECT", "Original": "select a as k from (select count(*) as a from user) t", "Instructions": { - "OperatorType": "SimpleProjection", - "Columns": [ - 0 - ], + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "sum_count_star(0) AS a", "Inputs": [ { - "OperatorType": "Aggregate", - "Variant": "Scalar", - "Aggregates": "sum_count_star(0) AS a", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select count(*) as a from `user` where 1 != 1", - "Query": "select count(*) as a from `user`", - "Table": "`user`" - } - ] + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*) as a from `user` where 1 != 1", + "Query": "select count(*) as a from `user`", + "Table": "`user`" } ] }, diff --git a/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.json b/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.json index 4dd03fa2608..944e5cd617c 100644 --- a/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.json @@ -501,40 +501,33 @@ "QueryType": "SELECT", "Original": "select id from (select user.id, user.col from user join user_extra) as t order by id", "Instructions": { - "OperatorType": "Sort", - "Variant": "Memory", - "OrderBy": "(0|1) ASC", - "ResultColumns": 1, + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0", + "TableName": "`user`_user_extra", "Inputs": [ { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "L:0,L:1", - "TableName": "`user`_user_extra", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select id, weight_string(id) from (select `user`.id, `user`.col from `user` where 1 != 1) as t where 1 != 1", - "Query": "select id, weight_string(id) from (select `user`.id, `user`.col from `user`) as t", - "Table": "`user`" - }, - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select 1 from user_extra where 1 != 1", - "Query": "select 1 from user_extra", - "Table": "user_extra" - } - ] + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id, weight_string(id) from (select `user`.id, `user`.col from `user` where 1 != 1) as t where 1 != 1", + "OrderBy": "(0|1) ASC", + "Query": "select id, weight_string(id) from (select `user`.id, `user`.col from `user`) as t order by id asc", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from user_extra where 1 != 1", + "Query": "select 1 from user_extra", + "Table": "user_extra" } ] }, diff --git a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json index 0991a324f6e..4908a29d134 100644 --- a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json @@ -1860,8 +1860,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select * from `user` as u, (select user_id from user_extra where 1 != 1) as eu where 1 != 1", - "Query": "select * from `user` as u, (select user_id from user_extra where user_id = 5) as eu where u.id = 5 and u.id = eu.user_id order by eu.user_id asc", + "FieldQuery": "select * from (select user_id from user_extra where 1 != 1) as eu, `user` as u where 1 != 1", + "Query": "select * from (select user_id from user_extra where user_id = 5) as eu, `user` as u where u.id = 5 and u.id = eu.user_id order by eu.user_id asc", "Table": "`user`, user_extra", "Values": [ "INT64(5)" @@ -3303,13 +3303,13 @@ { "OperatorType": "Aggregate", "Variant": "Ordered", - "Aggregates": "min(1) AS min(a.id)", + "Aggregates": "min(1|3) AS min(a.id)", "GroupBy": "(0|2)", "Inputs": [ { "OperatorType": "Join", "Variant": "Join", - "JoinColumnIndexes": "L:1,L:0,L:2", + "JoinColumnIndexes": "L:1,L:0,L:2,L:3", "JoinVars": { "a_tcol1": 1 }, @@ -3322,9 +3322,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select min(a.id), a.tcol1, weight_string(a.tcol1) from `user` as a where 1 != 1 group by a.tcol1, weight_string(a.tcol1)", + "FieldQuery": "select min(a.id), a.tcol1, weight_string(a.tcol1), weight_string(a.id) from `user` as a where 1 != 1 group by a.tcol1, weight_string(a.tcol1), weight_string(a.id)", "OrderBy": "(1|2) ASC", - "Query": "select min(a.id), a.tcol1, weight_string(a.tcol1) from `user` as a group by a.tcol1, weight_string(a.tcol1) order by a.tcol1 asc", + "Query": "select min(a.id), a.tcol1, weight_string(a.tcol1), weight_string(a.id) from `user` as a group by a.tcol1, weight_string(a.tcol1), weight_string(a.id) order by a.tcol1 asc", "Table": "`user`" }, { diff --git a/go/vt/vtgate/planbuilder/testdata/reference_cases.json b/go/vt/vtgate/planbuilder/testdata/reference_cases.json index ac5338ecd3a..289e23b0b07 100644 --- a/go/vt/vtgate/planbuilder/testdata/reference_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/reference_cases.json @@ -238,8 +238,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select ambiguous_ref_with_source.col from ambiguous_ref_with_source, (select aa from `user` where 1 != 1) as `user` where 1 != 1", - "Query": "select ambiguous_ref_with_source.col from ambiguous_ref_with_source, (select aa from `user` where `user`.id = 1) as `user`", + "FieldQuery": "select ambiguous_ref_with_source.col from (select aa from `user` where 1 != 1) as `user`, ambiguous_ref_with_source where 1 != 1", + "Query": "select ambiguous_ref_with_source.col from (select aa from `user` where `user`.id = 1) as `user`, ambiguous_ref_with_source", "Table": "`user`, ambiguous_ref_with_source", "Values": [ "INT64(1)" diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.json b/go/vt/vtgate/planbuilder/testdata/select_cases.json index e23bea5e0d4..08227bc5a15 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.json @@ -7459,8 +7459,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select music.id from music, (select max(id) as maxt from music where 1 != 1) as other where 1 != 1", - "Query": "select music.id from music, (select max(id) as maxt from music where music.user_id = 5) as other where other.maxt = music.id", + "FieldQuery": "select music.id from (select max(id) as maxt from music where 1 != 1) as other, music where 1 != 1", + "Query": "select music.id from (select max(id) as maxt from music where music.user_id = 5) as other, music where other.maxt = music.id", "Table": "music", "Values": [ "INT64(5)" @@ -7500,8 +7500,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select music.id from music, (select id from music where 1 != 1) as other where 1 != 1", - "Query": "select music.id from music, (select id from music where music.user_id = 5) as other where other.id = music.id", + "FieldQuery": "select music.id from (select id from music where 1 != 1) as other, music where 1 != 1", + "Query": "select music.id from (select id from music where music.user_id = 5) as other, music where other.id = music.id", "Table": "music", "Values": [ "INT64(5)" @@ -7541,8 +7541,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select music.id from music, (select id from music where 1 != 1) as other where 1 != 1", - "Query": "select music.id from music, (select id from music where music.user_id in ::__vals) as other where other.id = music.id", + "FieldQuery": "select music.id from (select id from music where 1 != 1) as other, music where 1 != 1", + "Query": "select music.id from (select id from music where music.user_id in ::__vals) as other, music where other.id = music.id", "Table": "music", "Values": [ "(INT64(5), INT64(6), INT64(7))" @@ -7620,24 +7620,24 @@ "Instructions": { "OperatorType": "Join", "Variant": "Join", - "JoinColumnIndexes": "R:0,L:1", + "JoinColumnIndexes": "R:0,L:0", "JoinVars": { - "t_id": 0 + "t_id": 1 }, "TableName": "user_extra_`user`", "Inputs": [ { "OperatorType": "SimpleProjection", "Columns": [ - 0, - 1 + 1, + 0 ], "Inputs": [ { "OperatorType": "Aggregate", "Variant": "Ordered", "Aggregates": "sum_count_star(1) AS b", - "GroupBy": "(0|3), (2|4)", + "GroupBy": "(2|3), (0|4)", "Inputs": [ { "OperatorType": "Route", @@ -7646,9 +7646,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select id, count(*) as b, req, weight_string(id), weight_string(req) from user_extra where 1 != 1 group by id, weight_string(id), req, weight_string(req)", - "OrderBy": "(0|3) ASC, (2|4) ASC", - "Query": "select id, count(*) as b, req, weight_string(id), weight_string(req) from user_extra group by id, weight_string(id), req, weight_string(req) order by id asc, req asc", + "FieldQuery": "select id, count(*) as b, req, weight_string(req), weight_string(id) from user_extra where 1 != 1 group by req, id, weight_string(req), weight_string(id)", + "OrderBy": "(2|3) ASC, (0|4) ASC", + "Query": "select id, count(*) as b, req, weight_string(req), weight_string(id) from user_extra group by req, id, weight_string(req), weight_string(id) order by req asc, id asc", "Table": "user_extra" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/tpcc_cases.json b/go/vt/vtgate/planbuilder/testdata/tpcc_cases.json index a36319cb322..51be1f7522c 100644 --- a/go/vt/vtgate/planbuilder/testdata/tpcc_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/tpcc_cases.json @@ -1633,8 +1633,8 @@ "Name": "main", "Sharded": true }, - "FieldQuery": "select o.o_id, o.o_d_id from orders1 as o, (select o_c_id, o_w_id, o_d_id, count(distinct o_w_id), o_id from orders1 where 1 != 1 group by o_c_id, o_d_id, o_w_id) as t where 1 != 1", - "Query": "select o.o_id, o.o_d_id from orders1 as o, (select o_c_id, o_w_id, o_d_id, count(distinct o_w_id), o_id from orders1 where o_w_id = 1 and o_id > 2100 and o_id < 11153 group by o_c_id, o_d_id, o_w_id having count(distinct o_id) > 1 limit 1) as t where t.o_w_id = o.o_w_id and t.o_d_id = o.o_d_id and t.o_c_id = o.o_c_id limit 1", + "FieldQuery": "select o.o_id, o.o_d_id from (select o_c_id, o_w_id, o_d_id, count(distinct o_w_id), o_id from orders1 where 1 != 1 group by o_c_id, o_d_id, o_w_id) as t, orders1 as o where 1 != 1", + "Query": "select o.o_id, o.o_d_id from (select o_c_id, o_w_id, o_d_id, count(distinct o_w_id), o_id from orders1 where o_w_id = 1 and o_id > 2100 and o_id < 11153 group by o_c_id, o_d_id, o_w_id having count(distinct o_id) > 1 limit 1) as t, orders1 as o where t.o_w_id = o.o_w_id and t.o_d_id = o.o_d_id and t.o_c_id = o.o_c_id limit 1", "Table": "orders1", "Values": [ "INT64(1)" diff --git a/go/vt/vtgate/planbuilder/testdata/tpch_cases.json b/go/vt/vtgate/planbuilder/testdata/tpch_cases.json index 33152cfc5a5..eb24df0b0b8 100644 --- a/go/vt/vtgate/planbuilder/testdata/tpch_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/tpch_cases.json @@ -553,131 +553,172 @@ "OperatorType": "Aggregate", "Variant": "Ordered", "Aggregates": "sum(3) AS revenue", - "GroupBy": "(0|6), (1|5), (2|4)", + "GroupBy": "(0|4), (1|5), (2|6)", "ResultColumns": 4, "Inputs": [ { "OperatorType": "Projection", "Expressions": [ - "[COLUMN 4] as supp_nation", - "[COLUMN 5] as cust_nation", - "[COLUMN 6] as l_year", - "(((([COLUMN 10] * COALESCE([COLUMN 11], INT64(1))) * COALESCE([COLUMN 12], INT64(1))) * COALESCE([COLUMN 13], INT64(1))) * COALESCE([COLUMN 14], INT64(1))) * COALESCE([COLUMN 15], INT64(1)) as revenue", - "[COLUMN 9]", - "[COLUMN 8]", - "[COLUMN 7]" + "[COLUMN 2] as supp_nation", + "[COLUMN 3] as cust_nation", + "[COLUMN 4] as l_year", + "[COLUMN 0] * [COLUMN 1] as revenue", + "[COLUMN 5] as weight_string(supp_nation)", + "[COLUMN 6] as weight_string(cust_nation)", + "[COLUMN 7] as weight_string(l_year)" ], "Inputs": [ { - "OperatorType": "Sort", - "Variant": "Memory", - "OrderBy": "(0|16) ASC, (1|17) ASC, (2|18) ASC", + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,R:0,L:1,R:1,L:2,L:5,R:2,L:6", + "JoinVars": { + "n1_n_name": 4, + "o_custkey": 3 + }, + "TableName": "lineitem_orders_supplier_nation_customer_nation", "Inputs": [ { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "L:2,R:0,L:3,L:4,L:8,R:1,L:9,L:13,R:2,L:14,L:15,L:16,L:17,L:18,R:3,R:4,L:19,R:5,L:20", - "JoinVars": { - "n1_n_name": 1, - "o_custkey": 0 - }, - "TableName": "lineitem_orders_supplier_nation_customer_nation", + "OperatorType": "Projection", + "Expressions": [ + "[COLUMN 0] * [COLUMN 1] as revenue", + "[COLUMN 2] as supp_nation", + "[COLUMN 3] as l_year", + "[COLUMN 4] as orders.o_custkey", + "[COLUMN 5] as n1.n_name", + "[COLUMN 6] as weight_string(supp_nation)", + "[COLUMN 7] as weight_string(l_year)" + ], "Inputs": [ { "OperatorType": "Join", "Variant": "Join", - "JoinColumnIndexes": "L:0,R:0,R:1,L:2,L:3,L:5,R:2,R:3,R:4,L:6,L:8,R:5,R:6,R:7,L:9,L:10,L:11,R:8,R:9,R:10,L:12", + "JoinColumnIndexes": "L:0,R:0,R:1,L:1,L:2,L:3,R:2,L:5", "JoinVars": { - "l_suppkey": 1 + "l_suppkey": 4 }, "TableName": "lineitem_orders_supplier_nation", "Inputs": [ { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "R:0,L:0,L:2,L:3,L:0,R:0,L:2,L:6,R:2,L:7,L:4,R:1,L:8", - "JoinVars": { - "l_orderkey": 1 - }, - "TableName": "lineitem_orders", + "OperatorType": "Projection", + "Expressions": [ + "[COLUMN 0] * [COLUMN 1] as revenue", + "[COLUMN 2] as l_year", + "[COLUMN 3] as orders.o_custkey", + "[COLUMN 4] as n1.n_name", + "[COLUMN 5] as lineitem.l_suppkey", + "[COLUMN 6] as weight_string(l_year)" + ], "Inputs": [ { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "main", - "Sharded": true - }, - "FieldQuery": "select l_suppkey, l_orderkey, extract(year from l_shipdate) as l_year, l_extendedprice * (1 - l_discount) as volume, sum(volume) as revenue, weight_string(l_orderkey), weight_string(l_suppkey), weight_string(extract(year from l_shipdate)), weight_string(extract(year from l_shipdate)) from lineitem where 1 != 1 group by l_orderkey, weight_string(l_orderkey), l_suppkey, weight_string(l_suppkey), l_year, weight_string(l_year)", - "Query": "select l_suppkey, l_orderkey, extract(year from l_shipdate) as l_year, l_extendedprice * (1 - l_discount) as volume, sum(volume) as revenue, weight_string(l_orderkey), weight_string(l_suppkey), weight_string(extract(year from l_shipdate)), weight_string(extract(year from l_shipdate)) from lineitem where l_shipdate between date('1995-01-01') and date('1996-12-31') group by l_orderkey, weight_string(l_orderkey), l_suppkey, weight_string(l_suppkey), l_year, weight_string(l_year)", - "Table": "lineitem" - }, - { - "OperatorType": "Route", - "Variant": "EqualUnique", - "Keyspace": { - "Name": "main", - "Sharded": true + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,R:0,L:1,L:2,L:3,L:4,L:6", + "JoinVars": { + "l_orderkey": 5 }, - "FieldQuery": "select o_custkey, count(*), weight_string(o_custkey) from orders where 1 != 1 group by o_custkey, weight_string(o_custkey)", - "Query": "select o_custkey, count(*), weight_string(o_custkey) from orders where o_orderkey = :l_orderkey group by o_custkey, weight_string(o_custkey)", - "Table": "orders", - "Values": [ - ":l_orderkey" - ], - "Vindex": "hash" + "TableName": "lineitem_orders", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "main", + "Sharded": true + }, + "FieldQuery": "select sum(volume) as revenue, l_year, shipping.`orders.o_custkey`, shipping.`n1.n_name`, shipping.`lineitem.l_suppkey`, shipping.`lineitem.l_orderkey`, weight_string(l_year), supp_nation, weight_string(supp_nation), cust_nation, weight_string(cust_nation) from (select extract(year from l_shipdate) as l_year, l_extendedprice * (1 - l_discount) as volume, orders.o_custkey as `orders.o_custkey`, lineitem.l_suppkey as `lineitem.l_suppkey`, lineitem.l_orderkey as `lineitem.l_orderkey` from lineitem where 1 != 1) as shipping where 1 != 1 group by l_year, shipping.`orders.o_custkey`, shipping.`n1.n_name`, shipping.`lineitem.l_suppkey`, shipping.`lineitem.l_orderkey`, weight_string(l_year)", + "OrderBy": "(7|8) ASC, (9|10) ASC, (1|6) ASC", + "Query": "select sum(volume) as revenue, l_year, shipping.`orders.o_custkey`, shipping.`n1.n_name`, shipping.`lineitem.l_suppkey`, shipping.`lineitem.l_orderkey`, weight_string(l_year), supp_nation, weight_string(supp_nation), cust_nation, weight_string(cust_nation) from (select extract(year from l_shipdate) as l_year, l_extendedprice * (1 - l_discount) as volume, orders.o_custkey as `orders.o_custkey`, lineitem.l_suppkey as `lineitem.l_suppkey`, lineitem.l_orderkey as `lineitem.l_orderkey` from lineitem where l_shipdate between date('1995-01-01') and date('1996-12-31')) as shipping group by l_year, shipping.`orders.o_custkey`, shipping.`n1.n_name`, shipping.`lineitem.l_suppkey`, shipping.`lineitem.l_orderkey`, weight_string(l_year) order by supp_nation asc, cust_nation asc, l_year asc", + "Table": "lineitem" + }, + { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "main", + "Sharded": true + }, + "FieldQuery": "select count(*) from orders where 1 != 1 group by .0", + "Query": "select count(*) from orders where o_orderkey = :l_orderkey group by .0", + "Table": "orders", + "Values": [ + ":l_orderkey" + ], + "Vindex": "hash" + } + ] } ] }, { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "R:0,R:1,R:0,R:0,R:1,R:3,R:3,R:4,L:1,R:2,R:5", - "JoinVars": { - "s_nationkey": 0 - }, - "TableName": "supplier_nation", + "OperatorType": "Projection", + "Expressions": [ + "[COLUMN 0] * [COLUMN 1] as count(*)", + "[COLUMN 2] as supp_nation", + "[COLUMN 3] as weight_string(supp_nation)" + ], "Inputs": [ { - "OperatorType": "Route", - "Variant": "EqualUnique", - "Keyspace": { - "Name": "main", - "Sharded": true - }, - "FieldQuery": "select s_nationkey, count(*), weight_string(s_nationkey) from supplier where 1 != 1 group by s_nationkey, weight_string(s_nationkey)", - "Query": "select s_nationkey, count(*), weight_string(s_nationkey) from supplier where s_suppkey = :l_suppkey group by s_nationkey, weight_string(s_nationkey)", - "Table": "supplier", - "Values": [ - ":l_suppkey" - ], - "Vindex": "hash" - }, - { - "OperatorType": "Route", - "Variant": "EqualUnique", - "Keyspace": { - "Name": "main", - "Sharded": true + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,R:0,R:1,R:2", + "JoinVars": { + "s_nationkey": 1 }, - "FieldQuery": "select n1.n_name, n1.n_name as supp_nation, count(*), weight_string(n1.n_name), weight_string(n1.n_name), weight_string(n1.n_name) from nation as n1 where 1 != 1 group by n1.n_name, weight_string(n1.n_name), supp_nation, weight_string(supp_nation)", - "Query": "select n1.n_name, n1.n_name as supp_nation, count(*), weight_string(n1.n_name), weight_string(n1.n_name), weight_string(n1.n_name) from nation as n1 where n1.n_nationkey = :s_nationkey group by n1.n_name, weight_string(n1.n_name), supp_nation, weight_string(supp_nation)", - "Table": "nation", - "Values": [ - ":s_nationkey" - ], - "Vindex": "hash" + "TableName": "supplier_nation", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "main", + "Sharded": true + }, + "FieldQuery": "select count(*), shipping.`supplier.s_nationkey` from (select supplier.s_nationkey as `supplier.s_nationkey` from supplier where 1 != 1) as shipping where 1 != 1 group by shipping.`supplier.s_nationkey`", + "Query": "select count(*), shipping.`supplier.s_nationkey` from (select supplier.s_nationkey as `supplier.s_nationkey` from supplier where s_suppkey = :l_suppkey) as shipping group by shipping.`supplier.s_nationkey`", + "Table": "supplier", + "Values": [ + ":l_suppkey" + ], + "Vindex": "hash" + }, + { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "main", + "Sharded": true + }, + "FieldQuery": "select count(*), supp_nation, weight_string(supp_nation) from (select n1.n_name as supp_nation from nation as n1 where 1 != 1) as shipping where 1 != 1 group by supp_nation, weight_string(supp_nation)", + "Query": "select count(*), supp_nation, weight_string(supp_nation) from (select n1.n_name as supp_nation from nation as n1 where n1.n_nationkey = :s_nationkey) as shipping group by supp_nation, weight_string(supp_nation)", + "Table": "nation", + "Values": [ + ":s_nationkey" + ], + "Vindex": "hash" + } + ] } ] } ] - }, + } + ] + }, + { + "OperatorType": "Projection", + "Expressions": [ + "[COLUMN 0] * [COLUMN 1] as count(*)", + "[COLUMN 2] as cust_nation", + "[COLUMN 3] as weight_string(cust_nation)" + ], + "Inputs": [ { "OperatorType": "Join", "Variant": "Join", - "JoinColumnIndexes": "R:0,R:0,R:2,L:1,R:1,R:3", + "JoinColumnIndexes": "L:0,R:0,R:1,R:2", "JoinVars": { - "c_nationkey": 0 + "c_nationkey": 1 }, "TableName": "customer_nation", "Inputs": [ @@ -688,8 +729,8 @@ "Name": "main", "Sharded": true }, - "FieldQuery": "select c_nationkey, count(*), weight_string(c_nationkey) from customer where 1 != 1 group by c_nationkey, weight_string(c_nationkey)", - "Query": "select c_nationkey, count(*), weight_string(c_nationkey) from customer where c_custkey = :o_custkey group by c_nationkey, weight_string(c_nationkey)", + "FieldQuery": "select count(*), shipping.`customer.c_nationkey` from (select customer.c_nationkey as `customer.c_nationkey` from customer where 1 != 1) as shipping where 1 != 1 group by shipping.`customer.c_nationkey`", + "Query": "select count(*), shipping.`customer.c_nationkey` from (select customer.c_nationkey as `customer.c_nationkey` from customer where c_custkey = :o_custkey) as shipping group by shipping.`customer.c_nationkey`", "Table": "customer", "Values": [ ":o_custkey" @@ -703,8 +744,8 @@ "Name": "main", "Sharded": true }, - "FieldQuery": "select n2.n_name as cust_nation, count(*), weight_string(n2.n_name), weight_string(n2.n_name) from nation as n2 where 1 != 1 group by cust_nation, weight_string(cust_nation)", - "Query": "select n2.n_name as cust_nation, count(*), weight_string(n2.n_name), weight_string(n2.n_name) from nation as n2 where (:n1_n_name = 'FRANCE' and n2.n_name = 'GERMANY' or :n1_n_name = 'GERMANY' and n2.n_name = 'FRANCE') and n2.n_nationkey = :c_nationkey group by cust_nation, weight_string(cust_nation)", + "FieldQuery": "select count(*), cust_nation, weight_string(cust_nation) from (select n2.n_name as cust_nation from nation as n2 where 1 != 1) as shipping where 1 != 1 group by cust_nation, weight_string(cust_nation)", + "Query": "select count(*), cust_nation, weight_string(cust_nation) from (select n2.n_name as cust_nation from nation as n2 where (:n1_n_name = 'FRANCE' and n2.n_name = 'GERMANY' or :n1_n_name = 'GERMANY' and n2.n_name = 'FRANCE') and n2.n_nationkey = :c_nationkey) as shipping group by cust_nation, weight_string(cust_nation)", "Table": "nation", "Values": [ ":c_nationkey" diff --git a/go/vt/vtgate/planbuilder/testdata/union_cases.json b/go/vt/vtgate/planbuilder/testdata/union_cases.json index ef76413b5c7..ed836bf207b 100644 --- a/go/vt/vtgate/planbuilder/testdata/union_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/union_cases.json @@ -2361,7 +2361,7 @@ "Sharded": false }, "FieldQuery": "select kcu.COLUMN_NAME from (select kcu.COLUMN_NAME from information_schema.key_column_usage as kcu where 1 != 1 union select kcu.COLUMN_NAME from information_schema.key_column_usage as kcu where 1 != 1) as kcu where 1 != 1", - "Query": "select kcu.COLUMN_NAME from (select kcu.COLUMN_NAME from information_schema.key_column_usage as kcu where kcu.table_schema = :__vtschemaname /* VARCHAR */ and kcu.table_name = :kcu_table_name /* VARCHAR */ and kcu.COLUMN_NAME = 'primary' union select kcu.COLUMN_NAME from information_schema.key_column_usage as kcu where kcu.table_schema = :__vtschemaname /* VARCHAR */ and kcu.table_name = :kcu_table_name1 /* VARCHAR */ and kcu.COLUMN_NAME = 'primary') as kcu", + "Query": "select kcu.COLUMN_NAME from (select kcu.COLUMN_NAME from information_schema.key_column_usage as kcu where kcu.table_schema = :__vtschemaname /* VARCHAR */ and kcu.table_name = :kcu_table_name /* VARCHAR */ and COLUMN_NAME = 'primary' union select kcu.COLUMN_NAME from information_schema.key_column_usage as kcu where kcu.table_schema = :__vtschemaname /* VARCHAR */ and kcu.table_name = :kcu_table_name1 /* VARCHAR */ and COLUMN_NAME = 'primary') as kcu", "SysTableTableName": "[kcu_table_name1:VARCHAR(\"music\"), kcu_table_name:VARCHAR(\"user_extra\")]", "SysTableTableSchema": "[VARCHAR(\"user\"), VARCHAR(\"user\")]", "Table": "information_schema.key_column_usage" @@ -2411,7 +2411,7 @@ "Sharded": false }, "FieldQuery": "select kcu.CONSTRAINT_CATALOG, kcu.CONSTRAINT_SCHEMA, kcu.CONSTRAINT_NAME, kcu.TABLE_CATALOG, kcu.TABLE_SCHEMA, kcu.TABLE_NAME, kcu.COLUMN_NAME, kcu.ORDINAL_POSITION, kcu.POSITION_IN_UNIQUE_CONSTRAINT, kcu.REFERENCED_TABLE_SCHEMA, kcu.REFERENCED_TABLE_NAME, kcu.REFERENCED_COLUMN_NAME from (select kcu.CONSTRAINT_CATALOG, kcu.CONSTRAINT_SCHEMA, kcu.CONSTRAINT_NAME, kcu.TABLE_CATALOG, kcu.TABLE_SCHEMA, kcu.TABLE_NAME, kcu.COLUMN_NAME, kcu.ORDINAL_POSITION, kcu.POSITION_IN_UNIQUE_CONSTRAINT, kcu.REFERENCED_TABLE_SCHEMA, kcu.REFERENCED_TABLE_NAME, kcu.REFERENCED_COLUMN_NAME from information_schema.key_column_usage as kcu where 1 != 1 union select kcu.CONSTRAINT_CATALOG, kcu.CONSTRAINT_SCHEMA, kcu.CONSTRAINT_NAME, kcu.TABLE_CATALOG, kcu.TABLE_SCHEMA, kcu.TABLE_NAME, kcu.COLUMN_NAME, kcu.ORDINAL_POSITION, kcu.POSITION_IN_UNIQUE_CONSTRAINT, kcu.REFERENCED_TABLE_SCHEMA, kcu.REFERENCED_TABLE_NAME, kcu.REFERENCED_COLUMN_NAME from information_schema.key_column_usage as kcu where 1 != 1) as kcu where 1 != 1", - "Query": "select kcu.CONSTRAINT_CATALOG, kcu.CONSTRAINT_SCHEMA, kcu.CONSTRAINT_NAME, kcu.TABLE_CATALOG, kcu.TABLE_SCHEMA, kcu.TABLE_NAME, kcu.COLUMN_NAME, kcu.ORDINAL_POSITION, kcu.POSITION_IN_UNIQUE_CONSTRAINT, kcu.REFERENCED_TABLE_SCHEMA, kcu.REFERENCED_TABLE_NAME, kcu.REFERENCED_COLUMN_NAME from (select kcu.CONSTRAINT_CATALOG, kcu.CONSTRAINT_SCHEMA, kcu.CONSTRAINT_NAME, kcu.TABLE_CATALOG, kcu.TABLE_SCHEMA, kcu.TABLE_NAME, kcu.COLUMN_NAME, kcu.ORDINAL_POSITION, kcu.POSITION_IN_UNIQUE_CONSTRAINT, kcu.REFERENCED_TABLE_SCHEMA, kcu.REFERENCED_TABLE_NAME, kcu.REFERENCED_COLUMN_NAME from information_schema.key_column_usage as kcu where kcu.table_schema = :__vtschemaname /* VARCHAR */ and kcu.table_name = :kcu_table_name /* VARCHAR */ and kcu.CONSTRAINT_NAME = 'primary' union select kcu.CONSTRAINT_CATALOG, kcu.CONSTRAINT_SCHEMA, kcu.CONSTRAINT_NAME, kcu.TABLE_CATALOG, kcu.TABLE_SCHEMA, kcu.TABLE_NAME, kcu.COLUMN_NAME, kcu.ORDINAL_POSITION, kcu.POSITION_IN_UNIQUE_CONSTRAINT, kcu.REFERENCED_TABLE_SCHEMA, kcu.REFERENCED_TABLE_NAME, kcu.REFERENCED_COLUMN_NAME from information_schema.key_column_usage as kcu where kcu.table_schema = :__vtschemaname /* VARCHAR */ and kcu.table_name = :kcu_table_name1 /* VARCHAR */ and kcu.CONSTRAINT_NAME = 'primary') as kcu", + "Query": "select kcu.CONSTRAINT_CATALOG, kcu.CONSTRAINT_SCHEMA, kcu.CONSTRAINT_NAME, kcu.TABLE_CATALOG, kcu.TABLE_SCHEMA, kcu.TABLE_NAME, kcu.COLUMN_NAME, kcu.ORDINAL_POSITION, kcu.POSITION_IN_UNIQUE_CONSTRAINT, kcu.REFERENCED_TABLE_SCHEMA, kcu.REFERENCED_TABLE_NAME, kcu.REFERENCED_COLUMN_NAME from (select kcu.CONSTRAINT_CATALOG, kcu.CONSTRAINT_SCHEMA, kcu.CONSTRAINT_NAME, kcu.TABLE_CATALOG, kcu.TABLE_SCHEMA, kcu.TABLE_NAME, kcu.COLUMN_NAME, kcu.ORDINAL_POSITION, kcu.POSITION_IN_UNIQUE_CONSTRAINT, kcu.REFERENCED_TABLE_SCHEMA, kcu.REFERENCED_TABLE_NAME, kcu.REFERENCED_COLUMN_NAME from information_schema.key_column_usage as kcu where kcu.table_schema = :__vtschemaname /* VARCHAR */ and kcu.table_name = :kcu_table_name /* VARCHAR */ and CONSTRAINT_NAME = 'primary' union select kcu.CONSTRAINT_CATALOG, kcu.CONSTRAINT_SCHEMA, kcu.CONSTRAINT_NAME, kcu.TABLE_CATALOG, kcu.TABLE_SCHEMA, kcu.TABLE_NAME, kcu.COLUMN_NAME, kcu.ORDINAL_POSITION, kcu.POSITION_IN_UNIQUE_CONSTRAINT, kcu.REFERENCED_TABLE_SCHEMA, kcu.REFERENCED_TABLE_NAME, kcu.REFERENCED_COLUMN_NAME from information_schema.key_column_usage as kcu where kcu.table_schema = :__vtschemaname /* VARCHAR */ and kcu.table_name = :kcu_table_name1 /* VARCHAR */ and CONSTRAINT_NAME = 'primary') as kcu", "SysTableTableName": "[kcu_table_name1:VARCHAR(\"music\"), kcu_table_name:VARCHAR(\"user_extra\")]", "SysTableTableSchema": "[VARCHAR(\"user\"), VARCHAR(\"user\")]", "Table": "information_schema.key_column_usage" diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json index 330de6c7fd3..cc020dac2af 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json @@ -459,5 +459,11 @@ "comment": "extremum on input from both sides", "query": "insert into music(user_id, id) select foo, bar from music on duplicate key update id = id+1", "plan": "VT12001: unsupported: DML cannot update vindex column" + }, + { + "comment": "aggregation on top of aggregation not supported", + "query": "select distinct count(*) from user, (select distinct count(*) from user) X", + "v3-plan": "VT12001: unsupported: cross-shard query with aggregates", + "gen4-plan": "VT12001: unsupported: aggregation on top of aggregation not supported" } ] diff --git a/go/vt/vttablet/tabletmanager/vdiff/table_plan.go b/go/vt/vttablet/tabletmanager/vdiff/table_plan.go index d5dc11f7691..e669dbd9a33 100644 --- a/go/vt/vttablet/tabletmanager/vdiff/table_plan.go +++ b/go/vt/vttablet/tabletmanager/vdiff/table_plan.go @@ -109,11 +109,10 @@ func (td *tableDiffer) buildTablePlan(dbClient binlogplayer.DBClient, dbName str // this won't work: "select count(*) from (select id from t limit 1)" // since vreplication only handles simple tables (no joins/derived tables) this is fine for now // but will need to be revisited when we add such support to vreplication - aggregateFuncType := "sum" - aggregates = append(aggregates, &engine.AggregateParams{ - Opcode: opcode.SupportedAggregates[aggregateFuncType], - Col: len(sourceSelect.SelectExprs) - 1, - }) + aggregates = append(aggregates, engine.NewAggregateParam( + /*opcode*/ opcode.AggregateSum, + /*offset*/ len(sourceSelect.SelectExprs)-1, + /*alias*/ "")) } } default: diff --git a/go/vt/vttablet/tabletmanager/vdiff/workflow_differ_test.go b/go/vt/vttablet/tabletmanager/vdiff/workflow_differ_test.go index 31a014a28f4..daf7360b895 100644 --- a/go/vt/vttablet/tabletmanager/vdiff/workflow_differ_test.go +++ b/go/vt/vttablet/tabletmanager/vdiff/workflow_differ_test.go @@ -437,13 +437,10 @@ func TestBuildPlanSuccess(t *testing.T) { Expr: &sqlparser.ColName{Name: sqlparser.NewIdentifierCI("c1")}, Direction: sqlparser.AscOrder, }}, - aggregates: []*engine.AggregateParams{{ - Opcode: opcode.AggregateSum, - Col: 2, - }, { - Opcode: opcode.AggregateSum, - Col: 3, - }}, + aggregates: []*engine.AggregateParams{ + engine.NewAggregateParam(opcode.AggregateSum, 2, ""), + engine.NewAggregateParam(opcode.AggregateSum, 3, ""), + }, }, }, { // date conversion on import. diff --git a/go/vt/wrangler/vdiff.go b/go/vt/wrangler/vdiff.go index 098f4f18ff1..761412f3726 100644 --- a/go/vt/wrangler/vdiff.go +++ b/go/vt/wrangler/vdiff.go @@ -691,11 +691,10 @@ func (df *vdiff) buildTablePlan(table *tabletmanagerdatapb.TableDefinition, quer // this won't work: "select count(*) from (select id from t limit 1)" // since vreplication only handles simple tables (no joins/derived tables) this is fine for now // but will need to be revisited when we add such support to vreplication - aggregateFuncType := "sum" - aggregates = append(aggregates, &engine.AggregateParams{ - Opcode: opcode.SupportedAggregates[aggregateFuncType], - Col: len(sourceSelect.SelectExprs) - 1, - }) + aggregates = append(aggregates, engine.NewAggregateParam( + /*opcode*/ opcode.AggregateSum, + /*offset*/ len(sourceSelect.SelectExprs)-1, + /*alias*/ "")) } } default: diff --git a/go/vt/wrangler/vdiff_test.go b/go/vt/wrangler/vdiff_test.go index ef2857e55cf..500cc9a4080 100644 --- a/go/vt/wrangler/vdiff_test.go +++ b/go/vt/wrangler/vdiff_test.go @@ -400,13 +400,10 @@ func TestVDiffPlanSuccess(t *testing.T) { pkCols: []int{0}, selectPks: []int{0}, sourcePrimitive: &engine.OrderedAggregate{ - Aggregates: []*engine.AggregateParams{{ - Opcode: opcode.AggregateSum, - Col: 2, - }, { - Opcode: opcode.AggregateSum, - Col: 3, - }}, + Aggregates: []*engine.AggregateParams{ + engine.NewAggregateParam(opcode.AggregateSum, 2, ""), + engine.NewAggregateParam(opcode.AggregateSum, 3, ""), + }, GroupByKeys: []*engine.GroupByParams{{KeyCol: 0, WeightStringCol: -1}}, Input: newMergeSorter(nil, []compareColInfo{{0, collations.Collation(nil), true}}), },