diff --git a/napalm/base/base.py b/napalm/base/base.py index 3e4f15e60..7a34e6fc1 100644 --- a/napalm/base/base.py +++ b/napalm/base/base.py @@ -16,14 +16,16 @@ from __future__ import print_function from __future__ import unicode_literals +import sys + +from netmiko import ConnectHandler, NetMikoTimeoutException + # local modules import napalm.base.exceptions -from napalm.base.exceptions import ConnectionException import napalm.base.helpers from napalm.base import constants as c from napalm.base import validate - -from netmiko import ConnectHandler, NetMikoTimeoutException +from napalm.base.exceptions import ConnectionException class NetworkDriver(object): @@ -44,8 +46,15 @@ def __init__(self, hostname, username, password, timeout=60, optional_args=None) raise NotImplementedError def __enter__(self): - self.open() - return self + try: + self.open() + return self + except: + # Swallow exception if __exit__ returns a True value + if self.__exit__(*sys.exc_info()): + pass + else: + raise def __exit__(self, exc_type, exc_value, exc_traceback): self.close() @@ -95,8 +104,9 @@ def _netmiko_open(self, device_type, netmiko_optional_args=None): def _netmiko_close(self): """Standardized method of closing a Netmiko connection.""" - self.device.disconnect() - self._netmiko_device = None + if getattr(self, "_netmiko_device", None): + self._netmiko_device.disconnect() + self._netmiko_device = None self.device = None def open(self):