diff --git a/atreugo.go b/atreugo.go index 3420af4..5744800 100644 --- a/atreugo.go +++ b/atreugo.go @@ -61,7 +61,7 @@ func New(cfg Config) *Atreugo { } server := &Atreugo{ - server: newFasthttpServer(cfg, r.router.Handler, log), + server: newFasthttpServer(cfg, log), log: log, cfg: cfg, Router: r, @@ -70,14 +70,9 @@ func New(cfg Config) *Atreugo { return server } -func newFasthttpServer(cfg Config, handler fasthttp.RequestHandler, log fasthttp.Logger) *fasthttp.Server { - if cfg.Compress { - handler = fasthttp.CompressHandler(handler) - } - +func newFasthttpServer(cfg Config, log fasthttp.Logger) *fasthttp.Server { return &fasthttp.Server{ Name: cfg.Name, - Handler: handler, HeaderReceived: cfg.HeaderReceived, Concurrency: cfg.Concurrency, DisableKeepalive: cfg.DisableKeepalive, @@ -104,6 +99,28 @@ func newFasthttpServer(cfg Config, handler fasthttp.RequestHandler, log fasthttp } } +func (s *Atreugo) handler() fasthttp.RequestHandler { + handler := s.router.Handler + + if len(s.virtualHosts) > 0 { + handler = func(ctx *fasthttp.RequestCtx) { + hostname := gotils.B2S(ctx.URI().Host()) + + if h := s.virtualHosts[hostname]; h != nil { + h(ctx) + } else { + s.router.Handler(ctx) + } + } + } + + if s.cfg.Compress { + handler = fasthttp.CompressHandler(handler) + } + + return handler +} + // SaveMatchedRoutePath if enabled, adds the matched route path onto the ctx.UserValue context // before invoking the handler. // The matched route path is only added to handlers of routes that were @@ -184,6 +201,7 @@ func (s *Atreugo) Serve(ln net.Listener) error { s.cfg.Addr = ln.Addr().String() s.cfg.Network = ln.Addr().Network() + s.server.Handler = s.handler() if gotils.StringSliceInclude(tcpNetworks, s.cfg.Network) { schema := "http" @@ -207,3 +225,29 @@ func (s *Atreugo) Serve(ln net.Listener) error { func (s *Atreugo) SetLogOutput(output io.Writer) { s.log.SetOutput(output) } + +// NewVirtualHost returns a new sub-router for running more than one web site +// (such as company1.example.com and company2.example.com) on a single atreugo instance. +// Virtual hosts can be "IP-based", meaning that you have a different IP address +// for every web site, or "name-based", meaning that you have multiple names +// running on each IP address. +// +// The fact that they are running on the same atreugo instance is not apparent to the end user. +func (s *Atreugo) NewVirtualHost(hostname string) *Router { + if s.virtualHosts == nil { + s.virtualHosts = make(map[string]fasthttp.RequestHandler) + } + + vHost := newRouter(s.log, s.cfg.ErrorView) + vHost.router.NotFound = s.router.NotFound + vHost.router.MethodNotAllowed = s.router.MethodNotAllowed + vHost.router.PanicHandler = s.router.PanicHandler + + if s.virtualHosts[hostname] != nil { + panicf("a router is already registered for virtual host '%s'", hostname) + } + + s.virtualHosts[hostname] = vHost.router.Handler + + return vHost +} diff --git a/atreugo_test.go b/atreugo_test.go index f60bd2c..0368e3c 100644 --- a/atreugo_test.go +++ b/atreugo_test.go @@ -3,6 +3,9 @@ package atreugo import ( "bytes" "errors" + "fmt" + "math/rand" + "net" "reflect" "testing" "time" @@ -160,75 +163,174 @@ func Test_New(t *testing.T) { //nolint:funlen,gocognit } func Test_newFasthttpServer(t *testing.T) { //nolint:funlen - type args struct { - compress bool + cfg := Config{ + Name: "test", + HeaderReceived: func(header *fasthttp.RequestHeader) fasthttp.RequestConfig { + return fasthttp.RequestConfig{} + }, + Concurrency: rand.Int(), // nolint:gosec + DisableKeepalive: true, + ReadBufferSize: rand.Int(), // nolint:gosec + WriteBufferSize: rand.Int(), // nolint:gosec + ReadTimeout: time.Duration(rand.Int()), // nolint:gosec + WriteTimeout: time.Duration(rand.Int()), // nolint:gosec + IdleTimeout: time.Duration(rand.Int()), // nolint:gosec + MaxConnsPerIP: rand.Int(), // nolint:gosec + MaxRequestsPerConn: rand.Int(), // nolint:gosec + MaxRequestBodySize: rand.Int(), // nolint:gosec + ReduceMemoryUsage: true, + GetOnly: true, + DisablePreParseMultipartForm: true, + LogAllErrors: true, + DisableHeaderNamesNormalizing: true, + SleepWhenConcurrencyLimitsExceeded: time.Duration(rand.Int()), // nolint:gosec + NoDefaultServerHeader: true, + NoDefaultDate: true, + NoDefaultContentType: true, + ConnState: func(net.Conn, fasthttp.ConnState) {}, + KeepHijackedConns: true, } - type want struct { - compress bool + srv := newFasthttpServer(cfg, testLog) + + if srv == nil { + t.Fatal("newFasthttpServer() == nil") + } + + fasthttpServerType := reflect.TypeOf(fasthttp.Server{}) + configType := reflect.TypeOf(Config{}) + + fasthttpServerValue := reflect.ValueOf(*srv) // nolint:govet + configValue := reflect.ValueOf(cfg) + + for i := 0; i < fasthttpServerType.NumField(); i++ { + field := fasthttpServerType.Field(i) + + if !unicode.IsUpper(rune(field.Name[0])) { // Check if the field is public + continue + } else if gotils.StringSliceInclude(notConfigFasthttpFields, field.Name) { + continue + } + + _, exist := configType.FieldByName(field.Name) + if !exist { + t.Errorf("The field '%s' does not exist in atreugo.Config", field.Name) + } + + v1 := fmt.Sprint(fasthttpServerValue.FieldByName(field.Name).Interface()) + v2 := fmt.Sprint(configValue.FieldByName(field.Name).Interface()) + + if v1 != v2 { + t.Errorf("fasthttp.Server.%s == %s, want %s", field.Name, v1, v2) + } + } + + if srv.Handler != nil { + t.Error("fasthttp.Server.Handler must be nil") + } + + if !isEqual(srv.Logger, testLog) { + t.Errorf("fasthttp.Server.Logger == %p, want %p", srv.Logger, testLog) + } +} + +func TestAtreugo_handler(t *testing.T) { // nolint:funlen,gocognit + type args struct { + cfg Config + hosts []string } tests := []struct { name string args args - want want }{ { - name: "NotCompress", + name: "Default", args: args{ - compress: false, - }, - want: want{ - compress: false, + cfg: Config{}, }, }, { name: "Compress", args: args{ - compress: true, + cfg: Config{Compress: true}, }, - want: want{ - compress: true, + }, + { + name: "MultiHost", + args: args{ + cfg: Config{}, + hosts: []string{"localhost", "example.com"}, + }, + }, + { + name: "MultiHostCompress", + args: args{ + cfg: Config{Compress: true}, + hosts: []string{"localhost", "example.com"}, }, }, } - handler := func(ctx *fasthttp.RequestCtx) {} - for _, test := range tests { tt := test t.Run(tt.name, func(t *testing.T) { - cfg := Config{ - LogLevel: "fatal", - Compress: tt.args.compress, + testView := func(ctx *RequestCtx) error { + return ctx.JSONResponse(JSON{"data": gotils.RandBytes(make([]byte, 300))}) } - srv := newFasthttpServer(cfg, handler, testLog) + testPath := "/" + + s := New(tt.args.cfg) + s.GET(testPath, testView) - if (reflect.ValueOf(handler).Pointer() == reflect.ValueOf(srv.Handler).Pointer()) == tt.want.compress { - t.Error("The handler has not been wrapped by compression handler") + for _, hostname := range tt.args.hosts { + vHost := s.NewVirtualHost(hostname) + vHost.GET(testPath, testView) } - }) - } -} -func TestAtreugo_ConfigFasthttpFields(t *testing.T) { - fasthttpServerType := reflect.TypeOf(fasthttp.Server{}) - configType := reflect.TypeOf(Config{}) + handler := s.handler() - for i := 0; i < fasthttpServerType.NumField(); i++ { - field := fasthttpServerType.Field(i) + if handler == nil { + t.Errorf("handler is nil") + } - if !unicode.IsUpper(rune(field.Name[0])) { // Check if the field is public - continue - } else if gotils.StringSliceInclude(notConfigFasthttpFields, field.Name) { - continue - } + newHostname := string(gotils.RandBytes(make([]byte, 10))) + ".com" - _, exist := configType.FieldByName(field.Name) - if !exist { - t.Errorf("The field '%s' does not exist in atreugo.Config", field.Name) - } + hosts := tt.args.hosts + hosts = append(hosts, newHostname) + + for _, hostname := range hosts { + for _, path := range []string{testPath, "/notfound"} { + ctx := new(fasthttp.RequestCtx) + ctx.Request.Header.Set(fasthttp.HeaderAcceptEncoding, "gzip") + ctx.Request.Header.Set(fasthttp.HeaderHost, hostname) + ctx.Request.URI().SetHost(hostname) + ctx.Request.SetRequestURI(path) + + handler(ctx) + + statusCode := ctx.Response.StatusCode() + wantStatusCode := fasthttp.StatusOK + + if path != testPath { + wantStatusCode = fasthttp.StatusNotFound + } + + if statusCode != wantStatusCode { + t.Errorf("Host %s - Path %s, Status code == %d, want %d", hostname, path, statusCode, wantStatusCode) + } + + if wantStatusCode == fasthttp.StatusNotFound { + continue + } + + if tt.args.cfg.Compress && len(ctx.Response.Header.Peek(fasthttp.HeaderContentEncoding)) == 0 { + t.Errorf("The header '%s' is not setted", fasthttp.HeaderContentEncoding) + } + } + } + }) } } @@ -359,6 +461,10 @@ func TestAtreugo_Serve(t *testing.T) { if s.cfg.Network != lnNetwork { t.Errorf("Atreugo.Config.Network = %s, want %s", s.cfg.Network, lnNetwork) } + + if s.server.Handler == nil { + t.Error("Atreugo.server.Handler is nil") + } } func TestAtreugo_SetLogOutput(t *testing.T) { @@ -372,3 +478,70 @@ func TestAtreugo_SetLogOutput(t *testing.T) { t.Error("SetLogOutput() log output was not changed") } } + +func TestAtreugo_NewVirtualHost(t *testing.T) { + hostname := "localhost" + s := New(testAtreugoConfig) + + if s.virtualHosts != nil { + t.Error("Atreugo.virtualHosts must be nil before register a new virtual host") + } + + vHost := s.NewVirtualHost(hostname) + if vHost == nil { + t.Fatal("Atreugo.NewVirtualHost() returned a nil router") + } + + if !isEqual(vHost.router.NotFound, s.router.NotFound) { + t.Errorf("VirtualHost router.NotFound == %p, want %p", vHost.router.NotFound, s.router.NotFound) + } + + if !isEqual(vHost.router.MethodNotAllowed, s.router.MethodNotAllowed) { + t.Errorf( + "VirtualHost router.MethodNotAllowed == %p, want %p", + vHost.router.MethodNotAllowed, + s.router.MethodNotAllowed, + ) + } + + if !isEqual(vHost.router.PanicHandler, s.router.PanicHandler) { + t.Errorf("VirtualHost router.PanicHandler == %p, want %p", vHost.router.PanicHandler, s.router.PanicHandler) + } + + if h := s.virtualHosts[hostname]; h == nil { + t.Error("The new virtual host is not registeded") + } + + defer func() { + err := recover() + if err == nil { + t.Error("Expected panic when a virtual host is duplicated") + } + + wantErrString := fmt.Sprintf("a router is already registered for virtual host '%s'", hostname) + if err != wantErrString { + t.Errorf("Error string == %s, want %s", err, wantErrString) + } + }() + + // panic when a virtual host is duplicated + s.NewVirtualHost(hostname) +} + +// Benchmarks. +func Benchmark_Handler(b *testing.B) { + s := New(testAtreugoConfig) + s.GET("/", func(ctx *RequestCtx) error { return nil }) + + ctx := new(fasthttp.RequestCtx) + ctx.Request.Header.SetMethod("GET") + ctx.Request.SetRequestURI("/") + + handler := s.handler() + + b.ResetTimer() + + for i := 0; i <= b.N; i++ { + handler(ctx) + } +} diff --git a/types.go b/types.go index e2c5e4d..da6faaf 100644 --- a/types.go +++ b/types.go @@ -21,6 +21,8 @@ type Atreugo struct { log *logger.Logger cfg Config + virtualHosts map[string]fasthttp.RequestHandler + *Router } diff --git a/utils.go b/utils.go index 130db9b..3da1a59 100644 --- a/utils.go +++ b/utils.go @@ -1,11 +1,16 @@ package atreugo import ( + "fmt" "reflect" "github.com/valyala/fasthttp" ) +func panicf(s string, args ...interface{}) { + panic(fmt.Sprintf(s, args...)) +} + func viewToHandler(view View, errorView ErrorView) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { actx := AcquireRequestCtx(ctx)