Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: check max_allowed_packet constraint for function insert #7502

Merged
merged 6 commits into from
Aug 29, 2018
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions expression/builtin_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -3170,21 +3170,30 @@ func (c *insertFunctionClass) getFunction(ctx sessionctx.Context, args []Express
bf.tp.Flen = mysql.MaxBlobWidth
SetBinFlagOrBinStr(args[0].GetType(), bf.tp)
SetBinFlagOrBinStr(args[3].GetType(), bf.tp)

valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, errors.Trace(err)
}

if types.IsBinaryStr(args[0].GetType()) {
sig = &builtinInsertBinarySig{bf}
sig = &builtinInsertBinarySig{bf, maxAllowedPacket}
} else {
sig = &builtinInsertSig{bf}
sig = &builtinInsertSig{bf, maxAllowedPacket}
}
return sig, nil
}

type builtinInsertBinarySig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinInsertBinarySig) Clone() builtinFunc {
newSig := &builtinInsertBinarySig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

Expand Down Expand Up @@ -3216,18 +3225,26 @@ func (b *builtinInsertBinarySig) evalString(row chunk.Row) (string, bool, error)
}

if length > strLength-pos+1 || length < 0 {
return str[0:pos-1] + newstr, false, nil
length = strLength - pos + 1
}

if uint64(strLength-length+int64(len(newstr))) > b.maxAllowedPacket {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("insert", b.maxAllowedPacket))
return "", true, nil
}

return str[0:pos-1] + newstr + str[pos+length-1:], false, nil
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we check whether len(str[0:pos-1] + newstr + str[pos+length-1:]) > b. maxAllowedPacket ?

}

type builtinInsertSig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinInsertSig) Clone() builtinFunc {
newSig := &builtinInsertSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

Expand Down Expand Up @@ -3260,7 +3277,12 @@ func (b *builtinInsertSig) evalString(row chunk.Row) (string, bool, error) {
}

if length > runeLength-pos+1 || length < 0 {
return string(runes[0:pos-1]) + newstr, false, nil
length = runeLength - pos + 1
}

if uint64(runeLength-length)*uint64(mysql.MaxBytesOfCharacter)+uint64(len(newstr)) > b.maxAllowedPacket {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the number of bytes of some character is less than MaxBytesOfCharacter? Will this raise warning while it should not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this false-negative case exists, but to compute the exact bytes of these characters, we may need the charset info of the string, and then compute the length of specific characters for this charset, which incurs too much overhead, I guess that is why we use MaxBytesOfCharacter in other functions such as builtinLpadSig::evalString. @zz-jason should we compute the exact bytes for the string?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can build the string first, and use len() to compute the consumed bytes, but is it possible that panic is raised in building the string because it is too large? I think the purpose of max_allowed_packet check here is to prevent this kind of panic to some extent.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in MySQL, the byte count for lpad is calculated by:

byte_count = count * collation.collation->mbmaxlen;

as for the insert function, seems MySQL calculates the exact bytes for the string:

  /*
    There is one exception not handled (intentionaly) by the character set
    aggregation code. If one string is strong side and is binary, and
    another one is weak side and is a multi-byte character string,
    then we need to operate on the second string in terms on bytes when
    calling ::numchars() and ::charpos(), rather than in terms of characters.
    Lets substitute its character set to binary.
  */
  if (collation.collation == &my_charset_bin) {
    res->set_charset(&my_charset_bin);
    res2->set_charset(&my_charset_bin);
  }

  /* start and length are now sufficiently valid to pass to charpos function */
  start = res->charpos((int)start);
  length = res->charpos((int)length, (uint32)start);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why first build it? It just constitutes of three parts, so sum them would be enough.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems there isn't a convenient and efficient way to calculate the byte count of a []rune. @tiancaiamao any idea?

Copy link
Member

@zz-jason zz-jason Aug 28, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another method is to write a utf-8 code point iterator like this: https://gist.github.com/zz-jason/078110974bb931b7f8e3432775ecfd05, we can iterate on the origin []byte, find the code point located at pos and pos+length, and the count the bytes for each []rune prefix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh you are right... updated

b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("insert", b.maxAllowedPacket))
return "", true, nil
}
return string(runes[0:pos-1]) + newstr + string(runes[pos+length-1:]), false, nil
}
Expand Down
45 changes: 45 additions & 0 deletions expression/builtin_string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1425,6 +1425,51 @@ func (s *testEvaluatorSuite) TestRpadSig(c *C) {
c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue, Commentf("err %v", lastWarn.Err))
}

func (s *testEvaluatorSuite) TestInsertBinarySig(c *C) {
colTypes := []*types.FieldType{
{Tp: mysql.TypeVarchar},
{Tp: mysql.TypeLonglong},
{Tp: mysql.TypeLonglong},
{Tp: mysql.TypeVarchar},
}
resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: 3}

args := []Expression{
&Column{Index: 0, RetType: colTypes[0]},
&Column{Index: 1, RetType: colTypes[1]},
&Column{Index: 2, RetType: colTypes[2]},
&Column{Index: 3, RetType: colTypes[3]},
}

base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType}
insert := &builtinInsertBinarySig{base, 3}

input := chunk.NewChunkWithCapacity(colTypes, 2)
input.AppendString(0, "abc")
input.AppendString(0, "abc")
input.AppendInt64(1, 3)
input.AppendInt64(1, 3)
input.AppendInt64(2, -1)
input.AppendInt64(2, -1)
input.AppendString(3, "d")
input.AppendString(3, "de")

res, isNull, err := insert.evalString(input.GetRow(0))
c.Assert(res, Equals, "abd")
c.Assert(isNull, IsFalse)
c.Assert(err, IsNil)

res, isNull, err = insert.evalString(input.GetRow(1))
c.Assert(res, Equals, "")
c.Assert(isNull, IsTrue)
c.Assert(err, IsNil)

warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings()
c.Assert(len(warnings), Equals, 1)
lastWarn := warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue, Commentf("err %v", lastWarn.Err))
}

func (s *testEvaluatorSuite) TestInstr(c *C) {
defer testleak.AfterTest(c)()
tbl := []struct {
Expand Down