diff --git a/options_filter.go b/options_filter.go index c8be5ea2a4..c6b8175d8e 100644 --- a/options_filter.go +++ b/options_filter.go @@ -23,12 +23,12 @@ func (f *filtersConnectionGater) InterceptPeerDial(p peer.ID) (allow bool) { return true } -func (f *filtersConnectionGater) InterceptAccept(_ network.ConnMultiaddrs) (allow bool) { - return true +func (f *filtersConnectionGater) InterceptAccept(connAddr network.ConnMultiaddrs) (allow bool) { + return !(*ma.Filters)(f).AddrBlocked(connAddr.RemoteMultiaddr()) } -func (f *filtersConnectionGater) InterceptSecured(_ network.Direction, _ peer.ID, _ network.ConnMultiaddrs) (allow bool) { - return true +func (f *filtersConnectionGater) InterceptSecured(_ network.Direction, _ peer.ID, connAddr network.ConnMultiaddrs) (allow bool) { + return !(*ma.Filters)(f).AddrBlocked(connAddr.RemoteMultiaddr()) } func (f *filtersConnectionGater) InterceptUpgraded(_ network.Conn) (allow bool, reason control.DisconnectReason) { diff --git a/options_test.go b/options_test.go index 4c899aa081..ba8c0cbfba 100644 --- a/options_test.go +++ b/options_test.go @@ -10,19 +10,20 @@ import ( "github.com/libp2p/go-libp2p-core/test" ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" "github.com/stretchr/testify/require" ) -func TestDeprecatedFiltersOptions(t *testing.T) { +func TestDeprecatedFiltersOptionsOutbound(t *testing.T) { require := require.New(t) f := ma.NewFilters() _, ipnet, _ := net.ParseCIDR("127.0.0.0/24") f.AddFilter(*ipnet, ma.ActionDeny) - host, err := New(context.TODO(), Filters(f)) + host0, err := New(context.TODO(), Filters(f)) require.NoError(err) - require.NotNil(host) + require.NotNil(host0) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() @@ -31,11 +32,54 @@ func TestDeprecatedFiltersOptions(t *testing.T) { addr, _ := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/0/p2p/" + id.Pretty()) ai, _ := peer.AddrInfoFromP2pAddr(addr) - err = host.Connect(ctx, *ai) + err = host0.Connect(ctx, *ai) require.Error(err) require.Contains(err.Error(), "no good addresses") } +var ( + ip4FullMask = net.IPMask{255, 255, 255, 255} + ip6FullMask = net.IPMask{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255} +) + +func TestDeprecatedFiltersOptionsInbound(t *testing.T) { + require := require.New(t) + + host0, err := New(context.TODO()) + require.NoError(err) + require.NotNil(host0) + + f := ma.NewFilters() + for _, addr := range host0.Addrs() { + ip, err := manet.ToIP(addr) + require.NoError(err) + require.NotNil(t, ip) + + var mask net.IPMask + if ip.To4() != nil { + mask = ip4FullMask + } else { + mask = ip6FullMask + } + + ipnet := net.IPNet{IP: ip, Mask: mask} + f.AddFilter(ipnet, ma.ActionDeny) + } + host1, err := New(context.TODO(), Filters(f)) + require.NoError(err) + require.NotNil(host1) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + peerInfo := peer.AddrInfo{ + ID: host1.ID(), + Addrs: host1.Addrs(), + } + err = host0.Connect(ctx, peerInfo) + require.Error(err) +} + func TestDeprecatedFiltersAndAddressesOptions(t *testing.T) { require := require.New(t)