Skip to content

Commit

Permalink
Merge pull request #169 from slimta/fix
Browse files Browse the repository at this point in the history
Various bug fixes from issues
  • Loading branch information
icgood authored Feb 14, 2021
2 parents 45ff459 + 98c2b9b commit d7008cc
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 28 deletions.
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
language: python
python:
- "2.7"
- "3.5"
- "3.6"
- "3.7"
- "3.8"
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
license = f.read()

setup(name='python-slimta',
version='4.1.1',
version='4.2.0',
author='Ian Good',
author_email='[email protected]',
description='Lightweight, asynchronous SMTP libraries.',
Expand Down Expand Up @@ -56,7 +56,6 @@
'License :: OSI Approved :: MIT License',
'Programming Language :: Python',
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8'])
Expand Down
5 changes: 4 additions & 1 deletion slimta/edge/smtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class SmtpValidators(object):
- ``handle_rset(reply)``: Called before replying to an RSET command.
- ``handle_tls()``: Called after a successful TLS handshake. This may be at
the beginning of the session or after a `STARTTLS` command.
- ``handle_tls2(ssl_socket)``: Identical to ``handle_tls()`` except the new
:class:`~ssl.SSLSocket` is passed in as an argument.
:param session: When sub-classes are instantiated, instances are passed
this object, stored and described in :attr:`session` below,
Expand Down Expand Up @@ -137,8 +139,9 @@ def HELO(self, reply, helo_as):
self.ehlo_as = helo_as
self.envelope = None

def TLSHANDSHAKE(self):
def TLSHANDSHAKE2(self, ssl_socket):
self._call_validator('tls')
self._call_validator('tls2', ssl_socket)
self.security = 'TLS'

def AUTH(self, reply, creds):
Expand Down
7 changes: 6 additions & 1 deletion slimta/smtp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@ class BadReply(SmtpError):
"""

def __init__(self, data):
super(BadReply, self).__init__('Bad SMTP reply from server.')
if data:
data_str = data.decode('utf-8', 'replace')
msg = 'Bad SMTP reply from server:\r\n' + data_str
else:
msg = 'Bad SMTP reply from server.'
super(BadReply, self).__init__(msg)
self.data = data


Expand Down
1 change: 1 addition & 0 deletions slimta/smtp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def _encrypt_session(self):
if not self.io.encrypt_socket_server(self.context):
return False
self._call_custom_handler('TLSHANDSHAKE')
self._call_custom_handler('TLSHANDSHAKE2', self.io.socket)
return True

def _check_close_code(self, reply):
Expand Down
52 changes: 44 additions & 8 deletions slimta/util/dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from __future__ import absolute_import

from collections import OrderedDict
from functools import partial

import pycares
Expand Down Expand Up @@ -108,24 +109,59 @@ def _result_cb(cls, result, answer, errno):
else:
result.set(answer)

@classmethod
def _distinct(cls, read_fds, write_fds):
seen = set()
for fd in read_fds:
if fd not in seen:
yield fd
seen.add(fd)
for fd in write_fds:
if fd not in seen:
yield fd
seen.add(fd)

@classmethod
def _register_fds(cls, poll, prev_fds_map):
# we must mimic the behavior of pycares sock_state_cb to maintain
# compatibility with custom DNSResolver.channel objects.
fds_map = OrderedDict()
_read_fds, _write_fds = cls._channel.getsock()
read_fds = set(_read_fds)
write_fds = set(_write_fds)
for fd in cls._distinct(_read_fds, _write_fds):
event = 0
if fd in read_fds:
event |= select.POLLIN
if fd in write_fds:
event |= select.POLLOUT
fds_map[fd] = event
prev_event = prev_fds_map.pop(fd, 0)
if event != prev_event:
poll.register(fd, event)
for fd in prev_fds_map:
poll.unregister(fd)
return fds_map

@classmethod
def _wait_channel(cls):
poll = select.poll()
fds_map = OrderedDict()
try:
while True:
read_fds, write_fds = cls._channel.getsock()
if not read_fds and not write_fds:
fds_map = cls._register_fds(poll, fds_map)
if not fds_map:
break
timeout = cls._channel.timeout()
if not timeout:
cls._channel.process_fd(pycares.ARES_SOCKET_BAD,
pycares.ARES_SOCKET_BAD)
continue
rlist, wlist, xlist = select.select(
read_fds, write_fds, [], timeout)
for fd in rlist:
cls._channel.process_fd(fd, pycares.ARES_SOCKET_BAD)
for fd in wlist:
cls._channel.process_fd(pycares.ARES_SOCKET_BAD, fd)
for fd, event in poll.poll(timeout):
if event & (select.POLLIN | select.POLLPRI):
cls._channel.process_fd(fd, pycares.ARES_SOCKET_BAD)
if event & select.POLLOUT:
cls._channel.process_fd(pycares.ARES_SOCKET_BAD, fd)
except Exception:
logging.log_exception(__name__)
cls._channel.cancel()
Expand Down
1 change: 0 additions & 1 deletion slimta/util/dnsbl.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def get(self, ip, timeout=None, strict=False):
if exc.errno == ARES_ENOTFOUND:
return False
logging.log_exception(__name__, query=query)
return not strict
else:
return True
return strict
Expand Down
5 changes: 4 additions & 1 deletion test/test_slimta_edge_smtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from mox3.mox import MoxTestBase, IsA, IgnoreArg
import gevent
from gevent.socket import create_connection
from gevent.ssl import SSLSocket

from slimta.edge.smtp import SmtpEdge, SmtpSession
from slimta.envelope import Envelope
Expand Down Expand Up @@ -47,17 +48,19 @@ def test_extended_handshake(self):
creds = self.mox.CreateMockAnything()
creds.authcid = 'testuser'
creds.authzid = 'testzid'
ssl_sock = self.mox.CreateMock(SSLSocket)
mock = self.mox.CreateMockAnything()
mock.__call__(IsA(SmtpSession)).AndReturn(mock)
mock.handle_banner(IsA(Reply), ('127.0.0.1', 0))
mock.handle_ehlo(IsA(Reply), 'there')
mock.handle_tls()
mock.handle_tls2(IsA(SSLSocket))
mock.handle_auth(IsA(Reply), creds)
self.mox.ReplayAll()
h = SmtpSession(('127.0.0.1', 0), mock, None)
h.BANNER_(Reply('220'))
h.EHLO(Reply('250'), 'there')
h.TLSHANDSHAKE()
h.TLSHANDSHAKE2(ssl_sock)
h.AUTH(Reply('235'), creds)
self.assertEqual('there', h.ehlo_as)
self.assertTrue(h.extended_smtp)
Expand Down
40 changes: 27 additions & 13 deletions test/test_slimta_util_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,28 +42,42 @@ def test_query(self):

def test_wait_channel(self):
DNSResolver._channel = channel = self.mox.CreateMockAnything()
self.mox.StubOutWithMock(select, 'select')
channel.getsock().AndReturn(('read', 'write'))
poll = self.mox.CreateMockAnything()
self.mox.StubOutWithMock(select, 'poll')
select.poll().AndReturn(poll)
channel.getsock().AndReturn(([1, 2], [2, 3]))
channel.timeout().AndReturn(1.0)
select.select('read', 'write', [], 1.0).AndReturn(
([1, 2, 3], [4, 5, 6], None))
for fd in [1, 2, 3]:
channel.process_fd(fd, pycares.ARES_SOCKET_BAD)
for fd in [4, 5, 6]:
channel.process_fd(pycares.ARES_SOCKET_BAD, fd)
channel.getsock().AndReturn(('read', 'write'))
poll.register(1, select.POLLIN)
poll.register(2, select.POLLIN | select.POLLOUT)
poll.register(3, select.POLLOUT)
poll.poll(1.0).AndReturn([(1, select.POLLIN), (3, select.POLLOUT)])
channel.process_fd(1, pycares.ARES_SOCKET_BAD)
channel.process_fd(pycares.ARES_SOCKET_BAD, 3)
channel.getsock().AndReturn(([1, 3], [4]))
channel.timeout().AndReturn(1.0)
poll.register(3, select.POLLIN)
poll.register(4, select.POLLOUT)
poll.unregister(2)
poll.poll(1.0).AndReturn([])
channel.getsock().AndReturn(([1, 3], [4]))
channel.timeout().AndReturn(None)
channel.process_fd(pycares.ARES_SOCKET_BAD, pycares.ARES_SOCKET_BAD)
channel.getsock().AndReturn((None, None))
channel.getsock().AndReturn(([], []))
poll.unregister(1)
poll.unregister(3)
poll.unregister(4)
self.mox.ReplayAll()
DNSResolver._wait_channel()

def test_wait_channel_error(self):
DNSResolver._channel = channel = self.mox.CreateMockAnything()
self.mox.StubOutWithMock(select, 'select')
channel.getsock().AndReturn(('read', 'write'))
poll = self.mox.CreateMockAnything()
self.mox.StubOutWithMock(select, 'poll')
select.poll().AndReturn(poll)
channel.getsock().AndReturn(([1], []))
channel.timeout().AndReturn(1.0)
select.select('read', 'write', [], 1.0).AndRaise(ValueError(13))
poll.register(1, select.POLLIN).AndReturn(None)
poll.poll(1.0).AndRaise(ValueError(13))
channel.cancel()
self.mox.ReplayAll()
with self.assertRaises(ValueError):
Expand Down

0 comments on commit d7008cc

Please sign in to comment.