From 1e61ec8114b38de960e06d72713fab61da157986 Mon Sep 17 00:00:00 2001 From: Roy Jacobson Date: Mon, 18 Sep 2023 13:52:56 +0300 Subject: [PATCH] chore: Add a context manager to DflyInstance so we don't forget to close them. (#1873) * chore: Add a context manager to DflyInstance so we don't forget to close them. * Update tests/dragonfly/config_test.py Co-authored-by: Roman Gershman Signed-off-by: Roy Jacobson --------- Signed-off-by: Roy Jacobson Co-authored-by: Roman Gershman --- tests/dragonfly/__init__.py | 7 ++++ tests/dragonfly/config_test.py | 28 ++++++------- tests/dragonfly/generic_test.py | 54 ++++++++++--------------- tests/dragonfly/http_conf_test.py | 57 ++++++++++++--------------- tests/dragonfly/server_family_test.py | 22 +++++------ tests/dragonfly/snapshot_test.py | 22 +++++------ tests/dragonfly/tls_conf_test.py | 32 +++++++-------- 7 files changed, 98 insertions(+), 124 deletions(-) diff --git a/tests/dragonfly/__init__.py b/tests/dragonfly/__init__.py index ad07f3542689..92b1410e5926 100644 --- a/tests/dragonfly/__init__.py +++ b/tests/dragonfly/__init__.py @@ -61,6 +61,13 @@ def __del__(self): def client(self, *args, **kwargs) -> RedisClient: return RedisClient(port=self.port, *args, **kwargs) + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + self.stop() + def start(self): if self.params.existing_port: return diff --git a/tests/dragonfly/config_test.py b/tests/dragonfly/config_test.py index 9fae4dc92f3f..73e194a4b698 100644 --- a/tests/dragonfly/config_test.py +++ b/tests/dragonfly/config_test.py @@ -7,23 +7,19 @@ async def test_maxclients(df_factory): # Needs some authentication - server = df_factory.create(port=1111, maxclients=1, admin_port=1112) - server.start() + with df_factory.create(port=1111, maxclients=1, admin_port=1112) as server: + async with server.client() as client1: + assert [b"maxclients", b"1"] == await client1.execute_command("CONFIG GET maxclients") - async with server.client() as client1: - assert [b"maxclients", b"1"] == await client1.execute_command("CONFIG GET maxclients") + with pytest.raises(redis.exceptions.ConnectionError): + async with server.client() as client2: + await client2.get("test") - with pytest.raises(redis.exceptions.ConnectionError): + # Check that admin connections are not limited. + async with RedisClient(port=server.admin_port) as admin_client: + await admin_client.get("test") + + await client1.execute_command("CONFIG SET maxclients 3") + assert [b"maxclients", b"3"] == await client1.execute_command("CONFIG GET maxclients") async with server.client() as client2: await client2.get("test") - - # Check that admin connections are not limited. - async with RedisClient(port=server.admin_port) as admin_client: - await admin_client.get("test") - - await client1.execute_command("CONFIG SET maxclients 3") - assert [b"maxclients", b"3"] == await client1.execute_command("CONFIG GET maxclients") - async with server.client() as client2: - await client2.get("test") - - server.stop() diff --git a/tests/dragonfly/generic_test.py b/tests/dragonfly/generic_test.py index f51e49111049..e9d36d217004 100644 --- a/tests/dragonfly/generic_test.py +++ b/tests/dragonfly/generic_test.py @@ -27,29 +27,23 @@ def export_dfly_password() -> str: async def test_password(df_local_factory, export_dfly_password): - dfly = df_local_factory.create() - dfly.start() - - # Expect password form environment variable - with pytest.raises(redis.exceptions.AuthenticationError): - async with aioredis.Redis(port=dfly.port) as client: + with df_local_factory.create() as dfly: + # Expect password form environment variable + with pytest.raises(redis.exceptions.AuthenticationError): + async with aioredis.Redis(port=dfly.port) as client: + await client.ping() + async with aioredis.Redis(password=export_dfly_password, port=dfly.port) as client: await client.ping() - async with aioredis.Redis(password=export_dfly_password, port=dfly.port) as client: - await client.ping() - dfly.stop() # --requirepass should take precedence over environment variable requirepass = "requirepass" - dfly = df_local_factory.create(requirepass=requirepass) - dfly.start() - - # Expect password form flag - with pytest.raises(redis.exceptions.AuthenticationError): - async with aioredis.Redis(port=dfly.port, password=export_dfly_password) as client: + with df_local_factory.create(requirepass=requirepass) as dfly: + # Expect password form flag + with pytest.raises(redis.exceptions.AuthenticationError): + async with aioredis.Redis(port=dfly.port, password=export_dfly_password) as client: + await client.ping() + async with aioredis.Redis(password=requirepass, port=dfly.port) as client: await client.ping() - async with aioredis.Redis(password=requirepass, port=dfly.port) as client: - await client.ping() - dfly.stop() """ @@ -84,30 +78,24 @@ async def task2(k, n): ) -@dfly_args({"port": 6377}) async def test_arg_from_environ_overwritten_by_cli(df_local_factory): with EnvironCntx(DFLY_port="6378"): - dfly = df_local_factory.create() - dfly.start() - client = aioredis.Redis(port="6377") - await client.ping() - dfly.stop() + with df_local_factory.create(port=6377): + client = aioredis.Redis(port=6377) + await client.ping() async def test_arg_from_environ(df_local_factory): with EnvironCntx(DFLY_requirepass="pass"): - dfly = df_local_factory.create() - dfly.start() + with df_local_factory.create() as dfly: + # Expect password from environment variable + with pytest.raises(redis.exceptions.AuthenticationError): + client = aioredis.Redis(port=dfly.port) + await client.ping() - # Expect password from environment variable - with pytest.raises(redis.exceptions.AuthenticationError): - client = aioredis.Redis(port=dfly.port) + client = aioredis.Redis(password="pass", port=dfly.port) await client.ping() - client = aioredis.Redis(password="pass", port=dfly.port) - await client.ping() - dfly.stop() - async def test_unknown_dfly_env(df_local_factory, export_dfly_password): with EnvironCntx(DFLY_abcdef="xyz"): diff --git a/tests/dragonfly/http_conf_test.py b/tests/dragonfly/http_conf_test.py index 06dab71485c4..4ad8202255b0 100644 --- a/tests/dragonfly/http_conf_test.py +++ b/tests/dragonfly/http_conf_test.py @@ -3,52 +3,45 @@ async def test_password(df_factory): # Needs a private key and certificate. - server = df_factory.create(port=1112, requirepass="XXX") - server.start() - - async with aiohttp.ClientSession() as session: - resp = await session.get(f"http://localhost:{server.port}/") - assert resp.status == 401 - async with aiohttp.ClientSession(auth=aiohttp.BasicAuth("user", "wrongpassword")) as session: - resp = await session.get(f"http://localhost:{server.port}/") - assert resp.status == 401 - async with aiohttp.ClientSession(auth=aiohttp.BasicAuth("user", "XXX")) as session: - resp = await session.get(f"http://localhost:{server.port}/") - assert resp.status == 200 - server.stop() + with df_factory.create(port=1112, requirepass="XXX") as server: + async with aiohttp.ClientSession() as session: + resp = await session.get(f"http://localhost:{server.port}/") + assert resp.status == 401 + async with aiohttp.ClientSession( + auth=aiohttp.BasicAuth("user", "wrongpassword") + ) as session: + resp = await session.get(f"http://localhost:{server.port}/") + assert resp.status == 401 + async with aiohttp.ClientSession(auth=aiohttp.BasicAuth("user", "XXX")) as session: + resp = await session.get(f"http://localhost:{server.port}/") + assert resp.status == 200 async def test_no_password_on_admin(df_factory): # Needs a private key and certificate. - server = df_factory.create( + with df_factory.create( port=1112, admin_port=1113, requirepass="XXX", noprimary_port_http_enabled=None, admin_nopass=None, - ) - server.start() - - async with aiohttp.ClientSession(auth=aiohttp.BasicAuth("user", "XXX")) as session: - resp = await session.get(f"http://localhost:{server.admin_port}/") - assert resp.status == 200 - server.stop() + ) as server: + async with aiohttp.ClientSession(auth=aiohttp.BasicAuth("user", "XXX")) as session: + resp = await session.get(f"http://localhost:{server.admin_port}/") + assert resp.status == 200 async def test_password_on_admin(df_factory): # Needs a private key and certificate. - server = df_factory.create( + with df_factory.create( port=1112, admin_port=1113, requirepass="XXX", admin_nopass=None, - ) - server.start() - - async with aiohttp.ClientSession(auth=aiohttp.BasicAuth("user", "badpass")) as session: - resp = await session.get(f"http://localhost:{server.port}/") - assert resp.status == 401 - async with aiohttp.ClientSession(auth=aiohttp.BasicAuth("user", "XXX")) as session: - resp = await session.get(f"http://localhost:{server.port}/") - assert resp.status == 200 - server.stop() + ) as server: + async with aiohttp.ClientSession(auth=aiohttp.BasicAuth("user", "badpass")) as session: + resp = await session.get(f"http://localhost:{server.port}/") + assert resp.status == 401 + async with aiohttp.ClientSession(auth=aiohttp.BasicAuth("user", "XXX")) as session: + resp = await session.get(f"http://localhost:{server.port}/") + assert resp.status == 200 diff --git a/tests/dragonfly/server_family_test.py b/tests/dragonfly/server_family_test.py index 759257e166c2..758d9a955d1f 100644 --- a/tests/dragonfly/server_family_test.py +++ b/tests/dragonfly/server_family_test.py @@ -73,18 +73,16 @@ async def test_get_databases(async_client: aioredis.Redis): async def test_client_list(df_factory): - instance = df_factory.create(port=1111, admin_port=1112) - instance.start() - client = aioredis.Redis(port=instance.port) - admin_client = aioredis.Redis(port=instance.admin_port) - - await client.ping() - await admin_client.ping() - assert len(await client.execute_command("CLIENT LIST")) == 2 - assert len(await admin_client.execute_command("CLIENT LIST")) == 2 - - instance.stop() - await disconnect_clients(client, admin_client) + with df_factory.create(port=1111, admin_port=1112) as instance: + client = aioredis.Redis(port=instance.port) + admin_client = aioredis.Redis(port=instance.admin_port) + + await client.ping() + await admin_client.ping() + assert len(await client.execute_command("CLIENT LIST")) == 2 + assert len(await admin_client.execute_command("CLIENT LIST")) == 2 + + await disconnect_clients(client, admin_client) async def test_scan(async_client: aioredis.Redis): diff --git a/tests/dragonfly/snapshot_test.py b/tests/dragonfly/snapshot_test.py index e4a4f5ddfbfc..1ffdbce499e0 100644 --- a/tests/dragonfly/snapshot_test.py +++ b/tests/dragonfly/snapshot_test.py @@ -136,19 +136,15 @@ async def test_snapshot(self, df_local_factory, save_type, dbfilename): df_args = {"dbfilename": dbfilename, **BASIC_ARGS, "port": 1111} if save_type == "rdb": df_args["nodf_snapshot_format"] = None - df_server = df_local_factory.create(**df_args) - df_server.start() - - client = aioredis.Redis(port=df_server.port) - await client.set("TEST", hash(dbfilename)) - await client.execute_command("SAVE " + save_type) - df_server.stop() - - df_server2 = df_local_factory.create(**df_args) - df_server2.start() - client = aioredis.Redis(port=df_server.port) - response = await client.get("TEST") - assert response.decode("utf-8") == str(hash(dbfilename)) + with df_local_factory.create(**df_args) as df_server: + async with df_server.client() as client: + await client.set("TEST", hash(dbfilename)) + await client.execute_command("SAVE " + save_type) + + with df_local_factory.create(**df_args) as df_server: + async with df_server.client() as client: + response = await client.get("TEST") + assert response.decode("utf-8") == str(hash(dbfilename)) @dfly_args({**BASIC_ARGS, "dbfilename": "test-periodic", "save_schedule": "*:*"}) diff --git a/tests/dragonfly/tls_conf_test.py b/tests/dragonfly/tls_conf_test.py index 31356f654a6e..79e719b0ab4f 100644 --- a/tests/dragonfly/tls_conf_test.py +++ b/tests/dragonfly/tls_conf_test.py @@ -19,23 +19,21 @@ async def test_tls_no_key(df_factory): async def test_tls_password(df_factory, with_tls_server_args, gen_ca_cert): - server = df_factory.create(requirepass="XXX", **with_tls_server_args) - server.start() - async with server.client( - ssl=True, password="XXX", ssl_ca_certs=gen_ca_cert["ca_cert"] - ) as client: - await client.ping() - server.stop() + with df_factory.create(requirepass="XXX", **with_tls_server_args) as server: + async with server.client( + ssl=True, password="XXX", ssl_ca_certs=gen_ca_cert["ca_cert"] + ) as client: + await client.ping() async def test_tls_client_certs( df_factory, with_ca_tls_server_args, with_tls_client_args, gen_ca_cert ): - server = df_factory.create(**with_ca_tls_server_args) - server.start() - async with server.client(**with_tls_client_args, ssl_ca_certs=gen_ca_cert["ca_cert"]) as client: - await client.ping() - server.stop() + with df_factory.create(**with_ca_tls_server_args) as server: + async with server.client( + **with_tls_client_args, ssl_ca_certs=gen_ca_cert["ca_cert"] + ) as client: + await client.ping() async def test_client_tls_no_auth(df_factory): @@ -45,14 +43,12 @@ async def test_client_tls_no_auth(df_factory): async def test_client_tls_password(df_factory): - server = df_factory.create(tls_replication=None, masterauth="XXX") - server.start() - server.stop() + with df_factory.create(tls_replication=None, masterauth="XXX"): + pass async def test_client_tls_cert(df_factory, with_tls_server_args): key_args = with_tls_server_args.copy() key_args.pop("tls") - server = df_factory.create(tls_replication=None, **key_args) - server.start() - server.stop() + with df_factory.create(tls_replication=None, **key_args): + pass