Skip to content

Commit

Permalink
Merge pull request #213 from actiontech/issue_541
Browse files Browse the repository at this point in the history
add sql audit before query
  • Loading branch information
sjjian authored May 24, 2022
2 parents 91c0f48 + 56c0abb commit 20b4eab
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 7 deletions.
28 changes: 21 additions & 7 deletions sqle/api/controller/v1/sql_query_ee.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"strconv"
"time"

sqlQuery "github.com/actiontech/sqle/sqle/server/sql_query"

"github.com/actiontech/sqle/sqle/errors"

"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -73,6 +75,22 @@ func prepareSQLQuery(c echo.Context) error {
return controller.JSONBaseErrorReq(c, err)
}

if len(nodes) == 0 {
return controller.JSONBaseErrorReq(c, errSqlQueryNoSql)
}

// audit
if instance.SqlQueryConfig.AuditEnabled {
singleSqls := make([]string, len(nodes))
for i, node := range nodes {
singleSqls[i] = node.Text
}
err = sqlQuery.Audit(singleSqls, instance)
if err != nil {
return controller.JSONBaseErrorReq(c, err)
}
}

rawSQL := &model.SqlQueryHistory{
CreateUserId: user.ID,
InstanceId: instance.ID,
Expand All @@ -90,10 +108,6 @@ func prepareSQLQuery(c echo.Context) error {
return controller.JSONBaseErrorReq(c, err)
}

if len(nodes) == 0 {
return controller.JSONBaseErrorReq(c, errSqlQueryNoSql)
}

for _, node := range nodes {
// validate SQL
validateResult, err := queryDriver.QueryPrepare(context.TODO(), node.Text, &driver.QueryPrepareConf{
Expand Down Expand Up @@ -237,9 +251,9 @@ func getSQLResult(c echo.Context) error {
singleSql.ExecEndAt = &endAt

l.WithFields(logrus.Fields{
"exec_start_time": startTime,
"exec_sql": rewriteRes.NewSQL,
"elapsed_time": endAt.Sub(startTime) / time.Millisecond,
"exec_start_time": startTime,
"exec_sql": rewriteRes.NewSQL,
"elapsed_time": endAt.Sub(startTime) / time.Millisecond,
}).Errorln("SQL Query error")

if err := s.Save(singleSql); err != nil {
Expand Down
60 changes: 60 additions & 0 deletions sqle/server/sql_query/audit_ee.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package sqlQuery

import (
"fmt"

"github.com/actiontech/sqle/sqle/driver"
"github.com/actiontech/sqle/sqle/errors"
"github.com/actiontech/sqle/sqle/log"
"github.com/actiontech/sqle/sqle/model"
"github.com/actiontech/sqle/sqle/server"
"github.com/sirupsen/logrus"
)

var ErrSqlQueryAuditLevelIsNotAllowed = errors.New(errors.DataExist, fmt.Errorf("the audit level is not allowed to perform sql query"))

func Audit(sqls []string, instance *model.Instance) error {
task := &model.Task{
DBType: instance.DbType,
Instance: instance,
}
for i, sql := range sqls {
task.ExecuteSQLs = append(task.ExecuteSQLs, &model.ExecuteSQL{
BaseSQL: model.BaseSQL{
Number: uint(i),
Content: sql,
},
})
}

logger := log.NewEntry().WithField("type", "sql_query")
err := server.Audit(logger, task)
if err != nil {
return err
}

allowQueryWhenLessThanAuditLevel := driver.RuleLevel(instance.SqlQueryConfig.AllowQueryWhenLessThanAuditLevel)
if allowQueryWhenLessThanAuditLevel.LessOrEqual(driver.RuleLevel(task.AuditLevel)) {
auditResults := make(map[string]struct{})
for _, sql := range task.ExecuteSQLs {
if allowQueryWhenLessThanAuditLevel.LessOrEqual(driver.RuleLevel(sql.AuditLevel)) {
auditResults[sql.AuditResult] = struct{}{}
logger.WithFields(logrus.Fields{
"sql": sql.Content,
"audit_result": sql.AuditResult,
"audit_level": sql.AuditLevel,
"allow_query_when_less_than_audit_level": instance.SqlQueryConfig.AllowQueryWhenLessThanAuditLevel,
}).Errorln(ErrSqlQueryAuditLevelIsNotAllowed.Error())
}
}

var auditResultsMsg string
for res := range auditResults {
auditResultsMsg = fmt.Sprintf("%v\n%v", auditResultsMsg, res)
}

return fmt.Errorf("%v: %v", ErrSqlQueryAuditLevelIsNotAllowed, auditResultsMsg)
}

return nil
}

0 comments on commit 20b4eab

Please sign in to comment.