diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 3634c97faa..7dc8a92c45 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -411,6 +411,11 @@ functions: if [ -n "${SETDEFAULTENCODING}" ]; then export SETDEFAULTENCODING="${SETDEFAULTENCODING}" fi + if [ -n "${test_loadbalancer}" ]; then + export TEST_LOADBALANCER=1 + export SINGLE_MONGOS_LB_URI="${SINGLE_MONGOS_LB_URI}" + export MULTI_MONGOS_LB_URI="${MULTI_MONGOS_LB_URI}" + fi PYTHON_BINARY=${PYTHON_BINARY} \ GREEN_FRAMEWORK=${GREEN_FRAMEWORK} \ @@ -788,6 +793,22 @@ functions: -v \ --fault revoked + "run load-balancer": + - command: shell.exec + params: + script: | + DRIVERS_TOOLS=${DRIVERS_TOOLS} MONGODB_URI=${MONGODB_URI} bash ${DRIVERS_TOOLS}/.evergreen/run-load-balancer.sh start + - command: expansions.update + params: + file: lb-expansion.yml + + "stop load-balancer": + - command: shell.exec + params: + script: | + cd ${DRIVERS_TOOLS}/.evergreen + DRIVERS_TOOLS=${DRIVERS_TOOLS} bash ${DRIVERS_TOOLS}/.evergreen/run-load-balancer.sh stop + "teardown_docker": - command: shell.exec params: @@ -1537,6 +1558,13 @@ tasks: - func: "run aws auth test with aws EC2 credentials" - func: "run aws ECS auth test" + - name: load-balancer-test + commands: + - func: "bootstrap mongo-orchestration" + vars: + TOPOLOGY: "sharded_cluster" + - func: "run load-balancer" + - func: "run tests" # }}} - name: "coverage-report" tags: ["coverage"] @@ -1941,6 +1969,16 @@ axes: variables: ORCHESTRATION_FILE: "versioned-api-testing.json" + # Run load balancer tests? + - id: loadbalancer + display_name: "Load Balancer" + values: + - id: "enabled" + display_name: "Load Balancer" + variables: + test_loadbalancer: true + batchtime: 10080 # 7 days + buildvariants: - matrix_name: "tests-all" matrix_spec: @@ -2463,6 +2501,17 @@ buildvariants: - name: "aws-auth-test-4.4" - name: "aws-auth-test-latest" +- matrix_name: "load-balancer" + matrix_spec: + platform: ubuntu-18.04 + mongodb-version: ["latest"] + auth-ssl: "*" + python-version: ["3.6", "3.9"] + loadbalancer: "*" + display_name: "Load Balancer ${platform} ${python-version} ${mongodb-version} ${auth-ssl}" + tasks: + - name: "load-balancer-test" + - matrix_name: "Release" matrix_spec: platform: [ubuntu-20.04, windows-64-vsMulti-small, macos-1014] diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index 7a78401264..9848b91877 100755 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -51,6 +51,11 @@ fi if [ "$SSL" != "nossl" ]; then export CLIENT_PEM="$DRIVERS_TOOLS/.evergreen/x509gen/client.pem" export CA_PEM="$DRIVERS_TOOLS/.evergreen/x509gen/ca.pem" + + if [ -n "$TEST_LOADBALANCER" ]; then + export SINGLE_MONGOS_LB_URI="${SINGLE_MONGOS_LB_URI}&tls=true" + export MULTI_MONGOS_LB_URI="${MULTI_MONGOS_LB_URI}&tls=true" + fi fi # For createvirtualenv. @@ -191,7 +196,12 @@ if [ -z "$GREEN_FRAMEWORK" ]; then # causing this script to exit. $PYTHON -c "from bson import _cbson; from pymongo import _cmessage" fi - $PYTHON $COVERAGE_ARGS setup.py $C_EXTENSIONS test $TEST_ARGS $OUTPUT + + if [ -n "$TEST_LOADBALANCER" ]; then + $PYTHON -m xmlrunner discover -s test/load_balancer -v --locals -o $XUNIT_DIR + else + $PYTHON $COVERAGE_ARGS setup.py $C_EXTENSIONS test $TEST_ARGS $OUTPUT + fi else # --no_ext has to come before "test" so there is no way to toggle extensions here. $PYTHON green_framework_test.py $GREEN_FRAMEWORK $OUTPUT diff --git a/pymongo/client_session.py b/pymongo/client_session.py index 950fe0dc13..5b6ff7524d 100644 --- a/pymongo/client_session.py +++ b/pymongo/client_session.py @@ -660,7 +660,7 @@ def abort_transaction(self): pass finally: self._transaction.state = _TxnState.ABORTED - self._unpin_mongos() + self._unpin() def _finish_transaction_with_retry(self, command_name): """Run commit or abort with one retry after any retryable error. @@ -779,13 +779,13 @@ def _pinned_address(self): return self._transaction.pinned_address return None - def _pin_mongos(self, server): - """Pin this session to the given mongos Server.""" + def _pin(self, server): + """Pin this session to the given Server.""" self._transaction.sharded = True self._transaction.pinned_address = server.description.address - def _unpin_mongos(self): - """Unpin this session from any pinned mongos address.""" + def _unpin(self): + """Unpin this session from any pinned Server.""" self._transaction.pinned_address = None def _txn_read_preference(self): @@ -906,9 +906,11 @@ def get_server_session(self, session_timeout_minutes): return _ServerSession(self.generation) def return_server_session(self, server_session, session_timeout_minutes): - self._clear_stale(session_timeout_minutes) - if not server_session.timed_out(session_timeout_minutes): - self.return_server_session_no_lock(server_session) + if session_timeout_minutes is not None: + self._clear_stale(session_timeout_minutes) + if server_session.timed_out(session_timeout_minutes): + return + self.return_server_session_no_lock(server_session) def return_server_session_no_lock(self, server_session): # Discard sessions from an old pool to avoid duplicate sessions in the diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index ebd11970a2..728c9a1670 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -1197,15 +1197,16 @@ def _select_server(self, server_selector, session, address=None): server = topology.select_server(server_selector) # Pin this session to the selected server if it's performing a # sharded transaction. - if server.description.mongos and (session and - session.in_transaction): - session._pin_mongos(server) + if (server.description.server_type in ( + SERVER_TYPE.Mongos, SERVER_TYPE.LoadBalancer) + and session and session.in_transaction): + session._pin(server) return server except PyMongoError as exc: # Server selection errors in a transaction are transient. if session and session.in_transaction: exc._add_error_label("TransientTransactionError") - session._unpin_mongos() + session._unpin() raise def _socket_for_writes(self, session): @@ -1350,7 +1351,7 @@ def is_retrying(): _add_retryable_write_error(exc, max_wire_version) retryable_error = exc.has_error_label("RetryableWriteError") if retryable_error: - session._unpin_mongos() + session._unpin() if is_retrying() or not retryable_error: raise if bulk: @@ -1965,7 +1966,7 @@ def _add_retryable_write_error(exc, max_wire_version): class _MongoClientErrorHandler(object): """Handle errors raised when executing an operation.""" __slots__ = ('client', 'server_address', 'session', 'max_wire_version', - 'sock_generation', 'completed_handshake') + 'sock_generation', 'completed_handshake', 'service_id') def __init__(self, client, server, session): self.client = client @@ -1978,11 +1979,13 @@ def __init__(self, client, server, session): # of the pool at the time the connection attempt was started." self.sock_generation = server.pool.generation self.completed_handshake = False + self.service_id = None def contribute_socket(self, sock_info): """Provide socket information to the error handler.""" self.max_wire_version = sock_info.max_wire_version self.sock_generation = sock_info.generation + self.service_id = sock_info.service_id self.completed_handshake = True def __enter__(self): @@ -2001,9 +2004,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): if issubclass(exc_type, PyMongoError): if (exc_val.has_error_label("TransientTransactionError") or exc_val.has_error_label("RetryableWriteError")): - self.session._unpin_mongos() + self.session._unpin() err_ctx = _ErrorContext( exc_val, self.max_wire_version, self.sock_generation, - self.completed_handshake) + self.completed_handshake, self.service_id) self.client._topology.handle_error(self.server_address, err_ctx) diff --git a/pymongo/monitoring.py b/pymongo/monitoring.py index 7537765637..b53629d12b 100644 --- a/pymongo/monitoring.py +++ b/pymongo/monitoring.py @@ -512,13 +512,16 @@ def register(listener): class _CommandEvent(object): """Base class for command events.""" - __slots__ = ("__cmd_name", "__rqst_id", "__conn_id", "__op_id") + __slots__ = ("__cmd_name", "__rqst_id", "__conn_id", "__op_id", + "__service_id") - def __init__(self, command_name, request_id, connection_id, operation_id): + def __init__(self, command_name, request_id, connection_id, operation_id, + service_id=None): self.__cmd_name = command_name self.__rqst_id = request_id self.__conn_id = connection_id self.__op_id = operation_id + self.__service_id = service_id @property def command_name(self): @@ -535,6 +538,14 @@ def connection_id(self): """The address (host, port) of the server this command was sent to.""" return self.__conn_id + @property + def service_id(self): + """The service_id this command was sent to, or ``None``. + + .. versionadded:: 3.12 + """ + return self.__service_id + @property def operation_id(self): """An id for this series of events or None.""" @@ -551,15 +562,17 @@ class CommandStartedEvent(_CommandEvent): - `connection_id`: The address (host, port) of the server this command was sent to. - `operation_id`: An optional identifier for a series of related events. + - `service_id`: The service_id this command was sent to, or ``None``. """ __slots__ = ("__cmd", "__db") - def __init__(self, command, database_name, *args): + def __init__(self, command, database_name, *args, service_id=None): if not command: raise ValueError("%r is not a valid command" % (command,)) # Command name must be first key. command_name = next(iter(command)) - super(CommandStartedEvent, self).__init__(command_name, *args) + super(CommandStartedEvent, self).__init__( + command_name, *args, service_id=service_id) if command_name.lower() in _SENSITIVE_COMMANDS: self.__cmd = {} else: @@ -577,9 +590,12 @@ def database_name(self): return self.__db def __repr__(self): - return "<%s %s db: %r, command: %r, operation_id: %s>" % ( - self.__class__.__name__, self.connection_id, self.database_name, - self.command_name, self.operation_id) + return ( + "<%s %s db: %r, command: %r, operation_id: %s, " + "service_id: %s>") % ( + self.__class__.__name__, self.connection_id, + self.database_name, self.command_name, self.operation_id, + self.service_id) class CommandSucceededEvent(_CommandEvent): @@ -593,13 +609,15 @@ class CommandSucceededEvent(_CommandEvent): - `connection_id`: The address (host, port) of the server this command was sent to. - `operation_id`: An optional identifier for a series of related events. + - `service_id`: The service_id this command was sent to, or ``None``. """ __slots__ = ("__duration_micros", "__reply") def __init__(self, duration, reply, command_name, - request_id, connection_id, operation_id): + request_id, connection_id, operation_id, service_id=None): super(CommandSucceededEvent, self).__init__( - command_name, request_id, connection_id, operation_id) + command_name, request_id, connection_id, operation_id, + service_id=service_id) self.__duration_micros = _to_micros(duration) if command_name.lower() in _SENSITIVE_COMMANDS: self.__reply = {} @@ -617,9 +635,12 @@ def reply(self): return self.__reply def __repr__(self): - return "<%s %s command: %r, operation_id: %s, duration_micros: %s>" % ( - self.__class__.__name__, self.connection_id, - self.command_name, self.operation_id, self.duration_micros) + return ( + "<%s %s command: %r, operation_id: %s, duration_micros: %s, " + "service_id: %s>") % ( + self.__class__.__name__, self.connection_id, + self.command_name, self.operation_id, self.duration_micros, + self.service_id) class CommandFailedEvent(_CommandEvent): @@ -633,11 +654,12 @@ class CommandFailedEvent(_CommandEvent): - `connection_id`: The address (host, port) of the server this command was sent to. - `operation_id`: An optional identifier for a series of related events. + - `service_id`: The service_id this command was sent to, or ``None``. """ __slots__ = ("__duration_micros", "__failure") - def __init__(self, duration, failure, *args): - super(CommandFailedEvent, self).__init__(*args) + def __init__(self, duration, failure, *args, service_id=None): + super(CommandFailedEvent, self).__init__(*args, service_id=service_id) self.__duration_micros = _to_micros(duration) self.__failure = failure @@ -654,9 +676,10 @@ def failure(self): def __repr__(self): return ( "<%s %s command: %r, operation_id: %s, duration_micros: %s, " - "failure: %r>" % ( + "failure: %r, service_id: %s>") % ( self.__class__.__name__, self.connection_id, self.command_name, - self.operation_id, self.duration_micros, self.failure)) + self.operation_id, self.duration_micros, self.failure, + self.service_id) class _PoolEvent(object): @@ -721,10 +744,29 @@ class PoolClearedEvent(_PoolEvent): :Parameters: - `address`: The address (host, port) pair of the server this Pool is attempting to connect to. + - `service_id`: The service_id this command was sent to, or ``None``. .. versionadded:: 3.9 """ - __slots__ = () + __slots__ = ("__service_id",) + + def __init__(self, address, service_id=None): + super(PoolClearedEvent, self).__init__(address) + self.__service_id = service_id + + @property + def service_id(self): + """Connections with this service_id are cleared. + + When service_id is ``None``, all connections in the pool are cleared. + + .. versionadded:: 3.12 + """ + return self.__service_id + + def __repr__(self): + return '%s(%r, %r)' % ( + self.__class__.__name__, self.address, self.__service_id) class PoolClosedEvent(_PoolEvent): @@ -1508,10 +1550,10 @@ def publish_pool_ready(self, address): except Exception: _handle_exception() - def publish_pool_cleared(self, address): + def publish_pool_cleared(self, address, service_id): """Publish a :class:`PoolClearedEvent` to all pool listeners. """ - event = PoolClearedEvent(address) + event = PoolClearedEvent(address, service_id) for subscriber in self.__cmap_listeners: try: subscriber.pool_cleared(event) diff --git a/pymongo/pool.py b/pymongo/pool.py index 728fec0f60..23ccdcab67 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -1129,7 +1129,7 @@ def ready(self): def closed(self): return self.state == PoolState.CLOSED - def _reset(self, close, pause=True): + def _reset(self, close, pause=True, service_id=None): old_state = self.state with self.size_cond: if self.closed: @@ -1161,7 +1161,8 @@ def _reset(self, close, pause=True): listeners.publish_pool_closed(self.address) else: if old_state != PoolState.PAUSED and self.enabled_for_cmap: - listeners.publish_pool_cleared(self.address) + listeners.publish_pool_cleared(self.address, + service_id=service_id) for sock_info in sockets: sock_info.close_socket(ConnectionClosedReason.STALE) @@ -1174,8 +1175,8 @@ def update_is_writable(self, is_writable): for socket in self.sockets: socket.update_is_writable(self.is_writable) - def reset(self): - self._reset(close=False) + def reset(self, service_id=None): + self._reset(close=False, service_id=service_id) def reset_without_pause(self): self._reset(close=False, pause=False) diff --git a/pymongo/server.py b/pymongo/server.py index fbfddae2e2..e9e29f49ea 100644 --- a/pymongo/server.py +++ b/pymongo/server.py @@ -49,9 +49,9 @@ def open(self): if not self._pool.opts.load_balanced: self._monitor.open() - def reset(self): + def reset(self, service_id=None): """Clear the connection pool.""" - self.pool.reset() + self.pool.reset(service_id) def close(self): """Clear the connection pool and stop the monitor. diff --git a/pymongo/server_description.py b/pymongo/server_description.py index 5dc8222fef..19cc349c78 100644 --- a/pymongo/server_description.py +++ b/pymongo/server_description.py @@ -204,10 +204,10 @@ def is_server_type_known(self): @property def retryable_writes_supported(self): """Checks if this server supports retryable writes.""" - return ( + return (( self._ls_timeout_minutes is not None and - self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary, - SERVER_TYPE.LoadBalancer)) + self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary)) + or self._server_type == SERVER_TYPE.LoadBalancer) @property def retryable_reads_supported(self): diff --git a/pymongo/topology.py b/pymongo/topology.py index 446bb9353d..18d5c4c8f4 100644 --- a/pymongo/topology.py +++ b/pymongo/topology.py @@ -453,7 +453,7 @@ def update_pool(self, all_credentials): try: server.pool.remove_stale_sockets(generation, all_credentials) except PyMongoError as exc: - ctx = _ErrorContext(exc, 0, generation, False) + ctx = _ErrorContext(exc, 0, generation, False, None) self.handle_error(server.description.address, ctx) raise @@ -528,11 +528,9 @@ def get_server_session(self): def return_server_session(self, server_session, lock): if lock: with self._lock: - session_timeout = \ - self._description.logical_session_timeout_minutes - if session_timeout is not None: - self._session_pool.return_server_session(server_session, - session_timeout) + self._session_pool.return_server_session( + server_session, + self._description.logical_session_timeout_minutes) else: # Called from a __del__ method, can't use a lock. self._session_pool.return_server_session_no_lock(server_session) @@ -566,7 +564,8 @@ def _ensure_opened(self): # Emit initial SDAM events for load balancer mode. self._process_change(ServerDescription( self._seed_addresses[0], - IsMaster({'ok': 1, 'serviceId': self._topology_id}))) + IsMaster({'ok': 1, 'serviceId': self._topology_id, + 'maxWireVersion': 13}))) # Ensure that the monitors are open. for server in self._servers.values(): @@ -599,6 +598,7 @@ def _handle_error(self, address, err_ctx): server = self._servers[address] error = err_ctx.error exc_type = type(error) + service_id = err_ctx.service_id if (issubclass(exc_type, NetworkTimeout) and err_ctx.completed_handshake): # The socket has been closed. Don't reset the server. @@ -629,21 +629,21 @@ def _handle_error(self, address, err_ctx): self._process_change(ServerDescription(address, error=error)) if is_shutting_down or (err_ctx.max_wire_version <= 7): # Clear the pool. - server.reset() + server.reset(service_id) server.request_check() elif not err_ctx.completed_handshake: # Unknown command error during the connection handshake. if not self._settings.load_balanced: self._process_change(ServerDescription(address, error=error)) # Clear the pool. - server.reset() + server.reset(service_id) elif issubclass(exc_type, ConnectionFailure): # "Client MUST replace the server's description with type Unknown # ... MUST NOT request an immediate check of the server." if not self._settings.load_balanced: self._process_change(ServerDescription(address, error=error)) # Clear the pool. - server.reset() + server.reset(service_id) # "When a client marks a server Unknown from `Network error when # reading or writing`_, clients MUST cancel the isMaster check on # that server and close the current monitoring connection." @@ -795,11 +795,12 @@ def __repr__(self): class _ErrorContext(object): """An error with context for SDAM error handling.""" def __init__(self, error, max_wire_version, sock_generation, - completed_handshake): + completed_handshake, service_id): self.error = error self.max_wire_version = max_wire_version self.sock_generation = sock_generation self.completed_handshake = completed_handshake + self.service_id = service_id def _is_stale_error_topology_version(current_tv, error_tv): diff --git a/test/__init__.py b/test/__init__.py index 83dac398e4..9e76b28f50 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -50,6 +50,7 @@ from pymongo.common import partition_node from pymongo.server_api import ServerApi from pymongo.ssl_support import HAVE_SSL, validate_cert_reqs +from pymongo.uri_parser import parse_uri from test.version import Version if HAVE_SSL: @@ -92,6 +93,14 @@ COMPRESSORS = os.environ.get("COMPRESSORS") MONGODB_API_VERSION = os.environ.get("MONGODB_API_VERSION") +TEST_LOADBALANCER = bool(os.environ.get("TEST_LOADBALANCER")) +SINGLE_MONGOS_LB_URI = os.environ.get("SINGLE_MONGOS_LB_URI") +MULTI_MONGOS_LB_URI = os.environ.get("MULTI_MONGOS_LB_URI") +if TEST_LOADBALANCER: + res = parse_uri(SINGLE_MONGOS_LB_URI) + host, port = res['nodelist'][0] + db_user = res['username'] or db_user + db_pwd = res['password'] or db_pwd def is_server_resolvable(): @@ -190,6 +199,7 @@ def _all_users(db): class ClientContext(object): + MULTI_MONGOS_LB_URI = MULTI_MONGOS_LB_URI def __init__(self): """Create a client and grab essential information from the server.""" @@ -216,7 +226,9 @@ def __init__(self): self.client = None self.conn_lock = threading.Lock() self.is_data_lake = False - self.load_balancer = False + self.load_balancer = TEST_LOADBALANCER + if self.load_balancer: + self.default_client_options["loadBalanced"] = True if COMPRESSORS: self.default_client_options["compressors"] = COMPRESSORS if MONGODB_API_VERSION: @@ -632,8 +644,10 @@ def check_auth_with_sharding(self, func): func=func) def is_topology_type(self, topologies): - if 'load-balanced' in topologies and self.load_balancer: - return True + if self.load_balancer: + if 'load-balanced' in topologies: + return True + return False if 'single' in topologies and not (self.is_mongos or self.is_rs): return True if 'replicaset' in topologies and self.is_rs: diff --git a/test/load_balancer/test_crud_unified.py b/test/load_balancer/test_crud_unified.py new file mode 100644 index 0000000000..dfe0935bba --- /dev/null +++ b/test/load_balancer/test_crud_unified.py @@ -0,0 +1,23 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +sys.path[0:0] = [""] + +from test.test_crud_unified import * + +if __name__ == '__main__': + unittest.main() diff --git a/test/load_balancer/test_dns.py b/test/load_balancer/test_dns.py new file mode 100644 index 0000000000..047b98b121 --- /dev/null +++ b/test/load_balancer/test_dns.py @@ -0,0 +1,23 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +sys.path[0:0] = [""] + +from test.test_dns import * + +if __name__ == '__main__': + unittest.main() diff --git a/test/load_balancer/test_load_balancer.py b/test/load_balancer/test_load_balancer.py new file mode 100644 index 0000000000..c31ff58ef1 --- /dev/null +++ b/test/load_balancer/test_load_balancer.py @@ -0,0 +1,34 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the Load Balancer unified spec tests.""" + +import os +import sys + +sys.path[0:0] = [""] + +from test import unittest + +from test.unified_format import generate_test_classes + +# Location of JSON test specifications. +TEST_PATH = os.path.join( + os.path.dirname(os.path.realpath(__file__)), 'unified') + +# Generate unified tests. +globals().update(generate_test_classes(TEST_PATH, module=__name__)) + +if __name__ == "__main__": + unittest.main() diff --git a/test/load_balancer/test_retryable_change_stream.py b/test/load_balancer/test_retryable_change_stream.py new file mode 100644 index 0000000000..b7c902dd30 --- /dev/null +++ b/test/load_balancer/test_retryable_change_stream.py @@ -0,0 +1,23 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +sys.path[0:0] = [""] + +from test.test_change_stream import * + +if __name__ == '__main__': + unittest.main() diff --git a/test/load_balancer/test_retryable_reads.py b/test/load_balancer/test_retryable_reads.py new file mode 100644 index 0000000000..c5de3c9078 --- /dev/null +++ b/test/load_balancer/test_retryable_reads.py @@ -0,0 +1,23 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +sys.path[0:0] = [""] + +from test.test_retryable_reads import * + +if __name__ == '__main__': + unittest.main() diff --git a/test/load_balancer/test_retryable_writes.py b/test/load_balancer/test_retryable_writes.py new file mode 100644 index 0000000000..3800641b08 --- /dev/null +++ b/test/load_balancer/test_retryable_writes.py @@ -0,0 +1,23 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +sys.path[0:0] = [""] + +from test.test_retryable_writes import * + +if __name__ == '__main__': + unittest.main() diff --git a/test/load_balancer/test_transactions_unified.py b/test/load_balancer/test_transactions_unified.py new file mode 100644 index 0000000000..2572028046 --- /dev/null +++ b/test/load_balancer/test_transactions_unified.py @@ -0,0 +1,23 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +sys.path[0:0] = [""] + +from test.test_transactions_unified import * + +if __name__ == '__main__': + unittest.main() diff --git a/test/load_balancer/test_uri_options.py b/test/load_balancer/test_uri_options.py new file mode 100644 index 0000000000..b644d7d334 --- /dev/null +++ b/test/load_balancer/test_uri_options.py @@ -0,0 +1,23 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +sys.path[0:0] = [""] + +from test.test_uri_spec import * + +if __name__ == '__main__': + unittest.main() diff --git a/test/load_balancer/test_versioned_api.py b/test/load_balancer/test_versioned_api.py new file mode 100644 index 0000000000..7e801968cb --- /dev/null +++ b/test/load_balancer/test_versioned_api.py @@ -0,0 +1,23 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +sys.path[0:0] = [""] + +from test.test_versioned_api import * + +if __name__ == '__main__': + unittest.main() diff --git a/test/load_balancer/unified/event-monitoring.json b/test/load_balancer/unified/event-monitoring.json new file mode 100644 index 0000000000..938c70bf38 --- /dev/null +++ b/test/load_balancer/unified/event-monitoring.json @@ -0,0 +1,184 @@ +{ + "description": "monitoring events include correct fields", + "schemaVersion": "1.3", + "runOnRequirements": [ + { + "topologies": [ + "load-balanced" + ] + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "useMultipleMongoses": true, + "uriOptions": { + "retryReads": false + }, + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent", + "commandFailedEvent", + "poolClearedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "database0" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "coll0" + } + } + ], + "initialData": [ + { + "databaseName": "database0", + "collectionName": "coll0", + "documents": [] + } + ], + "tests": [ + { + "description": "command started and succeeded events include serviceId", + "operations": [ + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "x": 1 + } + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "commandName": "insert", + "hasServiceId": true + } + }, + { + "commandSucceededEvent": { + "commandName": "insert", + "hasServiceId": true + } + } + ] + } + ] + }, + { + "description": "command failed events include serviceId", + "operations": [ + { + "name": "find", + "object": "collection0", + "arguments": { + "filter": { + "$or": true + } + }, + "expectError": { + "isError": true + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "commandName": "find", + "hasServiceId": true + } + }, + { + "commandFailedEvent": { + "commandName": "find", + "hasServiceId": true + } + } + ] + } + ] + }, + { + "description": "poolClearedEvent events include serviceId", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "find" + ], + "closeConnection": true + } + } + } + }, + { + "name": "find", + "object": "collection0", + "arguments": { + "filter": {} + }, + "expectError": { + "isClientError": true + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "commandName": "find", + "hasServiceId": true + } + }, + { + "commandFailedEvent": { + "commandName": "find", + "hasServiceId": true + } + } + ] + }, + { + "client": "client0", + "eventType": "cmap", + "events": [ + { + "poolClearedEvent": { + "hasServiceId": true + } + } + ] + } + ] + } + ] +} diff --git a/test/load_balancer/unified/lb-connection-establishment.json b/test/load_balancer/unified/lb-connection-establishment.json new file mode 100644 index 0000000000..0eaadf30c2 --- /dev/null +++ b/test/load_balancer/unified/lb-connection-establishment.json @@ -0,0 +1,58 @@ +{ + "description": "connection establishment for load-balanced clusters", + "schemaVersion": "1.3", + "runOnRequirements": [ + { + "topologies": [ + "load-balanced" + ] + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "uriOptions": { + "loadBalanced": false + }, + "observeEvents": [ + "commandStartedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "database0" + } + } + ], + "tests": [ + { + "description": "operations against load balancers fail if URI contains loadBalanced=false", + "skipReason": "servers have not implemented LB support yet so they will not fail the connection handshake in this case", + "operations": [ + { + "name": "runCommand", + "object": "database0", + "arguments": { + "commandName": "ping", + "command": { + "ping": 1 + } + }, + "expectError": { + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [] + } + ] + } + ] +} diff --git a/test/load_balancer/unified/non-lb-connection-establishment.json b/test/load_balancer/unified/non-lb-connection-establishment.json new file mode 100644 index 0000000000..6aaa7bdf98 --- /dev/null +++ b/test/load_balancer/unified/non-lb-connection-establishment.json @@ -0,0 +1,92 @@ +{ + "description": "connection establishment if loadBalanced is specified for non-load balanced clusters", + "schemaVersion": "1.3", + "runOnRequirements": [ + { + "topologies": [ + "single", + "sharded" + ] + } + ], + "createEntities": [ + { + "client": { + "id": "lbTrueClient", + "useMultipleMongoses": false, + "uriOptions": { + "loadBalanced": true + } + } + }, + { + "database": { + "id": "lbTrueDatabase", + "client": "lbTrueClient", + "databaseName": "lbTrueDb" + } + }, + { + "client": { + "id": "lbFalseClient", + "uriOptions": { + "loadBalanced": false + } + } + }, + { + "database": { + "id": "lbFalseDatabase", + "client": "lbFalseClient", + "databaseName": "lbFalseDb" + } + } + ], + "_yamlAnchors": { + "runCommandArguments": [ + { + "arguments": { + "commandName": "ping", + "command": { + "ping": 1 + } + } + } + ] + }, + "tests": [ + { + "description": "operations against non-load balanced clusters fail if URI contains loadBalanced=true", + "operations": [ + { + "name": "runCommand", + "object": "lbTrueDatabase", + "arguments": { + "commandName": "ping", + "command": { + "ping": 1 + } + }, + "expectError": { + "errorContains": "Driver attempted to initialize in load balancing mode, but the server does not support this mode" + } + } + ] + }, + { + "description": "operations against non-load balanced clusters succeed if URI contains loadBalanced=false", + "operations": [ + { + "name": "runCommand", + "object": "lbFalseDatabase", + "arguments": { + "commandName": "ping", + "command": { + "ping": 1 + } + } + } + ] + } + ] +} diff --git a/test/load_balancer/unified/server-selection.json b/test/load_balancer/unified/server-selection.json new file mode 100644 index 0000000000..00c7e4c95b --- /dev/null +++ b/test/load_balancer/unified/server-selection.json @@ -0,0 +1,82 @@ +{ + "description": "server selection for load-balanced clusters", + "schemaVersion": "1.3", + "runOnRequirements": [ + { + "topologies": [ + "load-balanced" + ] + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "useMultipleMongoses": true, + "observeEvents": [ + "commandStartedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "database0Name" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "coll0", + "collectionOptions": { + "readPreference": { + "mode": "secondaryPreferred" + } + } + } + } + ], + "initialData": [ + { + "collectionName": "coll0", + "databaseName": "database0Name", + "documents": [] + } + ], + "tests": [ + { + "description": "$readPreference is sent for load-balanced clusters", + "operations": [ + { + "name": "find", + "object": "collection0", + "arguments": { + "filter": {} + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "find": "coll0", + "filter": {}, + "$readPreference": { + "mode": "secondaryPreferred" + } + }, + "commandName": "find", + "databaseName": "database0Name" + } + } + ] + } + ] + } + ] +} diff --git a/test/test_client.py b/test/test_client.py index 3754cb0ac3..5d57c32c1c 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -1483,7 +1483,7 @@ def stop(self): def run(self): while self.running: exc = AutoReconnect('mock pool error') - ctx = _ErrorContext(exc, 0, pool.generation, False) + ctx = _ErrorContext(exc, 0, pool.generation, False, None) client._topology.handle_error(pool.address, ctx) time.sleep(0.001) diff --git a/test/test_cmap.py b/test/test_cmap.py index 7a9ab51804..053f27ba73 100644 --- a/test/test_cmap.py +++ b/test/test_cmap.py @@ -22,6 +22,7 @@ sys.path[0:0] = [""] from bson.son import SON +from bson.objectid import ObjectId from pymongo.errors import (ConnectionFailure, OperationFailure, @@ -422,6 +423,7 @@ def test_events_repr(self): self.assertRepr(ConnectionCheckOutStartedEvent(host)) self.assertRepr(PoolCreatedEvent(host, {})) self.assertRepr(PoolClearedEvent(host)) + self.assertRepr(PoolClearedEvent(host, service_id=ObjectId())) self.assertRepr(PoolClosedEvent(host)) def test_close_leaves_pool_unpaused(self): diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index ddd017bb9a..7fa96b5e18 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -113,7 +113,7 @@ def got_app_error(topology, app_error): topology.handle_error( server_address, _ErrorContext(e, max_wire_version, generation, - completed_handshake)) + completed_handshake, None)) def get_type(topology, hostname): diff --git a/test/test_monitoring.py b/test/test_monitoring.py index 46cfe87c4a..31e8282828 100644 --- a/test/test_monitoring.py +++ b/test/test_monitoring.py @@ -1183,7 +1183,7 @@ def test_command_event_repr(self): self.assertEqual( repr(event), "") + "command: 'isMaster', operation_id: 2, service_id: None>") delta = datetime.timedelta(milliseconds=100) event = monitoring.CommandSucceededEvent( delta, {'ok': 1}, 'isMaster', request_id, connection_id, @@ -1191,7 +1191,8 @@ def test_command_event_repr(self): self.assertEqual( repr(event), "") + "command: 'isMaster', operation_id: 2, duration_micros: 100000, " + "service_id: None>") event = monitoring.CommandFailedEvent( delta, {'ok': 0}, 'isMaster', request_id, connection_id, operation_id) @@ -1199,7 +1200,7 @@ def test_command_event_repr(self): repr(event), "") + "failure: {'ok': 0}, service_id: None>") def test_server_heartbeat_event_repr(self): connection_id = ('localhost', 27017) diff --git a/test/test_topology.py b/test/test_topology.py index 2abcab47b1..5e2f683f70 100644 --- a/test/test_topology.py +++ b/test/test_topology.py @@ -401,7 +401,7 @@ def test_handle_error(self): 'setName': 'rs', 'hosts': ['a', 'b']}) - errctx = _ErrorContext(AutoReconnect('mock'), 0, 0, True) + errctx = _ErrorContext(AutoReconnect('mock'), 0, 0, True, None) t.handle_error(('a', 27017), errctx) self.assertEqual(SERVER_TYPE.Unknown, get_type(t, 'a')) self.assertEqual(SERVER_TYPE.RSSecondary, get_type(t, 'b')) @@ -430,7 +430,7 @@ def test_handle_error_removed_server(self): t = create_mock_topology(replica_set_name='rs') # No error resetting a server not in the TopologyDescription. - errctx = _ErrorContext(AutoReconnect('mock'), 0, 0, True) + errctx = _ErrorContext(AutoReconnect('mock'), 0, 0, True, None) t.handle_error(('b', 27017), errctx) # Server was *not* added as type Unknown. diff --git a/test/unified_format.py b/test/unified_format.py index ad1e73a519..7fa7f55136 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -39,10 +39,16 @@ from pymongo.change_stream import ChangeStream from pymongo.collection import Collection from pymongo.database import Database -from pymongo.errors import BulkWriteError, InvalidOperation, PyMongoError +from pymongo.errors import ( + BulkWriteError, ConnectionFailure, InvalidOperation, NotMasterError, + PyMongoError) from pymongo.monitoring import ( CommandFailedEvent, CommandListener, CommandStartedEvent, - CommandSucceededEvent, _SENSITIVE_COMMANDS) + CommandSucceededEvent, _SENSITIVE_COMMANDS, PoolCreatedEvent, + PoolReadyEvent, PoolClearedEvent, PoolClosedEvent, ConnectionCreatedEvent, + ConnectionReadyEvent, ConnectionClosedEvent, + ConnectionCheckOutStartedEvent, ConnectionCheckOutFailedEvent, + ConnectionCheckedOutEvent, ConnectionCheckedInEvent) from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference from pymongo.results import BulkWriteResult @@ -51,7 +57,8 @@ from test import client_context, unittest, IntegrationTest from test.utils import ( - camel_to_snake, rs_or_single_client, single_client, snake_to_camel) + camel_to_snake, get_pool, rs_or_single_client, single_client, + snake_to_camel, CMAPListener) from test.version import Version from test.utils import ( @@ -142,28 +149,52 @@ def parse_bulk_write_error_result(error): return parse_bulk_write_result(write_result) -class EventListenerUtil(CommandListener): +class NonLazyCursor(object): + """A find cursor proxy that creates the remote cursor when initialized.""" + def __init__(self, find_cursor): + self.find_cursor = find_cursor + # Create the server side cursor. + self.first_result = next(find_cursor, None) + + def __next__(self): + if self.first_result is not None: + first = self.first_result + self.first_result = None + return first + return next(self.find_cursor) + + def close(self): + self.find_cursor.close() + + +class EventListenerUtil(CMAPListener, CommandListener): def __init__(self, observe_events, ignore_commands): - self._event_types = set(observe_events) + self._event_types = set(name.lower() for name in observe_events) self._ignore_commands = _SENSITIVE_COMMANDS | set(ignore_commands) self._ignore_commands.add('configurefailpoint') - self.results = [] + super(EventListenerUtil, self).__init__() + + def get_events(self, event_type): + if event_type == 'command': + return [e for e in self.events if 'Command' in type(e).__name__] + return [e for e in self.events if 'Command' not in type(e).__name__] + + def add_event(self, event): + if type(event).__name__.lower() in self._event_types: + super(EventListenerUtil, self).add_event(event) - def _observe_event(self, event): + def _command_event(self, event): if event.command_name.lower() not in self._ignore_commands: - self.results.append(event) + self.add_event(event) def started(self, event): - if 'commandStartedEvent' in self._event_types: - self._observe_event(event) + self._command_event(event) def succeeded(self, event): - if 'commandSucceededEvent' in self._event_types: - self._observe_event(event) + self._command_event(event) def failed(self, event): - if 'commandFailedEvent' in self._event_types: - self._observe_event(event) + self._command_event(event) class EntityMapUtil(object): @@ -173,28 +204,28 @@ def __init__(self, test_class): self._entities = {} self._listeners = {} self._session_lsids = {} - self._test_class = test_class + self.test = test_class def __getitem__(self, item): try: return self._entities[item] except KeyError: - self._test_class.fail('Could not find entity named %s in map' % ( + self.test.fail('Could not find entity named %s in map' % ( item,)) def __setitem__(self, key, value): if not isinstance(key, str): - self._test_class.fail( + self.test.fail( 'Expected entity name of type str, got %s' % (type(key))) if key in self._entities: - self._test_class.fail('Entity named %s already in map' % (key,)) + self.test.fail('Entity named %s already in map' % (key,)) self._entities[key] = value def _create_entity(self, entity_spec): if len(entity_spec) != 1: - self._test_class.fail( + self.test.fail( "Entity spec %s did not contain exactly one top-level key" % ( entity_spec,)) @@ -203,13 +234,17 @@ def _create_entity(self, entity_spec): kwargs = {} observe_events = spec.get('observeEvents', []) ignore_commands = spec.get('ignoreCommandMonitoringEvents', []) + # TODO: SUPPORT storeEventsAsEntities if len(observe_events) or len(ignore_commands): ignore_commands = [cmd.lower() for cmd in ignore_commands] listener = EventListenerUtil(observe_events, ignore_commands) self._listeners[spec['id']] = listener kwargs['event_listeners'] = [listener] - if client_context.is_mongos and spec.get('useMultipleMongoses'): - kwargs['h'] = client_context.mongos_seeds() + if spec.get('useMultipleMongoses'): + if client_context.load_balancer: + kwargs['h'] = client_context.MULTI_MONGOS_LB_URI + elif client_context.is_mongos: + kwargs['h'] = client_context.mongos_seeds() kwargs.update(spec.get('uriOptions', {})) server_api = spec.get('serverApi') if server_api: @@ -218,12 +253,12 @@ def _create_entity(self, entity_spec): deprecation_errors=server_api.get('deprecationErrors')) client = rs_or_single_client(**kwargs) self[spec['id']] = client - self._test_class.addCleanup(client.close) + self.test.addCleanup(client.close) return elif entity_type == 'database': client = self[spec['client']] if not isinstance(client, MongoClient): - self._test_class.fail( + self.test.fail( 'Expected entity %s to be of type MongoClient, got %s' % ( spec['client'], type(client))) options = parse_collection_or_database_options( @@ -234,7 +269,7 @@ def _create_entity(self, entity_spec): elif entity_type == 'collection': database = self[spec['database']] if not isinstance(database, Database): - self._test_class.fail( + self.test.fail( 'Expected entity %s to be of type Database, got %s' % ( spec['database'], type(database))) options = parse_collection_or_database_options( @@ -245,7 +280,7 @@ def _create_entity(self, entity_spec): elif entity_type == 'session': client = self[spec['client']] if not isinstance(client, MongoClient): - self._test_class.fail( + self.test.fail( 'Expected entity %s to be of type MongoClient, got %s' % ( spec['client'], type(client))) opts = camel_to_snake_args(spec.get('sessionOptions', {})) @@ -258,13 +293,13 @@ def _create_entity(self, entity_spec): session = client.start_session(**dict(opts)) self[spec['id']] = session self._session_lsids[spec['id']] = copy.deepcopy(session.session_id) - self._test_class.addCleanup(session.end_session) + self.test.addCleanup(session.end_session) return elif entity_type == 'bucket': # TODO: implement the 'bucket' entity type - self._test_class.skipTest( + self.test.skipTest( 'GridFS is not currently supported (PYTHON-2459)') - self._test_class.fail( + self.test.fail( 'Unable to create entity of unknown type %s' % (entity_type,)) def create_entities_from_spec(self, entity_spec): @@ -274,13 +309,13 @@ def create_entities_from_spec(self, entity_spec): def get_listener_for_client(self, client_name): client = self[client_name] if not isinstance(client, MongoClient): - self._test_class.fail( + self.test.fail( 'Expected entity %s to be of type MongoClient, got %s' % ( client_name, type(client))) listener = self._listeners.get(client_name) if not listener: - self._test_class.fail( + self.test.fail( 'No listeners configured for client %s' % (client_name,)) return listener @@ -288,7 +323,7 @@ def get_listener_for_client(self, client_name): def get_lsid_for_session(self, session_name): session = self[session_name] if not isinstance(session, ClientSession): - self._test_class.fail( + self.test.fail( 'Expected entity %s to be of type ClientSession, got %s' % ( session_name, type(session))) @@ -334,21 +369,21 @@ class MatchEvaluatorUtil(object): """Utility class that implements methods for evaluating matches as per the unified test format specification.""" def __init__(self, test_class): - self._test_class = test_class + self.test = test_class def _operation_exists(self, spec, actual, key_to_compare): if spec is True: - self._test_class.assertIn(key_to_compare, actual) + self.test.assertIn(key_to_compare, actual) elif spec is False: - self._test_class.assertNotIn(key_to_compare, actual) + self.test.assertNotIn(key_to_compare, actual) else: - self._test_class.fail( + self.test.fail( 'Expected boolean value for $$exists operator, got %s' % ( spec,)) def __type_alias_to_type(self, alias): if alias not in BSON_TYPE_ALIAS_MAP: - self._test_class.fail('Unrecognized BSON type alias %s' % (alias,)) + self.test.fail('Unrecognized BSON type alias %s' % (alias,)) return BSON_TYPE_ALIAS_MAP[alias] def _operation_type(self, spec, actual, key_to_compare): @@ -357,13 +392,13 @@ def _operation_type(self, spec, actual, key_to_compare): t for alias in spec for t in self.__type_alias_to_type(alias)]) else: permissible_types = self.__type_alias_to_type(spec) - self._test_class.assertIsInstance( + self.test.assertIsInstance( actual[key_to_compare], permissible_types) def _operation_matchesEntity(self, spec, actual, key_to_compare): - expected_entity = self._test_class.entity_map[spec] - self._test_class.assertIsInstance(expected_entity, abc.Mapping) - self._test_class.assertEqual(expected_entity, actual[key_to_compare]) + expected_entity = self.test.entity_map[spec] + self.test.assertIsInstance(expected_entity, abc.Mapping) + self.test.assertEqual(expected_entity, actual[key_to_compare]) def _operation_matchesHexBytes(self, spec, actual, key_to_compare): raise NotImplementedError @@ -380,8 +415,8 @@ def _operation_unsetOrMatches(self, spec, actual, key_to_compare): self.match_result(spec, actual[key_to_compare], in_recursive_call=True) def _operation_sessionLsid(self, spec, actual, key_to_compare): - expected_lsid = self._test_class.entity_map.get_lsid_for_session(spec) - self._test_class.assertEqual(expected_lsid, actual[key_to_compare]) + expected_lsid = self.test.entity_map.get_lsid_for_session(spec) + self.test.assertEqual(expected_lsid, actual[key_to_compare]) def _evaluate_special_operation(self, opname, spec, actual, key_to_compare): @@ -389,7 +424,7 @@ def _evaluate_special_operation(self, opname, spec, actual, try: method = getattr(self, method_name) except AttributeError: - self._test_class.fail( + self.test.fail( 'Unsupported special matching operator %s' % (opname,)) else: method(spec, actual, key_to_compare) @@ -440,16 +475,16 @@ def _match_document(self, expectation, actual, is_root): if self._evaluate_if_special_operation(expectation, actual): return - self._test_class.assertIsInstance(actual, abc.Mapping) + self.test.assertIsInstance(actual, abc.Mapping) for key, value in expectation.items(): if self._evaluate_if_special_operation(expectation, actual, key): continue - self._test_class.assertIn(key, actual) + self.test.assertIn(key, actual) self.match_result(value, actual[key], in_recursive_call=True) if not is_root: - self._test_class.assertEqual( + self.test.assertEqual( set(expectation.keys()), set(actual.keys())) def match_result(self, expectation, actual, @@ -459,7 +494,7 @@ def match_result(self, expectation, actual, expectation, actual, is_root=not in_recursive_call) if isinstance(expectation, abc.MutableSequence): - self._test_class.assertIsInstance(actual, abc.MutableSequence) + self.test.assertIsInstance(actual, abc.MutableSequence) for e, a in zip(expectation, actual): if isinstance(e, abc.Mapping): self._match_document( @@ -471,21 +506,22 @@ def match_result(self, expectation, actual, # account for flexible numerics in element-wise comparison if (isinstance(expectation, int) or isinstance(expectation, float)): - self._test_class.assertEqual(expectation, actual) + self.test.assertEqual(expectation, actual) else: - self._test_class.assertIsInstance(actual, type(expectation)) - self._test_class.assertEqual(expectation, actual) + self.test.assertIsInstance(actual, type(expectation)) + self.test.assertEqual(expectation, actual) - def match_event(self, expectation, actual): - event_type, spec = next(iter(expectation.items())) + def match_event(self, event_type, expectation, actual): + name, spec = next(iter(expectation.items())) - # every event type has the commandName field - command_name = spec.get('commandName') - if command_name: - self._test_class.assertEqual(command_name, actual.command_name) + # every command event has the commandName field + if event_type == 'command': + command_name = spec.get('commandName') + if command_name: + self.test.assertEqual(command_name, actual.command_name) - if event_type == 'commandStartedEvent': - self._test_class.assertIsInstance(actual, CommandStartedEvent) + if name == 'commandStartedEvent': + self.test.assertIsInstance(actual, CommandStartedEvent) command = spec.get('command') database_name = spec.get('databaseName') if command: @@ -497,18 +533,47 @@ def match_event(self, expectation, actual): update.setdefault('multi', False) self.match_result(command, actual.command) if database_name: - self._test_class.assertEqual( + self.test.assertEqual( database_name, actual.database_name) - elif event_type == 'commandSucceededEvent': - self._test_class.assertIsInstance(actual, CommandSucceededEvent) + elif name == 'commandSucceededEvent': + self.test.assertIsInstance(actual, CommandSucceededEvent) reply = spec.get('reply') if reply: self.match_result(reply, actual.reply) - elif event_type == 'commandFailedEvent': - self._test_class.assertIsInstance(actual, CommandFailedEvent) + elif name == 'commandFailedEvent': + self.test.assertIsInstance(actual, CommandFailedEvent) + elif name == 'poolCreatedEvent': + self.test.assertIsInstance(actual, PoolCreatedEvent) + elif name == 'poolReadyEvent': + self.test.assertIsInstance(actual, PoolReadyEvent) + elif name == 'poolClearedEvent': + self.test.assertIsInstance(actual, PoolClearedEvent) + if spec.get('hasServiceId'): + self.test.assertIsNotNone(actual.service_id) + self.test.assertIsInstance(actual.service_id, ObjectId) + else: + self.test.assertIsNone(actual.service_id) + elif name == 'poolClosedEvent': + self.test.assertIsInstance(actual, PoolClosedEvent) + elif name == 'connectionCreatedEvent': + self.test.assertIsInstance(actual, ConnectionCreatedEvent) + elif name == 'connectionReadyEvent': + self.test.assertIsInstance(actual, ConnectionReadyEvent) + elif name == 'connectionClosedEvent': + self.test.assertIsInstance(actual, ConnectionClosedEvent) + self.test.assertEqual(actual.reason, spec['reason']) + elif name == 'connectionCheckOutStartedEvent': + self.test.assertIsInstance(actual, ConnectionCheckOutStartedEvent) + elif name == 'connectionCheckOutFailedEvent': + self.test.assertIsInstance(actual, ConnectionCheckOutFailedEvent) + self.test.assertEqual(actual.reason, spec['reason']) + elif name == 'connectionCheckedOutEvent': + self.test.assertIsInstance(actual, ConnectionCheckedOutEvent) + elif name == 'connectionCheckedInEvent': + self.test.assertIsInstance(actual, ConnectionCheckedInEvent) else: - self._test_class.fail( - 'Unsupported event type %s' % (event_type,)) + self.test.fail( + 'Unsupported event type %s' % (name,)) def coerce_result(opname, result): @@ -623,7 +688,11 @@ def process_error(self, exception, spec): pass if is_client_error: - self.assertNotIsInstance(exception, PyMongoError) + # Connection errors are considered client errors. + if isinstance(exception, ConnectionFailure): + self.assertNotIsInstance(exception, NotMasterError) + else: + self.assertNotIsInstance(exception, PyMongoError) if error_contains: if isinstance(exception, BulkWriteError): @@ -692,6 +761,12 @@ def _databaseOperation_runCommand(self, target, **kwargs): kwargs['command'] = ordered_command return target.command(**kwargs) + def _databaseOperation_listCollections(self, target, *args, **kwargs): + if 'batch_size' in kwargs: + kwargs['cursor'] = {'batchSize': kwargs.pop('batch_size')} + cursor = target.list_collections(*args, **kwargs) + return list(cursor) + def __entityOperation_aggregate(self, target, *args, **kwargs): self.__raise_if_unsupported('aggregate', target, Database, Collection) return list(target.aggregate(*args, **kwargs)) @@ -707,6 +782,16 @@ def _collectionOperation_find(self, target, *args, **kwargs): find_cursor = target.find(*args, **kwargs) return list(find_cursor) + def _collectionOperation_createFindCursor(self, target, *args, **kwargs): + self.__raise_if_unsupported('find', target, Collection) + return NonLazyCursor(target.find(*args, **kwargs)) + + def _collectionOperation_listIndexes(self, target, *args, **kwargs): + if 'batch_size' in kwargs: + self.skipTest('PyMongo does not support batch_size for ' + 'list_indexes') + return target.list_indexes(*args, **kwargs) + def _sessionOperation_withTransaction(self, target, *args, **kwargs): if client_context.storage_engine == 'mmapv1': self.skipTest('MMAPv1 does not support document-level locking') @@ -725,11 +810,27 @@ def _changeStreamOperation_iterateUntilDocumentOrError(self, target, 'iterateUntilDocumentOrError', target, ChangeStream) return next(target) + def _cursor_iterateUntilDocumentOrError(self, target, *args, **kwargs): + self.__raise_if_unsupported( + 'iterateUntilDocumentOrError', target, NonLazyCursor) + return next(target) + + def _cursor_close(self, target, *args, **kwargs): + self.__raise_if_unsupported('close', target, NonLazyCursor) + return target.close() + def run_entity_operation(self, spec): target = self.entity_map[spec['object']] opname = spec['name'] opargs = spec.get('arguments') expect_error = spec.get('expectError') + save_as_entity = spec.get('saveResultAsEntity') + expect_result = spec.get('expectResult') + ignore = spec.get('ignoreResultAndError') + if ignore and (expect_error or save_as_entity or expect_result): + raise ValueError( + 'ignoreResultAndError is incompatible with saveResultAsEntity' + ', expectError, and expectResult') if opargs: arguments = parse_spec_options(copy.deepcopy(opargs)) prepare_spec_arguments(spec, arguments, camel_to_snake(opname), @@ -745,6 +846,8 @@ def run_entity_operation(self, spec): method_name = '_collectionOperation_%s' % (opname,) elif isinstance(target, ChangeStream): method_name = '_changeStreamOperation_%s' % (opname,) + elif isinstance(target, NonLazyCursor): + method_name = '_cursor_%s' % (opname,) elif isinstance(target, ClientSession): method_name = '_sessionOperation_%s' % (opname,) elif isinstance(target, GridFSBucket): @@ -766,15 +869,16 @@ def run_entity_operation(self, spec): try: result = cmd(**dict(arguments)) except Exception as exc: + if ignore: + return if expect_error: return self.process_error(exc, expect_error) raise - if 'expectResult' in spec: + if expect_result: actual = coerce_result(opname, result) - self.match_evaluator.match_result(spec['expectResult'], actual) + self.match_evaluator.match_result(expect_result, actual) - save_as_entity = spec.get('saveResultAsEntity') if save_as_entity: self.entity_map[save_as_entity] = result @@ -821,7 +925,7 @@ def _testOperation_assertSessionUnpinned(self, spec): def __get_last_two_command_lsids(self, listener): cmd_started_events = [] - for event in reversed(listener.results): + for event in reversed(listener.events): if isinstance(event, CommandStartedEvent): cmd_started_events.append(event) if len(cmd_started_events) < 2: @@ -869,6 +973,11 @@ def _testOperation_assertIndexNotExists(self, spec): for index in collection.list_indexes(): self.assertNotEqual(spec['indexName'], index['name']) + def _testOperation_assertNumberConnectionsCheckedOut(self, spec): + client = self.entity_map[spec['client']] + pool = get_pool(client) + self.assertEqual(spec['connections'], pool.active_sockets) + def run_special_operation(self, spec): opname = spec['name'] method_name = '_testOperation_%s' % (opname,) @@ -891,19 +1000,23 @@ def check_events(self, spec): for event_spec in spec: client_name = event_spec['client'] events = event_spec['events'] - listener = self.entity_map.get_listener_for_client(client_name) + # Valid types: 'command', 'cmap' + event_type = event_spec.get('eventType', 'command') + assert event_type in ('command', 'cmap') + listener = self.entity_map.get_listener_for_client(client_name) + actual_events = listener.get_events(event_type) if len(events) == 0: - self.assertEqual(listener.results, []) + self.assertEqual(actual_events, []) continue - if len(events) > len(listener.results): + if len(events) > len(actual_events): self.fail('Expected to see %s events, got %s' % ( - len(events), len(listener.results))) + len(events), len(actual_events))) for idx, expected_event in enumerate(events): self.match_evaluator.match_event( - expected_event, listener.results[idx]) + event_type, expected_event, actual_events[idx]) def verify_outcome(self, spec): for collection_data in spec: diff --git a/test/utils.py b/test/utils.py index f3d7dbe2aa..682782a432 100644 --- a/test/utils.py +++ b/test/utils.py @@ -277,7 +277,7 @@ def _reset(self): def ready(self): pass - def reset(self): + def reset(self, service_id=None): self._reset() def reset_without_pause(self): diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 17175884da..5f79789ec8 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -504,15 +504,16 @@ def run_scenario(self, scenario_def, test): client_context.storage_engine == 'mmapv1'): self.skipTest("MMAPv1 does not support retryWrites=True") use_multi_mongos = test['useMultipleMongoses'] - if client_context.is_mongos and use_multi_mongos: - client = rs_client( - client_context.mongos_seeds(), - event_listeners=[listener, pool_listener, server_listener], - **client_options) - else: - client = rs_client( - event_listeners=[listener, pool_listener, server_listener], - **client_options) + host = None + if use_multi_mongos: + if client_context.load_balancer: + host = client_context.MULTI_MONGOS_LB_URI + elif client_context.is_mongos: + host = client_context.mongos_seeds() + client = rs_client( + h=host, + event_listeners=[listener, pool_listener, server_listener], + **client_options) self.scenario_client = client self.listener = listener self.pool_listener = pool_listener