diff --git a/.gitignore b/.gitignore index 75051d8f..8140f1ae 100644 --- a/.gitignore +++ b/.gitignore @@ -101,6 +101,9 @@ ENV/ .vscode settings.json +# PyCharm +.idea + # mypy .mypy_cache/ diff --git a/bulk_index.py b/bulk_index.py index 92d55bcc..3be1f394 100644 --- a/bulk_index.py +++ b/bulk_index.py @@ -56,7 +56,7 @@ def populate(print_indexable: bool, paper_id: str, id_list: str, if load_cache: try: this_meta = from_cache(cache_dir, paper_id) - except RuntimeError as e: # No document. + except RuntimeError: # No document. pass if this_meta: @@ -67,7 +67,7 @@ def populate(print_indexable: bool, paper_id: str, id_list: str, if len(chunk) == retrieve_chunk_size or i == last: try: new_meta = metadata.bulk_retrieve(chunk) - except metadata.ConnectionFailed as e: # Try again. + except metadata.ConnectionFailed: # Try again. new_meta = metadata.bulk_retrieve(chunk) # Add metadata to the cache. key = lambda dm: dm.paper_id @@ -93,8 +93,8 @@ def populate(print_indexable: bool, paper_id: str, id_list: str, meta = [] index_bar.update(i) - except Exception as e: - raise RuntimeError('Populate failed: %s' % str(e)) from e + except Exception as ex: + raise RuntimeError('Populate failed: %s' % str(ex)) from ex finally: click.echo(f"Indexed {index_count} documents in total") @@ -162,8 +162,8 @@ def to_cache(cache_dir: str, arxiv_id: str, docmeta: List[DocMeta]) -> None: try: with open(cache_path, 'w') as f: json.dump([asdict(dm) for dm in docmeta], f) - except Exception as e: - raise RuntimeError(str(e)) from e + except Exception as ex: + raise RuntimeError(str(ex)) from ex def load_id_list(path: str) -> List[str]: diff --git a/search/agent/consumer.py b/search/agent/consumer.py index ee00172c..1a42dbba 100644 --- a/search/agent/consumer.py +++ b/search/agent/consumer.py @@ -67,22 +67,22 @@ def _get_metadata(self, arxiv_id: str) -> DocMeta: docmeta: DocMeta = retry_call(metadata.retrieve, (arxiv_id,), exceptions=metadata.ConnectionFailed, tries=2) - except metadata.ConnectionFailed as e: + except metadata.ConnectionFailed as ex: # Things really are looking bad. There is no need to keep # trying with subsequent records, so let's abort entirely. logger.error('%s: second attempt failed, giving up', arxiv_id) raise IndexingFailed( 'Indexing failed; metadata endpoint could not be reached.' - ) from e - except metadata.RequestFailed as e: + ) from ex + except metadata.RequestFailed as ex: logger.error(f'{arxiv_id}: request failed') - raise DocumentFailed('Request to metadata service failed') from e - except metadata.BadResponse as e: + raise DocumentFailed('Request to metadata service failed') from ex + except metadata.BadResponse as ex: logger.error(f'{arxiv_id}: bad response from metadata service') - raise DocumentFailed('Bad response from metadata service') from e - except Exception as e: - logger.error(f'{arxiv_id}: unhandled error, metadata service: {e}') - raise IndexingFailed('Unhandled exception') from e + raise DocumentFailed('Bad response from metadata service') from ex + except Exception as ex: + logger.error(f'{arxiv_id}: unhandled error, metadata service: {ex}') + raise IndexingFailed('Unhandled exception') from ex return docmeta def _get_bulk_metadata(self, arxiv_ids: List[str]) -> List[DocMeta]: @@ -116,20 +116,20 @@ def _get_bulk_metadata(self, arxiv_ids: List[str]) -> List[DocMeta]: meta = retry_call(metadata.bulk_retrieve, (arxiv_ids,), exceptions=metadata.ConnectionFailed, tries=2) - except metadata.ConnectionFailed as e: + except metadata.ConnectionFailed as ex: # Things really are looking bad. There is no need to keep # trying with subsequent records, so let's abort entirely. logger.error('%s: second attempt failed, giving up', arxiv_ids) - raise IndexingFailed('Metadata endpoint not available') from e - except metadata.RequestFailed as e: + raise IndexingFailed('Metadata endpoint not available') from ex + except metadata.RequestFailed as ex: logger.error('%s: request failed', arxiv_ids) - raise DocumentFailed('Request to metadata service failed') from e - except metadata.BadResponse as e: + raise DocumentFailed('Request to metadata service failed') from ex + except metadata.BadResponse as ex: logger.error('%s: bad response from metadata service', arxiv_ids) - raise DocumentFailed('Bad response from metadata service') from e - except Exception as e: - logger.error('%s: unhandled error, metadata svc: %s', arxiv_ids, e) - raise IndexingFailed('Unhandled exception') from e + raise DocumentFailed('Bad response from metadata service') from ex + except Exception as ex: + logger.error('%s: unhandled error, metadata svc: %s', arxiv_ids, ex) + raise IndexingFailed('Unhandled exception') from ex return meta @staticmethod @@ -156,10 +156,10 @@ def _transform_to_document(docmeta: DocMeta) -> Document: """ try: document = transform.to_search_document(docmeta) - except Exception as e: + except Exception as ex: # At the moment we don't have any special exceptions. - logger.error('unhandled exception during transform: %s', e) - raise DocumentFailed('Could not transform document') from e + logger.error('unhandled exception during transform: %s', ex) + raise DocumentFailed('Could not transform document') from ex return document @@ -182,11 +182,11 @@ def _add_to_index(document: Document) -> None: try: retry_call(index.SearchSession.add_document, (document,), exceptions=index.IndexConnectionError, tries=2) - except index.IndexConnectionError as e: - raise IndexingFailed('Could not index document') from e - except Exception as e: - logger.error(f'Unhandled exception from index service: {e}') - raise IndexingFailed('Unhandled exception') from e + except index.IndexConnectionError as ex: + raise IndexingFailed('Could not index document') from ex + except Exception as ex: + logger.error(f'Unhandled exception from index service: {ex}') + raise IndexingFailed('Unhandled exception') from ex @staticmethod def _bulk_add_to_index(documents: List[Document]) -> None: @@ -207,11 +207,11 @@ def _bulk_add_to_index(documents: List[Document]) -> None: try: retry_call(index.SearchSession.bulk_add_documents, (documents,), exceptions=index.IndexConnectionError, tries=2) - except index.IndexConnectionError as e: - raise IndexingFailed('Could not bulk index documents') from e - except Exception as e: - logger.error(f'Unhandled exception from index service: {e}') - raise IndexingFailed('Unhandled exception') from e + except index.IndexConnectionError as ex: + raise IndexingFailed('Could not bulk index documents') from ex + except Exception as ex: + logger.error(f'Unhandled exception from index service: {ex}') + raise IndexingFailed('Unhandled exception') from ex def index_paper(self, arxiv_id: str) -> None: """ @@ -254,10 +254,10 @@ def index_papers(self, arxiv_ids: List[str]) -> None: documents.append(document) logger.debug('add to index in bulk') MetadataRecordProcessor._bulk_add_to_index(documents) - except (DocumentFailed, IndexingFailed) as e: + except (DocumentFailed, IndexingFailed) as ex: # We just pass these along so that process_record() can keep track. - logger.debug(f'{arxiv_ids}: Document failed: {e}') - raise e + logger.debug(f'{arxiv_ids}: Document failed: {ex}') + raise ex def process_record(self, record: dict) -> None: """ @@ -285,8 +285,8 @@ def process_record(self, record: dict) -> None: try: deserialized = json.loads(record['Data'].decode('utf-8')) - except json.decoder.JSONDecodeError as e: - logger.error("Error while deserializing data %s", e) + except json.decoder.JSONDecodeError as ex: + logger.error("Error while deserializing data %s", ex) logger.error("Data payload: %s", record['Data']) raise DocumentFailed('Could not deserialize record data') # return # Don't bring down the whole batch. @@ -294,9 +294,9 @@ def process_record(self, record: dict) -> None: try: arxiv_id: str = deserialized.get('document_id') self.index_paper(arxiv_id) - except DocumentFailed as e: - logger.debug('%s: failed to index document: %s', arxiv_id, e) + except DocumentFailed as ex: + logger.debug('%s: failed to index document: %s', arxiv_id, ex) self._error_count += 1 - except IndexingFailed as e: - logger.error('Indexing failed: %s', e) + except IndexingFailed as ex: + logger.error('Indexing failed: %s', ex) raise diff --git a/search/agent/tests/test_record_processor.py b/search/agent/tests/test_record_processor.py index e01ae644..c4ddbec9 100644 --- a/search/agent/tests/test_record_processor.py +++ b/search/agent/tests/test_record_processor.py @@ -122,8 +122,8 @@ def test_add_document_succeeds(self, mock_index, mock_client_factory): processor = consumer.MetadataRecordProcessor(*self.args) try: processor._add_to_index(Document()) - except Exception as e: - self.fail(e) + except Exception as ex: + self.fail(ex) mock_index.add_document.assert_called_once() @mock.patch('boto3.client') @@ -178,8 +178,8 @@ def test_bulk_add_documents_succeeds(self, mock_index, processor = consumer.MetadataRecordProcessor(*self.args) try: processor._bulk_add_to_index([Document()]) - except Exception as e: - self.fail(e) + except Exception as ex: + self.fail(ex) mock_index.bulk_add_documents.assert_called_once() @mock.patch('boto3.client') diff --git a/search/controllers/advanced/__init__.py b/search/controllers/advanced/__init__.py index 53ab4f12..87ce33f0 100644 --- a/search/controllers/advanced/__init__.py +++ b/search/controllers/advanced/__init__.py @@ -116,30 +116,30 @@ def search(request_params: MultiDict) -> Response: # template rendering, so they get added directly to the # response content. asdict( response_data.update(SearchSession.search(q)) # type: ignore - except index.IndexConnectionError as e: + except index.IndexConnectionError as ex: # There was a (hopefully transient) connection problem. Either # this will clear up relatively quickly (next request), or # there is a more serious outage. - logger.error('IndexConnectionError: %s', e) + logger.error('IndexConnectionError: %s', ex) raise InternalServerError( "There was a problem connecting to the search index. This " "is quite likely a transient issue, so please try your " "search again. If this problem persists, please report it " "to help@arxiv.org." - ) from e - except index.QueryError as e: + ) from ex + except index.QueryError as ex: # Base exception routers should pick this up and show bug page. - logger.error('QueryError: %s', e) + logger.error('QueryError: %s', ex) raise InternalServerError( "There was a problem executing your query. Please try " "your search again. If this problem persists, please " "report it to help@arxiv.org." - ) from e - except index.OutsideAllowedRange as e: + ) from ex + except index.OutsideAllowedRange as ex: raise BadRequest( "Hello clever friend. You can't get results in that range" " right now." - ) from e + ) from ex response_data['query'] = q else: logger.debug('form is invalid: %s', str(form.errors)) diff --git a/search/controllers/advanced/tests.py b/search/controllers/advanced/tests.py index 0bbb7e90..0ec57e10 100644 --- a/search/controllers/advanced/tests.py +++ b/search/controllers/advanced/tests.py @@ -161,8 +161,8 @@ def _raiseQueryError(*args, **kwargs): with self.assertRaises(InternalServerError): try: response_data, code, headers = advanced.search(request_data) - except QueryError as e: - self.fail("QueryError should be handled (caught %s)" % e) + except QueryError as ex: + self.fail("QueryError should be handled (caught %s)" % ex) self.assertEqual(mock_index.search.call_count, 1, "A search should be attempted") diff --git a/search/controllers/api/__init__.py b/search/controllers/api/__init__.py index 61aa304b..2bde008b 100644 --- a/search/controllers/api/__init__.py +++ b/search/controllers/api/__init__.py @@ -1,30 +1,25 @@ """Controller for search API requests.""" from collections import defaultdict -from datetime import date, datetime -import re - from typing import Tuple, Dict, Any, Optional, List, Union from mypy_extensions import TypedDict -from dateutil.relativedelta import relativedelta import dateutil.parser import pytz from pytz import timezone -from werkzeug.datastructures import MultiDict, ImmutableMultiDict -from werkzeug.exceptions import InternalServerError, BadRequest, NotFound -from flask import url_for +from werkzeug.datastructures import MultiDict +from werkzeug.exceptions import BadRequest, NotFound from arxiv import status, taxonomy from arxiv.base import logging -from search.services import index, fulltext, metadata +from search.services import index from search.controllers.util import paginate -from ...domain import Query, APIQuery, FieldedSearchList, FieldedSearchTerm, \ - DateRange, ClassificationList, Classification, asdict, DocumentSet, \ - Document, ClassicAPIQuery -from ...domain.api import Phrase, Term, Operator, Field -from .classic_parser import parse_classic_query +from search.domain import ( + Query, APIQuery, FieldedSearchList, FieldedSearchTerm, DateRange, + Classification, DocumentSet, ClassicAPIQuery +) + logger = logging.getLogger(__name__) EASTERN = timezone('US/Eastern') @@ -106,91 +101,6 @@ def search(params: MultiDict) -> Tuple[Dict[str, Any], int, Dict[str, Any]]: return {'results': document_set, 'query': q}, status.HTTP_200_OK, {} -def classic_query(params: MultiDict) \ - -> Tuple[Dict[str, Any], int, Dict[str, Any]]: - """ - Handle a search request from the Clasic API. - - First, the method maps old request parameters to new parameters: - - search_query -> query - - start -> start - - max_results -> size - - Then the request is passed to :method:`search()` and returned. - - If ``id_list`` is specified in the parameters and ``search_query`` is - NOT specified, then each request is passed to :method:`paper()` and - results are aggregated. - - If ``id_list`` is specified AND ``search_query`` is also specified, - then the results from :method:`search()` are filtered by ``id_list``. - - Parameters - ---------- - params : :class:`MultiDict` - GET query parameters from the request. - - Returns - ------- - dict - Response data (to serialize). - int - HTTP status code. - dict - Extra headers for the response. - - Raises - ------ - :class:`BadRequest` - Raised when the search_query and id_list are not specified. - """ - params = params.copy() - - # Parse classic search query. - raw_query = params.get('search_query') - if raw_query: - phrase: Optional[Phrase] = parse_classic_query(raw_query) - else: - phrase = None - - # Parse id_list. - id_list = params.get('id_list', '') - if id_list: - id_list = id_list.split(',') - else: - id_list = None - - # Parse result size. - try: - size = int(params.get('max_results', 50)) - except ValueError: - # Ignore size errors. - size = 50 - - # Parse result start point. - try: - page_start = int(params.get('start', 0)) - except ValueError: - # Start at beginning by default. - page_start = 0 - - try: - query = ClassicAPIQuery(phrase=phrase, id_list=id_list, size=size, - page_start=page_start) - except ValueError: - raise BadRequest("Either a search_query or id_list must be specified" - " for the classic API.") - - # pass to search indexer, which will handle parsing - document_set: DocumentSet = index.SearchSession.current_session().search(query) - data: SearchResponseData = {'results': document_set, 'query': query} - logger.debug('Got document set with %i results', - len(document_set['results'])) - - # bad mypy inference on TypedDict and the status code - return data, status.HTTP_200_OK, {} # type:ignore - - def paper(paper_id: str) -> Tuple[Dict[str, Any], int, Dict[str, Any]]: """ Handle a request for paper metadata from the API. @@ -217,9 +127,9 @@ def paper(paper_id: str) -> Tuple[Dict[str, Any], int, Dict[str, Any]]: """ try: document = index.SearchSession.current_session().get_document(paper_id) # type: ignore - except index.DocumentNotFound as e: + except index.DocumentNotFound as ex: logger.error('Document not found') - raise NotFound('No such document') from e + raise NotFound('No such document') from ex return {'results': document}, status.HTTP_200_OK, {} @@ -310,16 +220,17 @@ def _get_classification(value: str, field: str, query_terms: List) \ query_terms.append({'parameter': field, 'value': value}) return clsns + SEARCH_QUERY_FIELDS = { - 'ti' : 'title', - 'au' : 'author', - 'abs' : 'abstract', - 'co' : 'comments', - 'jr' : 'journal_ref', - 'cat' : 'primary_classification', - 'rn' : 'report_number', - 'id' : 'paper_id', - 'all' : 'all' + 'ti': 'title', + 'au': 'author', + 'abs': 'abstract', + 'co': 'comments', + 'jr': 'journal_ref', + 'cat': 'primary_classification', + 'rn': 'report_number', + 'id': 'paper_id', + 'all': 'all' } diff --git a/search/controllers/api/tests/tests.py b/search/controllers/api/tests/tests_api_search.py similarity index 77% rename from search/controllers/api/tests/tests.py rename to search/controllers/api/tests/tests_api_search.py index b038b779..175ab331 100644 --- a/search/controllers/api/tests/tests.py +++ b/search/controllers/api/tests/tests_api_search.py @@ -1,18 +1,14 @@ """Tests for advanced search controller, :mod:`search.controllers.advanced`.""" from unittest import TestCase, mock -from datetime import date, datetime -from dateutil.relativedelta import relativedelta from werkzeug import MultiDict -from werkzeug.exceptions import InternalServerError, BadRequest +from werkzeug.exceptions import BadRequest from arxiv import status -from search.domain import Query, DateRange, FieldedSearchTerm, Classification,\ - AdvancedQuery, DocumentSet +from search.domain import DateRange, Classification from search.controllers import api from search.domain import api as api_domain -from search.services.index import IndexConnectionError, QueryError class TestAPISearch(TestCase): @@ -159,46 +155,6 @@ def test_with_end_dates_and_type(self, mock_index): DateRange.ANNOUNCED) -class TestClassicAPISearch(TestCase): - """Tests for :func:`.api.classic_query`.""" - - @mock.patch(f'{api.__name__}.index.SearchSession') - def test_no_params(self, mock_index): - """Request with no parameters.""" - params = MultiDict({}) - with self.assertRaises(BadRequest): - data, code, headers = api.classic_query(params) - - @mock.patch(f'{api.__name__}.index.SearchSession') - def test_classic_search_query(self, mock_index): - """Request with search_query.""" - params = MultiDict({'search_query' : 'au:Copernicus'}) - - data, code, headers = api.classic_query(params) - self.assertEqual(code, status.HTTP_200_OK, "Returns 200 OK") - self.assertIn("results", data, "Results are returned") - self.assertIn("query", data, "Query object is returned") - - @mock.patch(f'{api.__name__}.index.SearchSession') - def test_classic_search_query_with_quotes(self, mock_index): - """Request with search_query that includes a quoted phrase.""" - params = MultiDict({'search_query' : 'ti:"dark matter"'}) - - data, code, headers = api.classic_query(params) - self.assertEqual(code, status.HTTP_200_OK, "Returns 200 OK") - self.assertIn("results", data, "Results are returned") - self.assertIn("query", data, "Query object is returned") - - @mock.patch(f'{api.__name__}.index.SearchSession') - def test_classic_id_list(self, mock_index): - """Request with multi-element id_list with versioned and unversioned ids.""" - params = MultiDict({'id_list' : '1234.56789,1234.56789v3'}) - - data, code, headers = api.classic_query(params) - self.assertEqual(code, status.HTTP_200_OK, "Returns 200 OK") - self.assertIn("results", data, "Results are returned") - self.assertIn("query", data, "Query object is returned") - class TestPaper(TestCase): """Tests for :func:`.api.paper`.""" diff --git a/search/controllers/classic_api/__init__.py b/search/controllers/classic_api/__init__.py new file mode 100644 index 00000000..da04b430 --- /dev/null +++ b/search/controllers/classic_api/__init__.py @@ -0,0 +1,174 @@ +"""Controller for classic arXiv API requests.""" + +from typing import Tuple, Dict, Any, Optional, Union +from mypy_extensions import TypedDict +from pytz import timezone +from werkzeug.datastructures import MultiDict +from werkzeug.exceptions import BadRequest, NotFound + +from arxiv import status +from arxiv.base import logging +from arxiv.identifier import parse_arxiv_id + +from search.services import index +from search.domain.api import Phrase +from search.errors import ValidationError +from search.domain import Query, DocumentSet, ClassicAPIQuery +from search.controllers.classic_api.query_parser import parse_classic_query + + +logger = logging.getLogger(__name__) +EASTERN = timezone("US/Eastern") + +SearchResponseData = TypedDict( + "SearchResponseData", + {"results": DocumentSet, "query": Union[Query, ClassicAPIQuery]}, +) + + +def query(params: MultiDict) -> Tuple[Dict[str, Any], int, Dict[str, Any]]: + """ + Handle a search request from the Clasic API. + + First, the method maps old request parameters to new parameters: + - search_query -> query + - start -> start + - max_results -> size + + Then the request is passed to :method:`search()` and returned. + + If ``id_list`` is specified in the parameters and ``search_query`` is + NOT specified, then each request is passed to :method:`paper()` and + results are aggregated. + + If ``id_list`` is specified AND ``search_query`` is also specified, + then the results from :method:`search()` are filtered by ``id_list``. + + Parameters + ---------- + params : :class:`MultiDict` + GET query parameters from the request. + + Returns + ------- + dict + Response data (to serialize). + int + HTTP status code. + dict + Extra headers for the response. + + Raises + ------ + :class:`BadRequest` + Raised when the search_query and id_list are not specified. + """ + params = params.copy() + + # Parse classic search query. + raw_query = params.get("search_query") + if raw_query: + phrase: Optional[Phrase] = parse_classic_query(raw_query) + else: + phrase = None + + # Parse id_list. + id_list = params.get("id_list", "") + if id_list: + id_list = id_list.split(",") + # Check arxiv id validity + for arxiv_id in id_list: + try: + parse_arxiv_id(arxiv_id) + except ValueError: + raise ValidationError( + message="incorrect id format for {}".format(arxiv_id), + link=("http://arxiv.org/api/errors#" + "incorrect_id_format_for_{}").format(arxiv_id) + ) + else: + id_list = None + + # Parse result size. + try: + max_results = int(params.get("max_results", 50)) + except ValueError: + raise ValidationError( + message="max_results must be an integer", + link="http://arxiv.org/api/errors#max_results_must_be_an_integer", + ) + if max_results < 0: + raise ValidationError( + message="max_results must be non-negative", + link="http://arxiv.org/api/errors#max_results_must_be_non-negative" + ) + + # Parse result start point. + try: + start = int(params.get("start", 0)) + except ValueError: + raise ValidationError( + message="start must be an integer", + link="http://arxiv.org/api/errors#start_must_be_an_integer" + ) + if start < 0: + raise ValidationError( + message="start must be non-negative", + link="http://arxiv.org/api/errors#start_must_be_non-negative" + ) + + try: + query = ClassicAPIQuery( + phrase=phrase, id_list=id_list, size=max_results, page_start=start + ) + except ValueError: + raise BadRequest( + "Either a search_query or id_list must be specified" + " for the classic API." + ) + + # pass to search indexer, which will handle parsing + document_set: DocumentSet = index.SearchSession.current_session().search( + query + ) + data: SearchResponseData = {"results": document_set, "query": query} + logger.debug( + "Got document set with %i results", len(document_set["results"]) + ) + + # bad mypy inference on TypedDict and the status code + return data, status.HTTP_200_OK, {} # type:ignore + + +def paper(paper_id: str) -> Tuple[Dict[str, Any], int, Dict[str, Any]]: + """ + Handle a request for paper metadata from the API. + + Parameters + ---------- + paper_id : str + arXiv paper ID for the requested paper. + + Returns + ------- + dict + Response data (to serialize). + int + HTTP status code. + dict + Extra headers for the response. + + Raises + ------ + :class:`NotFound` + Raised when there is no document with the provided paper ID. + + """ + try: + document = index.SearchSession.current_session().get_document( + paper_id + ) # type: ignore + except index.DocumentNotFound as ex: + logger.error("Document not found") + raise NotFound("No such document") from ex + return {"results": document}, status.HTTP_200_OK, {} diff --git a/search/controllers/api/classic_parser.py b/search/controllers/classic_api/query_parser.py similarity index 96% rename from search/controllers/api/classic_parser.py rename to search/controllers/classic_api/query_parser.py index 9d93034f..1cc4560a 100644 --- a/search/controllers/api/classic_parser.py +++ b/search/controllers/classic_api/query_parser.py @@ -12,7 +12,7 @@ >>> parse_classic_query("au:del_maestro AND ti:checkerboard") ((Field.Author, 'del_maestro'), Operator.AND, (Field.Title, 'checkerboard')) -See :module:`tests.test_classic_parser` for more examples. +See :module:`tests.test_query_parser` for more examples. """ @@ -21,7 +21,7 @@ from werkzeug.exceptions import BadRequest -from ...domain.api import Phrase, Operator, Field, Term +from search.domain.api import Phrase, Operator, Field, Term def parse_classic_query(query: str) -> Phrase: """ @@ -143,8 +143,8 @@ def _group_tokens(classed_tokens: List[Union[Operator, Term, Phrase]]) -> Phrase def _parse_operator(characters: str) -> Operator: try: return Operator(characters.strip()) - except ValueError as e: - raise BadRequest(f'Cannot parse fragment: {characters}') from e + except ValueError as ex: + raise BadRequest(f'Cannot parse fragment: {characters}') from ex def _parse_field_query(field_part: str) -> Term: @@ -153,8 +153,8 @@ def _parse_field_query(field_part: str) -> Term: # Cast field to Field enum. try: field = Field(field_name) - except ValueError as e: - raise BadRequest(f'Invalid field: {field_name}') from e + except ValueError as ex: + raise BadRequest(f'Invalid field: {field_name}') from ex # Process leading and trailing whitespace and quotes, if present. value = value.strip() diff --git a/search/controllers/classic_api/tests/__init__.py b/search/controllers/classic_api/tests/__init__.py new file mode 100644 index 00000000..e8e03362 --- /dev/null +++ b/search/controllers/classic_api/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for arXiv classic API controllers.""" diff --git a/search/controllers/api/tests/test_classic_parser.py b/search/controllers/classic_api/tests/test_classic_parser.py similarity index 97% rename from search/controllers/api/tests/test_classic_parser.py rename to search/controllers/classic_api/tests/test_classic_parser.py index b68535ac..7a90eaaa 100644 --- a/search/controllers/api/tests/test_classic_parser.py +++ b/search/controllers/classic_api/tests/test_classic_parser.py @@ -1,8 +1,8 @@ # type: ignore """Test cases for the classic parser.""" -from ...api.classic_parser import parse_classic_query, phrase_to_query_string +from search.controllers.classic_api.query_parser import parse_classic_query, phrase_to_query_string -from ....domain.api import Phrase, Term, ClassicAPIQuery, Field, Operator +from search.domain.api import Phrase, Field, Operator from werkzeug.exceptions import BadRequest from unittest import TestCase diff --git a/search/controllers/classic_api/tests/test_classic_search.py b/search/controllers/classic_api/tests/test_classic_search.py new file mode 100644 index 00000000..a87f1193 --- /dev/null +++ b/search/controllers/classic_api/tests/test_classic_search.py @@ -0,0 +1,46 @@ +from unittest import TestCase, mock +from werkzeug import MultiDict +from werkzeug.exceptions import BadRequest +from arxiv import status +from search.controllers import classic_api + + +class TestClassicAPISearch(TestCase): + """Tests for :func:`.classic_api.query`.""" + + @mock.patch(f'{classic_api.__name__}.index.SearchSession') + def test_no_params(self, mock_index): + """Request with no parameters.""" + params = MultiDict({}) + with self.assertRaises(BadRequest): + data, code, headers = classic_api.query(params) + + @mock.patch(f'{classic_api.__name__}.index.SearchSession') + def test_classic_query(self, mock_index): + """Request with search_query.""" + params = MultiDict({'search_query': 'au:Copernicus'}) + + data, code, headers = classic_api.query(params) + self.assertEqual(code, status.HTTP_200_OK, "Returns 200 OK") + self.assertIn("results", data, "Results are returned") + self.assertIn("query", data, "Query object is returned") + + @mock.patch(f'{classic_api.__name__}.index.SearchSession') + def test_classic_query_with_quotes(self, mock_index): + """Request with search_query that includes a quoted phrase.""" + params = MultiDict({'search_query': 'ti:"dark matter"'}) + + data, code, headers = classic_api.query(params) + self.assertEqual(code, status.HTTP_200_OK, "Returns 200 OK") + self.assertIn("results", data, "Results are returned") + self.assertIn("query", data, "Query object is returned") + + @mock.patch(f'{classic_api.__name__}.index.SearchSession') + def test_classic_id_list(self, mock_index): + """Request with multi-element id_list with versioned and unversioned ids.""" + params = MultiDict({'id_list': '1234.56789,1234.56789v3'}) + + data, code, headers = classic_api.query(params) + self.assertEqual(code, status.HTTP_200_OK, "Returns 200 OK") + self.assertIn("results", data, "Results are returned") + self.assertIn("query", data, "Query object is returned") diff --git a/search/controllers/simple/__init__.py b/search/controllers/simple/__init__.py index e930adc4..fc6f682f 100644 --- a/search/controllers/simple/__init__.py +++ b/search/controllers/simple/__init__.py @@ -82,7 +82,7 @@ def search(request_params: MultiDict, ) # If so, redirect. logger.debug(f"got arXiv ID: {arxiv_id}") - except ValueError as e: + except ValueError: logger.debug('No arXiv ID detected; fall back to form') arxiv_id = None else: @@ -134,33 +134,33 @@ def search(request_params: MultiDict, # template rendering, so they get added directly to the # response content.asdict response_data.update(SearchSession.search(q)) # type: ignore - except index.IndexConnectionError as e: + except index.IndexConnectionError as ex: # There was a (hopefully transient) connection problem. Either # this will clear up relatively quickly (next request), or # there is a more serious outage. - logger.error('IndexConnectionError: %s', e) + logger.error('IndexConnectionError: %s', ex) raise InternalServerError( "There was a problem connecting to the search index. This is " "quite likely a transient issue, so please try your search " "again. If this problem persists, please report it to " "help@arxiv.org." - ) from e - except index.QueryError as e: + ) from ex + except index.QueryError as ex: # Base exception routers should pick this up and show bug page. - logger.error('QueryError: %s', e) + logger.error('QueryError: %s', ex) raise InternalServerError( "There was a problem executing your query. Please try your " "search again. If this problem persists, please report it to " "help@arxiv.org." - ) from e - except index.OutsideAllowedRange as e: + ) from ex + except index.OutsideAllowedRange as ex: raise BadRequest( "Hello clever friend. You can't get results in that range" " right now." - ) from e + ) from ex - except Exception as e: - logger.error('Unhandled exception: %s', str(e)) + except Exception as ex: + logger.error('Unhandled exception: %s', str(ex)) raise else: logger.debug('form is invalid: %s', str(form.errors)) @@ -207,28 +207,28 @@ def retrieve_document(document_id: str) -> Response: """ try: result = SearchSession.get_document(document_id) # type: ignore - except index.IndexConnectionError as e: + except index.IndexConnectionError as ex: # There was a (hopefully transient) connection problem. Either # this will clear up relatively quickly (next request), or # there is a more serious outage. - logger.error('IndexConnectionError: %s', e) + logger.error('IndexConnectionError: %s', ex) raise InternalServerError( "There was a problem connecting to the search index. This is " "quite likely a transient issue, so please try your search " "again. If this problem persists, please report it to " "help@arxiv.org." - ) from e - except index.QueryError as e: + ) from ex + except index.QueryError as ex: # Base exception routers should pick this up and show bug page. - logger.error('QueryError: %s', e) + logger.error('QueryError: %s', ex) raise InternalServerError( "There was a problem executing your query. Please try your " "search again. If this problem persists, please report it to " "help@arxiv.org." - ) from e - except index.DocumentNotFound as e: - logger.error('DocumentNotFound: %s', e) - raise NotFound(f"Could not find a paper with id {document_id}") from e + ) from ex + except index.DocumentNotFound as ex: + logger.error('DocumentNotFound: %s', ex) + raise NotFound(f"Could not find a paper with id {document_id}") from ex return {'document': result}, status.HTTP_200_OK, {} diff --git a/search/controllers/simple/tests.py b/search/controllers/simple/tests.py index 3b66b67a..8012bcee 100644 --- a/search/controllers/simple/tests.py +++ b/search/controllers/simple/tests.py @@ -31,8 +31,8 @@ def _raiseQueryError(*args, **kwargs): with self.assertRaises(InternalServerError): try: response_data, code, headers = simple.retrieve_document(1) - except QueryError as e: - self.fail("QueryError should be handled (caught %s)" % e) + except QueryError as ex: + self.fail("QueryError should be handled (caught %s)" % ex) self.assertEqual(mock_index.get_document.call_count, 1, "A search should be attempted") @@ -62,8 +62,8 @@ def _raiseDocumentNotFound(*args, **kwargs): with self.assertRaises(NotFound): try: response_data, code, headers = simple.retrieve_document(1) - except DocumentNotFound as e: - self.fail("DocumentNotFound should be handled (caught %s)" % e) + except DocumentNotFound as ex: + self.fail("DocumentNotFound should be handled (caught %s)" % ex) self.assertEqual(mock_index.get_document.call_count, 1, "A search should be attempted") @@ -164,8 +164,8 @@ def _raiseQueryError(*args, **kwargs): with self.assertRaises(InternalServerError): try: response_data, code, headers = simple.search(request_data) - except QueryError as e: - self.fail("QueryError should be handled (caught %s)" % e) + except QueryError as ex: + self.fail("QueryError should be handled (caught %s)" % ex) self.assertEqual(mock_index.search.call_count, 1, "A search should be attempted") diff --git a/search/controllers/tests.py b/search/controllers/tests.py index 083e2964..988044f5 100644 --- a/search/controllers/tests.py +++ b/search/controllers/tests.py @@ -71,5 +71,5 @@ def test_nonsense_input(self): """Garbage input is passed.""" try: catch_underscore_syntax("") - except Exception as e: - self.fail(e) + except Exception as ex: + self.fail(ex) diff --git a/search/domain/__init__.py b/search/domain/__init__.py index f3b0b9ab..5780d91c 100644 --- a/search/domain/__init__.py +++ b/search/domain/__init__.py @@ -12,4 +12,4 @@ from .base import * from .advanced import * from .api import * -from .documents import * +from .documents import * # type: ignore diff --git a/search/domain/documents.py b/search/domain/documents.py index 7acf8387..dc503515 100644 --- a/search/domain/documents.py +++ b/search/domain/documents.py @@ -1,18 +1,33 @@ """Data structs for search documents.""" -from datetime import datetime, date +from datetime import datetime, date, timezone from typing import Optional, List, Dict, Any -from dataclasses import field +from dataclasses import dataclass, field from mypy_extensions import TypedDict -from .base import Classification, ClassificationList +from search.domain.base import Classification, ClassificationList # The class keyword ``total=False`` allows instances that do not contain all of # the typed keys. See https://github.com/python/mypy/issues/2632 for # background. +def utcnow() -> datetime: + """Return timezone aware current timestamp.""" + return datetime.utcnow().astimezone(timezone.utc) + + +@dataclass +class Error: + """Represents an error that happened in the system.""" + + id: str + error: str + link: str + author: str = "arXiv api core" + created: datetime = field(default_factory=utcnow) + class Person(TypedDict, total=False): """Represents an author, owner, or other person in metadata.""" @@ -109,7 +124,7 @@ class DocumentSet(TypedDict): def document_set_from_documents(documents: List[Document]) -> DocumentSet: - """Utility for generating a DocumentSet with only a list of Documents. + """Generate a DocumentSet with only a list of Documents. Generates the metadata automatically, which is an advantage over calling DocumentSet(results=documents, metadata=dict()). @@ -121,7 +136,7 @@ def document_set_from_documents(documents: List[Document]) -> DocumentSet: def metadata_from_documents(documents: List[Document]) -> DocumentSetMetadata: - """Utility for generating DocumentSet metadata from a list of documents.""" + """Generate DocumentSet metadata from a list of documents.""" metadata: DocumentSetMetadata = {} metadata['size'] = len(documents) metadata['end'] = len(documents) diff --git a/search/errors.py b/search/errors.py new file mode 100644 index 00000000..b243a31c --- /dev/null +++ b/search/errors.py @@ -0,0 +1,31 @@ +"""Search error classes.""" + + +class SearchError(Exception): + """Generic search error.""" + + def __init__(self, message: str): + """Initialize the error message.""" + self.message = message + + @property + def name(self) -> str: + """Error name.""" + return self.__class__.__name__ + + def __str__(self) -> str: + """Represent error as a string.""" + return f"{self.name}({self.message})" + + __repr__ = __str__ + + +class ValidationError(SearchError): + """Validation error.""" + + def __init__( + self, message: str, link: str = "http://arxiv.org/api/errors" + ): + """Initialize the validation error.""" + super().__init__(message=message) + self.link = link diff --git a/search/factory.py b/search/factory.py index 65e37810..e08e92b2 100644 --- a/search/factory.py +++ b/search/factory.py @@ -4,12 +4,11 @@ from flask import Flask from flask_s3 import FlaskS3 -from werkzeug.contrib.profiler import ProfilerMiddleware from arxiv.base import Base from arxiv.base.middleware import wrap, request_logs from arxiv.users import auth -from search.routes import ui, api +from search.routes import ui, api, classic_api from search.services import index from search.converters import ArchiveConverter from search.encode import ISO8601JSONEncoder @@ -86,12 +85,12 @@ def create_classic_api_web_app() -> Flask: Base(app) auth.Auth(app) - app.register_blueprint(api.classic.blueprint) + app.register_blueprint(classic_api.blueprint) wrap(app, [request_logs.ClassicLogsMiddleware, auth.middleware.AuthMiddleware]) - for error, handler in api.exceptions.get_handlers(): + for error, handler in classic_api.exceptions.get_handlers(): app.errorhandler(error)(handler) return app diff --git a/search/routes/api/__init__.py b/search/routes/api/__init__.py index c2d80b6f..c85af473 100644 --- a/search/routes/api/__init__.py +++ b/search/routes/api/__init__.py @@ -1,22 +1,13 @@ """Provides routing blueprint from the search API.""" -import json -from typing import Dict, Callable, Union, Any, Optional, List -from functools import wraps -from urllib.parse import urljoin, urlparse, parse_qs, urlencode, urlunparse - -from flask.json import jsonify -from flask import Blueprint, make_response, render_template, redirect, \ - request, Response, url_for -from werkzeug.urls import Href, url_encode, url_parse, url_unparse, url_encode -from werkzeug.datastructures import MultiDict, ImmutableMultiDict - -from arxiv import status +from flask import Blueprint, make_response, request, Response + from arxiv.base import logging -from werkzeug.exceptions import InternalServerError +from search import serialize from search.controllers import api -from . import serialize, exceptions, classic +from search.routes.consts import JSON +from search.routes.api import exceptions from arxiv.users.auth.decorators import scoped from arxiv.users.auth import scopes @@ -25,9 +16,6 @@ blueprint = Blueprint('api', __name__, url_prefix='/') -ATOM_XML = "application/atom+xml; charset=utf-8" -JSON = "application/json; charset=utf-8" - @blueprint.route('/', methods=['GET']) @scoped(required=scopes.READ_PUBLIC) diff --git a/search/routes/api/exceptions.py b/search/routes/api/exceptions.py index e3cf397a..4539e23d 100644 --- a/search/routes/api/exceptions.py +++ b/search/routes/api/exceptions.py @@ -6,14 +6,22 @@ """ from typing import Callable, List, Tuple - -from werkzeug.exceptions import NotFound, Forbidden, Unauthorized, \ - MethodNotAllowed, RequestEntityTooLarge, BadRequest, InternalServerError, \ - HTTPException +from http import HTTPStatus + +from werkzeug.exceptions import ( + NotFound, + Forbidden, + Unauthorized, + MethodNotAllowed, + RequestEntityTooLarge, + BadRequest, + InternalServerError, + HTTPException, +) from flask import make_response, Response, jsonify -from arxiv import status from arxiv.base import logging +from search.routes.consts import JSON logger = logging.getLogger(__name__) @@ -26,6 +34,7 @@ def deco(func: Callable) -> Callable: """Register a function as an exception handler.""" _handlers.append((exception, func)) return func + return deco @@ -37,73 +46,60 @@ def get_handlers() -> List[Tuple[type, Callable]]: ------- list List of (:class:`.HTTPException`, callable) tuples. + """ return _handlers +def respond(error: HTTPException, status: HTTPStatus) -> Response: + """Generate a JSON response.""" + return make_response( # type: ignore + jsonify({"code": error.code, "error": error.description}), + status, + {"Content-type": JSON}, + ) + + @handler(NotFound) def handle_not_found(error: NotFound) -> Response: """Render the base 404 error page.""" - rendered = jsonify({'code': error.code, 'error': error.description}) - response: Response = make_response(rendered) - response.status_code = status.HTTP_404_NOT_FOUND - return response + return respond(error, HTTPStatus.NOT_FOUND) @handler(Forbidden) def handle_forbidden(error: Forbidden) -> Response: """Render the base 403 error page.""" - rendered = jsonify({'code': error.code, 'error': error.description}) - response: Response = make_response(rendered) - response.status_code = status.HTTP_403_FORBIDDEN - return response + return respond(error, HTTPStatus.FORBIDDEN) @handler(Unauthorized) def handle_unauthorized(error: Unauthorized) -> Response: """Render the base 401 error page.""" - rendered = jsonify({'code': error.code, 'error': error.description}) - response: Response = make_response(rendered) - response.status_code = status.HTTP_401_UNAUTHORIZED - return response + return respond(error, HTTPStatus.UNAUTHORIZED) @handler(MethodNotAllowed) def handle_method_not_allowed(error: MethodNotAllowed) -> Response: """Render the base 405 error page.""" - rendered = jsonify({'code': error.code, 'error': error.description}) - response: Response = make_response(rendered) - response.status_code = status.HTTP_405_METHOD_NOT_ALLOWED - return response + return respond(error, HTTPStatus.METHOD_NOT_ALLOWED) @handler(RequestEntityTooLarge) def handle_request_entity_too_large(error: RequestEntityTooLarge) -> Response: """Render the base 413 error page.""" - rendered = jsonify({'code': error.code, 'error': error.description}) - response: Response = make_response(rendered) - response.status_code = status.HTTP_413_REQUEST_ENTITY_TOO_LARGE - return response + return respond(error, HTTPStatus.REQUEST_ENTITY_TOO_LARGE) @handler(BadRequest) def handle_bad_request(error: BadRequest) -> Response: """Render the base 400 error page.""" - rendered = jsonify({'code': error.code, 'error': error.description}) - response: Response = make_response(rendered) - response.status_code = status.HTTP_400_BAD_REQUEST - return response + return respond(error, HTTPStatus.BAD_REQUEST) @handler(InternalServerError) def handle_internal_server_error(error: InternalServerError) -> Response: """Render the base 500 error page.""" - if isinstance(error, HTTPException): - rendered = jsonify({'code': error.code, 'error': error.description}) - else: - logger.error('Caught unhandled exception: %s', error) - rendered = jsonify({'code': status.HTTP_500_INTERNAL_SERVER_ERROR, - 'error': 'Unexpected error'}) - response: Response = make_response(rendered) - response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR - return response + if not isinstance(error, HTTPException): + logger.error("Caught unhandled exception: %s", error) + error.code = HTTPStatus.INTERNAL_SERVER_ERROR + return respond(error, HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/search/routes/api/serialize.py b/search/routes/api/serialize.py deleted file mode 100644 index dace5f99..00000000 --- a/search/routes/api/serialize.py +++ /dev/null @@ -1,249 +0,0 @@ -"""Serializers for API responses.""" - -from typing import Union, Optional, Dict, Any -from datetime import datetime -from xml.etree import ElementTree as etree -from flask import jsonify, url_for, Response - -from feedgen.feed import FeedGenerator -from pytz import utc - -from arxiv import status -from search.domain import DocumentSet, Document, Classification, Person, \ - APIQuery, ClassicAPIQuery, document_set_from_documents -from .atom_extensions import ArXivExtension, ArXivEntryExtension, \ - OpenSearchExtension, ARXIV_NS -from ...controllers.api.classic_parser import phrase_to_query_string - -class BaseSerializer(object): - """Base class for API serializers.""" - - -class JSONSerializer(BaseSerializer): - """Serializes a :class:`DocumentSet` as JSON.""" - - @classmethod - def _transform_classification(cls, clsn: Classification) -> Optional[dict]: - category = clsn.get('category') - if category is None: - return None - return {'group': clsn.get('group'), - 'archive': clsn.get('archive'), - 'category': category} - - @classmethod - def _transform_format(cls, fmt: str, paper_id: str, version: int) -> dict: - return {"format": fmt, - "href": url_for(fmt, paper_id=paper_id, version=version)} - - @classmethod - def _transform_latest(cls, document: Document) -> Optional[dict]: - latest = document.get('latest') - if latest is None: - return None - return { - "paper_id": latest, - "href": url_for("api.paper", paper_id=document['paper_id'], - version=document.get('latest_version'), - _external=True), - "canonical": url_for("abs", paper_id=document['paper_id'], - version=document.get('latest_version')), - "version": document.get('latest_version') - } - - @classmethod - def _transform_license(cls, license: dict) -> Optional[dict]: - uri = license.get('uri') - if uri is None: - return None - return {'label': license.get('label', ''), 'href': uri} - - @classmethod - def transform_document(cls, doc: Document, - query: Optional[APIQuery] = None) -> dict: - """Select a subset of :class:`Document` properties for public API.""" - # Only return fields that have been explicitly requested. - data = {key: value for key, value in doc.items() - if query is None or key in query.include_fields} - paper_id = doc['paper_id'] - version = doc['version'] - if 'submitted_date_first' in data: - data['submitted_date_first'] = \ - doc['submitted_date_first'].isoformat() - if 'announced_date_first' in data: - data['announced_date_first'] = \ - doc['announced_date_first'].isoformat() - if 'formats' in data: - data['formats'] = [cls._transform_format(fmt, paper_id, version) - for fmt in doc['formats']] - if 'license' in data: - data['license'] = cls._transform_license(doc['license']) - if 'latest' in data: - data['latest'] = cls._transform_latest(doc) - - data['href'] = url_for("api.paper", paper_id=paper_id, - version=version, _external=True) - data['canonical'] = url_for("abs", paper_id=paper_id, - version=version) - return data - - @classmethod - def serialize(cls, document_set: DocumentSet, - query: Optional[APIQuery] = None) -> Response: - """Generate JSON for a :class:`DocumentSet`.""" - total_results = int(document_set['metadata'].get('total_results', 0)) - serialized: Response = jsonify({ - 'results': [cls.transform_document(doc, query=query) - for doc in document_set['results']], - 'metadata': { - 'start': document_set['metadata'].get('start', ''), - 'end': document_set['metadata'].get('end', ''), - 'size': document_set['metadata'].get('size', ''), - 'total_results': total_results, - 'query': document_set['metadata'].get('query', []) - }, - }) - return serialized - - @classmethod - def serialize_document(cls, document: Document, - query: Optional[APIQuery] = None) -> Response: - """Generate JSON for a single :class:`Document`.""" - serialized: Response = jsonify( - cls.transform_document(document, query=query) - ) - return serialized - - -def as_json(document_or_set: Union[DocumentSet, Document], - query: Optional[APIQuery] = None) -> Response: - """Serialize a :class:`DocumentSet` as JSON.""" - if 'paper_id' in document_or_set: - return JSONSerializer.serialize_document(document_or_set, query=query) # type: ignore - return JSONSerializer.serialize(document_or_set, query=query) # type: ignore - - - -class AtomXMLSerializer(BaseSerializer): - """Atom XML serializer for paper metadata.""" - - @classmethod - def transform_document(cls, fg: FeedGenerator, doc: Document, - query: Optional[ClassicAPIQuery] = None) -> None: - """Select a subset of :class:`Document` properties for public API.""" - entry = fg.add_entry() - entry.id(url_for("abs", paper_id=doc['paper_id'], - version=doc['version'], _external=True)) - entry.title(doc['title']) - entry.summary(doc['abstract']) - entry.published(doc['submitted_date']) - entry.updated(doc['updated_date']) - entry.link({'href': url_for("abs", paper_id=doc['paper_id'], - version=doc['version'], _external=True), - "type": "text/html"}) - - entry.link({'href': url_for("pdf", paper_id=doc['paper_id'], - version=doc['version'], _external=True), - "type": "application/pdf", 'rel': 'related'}) - - if doc.get('comments'): - entry.arxiv.comment(doc['comments']) - - if doc.get('journal_ref'): - entry.arxiv.journal_ref(doc['journal_ref']) - - if doc.get('doi'): - entry.arxiv.doi(doc['doi']) - - if doc['primary_classification']['category'] is not None: - entry.arxiv.primary_category( - doc['primary_classification']['category']['id'] - ) - entry.category( - term=doc['primary_classification']['category']['id'], - scheme=ARXIV_NS - ) - - for category in doc['secondary_classification']: - entry.category( - term=category['category']['id'], - scheme=ARXIV_NS - ) - - for author in doc['authors']: - author_data: Dict[str, Any] = { - "name": author['full_name'] - } - if author.get('affiliation'): - author_data['affiliation'] = author['affiliation'] - entry.arxiv.author(author_data) - - @classmethod - def serialize(cls, document_set: DocumentSet, - query: Optional[ClassicAPIQuery] = None) -> str: - """Generate Atom response for a :class:`DocumentSet`.""" - fg = FeedGenerator() - fg.register_extension('opensearch', OpenSearchExtension) - fg.register_extension("arxiv", ArXivExtension, ArXivEntryExtension, - rss=False) - - if query: - if query.phrase is not None: - query_string = phrase_to_query_string(query.phrase) - else: - query_string = '' - - if query.id_list: - id_list = ','.join(query.id_list) - else: - id_list = '' - - fg.title( - f'arXiv Query: search_query={query_string}' - f'&start={query.page_start}&max_results={query.size}' - f'&id_list={id_list}') - fg.id(url_for('classic.query', search_query=query_string, - start=query.page_start, max_results=query.size, - id_list=id_list)) - fg.link({ - "href" : url_for('classic.query', search_query=query_string, - start=query.page_start, max_results=query.size, - id_list=id_list), - "type": 'application/atom+xml'}) - else: - # TODO: Discuss better defaults - fg.title("arXiv Search Results") - fg.id("https://arxiv.org/") - - fg.updated(datetime.utcnow().replace(tzinfo=utc)) - - # pylint struggles with the opensearch extensions, so we ignore no-member here. - # pylint: disable=no-member - fg.opensearch.totalResults( - document_set['metadata'].get('total_results') - ) - fg.opensearch.itemsPerPage(document_set['metadata'].get('size')) - fg.opensearch.startIndex(document_set['metadata'].get('start')) - - for doc in document_set['results']: - cls.transform_document(fg, doc, query=query) - - serialized: str = fg.atom_str(pretty=True) - return serialized - - @classmethod - def serialize_document(cls, document: Document, - query: Optional[ClassicAPIQuery] = None) -> str: - """Generate Atom feed for a single :class:`Document`.""" - # Wrap the single document in a DocumentSet wrapper. - document_set = document_set_from_documents([document]) - - return cls.serialize(document_set, query=query) - - -def as_atom(document_or_set: Union[DocumentSet, Document], - query: Optional[APIQuery] = None) -> str: - """Serialize a :class:`DocumentSet` as Atom.""" - if 'paper_id' in document_or_set: - return AtomXMLSerializer.serialize_document(document_or_set, query=query) # type: ignore - return AtomXMLSerializer.serialize(document_or_set, query=query) # type: ignore diff --git a/search/routes/api/classic.py b/search/routes/classic_api/__init__.py similarity index 71% rename from search/routes/api/classic.py rename to search/routes/classic_api/__init__.py index c89a7ad8..2af9bd80 100644 --- a/search/routes/api/classic.py +++ b/search/routes/classic_api/__init__.py @@ -1,23 +1,18 @@ """Provides the classic search API.""" -from flask import Blueprint, make_response, render_template, redirect, \ - request, Response, url_for +from flask import Blueprint, make_response, request, Response from arxiv.base import logging - -from arxiv.users.auth.decorators import scoped from arxiv.users.auth import scopes - -from search.controllers import api -from . import serialize, exceptions - +from arxiv.users.auth.decorators import scoped +from search import serialize +from search.controllers import classic_api +from search.routes.consts import ATOM_XML +from search.routes.classic_api import exceptions logger = logging.getLogger(__name__) -blueprint = Blueprint('classic', __name__, url_prefix='/classic') - -ATOM_XML = "application/atom+xml; charset=utf-8" -JSON = "application/json; charset=utf-8" +blueprint = Blueprint('classic_api', __name__, url_prefix='/classic_api') @blueprint.route('/query', methods=['GET']) @@ -25,7 +20,7 @@ def query() -> Response: """Main query endpoint.""" logger.debug('Got query: %s', request.args) - data, status_code, headers = api.classic_query(request.args) + data, status_code, headers = classic_api.query(request.args) # requested = request.accept_mimetypes.best_match([JSON, ATOM_XML]) # if requested == ATOM_XML: # return serialize.as_atom(data), status, headers @@ -39,7 +34,7 @@ def query() -> Response: @scoped(required=scopes.READ_PUBLIC) def paper(paper_id: str, version: str) -> Response: """Document metadata endpoint.""" - data, status_code, headers = api.paper(f'{paper_id}v{version}') + data, status_code, headers = classic_api.paper(f'{paper_id}v{version}') response_data = serialize.as_atom(data['results']) headers.update({'Content-type': ATOM_XML}) response: Response = make_response(response_data, status_code, headers) diff --git a/search/routes/classic_api/exceptions.py b/search/routes/classic_api/exceptions.py new file mode 100644 index 00000000..81d6f5e9 --- /dev/null +++ b/search/routes/classic_api/exceptions.py @@ -0,0 +1,119 @@ +""" +Exception handlers for classic arXiv API endpoints. + +.. todo:: This module belongs in :mod:`arxiv.base`. + +""" +from http import HTTPStatus +from typing import Callable, List, Tuple +from werkzeug.exceptions import ( + NotFound, + Forbidden, + Unauthorized, + MethodNotAllowed, + RequestEntityTooLarge, + BadRequest, + InternalServerError, + HTTPException, +) +from flask import make_response, Response + +from arxiv.base import logging +from search.serialize import as_atom +from search.domain import Error +from search.routes.consts import ATOM_XML +from search.errors import ValidationError + + +logger = logging.getLogger(__name__) + +_handlers = [] + + +def handler(exception: type) -> Callable: + """Generate a decorator to register a handler for an exception.""" + def deco(func: Callable) -> Callable: + """Register a function as an exception handler.""" + _handlers.append((exception, func)) + return func + + return deco + + +def get_handlers() -> List[Tuple[type, Callable]]: + """Get a list of registered exception handlers. + + Returns + ------- + list + List of (:class:`.HTTPException`, callable) tuples. + + """ + return _handlers + + +def respond( + error_msg: str, + link: str = "http://arxiv.org/api/errors", + status: HTTPStatus = HTTPStatus.INTERNAL_SERVER_ERROR, +) -> Response: + """Generate an Atom response.""" + return make_response( # type: ignore + as_atom(Error(id=link, error=error_msg, link=link)), + status, + {"Content-type": ATOM_XML}, + ) + + +@handler(NotFound) +def handle_not_found(error: NotFound) -> Response: + """Render the base 404 error page.""" + return respond(error.description, status=HTTPStatus.NOT_FOUND) + + +@handler(Forbidden) +def handle_forbidden(error: Forbidden) -> Response: + """Render the base 403 error page.""" + return respond(error.description, status=HTTPStatus.FORBIDDEN) + + +@handler(Unauthorized) +def handle_unauthorized(error: Unauthorized) -> Response: + """Render the base 401 error page.""" + return respond(error.description, status=HTTPStatus.UNAUTHORIZED) + + +@handler(MethodNotAllowed) +def handle_method_not_allowed(error: MethodNotAllowed) -> Response: + """Render the base 405 error page.""" + return respond(error.description, status=HTTPStatus.METHOD_NOT_ALLOWED) + + +@handler(RequestEntityTooLarge) +def handle_request_entity_too_large(error: RequestEntityTooLarge) -> Response: + """Render the base 413 error page.""" + return respond( + error.description, status=HTTPStatus.REQUEST_ENTITY_TOO_LARGE + ) + + +@handler(BadRequest) +def handle_bad_request(error: BadRequest) -> Response: + """Render the base 400 error page.""" + return respond(error.description, status=HTTPStatus.BAD_REQUEST) + + +@handler(InternalServerError) +def handle_internal_server_error(error: InternalServerError) -> Response: + """Render the base 500 error page.""" + if not isinstance(error, HTTPException): + logger.error("Caught unhandled exception: %s", error) + return respond(error.description, status=HTTPStatus.INTERNAL_SERVER_ERROR) + + +@handler(ValidationError) +def handle_validation_error(error: ValidationError) -> Response: + """Render the base 400 error page.""" + return respond( + error_msg=error.message, link=error.link, status=HTTPStatus.BAD_REQUEST + ) diff --git a/search/routes/classic_api/tests/__init__.py b/search/routes/classic_api/tests/__init__.py new file mode 100644 index 00000000..4c9d94f7 --- /dev/null +++ b/search/routes/classic_api/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for classic arXiv API routes.""" diff --git a/search/routes/api/tests/test_classic.py b/search/routes/classic_api/tests/test_classic.py similarity index 56% rename from search/routes/api/tests/test_classic.py rename to search/routes/classic_api/tests/test_classic.py index 0dcffb48..d00ca903 100644 --- a/search/routes/api/tests/test_classic.py +++ b/search/routes/classic_api/tests/test_classic.py @@ -1,16 +1,15 @@ """Tests for API routes.""" import os -import json +from http import HTTPStatus from datetime import datetime from unittest import TestCase, mock +from xml.etree import ElementTree -import jsonschema import pytz from arxiv.users import helpers, auth from arxiv.users.domain import Scope -from arxiv import status from search import factory from search import domain @@ -26,22 +25,28 @@ def setUp(self): self.app = factory.create_classic_api_web_app() self.app.config['JWT_SECRET'] = jwt_secret self.client = self.app.test_client() + self.auth_header = { + 'Authorization': helpers.generate_token( + '1234', 'foo@bar.com', 'foouser', + scope=[auth.scopes.READ_PUBLIC] + ) + } def test_request_without_token(self): """No auth token is provided on the request.""" - response = self.client.get('/classic/query?search_query=au:copernicus') - self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + response = self.client.get('/classic_api/query?search_query=au:copernicus') + self.assertEqual(response.status_code, HTTPStatus.UNAUTHORIZED) def test_with_token_lacking_scope(self): """Client auth token lacks required public read scope.""" token = helpers.generate_token('1234', 'foo@bar.com', 'foouser', scope=[Scope('something', 'read')]) response = self.client.get( - '/classic/query?search_query=au:copernicus', + '/classic_api/query?search_query=au:copernicus', headers={'Authorization': token}) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response.status_code, HTTPStatus.FORBIDDEN) - @mock.patch(f'{factory.__name__}.api.classic.api') + @mock.patch(f'{factory.__name__}.classic_api.classic_api') def test_with_valid_token(self, mock_controller): """Client auth token has required public read scope.""" document = dict( @@ -99,16 +104,13 @@ def test_with_valid_token(self, mock_controller): metadata={'start': 0, 'end': 1, 'size': 50, 'total': 1} ) r_data = {'results': docs, 'query': domain.ClassicAPIQuery(id_list=['1234.5678'])} - mock_controller.classic_query.return_value = r_data, status.HTTP_200_OK, {} - token = helpers.generate_token('1234', 'foo@bar.com', 'foouser', - scope=[auth.scopes.READ_PUBLIC]) + mock_controller.query.return_value = r_data, HTTPStatus.OK, {} response = self.client.get( - '/classic/query?search_query=au:copernicus', - headers={'Authorization': token}) - self.assertEqual(response.status_code, status.HTTP_200_OK) + '/classic_api/query?search_query=au:copernicus', + headers=self.auth_header) + self.assertEqual(response.status_code, HTTPStatus.OK) - - @mock.patch(f'{factory.__name__}.api.classic.api') + @mock.patch(f'{factory.__name__}.classic_api.classic_api') def test_paper_retrieval(self, mock_controller): """Test single-paper retrieval.""" document = dict( @@ -166,8 +168,87 @@ def test_paper_retrieval(self, mock_controller): metadata={'start': 0, 'end': 1, 'size': 50, 'total': 1} ) r_data = {'results': docs, 'query': domain.APIQuery()} - mock_controller.paper.return_value = r_data, status.HTTP_200_OK, {} - token = helpers.generate_token('1234', 'foo@bar.com', 'foouser', - scope=[auth.scopes.READ_PUBLIC]) - response = self.client.get('/classic/1234.56789v6', headers={'Authorization': token}) - self.assertEqual(response.status_code, status.HTTP_200_OK) + mock_controller.paper.return_value = r_data, HTTPStatus.OK, {} + response = self.client.get( + '/classic_api/1234.56789v6', headers=self.auth_header + ) + self.assertEqual(response.status_code, HTTPStatus.OK) + + # Validation errors + def _fix_path(self, path): + return "/".join([ + "{{http://www.w3.org/2005/Atom}}{}".format(p) + for p in path.split("/") + ]) + + def _node(self, et: ElementTree, path: str): + """Return the node.""" + return et.find(self._fix_path(path)) + + def _text(self, et: ElementTree, path: str): + """Return the text content of the node""" + return et.findtext(self._fix_path(path)) + + def check_validation_error(self, response, error, link): + et = ElementTree.fromstring(response.get_data(as_text=True)) + self.assertEqual(self._text(et, "entry/id"), link) + self.assertEqual(self._text(et, "entry/title"), "Error") + self.assertEqual(self._text(et, "entry/summary"), error) + link_attrib = self._node(et, "entry/link").attrib + self.assertEqual(link_attrib["href"], link) + + def test_start_not_a_number(self): + response = self.client.get( + '/classic_api/query?search_query=au:copernicus&start=non_number', + headers=self.auth_header) + self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST) + self.check_validation_error( + response, + "start must be an integer", + "http://arxiv.org/api/errors#start_must_be_an_integer" + ) + + def test_start_negative(self): + response = self.client.get( + '/classic_api/query?search_query=au:copernicus&start=-1', + headers=self.auth_header) + self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST) + self.check_validation_error( + response, + "start must be non-negative", + "http://arxiv.org/api/errors#start_must_be_non-negative" + ) + + def test_max_results_not_a_number(self): + response = self.client.get( + '/classic_api/query?search_query=au:copernicus&max_results=non_number', + headers=self.auth_header) + self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST) + self.check_validation_error( + response, + "max_results must be an integer", + "http://arxiv.org/api/errors#max_results_must_be_an_integer" + ) + + def test_max_results_negative(self): + response = self.client.get( + '/classic_api/query?search_query=au:copernicus&max_results=-1', + headers=self.auth_header) + self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST) + self.check_validation_error( + response, + "max_results must be non-negative", + "http://arxiv.org/api/errors#max_results_must_be_non-negative" + ) + + def test_invalid_arxiv_id(self): + response = self.client.get( + '/classic_api/query?id_list=cond—mat/0709123', + headers=self.auth_header) + self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST) + self.check_validation_error( + response, + "incorrect id format for cond—mat/0709123", + "http://arxiv.org/api/errors#" + "incorrect_id_format_for_cond—mat/0709123" + ) diff --git a/search/routes/consts.py b/search/routes/consts.py new file mode 100644 index 00000000..92862409 --- /dev/null +++ b/search/routes/consts.py @@ -0,0 +1,4 @@ +"""Serialization MIME type and charset constants.""" + +ATOM_XML = "application/atom+xml; charset=utf-8" +JSON = "application/json; charset=utf-8" diff --git a/search/serialize/__init__.py b/search/serialize/__init__.py new file mode 100644 index 00000000..95609a83 --- /dev/null +++ b/search/serialize/__init__.py @@ -0,0 +1,5 @@ +"""Provides serialization functions for API responses.""" +__all__ = ["JSONSerializer", "as_json", "AtomXMLSerializer", "as_atom"] + +from search.serialize.json import JSONSerializer, as_json +from search.serialize.atom import AtomXMLSerializer, as_atom diff --git a/search/serialize/atom.py b/search/serialize/atom.py new file mode 100644 index 00000000..cdbcfbf9 --- /dev/null +++ b/search/serialize/atom.py @@ -0,0 +1,177 @@ +"""Atom serialization for classic arXiv API.""" + +from typing import Union, Optional, Dict, Any +from datetime import datetime + +from pytz import utc +from flask import url_for +from feedgen.feed import FeedGenerator + +from search.domain import ( + Error, DocumentSet, Document, ClassicAPIQuery, document_set_from_documents +) +from search.serialize.atom_extensions import ( + ArXivExtension, ArXivEntryExtension, OpenSearchExtension, ARXIV_NS +) +from search.controllers.classic_api.query_parser import phrase_to_query_string +from search.serialize.base import BaseSerializer + + +class AtomXMLSerializer(BaseSerializer): + """Atom XML serializer for paper metadata.""" + + @classmethod + def transform_document(cls, fg: FeedGenerator, doc: Document, + query: Optional[ClassicAPIQuery] = None) -> None: + """Select a subset of :class:`Document` properties for public API.""" + entry = fg.add_entry() + entry.id(url_for("abs", paper_id=doc['paper_id'], + version=doc['version'], _external=True)) + entry.title(doc['title']) + entry.summary(doc['abstract']) + entry.published(doc['submitted_date']) + entry.updated(doc['updated_date']) + entry.link({'href': url_for("abs", paper_id=doc['paper_id'], + version=doc['version'], _external=True), + "type": "text/html"}) + + entry.link({'href': url_for("pdf", paper_id=doc['paper_id'], + version=doc['version'], _external=True), + "type": "application/pdf", 'rel': 'related'}) + + if doc.get('comments'): + entry.arxiv.comment(doc['comments']) + + if doc.get('journal_ref'): + entry.arxiv.journal_ref(doc['journal_ref']) + + if doc.get('doi'): + entry.arxiv.doi(doc['doi']) + + if doc['primary_classification']['category'] is not None: + entry.arxiv.primary_category( + doc['primary_classification']['category']['id'] + ) + entry.category( + term=doc['primary_classification']['category']['id'], + scheme=ARXIV_NS + ) + + for category in doc['secondary_classification']: + entry.category( + term=category['category']['id'], + scheme=ARXIV_NS + ) + + for author in doc['authors']: + author_data: Dict[str, Any] = { + "name": author['full_name'] + } + if author.get('affiliation'): + author_data['affiliation'] = author['affiliation'] + entry.arxiv.author(author_data) + + @staticmethod + def _get_feed(query: Optional[ClassicAPIQuery] = None) -> FeedGenerator: + fg = FeedGenerator() + fg.register_extension('opensearch', OpenSearchExtension) + fg.register_extension("arxiv", ArXivExtension, ArXivEntryExtension, + rss=False) + + if query: + if query.phrase is not None: + query_string = phrase_to_query_string(query.phrase) + else: + query_string = '' + + if query.id_list: + id_list = ','.join(query.id_list) + else: + id_list = '' + + fg.title( + f'arXiv Query: search_query={query_string}' + f'&start={query.page_start}&max_results={query.size}' + f'&id_list={id_list}') + fg.id(url_for('classic_api.query', search_query=query_string, + start=query.page_start, max_results=query.size, + id_list=id_list)) + fg.link({ + "href": url_for('classic_api.query', search_query=query_string, + start=query.page_start, max_results=query.size, + id_list=id_list), + "type": 'application/atom+xml'}) + else: + # TODO: Discuss better defaults + fg.title("arXiv Search Results") + fg.id("https://arxiv.org/") + + fg.updated(datetime.utcnow().replace(tzinfo=utc)) + return fg + + @classmethod + def serialize(cls, document_set: DocumentSet, + query: Optional[ClassicAPIQuery] = None) -> str: + """Generate Atom response for a :class:`DocumentSet`.""" + fg = cls._get_feed(query) + + # pylint struggles with the opensearch extensions, so we ignore + # no-member here. + # pylint: disable=no-member + fg.opensearch.totalResults( + document_set['metadata'].get('total_results') + ) + fg.opensearch.itemsPerPage(document_set['metadata'].get('size')) + fg.opensearch.startIndex(document_set['metadata'].get('start')) + + for doc in document_set['results']: + cls.transform_document(fg, doc, query=query) + + return fg.atom_str(pretty=True) # type: ignore + + @classmethod + def serialize_error(cls, error: Error, + query: Optional[ClassicAPIQuery] = None) -> str: + """Generate Atom error response.""" + fg = cls._get_feed(query) + + # pylint struggles with the opensearch extensions, so we ignore + # no-member here. + # pylint: disable=no-member + fg.opensearch.totalResults(1) + fg.opensearch.itemsPerPage(1) + fg.opensearch.startIndex(0) + + entry = fg.add_entry() + entry.id(error.id) + entry.title("Error") + entry.summary(error.error) + entry.updated(error.created) + entry.link({ + 'href': error.link, + "rel": "alternate", + "type": "text/html" + }) + entry.arxiv.author({"name": error.author}) + + return fg.atom_str(pretty=True) # type: ignore + + @classmethod + def serialize_document(cls, document: Document, + query: Optional[ClassicAPIQuery] = None) -> str: + """Generate Atom feed for a single :class:`Document`.""" + # Wrap the single document in a DocumentSet wrapper. + document_set = document_set_from_documents([document]) + + return cls.serialize(document_set, query=query) + + +def as_atom(document_or_set: Union[Error, DocumentSet, Document], + query: Optional[ClassicAPIQuery] = None) -> str: + """Serialize a :class:`DocumentSet` as Atom.""" + if isinstance(document_or_set, Error): + return AtomXMLSerializer.serialize_error(document_or_set, query=query) # type: ignore + # type: ignore + elif 'paper_id' in document_or_set: + return AtomXMLSerializer.serialize_document(document_or_set, query=query) # type: ignore + return AtomXMLSerializer.serialize(document_or_set, query=query) # type: ignore diff --git a/search/routes/api/atom_extensions.py b/search/serialize/atom_extensions.py similarity index 80% rename from search/routes/api/atom_extensions.py rename to search/serialize/atom_extensions.py index 2d0be2b2..37905958 100644 --- a/search/routes/api/atom_extensions.py +++ b/search/serialize/atom_extensions.py @@ -1,17 +1,21 @@ """Feedgen extensions to implement serialization of the arXiv legacy API atom feed. -Throughout module, pylint: disable=arguments-differ due to inconsistencies in feedgen library. +Throughout module, pylint: disable=arguments-differ due to inconsistencies in +feedgen library. """ # pylint: disable=arguments-differ from typing import Any, Dict, List + +from lxml import etree from feedgen.ext.base import BaseEntryExtension, BaseExtension from feedgen.entry import FeedEntry from feedgen.feed import FeedGenerator -from lxml import etree -ARXIV_NS = 'http://arxiv.org/schemas/atom' -OPENSEARCH_NS = 'http://a9.com/-/spec/opensearch/1.1/' + +ARXIV_NS = "http://arxiv.org/schemas/atom" +OPENSEARCH_NS = "http://a9.com/-/spec/opensearch/1.1/" + class OpenSearchExtension(BaseExtension): """Extension of the Feedgen base class to put OpenSearch metadata.""" @@ -25,7 +29,9 @@ def __init__(self: BaseExtension) -> None: self.__opensearch_startIndex = None self.__opensearch_itemsPerPage = None - def extend_atom(self: BaseExtension, atom_feed: FeedGenerator) -> FeedGenerator: + def extend_atom( + self: BaseExtension, atom_feed: FeedGenerator + ) -> FeedGenerator: """ Assign the Atom feed generator to the extension. @@ -41,15 +47,19 @@ def extend_atom(self: BaseExtension, atom_feed: FeedGenerator) -> FeedGenerator: """ if self.__opensearch_itemsPerPage is not None: - elt = etree.SubElement(atom_feed, f'{{{OPENSEARCH_NS}}}itemsPerPage') + elt = etree.SubElement( + atom_feed, f"{{{OPENSEARCH_NS}}}itemsPerPage" + ) elt.text = self.__opensearch_itemsPerPage if self.__opensearch_totalResults is not None: - elt = etree.SubElement(atom_feed, f'{{{OPENSEARCH_NS}}}totalResults') + elt = etree.SubElement( + atom_feed, f"{{{OPENSEARCH_NS}}}totalResults" + ) elt.text = self.__opensearch_totalResults if self.__opensearch_startIndex is not None: - elt = etree.SubElement(atom_feed, f'{{{OPENSEARCH_NS}}}startIndex') + elt = etree.SubElement(atom_feed, f"{{{OPENSEARCH_NS}}}startIndex") elt.text = self.__opensearch_startIndex return atom_feed @@ -83,7 +93,7 @@ def extend_ns() -> Dict[str, str]: The definition string for the "arxiv" namespace. """ - return {'opensearch': OPENSEARCH_NS} + return {"opensearch": OPENSEARCH_NS} def totalResults(self: BaseExtension, text: str) -> None: """Set the totalResults parameter.""" @@ -151,7 +161,7 @@ def extend_ns() -> Dict[str, str]: The definition string for the "arxiv" namespace. """ - return {'arxiv': ARXIV_NS} + return {"arxiv": ARXIV_NS} class ArXivEntryExtension(BaseEntryExtension): @@ -181,37 +191,42 @@ def extend_atom(self: BaseEntryExtension, entry: FeedEntry) -> FeedEntry: """ if self.__arxiv_comment: - comment_element = etree.SubElement(entry, f'{{{ARXIV_NS}}}comment') + comment_element = etree.SubElement(entry, f"{{{ARXIV_NS}}}comment") comment_element.text = self.__arxiv_comment if self.__arxiv_primary_category: - primary_category_element = etree.SubElement(entry, f'{{{ARXIV_NS}}}primary_category') - primary_category_element.attrib['term'] = self.__arxiv_primary_category + primary_category_element = etree.SubElement( + entry, f"{{{ARXIV_NS}}}primary_category" + ) + primary_category_element.attrib[ + "term" + ] = self.__arxiv_primary_category if self.__arxiv_journal_ref: - journal_ref_element = \ - etree.SubElement(entry, f'{{{ARXIV_NS}}}journal_ref') + journal_ref_element = etree.SubElement( + entry, f"{{{ARXIV_NS}}}journal_ref" + ) journal_ref_element.text = self.__arxiv_journal_ref if self.__arxiv_authors: for author in self.__arxiv_authors: - author_element = etree.SubElement(entry, 'author') - name_element = etree.SubElement(author_element, 'name') - name_element.text = author['name'] - for affiliation in author.get('affiliation', []): - affiliation_element = \ - etree.SubElement(author_element, - '{%s}affiliation' % ARXIV_NS) + author_element = etree.SubElement(entry, "author") + name_element = etree.SubElement(author_element, "name") + name_element.text = author["name"] + for affiliation in author.get("affiliation", []): + affiliation_element = etree.SubElement( + author_element, "{%s}affiliation" % ARXIV_NS + ) affiliation_element.text = affiliation if self.__arxiv_doi: for doi in self.__arxiv_doi: - doi_element = etree.SubElement(entry, f'{{{ARXIV_NS}}}doi') + doi_element = etree.SubElement(entry, f"{{{ARXIV_NS}}}doi") doi_element.text = doi - doi_link_element = etree.SubElement(entry, 'link') - doi_link_element.set('rel', 'related') - doi_link_element.set('href', f'https://doi.org/{doi}') + doi_link_element = etree.SubElement(entry, "link") + doi_link_element.set("rel", "related") + doi_link_element.set("href", f"https://doi.org/{doi}") return entry diff --git a/search/serialize/base.py b/search/serialize/base.py new file mode 100644 index 00000000..cacfb913 --- /dev/null +++ b/search/serialize/base.py @@ -0,0 +1,5 @@ +"""Base class for API serializers.""" + + +class BaseSerializer(object): + """Base class for API serializers.""" diff --git a/search/serialize/json.py b/search/serialize/json.py new file mode 100644 index 00000000..f210630e --- /dev/null +++ b/search/serialize/json.py @@ -0,0 +1,114 @@ +"""Serializers for API responses.""" + +from typing import Union, Optional +from flask import jsonify, url_for, Response + +from search.serialize.base import BaseSerializer +from search.domain import DocumentSet, Document, Classification, APIQuery + + +class JSONSerializer(BaseSerializer): + """Serializes a :class:`DocumentSet` as JSON.""" + + @classmethod + def _transform_classification(cls, clsn: Classification) -> Optional[dict]: + category = clsn.get('category') + if category is None: + return None + return {'group': clsn.get('group'), + 'archive': clsn.get('archive'), + 'category': category} + + @classmethod + def _transform_format(cls, fmt: str, paper_id: str, version: int) -> dict: + return {"format": fmt, + "href": url_for(fmt, paper_id=paper_id, version=version)} + + @classmethod + def _transform_latest(cls, document: Document) -> Optional[dict]: + latest = document.get('latest') + if latest is None: + return None + return { + "paper_id": latest, + "href": url_for("api.paper", paper_id=document['paper_id'], + version=document.get('latest_version'), + _external=True), + "canonical": url_for("abs", paper_id=document['paper_id'], + version=document.get('latest_version')), + "version": document.get('latest_version') + } + + @classmethod + def _transform_license(cls, license: dict) -> Optional[dict]: + uri = license.get('uri') + if uri is None: + return None + return {'label': license.get('label', ''), 'href': uri} + + @classmethod + def transform_document(cls, doc: Document, + query: Optional[APIQuery] = None) -> dict: + """Select a subset of :class:`Document` properties for public API.""" + # Only return fields that have been explicitly requested. + data = {key: value for key, value in doc.items() + if query is None or key in query.include_fields} + paper_id = doc['paper_id'] + version = doc['version'] + if 'submitted_date_first' in data: + data['submitted_date_first'] = \ + doc['submitted_date_first'].isoformat() + if 'announced_date_first' in data: + data['announced_date_first'] = \ + doc['announced_date_first'].isoformat() + if 'formats' in data: + data['formats'] = [cls._transform_format(fmt, paper_id, version) + for fmt in doc['formats']] + if 'license' in data: + data['license'] = cls._transform_license(doc['license']) + if 'latest' in data: + data['latest'] = cls._transform_latest(doc) + + data['href'] = url_for("api.paper", paper_id=paper_id, + version=version, _external=True) + data['canonical'] = url_for("abs", paper_id=paper_id, + version=version) + return data + + @classmethod + def serialize(cls, document_set: DocumentSet, + query: Optional[APIQuery] = None) -> Response: + """Generate JSON for a :class:`DocumentSet`.""" + total_results = int(document_set['metadata'].get('total_results', 0)) + serialized: Response = jsonify({ + 'results': [cls.transform_document(doc, query=query) + for doc in document_set['results']], + 'metadata': { + 'start': document_set['metadata'].get('start', ''), + 'end': document_set['metadata'].get('end', ''), + 'size': document_set['metadata'].get('size', ''), + 'total_results': total_results, + 'query': document_set['metadata'].get('query', []) + }, + }) + return serialized + + @classmethod + def serialize_document(cls, document: Document, + query: Optional[APIQuery] = None) -> Response: + """Generate JSON for a single :class:`Document`.""" + serialized: Response = jsonify( + cls.transform_document(document, query=query) + ) + return serialized + + +def as_json(document_or_set: Union[DocumentSet, Document], + query: Optional[APIQuery] = None) -> Response: + """Serialize a :class:`DocumentSet` as JSON.""" + if 'paper_id' in document_or_set: + return JSONSerializer.serialize_document(document_or_set, query=query) # type: ignore + return JSONSerializer.serialize(document_or_set, query=query) # type: ignore + + + diff --git a/search/serialize/tests/__init__.py b/search/serialize/tests/__init__.py new file mode 100644 index 00000000..f708b96f --- /dev/null +++ b/search/serialize/tests/__init__.py @@ -0,0 +1 @@ +"""Serialization tests.""" diff --git a/search/routes/api/tests/test_serialize.py b/search/serialize/tests/test_serialize.py similarity index 93% rename from search/routes/api/tests/test_serialize.py rename to search/serialize/tests/test_serialize.py index 0677e49f..73803908 100644 --- a/search/routes/api/tests/test_serialize.py +++ b/search/serialize/tests/test_serialize.py @@ -6,8 +6,8 @@ import pytz import json import jsonschema -from .... import domain, encode -from .. import serialize +from search import encode +from search import serialize def mock_jsonify(o): @@ -23,8 +23,9 @@ def setUp(self): with open(self.SCHEMA_PATH) as f: self.schema = json.load(f) - @mock.patch(f'{serialize.__name__}.url_for', lambda *a, **k: 'http://f/12') - @mock.patch(f'{serialize.__name__}.jsonify', mock_jsonify) + @mock.patch(f'search.serialize.json.url_for', + lambda *a, **k: 'http://f/12') + @mock.patch(f'search.serialize.json.jsonify', mock_jsonify) def test_to_json(self): """Just your run-of-the-mill arXiv document generates valid JSON.""" document = dict( @@ -94,8 +95,9 @@ def setUp(self): with open(self.SCHEMA_PATH) as f: self.schema = json.load(f) - @mock.patch(f'{serialize.__name__}.url_for', lambda *a, **k: 'http://f/12') - @mock.patch(f'{serialize.__name__}.jsonify', mock_jsonify) + @mock.patch(f'search.serialize.json.url_for', + lambda *a, **k: 'http://f/12') + @mock.patch(f'search.serialize.json.jsonify', mock_jsonify) def test_to_json(self): """Just your run-of-the-mill arXiv document generates valid JSON.""" document = dict( @@ -159,10 +161,11 @@ def test_to_json(self): jsonschema.validate(json.loads(srlzd), self.schema, resolver=res) ) + class TestSerializeAtomDocument(TestCase): """Serialize a single :class:`domain.Document` as Atom.""" - @mock.patch(f'{serialize.__name__}.url_for', lambda *a, **k: 'http://f/12') - @mock.patch(f'{serialize.__name__}.jsonify', mock_jsonify) + @mock.patch(f'search.serialize.atom.url_for', + lambda *a, **k: 'http://f/12') def test_to_atom(self): """Just your run-of-the-mill arXiv document generates valid Atom.""" document = dict( diff --git a/search/services/fulltext.py b/search/services/fulltext.py index a00bd9b2..a235ec2d 100644 --- a/search/services/fulltext.py +++ b/search/services/fulltext.py @@ -54,17 +54,17 @@ def retrieve(self, document_id: str) -> Fulltext: try: response = requests.get(urljoin(self.endpoint, document_id)) - except requests.exceptions.SSLError as e: - raise IOError('SSL failed: %s' % e) + except requests.exceptions.SSLError as ex: + raise IOError('SSL failed: %s' % ex) if response.status_code != status.HTTP_200_OK: raise IOError('%s: could not retrieve fulltext: %i' % (document_id, response.status_code)) try: data = response.json() - except json.decoder.JSONDecodeError as e: + except json.decoder.JSONDecodeError as ex: raise IOError('%s: could not decode response: %s' % - (document_id, e)) from e + (document_id, ex)) from ex return Fulltext(**data) # type: ignore # See https://github.com/python/mypy/issues/3937 diff --git a/search/services/index/__init__.py b/search/services/index/__init__.py index 7adb0c0e..86283976 100644 --- a/search/services/index/__init__.py +++ b/search/services/index/__init__.py @@ -62,35 +62,35 @@ def handle_es_exceptions() -> Generator: """Handle common ElasticSearch-related exceptions.""" try: yield - except TransportError as e: - if e.error == 'resource_already_exists_exception': + except TransportError as ex: + if ex.error == 'resource_already_exists_exception': logger.debug('Index already exists; move along') return - elif e.error == 'mapper_parsing_exception': - logger.error('ES mapper_parsing_exception: %s', e.info) - logger.debug(str(e.info)) - raise MappingError('Invalid mapping: %s' % str(e.info)) from e - elif e.error == 'index_not_found_exception': - logger.error('ES index_not_found_exception: %s', e.info) + elif ex.error == 'mapper_parsing_exception': + logger.error('ES mapper_parsing_exception: %s', ex.info) + logger.debug(str(ex.info)) + raise MappingError('Invalid mapping: %s' % str(ex.info)) from ex + elif ex.error == 'index_not_found_exception': + logger.error('ES index_not_found_exception: %s', ex.info) SearchSession.current_session().create_index() - elif e.error == 'parsing_exception': - logger.error('ES parsing_exception: %s', e.info) - raise QueryError(e.info) from e - elif e.status_code == 404: - logger.error('Caught NotFoundError: %s', e) + elif ex.error == 'parsing_exception': + logger.error('ES parsing_exception: %s', ex.info) + raise QueryError(ex.info) from ex + elif ex.status_code == 404: + logger.error('Caught NotFoundError: %s', ex) raise DocumentNotFound('No such document') - logger.error('Problem communicating with ES: %s' % e.error) + logger.error('Problem communicating with ES: %s' % ex.error) raise IndexConnectionError( - 'Problem communicating with ES: %s' % e.error - ) from e - except SerializationError as e: - logger.error("SerializationError: %s", e) - raise IndexingError('Problem serializing document: %s' % e) from e - except BulkIndexError as e: - logger.error("BulkIndexError: %s", e) - raise IndexingError('Problem with bulk indexing: %s' % e) from e - except Exception as e: - logger.error('Unhandled exception: %s') + 'Problem communicating with ES: %s' % ex.error + ) from ex + except SerializationError as ex: + logger.error("SerializationError: %s", ex) + raise IndexingError('Problem serializing document: %s' % ex) from ex + except BulkIndexError as ex: + logger.error("BulkIndexError: %s", ex) + raise IndexingError('Problem with bulk indexing: %s' % ex) from ex + except Exception as ex: + logger.error('Unhandled exception: %s' % ex) raise @@ -146,11 +146,11 @@ def new_connection(self) -> Elasticsearch: [self.conn_params], connection_class=Urllib3HttpConnection, **self.conn_extra) - except ElasticsearchException as e: - logger.error('ElasticsearchException: %s', e) + except ElasticsearchException as ex: + logger.error('ElasticsearchException: %s', ex) raise IndexConnectionError( - 'Could not initialize ES session: %s' % e - ) from e + 'Could not initialize ES session: %s' % ex + ) from ex return es def _base_search(self) -> Search: @@ -193,11 +193,11 @@ def cluster_available(self) -> bool: try: self.es.cluster.health(wait_for_status='yellow', request_timeout=1) return True - except urllib3.exceptions.HTTPError as e: - logger.debug('Health check failed: %s', str(e)) + except urllib3.exceptions.HTTPError as ex: + logger.debug('Health check failed: %s', str(ex)) return False - except Exception as e: - logger.debug('Health check failed: %s', str(e)) + except Exception as ex: + logger.debug('Health check failed: %s', str(ex)) return False def create_index(self) -> None: @@ -424,8 +424,8 @@ def search(self, query: Query, highlight: bool = True) -> DocumentSet: current_search = api_search(current_search, query) elif isinstance(query, ClassicAPIQuery): current_search = classic_search(current_search, query) - except TypeError as e: - raise e + except TypeError as ex: + raise ex # logger.error('Malformed query: %s', str(e)) # raise QueryError('Malformed query') from e diff --git a/search/services/metadata.py b/search/services/metadata.py index 4de6bfb9..3790d7d7 100644 --- a/search/services/metadata.py +++ b/search/services/metadata.py @@ -116,14 +116,14 @@ def retrieve(self, document_id: str) -> DocMeta: ) response = requests.get(target, verify=self._verify_cert, headers={'User-Agent': 'arXiv/system'}) - except requests.exceptions.SSLError as e: - logger.error('SSLError: %s', e) - raise SecurityException('SSL failed: %s' % e) from e - except requests.exceptions.ConnectionError as e: - logger.error('ConnectionError: %s', e) + except requests.exceptions.SSLError as ex: + logger.error('SSLError: %s', ex) + raise SecurityException('SSL failed: %s' % ex) from ex + except requests.exceptions.ConnectionError as ex: + logger.error('ConnectionError: %s', ex) raise ConnectionFailed( - 'Could not connect to metadata service: %s' % e - ) from e + 'Could not connect to metadata service: %s' % ex + ) from ex if response.status_code not in \ [status.HTTP_200_OK, status.HTTP_206_PARTIAL_CONTENT]: @@ -137,11 +137,11 @@ def retrieve(self, document_id: str) -> DocMeta: try: data = DocMeta(**response.json()) # type: ignore # See https://github.com/python/mypy/issues/3937 - except json.decoder.JSONDecodeError as e: - logger.error('JSONDecodeError: %s', e) + except json.decoder.JSONDecodeError as ex: + logger.error('JSONDecodeError: %s', ex) raise BadResponse( - '%s: could not decode response: %s' % (document_id, e) - ) from e + '%s: could not decode response: %s' % (document_id, ex) + ) from ex logger.debug(f'{document_id}: response decoded; done!') return data @@ -176,14 +176,14 @@ def bulk_retrieve(self, document_ids: List[str]) -> List[DocMeta]: f' verify {self._verify_cert}' ) response = self._session.get(target, verify=self._verify_cert) - except requests.exceptions.SSLError as e: - logger.error('SSLError: %s', e) - raise SecurityException('SSL failed: %s' % e) from e - except requests.exceptions.ConnectionError as e: - logger.error('ConnectionError: %s', e) + except requests.exceptions.SSLError as ex: + logger.error('SSLError: %s', ex) + raise SecurityException('SSL failed: %s' % ex) from ex + except requests.exceptions.ConnectionError as ex: + logger.error('ConnectionError: %s', ex) raise ConnectionFailed( - 'Could not connect to metadata service: %s' % e - ) from e + 'Could not connect to metadata service: %s' % ex + ) from ex if response.status_code not in \ [status.HTTP_200_OK, status.HTTP_206_PARTIAL_CONTENT]: @@ -198,11 +198,11 @@ def bulk_retrieve(self, document_ids: List[str]) -> List[DocMeta]: resp = response.json() # A list with metadata for each paper. data: List[DocMeta] data = [DocMeta(**value) for value in resp] # type: ignore - except json.decoder.JSONDecodeError as e: - logger.error('JSONDecodeError: %s', e) + except json.decoder.JSONDecodeError as ex: + logger.error('JSONDecodeError: %s', ex) raise BadResponse( - '%s: could not decode response: %s' % (document_ids, e) - ) from e + '%s: could not decode response: %s' % (document_ids, ex) + ) from ex logger.debug(f'{document_ids}: response decoded; done!') return data diff --git a/search/services/tests/test_fulltext.py b/search/services/tests/test_fulltext.py index 93ba9711..b66c5f6d 100644 --- a/search/services/tests/test_fulltext.py +++ b/search/services/tests/test_fulltext.py @@ -26,8 +26,8 @@ def test_calls_fulltext_endpoint(self, mock_get): try: fulltext_session.retrieve('1234.5678v3') - except Exception as e: - self.fail('Choked on valid response: %s' % e) + except Exception as ex: + self.fail('Choked on valid response: %s' % ex) args, _ = mock_get.call_args self.assertTrue(args[0].startswith(base)) @@ -63,8 +63,8 @@ def test_raise_ioerror_on_sslerror(self, mock_get): with self.assertRaises(IOError): try: fulltext.retrieve('1234.5678v3') - except Exception as e: - if type(e) is SSLError: + except Exception as ex: + if type(ex) is SSLError: self.fail('Should not return dependency exception') raise diff --git a/search/services/tests/test_metadata.py b/search/services/tests/test_metadata.py index 2e56d874..1ff31a61 100644 --- a/search/services/tests/test_metadata.py +++ b/search/services/tests/test_metadata.py @@ -34,12 +34,12 @@ def test_calls_metadata_endpoint(self, mock_get): try: docmeta_session.retrieve('1602.00123') - except Exception as e: - self.fail('Choked on valid response: %s' % e) + except Exception as ex: + self.fail('Choked on valid response: %s' % ex) try: args, _ = mock_get.call_args - except Exception as e: - self.fail('Did not call requests.get as expected: %s' % e) + except Exception as ex: + self.fail('Did not call requests.get as expected: %s' % ex) self.assertTrue(args[0].startswith(base)) @@ -64,24 +64,24 @@ def test_calls_metadata_endpoint_roundrobin(self, mock_get): try: docmeta_session.retrieve('1602.00123') - except Exception as e: - self.fail('Choked on valid response: %s' % e) + except Exception as ex: + self.fail('Choked on valid response: %s' % ex) try: args, _ = mock_get.call_args - except Exception as e: - self.fail('Did not call requests.get as expected: %s' % e) + except Exception as ex: + self.fail('Did not call requests.get as expected: %s' % ex) self.assertTrue( args[0].startswith(base[0]), "Expected call to %s" % base[0] ) try: docmeta_session.retrieve('1602.00124') - except Exception as e: - self.fail('Choked on valid response: %s' % e) + except Exception as ex: + self.fail('Choked on valid response: %s' % ex) try: args, _ = mock_get.call_args - except Exception as e: - self.fail('Did not call requests.get as expected: %s' % e) + except Exception as ex: + self.fail('Did not call requests.get as expected: %s' % ex) self.assertTrue( args[0].startswith(base[1]), "Expected call to %s" % base[1] ) @@ -118,8 +118,8 @@ def test_raise_ioerror_on_sslerror(self, mock_get): with self.assertRaises(IOError): try: metadata.retrieve('1234.5678v3') - except Exception as e: - if type(e) is SSLError: + except Exception as ex: + if type(ex) is SSLError: self.fail('Should not return dependency exception') raise