Skip to content

Commit

Permalink
implement middleware compose
Browse files Browse the repository at this point in the history
  • Loading branch information
yusing committed Oct 1, 2024
1 parent f5a36f9 commit 44cfd65
Show file tree
Hide file tree
Showing 17 changed files with 391 additions and 151 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

A lightweight, easy-to-use, and [performant](docs/benchmark_result.md) reverse proxy with a web UI.

_Join our [Discord](https://discord.gg/umReR62nRd) for help and discussions_

## Table of content

<!-- TOC -->
Expand Down
4 changes: 3 additions & 1 deletion internal/error/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package error

import (
"fmt"
"strings"
"sync"
)

Expand All @@ -24,7 +25,6 @@ func NewBuilder(format string, args ...any) Builder {
func (b Builder) Add(err NestedError) Builder {
if err != nil {
b.Lock()
// TODO: if err severity is higher than b.severity, update b.severity
b.errors = append(b.errors, err)
b.Unlock()
}
Expand All @@ -49,6 +49,8 @@ func (b Builder) Addf(format string, args ...any) Builder {
func (b Builder) Build() NestedError {
if len(b.errors) == 0 {
return nil
} else if len(b.errors) == 1 && !strings.ContainsRune(b.message, ' ') {
return b.errors[0].Subject(b.message)
}
return Join(b.message, b.errors...)
}
Expand Down
2 changes: 2 additions & 0 deletions internal/error/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ func (ne NestedError) Subject(s any) NestedError {
}
if ne.subject == "" {
ne.subject = subject
} else if !strings.ContainsRune(subject, ' ') || strings.ContainsRune(ne.subject, '.') {
ne.subject = fmt.Sprintf("%s.%s", subject, ne.subject)
} else {
ne.subject = fmt.Sprintf("%s > %s", subject, ne.subject)
}
Expand Down
31 changes: 20 additions & 11 deletions internal/error/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@ import (
)

var (
ErrFailure = stderrors.New("failed")
ErrInvalid = stderrors.New("invalid")
ErrUnsupported = stderrors.New("unsupported")
ErrUnexpected = stderrors.New("unexpected")
ErrNotExists = stderrors.New("does not exist")
ErrMissing = stderrors.New("missing")
ErrDuplicated = stderrors.New("duplicated")
ErrOutOfRange = stderrors.New("out of range")
ErrTypeError = stderrors.New("type error")
ErrFailure = stderrors.New("failed")
ErrInvalid = stderrors.New("invalid")
ErrUnsupported = stderrors.New("unsupported")
ErrUnexpected = stderrors.New("unexpected")
ErrNotExists = stderrors.New("does not exist")
ErrMissing = stderrors.New("missing")
ErrDuplicated = stderrors.New("duplicated")
ErrOutOfRange = stderrors.New("out of range")
ErrTypeError = stderrors.New("type error")
ErrTypeMismatch = stderrors.New("type mismatch")
)

const fmtSubjectWhat = "%w %v: %q"
Expand Down Expand Up @@ -63,6 +64,14 @@ func OutOfRange(subject any, value any) NestedError {
return errorf("%v %w: %v", subject, ErrOutOfRange, value)
}

func TypeError(subject any, from, to reflect.Value) NestedError {
return errorf("%v %w: %T -> %T", subject, ErrTypeError, from.Interface(), to.Interface())
func TypeError(subject any, from, to reflect.Type) NestedError {
return errorf("%v %w: %s -> %s\n", subject, ErrTypeError, from, to)
}

func TypeError2(subject any, from, to reflect.Value) NestedError {
return TypeError(subject, from.Type(), to.Type())
}

func TypeMismatch[Expect any](value any) NestedError {
return errorf("%w: expect %s got %T", ErrTypeMismatch, reflect.TypeFor[Expect](), value)
}
11 changes: 6 additions & 5 deletions internal/net/http/middleware/cloudflare_real_ip.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/types"
)

const (
Expand Down Expand Up @@ -53,7 +54,7 @@ func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.NestedError) {
return cri.m, nil
}

func tryFetchCFCIDR() (cfCIDRs []*net.IPNet) {
func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval {
return
}
Expand All @@ -66,14 +67,14 @@ func tryFetchCFCIDR() (cfCIDRs []*net.IPNet) {
}

if common.IsTest {
cfCIDRs = []*net.IPNet{
cfCIDRs = []*types.CIDR{
{IP: net.IPv4(127, 0, 0, 1), Mask: net.IPv4Mask(255, 0, 0, 0)},
{IP: net.IPv4(10, 0, 0, 0), Mask: net.IPv4Mask(255, 0, 0, 0)},
{IP: net.IPv4(172, 16, 0, 0), Mask: net.IPv4Mask(255, 255, 0, 0)},
{IP: net.IPv4(192, 168, 0, 0), Mask: net.IPv4Mask(255, 255, 255, 0)},
}
} else {
cfCIDRs = make([]*net.IPNet, 0, 30)
cfCIDRs = make([]*types.CIDR, 0, 30)
err := errors.Join(
fetchUpdateCFIPRange(cfIPv4CIDRsEndpoint, cfCIDRs),
fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, cfCIDRs),
Expand All @@ -90,7 +91,7 @@ func tryFetchCFCIDR() (cfCIDRs []*net.IPNet) {
return
}

func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*net.IPNet) error {
func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*types.CIDR) error {
resp, err := http.Get(endpoint)
if err != nil {
return err
Expand All @@ -110,7 +111,7 @@ func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*net.IPNet) error {
if err != nil {
return fmt.Errorf("cloudflare responeded an invalid CIDR: %s", line)
} else {
cfCIDRs = append(cfCIDRs, cidr)
cfCIDRs = append(cfCIDRs, (*types.CIDR)(cidr))
}
}

Expand Down
12 changes: 10 additions & 2 deletions internal/net/http/middleware/middleware.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package middleware

import (
"encoding/json"
"net/http"

D "github.com/yusing/go-proxy/internal/docker"
Expand Down Expand Up @@ -53,7 +54,14 @@ func (m *Middleware) String() string {
return m.name
}

func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) {
func (m *Middleware) MarshalJSON() ([]byte, error) {
return json.MarshalIndent(map[string]any{
"name": m.name,
"options": m.impl,
}, "", " ")
}

func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
if len(optsRaw) != 0 && m.withOptions != nil {
if mWithOpt, err := m.withOptions(optsRaw); err != nil {
return nil, err
Expand Down Expand Up @@ -87,7 +95,7 @@ func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res
continue
}

m, err := m.WithOptionsClone(opts, rp)
m, err := m.WithOptionsClone(opts)
if err != nil {
invalidOpts.Add(err.Subject(name))
continue
Expand Down
27 changes: 15 additions & 12 deletions internal/net/http/middleware/middleware_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,27 @@ import (
"gopkg.in/yaml.v3"
)

func BuildMiddlewaresFromYAML(filePath string) (middlewares map[string]*Middleware, outErr E.NestedError) {
b := E.NewBuilder("middlewares compile errors")
defer b.To(&outErr)

var data map[string][]map[string]any
func BuildMiddlewaresFromComposeFile(filePath string) (map[string]*Middleware, E.NestedError) {
fileContent, err := os.ReadFile(filePath)
if err != nil {
b.Add(E.FailWith("read file", err))
return
return nil, E.FailWith("read middleware compose file", err)
}
err = yaml.Unmarshal(fileContent, &data)
return BuildMiddlewaresFromYAML(fileContent)
}

func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, outErr E.NestedError) {
b := E.NewBuilder("middlewares compile errors")
defer b.To(&outErr)

var rawMap map[string][]map[string]any
err := yaml.Unmarshal(data, &rawMap)
if err != nil {
b.Add(E.FailWith("toml unmarshal", err))
return
}
middlewares = make(map[string]*Middleware)
for name, defs := range data {
chainErr := E.NewBuilder("errors in middleware chain %s", name)
for name, defs := range rawMap {
chainErr := E.NewBuilder(name)
chain := make([]*Middleware, 0, len(defs))
for i, def := range defs {
if def["use"] == nil || def["use"].(string) == "" {
Expand All @@ -39,9 +42,9 @@ func BuildMiddlewaresFromYAML(filePath string) (middlewares map[string]*Middlewa
continue
}
delete(def, "use")
m, err := base.withOptions(def)
m, err := base.WithOptionsClone(def)
if err != nil {
chainErr.Add(err.Subjectf("%s.%d", name, i))
chainErr.Add(err.Subjectf("item%d", i))
continue
}
chain = append(chain, m)
Expand Down
14 changes: 13 additions & 1 deletion internal/net/http/middleware/middleware_builder_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
package middleware

import (
_ "embed"
"encoding/json"
"testing"

E "github.com/yusing/go-proxy/internal/error"
. "github.com/yusing/go-proxy/internal/utils/testing"
)

func TestBuild(t *testing.T) {
//go:embed test_data/middleware_compose.yml
var testMiddlewareCompose []byte

func TestBuild(t *testing.T) {
// middlewares, err := BuildMiddlewaresFromYAML(testMiddlewareCompose)
// ExpectNoError(t, err.Error())
data, err := E.Check(json.MarshalIndent(middlewares, "", " "))
ExpectNoError(t, err.Error())
t.Log(string(data))
}
28 changes: 15 additions & 13 deletions internal/net/http/middleware/middlewares.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package middleware

import (
"fmt"
"net/http"
"path"
"strings"

Expand All @@ -15,27 +16,27 @@ import (
var middlewares map[string]*Middleware

func Get(name string) (middleware *Middleware, ok bool) {
middleware, ok = middlewares[name]
middleware, ok = middlewares[strings.ToLower(name)]
return
}

// initialize middleware names and label parsers
func init() {
middlewares = map[string]*Middleware{
"set_x_forwarded": SetXForwarded,
"hide_x_forwarded": HideXForwarded,
"redirect_http": RedirectHTTP,
"forward_auth": ForwardAuth.m,
"modify_response": ModifyResponse.m,
"modify_request": ModifyRequest.m,
"error_page": CustomErrorPage,
"custom_error_page": CustomErrorPage,
"real_ip": RealIP.m,
"cloudflare_real_ip": CloudflareRealIP.m,
"setxforwarded": SetXForwarded,
"hidexforwarded": HideXForwarded,
"redirecthttp": RedirectHTTP,
"forwardauth": ForwardAuth.m,
"modifyresponse": ModifyResponse.m,
"modifyrequest": ModifyRequest.m,
"errorpage": CustomErrorPage,
"customerrorpage": CustomErrorPage,
"realip": RealIP.m,
"cloudflarerealip": CloudflareRealIP.m,
}
names := make(map[*Middleware][]string)
for name, m := range middlewares {
names[m] = append(names[m], name)
names[m] = append(names[m], http.CanonicalHeaderKey(name))
// register middleware name to docker label parsr
// in order to parse middleware_name.option=value into correct type
if m.labelParserMap != nil {
Expand All @@ -49,6 +50,7 @@ func init() {
m.name = names[0]
}
}

// TODO: seperate from init()
b := E.NewBuilder("failed to load middlewares")
middlewareDefs, err := U.ListFiles(common.MiddlewareDefsBasePath, 0)
Expand All @@ -57,7 +59,7 @@ func init() {
return
}
for _, defFile := range middlewareDefs {
mws, err := BuildMiddlewaresFromYAML(defFile)
mws, err := BuildMiddlewaresFromComposeFile(defFile)
for name, m := range mws {
if _, ok := middlewares[name]; ok {
b.Add(E.Duplicated("middleware", name))
Expand Down
2 changes: 1 addition & 1 deletion internal/net/http/middleware/modify_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func TestSetModifyRequest(t *testing.T) {
}

t.Run("set_options", func(t *testing.T) {
mr, err := ModifyRequest.m.WithOptionsClone(opts, nil)
mr, err := ModifyRequest.m.WithOptionsClone(opts)
ExpectNoError(t, err.Error())
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))
Expand Down
2 changes: 1 addition & 1 deletion internal/net/http/middleware/modify_response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func TestSetModifyResponse(t *testing.T) {
}

t.Run("set_options", func(t *testing.T) {
mr, err := ModifyResponse.m.WithOptionsClone(opts, nil)
mr, err := ModifyResponse.m.WithOptionsClone(opts)
ExpectNoError(t, err.Error())
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))
Expand Down
Loading

0 comments on commit 44cfd65

Please sign in to comment.