Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Add a context manager to DflyInstance so we don't forget to close them. #1873

Merged
merged 2 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions tests/dragonfly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ def __del__(self):
def client(self, *args, **kwargs) -> RedisClient:
return RedisClient(port=self.port, *args, **kwargs)

def __enter__(self):
self.start()
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
self.stop()

def start(self):
if self.params.existing_port:
return
Expand Down
29 changes: 13 additions & 16 deletions tests/dragonfly/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,20 @@

async def test_maxclients(df_factory):
# Needs some authentication
server = df_factory.create(port=1111, maxclients=1, admin_port=1112)
server.start()
with df_factory.create(port=1111, maxclients=1, admin_port=1112) as server:
print(server)
royjacobson marked this conversation as resolved.
Show resolved Hide resolved
async with server.client() as client1:
assert [b"maxclients", b"1"] == await client1.execute_command("CONFIG GET maxclients")

async with server.client() as client1:
assert [b"maxclients", b"1"] == await client1.execute_command("CONFIG GET maxclients")
with pytest.raises(redis.exceptions.ConnectionError):
async with server.client() as client2:
await client2.get("test")

with pytest.raises(redis.exceptions.ConnectionError):
# Check that admin connections are not limited.
async with RedisClient(port=server.admin_port) as admin_client:
await admin_client.get("test")

await client1.execute_command("CONFIG SET maxclients 3")
assert [b"maxclients", b"3"] == await client1.execute_command("CONFIG GET maxclients")
async with server.client() as client2:
await client2.get("test")

# Check that admin connections are not limited.
async with RedisClient(port=server.admin_port) as admin_client:
await admin_client.get("test")

await client1.execute_command("CONFIG SET maxclients 3")
assert [b"maxclients", b"3"] == await client1.execute_command("CONFIG GET maxclients")
async with server.client() as client2:
await client2.get("test")

server.stop()
54 changes: 21 additions & 33 deletions tests/dragonfly/generic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,29 +27,23 @@ def export_dfly_password() -> str:


async def test_password(df_local_factory, export_dfly_password):
dfly = df_local_factory.create()
dfly.start()

# Expect password form environment variable
with pytest.raises(redis.exceptions.AuthenticationError):
async with aioredis.Redis(port=dfly.port) as client:
with df_local_factory.create() as dfly:
# Expect password form environment variable
with pytest.raises(redis.exceptions.AuthenticationError):
async with aioredis.Redis(port=dfly.port) as client:
await client.ping()
async with aioredis.Redis(password=export_dfly_password, port=dfly.port) as client:
await client.ping()
async with aioredis.Redis(password=export_dfly_password, port=dfly.port) as client:
await client.ping()
dfly.stop()

# --requirepass should take precedence over environment variable
requirepass = "requirepass"
dfly = df_local_factory.create(requirepass=requirepass)
dfly.start()

# Expect password form flag
with pytest.raises(redis.exceptions.AuthenticationError):
async with aioredis.Redis(port=dfly.port, password=export_dfly_password) as client:
with df_local_factory.create(requirepass=requirepass) as dfly:
# Expect password form flag
with pytest.raises(redis.exceptions.AuthenticationError):
async with aioredis.Redis(port=dfly.port, password=export_dfly_password) as client:
await client.ping()
async with aioredis.Redis(password=requirepass, port=dfly.port) as client:
await client.ping()
async with aioredis.Redis(password=requirepass, port=dfly.port) as client:
await client.ping()
dfly.stop()


"""
Expand Down Expand Up @@ -84,30 +78,24 @@ async def task2(k, n):
)


@dfly_args({"port": 6377})
async def test_arg_from_environ_overwritten_by_cli(df_local_factory):
with EnvironCntx(DFLY_port="6378"):
dfly = df_local_factory.create()
dfly.start()
client = aioredis.Redis(port="6377")
await client.ping()
dfly.stop()
with df_local_factory.create(port=6377):
client = aioredis.Redis(port=6377)
await client.ping()


async def test_arg_from_environ(df_local_factory):
with EnvironCntx(DFLY_requirepass="pass"):
dfly = df_local_factory.create()
dfly.start()
with df_local_factory.create() as dfly:
# Expect password from environment variable
with pytest.raises(redis.exceptions.AuthenticationError):
client = aioredis.Redis(port=dfly.port)
await client.ping()

# Expect password from environment variable
with pytest.raises(redis.exceptions.AuthenticationError):
client = aioredis.Redis(port=dfly.port)
client = aioredis.Redis(password="pass", port=dfly.port)
await client.ping()

client = aioredis.Redis(password="pass", port=dfly.port)
await client.ping()
dfly.stop()


async def test_unknown_dfly_env(df_local_factory, export_dfly_password):
with EnvironCntx(DFLY_abcdef="xyz"):
Expand Down
57 changes: 25 additions & 32 deletions tests/dragonfly/http_conf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,52 +3,45 @@

async def test_password(df_factory):
# Needs a private key and certificate.
server = df_factory.create(port=1112, requirepass="XXX")
server.start()

async with aiohttp.ClientSession() as session:
resp = await session.get(f"http://localhost:{server.port}/")
assert resp.status == 401
async with aiohttp.ClientSession(auth=aiohttp.BasicAuth("user", "wrongpassword")) as session:
resp = await session.get(f"http://localhost:{server.port}/")
assert resp.status == 401
async with aiohttp.ClientSession(auth=aiohttp.BasicAuth("user", "XXX")) as session:
resp = await session.get(f"http://localhost:{server.port}/")
assert resp.status == 200
server.stop()
with df_factory.create(port=1112, requirepass="XXX") as server:
async with aiohttp.ClientSession() as session:
resp = await session.get(f"http://localhost:{server.port}/")
assert resp.status == 401
async with aiohttp.ClientSession(
auth=aiohttp.BasicAuth("user", "wrongpassword")
) as session:
resp = await session.get(f"http://localhost:{server.port}/")
assert resp.status == 401
async with aiohttp.ClientSession(auth=aiohttp.BasicAuth("user", "XXX")) as session:
resp = await session.get(f"http://localhost:{server.port}/")
assert resp.status == 200


async def test_no_password_on_admin(df_factory):
# Needs a private key and certificate.
server = df_factory.create(
with df_factory.create(
port=1112,
admin_port=1113,
requirepass="XXX",
noprimary_port_http_enabled=None,
admin_nopass=None,
)
server.start()

async with aiohttp.ClientSession(auth=aiohttp.BasicAuth("user", "XXX")) as session:
resp = await session.get(f"http://localhost:{server.admin_port}/")
assert resp.status == 200
server.stop()
) as server:
async with aiohttp.ClientSession(auth=aiohttp.BasicAuth("user", "XXX")) as session:
resp = await session.get(f"http://localhost:{server.admin_port}/")
assert resp.status == 200


async def test_password_on_admin(df_factory):
# Needs a private key and certificate.
server = df_factory.create(
with df_factory.create(
port=1112,
admin_port=1113,
requirepass="XXX",
admin_nopass=None,
)
server.start()

async with aiohttp.ClientSession(auth=aiohttp.BasicAuth("user", "badpass")) as session:
resp = await session.get(f"http://localhost:{server.port}/")
assert resp.status == 401
async with aiohttp.ClientSession(auth=aiohttp.BasicAuth("user", "XXX")) as session:
resp = await session.get(f"http://localhost:{server.port}/")
assert resp.status == 200
server.stop()
) as server:
async with aiohttp.ClientSession(auth=aiohttp.BasicAuth("user", "badpass")) as session:
resp = await session.get(f"http://localhost:{server.port}/")
assert resp.status == 401
async with aiohttp.ClientSession(auth=aiohttp.BasicAuth("user", "XXX")) as session:
resp = await session.get(f"http://localhost:{server.port}/")
assert resp.status == 200
22 changes: 10 additions & 12 deletions tests/dragonfly/server_family_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,16 @@ async def test_get_databases(async_client: aioredis.Redis):


async def test_client_list(df_factory):
instance = df_factory.create(port=1111, admin_port=1112)
instance.start()
client = aioredis.Redis(port=instance.port)
admin_client = aioredis.Redis(port=instance.admin_port)

await client.ping()
await admin_client.ping()
assert len(await client.execute_command("CLIENT LIST")) == 2
assert len(await admin_client.execute_command("CLIENT LIST")) == 2

instance.stop()
await disconnect_clients(client, admin_client)
with df_factory.create(port=1111, admin_port=1112) as instance:
client = aioredis.Redis(port=instance.port)
admin_client = aioredis.Redis(port=instance.admin_port)

await client.ping()
await admin_client.ping()
assert len(await client.execute_command("CLIENT LIST")) == 2
assert len(await admin_client.execute_command("CLIENT LIST")) == 2

await disconnect_clients(client, admin_client)


async def test_scan(async_client: aioredis.Redis):
Expand Down
22 changes: 9 additions & 13 deletions tests/dragonfly/snapshot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,19 +136,15 @@ async def test_snapshot(self, df_local_factory, save_type, dbfilename):
df_args = {"dbfilename": dbfilename, **BASIC_ARGS, "port": 1111}
if save_type == "rdb":
df_args["nodf_snapshot_format"] = None
df_server = df_local_factory.create(**df_args)
df_server.start()

client = aioredis.Redis(port=df_server.port)
await client.set("TEST", hash(dbfilename))
await client.execute_command("SAVE " + save_type)
df_server.stop()

df_server2 = df_local_factory.create(**df_args)
df_server2.start()
client = aioredis.Redis(port=df_server.port)
response = await client.get("TEST")
assert response.decode("utf-8") == str(hash(dbfilename))
with df_local_factory.create(**df_args) as df_server:
async with df_server.client() as client:
await client.set("TEST", hash(dbfilename))
await client.execute_command("SAVE " + save_type)

with df_local_factory.create(**df_args) as df_server:
async with df_server.client() as client:
response = await client.get("TEST")
assert response.decode("utf-8") == str(hash(dbfilename))


@dfly_args({**BASIC_ARGS, "dbfilename": "test-periodic", "save_schedule": "*:*"})
Expand Down
32 changes: 14 additions & 18 deletions tests/dragonfly/tls_conf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,21 @@ async def test_tls_no_key(df_factory):


async def test_tls_password(df_factory, with_tls_server_args, gen_ca_cert):
server = df_factory.create(requirepass="XXX", **with_tls_server_args)
server.start()
async with server.client(
ssl=True, password="XXX", ssl_ca_certs=gen_ca_cert["ca_cert"]
) as client:
await client.ping()
server.stop()
with df_factory.create(requirepass="XXX", **with_tls_server_args) as server:
async with server.client(
ssl=True, password="XXX", ssl_ca_certs=gen_ca_cert["ca_cert"]
) as client:
await client.ping()


async def test_tls_client_certs(
df_factory, with_ca_tls_server_args, with_tls_client_args, gen_ca_cert
):
server = df_factory.create(**with_ca_tls_server_args)
server.start()
async with server.client(**with_tls_client_args, ssl_ca_certs=gen_ca_cert["ca_cert"]) as client:
await client.ping()
server.stop()
with df_factory.create(**with_ca_tls_server_args) as server:
async with server.client(
**with_tls_client_args, ssl_ca_certs=gen_ca_cert["ca_cert"]
) as client:
await client.ping()


async def test_client_tls_no_auth(df_factory):
Expand All @@ -45,14 +43,12 @@ async def test_client_tls_no_auth(df_factory):


async def test_client_tls_password(df_factory):
server = df_factory.create(tls_replication=None, masterauth="XXX")
server.start()
server.stop()
with df_factory.create(tls_replication=None, masterauth="XXX"):
pass


async def test_client_tls_cert(df_factory, with_tls_server_args):
key_args = with_tls_server_args.copy()
key_args.pop("tls")
server = df_factory.create(tls_replication=None, **key_args)
server.start()
server.stop()
with df_factory.create(tls_replication=None, **key_args):
pass