Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make MultistreamMuxer and Client APIs generic #95

Merged
merged 3 commits into from
Jan 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ import (
// "/cats" and "/dogs" and exposes it on a localhost:8765. It then opens connections
// to that port, selects the protocols and tests that the handlers are working.
func main() {
mux := ms.NewMultistreamMuxer()
mux := ms.NewMultistreamMuxer[string]()
mux.AddHandler("/cats", func(proto string, rwc io.ReadWriteCloser) error {
fmt.Fprintln(rwc, proto, ": HELLO I LIKE CATS")
return rwc.Close()
Expand Down
54 changes: 27 additions & 27 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ const (
// to inform the muxer of the protocol that will be used to communicate
// on this ReadWriteCloser. It returns an error if, for example,
// the muxer does not know how to handle this protocol.
func SelectProtoOrFail(proto string, rwc io.ReadWriteCloser) (err error) {
func SelectProtoOrFail[T StringLike](proto T, rwc io.ReadWriteCloser) (err error) {
defer func() {
if rerr := recover(); rerr != nil {
fmt.Fprintf(os.Stderr, "caught panic: %s\n%s\n", rerr, debug.Stack())
Expand Down Expand Up @@ -66,7 +66,7 @@ func SelectProtoOrFail(proto string, rwc io.ReadWriteCloser) (err error) {

// SelectOneOf will perform handshakes with the protocols on the given slice
// until it finds one which is supported by the muxer.
func SelectOneOf(protos []string, rwc io.ReadWriteCloser) (proto string, err error) {
func SelectOneOf[T StringLike](protos []T, rwc io.ReadWriteCloser) (proto T, err error) {
defer func() {
if rerr := recover(); rerr != nil {
fmt.Fprintf(os.Stderr, "caught panic: %s\n%s\n", rerr, debug.Stack())
Expand Down Expand Up @@ -97,7 +97,7 @@ const simOpenProtocol = "/libp2p/simultaneous-connect"

// SelectWithSimopenOrFail performs protocol negotiation with the simultaneous open extension.
// The returned boolean indicator will be true if we should act as a server.
func SelectWithSimopenOrFail(protos []string, rwc io.ReadWriteCloser) (proto string, isServer bool, err error) {
func SelectWithSimopenOrFail[T StringLike](protos []T, rwc io.ReadWriteCloser) (proto T, isServer bool, err error) {
defer func() {
if rerr := recover(); rerr != nil {
fmt.Fprintf(os.Stderr, "caught panic: %s\n%s\n", rerr, debug.Stack())
Expand Down Expand Up @@ -125,7 +125,7 @@ func SelectWithSimopenOrFail(protos []string, rwc io.ReadWriteCloser) (proto str
return "", false, err
}

tok, err := ReadNextToken(rwc)
tok, err := ReadNextToken[T](rwc)
if err != nil {
return "", false, err
}
Expand All @@ -146,13 +146,13 @@ func SelectWithSimopenOrFail(protos []string, rwc io.ReadWriteCloser) (proto str
}
return proto, false, nil
default:
return "", false, errors.New("unexpected response: " + tok)
return "", false, fmt.Errorf("unexpected response: %s", tok)
}
}

func clientOpen(protos []string, rwc io.ReadWriteCloser) (string, error) {
func clientOpen[T StringLike](protos []T, rwc io.ReadWriteCloser) (T, error) {
// check to see if we selected the pipelined protocol
tok, err := ReadNextToken(rwc)
tok, err := ReadNextToken[T](rwc)
if err != nil {
return "", err
}
Expand All @@ -163,11 +163,11 @@ func clientOpen(protos []string, rwc io.ReadWriteCloser) (string, error) {
case "na":
return selectProtosOrFail(protos[1:], rwc)
default:
return "", errors.New("unexpected response: " + tok)
return "", fmt.Errorf("unexpected response: %s", tok)
}
}

func selectProtosOrFail(protos []string, rwc io.ReadWriteCloser) (string, error) {
func selectProtosOrFail[T StringLike](protos []T, rwc io.ReadWriteCloser) (T, error) {
for _, p := range protos {
err := trySelect(p, rwc)
switch err {
Expand All @@ -181,7 +181,7 @@ func selectProtosOrFail(protos []string, rwc io.ReadWriteCloser) (string, error)
return "", ErrNotSupported
}

func simOpen(protos []string, rwc io.ReadWriteCloser) (string, bool, error) {
func simOpen[T StringLike](protos []T, rwc io.ReadWriteCloser) (T, bool, error) {
randBytes := make([]byte, 8)
_, err := rand.Read(randBytes)
if err != nil {
Expand All @@ -198,25 +198,25 @@ func simOpen(protos []string, rwc io.ReadWriteCloser) (string, bool, error) {

// skip exactly one protocol
// see https://github.com/multiformats/go-multistream/pull/42#discussion_r558757135
_, err = ReadNextToken(rwc)
_, err = ReadNextToken[T](rwc)
if err != nil {
return "", false, err
}

// read the tie breaker nonce
tok, err := ReadNextToken(rwc)
tok, err := ReadNextToken[T](rwc)
if err != nil {
return "", false, err
}
if !strings.HasPrefix(tok, tieBreakerPrefix) {
if !strings.HasPrefix(string(tok), tieBreakerPrefix) {
return "", false, errors.New("tie breaker nonce not sent with the correct prefix")
}

if err = <-werrCh; err != nil {
return "", false, err
}

peerNonce, err := strconv.ParseUint(tok[len(tieBreakerPrefix):], 10, 64)
peerNonce, err := strconv.ParseUint(string(tok[len(tieBreakerPrefix):]), 10, 64)
if err != nil {
return "", false, err
}
Expand All @@ -228,7 +228,7 @@ func simOpen(protos []string, rwc io.ReadWriteCloser) (string, bool, error) {
}
iamserver = peerNonce > myNonce

var proto string
var proto T
if iamserver {
proto, err = simOpenSelectServer(protos, rwc)
} else {
Expand All @@ -238,26 +238,26 @@ func simOpen(protos []string, rwc io.ReadWriteCloser) (string, bool, error) {
return proto, iamserver, err
}

func simOpenSelectServer(protos []string, rwc io.ReadWriteCloser) (string, error) {
func simOpenSelectServer[T StringLike](protos []T, rwc io.ReadWriteCloser) (T, error) {
werrCh := make(chan error, 1)
go func() {
err := delimWriteBuffered(rwc, []byte(responder))
werrCh <- err
}()

tok, err := ReadNextToken(rwc)
tok, err := ReadNextToken[T](rwc)
if err != nil {
return "", err
}
if tok != initiator {
return "", errors.New("unexpected response: " + tok)
return "", fmt.Errorf("unexpected response: %s", tok)
}
if err = <-werrCh; err != nil {
return "", err
}

for {
tok, err = ReadNextToken(rwc)
tok, err = ReadNextToken[T](rwc)

if err == io.EOF {
return "", ErrNotSupported
Expand Down Expand Up @@ -286,19 +286,19 @@ func simOpenSelectServer(protos []string, rwc io.ReadWriteCloser) (string, error

}

func simOpenSelectClient(protos []string, rwc io.ReadWriteCloser) (string, error) {
func simOpenSelectClient[T StringLike](protos []T, rwc io.ReadWriteCloser) (T, error) {
werrCh := make(chan error, 1)
go func() {
err := delimWriteBuffered(rwc, []byte(initiator))
werrCh <- err
}()

tok, err := ReadNextToken(rwc)
tok, err := ReadNextToken[T](rwc)
if err != nil {
return "", err
}
if tok != responder {
return "", errors.New("unexpected response: " + tok)
return "", fmt.Errorf("unexpected response: %s", tok)
}
if err = <-werrCh; err != nil {
return "", err
Expand All @@ -308,7 +308,7 @@ func simOpenSelectClient(protos []string, rwc io.ReadWriteCloser) (string, error
}

func readMultistreamHeader(r io.Reader) error {
tok, err := ReadNextToken(r)
tok, err := ReadNextToken[string](r)
if err != nil {
return err
}
Expand All @@ -319,16 +319,16 @@ func readMultistreamHeader(r io.Reader) error {
return nil
}

func trySelect(proto string, rwc io.ReadWriteCloser) error {
func trySelect[T StringLike](proto T, rwc io.ReadWriteCloser) error {
err := delimWriteBuffered(rwc, []byte(proto))
if err != nil {
return err
}
return readProto(proto, rwc)
}

func readProto(proto string, r io.Reader) error {
tok, err := ReadNextToken(r)
func readProto[T StringLike](proto T, r io.Reader) error {
tok, err := ReadNextToken[T](r)
if err != nil {
return err
}
Expand All @@ -339,6 +339,6 @@ func readProto(proto string, r io.Reader) error {
case "na":
return ErrNotSupported
default:
return errors.New("unrecognized response: " + tok)
return fmt.Errorf("unrecognized response: %s", tok)
}
}
32 changes: 16 additions & 16 deletions lazyClient.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@ import (

// NewMSSelect returns a new Multistream which is able to perform
// protocol selection with a MultistreamMuxer.
func NewMSSelect(c io.ReadWriteCloser, proto string) LazyConn {
return &lazyClientConn{
protos: []string{ProtocolID, proto},
func NewMSSelect[T StringLike](c io.ReadWriteCloser, proto T) LazyConn {
return &lazyClientConn[T]{
protos: []T{ProtocolID, proto},
con: c,
}
}

// NewMultistream returns a multistream for the given protocol. This will not
// perform any protocol selection. If you are using a MultistreamMuxer, use
// NewMSSelect.
func NewMultistream(c io.ReadWriteCloser, proto string) LazyConn {
return &lazyClientConn{
protos: []string{proto},
func NewMultistream[T StringLike](c io.ReadWriteCloser, proto T) LazyConn {
return &lazyClientConn[T]{
protos: []T{proto},
con: c,
}
}
Expand All @@ -31,7 +31,7 @@ func NewMultistream(c io.ReadWriteCloser, proto string) LazyConn {
// It *does not* block writes waiting for the other end to respond. Instead, it
// simply assumes the negotiation went successfully and starts writing data.
// See: https://github.com/multiformats/go-multistream/issues/20
type lazyClientConn struct {
type lazyClientConn[T StringLike] struct {
// Used to ensure we only trigger the write half of the handshake once.
rhandshakeOnce sync.Once
rerr error
Expand All @@ -41,7 +41,7 @@ type lazyClientConn struct {
werr error

// The sequence of protocols to negotiate.
protos []string
protos []T

// The inner connection.
con io.ReadWriteCloser
Expand All @@ -53,7 +53,7 @@ type lazyClientConn struct {
// half of the handshake and then waits for the read half to complete.
//
// It returns an error if the read half of the handshake fails.
func (l *lazyClientConn) Read(b []byte) (int, error) {
func (l *lazyClientConn[T]) Read(b []byte) (int, error) {
l.rhandshakeOnce.Do(func() {
go l.whandshakeOnce.Do(l.doWriteHandshake)
l.doReadHandshake()
Expand All @@ -68,10 +68,10 @@ func (l *lazyClientConn) Read(b []byte) (int, error) {
return l.con.Read(b)
}

func (l *lazyClientConn) doReadHandshake() {
func (l *lazyClientConn[T]) doReadHandshake() {
for _, proto := range l.protos {
// read protocol
tok, err := ReadNextToken(l.con)
tok, err := ReadNextToken[T](l.con)
if err != nil {
l.rerr = err
return
Expand All @@ -88,12 +88,12 @@ func (l *lazyClientConn) doReadHandshake() {
}
}

func (l *lazyClientConn) doWriteHandshake() {
func (l *lazyClientConn[T]) doWriteHandshake() {
l.doWriteHandshakeWithData(nil)
}

// Perform the write handshake but *also* write some extra data.
func (l *lazyClientConn) doWriteHandshakeWithData(extra []byte) int {
func (l *lazyClientConn[T]) doWriteHandshakeWithData(extra []byte) int {
buf := getWriter(l.con)
defer putWriter(buf)

Expand Down Expand Up @@ -122,7 +122,7 @@ func (l *lazyClientConn) doWriteHandshakeWithData(extra []byte) int {
//
// Write *also* ignores errors from the read half of the handshake (in case the
// stream is actually write only).
func (l *lazyClientConn) Write(b []byte) (int, error) {
func (l *lazyClientConn[T]) Write(b []byte) (int, error) {
n := 0
l.whandshakeOnce.Do(func() {
go l.rhandshakeOnce.Do(l.doReadHandshake)
Expand All @@ -137,7 +137,7 @@ func (l *lazyClientConn) Write(b []byte) (int, error) {
// Close closes the underlying io.ReadWriteCloser
//
// This does not flush anything.
func (l *lazyClientConn) Close() error {
func (l *lazyClientConn[T]) Close() error {
// As the client, we flush the handshake on close to cover an
// interesting edge-case where the server only speaks a single protocol
// and responds eagerly with that protocol before waiting for out
Expand All @@ -151,7 +151,7 @@ func (l *lazyClientConn) Close() error {
}

// Flush sends the handshake.
func (l *lazyClientConn) Flush() error {
func (l *lazyClientConn[T]) Flush() error {
l.whandshakeOnce.Do(func() {
go l.rhandshakeOnce.Do(l.doReadHandshake)
l.doWriteHandshake()
Expand Down
Loading