Skip to content

Commit

Permalink
feat(engine): Improve missing secrets visibility in AuthSandbox (#217)
Browse files Browse the repository at this point in the history
  • Loading branch information
daryllimyt authored Jun 29, 2024
1 parent e2ba8b1 commit f2d9acf
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 25 deletions.
12 changes: 9 additions & 3 deletions tests/unit/test_sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@ async def test_auth_sandbox_with_secrets(mocker: pytest_mock.MockFixture, auth_s
mock_secret = Secret(name="my_secret", owner_id=role.user_id)
mock_secret.keys = mock_secret_keys

# Mock httpx.Response
mock_response = mocker.Mock(spec=httpx.Response)
mock_response.raise_for_status = mocker.Mock()
mock_response.content = mock_secret.model_dump_json().encode()

mock_client = mocker.AsyncMock()
mock_client.get.return_value = httpx.Response(
200, content=mock_secret.model_dump_json().encode()
)
mock_client.__aenter__.return_value = mock_client
mock_client.__aexit__.return_value = None
mock_client.get.return_value = mock_response

# Patch the AuthenticatedAPIClient to return the mock client
mocker.patch(
Expand All @@ -36,6 +39,9 @@ async def test_auth_sandbox_with_secrets(mocker: pytest_mock.MockFixture, auth_s
# Assert that the secrets API was called with the correct parameters
mock_client.get.assert_called_once_with("/secrets/my_secret")

# Assert that raise_for_status was called
mock_response.raise_for_status.assert_called_once()


@pytest.mark.asyncio
async def test_auth_sandbox_without_secrets(auth_sandbox):
Expand Down
67 changes: 45 additions & 22 deletions tracecat/auth/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from loguru import logger

from tracecat.auth.clients import AuthenticatedAPIClient
from tracecat.concurrency import GatheringTaskGroup
from tracecat.contexts import ctx_role
from tracecat.types.auth import Role
from tracecat.types.exceptions import TracecatCredentialsError

if TYPE_CHECKING:
from tracecat.db.schemas import Secret
from tracecat.types.auth import Role


class AuthSandbox:
Expand All @@ -32,7 +34,7 @@ def __init__(
target: Literal["env", "context"] = "env",
):
self._role = role or ctx_role.get()
self._secret_paths: list[str] = secrets
self._secret_paths: list[str] = secrets or []
self._secret_objs: list[Secret] = []
self._target = target
self._context = {}
Expand Down Expand Up @@ -62,6 +64,12 @@ def secrets(self) -> dict[str, str]:
"""Return secret names mapped to their secret key value pairs."""
return self._context

def _iter_secrets(self):
"""Iterate over the secrets."""
for secret in self._secret_objs:
for kv in secret.keys or []:
yield secret.name, kv

def _set_secrets(self):
"""Set secrets in the target."""
if self._target == "context":
Expand All @@ -70,26 +78,22 @@ def _set_secrets(self):
paths=self._secret_paths,
objs=self._secret_objs,
)
for secret in self._secret_objs:
self._context[secret.name] = {
kv.key: kv.value.get_secret_value() for kv in secret.keys
}
for name, kv in self._iter_secrets():
self._context[name] = {kv.key: kv.value.get_secret_value()}
else:
logger.info("Setting secrets in the environment", paths=self._secret_paths)
for secret in self._secret_objs:
for kv in secret.keys:
os.environ[kv.key] = kv.value.get_secret_value()
for _, kv in self._iter_secrets():
os.environ[kv.key] = kv.value.get_secret_value()

def _unset_secrets(self):
if self._target == "context":
for secret in self._secret_objs:
if secret.name in self._context:
del self._context[secret.name]
else:
for secret in self._secret_objs:
for kv in secret.keys:
if kv.key in os.environ:
del os.environ[kv.key]
for _, kv in self._iter_secrets():
if kv.key in os.environ:
del os.environ[kv.key]

async def _get_secrets(self) -> list[Secret]:
"""Retrieve secrets from the secrets API."""
Expand All @@ -104,12 +108,31 @@ async def _get_secrets(self) -> list[Secret]:
)
secret_names = (path.split(".")[0] for path in self._secret_paths)

async with AuthenticatedAPIClient(role=self._role) as client:
# NOTE(perf): This is not really batched - room for improvement
secret_responses = await asyncio.gather(
*[client.get(f"/secrets/{secret_name}") for secret_name in secret_names]
)
return [
Secret.model_validate_json(secret_bytes.content)
for secret_bytes in secret_responses
]
try:
async with (
AuthenticatedAPIClient(role=self._role) as client,
GatheringTaskGroup() as tg,
):
for secret_name in secret_names:

async def fetcher(name: str):
try:
res = await client.get(f"/secrets/{name}")
res.raise_for_status() # Raise an exception for HTTP error codes
return res
except Exception as e:
msg = (
f"Failed to retrieve secret {name!r}."
f" Please ensure you have set all required secrets: {self._secret_paths}"
)
logger.error(msg)
raise TracecatCredentialsError(msg) from e

tg.create_task(fetcher(secret_name))
except* TracecatCredentialsError as eg:
raise eg

return [
Secret.model_validate_json(secret_bytes.content)
for secret_bytes in tg.results()
]
17 changes: 17 additions & 0 deletions tracecat/concurrency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import asyncio


class GatheringTaskGroup[T](asyncio.TaskGroup):
"""Convenience class to gather results from tasks in a task group."""

def __init__(self):
super().__init__()
self.__tasks: list[asyncio.Task[T]] = []

def create_task(self, coro, *, name=None, context=None) -> asyncio.Task[T]:
task = super().create_task(coro, name=name, context=context)
self.__tasks.append(task)
return task

def results(self) -> list[T]:
return [task.result() for task in self.__tasks]

0 comments on commit f2d9acf

Please sign in to comment.