Skip to content

Commit

Permalink
Fix encoding tuple argument (#155)
Browse files Browse the repository at this point in the history
Since Connections.encoders is broken by design.
Tuple and list is escaped directly in `Connection.literal()`.
Removed tuple and list from converters mapping.

Fixes #145
  • Loading branch information
methane authored Aug 30, 2017
1 parent e39df07 commit ca4317a
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 51 deletions.
82 changes: 41 additions & 41 deletions MySQLdb/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ class object, used to create cursors (keyword only)

use_unicode = kwargs2.pop('use_unicode', use_unicode)
sql_mode = kwargs2.pop('sql_mode', '')
binary_prefix = kwargs2.pop('binary_prefix', False)
self._binary_prefix = kwargs2.pop('binary_prefix', False)

client_flag = kwargs.get('client_flag', 0)
client_version = tuple([ numeric_part(n) for n in _mysql.get_client_info().split('.')[:2] ])
Expand All @@ -208,38 +208,28 @@ class object, used to create cursors (keyword only)

self._server_version = tuple([ numeric_part(n) for n in self.get_server_info().split('.')[:2] ])

self.encoding = 'ascii' # overriden in set_character_set()
db = proxy(self)
def _get_string_literal():
# Note: string_literal() is called for bytes object on Python 3 (via bytes_literal)
def string_literal(obj, dummy=None):
return db.string_literal(obj)
return string_literal

def _get_unicode_literal():
if PY2:
# unicode_literal is called for only unicode object.
def unicode_literal(u, dummy=None):
return db.string_literal(u.encode(unicode_literal.charset))
else:
# unicode_literal() is called for arbitrary object.
def unicode_literal(u, dummy=None):
return db.string_literal(str(u).encode(unicode_literal.charset))
return unicode_literal

def _get_bytes_literal():
def bytes_literal(obj, dummy=None):
return b'_binary' + db.string_literal(obj)
return bytes_literal

def _get_string_decoder():
def string_decoder(s):
return s.decode(string_decoder.charset)
return string_decoder

string_literal = _get_string_literal()
self.unicode_literal = unicode_literal = _get_unicode_literal()
bytes_literal = _get_bytes_literal()
self.string_decoder = string_decoder = _get_string_decoder()

# Note: string_literal() is called for bytes object on Python 3 (via bytes_literal)
def string_literal(obj, dummy=None):
return db.string_literal(obj)

if PY2:
# unicode_literal is called for only unicode object.
def unicode_literal(u, dummy=None):
return db.string_literal(u.encode(db.encoding))
else:
# unicode_literal() is called for arbitrary object.
def unicode_literal(u, dummy=None):
return db.string_literal(str(u).encode(db.encoding))

def bytes_literal(obj, dummy=None):
return b'_binary' + db.string_literal(obj)

def string_decoder(s):
return s.decode(db.encoding)

if not charset:
charset = self.character_set_name()
self.set_character_set(charset)
Expand All @@ -253,12 +243,7 @@ def string_decoder(s):
self.converter[FIELD_TYPE.VARCHAR].append((None, string_decoder))
self.converter[FIELD_TYPE.BLOB].append((None, string_decoder))

if binary_prefix:
self.encoders[bytes] = string_literal if PY2 else bytes_literal
self.encoders[bytearray] = bytes_literal
else:
self.encoders[bytes] = string_literal

self.encoders[bytes] = string_literal
self.encoders[unicode] = unicode_literal
self._transactional = self.server_capabilities & CLIENT.TRANSACTIONS
if self._transactional:
Expand Down Expand Up @@ -305,6 +290,16 @@ def __exit__(self, exc, value, tb):
else:
self.commit()

def _bytes_literal(self, bs):
assert isinstance(bs, (bytes, bytearray))
x = self.string_literal(bs) # x is escaped and quoted bytes
if self._binary_prefix:
return b'_binary' + x
return x

def _tuple_literal(self, t, d):
return "(%s)" % (','.join(map(self.literal, t)))

def literal(self, o):
"""If o is a single object, returns an SQL literal as a string.
If o is a non-string sequence, the items of the sequence are
Expand All @@ -313,7 +308,14 @@ def literal(self, o):
Non-standard. For internal use; do not use this in your
applications.
"""
s = self.escape(o, self.encoders)
if isinstance(o, bytearray):
s = self._bytes_literal(o)
elif not PY2 and isinstance(o, bytes):
s = self._bytes_literal(o)
elif isinstance(o, (tuple, list)):
s = self._tuple_literal(o)
else:
s = self.escape(o, self.encoders)
# Python 3(~3.4) doesn't support % operation for bytes object.
# We should decode it before using %.
# Decoding with ascii and surrogateescape allows convert arbitrary
Expand Down Expand Up @@ -360,8 +362,6 @@ def set_character_set(self, charset):
raise NotSupportedError("server is too old to set charset")
self.query('SET NAMES %s' % charset)
self.store_result()
self.string_decoder.charset = py_charset
self.unicode_literal.charset = py_charset
self.encoding = py_charset

def set_sql_mode(self, sql_mode):
Expand Down
8 changes: 3 additions & 5 deletions MySQLdb/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,9 @@
Don't modify conversions if you can avoid it. Instead, make copies
(with the copy() method), modify the copies, and then pass them to
MySQL.connect().
"""

from _mysql import string_literal, escape_sequence, escape_dict, escape, NULL
from _mysql import string_literal, escape, NULL
from MySQLdb.constants import FIELD_TYPE, FLAG
from MySQLdb.times import *
from MySQLdb.compat import PY2, long
Expand All @@ -53,6 +52,7 @@ def Str2Set(s):
return set([ i for i in s.split(',') if i ])

def Set2Str(s, d):
# Only support ascii string. Not tested.
return string_literal(','.join(s), d)

def Thing2Str(s, d):
Expand Down Expand Up @@ -97,16 +97,14 @@ def quote_tuple(t, d):
long: Thing2Str,
float: Float2Str,
NoneType: None2NULL,
tuple: quote_tuple,
list: quote_tuple,
dict: escape_dict,
ArrayType: array2Str,
bool: Bool2Str,
Date: Thing2Literal,
DateTimeType: DateTime2literal,
DateTimeDeltaType: DateTimeDelta2literal,
str: Thing2Literal, # default
set: Set2Str,

FIELD_TYPE.TINY: int,
FIELD_TYPE.SHORT: int,
FIELD_TYPE.LONG: long,
Expand Down
10 changes: 5 additions & 5 deletions MySQLdb/cursors.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,22 +225,22 @@ def execute(self, query, args=None):
# db.literal(obj) always returns str.

if PY2 and isinstance(query, unicode):
query = query.encode(db.unicode_literal.charset)
query = query.encode(db.encoding)

if args is not None:
if isinstance(args, dict):
args = dict((key, db.literal(item)) for key, item in args.items())
else:
args = tuple(map(db.literal, args))
if not PY2 and isinstance(query, (bytes, bytearray)):
query = query.decode(db.unicode_literal.charset)
query = query.decode(db.encoding)
try:
query = query % args
except TypeError as m:
self.errorhandler(self, ProgrammingError, str(m))

if isinstance(query, unicode):
query = query.encode(db.unicode_literal.charset, 'surrogateescape')
query = query.encode(db.encoding, 'surrogateescape')

res = None
try:
Expand Down Expand Up @@ -353,15 +353,15 @@ def callproc(self, procname, args=()):
q = "SET @_%s_%d=%s" % (procname, index,
db.literal(arg))
if isinstance(q, unicode):
q = q.encode(db.unicode_literal.charset, 'surrogateescape')
q = q.encode(db.encoding, 'surrogateescape')
self._query(q)
self.nextset()

q = "CALL %s(%s)" % (procname,
','.join(['@_%s_%d' % (procname, i)
for i in range(len(args))]))
if isinstance(q, unicode):
q = q.encode(db.unicode_literal.charset, 'surrogateescape')
q = q.encode(db.encoding, 'surrogateescape')
self._query(q)
self._executed = q
if not self._defer_warnings:
Expand Down
2 changes: 2 additions & 0 deletions _mysql.c
Original file line number Diff line number Diff line change
Expand Up @@ -2777,12 +2777,14 @@ _mysql_methods[] = {
_mysql_escape__doc__
},
{
// deprecated.
"escape_sequence",
(PyCFunction)_mysql_escape_sequence,
METH_VARARGS,
_mysql_escape_sequence__doc__
},
{
// deprecated.
"escape_dict",
(PyCFunction)_mysql_escape_dict,
METH_VARARGS,
Expand Down

0 comments on commit ca4317a

Please sign in to comment.