Skip to content

Commit

Permalink
feat: filter support
Browse files Browse the repository at this point in the history
  • Loading branch information
shoriwe committed Jun 12, 2023
1 parent c9f30d7 commit 64d48ed
Show file tree
Hide file tree
Showing 3 changed files with 304 additions and 16 deletions.
103 changes: 103 additions & 0 deletions compose/filter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package compose

import (
"net"
"regexp"

"github.com/shoriwe/fullproxy/v4/filter"
"github.com/shoriwe/fullproxy/v4/utils/network"
)

type PortRange struct {
From *int `yaml:"from,omitempty" json:"from,omitempty"`
To *int `yaml:"to,omitempty" json:"to,omitempty"`
}

type Match struct {
Host *string `yaml:"host,omitempty" json:"host,omitempty"`
Port *int `yaml:"port,omitempty" json:"port,omitempty"`
Range *PortRange `yaml:"portRange,omitempty" json:"portRange,omitempty"`
}

func (m *Match) Compile() (filter.Match, error) {
var (
err error
host *regexp.Regexp
port int = -1
from int = -1
to int = -1
)
if m.Host != nil {
host, err = regexp.Compile(*m.Host)
}
if m.Port != nil {
port = *m.Port
}
if m.Range != nil {
if m.Range.From != nil {
from = *m.Range.From
}
if m.Range.To != nil {
to = *m.Range.To
}
}
f := filter.Match{
Host: host,
Port: port,
PortRange: [2]int{from, to},
}
return f, err
}

type Filter struct {
Whitelist []Match `yaml:"whitelist,omitempty" json:"whitelist,omitempty"`
Blacklist []Match `yaml:"blacklist,omitempty" json:"blacklist,omitempty"`
}

func (f *Filter) Listener(l net.Listener) (net.Listener, error) {
var whitelist, blacklist []filter.Match
for _, white := range f.Whitelist {
compiled, err := white.Compile()
if err != nil {
return nil, err
}
whitelist = append(whitelist, compiled)
}
for _, black := range f.Blacklist {
compiled, err := black.Compile()
if err != nil {
return nil, err
}
blacklist = append(blacklist, compiled)
}
ll := &filter.Listener{
Listener: l,
Whitelist: whitelist,
Blacklist: blacklist,
}
return ll, nil
}

func (f *Filter) DialFunc(dialFunc network.DialFunc) (*filter.DialFunc, error) {
var whitelist, blacklist []filter.Match
for _, white := range f.Whitelist {
compiled, err := white.Compile()
if err != nil {
return nil, err
}
whitelist = append(whitelist, compiled)
}
for _, black := range f.Blacklist {
compiled, err := black.Compile()
if err != nil {
return nil, err
}
blacklist = append(blacklist, compiled)
}
df := &filter.DialFunc{
DialFunc: dialFunc,
Whitelist: whitelist,
Blacklist: blacklist,
}
return df, nil
}
170 changes: 170 additions & 0 deletions compose/filter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
package compose

import (
"net"
"regexp"
"testing"

"github.com/shoriwe/fullproxy/v4/filter"
"github.com/shoriwe/fullproxy/v4/utils/network"
"github.com/stretchr/testify/assert"
)

func TestMatch_Compile(t *testing.T) {
t.Run("Valid", func(tt *testing.T) {
m := Match{
Host: new(string),
Port: new(int),
Range: &PortRange{
From: new(int),
To: new(int),
},
}
*m.Host = "127.0.0.1"
*m.Port = 80
*m.Range.From = 0
*m.Range.To = 8000
match, err := m.Compile()
assert.Nil(tt, err)
expect := filter.Match{
Host: regexp.MustCompile("127.0.0.1"),
Port: 80,
PortRange: [2]int{0, 8000},
}
assert.Equal(tt, expect, match)
})
}

func TestFilter_Listener(t *testing.T) {
t.Run("Valid", func(tt *testing.T) {
m := Match{
Host: new(string),
Port: new(int),
Range: &PortRange{
From: new(int),
To: new(int),
},
}
*m.Host = "127.0.0.1"
*m.Port = 80
*m.Range.From = 0
*m.Range.To = 8000
f := Filter{
Whitelist: []Match{m},
Blacklist: []Match{m},
}
l := network.ListenAny()
defer l.Close()
ll, err := f.Listener(l)
assert.Nil(tt, err)
defer ll.Close()
})
t.Run("Invalid Whitelist", func(tt *testing.T) {
m := Match{
Host: new(string),
Port: new(int),
Range: &PortRange{
From: new(int),
To: new(int),
},
}
*m.Host = ")"
*m.Port = 80
*m.Range.From = 0
*m.Range.To = 8000
f := Filter{
Whitelist: []Match{m},
// Blacklist: []Match{m},
}
l := network.ListenAny()
defer l.Close()
_, err := f.Listener(l)
assert.NotNil(tt, err)
})
t.Run("Invalid Blacklist", func(tt *testing.T) {
m := Match{
Host: new(string),
Port: new(int),
Range: &PortRange{
From: new(int),
To: new(int),
},
}
*m.Host = ")"
*m.Port = 80
*m.Range.From = 0
*m.Range.To = 8000
f := Filter{
// Whitelist: []Match{m},
Blacklist: []Match{m},
}
l := network.ListenAny()
defer l.Close()
_, err := f.Listener(l)
assert.NotNil(tt, err)
})
}

func TestFilter_DialFunc(t *testing.T) {
t.Run("Valid", func(tt *testing.T) {
m := Match{
Host: new(string),
Port: new(int),
Range: &PortRange{
From: new(int),
To: new(int),
},
}
*m.Host = "127.0.0.1"
*m.Port = 80
*m.Range.From = 0
*m.Range.To = 8000
f := Filter{
Whitelist: []Match{m},
Blacklist: []Match{m},
}
df, err := f.DialFunc(net.Dial)
assert.Nil(tt, err)
assert.NotNil(tt, df)
})
t.Run("Invalid Whitelist", func(tt *testing.T) {
m := Match{
Host: new(string),
Port: new(int),
Range: &PortRange{
From: new(int),
To: new(int),
},
}
*m.Host = ")"
*m.Port = 80
*m.Range.From = 0
*m.Range.To = 8000
f := Filter{
Whitelist: []Match{m},
// Blacklist: []Match{m},
}
_, err := f.DialFunc(net.Dial)
assert.NotNil(tt, err)
})
t.Run("Invalid Blacklist", func(tt *testing.T) {
m := Match{
Host: new(string),
Port: new(int),
Range: &PortRange{
From: new(int),
To: new(int),
},
}
*m.Host = ")"
*m.Port = 80
*m.Range.From = 0
*m.Range.To = 8000
f := Filter{
// Whitelist: []Match{m},
Blacklist: []Match{m},
}
_, err := f.DialFunc(net.Dial)
assert.NotNil(tt, err)
})
}
47 changes: 31 additions & 16 deletions compose/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"net"

"github.com/shoriwe/fullproxy/v4/filter"
"github.com/shoriwe/fullproxy/v4/reverse"
"github.com/shoriwe/fullproxy/v4/sshd"
"github.com/shoriwe/fullproxy/v4/utils/network"
Expand All @@ -17,17 +18,19 @@ const (
)

type Network struct {
Type string `yaml:"type" json:"type"`
Network *string `yaml:"network,omitempty" json:"network,omitempty"`
Address *string `yaml:"address,omitempty" json:"address,omitempty"`
Data *Network `yaml:"data,omitempty" json:"data,omitempty"`
Control *Network `yaml:"control,omitempty" json:"control,omitempty"`
Auth *Auth `yaml:"auth,omitempty" json:"auth,omitempty"`
Crypto *Crypto `yaml:"crypto,omitempty" json:"crypto,omitempty"`
SlaveListener *bool `yaml:"slaveListener,omitempty" json:"slaveListener,omitempty"`
master *reverse.Master
sshConn *ssh.Client
listener net.Listener
Type string `yaml:"type" json:"type"`
Network *string `yaml:"network,omitempty" json:"network,omitempty"`
Address *string `yaml:"address,omitempty" json:"address,omitempty"`
Data *Network `yaml:"data,omitempty" json:"data,omitempty"`
Control *Network `yaml:"control,omitempty" json:"control,omitempty"`
Auth *Auth `yaml:"auth,omitempty" json:"auth,omitempty"`
Crypto *Crypto `yaml:"crypto,omitempty" json:"crypto,omitempty"`
SlaveListener *bool `yaml:"slaveListener,omitempty" json:"slaveListener,omitempty"`
ListenerFilter *Filter `yaml:"listenerFilter,omitempty" json:"listenerFilter,omitempty"`
DialFilter *Filter `yaml:"dialFilter,omitempty" json:"dialFilter,omitempty"`
master *reverse.Master
sshConn *ssh.Client
listener net.Listener
}

func (n *Network) setupBasicListener(listen network.ListenFunc) (_ net.Listener, err error) {
Expand Down Expand Up @@ -152,9 +155,13 @@ func (n *Network) Listen() (ll net.Listener, err error) {
default:
err = fmt.Errorf("unknown network type %s", n.Type)
}
network.CloseOnError(&err, ll)
if err == nil && n.Crypto != nil {
ll, err = n.Crypto.WrapListener(ll)
}
if err == nil && n.ListenerFilter != nil {
ll, err = n.ListenerFilter.Listener(ll)
}
return ll, err
}

Expand All @@ -174,15 +181,23 @@ func (n *Network) setupSSHDialFunc() (dialFunc network.DialFunc, err error) {
return dialFunc, err
}

func (n *Network) DialFunc() (network.DialFunc, error) {
func (n *Network) DialFunc() (dialFunc network.DialFunc, err error) {
switch n.Type {
case NetworkBasic:
return net.Dial, nil
dialFunc, err = net.Dial, nil
case NetworkMaster:
return n.setupMasterDialFunc()
dialFunc, err = n.setupMasterDialFunc()
case NetworkSSH:
return n.setupSSHDialFunc()
dialFunc, err = n.setupSSHDialFunc()
default:
return nil, fmt.Errorf("unknown network type %s", n.Type)
dialFunc, err = nil, fmt.Errorf("unknown network type %s", n.Type)
}
if err == nil && n.DialFilter != nil {
var df *filter.DialFunc
df, err = n.DialFilter.DialFunc(dialFunc)
if err == nil {
dialFunc = df.Dial
}
}
return dialFunc, err
}

0 comments on commit 64d48ed

Please sign in to comment.