diff --git a/bigtable/bttest/inmem.go b/bigtable/bttest/inmem.go index bca700b92094..96edb64ebacf 100644 --- a/bigtable/bttest/inmem.go +++ b/bigtable/bttest/inmem.go @@ -334,20 +334,43 @@ func (s *server) ModifyColumnFamilies(ctx context.Context, req *btapb.ModifyColu }) } else if modify := mod.GetUpdate(); modify != nil { newcf := newColumnFamily(req.Name+"/columnFamilies/"+mod.Id, 0, modify) + updateMask := mod.GetUpdateMask() + paths := updateMask.GetPaths() + cf, ok := tbl.families[mod.Id] if !ok { return nil, fmt.Errorf("no such family %q", mod.Id) } - if cf.valueType != nil { - _, isOldAggregateType := cf.valueType.Kind.(*btapb.Type_AggregateType) - if isOldAggregateType && cf.valueType != newcf.valueType { - return nil, status.Errorf(codes.InvalidArgument, "Immutable fields 'value_type.aggregate_type' cannot be updated") - } + + var utr *btapb.ColumnFamily + if len(paths) > 0 && + !updateMask.IsValid(utr) { + return nil, status.Errorf(codes.InvalidArgument, + "incorrect path in UpdateMask; got: %v", + updateMask) } - // assume that we ALWAYS want to replace by the new setting - // we may need partial update through - tbl.families[mod.Id] = newcf + if len(paths) == 0 { + // Assume that the update is only modifying the GC policy. + cf.gcRule = newcf.gcRule + } + + for _, path := range paths { + switch path { + case "value_type": + if cf.valueType != nil && + cf.valueType.GetAggregateType() != nil { + // The existing column family is an aggregate type, and the update + // is attempting to modify its immutable type. + return nil, status.Errorf(codes.InvalidArgument, + "Immutable fields 'value_type.aggregate_type' cannot be updated") + } + + cf.valueType = newcf.valueType + case "gc_rule": + cf.gcRule = newcf.gcRule + } + } } } diff --git a/bigtable/bttest/inmem_test.go b/bigtable/bttest/inmem_test.go index 04b15ef87666..b807e940174a 100644 --- a/bigtable/bttest/inmem_test.go +++ b/bigtable/bttest/inmem_test.go @@ -42,6 +42,7 @@ import ( "google.golang.org/protobuf/proto" "google.golang.org/protobuf/testing/protocmp" "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/fieldmaskpb" "google.golang.org/protobuf/types/known/wrapperspb" ) @@ -2711,3 +2712,313 @@ func TestAuthorizedViewApis(t *testing.T) { t.Fatalf("Failed to error %s", err) } } + +func TestUpdateGCPolicyOnAggregateColumnFamily(t *testing.T) { + ctx := context.Background() + + s := &server{ + tables: make(map[string]*table), + } + + tblInfo, err := s.CreateTable(ctx, &btapb.CreateTableRequest{ + Parent: "cluster", + TableId: "t", + Table: &btapb.Table{ + ColumnFamilies: map[string]*btapb.ColumnFamily{ + "sum": { + ValueType: &btapb.Type{ + Kind: &btapb.Type_AggregateType{ + AggregateType: &btapb.Type_Aggregate{ + InputType: &btapb.Type{ + Kind: &btapb.Type_Int64Type{}, + }, + Aggregator: &btapb.Type_Aggregate_Sum_{ + Sum: &btapb.Type_Aggregate_Sum{}, + }, + }, + }, + }, + }, + }, + }, + }) + if err != nil { + t.Fatal(err) + } + + if tblInfo.ColumnFamilies["sum"]. + GetValueType(). + GetAggregateType(). + GetSum() == nil { + t.Fatal("Unexpected aggregate column family type at start of test") + } + + if tblInfo.ColumnFamilies["sum"]. + GetGcRule(). + GetMaxNumVersions() == 1 { + t.Fatal("Unexpected GC policy state at start of test") + } + + tblInfo, err = s.ModifyColumnFamilies(ctx, &btapb.ModifyColumnFamiliesRequest{ + Name: tblInfo.Name, + Modifications: []*btapb.ModifyColumnFamiliesRequest_Modification{ + { + Id: "sum", + // UpdateMask intentionally left empty, which the server will + // implicitly interpret as a gc_rule update. + Mod: &btapb.ModifyColumnFamiliesRequest_Modification_Update{ + Update: &btapb.ColumnFamily{ + GcRule: &btapb.GcRule{ + Rule: &btapb.GcRule_MaxNumVersions{ + MaxNumVersions: 1, + }, + }, + // HACK: Intentionally include an invalid type + // update, which should be ignored since it isn't + // present in the UpdateMask. + ValueType: &btapb.Type{ + Kind: &btapb.Type_AggregateType{ + AggregateType: &btapb.Type_Aggregate{ + InputType: &btapb.Type{ + Kind: &btapb.Type_Int64Type{}, + }, + Aggregator: &btapb.Type_Aggregate_Max_{ + Max: &btapb.Type_Aggregate_Max{}, + }, + }, + }, + }, + }, + }, + }, + }, + }) + if err != nil { + t.Fatal(err) + } + + if tblInfo.ColumnFamilies["sum"]. + GetValueType(). + GetAggregateType(). + GetSum() == nil { + t.Fatal("Aggregate type was updated when it should not have been") + } + + if tblInfo.ColumnFamilies["sum"]. + GetGcRule(). + GetMaxNumVersions() != 1 { + t.Fatal("GC policy was not updated when it should have been") + } + + tblInfo, err = s.ModifyColumnFamilies(ctx, &btapb.ModifyColumnFamiliesRequest{ + Name: tblInfo.Name, + Modifications: []*btapb.ModifyColumnFamiliesRequest_Modification{ + { + Id: "sum", + // Including UpdateMask in the request this time. + UpdateMask: &fieldmaskpb.FieldMask{ + Paths: []string{"gc_rule"}, + }, + Mod: &btapb.ModifyColumnFamiliesRequest_Modification_Update{ + Update: &btapb.ColumnFamily{ + GcRule: &btapb.GcRule{ + Rule: &btapb.GcRule_MaxNumVersions{ + MaxNumVersions: 2, + }, + }, + // HACK: Intentionally including an invalid type + // update, which should be ignored since it isn't + // present in the UpdateMask. + ValueType: &btapb.Type{ + Kind: &btapb.Type_AggregateType{ + AggregateType: &btapb.Type_Aggregate{ + InputType: &btapb.Type{ + Kind: &btapb.Type_Int64Type{}, + }, + Aggregator: &btapb.Type_Aggregate_Max_{ + Max: &btapb.Type_Aggregate_Max{}, + }, + }, + }, + }, + }, + }, + }, + }, + }) + if err != nil { + t.Fatal(err) + } + + if tblInfo.ColumnFamilies["sum"]. + GetValueType(). + GetAggregateType(). + GetSum() == nil { + t.Fatal("Aggregate type was updated when it should not have been") + } + + if tblInfo.ColumnFamilies["sum"]. + GetGcRule(). + GetMaxNumVersions() != 2 { + t.Fatal("GC policy was not updated when it should have been") + } +} + +func TestCannotUpdateTypeOfAggregateColumnFamily(t *testing.T) { + ctx := context.Background() + + s := &server{ + tables: make(map[string]*table), + } + + tblInfo, err := s.CreateTable(ctx, &btapb.CreateTableRequest{ + Parent: "cluster", + TableId: "t", + Table: &btapb.Table{ + ColumnFamilies: map[string]*btapb.ColumnFamily{ + "sum": { + ValueType: &btapb.Type{ + Kind: &btapb.Type_AggregateType{ + AggregateType: &btapb.Type_Aggregate{ + InputType: &btapb.Type{ + Kind: &btapb.Type_Int64Type{}, + }, + Aggregator: &btapb.Type_Aggregate_Sum_{ + Sum: &btapb.Type_Aggregate_Sum{}, + }, + }, + }, + }, + }, + }, + }, + }) + if err != nil { + t.Fatal(err) + } + + if tblInfo.ColumnFamilies["sum"]. + GetValueType(). + GetAggregateType(). + GetSum() == nil { + t.Fatal("Unexpected aggregate column family type at start of test") + } + + _, err = s.ModifyColumnFamilies(ctx, &btapb.ModifyColumnFamiliesRequest{ + Name: tblInfo.Name, + Modifications: []*btapb.ModifyColumnFamiliesRequest_Modification{ + { + Id: "sum", + UpdateMask: &fieldmaskpb.FieldMask{ + Paths: []string{"value_type"}, + }, + Mod: &btapb.ModifyColumnFamiliesRequest_Modification_Update{ + Update: &btapb.ColumnFamily{ + ValueType: &btapb.Type{ + Kind: &btapb.Type_AggregateType{ + AggregateType: &btapb.Type_Aggregate{ + InputType: &btapb.Type{ + Kind: &btapb.Type_Int64Type{}, + }, + Aggregator: &btapb.Type_Aggregate_Max_{ + Max: &btapb.Type_Aggregate_Max{}, + }, + }, + }, + }, + }, + }, + }, + }, + }) + if err == nil { + t.Fatal("ModifyColumnFamilies was supposed to return an error, but it did not") + } + + tblInfo, err = s.GetTable(ctx, &btapb.GetTableRequest{Name: tblInfo.Name}) + if err != nil { + t.Fatal(err) + } + + if tblInfo.ColumnFamilies["sum"]. + GetValueType(). + GetAggregateType(). + GetSum() == nil { + t.Fatal("Aggregate type was updated when it should not have been") + } +} + +func TestInvalidUpdateMaskInColumnFamilyUpdate(t *testing.T) { + ctx := context.Background() + + s := &server{ + tables: make(map[string]*table), + } + + tblInfo, err := s.CreateTable(ctx, &btapb.CreateTableRequest{ + Parent: "cluster", + TableId: "t", + Table: &btapb.Table{ + ColumnFamilies: map[string]*btapb.ColumnFamily{ + "sum": { + ValueType: &btapb.Type{ + Kind: &btapb.Type_AggregateType{ + AggregateType: &btapb.Type_Aggregate{ + InputType: &btapb.Type{ + Kind: &btapb.Type_Int64Type{}, + }, + Aggregator: &btapb.Type_Aggregate_Sum_{ + Sum: &btapb.Type_Aggregate_Sum{}, + }, + }, + }, + }, + }, + }, + }, + }) + if err != nil { + t.Fatal(err) + } + + if tblInfo.ColumnFamilies["sum"]. + GetGcRule(). + GetMaxNumVersions() == 1 { + t.Fatal("Unexpected GC policy state at start of test") + } + + _, err = s.ModifyColumnFamilies(ctx, &btapb.ModifyColumnFamiliesRequest{ + Name: tblInfo.Name, + Modifications: []*btapb.ModifyColumnFamiliesRequest_Modification{ + { + Id: "sum", + UpdateMask: &fieldmaskpb.FieldMask{ + Paths: []string{"bad", "gc_rule"}, + }, + Mod: &btapb.ModifyColumnFamiliesRequest_Modification_Update{ + Update: &btapb.ColumnFamily{ + GcRule: &btapb.GcRule{ + Rule: &btapb.GcRule_MaxNumVersions{ + MaxNumVersions: 1, + }, + }, + }, + }, + }, + }, + }) + if err == nil { + t.Fatal("ModifyColumnFamilies was supposed to return an error, but it did not") + } + + tblInfo, err = s.GetTable(ctx, &btapb.GetTableRequest{Name: tblInfo.Name}) + if err != nil { + t.Fatal(err) + } + + if tblInfo.ColumnFamilies["sum"]. + GetGcRule(). + GetMaxNumVersions() == 1 { + t.Fatal("GC policy was updated when it should not have been") + } +}