diff --git a/README.md b/README.md index b42b1ba0..a8f9aafb 100755 --- a/README.md +++ b/README.md @@ -11,6 +11,8 @@ A lightweight, easy-to-use, and [performant](docs/benchmark_result.md) reverse proxy with a web UI. +_Join our [Discord](https://discord.gg/umReR62nRd) for help and discussions_ + ## Table of content diff --git a/internal/error/builder.go b/internal/error/builder.go index 813b90c4..2950923b 100644 --- a/internal/error/builder.go +++ b/internal/error/builder.go @@ -2,6 +2,7 @@ package error import ( "fmt" + "strings" "sync" ) @@ -24,7 +25,6 @@ func NewBuilder(format string, args ...any) Builder { func (b Builder) Add(err NestedError) Builder { if err != nil { b.Lock() - // TODO: if err severity is higher than b.severity, update b.severity b.errors = append(b.errors, err) b.Unlock() } @@ -49,6 +49,8 @@ func (b Builder) Addf(format string, args ...any) Builder { func (b Builder) Build() NestedError { if len(b.errors) == 0 { return nil + } else if len(b.errors) == 1 && !strings.ContainsRune(b.message, ' ') { + return b.errors[0].Subject(b.message) } return Join(b.message, b.errors...) } diff --git a/internal/error/error.go b/internal/error/error.go index c4fa150a..33627cb0 100644 --- a/internal/error/error.go +++ b/internal/error/error.go @@ -166,6 +166,8 @@ func (ne NestedError) Subject(s any) NestedError { } if ne.subject == "" { ne.subject = subject + } else if !strings.ContainsRune(subject, ' ') || strings.ContainsRune(ne.subject, '.') { + ne.subject = fmt.Sprintf("%s.%s", subject, ne.subject) } else { ne.subject = fmt.Sprintf("%s > %s", subject, ne.subject) } diff --git a/internal/error/errors.go b/internal/error/errors.go index 1c437bbb..8728c732 100644 --- a/internal/error/errors.go +++ b/internal/error/errors.go @@ -6,15 +6,16 @@ import ( ) var ( - ErrFailure = stderrors.New("failed") - ErrInvalid = stderrors.New("invalid") - ErrUnsupported = stderrors.New("unsupported") - ErrUnexpected = stderrors.New("unexpected") - ErrNotExists = stderrors.New("does not exist") - ErrMissing = stderrors.New("missing") - ErrDuplicated = stderrors.New("duplicated") - ErrOutOfRange = stderrors.New("out of range") - ErrTypeError = stderrors.New("type error") + ErrFailure = stderrors.New("failed") + ErrInvalid = stderrors.New("invalid") + ErrUnsupported = stderrors.New("unsupported") + ErrUnexpected = stderrors.New("unexpected") + ErrNotExists = stderrors.New("does not exist") + ErrMissing = stderrors.New("missing") + ErrDuplicated = stderrors.New("duplicated") + ErrOutOfRange = stderrors.New("out of range") + ErrTypeError = stderrors.New("type error") + ErrTypeMismatch = stderrors.New("type mismatch") ) const fmtSubjectWhat = "%w %v: %q" @@ -63,6 +64,14 @@ func OutOfRange(subject any, value any) NestedError { return errorf("%v %w: %v", subject, ErrOutOfRange, value) } -func TypeError(subject any, from, to reflect.Value) NestedError { - return errorf("%v %w: %T -> %T", subject, ErrTypeError, from.Interface(), to.Interface()) +func TypeError(subject any, from, to reflect.Type) NestedError { + return errorf("%v %w: %s -> %s\n", subject, ErrTypeError, from, to) +} + +func TypeError2(subject any, from, to reflect.Value) NestedError { + return TypeError(subject, from.Type(), to.Type()) +} + +func TypeMismatch[Expect any](value any) NestedError { + return errorf("%w: expect %s got %T", ErrTypeMismatch, reflect.TypeFor[Expect](), value) } diff --git a/internal/net/http/middleware/cloudflare_real_ip.go b/internal/net/http/middleware/cloudflare_real_ip.go index 1bd83d9d..05dbe448 100644 --- a/internal/net/http/middleware/cloudflare_real_ip.go +++ b/internal/net/http/middleware/cloudflare_real_ip.go @@ -13,6 +13,7 @@ import ( "github.com/sirupsen/logrus" "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/types" ) const ( @@ -53,7 +54,7 @@ func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.NestedError) { return cri.m, nil } -func tryFetchCFCIDR() (cfCIDRs []*net.IPNet) { +func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) { if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval { return } @@ -66,14 +67,14 @@ func tryFetchCFCIDR() (cfCIDRs []*net.IPNet) { } if common.IsTest { - cfCIDRs = []*net.IPNet{ + cfCIDRs = []*types.CIDR{ {IP: net.IPv4(127, 0, 0, 1), Mask: net.IPv4Mask(255, 0, 0, 0)}, {IP: net.IPv4(10, 0, 0, 0), Mask: net.IPv4Mask(255, 0, 0, 0)}, {IP: net.IPv4(172, 16, 0, 0), Mask: net.IPv4Mask(255, 255, 0, 0)}, {IP: net.IPv4(192, 168, 0, 0), Mask: net.IPv4Mask(255, 255, 255, 0)}, } } else { - cfCIDRs = make([]*net.IPNet, 0, 30) + cfCIDRs = make([]*types.CIDR, 0, 30) err := errors.Join( fetchUpdateCFIPRange(cfIPv4CIDRsEndpoint, cfCIDRs), fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, cfCIDRs), @@ -90,7 +91,7 @@ func tryFetchCFCIDR() (cfCIDRs []*net.IPNet) { return } -func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*net.IPNet) error { +func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*types.CIDR) error { resp, err := http.Get(endpoint) if err != nil { return err @@ -110,7 +111,7 @@ func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*net.IPNet) error { if err != nil { return fmt.Errorf("cloudflare responeded an invalid CIDR: %s", line) } else { - cfCIDRs = append(cfCIDRs, cidr) + cfCIDRs = append(cfCIDRs, (*types.CIDR)(cidr)) } } diff --git a/internal/net/http/middleware/middleware.go b/internal/net/http/middleware/middleware.go index 249d0792..ec18fb38 100644 --- a/internal/net/http/middleware/middleware.go +++ b/internal/net/http/middleware/middleware.go @@ -1,6 +1,7 @@ package middleware import ( + "encoding/json" "net/http" D "github.com/yusing/go-proxy/internal/docker" @@ -53,7 +54,14 @@ func (m *Middleware) String() string { return m.name } -func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) { +func (m *Middleware) MarshalJSON() ([]byte, error) { + return json.MarshalIndent(map[string]any{ + "name": m.name, + "options": m.impl, + }, "", " ") +} + +func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.NestedError) { if len(optsRaw) != 0 && m.withOptions != nil { if mWithOpt, err := m.withOptions(optsRaw); err != nil { return nil, err @@ -87,7 +95,7 @@ func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res continue } - m, err := m.WithOptionsClone(opts, rp) + m, err := m.WithOptionsClone(opts) if err != nil { invalidOpts.Add(err.Subject(name)) continue diff --git a/internal/net/http/middleware/middleware_builder.go b/internal/net/http/middleware/middleware_builder.go index 0ad89b8e..78e7c967 100644 --- a/internal/net/http/middleware/middleware_builder.go +++ b/internal/net/http/middleware/middleware_builder.go @@ -8,24 +8,27 @@ import ( "gopkg.in/yaml.v3" ) -func BuildMiddlewaresFromYAML(filePath string) (middlewares map[string]*Middleware, outErr E.NestedError) { - b := E.NewBuilder("middlewares compile errors") - defer b.To(&outErr) - - var data map[string][]map[string]any +func BuildMiddlewaresFromComposeFile(filePath string) (map[string]*Middleware, E.NestedError) { fileContent, err := os.ReadFile(filePath) if err != nil { - b.Add(E.FailWith("read file", err)) - return + return nil, E.FailWith("read middleware compose file", err) } - err = yaml.Unmarshal(fileContent, &data) + return BuildMiddlewaresFromYAML(fileContent) +} + +func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, outErr E.NestedError) { + b := E.NewBuilder("middlewares compile errors") + defer b.To(&outErr) + + var rawMap map[string][]map[string]any + err := yaml.Unmarshal(data, &rawMap) if err != nil { b.Add(E.FailWith("toml unmarshal", err)) return } middlewares = make(map[string]*Middleware) - for name, defs := range data { - chainErr := E.NewBuilder("errors in middleware chain %s", name) + for name, defs := range rawMap { + chainErr := E.NewBuilder(name) chain := make([]*Middleware, 0, len(defs)) for i, def := range defs { if def["use"] == nil || def["use"].(string) == "" { @@ -39,9 +42,9 @@ func BuildMiddlewaresFromYAML(filePath string) (middlewares map[string]*Middlewa continue } delete(def, "use") - m, err := base.withOptions(def) + m, err := base.WithOptionsClone(def) if err != nil { - chainErr.Add(err.Subjectf("%s.%d", name, i)) + chainErr.Add(err.Subjectf("item%d", i)) continue } chain = append(chain, m) diff --git a/internal/net/http/middleware/middleware_builder_test.go b/internal/net/http/middleware/middleware_builder_test.go index 31dc697d..997d41ce 100644 --- a/internal/net/http/middleware/middleware_builder_test.go +++ b/internal/net/http/middleware/middleware_builder_test.go @@ -1,9 +1,21 @@ package middleware import ( + _ "embed" + "encoding/json" "testing" + + E "github.com/yusing/go-proxy/internal/error" + . "github.com/yusing/go-proxy/internal/utils/testing" ) -func TestBuild(t *testing.T) { +//go:embed test_data/middleware_compose.yml +var testMiddlewareCompose []byte +func TestBuild(t *testing.T) { + // middlewares, err := BuildMiddlewaresFromYAML(testMiddlewareCompose) + // ExpectNoError(t, err.Error()) + data, err := E.Check(json.MarshalIndent(middlewares, "", " ")) + ExpectNoError(t, err.Error()) + t.Log(string(data)) } diff --git a/internal/net/http/middleware/middlewares.go b/internal/net/http/middleware/middlewares.go index 7ef542ae..9f71279f 100644 --- a/internal/net/http/middleware/middlewares.go +++ b/internal/net/http/middleware/middlewares.go @@ -2,6 +2,7 @@ package middleware import ( "fmt" + "net/http" "path" "strings" @@ -15,27 +16,27 @@ import ( var middlewares map[string]*Middleware func Get(name string) (middleware *Middleware, ok bool) { - middleware, ok = middlewares[name] + middleware, ok = middlewares[strings.ToLower(name)] return } // initialize middleware names and label parsers func init() { middlewares = map[string]*Middleware{ - "set_x_forwarded": SetXForwarded, - "hide_x_forwarded": HideXForwarded, - "redirect_http": RedirectHTTP, - "forward_auth": ForwardAuth.m, - "modify_response": ModifyResponse.m, - "modify_request": ModifyRequest.m, - "error_page": CustomErrorPage, - "custom_error_page": CustomErrorPage, - "real_ip": RealIP.m, - "cloudflare_real_ip": CloudflareRealIP.m, + "setxforwarded": SetXForwarded, + "hidexforwarded": HideXForwarded, + "redirecthttp": RedirectHTTP, + "forwardauth": ForwardAuth.m, + "modifyresponse": ModifyResponse.m, + "modifyrequest": ModifyRequest.m, + "errorpage": CustomErrorPage, + "customerrorpage": CustomErrorPage, + "realip": RealIP.m, + "cloudflarerealip": CloudflareRealIP.m, } names := make(map[*Middleware][]string) for name, m := range middlewares { - names[m] = append(names[m], name) + names[m] = append(names[m], http.CanonicalHeaderKey(name)) // register middleware name to docker label parsr // in order to parse middleware_name.option=value into correct type if m.labelParserMap != nil { @@ -49,6 +50,7 @@ func init() { m.name = names[0] } } + // TODO: seperate from init() b := E.NewBuilder("failed to load middlewares") middlewareDefs, err := U.ListFiles(common.MiddlewareDefsBasePath, 0) @@ -57,7 +59,7 @@ func init() { return } for _, defFile := range middlewareDefs { - mws, err := BuildMiddlewaresFromYAML(defFile) + mws, err := BuildMiddlewaresFromComposeFile(defFile) for name, m := range mws { if _, ok := middlewares[name]; ok { b.Add(E.Duplicated("middleware", name)) diff --git a/internal/net/http/middleware/modify_request_test.go b/internal/net/http/middleware/modify_request_test.go index 1ed4c5c0..1590e11d 100644 --- a/internal/net/http/middleware/modify_request_test.go +++ b/internal/net/http/middleware/modify_request_test.go @@ -15,7 +15,7 @@ func TestSetModifyRequest(t *testing.T) { } t.Run("set_options", func(t *testing.T) { - mr, err := ModifyRequest.m.WithOptionsClone(opts, nil) + mr, err := ModifyRequest.m.WithOptionsClone(opts) ExpectNoError(t, err.Error()) ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string)) ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string)) diff --git a/internal/net/http/middleware/modify_response_test.go b/internal/net/http/middleware/modify_response_test.go index 9c7a7a87..65e98d63 100644 --- a/internal/net/http/middleware/modify_response_test.go +++ b/internal/net/http/middleware/modify_response_test.go @@ -15,7 +15,7 @@ func TestSetModifyResponse(t *testing.T) { } t.Run("set_options", func(t *testing.T) { - mr, err := ModifyResponse.m.WithOptionsClone(opts, nil) + mr, err := ModifyResponse.m.WithOptionsClone(opts) ExpectNoError(t, err.Error()) ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string)) ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string)) diff --git a/internal/net/http/middleware/real_ip.go b/internal/net/http/middleware/real_ip.go index 4e2260ff..87094765 100644 --- a/internal/net/http/middleware/real_ip.go +++ b/internal/net/http/middleware/real_ip.go @@ -2,11 +2,11 @@ package middleware import ( "net" - "strings" "github.com/sirupsen/logrus" D "github.com/yusing/go-proxy/internal/docker" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/types" ) // https://nginx.org/en/docs/http/ngx_http_realip_module.html @@ -20,7 +20,7 @@ type realIPOpts struct { // Header is the name of the header to use for the real client IP Header string // From is a list of Address / CIDRs to trust - From []*net.IPNet + From []*types.CIDR /* If recursive search is disabled, the original client address that matches one of the trusted addresses is replaced by @@ -35,7 +35,7 @@ type realIPOpts struct { var RealIP = &realIP{ m: &Middleware{ labelParserMap: D.ValueParserMap{ - "from": CIDRListParser, + "from": D.YamlStringListParser, "recursive": D.BoolParser, }, withOptions: NewRealIP, @@ -45,14 +45,7 @@ var RealIP = &realIP{ var realIPOptsDefault = func() *realIPOpts { return &realIPOpts{ Header: "X-Real-IP", - From: []*net.IPNet{ - {IP: net.IPv4(127, 0, 0, 1), Mask: net.CIDRMask(8, 32)}, - {IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, - {IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, - {IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, - {IP: net.ParseIP("fc00::"), Mask: net.CIDRMask(7, 128)}, - {IP: net.ParseIP("fe80::"), Mask: net.CIDRMask(10, 128)}, - }, + From: []*types.CIDR{}, } } @@ -72,31 +65,6 @@ func NewRealIP(opts OptionsRaw) (*Middleware, E.NestedError) { return riWithOpts.m, nil } -func CIDRListParser(s string) (any, E.NestedError) { - sl, err := D.YamlStringListParser(s) - if err != nil { - return nil, err - } - - b := E.NewBuilder("invalid CIDR(s)") - - CIDRs := sl.([]string) - res := make([]*net.IPNet, 0, len(CIDRs)) - - for _, cidr := range CIDRs { - if !strings.Contains(cidr, "/") { - cidr += "/32" // single IP - } - _, ipnet, err := net.ParseCIDR(cidr) - if err != nil { - b.Add(E.Invalid("CIDR", cidr)) - continue - } - res = append(res, ipnet) - } - return res, b.Build() -} - func (ri *realIP) isInCIDRList(ip net.IP) bool { for _, CIDR := range ri.From { if CIDR.Contains(ip) { diff --git a/internal/net/http/middleware/real_ip_test.go b/internal/net/http/middleware/real_ip_test.go new file mode 100644 index 00000000..c4b941c8 --- /dev/null +++ b/internal/net/http/middleware/real_ip_test.go @@ -0,0 +1,58 @@ +package middleware + +import ( + "net" + "testing" + + "github.com/yusing/go-proxy/internal/types" + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +func TestSetRealIP(t *testing.T) { + opts := OptionsRaw{ + "header": "X-Real-IP", + "from": []string{ + "127.0.0.0/8", + "192.168.0.0/16", + "172.16.0.0/12", + }, + "recursive": true, + } + optExpected := &realIPOpts{ + Header: "X-Real-IP", + From: []*types.CIDR{ + { + IP: net.ParseIP("127.0.0.0"), + Mask: net.IPv4Mask(255, 0, 0, 0), + }, + { + IP: net.ParseIP("192.168.0.0"), + Mask: net.IPv4Mask(255, 255, 0, 0), + }, + { + IP: net.ParseIP("172.16.0.0"), + Mask: net.IPv4Mask(255, 240, 0, 0), + }, + }, + Recursive: true, + } + + t.Run("set_options", func(t *testing.T) { + ri, err := RealIP.m.WithOptionsClone(opts) + ExpectNoError(t, err.Error()) + // ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header) + // ExpectDeepEqual(t, ri.impl.(*realIP).From, optExpected.From) + // ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive) + ExpectDeepEqual(t, ri.impl.(*realIP).realIPOpts, optExpected) + }) + + // t.Run("request_headers", func(t *testing.T) { + // result, err := newMiddlewareTest(ModifyRequest.m, &testArgs{ + // middlewareOpt: opts, + // }) + // ExpectNoError(t, err.Error()) + // ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0") + // ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value")) + // ExpectEqual(t, result.RequestHeaders.Get("Accept"), "") + // }) +} diff --git a/internal/net/http/middleware/test_data/middleware_compose.yml b/internal/net/http/middleware/test_data/middleware_compose.yml new file mode 100644 index 00000000..4ec30401 --- /dev/null +++ b/internal/net/http/middleware/test_data/middleware_compose.yml @@ -0,0 +1,41 @@ +theGreatPretender: + - use: HideXForwarded + - use: ModifyRequest + setHeaders: + X-Real-IP: 6.6.6.6 + - use: ModifyResponse + hideHeaders: + - X-Test3 + - X-Test4 + +notAuthenticAuthentik: + - use: RedirectHTTP + - use: ForwardAuth + address: https://authentik.company + trustForwardHeader: true + addAuthCookiesToResponse: + - session_id + - user_id + authResponseHeaders: + - X-Auth-SessionID + - X-Auth-UserID + - use: CustomErrorPage + +realIPAuthentik: + - use: RedirectHTTP + - use: RealIP + header: X-Real-IP + from: + - "127.0.0.0/8" + - "192.168.0.0/16" + - "172.16.0.0/12" + recursive: true + - use: ForwardAuth + address: https://authentik.company + trustForwardHeader: true + +testFakeRealIP: + - use: ModifyRequest + setHeaders: + CF-Connecting-IP: 127.0.0.1 + - use: CloudflareRealIP diff --git a/internal/types/cidr.go b/internal/types/cidr.go new file mode 100644 index 00000000..8d3c4826 --- /dev/null +++ b/internal/types/cidr.go @@ -0,0 +1,34 @@ +package types + +import ( + "net" + "strings" + + E "github.com/yusing/go-proxy/internal/error" +) + +type CIDR net.IPNet + +func (*CIDR) ConvertFrom(val any) (any, E.NestedError) { + cidr, ok := val.(string) + if !ok { + return nil, E.TypeMismatch[string](val) + } + + if !strings.Contains(cidr, "/") { + cidr += "/32" // single IP + } + _, ipnet, err := net.ParseCIDR(cidr) + if err != nil { + return nil, E.Invalid("CIDR", cidr) + } + return (*CIDR)(ipnet), nil +} + +func (cidr *CIDR) Contains(ip net.IP) bool { + return (*net.IPNet)(cidr).Contains(ip) +} + +func (cidr *CIDR) String() string { + return (*net.IPNet)(cidr).String() +} diff --git a/internal/utils/serialization.go b/internal/utils/serialization.go index 9e525833..71a48cdb 100644 --- a/internal/utils/serialization.go +++ b/internal/utils/serialization.go @@ -12,6 +12,11 @@ import ( "gopkg.in/yaml.v3" ) +type SerializedObject = map[string]any +type Convertor interface { + ConvertFrom(value any) (any, E.NestedError) +} + func ValidateYaml(schema *jsonschema.Schema, data []byte) E.NestedError { var i any @@ -89,7 +94,7 @@ func Serialize(data any) (SerializedObject, E.NestedError) { } else if field.Anonymous { // If the field is an embedded struct, add its fields to the result fieldMap, err := Serialize(value.Field(i).Interface()) - if err.HasError() { + if err != nil { return nil, err } for k, v := range fieldMap { @@ -106,90 +111,138 @@ func Serialize(data any) (SerializedObject, E.NestedError) { return result, nil } -func Deserialize(src SerializedObject, target any) E.NestedError { - if src == nil || target == nil { +// Deserialize takes a SerializedObject and a target value, and assigns the values in the SerializedObject to the target value. +// Deserialize ignores case differences between the field names in the SerializedObject and the target. +// +// The target value must be a struct or a map[string]any. +// If the target value is a struct, the SerializedObject will be deserialized into the struct fields. +// If the target value is a map[string]any, the SerializedObject will be deserialized into the map. +// +// The function returns an error if the target value is not a struct or a map[string]any, or if there is an error during deserialization. +func Deserialize(src SerializedObject, dst any) E.NestedError { + if src == nil || dst == nil { return nil } - tValue := reflect.ValueOf(target) - mapping := make(map[string]string) + dstV := reflect.ValueOf(dst) + dstT := dstV.Type() - if tValue.Kind() == reflect.Ptr { - tValue = tValue.Elem() + if dstV.Kind() == reflect.Ptr { + dstV = dstV.Elem() + dstT = dstV.Type() } // convert data fields to lower no-snake // convert target fields to lower no-snake // then check if the field of data is in the target - if tValue.Kind() == reflect.Struct { - t := reflect.TypeOf(target).Elem() - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - snakeCaseField := ToLowerNoSnake(field.Name) - mapping[snakeCaseField] = field.Name + // TODO: use E.Builder to collect errors from all fields + + if dstV.Kind() == reflect.Struct { + mapping := make(map[string]reflect.Value) + for i := 0; i < dstV.NumField(); i++ { + field := dstT.Field(i) + mapping[ToLowerNoSnake(field.Name)] = dstV.Field(i) + } + for k, v := range src { + if field, ok := mapping[ToLowerNoSnake(k)]; ok { + err := Convert(reflect.ValueOf(v), field) + if err != nil { + return err.Subject(k) + } + } else { + return E.Unexpected("field", k) + } } - } else if tValue.Kind() == reflect.Map && tValue.Type().Key().Kind() == reflect.String { - if tValue.IsNil() { - tValue.Set(reflect.MakeMap(tValue.Type())) + } else if dstV.Kind() == reflect.Map && dstT.Key().Kind() == reflect.String { + if dstV.IsNil() { + dstV.Set(reflect.MakeMap(dstT)) } for k := range src { - // TODO: type check - tValue.SetMapIndex(reflect.ValueOf(ToLowerNoSnake(k)), reflect.ValueOf(src[k])) + tmp := reflect.New(dstT.Elem()).Elem() + err := Convert(reflect.ValueOf(src[k]), tmp) + if err != nil { + return err.Subject(k) + } + dstV.SetMapIndex(reflect.ValueOf(ToLowerNoSnake(k)), tmp) } return nil } else { - return E.Unsupported("target type", fmt.Sprintf("%T", target)) + return E.Unsupported("target type", fmt.Sprintf("%T", dst)) } - for k, v := range src { - kCleaned := ToLowerNoSnake(k) - if fieldName, ok := mapping[kCleaned]; ok { - prop := tValue.FieldByName(fieldName) - propType := prop.Type() - isPtr := prop.Kind() == reflect.Ptr - if prop.CanSet() { - val := reflect.ValueOf(v) - vType := val.Type() - switch { - case isPtr && vType.ConvertibleTo(propType.Elem()): - ptr := reflect.New(propType.Elem()) - ptr.Elem().Set(val.Convert(propType.Elem())) - prop.Set(ptr) - case vType.ConvertibleTo(propType): - prop.Set(val.Convert(propType)) - case isPtr: - var vSerialized SerializedObject - vSerialized, ok = v.(SerializedObject) - if !ok { - if vType.ConvertibleTo(reflect.TypeFor[SerializedObject]()) { - vSerialized = val.Convert(reflect.TypeFor[SerializedObject]()).Interface().(SerializedObject) - } else { - return E.Failure(fmt.Sprintf("convert %s (%T) to %s", k, v, reflect.TypeFor[SerializedObject]())) - } - } - propNew := reflect.New(propType.Elem()) - err := Deserialize(vSerialized, propNew.Interface()) - if err.HasError() { - return E.Failure("set field").With(err).Subject(k) - } - prop.Set(propNew) - default: - obj, ok := val.Interface().(SerializedObject) - if !ok { - return E.Invalid("conversion", k).Extraf("from %s to %s", vType, propType) - } - err := Deserialize(obj, prop.Addr().Interface()) - if err.HasError() { - return E.Failure("set field").With(err).Subject(k) - } - } - } else { - return E.Unsupported("field", k).Extraf("type %s is not settable", propType) + return nil +} + +// Convert attempts to convert the src to dst. +// +// If src is a map, it is deserialized into dst. +// If src is a slice, each of its elements are converted and stored in dst. +// For any other type, it is converted using the reflect.Value.Convert function (if possible). +// +// If dst is not settable, an error is returned. +// If src cannot be converted to dst, an error is returned. +// If any error occurs during conversion (e.g. deserialization), it is returned. +// +// Returns: +// - error: the error occurred during conversion, or nil if no error occurred. +func Convert(src reflect.Value, dst reflect.Value) E.NestedError { + srcT := src.Type() + dstVT := dst.Type() + + if src.Kind() == reflect.Interface { + src = src.Elem() + srcT = src.Type() + } + + if !dst.CanSet() { + return E.From(fmt.Errorf("%w type %T is unsettable", E.ErrUnsupported, dst.Interface())) + } + + switch { + case srcT.AssignableTo(dstVT): + dst.Set(src) + case srcT.ConvertibleTo(dstVT): + dst.Set(src.Convert(dstVT)) + case srcT.Kind() == reflect.Map: + if dstVT.Kind() != reflect.Map { + return E.TypeError("map", srcT, dstVT) + } + obj, ok := src.Interface().(SerializedObject) + if !ok { + return E.TypeError("map", srcT, dstVT) + } + err := Deserialize(obj, dst.Addr().Interface()) + if err != nil { + return err + } + case srcT.Kind() == reflect.Slice: + if dstVT.Kind() != reflect.Slice { + return E.TypeError("slice", srcT, dstVT) + } + newSlice := reflect.MakeSlice(dstVT, 0, src.Len()) + i := 0 + for _, v := range src.Seq2() { + tmp := reflect.New(dstVT.Elem()).Elem() + err := Convert(v, tmp) + if err != nil { + return err.Subjectf("[%d]", i) + } + newSlice = reflect.Append(newSlice, tmp) + i++ + } + dst.Set(newSlice) + default: + // check if Convertor is implemented + if converter, ok := dst.Interface().(Convertor); ok { + converted, err := converter.ConvertFrom(src.Interface()) + if err != nil { + return err } - } else { - return E.Unexpected("field", k) + dst.Set(reflect.ValueOf(converted)) + return nil } + return E.TypeError("conversion", srcT, dstVT) } return nil @@ -197,7 +250,7 @@ func Deserialize(src SerializedObject, target any) E.NestedError { func DeserializeJson(j map[string]string, target any) E.NestedError { data, err := E.Check(json.Marshal(j)) - if err.HasError() { + if err != nil { return err } return E.From(json.Unmarshal(data, target)) @@ -206,5 +259,3 @@ func DeserializeJson(j map[string]string, target any) E.NestedError { func ToLowerNoSnake(s string) string { return strings.ToLower(strings.ReplaceAll(s, "_", "")) } - -type SerializedObject = map[string]any diff --git a/internal/utils/serialization_test.go b/internal/utils/serialization_test.go new file mode 100644 index 00000000..387e8b5d --- /dev/null +++ b/internal/utils/serialization_test.go @@ -0,0 +1,47 @@ +package utils + +import ( + "testing" + + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +type S = struct { + I int + S string + IS []int + SS []string + MSI map[string]int + MIS map[int]string +} + +var testStruct = S{ + I: 1, + S: "hello", + IS: []int{1, 2, 3}, + SS: []string{"a", "b", "c"}, + MSI: map[string]int{"a": 1, "b": 2, "c": 3}, + MIS: map[int]string{1: "a", 2: "b", 3: "c"}, +} + +var testStructSerialized = map[string]any{ + "I": 1, + "S": "hello", + "IS": []int{1, 2, 3}, + "SS": []string{"a", "b", "c"}, + "MSI": map[string]int{"a": 1, "b": 2, "c": 3}, + "MIS": map[int]string{1: "a", 2: "b", 3: "c"}, +} + +func TestSerialize(t *testing.T) { + s, err := Serialize(testStruct) + ExpectNoError(t, err.Error()) + ExpectDeepEqual(t, s, testStructSerialized) +} + +func TestDeserialize(t *testing.T) { + var s S + err := Deserialize(testStructSerialized, &s) + ExpectNoError(t, err.Error()) + ExpectDeepEqual(t, s, testStruct) +}