Skip to content

Commit

Permalink
string escaping
Browse files Browse the repository at this point in the history
  • Loading branch information
mgravell committed Nov 18, 2023
1 parent 1f0a408 commit 251d9e0
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 35 deletions.
50 changes: 26 additions & 24 deletions src/Dapper.AOT.Analyzers/Internal/GeneralSqlParser.cs
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@
using System;
using Dapper.SqlAnalysis;
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;

namespace Dapper.Internal.SqlParsing;

public enum Batch
{
None,
Semicolon, // used by Postgresql
Go, // used by TSQL
}

public readonly struct CommandVariable : IEquatable<CommandVariable>
{
public CommandVariable(string name, int index)
Expand Down Expand Up @@ -78,7 +72,7 @@ private enum ParseState
/// Tokenize a sql fragment into batches, extracting the variables/locals in use
/// </summary>
/// <remarks>This is a basic parse only; no syntax processing - just literals, identifiers, etc</remarks>
public static List<CommandBatch> Parse(string sql, Batch batch, bool strip = false)
public static List<CommandBatch> Parse(string sql, SqlSyntax syntax, bool strip = false)
{
int bIndex = 0;
char[] buffer = ArrayPool<char>.Shared.Rent(sql.Length + 1);
Expand All @@ -89,6 +83,8 @@ public static List<CommandBatch> Parse(string sql, Batch batch, bool strip = fal
ImmutableArray<CommandVariable>.Builder? variables = null;
var result = new List<CommandBatch>();

bool BatchSemicolon() => syntax == SqlSyntax.PostgreSql;

char LookAhead(int delta = 1)
{
var ci = i + delta;
Expand Down Expand Up @@ -148,7 +144,7 @@ void FlushBatch()
if (IsGo()) bIndex -= 2; // don't retain the GO

bool removedSemicolon = false;
if ((strip || batch == Batch.Semicolon) && Last(0) == ';')
if ((strip || BatchSemicolon()) && Last(0) == ';')
{
Discard();
removedSemicolon = true;
Expand Down Expand Up @@ -190,12 +186,15 @@ bool IsWhitespace()
}
bool IsGo()
{
return batch == Batch.Go && ElementLength() == 2
return syntax == SqlSyntax.SqlServer && ElementLength() == 2
&& Last(1) is 'g' or 'G' && Last(0) is 'o' or 'O';
}

bool IsString(char c) => state == ParseState.String && stringType == c;

bool IsSingleQuoteString() => state == ParseState.String && (stringType == '\'' || char.IsLetter(stringType));
void Advance() => buffer[bIndex++] = sql[++i];

for (; i < sql.Length; i++)
{
var c = i == sql.Length ? ';' : sql[i]; // spoof a ; at the end to simplify end-of-block handling
Expand All @@ -214,7 +213,7 @@ bool IsGo()
}

// store by default, we'll backtrack in the rare scenarios that we don't want it
buffer[bIndex++] = c;
buffer[bIndex++] = sql[i];

switch (state)
{
Expand All @@ -231,16 +230,19 @@ bool IsGo()
case ParseState.BlockComment or ParseState.LineComment: // keep ignoring line comment
if (strip) Discard();
continue;
case ParseState.String when c == '\'' && IsString('E') && LookBehind() == '\\': // E'...\'...'
continue;
case ParseState.String when c == '\'' && (!IsString('"')):
if (LookBehind() != '\'') // [E]'...''...'
{
state = ParseState.None; // '.....'
}
// string-escape characters
case ParseState.String when c == '\'' && IsSingleQuoteString() && LookAhead() == '\'': // [?]'...''...'
case ParseState.String when c == '"' && IsString('"') && LookAhead() == '\"': // "...""..."
case ParseState.String when c == '\\' && IsString('E') && LookAhead() != '\0': // E'...\*...'
case ParseState.String when c == ']' && IsString('[') && LookAhead() == ']': // [...]]...]
// escaped or double-quote; move forwards immediately
Advance();
continue;
// end string
case ParseState.String when c == '"' && IsString('"'): // "....."
state = ParseState.None;
case ParseState.String when c == ']' && IsString('['): // [.....]
case ParseState.String when c == '\'' && IsSingleQuoteString(): // [?]'....'
state = ParseState.None;
continue;
case ParseState.String:
// ongoing string content
Expand Down Expand Up @@ -275,7 +277,7 @@ bool IsGo()

if (c == ';')
{
if (batch == Batch.Semicolon)
if (BatchSemicolon())
{
FlushBatch();
continue;
Expand All @@ -300,7 +302,7 @@ bool IsGo()

elementStartbIndex = bIndex;

if (c is '"' or '\'')
if (c is '"' or '\'' or '[')
{
// start a new string
state = ParseState.String;
Expand All @@ -309,7 +311,7 @@ bool IsGo()
}

if (SqlTools.ParameterPrefixCharacters.IndexOf(c) >= 0
&& IsToken(LookAhead())) // avoid altgt alt
&& IsToken(LookAhead()) && LookBehind() != c) // avoid @>, @@IDENTTIY etc
{
// start a new variable
state = ParseState.Variable;
Expand All @@ -326,7 +328,7 @@ bool IsGo()
// other arbitrary syntax - operators etc
}

if (batch == Batch.Semicolon)
if (BatchSemicolon())
{
// spoof a final ;
buffer[bIndex++] = ';';
Expand Down
30 changes: 19 additions & 11 deletions test/Dapper.AOT.Test/GeneralSqlParseTests.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Dapper.Internal;

using Dapper.Internal.SqlParsing;
using Xunit;

using static global::Dapper.SqlAnalysis.SqlSyntax;
namespace Dapper.AOT.Test;

public class GeneralSqlParseTests
Expand All @@ -23,7 +23,7 @@ public void BatchifyNonStrippedPostgresql() => Assert.Equal(
;
;
more
""", Batch.Semicolon, strip: false));
""", PostgreSql, strip: false));

[Fact]
public void BatchifyStrippedPostgresql() => Assert.Equal(
Expand All @@ -38,7 +38,7 @@ public void BatchifyStrippedPostgresql() => Assert.Equal(
;
;
more
""", Batch.Semicolon, strip: true));
""", PostgreSql, strip: true));

[Fact]
public void BatchifyNonStrippedSqlServer() => Assert.Equal(
Expand All @@ -58,7 +58,7 @@ public void BatchifyNonStrippedSqlServer() => Assert.Equal(
;
;
more
""", Batch.Go, strip: false));
""", SqlServer, strip: false));

[Fact]
public void BatchifyStrippedSqlServer() => Assert.Equal(
Expand All @@ -73,7 +73,7 @@ public void BatchifyStrippedSqlServer() => Assert.Equal(
;
;
more
""", Batch.Go, strip: true));
""", SqlServer, strip: true));

[Fact]
public void BatchifyNonStrippedSqlServer_Go() => Assert.Equal(
Expand All @@ -84,7 +84,7 @@ public void BatchifyNonStrippedSqlServer_Go() => Assert.Equal(
something
GO
something ' GO ' else;
""", Batch.Go, strip: true));
""", SqlServer, strip: true));

[Fact]
public void DetectArgs() => Assert.Equal(
Expand All @@ -95,18 +95,26 @@ public void DetectArgs() => Assert.Equal(
], GeneralSqlParser.Parse("""
select * from SomeTable
where Id = @foo and Name = '@bar'
""", Batch.Semicolon, strip: true));
""", PostgreSql, strip: true));

[Fact]
public void DetectArgsAndBatchify() => Assert.Equal(
[
new("select * from SomeTable where Id = @foo and Name = '@bar';", new("@foo", 35)),
new("insert Bar (Id) values ($1);", new("$1", 24)),
new("insert Bar (Id, X) values ($1, @@IDENTITY);", new("$1", 27)),
new("insert Blap (Id) values ($1, @foo);", new("$1", 25), new("@foo", 29)),
], GeneralSqlParser.Parse("""
select * from SomeTable where Id = @foo and Name = '@bar' -- $4
;
insert Bar (Id) /* @abc */ values ($1);
insert Bar (Id, X) /* @abc */ values ($1, @@IDENTITY);
insert Blap (Id) values ($1, @foo)
""", Batch.Semicolon, strip: true));
""", PostgreSql, strip: true));

[Fact]
public void StringEscapingSqlServer() => Assert.Equal(
[
new("select ' @a '' @b ' as [ @c ]] @d ];") // no vars
], GeneralSqlParser.Parse("""
select ' @a '' @b ' as [ @c ]] @d ];
""", SqlServer, strip: true));
}

0 comments on commit 251d9e0

Please sign in to comment.