Skip to content

Commit

Permalink
[textanalytics] add categories filter to RecognizePiiEntitiesAction (#…
Browse files Browse the repository at this point in the history
…19223)

* add categories_filter to RecognizePiiEntitiesAction

* add tests for categories_filter

* fix tests and docstring

* update changelog

* updating categories filter tests to new analyze design
  • Loading branch information
kristapratico authored Jun 15, 2021
1 parent 6fcb85c commit ec064b8
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 13 deletions.
3 changes: 3 additions & 0 deletions sdk/textanalytics/azure-ai-textanalytics/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
- Removed `AnalyzeActionsResult`
- Removed `AnalyzeActionsError`

**New Features**
- Added `catagories_filter` to `RecognizePiiEntitiesAction`

## 5.1.0b7 (2021-05-18)

**Breaking Changes**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)

from ._generated.v3_0 import models as _v3_0_models
from ._generated.v3_1 import models as _latest_preview_models
from ._generated.v3_1 import models as _v3_1_models

def _get_indices(relation):
return [int(s) for s in re.findall(r"\d+", relation)]
Expand Down Expand Up @@ -1410,8 +1410,8 @@ def __repr__(self, **kwargs):
.format(self.model_version, self.string_index_type, self.disable_service_logs)[:1024]

def to_generated(self):
return _latest_preview_models.EntitiesTask(
parameters=_latest_preview_models.EntitiesTaskParameters(
return _v3_1_models.EntitiesTask(
parameters=_v3_1_models.EntitiesTaskParameters(
model_version=self.model_version,
string_index_type=self.string_index_type,
logging_opt_out=self.disable_service_logs,
Expand Down Expand Up @@ -1481,8 +1481,8 @@ def __repr__(self, **kwargs):
)[:1024]

def to_generated(self):
return _latest_preview_models.SentimentAnalysisTask(
parameters=_latest_preview_models.SentimentAnalysisTaskParameters(
return _v3_1_models.SentimentAnalysisTask(
parameters=_v3_1_models.SentimentAnalysisTaskParameters(
model_version=self.model_version,
opinion_mining=self.show_opinion_mining,
string_index_type=self.string_index_type,
Expand All @@ -1502,6 +1502,11 @@ class RecognizePiiEntitiesAction(DictMixin):
:keyword str model_version: The model version to use for the analysis.
:keyword str domain_filter: An optional string to set the PII domain to include only a
subset of the PII entity categories. Possible values include 'phi' or None.
:keyword categories_filter: Instead of filtering over all PII entity categories, you can pass in a list of
the specific PII entity categories you want to filter out. For example, if you only want to filter out
U.S. social security numbers in a document, you can pass in
`[PiiEntityCategoryType.US_SOCIAL_SECURITY_NUMBER]` for this kwarg.
:paramtype categories_filter: list[~azure.ai.textanalytics.PiiEntityCategoryType]
:keyword str string_index_type: Specifies the method used to interpret string offsets.
`UnicodeCodePoint`, the Python encoding, is the default. To override the Python default,
you can also pass in `Utf16CodePoint` or TextElement_v8`. For additional information
Expand All @@ -1517,6 +1522,11 @@ class RecognizePiiEntitiesAction(DictMixin):
:ivar str model_version: The model version to use for the analysis.
:ivar str domain_filter: An optional string to set the PII domain to include only a
subset of the PII entity categories. Possible values include 'phi' or None.
:ivar categories_filter: Instead of filtering over all PII entity categories, you can pass in a list of
the specific PII entity categories you want to filter out. For example, if you only want to filter out
U.S. social security numbers in a document, you can pass in
`[PiiEntityCategoryType.US_SOCIAL_SECURITY_NUMBER]` for this kwarg.
:vartype categories_filter: list[~azure.ai.textanalytics.PiiEntityCategoryType]
:ivar str string_index_type: Specifies the method used to interpret string offsets.
`UnicodeCodePoint`, the Python encoding, is the default. To override the Python default,
you can also pass in `Utf16CodePoint` or TextElement_v8`. For additional information
Expand All @@ -1534,23 +1544,26 @@ class RecognizePiiEntitiesAction(DictMixin):
def __init__(self, **kwargs):
self.model_version = kwargs.get("model_version", "latest")
self.domain_filter = kwargs.get("domain_filter", None)
self.categories_filter = kwargs.get("categories_filter", None)
self.string_index_type = kwargs.get("string_index_type", "UnicodeCodePoint")
self.disable_service_logs = kwargs.get("disable_service_logs", False)

def __repr__(self, **kwargs):
return "RecognizePiiEntitiesAction(model_version={}, domain_filter={}, string_index_type={}, "\
"disable_service_logs={}".format(
return "RecognizePiiEntitiesAction(model_version={}, domain_filter={}, categories_filter={}, "\
"string_index_type={}, disable_service_logs={}".format(
self.model_version,
self.domain_filter,
self.categories_filter,
self.string_index_type,
self.disable_service_logs,
)[:1024]

def to_generated(self):
return _latest_preview_models.PiiTask(
parameters=_latest_preview_models.PiiTaskParameters(
return _v3_1_models.PiiTask(
parameters=_v3_1_models.PiiTaskParameters(
model_version=self.model_version,
domain=self.domain_filter,
pii_categories=self.categories_filter,
string_index_type=self.string_index_type,
logging_opt_out=self.disable_service_logs
)
Expand Down Expand Up @@ -1594,8 +1607,8 @@ def __repr__(self, **kwargs):
.format(self.model_version, self.disable_service_logs)[:1024]

def to_generated(self):
return _latest_preview_models.KeyPhrasesTask(
parameters=_latest_preview_models.KeyPhrasesTaskParameters(
return _v3_1_models.KeyPhrasesTask(
parameters=_v3_1_models.KeyPhrasesTaskParameters(
model_version=self.model_version,
logging_opt_out=self.disable_service_logs,
)
Expand Down Expand Up @@ -1650,8 +1663,8 @@ def __repr__(self, **kwargs):
)[:1024]

def to_generated(self):
return _latest_preview_models.EntityLinkingTask(
parameters=_latest_preview_models.EntityLinkingTaskParameters(
return _v3_1_models.EntityLinkingTask(
parameters=_v3_1_models.EntityLinkingTaskParameters(
model_version=self.model_version,
string_index_type=self.string_index_type,
logging_opt_out=self.disable_service_logs,
Expand Down
29 changes: 29 additions & 0 deletions sdk/textanalytics/azure-ai-textanalytics/tests/test_analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
RecognizeLinkedEntitiesResult,
RecognizeEntitiesResult,
RecognizePiiEntitiesResult,
PiiEntityCategoryType
)

# pre-apply the client_cls positional argument so it needn't be explicitly passed below
Expand Down Expand Up @@ -679,3 +680,31 @@ def callback(resp):
polling_interval=self._interval(),
raw_response_hook=callback,
).result()

@GlobalTextAnalyticsAccountPreparer()
@TextAnalyticsClientPreparer()
def test_pii_action_categories_filter(self, client):

docs = [{"id": "1", "text": "My SSN is 859-98-0987."},
{"id": "2",
"text": "Your ABA number - 111000025 - is the first 9 digits in the lower left hand corner of your personal check."},
{"id": "3", "text": "Is 998.214.865-68 your Brazilian CPF number?"}]

actions = [
RecognizePiiEntitiesAction(
categories_filter=[
PiiEntityCategoryType.US_SOCIAL_SECURITY_NUMBER,
PiiEntityCategoryType.ABA_ROUTING_NUMBER,
]
),
]

result = client.begin_analyze_actions(documents=docs, actions=actions, polling_interval=self._interval()).result()
action_results = list(result)
assert len(action_results) == 3

assert action_results[0][0].entities[0].text == "859-98-0987"
assert action_results[0][0].entities[0].category == PiiEntityCategoryType.US_SOCIAL_SECURITY_NUMBER
assert action_results[1][0].entities[0].text == "111000025"
assert action_results[1][0].entities[0].category == PiiEntityCategoryType.ABA_ROUTING_NUMBER
assert action_results[2][0].entities == [] # No Brazilian CPF since not in categories_filter
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
RecognizeLinkedEntitiesResult,
AnalyzeSentimentResult,
ExtractKeyPhrasesResult,
PiiEntityCategoryType
)

# pre-apply the client_cls positional argument so it needn't be explicitly passed below
Expand Down Expand Up @@ -727,3 +728,34 @@ async def test_disable_service_logs(self, client):
actions=actions,
polling_interval=self._interval(),
)).result()

@GlobalTextAnalyticsAccountPreparer()
@TextAnalyticsClientPreparer()
async def test_pii_action_categories_filter(self, client):

docs = [{"id": "1", "text": "My SSN is 859-98-0987."},
{"id": "2",
"text": "Your ABA number - 111000025 - is the first 9 digits in the lower left hand corner of your personal check."},
{"id": "3", "text": "Is 998.214.865-68 your Brazilian CPF number?"}]

actions = [
RecognizePiiEntitiesAction(
categories_filter=[
PiiEntityCategoryType.US_SOCIAL_SECURITY_NUMBER,
PiiEntityCategoryType.ABA_ROUTING_NUMBER
]
),
]
async with client:
result = await (await client.begin_analyze_actions(documents=docs, actions=actions, polling_interval=self._interval())).result()
action_results = []
async for p in result:
action_results.append(p)

assert len(action_results) == 3

assert action_results[0][0].entities[0].text == "859-98-0987"
assert action_results[0][0].entities[0].category == PiiEntityCategoryType.US_SOCIAL_SECURITY_NUMBER
assert action_results[1][0].entities[0].text == "111000025"
assert action_results[1][0].entities[0].category == PiiEntityCategoryType.ABA_ROUTING_NUMBER
assert action_results[2][0].entities == [] # No Brazilian CPF since not in categories_filter

0 comments on commit ec064b8

Please sign in to comment.