Skip to content

Commit

Permalink
ConfigSecretStr mask lengths match value lengths (microsoft#141)
Browse files Browse the repository at this point in the history
  • Loading branch information
markwaddle authored Oct 17, 2024
1 parent 2c6f8d7 commit 4c75ba1
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ def _config_secret_str_serialization_mode_from_context(
)


_MASKED_VALUE = "*" * 10
def _mask(value: str) -> str:
return "*" * len(value)


def _config_secret_str_json_serializer(value: str, info: SerializationInfo) -> str:
Expand All @@ -195,7 +196,7 @@ def _config_secret_str_json_serializer(value: str, info: SerializationInfo) -> s
return value

case ConfigSecretStrJsonSerializationMode.serialize_masked_value:
return _MASKED_VALUE
return _mask(value)


def replace_config_secret_str_masked_values(model_values: ModelT, original_model_values: ModelT) -> ModelT:
Expand All @@ -211,7 +212,7 @@ def replace_config_secret_str_masked_values(model_values: ModelT, original_model
continue

if field_info.annotation is ConfigSecretStr:
if getattr(updated_model_values, field_name) == _MASKED_VALUE:
if getattr(updated_model_values, field_name) == _mask(getattr(original_model_values, field_name)):
setattr(updated_model_values, field_name, getattr(original_model_values, field_name))
continue

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ class TestConfigModel(BaseModel):

response = await instance_client.get_config()
assert response == assistant_model.ConfigResponseModel(
config={"test_key": "test_value", "secret_field": "**********"},
config={"test_key": "test_value", "secret_field": len("secret_default") * "*"},
errors=[],
json_schema=TestConfigModel.model_json_schema(),
ui_schema=expected_ui_schema,
Expand All @@ -391,7 +391,7 @@ class TestConfigModel(BaseModel):
assistant_model.ConfigPutRequestModel(config={"test_key": "new_value", "secret_field": "new_secret"})
)
assert response == assistant_model.ConfigResponseModel(
config={"test_key": "new_value", "secret_field": "**********"},
config={"test_key": "new_value", "secret_field": len("new_secret") * "*"},
errors=[],
json_schema=TestConfigModel.model_json_schema(),
ui_schema=expected_ui_schema,
Expand All @@ -406,7 +406,7 @@ class TestConfigModel(BaseModel):

response = await instance_client.get_config()
assert response == assistant_model.ConfigResponseModel(
config={"test_key": "new_value", "secret_field": "**********"},
config={"test_key": "new_value", "secret_field": len("new_secret") * "*"},
errors=[],
json_schema=TestConfigModel.model_json_schema(),
ui_schema=expected_ui_schema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,19 @@
("dict_python", ConfigSecretStrJsonSerializationMode.serialize_value, "super-secret", "super-secret"),
("dict_python", ConfigSecretStrJsonSerializationMode.serialize_value, "", ""),
# json serialization should return the expected value based on the serialization mode
("dict_json", None, "super-secret", "**********"),
("dict_json", None, "super-secret", "************"),
("dict_json", None, "", ""),
("dict_json", ConfigSecretStrJsonSerializationMode.serialize_as_empty, "super-secret", ""),
("dict_json", ConfigSecretStrJsonSerializationMode.serialize_as_empty, "", ""),
("dict_json", ConfigSecretStrJsonSerializationMode.serialize_masked_value, "super-secret", "**********"),
("dict_json", ConfigSecretStrJsonSerializationMode.serialize_masked_value, "super-secret", "************"),
("dict_json", ConfigSecretStrJsonSerializationMode.serialize_masked_value, "", ""),
("dict_json", ConfigSecretStrJsonSerializationMode.serialize_value, "super-secret", "super-secret"),
("dict_json", ConfigSecretStrJsonSerializationMode.serialize_value, "", ""),
("str_json", None, "super-secret", "**********"),
("str_json", None, "super-secret", "************"),
("str_json", None, "", ""),
("str_json", ConfigSecretStrJsonSerializationMode.serialize_as_empty, "super-secret", ""),
("str_json", ConfigSecretStrJsonSerializationMode.serialize_as_empty, "", ""),
("str_json", ConfigSecretStrJsonSerializationMode.serialize_masked_value, "super-secret", "**********"),
("str_json", ConfigSecretStrJsonSerializationMode.serialize_masked_value, "super-secret", "************"),
("str_json", ConfigSecretStrJsonSerializationMode.serialize_masked_value, "", ""),
("str_json", ConfigSecretStrJsonSerializationMode.serialize_value, "super-secret", "super-secret"),
("str_json", ConfigSecretStrJsonSerializationMode.serialize_value, "", ""),
Expand Down Expand Up @@ -92,8 +92,8 @@ class TestModel(BaseModel):

serialized_config = model.model_dump(mode="json")

assert serialized_config["secret"] == "**********"
assert serialized_config["sub_model"]["secret"] == "**********"
assert serialized_config["secret"] == "*" * len(secret_value)
assert serialized_config["sub_model"]["secret"] == "*" * len(secret_value)

deserialized_config = TestModel.model_validate(serialized_config)

Expand Down

0 comments on commit 4c75ba1

Please sign in to comment.