diff --git a/internal/database/repository/partitions_init_test.go b/internal/database/repository/partitions_init_test.go index 66764f3..d3a1dde 100644 --- a/internal/database/repository/partitions_init_test.go +++ b/internal/database/repository/partitions_init_test.go @@ -2,7 +2,6 @@ package repository import ( "context" - "testing" "time" "github.com/G-Research/yunikorn-core/pkg/webservice/dao" @@ -21,21 +20,150 @@ type PartitionIntTest struct { } func (ps *PartitionIntTest) SetupSuite() { - ctx := context.Background() require.NotNil(ps.T(), ps.pool) repo, err := NewPostgresRepository(ps.pool) require.NoError(ps.T(), err) ps.repo = repo - - seedPartitions(ctx, ps.T(), ps.repo) } func (ps *PartitionIntTest) TearDownSuite() { ps.pool.Close() } +func (ps *PartitionIntTest) TestInsertPartition() { + ctx := context.Background() + now := time.Now() + nowNano := now.UnixNano() + tests := []struct { + name string + existingPartitions []*model.Partition + partitionToInsert *model.Partition + expectedError bool + }{ + { + name: "Insert Partition", + existingPartitions: nil, + partitionToInsert: &model.Partition{ + Metadata: model.Metadata{ + CreatedAtNano: nowNano, + }, + PartitionInfo: dao.PartitionInfo{ + ID: "300", + Name: "default", + ClusterID: "cluster1", + State: "Active", + LastStateTransitionTime: now.UnixMilli(), + }, + }, + expectedError: false, + }, + { + name: "Insert Partition with same ID", + existingPartitions: []*model.Partition{ + { + Metadata: model.Metadata{ + CreatedAtNano: nowNano, + }, + PartitionInfo: dao.PartitionInfo{ + ID: "300", + Name: "default", + ClusterID: "cluster1", + State: "Active", + LastStateTransitionTime: now.UnixMilli(), + }, + }, + }, + partitionToInsert: &model.Partition{ + Metadata: model.Metadata{ + CreatedAtNano: nowNano, + }, + PartitionInfo: dao.PartitionInfo{ + ID: "300", + Name: "default", + ClusterID: "cluster1", + State: "Active", + LastStateTransitionTime: now.UnixMilli(), + }, + }, + expectedError: true, + }, + } + + for _, tt := range tests { + ps.Run(tt.name, func() { + if tt.existingPartitions != nil { + for _, p := range tt.existingPartitions { + err := ps.repo.InsertPartition(ctx, p) + require.NoError(ps.T(), err) + } + } + err := ps.repo.InsertPartition(ctx, tt.partitionToInsert) + require.Equal(ps.T(), tt.expectedError, err != nil) + ps.clearPartitionsTable(ctx) + }) + } +} + func (ps *PartitionIntTest) TestGetAllPartitions() { ctx := context.Background() + now := time.Now() + nowNano := now.UnixNano() + + partitions := []*model.Partition{ + { + Metadata: model.Metadata{ + CreatedAtNano: nowNano, + }, + PartitionInfo: dao.PartitionInfo{ + ID: "1", + Name: "default", + ClusterID: "cluster1", + State: "Active", + LastStateTransitionTime: now.Add(-1 * time.Hour).UnixMilli(), + }, + }, + { + Metadata: model.Metadata{ + CreatedAtNano: nowNano, + }, + PartitionInfo: dao.PartitionInfo{ + ID: "2", + Name: "second", + ClusterID: "cluster1", + State: "Active", + LastStateTransitionTime: now.Add(-2 * time.Hour).UnixMilli(), + }, + }, + { + Metadata: model.Metadata{ + CreatedAtNano: nowNano, + }, + PartitionInfo: dao.PartitionInfo{ + ID: "3", + Name: "third", + ClusterID: "cluster1", + State: "Active", + LastStateTransitionTime: now.Add(-3 * time.Hour).UnixMilli(), + }, + }, + { + Metadata: model.Metadata{ + CreatedAtNano: nowNano, + }, + PartitionInfo: dao.PartitionInfo{ + ID: "4", + Name: "fourth", + ClusterID: "cluster1", + State: "FakeState", + LastStateTransitionTime: now.Add(-4 * time.Hour).UnixMilli(), + }, + }, + } + for _, p := range partitions { + err := ps.repo.InsertPartition(ctx, p) + require.NoError(ps.T(), err) + } + tests := []struct { name string filters PartitionFilters @@ -98,65 +226,219 @@ func (ps *PartitionIntTest) TestGetAllPartitions() { require.Len(ps.T(), nodes, tt.expected) }) } + ps.clearPartitionsTable(ctx) } -func seedPartitions(ctx context.Context, t *testing.T, repo *PostgresRepository) { - t.Helper() - +func (ps *PartitionIntTest) TestGetPartitionByID() { + ctx := context.Background() now := time.Now() nowNano := now.UnixNano() - partitions := []*model.Partition{ + + tests := []struct { + name string + existingPartition []*model.Partition + partitionID string + expectedError bool + }{ { - Metadata: model.Metadata{ - CreatedAtNano: nowNano, - }, - PartitionInfo: dao.PartitionInfo{ - ID: "1", - Name: "default", - ClusterID: "cluster1", - State: "Active", - LastStateTransitionTime: now.Add(-1 * time.Hour).UnixMilli(), + name: "Get Partition by ID", + existingPartition: []*model.Partition{ + { + Metadata: model.Metadata{ + CreatedAtNano: nowNano, + }, + PartitionInfo: dao.PartitionInfo{ + ID: "100", + Name: "default", + ClusterID: "cluster1", + State: "Active", + LastStateTransitionTime: now.Add(-1 * time.Hour).UnixMilli(), + }, + }, }, + partitionID: "100", + expectedError: false, }, { - Metadata: model.Metadata{ - CreatedAtNano: nowNano, - }, - PartitionInfo: dao.PartitionInfo{ - ID: "2", - Name: "second", - ClusterID: "cluster1", - State: "Active", - LastStateTransitionTime: now.Add(-2 * time.Hour).UnixMilli(), - }, + name: "Get Partition by ID when partition does not exist", + existingPartition: nil, + partitionID: "200", + expectedError: true, }, + } + + for _, tt := range tests { + ps.Run(tt.name, func() { + if tt.existingPartition != nil { + for _, p := range tt.existingPartition { + err := ps.repo.InsertPartition(ctx, p) + require.NoError(ps.T(), err) + } + } + partition, err := ps.repo.GetPartitionByID(ctx, tt.partitionID) + require.Equal(ps.T(), tt.expectedError, err != nil) + if !tt.expectedError { + require.Equal(ps.T(), tt.partitionID, partition.ID) + } + ps.clearPartitionsTable(ctx) + }) + } +} + +func (ps *PartitionIntTest) TestUpdatePartition() { + ctx := context.Background() + nowNano := time.Now().UnixNano() + tests := []struct { + name string + existingPartition []*model.Partition + partitionToUpdate *model.Partition + expectedError bool + }{ { - Metadata: model.Metadata{ - CreatedAtNano: nowNano, + name: "Update Partition when exists", + existingPartition: []*model.Partition{ + { + Metadata: model.Metadata{ + CreatedAtNano: nowNano, + }, + PartitionInfo: dao.PartitionInfo{ + ID: "100", + Name: "default", + ClusterID: "cluster1", + State: "Active", + LastStateTransitionTime: time.Now().Add(-1 * time.Hour).UnixMilli(), + }, + }, }, - PartitionInfo: dao.PartitionInfo{ - ID: "3", - Name: "third", - ClusterID: "cluster1", - State: "Active", - LastStateTransitionTime: now.Add(-3 * time.Hour).UnixMilli(), + partitionToUpdate: &model.Partition{ + Metadata: model.Metadata{ + CreatedAtNano: nowNano, + }, + PartitionInfo: dao.PartitionInfo{ + ID: "100", + Name: "updated", + ClusterID: "cluster2", + State: "Inactive", + LastStateTransitionTime: time.Now().UnixMilli(), + }, }, + expectedError: false, }, { - Metadata: model.Metadata{ - CreatedAtNano: nowNano, + name: "Update Partition when does not exist", + existingPartition: nil, + partitionToUpdate: &model.Partition{ + Metadata: model.Metadata{ + CreatedAtNano: nowNano, + }, + PartitionInfo: dao.PartitionInfo{ + ID: "200", + Name: "updated", + ClusterID: "cluster2", + State: "Inactive", + LastStateTransitionTime: time.Now().UnixMilli(), + }, }, - PartitionInfo: dao.PartitionInfo{ - ID: "4", - Name: "fourth", - ClusterID: "cluster1", - State: "FakeState", - LastStateTransitionTime: now.Add(-4 * time.Hour).UnixMilli(), + expectedError: true, + }, + } + + for _, tt := range tests { + ps.Run(tt.name, func() { + if tt.existingPartition != nil { + for _, p := range tt.existingPartition { + err := ps.repo.InsertPartition(ctx, p) + require.NoError(ps.T(), err) + } + } + + err := ps.repo.UpdatePartition(ctx, tt.partitionToUpdate) + require.Equal(ps.T(), tt.expectedError, err != nil) + + if !tt.expectedError { + partition, err := ps.repo.GetPartitionByID(ctx, tt.partitionToUpdate.ID) + require.NoError(ps.T(), err) + require.Equal(ps.T(), tt.partitionToUpdate.Name, partition.Name) + require.Equal(ps.T(), tt.partitionToUpdate.ClusterID, partition.ClusterID) + require.Equal(ps.T(), tt.partitionToUpdate.State, partition.State) + require.Equal(ps.T(), tt.partitionToUpdate.LastStateTransitionTime, partition.LastStateTransitionTime) + } + ps.clearPartitionsTable(ctx) + }) + } +} + +func (ps *PartitionIntTest) TestDeletePartitionsNotInIDs() { + ctx := context.Background() + now := time.Now() + nowNano := now.UnixNano() + + tests := []struct { + name string + existingPartitions []*model.Partition + partitionIDs []string + expectedError bool + }{ + { + name: "Delete Partitions with correct IDs", + existingPartitions: []*model.Partition{ + { + Metadata: model.Metadata{ + CreatedAtNano: nowNano, + }, + PartitionInfo: dao.PartitionInfo{ + ID: "1", + Name: "default", + ClusterID: "cluster1", + State: "Active", + LastStateTransitionTime: now.Add(-1 * time.Hour).UnixMilli(), + }, + }, + { + Metadata: model.Metadata{ + CreatedAtNano: nowNano, + }, + PartitionInfo: dao.PartitionInfo{ + ID: "2", + Name: "second", + ClusterID: "cluster1", + State: "Active", + LastStateTransitionTime: now.Add(-2 * time.Hour).UnixMilli(), + }, + }, + { + Metadata: model.Metadata{ + CreatedAtNano: nowNano, + }, + PartitionInfo: dao.PartitionInfo{ + ID: "3", + Name: "third", + ClusterID: "cluster1", + State: "Active", + LastStateTransitionTime: now.Add(-3 * time.Hour).UnixMilli(), + }, + }, }, + partitionIDs: []string{"1", "2"}, + expectedError: false, }, } - for _, p := range partitions { - err := repo.InsertPartition(ctx, p) - require.NoError(t, err) + + for _, tt := range tests { + ps.Run(tt.name, func() { + for _, p := range tt.existingPartitions { + err := ps.repo.InsertPartition(ctx, p) + require.NoError(ps.T(), err) + } + + err := ps.repo.DeletePartitionsNotInIDs(ctx, tt.partitionIDs, nowNano) + require.Equal(ps.T(), tt.expectedError, err != nil) + ps.clearPartitionsTable(ctx) + }) } } + +func (ps *PartitionIntTest) clearPartitionsTable(ctx context.Context) { + _, err := ps.pool.Exec(ctx, "DELETE FROM partitions") + require.NoError(ps.T(), err) +}