From cfc75536ca9221796594404750c34bf039761891 Mon Sep 17 00:00:00 2001 From: Nickolai Novik Date: Fri, 8 Jan 2016 18:42:29 +0200 Subject: [PATCH 1/3] proper fix fo multiple results issure proted from pymysql --- aiomysql/connection.py | 12 +++++++----- tests/test_connection.py | 15 +++++++++++++-- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/aiomysql/connection.py b/aiomysql/connection.py index ee17eb9b..4cd9cced 100644 --- a/aiomysql/connection.py +++ b/aiomysql/connection.py @@ -372,9 +372,6 @@ def cursor(self, cursor=None): @asyncio.coroutine def query(self, sql, unbuffered=False): # logger.debug("DEBUG: sending query: %s", _convert_to_str(sql)) - if self._result is not None and self._result.has_next: - raise ProgrammingError("Previous results have not been fetched. " - "You may not close previous cursor.") if isinstance(sql, str): sql = sql.encode(self.encoding, 'surrogateescape') yield from self._execute_command(COM_QUERY, sql) @@ -550,8 +547,13 @@ def _execute_command(self, command, sql): # If the last query was unbuffered, make sure it finishes before # sending new commands - if self._result is not None and self._result.unbuffered_active: - yield from self._result._finish_unbuffered_query() + if self._result is not None: + if self._result.unbuffered_active: + warnings.warn("Previous unbuffered result was left incomplete") + self._result._finish_unbuffered_query() + while self._result.has_next: + yield from self.next_result() + self._result = None if isinstance(sql, str): sql = sql.encode(self._encoding) diff --git a/tests/test_connection.py b/tests/test_connection.py index d5065252..9734e112 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -250,5 +250,16 @@ def test_previous_cursor_not_closed(self): cur1 = yield from conn.cursor() yield from cur1.execute("SELECT 1; SELECT 2") cur2 = yield from conn.cursor() - with self.assertRaises(aiomysql.ProgrammingError): - yield from cur2.execute("SELECT 3") + yield from cur2.execute("SELECT 3;") + resp = yield from cur2.fetchone() + self.assertEqual(resp[0], 3) + + @run_until_complete + def test_commit_during_multi_result(self): + conn = yield from self.connect() + cur = yield from conn.cursor() + yield from cur.execute("SELECT 1; SELECT 2;") + yield from conn.commit() + yield from cur.execute("SELECT 3;") + resp = yield from cur.fetchone() + self.assertEqual(resp[0], 3) From d155d7232e71c7cfb2c43c1b8adfb1f85ab93b20 Mon Sep 17 00:00:00 2001 From: Nickolai Novik Date: Fri, 8 Jan 2016 20:30:48 +0200 Subject: [PATCH 2/3] do not do bulk import --- aiomysql/connection.py | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/aiomysql/connection.py b/aiomysql/connection.py index 4cd9cced..438c663e 100644 --- a/aiomysql/connection.py +++ b/aiomysql/connection.py @@ -14,8 +14,8 @@ from pymysql.charset import charset_by_name, charset_by_id from pymysql.constants import SERVER_STATUS -from pymysql.constants.CLIENT import * # noqa -from pymysql.constants.COMMAND import * # noqa +from pymysql.constants import CLIENT +from pymysql.constants import COMMAND from pymysql.util import byte2int, int2byte from pymysql.converters import escape_item, encoders, decoders, escape_string from pymysql.err import (Warning, Error, @@ -177,12 +177,12 @@ def __init__(self, host="localhost", user=None, password="", self._encoding = charset_by_name(self._charset).encoding if local_infile: - client_flag |= LOCAL_FILES + client_flag |= CLIENT.LOCAL_FILES - client_flag |= CAPABILITIES - client_flag |= MULTI_STATEMENTS + client_flag |= CLIENT.CAPABILITIES + client_flag |= CLIENT.MULTI_STATEMENTS if self._db: - client_flag |= CONNECT_WITH_DB + client_flag |= CLIENT.CONNECT_WITH_DB self.client_flag = client_flag self.cursorclass = cursorclass @@ -268,7 +268,7 @@ def ensure_closed(self): if self._writer is None: # connection has been closed return - send_data = struct.pack('= 5: - self.client_flag |= MULTI_RESULTS + self.client_flag |= CLIENT.MULTI_RESULTS if self._user is None: raise ValueError("Did not specify a username") @@ -782,7 +782,8 @@ def _read_load_local_packet(self, first_packet): @asyncio.coroutine def _print_warnings(self): - yield from self.connection._execute_command(COM_QUERY, 'SHOW WARNINGS') + yield from self.connection._execute_command( + COMMAND.COM_QUERY, 'SHOW WARNINGS') yield from self.read() if self.rows: message = "\n" From 3b91d38d0e050c03671309ca6b9465b3d197ab66 Mon Sep 17 00:00:00 2001 From: Nickolai Novik Date: Fri, 8 Jan 2016 21:23:51 +0200 Subject: [PATCH 3/3] fix test --- tests/test_sscursor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_sscursor.py b/tests/test_sscursor.py index 74624150..91d705ff 100644 --- a/tests/test_sscursor.py +++ b/tests/test_sscursor.py @@ -87,7 +87,9 @@ def test_sscursor_fetchmany(self): self.assertEqual(len(fetched_data), 2, 'fetchmany failed. Number of rows does not match') + yield from cursor.close() # test default fetchmany size + cursor = yield from conn.cursor(SSCursor) yield from cursor.execute('SELECT * FROM tz_data;') fetched_data = yield from cursor.fetchmany() self.assertEqual(len(fetched_data), 1)