Skip to content

Commit

Permalink
Remove hack in openStream RPC
Browse files Browse the repository at this point in the history
Problem: RPC protocol requires response to stream open RPC to arrive
before any message from the stream. This was implemented with use of an
ugly hack.

Solution: remove hack, introduce notion of after-write handler in
to be executed after the rpc response is written to output.
  • Loading branch information
georgeee committed Nov 16, 2023
1 parent 5ec0334 commit 036464d
Show file tree
Hide file tree
Showing 14 changed files with 78 additions and 64 deletions.
2 changes: 1 addition & 1 deletion src/app/libp2p_helper/src/libp2p_helper/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
)

func newApp() *app {
outChan := make(chan *capnp.Message, 1<<12) // 4kb
outChan := make(chan *capnp.Message, 1<<12) // 4096 messages stacked
ctx := context.Background()
return &app{
P2p: nil,
Expand Down
2 changes: 1 addition & 1 deletion src/app/libp2p_helper/src/libp2p_helper/bandwidth_msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func fromBandwidthInfoReq(req ipcRpcRequest) (rpcRequest, error) {
return BandwidthInfoReq(i), err
}

func (msg BandwidthInfoReq) handle(app *app, seqno uint64) *capnp.Message {
func (msg BandwidthInfoReq) handle(app *app, seqno uint64) (*capnp.Message, func()) {
if app.P2p == nil {
return mkRpcRespError(seqno, needsConfigure())
}
Expand Down
4 changes: 2 additions & 2 deletions src/app/libp2p_helper/src/libp2p_helper/bitswap_msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func fromTestDecodeBitswapBlocksReq(req ipcRpcRequest) (rpcRequest, error) {
return TestDecodeBitswapBlocksReq(i), err
}

func (m TestDecodeBitswapBlocksReq) handle(app *app, seqno uint64) *capnp.Message {
func (m TestDecodeBitswapBlocksReq) handle(app *app, seqno uint64) (*capnp.Message, func()) {
blocks, err := TestDecodeBitswapBlocksReqT(m).Blocks()
if err != nil {
return mkRpcRespError(seqno, badRPC(err))
Expand Down Expand Up @@ -156,7 +156,7 @@ func fromTestEncodeBitswapBlocksReq(req ipcRpcRequest) (rpcRequest, error) {
return TestEncodeBitswapBlocksReq(i), err
}

func (m TestEncodeBitswapBlocksReq) handle(app *app, seqno uint64) *capnp.Message {
func (m TestEncodeBitswapBlocksReq) handle(app *app, seqno uint64) (*capnp.Message, func()) {
mr := TestEncodeBitswapBlocksReqT(m)

data, err := mr.Data()
Expand Down
14 changes: 7 additions & 7 deletions src/app/libp2p_helper/src/libp2p_helper/config_msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func fromBeginAdvertisingReq(req ipcRpcRequest) (rpcRequest, error) {
i, err := req.BeginAdvertising()
return BeginAdvertisingReq(i), err
}
func (msg BeginAdvertisingReq) handle(app *app, seqno uint64) *capnp.Message {
func (msg BeginAdvertisingReq) handle(app *app, seqno uint64) (*capnp.Message, func()) {
if app.P2p == nil {
return mkRpcRespError(seqno, needsConfigure())
}
Expand Down Expand Up @@ -293,7 +293,7 @@ func fromConfigureReq(req ipcRpcRequest) (rpcRequest, error) {
i, err := req.Configure()
return ConfigureReq(i), err
}
func (msg ConfigureReq) handle(app *app, seqno uint64) *capnp.Message {
func (msg ConfigureReq) handle(app *app, seqno uint64) (*capnp.Message, func()) {
m, err := ConfigureReqT(msg).Config()
if err != nil {
return mkRpcRespError(seqno, badRPC(err))
Expand Down Expand Up @@ -487,7 +487,7 @@ func fromGetListeningAddrsReq(req ipcRpcRequest) (rpcRequest, error) {
i, err := req.GetListeningAddrs()
return GetListeningAddrsReq(i), err
}
func (msg GetListeningAddrsReq) handle(app *app, seqno uint64) *capnp.Message {
func (msg GetListeningAddrsReq) handle(app *app, seqno uint64) (*capnp.Message, func()) {
if app.P2p == nil {
return mkRpcRespError(seqno, needsConfigure())
}
Expand All @@ -508,7 +508,7 @@ func fromGenerateKeypairReq(req ipcRpcRequest) (rpcRequest, error) {
i, err := req.GenerateKeypair()
return GenerateKeypairReq(i), err
}
func (msg GenerateKeypairReq) handle(app *app, seqno uint64) *capnp.Message {
func (msg GenerateKeypairReq) handle(app *app, seqno uint64) (*capnp.Message, func()) {
privk, pubk, err := crypto.GenerateEd25519Key(cryptorand.Reader)
if err != nil {
return mkRpcRespError(seqno, badp2p(err))
Expand Down Expand Up @@ -548,7 +548,7 @@ func fromListenReq(req ipcRpcRequest) (rpcRequest, error) {
i, err := req.Listen()
return ListenReq(i), err
}
func (m ListenReq) handle(app *app, seqno uint64) *capnp.Message {
func (m ListenReq) handle(app *app, seqno uint64) (*capnp.Message, func()) {
if app.P2p == nil {
return mkRpcRespError(seqno, needsConfigure())
}
Expand Down Expand Up @@ -586,7 +586,7 @@ func fromSetGatingConfigReq(req ipcRpcRequest) (rpcRequest, error) {
i, err := req.SetGatingConfig()
return SetGatingConfigReq(i), err
}
func (m SetGatingConfigReq) handle(app *app, seqno uint64) *capnp.Message {
func (m SetGatingConfigReq) handle(app *app, seqno uint64) (*capnp.Message, func()) {
if app.P2p == nil {
return mkRpcRespError(seqno, needsConfigure())
}
Expand Down Expand Up @@ -616,7 +616,7 @@ func fromSetNodeStatusReq(req ipcRpcRequest) (rpcRequest, error) {
i, err := req.SetNodeStatus()
return SetNodeStatusReq(i), err
}
func (m SetNodeStatusReq) handle(app *app, seqno uint64) *capnp.Message {
func (m SetNodeStatusReq) handle(app *app, seqno uint64) (*capnp.Message, func()) {
status, err := SetNodeStatusReqT(m).Status()
if err != nil {
return mkRpcRespError(seqno, badRPC(err))
Expand Down
12 changes: 6 additions & 6 deletions src/app/libp2p_helper/src/libp2p_helper/config_msg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ func TestConfigure(t *testing.T) {
require.NoError(t, err)
gc.SetIsolate(false)

resMsg := ConfigureReq(m).handle(testApp, 239)
resMsg, _ := ConfigureReq(m).handle(testApp, 239)
require.NoError(t, err)
seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "configure")
require.Equal(t, seqno, uint64(239))
Expand All @@ -206,7 +206,7 @@ func TestGenerateKeypair(t *testing.T) {
require.NoError(t, err)

testApp, _ := newTestApp(t, nil, true)
resMsg := GenerateKeypairReq(m).handle(testApp, 7839)
resMsg, _ := GenerateKeypairReq(m).handle(testApp, 7839)
require.NoError(t, err)
seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "generateKeypair")
require.Equal(t, seqno, uint64(7839))
Expand Down Expand Up @@ -239,7 +239,7 @@ func TestGetListeningAddrs(t *testing.T) {
m, err := ipc.NewRootLibp2pHelperInterface_GetListeningAddrs_Request(seg)
require.NoError(t, err)
var mRpcSeqno uint64 = 1024
resMsg := GetListeningAddrsReq(m).handle(testApp, mRpcSeqno)
resMsg, _ := GetListeningAddrsReq(m).handle(testApp, mRpcSeqno)
seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "getListeningAddrs")
require.Equal(t, seqno, mRpcSeqno)
require.True(t, respSuccess.HasGetListeningAddrs())
Expand All @@ -265,7 +265,7 @@ func TestListen(t *testing.T) {
require.NoError(t, iface.SetRepresentation(addrStr))
require.NoError(t, err)

resMsg := ListenReq(m).handle(testApp, 1239)
resMsg, _ := ListenReq(m).handle(testApp, 1239)
require.NoError(t, err)
seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "listen")
require.Equal(t, seqno, uint64(1239))
Expand Down Expand Up @@ -316,7 +316,7 @@ func setGatingConfigImpl(t *testing.T, app *app, allowedIps, allowedIds, bannedI
gc.SetIsolate(false)

var mRpcSeqno uint64 = 2003
resMsg := SetGatingConfigReq(m).handle(app, mRpcSeqno)
resMsg, _ := SetGatingConfigReq(m).handle(app, mRpcSeqno)
seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "setGatingConfig")
require.Equal(t, seqno, mRpcSeqno)
require.True(t, respSuccess.HasSetGatingConfig())
Expand Down Expand Up @@ -369,7 +369,7 @@ func TestSetNodeStatus(t *testing.T) {
testStatus := []byte("test_node_status")
require.NoError(t, m.SetStatus(testStatus))

resMsg := SetNodeStatusReq(m).handle(testApp, 11239)
resMsg, _ := SetNodeStatusReq(m).handle(testApp, 11239)
require.NoError(t, err)
seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "setNodeStatus")
require.Equal(t, seqno, uint64(11239))
Expand Down
18 changes: 11 additions & 7 deletions src/app/libp2p_helper/src/libp2p_helper/incoming_msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,32 +43,36 @@ var pushMesssageExtractors = map[ipc.Libp2pHelperInterface_PushMessage_Which]ext
// Handles messages coming from the OCaml process
func (app *app) handleIncomingMsg(msg *ipc.Libp2pHelperInterface_Message) {
if msg.HasRpcRequest() {
resp, err := func() (*capnp.Message, error) {
resp, afterWriteHandler, err := func() (*capnp.Message, func(), error) {
req, err := msg.RpcRequest()
if err != nil {
return nil, err
return nil, nil, err
}
h, err := req.Header()
if err != nil {
return nil, err
return nil, nil, err
}
seqnoO, err := h.SequenceNumber()
if err != nil {
return nil, err
return nil, nil, err
}
seqno := seqnoO.Seqno()
extractor, foundHandler := rpcRequestExtractors[req.Which()]
if !foundHandler {
return nil, errors.New("Received rpc message of an unknown type")
return nil, nil, errors.New("Received rpc message of an unknown type")
}
req2, err := extractor(req)
if err != nil {
return nil, err
return nil, nil, err
}
return req2.handle(app, seqno), nil
resp, afterWriteHandler := req2.handle(app, seqno)
return resp, afterWriteHandler, nil
}()
if err == nil {
app.writeMsg(resp)
if afterWriteHandler != nil {
afterWriteHandler()
}
} else {
app.P2p.Logger.Errorf("Failed to process rpc message: %s", err)
}
Expand Down
19 changes: 16 additions & 3 deletions src/app/libp2p_helper/src/libp2p_helper/msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@ type extractPushMessage = func(ipcPushMessage) (pushMessage, error)

type ipcRpcRequest = ipc.Libp2pHelperInterface_RpcRequest
type rpcRequest interface {
handle(app *app, seqno uint64) *capnp.Message
// Handles rpc request and returns response and a function to be called
// immediately after writing response to the output stream
//
// Callback is needed in some cases to make sure response is written
// before some other messages might get written to the output stream
handle(app *app, seqno uint64) (*capnp.Message, func())
}
type extractRequest = func(ipcRpcRequest) (rpcRequest, error)

Expand Down Expand Up @@ -207,7 +212,7 @@ func setNanoTime(ns *ipc.UnixNano, t time.Time) {
ns.SetNanoSec(t.UnixNano())
}

func mkRpcRespError(seqno uint64, rpcRespErr error) *capnp.Message {
func mkRpcRespErrorNoFunc(seqno uint64, rpcRespErr error) *capnp.Message {
if rpcRespErr == nil {
panic("mkRpcRespError: nil error")
}
Expand All @@ -228,7 +233,11 @@ func mkRpcRespError(seqno uint64, rpcRespErr error) *capnp.Message {
})
}

func mkRpcRespSuccess(seqno uint64, f func(*ipc.Libp2pHelperInterface_RpcResponseSuccess)) *capnp.Message {
func mkRpcRespError(seqno uint64, rpcRespErr error) (*capnp.Message, func()) {
return mkRpcRespErrorNoFunc(seqno, rpcRespErr), nil
}

func mkRpcRespSuccessNoFunc(seqno uint64, f func(*ipc.Libp2pHelperInterface_RpcResponseSuccess)) *capnp.Message {
return mkMsg(func(seg *capnp.Segment) {
m, err := ipc.NewRootDaemonInterface_Message(seg)
panicOnErr(err)
Expand All @@ -248,6 +257,10 @@ func mkRpcRespSuccess(seqno uint64, f func(*ipc.Libp2pHelperInterface_RpcRespons
})
}

func mkRpcRespSuccess(seqno uint64, f func(*ipc.Libp2pHelperInterface_RpcResponseSuccess)) (*capnp.Message, func()) {
return mkRpcRespSuccessNoFunc(seqno, f), nil
}

func mkPushMsg(f func(ipc.DaemonInterface_PushMessage)) *capnp.Message {
return mkMsg(func(seg *capnp.Segment) {
m, err := ipc.NewRootDaemonInterface_Message(seg)
Expand Down
6 changes: 3 additions & 3 deletions src/app/libp2p_helper/src/libp2p_helper/peer_msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func fromAddPeerReq(req ipcRpcRequest) (rpcRequest, error) {
i, err := req.AddPeer()
return AddPeerReq(i), err
}
func (m AddPeerReq) handle(app *app, seqno uint64) *capnp.Message {
func (m AddPeerReq) handle(app *app, seqno uint64) (*capnp.Message, func()) {
if app.P2p == nil {
return mkRpcRespError(seqno, needsConfigure())
}
Expand Down Expand Up @@ -71,7 +71,7 @@ func fromGetPeerNodeStatusReq(req ipcRpcRequest) (rpcRequest, error) {
i, err := req.GetPeerNodeStatus()
return GetPeerNodeStatusReq(i), err
}
func (m GetPeerNodeStatusReq) handle(app *app, seqno uint64) *capnp.Message {
func (m GetPeerNodeStatusReq) handle(app *app, seqno uint64) (*capnp.Message, func()) {
ctx, cancel := context.WithTimeout(app.Ctx, codanet.NodeStatusTimeout)
defer cancel()
pma, err := GetPeerNodeStatusReqT(m).Peer()
Expand Down Expand Up @@ -147,7 +147,7 @@ func fromListPeersReq(req ipcRpcRequest) (rpcRequest, error) {
i, err := req.ListPeers()
return ListPeersReq(i), err
}
func (msg ListPeersReq) handle(app *app, seqno uint64) *capnp.Message {
func (msg ListPeersReq) handle(app *app, seqno uint64) (*capnp.Message, func()) {
if app.P2p == nil {
return mkRpcRespError(seqno, needsConfigure())
}
Expand Down
6 changes: 3 additions & 3 deletions src/app/libp2p_helper/src/libp2p_helper/peer_msg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func testAddPeerImplDo(t *testing.T, node *app, peerAddr peer.AddrInfo, isSeed b
m.SetIsSeed(isSeed)

var mRpcSeqno uint64 = 2000
resMsg := AddPeerReq(m).handle(node, mRpcSeqno)
resMsg, _ := AddPeerReq(m).handle(node, mRpcSeqno)
seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "addPeer")
require.Equal(t, seqno, mRpcSeqno)
require.True(t, respSuccess.HasAddPeer())
Expand Down Expand Up @@ -88,7 +88,7 @@ func TestGetPeerNodeStatus(t *testing.T) {
require.NoError(t, ma.SetRepresentation(addr))

var mRpcSeqno uint64 = 18900
resMsg := GetPeerNodeStatusReq(m).handle(appB, mRpcSeqno)
resMsg, _ := GetPeerNodeStatusReq(m).handle(appB, mRpcSeqno)
seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "getPeerNodeStatus")
require.Equal(t, seqno, mRpcSeqno)
require.True(t, respSuccess.HasGetPeerNodeStatus())
Expand All @@ -108,7 +108,7 @@ func TestListPeers(t *testing.T) {
require.NoError(t, err)

var mRpcSeqno uint64 = 2002
resMsg := ListPeersReq(m).handle(appB, mRpcSeqno)
resMsg, _ := ListPeersReq(m).handle(appB, mRpcSeqno)
seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "listPeers")
require.Equal(t, seqno, mRpcSeqno)
require.True(t, respSuccess.HasListPeers())
Expand Down
6 changes: 3 additions & 3 deletions src/app/libp2p_helper/src/libp2p_helper/pubsub_msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func fromPublishReq(req ipcRpcRequest) (rpcRequest, error) {
i, err := req.Publish()
return PublishReq(i), err
}
func (m PublishReq) handle(app *app, seqno uint64) *capnp.Message {
func (m PublishReq) handle(app *app, seqno uint64) (*capnp.Message, func()) {
if app.P2p == nil {
return mkRpcRespError(seqno, needsConfigure())
}
Expand Down Expand Up @@ -111,7 +111,7 @@ func fromSubscribeReq(req ipcRpcRequest) (rpcRequest, error) {
i, err := req.Subscribe()
return SubscribeReq(i), err
}
func (m SubscribeReq) handle(app *app, seqno uint64) *capnp.Message {
func (m SubscribeReq) handle(app *app, seqno uint64) (*capnp.Message, func()) {
if app.P2p == nil {
return mkRpcRespError(seqno, needsConfigure())
}
Expand Down Expand Up @@ -244,7 +244,7 @@ func fromUnsubscribeReq(req ipcRpcRequest) (rpcRequest, error) {
i, err := req.Unsubscribe()
return UnsubscribeReq(i), err
}
func (m UnsubscribeReq) handle(app *app, seqno uint64) *capnp.Message {
func (m UnsubscribeReq) handle(app *app, seqno uint64) (*capnp.Message, func()) {
if app.P2p == nil {
return mkRpcRespError(seqno, needsConfigure())
}
Expand Down
6 changes: 3 additions & 3 deletions src/app/libp2p_helper/src/libp2p_helper/pubsub_msg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func testPublishDo(t *testing.T, app *app, topic string, data []byte, rpcSeqno u
require.NoError(t, m.SetTopic(topic))
require.NoError(t, m.SetData(data))

resMsg := PublishReq(m).handle(app, rpcSeqno)
resMsg, _ := PublishReq(m).handle(app, rpcSeqno)
require.NoError(t, err)
seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "publish")
require.Equal(t, seqno, rpcSeqno)
Expand Down Expand Up @@ -47,7 +47,7 @@ func testSubscribeDo(t *testing.T, app *app, topic string, subId uint64, rpcSeqn
require.NoError(t, err)
sid.SetId(subId)

resMsg := SubscribeReq(m).handle(app, rpcSeqno)
resMsg, _ := SubscribeReq(m).handle(app, rpcSeqno)
require.NoError(t, err)
seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "subscribe")
require.Equal(t, seqno, rpcSeqno)
Expand Down Expand Up @@ -89,7 +89,7 @@ func TestUnsubscribe(t *testing.T) {
require.NoError(t, err)
sid.SetId(idx)

resMsg := UnsubscribeReq(m).handle(testApp, 7739)
resMsg, _ := UnsubscribeReq(m).handle(testApp, 7739)
require.NoError(t, err)
seqno, respSuccess := checkRpcResponseSuccess(t, resMsg, "unsubscribe")
require.Equal(t, seqno, uint64(7739))
Expand Down
Loading

0 comments on commit 036464d

Please sign in to comment.