Skip to content

Commit

Permalink
implement a shortcut for determining secure connections, now supporti…
Browse files Browse the repository at this point in the history
…ng unix sockets

ports PyMySQL/PyMySQL#696
  • Loading branch information
Nothing4You committed Jan 28, 2022
1 parent 2955052 commit 7519b18
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions aiomysql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def __init__(self, host="localhost", user=None, password="",
self._client_auth_plugin = auth_plugin
self._server_auth_plugin = ""
self._auth_plugin_used = ""
self._secure = False
self.server_public_key = server_public_key
self.salt = None

Expand Down Expand Up @@ -526,14 +527,15 @@ async def _connect(self):
# raise OperationalError(CR.CR_SERVER_GONE_ERROR,
# "MySQL server has gone away (%r)" % (e,))
try:
if self._unix_socket and self._host in ('localhost', '127.0.0.1'):
if self._unix_socket:
self._reader, self._writer = await \
asyncio.wait_for(
_open_unix_connection(
self._unix_socket),
timeout=self.connect_timeout)
self.host_info = "Localhost via UNIX socket: " + \
self._unix_socket
self._secure = True
else:
self._reader, self._writer = await \
asyncio.wait_for(
Expand Down Expand Up @@ -743,7 +745,7 @@ async def _request_authentication(self):
if self.user is None:
raise ValueError("Did not specify a username")

if self._ssl_context:
if self._ssl_context and self.server_capabilities & CLIENT.SSL:
# capablities, max packet, charset
data = struct.pack('<IIB', self.client_flag, 16777216, 33)
data += b'\x00' * (32 - len(data))
Expand All @@ -770,6 +772,8 @@ async def _request_authentication(self):
server_hostname=self._host
)

self._secure = True

charset_id = charset_by_name(self.charset).id
if isinstance(self.user, str):
_user = self.user.encode(self.encoding)
Expand Down Expand Up @@ -798,7 +802,7 @@ async def _request_authentication(self):
)
# Else: empty password
elif auth_plugin == 'sha256_password':
if self._ssl_context and self.server_capabilities & CLIENT.SSL:
if self._secure:
authresp = self._password.encode('latin1') + b'\0'
elif self._password:
authresp = b'\1' # request public key
Expand Down Expand Up @@ -960,7 +964,7 @@ async def caching_sha2_password_auth(self, pkt):

logger.debug("caching sha2: Trying full auth...")

if self._ssl_context:
if self._secure:
logger.debug("caching sha2: Sending plain "
"password via secure connection")
self.write_packet(self._password.encode('latin1') + b'\0')
Expand Down Expand Up @@ -991,7 +995,7 @@ async def caching_sha2_password_auth(self, pkt):
pkt.check_error()

async def sha256_password_auth(self, pkt):
if self._ssl_context:
if self._secure:
logger.debug("sha256: Sending plain password")
data = self._password.encode('latin1') + b'\0'
self.write_packet(data)
Expand Down

0 comments on commit 7519b18

Please sign in to comment.