This repository was archived by the owner on Jun 14, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Adding lexer to sanitize sql string for hash key ✨
- Change import path of contexts to our own modules/lib ✏️
- Loading branch information
1 parent
5fd8896
commit 4f3026e
Showing
6 changed files
with
235 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
package hashkey | ||
|
||
import ( | ||
"strings" | ||
"sync" | ||
"unicode" | ||
|
||
lexer "github.com/bbuck/go-lexer" | ||
) | ||
|
||
type ( | ||
// RuneChecker specifies the signature requirements for the function to complement rune checker | ||
RuneChecker func(ch rune) bool | ||
) | ||
|
||
var ( | ||
stringsBuilderPool = &sync.Pool{ | ||
New: func() interface{} { | ||
return new(strings.Builder) | ||
}, | ||
} | ||
noWhitespace lexer.StateFunc | ||
insideQuote lexer.StateFunc | ||
isIgnored RuneChecker | ||
isBrackets RuneChecker | ||
isQuotes RuneChecker | ||
) | ||
|
||
// List of all tokens used in this hashkey token | ||
const ( | ||
UsualToken lexer.TokenType = iota | ||
InsideQuoteToken | ||
) | ||
|
||
func init() { | ||
isIgnored = func(ch rune) bool { | ||
return ch == ',' || ch == ';' | ||
} | ||
isBrackets = func(ch rune) bool { | ||
return ch == '(' || ch == ')' || ch == '{' || ch == '}' || ch == '[' || ch == ']' | ||
} | ||
isQuotes = func(ch rune) bool { | ||
return (ch == '\'' || ch == '"' || ch == '`') | ||
} | ||
noWhitespace = func(l *lexer.L) (fn lexer.StateFunc) { | ||
ch := l.Peek() | ||
for ch != lexer.EOFRune { | ||
if isQuotes(ch) { | ||
fn = insideQuote | ||
return | ||
} | ||
if unicode.IsControl(ch) || unicode.IsSpace(ch) || isIgnored(ch) || isBrackets(ch) { | ||
l.Next() | ||
l.Ignore() | ||
goto NEXTLOOP | ||
} | ||
l.Next() | ||
l.Emit(UsualToken) | ||
NEXTLOOP: | ||
ch = l.Peek() | ||
} | ||
return | ||
} | ||
insideQuote = func(l *lexer.L) (fn lexer.StateFunc) { | ||
startQuote := l.Next() | ||
l.Ignore() | ||
ch := l.Peek() | ||
for startQuote != ch && ch != lexer.EOFRune { | ||
l.Next() | ||
ch = l.Peek() | ||
} | ||
l.Emit(InsideQuoteToken) | ||
l.Next() | ||
l.Ignore() | ||
fn = noWhitespace | ||
return | ||
} | ||
} | ||
|
||
// Get gets the hash key for the SQL string. | ||
func Get(sqlStr string) (res string) { | ||
builder, ok := stringsBuilderPool.Get().(*strings.Builder) | ||
if !ok { | ||
builder = new(strings.Builder) | ||
} | ||
builder.Reset() | ||
defer stringsBuilderPool.Put(builder) | ||
|
||
hkLexer := lexer.New(sqlStr, noWhitespace) | ||
hkLexer.Start() | ||
for { | ||
token, ok := hkLexer.NextToken() | ||
if ok { | ||
break | ||
} | ||
switch token.Type { | ||
case UsualToken, InsideQuoteToken: | ||
builder.WriteString(token.Value) | ||
} | ||
} | ||
res = builder.String() | ||
return | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
package hashkey | ||
|
||
import "testing" | ||
|
||
func TestGet(t *testing.T) { | ||
type args struct { | ||
sqlStr string | ||
} | ||
tests := []struct { | ||
name string | ||
args args | ||
wantRes string | ||
}{ | ||
{ | ||
name: "Remove all whitespace inside the sql string", | ||
args: args{ | ||
sqlStr: "SELECT * FROM test", | ||
}, | ||
wantRes: "SELECT*FROMtest", | ||
}, | ||
{ | ||
name: "Remove all whitespace characters inside the sql string", | ||
args: args{ | ||
sqlStr: "SELECT * FROM test WHERE id = 5", | ||
}, | ||
wantRes: "SELECT*FROMtestWHEREid=5", | ||
}, | ||
{ | ||
name: "Don't remove any characters inside quotes", | ||
args: args{ | ||
sqlStr: "SELECT * FROM `test you` WHERE id = 5 AND name = 'Hello Test'", | ||
}, | ||
wantRes: "SELECT*FROMtest youWHEREid=5ANDname=Hello Test", | ||
}, | ||
{ | ||
name: "Multi quotes in sql string for (edge cases)", | ||
args: args{ | ||
sqlStr: "SELECT * FROM ``test you`` WHERE id = 5 AND name = 'Hello Test'", | ||
}, | ||
wantRes: "SELECT*FROMtestyouWHEREid=5ANDname=Hello Test", | ||
}, | ||
{ | ||
name: "Multi quotes in sql string for appending", | ||
args: args{ | ||
sqlStr: "SELECT * FROM `test you` WHERE id = 5 AND name = 'Hello Test''Appending Purposes'", | ||
}, | ||
wantRes: "SELECT*FROMtest youWHEREid=5ANDname=Hello TestAppending Purposes", | ||
}, | ||
{ | ||
name: "Remove all ignored characters, ex colon and brackets", | ||
args: args{ | ||
sqlStr: "INSERT INTO test(field1,field2) VALUES ('hello','name') ", | ||
}, | ||
wantRes: "INSERTINTOtestfield1field2VALUEShelloname", | ||
}, | ||
{ | ||
name: "Remove all ignored characters, ex end statement", | ||
args: args{ | ||
sqlStr: "INSERT INTO test(field1,field2) VALUES ('hello','name');", | ||
}, | ||
wantRes: "INSERTINTOtestfield1field2VALUEShelloname", | ||
}, | ||
{ | ||
name: "Remove all control characters", | ||
args: args{ | ||
sqlStr: "INSERT INTO test(field1,field2) VALUES ('hello','name');\x1A", | ||
}, | ||
wantRes: "INSERTINTOtestfield1field2VALUEShelloname", | ||
}, | ||
} | ||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
if gotRes := Get(tt.args.sqlStr); gotRes != tt.wantRes { | ||
t.Errorf("Get() = %v, want %v", gotRes, tt.wantRes) | ||
} | ||
}) | ||
} | ||
} |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,36 +1,73 @@ | ||
package xorm | ||
|
||
import ( | ||
"strings" | ||
"sync" | ||
|
||
"github.com/cespare/xxhash" | ||
"github.com/fairyhunter13/xorm/core" | ||
"github.com/fairyhunter13/xorm/lexer/hashkey" | ||
) | ||
|
||
var ( | ||
stmtCache = make(map[uint64]*core.Stmt, 0) //key: xxhash of sanitized sqlstring | ||
mutex = new(sync.RWMutex) | ||
) | ||
func newStatementCache() *StatementCache { | ||
return &StatementCache{ | ||
mapping: make(map[uint64]map[*core.DB]*core.Stmt), | ||
mutex: new(sync.RWMutex), | ||
} | ||
} | ||
|
||
// StatementCache provides mechanism to map statement to db and query. | ||
type StatementCache struct { | ||
mapping map[uint64]map[*core.DB]*core.Stmt | ||
mutex *sync.RWMutex | ||
} | ||
|
||
func (sc *StatementCache) getDBMap(key uint64) (dbMap map[*core.DB]*core.Stmt) { | ||
var ( | ||
ok bool | ||
) | ||
sc.mutex.RLock() | ||
dbMap, ok = sc.mapping[key] | ||
sc.mutex.RUnlock() | ||
if !ok { | ||
dbMap = make(map[*core.DB]*core.Stmt) | ||
sc.mutex.Lock() | ||
sc.mapping[key] = dbMap | ||
sc.mutex.Unlock() | ||
} | ||
return | ||
} | ||
|
||
// Get return the statement based on the hash key and db. | ||
func (sc *StatementCache) Get(key uint64, db *core.DB) (stmt *core.Stmt, has bool) { | ||
dbMap := sc.getDBMap(key) | ||
sc.mutex.RLock() | ||
stmt, has = dbMap[db] | ||
sc.mutex.RUnlock() | ||
return | ||
} | ||
|
||
func getKey(sqlStr string) string { | ||
return strings.Join(strings.Fields(sqlStr), "") | ||
// Set sets the statement based on the hash key and the db. | ||
func (sc *StatementCache) Set(key uint64, db *core.DB, stmt *core.Stmt) { | ||
dbMap := sc.getDBMap(key) | ||
sc.mutex.Lock() | ||
dbMap[db] = stmt | ||
sc.mutex.Unlock() | ||
} | ||
|
||
var ( | ||
stmtCache = newStatementCache() | ||
) | ||
|
||
func (session *Session) doPrepare(db *core.DB, sqlStr string) (stmt *core.Stmt, err error) { | ||
xxh := xxhash.Sum64String(getKey(sqlStr)) | ||
xxh := xxhash.Sum64String(hashkey.Get(sqlStr)) | ||
var has bool | ||
mutex.RLock() | ||
stmt, has = stmtCache[xxh] | ||
mutex.RUnlock() | ||
stmt, has = stmtCache.Get(xxh, db) | ||
if !has { | ||
stmt, err = db.PrepareContext(session.ctx, sqlStr) | ||
if err != nil { | ||
return nil, err | ||
} | ||
mutex.Lock() | ||
stmtCache[xxh] = stmt | ||
mutex.Unlock() | ||
stmtCache.Set(xxh, db, stmt) | ||
} | ||
return | ||
} |