Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add StreamListObject to LB #2203

Merged
merged 6 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions internal/errors/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,14 @@ var (
ErrInvalidProtoMessageType = func(v interface{}) error {
return Errorf("failed to marshal/unmarshal proto message, message type is %T (missing vtprotobuf/protobuf helpers)", v)
}

// ErrServerStreamClientRecv represents a function to generate an error that the gRPC client couldn't receive from stream.
ErrServerStreamClientRecv = func(err error) error {
return Wrap(err, "gRPC client failed to receive from stream")
}

// ErrServerStreamClientSend represents a function to generate an error that the gRPC server couldn't send to stream.
ErrServerStreamServerSend = func(err error) error {
return Wrap(err, "gRPC server failed to send to stream")
}
)
90 changes: 89 additions & 1 deletion pkg/gateway/lb/handler/grpc/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package grpc
import (
"context"
"fmt"
"io"
"slices"
"strconv"
"sync/atomic"
Expand Down Expand Up @@ -2907,7 +2908,7 @@ func (s *server) getObject(ctx context.Context, uuid string) (vec *payload.Objec
ech <- s.gateway.BroadCast(ctx, func(ctx context.Context, target string, vc vald.Client, copts ...grpc.CallOption) error {
sctx, sspan := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BroadCast/"+target), apiName+"/getObject/BroadCast/"+target)
defer func() {
if span != nil {
if sspan != nil {
sspan.End()
}
}()
Expand Down Expand Up @@ -3134,3 +3135,90 @@ func (s *server) StreamGetObject(stream vald.Object_StreamGetObjectServer) (err
}
return nil
}

func (s *server) StreamListObject(req *payload.Object_List_Request, stream vald.Object_StreamListObjectServer) error {
ctx, span := trace.StartSpan(grpc.WithGRPCMethod(stream.Context(), vald.PackageName+"."+vald.ObjectRPCServiceName+"/"+vald.StreamListObjectRPCName), apiName+"/"+vald.StreamListObjectRPCName)
defer func() {
if span != nil {
span.End()
}
}()

ctx, cancel := context.WithCancel(ctx)
defer cancel()

var rmu, smu sync.Mutex
err := s.gateway.BroadCast(ctx, func(ctx context.Context, target string, vc vald.Client, copts ...grpc.CallOption) error {
ctx, sspan := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BroadCast/"+target), apiName+"/"+vald.StreamListObjectRPCName+"/"+target)
defer func() {
if sspan != nil {
sspan.End()
}
}()

client, err := vc.StreamListObject(ctx, req, copts...)
if err != nil {
log.Errorf("failed to get StreamListObject client for agent(%s): %v", target, err)
return err
}

eg, ctx := errgroup.WithContext(ctx)
ectx, ecancel := context.WithCancel(ctx)
defer ecancel()
eg.SetLimit(s.streamConcurrency)

for {
select {
case <-ectx.Done():
var err error
if !errors.Is(ctx.Err(), context.Canceled) {
err = errors.Join(err, ctx.Err())
}
if egerr := eg.Wait(); err != nil {
err = errors.Join(err, egerr)
}
return err
default:
eg.Go(safety.RecoverFunc(func() error {
rmu.Lock()
res, err := client.Recv()
rmu.Unlock()
if err != nil {
if errors.Is(err, io.EOF) {
ecancel()
return nil
}
return errors.ErrServerStreamClientRecv(err)
}

vec := res.GetVector()
if vec == nil {
st := res.GetStatus()
log.Warnf("received empty vector: code %v: details %v: message %v",
st.GetCode(),
st.GetDetails(),
st.GetMessage(),
)
return nil
}

smu.Lock()
err = stream.Send(res)
smu.Unlock()
if err != nil {
if sspan != nil {
st, msg, err := status.ParseError(err, codes.Internal, "failed to parse StreamListObject send gRPC error response")
sspan.RecordError(err)
sspan.SetAttributes(trace.FromGRPCStatus(st.Code(), msg)...)
sspan.SetStatus(trace.StatusError, err.Error())
}
return errors.ErrServerStreamServerSend(err)
}

return nil
}))
}
}
})
return err
}
4 changes: 2 additions & 2 deletions pkg/gateway/lb/service/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ type Gateway interface {
GetAgentCount(ctx context.Context) int
Addrs(ctx context.Context) []string
DoMulti(ctx context.Context, num int,
f func(ctx context.Context, tgt string, ac vald.Client, copts ...grpc.CallOption) error) error
f func(ctx context.Context, target string, ac vald.Client, copts ...grpc.CallOption) error) error
BroadCast(ctx context.Context,
f func(ctx context.Context, tgt string, ac vald.Client, copts ...grpc.CallOption) error) error
f func(ctx context.Context, target string, ac vald.Client, copts ...grpc.CallOption) error) error
}

type gateway struct {
Expand Down
7 changes: 7 additions & 0 deletions tests/e2e/crud/crud_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,13 @@ func TestE2EStandardCRUD(t *testing.T) {
t.Fatalf("an error occurred: %s", err)
}

err = op.StreamListObject(t, ctx, operation.Dataset{
Train: ds.Train[insertFrom : insertFrom+insertNum],
})
if err != nil {
t.Fatalf("an error occurred: %s", err)
}

err = op.Update(t, ctx, operation.Dataset{
Train: ds.Train[updateFrom : updateFrom+updateNum],
})
Expand Down
1 change: 1 addition & 0 deletions tests/e2e/operation/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ type Client interface {
MultiUpsert(t *testing.T, ctx context.Context, ds Dataset) error
MultiRemove(t *testing.T, ctx context.Context, ds Dataset) error
GetObject(t *testing.T, ctx context.Context, ds Dataset) error
StreamListObject(t *testing.T, ctx context.Context, ds Dataset) error
Exists(t *testing.T, ctx context.Context, id string) error
CreateIndex(t *testing.T, ctx context.Context) error
SaveIndex(t *testing.T, ctx context.Context) error
Expand Down
59 changes: 59 additions & 0 deletions tests/e2e/operation/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package operation

import (
"context"
"fmt"
"reflect"
"strconv"
"testing"
Expand Down Expand Up @@ -1167,3 +1168,61 @@ func (c *client) GetObject(

return rerr
}

func (c *client) StreamListObject(
t *testing.T,
ctx context.Context,
ds Dataset,
) error {
t.Log("StreamListObject operation started")

client, err := c.getClient(ctx)
if err != nil {
return err
}

sc, err := client.StreamListObject(ctx, &payload.Object_List_Request{})
if err != nil {
return err
}

// kv : [indexId]count
indexCnt := make(map[string]int)
exit_loop:
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
res, err := sc.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
break exit_loop
}
return err
}
vec := res.GetVector()
if vec == nil {
st := res.GetStatus()
return fmt.Errorf("returned vector is empty: code: %v, msg: %v, details: %v", st.GetCode(), st.GetMessage(), st.GetDetails())
}
indexCnt[vec.GetId()]++
}
}

if len(indexCnt) != len(ds.Train) {
return fmt.Errorf("the number of vectors returned is different: got %v, want %v", len(indexCnt), len(ds.Train))
}

replica := -1
for k, v := range indexCnt {
if replica == -1 {
replica = v
continue
}
if v != replica {
return fmt.Errorf("the number of vectors returned is different at index id %v: got %v, want %v", k, v, replica)
}
}
return nil
}