From ca4317a6b5ffc4c78c711123b52dfda400d00eb5 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Wed, 30 Aug 2017 18:49:02 +0900 Subject: [PATCH] Fix encoding tuple argument (#155) Since Connections.encoders is broken by design. Tuple and list is escaped directly in `Connection.literal()`. Removed tuple and list from converters mapping. Fixes #145 --- MySQLdb/connections.py | 82 +++++++++++++++++++++--------------------- MySQLdb/converters.py | 8 ++--- MySQLdb/cursors.py | 10 +++--- _mysql.c | 2 ++ 4 files changed, 51 insertions(+), 51 deletions(-) diff --git a/MySQLdb/connections.py b/MySQLdb/connections.py index 8d9f2dd5..e5d9fa6f 100644 --- a/MySQLdb/connections.py +++ b/MySQLdb/connections.py @@ -186,7 +186,7 @@ class object, used to create cursors (keyword only) use_unicode = kwargs2.pop('use_unicode', use_unicode) sql_mode = kwargs2.pop('sql_mode', '') - binary_prefix = kwargs2.pop('binary_prefix', False) + self._binary_prefix = kwargs2.pop('binary_prefix', False) client_flag = kwargs.get('client_flag', 0) client_version = tuple([ numeric_part(n) for n in _mysql.get_client_info().split('.')[:2] ]) @@ -208,38 +208,28 @@ class object, used to create cursors (keyword only) self._server_version = tuple([ numeric_part(n) for n in self.get_server_info().split('.')[:2] ]) + self.encoding = 'ascii' # overriden in set_character_set() db = proxy(self) - def _get_string_literal(): - # Note: string_literal() is called for bytes object on Python 3 (via bytes_literal) - def string_literal(obj, dummy=None): - return db.string_literal(obj) - return string_literal - - def _get_unicode_literal(): - if PY2: - # unicode_literal is called for only unicode object. - def unicode_literal(u, dummy=None): - return db.string_literal(u.encode(unicode_literal.charset)) - else: - # unicode_literal() is called for arbitrary object. - def unicode_literal(u, dummy=None): - return db.string_literal(str(u).encode(unicode_literal.charset)) - return unicode_literal - - def _get_bytes_literal(): - def bytes_literal(obj, dummy=None): - return b'_binary' + db.string_literal(obj) - return bytes_literal - - def _get_string_decoder(): - def string_decoder(s): - return s.decode(string_decoder.charset) - return string_decoder - - string_literal = _get_string_literal() - self.unicode_literal = unicode_literal = _get_unicode_literal() - bytes_literal = _get_bytes_literal() - self.string_decoder = string_decoder = _get_string_decoder() + + # Note: string_literal() is called for bytes object on Python 3 (via bytes_literal) + def string_literal(obj, dummy=None): + return db.string_literal(obj) + + if PY2: + # unicode_literal is called for only unicode object. + def unicode_literal(u, dummy=None): + return db.string_literal(u.encode(db.encoding)) + else: + # unicode_literal() is called for arbitrary object. + def unicode_literal(u, dummy=None): + return db.string_literal(str(u).encode(db.encoding)) + + def bytes_literal(obj, dummy=None): + return b'_binary' + db.string_literal(obj) + + def string_decoder(s): + return s.decode(db.encoding) + if not charset: charset = self.character_set_name() self.set_character_set(charset) @@ -253,12 +243,7 @@ def string_decoder(s): self.converter[FIELD_TYPE.VARCHAR].append((None, string_decoder)) self.converter[FIELD_TYPE.BLOB].append((None, string_decoder)) - if binary_prefix: - self.encoders[bytes] = string_literal if PY2 else bytes_literal - self.encoders[bytearray] = bytes_literal - else: - self.encoders[bytes] = string_literal - + self.encoders[bytes] = string_literal self.encoders[unicode] = unicode_literal self._transactional = self.server_capabilities & CLIENT.TRANSACTIONS if self._transactional: @@ -305,6 +290,16 @@ def __exit__(self, exc, value, tb): else: self.commit() + def _bytes_literal(self, bs): + assert isinstance(bs, (bytes, bytearray)) + x = self.string_literal(bs) # x is escaped and quoted bytes + if self._binary_prefix: + return b'_binary' + x + return x + + def _tuple_literal(self, t, d): + return "(%s)" % (','.join(map(self.literal, t))) + def literal(self, o): """If o is a single object, returns an SQL literal as a string. If o is a non-string sequence, the items of the sequence are @@ -313,7 +308,14 @@ def literal(self, o): Non-standard. For internal use; do not use this in your applications. """ - s = self.escape(o, self.encoders) + if isinstance(o, bytearray): + s = self._bytes_literal(o) + elif not PY2 and isinstance(o, bytes): + s = self._bytes_literal(o) + elif isinstance(o, (tuple, list)): + s = self._tuple_literal(o) + else: + s = self.escape(o, self.encoders) # Python 3(~3.4) doesn't support % operation for bytes object. # We should decode it before using %. # Decoding with ascii and surrogateescape allows convert arbitrary @@ -360,8 +362,6 @@ def set_character_set(self, charset): raise NotSupportedError("server is too old to set charset") self.query('SET NAMES %s' % charset) self.store_result() - self.string_decoder.charset = py_charset - self.unicode_literal.charset = py_charset self.encoding = py_charset def set_sql_mode(self, sql_mode): diff --git a/MySQLdb/converters.py b/MySQLdb/converters.py index 8a9908d4..505a4df0 100644 --- a/MySQLdb/converters.py +++ b/MySQLdb/converters.py @@ -29,10 +29,9 @@ Don't modify conversions if you can avoid it. Instead, make copies (with the copy() method), modify the copies, and then pass them to MySQL.connect(). - """ -from _mysql import string_literal, escape_sequence, escape_dict, escape, NULL +from _mysql import string_literal, escape, NULL from MySQLdb.constants import FIELD_TYPE, FLAG from MySQLdb.times import * from MySQLdb.compat import PY2, long @@ -53,6 +52,7 @@ def Str2Set(s): return set([ i for i in s.split(',') if i ]) def Set2Str(s, d): + # Only support ascii string. Not tested. return string_literal(','.join(s), d) def Thing2Str(s, d): @@ -97,9 +97,6 @@ def quote_tuple(t, d): long: Thing2Str, float: Float2Str, NoneType: None2NULL, - tuple: quote_tuple, - list: quote_tuple, - dict: escape_dict, ArrayType: array2Str, bool: Bool2Str, Date: Thing2Literal, @@ -107,6 +104,7 @@ def quote_tuple(t, d): DateTimeDeltaType: DateTimeDelta2literal, str: Thing2Literal, # default set: Set2Str, + FIELD_TYPE.TINY: int, FIELD_TYPE.SHORT: int, FIELD_TYPE.LONG: long, diff --git a/MySQLdb/cursors.py b/MySQLdb/cursors.py index dfbd736b..3769ab5e 100644 --- a/MySQLdb/cursors.py +++ b/MySQLdb/cursors.py @@ -225,7 +225,7 @@ def execute(self, query, args=None): # db.literal(obj) always returns str. if PY2 and isinstance(query, unicode): - query = query.encode(db.unicode_literal.charset) + query = query.encode(db.encoding) if args is not None: if isinstance(args, dict): @@ -233,14 +233,14 @@ def execute(self, query, args=None): else: args = tuple(map(db.literal, args)) if not PY2 and isinstance(query, (bytes, bytearray)): - query = query.decode(db.unicode_literal.charset) + query = query.decode(db.encoding) try: query = query % args except TypeError as m: self.errorhandler(self, ProgrammingError, str(m)) if isinstance(query, unicode): - query = query.encode(db.unicode_literal.charset, 'surrogateescape') + query = query.encode(db.encoding, 'surrogateescape') res = None try: @@ -353,7 +353,7 @@ def callproc(self, procname, args=()): q = "SET @_%s_%d=%s" % (procname, index, db.literal(arg)) if isinstance(q, unicode): - q = q.encode(db.unicode_literal.charset, 'surrogateescape') + q = q.encode(db.encoding, 'surrogateescape') self._query(q) self.nextset() @@ -361,7 +361,7 @@ def callproc(self, procname, args=()): ','.join(['@_%s_%d' % (procname, i) for i in range(len(args))])) if isinstance(q, unicode): - q = q.encode(db.unicode_literal.charset, 'surrogateescape') + q = q.encode(db.encoding, 'surrogateescape') self._query(q) self._executed = q if not self._defer_warnings: diff --git a/_mysql.c b/_mysql.c index b0b8fdeb..36672373 100644 --- a/_mysql.c +++ b/_mysql.c @@ -2777,12 +2777,14 @@ _mysql_methods[] = { _mysql_escape__doc__ }, { + // deprecated. "escape_sequence", (PyCFunction)_mysql_escape_sequence, METH_VARARGS, _mysql_escape_sequence__doc__ }, { + // deprecated. "escape_dict", (PyCFunction)_mysql_escape_dict, METH_VARARGS,