Skip to content

Commit

Permalink
#562】fix context Keys map concurrent issue (#561)
Browse files Browse the repository at this point in the history
* 1、Fix context Keys map concurrent issue
2、add RemoteIP func to get client ip

* context mutex default to zero-value

Co-authored-by: John Sun <[email protected]>
  • Loading branch information
UnderTreeTech and John Sun authored Apr 30, 2020
1 parent d99f595 commit d29dfdf
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 12 deletions.
99 changes: 99 additions & 0 deletions pkg/net/http/blademaster/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@ import (
"math"
"net/http"
"strconv"
"strings"
"sync"
"text/template"

"github.com/go-kratos/kratos/pkg/net/metadata"

"github.com/go-kratos/kratos/pkg/ecode"
"github.com/go-kratos/kratos/pkg/net/http/blademaster/binding"
"github.com/go-kratos/kratos/pkg/net/http/blademaster/render"
Expand Down Expand Up @@ -40,6 +44,8 @@ type Context struct {

// Keys is a key/value pair exclusively for the context of each request.
Keys map[string]interface{}
// This mutex protect Keys map
keysMutex sync.RWMutex

Error error

Expand All @@ -51,6 +57,20 @@ type Context struct {
Params Params
}

/************************************/
/********** CONTEXT CREATION ********/
/************************************/
func (c *Context) reset() {
c.Context = nil
c.index = -1
c.handlers = nil
c.Keys = nil
c.Error = nil
c.method = ""
c.RoutePath = ""
c.Params = c.Params[0:0]
}

/************************************/
/*********** FLOW CONTROL ***********/
/************************************/
Expand Down Expand Up @@ -93,16 +113,76 @@ func (c *Context) IsAborted() bool {
// Set is used to store a new key/value pair exclusively for this context.
// It also lazy initializes c.Keys if it was not used previously.
func (c *Context) Set(key string, value interface{}) {
c.keysMutex.Lock()
if c.Keys == nil {
c.Keys = make(map[string]interface{})
}
c.Keys[key] = value
c.keysMutex.Unlock()
}

// Get returns the value for the given key, ie: (value, true).
// If the value does not exists it returns (nil, false)
func (c *Context) Get(key string) (value interface{}, exists bool) {
c.keysMutex.RLock()
value, exists = c.Keys[key]
c.keysMutex.RUnlock()
return
}

// GetString returns the value associated with the key as a string.
func (c *Context) GetString(key string) (s string) {
if val, ok := c.Get(key); ok && val != nil {
s, _ = val.(string)
}
return
}

// GetBool returns the value associated with the key as a boolean.
func (c *Context) GetBool(key string) (b bool) {
if val, ok := c.Get(key); ok && val != nil {
b, _ = val.(bool)
}
return
}

// GetInt returns the value associated with the key as an integer.
func (c *Context) GetInt(key string) (i int) {
if val, ok := c.Get(key); ok && val != nil {
i, _ = val.(int)
}
return
}

// GetUint returns the value associated with the key as an unsigned integer.
func (c *Context) GetUint(key string) (ui uint) {
if val, ok := c.Get(key); ok && val != nil {
ui, _ = val.(uint)
}
return
}

// GetInt64 returns the value associated with the key as an integer.
func (c *Context) GetInt64(key string) (i64 int64) {
if val, ok := c.Get(key); ok && val != nil {
i64, _ = val.(int64)
}
return
}

// GetUint64 returns the value associated with the key as an unsigned integer.
func (c *Context) GetUint64(key string) (ui64 uint64) {
if val, ok := c.Get(key); ok && val != nil {
ui64, _ = val.(uint64)
}
return
}

// GetFloat64 returns the value associated with the key as a float64.
func (c *Context) GetFloat64(key string) (f64 float64) {
if val, ok := c.Get(key); ok && val != nil {
f64, _ = val.(float64)
}
return
}

Expand Down Expand Up @@ -307,3 +387,22 @@ func writeStatusCode(w http.ResponseWriter, ecode int) {
header := w.Header()
header.Set("kratos-status-code", strconv.FormatInt(int64(ecode), 10))
}

// RemoteIP implements a best effort algorithm to return the real client IP, it parses
// X-Real-IP and X-Forwarded-For in order to work properly with reverse-proxies such us: nginx or haproxy.
// Use X-Forwarded-For before X-Real-Ip as nginx uses X-Real-Ip with the proxy's IP.
// Notice: metadata.RemoteIP take precedence over X-Forwarded-For and X-Real-Ip
func (c *Context) RemoteIP() (remoteIP string) {
remoteIP = metadata.String(c, metadata.RemoteIP)
if remoteIP != "" {
return
}

remoteIP = c.Request.Header.Get("X-Forwarded-For")
remoteIP = strings.TrimSpace(strings.Split(remoteIP, ",")[0])
if remoteIP == "" {
remoteIP = strings.TrimSpace(c.Request.Header.Get("X-Real-Ip"))
}

return
}
3 changes: 1 addition & 2 deletions pkg/net/http/blademaster/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ func Logger() HandlerFunc {
const noUser = "no_user"
return func(c *Context) {
now := time.Now()
ip := metadata.String(c, metadata.RemoteIP)
req := c.Request
path := req.URL.Path
params := req.Form
Expand Down Expand Up @@ -55,7 +54,7 @@ func Logger() HandlerFunc {
}
lf(c,
log.KVString("method", req.Method),
log.KVString("ip", ip),
log.KVString("ip", c.RemoteIP()),
log.KVString("user", caller),
log.KVString("path", path),
log.KVString("params", params.Encode()),
Expand Down
23 changes: 13 additions & 10 deletions pkg/net/http/blademaster/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ type Engine struct {
allNoMethod []HandlerFunc
noRoute []HandlerFunc
noMethod []HandlerFunc

pool sync.Pool
}

type injection struct {
Expand Down Expand Up @@ -182,6 +184,9 @@ func NewServer(conf *ServerConfig) *Engine {
if err := engine.SetConfig(conf); err != nil {
panic(err)
}
engine.pool.New = func() interface{} {
return engine.newContext()
}
engine.RouterGroup.engine = engine
// NOTE add prometheus monitor location
engine.addRoute("GET", "/metrics", monitor())
Expand Down Expand Up @@ -477,20 +482,18 @@ func (engine *Engine) Inject(pattern string, handlers ...HandlerFunc) {

// ServeHTTP conforms to the http.Handler interface.
func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) {
c := &Context{
Context: nil,
engine: engine,
index: -1,
handlers: nil,
Keys: nil,
method: "",
Error: nil,
}

c := engine.pool.Get().(*Context)
c.Request = req
c.Writer = w
c.reset()

engine.handleContext(c)
engine.pool.Put(c)
}

//newContext for sync.pool
func (engine *Engine) newContext() *Context {
return &Context{engine: engine}
}

// NoRoute adds handlers for NoRoute. It return a 404 code by default.
Expand Down

0 comments on commit d29dfdf

Please sign in to comment.