Skip to content

Commit

Permalink
import into: redact sensitive info & check active job before create (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
D3Hunter authored Jun 15, 2023
1 parent 1077c72 commit beccd05
Show file tree
Hide file tree
Showing 20 changed files with 218 additions and 76 deletions.
2 changes: 1 addition & 1 deletion br/pkg/storage/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ go_test(
],
embed = [":storage"],
flaky = True,
shard_count = 46,
shard_count = 45,
deps = [
"//br/pkg/mock",
"@com_github_aws_aws_sdk_go//aws",
Expand Down
38 changes: 0 additions & 38 deletions br/pkg/storage/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,44 +261,6 @@ func TestParseRawURL(t *testing.T) {
}
}

func TestRedactURL(t *testing.T) {
type args struct {
str string
}
tests := []struct {
args args
want string
wantErr bool
}{
{args{""}, "", false},
{args{":"}, "", true},
{args{"~/file"}, "~/file", false},
{args{"gs://bucket/file"}, "gs://bucket/file", false},
// gs don't have access-key/secret-access-key, so it will NOT be redacted
{args{"gs://bucket/file?access-key=123"}, "gs://bucket/file?access-key=123", false},
{args{"gs://bucket/file?secret-access-key=123"}, "gs://bucket/file?secret-access-key=123", false},
{args{"s3://bucket/file"}, "s3://bucket/file", false},
{args{"s3://bucket/file?other-key=123"}, "s3://bucket/file?other-key=123", false},
{args{"s3://bucket/file?access-key=123"}, "s3://bucket/file?access-key=redacted", false},
{args{"s3://bucket/file?secret-access-key=123"}, "s3://bucket/file?secret-access-key=redacted", false},
// underline
{args{"s3://bucket/file?access_key=123"}, "s3://bucket/file?access_key=redacted", false},
{args{"s3://bucket/file?secret_access_key=123"}, "s3://bucket/file?secret_access_key=redacted", false},
}
for _, tt := range tests {
t.Run(tt.args.str, func(t *testing.T) {
got, err := RedactURL(tt.args.str)
if (err != nil) != tt.wantErr {
t.Errorf("RedactURL() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("RedactURL() got = %v, want %v", got, tt.want)
}
})
}
}

func TestIsLocal(t *testing.T) {
type args struct {
path string
Expand Down
2 changes: 1 addition & 1 deletion disttask/framework/handle/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func WaitGlobalTask(ctx context.Context, globalTask *proto.Task) error {
for {
select {
case <-ctx.Done():
return nil
return ctx.Err()
case <-ticker.C:
found, err := globalTaskManager.GetGlobalTaskByID(globalTask.ID)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions disttask/importinto/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ go_library(
"//parser/mysql",
"//sessionctx",
"//table/tables",
"//util/dbterror/exeerrors",
"//util/etcd",
"//util/logutil",
"//util/sqlexec",
Expand Down
13 changes: 13 additions & 0 deletions disttask/importinto/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/pingcap/tidb/executor/importer"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/util/dbterror/exeerrors"
"github.com/pingcap/tidb/util/logutil"
"github.com/pingcap/tidb/util/sqlexec"
"go.uber.org/zap"
Expand Down Expand Up @@ -158,6 +159,18 @@ func (ti *DistImporter) SubmitTask(ctx context.Context) (int64, *proto.Task, err
if err = globalTaskManager.WithNewTxn(func(se sessionctx.Context) error {
var err2 error
exec := se.(sqlexec.SQLExecutor)
// If 2 client try to execute IMPORT INTO concurrently, there's chance that both of them will pass the check.
// We can enforce ONLY one import job running by:
// - using LOCK TABLES, but it requires enable-table-lock=true, it's not enabled by default.
// - add a key to PD as a distributed lock, but it's a little complex, and we might support job queuing later.
// So we only add this simple soft check here and doc it.
activeJobCnt, err2 := importer.GetActiveJobCnt(ctx, exec)
if err2 != nil {
return err2
}
if activeJobCnt > 0 {
return exeerrors.ErrLoadDataPreCheckFailed.FastGenByArgs("there's pending or running jobs")
}
jobID, err2 = importer.CreateJob(ctx, exec, plan.DBName, plan.TableInfo.Name.L, plan.TableInfo.ID,
plan.User, plan.Parameters, ti.sourceFileSize)
if err2 != nil {
Expand Down
22 changes: 15 additions & 7 deletions executor/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -538,13 +538,7 @@ func (a *ExecStmt) Exec(ctx context.Context) (_ sqlexec.RecordSet, err error) {
var pi processinfoSetter
if raw, ok := sctx.(processinfoSetter); ok {
pi = raw
sql := a.OriginText()
if simple, ok := a.Plan.(*plannercore.Simple); ok && simple.Statement != nil {
if ss, ok := simple.Statement.(ast.SensitiveStmtNode); ok {
// Use SecureText to avoid leak password information.
sql = ss.SecureText()
}
}
sql := a.getSQLForProcessInfo()
maxExecutionTime := getMaxExecutionTime(sctx)
// Update processinfo, ShowProcess() will use it.
if a.Ctx.GetSessionVars().StmtCtx.StmtType == "" {
Expand Down Expand Up @@ -585,6 +579,20 @@ func (a *ExecStmt) Exec(ctx context.Context) (_ sqlexec.RecordSet, err error) {
}, nil
}

func (a *ExecStmt) getSQLForProcessInfo() string {
sql := a.OriginText()
if simple, ok := a.Plan.(*plannercore.Simple); ok && simple.Statement != nil {
if ss, ok := simple.Statement.(ast.SensitiveStmtNode); ok {
// Use SecureText to avoid leak password information.
sql = ss.SecureText()
}
} else if sn, ok2 := a.StmtNode.(ast.SensitiveStmtNode); ok2 {
// such as import into statement
sql = sn.SecureText()
}
return sql
}

func (a *ExecStmt) handleStmtForeignKeyTrigger(ctx context.Context, e Executor) error {
stmtCtx := a.Ctx.GetSessionVars().StmtCtx
if stmtCtx.ForeignKeyTriggerCtx.HasFKCascades {
Expand Down
11 changes: 8 additions & 3 deletions executor/import_into.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,19 @@ func (e *ImportIntoExec) Next(ctx context.Context, req *chunk.Chunk) (err error)
// need to return an empty req to indicate all results have been written
return nil
}

// todo: we don't need to do it here, remove it.
if err2 := e.controller.InitDataFiles(ctx); err2 != nil {
return err2
}

sqlExec := e.userSctx.(sqlexec.SQLExecutor)
if err2 := e.controller.CheckRequirements(ctx, sqlExec); err2 != nil {
// must use a new session to pre-check, else the stmt in show processlist will be changed.
newSCtx, err2 := CreateSession(e.userSctx)
if err2 != nil {
return err2
}
defer CloseSession(newSCtx)
sqlExec := newSCtx.(sqlexec.SQLExecutor)
if err2 = e.controller.CheckRequirements(ctx, sqlExec); err2 != nil {
return err2
}

Expand Down
5 changes: 1 addition & 4 deletions executor/importer/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -636,10 +636,7 @@ func (p *Plan) adjustOptions() {
}

func (p *Plan) initParameters(plan *plannercore.ImportInto) error {
redactURL, err := storage.RedactURL(p.Path)
if err != nil {
return exeerrors.ErrLoadDataInvalidURI.GenWithStackByArgs(err.Error())
}
redactURL := ast.RedactURL(p.Path)
var columnsAndVars, setClause string
var sb strings.Builder
formatCtx := pformat.NewRestoreCtx(pformat.DefaultRestoreFlags, &sb)
Expand Down
19 changes: 19 additions & 0 deletions executor/importer/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,25 @@ func GetJob(ctx context.Context, conn sqlexec.SQLExecutor, jobID int64, user str
return info, nil
}

// GetActiveJobCnt returns the count of active import jobs.
// Active import jobs include pending and running jobs.
func GetActiveJobCnt(ctx context.Context, conn sqlexec.SQLExecutor) (int64, error) {
ctx = util.WithInternalSourceType(ctx, kv.InternalImportInto)

sql := `select count(1) from mysql.tidb_import_jobs where status in (%?, %?)`
rs, err := conn.ExecuteInternal(ctx, sql, jobStatusPending, JobStatusRunning)
if err != nil {
return 0, err
}
defer terror.Call(rs.Close)
rows, err := sqlexec.DrainRecordSet(ctx, rs, 1)
if err != nil {
return 0, err
}
cnt := rows[0].GetInt64(0)
return cnt, nil
}

// CreateJob creates import into job by insert a record to system table.
// The AUTO_INCREMENT value will be returned as jobID.
func CreateJob(
Expand Down
19 changes: 19 additions & 0 deletions executor/importer/job_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ func TestJobHappyPath(t *testing.T) {
require.True(t, gotJobInfo.StartTime.IsZero())
require.True(t, gotJobInfo.EndTime.IsZero())
jobInfoEqual(t, jobInfo, gotJobInfo)
cnt, err := importer.GetActiveJobCnt(ctx, conn)
require.NoError(t, err)
require.Equal(t, int64(1), cnt)

// action before start, no effect
c.action(jobID)
Expand All @@ -111,9 +114,15 @@ func TestJobHappyPath(t *testing.T) {
jobInfo.Status = "running"
jobInfo.Step = importer.JobStepImporting
jobInfoEqual(t, jobInfo, gotJobInfo)
cnt, err = importer.GetActiveJobCnt(ctx, conn)
require.NoError(t, err)
require.Equal(t, int64(1), cnt)

// change job step
require.NoError(t, importer.Job2Step(ctx, conn, jobID, importer.JobStepValidating))
cnt, err = importer.GetActiveJobCnt(ctx, conn)
require.NoError(t, err)
require.Equal(t, int64(1), cnt)

// do action
c.action(jobID)
Expand All @@ -127,6 +136,10 @@ func TestJobHappyPath(t *testing.T) {
jobInfo.Summary = c.expectedSummary
jobInfo.ErrorMessage = c.expectedErrMsg
jobInfoEqual(t, jobInfo, gotJobInfo)
cnt, err = importer.GetActiveJobCnt(ctx, conn)
require.NoError(t, err)
require.Equal(t, int64(0), cnt)

// do action again, no effect
endTime := gotJobInfo.EndTime
c.action(jobID)
Expand Down Expand Up @@ -170,6 +183,9 @@ func TestGetAndCancelJob(t *testing.T) {
require.True(t, gotJobInfo.StartTime.IsZero())
require.True(t, gotJobInfo.EndTime.IsZero())
jobInfoEqual(t, jobInfo, gotJobInfo)
cnt, err := importer.GetActiveJobCnt(ctx, conn)
require.NoError(t, err)
require.Equal(t, int64(1), cnt)

// cancel job
require.NoError(t, importer.CancelJob(ctx, conn, jobID1))
Expand All @@ -182,6 +198,9 @@ func TestGetAndCancelJob(t *testing.T) {
jobInfo.Status = "cancelled"
jobInfo.ErrorMessage = "cancelled by user"
jobInfoEqual(t, jobInfo, gotJobInfo)
cnt, err = importer.GetActiveJobCnt(ctx, conn)
require.NoError(t, err)
require.Equal(t, int64(0), cnt)

// call cancel twice is ok, caller should check job status before cancel.
require.NoError(t, importer.CancelJob(ctx, conn, jobID1))
Expand Down
1 change: 1 addition & 0 deletions parser/ast/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ go_library(
"//parser/tidb",
"//parser/types",
"@com_github_pingcap_errors//:errors",
"@com_github_pingcap_failpoint//:failpoint",
],
)

Expand Down
12 changes: 12 additions & 0 deletions parser/ast/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
package ast

import (
"strings"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/parser/auth"
"github.com/pingcap/tidb/parser/format"
Expand Down Expand Up @@ -2076,6 +2078,8 @@ type ImportIntoStmt struct {
Options []*LoadDataOpt
}

var _ SensitiveStmtNode = &ImportIntoStmt{}

// Restore implements Node interface.
func (n *ImportIntoStmt) Restore(ctx *format.RestoreCtx) error {
ctx.WriteKeyWord("IMPORT INTO ")
Expand Down Expand Up @@ -2161,6 +2165,14 @@ func (n *ImportIntoStmt) Accept(v Visitor) (Node, bool) {
return v.Leave(n)
}

func (n *ImportIntoStmt) SecureText() string {
redactedStmt := *n
redactedStmt.Path = RedactURL(n.Path)
var sb strings.Builder
_ = redactedStmt.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &sb))
return sb.String()
}

// CallStmt represents a call procedure query node.
// See https://dev.mysql.com/doc/refman/5.7/en/call.html
type CallStmt struct {
Expand Down
28 changes: 28 additions & 0 deletions parser/ast/dml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
package ast_test

import (
"fmt"
"testing"

"github.com/pingcap/tidb/parser"
. "github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/format"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -605,3 +607,29 @@ func TestFulltextSearchModifier(t *testing.T) {
require.True(t, FulltextSearchModifier(FulltextSearchModifierNaturalLanguageMode).IsNaturalLanguageMode())
require.False(t, FulltextSearchModifier(FulltextSearchModifierNaturalLanguageMode).WithQueryExpansion())
}

func TestImportIntoSecureText(t *testing.T) {
testCases := []struct {
input string
secured string
}{
{
input: "import into t from 's3://bucket/prefix?access-key=aaaaa&secret-access-key=bbbbb'",
secured: `^IMPORT INTO .t. FROM \Q's3://bucket/prefix?\E((access-key=xxxxxx|secret-access-key=xxxxxx)(&|'$)){2}`,
},
{
input: "import into t from 'gcs://bucket/prefix?access-key=aaaaa&secret-access-key=bbbbb'",
secured: "\\QIMPORT INTO `t` FROM 'gcs://bucket/prefix?access-key=aaaaa&secret-access-key=bbbbb'\\E",
},
}

p := parser.New()
for _, tc := range testCases {
comment := fmt.Sprintf("input = %s", tc.input)
node, err := p.ParseOneStmt(tc.input, "", "")
require.NoError(t, err, comment)
n, ok := node.(SensitiveStmtNode)
require.True(t, ok, comment)
require.Regexp(t, tc.secured, n.SecureText(), comment)
}
}
41 changes: 26 additions & 15 deletions parser/ast/misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"strings"

"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/parser/auth"
"github.com/pingcap/tidb/parser/format"
"github.com/pingcap/tidb/parser/model"
Expand Down Expand Up @@ -3422,30 +3423,40 @@ func (n *BRIEStmt) Restore(ctx *format.RestoreCtx) error {
return nil
}

// SecureText implements SensitiveStmtNode
func (n *BRIEStmt) SecureText() string {
// RedactURL redacts the secret tokens in the URL. only S3 url need redaction for now.
// if the url is not a valid url, return the original string.
func RedactURL(str string) string {
// FIXME: this solution is not scalable, and duplicates some logic from BR.
redactedStorage := n.Storage
u, err := url.Parse(n.Storage)
if err == nil {
if u.Scheme == "s3" {
query := u.Query()
for key := range query {
switch strings.ToLower(strings.ReplaceAll(key, "_", "-")) {
case "access-key", "secret-access-key":
query[key] = []string{"xxxxxx"}
}
u, err := url.Parse(str)
if err != nil {
return str
}
scheme := u.Scheme
failpoint.Inject("forceRedactURL", func() {
scheme = "s3"
})
if strings.ToLower(scheme) == "s3" {
values := u.Query()
for k := range values {
// see below on why we normalize key
// https://github.com/pingcap/tidb/blob/a7c0d95f16ea2582bb569278c3f829403e6c3a7e/br/pkg/storage/parse.go#L163
normalizedKey := strings.ToLower(strings.ReplaceAll(k, "_", "-"))
if normalizedKey == "access-key" || normalizedKey == "secret-access-key" {
values[k] = []string{"xxxxxx"}
}
u.RawQuery = query.Encode()
redactedStorage = u.String()
}
u.RawQuery = values.Encode()
}
return u.String()
}

// SecureText implements SensitiveStmtNode
func (n *BRIEStmt) SecureText() string {
redactedStmt := &BRIEStmt{
Kind: n.Kind,
Schemas: n.Schemas,
Tables: n.Tables,
Storage: redactedStorage,
Storage: RedactURL(n.Storage),
Options: n.Options,
}

Expand Down
Loading

0 comments on commit beccd05

Please sign in to comment.