Skip to content
This repository has been archived by the owner on Mar 27, 2024. It is now read-only.

feat: allow mediator GetConnections APIs to filter by didcomm version. #3320

Merged
merged 1 commit into from
Aug 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pkg/client/mediator/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ type protocolService interface {
Unregister(connID string) error

// GetConnections returns router`s connections.
GetConnections() ([]string, error)
GetConnections(...mediator.ConnectionOption) ([]string, error)

// Config returns the router's configuration.
Config(connID string) (*mediator.Config, error)
Expand Down Expand Up @@ -91,8 +91,8 @@ func (c *Client) Unregister(connID string) error {
}

// GetConnections returns router`s connections.
func (c *Client) GetConnections() ([]string, error) {
connections, err := c.routeSvc.GetConnections()
func (c *Client) GetConnections(options ...ConnectionOption) ([]string, error) {
connections, err := c.routeSvc.GetConnections(options...)
if err != nil {
return nil, fmt.Errorf("get router connections: %w", err)
}
Expand Down
3 changes: 3 additions & 0 deletions pkg/client/mediator/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ const (
// Request is the route-request message of this protocol.
type Request = mediator.Request

// ConnectionOption option for Client.GetConnections.
type ConnectionOption = mediator.ConnectionOption

// NewRequest creates a new request.
func NewRequest() *Request {
return &Request{
Expand Down
36 changes: 35 additions & 1 deletion pkg/controller/command/mediator/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/hyperledger/aries-framework-go/pkg/controller/command"
"github.com/hyperledger/aries-framework-go/pkg/controller/internal/cmdutil"
"github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service"
mediatorSvc "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/mediator"
"github.com/hyperledger/aries-framework-go/pkg/internal/logutil"
"github.com/hyperledger/aries-framework-go/pkg/kms"
)
Expand Down Expand Up @@ -215,7 +216,40 @@ func (o *Command) Unregister(rw io.Writer, req io.Reader) command.Error {

// Connections returns the connections of the router.
func (o *Command) Connections(rw io.Writer, req io.Reader) command.Error {
connections, err := o.routeClient.GetConnections()
var request ConnectionsRequest

if req != nil {
reqData, err := io.ReadAll(req)
if err != nil {
logutil.LogInfo(logger, CommandName, GetConnectionsCommandMethod, err.Error())
return command.NewValidationError(GetConnectionsErrorCode, fmt.Errorf("read request : %w", err))
}

if len(reqData) > 0 {
err = json.Unmarshal(reqData, &request)
if err != nil {
logutil.LogInfo(logger, CommandName, GetConnectionsCommandMethod, err.Error())
return command.NewValidationError(InvalidRequestErrorCode, fmt.Errorf("decode request : %w", err))
}
}
}

opts := []mediator.ConnectionOption{}

switch {
case request.DIDCommV1Only && request.DIDCommV2Only:
errMsg := "can't request didcomm v1 only at the same time as didcomm v2 only"

logutil.LogError(logger, CommandName, GetConnectionsCommandMethod, errMsg)

return command.NewValidationError(GetConnectionsErrorCode, fmt.Errorf("%s", errMsg))
case request.DIDCommV2Only:
opts = append(opts, mediatorSvc.ConnectionByVersion(service.V2))
case request.DIDCommV1Only:
opts = append(opts, mediatorSvc.ConnectionByVersion(service.V1))
}

connections, err := o.routeClient.GetConnections(opts...)
if err != nil {
logutil.LogError(logger, CommandName, GetConnectionsCommandMethod, err.Error())
return command.NewExecuteError(GetConnectionsErrorCode, err)
Expand Down
100 changes: 95 additions & 5 deletions pkg/controller/command/mediator/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,14 +207,96 @@ func TestCommand_Connections(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, cmd)

testcases := []struct {
name string
input string
}{
{
name: "no filters",
input: `{}`,
},
{
name: "didcomm v1 only",
input: `{"didcomm_v1": true}`,
},
{
name: "didcomm v2 only",
input: `{"didcomm_v2": true}`,
},
}

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
var b bytes.Buffer
err = cmd.Connections(&b, bytes.NewBufferString(tc.input))
require.NoError(t, err)

response := ConnectionsResponse{}
err = json.NewDecoder(&b).Decode(&response)
require.NoError(t, err)
require.Equal(t, routerConnectionID, response.Connections[0])
})
}
})

t.Run("test get connection - read request error", func(t *testing.T) {
cmd, err := New(
&mockprovider.Provider{
ServiceMap: map[string]interface{}{
messagepickupSvc.MessagePickup: &messagepickup.MockMessagePickupSvc{},
mediator.Coordination: &mockroute.MockMediatorSvc{},
oobsvc.Name: &mockoob.MockOobService{},
},
},
false,
)
require.NoError(t, err)
require.NotNil(t, cmd)

var b bytes.Buffer
err = cmd.Connections(&b, nil)
err = cmd.Connections(&b, &errReader{err: fmt.Errorf("expected error")})
require.Error(t, err)
require.Contains(t, err.Error(), "read request")
})

t.Run("test get connection - decode request error", func(t *testing.T) {
cmd, err := New(
&mockprovider.Provider{
ServiceMap: map[string]interface{}{
messagepickupSvc.MessagePickup: &messagepickup.MockMessagePickupSvc{},
mediator.Coordination: &mockroute.MockMediatorSvc{},
oobsvc.Name: &mockoob.MockOobService{},
},
},
false,
)
require.NoError(t, err)
require.NotNil(t, cmd)

response := ConnectionsResponse{}
err = json.NewDecoder(&b).Decode(&response)
var b bytes.Buffer
err = cmd.Connections(&b, bytes.NewBufferString("{"))
require.Error(t, err)
require.Contains(t, err.Error(), "decode request")
})

t.Run("test get connection - invalid filter options error", func(t *testing.T) {
cmd, err := New(
&mockprovider.Provider{
ServiceMap: map[string]interface{}{
messagepickupSvc.MessagePickup: &messagepickup.MockMessagePickupSvc{},
mediator.Coordination: &mockroute.MockMediatorSvc{},
oobsvc.Name: &mockoob.MockOobService{},
},
},
false,
)
require.NoError(t, err)
require.Equal(t, routerConnectionID, response.Connections[0])
require.NotNil(t, cmd)

var b bytes.Buffer
err = cmd.Connections(&b, bytes.NewBufferString(`{"didcomm_v1": true, "didcomm_v2": true}`))
require.Error(t, err)
require.Contains(t, err.Error(), "at the same time")
})

t.Run("test get connection - error", func(t *testing.T) {
Expand All @@ -234,7 +316,7 @@ func TestCommand_Connections(t *testing.T) {
require.NotNil(t, cmd)

var b bytes.Buffer
err = cmd.Connections(&b, nil)
err = cmd.Connections(&b, bytes.NewBufferString("{}"))
require.Error(t, err)
require.Contains(t, err.Error(), "get router connections")
})
Expand Down Expand Up @@ -533,3 +615,11 @@ func newMockProvider(serviceMap map[string]interface{}) *mockprovider.Provider {
ProtocolStateStorageProviderValue: mockstorage.NewMockStoreProvider(),
}
}

type errReader struct {
err error
}

func (e *errReader) Read([]byte) (int, error) {
return 0, e.err
}
6 changes: 6 additions & 0 deletions pkg/controller/command/mediator/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ type RegisterRoute struct {
ConnectionID string `json:"connectionID"`
}

// ConnectionsRequest contains parameters for filtering when requesting router connections.
type ConnectionsRequest struct {
DIDCommV1Only bool `json:"didcomm_v1"`
DIDCommV2Only bool `json:"didcomm_v2"`
}

// ConnectionsResponse is response for router`s connections.
type ConnectionsResponse struct {
Connections []string `json:"connections"`
Expand Down
2 changes: 1 addition & 1 deletion pkg/didcomm/protocol/mediator/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ type ProtocolService interface {
Config(connID string) (*Config, error)

// GetConnections returns all router connections
GetConnections() ([]string, error)
GetConnections(options ...ConnectionOption) ([]string, error)
}
54 changes: 49 additions & 5 deletions pkg/didcomm/protocol/mediator/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ type callback struct {
err error
}

type routerConnectionEntry struct {
ConnectionID string `json:"connectionID"`
DIDCommVersion service.Version `json:"didcomm_version,omitempty"`
}

type connections interface {
GetConnectionIDByDIDs(string, string) (string, error)
GetConnectionRecord(string) (*connection.Record, error)
Expand Down Expand Up @@ -646,7 +651,7 @@ func (s *Service) doRegistration(record *connection.Record, req *Request, timeou
logger.Debugf("saved router config from inbound grant: %+v", grant)

// save the connectionID of the router
return s.saveRouterConnectionID(record.ConnectionID)
return s.saveRouterConnectionID(record.ConnectionID, record.DIDCommVersion)
}

func (s *Service) getGrant(id string, timeout time.Duration) (*Grant, error) {
Expand Down Expand Up @@ -700,7 +705,13 @@ func (s *Service) Unregister(connID string) error {
}

// GetConnections returns the connections of the router.
func (s *Service) GetConnections() ([]string, error) {
func (s *Service) GetConnections(options ...ConnectionOption) ([]string, error) {
opts := &getConnectionOpts{}

for _, option := range options {
option(opts)
}

records, err := s.routeStore.Query(routeConnIDDataKey)
if err != nil {
return nil, fmt.Errorf("failed to query route store: %w", err)
Expand All @@ -721,7 +732,16 @@ func (s *Service) GetConnections() ([]string, error) {
return nil, fmt.Errorf("failed to get value from records: %w", err)
}

conns = append(conns, string(value))
data := &routerConnectionEntry{}

err = json.Unmarshal(value, data)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal router connection entry: %w", err)
}

if opts.version == "" || opts.version == data.DIDCommVersion {
conns = append(conns, data.ConnectionID)
}

more, err = records.Next()
if err != nil {
Expand Down Expand Up @@ -838,8 +858,18 @@ func (s *Service) deleteRouterConnectionID(connID string) error {
return s.routeStore.Delete(fmt.Sprintf(routeConnIDDataKey, connID))
}

func (s *Service) saveRouterConnectionID(connID string) error {
return s.routeStore.Put(fmt.Sprintf(routeConnIDDataKey, connID), []byte(connID), storage.Tag{Name: routeConnIDDataKey})
func (s *Service) saveRouterConnectionID(connID string, didcommVersion service.Version) error {
data := &routerConnectionEntry{
ConnectionID: connID,
DIDCommVersion: didcommVersion,
}

dataBytes, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("marshalling router connection ID data: %w", err)
}

return s.routeStore.Put(fmt.Sprintf(routeConnIDDataKey, connID), dataBytes, storage.Tag{Name: routeConnIDDataKey})
}

type config struct {
Expand Down Expand Up @@ -918,3 +948,17 @@ func parseClientOpts(options ...ClientOption) *ClientOptions {

return opts
}

type getConnectionOpts struct {
version service.Version
}

// ConnectionOption option for Service.GetConnections.
type ConnectionOption func(opts *getConnectionOpts)

// ConnectionByVersion filter for mediator connections of the given DIDComm version.
func ConnectionByVersion(v service.Version) ConnectionOption {
return func(opts *getConnectionOpts) {
opts.version = v
}
}
Loading