diff --git a/README.md b/README.md index 3ca50ef5..23de080d 100644 --- a/README.md +++ b/README.md @@ -401,6 +401,23 @@ prompts: {reminders} ``` +### Customize the moderation levels + +Levels are defined for each category of Content Safety. The higher the score, the more strict the moderation is, from 0 to 7. + +Moderation is applied on all bot data, including the web page and the conversation. + +```yaml +# config.yaml +[...] + +content_safety: + category_hate_score: 0 + category_self_harm_score: 0 + category_sexual_score: 5 + category_violence_score: 0 +``` + ### Customize the claim data schema Customization of the data schema is not supported yet through the configuration file. However, you can customize the data schema by modifying the application source code. diff --git a/helpers/config_models/content_safety.py b/helpers/config_models/content_safety.py index b90d626d..fd6489d2 100644 --- a/helpers/config_models/content_safety.py +++ b/helpers/config_models/content_safety.py @@ -1,9 +1,13 @@ from typing import List -from pydantic import SecretStr +from pydantic import SecretStr, Field from pydantic_settings import BaseSettings class ContentSafetyModel(BaseSettings): access_key: SecretStr blocklists: List[str] + category_hate_score: int = Field(default=0, ge=0, le=7) + category_self_harm_score: int = Field(default=1, ge=0, le=7) + category_sexual_score: int = Field(default=2, ge=0, le=7) + category_violence_score: int = Field(default=0, ge=0, le=7) endpoint: str diff --git a/helpers/llm.py b/helpers/llm.py index b2e9f1eb..20b070bc 100644 --- a/helpers/llm.py +++ b/helpers/llm.py @@ -63,7 +63,14 @@ class SafetyCheckError(Exception): - pass + message: str + + def __init__(self, message: str) -> None: + self.message = message + super().__init__(message) + + def __str__(self) -> str: + return self.message @retry( @@ -139,8 +146,7 @@ async def completion_sync( content = res.choices[0].message.content if not json_output: - if not await safety_check(content): - raise SafetyCheckError() + await safety_check(content) return content @@ -165,23 +171,23 @@ async def completion_model_sync( return model.model_validate_json(res) -async def safety_check(text: str) -> bool: +async def safety_check(text: str) -> None: """ - Returns `True` if the text is safe, `False` otherwise. + Raise `SafetyCheckError` if the text is safe, nothing otherwise. Text can be returned both safe and censored, before containing unsafe content. """ if not text: - return True + return try: res = await _contentsafety_analysis(text) except HttpResponseError as e: _logger.error(f"Failed to run safety check: {e.message}") - return True # Assume safe + return # Assume safe if not res: _logger.error("Failed to run safety check: No result") - return True # Assume safe + return # Assume safe for match in res.blocklists_match or []: _logger.debug(f"Matched blocklist item: {match.blocklist_item_text}") @@ -190,22 +196,33 @@ async def safety_check(text: str) -> bool: ) hate_result = _contentsafety_category_test( - res.categories_analysis, TextCategory.HATE + res.categories_analysis, + TextCategory.HATE, + CONFIG.content_safety.category_hate_score, ) self_harm_result = _contentsafety_category_test( - res.categories_analysis, TextCategory.SELF_HARM + res.categories_analysis, + TextCategory.SELF_HARM, + CONFIG.content_safety.category_self_harm_score, ) sexual_result = _contentsafety_category_test( - res.categories_analysis, TextCategory.SEXUAL + res.categories_analysis, + TextCategory.SEXUAL, + CONFIG.content_safety.category_sexual_score, ) violence_result = _contentsafety_category_test( - res.categories_analysis, TextCategory.VIOLENCE + res.categories_analysis, + TextCategory.VIOLENCE, + CONFIG.content_safety.category_violence_score, ) safety = hate_result and self_harm_result and sexual_result and violence_result _logger.debug(f'Text safety "{safety}" for text: {text}') - return safety + if not safety: + raise SafetyCheckError( + f"Unsafe content detected, hate={hate_result}, self_harm={self_harm_result}, sexual={sexual_result}, violence={violence_result}" + ) async def close() -> None: @@ -224,21 +241,28 @@ async def close() -> None: async def _contentsafety_analysis(text: str) -> AnalyzeTextResult: return await _contentsafety.analyze_text( AnalyzeTextOptions( - text=text, blocklist_names=CONFIG.content_safety.blocklists, halt_on_blocklist_hit=False, + output_type="EightSeverityLevels", + text=text, ) ) def _contentsafety_category_test( - res: List[TextCategoriesAnalysis], category: TextCategory + res: List[TextCategoriesAnalysis], + category: TextCategory, + score: int, ) -> bool: """ Returns `True` if the category is safe or the severity is low, `False` otherwise, meaning the category is unsafe. """ + if score == 0: + return True # No need to check severity + detection = next(item for item in res if item.category == category) - if detection and detection.severity and detection.severity > 2: + + if detection and detection.severity and detection.severity > score: _logger.debug(f"Matched {category} with severity {detection.severity}") return False return True diff --git a/main.py b/main.py index 5c32f08b..bcccca6f 100644 --- a/main.py +++ b/main.py @@ -440,8 +440,10 @@ async def intelligence( async def tts_callback(text: str, style: MessageStyle) -> None: nonlocal has_started - if not await safety_check(text): - _logger.warn(f"Unsafe text detected, not playing ({call.call_id})") + try: + await safety_check(text) + except SafetyCheckError as e: + _logger.warn(f"Unsafe text detected, not playing ({call.call_id}): {e}") return has_started = True @@ -626,8 +628,8 @@ async def llm_completion(system: Optional[str], call: CallModel) -> Optional[str ) except APIError: _logger.warn(f"OpenAI API call error", exc_info=True) - except SafetyCheckError: - _logger.warn(f"Safety check error", exc_info=True) + except SafetyCheckError as e: + _logger.warn(f"OpenAI safety check error: {e}") return content