diff --git a/pkg/client/mediator/client.go b/pkg/client/mediator/client.go index f17d574a24..4f3eec55bd 100644 --- a/pkg/client/mediator/client.go +++ b/pkg/client/mediator/client.go @@ -39,7 +39,7 @@ type protocolService interface { Unregister(connID string) error // GetConnections returns router`s connections. - GetConnections() ([]string, error) + GetConnections(...mediator.ConnectionOption) ([]string, error) // Config returns the router's configuration. Config(connID string) (*mediator.Config, error) @@ -91,8 +91,8 @@ func (c *Client) Unregister(connID string) error { } // GetConnections returns router`s connections. -func (c *Client) GetConnections() ([]string, error) { - connections, err := c.routeSvc.GetConnections() +func (c *Client) GetConnections(options ...ConnectionOption) ([]string, error) { + connections, err := c.routeSvc.GetConnections(options...) if err != nil { return nil, fmt.Errorf("get router connections: %w", err) } diff --git a/pkg/client/mediator/models.go b/pkg/client/mediator/models.go index b6b8dbdc23..adf23ecf10 100644 --- a/pkg/client/mediator/models.go +++ b/pkg/client/mediator/models.go @@ -22,6 +22,9 @@ const ( // Request is the route-request message of this protocol. type Request = mediator.Request +// ConnectionOption option for Client.GetConnections. +type ConnectionOption = mediator.ConnectionOption + // NewRequest creates a new request. func NewRequest() *Request { return &Request{ diff --git a/pkg/controller/command/mediator/command.go b/pkg/controller/command/mediator/command.go index cf69599dfb..11dd448c60 100644 --- a/pkg/controller/command/mediator/command.go +++ b/pkg/controller/command/mediator/command.go @@ -19,6 +19,7 @@ import ( "github.com/hyperledger/aries-framework-go/pkg/controller/command" "github.com/hyperledger/aries-framework-go/pkg/controller/internal/cmdutil" "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" + mediatorSvc "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/mediator" "github.com/hyperledger/aries-framework-go/pkg/internal/logutil" "github.com/hyperledger/aries-framework-go/pkg/kms" ) @@ -215,7 +216,38 @@ func (o *Command) Unregister(rw io.Writer, req io.Reader) command.Error { // Connections returns the connections of the router. func (o *Command) Connections(rw io.Writer, req io.Reader) command.Error { - connections, err := o.routeClient.GetConnections() + var request ConnectionsRequest + + if req != nil { + reqData, err := io.ReadAll(req) + if err != nil { + logutil.LogInfo(logger, CommandName, GetConnectionsCommandMethod, err.Error()) + return command.NewValidationError(GetConnectionsErrorCode, fmt.Errorf("read request : %w", err)) + } + + if len(reqData) > 0 { + err = json.Unmarshal(reqData, &request) + if err != nil { + logutil.LogInfo(logger, CommandName, GetConnectionsCommandMethod, err.Error()) + return command.NewValidationError(InvalidRequestErrorCode, fmt.Errorf("request decode : %w", err)) + } + } + } + + opts := []mediator.ConnectionOption{} + + if request.DIDCommV1Only && request.DIDCommV2Only { + errMsg := "can't request didcomm v1 only at the same time as didcomm v2 only" + + logutil.LogError(logger, CommandName, GetConnectionsCommandMethod, errMsg) + return command.NewValidationError(GetConnectionsErrorCode, fmt.Errorf("%s", errMsg)) + } else if request.DIDCommV2Only { + opts = append(opts, mediatorSvc.ConnectionByVersion(service.V2)) + } else if request.DIDCommV1Only { + opts = append(opts, mediatorSvc.ConnectionByVersion(service.V1)) + } + + connections, err := o.routeClient.GetConnections(opts...) if err != nil { logutil.LogError(logger, CommandName, GetConnectionsCommandMethod, err.Error()) return command.NewExecuteError(GetConnectionsErrorCode, err) diff --git a/pkg/controller/command/mediator/command_test.go b/pkg/controller/command/mediator/command_test.go index f6cfa15754..4a5b70779a 100644 --- a/pkg/controller/command/mediator/command_test.go +++ b/pkg/controller/command/mediator/command_test.go @@ -208,7 +208,7 @@ func TestCommand_Connections(t *testing.T) { require.NotNil(t, cmd) var b bytes.Buffer - err = cmd.Connections(&b, nil) + err = cmd.Connections(&b, bytes.NewBufferString("{}")) require.NoError(t, err) response := ConnectionsResponse{} @@ -234,7 +234,7 @@ func TestCommand_Connections(t *testing.T) { require.NotNil(t, cmd) var b bytes.Buffer - err = cmd.Connections(&b, nil) + err = cmd.Connections(&b, bytes.NewBufferString("{}")) require.Error(t, err) require.Contains(t, err.Error(), "get router connections") }) diff --git a/pkg/controller/command/mediator/models.go b/pkg/controller/command/mediator/models.go index 43256bc900..aa26679314 100644 --- a/pkg/controller/command/mediator/models.go +++ b/pkg/controller/command/mediator/models.go @@ -16,6 +16,12 @@ type RegisterRoute struct { ConnectionID string `json:"connectionID"` } +// ConnectionsRequest contains parameters for filtering when requesting router connections. +type ConnectionsRequest struct { + DIDCommV2Only bool `json:"didcomm_v2"` + DIDCommV1Only bool `json:"didcomm_v1"` +} + // ConnectionsResponse is response for router`s connections. type ConnectionsResponse struct { Connections []string `json:"connections"` diff --git a/pkg/didcomm/protocol/mediator/api.go b/pkg/didcomm/protocol/mediator/api.go index f318e6c5ec..0f2ad9663b 100644 --- a/pkg/didcomm/protocol/mediator/api.go +++ b/pkg/didcomm/protocol/mediator/api.go @@ -15,5 +15,5 @@ type ProtocolService interface { Config(connID string) (*Config, error) // GetConnections returns all router connections - GetConnections() ([]string, error) + GetConnections(options ...ConnectionOption) ([]string, error) } diff --git a/pkg/didcomm/protocol/mediator/service.go b/pkg/didcomm/protocol/mediator/service.go index 4e1a8d2358..d737d59c8d 100644 --- a/pkg/didcomm/protocol/mediator/service.go +++ b/pkg/didcomm/protocol/mediator/service.go @@ -126,6 +126,11 @@ type callback struct { err error } +type routerConnectionEntry struct { + ConnectionID string `json:"connectionID"` + DIDCommVersion service.Version `json:"didcomm_version,omitempty"` +} + type connections interface { GetConnectionIDByDIDs(string, string) (string, error) GetConnectionRecord(string) (*connection.Record, error) @@ -646,7 +651,7 @@ func (s *Service) doRegistration(record *connection.Record, req *Request, timeou logger.Debugf("saved router config from inbound grant: %+v", grant) // save the connectionID of the router - return s.saveRouterConnectionID(record.ConnectionID) + return s.saveRouterConnectionID(record.ConnectionID, record.DIDCommVersion) } func (s *Service) getGrant(id string, timeout time.Duration) (*Grant, error) { @@ -700,7 +705,13 @@ func (s *Service) Unregister(connID string) error { } // GetConnections returns the connections of the router. -func (s *Service) GetConnections() ([]string, error) { +func (s *Service) GetConnections(options ...ConnectionOption) ([]string, error) { + opts := &getConnectionOpts{} + + for _, option := range options { + option(opts) + } + records, err := s.routeStore.Query(routeConnIDDataKey) if err != nil { return nil, fmt.Errorf("failed to query route store: %w", err) @@ -721,7 +732,16 @@ func (s *Service) GetConnections() ([]string, error) { return nil, fmt.Errorf("failed to get value from records: %w", err) } - conns = append(conns, string(value)) + data := &routerConnectionEntry{} + + err = json.Unmarshal(value, data) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal router connection entry: %w", err) + } + + if opts.version == "" || opts.version == data.DIDCommVersion { + conns = append(conns, data.ConnectionID) + } more, err = records.Next() if err != nil { @@ -838,8 +858,18 @@ func (s *Service) deleteRouterConnectionID(connID string) error { return s.routeStore.Delete(fmt.Sprintf(routeConnIDDataKey, connID)) } -func (s *Service) saveRouterConnectionID(connID string) error { - return s.routeStore.Put(fmt.Sprintf(routeConnIDDataKey, connID), []byte(connID), storage.Tag{Name: routeConnIDDataKey}) +func (s *Service) saveRouterConnectionID(connID string, didcommVersion service.Version) error { + data := &routerConnectionEntry{ + ConnectionID: connID, + DIDCommVersion: didcommVersion, + } + + dataBytes, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("marshalling router connection ID data: %w", err) + } + + return s.routeStore.Put(fmt.Sprintf(routeConnIDDataKey, connID), dataBytes, storage.Tag{Name: routeConnIDDataKey}) } type config struct { @@ -918,3 +948,17 @@ func parseClientOpts(options ...ClientOption) *ClientOptions { return opts } + +type getConnectionOpts struct { + version service.Version +} + +// ConnectionOption option for Service.GetConnections. +type ConnectionOption func(opts *getConnectionOpts) + +// ConnectionByVersion filter for mediator connections of the given DIDComm version. +func ConnectionByVersion(v service.Version) ConnectionOption { + return func(opts *getConnectionOpts) { + opts.version = v + } +} diff --git a/pkg/didcomm/protocol/mediator/service_test.go b/pkg/didcomm/protocol/mediator/service_test.go index 9676e0471b..3d0c09f4cd 100644 --- a/pkg/didcomm/protocol/mediator/service_test.go +++ b/pkg/didcomm/protocol/mediator/service_test.go @@ -1096,7 +1096,7 @@ func TestUnregister(t *testing.T) { ) require.NoError(t, err) - s[fmt.Sprintf(routeConnIDDataKey, connID)] = mockstore.DBEntry{Value: []byte("conn-abc-xyz")} + s[fmt.Sprintf(routeConnIDDataKey, connID)] = mockstore.DBEntry{Value: []byte("{\"connectionID\":\"conn-abc-xyz\"}")} err = svc.Unregister(connID) require.NoError(t, err) @@ -1179,7 +1179,7 @@ func TestKeylistUpdate(t *testing.T) { require.NoError(t, err) // save router connID - require.NoError(t, svc.saveRouterConnectionID("conn")) + require.NoError(t, svc.saveRouterConnectionID("conn", "")) // save connections connRec := &connection.Record{ @@ -1245,7 +1245,7 @@ func TestKeylistUpdate(t *testing.T) { require.Contains(t, err.Error(), "router not registered") // save router connID - require.NoError(t, svc.saveRouterConnectionID("conn")) + require.NoError(t, svc.saveRouterConnectionID("conn", "")) // no connections saved err = svc.AddKey("conn", recKey) @@ -1298,7 +1298,7 @@ func TestKeylistUpdate(t *testing.T) { connBytes, err := json.Marshal(connRec) require.NoError(t, err) s["conn_conn2"] = mockstore.DBEntry{Value: connBytes} - require.NoError(t, svc.saveRouterConnectionID("conn2")) + require.NoError(t, svc.saveRouterConnectionID("conn2", "")) err = svc.AddKey("conn2", "recKey") require.Error(t, err) @@ -1342,7 +1342,7 @@ func TestConfig(t *testing.T) { }) require.NoError(t, err) - require.NoError(t, svc.saveRouterConnectionID("connID-123")) + require.NoError(t, svc.saveRouterConnectionID("connID-123", "")) require.NoError(t, svc.saveRouterConfig("connID-123", &config{ RouterEndpoint: ENDPOINT, RoutingKeys: routingKeys, @@ -1386,7 +1386,7 @@ func TestConfig(t *testing.T) { }) require.NoError(t, err) - require.NoError(t, svc.saveRouterConnectionID("connID-123")) + require.NoError(t, svc.saveRouterConnectionID("connID-123", "")) conf, err := svc.Config("connID-123") require.Error(t, err) @@ -1409,7 +1409,7 @@ func TestConfig(t *testing.T) { const conn = "connID-123" - require.NoError(t, svc.saveRouterConnectionID(conn)) + require.NoError(t, svc.saveRouterConnectionID(conn, "")) require.NoError(t, svc.routeStore.Put(fmt.Sprintf(routeConfigDataKey, conn), []byte("invalid data"))) conf, err := svc.Config(conn) @@ -1433,7 +1433,7 @@ func TestConfig(t *testing.T) { }) require.NoError(t, err) - require.NoError(t, svc.saveRouterConnectionID("connID-123")) + require.NoError(t, svc.saveRouterConnectionID("connID-123", "")) require.NoError(t, svc.routeStore.Put(routeConfigDataKey, []byte("invalid data"))) conf, err := svc.Config("connID-123") @@ -1458,7 +1458,7 @@ func TestGetConnections(t *testing.T) { ) require.NoError(t, err) - err = svc.saveRouterConnectionID(routerConnectionID) + err = svc.saveRouterConnectionID(routerConnectionID, "") require.NoError(t, err) connID, err := svc.GetConnections() diff --git a/pkg/didcomm/protocol/mediator/util_test.go b/pkg/didcomm/protocol/mediator/util_test.go index 4522abe689..9288759561 100644 --- a/pkg/didcomm/protocol/mediator/util_test.go +++ b/pkg/didcomm/protocol/mediator/util_test.go @@ -87,7 +87,7 @@ func (m *mockRouteSvc) AddKey(connID, recKey string) error { } // AddKey adds agents recKey to the router. -func (m *mockRouteSvc) GetConnections() ([]string, error) { +func (m *mockRouteSvc) GetConnections(...ConnectionOption) ([]string, error) { return m.Connections, m.ConnectionsErr } diff --git a/pkg/mock/didcomm/protocol/mediator/mock_mediator.go b/pkg/mock/didcomm/protocol/mediator/mock_mediator.go index 5a290d8494..d133ca4046 100644 --- a/pkg/mock/didcomm/protocol/mediator/mock_mediator.go +++ b/pkg/mock/didcomm/protocol/mediator/mock_mediator.go @@ -115,7 +115,7 @@ func (m *MockMediatorSvc) Config(connID string) (*mediator.Config, error) { } // GetConnections returns router`s connections. -func (m *MockMediatorSvc) GetConnections() ([]string, error) { +func (m *MockMediatorSvc) GetConnections(...mediator.ConnectionOption) ([]string, error) { if m.GetConnectionsErr != nil { return nil, m.GetConnectionsErr }