From 536f3fd6145067f9933d144f8b1176416d625cb8 Mon Sep 17 00:00:00 2001 From: Erik Bernhardsson Date: Tue, 13 Feb 2024 08:56:24 -0500 Subject: [PATCH] Support --force for overwriting existing secrets --- modal/cli/secret.py | 5 +++-- modal/secret.py | 7 ++++++- modal_proto/api.proto | 1 + test/cli_test.py | 6 ++++++ test/conftest.py | 14 +++++++++++--- 5 files changed, 27 insertions(+), 6 deletions(-) diff --git a/modal/cli/secret.py b/modal/cli/secret.py index 2385f2b9f..246142df4 100644 --- a/modal/cli/secret.py +++ b/modal/cli/secret.py @@ -43,12 +43,13 @@ async def list(env: Optional[str] = ENV_OPTION, json: Optional[bool] = False): display_table(column_names, rows, json, title=f"Secrets{env_part}") -@secret_cli.command("create", help="Create a new secret") +@secret_cli.command("create", help="Create a new secret. Use `--force` to overwrite any existing one.") @synchronizer.create_blocking async def create( secret_name, keyvalues: List[str] = typer.Argument(..., help="Space-separated KEY=VALUE items"), env: Optional[str] = ENV_OPTION, + force: bool = typer.Option(False, "--force"), ): env = ensure_env(env) env_dict = {} @@ -72,7 +73,7 @@ async def create( raise click.UsageError("You need to specify at least one key for your secret") # Create secret - await _Secret.create_deployed(secret_name, env_dict) + await _Secret.create_deployed(secret_name, env_dict, overwrite=force) # Print code sample console = Console() diff --git a/modal/secret.py b/modal/secret.py index c72ba1ed4..7aebc6600 100644 --- a/modal/secret.py +++ b/modal/secret.py @@ -190,14 +190,19 @@ async def create_deployed( namespace=api_pb2.DEPLOYMENT_NAMESPACE_WORKSPACE, client: Optional[_Client] = None, environment_name: Optional[str] = None, + overwrite: bool = False, ) -> str: if client is None: client = await _Client.from_env() + if overwrite: + object_creation_type = api_pb2.OBJECT_CREATION_TYPE_CREATE_OVERWRITE_IF_EXISTS + else: + object_creation_type = api_pb2.OBJECT_CREATION_TYPE_CREATE_FAIL_IF_EXISTS request = api_pb2.SecretGetOrCreateRequest( deployment_name=deployment_name, namespace=namespace, environment_name=_get_environment_name(environment_name), - object_creation_type=api_pb2.OBJECT_CREATION_TYPE_CREATE_FAIL_IF_EXISTS, + object_creation_type=object_creation_type, env_dict=env_dict, ) resp = await retry_transient_errors(client.stub.SecretGetOrCreate, request) diff --git a/modal_proto/api.proto b/modal_proto/api.proto index 6d356d6d0..899ca8f66 100644 --- a/modal_proto/api.proto +++ b/modal_proto/api.proto @@ -142,6 +142,7 @@ enum ObjectCreationType { OBJECT_CREATION_TYPE_UNSPECIFIED = 0; // just lookup OBJECT_CREATION_TYPE_CREATE_IF_MISSING = 1; OBJECT_CREATION_TYPE_CREATE_FAIL_IF_EXISTS = 2; + OBJECT_CREATION_TYPE_CREATE_OVERWRITE_IF_EXISTS = 3; } enum ProgressType { diff --git a/test/cli_test.py b/test/cli_test.py index 5f4411737..7d3e76b27 100644 --- a/test/cli_test.py +++ b/test/cli_test.py @@ -96,6 +96,12 @@ def test_secret_create(servicer, set_env_client): _run(["secret", "create", "foo", "bar=baz"]) assert len(servicer.secrets) == 1 + # Creating the same one again should fail + _run(["secret", "create", "foo", "bar=baz"], expected_exit_code=1) + + # But it should succeed with --force + _run(["secret", "create", "foo", "bar=baz", "--force"]) + def test_secret_list(servicer, set_env_client): res = _run(["secret", "list"]) diff --git a/test/conftest.py b/test/conftest.py index c6d73c70f..f159ff783 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -826,15 +826,23 @@ async def SecretGetOrCreate(self, stream): request: api_pb2.SecretGetOrCreateRequest = await stream.recv_message() k = (request.deployment_name, request.namespace, request.environment_name) if request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_FAIL_IF_EXISTS: - secret_id = "st-" + str(len(self.secrets)) - self.secrets[secret_id] = request.env_dict - self.deployed_secrets[k] = secret_id + if k in self.deployed_secrets: + raise GRPCError(Status.ALREADY_EXISTS, "Already exists") + secret_id = None + elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_CREATE_OVERWRITE_IF_EXISTS: + secret_id = None elif request.object_creation_type == api_pb2.OBJECT_CREATION_TYPE_UNSPECIFIED: if k not in self.deployed_secrets: raise GRPCError(Status.NOT_FOUND, "No such secret") secret_id = self.deployed_secrets[k] else: raise Exception("unsupported creation type") + + if secret_id is None: # Create one + secret_id = "st-" + str(len(self.secrets)) + self.secrets[secret_id] = request.env_dict + self.deployed_secrets[k] = secret_id + await stream.send_message(api_pb2.SecretGetOrCreateResponse(secret_id=secret_id)) async def SecretList(self, stream):