Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[textanalytics] use language API for analyze text operations #23814

Merged
  •  
  •  
  •  
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class TextAnalyticsApiVersion(str, Enum, metaclass=CaseInsensitiveEnumMeta):
"""Text Analytics API versions supported by this package"""

#: this is the default version
V2022_03_01_PREVIEW = "2022-03-01-preview"
V3_2_PREVIEW = "v3.2-preview.2"
V3_1 = "v3.1"
V3_0 = "v3.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def on_response(self, request, response):
if self._is_lro and (not data or data.get("status") not in _FINISHED):
return
if data:
data = data.get("results", data) # language API compat
statistics = data.get("statistics", None)
model_version = data.get("modelVersion", None)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
)


def is_language_api(api_version):
"""Language API is date-based
"""
import re
return re.search(r'\d{4}-\d{2}-\d{2}', api_version)


def _validate_input(documents, hint, whole_input_hint):
"""Validate that batch input has either all string docs
or dict/DetectLanguageInput/TextDocumentInput, not a mix of both.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ def order_results(response, combined):
:param combined: A combined list of the results | errors
:return: In order list of results | errors (if any)
"""
request = json.loads(response.http_response.request.body)["documents"]
try:
request = json.loads(response.http_response.request.body)["documents"]
except KeyError: # language API compat
request = json.loads(response.http_response.request.body)["analysisInput"]["documents"]
mapping = {item.id: item for item in combined}
ordered_response = [mapping[item["id"]] for item in request]
return ordered_response
Expand Down Expand Up @@ -97,6 +100,9 @@ def choose_wrapper(*args, **kwargs):
def wrapper(
response, obj, response_headers, ordering_function
): # pylint: disable=unused-argument
if hasattr(obj, "results"):
obj = obj.results # language API compat

if obj.errors:
combined = obj.documents + obj.errors
results = ordering_function(response, combined)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
_validate_input,
_determine_action_type,
_check_string_index_type_arg,
is_language_api
)
from ._version import DEFAULT_API_VERSION
from ._response_handlers import (
Expand Down Expand Up @@ -198,13 +199,29 @@ def detect_language(
model_version = kwargs.pop("model_version", None)
show_stats = kwargs.pop("show_stats", None)
disable_service_logs = kwargs.pop("disable_service_logs", None)
if disable_service_logs is not None:
kwargs["logging_opt_out"] = disable_service_logs

try:
if is_language_api(self._api_version):
models = self._client.models(api_version=self._api_version)
return self._client.analyze_text(
body=models.AnalyzeTextLanguageDetectionInput(
analysis_input={"documents": docs},
parameters=models.LanguageDetectionTaskParameters(
logging_opt_out=disable_service_logs,
model_version=model_version
)
),
show_stats=show_stats,
cls=kwargs.pop("cls", language_result),
**kwargs
)

# api_versions 3.0, 3.1
return self._client.languages(
documents=docs,
model_version=model_version,
show_stats=show_stats,
logging_opt_out=disable_service_logs,
cls=kwargs.pop("cls", language_result),
**kwargs
)
Expand Down Expand Up @@ -283,19 +300,34 @@ def recognize_entities(
self._api_version,
string_index_type_default=self._string_index_type_default,
)
if string_index_type:
kwargs.update({"string_index_type": string_index_type})
disable_service_logs = kwargs.pop("disable_service_logs", None)
if disable_service_logs is not None:
kwargs["logging_opt_out"] = disable_service_logs

try:
if is_language_api(self._api_version):
models = self._client.models(api_version=self._api_version)
return self._client.analyze_text(
body=models.AnalyzeTextEntityRecognitionInput(
analysis_input={"documents": docs},
parameters=models.EntitiesTaskParameters(
logging_opt_out=disable_service_logs,
model_version=model_version,
string_index_type=string_index_type
)
),
show_stats=show_stats,
cls=kwargs.pop("cls", entities_result),
**kwargs
)

# api_versions 3.0, 3.1
return self._client.entities_recognition_general(
documents=docs,
model_version=model_version,
show_stats=show_stats,
string_index_type=string_index_type,
logging_opt_out=disable_service_logs,
cls=kwargs.pop("cls", entities_result),
**kwargs
**kwargs,
)
except HttpResponseError as error:
return process_http_response_error(error)
Expand Down Expand Up @@ -377,25 +409,41 @@ def recognize_pii_entities(
show_stats = kwargs.pop("show_stats", None)
domain_filter = kwargs.pop("domain_filter", None)
categories_filter = kwargs.pop("categories_filter", None)

string_index_type = _check_string_index_type_arg(
kwargs.pop("string_index_type", None),
self._api_version,
string_index_type_default=self._string_index_type_default,
)
if string_index_type:
kwargs.update({"string_index_type": string_index_type})
disable_service_logs = kwargs.pop("disable_service_logs", None)
if disable_service_logs is not None:
kwargs["logging_opt_out"] = disable_service_logs

try:
if is_language_api(self._api_version):
models = self._client.models(api_version=self._api_version)
return self._client.analyze_text(
body=models.AnalyzeTextPiiEntitiesRecognitionInput(
analysis_input={"documents": docs},
parameters=models.PiiTaskParameters(
logging_opt_out=disable_service_logs,
model_version=model_version,
domain=domain_filter,
pii_categories=categories_filter,
string_index_type=string_index_type
)
),
show_stats=show_stats,
cls=kwargs.pop("cls", pii_entities_result),
**kwargs
)

# api_versions 3.0, 3.1
return self._client.entities_recognition_pii(
documents=docs,
model_version=model_version,
show_stats=show_stats,
domain=domain_filter,
pii_categories=categories_filter,
logging_opt_out=disable_service_logs,
string_index_type=string_index_type,
cls=kwargs.pop("cls", pii_entities_result),
**kwargs
)
Expand Down Expand Up @@ -480,21 +528,35 @@ def recognize_linked_entities(
model_version = kwargs.pop("model_version", None)
show_stats = kwargs.pop("show_stats", None)
disable_service_logs = kwargs.pop("disable_service_logs", None)
if disable_service_logs is not None:
kwargs["logging_opt_out"] = disable_service_logs

string_index_type = _check_string_index_type_arg(
kwargs.pop("string_index_type", None),
self._api_version,
string_index_type_default=self._string_index_type_default,
)
if string_index_type:
kwargs.update({"string_index_type": string_index_type})

try:
if is_language_api(self._api_version):
models = self._client.models(api_version=self._api_version)
return self._client.analyze_text(
body=models.AnalyzeTextEntityLinkingInput(
analysis_input={"documents": docs},
parameters=models.EntityLinkingTaskParameters(
logging_opt_out=disable_service_logs,
model_version=model_version,
string_index_type=string_index_type
)
),
show_stats=show_stats,
cls=kwargs.pop("cls", linked_entities_result),
**kwargs
)

# api_versions 3.0, 3.1
return self._client.entities_linking(
documents=docs,
logging_opt_out=disable_service_logs,
model_version=model_version,
string_index_type=string_index_type,
show_stats=show_stats,
cls=kwargs.pop("cls", linked_entities_result),
**kwargs
Expand Down Expand Up @@ -724,14 +786,29 @@ def extract_key_phrases(
model_version = kwargs.pop("model_version", None)
show_stats = kwargs.pop("show_stats", None)
disable_service_logs = kwargs.pop("disable_service_logs", None)
if disable_service_logs is not None:
kwargs["logging_opt_out"] = disable_service_logs

try:
if is_language_api(self._api_version):
models = self._client.models(api_version=self._api_version)
return self._client.analyze_text(
body=models.AnalyzeTextKeyPhraseExtractionInput(
analysis_input={"documents": docs},
parameters=models.KeyPhraseTaskParameters(
logging_opt_out=disable_service_logs,
model_version=model_version,
)
),
show_stats=show_stats,
cls=kwargs.pop("cls", key_phrases_result),
**kwargs
)

# api_versions 3.0, 3.1
return self._client.key_phrases(
documents=docs,
model_version=model_version,
show_stats=show_stats,
logging_opt_out=disable_service_logs,
cls=kwargs.pop("cls", key_phrases_result),
**kwargs
)
Expand Down Expand Up @@ -813,17 +890,11 @@ def analyze_sentiment(
show_stats = kwargs.pop("show_stats", None)
show_opinion_mining = kwargs.pop("show_opinion_mining", None)
disable_service_logs = kwargs.pop("disable_service_logs", None)
if disable_service_logs is not None:
kwargs["logging_opt_out"] = disable_service_logs

string_index_type = _check_string_index_type_arg(
kwargs.pop("string_index_type", None),
self._api_version,
string_index_type_default=self._string_index_type_default,
)
if string_index_type:
kwargs.update({"string_index_type": string_index_type})

if show_opinion_mining is not None:
if (
self._api_version == TextAnalyticsApiVersion.V3_0
Expand All @@ -832,12 +903,32 @@ def analyze_sentiment(
raise ValueError(
"'show_opinion_mining' is only available for API version v3.1 and up"
)
kwargs.update({"opinion_mining": show_opinion_mining})

try:
if is_language_api(self._api_version):
models = self._client.models(api_version=self._api_version)
return self._client.analyze_text(
body=models.AnalyzeTextSentimentAnalysisInput(
analysis_input={"documents": docs},
parameters=models.SentimentAnalysisTaskParameters(
logging_opt_out=disable_service_logs,
model_version=model_version,
string_index_type=string_index_type,
opinion_mining=show_opinion_mining,
)
),
show_stats=show_stats,
cls=kwargs.pop("cls", sentiment_result),
**kwargs
)

# api_versions 3.0, 3.1
return self._client.sentiment(
documents=docs,
logging_opt_out=disable_service_logs,
model_version=model_version,
string_index_type=string_index_type,
opinion_mining=show_opinion_mining,
show_stats=show_stats,
cls=kwargs.pop("cls", sentiment_result),
**kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
# ------------------------------------

VERSION = "5.2.0b4"
DEFAULT_API_VERSION = "v3.2-preview.2"
DEFAULT_API_VERSION = "2022-03-01-preview"
Loading