Skip to content

Commit

Permalink
chore: Add a context manager to DflyInstance so we don't forget to cl…
Browse files Browse the repository at this point in the history
…ose them. (#1873)

* chore: Add a context manager to DflyInstance so we don't forget to close
them.

* Update tests/dragonfly/config_test.py

Co-authored-by: Roman Gershman <[email protected]>
Signed-off-by: Roy Jacobson <[email protected]>

---------

Signed-off-by: Roy Jacobson <[email protected]>
Co-authored-by: Roman Gershman <[email protected]>
  • Loading branch information
royjacobson and romange authored Sep 18, 2023
1 parent 74d7826 commit 1e61ec8
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 124 deletions.
7 changes: 7 additions & 0 deletions tests/dragonfly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ def __del__(self):
def client(self, *args, **kwargs) -> RedisClient:
return RedisClient(port=self.port, *args, **kwargs)

def __enter__(self):
self.start()
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
self.stop()

def start(self):
if self.params.existing_port:
return
Expand Down
28 changes: 12 additions & 16 deletions tests/dragonfly/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,19 @@

async def test_maxclients(df_factory):
# Needs some authentication
server = df_factory.create(port=1111, maxclients=1, admin_port=1112)
server.start()
with df_factory.create(port=1111, maxclients=1, admin_port=1112) as server:
async with server.client() as client1:
assert [b"maxclients", b"1"] == await client1.execute_command("CONFIG GET maxclients")

async with server.client() as client1:
assert [b"maxclients", b"1"] == await client1.execute_command("CONFIG GET maxclients")
with pytest.raises(redis.exceptions.ConnectionError):
async with server.client() as client2:
await client2.get("test")

with pytest.raises(redis.exceptions.ConnectionError):
# Check that admin connections are not limited.
async with RedisClient(port=server.admin_port) as admin_client:
await admin_client.get("test")

await client1.execute_command("CONFIG SET maxclients 3")
assert [b"maxclients", b"3"] == await client1.execute_command("CONFIG GET maxclients")
async with server.client() as client2:
await client2.get("test")

# Check that admin connections are not limited.
async with RedisClient(port=server.admin_port) as admin_client:
await admin_client.get("test")

await client1.execute_command("CONFIG SET maxclients 3")
assert [b"maxclients", b"3"] == await client1.execute_command("CONFIG GET maxclients")
async with server.client() as client2:
await client2.get("test")

server.stop()
54 changes: 21 additions & 33 deletions tests/dragonfly/generic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,29 +27,23 @@ def export_dfly_password() -> str:


async def test_password(df_local_factory, export_dfly_password):
dfly = df_local_factory.create()
dfly.start()

# Expect password form environment variable
with pytest.raises(redis.exceptions.AuthenticationError):
async with aioredis.Redis(port=dfly.port) as client:
with df_local_factory.create() as dfly:
# Expect password form environment variable
with pytest.raises(redis.exceptions.AuthenticationError):
async with aioredis.Redis(port=dfly.port) as client:
await client.ping()
async with aioredis.Redis(password=export_dfly_password, port=dfly.port) as client:
await client.ping()
async with aioredis.Redis(password=export_dfly_password, port=dfly.port) as client:
await client.ping()
dfly.stop()

# --requirepass should take precedence over environment variable
requirepass = "requirepass"
dfly = df_local_factory.create(requirepass=requirepass)
dfly.start()

# Expect password form flag
with pytest.raises(redis.exceptions.AuthenticationError):
async with aioredis.Redis(port=dfly.port, password=export_dfly_password) as client:
with df_local_factory.create(requirepass=requirepass) as dfly:
# Expect password form flag
with pytest.raises(redis.exceptions.AuthenticationError):
async with aioredis.Redis(port=dfly.port, password=export_dfly_password) as client:
await client.ping()
async with aioredis.Redis(password=requirepass, port=dfly.port) as client:
await client.ping()
async with aioredis.Redis(password=requirepass, port=dfly.port) as client:
await client.ping()
dfly.stop()


"""
Expand Down Expand Up @@ -84,30 +78,24 @@ async def task2(k, n):
)


@dfly_args({"port": 6377})
async def test_arg_from_environ_overwritten_by_cli(df_local_factory):
with EnvironCntx(DFLY_port="6378"):
dfly = df_local_factory.create()
dfly.start()
client = aioredis.Redis(port="6377")
await client.ping()
dfly.stop()
with df_local_factory.create(port=6377):
client = aioredis.Redis(port=6377)
await client.ping()


async def test_arg_from_environ(df_local_factory):
with EnvironCntx(DFLY_requirepass="pass"):
dfly = df_local_factory.create()
dfly.start()
with df_local_factory.create() as dfly:
# Expect password from environment variable
with pytest.raises(redis.exceptions.AuthenticationError):
client = aioredis.Redis(port=dfly.port)
await client.ping()

# Expect password from environment variable
with pytest.raises(redis.exceptions.AuthenticationError):
client = aioredis.Redis(port=dfly.port)
client = aioredis.Redis(password="pass", port=dfly.port)
await client.ping()

client = aioredis.Redis(password="pass", port=dfly.port)
await client.ping()
dfly.stop()


async def test_unknown_dfly_env(df_local_factory, export_dfly_password):
with EnvironCntx(DFLY_abcdef="xyz"):
Expand Down
57 changes: 25 additions & 32 deletions tests/dragonfly/http_conf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,52 +3,45 @@

async def test_password(df_factory):
# Needs a private key and certificate.
server = df_factory.create(port=1112, requirepass="XXX")
server.start()

async with aiohttp.ClientSession() as session:
resp = await session.get(f"http://localhost:{server.port}/")
assert resp.status == 401
async with aiohttp.ClientSession(auth=aiohttp.BasicAuth("user", "wrongpassword")) as session:
resp = await session.get(f"http://localhost:{server.port}/")
assert resp.status == 401
async with aiohttp.ClientSession(auth=aiohttp.BasicAuth("user", "XXX")) as session:
resp = await session.get(f"http://localhost:{server.port}/")
assert resp.status == 200
server.stop()
with df_factory.create(port=1112, requirepass="XXX") as server:
async with aiohttp.ClientSession() as session:
resp = await session.get(f"http://localhost:{server.port}/")
assert resp.status == 401
async with aiohttp.ClientSession(
auth=aiohttp.BasicAuth("user", "wrongpassword")
) as session:
resp = await session.get(f"http://localhost:{server.port}/")
assert resp.status == 401
async with aiohttp.ClientSession(auth=aiohttp.BasicAuth("user", "XXX")) as session:
resp = await session.get(f"http://localhost:{server.port}/")
assert resp.status == 200


async def test_no_password_on_admin(df_factory):
# Needs a private key and certificate.
server = df_factory.create(
with df_factory.create(
port=1112,
admin_port=1113,
requirepass="XXX",
noprimary_port_http_enabled=None,
admin_nopass=None,
)
server.start()

async with aiohttp.ClientSession(auth=aiohttp.BasicAuth("user", "XXX")) as session:
resp = await session.get(f"http://localhost:{server.admin_port}/")
assert resp.status == 200
server.stop()
) as server:
async with aiohttp.ClientSession(auth=aiohttp.BasicAuth("user", "XXX")) as session:
resp = await session.get(f"http://localhost:{server.admin_port}/")
assert resp.status == 200


async def test_password_on_admin(df_factory):
# Needs a private key and certificate.
server = df_factory.create(
with df_factory.create(
port=1112,
admin_port=1113,
requirepass="XXX",
admin_nopass=None,
)
server.start()

async with aiohttp.ClientSession(auth=aiohttp.BasicAuth("user", "badpass")) as session:
resp = await session.get(f"http://localhost:{server.port}/")
assert resp.status == 401
async with aiohttp.ClientSession(auth=aiohttp.BasicAuth("user", "XXX")) as session:
resp = await session.get(f"http://localhost:{server.port}/")
assert resp.status == 200
server.stop()
) as server:
async with aiohttp.ClientSession(auth=aiohttp.BasicAuth("user", "badpass")) as session:
resp = await session.get(f"http://localhost:{server.port}/")
assert resp.status == 401
async with aiohttp.ClientSession(auth=aiohttp.BasicAuth("user", "XXX")) as session:
resp = await session.get(f"http://localhost:{server.port}/")
assert resp.status == 200
22 changes: 10 additions & 12 deletions tests/dragonfly/server_family_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,16 @@ async def test_get_databases(async_client: aioredis.Redis):


async def test_client_list(df_factory):
instance = df_factory.create(port=1111, admin_port=1112)
instance.start()
client = aioredis.Redis(port=instance.port)
admin_client = aioredis.Redis(port=instance.admin_port)

await client.ping()
await admin_client.ping()
assert len(await client.execute_command("CLIENT LIST")) == 2
assert len(await admin_client.execute_command("CLIENT LIST")) == 2

instance.stop()
await disconnect_clients(client, admin_client)
with df_factory.create(port=1111, admin_port=1112) as instance:
client = aioredis.Redis(port=instance.port)
admin_client = aioredis.Redis(port=instance.admin_port)

await client.ping()
await admin_client.ping()
assert len(await client.execute_command("CLIENT LIST")) == 2
assert len(await admin_client.execute_command("CLIENT LIST")) == 2

await disconnect_clients(client, admin_client)


async def test_scan(async_client: aioredis.Redis):
Expand Down
22 changes: 9 additions & 13 deletions tests/dragonfly/snapshot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,19 +136,15 @@ async def test_snapshot(self, df_local_factory, save_type, dbfilename):
df_args = {"dbfilename": dbfilename, **BASIC_ARGS, "port": 1111}
if save_type == "rdb":
df_args["nodf_snapshot_format"] = None
df_server = df_local_factory.create(**df_args)
df_server.start()

client = aioredis.Redis(port=df_server.port)
await client.set("TEST", hash(dbfilename))
await client.execute_command("SAVE " + save_type)
df_server.stop()

df_server2 = df_local_factory.create(**df_args)
df_server2.start()
client = aioredis.Redis(port=df_server.port)
response = await client.get("TEST")
assert response.decode("utf-8") == str(hash(dbfilename))
with df_local_factory.create(**df_args) as df_server:
async with df_server.client() as client:
await client.set("TEST", hash(dbfilename))
await client.execute_command("SAVE " + save_type)

with df_local_factory.create(**df_args) as df_server:
async with df_server.client() as client:
response = await client.get("TEST")
assert response.decode("utf-8") == str(hash(dbfilename))


@dfly_args({**BASIC_ARGS, "dbfilename": "test-periodic", "save_schedule": "*:*"})
Expand Down
32 changes: 14 additions & 18 deletions tests/dragonfly/tls_conf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,21 @@ async def test_tls_no_key(df_factory):


async def test_tls_password(df_factory, with_tls_server_args, gen_ca_cert):
server = df_factory.create(requirepass="XXX", **with_tls_server_args)
server.start()
async with server.client(
ssl=True, password="XXX", ssl_ca_certs=gen_ca_cert["ca_cert"]
) as client:
await client.ping()
server.stop()
with df_factory.create(requirepass="XXX", **with_tls_server_args) as server:
async with server.client(
ssl=True, password="XXX", ssl_ca_certs=gen_ca_cert["ca_cert"]
) as client:
await client.ping()


async def test_tls_client_certs(
df_factory, with_ca_tls_server_args, with_tls_client_args, gen_ca_cert
):
server = df_factory.create(**with_ca_tls_server_args)
server.start()
async with server.client(**with_tls_client_args, ssl_ca_certs=gen_ca_cert["ca_cert"]) as client:
await client.ping()
server.stop()
with df_factory.create(**with_ca_tls_server_args) as server:
async with server.client(
**with_tls_client_args, ssl_ca_certs=gen_ca_cert["ca_cert"]
) as client:
await client.ping()


async def test_client_tls_no_auth(df_factory):
Expand All @@ -45,14 +43,12 @@ async def test_client_tls_no_auth(df_factory):


async def test_client_tls_password(df_factory):
server = df_factory.create(tls_replication=None, masterauth="XXX")
server.start()
server.stop()
with df_factory.create(tls_replication=None, masterauth="XXX"):
pass


async def test_client_tls_cert(df_factory, with_tls_server_args):
key_args = with_tls_server_args.copy()
key_args.pop("tls")
server = df_factory.create(tls_replication=None, **key_args)
server.start()
server.stop()
with df_factory.create(tls_replication=None, **key_args):
pass

0 comments on commit 1e61ec8

Please sign in to comment.