diff --git a/sdk/textanalytics/azure-ai-textanalytics/CHANGELOG.md b/sdk/textanalytics/azure-ai-textanalytics/CHANGELOG.md index fa4e485f5fce..f8d6c7dd0f8c 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/CHANGELOG.md +++ b/sdk/textanalytics/azure-ai-textanalytics/CHANGELOG.md @@ -9,6 +9,9 @@ - Removed `AnalyzeActionsResult` - Removed `AnalyzeActionsError` +**New Features** +- Added `catagories_filter` to `RecognizePiiEntitiesAction` + ## 5.1.0b7 (2021-05-18) **Breaking Changes** diff --git a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_models.py b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_models.py index 16caace73c78..3e5955d32f40 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_models.py +++ b/sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_models.py @@ -11,7 +11,7 @@ ) from ._generated.v3_0 import models as _v3_0_models -from ._generated.v3_1 import models as _latest_preview_models +from ._generated.v3_1 import models as _v3_1_models def _get_indices(relation): return [int(s) for s in re.findall(r"\d+", relation)] @@ -1410,8 +1410,8 @@ def __repr__(self, **kwargs): .format(self.model_version, self.string_index_type, self.disable_service_logs)[:1024] def to_generated(self): - return _latest_preview_models.EntitiesTask( - parameters=_latest_preview_models.EntitiesTaskParameters( + return _v3_1_models.EntitiesTask( + parameters=_v3_1_models.EntitiesTaskParameters( model_version=self.model_version, string_index_type=self.string_index_type, logging_opt_out=self.disable_service_logs, @@ -1481,8 +1481,8 @@ def __repr__(self, **kwargs): )[:1024] def to_generated(self): - return _latest_preview_models.SentimentAnalysisTask( - parameters=_latest_preview_models.SentimentAnalysisTaskParameters( + return _v3_1_models.SentimentAnalysisTask( + parameters=_v3_1_models.SentimentAnalysisTaskParameters( model_version=self.model_version, opinion_mining=self.show_opinion_mining, string_index_type=self.string_index_type, @@ -1502,6 +1502,11 @@ class RecognizePiiEntitiesAction(DictMixin): :keyword str model_version: The model version to use for the analysis. :keyword str domain_filter: An optional string to set the PII domain to include only a subset of the PII entity categories. Possible values include 'phi' or None. + :keyword categories_filter: Instead of filtering over all PII entity categories, you can pass in a list of + the specific PII entity categories you want to filter out. For example, if you only want to filter out + U.S. social security numbers in a document, you can pass in + `[PiiEntityCategoryType.US_SOCIAL_SECURITY_NUMBER]` for this kwarg. + :paramtype categories_filter: list[~azure.ai.textanalytics.PiiEntityCategoryType] :keyword str string_index_type: Specifies the method used to interpret string offsets. `UnicodeCodePoint`, the Python encoding, is the default. To override the Python default, you can also pass in `Utf16CodePoint` or TextElement_v8`. For additional information @@ -1517,6 +1522,11 @@ class RecognizePiiEntitiesAction(DictMixin): :ivar str model_version: The model version to use for the analysis. :ivar str domain_filter: An optional string to set the PII domain to include only a subset of the PII entity categories. Possible values include 'phi' or None. + :ivar categories_filter: Instead of filtering over all PII entity categories, you can pass in a list of + the specific PII entity categories you want to filter out. For example, if you only want to filter out + U.S. social security numbers in a document, you can pass in + `[PiiEntityCategoryType.US_SOCIAL_SECURITY_NUMBER]` for this kwarg. + :vartype categories_filter: list[~azure.ai.textanalytics.PiiEntityCategoryType] :ivar str string_index_type: Specifies the method used to interpret string offsets. `UnicodeCodePoint`, the Python encoding, is the default. To override the Python default, you can also pass in `Utf16CodePoint` or TextElement_v8`. For additional information @@ -1534,23 +1544,26 @@ class RecognizePiiEntitiesAction(DictMixin): def __init__(self, **kwargs): self.model_version = kwargs.get("model_version", "latest") self.domain_filter = kwargs.get("domain_filter", None) + self.categories_filter = kwargs.get("categories_filter", None) self.string_index_type = kwargs.get("string_index_type", "UnicodeCodePoint") self.disable_service_logs = kwargs.get("disable_service_logs", False) def __repr__(self, **kwargs): - return "RecognizePiiEntitiesAction(model_version={}, domain_filter={}, string_index_type={}, "\ - "disable_service_logs={}".format( + return "RecognizePiiEntitiesAction(model_version={}, domain_filter={}, categories_filter={}, "\ + "string_index_type={}, disable_service_logs={}".format( self.model_version, self.domain_filter, + self.categories_filter, self.string_index_type, self.disable_service_logs, )[:1024] def to_generated(self): - return _latest_preview_models.PiiTask( - parameters=_latest_preview_models.PiiTaskParameters( + return _v3_1_models.PiiTask( + parameters=_v3_1_models.PiiTaskParameters( model_version=self.model_version, domain=self.domain_filter, + pii_categories=self.categories_filter, string_index_type=self.string_index_type, logging_opt_out=self.disable_service_logs ) @@ -1594,8 +1607,8 @@ def __repr__(self, **kwargs): .format(self.model_version, self.disable_service_logs)[:1024] def to_generated(self): - return _latest_preview_models.KeyPhrasesTask( - parameters=_latest_preview_models.KeyPhrasesTaskParameters( + return _v3_1_models.KeyPhrasesTask( + parameters=_v3_1_models.KeyPhrasesTaskParameters( model_version=self.model_version, logging_opt_out=self.disable_service_logs, ) @@ -1650,8 +1663,8 @@ def __repr__(self, **kwargs): )[:1024] def to_generated(self): - return _latest_preview_models.EntityLinkingTask( - parameters=_latest_preview_models.EntityLinkingTaskParameters( + return _v3_1_models.EntityLinkingTask( + parameters=_v3_1_models.EntityLinkingTaskParameters( model_version=self.model_version, string_index_type=self.string_index_type, logging_opt_out=self.disable_service_logs, diff --git a/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze.py b/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze.py index b36d16dbf390..b6f0a0071cf8 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze.py +++ b/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze.py @@ -33,6 +33,7 @@ RecognizeLinkedEntitiesResult, RecognizeEntitiesResult, RecognizePiiEntitiesResult, + PiiEntityCategoryType ) # pre-apply the client_cls positional argument so it needn't be explicitly passed below @@ -679,3 +680,31 @@ def callback(resp): polling_interval=self._interval(), raw_response_hook=callback, ).result() + + @GlobalTextAnalyticsAccountPreparer() + @TextAnalyticsClientPreparer() + def test_pii_action_categories_filter(self, client): + + docs = [{"id": "1", "text": "My SSN is 859-98-0987."}, + {"id": "2", + "text": "Your ABA number - 111000025 - is the first 9 digits in the lower left hand corner of your personal check."}, + {"id": "3", "text": "Is 998.214.865-68 your Brazilian CPF number?"}] + + actions = [ + RecognizePiiEntitiesAction( + categories_filter=[ + PiiEntityCategoryType.US_SOCIAL_SECURITY_NUMBER, + PiiEntityCategoryType.ABA_ROUTING_NUMBER, + ] + ), + ] + + result = client.begin_analyze_actions(documents=docs, actions=actions, polling_interval=self._interval()).result() + action_results = list(result) + assert len(action_results) == 3 + + assert action_results[0][0].entities[0].text == "859-98-0987" + assert action_results[0][0].entities[0].category == PiiEntityCategoryType.US_SOCIAL_SECURITY_NUMBER + assert action_results[1][0].entities[0].text == "111000025" + assert action_results[1][0].entities[0].category == PiiEntityCategoryType.ABA_ROUTING_NUMBER + assert action_results[2][0].entities == [] # No Brazilian CPF since not in categories_filter diff --git a/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze_async.py b/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze_async.py index 4bb86432e0f0..eba9d69ea505 100644 --- a/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze_async.py +++ b/sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze_async.py @@ -34,6 +34,7 @@ RecognizeLinkedEntitiesResult, AnalyzeSentimentResult, ExtractKeyPhrasesResult, + PiiEntityCategoryType ) # pre-apply the client_cls positional argument so it needn't be explicitly passed below @@ -727,3 +728,34 @@ async def test_disable_service_logs(self, client): actions=actions, polling_interval=self._interval(), )).result() + + @GlobalTextAnalyticsAccountPreparer() + @TextAnalyticsClientPreparer() + async def test_pii_action_categories_filter(self, client): + + docs = [{"id": "1", "text": "My SSN is 859-98-0987."}, + {"id": "2", + "text": "Your ABA number - 111000025 - is the first 9 digits in the lower left hand corner of your personal check."}, + {"id": "3", "text": "Is 998.214.865-68 your Brazilian CPF number?"}] + + actions = [ + RecognizePiiEntitiesAction( + categories_filter=[ + PiiEntityCategoryType.US_SOCIAL_SECURITY_NUMBER, + PiiEntityCategoryType.ABA_ROUTING_NUMBER + ] + ), + ] + async with client: + result = await (await client.begin_analyze_actions(documents=docs, actions=actions, polling_interval=self._interval())).result() + action_results = [] + async for p in result: + action_results.append(p) + + assert len(action_results) == 3 + + assert action_results[0][0].entities[0].text == "859-98-0987" + assert action_results[0][0].entities[0].category == PiiEntityCategoryType.US_SOCIAL_SECURITY_NUMBER + assert action_results[1][0].entities[0].text == "111000025" + assert action_results[1][0].entities[0].category == PiiEntityCategoryType.ABA_ROUTING_NUMBER + assert action_results[2][0].entities == [] # No Brazilian CPF since not in categories_filter