Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proper fix fo multiple results issure ported from pymysql #52

Merged
merged 3 commits into from
Jan 8, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 27 additions & 24 deletions aiomysql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

from pymysql.charset import charset_by_name, charset_by_id
from pymysql.constants import SERVER_STATUS
from pymysql.constants.CLIENT import * # noqa
from pymysql.constants.COMMAND import * # noqa
from pymysql.constants import CLIENT
from pymysql.constants import COMMAND
from pymysql.util import byte2int, int2byte
from pymysql.converters import escape_item, encoders, decoders, escape_string
from pymysql.err import (Warning, Error,
Expand Down Expand Up @@ -177,12 +177,12 @@ def __init__(self, host="localhost", user=None, password="",
self._encoding = charset_by_name(self._charset).encoding

if local_infile:
client_flag |= LOCAL_FILES
client_flag |= CLIENT.LOCAL_FILES

client_flag |= CAPABILITIES
client_flag |= MULTI_STATEMENTS
client_flag |= CLIENT.CAPABILITIES
client_flag |= CLIENT.MULTI_STATEMENTS
if self._db:
client_flag |= CONNECT_WITH_DB
client_flag |= CLIENT.CONNECT_WITH_DB
self.client_flag = client_flag

self.cursorclass = cursorclass
Expand Down Expand Up @@ -268,7 +268,7 @@ def ensure_closed(self):
if self._writer is None:
# connection has been closed
return
send_data = struct.pack('<i', 1) + int2byte(COM_QUIT)
send_data = struct.pack('<i', 1) + int2byte(COMMAND.COM_QUIT)
self._writer.write(send_data)
yield from self._writer.drain()
self.close()
Expand Down Expand Up @@ -305,32 +305,32 @@ def _read_ok_packet(self):
def _send_autocommit_mode(self):
"""Set whether or not to commit after every execute() """
yield from self._execute_command(
COM_QUERY,
COMMAND.COM_QUERY,
"SET AUTOCOMMIT = %s" % self.escape(self.autocommit_mode))
yield from self._read_ok_packet()

@asyncio.coroutine
def begin(self):
"""Begin transaction."""
yield from self._execute_command(COM_QUERY, "BEGIN")
yield from self._execute_command(COMMAND.COM_QUERY, "BEGIN")
yield from self._read_ok_packet()

@asyncio.coroutine
def commit(self):
"""Commit changes to stable storage."""
yield from self._execute_command(COM_QUERY, "COMMIT")
yield from self._execute_command(COMMAND.COM_QUERY, "COMMIT")
yield from self._read_ok_packet()

@asyncio.coroutine
def rollback(self):
"""Roll back the current transaction."""
yield from self._execute_command(COM_QUERY, "ROLLBACK")
yield from self._execute_command(COMMAND.COM_QUERY, "ROLLBACK")
yield from self._read_ok_packet()

@asyncio.coroutine
def select_db(self, db):
"""Set current db"""
yield from self._execute_command(COM_INIT_DB, db)
yield from self._execute_command(COMMAND.COM_INIT_DB, db)
yield from self._read_ok_packet()

def escape(self, obj):
Expand Down Expand Up @@ -372,12 +372,9 @@ def cursor(self, cursor=None):
@asyncio.coroutine
def query(self, sql, unbuffered=False):
# logger.debug("DEBUG: sending query: %s", _convert_to_str(sql))
if self._result is not None and self._result.has_next:
raise ProgrammingError("Previous results have not been fetched. "
"You may not close previous cursor.")
if isinstance(sql, str):
sql = sql.encode(self.encoding, 'surrogateescape')
yield from self._execute_command(COM_QUERY, sql)
yield from self._execute_command(COMMAND.COM_QUERY, sql)
yield from self._read_query_result(unbuffered=unbuffered)
return self._affected_rows

Expand All @@ -392,7 +389,7 @@ def affected_rows(self):
@asyncio.coroutine
def kill(self, thread_id):
arg = struct.pack('<I', thread_id)
yield from self._execute_command(COM_PROCESS_KILL, arg)
yield from self._execute_command(COMMAND.COM_PROCESS_KILL, arg)
yield from self._read_ok_packet()

@asyncio.coroutine
Expand All @@ -405,7 +402,7 @@ def ping(self, reconnect=True):
else:
raise Error("Already closed")
try:
yield from self._execute_command(COM_PING, "")
yield from self._execute_command(COMMAND.COM_PING, "")
yield from self._read_ok_packet()
except Exception:
if reconnect:
Expand All @@ -419,7 +416,7 @@ def set_charset(self, charset):
"""Sets the character set for the current connection"""
# Make sure charset is supported.
encoding = charset_by_name(charset).encoding
yield from self._execute_command(COM_QUERY, "SET NAMES %s"
yield from self._execute_command(COMMAND.COM_QUERY, "SET NAMES %s"
% self.escape(charset))
yield from self._read_packet()
self._charset = charset
Expand Down Expand Up @@ -550,8 +547,13 @@ def _execute_command(self, command, sql):

# If the last query was unbuffered, make sure it finishes before
# sending new commands
if self._result is not None and self._result.unbuffered_active:
yield from self._result._finish_unbuffered_query()
if self._result is not None:
if self._result.unbuffered_active:
warnings.warn("Previous unbuffered result was left incomplete")
self._result._finish_unbuffered_query()
while self._result.has_next:
yield from self.next_result()
self._result = None

if isinstance(sql, str):
sql = sql.encode(self._encoding)
Expand Down Expand Up @@ -579,9 +581,9 @@ def _execute_command(self, command, sql):

@asyncio.coroutine
def _request_authentication(self):
self.client_flag |= CAPABILITIES
self.client_flag |= CLIENT.CAPABILITIES
if int(self.server_version.split('.', 1)[0]) >= 5:
self.client_flag |= MULTI_RESULTS
self.client_flag |= CLIENT.MULTI_RESULTS

if self._user is None:
raise ValueError("Did not specify a username")
Expand Down Expand Up @@ -780,7 +782,8 @@ def _read_load_local_packet(self, first_packet):

@asyncio.coroutine
def _print_warnings(self):
yield from self.connection._execute_command(COM_QUERY, 'SHOW WARNINGS')
yield from self.connection._execute_command(
COMMAND.COM_QUERY, 'SHOW WARNINGS')
yield from self.read()
if self.rows:
message = "\n"
Expand Down
15 changes: 13 additions & 2 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,5 +250,16 @@ def test_previous_cursor_not_closed(self):
cur1 = yield from conn.cursor()
yield from cur1.execute("SELECT 1; SELECT 2")
cur2 = yield from conn.cursor()
with self.assertRaises(aiomysql.ProgrammingError):
yield from cur2.execute("SELECT 3")
yield from cur2.execute("SELECT 3;")
resp = yield from cur2.fetchone()
self.assertEqual(resp[0], 3)

@run_until_complete
def test_commit_during_multi_result(self):
conn = yield from self.connect()
cur = yield from conn.cursor()
yield from cur.execute("SELECT 1; SELECT 2;")
yield from conn.commit()
yield from cur.execute("SELECT 3;")
resp = yield from cur.fetchone()
self.assertEqual(resp[0], 3)
2 changes: 2 additions & 0 deletions tests/test_sscursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def test_sscursor_fetchmany(self):
self.assertEqual(len(fetched_data), 2,
'fetchmany failed. Number of rows does not match')

yield from cursor.close()
# test default fetchmany size
cursor = yield from conn.cursor(SSCursor)
yield from cursor.execute('SELECT * FROM tz_data;')
fetched_data = yield from cursor.fetchmany()
self.assertEqual(len(fetched_data), 1)
Expand Down