Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…into fr-business-cards

* 'master' of https://github.com/Azure/azure-sdk-for-python:
  [form recognizer] Remove unnecessary code (#14257)
  update api version enum (#14254)
  small fixes while going through release checklist (#14250)
  small changes to documentation, updated link in samples README, updat… (#14249)
  • Loading branch information
iscai-msft committed Oct 5, 2020
2 parents 827d37e + 4cabc99 commit 55c98f4
Show file tree
Hide file tree
Showing 20 changed files with 67 additions and 188 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class FormRecognizerApiVersion(str, Enum):
"""Form Recognizer API versions supported by this package"""

#: this is the default version
V2_1_PREVIEW_1 = "2.1-preview.1"
V2_1_PREVIEW = "2.1-preview.1"
V2_0 = "2.0"


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ def __init__(self, endpoint, credential, **kwargs):
# type: (str, Union[AzureKeyCredential, TokenCredential], Any) -> None
self._endpoint = endpoint
self._credential = credential
self.api_version = kwargs.pop('api_version', FormRecognizerApiVersion.V2_1_PREVIEW_1)
self.api_version = kwargs.pop('api_version', FormRecognizerApiVersion.V2_1_PREVIEW)
validate_api_version(self.api_version)

authentication_policy = get_authentication_policy(credential)
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
validate_api_version(self.api_version)

http_logging_policy = HttpLoggingPolicy(**kwargs)
http_logging_policy.allowed_header_names.update(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@
TYPE_CHECKING
)
from azure.core.tracing.decorator import distributed_trace
from azure.core.polling import LROPoller
from azure.core.polling.base_polling import LROBasePolling

from ._response_handlers import (
prepare_receipt,
prepare_content_result,
prepare_form_result
)
from ._helpers import get_content_type, error_map
from ._helpers import get_content_type
from ._form_base_client import FormRecognizerClientBase
from ._polling import AnalyzePolling
if TYPE_CHECKING:
from azure.core.polling import LROPoller
from ._models import FormPage, RecognizedForm


Expand Down Expand Up @@ -104,14 +105,11 @@ def begin_recognize_receipts(self, receipt, **kwargs):
:caption: Recognize US sales receipt fields.
"""
locale = kwargs.pop("locale", None)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
content_type = kwargs.pop("content_type", None)
include_field_elements = kwargs.pop("include_field_elements", False)
if content_type == "application/json":
raise TypeError("Call begin_recognize_receipts_from_url() to analyze a receipt from a URL.")
cls = kwargs.pop("cls", self._receipt_callback)
polling = LROBasePolling(timeout=polling_interval, **kwargs)
if content_type is None:
content_type = get_content_type(receipt)

Expand All @@ -123,9 +121,7 @@ def begin_recognize_receipts(self, receipt, **kwargs):
content_type=content_type,
include_text_details=include_field_elements,
cls=cls,
polling=polling,
error_map=error_map,
continuation_token=continuation_token,
polling=True,
**kwargs
)

Expand Down Expand Up @@ -161,20 +157,15 @@ def begin_recognize_receipts_from_url(self, receipt_url, **kwargs):
:caption: Recognize US sales receipt fields from a URL.
"""
locale = kwargs.pop("locale", None)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
include_field_elements = kwargs.pop("include_field_elements", False)
cls = kwargs.pop("cls", self._receipt_callback)
polling = LROBasePolling(timeout=polling_interval, **kwargs)
if self.api_version == "2.1-preview.1" and locale:
kwargs.update({"locale": locale})
return self._client.begin_analyze_receipt_async( # type: ignore
file_stream={"source": receipt_url},
include_text_details=include_field_elements,
cls=cls,
polling=polling,
error_map=error_map,
continuation_token=continuation_token,
polling=True,
**kwargs
)

Expand Down Expand Up @@ -313,9 +304,6 @@ def begin_recognize_content(self, form, **kwargs):
:dedent: 8
:caption: Recognize text and content/layout information from a form.
"""

polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
content_type = kwargs.pop("content_type", None)
if content_type == "application/json":
raise TypeError("Call begin_recognize_content_from_url() to analyze a document from a URL.")
Expand All @@ -327,9 +315,7 @@ def begin_recognize_content(self, form, **kwargs):
file_stream=form,
content_type=content_type,
cls=kwargs.pop("cls", self._content_callback),
polling=LROBasePolling(timeout=polling_interval, **kwargs),
error_map=error_map,
continuation_token=continuation_token,
polling=True,
**kwargs
)

Expand All @@ -350,15 +336,10 @@ def begin_recognize_content_from_url(self, form_url, **kwargs):
:raises ~azure.core.exceptions.HttpResponseError:
"""

polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)

return self._client.begin_analyze_layout_async( # type: ignore
file_stream={"source": form_url},
cls=kwargs.pop("cls", self._content_callback),
polling=LROBasePolling(timeout=polling_interval, **kwargs),
error_map=error_map,
continuation_token=continuation_token,
polling=True,
**kwargs
)

Expand Down Expand Up @@ -400,9 +381,8 @@ def begin_recognize_custom_forms(self, model_id, form, **kwargs):
if not model_id:
raise ValueError("model_id cannot be None or empty.")

cls = kwargs.pop("cls", None)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)

content_type = kwargs.pop("content_type", None)
if content_type == "application/json":
raise TypeError("Call begin_recognize_custom_forms_from_url() to analyze a document from a URL.")
Expand All @@ -415,16 +395,13 @@ def analyze_callback(raw_response, _, headers): # pylint: disable=unused-argume
analyze_result = self._deserialize(self._generated_models.AnalyzeOperationResult, raw_response)
return prepare_form_result(analyze_result, model_id)

deserialization_callback = cls if cls else analyze_callback
return self._client.begin_analyze_with_custom_model( # type: ignore
file_stream=form,
model_id=model_id,
include_text_details=include_field_elements,
content_type=content_type,
cls=deserialization_callback,
cls=kwargs.pop("cls", analyze_callback),
polling=LROBasePolling(timeout=polling_interval, lro_algorithms=[AnalyzePolling()], **kwargs),
error_map=error_map,
continuation_token=continuation_token,
**kwargs
)

Expand Down Expand Up @@ -452,24 +429,20 @@ def begin_recognize_custom_forms_from_url(self, model_id, form_url, **kwargs):
if not model_id:
raise ValueError("model_id cannot be None or empty.")

cls = kwargs.pop("cls", None)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)

include_field_elements = kwargs.pop("include_field_elements", False)

def analyze_callback(raw_response, _, headers): # pylint: disable=unused-argument
analyze_result = self._deserialize(self._generated_models.AnalyzeOperationResult, raw_response)
return prepare_form_result(analyze_result, model_id)

deserialization_callback = cls if cls else analyze_callback
return self._client.begin_analyze_with_custom_model( # type: ignore
file_stream={"source": form_url},
model_id=model_id,
include_text_details=include_field_elements,
cls=deserialization_callback,
cls=kwargs.pop("cls", analyze_callback),
polling=LROBasePolling(timeout=polling_interval, lro_algorithms=[AnalyzePolling()], **kwargs),
error_map=error_map,
continuation_token=continuation_token,
**kwargs
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@
CopyRequest,
CopyAuthorizationResult
)
from ._helpers import (
error_map,
TransportWrapper
)
from ._helpers import TransportWrapper

from ._models import (
CustomFormModelInfo,
AccountProperties,
Expand Down Expand Up @@ -152,7 +150,6 @@ def callback_v2_1(raw_response, _, headers): # pylint: disable=unused-argument
)
),
cls=lambda pipeline_response, _, response_headers: pipeline_response,
error_map=error_map,
**kwargs
) # type: PipelineResponseType

Expand All @@ -176,7 +173,6 @@ def callback_v2_1(raw_response, _, headers): # pylint: disable=unused-argument
cls=deserialization_callback,
continuation_token=continuation_token,
polling=LROBasePolling(timeout=polling_interval, lro_algorithms=[TrainingPolling()], **kwargs),
error_map=error_map,
**kwargs
)

Expand Down Expand Up @@ -204,11 +200,7 @@ def delete_model(self, model_id, **kwargs):
if not model_id:
raise ValueError("model_id cannot be None or empty.")

self._client.delete_custom_model(
model_id=model_id,
error_map=error_map,
**kwargs
)
self._client.delete_custom_model(model_id=model_id, **kwargs)

@distributed_trace
def list_custom_models(self, **kwargs):
Expand All @@ -231,7 +223,6 @@ def list_custom_models(self, **kwargs):
"""
return self._client.list_custom_models( # type: ignore
cls=kwargs.pop("cls", lambda objs: [CustomFormModelInfo._from_generated(x) for x in objs]),
error_map=error_map,
**kwargs
)

Expand All @@ -254,7 +245,7 @@ def get_account_properties(self, **kwargs):
:dedent: 8
:caption: Get properties for the form recognizer account.
"""
response = self._client.get_custom_models(error_map=error_map, **kwargs)
response = self._client.get_custom_models(**kwargs)
return AccountProperties._from_generated(response.summary)

@distributed_trace
Expand All @@ -281,7 +272,7 @@ def get_custom_model(self, model_id, **kwargs):
if not model_id:
raise ValueError("model_id cannot be None or empty.")

response = self._client.get_custom_model(model_id=model_id, include_keys=True, error_map=error_map, **kwargs)
response = self._client.get_custom_model(model_id=model_id, include_keys=True, **kwargs)
return CustomFormModel._from_generated(response)

@distributed_trace
Expand Down Expand Up @@ -314,7 +305,6 @@ def get_copy_authorization(self, resource_id, resource_region, **kwargs):

response = self._client.generate_model_copy_authorization( # type: ignore
cls=lambda pipeline_response, deserialized, response_headers: pipeline_response,
error_map=error_map,
**kwargs
) # type: PipelineResponse
target = json.loads(response.http_response.text())
Expand Down Expand Up @@ -359,9 +349,7 @@ def begin_copy_model(

if not model_id:
raise ValueError("model_id cannot be None or empty.")

polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)

def _copy_callback(raw_response, _, headers): # pylint: disable=unused-argument
copy_result = self._deserialize(self._generated_models.CopyOperationResult, raw_response)
Expand All @@ -380,8 +368,6 @@ def _copy_callback(raw_response, _, headers): # pylint: disable=unused-argument
),
cls=kwargs.pop("cls", _copy_callback),
polling=LROBasePolling(timeout=polling_interval, lro_algorithms=[CopyPolling()], **kwargs),
error_map=error_map,
continuation_token=continuation_token,
**kwargs
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,10 @@
from azure.core.credentials import AzureKeyCredential
from azure.core.pipeline.policies import AzureKeyCredentialPolicy
from azure.core.pipeline.transport import HttpTransport
from azure.core.exceptions import (
ResourceNotFoundError,
ResourceExistsError,
ClientAuthenticationError
)

POLLING_INTERVAL = 5
COGNITIVE_KEY_HEADER = "Ocp-Apim-Subscription-Key"


error_map = {
404: ResourceNotFoundError,
409: ResourceExistsError,
401: ClientAuthenticationError
}

def _get_deserialize():
from ._generated.v2_1_preview_1 import FormRecognizerClient
return FormRecognizerClient("dummy", "dummy")._deserialize # pylint: disable=protected-access
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ def __init__(
) -> None:
self._endpoint = endpoint
self._credential = credential
self.api_version = kwargs.pop('api_version', FormRecognizerApiVersion.V2_1_PREVIEW_1)
self.api_version = kwargs.pop('api_version', FormRecognizerApiVersion.V2_1_PREVIEW)
validate_api_version(self.api_version)

authentication_policy = get_authentication_policy(credential)
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
validate_api_version(self.api_version)

http_logging_policy = HttpLoggingPolicy(**kwargs)
http_logging_policy.allowed_header_names.update(
Expand Down
Loading

0 comments on commit 55c98f4

Please sign in to comment.