diff --git a/.coveragerc b/.coveragerc index dd0c164d..57687ab6 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,13 +1,20 @@ [run] omit = + api.py app.py + audit.py + classic_api.py setup.py docs/* - test* + *test* wsgi.py + wsgi-api.py + wsgi-classic-api.py + wsgi-app.py populate_test_metadata.py upload_static_assets.py create_index.py + reindex.py bulk_index.py shard_ids_for_index.py search/config.py diff --git a/.gitignore b/.gitignore index 8f5e8d4d..8140f1ae 100644 --- a/.gitignore +++ b/.gitignore @@ -101,6 +101,9 @@ ENV/ .vscode settings.json +# PyCharm +.idea + # mypy .mypy_cache/ @@ -110,3 +113,4 @@ temp/ to_index/ +.pytest_cache/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..6244f1d5 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,35 @@ +- repo: https://github.com/ambv/black + rev: stable + hooks: + - id: black + name: Format Python Code + language: python + entry: black + args: + - --safe + - --line-length=79 + - --target-version=py37 + - . + +- repo: https://github.com/PyCQA/flake8 + rev: 3.7.9 + hooks: + - id: flake8 + name: Flake8 Check + language: python + entry: flake8 + args: + - search + - tests + +- repo: https://github.com/pycqa/pydocstyle + rev: master + hooks: + - id: pydocstyle + name: Python Documentation Style Check + language: python + entry: pydocstyle + args: + - search + - tests + - --add-ignore=D401,D202 diff --git a/.pylintrc b/.pylintrc index 1ac11602..2a12e69b 100644 --- a/.pylintrc +++ b/.pylintrc @@ -3,7 +3,8 @@ # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code -extension-pkg-whitelist= +extension-pkg-whitelist=lxml.etree +# lxml.etree is unfortunately a dependency of feedgen, and isn't great with pylint # Add files or directories to the blacklist. They should be base names, not # paths. diff --git a/.travis.yml b/.travis.yml index 75c2a015..57287a8b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,13 +11,12 @@ python: - "3.6" script: - pip install pipenv - - pipenv install - - pipenv install --dev --skip-lock + - pipenv sync --dev + - pipenv run nose2 -vvv tests.base_app_tests - pipenv run nose2 -vvv --with-coverage + - "./lintstats.sh" after_success: - coveralls - - pipenv install pylint pydocstyle mypy - - ./lintstats.sh - docker login -u "$DOCKERHUB_USERNAME" -p "$DOCKERHUB_PASSWORD" - docker build . -t arxiv/search:${TRAVIS_COMMIT}; docker push arxiv/search:${TRAVIS_COMMIT} diff --git a/Dockerfile b/Dockerfile index fe54b384..61560695 100644 --- a/Dockerfile +++ b/Dockerfile @@ -37,7 +37,7 @@ ENV FLASK_APP /opt/arxiv/app.py ENV ELASTICSEARCH_SERVICE_HOST 127.0.0.1 ENV ELASTICSEARCH_SERVICE_PORT 9200 -ENV ELASTICSEARCH_PORT_9200_PROTO http +ENV ELASTICSEARCH_SERVICE_PORT_9200_PROTO http ENV ELASTICSEARCH_PASSWORD changeme ENV METADATA_ENDPOINT https://arxiv.org/docmeta_bulk/ diff --git a/Dockerfile-agent b/Dockerfile-agent index ebea1c2b..ffe3f74e 100644 --- a/Dockerfile-agent +++ b/Dockerfile-agent @@ -4,13 +4,13 @@ # article metadata becomes available. Subscribes to a Kinesis stream for # notifications about new metadata. -FROM arxiv/search:0.5.1 +FROM arxiv/search:0.5.6 WORKDIR /opt/arxiv ENV ELASTICSEARCH_SERVICE_HOST 127.0.0.1 ENV ELASTICSEARCH_SERVICE_PORT 9200 -ENV ELASTICSEARCH_PORT_9200_PROTO http +ENV ELASTICSEARCH_SERVICE_PORT_9200_PROTO http ENV ELASTICSEARCH_INDEX arxiv ENV ELASTICSEARCH_USER elastic ENV ELASTICSEARCH_PASSWORD changeme diff --git a/Dockerfile-api b/Dockerfile-api index 2f7e38ac..c99470a5 100644 --- a/Dockerfile-api +++ b/Dockerfile-api @@ -32,7 +32,7 @@ ENV FLASK_APP /opt/arxiv/app.py ENV ELASTICSEARCH_SERVICE_HOST 127.0.0.1 ENV ELASTICSEARCH_SERVICE_PORT 9200 -ENV ELASTICSEARCH_PORT_9200_PROTO http +ENV ELASTICSEARCH_SERVICE_PORT_9200_PROTO http ENV ELASTICSEARCH_INDEX arxiv ENV ELASTICSEARCH_USER elastic ENV ELASTICSEARCH_PASSWORD changeme diff --git a/Dockerfile-index b/Dockerfile-index index 41e9d4b8..3cc00f71 100644 --- a/Dockerfile-index +++ b/Dockerfile-index @@ -11,7 +11,7 @@ # $ cp arxiv_id_dump.txt /tmp/to_index # $ docker run -it --network=arxivsearch_es_stack \ # > -v /tmp/to_index:/to_index \ -# > -e ELASTICSEARCH_HOST=elasticsearch \ +# > -e ELASTICSEARCH_SERVICE_HOST=elasticsearch \ # > arxiv/search-index /to_index/arxiv_id_dump.txt # # See also ELASTICSEARCH_* and METADATA_ENDPOINT parameters, below. @@ -31,7 +31,7 @@ ENV FLASK_APP /opt/arxiv/app.py ENV ELASTICSEARCH_SERVICE_HOST 127.0.0.1 ENV ELASTICSEARCH_SERVICE_PORT 9200 -ENV ELASTICSEARCH_PORT_9200_PROTO http +ENV ELASTICSEARCH_SERVICE_PORT_9200_PROTO http ENV ELASTICSEARCH_USER elastic ENV ELASTICSEARCH_PASSWORD changeme ENV METADATA_ENDPOINT https://arxiv.org/docmeta_bulk/ diff --git a/Pipfile b/Pipfile index 2dec3fac..1b709059 100644 --- a/Pipfile +++ b/Pipfile @@ -5,7 +5,7 @@ name = "pypi" [packages] arxiv-auth = "==0.2.7" -arxiv-base = "==0.16.4" +arxiv-base = "==0.16.6" boto = "==2.48.0" "boto3" = "==1.6.6" botocore = "==1.9.6" @@ -15,9 +15,10 @@ click = "==6.7" coverage = "==4.4.2" dataclasses = "==0.4" docutils = "==0.14" -elasticsearch = ">=6.0.0,<7.0.0" -elasticsearch-dsl = ">=6.0.0,<7.0.0" -flask = "*" +elasticsearch = "==6.3.0" +elasticsearch-dsl = "==6.4.0" +feedgen = "==0.9.0" +flask = "==1.0.2" "flask-s3" = "==0.3.3" idna = "==2.6" ipaddress = "==1.0.19" @@ -25,10 +26,11 @@ itsdangerous = "==0.24" "jinja2" = ">=2.10.1" jmespath = "==0.9.3" jsonschema = "==2.6.0" +lark-parser = "==0.8.1" markupsafe = "==1.0" mccabe = "==0.6.1" mock = "==2.0.0" -mypy = "==0.670" +mypy = "==0.720" "nose2" = "==0.7.3" pbr = "==3.1.1" psutil = "==5.4.3" @@ -44,7 +46,7 @@ snowballstemmer = "==1.2.1" thrift = "==0.11.0" thrift-connector = "==0.23" "urllib3" = ">=1.23" -werkzeug = "*" +werkzeug = "~=0.14" wtforms = "==2.1" bleach = "*" lxml = "*" @@ -57,3 +59,10 @@ sphinx = "*" sphinxcontrib-websupport = "*" sphinx-autodoc-typehints = "*" pylint = "*" +pytest = "*" +nose = "*" +mypy = "==0.720" +pre-commit = "==2.0.1" + +[requires] +python_version = "3.6" diff --git a/Pipfile.lock b/Pipfile.lock index 4b422c85..70998a31 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,10 +1,12 @@ { "_meta": { "hash": { - "sha256": "3f518d857ecb1f7629ccc6987971a17bd0a1d18f63d96534ace30c664ac0e94c" + "sha256": "af2f18be213daa09b57f4fa5068be2eac4eefaa7b5dcbb04b211b068afd66f51" }, "pipfile-spec": 6, - "requires": {}, + "requires": { + "python_version": "3.6" + }, "sources": [ { "name": "pypi", @@ -23,10 +25,10 @@ }, "arxiv-base": { "hashes": [ - "sha256:863481f5c4ca2df38846be3560f4ce98e97c4b205f9029120500907fbb52b2e6" + "sha256:f7a5ff38c91bc3c8f40db500179394cbca1a3dd00d9c89c8e9e2ad63a6829664" ], "index": "pypi", - "version": "==0.16.4" + "version": "==0.16.6" }, "backports-datetime-fromisoformat": { "hashes": [ @@ -93,17 +95,13 @@ "coverage": { "hashes": [ "sha256:007eeef7e23f9473622f7d94a3e029a45d55a92a1f083f0f3512f5ab9a669b05", - "sha256:0388c12539372bb92d6dde68b4627f0300d948965bbb7fc104924d715fdc0965", "sha256:079248312838c4c8f3494934ab7382a42d42d5f365f0cf7516f938dbb3f53f3f", "sha256:17307429935f96c986a1b1674f78079528833410750321d22b5fb35d1883828e", - "sha256:1afccd7e27cac1b9617be8c769f6d8a6d363699c9b86820f40c74cfb3328921c", "sha256:2ad357d12971e77360034c1596011a03f50c0f9e1ecd12e081342b8d1aee2236", - "sha256:2b4d7f03a8a6632598cbc5df15bbca9f778c43db7cf1a838f4fa2c8599a8691a", "sha256:2e1a5c6adebb93c3b175103c2f855eda957283c10cf937d791d81bef8872d6ca", "sha256:309d91bd7a35063ec7a0e4d75645488bfab3f0b66373e7722f23da7f5b0f34cc", "sha256:358d635b1fc22a425444d52f26287ae5aea9e96e254ff3c59c407426f44574f4", "sha256:3f4d0b3403d3e110d2588c275540649b1841725f5a11a7162620224155d00ba2", - "sha256:43a155eb76025c61fc20c3d03b89ca28efa6f5be572ab6110b2fb68eda96bfea", "sha256:493082f104b5ca920e97a485913de254cbe351900deed72d4264571c73464cd0", "sha256:4c4f368ffe1c2e7602359c2c50233269f3abe1c48ca6b288dcd0fb1d1c679733", "sha256:5ff16548492e8a12e65ff3d55857ccd818584ed587a6c2898a9ebbe09a880674", @@ -114,11 +112,8 @@ "sha256:845fddf89dca1e94abe168760a38271abfc2e31863fbb4ada7f9a99337d7c3dc", "sha256:87d942863fe74b1c3be83a045996addf1639218c2cb89c5da18c06c0fe3917ea", "sha256:9721f1b7275d3112dc7ccf63f0553c769f09b5c25a26ee45872c7f5c09edf6c1", - "sha256:a4497faa4f1c0fc365ba05eaecfb6b5d24e3c8c72e95938f9524e29dadb15e76", "sha256:a7cfaebd8f24c2b537fa6a271229b051cdac9c1734bb6f939ccfc7c055689baa", - "sha256:ab3508df9a92c1d3362343d235420d08e2662969b83134f8a97dc1451cbe5e84", "sha256:b0059630ca5c6b297690a6bf57bf2fdac1395c24b7935fd73ee64190276b743b", - "sha256:b6cebae1502ce5b87d7c6f532fa90ab345cfbda62b95aeea4e431e164d498a3d", "sha256:bd4800e32b4c8d99c3a2c943f1ac430cbf80658d884123d19639bcde90dad44a", "sha256:cdd92dd9471e624cd1d8c1a2703d25f114b59b736b0f1f659a98414e535ffb3d", "sha256:d00e29b78ff610d300b2c37049a41234d48ea4f2d2581759ebcf67caaf731c31", @@ -130,8 +125,7 @@ "sha256:f29841e865590af72c4b90d7b5b8e93fd560f5dea436c1d5ee8053788f9285de", "sha256:f3a5c6d054c531536a83521c00e5d4004f1e126e2e2556ce399bef4180fbe540", "sha256:f87f522bde5540d8a4b11df80058281ac38c44b13ce29ced1e294963dd51a8f8", - "sha256:f8c55dd0f56d3d618dfacf129e010cbe5d5f94b6951c1b2f13ab1a2f79c284da", - "sha256:f98b461cb59f117887aa634a66022c0bd394278245ed51189f63a036516e32de" + "sha256:f8c55dd0f56d3d618dfacf129e010cbe5d5f94b6951c1b2f13ab1a2f79c284da" ], "index": "pypi", "version": "==4.4.2" @@ -162,11 +156,11 @@ }, "elasticsearch": { "hashes": [ - "sha256:1f0f633e3b500d5042424f75a505badf8c4b9962c1b4734cdfb3087fb67920be", - "sha256:fb5ab15ee283f104b5a7a5695c7e879cb2927e4eb5aed9c530811590b41259ad" + "sha256:24c93ba3bb078328c70137c31d9bfcfa152f61c3df64823b99b25307123611df", + "sha256:80ff7a1a56920535a9987da333c7e385b2ded27595b6de33860707dab758efbe" ], "index": "pypi", - "version": "==6.4.0" + "version": "==6.3.0" }, "elasticsearch-dsl": { "hashes": [ @@ -176,6 +170,13 @@ "index": "pypi", "version": "==6.4.0" }, + "feedgen": { + "hashes": [ + "sha256:8e811bdbbed6570034950db23a4388453628a70e689a6e8303ccec430f5a804a" + ], + "index": "pypi", + "version": "==0.9.0" + }, "flask": { "hashes": [ "sha256:13f9f196f330c7c2c5d7a5cf91af894110ca0215ac051b5844701f2bfd934d52", @@ -187,8 +188,7 @@ "flask-s3": { "hashes": [ "sha256:1d49061d4b78759df763358a901f4ed32bb43f672c9f8e1ec7226793f6ae0fd2", - "sha256:23cbbb1db4c29c313455dbe16f25be078d6318f0a11abcbb610f99e116945b62", - "sha256:d6e1fc3834f0be74c17e26bb8d0f506f711eb888775ab6af9164c0abb6f4c97c" + "sha256:23cbbb1db4c29c313455dbe16f25be078d6318f0a11abcbb610f99e116945b62" ], "index": "pypi", "version": "==0.3.3" @@ -217,11 +217,11 @@ }, "jinja2": { "hashes": [ - "sha256:74320bb91f31270f9551d46522e33af46a80c3d619f4a4bf42b3164d30b5911f", - "sha256:9fe95f19286cfefaa917656583d020be14e7859c6b0252588391e47db34527de" + "sha256:93187ffbc7808079673ef52771baa950426fd664d3aad1d0fa3e95644360e250", + "sha256:b0eaf100007721b5c16c1fc1eecb87409464edc10469ddc9a22a27a99123be49" ], "index": "pypi", - "version": "==2.10.3" + "version": "==2.11.1" }, "jmespath": { "hashes": [ @@ -239,37 +239,45 @@ "index": "pypi", "version": "==2.6.0" }, + "lark-parser": { + "hashes": [ + "sha256:b9f4e4711b0837d682051009d232e6df6b37a1278db4568a8221124bb8c1dc8a" + ], + "index": "pypi", + "version": "==0.8.1" + }, "lxml": { "hashes": [ - "sha256:00ac0d64949fef6b3693813fe636a2d56d97a5a49b5bbb86e4cc4cc50ebc9ea2", - "sha256:0571e607558665ed42e450d7bf0e2941d542c18e117b1ebbf0ba72f287ad841c", - "sha256:0e3f04a7615fdac0be5e18b2406529521d6dbdb0167d2a690ee328bef7807487", - "sha256:13cf89be53348d1c17b453867da68704802966c433b2bb4fa1f970daadd2ef70", - "sha256:217262fcf6a4c2e1c7cb1efa08bd9ebc432502abc6c255c4abab611e8be0d14d", - "sha256:223e544828f1955daaf4cefbb4853bc416b2ec3fd56d4f4204a8b17007c21250", - "sha256:277cb61fede2f95b9c61912fefb3d43fbd5f18bf18a14fae4911b67984486f5d", - "sha256:3213f753e8ae86c396e0e066866e64c6b04618e85c723b32ecb0909885211f74", - "sha256:4690984a4dee1033da0af6df0b7a6bde83f74e1c0c870623797cec77964de34d", - "sha256:4fcc472ef87f45c429d3b923b925704aa581f875d65bac80f8ab0c3296a63f78", - "sha256:61409bd745a265a742f2693e4600e4dbd45cc1daebe1d5fad6fcb22912d44145", - "sha256:678f1963f755c5d9f5f6968dded7b245dd1ece8cf53c1aa9d80e6734a8c7f41d", - "sha256:6c6d03549d4e2734133badb9ab1c05d9f0ef4bcd31d83e5d2b4747c85cfa21da", - "sha256:6e74d5f4d6ecd6942375c52ffcd35f4318a61a02328f6f1bd79fcb4ffedf969e", - "sha256:7b4fc7b1ecc987ca7aaf3f4f0e71bbfbd81aaabf87002558f5bc95da3a865bcd", - "sha256:7ed386a40e172ddf44c061ad74881d8622f791d9af0b6f5be20023029129bc85", - "sha256:8f54f0924d12c47a382c600c880770b5ebfc96c9fd94cf6f6bdc21caf6163ea7", - "sha256:ad9b81351fdc236bda538efa6879315448411a81186c836d4b80d6ca8217cdb9", - "sha256:bbd00e21ea17f7bcc58dccd13869d68441b32899e89cf6cfa90d624a9198ce85", - "sha256:c3c289762cc09735e2a8f8a49571d0e8b4f57ea831ea11558247b5bdea0ac4db", - "sha256:cf4650942de5e5685ad308e22bcafbccfe37c54aa7c0e30cd620c2ee5c93d336", - "sha256:cfcbc33c9c59c93776aa41ab02e55c288a042211708b72fdb518221cc803abc8", - "sha256:e301055deadfedbd80cf94f2f65ff23126b232b0d1fea28f332ce58137bcdb18", - "sha256:ebbfe24df7f7b5c6c7620702496b6419f6a9aa2fd7f005eb731cc80d7b4692b9", - "sha256:eff69ddbf3ad86375c344339371168640951c302450c5d3e9936e98d6459db06", - "sha256:f6ed60a62c5f1c44e789d2cf14009423cb1646b44a43e40a9cf6a21f077678a1" + "sha256:06d4e0bbb1d62e38ae6118406d7cdb4693a3fa34ee3762238bcb96c9e36a93cd", + "sha256:0701f7965903a1c3f6f09328c1278ac0eee8f56f244e66af79cb224b7ef3801c", + "sha256:1f2c4ec372bf1c4a2c7e4bb20845e8bcf8050365189d86806bad1e3ae473d081", + "sha256:4235bc124fdcf611d02047d7034164897ade13046bda967768836629bc62784f", + "sha256:5828c7f3e615f3975d48f40d4fe66e8a7b25f16b5e5705ffe1d22e43fb1f6261", + "sha256:585c0869f75577ac7a8ff38d08f7aac9033da2c41c11352ebf86a04652758b7a", + "sha256:5d467ce9c5d35b3bcc7172c06320dddb275fea6ac2037f72f0a4d7472035cea9", + "sha256:63dbc21efd7e822c11d5ddbedbbb08cd11a41e0032e382a0fd59b0b08e405a3a", + "sha256:7bc1b221e7867f2e7ff1933165c0cec7153dce93d0cdba6554b42a8beb687bdb", + "sha256:8620ce80f50d023d414183bf90cc2576c2837b88e00bea3f33ad2630133bbb60", + "sha256:8a0ebda56ebca1a83eb2d1ac266649b80af8dd4b4a3502b2c1e09ac2f88fe128", + "sha256:90ed0e36455a81b25b7034038e40880189169c308a3df360861ad74da7b68c1a", + "sha256:95e67224815ef86924fbc2b71a9dbd1f7262384bca4bc4793645794ac4200717", + "sha256:afdb34b715daf814d1abea0317b6d672476b498472f1e5aacbadc34ebbc26e89", + "sha256:b4b2c63cc7963aedd08a5f5a454c9f67251b1ac9e22fd9d72836206c42dc2a72", + "sha256:d068f55bda3c2c3fcaec24bd083d9e2eede32c583faf084d6e4b9daaea77dde8", + "sha256:d5b3c4b7edd2e770375a01139be11307f04341ec709cf724e0f26ebb1eef12c3", + "sha256:deadf4df349d1dcd7b2853a2c8796593cc346600726eff680ed8ed11812382a7", + "sha256:df533af6f88080419c5a604d0d63b2c33b1c0c4409aba7d0cb6de305147ea8c8", + "sha256:e4aa948eb15018a657702fee0b9db47e908491c64d36b4a90f59a64741516e77", + "sha256:e5d842c73e4ef6ed8c1bd77806bf84a7cb535f9c0cf9b2c74d02ebda310070e1", + "sha256:ebec08091a22c2be870890913bdadd86fcd8e9f0f22bcb398abd3af914690c15", + "sha256:edc15fcfd77395e24543be48871c251f38132bb834d9fdfdad756adb6ea37679", + "sha256:f2b74784ed7e0bc2d02bd53e48ad6ba523c9b36c194260b7a5045071abbb1012", + "sha256:fa071559f14bd1e92077b1b5f6c22cf09756c6de7139370249eb372854ce51e6", + "sha256:fd52e796fee7171c4361d441796b64df1acfceb51f29e545e812f16d023c4bbc", + "sha256:fe976a0f1ef09b3638778024ab9fb8cde3118f203364212c198f71341c0715ca" ], "index": "pypi", - "version": "==4.4.2" + "version": "==4.5.0" }, "markupsafe": { "hashes": [ @@ -296,11 +304,20 @@ }, "mypy": { "hashes": [ - "sha256:308c274eb8482fbf16006f549137ddc0d69e5a589465e37b99c4564414363ca7", - "sha256:e80fd6af34614a0e898a57f14296d0dacb584648f0339c2e000ddbf0f4cc2f8d" + "sha256:0107bff4f46a289f0e4081d59b77cef1c48ea43da5a0dbf0005d54748b26df2a", + "sha256:07957f5471b3bb768c61f08690c96d8a09be0912185a27a68700f3ede99184e4", + "sha256:10af62f87b6921eac50271e667cc234162a194e742d8e02fc4ddc121e129a5b0", + "sha256:11fd60d2f69f0cefbe53ce551acf5b1cec1a89e7ce2d47b4e95a84eefb2899ae", + "sha256:15e43d3b1546813669bd1a6ec7e6a11d2888db938e0607f7b5eef6b976671339", + "sha256:352c24ba054a89bb9a35dd064ee95ab9b12903b56c72a8d3863d882e2632dc76", + "sha256:437020a39417e85e22ea8edcb709612903a9924209e10b3ec6d8c9f05b79f498", + "sha256:49925f9da7cee47eebf3420d7c0e00ec662ec6abb2780eb0a16260a7ba25f9c4", + "sha256:6724fcd5777aa6cebfa7e644c526888c9d639bd22edd26b2a8038c674a7c34bd", + "sha256:7a17613f7ea374ab64f39f03257f22b5755335b73251d0d253687a69029701ba", + "sha256:cdc1151ced496ca1496272da7fc356580e95f2682be1d32377c22ddebdf73c91" ], "index": "pypi", - "version": "==0.670" + "version": "==0.720" }, "mypy-extensions": { "hashes": [ @@ -351,10 +368,10 @@ }, "py": { "hashes": [ - "sha256:64f65755aee5b381cea27766a3a147c3f15b9b6b9ac88676de66ba2ae36793fa", - "sha256:dc639b046a6e2cff5bbe40194ad65936d6ba360b52b3c3fe1d08a82dd50b5e53" + "sha256:5e27081401262157467ad6e7f851b7aa402c5852dbcb3dae06768434de5752aa", + "sha256:c20fdd83a5dbc0af9efd622bee9a5564e278f6380fffcacc43ba6f43db2813b0" ], - "version": "==1.8.0" + "version": "==1.8.1" }, "pycodestyle": { "hashes": [ @@ -412,15 +429,8 @@ }, "pytz": { "hashes": [ - "sha256:59707844a9825589878236ff2f4e0dc9958511b7ffaae94dc615da07d4a68d33", - "sha256:699d18a2a56f19ee5698ab1123bbcc1d269d061996aeb1eda6d89248d3542b82", - "sha256:80af0f3008046b9975242012a985f04c5df1f01eed4ec1633d56cc47a75a6a48", - "sha256:8cc90340159b5d7ced6f2ba77694d946fc975b09f1a51d93f3ce3bb399396f94", "sha256:c41c62827ce9cafacd6f2f7018e4f83a6f1986e87bfd000b8cfbd4ab5da95f1a", - "sha256:d0ef5ef55ed3d37854320d4926b04a4cb42a2e88f71da9ddfdacfde8e364f027", - "sha256:dd2e4ca6ce3785c8dd342d1853dd9052b19290d5bf66060846e5dc6b8d6667f7", - "sha256:fae4cffc040921b8a2d60c6cf0b5d662c1190fe54d718271db4eb17d44a185b7", - "sha256:feb2365914948b8620347784b6b6da356f31c9d03560259070b2f30cff3d469d" + "sha256:fae4cffc040921b8a2d60c6cf0b5d662c1190fe54d718271db4eb17d44a185b7" ], "index": "pypi", "version": "==2017.3" @@ -440,11 +450,11 @@ }, "requests": { "hashes": [ - "sha256:11e007a8a2aa0323f5a921e9e6a2d7e4e67d9877e85773fba9ba6419025cbeb4", - "sha256:9cf5292fcd0f598c671cfc1e0d7d1a7f13bb8085e9a590f48c010551dc6c4b31" + "sha256:43999036bfa82904b6af1d99e4882b560e5e2c68e5c4b0aa03b655f3d7d73fee", + "sha256:b3f43d496c6daba4493e7c431722aeb7dbc6288f52a6e04e7b6023b0247817e6" ], "index": "pypi", - "version": "==2.22.0" + "version": "==2.23.0" }, "retry": { "hashes": [ @@ -464,10 +474,10 @@ }, "six": { "hashes": [ - "sha256:1f1b7d42e254082a9db6279deae68afb421ceba6158efa6131de7b3003ee93fd", - "sha256:30f610279e8b2578cab6db20741130331735c781b56053c59c4076da27f06b66" + "sha256:236bdbdce46e6e6a3d61a337c0f8b763ca1e8717c03b369e87a7ec7ce1319c0a", + "sha256:8f3cd2e254d8f793e7f3d6d9df77b92252b52637291d0f0da013c76ea2724b6c" ], - "version": "==1.13.0" + "version": "==1.14.0" }, "snowballstemmer": { "hashes": [ @@ -479,9 +489,9 @@ }, "sqlalchemy": { "hashes": [ - "sha256:bfb8f464a5000b567ac1d350b9090cf081180ec1ab4aa87e7bca12dab25320ec" + "sha256:64a7b71846db6423807e96820993fa12a03b89127d278290ca25c0b11ed7b4fb" ], - "version": "==1.3.12" + "version": "==1.3.13" }, "thrift": { "hashes": [ @@ -500,27 +510,29 @@ }, "typed-ast": { "hashes": [ - "sha256:132eae51d6ef3ff4a8c47c393a4ef5ebf0d1aecc96880eb5d6c8ceab7017cc9b", - "sha256:18141c1484ab8784006c839be8b985cfc82a2e9725837b0ecfa0203f71c4e39d", - "sha256:2baf617f5bbbfe73fd8846463f5aeafc912b5ee247f410700245d68525ec584a", - "sha256:3d90063f2cbbe39177e9b4d888e45777012652d6110156845b828908c51ae462", - "sha256:4304b2218b842d610aa1a1d87e1dc9559597969acc62ce717ee4dfeaa44d7eee", - "sha256:4983ede548ffc3541bae49a82675996497348e55bafd1554dc4e4a5d6eda541a", - "sha256:5315f4509c1476718a4825f45a203b82d7fdf2a6f5f0c8f166435975b1c9f7d4", - "sha256:6cdfb1b49d5345f7c2b90d638822d16ba62dc82f7616e9b4caa10b72f3f16649", - "sha256:7b325f12635598c604690efd7a0197d0b94b7d7778498e76e0710cd582fd1c7a", - "sha256:8d3b0e3b8626615826f9a626548057c5275a9733512b137984a68ba1598d3d2f", - "sha256:8f8631160c79f53081bd23446525db0bc4c5616f78d04021e6e434b286493fd7", - "sha256:912de10965f3dc89da23936f1cc4ed60764f712e5fa603a09dd904f88c996760", - "sha256:b010c07b975fe853c65d7bbe9d4ac62f1c69086750a574f6292597763781ba18", - "sha256:c908c10505904c48081a5415a1e295d8403e353e0c14c42b6d67f8f97fae6616", - "sha256:c94dd3807c0c0610f7c76f078119f4ea48235a953512752b9175f9f98f5ae2bd", - "sha256:ce65dee7594a84c466e79d7fb7d3303e7295d16a83c22c7c4037071b059e2c21", - "sha256:eaa9cfcb221a8a4c2889be6f93da141ac777eb8819f077e1d09fb12d00a09a93", - "sha256:f3376bc31bad66d46d44b4e6522c5c21976bf9bca4ef5987bb2bf727f4506cbb", - "sha256:f9202fa138544e13a4ec1a6792c35834250a85958fde1251b6a22e07d1260ae7" - ], - "version": "==1.3.5" + "sha256:0666aa36131496aed8f7be0410ff974562ab7eeac11ef351def9ea6fa28f6355", + "sha256:0c2c07682d61a629b68433afb159376e24e5b2fd4641d35424e462169c0a7919", + "sha256:249862707802d40f7f29f6e1aad8d84b5aa9e44552d2cc17384b209f091276aa", + "sha256:24995c843eb0ad11a4527b026b4dde3da70e1f2d8806c99b7b4a7cf491612652", + "sha256:269151951236b0f9a6f04015a9004084a5ab0d5f19b57de779f908621e7d8b75", + "sha256:4083861b0aa07990b619bd7ddc365eb7fa4b817e99cf5f8d9cf21a42780f6e01", + "sha256:498b0f36cc7054c1fead3d7fc59d2150f4d5c6c56ba7fb150c013fbc683a8d2d", + "sha256:4e3e5da80ccbebfff202a67bf900d081906c358ccc3d5e3c8aea42fdfdfd51c1", + "sha256:6daac9731f172c2a22ade6ed0c00197ee7cc1221aa84cfdf9c31defeb059a907", + "sha256:715ff2f2df46121071622063fc7543d9b1fd19ebfc4f5c8895af64a77a8c852c", + "sha256:73d785a950fc82dd2a25897d525d003f6378d1cb23ab305578394694202a58c3", + "sha256:8c8aaad94455178e3187ab22c8b01a3837f8ee50e09cf31f1ba129eb293ec30b", + "sha256:8ce678dbaf790dbdb3eba24056d5364fb45944f33553dd5869b7580cdbb83614", + "sha256:aaee9905aee35ba5905cfb3c62f3e83b3bec7b39413f0a7f19be4e547ea01ebb", + "sha256:bcd3b13b56ea479b3650b82cabd6b5343a625b0ced5429e4ccad28a8973f301b", + "sha256:c9e348e02e4d2b4a8b2eedb48210430658df6951fa484e59de33ff773fbd4b41", + "sha256:d205b1b46085271b4e15f670058ce182bd1199e56b317bf2ec004b6a44f911f6", + "sha256:d43943ef777f9a1c42bf4e552ba23ac77a6351de620aa9acf64ad54933ad4d34", + "sha256:d5d33e9e7af3b34a40dc05f498939f0ebf187f07c385fd58d591c533ad8562fe", + "sha256:fc0fea399acb12edbf8a628ba8d2312f583bdbdb3335635db062fa98cf71fca4", + "sha256:fe460b922ec15dd205595c9b5b99e2f056fd98ae8f9f56b888e7a17dc2b757e7" + ], + "version": "==1.4.1" }, "typing-extensions": { "hashes": [ @@ -532,11 +544,11 @@ }, "urllib3": { "hashes": [ - "sha256:a8a318824cc77d1fd4b2bec2ded92646630d7fe8619497b142c84a9e6f5a7293", - "sha256:f3c5fd51747d450d4dcf6f923c81f78f811aab8205fda64b0aba34a4e48b0745" + "sha256:2f3db8b19923a873b3e5256dc9c2dedfa883e33d87c690d9c7913e1f40673cdc", + "sha256:87716c2d2a7121198ebcb7ce7cccf6ce5e9ba539041cfbaeecfb641dc0bf6acc" ], "index": "pypi", - "version": "==1.25.7" + "version": "==1.25.8" }, "uwsgi": { "hashes": [ @@ -553,11 +565,11 @@ }, "werkzeug": { "hashes": [ - "sha256:7280924747b5733b246fe23972186c6b348f9ae29724135a6dfc1e53cea433e7", - "sha256:e5f4a1f98b52b18a93da705a7458e55afb26f32bff83ff5d19189f92462d65c4" + "sha256:1e0dedc2acb1f46827daa2e399c1485c8fa17c0d8e70b6b875b4e7f54bf408d2", + "sha256:b353856d37dec59d6511359f97f6a4b2468442e454bd1c98298ddce53cac1f04" ], "index": "pypi", - "version": "==0.16.0" + "version": "==0.16.1" }, "wtforms": { "hashes": [ @@ -575,6 +587,20 @@ ], "version": "==0.7.12" }, + "appdirs": { + "hashes": [ + "sha256:9e5896d1372858f8dd3344faf4e5014d21849c756c8d5701f78f8a103b372d92", + "sha256:d8b24664561d0d34ddfaec54636d502d7cea6e29c3eaf68f3df6180863e2166e" + ], + "version": "==1.4.3" + }, + "aspy.yaml": { + "hashes": [ + "sha256:463372c043f70160a9ec950c3f1e4c3a82db5fca01d334b6bc89c7164d744bdc", + "sha256:e7c742382eff2caed61f87a39d13f99109088e5e93f04d76eb8d4b28aa143f45" + ], + "version": "==1.3.0" + }, "astroid": { "hashes": [ "sha256:71ea07f44df9568a75d0f354c49143a4575d90645e9fead6dfb52c26a85ed13a", @@ -582,12 +608,19 @@ ], "version": "==2.3.3" }, + "attrs": { + "hashes": [ + "sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c", + "sha256:f7b7ce16570fe9965acd6d30101a28f62fb4a7f9e926b3bbc9b61f8b04247e72" + ], + "version": "==19.3.0" + }, "babel": { "hashes": [ - "sha256:af92e6106cb7c55286b25b38ad7695f8b4efb36a90ba483d7f7a6628c46158ab", - "sha256:e86135ae101e31e2c8ec20a4e0c5220f4eed12487d5cf3f78be7e98d3a57fc28" + "sha256:1aac2ae2d0d8ea368fa90906567f5c08463d98ade155c0c4bfedd6a0f7160e38", + "sha256:d670ea0b10f8b723672d3a6abeb87b565b244da220d76b4dba1b66269ec152d4" ], - "version": "==2.7.0" + "version": "==2.8.0" }, "certifi": { "hashes": [ @@ -597,6 +630,13 @@ "index": "pypi", "version": "==2017.7.27.1" }, + "cfgv": { + "hashes": [ + "sha256:04b093b14ddf9fd4d17c53ebfd55582d27b76ed30050193c14e560770c5360eb", + "sha256:f22b426ed59cd2ab2b54ff96608d846c33dfb8766a67f0b4a6ce130ce244414f" + ], + "version": "==3.0.0" + }, "chardet": { "hashes": [ "sha256:84ab92ed1c4d4f16916e05906b6b75a6c0fb5db821cc65e70cbd64a3e2a5eaae", @@ -608,17 +648,13 @@ "coverage": { "hashes": [ "sha256:007eeef7e23f9473622f7d94a3e029a45d55a92a1f083f0f3512f5ab9a669b05", - "sha256:0388c12539372bb92d6dde68b4627f0300d948965bbb7fc104924d715fdc0965", "sha256:079248312838c4c8f3494934ab7382a42d42d5f365f0cf7516f938dbb3f53f3f", "sha256:17307429935f96c986a1b1674f78079528833410750321d22b5fb35d1883828e", - "sha256:1afccd7e27cac1b9617be8c769f6d8a6d363699c9b86820f40c74cfb3328921c", "sha256:2ad357d12971e77360034c1596011a03f50c0f9e1ecd12e081342b8d1aee2236", - "sha256:2b4d7f03a8a6632598cbc5df15bbca9f778c43db7cf1a838f4fa2c8599a8691a", "sha256:2e1a5c6adebb93c3b175103c2f855eda957283c10cf937d791d81bef8872d6ca", "sha256:309d91bd7a35063ec7a0e4d75645488bfab3f0b66373e7722f23da7f5b0f34cc", "sha256:358d635b1fc22a425444d52f26287ae5aea9e96e254ff3c59c407426f44574f4", "sha256:3f4d0b3403d3e110d2588c275540649b1841725f5a11a7162620224155d00ba2", - "sha256:43a155eb76025c61fc20c3d03b89ca28efa6f5be572ab6110b2fb68eda96bfea", "sha256:493082f104b5ca920e97a485913de254cbe351900deed72d4264571c73464cd0", "sha256:4c4f368ffe1c2e7602359c2c50233269f3abe1c48ca6b288dcd0fb1d1c679733", "sha256:5ff16548492e8a12e65ff3d55857ccd818584ed587a6c2898a9ebbe09a880674", @@ -629,11 +665,8 @@ "sha256:845fddf89dca1e94abe168760a38271abfc2e31863fbb4ada7f9a99337d7c3dc", "sha256:87d942863fe74b1c3be83a045996addf1639218c2cb89c5da18c06c0fe3917ea", "sha256:9721f1b7275d3112dc7ccf63f0553c769f09b5c25a26ee45872c7f5c09edf6c1", - "sha256:a4497faa4f1c0fc365ba05eaecfb6b5d24e3c8c72e95938f9524e29dadb15e76", "sha256:a7cfaebd8f24c2b537fa6a271229b051cdac9c1734bb6f939ccfc7c055689baa", - "sha256:ab3508df9a92c1d3362343d235420d08e2662969b83134f8a97dc1451cbe5e84", "sha256:b0059630ca5c6b297690a6bf57bf2fdac1395c24b7935fd73ee64190276b743b", - "sha256:b6cebae1502ce5b87d7c6f532fa90ab345cfbda62b95aeea4e431e164d498a3d", "sha256:bd4800e32b4c8d99c3a2c943f1ac430cbf80658d884123d19639bcde90dad44a", "sha256:cdd92dd9471e624cd1d8c1a2703d25f114b59b736b0f1f659a98414e535ffb3d", "sha256:d00e29b78ff610d300b2c37049a41234d48ea4f2d2581759ebcf67caaf731c31", @@ -645,19 +678,24 @@ "sha256:f29841e865590af72c4b90d7b5b8e93fd560f5dea436c1d5ee8053788f9285de", "sha256:f3a5c6d054c531536a83521c00e5d4004f1e126e2e2556ce399bef4180fbe540", "sha256:f87f522bde5540d8a4b11df80058281ac38c44b13ce29ced1e294963dd51a8f8", - "sha256:f8c55dd0f56d3d618dfacf129e010cbe5d5f94b6951c1b2f13ab1a2f79c284da", - "sha256:f98b461cb59f117887aa634a66022c0bd394278245ed51189f63a036516e32de" + "sha256:f8c55dd0f56d3d618dfacf129e010cbe5d5f94b6951c1b2f13ab1a2f79c284da" ], "index": "pypi", "version": "==4.4.2" }, "coveralls": { "hashes": [ - "sha256:25522a50cdf720d956601ca6ef480786e655ae2f0c94270c77e1a23d742de558", - "sha256:8e3315e8620bb6b3c6f3179a75f498e7179c93b3ddc440352404f941b1f70524" + "sha256:4b6bfc2a2a77b890f556bc631e35ba1ac21193c356393b66c84465c06218e135", + "sha256:67188c7ec630c5f708c31552f2bcdac4580e172219897c4136504f14b823132f" ], "index": "pypi", - "version": "==1.9.2" + "version": "==1.11.1" + }, + "distlib": { + "hashes": [ + "sha256:2e166e231a26b36d6dfe35a48c4464346620f8645ed0ace01ee31822b288de21" + ], + "version": "==0.3.0" }, "docopt": { "hashes": [ @@ -674,6 +712,20 @@ "index": "pypi", "version": "==0.14" }, + "filelock": { + "hashes": [ + "sha256:18d82244ee114f543149c66a6e0c14e9c4f8a1044b5cdaadd0f82159d6a6ff59", + "sha256:929b7d63ec5b7d6b71b0fa5ac14e030b3f70b75747cef1b10da9b879fef15836" + ], + "version": "==3.0.12" + }, + "identify": { + "hashes": [ + "sha256:1222b648251bdcb8deb240b294f450fbf704c7984e08baa92507e4ea10b436d5", + "sha256:d824ebe21f38325c771c41b08a95a761db1982f1fc0eee37c6c97df3f1636b96" + ], + "version": "==1.4.11" + }, "idna": { "hashes": [ "sha256:2c6a5de3089009e3da7c5dde64a141dbc8551d5b7f6cf4ed7c2568d0cc520a8f", @@ -684,10 +736,26 @@ }, "imagesize": { "hashes": [ - "sha256:3f349de3eb99145973fefb7dbe38554414e5c30abd0c8e4b970a7c9d09f3a1d8", - "sha256:f3832918bc3c66617f92e35f5d70729187676313caa60c187eb0f28b8fe5e3b5" + "sha256:6965f19a6a2039c7d48bca7dba2473069ff854c36ae6f19d2cde309d998228a1", + "sha256:b1f6b5a4eab1f73479a50fb79fcf729514a900c341d8503d62a62dbc4127a2b1" + ], + "version": "==1.2.0" + }, + "importlib-metadata": { + "hashes": [ + "sha256:06f5b3a99029c7134207dd882428a66992a9de2bef7c2b699b5641f9886c3302", + "sha256:b97607a1a18a5100839aec1dc26a1ea17ee0d93b20b0f008d80a5a050afb200b" + ], + "markers": "python_version < '3.8'", + "version": "==1.5.0" + }, + "importlib-resources": { + "hashes": [ + "sha256:6e2783b2538bd5a14678284a3962b0660c715e5a0f10243fd5e00a4b5974f50b", + "sha256:d3279fd0f6f847cced9f7acc19bd3e5df54d34f93a2e7bb5f238f81545787078" ], - "version": "==1.1.0" + "markers": "python_version < '3.7'", + "version": "==1.0.2" }, "isort": { "hashes": [ @@ -698,11 +766,11 @@ }, "jinja2": { "hashes": [ - "sha256:74320bb91f31270f9551d46522e33af46a80c3d619f4a4bf42b3164d30b5911f", - "sha256:9fe95f19286cfefaa917656583d020be14e7859c6b0252588391e47db34527de" + "sha256:93187ffbc7808079673ef52771baa950426fd664d3aad1d0fa3e95644360e250", + "sha256:b0eaf100007721b5c16c1fc1eecb87409464edc10469ddc9a22a27a99123be49" ], "index": "pypi", - "version": "==2.10.3" + "version": "==2.11.1" }, "lazy-object-proxy": { "hashes": [ @@ -745,12 +813,81 @@ "index": "pypi", "version": "==0.6.1" }, + "more-itertools": { + "hashes": [ + "sha256:5dd8bcf33e5f9513ffa06d5ad33d78f31e1931ac9a18f33d37e77a180d393a7c", + "sha256:b1ddb932186d8a6ac451e1d95844b382f55e12686d51ca0c68b6f61f2ab7a507" + ], + "version": "==8.2.0" + }, + "mypy": { + "hashes": [ + "sha256:0107bff4f46a289f0e4081d59b77cef1c48ea43da5a0dbf0005d54748b26df2a", + "sha256:07957f5471b3bb768c61f08690c96d8a09be0912185a27a68700f3ede99184e4", + "sha256:10af62f87b6921eac50271e667cc234162a194e742d8e02fc4ddc121e129a5b0", + "sha256:11fd60d2f69f0cefbe53ce551acf5b1cec1a89e7ce2d47b4e95a84eefb2899ae", + "sha256:15e43d3b1546813669bd1a6ec7e6a11d2888db938e0607f7b5eef6b976671339", + "sha256:352c24ba054a89bb9a35dd064ee95ab9b12903b56c72a8d3863d882e2632dc76", + "sha256:437020a39417e85e22ea8edcb709612903a9924209e10b3ec6d8c9f05b79f498", + "sha256:49925f9da7cee47eebf3420d7c0e00ec662ec6abb2780eb0a16260a7ba25f9c4", + "sha256:6724fcd5777aa6cebfa7e644c526888c9d639bd22edd26b2a8038c674a7c34bd", + "sha256:7a17613f7ea374ab64f39f03257f22b5755335b73251d0d253687a69029701ba", + "sha256:cdc1151ced496ca1496272da7fc356580e95f2682be1d32377c22ddebdf73c91" + ], + "index": "pypi", + "version": "==0.720" + }, + "mypy-extensions": { + "hashes": [ + "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d", + "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8" + ], + "index": "pypi", + "version": "==0.4.3" + }, + "nodeenv": { + "hashes": [ + "sha256:5b2438f2e42af54ca968dd1b374d14a1194848955187b0e5e4be1f73813a5212" + ], + "version": "==1.3.5" + }, + "nose": { + "hashes": [ + "sha256:9ff7c6cc443f8c51994b34a667bbcf45afd6d945be7477b52e97516fd17c53ac", + "sha256:dadcddc0aefbf99eea214e0f1232b94f2fa9bd98fa8353711dacb112bfcbbb2a", + "sha256:f1bffef9cbc82628f6e7d7b40d7e255aefaa1adb6a1b1d26c69a8b79e6208a98" + ], + "index": "pypi", + "version": "==1.3.7" + }, "packaging": { "hashes": [ - "sha256:28b924174df7a2fa32c1953825ff29c61e2f5e082343165438812f00d3a7fc47", - "sha256:d9551545c6d761f3def1677baf08ab2a3ca17c56879e70fecba2fc4dde4ed108" + "sha256:170748228214b70b672c581a3dd610ee51f733018650740e98c7df862a583f73", + "sha256:e665345f9eef0c621aa0bf2f8d78cf6d21904eef16a93f020240b704a57f1334" + ], + "version": "==20.1" + }, + "pluggy": { + "hashes": [ + "sha256:15b2acde666561e1298d71b523007ed7364de07029219b604cf808bfa1c765b0", + "sha256:966c145cd83c96502c3c3868f50408687b38434af77734af1e9ca461a4081d2d" ], - "version": "==19.2" + "version": "==0.13.1" + }, + "pre-commit": { + "hashes": [ + "sha256:0385479a0fe0765b1d32241f6b5358668cb4b6496a09aaf9c79acc6530489dbb", + "sha256:bf80d9dd58bea4f45d5d71845456fdcb78c1027eda9ed562db6fa2bd7a680c3a" + ], + "index": "pypi", + "version": "==2.0.1" + }, + "py": { + "hashes": [ + "sha256:5e27081401262157467ad6e7f851b7aa402c5852dbcb3dae06768434de5752aa", + "sha256:c20fdd83a5dbc0af9efd622bee9a5564e278f6380fffcacc43ba6f43db2813b0" + ], + "version": "==1.8.1" }, "pygments": { "hashes": [ @@ -769,40 +906,57 @@ }, "pyparsing": { "hashes": [ - "sha256:20f995ecd72f2a1f4bf6b072b63b22e2eb457836601e76d6e5dfcd75436acc1f", - "sha256:4ca62001be367f01bd3e92ecbb79070272a9d4964dce6a48a82ff0b8bc7e683a" + "sha256:4c830582a84fb022400b85429791bc551f1f4871c33f23e44f353119e92f969f", + "sha256:c342dccb5250c08d45fd6f8b4a559613ca603b57498511740e65cd11a2e7dcec" + ], + "version": "==2.4.6" + }, + "pytest": { + "hashes": [ + "sha256:0d5fe9189a148acc3c3eb2ac8e1ac0742cb7618c084f3d228baaec0c254b318d", + "sha256:ff615c761e25eb25df19edddc0b970302d2a9091fbce0e7213298d85fb61fef6" ], - "version": "==2.4.5" + "index": "pypi", + "version": "==5.3.5" }, "pytz": { "hashes": [ - "sha256:59707844a9825589878236ff2f4e0dc9958511b7ffaae94dc615da07d4a68d33", - "sha256:699d18a2a56f19ee5698ab1123bbcc1d269d061996aeb1eda6d89248d3542b82", - "sha256:80af0f3008046b9975242012a985f04c5df1f01eed4ec1633d56cc47a75a6a48", - "sha256:8cc90340159b5d7ced6f2ba77694d946fc975b09f1a51d93f3ce3bb399396f94", "sha256:c41c62827ce9cafacd6f2f7018e4f83a6f1986e87bfd000b8cfbd4ab5da95f1a", - "sha256:d0ef5ef55ed3d37854320d4926b04a4cb42a2e88f71da9ddfdacfde8e364f027", - "sha256:dd2e4ca6ce3785c8dd342d1853dd9052b19290d5bf66060846e5dc6b8d6667f7", - "sha256:fae4cffc040921b8a2d60c6cf0b5d662c1190fe54d718271db4eb17d44a185b7", - "sha256:feb2365914948b8620347784b6b6da356f31c9d03560259070b2f30cff3d469d" + "sha256:fae4cffc040921b8a2d60c6cf0b5d662c1190fe54d718271db4eb17d44a185b7" ], "index": "pypi", "version": "==2017.3" }, + "pyyaml": { + "hashes": [ + "sha256:059b2ee3194d718896c0ad077dd8c043e5e909d9180f387ce42012662a4946d6", + "sha256:1cf708e2ac57f3aabc87405f04b86354f66799c8e62c28c5fc5f88b5521b2dbf", + "sha256:24521fa2890642614558b492b473bee0ac1f8057a7263156b02e8b14c88ce6f5", + "sha256:4fee71aa5bc6ed9d5f116327c04273e25ae31a3020386916905767ec4fc5317e", + "sha256:70024e02197337533eef7b85b068212420f950319cc8c580261963aefc75f811", + "sha256:74782fbd4d4f87ff04159e986886931456a1894c61229be9eaf4de6f6e44b99e", + "sha256:940532b111b1952befd7db542c370887a8611660d2b9becff75d39355303d82d", + "sha256:cb1f2f5e426dc9f07a7681419fe39cee823bb74f723f36f70399123f439e9b20", + "sha256:dbbb2379c19ed6042e8f11f2a2c66d39cceb8aeace421bfc29d085d93eda3689", + "sha256:e3a057b7a64f1222b56e47bcff5e4b94c4f61faac04c7c4ecb1985e18caa3994", + "sha256:e9f45bd5b92c7974e59bcd2dcc8631a6b6cc380a904725fce7bc08872e691615" + ], + "version": "==5.3" + }, "requests": { "hashes": [ - "sha256:11e007a8a2aa0323f5a921e9e6a2d7e4e67d9877e85773fba9ba6419025cbeb4", - "sha256:9cf5292fcd0f598c671cfc1e0d7d1a7f13bb8085e9a590f48c010551dc6c4b31" + "sha256:43999036bfa82904b6af1d99e4882b560e5e2c68e5c4b0aa03b655f3d7d73fee", + "sha256:b3f43d496c6daba4493e7c431722aeb7dbc6288f52a6e04e7b6023b0247817e6" ], "index": "pypi", - "version": "==2.22.0" + "version": "==2.23.0" }, "six": { "hashes": [ - "sha256:1f1b7d42e254082a9db6279deae68afb421ceba6158efa6131de7b3003ee93fd", - "sha256:30f610279e8b2578cab6db20741130331735c781b56053c59c4076da27f06b66" + "sha256:236bdbdce46e6e6a3d61a337c0f8b763ca1e8717c03b369e87a7ec7ce1319c0a", + "sha256:8f3cd2e254d8f793e7f3d6d9df77b92252b52637291d0f0da013c76ea2724b6c" ], - "version": "==1.13.0" + "version": "==1.14.0" }, "snowballstemmer": { "hashes": [ @@ -814,11 +968,11 @@ }, "sphinx": { "hashes": [ - "sha256:0a11e2fd31fe5c7e64b4fc53c2c022946512f021d603eb41ac6ae51d5fcbb574", - "sha256:138e39aa10f28d52aa5759fc6d1cba2be6a4b750010974047fa7d0e31addcf63" + "sha256:525527074f2e0c2585f68f73c99b4dc257c34bbe308b27f5f8c7a6e20642742f", + "sha256:543d39db5f82d83a5c1aa0c10c88f2b6cff2da3e711aa849b2c627b4b403bbd9" ], "index": "pypi", - "version": "==2.3.0" + "version": "==2.4.2" }, "sphinx-autodoc-typehints": { "hashes": [ @@ -872,49 +1026,87 @@ }, "sphinxcontrib-websupport": { "hashes": [ - "sha256:1501befb0fdf1d1c29a800fdbf4ef5dc5369377300ddbdd16d2cd40e54c6eefc", - "sha256:e02f717baf02d0b6c3dd62cf81232ffca4c9d5c331e03766982e3ff9f1d2bc3f" + "sha256:50fb98fcb8ff2a8869af2afa6b8ee51b3baeb0b17dacd72505105bf15d506ead", + "sha256:bad3fbd312bc36a31841e06e7617471587ef642bdacdbdddaa8cc30cf251b5ea" ], "index": "pypi", - "version": "==1.1.2" + "version": "==1.2.0" + }, + "toml": { + "hashes": [ + "sha256:229f81c57791a41d65e399fc06bf0848bab550a9dfd5ed66df18ce5f05e73d5c", + "sha256:235682dd292d5899d361a811df37e04a8828a5b1da3115886b73cf81ebc9100e" + ], + "version": "==0.10.0" }, "typed-ast": { "hashes": [ - "sha256:132eae51d6ef3ff4a8c47c393a4ef5ebf0d1aecc96880eb5d6c8ceab7017cc9b", - "sha256:18141c1484ab8784006c839be8b985cfc82a2e9725837b0ecfa0203f71c4e39d", - "sha256:2baf617f5bbbfe73fd8846463f5aeafc912b5ee247f410700245d68525ec584a", - "sha256:3d90063f2cbbe39177e9b4d888e45777012652d6110156845b828908c51ae462", - "sha256:4304b2218b842d610aa1a1d87e1dc9559597969acc62ce717ee4dfeaa44d7eee", - "sha256:4983ede548ffc3541bae49a82675996497348e55bafd1554dc4e4a5d6eda541a", - "sha256:5315f4509c1476718a4825f45a203b82d7fdf2a6f5f0c8f166435975b1c9f7d4", - "sha256:6cdfb1b49d5345f7c2b90d638822d16ba62dc82f7616e9b4caa10b72f3f16649", - "sha256:7b325f12635598c604690efd7a0197d0b94b7d7778498e76e0710cd582fd1c7a", - "sha256:8d3b0e3b8626615826f9a626548057c5275a9733512b137984a68ba1598d3d2f", - "sha256:8f8631160c79f53081bd23446525db0bc4c5616f78d04021e6e434b286493fd7", - "sha256:912de10965f3dc89da23936f1cc4ed60764f712e5fa603a09dd904f88c996760", - "sha256:b010c07b975fe853c65d7bbe9d4ac62f1c69086750a574f6292597763781ba18", - "sha256:c908c10505904c48081a5415a1e295d8403e353e0c14c42b6d67f8f97fae6616", - "sha256:c94dd3807c0c0610f7c76f078119f4ea48235a953512752b9175f9f98f5ae2bd", - "sha256:ce65dee7594a84c466e79d7fb7d3303e7295d16a83c22c7c4037071b059e2c21", - "sha256:eaa9cfcb221a8a4c2889be6f93da141ac777eb8819f077e1d09fb12d00a09a93", - "sha256:f3376bc31bad66d46d44b4e6522c5c21976bf9bca4ef5987bb2bf727f4506cbb", - "sha256:f9202fa138544e13a4ec1a6792c35834250a85958fde1251b6a22e07d1260ae7" + "sha256:0666aa36131496aed8f7be0410ff974562ab7eeac11ef351def9ea6fa28f6355", + "sha256:0c2c07682d61a629b68433afb159376e24e5b2fd4641d35424e462169c0a7919", + "sha256:249862707802d40f7f29f6e1aad8d84b5aa9e44552d2cc17384b209f091276aa", + "sha256:24995c843eb0ad11a4527b026b4dde3da70e1f2d8806c99b7b4a7cf491612652", + "sha256:269151951236b0f9a6f04015a9004084a5ab0d5f19b57de779f908621e7d8b75", + "sha256:4083861b0aa07990b619bd7ddc365eb7fa4b817e99cf5f8d9cf21a42780f6e01", + "sha256:498b0f36cc7054c1fead3d7fc59d2150f4d5c6c56ba7fb150c013fbc683a8d2d", + "sha256:4e3e5da80ccbebfff202a67bf900d081906c358ccc3d5e3c8aea42fdfdfd51c1", + "sha256:6daac9731f172c2a22ade6ed0c00197ee7cc1221aa84cfdf9c31defeb059a907", + "sha256:715ff2f2df46121071622063fc7543d9b1fd19ebfc4f5c8895af64a77a8c852c", + "sha256:73d785a950fc82dd2a25897d525d003f6378d1cb23ab305578394694202a58c3", + "sha256:8c8aaad94455178e3187ab22c8b01a3837f8ee50e09cf31f1ba129eb293ec30b", + "sha256:8ce678dbaf790dbdb3eba24056d5364fb45944f33553dd5869b7580cdbb83614", + "sha256:aaee9905aee35ba5905cfb3c62f3e83b3bec7b39413f0a7f19be4e547ea01ebb", + "sha256:bcd3b13b56ea479b3650b82cabd6b5343a625b0ced5429e4ccad28a8973f301b", + "sha256:c9e348e02e4d2b4a8b2eedb48210430658df6951fa484e59de33ff773fbd4b41", + "sha256:d205b1b46085271b4e15f670058ce182bd1199e56b317bf2ec004b6a44f911f6", + "sha256:d43943ef777f9a1c42bf4e552ba23ac77a6351de620aa9acf64ad54933ad4d34", + "sha256:d5d33e9e7af3b34a40dc05f498939f0ebf187f07c385fd58d591c533ad8562fe", + "sha256:fc0fea399acb12edbf8a628ba8d2312f583bdbdb3335635db062fa98cf71fca4", + "sha256:fe460b922ec15dd205595c9b5b99e2f056fd98ae8f9f56b888e7a17dc2b757e7" + ], + "version": "==1.4.1" + }, + "typing-extensions": { + "hashes": [ + "sha256:091ecc894d5e908ac75209f10d5b4f118fbdb2eb1ede6a63544054bb1edb41f2", + "sha256:910f4656f54de5993ad9304959ce9bb903f90aadc7c67a0bef07e678014e892d", + "sha256:cf8b63fedea4d89bab840ecbb93e75578af28f76f66c35889bd7065f5af88575" ], - "version": "==1.3.5" + "version": "==3.7.4.1" }, "urllib3": { "hashes": [ - "sha256:a8a318824cc77d1fd4b2bec2ded92646630d7fe8619497b142c84a9e6f5a7293", - "sha256:f3c5fd51747d450d4dcf6f923c81f78f811aab8205fda64b0aba34a4e48b0745" + "sha256:2f3db8b19923a873b3e5256dc9c2dedfa883e33d87c690d9c7913e1f40673cdc", + "sha256:87716c2d2a7121198ebcb7ce7cccf6ce5e9ba539041cfbaeecfb641dc0bf6acc" ], "index": "pypi", - "version": "==1.25.7" + "version": "==1.25.8" + }, + "virtualenv": { + "hashes": [ + "sha256:08f3623597ce73b85d6854fb26608a6f39ee9d055c81178dc6583803797f8994", + "sha256:de2cbdd5926c48d7b84e0300dea9e8f276f61d186e8e49223d71d91250fbaebd" + ], + "version": "==20.0.4" + }, + "wcwidth": { + "hashes": [ + "sha256:8fd29383f539be45b20bd4df0dc29c20ba48654a41e661925e612311e9f3c603", + "sha256:f28b3e8a6483e5d49e7f8949ac1a78314e740333ae305b4ba5defd3e74fb37a8" + ], + "version": "==0.1.8" }, "wrapt": { "hashes": [ "sha256:565a021fd19419476b9362b05eeaa094178de64f8361e44468f9e9d7843901e1" ], "version": "==1.11.2" + }, + "zipp": { + "hashes": [ + "sha256:12248a63bbdf7548f89cb4c7cda4681e537031eda29c02ea29674bc6854460c2", + "sha256:7c0f8e91abc0dc07a5068f315c52cb30c66bfbc581e5b50704c8a2f6ebae794a" + ], + "version": "==3.0.0" } } } diff --git a/README.md b/README.md index d297f528..0cc6cbc1 100644 --- a/README.md +++ b/README.md @@ -46,8 +46,8 @@ is only accessible from the CUL network. ```bash pipenv install -FLASK_APP=app.py FLASK_DEBUG=1 ELASTICSEARCH_HOST=127.0.0.1 pipenv run python create_index.py -FLASK_APP=app.py FLASK_DEBUG=1 ELASTICSEARCH_HOST=127.0.0.1 pipenv run python bulk_index.py +FLASK_APP=app.py FLASK_DEBUG=1 ELASTICSEARCH_SERVICE_HOST=127.0.0.1 pipenv run python create_index.py +FLASK_APP=app.py FLASK_DEBUG=1 ELASTICSEARCH_SERVICE_HOST=127.0.0.1 pipenv run python bulk_index.py ``` ``bulk_index.py`` without parameters populate the index with the @@ -58,7 +58,7 @@ parameter. To check for missing records, use ``audit.py``: ```bash -ELASTICSEARCH_HOST=127.0.0.1 ELASTICSEARCH_INDEX=arxiv pipenv run python audit.py -l list_of_papers.txt -o missing.txt +ELASTICSEARCH_SERVICE_HOST=127.0.0.1 ELASTICSEARCH_INDEX=arxiv pipenv run python audit.py -l list_of_papers.txt -o missing.txt ``` ### Reindexing @@ -70,7 +70,7 @@ processed. If the destination index does not already exist, it will be created using the current configured mapping. ```bash -FLASK_APP=app.py ELASTICSEARCH_HOST=127.0.0.1 pipenv run python reindex.py OLD_INDEX NEW_INDEX +FLASK_APP=app.py ELASTICSEARCH_SERVICE_HOST=127.0.0.1 pipenv run python reindex.py OLD_INDEX NEW_INDEX ``` @@ -79,7 +79,7 @@ FLASK_APP=app.py ELASTICSEARCH_HOST=127.0.0.1 pipenv run python reindex.py OLD_I You can spin up the search app directly. ```bash -FLASK_APP=app.py FLASK_DEBUG=1 ELASTICSEARCH_HOST=127.0.0.1 pipenv run flask run +FLASK_APP=app.py FLASK_DEBUG=1 ELASTICSEARCH_SERVICE_HOST=127.0.0.1 pipenv run flask run ``` This will monitor any of the Python bits for changes and restart the server. Unfortunately static files and templates are not monitored, so you'll have to @@ -87,13 +87,18 @@ manually restart to see those changes take effect. If all goes well... http://127.0.0.1:5000/ should render the basic search page. -You can run the API in dev mode by changing `FLASK_APP` to point to ``api.py``, +You can run the new metadata API in dev mode by changing `FLASK_APP` to point to ``wsgi-api.py``, i.e.: ```bash -FLASK_APP=api.py FLASK_DEBUG=1 ELASTICSEARCH_HOST=127.0.0.1 pipenv run flask run +JWT_SECRET=foosecret FLASK_APP=wsgi-api.py FLASK_DEBUG=1 ELASTICSEARCH_SERVICE_HOST=127.0.0.1 pipenv run flask run ``` +To run the classic API in dev mode, use ``wsgi-classic-api.py``: + +```bash +FLASK_APP=wsgi-classic-api.py FLASK_DEBUG=1 ELASTICSEARCH_SERVICE_HOST=127.0.0.1 pipenv run flask run +``` ## Running the indexing agent. @@ -275,3 +280,27 @@ make [format] where [format] can be ``html``, ``latexpdf``. See the ``Sphinx documentation `_. + + +## Pre commit hooks + +To run pre commit hooks install the dev dependencies: + +```bash +pipenv install --dev +``` + +After that you'll need to install the pre commit hooks: + +```bash +pipenv run pre-commit install +``` + +Git will run all the pre-commit hooks on all changed files before you are +allowed to commit. You will be allowed to commit only if all checks pass. + +You can also run the pre commit hooks manually with: + +```bash +pipenv run pre-commit run +``` diff --git a/audit.py b/audit.py index b6f6b4fc..3a3729fa 100644 --- a/audit.py +++ b/audit.py @@ -32,20 +32,27 @@ def exists(chunk: List[str]) -> List[Tuple[str, bool]]: """ with app.app_context(): from search.services import index + status = [] for ident in chunk: - time.sleep(0.05) # TODO: make this configurable? + time.sleep(0.05) # TODO: make this configurable? status.append((ident, index.SearchSession.exists(ident))) return status @app.cli.command() -@click.option('--id_list', '-l', - help="Index paper IDs in a file (one ID per line)") -@click.option('--batch-size', '-b', type=int, default=1_600, - help="Number of records to process each iteration") -@click.option('--n-workers', '-n', type=int, default=8, help="Num of workers") -@click.option('--output', '-o', help="File in which missing IDs are stored") +@click.option( + "--id_list", "-l", help="Index paper IDs in a file (one ID per line)" +) +@click.option( + "--batch-size", + "-b", + type=int, + default=1_600, + help="Number of records to process each iteration", +) +@click.option("--n-workers", "-n", type=int, default=8, help="Num of workers") +@click.option("--output", "-o", help="File in which missing IDs are stored") def audit(id_list: str, batch_size: int, n_workers: int, output: str): """ Check the index for missing papers. @@ -74,7 +81,7 @@ def audit(id_list: str, batch_size: int, n_workers: int, output: str): raise click.ClickException( "batch size must be divisible by the number of workers" ) - chunk_size = int(round(batch_size/n_workers)) + chunk_size = int(round(batch_size / n_workers)) if not os.path.exists(id_list): raise click.ClickException("no such file") @@ -84,32 +91,34 @@ def audit(id_list: str, batch_size: int, n_workers: int, output: str): data = [row[0] for row in csv.reader(f)] # Create the output file. - with open(output, 'w') as f: - f.write('') + with open(output, "w") as f: + f.write("") N_results = 0 N_total = len(data) - with click.progressbar(length=N_total, label='Papers checked') as bar: + with click.progressbar(length=N_total, label="Papers checked") as bar: # We do this in batches, so that we can track and save as we go. for i in range(0, len(data), batch_size): - batch = data[i:i + batch_size] - chunks = [batch[c:c + chunk_size] - for c in range(0, batch_size, chunk_size)] + batch = data[i : i + batch_size] + chunks = [ + batch[c : c + chunk_size] + for c in range(0, batch_size, chunk_size) + ] with Pool(n_workers) as p: results = reduce(concat, p.map(exists, chunks)) # Write one missing paper ID per line. - with open(output, 'a') as f: # Append to output file. + with open(output, "a") as f: # Append to output file. for ident, status in results: if status: continue - f.write(f'{ident}\n') + f.write(f"{ident}\n") N_results += len(results) bar.update(N_results) -if __name__ == '__main__': +if __name__ == "__main__": audit() diff --git a/bin/start_agent.py b/bin/start_agent.py index dfd41937..221ed8a2 100644 --- a/bin/start_agent.py +++ b/bin/start_agent.py @@ -10,5 +10,5 @@ def start_agent() -> None: process_stream() -if __name__ == '__main__': +if __name__ == "__main__": start_agent() diff --git a/bulk_index.py b/bulk_index.py index 92d55bcc..ca26579d 100644 --- a/bulk_index.py +++ b/bulk_index.py @@ -1,42 +1,56 @@ """Use this to populate a search index for testing.""" -import json import os +import json import tempfile -import click -from itertools import islice, groupby +import operator from typing import List -import re -from search.factory import create_ui_web_app -from search.agent import MetadataRecordProcessor, DocumentFailed, \ - IndexingFailed -from search.domain import asdict, DocMeta, Document -from search.services import metadata, index +from itertools import groupby + +import click + from search.process import transform +from search.domain import asdict, DocMeta +from search.services import metadata, index +from search.factory import create_ui_web_app + app = create_ui_web_app() @app.cli.command() -@click.option('--print_indexable', '-i', is_flag=True, - help='Print the indexable JSON to stdout.') -@click.option('--paper_id', '-p', - help='Index specified paper id') -@click.option('--id_list', '-l', - help="Index paper IDs in a file (one ID per line)") -@click.option('--load-cache', '-d', is_flag=True, - help="Install papers from a cache on disk. Note: this will" - " preempt checking for new versions of papers that are" - " in the cache.") -@click.option('--cache-dir', '-c', help="Specify the cache directory.") -def populate(print_indexable: bool, paper_id: str, id_list: str, - load_cache: bool, cache_dir: str) -> None: +@click.option( + "--print_indexable", + "-i", + is_flag=True, + help="Print the indexable JSON to stdout.", +) +@click.option("--paper_id", "-p", help="Index specified paper id") +@click.option( + "--id_list", "-l", help="Index paper IDs in a file (one ID per line)" +) +@click.option( + "--load-cache", + "-d", + is_flag=True, + help="Install papers from a cache on disk. Note: this will" + " preempt checking for new versions of papers that are" + " in the cache.", +) +@click.option("--cache-dir", "-c", help="Specify the cache directory.") +def populate( + print_indexable: bool, + paper_id: str, + id_list: str, + load_cache: bool, + cache_dir: str, +) -> None: """Populate the search index with some test data.""" cache_dir = init_cache(cache_dir) index_count = 0 - if paper_id: # Index a single paper. + if paper_id: # Index a single paper. TO_INDEX = [paper_id] - elif id_list: # Index a list of papers. + elif id_list: # Index a list of papers. TO_INDEX = load_id_list(id_list) else: TO_INDEX = load_id_sample() @@ -48,15 +62,16 @@ def populate(print_indexable: bool, paper_id: str, id_list: str, meta: List[DocMeta] = [] index.SearchSession.create_index() try: - with click.progressbar(length=approx_size, - label='Papers indexed') as index_bar: + with click.progressbar( + length=approx_size, label="Papers indexed" + ) as index_bar: last = len(TO_INDEX) - 1 for i, paper_id in enumerate(TO_INDEX): this_meta = [] if load_cache: try: this_meta = from_cache(cache_dir, paper_id) - except RuntimeError as e: # No document. + except RuntimeError: # No document. pass if this_meta: @@ -67,10 +82,10 @@ def populate(print_indexable: bool, paper_id: str, id_list: str, if len(chunk) == retrieve_chunk_size or i == last: try: new_meta = metadata.bulk_retrieve(chunk) - except metadata.ConnectionFailed as e: # Try again. + except metadata.ConnectionFailed: # Try again. new_meta = metadata.bulk_retrieve(chunk) # Add metadata to the cache. - key = lambda dm: dm.paper_id + key = operator.attrgetter("paper_id") new_meta_srt = sorted(new_meta, key=key) for paper_id, grp in groupby(new_meta_srt, key): to_cache(cache_dir, paper_id, [dm for dm in grp]) @@ -93,20 +108,25 @@ def populate(print_indexable: bool, paper_id: str, id_list: str, meta = [] index_bar.update(i) - except Exception as e: - raise RuntimeError('Populate failed: %s' % str(e)) from e + except Exception as ex: + raise RuntimeError("Populate failed: %s" % str(ex)) from ex finally: click.echo(f"Indexed {index_count} documents in total") - click.echo(f"Cache path: {cache_dir}; use `-c {cache_dir}` to reuse in" - f" subsequent calls") + click.echo( + f"Cache path: {cache_dir}; use `-c {cache_dir}` to reuse in" + f" subsequent calls" + ) def init_cache(cache_dir: str) -> None: """Configure the processor to use a local cache for docmeta.""" # Create cache directory if it doesn't exist - if not (cache_dir and os.path.exists(cache_dir) - and os.access(cache_dir, os.W_OK)): + if not ( + cache_dir + and os.path.exists(cache_dir) + and os.access(cache_dir, os.W_OK) + ): cache_dir = tempfile.mkdtemp() return cache_dir @@ -130,10 +150,10 @@ def from_cache(cache_dir: str, arxiv_id: str) -> List[DocMeta]: be found in the cache. """ - fname = '%s.json' % arxiv_id.replace('/', '_') + fname = "%s.json" % arxiv_id.replace("/", "_") cache_path = os.path.join(cache_dir, fname) if not os.path.exists(cache_path): - raise RuntimeError('No cached document') + raise RuntimeError("No cached document") with open(cache_path) as f: data: dict = json.load(f) @@ -157,19 +177,19 @@ def to_cache(cache_dir: str, arxiv_id: str, docmeta: List[DocMeta]) -> None: be added to the cache. """ - fname = '%s.json' % arxiv_id.replace('/', '_') + fname = "%s.json" % arxiv_id.replace("/", "_") cache_path = os.path.join(cache_dir, fname) try: - with open(cache_path, 'w') as f: + with open(cache_path, "w") as f: json.dump([asdict(dm) for dm in docmeta], f) - except Exception as e: - raise RuntimeError(str(e)) from e + except Exception as ex: + raise RuntimeError(str(ex)) from ex def load_id_list(path: str) -> List[str]: """Load a list of paper IDs from ``path``.""" if not os.path.exists(path): - raise RuntimeError('Path does not exist: %s' % path) + raise RuntimeError("Path does not exist: %s" % path) return with open(path) as f: # Stream from the file, in case it's large. @@ -178,9 +198,9 @@ def load_id_list(path: str) -> List[str]: def load_id_sample() -> List[str]: """Load a list of IDs from the testing sample.""" - with open('tests/data/sample.json') as f: - return [datum['id'] for datum in json.load(f).get('sample')] + with open("tests/data/sample.json") as f: + return [datum["id"] for datum in json.load(f).get("sample")] -if __name__ == '__main__': +if __name__ == "__main__": populate() diff --git a/create_index.py b/create_index.py index 3d3ad10b..2ee90b26 100644 --- a/create_index.py +++ b/create_index.py @@ -1,7 +1,6 @@ """Use this to initialize the search index for testing.""" -import json -import click + from search.factory import create_ui_web_app from search.services import index @@ -12,8 +11,8 @@ @app.cli.command() def create_index(): """Initialize the search index.""" - index.SearchSession().create_index() + index.SearchSession.create_index() -if __name__ == '__main__': +if __name__ == "__main__": create_index() diff --git a/docker-compose.yml b/docker-compose.yml index e146c3ae..11a4b6ef 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -35,7 +35,7 @@ services: AWS_SECRET_ACCESS_KEY: "bar" ELASTICSEARCH_SERVICE_HOST: "elasticsearch" ELASTICSEARCH_SERVICE_PORT: "9200" - ELASTICSEARCH_PORT_9200_PROTO: "http" + ELASTICSEARCH_SERVICE_PORT_9200_PROTO: "http" ELASTICSEARCH_USER: "elastic" ELASTICSEARCH_PASSWORD: "changeme" ELASTICSEARCH_VERIFY: "false" diff --git a/docs/source/conf.py b/docs/source/conf.py index b94ed978..99409c9c 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -19,8 +19,9 @@ # import os import sys -sys.path.insert(0, os.path.abspath('.')) -sys.path.append(os.path.abspath('../..')) + +sys.path.insert(0, os.path.abspath(".")) +sys.path.append(os.path.abspath("../..")) # -- General configuration ------------------------------------------------ @@ -32,45 +33,45 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx_autodoc_typehints', - 'sphinx.ext.autosummary', - 'sphinx.ext.napoleon', - 'sphinx.ext.intersphinx', - 'sphinx.ext.graphviz', - 'sphinx.ext.todo', - 'sphinx.ext.coverage', - 'sphinx.ext.mathjax', - 'sphinx.ext.ifconfig', - 'sphinx.ext.viewcode', - 'sphinx.ext.githubpages' + "sphinx.ext.autodoc", + "sphinx_autodoc_typehints", + "sphinx.ext.autosummary", + "sphinx.ext.napoleon", + "sphinx.ext.intersphinx", + "sphinx.ext.graphviz", + "sphinx.ext.todo", + "sphinx.ext.coverage", + "sphinx.ext.mathjax", + "sphinx.ext.ifconfig", + "sphinx.ext.viewcode", + "sphinx.ext.githubpages", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = 'arXiv search' -copyright = '2017, arXiv Team' -author = 'arXiv Team' +project = "arXiv search" +copyright = "2017, arXiv Team" +author = "arXiv Team" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. -version = '0.2' +version = "0.2" # The full version, including alpha/beta/rc tags. -release = '0.2' +release = "0.2" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -85,7 +86,7 @@ exclude_patterns = [] # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = True @@ -96,7 +97,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'alabaster' +html_theme = "alabaster" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -107,7 +108,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Custom sidebar templates, must be a dictionary that maps document names # to template names. @@ -115,11 +116,11 @@ # This is required for the alabaster theme # refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars html_sidebars = { - '**': [ - 'about.html', - 'navigation.html', - 'relations.html', # needs 'show_related': True theme option to display - 'searchbox.html', + "**": [ + "about.html", + "navigation.html", + "relations.html", # needs 'show_related': True theme option to display + "searchbox.html", ] } @@ -127,7 +128,7 @@ # -- Options for HTMLHelp output ------------------------------------------ # Output file base name for HTML help builder. -htmlhelp_basename = 'arXivsearchdoc' +htmlhelp_basename = "arXivsearchdoc" # -- Options for LaTeX output --------------------------------------------- @@ -136,15 +137,12 @@ # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # # 'preamble': '', - # Latex figure (float) alignment # # 'figure_align': 'htbp', @@ -154,8 +152,13 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'arXivsearch.tex', 'arXiv search Documentation', - 'Jane Bloggs', 'manual'), + ( + master_doc, + "arXivsearch.tex", + "arXiv search Documentation", + "Jane Bloggs", + "manual", + ) ] @@ -164,8 +167,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - (master_doc, 'arxivsearch', 'arXiv search Documentation', - [author], 1) + (master_doc, "arxivsearch", "arXiv search Documentation", [author], 1) ] @@ -175,18 +177,24 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'arXiv-search', 'arXiv Search Documentation', - author, 'arXiv-Search', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "arXiv-search", + "arXiv Search Documentation", + author, + "arXiv-Search", + "One line description of project.", + "Miscellaneous", + ) ] intersphinx_mapping = { - 'python': ('https://docs.python.org/3.6', None), - 'arxitecture': ('https://arXiv.github.io/arxiv-arxitecture/', None), - 'arxiv.taxonomy': ('https://arXiv.github.io/arxiv-base', None), - 'arxiv.base': ('https://arXiv.github.io/arxiv-base', None), - 'browse': ('https://arXiv.github.io/arxiv-browse/', None), - 'search': ('https://arXiv.github.io/arxiv-search/', None), - 'zero': ('https://arXiv.github.io/arxiv-zero/', None), + "python": ("https://docs.python.org/3.6", None), + "arxitecture": ("https://arXiv.github.io/arxiv-arxitecture/", None), + "arxiv.taxonomy": ("https://arXiv.github.io/arxiv-base", None), + "arxiv.base": ("https://arXiv.github.io/arxiv-base", None), + "browse": ("https://arXiv.github.io/arxiv-browse/", None), + "search": ("https://arXiv.github.io/arxiv-search/", None), + "zero": ("https://arXiv.github.io/arxiv-zero/", None), } diff --git a/lintstats.sh b/lintstats.sh index 81a6379e..dd96ee38 100755 --- a/lintstats.sh +++ b/lintstats.sh @@ -6,11 +6,9 @@ PYLINT_PASS=$(echo $PYLINT_SCORE">="$MIN_SCORE | bc -l) if [ "$TRAVIS_PULL_REQUEST_SHA" = "" ]; then SHA=$TRAVIS_COMMIT; else SHA=$TRAVIS_PULL_REQUEST_SHA; fi if [ "$PYLINT_PASS" ]; then PYLINT_STATE="success" && echo "pylint passed with score "$PYLINT_SCORE" for sha "$SHA; else PYLINT_STATE="failure" && echo "pylint failed with score "$PYLINT_SCORE" for sha "$SHA; fi -curl -u $USERNAME:$GITHUB_TOKEN \ - -d '{"state": "'$PYLINT_STATE'", "target_url": "https://travis-ci.org/'$TRAVIS_REPO_SLUG'/builds/'$TRAVIS_BUILD_ID'", "description": "'$PYLINT_SCORE'/10", "context": "code-quality/pylint"}' \ - -XPOST https://api.github.com/repos/$TRAVIS_REPO_SLUG/statuses/$SHA \ - > /dev/null 2>&1 - +curl -u $USER:$GITHUB_TOKEN \ + -d '{"state": "'$PYLINT_STATE'", "target_url": "https://travis-ci.com/'$TRAVIS_REPO_SLUG'/builds/'$TRAVIS_BUILD_ID'", "description": "'$PYLINT_SCORE'/10", "context": "code-quality/pylint"}' \ + -XPOST https://api.github.com/repos/$TRAVIS_REPO_SLUG/statuses/$SHA # Check mypy integration @@ -25,7 +23,7 @@ curl -u $USERNAME:$GITHUB_TOKEN \ # Check pydocstyle integration -pipenv run pydocstyle --convention=numpy --add-ignore=D401 search +pipenv run pydocstyle --convention=numpy --add-ignore=D401,D202 search PYDOCSTYLE_STATUS=$? if [ $PYDOCSTYLE_STATUS -ne 0 ]; then PYDOCSTYLE_STATE="failure" && echo "pydocstyle failed"; else PYDOCSTYLE_STATE="success" && echo "pydocstyle passed"; fi diff --git a/reindex.py b/reindex.py index ab1e0d74..98b1c952 100644 --- a/reindex.py +++ b/reindex.py @@ -1,7 +1,5 @@ """Helper script to reindex all arXiv papers.""" -import os -import tempfile import click import time @@ -12,8 +10,8 @@ @app.cli.command() -@click.argument('old_index', nargs=1) -@click.argument('new_index', nargs=1) +@click.argument("old_index", nargs=1) +@click.argument("new_index", nargs=1) def reindex(old_index: str, new_index: str): """ Reindex the documents in `old_index` to `new_index`. @@ -30,24 +28,24 @@ def reindex(old_index: str, new_index: str): raise click.ClickException("Failed to get or create new index") click.echo(f"Started reindexing task") - task_id = r['task'] - with click.progressbar(length=100, label='percent complete') as progress: + task_id = r["task"] + with click.progressbar(length=100, label="percent complete") as progress: while True: status = index.SearchSession.get_task_status(task_id) - total = float(status['task']['status']['total']) - if status['completed'] or total == 0: + total = float(status["task"]["status"]["total"]) + if status["completed"] or total == 0: progress.update(100) break - updated = status['task']['status']['updated'] - created = status['task']['status']['created'] - deleted = status['task']['status']['deleted'] - complete = (updated + created + deleted)/total + updated = status["task"]["status"]["updated"] + created = status["task"]["status"]["created"] + deleted = status["task"]["status"]["deleted"] + complete = (updated + created + deleted) / total progress.update(complete * 100) if complete == 1: break time.sleep(2) -if __name__ == '__main__': +if __name__ == "__main__": reindex() diff --git a/search/agent/__init__.py b/search/agent/__init__.py index 663cb3f0..7b2a331d 100644 --- a/search/agent/__init__.py +++ b/search/agent/__init__.py @@ -13,7 +13,7 @@ from flask import current_app as app from arxiv.base import agent -from .consumer import MetadataRecordProcessor, DocumentFailed, IndexingFailed +from search.agent.consumer import MetadataRecordProcessor def process_stream(duration: Optional[int] = None) -> None: @@ -29,6 +29,6 @@ def process_stream(duration: Optional[int] = None) -> None: """ # We use the Flask application instance for configuration, and to manage # integrations with metadata service, search index. - agent.process_stream(MetadataRecordProcessor, app.config, - duration=duration) - + agent.process_stream( + MetadataRecordProcessor, app.config, duration=duration + ) diff --git a/search/agent/consumer.py b/search/agent/consumer.py index ee00172c..0c8b4e8a 100644 --- a/search/agent/consumer.py +++ b/search/agent/consumer.py @@ -1,16 +1,17 @@ """Provides a record processor for MetadataIsAvailable notifications.""" import json -import os import time -from typing import List, Any, Optional, Dict +from typing import List, Dict, Any + +from retry.api import retry_call + from arxiv.base import logging +from arxiv.base.agent import BaseConsumer from search.services import metadata, index from search.process import transform -from search.domain import DocMeta, Document, asdict -from arxiv.base.agent import BaseConsumer +from search.domain import DocMeta, Document -from retry.api import retry_call logger = logging.getLogger(__name__) logger.propagate = False @@ -32,8 +33,10 @@ class MetadataRecordProcessor(BaseConsumer): def __init__(self, *args: Any, **kwargs: Any) -> None: """Initialize exception counter.""" - self.sleep: float = kwargs.pop('sleep', 0.1) - super(MetadataRecordProcessor, self).__init__(*args, **kwargs) # type: ignore + self.sleep: float = kwargs.pop("sleep", 0.1) + super(MetadataRecordProcessor, self).__init__( + *args, **kwargs + ) # type: ignore self._error_count = 0 # TODO: bring McCabe index down. @@ -61,28 +64,33 @@ def _get_metadata(self, arxiv_id: str) -> DocMeta: is unlikely for subsequent papers. """ - logger.debug('%s: get metadata', arxiv_id) + logger.debug("%s: get metadata", arxiv_id) try: - docmeta: DocMeta = retry_call(metadata.retrieve, (arxiv_id,), - exceptions=metadata.ConnectionFailed, - tries=2) - except metadata.ConnectionFailed as e: + docmeta: DocMeta = retry_call( + metadata.retrieve, + (arxiv_id,), + exceptions=metadata.ConnectionFailed, + tries=2, + ) + except metadata.ConnectionFailed as ex: # Things really are looking bad. There is no need to keep # trying with subsequent records, so let's abort entirely. - logger.error('%s: second attempt failed, giving up', arxiv_id) + logger.error("%s: second attempt failed, giving up", arxiv_id) raise IndexingFailed( - 'Indexing failed; metadata endpoint could not be reached.' - ) from e - except metadata.RequestFailed as e: - logger.error(f'{arxiv_id}: request failed') - raise DocumentFailed('Request to metadata service failed') from e - except metadata.BadResponse as e: - logger.error(f'{arxiv_id}: bad response from metadata service') - raise DocumentFailed('Bad response from metadata service') from e - except Exception as e: - logger.error(f'{arxiv_id}: unhandled error, metadata service: {e}') - raise IndexingFailed('Unhandled exception') from e + "Indexing failed; metadata endpoint could not be reached." + ) from ex + except metadata.RequestFailed as ex: + logger.error(f"{arxiv_id}: request failed") + raise DocumentFailed("Request to metadata service failed") from ex + except metadata.BadResponse as ex: + logger.error(f"{arxiv_id}: bad response from metadata service") + raise DocumentFailed("Bad response from metadata service") from ex + except Exception as ex: + logger.error( + f"{arxiv_id}: unhandled error, metadata service: {ex}" + ) + raise IndexingFailed("Unhandled exception") from ex return docmeta def _get_bulk_metadata(self, arxiv_ids: List[str]) -> List[DocMeta]: @@ -110,26 +118,31 @@ def _get_bulk_metadata(self, arxiv_ids: List[str]) -> List[DocMeta]: is unlikely for subsequent papers. """ - logger.debug('%s: get bulk metadata', arxiv_ids) + logger.debug("%s: get bulk metadata", arxiv_ids) meta: List[DocMeta] try: - meta = retry_call(metadata.bulk_retrieve, (arxiv_ids,), - exceptions=metadata.ConnectionFailed, - tries=2) - except metadata.ConnectionFailed as e: + meta = retry_call( + metadata.bulk_retrieve, + (arxiv_ids,), + exceptions=metadata.ConnectionFailed, + tries=2, + ) + except metadata.ConnectionFailed as ex: # Things really are looking bad. There is no need to keep # trying with subsequent records, so let's abort entirely. - logger.error('%s: second attempt failed, giving up', arxiv_ids) - raise IndexingFailed('Metadata endpoint not available') from e - except metadata.RequestFailed as e: - logger.error('%s: request failed', arxiv_ids) - raise DocumentFailed('Request to metadata service failed') from e - except metadata.BadResponse as e: - logger.error('%s: bad response from metadata service', arxiv_ids) - raise DocumentFailed('Bad response from metadata service') from e - except Exception as e: - logger.error('%s: unhandled error, metadata svc: %s', arxiv_ids, e) - raise IndexingFailed('Unhandled exception') from e + logger.error("%s: second attempt failed, giving up", arxiv_ids) + raise IndexingFailed("Metadata endpoint not available") from ex + except metadata.RequestFailed as ex: + logger.error("%s: request failed", arxiv_ids) + raise DocumentFailed("Request to metadata service failed") from ex + except metadata.BadResponse as ex: + logger.error("%s: bad response from metadata service", arxiv_ids) + raise DocumentFailed("Bad response from metadata service") from ex + except Exception as ex: + logger.error( + "%s: unhandled error, metadata svc: %s", arxiv_ids, ex + ) + raise IndexingFailed("Unhandled exception") from ex return meta @staticmethod @@ -156,10 +169,10 @@ def _transform_to_document(docmeta: DocMeta) -> Document: """ try: document = transform.to_search_document(docmeta) - except Exception as e: + except Exception as ex: # At the moment we don't have any special exceptions. - logger.error('unhandled exception during transform: %s', e) - raise DocumentFailed('Could not transform document') from e + logger.error("unhandled exception during transform: %s", ex) + raise DocumentFailed("Could not transform document") from ex return document @@ -180,13 +193,17 @@ def _add_to_index(document: Document) -> None: """ try: - retry_call(index.SearchSession.add_document, (document,), - exceptions=index.IndexConnectionError, tries=2) - except index.IndexConnectionError as e: - raise IndexingFailed('Could not index document') from e - except Exception as e: - logger.error(f'Unhandled exception from index service: {e}') - raise IndexingFailed('Unhandled exception') from e + retry_call( + index.SearchSession.add_document, + (document,), + exceptions=index.IndexConnectionError, + tries=2, + ) + except index.IndexConnectionError as ex: + raise IndexingFailed("Could not index document") from ex + except Exception as ex: + logger.error(f"Unhandled exception from index service: {ex}") + raise IndexingFailed("Unhandled exception") from ex @staticmethod def _bulk_add_to_index(documents: List[Document]) -> None: @@ -205,13 +222,17 @@ def _bulk_add_to_index(documents: List[Document]) -> None: """ try: - retry_call(index.SearchSession.bulk_add_documents, (documents,), - exceptions=index.IndexConnectionError, tries=2) - except index.IndexConnectionError as e: - raise IndexingFailed('Could not bulk index documents') from e - except Exception as e: - logger.error(f'Unhandled exception from index service: {e}') - raise IndexingFailed('Unhandled exception') from e + retry_call( + index.SearchSession.bulk_add_documents, + (documents,), + exceptions=index.IndexConnectionError, + tries=2, + ) + except index.IndexConnectionError as ex: + raise IndexingFailed("Could not bulk index documents") from ex + except Exception as ex: + logger.error(f"Unhandled exception from index service: {ex}") + raise IndexingFailed("Unhandled exception") from ex def index_paper(self, arxiv_id: str) -> None: """ @@ -247,19 +268,20 @@ def index_papers(self, arxiv_ids: List[str]) -> None: try: documents = [] for docmeta in self._get_bulk_metadata(arxiv_ids): - logger.debug('%s: transform to Document', docmeta.paper_id) + logger.debug("%s: transform to Document", docmeta.paper_id) document = MetadataRecordProcessor._transform_to_document( docmeta ) documents.append(document) - logger.debug('add to index in bulk') + logger.debug("add to index in bulk") MetadataRecordProcessor._bulk_add_to_index(documents) - except (DocumentFailed, IndexingFailed) as e: + except (DocumentFailed, IndexingFailed) as ex: # We just pass these along so that process_record() can keep track. - logger.debug(f'{arxiv_ids}: Document failed: {e}') - raise e + logger.debug(f"{arxiv_ids}: Document failed: {ex}") + raise ex - def process_record(self, record: dict) -> None: + # FIXME: Argument type. + def process_record(self, record: Dict[Any, Any]) -> None: """ Call for each record that is passed to process_records. @@ -281,22 +303,22 @@ def process_record(self, record: dict) -> None: time.sleep(self.sleep) logger.info(f'Processing record {record["SequenceNumber"]}') if self._error_count > self.MAX_ERRORS: - raise IndexingFailed('Too many errors') + raise IndexingFailed("Too many errors") try: - deserialized = json.loads(record['Data'].decode('utf-8')) - except json.decoder.JSONDecodeError as e: - logger.error("Error while deserializing data %s", e) - logger.error("Data payload: %s", record['Data']) - raise DocumentFailed('Could not deserialize record data') + deserialized = json.loads(record["Data"].decode("utf-8")) + except json.decoder.JSONDecodeError as ex: + logger.error("Error while deserializing data %s", ex) + logger.error("Data payload: %s", record["Data"]) + raise DocumentFailed("Could not deserialize record data") # return # Don't bring down the whole batch. try: - arxiv_id: str = deserialized.get('document_id') + arxiv_id: str = deserialized.get("document_id") self.index_paper(arxiv_id) - except DocumentFailed as e: - logger.debug('%s: failed to index document: %s', arxiv_id, e) + except DocumentFailed as ex: + logger.debug("%s: failed to index document: %s", arxiv_id, ex) self._error_count += 1 - except IndexingFailed as e: - logger.error('Indexing failed: %s', e) + except IndexingFailed as ex: + logger.error("Indexing failed: %s", ex) raise diff --git a/search/agent/tests/test_integration.py b/search/agent/tests/test_integration.py index 9fed6ca3..39a82a85 100644 --- a/search/agent/tests/test_integration.py +++ b/search/agent/tests/test_integration.py @@ -2,118 +2,125 @@ from unittest import TestCase, mock import os +import json import time -import subprocess -import tempfile import boto3 -import json -import threading +import tempfile +import subprocess -from search.agent import process_stream from arxiv.base.agent import StopProcessing -from search.services import metadata from search.domain import DocMeta +from search.services import metadata +from search.agent import process_stream from search.factory import create_ui_web_app -BASE_PATH = os.path.join(os.path.split(os.path.abspath(__file__))[0], - '../../../tests/data/examples') + +BASE_PATH = os.path.join( + os.path.split(os.path.abspath(__file__))[0], "../../../tests/data/examples" +) class TestKinesisIntegration(TestCase): """Test :class:`.MetadataRecordProcessor` with a live Kinesis stream.""" - __test__ = int(bool(os.environ.get('WITH_INTEGRATION', False))) + __test__ = int(bool(os.environ.get("WITH_INTEGRATION", False))) @classmethod def setUpClass(cls): """Spin up ES and index documents.""" - os.environ['ELASTICSEARCH_SERVICE_HOST'] = 'localhost' - os.environ['ELASTICSEARCH_SERVICE_PORT'] = "9201" - os.environ['ELASTICSEARCH_PORT_9201_PROTO'] = "http" - os.environ['ELASTICSEARCH_VERIFY'] = 'false' - - os.environ['KINESIS_STREAM'] = 'MetadataIsAvailable' - os.environ['KINESIS_SHARD_ID'] = '0' - os.environ['KINESIS_CHECKPOINT_VOLUME'] = tempfile.mkdtemp() - os.environ['KINESIS_ENDPOINT'] = 'http://127.0.0.1:6568' - os.environ['KINESIS_VERIFY'] = 'false' - os.environ['KINESIS_START_TYPE'] = 'TRIM_HORIZON' - - print('pulling localstack image') - pull_localstack = subprocess.run( + os.environ["ELASTICSEARCH_SERVICE_HOST"] = "localhost" + os.environ["ELASTICSEARCH_SERVICE_PORT"] = "9201" + os.environ["ELASTICSEARCH_SERVICE_PORT_9201_PROTO"] = "http" + os.environ["ELASTICSEARCH_VERIFY"] = "false" + + os.environ["KINESIS_STREAM"] = "MetadataIsAvailable" + os.environ["KINESIS_SHARD_ID"] = "0" + os.environ["KINESIS_CHECKPOINT_VOLUME"] = tempfile.mkdtemp() + os.environ["KINESIS_ENDPOINT"] = "http://127.0.0.1:6568" + os.environ["KINESIS_VERIFY"] = "false" + os.environ["KINESIS_START_TYPE"] = "TRIM_HORIZON" + + print("pulling localstack image") + _ = subprocess.run( "docker pull atlassianlabs/localstack", - stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True, ) - print('starting localstack') + print("starting localstack") start_localstack = subprocess.run( "docker run -d -p 6568:4568 --name ltest atlassianlabs/localstack", - stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True, + ) if start_localstack.returncode != 0: raise RuntimeError( - f'Could not start localstack: {start_localstack.stdout}.' - f' Is one already running? Is port 6568 available?' + f"Could not start localstack: {start_localstack.stdout}." + f" Is one already running? Is port 6568 available?" ) - cls.ls_container = start_localstack.stdout.decode('ascii').strip() - print(f'localstack started as {cls.ls_container}') + cls.ls_container = start_localstack.stdout.decode("ascii").strip() + print(f"localstack started as {cls.ls_container}") cls.client = boto3.client( - 'kinesis', - region_name='us-east-1', + "kinesis", + region_name="us-east-1", endpoint_url="http://localhost:6568", - aws_access_key_id='foo', - aws_secret_access_key='bar', - verify=False + aws_access_key_id="foo", + aws_secret_access_key="bar", + verify=False, ) - print('creating stream ahead of time, to populate with records') + print("creating stream ahead of time, to populate with records") cls.client.create_stream( - StreamName='MetadataIsAvailable', - ShardCount=1 + StreamName="MetadataIsAvailable", ShardCount=1 ) time.sleep(5) - print('created stream, ready to test') + print("created stream, ready to test") cls.app = create_ui_web_app() @classmethod def tearDownClass(cls): """Tear down Elasticsearch once all tests have run.""" - stop_es = subprocess.run(f"docker rm -f {cls.ls_container}", - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - shell=True) + _ = subprocess.run( + f"docker rm -f {cls.ls_container}", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True, + ) - @mock.patch('search.agent.consumer.index') - @mock.patch('search.agent.consumer.metadata') + @mock.patch("search.agent.consumer.index") + @mock.patch("search.agent.consumer.metadata") def test_process(self, mock_metadata, mock_index): """Add some records to the stream, and run processing loop for 5s.""" to_index = [ - "1712.04442", # flux capacitor - "1511.07473", # flux capacitor - "1604.04228", # flux capacitor - "1403.6219", # λ - "1404.3450", # $z_1$ - "1703.09067", # $\lambda$ - "1408.6682", # $\lambda$ - "1607.05107", # Schröder - "1509.08727", # Schroder - "1710.01597", # Schroeder - "1708.07156", # w w - "1401.1012", # Wonmin Son + "1712.04442", # flux capacitor + "1511.07473", # flux capacitor + "1604.04228", # flux capacitor + "1403.6219", # λ + "1404.3450", # $z_1$ + "1703.09067", # $\lambda$ + "1408.6682", # $\lambda$ + "1607.05107", # Schröder + "1509.08727", # Schroder + "1710.01597", # Schroeder + "1708.07156", # w w + "1401.1012", # Wonmin Son ] for document_id in to_index: data = bytes( - json.dumps({'document_id': document_id}), - encoding='utf-8' + json.dumps({"document_id": document_id}), encoding="utf-8" ) self.client.put_record( - StreamName='MetadataIsAvailable', Data=data, PartitionKey='0' + StreamName="MetadataIsAvailable", Data=data, PartitionKey="0" ) def retrieve(document_id): - with open(os.path.join(BASE_PATH, f'{document_id}.json')) as f: + with open(os.path.join(BASE_PATH, f"{document_id}.json")) as f: return DocMeta(**json.load(f)) + mock_metadata.retrieve.side_effect = retrieve # Preserve exceptions diff --git a/search/agent/tests/test_record_processor.py b/search/agent/tests/test_record_processor.py index e01ae644..56adf97e 100644 --- a/search/agent/tests/test_record_processor.py +++ b/search/agent/tests/test_record_processor.py @@ -15,15 +15,22 @@ class TestIndexPaper(TestCase): def setUp(self): """Initialize a :class:`.MetadataRecordProcessor`.""" self.checkpointer = mock.MagicMock() - self.args = ('foo', '1', 'a1b2c3d4', 'qwertyuiop', 'us-east-1', - self.checkpointer) - - @mock.patch('boto3.client') - @mock.patch('search.agent.consumer.index.SearchSession') - @mock.patch('search.agent.consumer.transform') - @mock.patch('search.agent.consumer.metadata') - def test_paper_has_one_version(self, mock_meta, mock_tx, mock_idx, - mock_client_factory): + self.args = ( + "foo", + "1", + "a1b2c3d4", + "qwertyuiop", + "us-east-1", + self.checkpointer, + ) + + @mock.patch("boto3.client") + @mock.patch("search.agent.consumer.index.SearchSession") + @mock.patch("search.agent.consumer.transform") + @mock.patch("search.agent.consumer.metadata") + def test_paper_has_one_version( + self, mock_meta, mock_tx, mock_idx, mock_client_factory + ): """The arXiv paper has only one version.""" mock_client = mock.MagicMock() mock_waiter = mock.MagicMock() @@ -31,25 +38,34 @@ def test_paper_has_one_version(self, mock_meta, mock_tx, mock_idx, mock_client_factory.return_value = mock_client processor = consumer.MetadataRecordProcessor(*self.args) - mock_docmeta = DocMeta(version=1, paper_id='1234.56789', title='foo', - submitted_date='2001-03-02T03:04:05-400') + mock_docmeta = DocMeta( + version=1, + paper_id="1234.56789", + title="foo", + submitted_date="2001-03-02T03:04:05-400", + ) mock_meta.retrieve.return_value = mock_docmeta mock_meta.bulk_retrieve.return_value = [mock_docmeta] - mock_doc = Document(version=1, paper_id='1234.56789', title='foo', - submitted_date=['2001-03-02T03:04:05-400']) + mock_doc = Document( + version=1, + paper_id="1234.56789", + title="foo", + submitted_date=["2001-03-02T03:04:05-400"], + ) mock_tx.to_search_document.return_value = mock_doc - processor.index_paper('1234.56789') + processor.index_paper("1234.56789") mock_idx.bulk_add_documents.assert_called_once_with([mock_doc]) - @mock.patch('boto3.client') - @mock.patch('search.agent.consumer.index.SearchSession') - @mock.patch('search.agent.consumer.transform') - @mock.patch('search.agent.consumer.metadata') - def test_paper_has_three_versions(self, mock_meta, mock_tx, mock_idx, - mock_client_factory): + @mock.patch("boto3.client") + @mock.patch("search.agent.consumer.index.SearchSession") + @mock.patch("search.agent.consumer.transform") + @mock.patch("search.agent.consumer.metadata") + def test_paper_has_three_versions( + self, mock_meta, mock_tx, mock_idx, mock_client_factory + ): """The arXiv paper has three versions.""" mock_client = mock.MagicMock() mock_waiter = mock.MagicMock() @@ -57,49 +73,83 @@ def test_paper_has_three_versions(self, mock_meta, mock_tx, mock_idx, mock_client_factory.return_value = mock_client processor = consumer.MetadataRecordProcessor(*self.args) - mock_dm_1 = DocMeta(version=1, paper_id='1234.56789', title='foo', - submitted_date='2001-03-02T03:04:05-400') - mock_dm_2 = DocMeta(version=2, paper_id='1234.56789', title='foo', - submitted_date='2001-03-03T03:04:05-400') - mock_dm_3 = DocMeta(version=3, paper_id='1234.56789', title='foo', - submitted_date='2001-03-04T03:04:05-400') + mock_dm_1 = DocMeta( + version=1, + paper_id="1234.56789", + title="foo", + submitted_date="2001-03-02T03:04:05-400", + ) + mock_dm_2 = DocMeta( + version=2, + paper_id="1234.56789", + title="foo", + submitted_date="2001-03-03T03:04:05-400", + ) + mock_dm_3 = DocMeta( + version=3, + paper_id="1234.56789", + title="foo", + submitted_date="2001-03-04T03:04:05-400", + ) mock_meta.retrieve.side_effect = [mock_dm_3, mock_dm_1, mock_dm_2] mock_meta.bulk_retrieve.return_value = [ - mock_dm_3, mock_dm_1, mock_dm_2, mock_dm_3 + mock_dm_3, + mock_dm_1, + mock_dm_2, + mock_dm_3, ] - mock_doc_1 = Document(version=1, paper_id='1234.56789', title='foo', - submitted_date=['2001-03-02T03:04:05-400'], - submitted_date_all=[ - '2001-03-02T03:04:05-400', - ]) - mock_doc_2 = Document(version=2, paper_id='1234.56789', title='foo', - submitted_date=['2001-03-03T03:04:05-400'], - submitted_date_all=[ - '2001-03-02T03:04:05-400', - '2001-03-03T03:04:05-400', - ]) - mock_doc_3 = Document(version=3, paper_id='1234.56789', title='foo', - submitted_date=['2001-03-04T03:04:05-400'], - submitted_date_all=[ - '2001-03-02T03:04:05-400', - '2001-03-03T03:04:05-400', - '2001-03-04T03:04:05-400' - ]) + mock_doc_1 = Document( + version=1, + paper_id="1234.56789", + title="foo", + submitted_date=["2001-03-02T03:04:05-400"], + submitted_date_all=["2001-03-02T03:04:05-400"], + ) + mock_doc_2 = Document( + version=2, + paper_id="1234.56789", + title="foo", + submitted_date=["2001-03-03T03:04:05-400"], + submitted_date_all=[ + "2001-03-02T03:04:05-400", + "2001-03-03T03:04:05-400", + ], + ) + mock_doc_3 = Document( + version=3, + paper_id="1234.56789", + title="foo", + submitted_date=["2001-03-04T03:04:05-400"], + submitted_date_all=[ + "2001-03-02T03:04:05-400", + "2001-03-03T03:04:05-400", + "2001-03-04T03:04:05-400", + ], + ) mock_tx.to_search_document.side_effect = [ - mock_doc_3, mock_doc_1, mock_doc_2, mock_doc_3 + mock_doc_3, + mock_doc_1, + mock_doc_2, + mock_doc_3, ] - processor.index_paper('1234.56789') - self.assertEqual(mock_meta.bulk_retrieve.call_count, 1, - "Metadata should be retrieved for current version" - " with bulk_retrieve") - self.assertEqual(mock_meta.retrieve.call_count, 0, - "Metadata should be retrieved for each non-current" - " version") + processor.index_paper("1234.56789") + self.assertEqual( + mock_meta.bulk_retrieve.call_count, + 1, + "Metadata should be retrieved for current version" + " with bulk_retrieve", + ) + self.assertEqual( + mock_meta.retrieve.call_count, + 0, + "Metadata should be retrieved for each non-current" " version", + ) mock_idx.bulk_add_documents.assert_called_once_with( - [mock_doc_3, mock_doc_1, mock_doc_2, mock_doc_3]) + [mock_doc_3, mock_doc_1, mock_doc_2, mock_doc_3] + ) class TestAddToIndex(TestCase): @@ -108,11 +158,17 @@ class TestAddToIndex(TestCase): def setUp(self): """Initialize a :class:`.MetadataRecordProcessor`.""" self.checkpointer = mock.MagicMock() - self.args = ('foo', '1', 'a1b2c3d4', 'qwertyuiop', 'us-east-1', - self.checkpointer) - - @mock.patch('boto3.client') - @mock.patch('search.agent.consumer.index.SearchSession') + self.args = ( + "foo", + "1", + "a1b2c3d4", + "qwertyuiop", + "us-east-1", + self.checkpointer, + ) + + @mock.patch("boto3.client") + @mock.patch("search.agent.consumer.index.SearchSession") def test_add_document_succeeds(self, mock_index, mock_client_factory): """The search document is added successfully.""" mock_client = mock.MagicMock() @@ -122,14 +178,15 @@ def test_add_document_succeeds(self, mock_index, mock_client_factory): processor = consumer.MetadataRecordProcessor(*self.args) try: processor._add_to_index(Document()) - except Exception as e: - self.fail(e) + except Exception as ex: + self.fail(ex) mock_index.add_document.assert_called_once() - @mock.patch('boto3.client') - @mock.patch('search.agent.consumer.index.SearchSession') - def test_index_raises_index_connection_error(self, mock_index, - mock_client_factory): + @mock.patch("boto3.client") + @mock.patch("search.agent.consumer.index.SearchSession") + def test_index_raises_index_connection_error( + self, mock_index, mock_client_factory + ): """The index raises :class:`.index.IndexConnectionError`.""" mock_client = mock.MagicMock() mock_waiter = mock.MagicMock() @@ -141,10 +198,11 @@ def test_index_raises_index_connection_error(self, mock_index, with self.assertRaises(consumer.IndexingFailed): processor._add_to_index(Document()) - @mock.patch('boto3.client') - @mock.patch('search.agent.consumer.index.SearchSession') - def test_index_raises_unhandled_error(self, mock_index, - mock_client_factory): + @mock.patch("boto3.client") + @mock.patch("search.agent.consumer.index.SearchSession") + def test_index_raises_unhandled_error( + self, mock_index, mock_client_factory + ): """The index raises an unhandled exception.""" mock_client = mock.MagicMock() mock_waiter = mock.MagicMock() @@ -163,13 +221,20 @@ class TestBulkAddToIndex(TestCase): def setUp(self): """Initialize a :class:`.MetadataRecordProcessor`.""" self.checkpointer = mock.MagicMock() - self.args = ('foo', '1', 'a1b2c3d4', 'qwertyuiop', 'us-east-1', - self.checkpointer) - - @mock.patch('boto3.client') - @mock.patch('search.agent.consumer.index.SearchSession') - def test_bulk_add_documents_succeeds(self, mock_index, - mock_client_factory): + self.args = ( + "foo", + "1", + "a1b2c3d4", + "qwertyuiop", + "us-east-1", + self.checkpointer, + ) + + @mock.patch("boto3.client") + @mock.patch("search.agent.consumer.index.SearchSession") + def test_bulk_add_documents_succeeds( + self, mock_index, mock_client_factory + ): """The search document is added successfully.""" mock_client = mock.MagicMock() mock_waiter = mock.MagicMock() @@ -178,14 +243,15 @@ def test_bulk_add_documents_succeeds(self, mock_index, processor = consumer.MetadataRecordProcessor(*self.args) try: processor._bulk_add_to_index([Document()]) - except Exception as e: - self.fail(e) + except Exception as ex: + self.fail(ex) mock_index.bulk_add_documents.assert_called_once() - @mock.patch('boto3.client') - @mock.patch('search.agent.consumer.index.SearchSession') - def test_index_raises_index_connection_error(self, mock_index, - mock_client_factory): + @mock.patch("boto3.client") + @mock.patch("search.agent.consumer.index.SearchSession") + def test_index_raises_index_connection_error( + self, mock_index, mock_client_factory + ): """The index raises :class:`.index.IndexConnectionError`.""" mock_client = mock.MagicMock() mock_waiter = mock.MagicMock() @@ -197,10 +263,11 @@ def test_index_raises_index_connection_error(self, mock_index, with self.assertRaises(consumer.IndexingFailed): processor._bulk_add_to_index([Document()]) - @mock.patch('boto3.client') - @mock.patch('search.agent.consumer.index.SearchSession') - def test_index_raises_unhandled_error(self, mock_index, - mock_client_factory): + @mock.patch("boto3.client") + @mock.patch("search.agent.consumer.index.SearchSession") + def test_index_raises_unhandled_error( + self, mock_index, mock_client_factory + ): """The index raises an unhandled exception.""" mock_client = mock.MagicMock() mock_waiter = mock.MagicMock() @@ -219,13 +286,20 @@ class TestTransformToDocument(TestCase): def setUp(self): """Initialize a :class:`.MetadataRecordProcessor`.""" self.checkpointer = mock.MagicMock() - self.args = ('foo', '1', 'a1b2c3d4', 'qwertyuiop', 'us-east-1', - self.checkpointer) - - @mock.patch('boto3.client') - @mock.patch('search.agent.consumer.transform') - def test_transform_raises_exception(self, mock_transform, - mock_client_factory): + self.args = ( + "foo", + "1", + "a1b2c3d4", + "qwertyuiop", + "us-east-1", + self.checkpointer, + ) + + @mock.patch("boto3.client") + @mock.patch("search.agent.consumer.transform") + def test_transform_raises_exception( + self, mock_transform, mock_client_factory + ): """The transform module raises an exception.""" mock_client = mock.MagicMock() mock_waiter = mock.MagicMock() @@ -244,13 +318,20 @@ class TestGetMetadata(TestCase): def setUp(self): """Initialize a :class:`.MetadataRecordProcessor`.""" self.checkpointer = mock.MagicMock() - self.args = ('foo', '1', 'a1b2c3d4', 'qwertyuiop', 'us-east-1', - self.checkpointer) - - @mock.patch('boto3.client') - @mock.patch('search.agent.consumer.metadata') - def test_metadata_service_returns_metadata(self, mock_metadata, - mock_client_factory): + self.args = ( + "foo", + "1", + "a1b2c3d4", + "qwertyuiop", + "us-east-1", + self.checkpointer, + ) + + @mock.patch("boto3.client") + @mock.patch("search.agent.consumer.metadata") + def test_metadata_service_returns_metadata( + self, mock_metadata, mock_client_factory + ): """The metadata service returns valid metadata.""" mock_client = mock.MagicMock() mock_waiter = mock.MagicMock() @@ -260,13 +341,17 @@ def test_metadata_service_returns_metadata(self, mock_metadata, docmeta = DocMeta() mock_metadata.retrieve.return_value = docmeta - self.assertEqual(docmeta, processor._get_metadata('1234.5678'), - "The metadata is returned.") - - @mock.patch('boto3.client') - @mock.patch('search.agent.consumer.metadata') - def test_metadata_service_raises_connection_error(self, mock_metadata, - mock_client_factory): + self.assertEqual( + docmeta, + processor._get_metadata("1234.5678"), + "The metadata is returned.", + ) + + @mock.patch("boto3.client") + @mock.patch("search.agent.consumer.metadata") + def test_metadata_service_raises_connection_error( + self, mock_metadata, mock_client_factory + ): """The metadata service raises :class:`.metadata.ConnectionFailed`.""" mock_metadata.RequestFailed = metadata.RequestFailed mock_metadata.ConnectionFailed = metadata.ConnectionFailed @@ -279,12 +364,13 @@ def test_metadata_service_raises_connection_error(self, mock_metadata, mock_metadata.retrieve.side_effect = metadata.ConnectionFailed with self.assertRaises(consumer.IndexingFailed): - processor._get_metadata('1234.5678') + processor._get_metadata("1234.5678") - @mock.patch('boto3.client') - @mock.patch('search.agent.consumer.metadata') - def test_metadata_service_raises_request_error(self, mock_metadata, - mock_client_factory): + @mock.patch("boto3.client") + @mock.patch("search.agent.consumer.metadata") + def test_metadata_service_raises_request_error( + self, mock_metadata, mock_client_factory + ): """The metadata service raises :class:`.metadata.RequestFailed`.""" mock_metadata.RequestFailed = metadata.RequestFailed mock_metadata.ConnectionFailed = metadata.ConnectionFailed @@ -298,12 +384,13 @@ def test_metadata_service_raises_request_error(self, mock_metadata, mock_metadata.retrieve.side_effect = metadata.RequestFailed with self.assertRaises(consumer.DocumentFailed): - processor._get_metadata('1234.5678') + processor._get_metadata("1234.5678") - @mock.patch('boto3.client') - @mock.patch('search.agent.consumer.metadata') - def test_metadata_service_raises_bad_response(self, mock_metadata, - mock_client_factory): + @mock.patch("boto3.client") + @mock.patch("search.agent.consumer.metadata") + def test_metadata_service_raises_bad_response( + self, mock_metadata, mock_client_factory + ): """The metadata service raises :class:`.metadata.BadResponse`.""" mock_metadata.RequestFailed = metadata.RequestFailed mock_metadata.ConnectionFailed = metadata.ConnectionFailed @@ -317,4 +404,4 @@ def test_metadata_service_raises_bad_response(self, mock_metadata, mock_metadata.retrieve.side_effect = metadata.BadResponse with self.assertRaises(consumer.DocumentFailed): - processor._get_metadata('1234.5678') + processor._get_metadata("1234.5678") diff --git a/search/config.py b/search/config.py index 3fa3337c..b14b9daa 100644 --- a/search/config.py +++ b/search/config.py @@ -6,28 +6,30 @@ """ import os -APP_VERSION = '0.5.5' +APP_VERSION = "0.5.6" """The application version """ -ON = 'yes' -OFF = 'no' +ON = "yes" +OFF = "no" -DEBUG = os.environ.get('DEBUG') == ON +DEBUG = os.environ.get("DEBUG") == ON """enable/disable debug mode""" -TESTING = os.environ.get('TESTING') == ON +TESTING = os.environ.get("TESTING") == ON """enable/disable testing mode""" -PROPAGATE_EXCEPTIONS = \ - True if os.environ.get('PROPAGATE_EXCEPTIONS') == ON else None +PROPAGATE_EXCEPTIONS = ( + True if os.environ.get("PROPAGATE_EXCEPTIONS") == ON else None +) """ explicitly enable or disable the propagation of exceptions. If not set or explicitly set to None this is implicitly true if either TESTING or DEBUG is true. """ -PRESERVE_CONTEXT_ON_EXCEPTION = \ - True if os.environ.get('PRESERVE_CONTEXT_ON_EXCEPTION') == ON else None +PRESERVE_CONTEXT_ON_EXCEPTION = ( + True if os.environ.get("PRESERVE_CONTEXT_ON_EXCEPTION") == ON else None +) """ By default if the application is in debug mode the request context is not popped on exceptions to enable debuggers to introspect the data. This can be @@ -37,13 +39,13 @@ """ -USE_X_SENDFILE = os.environ.get('USE_X_SENDFILE') == ON +USE_X_SENDFILE = os.environ.get("USE_X_SENDFILE") == ON """Enable/disable x-sendfile""" -LOGGER_NAME = os.environ.get('LOGGER_NAME', 'search') +LOGGER_NAME = os.environ.get("LOGGER_NAME", "search") """The name of the logger.""" -LOGGER_HANDLER_POLICY = os.environ.get('LOGGER_HANDLER_POLICY', 'debug') +LOGGER_HANDLER_POLICY = os.environ.get("LOGGER_HANDLER_POLICY", "debug") """ the policy of the default logging handler. The default is 'always' which means that the default logging handler is always active. 'debug' will only activate @@ -51,7 +53,7 @@ disables it entirely. """ -SERVER_NAME = os.environ.get('SEARCH_SERVER_NAME', None) +SERVER_NAME = os.environ.get("SEARCH_SERVER_NAME", None) """ the name and port number of the server. Required for subdomain support (e.g.: 'myapp.dev:5000') Note that localhost does not support subdomains so @@ -60,21 +62,22 @@ application context. """ -APPLICATION_ROOT = os.environ.get('APPLICATION_ROOT', '/') +APPLICATION_ROOT = os.environ.get("APPLICATION_ROOT", "/") """ If the application does not occupy a whole domain or subdomain this can be set to the path where the application is configured to live. This is for session cookie as path value. If domains are used, this should be None. """ -MAX_CONTENT_LENGTH = os.environ.get('MAX_CONTENT_LENGTH', None) +MAX_CONTENT_LENGTH = os.environ.get("MAX_CONTENT_LENGTH", None) """ If set to a value in bytes, Flask will reject incoming requests with a content length greater than this by returning a 413 status code. """ -SEND_FILE_MAX_AGE_DEFAULT = int(os.environ.get('SEND_FILE_MAX_AGE_DEFAULT', - 43200)) +SEND_FILE_MAX_AGE_DEFAULT = int( + os.environ.get("SEND_FILE_MAX_AGE_DEFAULT", 43200) +) """ Default cache control max age to use with send_static_file() (the default static file handler) and send_file(), as datetime.timedelta or as seconds. @@ -82,7 +85,7 @@ on Flask or Blueprint, respectively. Defaults to 43200 (12 hours). """ -TRAP_HTTP_EXCEPTIONS = os.environ.get('TRAP_HTTP_EXCEPTIONS') == ON +TRAP_HTTP_EXCEPTIONS = os.environ.get("TRAP_HTTP_EXCEPTIONS") == ON """ If this is set to True Flask will not execute the error handlers of HTTP exceptions but instead treat the exception like any other and bubble it through @@ -90,7 +93,7 @@ have to find out where an HTTP exception is coming from. """ -TRAP_BAD_REQUEST_ERRORS = os.environ.get('TRAP_BAD_REQUEST_ERRORS') == ON +TRAP_BAD_REQUEST_ERRORS = os.environ.get("TRAP_BAD_REQUEST_ERRORS") == ON """ Werkzeug's internal data structures that deal with request specific data will raise special key errors that are also bad request exceptions. Likewise many @@ -100,13 +103,13 @@ regular traceback instead. """ -PREFERRED_URL_SCHEME = os.environ.get('PREFERRED_URL_SCHEME', 'http') +PREFERRED_URL_SCHEME = os.environ.get("PREFERRED_URL_SCHEME", "http") """ The URL scheme that should be used for URL generation if no URL scheme is available. This defaults to http. """ -JSON_AS_ASCII = os.environ.get('JSON_AS_ASCII') == ON +JSON_AS_ASCII = os.environ.get("JSON_AS_ASCII") == ON """ By default Flask serialize object to ascii-encoded JSON. If this is set to False Flask will not encode to ASCII and output strings as-is and return @@ -114,36 +117,36 @@ transport for instance. """ -JSON_SORT_KEYS = os.environ.get('JSON_AS_ASCII') != OFF +JSON_SORT_KEYS = os.environ.get("JSON_AS_ASCII") != OFF """ -By default Flask will serialize JSON objects in a way that the keys are ordered. -This is done in order to ensure that independent of the hash seed of the -dictionary the return value will be consistent to not trash external HTTP +By default Flask will serialize JSON objects in a way that the keys are +ordered. This is done in order to ensure that independent of the hash seed of +the dictionary the return value will be consistent to not trash external HTTP caches. You can override the default behavior by changing this variable. This is not recommended but might give you a performance improvement on the cost of cacheability. """ -JSONIFY_PRETTYPRINT_REGULAR = os.environ.get('JSON_AS_ASCII') != OFF +JSONIFY_PRETTYPRINT_REGULAR = os.environ.get("JSON_AS_ASCII") != OFF """ If this is set to True (the default) jsonify responses will be pretty printed if they are not requested by an XMLHttpRequest object (controlled by the X-Requested-With header). """ -JSONIFY_MIMETYPE = os.environ.get('JSONIFY_MIMETYPE', 'application/json') +JSONIFY_MIMETYPE = os.environ.get("JSONIFY_MIMETYPE", "application/json") """ MIME type used for jsonify responses. """ -TEMPLATES_AUTO_RELOAD = os.environ.get('TEMPLATES_AUTO_RELOAD') == ON +TEMPLATES_AUTO_RELOAD = os.environ.get("TEMPLATES_AUTO_RELOAD") == ON """ Whether to check for modifications of the template source and reload it automatically. By default the value is None which means that Flask checks original file only in debug mode. """ -EXPLAIN_TEMPLATE_LOADING = os.environ.get('EXPLAIN_TEMPLATE_LOADING') == ON +EXPLAIN_TEMPLATE_LOADING = os.environ.get("EXPLAIN_TEMPLATE_LOADING") == ON """ If this is enabled then every attempt to load a template will write an info message to the logger explaining the attempts to locate the template. This can @@ -152,66 +155,69 @@ """ # AWS credentials. -AWS_ACCESS_KEY_ID = os.environ.get('AWS_ACCESS_KEY_ID', 'nope') -AWS_SECRET_ACCESS_KEY = os.environ.get('AWS_SECRET_ACCESS_KEY', 'nope') -AWS_REGION = os.environ.get('AWS_REGION', 'us-east-1') +AWS_ACCESS_KEY_ID = os.environ.get("AWS_ACCESS_KEY_ID", "nope") +AWS_SECRET_ACCESS_KEY = os.environ.get("AWS_SECRET_ACCESS_KEY", "nope") +AWS_REGION = os.environ.get("AWS_REGION", "us-east-1") -LOGFILE = os.environ.get('LOGFILE') -LOGLEVEL = os.environ.get('LOGLEVEL', 40) +LOGFILE = os.environ.get("LOGFILE") +LOGLEVEL = os.environ.get("LOGLEVEL", 40) """ Log level for search service. See ``_ . """ -ELASTICSEARCH_HOST = os.environ.get('ELASTICSEARCH_SERVICE_HOST', 'localhost') -ELASTICSEARCH_PORT = os.environ.get('ELASTICSEARCH_SERVICE_PORT', '9200') -ELASTICSEARCH_SCHEME = os.environ.get( - 'ELASTICSEARCH_PORT_%s_PROTO' % ELASTICSEARCH_PORT, 'http' +ELASTICSEARCH_SERVICE_HOST = os.environ.get( + "ELASTICSEARCH_SERVICE_HOST", "localhost" ) -ELASTICSEARCH_INDEX = os.environ.get('ELASTICSEARCH_INDEX', 'arxiv') -ELASTICSEARCH_USER = os.environ.get('ELASTICSEARCH_USER', None) -ELASTICSEARCH_PASSWORD = os.environ.get('ELASTICSEARCH_PASSWORD', None) -ELASTICSEARCH_VERIFY = os.environ.get('ELASTICSEARCH_VERIFY', 'true') +ELASTICSEARCH_SERVICE_PORT = os.environ.get( + "ELASTICSEARCH_SERVICE_PORT", "9200" +) +_proto_key = "ELASTICSEARCH_SERVICE_PORT_%s_PROTO" % ELASTICSEARCH_SERVICE_PORT +locals()[_proto_key] = os.environ.get(_proto_key, "http") + +ELASTICSEARCH_INDEX = os.environ.get("ELASTICSEARCH_INDEX", "arxiv") +ELASTICSEARCH_USER = os.environ.get("ELASTICSEARCH_USER", None) +ELASTICSEARCH_PASSWORD = os.environ.get("ELASTICSEARCH_PASSWORD", None) +ELASTICSEARCH_VERIFY = os.environ.get("ELASTICSEARCH_VERIFY", "true") """Indicates whether SSL certificate verification for ES should be enforced.""" -METADATA_ENDPOINT = os.environ.get('METADATA_ENDPOINT', - 'https://arxiv.org/') +METADATA_ENDPOINT = os.environ.get("METADATA_ENDPOINT", "https://arxiv.org/") """ Location of endpoint(s) for metadata retrieval. Multiple endpoints may be provided with comma delimitation. """ -METADATA_CACHE_DIR = os.environ.get('METADATA_CACHE_DIR') +METADATA_CACHE_DIR = os.environ.get("METADATA_CACHE_DIR") """Cache directory for metadata documents.""" -METADATA_VERIFY_CERT = os.environ.get('METADATA_VERIFY_CERT', 'True') +METADATA_VERIFY_CERT = os.environ.get("METADATA_VERIFY_CERT", "True") """If ``False``, SSL certificate verification will be disabled.""" -FULLTEXT_ENDPOINT = os.environ.get('FULLTEXT_ENDPOINT', - 'https://fulltext.arxiv.org/fulltext/') +FULLTEXT_ENDPOINT = os.environ.get( + "FULLTEXT_ENDPOINT", "https://fulltext.arxiv.org/fulltext/" +) # Settings for the indexing agent. -KINESIS_ENDPOINT = os.environ.get('KINESIS_ENDPOINT') +KINESIS_ENDPOINT = os.environ.get("KINESIS_ENDPOINT") """Can be used to set an alternate endpoint, e.g. for testing.""" -KINESIS_VERIFY = os.environ.get('KINESIS_VERIFY', "true") +KINESIS_VERIFY = os.environ.get("KINESIS_VERIFY", "true") """Indicates whether SSL certificate verification should be enforced.""" -KINESIS_STREAM = os.environ.get('KINESIS_STREAM', 'MetadataIsAvailable') +KINESIS_STREAM = os.environ.get("KINESIS_STREAM", "MetadataIsAvailable") """Name of the stream to which the indexing agent subscribes.""" -KINESIS_SHARD_ID = os.environ.get('KINESIS_SHARD_ID', '0') +KINESIS_SHARD_ID = os.environ.get("KINESIS_SHARD_ID", "0") -KINESIS_CHECKPOINT_VOLUME = os.environ.get('KINESIS_CHECKPOINT_VOLUME', - '/tmp') +KINESIS_CHECKPOINT_VOLUME = os.environ.get("KINESIS_CHECKPOINT_VOLUME", "/tmp") -KINESIS_START_TYPE = os.environ.get('KINESIS_START_TYPE', 'AT_TIMESTAMP') -KINESIS_START_AT = os.environ.get('KINESIS_START_AT') +KINESIS_START_TYPE = os.environ.get("KINESIS_START_TYPE", "AT_TIMESTAMP") +KINESIS_START_AT = os.environ.get("KINESIS_START_AT") -KINESIS_SLEEP = os.environ.get('KINESIS_SLEEP', '0.1') +KINESIS_SLEEP = os.environ.get("KINESIS_SLEEP", "0.1") """Amount of time to wait before moving on to the next record.""" @@ -220,19 +226,19 @@ See ``_. """ -FLASKS3_BUCKET_NAME = os.environ.get('FLASKS3_BUCKET_NAME', 'some_bucket') -FLASKS3_CDN_DOMAIN = os.environ.get('FLASKS3_CDN_DOMAIN', 'static.arxiv.org') -FLASKS3_USE_HTTPS = os.environ.get('FLASKS3_USE_HTTPS', 1) -FLASKS3_FORCE_MIMETYPE = os.environ.get('FLASKS3_FORCE_MIMETYPE', 1) -FLASKS3_ACTIVE = os.environ.get('FLASKS3_ACTIVE', 0) +FLASKS3_BUCKET_NAME = os.environ.get("FLASKS3_BUCKET_NAME", "some_bucket") +FLASKS3_CDN_DOMAIN = os.environ.get("FLASKS3_CDN_DOMAIN", "static.arxiv.org") +FLASKS3_USE_HTTPS = os.environ.get("FLASKS3_USE_HTTPS", 1) +FLASKS3_FORCE_MIMETYPE = os.environ.get("FLASKS3_FORCE_MIMETYPE", 1) +FLASKS3_ACTIVE = os.environ.get("FLASKS3_ACTIVE", 0) # Settings for display of release information -RELEASE_NOTES_URL = 'https://confluence.cornell.edu/x/giazFQ' -RELEASE_NOTES_TEXT = 'Search v0.5 released 2018-12-20' +RELEASE_NOTES_URL = "https://github.com/arXiv/arxiv-search/releases" +RELEASE_NOTES_TEXT = "Search v0.5.6 released 2020-02-24" -EXTERNAL_URL_SCHEME = os.environ.get('EXTERNAL_URL_SCHEME', 'https') -BASE_SERVER = os.environ.get('BASE_SERVER', 'arxiv.org') +EXTERNAL_URL_SCHEME = os.environ.get("EXTERNAL_URL_SCHEME", "https") +BASE_SERVER = os.environ.get("BASE_SERVER", "arxiv.org") URLS = [ ("pdf", "/pdf/v", BASE_SERVER), @@ -253,7 +259,7 @@ ("other_by_id", "/format/", BASE_SERVER), ] -JWT_SECRET = os.environ.get('JWT_SECRET', 'foosecret') +JWT_SECRET = os.environ.get("JWT_SECRET", "foosecret") # TODO: one place to set the version, update release notes text, JIRA issue # collector, etc. diff --git a/search/consts.py b/search/consts.py new file mode 100644 index 00000000..e8ae800d --- /dev/null +++ b/search/consts.py @@ -0,0 +1,14 @@ +"""Constants.""" +from pytz import timezone + +# Sorting + +DEFAULT_SORT_ORDER = [ + {"announced_date_first": {"order": "desc"}}, + {"_doc": {"order": "asc"}}, +] + + +# Timezones + +EASTERN = timezone("US/Eastern") diff --git a/search/context.py b/search/context.py index cf7b5169..095df160 100644 --- a/search/context.py +++ b/search/context.py @@ -7,7 +7,9 @@ import werkzeug -def get_application_config(app: Optional[Union[Flask, object]] = None) -> Union[dict, os._Environ]: +def get_application_config( + app: Optional[Union[Flask, object]] = None +) -> Union[dict, os._Environ]: """ Get a configuration from the current app, or fall back to env. @@ -24,9 +26,9 @@ def get_application_config(app: Optional[Union[Flask, object]] = None) -> Union[ # pylint: disable=protected-access if app is not None: if isinstance(app, Flask): - return app.config # type: ignore - if flask_app: # Proxy object; falsey if there is no application context. - return flask_app.config # type: ignore + return app.config # type: ignore + if flask_app: # Proxy object; falsey if there is no application context. + return flask_app.config # type: ignore return os.environ @@ -39,5 +41,5 @@ def get_application_global() -> Optional[werkzeug.local.LocalProxy]: proxy or None """ if g: - return g # type: ignore + return g # type: ignore return None diff --git a/search/controllers/__init__.py b/search/controllers/__init__.py index 95a0c3a0..d5c36596 100644 --- a/search/controllers/__init__.py +++ b/search/controllers/__init__.py @@ -7,8 +7,9 @@ of response data (``dict``), status code (``int``), and extra response headers (``dict``). """ +from http import HTTPStatus from typing import Tuple, Dict, Any -from arxiv import status + from search.services import index from search.domain import SimpleQuery @@ -29,10 +30,10 @@ def health_check() -> Tuple[str, int, Dict[str, Any]]: """ try: document_set = index.SearchSession.search( # type: ignore - SimpleQuery(search_field='all', value='theory') + SimpleQuery(search_field="all", value="theory") ) except Exception: - return 'DOWN', status.HTTP_500_INTERNAL_SERVER_ERROR, {} - if document_set['results']: - return 'OK', status.HTTP_200_OK, {} - return 'DOWN', status.HTTP_500_INTERNAL_SERVER_ERROR, {} + return "DOWN", HTTPStatus.INTERNAL_SERVER_ERROR, {} + if document_set["results"]: + return "OK", HTTPStatus.OK, {} + return "DOWN", HTTPStatus.INTERNAL_SERVER_ERROR, {} diff --git a/search/controllers/advanced/__init__.py b/search/controllers/advanced/__init__.py index 53ab4f12..c1ef08aa 100644 --- a/search/controllers/advanced/__init__.py +++ b/search/controllers/advanced/__init__.py @@ -7,33 +7,41 @@ parameters, and produce informative error messages for the user. """ -from typing import Tuple, Dict, Any, Optional import re +from http import HTTPStatus +from typing import Tuple, List, Dict, Any, Optional from datetime import date, datetime from dateutil.relativedelta import relativedelta -from pytz import timezone - +from flask import url_for from werkzeug.datastructures import MultiDict, ImmutableMultiDict from werkzeug.exceptions import InternalServerError, BadRequest, NotFound -from flask import url_for -from arxiv import status, taxonomy -from search.services import index, SearchSession, fulltext, metadata -from search.domain import AdvancedQuery, FieldedSearchTerm, DateRange, \ - Classification, FieldedSearchList, ClassificationList, Query, asdict +from arxiv import taxonomy from arxiv.base import logging +from search.services import index, SearchSession +from search.domain import ( + AdvancedQuery, + FieldedSearchTerm, + DateRange, + Classification, + FieldedSearchList, + ClassificationList, + Query, +) +from search import consts +from search.controllers.advanced import forms from search.controllers.util import paginate, catch_underscore_syntax -from . import forms logger = logging.getLogger(__name__) + Response = Tuple[Dict[str, Any], int, Dict[str, Any]] -EASTERN = timezone('US/Eastern') -TERM_FIELD_PTN = re.compile(r'terms-([0-9])+-term') + +TERM_FIELD_PTN = re.compile(r"terms-([0-9])+-term") def search(request_params: MultiDict) -> Response: @@ -68,10 +76,10 @@ def search(request_params: MultiDict) -> Response: if isinstance(request_params, ImmutableMultiDict): request_params = MultiDict(request_params.items(multi=True)) - logger.debug('search request from advanced form') + logger.debug("search request from advanced form") response_data: Dict[str, Any] = {} - response_data['show_form'] = ('advanced' not in request_params) - logger.debug('show_form: %s', str(response_data['show_form'])) + response_data["show_form"] = "advanced" not in request_params + logger.debug("show_form: %s", str(response_data["show_form"])) # Here we intervene on the user's query to look for holdouts from # the classic search system's author indexing syntax (surname_f). We @@ -86,26 +94,26 @@ def search(request_params: MultiDict) -> Response: continue value = str(value) i = match.group(1) - field = request_params.get(f'terms-{i}-field') + field = request_params.get(f"terms-{i}-field") # We are only looking for this syntax in the author search, or # in an all-fields search. - if field not in ['all', 'author']: + if field not in ["all", "author"]: continue value, _has_classic = catch_underscore_syntax(value) has_classic = _has_classic if not has_classic else has_classic request_params.setlist(key, [value]) - response_data['has_classic_format'] = has_classic + response_data["has_classic_format"] = has_classic form = forms.AdvancedSearchForm(request_params) q: Optional[Query] # We want to avoid attempting to validate if no query has been entered. # If a query was actually submitted via the form, 'advanced' will be # present in the request parameters. - if 'advanced' in request_params: + if "advanced" in request_params: if form.validate(): - logger.debug('form is valid') + logger.debug("form is valid") q = _query_from_form(form) # Pagination is handled outside of the form. @@ -116,53 +124,54 @@ def search(request_params: MultiDict) -> Response: # template rendering, so they get added directly to the # response content. asdict( response_data.update(SearchSession.search(q)) # type: ignore - except index.IndexConnectionError as e: + except index.IndexConnectionError as ex: # There was a (hopefully transient) connection problem. Either # this will clear up relatively quickly (next request), or # there is a more serious outage. - logger.error('IndexConnectionError: %s', e) + logger.error("IndexConnectionError: %s", ex) raise InternalServerError( "There was a problem connecting to the search index. This " "is quite likely a transient issue, so please try your " "search again. If this problem persists, please report it " "to help@arxiv.org." - ) from e - except index.QueryError as e: + ) from ex + except index.QueryError as ex: # Base exception routers should pick this up and show bug page. - logger.error('QueryError: %s', e) + logger.error("QueryError: %s", ex) raise InternalServerError( "There was a problem executing your query. Please try " "your search again. If this problem persists, please " "report it to help@arxiv.org." - ) from e - except index.OutsideAllowedRange as e: + ) from ex + except index.OutsideAllowedRange as ex: raise BadRequest( "Hello clever friend. You can't get results in that range" " right now." - ) from e - response_data['query'] = q + ) from ex + response_data["query"] = q else: - logger.debug('form is invalid: %s', str(form.errors)) - if 'order' in form.errors or 'size' in form.errors: + logger.debug("form is invalid: %s", str(form.errors)) + if "order" in form.errors or "size" in form.errors: # It's likely that the user tried to set these parameters # manually, or that the search originated from somewhere else # (and was configured incorrectly). - advanced_url = url_for('ui.advanced_search') + advanced_url = url_for("ui.advanced_search") raise BadRequest( f"It looks like there's something odd about your search" f" request. Please try starting" - f" over.") + f" over." + ) # Force the form to be displayed, so that we can render errors. # This has most likely occurred due to someone manually crafting # a GET response, but it could be something else. - response_data['show_form'] = True + response_data["show_form"] = True # We want the form handy even when it is not shown to the user. For # example, we can generate new form-friendly requests to update sort # order and page size by embedding the form (hidden). - response_data['form'] = form - return response_data, status.HTTP_200_OK, {} + response_data["form"] = form + return response_data, HTTPStatus.OK, {} def _query_from_form(form: forms.AdvancedSearchForm) -> AdvancedQuery: @@ -183,97 +192,130 @@ def _query_from_form(form: forms.AdvancedSearchForm) -> AdvancedQuery: q = _update_query_with_dates(q, form.date.data) q = _update_query_with_terms(q, form.terms.data) q = _update_query_with_classification(q, form.classification.data) - q.include_cross_list = form.classification.include_cross_list.data \ + q.include_cross_list = ( + form.classification.include_cross_list.data == form.classification.INCLUDE_CROSS_LIST + ) if form.include_older_versions.data: q.include_older_versions = True order = form.order.data - if order and order != 'None': + if order and order != "None": q.order = order q.hide_abstracts = form.abstracts.data == form.HIDE_ABSTRACTS return q -def _update_query_with_classification(q: AdvancedQuery, data: MultiDict) \ - -> AdvancedQuery: +def _update_query_with_classification( + q: AdvancedQuery, data: MultiDict +) -> AdvancedQuery: q.classification = ClassificationList() archives = [ - ('computer_science', 'cs'), ('economics', 'econ'), ('eess', 'eess'), - ('mathematics', 'math'), ('q_biology', 'q-bio'), - ('q_finance', 'q-fin'), ('statistics', 'stat') + ("computer_science", "cs"), + ("economics", "econ"), + ("eess", "eess"), + ("mathematics", "math"), + ("q_biology", "q-bio"), + ("q_finance", "q-fin"), + ("statistics", "stat"), ] for field, archive in archives: if data.get(field): # Fix for these typing issues is coming soon! # See: https://github.com/python/mypy/pull/4397 q.classification.append( - Classification(archive={'id': archive}) # type: ignore + Classification(archive={"id": archive}) # type: ignore ) - if data.get('physics') and 'physics_archives' in data: - if 'all' in data['physics_archives']: + if data.get("physics") and "physics_archives" in data: + if "all" in data["physics_archives"]: q.classification.append( - Classification(group={'id': 'grp_physics'}) # type: ignore + Classification(group={"id": "grp_physics"}) # type: ignore ) else: q.classification.append( - Classification( # type: ignore - group={'id': 'grp_physics'}, - archive={'id': data['physics_archives']} + Classification( # type: ignore + group={"id": "grp_physics"}, + archive={"id": data["physics_archives"]}, ) ) return q -def _update_query_with_terms(q: AdvancedQuery, terms_data: list) \ - -> AdvancedQuery: - q.terms = FieldedSearchList([ - FieldedSearchTerm(**term) # type: ignore - for term in terms_data if term['term'] - ]) +# FIXME: Argument type. +def _update_query_with_terms( + q: AdvancedQuery, terms_data: List[Any] +) -> AdvancedQuery: + q.terms = FieldedSearchList( + [ + FieldedSearchTerm(**term) # type: ignore + for term in terms_data + if term["term"] + ] + ) return q -def _update_query_with_dates(q: AdvancedQuery, date_data: MultiDict) \ - -> AdvancedQuery: - filter_by = date_data['filter_by'] - if filter_by == 'all_dates': # Nothing to do; all dates by default. +def _update_query_with_dates( + q: AdvancedQuery, date_data: MultiDict +) -> AdvancedQuery: + filter_by = date_data["filter_by"] + if filter_by == "all_dates": # Nothing to do; all dates by default. return q - elif filter_by == 'past_12': + elif filter_by == "past_12": one_year_ago = date.today() - relativedelta(months=12) # Fix for these typing issues is coming soon! # See: https://github.com/python/mypy/pull/4397 - q.date_range = DateRange( # type: ignore - start_date=datetime(year=one_year_ago.year, - month=one_year_ago.month, - day=1, hour=0, minute=0, second=0, - tzinfo=EASTERN) + q.date_range = DateRange( # type: ignore + start_date=datetime( + year=one_year_ago.year, + month=one_year_ago.month, + day=1, + hour=0, + minute=0, + second=0, + tzinfo=consts.EASTERN, + ) ) - elif filter_by == 'specific_year': - q.date_range = DateRange( # type: ignore - start_date=datetime(year=date_data['year'].year, month=1, day=1, - hour=0, minute=0, second=0, tzinfo=EASTERN), - end_date=datetime(year=date_data['year'].year + 1, month=1, day=1, - hour=0, minute=0, second=0, tzinfo=EASTERN), + elif filter_by == "specific_year": + q.date_range = DateRange( # type: ignore + start_date=datetime( + year=date_data["year"].year, + month=1, + day=1, + hour=0, + minute=0, + second=0, + tzinfo=consts.EASTERN, + ), + end_date=datetime( + year=date_data["year"].year + 1, + month=1, + day=1, + hour=0, + minute=0, + second=0, + tzinfo=consts.EASTERN, + ), ) - elif filter_by == 'date_range': - if date_data['from_date']: - date_data['from_date'] = datetime.combine( # type: ignore - date_data['from_date'], + elif filter_by == "date_range": + if date_data["from_date"]: + date_data["from_date"] = datetime.combine( # type: ignore + date_data["from_date"], datetime.min.time(), - tzinfo=EASTERN) - if date_data['to_date']: - date_data['to_date'] = datetime.combine( # type: ignore - date_data['to_date'], + tzinfo=consts.EASTERN, + ) + if date_data["to_date"]: + date_data["to_date"] = datetime.combine( # type: ignore + date_data["to_date"], datetime.min.time(), - tzinfo=EASTERN) + tzinfo=consts.EASTERN, + ) - q.date_range = DateRange( # type: ignore - start_date=date_data['from_date'], - end_date=date_data['to_date'], + q.date_range = DateRange( # type: ignore + start_date=date_data["from_date"], end_date=date_data["to_date"] ) if q.date_range: - q.date_range.date_type = date_data['date_type'] + q.date_range.date_type = date_data["date_type"] return q @@ -285,37 +327,37 @@ def group_search(args: MultiDict, groups_or_archives: str) -> Response: Note that this only supports options supported in the advanced search interface. Anything else will result in a 404. """ - logger.debug('Group search for %s', groups_or_archives) + logger.debug("Group search for %s", groups_or_archives) valid_archives = [] - for archive in groups_or_archives.split(','): + for archive in groups_or_archives.split(","): if archive not in taxonomy.ARCHIVES: - logger.debug('archive %s not found in taxonomy', archive) + logger.debug("archive %s not found in taxonomy", archive) continue # Support old archives. if archive in taxonomy.ARCHIVES_SUBSUMED: category = taxonomy.CATEGORIES[taxonomy.ARCHIVES_SUBSUMED[archive]] - archive = category['in_archive'] + archive = category["in_archive"] valid_archives.append(archive) if len(valid_archives) == 0: - logger.debug('No valid archives in request') - raise NotFound('No such archive.') + logger.debug("No valid archives in request") + raise NotFound("No such archive.") - logger.debug('Request for %i valid archives', len(valid_archives)) + logger.debug("Request for %i valid archives", len(valid_archives)) args = args.copy() for archive in valid_archives: fld = dict(forms.ClassificationForm.ARCHIVES).get(archive) - if fld is not None: # Try a top-level archive first. - args[f'classification-{fld}'] = True + if fld is not None: # Try a top-level archive first. + args[f"classification-{fld}"] = True else: # Might be a physics archive; if so, also select the physics # group on the form. fld = dict(forms.ClassificationForm.PHYSICS_ARCHIVES).get(archive) if fld is None: - logger.warn(f'Invalid archive shortcut: {fld}') + logger.warn(f"Invalid archive shortcut: {fld}") continue - args['classification-physics'] = True + args["classification-physics"] = True # If there is more than one physics archives, only the last one # will be preserved. - args['classification-physics_archives'] = fld + args["classification-physics_archives"] = fld return search(args) diff --git a/search/controllers/advanced/forms.py b/search/controllers/advanced/forms.py index a5035e7f..3e039991 100644 --- a/search/controllers/advanced/forms.py +++ b/search/controllers/advanced/forms.py @@ -1,65 +1,79 @@ """Provides form rendering and validation for the advanced search feature.""" -import calendar import re +import calendar from datetime import date, datetime from typing import Callable, Optional, List, Any -from wtforms import Form, BooleanField, StringField, SelectField, validators, \ - FormField, SelectMultipleField, DateField, ValidationError, FieldList, \ - RadioField - +from wtforms import ( + Form, + BooleanField, + StringField, + SelectField, + validators, + FormField, + DateField, + ValidationError, + FieldList, + RadioField, +) from wtforms.fields import HiddenField -from wtforms import widgets from arxiv import taxonomy from search.domain import DateRange, AdvancedQuery -from search.controllers.util import does_not_start_with_wildcard, \ - strip_white_space, has_balanced_quotes +from search.controllers.util import ( + does_not_start_with_wildcard, + strip_white_space, + has_balanced_quotes, +) class MultiFormatDateField(DateField): """Extends :class:`.DateField` to support multiple date formats.""" - def __init__(self, label: Optional[str] = None, - validators: Optional[List[Callable]] = None, - formats: List[str] = ['%Y-%m-%d %H:%M:%S'], - default_upper_bound: bool = False, - **kwargs: Any) -> None: + def __init__( + self, + label: Optional[str] = None, + validators: Optional[List[Callable]] = None, + formats: Optional[List[str]] = None, + default_upper_bound: bool = False, + **kwargs: Any, + ) -> None: """Override to change ``format: str`` to ``formats: List[str]``.""" super(DateField, self).__init__(label, validators, **kwargs) - self.formats = formats + self.formats = formats or ["%Y-%m-%d %H:%M:%S"] self.default_upper_bound = default_upper_bound def _value(self) -> str: if self.raw_data: - return ' '.join(self.raw_data) + return " ".join(self.raw_data) else: - return self.data and self.data.strftime(self.formats[0]) or '' + return self.data and self.data.strftime(self.formats[0]) or "" def process_formdata(self, valuelist: List[str]) -> None: """Try date formats until one sticks, or raise ValueError.""" if valuelist: - date_str = ' '.join(valuelist) + date_str = " ".join(valuelist) self.data: Optional[date] for fmt in self.formats: try: adj_date = datetime.strptime(date_str, fmt).date() if self.default_upper_bound: - if not re.search(r'%[Bbm]', fmt): + if not re.search(r"%[Bbm]", fmt): # when month does not appear in matching format adj_date = adj_date.replace(month=12, day=31) - elif not re.search('%d', fmt): + elif not re.search("%d", fmt): # when day does not appear in matching format - last_day = calendar.monthrange(adj_date.year, - adj_date.month)[1] + last_day = calendar.monthrange( + adj_date.year, adj_date.month + )[1] adj_date = adj_date.replace(day=last_day) self.data = adj_date return except ValueError: continue self.data = None - raise ValueError(self.gettext('Not a valid date value')) + raise ValueError(self.gettext("Not a valid date value")) class FieldForm(Form): @@ -67,12 +81,16 @@ class FieldForm(Form): # pylint: disable=too-few-public-methods - term = StringField("Search term...", filters=[strip_white_space], - validators=[does_not_start_with_wildcard, - has_balanced_quotes]) - operator = SelectField("Operator", choices=[ - ('AND', 'AND'), ('OR', 'OR'), ('NOT', 'NOT') - ], default='AND') + term = StringField( + "Search term...", + filters=[strip_white_space], + validators=[does_not_start_with_wildcard, has_balanced_quotes], + ) + operator = SelectField( + "Operator", + choices=[("AND", "AND"), ("OR", "OR"), ("NOT", "NOT")], + default="AND", + ) field = SelectField("Field", choices=AdvancedQuery.SUPPORTED_FIELDS) @@ -88,37 +106,42 @@ class ClassificationForm(Form): # until we replace the classic-style advanced interface with faceted # search. ARCHIVES = [ - ('cs', 'computer_science'), - ('econ', 'economics'), - ('eess', 'eess'), - ('math', 'mathematics'), - ('physics', 'physics'), - ('q-bio', 'q_biology'), - ('q-fin', 'q_finance'), - ('stat', 'statistics') + ("cs", "computer_science"), + ("econ", "economics"), + ("eess", "eess"), + ("math", "mathematics"), + ("physics", "physics"), + ("q-bio", "q_biology"), + ("q-fin", "q_finance"), + ("stat", "statistics"), ] - PHYSICS_ARCHIVES = [('all', 'all')] + \ - [(archive, archive) for archive, description - in taxonomy.ARCHIVES_ACTIVE.items() - if description['in_group'] == 'grp_physics'] - - INCLUDE_CROSS_LIST = 'include' - EXCLUDE_CROSS_LIST = 'exclude' - - computer_science = BooleanField('Computer Science (cs)') - economics = BooleanField('Economics (econ)') - eess = BooleanField('Electrical Engineering and Systems Science (eess)') - mathematics = BooleanField('Mathematics (math)') - physics = BooleanField('Physics') - physics_archives = SelectField(choices=PHYSICS_ARCHIVES, default='all') - q_biology = BooleanField('Quantitative Biology (q-bio)') - q_finance = BooleanField('Quantitative Finance (q-fin)') - statistics = BooleanField('Statistics (stat)') - - include_cross_list = RadioField('Include cross-list', choices=[ - (INCLUDE_CROSS_LIST, 'Include cross-listed papers'), - (EXCLUDE_CROSS_LIST, 'Exclude cross-listed papers') - ], default=INCLUDE_CROSS_LIST) + PHYSICS_ARCHIVES = [("all", "all")] + [ + (archive, archive) + for archive, description in taxonomy.ARCHIVES_ACTIVE.items() + if description["in_group"] == "grp_physics" + ] + + INCLUDE_CROSS_LIST = "include" + EXCLUDE_CROSS_LIST = "exclude" + + computer_science = BooleanField("Computer Science (cs)") + economics = BooleanField("Economics (econ)") + eess = BooleanField("Electrical Engineering and Systems Science (eess)") + mathematics = BooleanField("Mathematics (math)") + physics = BooleanField("Physics") + physics_archives = SelectField(choices=PHYSICS_ARCHIVES, default="all") + q_biology = BooleanField("Quantitative Biology (q-bio)") + q_finance = BooleanField("Quantitative Finance (q-fin)") + statistics = BooleanField("Statistics (stat)") + + include_cross_list = RadioField( + "Include cross-list", + choices=[ + (INCLUDE_CROSS_LIST, "Include cross-listed papers"), + (EXCLUDE_CROSS_LIST, "Exclude cross-listed papers"), + ], + default=INCLUDE_CROSS_LIST, + ) def yearInBounds(form: Form, field: DateField) -> None: @@ -129,65 +152,66 @@ def yearInBounds(form: Form, field: DateField) -> None: start_of_time = date(year=1991, month=1, day=1) upper_limit = date.today().replace(year=date.today().year + 1) if field.data < start_of_time or field.data > upper_limit: - raise ValidationError('Not a valid publication year') + raise ValidationError("Not a valid publication year") class DateForm(Form): """Subform with options for limiting results by publication date.""" filter_by = RadioField( - 'Filter by', choices=[ - ('all_dates', 'All dates'), - ('past_12', 'Past 12 months'), - ('specific_year', 'Specific year'), - ('date_range', 'Date range') + "Filter by", + choices=[ + ("all_dates", "All dates"), + ("past_12", "Past 12 months"), + ("specific_year", "Specific year"), + ("date_range", "Date range"), ], - default='all_dates' + default="all_dates", ) year = DateField( - 'Year', - format='%Y', - validators=[validators.Optional(), yearInBounds] + "Year", format="%Y", validators=[validators.Optional(), yearInBounds] ) from_date = MultiFormatDateField( - 'From', + "From", validators=[validators.Optional(), yearInBounds], - formats=['%Y-%m-%d', '%Y-%m', '%Y'] - + formats=["%Y-%m-%d", "%Y-%m", "%Y"], ) to_date = MultiFormatDateField( - 'to', + "to", validators=[validators.Optional(), yearInBounds], - formats=['%Y-%m-%d', '%Y-%m', '%Y'], - default_upper_bound=True + formats=["%Y-%m-%d", "%Y-%m", "%Y"], + default_upper_bound=True, ) SUBMITTED_ORIGINAL = DateRange.SUBMITTED_ORIGINAL SUBMITTED_CURRENT = DateRange.SUBMITTED_CURRENT ANNOUNCED = DateRange.ANNOUNCED DATE_TYPE_CHOICES = [ - (SUBMITTED_CURRENT, 'Submission date (most recent)'), - (SUBMITTED_ORIGINAL, 'Submission date (original)'), - (ANNOUNCED, 'Announcement date'), + (SUBMITTED_CURRENT, "Submission date (most recent)"), + (SUBMITTED_ORIGINAL, "Submission date (original)"), + (ANNOUNCED, "Announcement date"), ] - date_type = RadioField('Apply to', choices=DATE_TYPE_CHOICES, - default=SUBMITTED_CURRENT, - description="You may filter on either submission" - " date or announcement date. Note that announcement" - " date supports only year and month granularity.") + date_type = RadioField( + "Apply to", + choices=DATE_TYPE_CHOICES, + default=SUBMITTED_CURRENT, + description="You may filter on either submission" + " date or announcement date. Note that announcement" + " date supports only year and month granularity.", + ) def validate_filter_by(self, field: RadioField) -> None: """Ensure that related fields are filled.""" - if field.data == 'specific_year' and not self.data.get('year'): - raise ValidationError('Please select a year') - elif field.data == 'date_range': - if not self.data.get('from_date') and not self.data.get('to_date'): - raise ValidationError('Must select start and/or end date(s)') - if self.data.get('from_date') and self.data.get('to_date'): - if self.data.get('from_date') >= self.data.get('to_date'): + if field.data == "specific_year" and not self.data.get("year"): + raise ValidationError("Please select a year") + elif field.data == "date_range": + if not self.data.get("from_date") and not self.data.get("to_date"): + raise ValidationError("Must select start and/or end date(s)") + if self.data.get("from_date") and self.data.get("to_date"): + if self.data.get("from_date") >= self.data.get("to_date"): raise ValidationError( - 'End date must be later than start date' + "End date must be later than start date" ) @@ -196,31 +220,39 @@ class AdvancedSearchForm(Form): # pylint: disable=too-few-public-methods - advanced = HiddenField('Advanced', default=1) + advanced = HiddenField("Advanced", default=1) """Used to indicate whether the form should be shown.""" terms = FieldList(FormField(FieldForm), min_entries=1) classification = FormField(ClassificationForm) date = FormField(DateForm) - size = SelectField('results per page', default=50, choices=[ - ('25', '25'), - ('50', '50'), - ('100', '100'), - ('200', '200') - ]) - order = SelectField('Sort results by', choices=[ - ('-announced_date_first', 'Announcement date (newest first)'), - ('announced_date_first', 'Announcement date (oldest first)'), - ('-submitted_date', 'Submission date (newest first)'), - ('submitted_date', 'Submission date (oldest first)'), - ('', 'Relevance') - ], validators=[validators.Optional()], default='-announced_date_first') - include_older_versions = BooleanField('Include older versions of papers') - - HIDE_ABSTRACTS = 'hide' - SHOW_ABSTRACTS = 'show' - - abstracts = RadioField('Abstracts', choices=[ - (SHOW_ABSTRACTS, 'Show abstracts'), - (HIDE_ABSTRACTS, 'Hide abstracts') - ], default=SHOW_ABSTRACTS) + size = SelectField( + "results per page", + default=50, + choices=[("25", "25"), ("50", "50"), ("100", "100"), ("200", "200")], + ) + order = SelectField( + "Sort results by", + choices=[ + ("-announced_date_first", "Announcement date (newest first)"), + ("announced_date_first", "Announcement date (oldest first)"), + ("-submitted_date", "Submission date (newest first)"), + ("submitted_date", "Submission date (oldest first)"), + ("", "Relevance"), + ], + validators=[validators.Optional()], + default="-announced_date_first", + ) + include_older_versions = BooleanField("Include older versions of papers") + + HIDE_ABSTRACTS = "hide" + SHOW_ABSTRACTS = "show" + + abstracts = RadioField( + "Abstracts", + choices=[ + (SHOW_ABSTRACTS, "Show abstracts"), + (HIDE_ABSTRACTS, "Hide abstracts"), + ], + default=SHOW_ABSTRACTS, + ) diff --git a/search/controllers/advanced/tests.py b/search/controllers/advanced/tests.py index 0bbb7e90..2e6618c0 100644 --- a/search/controllers/advanced/tests.py +++ b/search/controllers/advanced/tests.py @@ -1,15 +1,15 @@ """Tests for advanced search controller, :mod:`search.controllers.advanced`.""" +from http import HTTPStatus from unittest import TestCase, mock from datetime import date, datetime from dateutil.relativedelta import relativedelta -from werkzeug import MultiDict + +from werkzeug.datastructures import MultiDict from werkzeug.exceptions import InternalServerError, BadRequest -from arxiv import status -from search.domain import Query, DateRange, FieldedSearchTerm, Classification,\ - AdvancedQuery, DocumentSet +from search.domain import Query, DateRange, FieldedSearchTerm, AdvancedQuery from search.controllers import advanced from search.controllers.advanced.forms import MultiFormatDateField from search.controllers.advanced.forms import AdvancedSearchForm @@ -22,48 +22,47 @@ class TestMultiFormatDateField(TestCase): def test_value_with_one_format(self): """One date format is specified.""" - fmt = '%Y-%m-%d %H:%M:%S' + fmt = "%Y-%m-%d %H:%M:%S" value = datetime.now() field = MultiFormatDateField( - formats=[fmt], - _form=mock.MagicMock(), - _name='test' + formats=[fmt], _form=mock.MagicMock(), _name="test" ) field.data = value - self.assertEqual(field._value(), value.strftime(fmt), - "Should use the first (only) format to render value") + self.assertEqual( + field._value(), + value.strftime(fmt), + "Should use the first (only) format to render value", + ) def test_process_with_one_format(self): """One date format is specified.""" - fmt = '%Y-%m-%d %H:%M:%S' + fmt = "%Y-%m-%d %H:%M:%S" field = MultiFormatDateField( - formats=[fmt], - _form=mock.MagicMock(), - _name='test' + formats=[fmt], _form=mock.MagicMock(), _name="test" ) - field.process_formdata(['2012-01-02 05:55:02']) + field.process_formdata(["2012-01-02 05:55:02"]) self.assertIsInstance(field.data, date, "Should parse successfully") def test_process_with_several_formats(self): """Several date formats are specified.""" field = MultiFormatDateField( - formats=['%Y-%m-%d', '%Y-%m', '%Y'], + formats=["%Y-%m-%d", "%Y-%m", "%Y"], _form=mock.MagicMock(), - _name='test' + _name="test", ) - field.process_formdata(['2012-03-02']) + field.process_formdata(["2012-03-02"]) self.assertIsInstance(field.data, date, "Should parse successfully") self.assertEqual(field.data.day, 2) self.assertEqual(field.data.month, 3) self.assertEqual(field.data.year, 2012) - field.process_formdata(['2014-05']) + field.process_formdata(["2014-05"]) self.assertIsInstance(field.data, date, "Should parse successfully") self.assertEqual(field.data.day, 1) self.assertEqual(field.data.month, 5) self.assertEqual(field.data.year, 2014) - field.process_formdata(['2011']) + field.process_formdata(["2011"]) self.assertIsInstance(field.data, date, "Should parse successfully") self.assertEqual(field.data.day, 1) self.assertEqual(field.data.month, 1) @@ -73,99 +72,120 @@ def test_process_with_several_formats(self): class TestSearchController(TestCase): """Tests for :func:`.advanced.search`.""" - @mock.patch('search.controllers.advanced.SearchSession') + @mock.patch("search.controllers.advanced.SearchSession") def test_no_form_data(self, mock_index): """No form data has been submitted.""" request_data = MultiDict() response_data, code, headers = advanced.search(request_data) - self.assertEqual(code, status.HTTP_200_OK, "Response should be OK.") + self.assertEqual(code, HTTPStatus.OK, "Response should be OK.") - self.assertIn('form', response_data, "Response should include form.") + self.assertIn("form", response_data, "Response should include form.") - self.assertEqual(mock_index.search.call_count, 0, - "No search should be attempted") + self.assertEqual( + mock_index.search.call_count, 0, "No search should be attempted" + ) - @mock.patch('search.controllers.advanced.SearchSession') + @mock.patch("search.controllers.advanced.SearchSession") def test_single_field_term(self, mock_index): """Form data and ``advanced`` param are present.""" - mock_index.search.return_value = dict(metadata={}, results=[]) - - request_data = MultiDict({ - 'advanced': True, - 'terms-0-operator': 'AND', - 'terms-0-field': 'title', - 'terms-0-term': 'foo' - }) + mock_index.search.return_value = {"metadata": {}, "results": []} + + request_data = MultiDict( + { + "advanced": True, + "terms-0-operator": "AND", + "terms-0-field": "title", + "terms-0-term": "foo", + } + ) response_data, code, headers = advanced.search(request_data) - self.assertEqual(mock_index.search.call_count, 1, - "A search should be attempted") + self.assertEqual( + mock_index.search.call_count, 1, "A search should be attempted" + ) call_args, call_kwargs = mock_index.search.call_args - self.assertIsInstance(call_args[0], AdvancedQuery, - "An AdvancedQuery is passed to the search index") - self.assertEqual(code, status.HTTP_200_OK, "Response should be OK.") + self.assertIsInstance( + call_args[0], + AdvancedQuery, + "An AdvancedQuery is passed to the search index", + ) + self.assertEqual(code, HTTPStatus.OK, "Response should be OK.") - @mock.patch('search.controllers.advanced.SearchSession') + @mock.patch("search.controllers.advanced.SearchSession") def test_invalid_data(self, mock_index): """Form data are invalid.""" - request_data = MultiDict({ - 'advanced': True, - 'date-past_12': True, - 'date-specific_year': True, - 'date-year': '2012' - }) + request_data = MultiDict( + { + "advanced": True, + "date-past_12": True, + "date-specific_year": True, + "date-year": "2012", + } + ) response_data, code, headers = advanced.search(request_data) - self.assertEqual(code, status.HTTP_200_OK, "Response should be OK.") + self.assertEqual(code, HTTPStatus.OK, "Response should be OK.") - self.assertIn('form', response_data, "Response should include form.") + self.assertIn("form", response_data, "Response should include form.") - self.assertEqual(mock_index.search.call_count, 0, - "No search should be attempted") + self.assertEqual( + mock_index.search.call_count, 0, "No search should be attempted" + ) - @mock.patch('search.controllers.advanced.SearchSession') + @mock.patch("search.controllers.advanced.SearchSession") def test_index_raises_connection_exception(self, mock_index): """Index service raises a IndexConnectionError.""" + def _raiseIndexConnectionError(*args, **kwargs): - raise IndexConnectionError('What now') + raise IndexConnectionError("What now") mock_index.search.side_effect = _raiseIndexConnectionError - request_data = MultiDict({ - 'advanced': True, - 'terms-0-operator': 'AND', - 'terms-0-field': 'title', - 'terms-0-term': 'foo' - }) + request_data = MultiDict( + { + "advanced": True, + "terms-0-operator": "AND", + "terms-0-field": "title", + "terms-0-term": "foo", + } + ) with self.assertRaises(InternalServerError): response_data, code, headers = advanced.search(request_data) - self.assertEqual(mock_index.search.call_count, 1, - "A search should be attempted") + self.assertEqual( + mock_index.search.call_count, 1, "A search should be attempted" + ) call_args, call_kwargs = mock_index.search.call_args - self.assertIsInstance(call_args[0], AdvancedQuery, - "An AdvancedQuery is passed to the search index") + self.assertIsInstance( + call_args[0], + AdvancedQuery, + "An AdvancedQuery is passed to the search index", + ) - @mock.patch('search.controllers.advanced.SearchSession') + @mock.patch("search.controllers.advanced.SearchSession") def test_index_raises_query_error(self, mock_index): """Index service raises a QueryError.""" + def _raiseQueryError(*args, **kwargs): - raise QueryError('What now') + raise QueryError("What now") mock_index.search.side_effect = _raiseQueryError - request_data = MultiDict({ - 'advanced': True, - 'terms-0-operator': 'AND', - 'terms-0-field': 'title', - 'terms-0-term': 'foo' - }) + request_data = MultiDict( + { + "advanced": True, + "terms-0-operator": "AND", + "terms-0-field": "title", + "terms-0-term": "foo", + } + ) with self.assertRaises(InternalServerError): try: response_data, code, headers = advanced.search(request_data) - except QueryError as e: - self.fail("QueryError should be handled (caught %s)" % e) + except QueryError as ex: + self.fail("QueryError should be handled (caught %s)" % ex) - self.assertEqual(mock_index.search.call_count, 1, - "A search should be attempted") + self.assertEqual( + mock_index.search.call_count, 1, "A search should be attempted" + ) class TestAdvancedSearchForm(TestCase): @@ -173,202 +193,237 @@ class TestAdvancedSearchForm(TestCase): def test_single_field_term(self): """User has entered a single term for a field-based search.""" - data = MultiDict({ - 'terms-0-operator': 'AND', - 'terms-0-field': 'title', - 'terms-0-term': 'foo' - }) + data = MultiDict( + { + "terms-0-operator": "AND", + "terms-0-field": "title", + "terms-0-term": "foo", + } + ) form = AdvancedSearchForm(data) self.assertTrue(form.validate()) def test_term_starts_with_wildcard(self): """User has entered a string that starts with a wildcard.""" - data = MultiDict({ - 'terms-0-operator': 'AND', - 'terms-0-field': 'title', - 'terms-0-term': '*foo' - }) + data = MultiDict( + { + "terms-0-operator": "AND", + "terms-0-field": "title", + "terms-0-term": "*foo", + } + ) form = AdvancedSearchForm(data) self.assertFalse(form.validate(), "Form should be invalid") def test_specific_year_must_be_specified(self): """If the user selects specific year, they must indicate a year.""" - data = MultiDict({ - 'terms-0-operator': 'AND', - 'terms-0-field': 'title', - 'terms-0-term': 'foo', - 'date-filter_by': 'specific_year', - }) + data = MultiDict( + { + "terms-0-operator": "AND", + "terms-0-field": "title", + "terms-0-term": "foo", + "date-filter_by": "specific_year", + } + ) form = AdvancedSearchForm(data) self.assertFalse(form.validate()) self.assertEqual(len(form.errors), 1) - data = MultiDict({ - 'terms-0-operator': 'AND', - 'terms-0-field': 'title', - 'terms-0-term': 'foo', - 'date-filter_by': 'specific_year', - 'date-year': '2012' - }) + data = MultiDict( + { + "terms-0-operator": "AND", + "terms-0-field": "title", + "terms-0-term": "foo", + "date-filter_by": "specific_year", + "date-year": "2012", + } + ) form = AdvancedSearchForm(data) self.assertTrue(form.validate()) # ARXIVNG-382 def test_date_range_supports_variable_precision(self): """Date range in advanced search should support variable precision.""" - data = MultiDict({ - 'terms-0-operator': 'AND', - 'terms-0-field': 'title', - 'terms-0-term': 'foo', - 'date-filter_by': 'date_range', - 'date-to_date': '2012-02-05' - }) + data = MultiDict( + { + "terms-0-operator": "AND", + "terms-0-field": "title", + "terms-0-term": "foo", + "date-filter_by": "date_range", + "date-to_date": "2012-02-05", + } + ) form = AdvancedSearchForm(data) self.assertTrue(form.validate()) - data = MultiDict({ - 'terms-0-operator': 'AND', - 'terms-0-field': 'title', - 'terms-0-term': 'foo', - 'date-filter_by': 'date_range', - 'date-to_date': '2012-02' - }) + data = MultiDict( + { + "terms-0-operator": "AND", + "terms-0-field": "title", + "terms-0-term": "foo", + "date-filter_by": "date_range", + "date-to_date": "2012-02", + } + ) form = AdvancedSearchForm(data) self.assertTrue(form.validate()) - data = MultiDict({ - 'terms-0-operator': 'AND', - 'terms-0-field': 'title', - 'terms-0-term': 'foo', - 'date-filter_by': 'date_range', - 'date-to_date': '2013', - 'date-from_date': '2012-03' - }) + data = MultiDict( + { + "terms-0-operator": "AND", + "terms-0-field": "title", + "terms-0-term": "foo", + "date-filter_by": "date_range", + "date-to_date": "2013", + "date-from_date": "2012-03", + } + ) form = AdvancedSearchForm(data) self.assertTrue(form.validate()) def test_date_range_must_be_specified(self): """If the user selects date range, they must indicate start or end.""" - data = MultiDict({ - 'terms-0-operator': 'AND', - 'terms-0-field': 'title', - 'terms-0-term': 'foo', - 'date-filter_by': 'date_range', - }) + data = MultiDict( + { + "terms-0-operator": "AND", + "terms-0-field": "title", + "terms-0-term": "foo", + "date-filter_by": "date_range", + } + ) form = AdvancedSearchForm(data) self.assertFalse(form.validate()) self.assertEqual(len(form.errors), 1) - data = MultiDict({ - 'terms-0-operator': 'AND', - 'terms-0-field': 'title', - 'terms-0-term': 'foo', - 'date-filter_by': 'date_range', - 'date-from_date': '2012-02-05' - }) + data = MultiDict( + { + "terms-0-operator": "AND", + "terms-0-field": "title", + "terms-0-term": "foo", + "date-filter_by": "date_range", + "date-from_date": "2012-02-05", + } + ) form = AdvancedSearchForm(data) self.assertTrue(form.validate()) - data = MultiDict({ - 'terms-0-operator': 'AND', - 'terms-0-field': 'title', - 'terms-0-term': 'foo', - 'date-filter_by': 'date_range', - 'date-to_date': '2012-02-05' - }) + data = MultiDict( + { + "terms-0-operator": "AND", + "terms-0-field": "title", + "terms-0-term": "foo", + "date-filter_by": "date_range", + "date-to_date": "2012-02-05", + } + ) form = AdvancedSearchForm(data) self.assertTrue(form.validate()) # ARXIVNG-997 def test_end_date_bounding(self): """If a user selects an end date, it must be bounded correctly.""" - data = MultiDict({ - 'terms-0-operator': 'AND', - 'terms-0-field': 'title', - 'terms-0-term': 'foo', - 'date-filter_by': 'date_range', - 'date-to_date': '2012' - }) + data = MultiDict( + { + "terms-0-operator": "AND", + "terms-0-field": "title", + "terms-0-term": "foo", + "date-filter_by": "date_range", + "date-to_date": "2012", + } + ) form = AdvancedSearchForm(data) self.assertTrue(form.validate()) - self.assertEqual(form.date.to_date.data, - date(year=2012, month=12, day=31)) + self.assertEqual( + form.date.to_date.data, date(year=2012, month=12, day=31) + ) - data['date-to_date'] = '2012-02' + data["date-to_date"] = "2012-02" form = AdvancedSearchForm(data) self.assertTrue(form.validate()) - self.assertEqual(form.date.to_date.data, - date(year=2012, month=2, day=29)) + self.assertEqual( + form.date.to_date.data, date(year=2012, month=2, day=29) + ) - data['date-to_date'] = '2016-06' + data["date-to_date"] = "2016-06" form = AdvancedSearchForm(data) self.assertTrue(form.validate()) - self.assertEqual(form.date.to_date.data, - date(year=2016, month=6, day=30)) + self.assertEqual( + form.date.to_date.data, date(year=2016, month=6, day=30) + ) - data['date-to_date'] = '2016-06-30' + data["date-to_date"] = "2016-06-30" form = AdvancedSearchForm(data) self.assertTrue(form.validate()) - self.assertEqual(form.date.to_date.data, - date(year=2016, month=6, day=30)) + self.assertEqual( + form.date.to_date.data, date(year=2016, month=6, day=30) + ) - data['date-to_date'] = '2100-02' + data["date-to_date"] = "2100-02" form = AdvancedSearchForm(data) self.assertFalse(form.validate()) def test_year_must_be_after_1990(self): """If the user selects a specific year, it must be after 1990.""" - data = MultiDict({ - 'terms-0-operator': 'AND', - 'terms-0-field': 'title', - 'terms-0-term': 'foo', - 'date-filter_by': 'specific_year', - 'date-year': '1990' - }) + data = MultiDict( + { + "terms-0-operator": "AND", + "terms-0-field": "title", + "terms-0-term": "foo", + "date-filter_by": "specific_year", + "date-year": "1990", + } + ) form = AdvancedSearchForm(data) self.assertFalse(form.validate()) - data = MultiDict({ - 'terms-0-operator': 'AND', - 'terms-0-field': 'title', - 'terms-0-term': 'foo', - 'date-filter_by': 'specific_year', - 'date-year': '1991' - }) + data = MultiDict( + { + "terms-0-operator": "AND", + "terms-0-field": "title", + "terms-0-term": "foo", + "date-filter_by": "specific_year", + "date-year": "1991", + } + ) form = AdvancedSearchForm(data) self.assertTrue(form.validate()) def test_input_whitespace_is_stripped(self): """If query has padding whitespace, it should be removed.""" - data = MultiDict({ - 'terms-0-operator': 'AND', - 'terms-0-field': 'title', - 'terms-0-term': ' foo ' - }) + data = MultiDict( + { + "terms-0-operator": "AND", + "terms-0-field": "title", + "terms-0-term": " foo ", + } + ) form = AdvancedSearchForm(data) self.assertTrue(form.validate(), "Form should be valid.") - self.assertEqual(form.terms[0].term.data, 'foo', - "Whitespace should be stripped.") + self.assertEqual( + form.terms[0].term.data, "foo", "Whitespace should be stripped." + ) def test_querystring_has_unbalanced_quotes(self): """Querystring has an odd number of quote characters.""" - data = MultiDict({ - 'terms-0-operator': 'AND', - 'terms-0-field': 'title', - 'terms-0-term': '"rhubarb' - }) + data = MultiDict( + { + "terms-0-operator": "AND", + "terms-0-field": "title", + "terms-0-term": '"rhubarb', + } + ) form = AdvancedSearchForm(data) self.assertFalse(form.validate(), "Form should be invalid") - data['terms-0-term'] = '"rhubarb"' + data["terms-0-term"] = '"rhubarb"' form = AdvancedSearchForm(data) self.assertTrue(form.validate(), "Form should be valid") - data['terms-0-term'] = '"rhubarb" "pie' + data["terms-0-term"] = '"rhubarb" "pie' form = AdvancedSearchForm(data) self.assertFalse(form.validate(), "Form should be invalid") - data['terms-0-term'] = '"rhubarb" "pie"' + data["terms-0-term"] = '"rhubarb" "pie"' form = AdvancedSearchForm(data) self.assertTrue(form.validate(), "Form should be valid") @@ -378,48 +433,53 @@ class TestUpdatequeryWithClassification(TestCase): def test_classification_is_selected(self): """Selected classifications are added to the query.""" - class_data = {'computer_science': True} + class_data = {"computer_science": True} q = advanced._update_query_with_classification(Query(), class_data) - self.assertEqual(q.classification, [{'archive': {'id': 'cs'}}]) + self.assertEqual(q.classification, [{"archive": {"id": "cs"}}]) def test_multiple_classifications_are_selected(self): """Selected classifications are added to the query.""" - class_data = {'computer_science': True, 'eess': True} + class_data = {"computer_science": True, "eess": True} q = advanced._update_query_with_classification(Query(), class_data) - self.assertEqual(q.classification, - [{'archive': {'id': 'cs'}}, - {'archive': {'id': 'eess'}}]) + self.assertEqual( + q.classification, + [{"archive": {"id": "cs"}}, {"archive": {"id": "eess"}}], + ) def test_physics_is_selected_all_archives(self): """The physics group is added to the query.""" - class_data = {'physics': True, 'physics_archives': 'all'} + class_data = {"physics": True, "physics_archives": "all"} q = advanced._update_query_with_classification(Query(), class_data) - self.assertEqual(q.classification, [{'group': {'id': 'grp_physics'}}]) + self.assertEqual(q.classification, [{"group": {"id": "grp_physics"}}]) def test_physics_is_selected_specific_archive(self): """The physic group and specified archive are added to the query.""" - class_data = {'physics': True, 'physics_archives': 'hep-ex'} + class_data = {"physics": True, "physics_archives": "hep-ex"} q = advanced._update_query_with_classification(Query(), class_data) - self.assertEqual(q.classification, - [{'archive': {'id': 'hep-ex'}, - 'group': {'id': 'grp_physics'}}]) + self.assertEqual( + q.classification, + [{"archive": {"id": "hep-ex"}, "group": {"id": "grp_physics"}}], + ) def test_physics_is_selected_specific_archive_plus_other_groups(self): """The physics group and specified archive are added to the query.""" class_data = { - 'mathematics': True, - 'physics': True, - 'physics_archives': 'hep-ex' + "mathematics": True, + "physics": True, + "physics_archives": "hep-ex", } q = advanced._update_query_with_classification(Query(), class_data) self.assertIsInstance(q, Query) self.assertIsInstance(q.classification, list) self.assertEqual(len(q.classification), 2) - self.assertEqual(q.classification, - [{'archive': {'id': 'math'}}, - {'group': {'id': 'grp_physics'}, - 'archive': {'id': 'hep-ex'}}]) + self.assertEqual( + q.classification, + [ + {"archive": {"id": "math"}}, + {"group": {"id": "grp_physics"}, "archive": {"id": "hep-ex"}}, + ], + ) class TestUpdateQueryWithFieldedTerms(TestCase): @@ -427,45 +487,45 @@ class TestUpdateQueryWithFieldedTerms(TestCase): def test_terms_are_provided(self): """Selected terms are added to the query.""" - terms_data = [{'term': 'muon', 'operator': 'AND', 'field': 'title'}] + terms_data = [{"term": "muon", "operator": "AND", "field": "title"}] q = advanced._update_query_with_terms(Query(), terms_data) self.assertIsInstance(q, Query) self.assertIsInstance(q.terms, list) self.assertEqual(len(q.terms), 1) self.assertIsInstance(q.terms[0], FieldedSearchTerm) - self.assertEqual(q.terms[0].term, 'muon') - self.assertEqual(q.terms[0].operator, 'AND') - self.assertEqual(q.terms[0].field, 'title') + self.assertEqual(q.terms[0].term, "muon") + self.assertEqual(q.terms[0].operator, "AND") + self.assertEqual(q.terms[0].field, "title") def test_multiple_terms_are_provided(self): """Selected terms are added to the query.""" terms_data = [ - {'term': 'muon', 'operator': 'AND', 'field': 'title'}, - {'term': 'boson', 'operator': 'OR', 'field': 'title'} + {"term": "muon", "operator": "AND", "field": "title"}, + {"term": "boson", "operator": "OR", "field": "title"}, ] q = advanced._update_query_with_terms(Query(), terms_data) self.assertIsInstance(q, Query) self.assertIsInstance(q.terms, list) self.assertEqual(len(q.terms), 2) self.assertIsInstance(q.terms[0], FieldedSearchTerm) - self.assertEqual(q.terms[1].term, 'boson') - self.assertEqual(q.terms[1].operator, 'OR') - self.assertEqual(q.terms[1].field, 'title') + self.assertEqual(q.terms[1].term, "boson") + self.assertEqual(q.terms[1].operator, "OR") + self.assertEqual(q.terms[1].field, "title") def test_multiple_terms_are_provided_with_all_field(self): """Selected terms are added to the query.""" terms_data = [ - {'term': 'switch', 'operator': 'AND', 'field': 'all'}, - {'term': 'disk', 'operator': 'OR', 'field': 'all'} + {"term": "switch", "operator": "AND", "field": "all"}, + {"term": "disk", "operator": "OR", "field": "all"}, ] q = advanced._update_query_with_terms(Query(), terms_data) self.assertIsInstance(q, Query) self.assertIsInstance(q.terms, list) self.assertEqual(len(q.terms), 2) self.assertIsInstance(q.terms[0], FieldedSearchTerm) - self.assertEqual(q.terms[1].term, 'disk') - self.assertEqual(q.terms[1].operator, 'OR') - self.assertEqual(q.terms[1].field, 'all') + self.assertEqual(q.terms[1].term, "disk") + self.assertEqual(q.terms[1].operator, "OR") + self.assertEqual(q.terms[1].field, "all") class TestUpdateQueryWithDates(TestCase): @@ -473,7 +533,7 @@ class TestUpdateQueryWithDates(TestCase): def test_past_12_is_selected(self): """Query selects the past twelve months.""" - date_data = {'filter_by': 'past_12', 'date_type': 'submitted_date'} + date_data = {"filter_by": "past_12", "date_type": "submitted_date"} q = advanced._update_query_with_dates(Query(), date_data) self.assertIsInstance(q, Query) self.assertIsInstance(q.date_range, DateRange) @@ -481,12 +541,12 @@ def test_past_12_is_selected(self): self.assertEqual( q.date_range.start_date.date(), date.today() - twelve_months, - "Start date is the first day of the month twelve prior to today." + "Start date is the first day of the month twelve prior to today.", ) def test_all_dates_is_selected(self): """Query does not select on date.""" - date_data = {'filter_by': 'all_dates', 'date_type': 'submitted_date'} + date_data = {"filter_by": "all_dates", "date_type": "submitted_date"} q = advanced._update_query_with_dates(AdvancedQuery(), date_data) self.assertIsInstance(q, AdvancedQuery) self.assertIsNone(q.date_range) @@ -494,26 +554,28 @@ def test_all_dates_is_selected(self): def test_specific_year_is_selected(self): """Start and end dates are set, one year apart.""" date_data = { - 'filter_by': 'specific_year', - 'year': date(year=1999, month=1, day=1), - 'date_type': 'submitted_date' + "filter_by": "specific_year", + "year": date(year=1999, month=1, day=1), + "date_type": "submitted_date", } q = advanced._update_query_with_dates(AdvancedQuery(), date_data) self.assertIsInstance(q, AdvancedQuery) - self.assertEqual(q.date_range.end_date.date(), - date(year=2000, month=1, day=1)) - self.assertEqual(q.date_range.start_date.date(), - date(year=1999, month=1, day=1)) + self.assertEqual( + q.date_range.end_date.date(), date(year=2000, month=1, day=1) + ) + self.assertEqual( + q.date_range.start_date.date(), date(year=1999, month=1, day=1) + ) def test_date_range_is_selected(self): """Start and end dates are set based on selection.""" from_date = date(year=1999, month=7, day=3) to_date = date(year=1999, month=8, day=5) date_data = { - 'filter_by': 'date_range', - 'from_date': from_date, - 'to_date': to_date, - 'date_type': 'submitted_date' + "filter_by": "date_range", + "from_date": from_date, + "to_date": to_date, + "date_type": "submitted_date", } q = advanced._update_query_with_dates(AdvancedQuery(), date_data) self.assertIsInstance(q, AdvancedQuery) @@ -532,31 +594,35 @@ class TestPaginationParametersAreFunky(TestCase): search form. """ - @mock.patch('search.controllers.advanced.url_for') + @mock.patch("search.controllers.advanced.url_for") def test_order_is_invalid(self, mock_url_for): """The order parameter on the request is invalid.""" - request_data = MultiDict({ - 'advanced': True, - 'terms-0-operator': 'AND', - 'terms-0-field': 'title', - 'terms-0-term': 'foo', - 'size': 50, # Valid. - 'order': 'foo' # Invalid - }) + request_data = MultiDict( + { + "advanced": True, + "terms-0-operator": "AND", + "terms-0-field": "title", + "terms-0-term": "foo", + "size": 50, # Valid. + "order": "foo", # Invalid + } + ) with self.assertRaises(BadRequest): advanced.search(request_data) - @mock.patch('search.controllers.advanced.url_for') + @mock.patch("search.controllers.advanced.url_for") def test_size_is_invalid(self, mock_url_for): """The order parameter on the request is invalid.""" - request_data = MultiDict({ - 'advanced': True, - 'terms-0-operator': 'AND', - 'terms-0-field': 'title', - 'terms-0-term': 'foo', - 'size': 51, # Invalid - 'order': '' # Valid - }) + request_data = MultiDict( + { + "advanced": True, + "terms-0-operator": "AND", + "terms-0-field": "title", + "terms-0-term": "foo", + "size": 51, # Invalid + "order": "", # Valid + } + ) with self.assertRaises(BadRequest): advanced.search(request_data) @@ -571,85 +637,111 @@ class TestClassicAuthorSyntaxIsIntercepted(TestCase): about the syntax change. """ - @mock.patch('search.controllers.advanced.SearchSession') + @mock.patch("search.controllers.advanced.SearchSession") def test_all_fields_search_contains_classic_syntax(self, mock_index): """User has entered a `surname_f` query in an all-fields term.""" - request_data = MultiDict({ - 'advanced': True, - 'terms-0-operator': 'AND', - 'terms-0-field': 'all', - 'terms-0-term': 'franklin_r', - 'size': 50, - 'order': '' - }) - mock_index.search.return_value = dict(metadata={}, results=[]) - + request_data = MultiDict( + { + "advanced": True, + "terms-0-operator": "AND", + "terms-0-field": "all", + "terms-0-term": "franklin_r", + "size": 50, + "order": "", + } + ) + mock_index.search.return_value = {"metadata": {}, "results": []} data, code, headers = advanced.search(request_data) - self.assertEqual(data['query'].terms[0].term, "franklin, r", - "The query should be rewritten.") - self.assertTrue(data['has_classic_format'], - "A flag denoting the syntax interception should be set" - " in the response context, so that a message may be" - " rendered in the template.") - - @mock.patch('search.controllers.advanced.SearchSession') + self.assertEqual( + data["query"].terms[0].term, + "franklin, r", + "The query should be rewritten.", + ) + self.assertTrue( + data["has_classic_format"], + "A flag denoting the syntax interception should be set" + " in the response context, so that a message may be" + " rendered in the template.", + ) + + @mock.patch("search.controllers.advanced.SearchSession") def test_author_search_contains_classic_syntax(self, mock_index): """User has entered a `surname_f` query in an author search.""" - request_data = MultiDict({ - 'advanced': True, - 'terms-0-operator': 'AND', - 'terms-0-field': 'author', - 'terms-0-term': 'franklin_r', - 'size': 50, - 'order': '' - }) - mock_index.search.return_value = dict(metadata={}, results=[]) + request_data = MultiDict( + { + "advanced": True, + "terms-0-operator": "AND", + "terms-0-field": "author", + "terms-0-term": "franklin_r", + "size": 50, + "order": "", + } + ) + mock_index.search.return_value = {"metadata": {}, "results": []} data, code, headers = advanced.search(request_data) - self.assertEqual(data['query'].terms[0].term, "franklin, r", - "The query should be rewritten.") - self.assertTrue(data['has_classic_format'], - "A flag denoting the syntax interception should be set" - " in the response context, so that a message may be" - " rendered in the template.") - - @mock.patch('search.controllers.advanced.SearchSession') + self.assertEqual( + data["query"].terms[0].term, + "franklin, r", + "The query should be rewritten.", + ) + self.assertTrue( + data["has_classic_format"], + "A flag denoting the syntax interception should be set" + " in the response context, so that a message may be" + " rendered in the template.", + ) + + @mock.patch("search.controllers.advanced.SearchSession") def test_all_fields_search_multiple_classic_syntax(self, mock_index): """User has entered a classic query with multiple authors.""" - request_data = MultiDict({ - 'advanced': True, - 'terms-0-operator': 'AND', - 'terms-0-field': 'all', - 'terms-0-term': 'j franklin_r hawking_s', - 'size': 50, - 'order': '' - }) - mock_index.search.return_value = dict(metadata={}, results=[]) + request_data = MultiDict( + { + "advanced": True, + "terms-0-operator": "AND", + "terms-0-field": "all", + "terms-0-term": "j franklin_r hawking_s", + "size": 50, + "order": "", + } + ) + mock_index.search.return_value = {"metadata": {}, "results": []} data, code, headers = advanced.search(request_data) - self.assertEqual(data['query'].terms[0].term, - "j franklin, r; hawking, s", - "The query should be rewritten.") - self.assertTrue(data['has_classic_format'], - "A flag denoting the syntax interception should be set" - " in the response context, so that a message may be" - " rendered in the template.") - - @mock.patch('search.controllers.advanced.SearchSession') + self.assertEqual( + data["query"].terms[0].term, + "j franklin, r; hawking, s", + "The query should be rewritten.", + ) + self.assertTrue( + data["has_classic_format"], + "A flag denoting the syntax interception should be set" + " in the response context, so that a message may be" + " rendered in the template.", + ) + + @mock.patch("search.controllers.advanced.SearchSession") def test_title_search_contains_classic_syntax(self, mock_index): """User has entered a `surname_f` query in a title search.""" - request_data = MultiDict({ - 'advanced': True, - 'terms-0-operator': 'AND', - 'terms-0-field': 'title', - 'terms-0-term': 'franklin_r', - 'size': 50, - 'order': '' - }) - mock_index.search.return_value = dict(metadata={}, results=[]) + request_data = MultiDict( + { + "advanced": True, + "terms-0-operator": "AND", + "terms-0-field": "title", + "terms-0-term": "franklin_r", + "size": 50, + "order": "", + } + ) + mock_index.search.return_value = {"metadata": {}, "results": []} data, code, headers = advanced.search(request_data) - self.assertEqual(data['query'].terms[0].term, "franklin_r", - "The query should not be rewritten.") - self.assertFalse(data['has_classic_format'], - "Flag should not be set, as no rewrite has occurred.") + self.assertEqual( + data["query"].terms[0].term, + "franklin_r", + "The query should not be rewritten.", + ) + self.assertFalse( + data["has_classic_format"], + "Flag should not be set, as no rewrite has occurred.", + ) diff --git a/search/controllers/api/__init__.py b/search/controllers/api/__init__.py index e42545ae..eec93ca3 100644 --- a/search/controllers/api/__init__.py +++ b/search/controllers/api/__init__.py @@ -1,28 +1,40 @@ """Controller for search API requests.""" -from typing import Tuple, Dict, Any, Optional, List -import re -from datetime import date, datetime -from dateutil.relativedelta import relativedelta -import dateutil.parser -from pytz import timezone import pytz +from http import HTTPStatus +from collections import defaultdict +from typing import Tuple, Dict, Any, Optional, List, Union +import dateutil.parser +from mypy_extensions import TypedDict +from werkzeug.datastructures import MultiDict +from werkzeug.exceptions import BadRequest, NotFound -from werkzeug.datastructures import MultiDict, ImmutableMultiDict -from werkzeug.exceptions import InternalServerError, BadRequest, NotFound -from flask import url_for - -from arxiv import status, taxonomy +from arxiv import taxonomy from arxiv.base import logging -from search.services import index, fulltext, metadata +from search import consts +from search.services import index from search.controllers.util import paginate -from ...domain import Query, APIQuery, FieldedSearchList, FieldedSearchTerm, \ - DateRange, ClassificationList, Classification, Document +from search.domain import ( + Query, + APIQuery, + FieldedSearchList, + FieldedSearchTerm, + DateRange, + Classification, + DocumentSet, + ClassicAPIQuery, +) + logger = logging.getLogger(__name__) -EASTERN = timezone('US/Eastern') + + +SearchResponseData = TypedDict( + "SearchResponseData", + {"results": DocumentSet, "query": Union[Query, ClassicAPIQuery]}, +) def search(params: MultiDict) -> Tuple[Dict[str, Any], int, Dict[str, Any]]: @@ -45,35 +57,46 @@ def search(params: MultiDict) -> Tuple[Dict[str, Any], int, Dict[str, Any]]: """ q = APIQuery() - # parse advanced classic-style queries + # Parse NG queries utilizing the Classic API syntax. + # This implementation parses the `query` parameter as if it were + # using the Classic endpoint's `search_query` parameter. It is meant + # as a migration pathway so that the URL and query structure aren't + # both changed at the same time by end users. + # TODO: Implement the NG API using the Classic API domain. + parsed_operators = ( + None # Default in the event that there is not a Classic query. + ) try: - parsed_terms = _parse_search_query(params.get('query', '')) + parsed_operators, parsed_terms = _parse_search_query( + params.get("query", "") + ) params = params.copy() for field, term in parsed_terms.items(): - params[field] = term + params.add(field, term) except ValueError: raise BadRequest(f"Improper syntax in query: {params.get('query')}") - # process fielded terms + # process fielded terms, using the operators above query_terms: List[Dict[str, Any]] = [] - terms = _get_fielded_terms(params, query_terms) + terms = _get_fielded_terms(params, query_terms, parsed_operators) + if terms is not None: q.terms = terms date_range = _get_date_params(params, query_terms) if date_range is not None: q.date_range = date_range - primary = params.get('primary_classification') + primary = params.get("primary_classification") if primary: - primary_classification = _get_classification(primary, - 'primary_classification', - query_terms) + primary_classification = _get_classification( + primary, "primary_classification", query_terms + ) q.primary_classification = primary_classification - secondaries = params.getlist('secondary_classification') + secondaries = params.getlist("secondary_classification") if secondaries: q.secondary_classification = [ - _get_classification(sec, 'secondary_classification', query_terms) + _get_classification(sec, "secondary_classification", query_terms) for sec in secondaries ] @@ -81,98 +104,15 @@ def search(params: MultiDict) -> Tuple[Dict[str, Any], int, Dict[str, Any]]: if include_fields: q.include_fields += include_fields - q = paginate(q, params) # type: ignore - document_set = index.SearchSession.search(q, highlight=False) # type: ignore - document_set['metadata']['query'] = query_terms - logger.debug('Got document set with %i results', - len(document_set['results'])) - return {'results': document_set, 'query': q}, status.HTTP_200_OK, {} - - -def classic_query(params: MultiDict) \ - -> Tuple[Dict[str, Any], int, Dict[str, Any]]: - """ - Handle a search request from the Clasic API. - - First, the method maps old request parameters to new parameters: - - search_query -> query - - start -> start - - max_results -> size - - Then the request is passed to :method:`search()` and returned. - - If ``id_list`` is specified in the parameters and ``search_query`` is - NOT specified, then each request is passed to :method:`paper()` and - results are aggregated. - - If ``id_list`` is specified AND ``search_query`` is also specified, - then the results from :method:`search()` are filtered by ``id_list``. - - Parameters - ---------- - params : :class:`MultiDict` - GET query parameters from the request. - - Returns - ------- - dict - Response data (to serialize). - int - HTTP status code. - dict - Extra headers for the response. - - Raises - ------ - :class:`BadRequest` - Raised when the search_query and id_list are not specified. - """ - params = params.copy() - raw_query = params.get('search_query') - - # parse id_list - id_list = params.get('id_list', '') - if id_list: - id_list = id_list.split(',') - else: - id_list = [] - - # error if neither search_query nor id_list are specified. - if not id_list and not raw_query: - raise BadRequest("Either a search_query or id_list must be specified" - " for the classic API.") - - if raw_query: - # migrate search_query -> query variable - params['query'] = raw_query - del params['search_query'] - - # pass to normal search, which will handle parsing - data, _, _ = search(params) - - if id_list and not raw_query: - # Process only id_lists. - # Note lack of error handling to implicitly propogate any errors. - # Classic API also errors if even one ID is malformed. - papers = [paper(paper_id) for paper_id in id_list] - - data, _, _ = zip(*papers) - results = [paper['results'] for paper in data] # type: ignore - data = { - 'results' : dict(results=results, metadata=dict()), # TODO: Aggregate search metadata - 'query' : APIQuery() # TODO: Specify query - } - - elif id_list and raw_query: - # Filter results based on id_list - results = [paper for paper in data['results']['results'] - if paper.paper_id in id_list or paper.paper_id_v in id_list] - data = { - 'results' : dict(results=results, metadata=dict()), # TODO: Aggregate search metadata - 'query' : APIQuery() # TODO: Specify query - } - - return data, status.HTTP_200_OK, {} + q = paginate(q, params) # type: ignore + document_set = index.SearchSession.search( # type: ignore + q, highlight=False + ) + document_set["metadata"]["query"] = query_terms + logger.debug( + "Got document set with %i results", len(document_set["results"]) + ) + return {"results": document_set, "query": q}, HTTPStatus.OK, {} def paper(paper_id: str) -> Tuple[Dict[str, Any], int, Dict[str, Any]]: @@ -200,43 +140,51 @@ def paper(paper_id: str) -> Tuple[Dict[str, Any], int, Dict[str, Any]]: """ try: - document = index.SearchSession.get_document(paper_id) # type: ignore - except index.DocumentNotFound as e: - logger.error('Document not found') - raise NotFound('No such document') from e - return {'results': document}, status.HTTP_200_OK, {} + document = index.SearchSession.current_session().get_document( + paper_id + ) # type: ignore + except index.DocumentNotFound as ex: + logger.error("Document not found") + raise NotFound("No such document") from ex + return {"results": document}, HTTPStatus.OK, {} def _get_include_fields(params: MultiDict, query_terms: List) -> List[str]: - include_fields: List[str] = params.getlist('include') + include_fields: List[str] = params.getlist("include") if include_fields: for field in include_fields: - query_terms.append({'parameter': 'include', 'value': field}) + query_terms.append({"parameter": "include", "value": field}) return include_fields return [] -def _get_fielded_terms(params: MultiDict, query_terms: List) \ - -> Optional[FieldedSearchList]: +def _get_fielded_terms( + params: MultiDict, + query_terms: List, + operators: Optional[Dict[str, Any]] = None, +) -> Optional[FieldedSearchList]: + if operators is None: + operators = defaultdict(default_factory=lambda: "AND") terms = FieldedSearchList() for field, _ in Query.SUPPORTED_FIELDS: values = params.getlist(field) for value in values: - query_terms.append({'parameter': field, 'value': value}) - terms.append(FieldedSearchTerm( # type: ignore - operator='AND', - field=field, - term=value - )) + query_terms.append({"parameter": field, "value": value}) + terms.append( + FieldedSearchTerm( # type: ignore + operator=operators[field], field=field, term=value + ) + ) if not terms: return None return terms -def _get_date_params(params: MultiDict, query_terms: List) \ - -> Optional[DateRange]: +def _get_date_params( + params: MultiDict, query_terms: List +) -> Optional[DateRange]: date_params = {} - for field in ['start_date', 'end_date']: + for field in ["start_date", "end_date"]: value = params.getlist(field) if not value: continue @@ -244,94 +192,111 @@ def _get_date_params(params: MultiDict, query_terms: List) \ dt = dateutil.parser.parse(value[0]) if not dt.tzinfo: dt = pytz.utc.localize(dt) - dt = dt.replace(tzinfo=EASTERN) + dt = dt.replace(tzinfo=consts.EASTERN) except ValueError: - raise BadRequest(f'Invalid datetime in {field}') + raise BadRequest(f"Invalid datetime in {field}") date_params[field] = dt - query_terms.append({'parameter': field, 'value': dt}) - if 'date_type' in params: - date_params['date_type'] = params.get('date_type') # type: ignore - query_terms.append({'parameter': 'date_type', - 'value': date_params['date_type']}) + query_terms.append({"parameter": field, "value": dt}) + if "date_type" in params: + date_params["date_type"] = params.get("date_type") # type: ignore + query_terms.append( + {"parameter": "date_type", "value": date_params["date_type"]} + ) if date_params: return DateRange(**date_params) # type: ignore return None -def _to_classification(value: str, query_terms: List) \ - -> Tuple[Classification, ...]: +def _to_classification(value: str) -> Tuple[Classification, ...]: clsns = [] if value in taxonomy.definitions.GROUPS: klass = taxonomy.Group - field = 'group' + field = "group" elif value in taxonomy.definitions.ARCHIVES: klass = taxonomy.Archive - field = 'archive' + field = "archive" elif value in taxonomy.definitions.CATEGORIES: klass = taxonomy.Category - field = 'category' + field = "category" else: - raise ValueError('not a valid classification') + raise ValueError("not a valid classification") cast_value = klass(value) - clsns.append(Classification(**{field: {'id': value}})) # type: ignore + clsns.append(Classification(**{field: {"id": value}})) # type: ignore if cast_value.unalias() != cast_value: - clsns.append(Classification(**{field: {'id': cast_value.unalias()}})) # type: ignore - if cast_value.canonical != cast_value \ - and cast_value.canonical != cast_value.unalias(): - clsns.append(Classification(**{field: {'id': cast_value.canonical}})) # type: ignore + clsns.append( + Classification( # type: ignore # noqa: E501 # fmt: off + **{field: {"id": cast_value.unalias()}} + ) + ) + if ( + cast_value.canonical != cast_value + and cast_value.canonical != cast_value.unalias() + ): + clsns.append( + Classification( # type: ignore # noqa: E501 # fmt: off + **{field: {"id": cast_value.canonical}} + ) + ) return tuple(clsns) -def _get_classification(value: str, field: str, query_terms: List) \ - -> Tuple[Classification, ...]: +def _get_classification( + value: str, field: str, query_terms: List +) -> Tuple[Classification, ...]: try: - clsns = _to_classification(value, query_terms) + clsns = _to_classification(value) except ValueError: - raise BadRequest(f'Not a valid classification term: {field}={value}') - query_terms.append({'parameter': field, 'value': value}) + raise BadRequest(f"Not a valid classification term: {field}={value}") + query_terms.append({"parameter": field, "value": value}) return clsns + SEARCH_QUERY_FIELDS = { - 'ti' : 'title', - 'au' : 'author', - 'abs' : 'abstract', - 'co' : 'comments', - 'jr' : 'journal_ref', - 'cat' : 'primary_classification', - 'rn' : 'report_number', - 'id' : 'paper_id', - 'all' : 'all' + "ti": "title", + "au": "author", + "abs": "abstract", + "co": "comments", + "jr": "journal_ref", + "cat": "primary_classification", + "rn": "report_number", + "id": "paper_id", + "all": "all", } -def _parse_search_query(query: str) -> Dict[str, Any]: - # TODO: Add support for booleans. +def _parse_search_query(query: str) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Parses a query into tuple of operators and parameters.""" new_query_params = {} + new_query_operators: Dict[str, str] = defaultdict( + default_factory=lambda: "AND" + ) terms = query.split() - expect_new = True - """expect_new handles quotation state.""" + expect_new = True # expect_new handles quotation state. + next_operator = "AND" # next_operator handles the operator state. for term in terms: - if expect_new and term in ["AND", "OR", "ANDNOT"]: - # TODO: Process booleans - pass + if expect_new and term in ["AND", "OR", "ANDNOT", "NOT"]: + if term == "ANDNOT": + term = "NOT" # Translate to NG representation. + next_operator = term elif expect_new: - field, term = term.split(':') + field, term = term.split(":") - # quotation handling + # Quotation handling. if term.startswith('"') and not term.endswith('"'): expect_new = False - term = term.replace('"', '') + term = term.replace('"', "") new_query_params[SEARCH_QUERY_FIELDS[field]] = term + new_query_operators[SEARCH_QUERY_FIELDS[field]] = next_operator else: - # quotation handling, expecting more terms + # If the term ends in a quote, we close the term and look for the + # next one. if term.endswith('"'): expect_new = True - term = term.replace('"', '') + term = term.replace('"', "") new_query_params[SEARCH_QUERY_FIELDS[field]] += " " + term - - return new_query_params + return new_query_operators, new_query_params diff --git a/search/controllers/api/tests.py b/search/controllers/api/tests.py deleted file mode 100644 index e157ddb2..00000000 --- a/search/controllers/api/tests.py +++ /dev/null @@ -1,144 +0,0 @@ -"""Tests for advanced search controller, :mod:`search.controllers.advanced`.""" - -from unittest import TestCase, mock -from datetime import date, datetime -from dateutil.relativedelta import relativedelta -from werkzeug import MultiDict -from werkzeug.exceptions import InternalServerError, BadRequest - -from arxiv import status - -from search.domain import Query, DateRange, FieldedSearchTerm, Classification,\ - AdvancedQuery, DocumentSet -from search.controllers import api -from search.domain import api as api_domain -from search.services.index import IndexConnectionError, QueryError - - -class TestAPISearch(TestCase): - """Tests for :func:`.api.search`.""" - - @mock.patch(f'{api.__name__}.index.SearchSession') - def test_no_params(self, mock_index): - """Request with no parameters.""" - params = MultiDict({}) - data, code, headers = api.search(params) - - self.assertEqual(code, status.HTTP_200_OK, "Returns 200 OK") - self.assertIn("results", data, "Results are returned") - self.assertIn("query", data, "Query object is returned") - expected_fields = api_domain.get_required_fields() \ - + api_domain.get_default_extra_fields() - self.assertEqual(set(data["query"].include_fields), - set(expected_fields), - "Default set of fields is included") - - @mock.patch(f'{api.__name__}.index.SearchSession') - def test_include_fields(self, mock_index): - """Request with specific fields included.""" - extra_fields = ['title', 'abstract', 'authors'] - params = MultiDict({'include': extra_fields}) - data, code, headers = api.search(params) - - self.assertEqual(code, status.HTTP_200_OK, "Returns 200 OK") - self.assertIn("results", data, "Results are returned") - self.assertIn("query", data, "Query object is returned") - expected_fields = api_domain.get_required_fields() + extra_fields - self.assertEqual(set(data["query"].include_fields), - set(expected_fields), - "Requested fields are included") - - @mock.patch(f'{api.__name__}.index.SearchSession') - def test_group_primary_classification(self, mock_index): - """Request with a group as primary classification.""" - group = 'grp_physics' - params = MultiDict({'primary_classification': group}) - data, code, headers = api.search(params) - - self.assertEqual(code, status.HTTP_200_OK, "Returns 200 OK") - query = mock_index.search.call_args[0][0] - self.assertEqual(len(query.primary_classification), 1) - self.assertEqual(query.primary_classification[0], - Classification(group={'id': group})) - - @mock.patch(f'{api.__name__}.index.SearchSession') - def test_archive_primary_classification(self, mock_index): - """Request with an archive as primary classification.""" - archive = 'physics' - params = MultiDict({'primary_classification': archive}) - data, code, headers = api.search(params) - - self.assertEqual(code, status.HTTP_200_OK, "Returns 200 OK") - query = mock_index.search.call_args[0][0] - self.assertEqual(len(query.primary_classification), 1) - self.assertEqual(query.primary_classification[0], - Classification(archive={'id': archive})) - - @mock.patch(f'{api.__name__}.index.SearchSession') - def test_archive_subsumed_classification(self, mock_index): - """Request with a subsumed archive as primary classification.""" - archive = 'chao-dyn' - params = MultiDict({'primary_classification': archive}) - data, code, headers = api.search(params) - - self.assertEqual(code, status.HTTP_200_OK, "Returns 200 OK") - query = mock_index.search.call_args[0][0] - self.assertEqual(len(query.primary_classification), 2) - self.assertEqual(query.primary_classification[0], - Classification(archive={'id': archive})) - self.assertEqual(query.primary_classification[1], - Classification(archive={'id': 'nlin.CD'}), - "The canonical archive is used instead") - - @mock.patch(f'{api.__name__}.index.SearchSession') - def test_category_primary_classification(self, mock_index): - """Request with a category as primary classification.""" - category = 'cs.DL' - params = MultiDict({'primary_classification': category}) - data, code, headers = api.search(params) - - self.assertEqual(code, status.HTTP_200_OK, "Returns 200 OK") - query = mock_index.search.call_args[0][0] - self.assertEqual(len(query.primary_classification), 1) - self.assertEqual(query.primary_classification[0], - Classification(category={'id': category})) - - @mock.patch(f'{api.__name__}.index.SearchSession') - def test_bad_classification(self, mock_index): - """Request with nonsense as primary classification.""" - params = MultiDict({'primary_classification': 'nonsense'}) - with self.assertRaises(BadRequest): - api.search(params) - - @mock.patch(f'{api.__name__}.index.SearchSession') - def test_with_start_date(self, mock_index): - """Request with dates specified.""" - params = MultiDict({'start_date': '1999-01-02'}) - data, code, headers = api.search(params) - - self.assertEqual(code, status.HTTP_200_OK, "Returns 200 OK") - query = mock_index.search.call_args[0][0] - self.assertIsNotNone(query.date_range) - self.assertEqual(query.date_range.start_date.year, 1999) - self.assertEqual(query.date_range.start_date.month, 1) - self.assertEqual(query.date_range.start_date.day, 2) - self.assertEqual(query.date_range.date_type, - DateRange.SUBMITTED_CURRENT, - "Submitted date of current version is the default") - - @mock.patch(f'{api.__name__}.index.SearchSession') - def test_with_end_dates_and_type(self, mock_index): - """Request with end date and date type specified.""" - params = MultiDict({'end_date': '1999-01-02', - 'date_type': 'announced_date_first'}) - data, code, headers = api.search(params) - - self.assertEqual(code, status.HTTP_200_OK, "Returns 200 OK") - query = mock_index.search.call_args[0][0] - self.assertIsNotNone(query.date_range) - self.assertEqual(query.date_range.end_date.year, 1999) - self.assertEqual(query.date_range.end_date.month, 1) - self.assertEqual(query.date_range.end_date.day, 2) - - self.assertEqual(query.date_range.date_type, - DateRange.ANNOUNCED) diff --git a/search/controllers/api/tests/__init__.py b/search/controllers/api/tests/__init__.py new file mode 100644 index 00000000..a7531d00 --- /dev/null +++ b/search/controllers/api/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for :mod:`.controllers.api`.""" diff --git a/search/controllers/api/tests/tests_api_search.py b/search/controllers/api/tests/tests_api_search.py new file mode 100644 index 00000000..2e3c0116 --- /dev/null +++ b/search/controllers/api/tests/tests_api_search.py @@ -0,0 +1,185 @@ +"""Tests for advanced search controller, :mod:`search.controllers.advanced`.""" + +from http import HTTPStatus +from unittest import TestCase, mock +from werkzeug.datastructures import MultiDict +from werkzeug.exceptions import BadRequest + +from search.controllers import api +from search.domain import api as api_domain +from search.domain import DateRange, Classification + + +class TestAPISearch(TestCase): + """Tests for :func:`.api.search`.""" + + @mock.patch(f"{api.__name__}.index.SearchSession") + def test_no_params(self, mock_index): + """Request with no parameters.""" + params = MultiDict({}) + data, code, headers = api.search(params) + + self.assertEqual(code, HTTPStatus.OK, "Returns 200 OK") + self.assertIn("results", data, "Results are returned") + self.assertIn("query", data, "Query object is returned") + expected_fields = ( + api_domain.get_required_fields() + + api_domain.get_default_extra_fields() + ) + self.assertEqual( + set(data["query"].include_fields), + set(expected_fields), + "Default set of fields is included", + ) + + @mock.patch(f"{api.__name__}.index.SearchSession") + def test_query_param(self, mock_index): + """Request with a query string. Tests conjuncts and quoted phrases.""" + params = MultiDict({"query": 'au:copernicus AND ti:"dark matter"'}) + data, code, headers = api.search(params) + + self.assertEqual(code, HTTPStatus.OK, "Returns 200 OK") + self.assertIn("results", data, "Results are returned") + self.assertIn("query", data, "Query object is returned") + expected_fields = ( + api_domain.get_required_fields() + + api_domain.get_default_extra_fields() + ) + self.assertEqual( + set(data["query"].include_fields), + set(expected_fields), + "Default set of fields is included", + ) + + @mock.patch(f"{api.__name__}.index.SearchSession") + def test_include_fields(self, mock_index): + """Request with specific fields included.""" + extra_fields = ["title", "abstract", "authors"] + params = MultiDict({"include": extra_fields}) + data, code, headers = api.search(params) + + self.assertEqual(code, HTTPStatus.OK, "Returns 200 OK") + self.assertIn("results", data, "Results are returned") + self.assertIn("query", data, "Query object is returned") + expected_fields = api_domain.get_required_fields() + extra_fields + self.assertEqual( + set(data["query"].include_fields), + set(expected_fields), + "Requested fields are included", + ) + + @mock.patch(f"{api.__name__}.index.SearchSession") + def test_group_primary_classification(self, mock_index): + """Request with a group as primary classification.""" + group = "grp_physics" + params = MultiDict({"primary_classification": group}) + data, code, headers = api.search(params) + + self.assertEqual(code, HTTPStatus.OK, "Returns 200 OK") + query = mock_index.search.call_args[0][0] + self.assertEqual(len(query.primary_classification), 1) + self.assertEqual( + query.primary_classification[0], + Classification(group={"id": group}), + ) + + @mock.patch(f"{api.__name__}.index.SearchSession") + def test_archive_primary_classification(self, mock_index): + """Request with an archive as primary classification.""" + archive = "physics" + params = MultiDict({"primary_classification": archive}) + data, code, headers = api.search(params) + + self.assertEqual(code, HTTPStatus.OK, "Returns 200 OK") + query = mock_index.search.call_args[0][0] + self.assertEqual(len(query.primary_classification), 1) + self.assertEqual( + query.primary_classification[0], + Classification(archive={"id": archive}), + ) + + @mock.patch(f"{api.__name__}.index.SearchSession") + def test_archive_subsumed_classification(self, mock_index): + """Request with a subsumed archive as primary classification.""" + archive = "chao-dyn" + params = MultiDict({"primary_classification": archive}) + data, code, headers = api.search(params) + + self.assertEqual(code, HTTPStatus.OK, "Returns 200 OK") + query = mock_index.search.call_args[0][0] + self.assertEqual(len(query.primary_classification), 2) + self.assertEqual( + query.primary_classification[0], + Classification(archive={"id": archive}), + ) + self.assertEqual( + query.primary_classification[1], + Classification(archive={"id": "nlin.CD"}), + "The canonical archive is used instead", + ) + + @mock.patch(f"{api.__name__}.index.SearchSession") + def test_category_primary_classification(self, mock_index): + """Request with a category as primary classification.""" + category = "cs.DL" + params = MultiDict({"primary_classification": category}) + data, code, headers = api.search(params) + + self.assertEqual(code, HTTPStatus.OK, "Returns 200 OK") + query = mock_index.search.call_args[0][0] + self.assertEqual(len(query.primary_classification), 1) + self.assertEqual( + query.primary_classification[0], + Classification(category={"id": category}), + ) + + @mock.patch(f"{api.__name__}.index.SearchSession") + def test_bad_classification(self, mock_index): + """Request with nonsense as primary classification.""" + params = MultiDict({"primary_classification": "nonsense"}) + with self.assertRaises(BadRequest): + api.search(params) + + @mock.patch(f"{api.__name__}.index.SearchSession") + def test_with_start_date(self, mock_index): + """Request with dates specified.""" + params = MultiDict({"start_date": "1999-01-02"}) + data, code, headers = api.search(params) + + self.assertEqual(code, HTTPStatus.OK, "Returns 200 OK") + query = mock_index.search.call_args[0][0] + self.assertIsNotNone(query.date_range) + self.assertEqual(query.date_range.start_date.year, 1999) + self.assertEqual(query.date_range.start_date.month, 1) + self.assertEqual(query.date_range.start_date.day, 2) + self.assertEqual( + query.date_range.date_type, + DateRange.SUBMITTED_CURRENT, + "Submitted date of current version is the default", + ) + + @mock.patch(f"{api.__name__}.index.SearchSession") + def test_with_end_dates_and_type(self, mock_index): + """Request with end date and date type specified.""" + params = MultiDict( + {"end_date": "1999-01-02", "date_type": "announced_date_first"} + ) + data, code, headers = api.search(params) + + self.assertEqual(code, HTTPStatus.OK, "Returns 200 OK") + query = mock_index.search.call_args[0][0] + self.assertIsNotNone(query.date_range) + self.assertEqual(query.date_range.end_date.year, 1999) + self.assertEqual(query.date_range.end_date.month, 1) + self.assertEqual(query.date_range.end_date.day, 2) + + self.assertEqual(query.date_range.date_type, DateRange.ANNOUNCED) + + +class TestPaper(TestCase): + """Tests for :func:`.api.paper`.""" + + @mock.patch(f"{api.__name__}.index.SearchSession") + def test_paper(self, mock_index): + """Request with single parameter paper.""" + _, _, _ = api.paper("1234.56789") diff --git a/search/controllers/classic_api/__init__.py b/search/controllers/classic_api/__init__.py new file mode 100644 index 00000000..d3dafeaf --- /dev/null +++ b/search/controllers/classic_api/__init__.py @@ -0,0 +1,203 @@ +"""Controller for classic arXiv API requests.""" + +from http import HTTPStatus +from typing import Tuple, Dict, Any + +from werkzeug.datastructures import MultiDict +from werkzeug.exceptions import BadRequest, NotFound + +from arxiv.base import logging +from arxiv.identifier import parse_arxiv_id + +from search.services import index +from search.errors import ValidationError +from search.domain import ( + SortDirection, + SortBy, + SortOrder, + DocumentSet, + ClassicAPIQuery, + ClassicSearchResponseData, +) + +logger = logging.getLogger(__name__) + + +def query( + params: MultiDict, +) -> Tuple[ClassicSearchResponseData, HTTPStatus, Dict[str, Any]]: + """ + Handle a search request from the Clasic API. + + First, the method maps old request parameters to new parameters: + - search_query -> query + - start -> start + - max_results -> size + + Then the request is passed to :method:`search()` and returned. + + If ``id_list`` is specified in the parameters and ``search_query`` is + NOT specified, then each request is passed to :method:`paper()` and + results are aggregated. + + If ``id_list`` is specified AND ``search_query`` is also specified, + then the results from :method:`search()` are filtered by ``id_list``. + + Parameters + ---------- + params : :class:`MultiDict` + GET query parameters from the request. + + Returns + ------- + SearchResponseData + Response data (to serialize). + int + HTTP status code. + dict + Extra headers for the response. + + Raises + ------ + :class:`BadRequest` + Raised when the search_query and id_list are not specified. + + """ + params = params.copy() + + # Parse classic search query. + search_query = params.get("search_query", None) + + # Parse id_list. + id_list = params.get("id_list", "") + if id_list: + id_list = id_list.split(",") + # Check arxiv id validity + for arxiv_id in id_list: + try: + parse_arxiv_id(arxiv_id) + except ValueError: + raise ValidationError( + message="incorrect id format for {}".format(arxiv_id), + link=( + "http://arxiv.org/api/errors#" + "incorrect_id_format_for_{}" + ).format(arxiv_id), + ) + else: + id_list = None + + # Parse result size. + try: + max_results = int(params.get("max_results", 10)) + except ValueError: + raise ValidationError( + message="max_results must be an integer", + link="http://arxiv.org/api/errors#max_results_must_be_an_integer", + ) + if max_results < 0: + raise ValidationError( + message="max_results must be non-negative", + link="http://arxiv.org/api/errors#max_results_must_be_" + "non-negative", + ) + + # Parse result start point. + try: + start = int(params.get("start", 0)) + except ValueError: + raise ValidationError( + message="start must be an integer", + link="http://arxiv.org/api/errors#start_must_be_an_integer", + ) + if start < 0: + raise ValidationError( + message="start must be non-negative", + link="http://arxiv.org/api/errors#start_must_be_non-negative", + ) + + # sort by and sort order + value = params.get("sortBy", SortBy.relevance) + try: + sort_by = SortBy(value) + except ValueError: + raise ValidationError( + message=f"sortBy must be in: {', '.join(SortBy)}", + link="https://arxiv.org/help/api/user-manual#sort", + ) + value = params.get("sortOrder", SortDirection.descending) + try: + sort_direction = SortDirection(value) + except ValueError: + raise ValidationError( + message=f"sortOrder must be in: {', '.join(SortDirection)}", + link="https://arxiv.org/help/api/user-manual#sort", + ) + + try: + classic_query = ClassicAPIQuery( + order=SortOrder(by=sort_by, direction=sort_direction), + search_query=search_query, + id_list=id_list, + size=max_results, + page_start=start, + ) + except ValueError: + raise BadRequest( + "Either a search_query or id_list must be specified" + " for the classic API." + ) + + # pass to search indexer, which will handle parsing + document_set: DocumentSet = index.SearchSession.current_session().search( + classic_query + ) + logger.debug( + "Got document set with %i results", len(document_set["results"]) + ) + + return ( + ClassicSearchResponseData(results=document_set, query=classic_query), + HTTPStatus.OK, + {}, + ) + + +def paper( + paper_id: str, +) -> Tuple[ClassicSearchResponseData, HTTPStatus, Dict[str, Any]]: + """ + Handle a request for paper metadata from the API. + + Parameters + ---------- + paper_id : str + arXiv paper ID for the requested paper. + + Returns + ------- + dict + Response data (to serialize). + int + HTTP status code. + dict + Extra headers for the response. + + Raises + ------ + :class:`NotFound` + Raised when there is no document with the provided paper ID. + + """ + try: + document = index.SearchSession.current_session().get_document( + paper_id + ) # type: ignore + except index.DocumentNotFound as ex: + logger.error("Document not found") + raise NotFound("No such document") from ex + return ( + ClassicSearchResponseData(results=document), # type: ignore + HTTPStatus.OK, + {}, + ) # type: ignore diff --git a/search/controllers/classic_api/tests/__init__.py b/search/controllers/classic_api/tests/__init__.py new file mode 100644 index 00000000..e8e03362 --- /dev/null +++ b/search/controllers/classic_api/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for arXiv classic API controllers.""" diff --git a/search/controllers/classic_api/tests/test_classic_api_search.py b/search/controllers/classic_api/tests/test_classic_api_search.py new file mode 100644 index 00000000..bf211e30 --- /dev/null +++ b/search/controllers/classic_api/tests/test_classic_api_search.py @@ -0,0 +1,134 @@ +"""Tests for classic API search.""" +from http import HTTPStatus +from unittest import TestCase, mock +from werkzeug.datastructures import MultiDict +from werkzeug.exceptions import BadRequest + +from search import domain +from search.errors import ValidationError +from search.controllers import classic_api + + +class TestClassicAPISearch(TestCase): + """Tests for :func:`.classic_api.query`.""" + + @mock.patch(f"{classic_api.__name__}.index.SearchSession") + def test_no_params(self, mock_index): + """Request with no parameters.""" + params = MultiDict({}) + with self.assertRaises(BadRequest): + _, _, _ = classic_api.query(params) + + @mock.patch(f"{classic_api.__name__}.index.SearchSession") + def test_classic_query(self, mock_index): + """Request with search_query.""" + params = MultiDict({"search_query": "au:Copernicus"}) + + data, code, headers = classic_api.query(params) + self.assertEqual(code, HTTPStatus.OK, "Returns 200 OK") + self.assertIsNotNone(data.results, "Results are returned") + self.assertIsNotNone(data.query, "Query object is returned") + + @mock.patch(f"{classic_api.__name__}.index.SearchSession") + def test_classic_query_with_quotes(self, mock_index): + """Request with search_query that includes a quoted phrase.""" + params = MultiDict({"search_query": 'ti:"dark matter"'}) + + data, code, headers = classic_api.query(params) + self.assertEqual(code, HTTPStatus.OK, "Returns 200 OK") + self.assertIsNotNone(data.results, "Results are returned") + self.assertIsNotNone(data.query, "Query object is returned") + + @mock.patch(f"{classic_api.__name__}.index.SearchSession") + def test_classic_id_list(self, mock_index): + """Request with multi-element id_list with (un)versioned ids.""" + params = MultiDict({"id_list": "1234.56789,1234.56789v3"}) + + data, code, headers = classic_api.query(params) + self.assertEqual(code, HTTPStatus.OK, "Returns 200 OK") + self.assertIsNotNone(data.results, "Results are returned") + self.assertIsNotNone(data.query, "Query object is returned") + + @mock.patch(f"{classic_api.__name__}.index.SearchSession") + def test_classic_start(self, mock_index): + # Default value + params = MultiDict({"search_query": "au:Copernicus"}) + data, _, _ = classic_api.query(params) + self.assertEqual(data.query.page_start, 0) + # Valid value + params = MultiDict({"search_query": "au:Copernicus", "start": "50"}) + data, _, _ = classic_api.query(params) + self.assertEqual(data.query.page_start, 50) + # Invalid value + params = MultiDict({"search_query": "au:Copernicus", "start": "-1"}) + with self.assertRaises(ValidationError): + data, _, _ = classic_api.query(params) + params = MultiDict({"search_query": "au:Copernicus", "start": "foo"}) + with self.assertRaises(ValidationError): + data, _, _ = classic_api.query(params) + + @mock.patch(f"{classic_api.__name__}.index.SearchSession") + def test_classic_max_result(self, mock_index): + # Default value + params = MultiDict({"search_query": "au:Copernicus"}) + data, _, _ = classic_api.query(params) + self.assertEqual(data.query.size, 10) + # Valid value + params = MultiDict( + {"search_query": "au:Copernicus", "max_results": "50"} + ) + data, _, _ = classic_api.query(params) + self.assertEqual(data.query.size, 50) + # Invalid value + params = MultiDict( + {"search_query": "au:Copernicus", "max_results": "-1"} + ) + with self.assertRaises(ValidationError): + _, _, _ = classic_api.query(params) + params = MultiDict( + {"search_query": "au:Copernicus", "max_results": "foo"} + ) + with self.assertRaises(ValidationError): + _, _, _ = classic_api.query(params) + + @mock.patch(f"{classic_api.__name__}.index.SearchSession") + def test_classic_sort_by(self, mock_index): + # Default value + params = MultiDict({"search_query": "au:Copernicus"}) + data, _, _ = classic_api.query(params) + self.assertEqual(data.query.order.by, domain.SortBy.relevance) + # Valid value + for value in domain.SortBy: + params = MultiDict( + {"search_query": "au:Copernicus", "sortBy": f"{value}"} + ) + data, _, _ = classic_api.query(params) + self.assertEqual(data.query.order.by, value) + + # Invalid value + params = MultiDict({"search_query": "au:Copernicus", "sortBy": "foo"}) + with self.assertRaises(ValidationError): + data, _, _ = classic_api.query(params) + + @mock.patch(f"{classic_api.__name__}.index.SearchSession") + def test_classic_sort_order(self, mock_index): + # Default value + params = MultiDict({"search_query": "au:Copernicus"}) + data, _, _ = classic_api.query(params) + self.assertEqual( + data.query.order.direction, domain.SortDirection.descending + ) + # Valid value + for value in domain.SortDirection: + params = MultiDict( + {"search_query": "au:Copernicus", "sortOrder": f"{value}"} + ) + data, _, _ = classic_api.query(params) + self.assertEqual(data.query.order.direction, value) + + # Invalid value + params = MultiDict( + {"search_query": "au:Copernicus", "sortOrder": "foo"} + ) + with self.assertRaises(ValidationError): + data, _, _ = classic_api.query(params) diff --git a/search/controllers/simple/__init__.py b/search/controllers/simple/__init__.py index e930adc4..8647b5ae 100644 --- a/search/controllers/simple/__init__.py +++ b/search/controllers/simple/__init__.py @@ -7,30 +7,34 @@ error messages for the user. """ +from http import HTTPStatus from typing import Tuple, Dict, Any, Optional, List -from werkzeug.exceptions import InternalServerError, NotFound, BadRequest -from werkzeug import MultiDict, ImmutableMultiDict from flask import url_for +from werkzeug.exceptions import InternalServerError, NotFound, BadRequest +from werkzeug.datastructures import MultiDict, ImmutableMultiDict -from arxiv import status, identifier, taxonomy - +from arxiv import identifier from arxiv.base import logging -from search.services import index, fulltext, metadata, SearchSession -from search.domain import Query, SimpleQuery, asdict, Classification, \ - ClassificationList +from search.services import index, SearchSession +from search.controllers.simple.forms import SimpleSearchForm from search.controllers.util import paginate, catch_underscore_syntax +from search.domain import ( + Query, + SimpleQuery, + Classification, + ClassificationList, +) -from .forms import SimpleSearchForm -# from search.routes.ui import external_url_builder logger = logging.getLogger(__name__) Response = Tuple[Dict[str, Any], int, Dict[str, Any]] -def search(request_params: MultiDict, - archives: Optional[List[str]] = None) -> Response: +def search( + request_params: MultiDict, archives: Optional[List[str]] = None +) -> Response: """ Perform a simple search. @@ -63,64 +67,72 @@ def search(request_params: MultiDict, """ if archives is not None and len(archives) == 0: - raise NotFound('No such archive') + raise NotFound("No such archive") # We may need to intervene on the request parameters, so we'll # reinstantiate as a mutable MultiDict. if isinstance(request_params, ImmutableMultiDict): request_params = MultiDict(request_params.items(multi=True)) - logger.debug('simple search form') + logger.debug("simple search form") response_data = {} # type: Dict[str, Any] - logger.debug('simple search request') - if 'query' in request_params: + logger.debug("simple search request") + if "query" in request_params: try: # first check if the URL includes an arXiv ID arxiv_id: Optional[str] = identifier.parse_arxiv_id( - request_params['query'] + request_params["query"] ) # If so, redirect. logger.debug(f"got arXiv ID: {arxiv_id}") - except ValueError as e: - logger.debug('No arXiv ID detected; fall back to form') + except ValueError: + logger.debug("No arXiv ID detected; fall back to form") arxiv_id = None else: arxiv_id = None if arxiv_id: - headers = {'Location': url_for('abs_by_id', paper_id=arxiv_id)} - return {}, status.HTTP_301_MOVED_PERMANENTLY, headers + headers = {"Location": url_for("abs_by_id", paper_id=arxiv_id)} + return {}, HTTPStatus.MOVED_PERMANENTLY, headers # Here we intervene on the user's query to look for holdouts from the # classic search system's author indexing syntax (surname_f). We # rewrite with a comma, and show a warning to the user about the # change. - response_data['has_classic_format'] = False - if 'searchtype' in request_params and 'query' in request_params: - if request_params['searchtype'] in ['author', 'all']: - _query, _classic = catch_underscore_syntax(request_params['query']) - response_data['has_classic_format'] = _classic - request_params['query'] = _query + response_data["has_classic_format"] = False + if "searchtype" in request_params and "query" in request_params: + if request_params["searchtype"] in ["author", "all"]: + _query, _classic = catch_underscore_syntax(request_params["query"]) + response_data["has_classic_format"] = _classic + request_params["query"] = _query # Fall back to form-based search. form = SimpleSearchForm(request_params) if form.query.data: # Temporary workaround to support classic help search - if form.searchtype.data == 'help': - return {}, status.HTTP_301_MOVED_PERMANENTLY,\ - {'Location': f'/help/search?q={form.query.data}'} + if form.searchtype.data == "help": + return ( + {}, + HTTPStatus.MOVED_PERMANENTLY, + {"Location": f"/help/search?q={form.query.data}"}, + ) # Support classic "expeirmental" search - elif form.searchtype.data == 'full_text': - return {}, status.HTTP_301_MOVED_PERMANENTLY,\ - {'Location': 'http://search.arxiv.org:8081/' - f'?in=&query={form.query.data}'} + elif form.searchtype.data == "full_text": + return ( + {}, + HTTPStatus.MOVED_PERMANENTLY, + { + "Location": "http://search.arxiv.org:8081/" + f"?in=&query={form.query.data}" + }, + ) q: Optional[Query] if form.validate(): - logger.debug('form is valid') + logger.debug("form is valid") q = _query_from_form(form) if archives is not None: @@ -134,49 +146,50 @@ def search(request_params: MultiDict, # template rendering, so they get added directly to the # response content.asdict response_data.update(SearchSession.search(q)) # type: ignore - except index.IndexConnectionError as e: + except index.IndexConnectionError as ex: # There was a (hopefully transient) connection problem. Either # this will clear up relatively quickly (next request), or # there is a more serious outage. - logger.error('IndexConnectionError: %s', e) + logger.error("IndexConnectionError: %s", ex) raise InternalServerError( "There was a problem connecting to the search index. This is " "quite likely a transient issue, so please try your search " "again. If this problem persists, please report it to " "help@arxiv.org." - ) from e - except index.QueryError as e: + ) from ex + except index.QueryError as ex: # Base exception routers should pick this up and show bug page. - logger.error('QueryError: %s', e) + logger.error("QueryError: %s", ex) raise InternalServerError( "There was a problem executing your query. Please try your " "search again. If this problem persists, please report it to " "help@arxiv.org." - ) from e - except index.OutsideAllowedRange as e: + ) from ex + except index.OutsideAllowedRange as ex: raise BadRequest( "Hello clever friend. You can't get results in that range" " right now." - ) from e + ) from ex - except Exception as e: - logger.error('Unhandled exception: %s', str(e)) + except Exception as ex: + logger.error("Unhandled exception: %s", str(ex)) raise else: - logger.debug('form is invalid: %s', str(form.errors)) - if 'order' in form.errors or 'size' in form.errors: + logger.debug("form is invalid: %s", str(form.errors)) + if "order" in form.errors or "size" in form.errors: # It's likely that the user tried to set these parameters manually, # or that the search originated from somewhere else (and was # configured incorrectly). - simple_url = url_for('ui.search') + simple_url = url_for("ui.search") raise BadRequest( f"It looks like there's something odd about your search" f" request. Please try starting" - f" over.") + f" over." + ) q = None - response_data['query'] = q - response_data['form'] = form - return response_data, status.HTTP_200_OK, {} + response_data["query"] = q + response_data["form"] = form + return response_data, HTTPStatus.OK, {} def retrieve_document(document_id: str) -> Response: @@ -207,29 +220,29 @@ def retrieve_document(document_id: str) -> Response: """ try: result = SearchSession.get_document(document_id) # type: ignore - except index.IndexConnectionError as e: + except index.IndexConnectionError as ex: # There was a (hopefully transient) connection problem. Either # this will clear up relatively quickly (next request), or # there is a more serious outage. - logger.error('IndexConnectionError: %s', e) + logger.error("IndexConnectionError: %s", ex) raise InternalServerError( "There was a problem connecting to the search index. This is " "quite likely a transient issue, so please try your search " "again. If this problem persists, please report it to " "help@arxiv.org." - ) from e - except index.QueryError as e: + ) from ex + except index.QueryError as ex: # Base exception routers should pick this up and show bug page. - logger.error('QueryError: %s', e) + logger.error("QueryError: %s", ex) raise InternalServerError( "There was a problem executing your query. Please try your " "search again. If this problem persists, please report it to " "help@arxiv.org." - ) from e - except index.DocumentNotFound as e: - logger.error('DocumentNotFound: %s', e) - raise NotFound(f"Could not find a paper with id {document_id}") from e - return {'document': result}, status.HTTP_200_OK, {} + ) from ex + except index.DocumentNotFound as ex: + logger.error("DocumentNotFound: %s", ex) + raise NotFound(f"Could not find a paper with id {document_id}") from ex + return {"document": result}, HTTPStatus.OK, {} def _update_with_archives(q: SimpleQuery, archives: List[str]) -> SimpleQuery: @@ -245,11 +258,13 @@ def _update_with_archives(q: SimpleQuery, archives: List[str]) -> SimpleQuery: ------- :class:`SimpleQuery` """ - logger.debug('Search within %s', archives) - q.classification = ClassificationList([ - Classification(archive={'id': archive}) # type: ignore - for archive in archives - ]) + logger.debug("Search within %s", archives) + q.classification = ClassificationList( + [ + Classification(archive={"id": archive}) # type: ignore + for archive in archives + ] + ) return q @@ -272,6 +287,6 @@ def _query_from_form(form: SimpleSearchForm) -> SimpleQuery: q.value = form.query.data q.hide_abstracts = form.abstracts.data == form.HIDE_ABSTRACTS order = form.order.data - if order and order != 'None': + if order and order != "None": q.order = order return q diff --git a/search/controllers/simple/forms.py b/search/controllers/simple/forms.py index 2d5b6049..deaf8f9c 100644 --- a/search/controllers/simple/forms.py +++ b/search/controllers/simple/forms.py @@ -1,52 +1,59 @@ """Provides form rendering and validation for the simple search feature.""" -from datetime import date +from wtforms import Form, StringField, SelectField, validators, RadioField -from wtforms import Form, BooleanField, StringField, SelectField, validators, \ - FormField, SelectMultipleField, DateField, ValidationError, FieldList, \ - widgets, RadioField -from wtforms.fields import HiddenField - -from search.controllers.util import does_not_start_with_wildcard, \ - has_balanced_quotes, strip_white_space -from ...domain import Query +from search.controllers.util import ( + does_not_start_with_wildcard, + has_balanced_quotes, + strip_white_space, +) +from search.domain import Query class SimpleSearchForm(Form): """Provides a simple field-query search form.""" searchtype = SelectField("Field", choices=Query.SUPPORTED_FIELDS) - query = StringField('Search or Article ID', - filters=[strip_white_space], - validators=[does_not_start_with_wildcard, - has_balanced_quotes]) - size = SelectField('results per page', default=50, choices=[ - ('25', '25'), - ('50', '50'), - ('100', '100'), - ('200', '200') - ]) - order = SelectField('Sort results by', choices=[ - ('-announced_date_first', 'Announcement date (newest first)'), - ('announced_date_first', 'Announcement date (oldest first)'), - ('-submitted_date', 'Submission date (newest first)'), - ('submitted_date', 'Submission date (oldest first)'), - ('', 'Relevance') - ], validators=[validators.Optional()], default='-announced_date_first') - - HIDE_ABSTRACTS = 'hide' - SHOW_ABSTRACTS = 'show' - - abstracts = RadioField('Abstracts', choices=[ - (SHOW_ABSTRACTS, 'Show abstracts'), - (HIDE_ABSTRACTS, 'Hide abstracts') - ], default=SHOW_ABSTRACTS) + query = StringField( + "Search or Article ID", + filters=[strip_white_space], + validators=[does_not_start_with_wildcard, has_balanced_quotes], + ) + size = SelectField( + "results per page", + default=50, + choices=[("25", "25"), ("50", "50"), ("100", "100"), ("200", "200")], + ) + order = SelectField( + "Sort results by", + choices=[ + ("-announced_date_first", "Announcement date (newest first)"), + ("announced_date_first", "Announcement date (oldest first)"), + ("-submitted_date", "Submission date (newest first)"), + ("submitted_date", "Submission date (oldest first)"), + ("", "Relevance"), + ], + validators=[validators.Optional()], + default="-announced_date_first", + ) + + HIDE_ABSTRACTS = "hide" + SHOW_ABSTRACTS = "show" + + abstracts = RadioField( + "Abstracts", + choices=[ + (SHOW_ABSTRACTS, "Show abstracts"), + (HIDE_ABSTRACTS, "Hide abstracts"), + ], + default=SHOW_ABSTRACTS, + ) def validate_query(form: Form, field: StringField) -> None: """Validate the length of the querystring, if searchtype is set.""" - if form.searchtype.data is None or form.searchtype.data == 'None': + if form.searchtype.data is None or form.searchtype.data == "None": return if not form.query.data or len(form.query.data) < 1: raise validators.ValidationError( - 'Field must be at least 1 character long.' + "Field must be at least 1 character long." ) diff --git a/search/controllers/simple/tests.py b/search/controllers/simple/tests.py index 3b66b67a..4ae1184a 100644 --- a/search/controllers/simple/tests.py +++ b/search/controllers/simple/tests.py @@ -1,174 +1,194 @@ """Tests for simple search controller, :mod:`search.controllers.simple`.""" +from http import HTTPStatus from unittest import TestCase, mock -from datetime import date, datetime -from dateutil.relativedelta import relativedelta -from werkzeug import MultiDict -from werkzeug.exceptions import InternalServerError, NotFound, BadRequest - -from arxiv import status -from search.domain import Query, DateRange, SimpleQuery, DocumentSet +from werkzeug.datastructures import MultiDict +from werkzeug.exceptions import InternalServerError, NotFound, BadRequest +from search.domain import SimpleQuery from search.controllers import simple from search.controllers.simple.forms import SimpleSearchForm - -from search.services.index import IndexConnectionError, QueryError, \ - DocumentNotFound +from search.services.index import ( + IndexConnectionError, + QueryError, + DocumentNotFound, +) class TestRetrieveDocument(TestCase): """Tests for :func:`.simple.retrieve_document`.""" - @mock.patch('search.controllers.simple.SearchSession') + @mock.patch("search.controllers.simple.SearchSession") def test_encounters_queryerror(self, mock_index): """There is a bug in the index or query.""" + def _raiseQueryError(*args, **kwargs): - raise QueryError('What now') + raise QueryError("What now") mock_index.get_document.side_effect = _raiseQueryError with self.assertRaises(InternalServerError): try: response_data, code, headers = simple.retrieve_document(1) - except QueryError as e: - self.fail("QueryError should be handled (caught %s)" % e) + except QueryError as ex: + self.fail("QueryError should be handled (caught %s)" % ex) - self.assertEqual(mock_index.get_document.call_count, 1, - "A search should be attempted") + self.assertEqual( + mock_index.get_document.call_count, + 1, + "A search should be attempted", + ) - @mock.patch('search.controllers.simple.SearchSession') + @mock.patch("search.controllers.simple.SearchSession") def test_index_raises_connection_exception(self, mock_index): """Index service raises a IndexConnectionError.""" mock_index.get_document.side_effect = IndexConnectionError with self.assertRaises(InternalServerError): - response_data, code, headers = simple.retrieve_document('124.5678') - self.assertEqual(mock_index.get_document.call_count, 1, - "A search should be attempted") + response_data, code, headers = simple.retrieve_document("124.5678") + self.assertEqual( + mock_index.get_document.call_count, + 1, + "A search should be attempted", + ) call_args, call_kwargs = mock_index.get_document.call_args self.assertIsInstance(call_args[0], str, "arXiv ID is passed") # self.assertEqual(code, status.HTTP_500_INTERNAL_SERVER_ERROR) - @mock.patch('search.controllers.simple.SearchSession') + @mock.patch("search.controllers.simple.SearchSession") def test_document_not_found(self, mock_index): """The document is not found.""" + def _raiseDocumentNotFound(*args, **kwargs): - raise DocumentNotFound('What now') + raise DocumentNotFound("What now") mock_index.get_document.side_effect = _raiseDocumentNotFound with self.assertRaises(NotFound): try: response_data, code, headers = simple.retrieve_document(1) - except DocumentNotFound as e: - self.fail("DocumentNotFound should be handled (caught %s)" % e) + except DocumentNotFound as ex: + self.fail( + "DocumentNotFound should be handled (caught %s)" % ex + ) - self.assertEqual(mock_index.get_document.call_count, 1, - "A search should be attempted") + self.assertEqual( + mock_index.get_document.call_count, + 1, + "A search should be attempted", + ) class TestSearchController(TestCase): """Tests for :func:`.simple.search`.""" - @mock.patch('search.controllers.simple.url_for', - lambda *a, **k: f'https://arxiv.org/{k["paper_id"]}') - @mock.patch('search.controllers.simple.SearchSession') + @mock.patch( + "search.controllers.simple.url_for", + lambda *a, **k: f'https://arxiv.org/{k["paper_id"]}', + ) + @mock.patch("search.controllers.simple.SearchSession") def test_arxiv_id(self, mock_index): """Query parameter contains an arXiv ID.""" - request_data = MultiDict({'query': '1702.00123'}) + request_data = MultiDict({"query": "1702.00123"}) response_data, code, headers = simple.search(request_data) - self.assertEqual(code, status.HTTP_301_MOVED_PERMANENTLY, - "Response should be a 301 redirect.") - self.assertIn('Location', headers, "Location header should be set") - - self.assertEqual(mock_index.search.call_count, 0, - "No search should be attempted") - - @mock.patch('search.controllers.simple.SearchSession') + self.assertEqual( + code, + HTTPStatus.MOVED_PERMANENTLY, + "Response should be a 301 redirect.", + ) + self.assertIn("Location", headers, "Location header should be set") + + self.assertEqual( + mock_index.search.call_count, 0, "No search should be attempted" + ) + + @mock.patch("search.controllers.simple.SearchSession") def test_no_form_data(self, mock_index): """No form data has been submitted.""" request_data = MultiDict() response_data, code, headers = simple.search(request_data) - self.assertEqual(code, status.HTTP_200_OK, "Response should be OK.") + self.assertEqual(code, HTTPStatus.OK, "Response should be OK.") - self.assertIn('form', response_data, "Response should include form.") + self.assertIn("form", response_data, "Response should include form.") - self.assertEqual(mock_index.search.call_count, 0, - "No search should be attempted") + self.assertEqual( + mock_index.search.call_count, 0, "No search should be attempted" + ) - @mock.patch('search.controllers.simple.SearchSession') + @mock.patch("search.controllers.simple.SearchSession") def test_single_field_term(self, mock_index): """Form data are present.""" - mock_index.search.return_value = dict(metadata={}, results=[]) - request_data = MultiDict({ - 'searchtype': 'title', - 'query': 'foo title' - }) + mock_index.search.return_value = {"metadata": {}, "results": []} + request_data = MultiDict({"searchtype": "title", "query": "foo title"}) response_data, code, headers = simple.search(request_data) - self.assertEqual(mock_index.search.call_count, 1, - "A search should be attempted") + self.assertEqual( + mock_index.search.call_count, 1, "A search should be attempted" + ) call_args, call_kwargs = mock_index.search.call_args - self.assertIsInstance(call_args[0], SimpleQuery, - "An SimpleQuery is passed to the search index") - self.assertEqual(code, status.HTTP_200_OK, "Response should be OK.") - - @mock.patch('search.controllers.simple.SearchSession') + self.assertIsInstance( + call_args[0], + SimpleQuery, + "An SimpleQuery is passed to the search index", + ) + self.assertEqual(code, HTTPStatus.OK, "Response should be OK.") + + @mock.patch("search.controllers.simple.SearchSession") def test_invalid_data(self, mock_index): """Form data are invalid.""" - request_data = MultiDict({ - 'searchtype': 'title' - }) + request_data = MultiDict({"searchtype": "title"}) response_data, code, headers = simple.search(request_data) - self.assertEqual(code, status.HTTP_200_OK, "Response should be OK.") + self.assertEqual(code, HTTPStatus.OK, "Response should be OK.") - self.assertIn('form', response_data, "Response should include form.") + self.assertIn("form", response_data, "Response should include form.") - self.assertEqual(mock_index.search.call_count, 0, - "No search should be attempted") + self.assertEqual( + mock_index.search.call_count, 0, "No search should be attempted" + ) - @mock.patch('search.controllers.simple.SearchSession') + @mock.patch("search.controllers.simple.SearchSession") def test_index_raises_connection_exception(self, mock_index): """Index service raises a IndexConnectionError.""" + def _raiseIndexConnectionError(*args, **kwargs): - raise IndexConnectionError('What now') + raise IndexConnectionError("What now") mock_index.search.side_effect = _raiseIndexConnectionError - request_data = MultiDict({ - 'searchtype': 'title', - 'query': 'foo title' - }) + request_data = MultiDict({"searchtype": "title", "query": "foo title"}) with self.assertRaises(InternalServerError): - response_data, code, headers = simple.search(request_data) + _, _, _ = simple.search(request_data) - self.assertEqual(mock_index.search.call_count, 1, - "A search should be attempted") + self.assertEqual( + mock_index.search.call_count, 1, "A search should be attempted" + ) call_args, call_kwargs = mock_index.search.call_args - self.assertIsInstance(call_args[0], SimpleQuery, - "An SimpleQuery is passed to the search index") + self.assertIsInstance( + call_args[0], + SimpleQuery, + "An SimpleQuery is passed to the search index", + ) - @mock.patch('search.controllers.simple.SearchSession') + @mock.patch("search.controllers.simple.SearchSession") def test_index_raises_query_error(self, mock_index): """Index service raises a QueryError.""" + def _raiseQueryError(*args, **kwargs): - raise QueryError('What now') + raise QueryError("What now") mock_index.search.side_effect = _raiseQueryError - request_data = MultiDict({ - 'searchtype': 'title', - 'query': 'foo title' - }) + request_data = MultiDict({"searchtype": "title", "query": "foo title"}) with self.assertRaises(InternalServerError): try: response_data, code, headers = simple.search(request_data) - except QueryError as e: - self.fail("QueryError should be handled (caught %s)" % e) + except QueryError as ex: + self.fail("QueryError should be handled (caught %s)" % ex) - self.assertEqual(mock_index.search.call_count, 1, - "A search should be attempted") + self.assertEqual( + mock_index.search.call_count, 1, "A search should be attempted" + ) class TestSimpleSearchForm(TestCase): @@ -176,26 +196,19 @@ class TestSimpleSearchForm(TestCase): def test_searchtype_only(self): """User has entered only a searchtype (field).""" - data = MultiDict({ - 'searchtype': 'title' - }) + data = MultiDict({"searchtype": "title"}) form = SimpleSearchForm(data) self.assertFalse(form.validate(), "Form should be invalid") def test_query_only(self): """User has entered only a query (value); this should never happen.""" - data = MultiDict({ - 'query': 'someone monkeyed with the request' - }) + data = MultiDict({"query": "someone monkeyed with the request"}) form = SimpleSearchForm(data) self.assertFalse(form.validate(), "Form should be invalid") def test_query_and_searchtype(self): """User has entered a searchtype (field) and query (value).""" - data = MultiDict({ - 'searchtype': 'title', - 'query': 'foo title' - }) + data = MultiDict({"searchtype": "title", "query": "foo title"}) form = SimpleSearchForm(data) self.assertTrue(form.validate(), "Form should be valid") @@ -205,78 +218,69 @@ class TestQueryFromForm(TestCase): def test_multiple_simple(self): """Form data has three simple.""" - data = MultiDict({ - 'searchtype': 'title', - 'query': 'foo title' - }) + data = MultiDict({"searchtype": "title", "query": "foo title"}) form = SimpleSearchForm(data) query = simple._query_from_form(form) - self.assertIsInstance(query, SimpleQuery, - "Should return an instance of SimpleQuery") + self.assertIsInstance( + query, SimpleQuery, "Should return an instance of SimpleQuery" + ) def test_form_data_has_order(self): """Form data includes sort order.""" - data = MultiDict({ - 'searchtype': 'title', - 'query': 'foo title', - 'order': 'submitted_date' - }) + data = MultiDict( + { + "searchtype": "title", + "query": "foo title", + "order": "submitted_date", + } + ) form = SimpleSearchForm(data) query = simple._query_from_form(form) - self.assertIsInstance(query, SimpleQuery, - "Should return an instance of SimpleQuery") - self.assertEqual(query.order, 'submitted_date') + self.assertIsInstance( + query, SimpleQuery, "Should return an instance of SimpleQuery" + ) + self.assertEqual(query.order, "submitted_date") def test_form_data_has_no_order(self): """Form data includes sort order parameter, but it is 'None'.""" - data = MultiDict({ - 'searchtype': 'title', - 'query': 'foo title', - 'order': 'None' # - }) + data = MultiDict( + {"searchtype": "title", "query": "foo title", "order": "None"} # + ) form = SimpleSearchForm(data) query = simple._query_from_form(form) - self.assertIsInstance(query, SimpleQuery, - "Should return an instance of SimpleQuery") + self.assertIsInstance( + query, SimpleQuery, "Should return an instance of SimpleQuery" + ) self.assertIsNone(query.order, "Order should be None") def test_querystring_has_wildcard_at_start(self): """Querystring starts with a wildcard.""" - data = MultiDict({ - 'searchtype': 'title', - 'query': '*foo title' - }) + data = MultiDict({"searchtype": "title", "query": "*foo title"}) form = SimpleSearchForm(data) self.assertFalse(form.validate(), "Form should be invalid") def test_input_whitespace_is_stripped(self): """If query has padding whitespace, it should be removed.""" - data = MultiDict({ - 'searchtype': 'title', - 'query': ' foo title ' - }) + data = MultiDict({"searchtype": "title", "query": " foo title "}) form = SimpleSearchForm(data) self.assertTrue(form.validate(), "Form should be valid.") - self.assertEqual(form.query.data, 'foo title') + self.assertEqual(form.query.data, "foo title") def test_querystring_has_unbalanced_quotes(self): """Querystring has an odd number of quote characters.""" - data = MultiDict({ - 'searchtype': 'title', - 'query': '"rhubarb' - }) + data = MultiDict({"searchtype": "title", "query": '"rhubarb'}) form = SimpleSearchForm(data) self.assertFalse(form.validate(), "Form should be invalid") - data['query'] = '"rhubarb"' + data["query"] = '"rhubarb"' form = SimpleSearchForm(data) self.assertTrue(form.validate(), "Form should be valid") - data['query'] = '"rhubarb" "pie' + data["query"] = '"rhubarb" "pie' form = SimpleSearchForm(data) self.assertFalse(form.validate(), "Form should be invalid") - data['query'] = '"rhubarb" "pie"' + data["query"] = '"rhubarb" "pie"' form = SimpleSearchForm(data) self.assertTrue(form.validate(), "Form should be valid") @@ -292,27 +296,31 @@ class TestPaginationParametersAreFunky(TestCase): search form. """ - @mock.patch('search.controllers.simple.url_for') + @mock.patch("search.controllers.simple.url_for") def test_order_is_invalid(self, mock_url_for): """The order parameter on the request is invalid.""" - request_data = MultiDict({ - 'searchtype': 'title', - 'query': 'foo title', - 'size': 50, # Valid. - 'order': 'foo' # Invalid - }) + request_data = MultiDict( + { + "searchtype": "title", + "query": "foo title", + "size": 50, # Valid. + "order": "foo", # Invalid + } + ) with self.assertRaises(BadRequest): simple.search(request_data) - @mock.patch('search.controllers.simple.url_for') + @mock.patch("search.controllers.simple.url_for") def test_size_is_invalid(self, mock_url_for): """The order parameter on the request is invalid.""" - request_data = MultiDict({ - 'searchtype': 'title', - 'query': 'foo title', - 'size': 51, # Invalid - 'order': '' # Valid - }) + request_data = MultiDict( + { + "searchtype": "title", + "query": "foo title", + "size": 51, # Invalid + "order": "", # Valid + } + ) with self.assertRaises(BadRequest): simple.search(request_data) @@ -327,76 +335,104 @@ class TestClassicAuthorSyntaxIsIntercepted(TestCase): about the syntax change. """ - @mock.patch('search.controllers.simple.SearchSession') + @mock.patch("search.controllers.simple.SearchSession") def test_all_fields_search_contains_classic_syntax(self, mock_index): """User has entered a `surname_f` query in an all-fields search.""" - request_data = MultiDict({ - 'searchtype': 'all', - 'query': 'franklin_r', - 'size': 50, - 'order': '' - }) - mock_index.search.return_value = dict(metadata={}, results=[]) + request_data = MultiDict( + { + "searchtype": "all", + "query": "franklin_r", + "size": 50, + "order": "", + } + ) + mock_index.search.return_value = {"metadata": {}, "results": []} data, code, headers = simple.search(request_data) - self.assertEqual(data['query'].value, "franklin, r", - "The query should be rewritten.") - self.assertTrue(data['has_classic_format'], - "A flag denoting the syntax interception should be set" - " in the response context, so that a message may be" - " rendered in the template.") - - @mock.patch('search.controllers.simple.SearchSession') + self.assertEqual( + data["query"].value, + "franklin, r", + "The query should be rewritten.", + ) + self.assertTrue( + data["has_classic_format"], + "A flag denoting the syntax interception should be set" + " in the response context, so that a message may be" + " rendered in the template.", + ) + + @mock.patch("search.controllers.simple.SearchSession") def test_author_search_contains_classic_syntax(self, mock_index): """User has entered a `surname_f` query in an author search.""" - request_data = MultiDict({ - 'searchtype': 'author', - 'query': 'franklin_r', - 'size': 50, - 'order': '' - }) - mock_index.search.return_value = dict(metadata={}, results=[]) + request_data = MultiDict( + { + "searchtype": "author", + "query": "franklin_r", + "size": 50, + "order": "", + } + ) + mock_index.search.return_value = {"metadata": {}, "results": []} data, code, headers = simple.search(request_data) - self.assertEqual(data['query'].value, "franklin, r", - "The query should be rewritten.") - self.assertTrue(data['has_classic_format'], - "A flag denoting the syntax interception should be set" - " in the response context, so that a message may be" - " rendered in the template.") - - @mock.patch('search.controllers.simple.SearchSession') + self.assertEqual( + data["query"].value, + "franklin, r", + "The query should be rewritten.", + ) + self.assertTrue( + data["has_classic_format"], + "A flag denoting the syntax interception should be set" + " in the response context, so that a message may be" + " rendered in the template.", + ) + + @mock.patch("search.controllers.simple.SearchSession") def test_all_fields_search_multiple_classic_syntax(self, mock_index): """User has entered a classic query with multiple authors.""" - request_data = MultiDict({ - 'searchtype': 'all', - 'query': 'j franklin_r hawking_s', - 'size': 50, - 'order': '' - }) - mock_index.search.return_value = dict(metadata={}, results=[]) + request_data = MultiDict( + { + "searchtype": "all", + "query": "j franklin_r hawking_s", + "size": 50, + "order": "", + } + ) + mock_index.search.return_value = {"metadata": {}, "results": []} data, code, headers = simple.search(request_data) - self.assertEqual(data['query'].value, "j franklin, r; hawking, s", - "The query should be rewritten.") - self.assertTrue(data['has_classic_format'], - "A flag denoting the syntax interception should be set" - " in the response context, so that a message may be" - " rendered in the template.") - - @mock.patch('search.controllers.simple.SearchSession') + self.assertEqual( + data["query"].value, + "j franklin, r; hawking, s", + "The query should be rewritten.", + ) + self.assertTrue( + data["has_classic_format"], + "A flag denoting the syntax interception should be set" + " in the response context, so that a message may be" + " rendered in the template.", + ) + + @mock.patch("search.controllers.simple.SearchSession") def test_title_search_contains_classic_syntax(self, mock_index): """User has entered a `surname_f` query in a title search.""" - request_data = MultiDict({ - 'searchtype': 'title', - 'query': 'franklin_r', - 'size': 50, - 'order': '' - }) - mock_index.search.return_value = dict(metadata={}, results=[]) + request_data = MultiDict( + { + "searchtype": "title", + "query": "franklin_r", + "size": 50, + "order": "", + } + ) + mock_index.search.return_value = {"metadata": {}, "results": []} data, code, headers = simple.search(request_data) - self.assertEqual(data['query'].value, "franklin_r", - "The query should not be rewritten.") - self.assertFalse(data['has_classic_format'], - "Flag should not be set, as no rewrite has occurred.") + self.assertEqual( + data["query"].value, + "franklin_r", + "The query should not be rewritten.", + ) + self.assertFalse( + data["has_classic_format"], + "Flag should not be set, as no rewrite has occurred.", + ) diff --git a/search/controllers/tests.py b/search/controllers/tests.py index 083e2964..cb9caaf1 100644 --- a/search/controllers/tests.py +++ b/search/controllers/tests.py @@ -1,43 +1,48 @@ """Tests for :mod:`search.controllers`.""" +from http import HTTPStatus from unittest import TestCase, mock -from datetime import date -from arxiv import status -from search.domain import DocumentSet, Document from search.controllers import health_check -from .util import catch_underscore_syntax +from search.controllers.util import catch_underscore_syntax class TestHealthCheck(TestCase): """Tests for :func:`.health_check`.""" - @mock.patch('search.controllers.index.SearchSession') + @mock.patch("search.controllers.index.SearchSession") def test_index_is_down(self, mock_index): """Test returns 'DOWN' + status 500 when index raises an exception.""" mock_index.search.side_effect = RuntimeError response, status_code, _ = health_check() - self.assertEqual(response, 'DOWN', "Response content should be DOWN") - self.assertEqual(status_code, status.HTTP_500_INTERNAL_SERVER_ERROR, - "Should return 500 status code.") + self.assertEqual(response, "DOWN", "Response content should be DOWN") + self.assertEqual( + status_code, + HTTPStatus.INTERNAL_SERVER_ERROR, + "Should return 500 status code.", + ) - @mock.patch('search.controllers.index.SearchSession') + @mock.patch("search.controllers.index.SearchSession") def test_index_returns_no_result(self, mock_index): """Test returns 'DOWN' + status 500 when index returns no results.""" - mock_index.search.return_value = dict(metadata={}, results=[]) + mock_index.search.return_value = {"metadata": {}, "results": []} response, status_code, _ = health_check() - self.assertEqual(response, 'DOWN', "Response content should be DOWN") - self.assertEqual(status_code, status.HTTP_500_INTERNAL_SERVER_ERROR, - "Should return 500 status code.") + self.assertEqual(response, "DOWN", "Response content should be DOWN") + self.assertEqual( + status_code, + HTTPStatus.INTERNAL_SERVER_ERROR, + "Should return 500 status code.", + ) - @mock.patch('search.controllers.index.SearchSession') + @mock.patch("search.controllers.index.SearchSession") def test_index_returns_result(self, mock_index): """Test returns 'OK' + status 200 when index returns results.""" - mock_index.search.return_value = dict(metadata={}, results=[dict()]) + mock_index.search.return_value = {"metadata": {}, "results": [{}]} response, status_code, _ = health_check() - self.assertEqual(response, 'OK', "Response content should be OK") - self.assertEqual(status_code, status.HTTP_200_OK, - "Should return 200 status code.") + self.assertEqual(response, "OK", "Response content should be OK") + self.assertEqual( + status_code, HTTPStatus.OK, "Should return 200 status code." + ) class TestUnderscoreHandling(TestCase): @@ -47,8 +52,11 @@ def test_underscore_is_rewritten(self): """User searches for an author name with `surname_f` format.""" query = "franklin_r" after, classic_name = catch_underscore_syntax(query) - self.assertEqual(after, "franklin, r", - "The underscore should be replaced with `, `.") + self.assertEqual( + after, + "franklin, r", + "The underscore should be replaced with `, `.", + ) self.assertTrue(classic_name, "Should be identified as classic") def test_false_positive(self): @@ -63,13 +71,16 @@ def test_multiple_authors(self): # E-gads. query = "franklin_r dole_b" after, classic_name = catch_underscore_syntax(query) - self.assertEqual(after, "franklin, r; dole, b", - "The underscore should be replaced with `, `.") + self.assertEqual( + after, + "franklin, r; dole, b", + "The underscore should be replaced with `, `.", + ) self.assertTrue(classic_name, "Should be identified as classic") def test_nonsense_input(self): """Garbage input is passed.""" try: catch_underscore_syntax("") - except Exception as e: - self.fail(e) + except Exception as ex: + self.fail(ex) diff --git a/search/controllers/util.py b/search/controllers/util.py index a506f090..9f1c12c3 100644 --- a/search/controllers/util.py +++ b/search/controllers/util.py @@ -1,26 +1,30 @@ """Controller helpers.""" import re -from typing import Tuple +from typing import Tuple, Dict, Any from wtforms import Form, StringField, validators from search.domain import Query -CLASSIC_AUTHOR = r'([A-Za-z]+)_([a-zA-Z])(?=$|\s)' +CLASSIC_AUTHOR = r"([A-Za-z]+)_([a-zA-Z])(?=$|\s)" def does_not_start_with_wildcard(form: Form, field: StringField) -> None: """Check that ``value`` does not start with a wildcard character.""" if not field.data: return - if field.data.startswith('?') or field.data.startswith('*'): + if field.data.startswith("?") or field.data.startswith("*"): raise validators.ValidationError( - 'Search cannot start with a wildcard (? *).') - if any([part.startswith('?') or part.startswith('*') - for part in field.data.split()]): - raise validators.ValidationError('Search terms cannot start with a' - ' wildcard (? *).') + "Search cannot start with a wildcard (? *)." + ) + if any( + part.startswith("?") or part.startswith("*") + for part in field.data.split() + ): + raise validators.ValidationError( + "Search terms cannot start with a" " wildcard (? *)." + ) def has_balanced_quotes(form: Form, field: StringField) -> None: @@ -38,7 +42,8 @@ def strip_white_space(value: str) -> str: return value.strip() -def paginate(query: Query, data: dict) -> Query: +# FIXME: Argument type. +def paginate(query: Query, data: Dict[Any, Any]) -> Query: """ Update pagination parameters on a :class:`.Query` from request parameters. @@ -52,8 +57,8 @@ def paginate(query: Query, data: dict) -> Query: :class:`.Query` """ - query.page_start = max(int(data.get('start', 0)), 0) - query.size = min(int(data.get('size', 50)), Query.MAXIMUM_size) + query.page_start = max(int(data.get("start", 0)), 0) + query.size = min(int(data.get("size", 50)), Query.MAXIMUM_size) return query @@ -62,4 +67,4 @@ def catch_underscore_syntax(term: str) -> Tuple[str, bool]: match = re.search(CLASSIC_AUTHOR, term) if not match: return term, False - return re.sub(CLASSIC_AUTHOR, r'\g<1>, \g<2>;', term).rstrip(';'), True + return re.sub(CLASSIC_AUTHOR, r"\g<1>, \g<2>;", term).rstrip(";"), True diff --git a/search/converters.py b/search/converters.py index b96a40d0..7742e42d 100644 --- a/search/converters.py +++ b/search/converters.py @@ -1,6 +1,5 @@ """URL conversion for paths containing arXiv groups or archives.""" -import re from typing import List, Optional from arxiv import taxonomy from werkzeug.routing import BaseConverter, ValidationError @@ -12,13 +11,13 @@ class ArchiveConverter(BaseConverter): def to_python(self, value: str) -> Optional[List[str]]: """Parse URL path part to Python rep (str).""" valid_archives = [] - for archive in value.split(','): + for archive in value.split(","): if archive not in taxonomy.ARCHIVES: continue # Support old archives. if archive in taxonomy.ARCHIVES_SUBSUMED: cat = taxonomy.CATEGORIES[taxonomy.ARCHIVES_SUBSUMED[archive]] - archive = cat['in_archive'] + archive = cat["in_archive"] valid_archives.append(archive) if not valid_archives: raise ValidationError() diff --git a/search/domain/__init__.py b/search/domain/__init__.py index f3b0b9ab..9d2ebe65 100644 --- a/search/domain/__init__.py +++ b/search/domain/__init__.py @@ -8,8 +8,71 @@ intelligibility of the codebase. """ +__all__ = [ + # base + "asdict", + "DocMeta", + "Fulltext", + "DateRange", + "Classification", + "ClassificationList", + "Operator", + "Field", + "Term", + "Phrase", + "Phrase", + "SortDirection", + "SortBy", + "SortOrder", + "Query", + "SimpleQuery", + # advanced + "FieldedSearchTerm", + "FieldedSearchList", + "AdvancedQuery", + # api + "APIQuery", + # classic api + "ClassicAPIQuery", + "ClassicSearchResponseData", + # documenhts + "Error", + "Document", + "DocumentSet", + "document_set_from_documents", +] + # pylint: disable=wildcard-import -from .base import * -from .advanced import * -from .api import * -from .documents import * +from search.domain.base import ( + asdict, + DocMeta, + Fulltext, + DateRange, + Classification, + ClassificationList, + Operator, + Field, + Term, + Phrase, + SortDirection, + SortBy, + SortOrder, + Query, + SimpleQuery, +) +from search.domain.advanced import ( + FieldedSearchTerm, + FieldedSearchList, + AdvancedQuery, +) +from search.domain.api import APIQuery +from search.domain.classic_api import ( + ClassicAPIQuery, + ClassicSearchResponseData, +) +from search.domain.documents import ( + Error, + Document, + DocumentSet, + document_set_from_documents, +) diff --git a/search/domain/advanced.py b/search/domain/advanced.py index c15b4db5..19e17dd9 100644 --- a/search/domain/advanced.py +++ b/search/domain/advanced.py @@ -1,22 +1,22 @@ """Represents fielded search terms, with multiple operators.""" -from .base import DateRange, Query, ClassificationList - +from typing import Optional from dataclasses import dataclass, field -from typing import NamedTuple, Optional + +from search.domain.base import DateRange, Query, ClassificationList @dataclass class FieldedSearchTerm: """Represents a fielded search term.""" - operator: str + operator: Optional[str] field: str term: str def __str__(self) -> str: """Build a string representation, for use in rendering.""" - return f'{self.operator} {self.field}={self.term}' + return f"{self.operator} {self.field}={self.term}" class FieldedSearchList(list): @@ -24,7 +24,7 @@ class FieldedSearchList(list): def __str__(self) -> str: """Build a string representation, for use in rendering.""" - return '; '.join([str(item) for item in self]) + return "; ".join([str(item) for item in self]) @dataclass @@ -36,20 +36,20 @@ class AdvancedQuery(Query): """ SUPPORTED_FIELDS = [ - ('title', 'Title'), - ('author', 'Author(s)'), - ('abstract', 'Abstract'), - ('comments', 'Comments'), - ('journal_ref', 'Journal reference'), - ('acm_class', 'ACM classification'), - ('msc_class', 'MSC classification'), - ('report_num', 'Report number'), - ('paper_id', 'arXiv identifier'), - ('cross_list_category', 'Cross-list category'), - ('doi', 'DOI'), - ('orcid', 'ORCID'), - ('author_id', 'arXiv author ID'), - ('all', 'All fields') + ("title", "Title"), + ("author", "Author(s)"), + ("abstract", "Abstract"), + ("comments", "Comments"), + ("journal_ref", "Journal reference"), + ("acm_class", "ACM classification"), + ("msc_class", "MSC classification"), + ("report_num", "Report number"), + ("paper_id", "arXiv identifier"), + ("cross_list_category", "Cross-list category"), + ("doi", "DOI"), + ("orcid", "ORCID"), + ("author_id", "arXiv author ID"), + ("all", "All fields"), ] date_range: Optional[DateRange] = None diff --git a/search/domain/api.py b/search/domain/api.py index 1d7604c9..035a20e5 100644 --- a/search/domain/api.py +++ b/search/domain/api.py @@ -1,20 +1,20 @@ """API-specific domain classes.""" -from .base import DateRange, Query, ClassificationList, Classification, List -from .advanced import FieldedSearchList, FieldedSearchTerm - from dataclasses import dataclass, field -from typing import NamedTuple, Optional, Tuple +from typing import Optional, Tuple + +from search.domain.advanced import FieldedSearchList +from search.domain.base import DateRange, Query, Classification, List def get_default_extra_fields() -> List[str]: """These are the default extra fields.""" - return ['title'] + return ["title"] def get_required_fields() -> List[str]: """These fields should always be included.""" - return ['paper_id', 'paper_id_v', 'version', 'href', 'canonical'] + return ["paper_id", "paper_id_v", "version", "href", "canonical"] @dataclass @@ -26,8 +26,9 @@ class APIQuery(Query): """ date_range: Optional[DateRange] = None - primary_classification: Tuple[Classification, ...] = \ - field(default_factory=tuple) + primary_classification: Tuple[Classification, ...] = field( + default_factory=tuple + ) """Limit results to a specific primary classification.""" secondary_classification: List[Tuple[Classification, ...]] = field( default_factory=list diff --git a/search/domain/base.py b/search/domain/base.py index d2feec98..147cd98f 100644 --- a/search/domain/base.py +++ b/search/domain/base.py @@ -1,21 +1,17 @@ """Base domain classes for search service.""" -from typing import Any, Optional, List, Dict -from datetime import datetime, date -from operator import attrgetter -from pytz import timezone -import re -from mypy_extensions import TypedDict - -from arxiv import taxonomy +from enum import Enum +from datetime import datetime +from typing import Any, Optional, List, Dict, Union, Tuple +from dataclasses import dataclass, field, asdict as _asdict -from dataclasses import dataclass, field -from dataclasses import asdict as _asdict +from mypy_extensions import TypedDict -EASTERN = timezone('US/Eastern') +from search import consts -def asdict(obj: Any) -> dict: +# FIXME: Return type. +def asdict(obj: Any) -> Dict[Any, Any]: """Coerce a dataclass object to a dict.""" return {key: value for key, value in _asdict(obj).items()} @@ -40,8 +36,9 @@ class DocMeta: is_withdrawn: bool = field(default=False) license: Dict[str, str] = field(default_factory=dict) primary_classification: Dict[str, str] = field(default_factory=dict) - secondary_classification: List[Dict[str, str]] = \ - field(default_factory=list) + secondary_classification: List[Dict[str, str]] = field( + default_factory=list + ) title: str = field(default_factory=str) title_utf8: str = field(default_factory=str) source: Dict[str, Any] = field(default_factory=dict) @@ -77,27 +74,27 @@ class Fulltext: class DateRange: """Represents an open or closed date range, for use in :class:`.Query`.""" - start_date: datetime = datetime(1990, 1, 1, tzinfo=EASTERN) + start_date: datetime = datetime(1990, 1, 1, tzinfo=consts.EASTERN) """The day/time on which the range begins.""" - end_date: datetime = datetime.now(tz=EASTERN) + end_date: datetime = datetime.now(tz=consts.EASTERN) """The day/time at (just before) which the range ends.""" - SUBMITTED_ORIGINAL = 'submitted_date_first' - SUBMITTED_CURRENT = 'submitted_date' - ANNOUNCED = 'announced_date_first' + SUBMITTED_ORIGINAL = "submitted_date_first" + SUBMITTED_CURRENT = "submitted_date" + ANNOUNCED = "announced_date_first" date_type: str = field(default=SUBMITTED_CURRENT) """The date associated with the paper that should be queried.""" def __str__(self) -> str: """Build a string representation, for use in rendering.""" - _str = '' + _str = "" if self.start_date: - start_date = self.start_date.strftime('%Y-%m-%d') - _str += f'from {start_date} ' + start_date = self.start_date.strftime("%Y-%m-%d") + _str += f"from {start_date} " if self.end_date: - end_date = self.end_date.strftime('%Y-%m-%d') - _str += f'to {end_date}' + end_date = self.end_date.strftime("%Y-%m-%d") + _str += f"to {end_date}" return _str @@ -123,7 +120,154 @@ class ClassificationList(list): def __str__(self) -> str: """Build a string representation, for use in rendering.""" - return ', '.join([str(item) for item in self]) + return ", ".join([str(item) for item in self]) + + +class Operator(str, Enum): + """Supported boolean operators.""" + + AND = "AND" + OR = "OR" + ANDNOT = "ANDNOT" + + @classmethod + def is_valid_value(cls, value: str) -> bool: + """ + Determine whether or not ``value`` is a valid value of a member. + + Parameters + ---------- + value : str + + Returns + ------- + bool + + """ + try: + cls(value) + except ValueError: + return False + return True + + +class Field(str, Enum): + """Supported fields in the classic API.""" + + Title = "ti" + Author = "au" + Abstract = "abs" + Comment = "co" + JournalReference = "jr" + SubjectCategory = "cat" + ReportNumber = "rn" + Identifier = "id" + All = "all" + + +@dataclass +class Term: + """Class representing a Field and search term. + + Examples + -------- + .. code-block:: python + + term = Term(Field.Title, 'dark matter') + + """ + + field: Field + value: str = "" + + @property + def is_empty(self) -> bool: + return self.value.strip() == "" + + +# mypy doesn't yet support recursive type definitions. These ignores suppress +# the cyclic definition error, and forward-references to ``Phrase`` are +# are replaced with ``Any``. +Phrase = Union[ # type: ignore + Term, # type: ignore + Tuple[Operator, "Phrase"], # type: ignore + Tuple[Operator, "Phrase", "Phrase"], # type: ignore +] +""" +Recursive representation of a search query. + +Examples +-------- + +.. code-block:: python + + # Simple query without grouping/nesting. + phrase: Phrase = Term(Field.Author, 'copernicus') + + # Simple query with a unary operator without grouping/nesting. + phrase: Phrase = (Operator.ANDNOT, Term(Field.Author, 'copernicus')) + + # Simple conjunct query. + phrase: Phrase = ( + Operator.AND, + Term(Field.Author, "del_maestro"), + Term(Field.Title, "checkerboard") + ) + + # Disjunct query with an unary not. + phrase = ( + Operator.OR, + Term(Field.Author, "del_maestro"), + ( + Operator.ANDNOT, + Term(Field.Title, "checkerboard") + ) + ) + + # Conjunct query with nested disjunct query. + phrase = ( + Operator.ANDNOT, + Term(Field.Author, "del_maestro"), + ( + Operator.OR, + Term(Field.Title, "checkerboard"), + Term(Field.Title, "Pyrochlore") + ) + ) +""" + + +class SortDirection(str, Enum): + ascending = "ascending" + descending = "descending" + + def to_es(self) -> Dict[str, str]: + return {"order": "asc" if self == SortDirection.ascending else "desc"} + + +class SortBy(str, Enum): + relevance = "relevance" + last_updated_date = "lastUpdatedDate" + submitted_date = "submittedDate" + + def to_es(self) -> str: + return { + SortBy.relevance: "_score", + SortBy.last_updated_date: "updated_date", + SortBy.submitted_date: "submitted_date", + }[self] + + +@dataclass +class SortOrder: + by: Optional[SortBy] = None + direction: SortDirection = SortDirection.descending + + def to_es(self) -> List[Dict[str, Dict[str, str]]]: + if self.by is None: + return consts.DEFAULT_SORT_ORDER + else: + return [{self.by.to_es(): self.direction.to_es()}] @dataclass @@ -134,25 +278,25 @@ class Query: """The maximum number of records that can be retrieved.""" SUPPORTED_FIELDS = [ - ('all', 'All fields'), - ('title', 'Title'), - ('author', 'Author(s)'), - ('abstract', 'Abstract'), - ('comments', 'Comments'), - ('journal_ref', 'Journal reference'), - ('acm_class', 'ACM classification'), - ('msc_class', 'MSC classification'), - ('report_num', 'Report number'), - ('paper_id', 'arXiv identifier'), - ('doi', 'DOI'), - ('orcid', 'ORCID'), - ('license', 'License (URI)'), - ('author_id', 'arXiv author ID'), - ('help', 'Help pages'), - ('full_text', 'Full text') + ("all", "All fields"), + ("title", "Title"), + ("author", "Author(s)"), + ("abstract", "Abstract"), + ("comments", "Comments"), + ("journal_ref", "Journal reference"), + ("acm_class", "ACM classification"), + ("msc_class", "MSC classification"), + ("report_num", "Report number"), + ("paper_id", "arXiv identifier"), + ("doi", "DOI"), + ("orcid", "ORCID"), + ("license", "License (URI)"), + ("author_id", "arXiv author ID"), + ("help", "Help pages"), + ("full_text", "Full text"), ] - order: Optional[str] = field(default=None) + order: Union[SortOrder, Optional[str]] = field(default=None) size: int = field(default=50) page_start: int = field(default=0) include_older_versions: bool = field(default=False) @@ -166,7 +310,7 @@ def page_end(self) -> int: @property def page(self) -> int: """Get the approximate page number.""" - return 1 + int(round(self.page_start/self.size)) + return 1 + int(round(self.page_start / self.size)) @dataclass diff --git a/search/domain/classic_api/__init__.py b/search/domain/classic_api/__init__.py new file mode 100644 index 00000000..8bd420cb --- /dev/null +++ b/search/domain/classic_api/__init__.py @@ -0,0 +1,8 @@ +"""Classic API Query object.""" + +__all__ = ["ClassicAPIQuery", "ClassicSearchResponseData"] + +from search.domain.classic_api.classic_query import ( + ClassicAPIQuery, + ClassicSearchResponseData, +) diff --git a/search/domain/classic_api/classic_query.py b/search/domain/classic_api/classic_query.py new file mode 100644 index 00000000..126d8db0 --- /dev/null +++ b/search/domain/classic_api/classic_query.py @@ -0,0 +1,45 @@ +"""Classic API Query object.""" + +from typing import Optional, List +from dataclasses import dataclass, field + +from search.domain.base import Query, Phrase +from search.domain.documents import DocumentSet +from search.domain.classic_api.query_parser import parse_classic_query + + +@dataclass +class ClassicAPIQuery(Query): + """Query supported by the classic arXiv API.""" + + search_query: Optional[str] = field(default=None) + phrase: Optional[Phrase] = field(default=None) + id_list: Optional[List[str]] = field(default=None) + size: int = field(default=10) + + def __post_init__(self) -> None: + """Ensure that either a phrase or id_list is set.""" + if self.search_query is not None: + self.phrase = parse_classic_query(self.search_query) + + if self.phrase is None and self.id_list is None: + raise ValueError( + "ClassicAPIQuery requires either a phrase, id_list, or both" + ) + + def to_query_string(self) -> str: + """Return a string representation of the API query.""" + return ( + f"search_query={self.search_query or ''}&" + f"id_list={','.join(self.id_list) if self.id_list else ''}&" + f"start={self.page_start}&" + f"max_results={self.size}" + ) + + +@dataclass +class ClassicSearchResponseData: + """Classic API search response data.""" + + results: Optional[DocumentSet] = None + query: Optional[ClassicAPIQuery] = None diff --git a/search/domain/classic_api/query_parser.py b/search/domain/classic_api/query_parser.py new file mode 100644 index 00000000..060d6785 --- /dev/null +++ b/search/domain/classic_api/query_parser.py @@ -0,0 +1,167 @@ +""" +Utility module for Classic API Query parsing. + +Uses lark-parser (EBNF parser) [1]. +[1]: https://github.com/lark-parser/lark/blob/master/README.md + + +The final, parsed query is a :class:`domain.api.Phrase`, which is a nested +set of Tuples:: + + >>> parse_classic_query("au:del_maestro AND ti:checkerboard") + ( + Operator.AND, + Term(field=Field.Author, value='del_maestro'), + Term(field=Field.Title, value='checkerboard') + ) + +See :module:`tests.test_query_parser` for more examples. +""" +import re +from typing import Tuple, List, Optional + +from lark import Lark, Transformer, Token +from werkzeug.exceptions import BadRequest + +from search.domain.base import Operator, Field, Term, Phrase + + +class QueryTransformer(Transformer): + """AST builder class. + + This class will be used to traverse the AST generated by the LARK parser + and transform it's tokens to our AST representation. + + Classic query phrase can be either a:: + + - Term - just a single term. E.e: Term(Field.All, "electron") + - (Operator, Phrase) - Unary operation (only ANDNOT is allowed and it + represent unary negation. I.e: + (Operator.ANDNOT Term(Field.All, "electron")) + - (Operator, Phrase, Phrase) - Binary operation (AND, OR, ANDNOT). I.e: + ( + Operator.AND, + Term(Field.All, "electron"), + Term(Field.Author, "john") + ) + + And also any recursive representation of the following structure. + + """ + + def field(self, tokens: List[Token]) -> Field: + """Transform `all`, `au`...field identifiers to `Field` enum values.""" + print(tokens) + (f,) = tokens + return Field(str(f)) + + def search_string(self, tokens: List[Token]) -> str: + """Un-quote a search string and strips it of whitespace. + + This is the actual search string entered after the Field qualifier. + """ + (s,) = tokens + if s.startswith('"') and s.endswith('"'): + s = s[1:-1] + return s.strip() or "" + + def term(self, tokens: List[Token]) -> Term: + """Construct a Term combining a field and search string.""" + return Term(*tokens) + + def unary_operator(self, tokens: List[Token]) -> Operator: + """Transform unary operator string to Operator enum value.""" + (u,) = tokens + return Operator(str(u)) + + def unary_expression(self, tokens: List[Token]) -> Tuple[Operator, Phrase]: + """Create a unary operation tuple.""" + return tokens[0], tokens[1] + + def binary_operator(self, tokens: List[Token]) -> Operator: + """Transform binary operator string to Operator enum value.""" + (b,) = tokens + return Operator(str(b)) + + def binary_expression( + self, tokens: List[Token] + ) -> Tuple[Operator, Phrase, Phrase]: + """Create a binary operation tuple.""" + return tokens[1], tokens[0], tokens[2] + + def expression(self, tokens: List[Token]) -> Phrase: + """Do nothing, expression is already a singular value.""" + return tokens[0] # type:ignore + + def empty(self, tokens: List[Token]) -> Term: + """Return empty term for an empty string.""" + return Term(Field.All) + + def query(self, tokens: List[Token]) -> Phrase: + """Query is just an expression which is a singular value.""" + return tokens[0] # type:ignore + + +QUERY_PARSER = Lark( + fr""" + query : expression + | empty + + empty : // + + expression : term + | "(" expression ")" + | unary_expression + | binary_expression + + term : field ":" search_string + field : /{"|".join(Field)}/ + search_string : /[^\s\"\(\)]+/ | ESCAPED_STRING + + unary_operator : /ANDNOT/ + unary_expression : unary_operator expression + + binary_operator : /(ANDNOT|AND|OR)/ + binary_expression : expression binary_operator expression + + %import common.ESCAPED_STRING + %import common.WS + %ignore WS + + """, + start="query", + parser="lalr", + transformer=QueryTransformer(), +) + + +def parse_classic_query(query: str) -> Optional[Phrase]: + """Parse the classic query.""" + try: + return QUERY_PARSER.parse(query) # type:ignore + except Exception: + raise BadRequest(f"Invalid query string: '{query}'") + return + + +def phrase_to_query_string(phrase: Phrase, depth: int = 0) -> Optional[str]: + """Convert a Phrase to a query string.""" + if isinstance(phrase, Term): + return ( + f"{phrase.field}:{phrase.value}" + if re.search(r"\s", phrase.value) is None + else f'{phrase.field}:"{phrase.value}"' + ) + elif len(phrase) == 2: + unary_op, exp = phrase[:2] + value = f"{unary_op.value} {phrase_to_query_string(exp, depth+1)}" + return f"({value})" if depth != 0 else value + elif len(phrase) == 3: + binary_op, exp1, exp2 = phrase[:3] # type:ignore + value = ( + f"{phrase_to_query_string(exp1, depth+1)} " + f"{binary_op.value} " + f"{phrase_to_query_string(exp2, depth+1)}" + ) + return f"({value})" if depth != 0 else value + return None diff --git a/search/domain/classic_api/tests/__init__.py b/search/domain/classic_api/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/search/domain/classic_api/tests/test_classic_parser.py b/search/domain/classic_api/tests/test_classic_parser.py new file mode 100644 index 00000000..c451439e --- /dev/null +++ b/search/domain/classic_api/tests/test_classic_parser.py @@ -0,0 +1,249 @@ +# type: ignore +"""Test cases for the classic parser.""" +from typing import List +from dataclasses import dataclass + +from search.domain import Phrase, Field, Operator, Term +from search.domain.classic_api.query_parser import ( + parse_classic_query, + phrase_to_query_string, +) + +from werkzeug.exceptions import BadRequest +from unittest import TestCase + + +@dataclass +class Case: + message: str + query: str + phrase: Phrase = None + + +TEST_PARSE_OK_CASES: List[Case] = [ + Case(message="Empty query.", query="", phrase=Term(Field.All, ""),), + Case( + message="Empty query full of spaces.", + query='au:" "', + phrase=Term(Field.Author, ""), + ), + Case( + message="Empty query in conjunct.", + query='all:electron AND au:""', + phrase=( + Operator.AND, + Term(Field.All, "electron"), + Term(Field.Author, ""), + ), + ), + Case( + message="Simple query without grouping/nesting.", + query="au:copernicus", + phrase=Term(Field.Author, "copernicus"), + ), + Case( + message="Simple query with quotations.", + query='ti:"dark matter"', + phrase=Term(Field.Title, "dark matter"), + ), + Case( + message="Simple query with quotations and extra spacing.", + query='ti:" dark matter "', + phrase=Term(Field.Title, "dark matter"), + ), + Case( + message="Simple conjunct query.", + query="au:del_maestro AND ti:checkerboard", + phrase=( + Operator.AND, + Term(Field.Author, "del_maestro"), + Term(Field.Title, "checkerboard"), + ), + ), + Case( + message="Simple conjunct query with quoted field.", + query='au:del_maestro AND ti:"dark matter"', + phrase=( + Operator.AND, + Term(Field.Author, "del_maestro"), + Term(Field.Title, "dark matter"), + ), + ), + Case( + message="Simple conjunct query with quoted field and spacing.", + query='au:del_maestro AND ti:" dark matter "', + phrase=( + Operator.AND, + Term(Field.Author, "del_maestro"), + Term(Field.Title, "dark matter"), + ), + ), + Case( + message="Disjunct query with an unary not.", + query="au:del_maestro OR (ANDNOT ti:checkerboard)", + phrase=( + Operator.OR, + Term(Field.Author, "del_maestro"), + (Operator.ANDNOT, Term(Field.Title, "checkerboard")), + ), + ), + Case( + message="Conjunct query with nested disjunct query.", + query="au:del_maestro ANDNOT (ti:checkerboard OR ti:Pyrochlore)", + phrase=( + Operator.ANDNOT, + Term(Field.Author, "del_maestro"), + ( + Operator.OR, + Term(Field.Title, "checkerboard"), + Term(Field.Title, "Pyrochlore"), + ), + ), + ), + Case( + message="Conjunct query with nested disjunct query.", + query=( + "((au:del_maestro OR au:bob) " + "ANDNOT (ti:checkerboard OR ti:Pyrochlore))" + ), + phrase=( + Operator.ANDNOT, + ( + Operator.OR, + Term(Field.Author, "del_maestro"), + Term(Field.Author, "bob"), + ), + ( + Operator.OR, + Term(Field.Title, "checkerboard"), + Term(Field.Title, "Pyrochlore"), + ), + ), + ), + Case( + message="Conjunct ANDNOT query with nested disjunct query.", + query="(ti:checkerboard OR ti:Pyrochlore) ANDNOT au:del_maestro", + phrase=( + Operator.ANDNOT, + ( + Operator.OR, + Term(Field.Title, "checkerboard"), + Term(Field.Title, "Pyrochlore"), + ), + Term(Field.Author, "del_maestro"), + ), + ), + Case( + message="Conjunct AND query with nested disjunct query.", + query=( + "(ti:checkerboard OR ti:Pyrochlore) AND " + "(au:del_maestro OR au:hawking)" + ), + phrase=( + Operator.AND, + ( + Operator.OR, + Term(Field.Title, "checkerboard"), + Term(Field.Title, "Pyrochlore"), + ), + ( + Operator.OR, + Term(Field.Author, "del_maestro"), + Term(Field.Author, "hawking"), + ), + ), + ), +] + +TEST_PARSE_ERROR_CASES: List[Case] = [ + Case( + message="Error case with two consecutive operators.", + query="ti:a or and ti:b", + ), + Case( + message="Error case with two consecutive terms.", + query="ti:a and ti:b ti:c", + ), + Case( + message="Error case with a trailing operator.", + query="ti:a and ti:b and", + ), + Case( + message="Error case with a leading operator.", query="or ti:a and ti:b" + ), + Case(message="Testing unclosed quote.", query='ti:a and ti:"b'), + Case( + message="Testing query string with many problems.", + query='or ti:a and and ti:"b', + ), +] + +TEST_SERIALIZE_CASES: List[Case] = [ + Case( + message="Simple query serialization.", + query="au:copernicus", + phrase=Term(Field.Author, "copernicus"), + ), + Case( + message="Simple query with quotations.", + query='ti:"dark matter"', + phrase=Term(Field.Title, "dark matter"), + ), + Case( + message="Simple conjunct query.", + query="au:del_maestro AND ti:checkerboard", + phrase=( + Operator.AND, + Term(Field.Author, "del_maestro"), + Term(Field.Title, "checkerboard"), + ), + ), + Case( + message="Conjunct query with nested disjunct query.", + query=( + "(ti:checkerboard OR ti:Pyrochlore) AND " + "(au:del_maestro OR au:hawking)" + ), + phrase=( + Operator.AND, + ( + Operator.OR, + Term(Field.Title, "checkerboard"), + Term(Field.Title, "Pyrochlore"), + ), + ( + Operator.OR, + Term(Field.Author, "del_maestro"), + Term(Field.Author, "hawking"), + ), + ), + ), +] + + +class TestParsing(TestCase): + """Testing the classic parser.""" + + def test_all_valid_field_values(self): + for field in Field: + result = parse_classic_query(f"{field}:some_text") + self.assertEqual(result, Term(field, "some_text")) + + def test_parse_ok_test_cases(self): + for case in TEST_PARSE_OK_CASES: + self.assertEqual( + parse_classic_query(case.query), case.phrase, msg=case.message + ) + + def test_parse_error_test_cases(self): + for case in TEST_PARSE_ERROR_CASES: + with self.assertRaises(BadRequest, msg=case.message): + parse_classic_query(case.query) + + def test_serialize_cases(self): + for case in TEST_SERIALIZE_CASES: + self.assertEqual( + phrase_to_query_string(case.phrase), + case.query, + msg=case.message, + ) diff --git a/search/domain/documents.py b/search/domain/documents.py index 4ea7e990..1630b619 100644 --- a/search/domain/documents.py +++ b/search/domain/documents.py @@ -1,11 +1,36 @@ """Data structs for search documents.""" -from datetime import datetime, date +from datetime import datetime, date, timezone from typing import Optional, List, Dict, Any + +from dataclasses import dataclass, field from mypy_extensions import TypedDict +from search.domain.base import Classification, ClassificationList + + +# The class keyword ``total=False`` allows instances that do not contain all of +# the typed keys. See https://github.com/python/mypy/issues/2632 for +# background. + + +def utcnow() -> datetime: + """Return timezone aware current timestamp.""" + return datetime.utcnow().astimezone(timezone.utc) + + +@dataclass +class Error: + """Represents an error that happened in the system.""" + + id: str + error: str + link: str + author: str = "arXiv api core" + created: datetime = field(default_factory=utcnow) + -class Person(TypedDict): +class Person(TypedDict, total=False): """Represents an author, owner, or other person in metadata.""" full_name: str @@ -23,7 +48,7 @@ class Person(TypedDict): """Legacy arXiv author identifier.""" -class Document(TypedDict): +class Document(TypedDict, total=False): """A search document, representing an arXiv paper.""" submitted_date: datetime @@ -61,26 +86,68 @@ class Document(TypedDict): comments: str abs_categories: str formats: List[str] - primary_classification: Dict[str, str] - secondary_classification: List[Dict[str, str]] + primary_classification: Classification + secondary_classification: ClassificationList score: float - highlight: dict + # FIXME: Type. + highlight: Dict[Any, Any] """Contains highlighted versions of field values.""" - preview: dict + # FIXME: Type. + preview: Dict[Any, Any] """Contains truncations of field values for preview/snippet display.""" - match: dict + # FIXME: Type. + match: Dict[Any, Any] """Contains fields that matched but lack highlighting.""" - truncated: dict + # FIXME: Type. + truncated: Dict[Any, Any] """Contains fields for which the preview is truncated.""" +class DocumentSetMetadata(TypedDict, total=False): + """Metadata for search results.""" + + current_page: int + end: int + max_pages: int + size: int + start: int + total_results: int + total_pages: int + query: List[Dict[str, Any]] + + class DocumentSet(TypedDict): """A set of search results retrieved from the search index.""" - metadata: Dict[str, Any] + metadata: DocumentSetMetadata results: List[Document] + + +def document_set_from_documents(documents: List[Document]) -> DocumentSet: + """Generate a DocumentSet with only a list of Documents. + + Generates the metadata automatically, which is an advantage over calling + DocumentSet(results=documents, metadata=dict()). + """ + return DocumentSet( + results=documents, metadata=metadata_from_documents(documents) + ) + + +def metadata_from_documents(documents: List[Document]) -> DocumentSetMetadata: + """Generate DocumentSet metadata from a list of documents.""" + metadata: DocumentSetMetadata = {} + metadata["size"] = len(documents) + metadata["end"] = len(documents) + metadata["total_results"] = len(documents) + metadata["start"] = 0 + metadata["max_pages"] = 1 + metadata["current_page"] = 1 + metadata["total_pages"] = 1 + + return metadata diff --git a/search/encode.py b/search/encode.py index 6a17ee3e..333f73df 100644 --- a/search/encode.py +++ b/search/encode.py @@ -20,4 +20,4 @@ def default(self, obj: Any) -> Union[str, List[Any]]: pass else: return list(iterable) - return JSONEncoder.default(self, obj) #type: ignore + return JSONEncoder.default(self, obj) # type: ignore diff --git a/search/errors.py b/search/errors.py new file mode 100644 index 00000000..b243a31c --- /dev/null +++ b/search/errors.py @@ -0,0 +1,31 @@ +"""Search error classes.""" + + +class SearchError(Exception): + """Generic search error.""" + + def __init__(self, message: str): + """Initialize the error message.""" + self.message = message + + @property + def name(self) -> str: + """Error name.""" + return self.__class__.__name__ + + def __str__(self) -> str: + """Represent error as a string.""" + return f"{self.name}({self.message})" + + __repr__ = __str__ + + +class ValidationError(SearchError): + """Validation error.""" + + def __init__( + self, message: str, link: str = "http://arxiv.org/api/errors" + ): + """Initialize the validation error.""" + super().__init__(message=message) + self.link = link diff --git a/search/factory.py b/search/factory.py index f268e05e..f7021496 100644 --- a/search/factory.py +++ b/search/factory.py @@ -4,12 +4,11 @@ from flask import Flask from flask_s3 import FlaskS3 -from werkzeug.contrib.profiler import ProfilerMiddleware from arxiv.base import Base from arxiv.base.middleware import wrap, request_logs from arxiv.users import auth -from search.routes import ui, api +from search.routes import ui, api, classic_api from search.services import index from search.converters import ArchiveConverter from search.encode import ISO8601JSONEncoder @@ -21,13 +20,13 @@ def create_ui_web_app() -> Flask: """Initialize an instance of the search frontend UI web application.""" - logging.getLogger('boto').setLevel(logging.ERROR) - logging.getLogger('boto3').setLevel(logging.ERROR) - logging.getLogger('botocore').setLevel(logging.ERROR) + logging.getLogger("boto").setLevel(logging.ERROR) + logging.getLogger("boto3").setLevel(logging.ERROR) + logging.getLogger("botocore").setLevel(logging.ERROR) - app = Flask('search') - app.config.from_pyfile('config.py') # type: ignore - app.url_map.converters['archive'] = ArchiveConverter + app = Flask("search") + app.config.from_pyfile("config.py") # type: ignore + app.url_map.converters["archive"] = ArchiveConverter index.SearchSession.init_app(app) @@ -39,7 +38,9 @@ def create_ui_web_app() -> Flask: wrap(app, [request_logs.ClassicLogsMiddleware]) # app.config['PROFILE'] = True # app.config['DEBUG'] = True - # app.wsgi_app = ProfilerMiddleware(app.wsgi_app, restrictions=[100], sort_by=('cumtime', )) + # app.wsgi_app = ProfilerMiddleware( + # app.wsgi_app, restrictions=[100], sort_by=('cumtime', ) + # ) for filter_name, template_filter in filters.filters: app.template_filter(filter_name)(template_filter) @@ -49,13 +50,13 @@ def create_ui_web_app() -> Flask: def create_api_web_app() -> Flask: """Initialize an instance of the search frontend UI web application.""" - logging.getLogger('boto').setLevel(logging.ERROR) - logging.getLogger('boto3').setLevel(logging.ERROR) - logging.getLogger('botocore').setLevel(logging.ERROR) + logging.getLogger("boto").setLevel(logging.ERROR) + logging.getLogger("boto3").setLevel(logging.ERROR) + logging.getLogger("botocore").setLevel(logging.ERROR) - app = Flask('search') + app = Flask("search") app.json_encoder = ISO8601JSONEncoder - app.config.from_pyfile('config.py') # type: ignore + app.config.from_pyfile("config.py") # type: ignore index.SearchSession.init_app(app) @@ -63,8 +64,10 @@ def create_api_web_app() -> Flask: auth.Auth(app) app.register_blueprint(api.blueprint) - wrap(app, [request_logs.ClassicLogsMiddleware, - auth.middleware.AuthMiddleware]) + wrap( + app, + [request_logs.ClassicLogsMiddleware, auth.middleware.AuthMiddleware], + ) for error, handler in api.exceptions.get_handlers(): app.errorhandler(error)(handler) @@ -74,25 +77,26 @@ def create_api_web_app() -> Flask: def create_classic_api_web_app() -> Flask: """Initialize an instance of the search frontend UI web application.""" - logging.getLogger('boto').setLevel(logging.ERROR) - logging.getLogger('boto3').setLevel(logging.ERROR) - logging.getLogger('botocore').setLevel(logging.ERROR) + logging.getLogger("boto").setLevel(logging.ERROR) + logging.getLogger("boto3").setLevel(logging.ERROR) + logging.getLogger("botocore").setLevel(logging.ERROR) - app = Flask('search') + app = Flask("search") app.json_encoder = ISO8601JSONEncoder - app.config.from_pyfile('config.py') # type: ignore + app.config.from_pyfile("config.py") # type: ignore index.SearchSession.init_app(app) Base(app) auth.Auth(app) - app.register_blueprint(api.classic.blueprint) + app.register_blueprint(classic_api.blueprint) - wrap(app, [request_logs.ClassicLogsMiddleware, - auth.middleware.AuthMiddleware]) + wrap( + app, + [request_logs.ClassicLogsMiddleware, auth.middleware.AuthMiddleware], + ) - for error, handler in api.exceptions.get_handlers(): + for error, handler in classic_api.exceptions.get_handlers(): app.errorhandler(error)(handler) return app - diff --git a/search/filters.py b/search/filters.py index 27c51205..74975e61 100644 --- a/search/filters.py +++ b/search/filters.py @@ -1,55 +1,55 @@ """Template filters for :mod:`search`.""" -from typing import Dict, Callable from operator import attrgetter from arxiv import taxonomy -from .domain import Classification, Query +from search.domain import Classification, Query def display_classification(classification: Classification) -> str: """Generate a display-friendly label for a classification.""" - group = classification.get('group') - category = classification.get('category') - archive = classification.get('archive') + group = classification.get("group") + category = classification.get("category") + archive = classification.get("archive") parts = [] if group is not None: parts.append( - group.get('name', taxonomy.get_group_display(group["id"]))) + group.get("name", taxonomy.get_group_display(group["id"])) + ) if archive is not None: parts.append( - archive.get('name', taxonomy.get_archive_display(archive["id"]))) + archive.get("name", taxonomy.get_archive_display(archive["id"])) + ) if category is not None: parts.append( - category.get('name', - taxonomy.get_category_display(category["id"]))) - return '::'.join(parts) + category.get("name", taxonomy.get_category_display(category["id"])) + ) + return "::".join(parts) def category_name(classification: Classification) -> str: """Get the category display name for a classification.""" - category = classification.get('category') + category = classification.get("category") if not category: - raise ValueError('No category') - return category.get('name', - taxonomy.get_category_display(category["id"])) + raise ValueError("No category") + return category.get("name", taxonomy.get_category_display(category["id"])) def display_query(query: Query) -> str: """Build a display representation of a :class:`.Query`.""" _parts = [] - for attr in type(query).__dataclass_fields__.keys(): # type: ignore + for attr in type(query).__dataclass_fields__.keys(): # type: ignore value = attrgetter(attr)(query) if not value: continue - if attr == 'classification': - value = ', '.join([display_classification(v) for v in value]) - _parts.append('%s: %s' % (attr, value)) - return '; '.join(_parts) + if attr == "classification": + value = ", ".join([display_classification(v) for v in value]) + _parts.append("%s: %s" % (attr, value)) + return "; ".join(_parts) filters = [ - ('display_classification', display_classification), - ('category_name', category_name), - ('display_query', display_query) + ("display_classification", display_classification), + ("category_name", category_name), + ("display_query", display_query), ] diff --git a/search/process/tests.py b/search/process/tests.py index 46b703aa..270836a1 100644 --- a/search/process/tests.py +++ b/search/process/tests.py @@ -1,11 +1,10 @@ """Tests for :mod:`search.transform`.""" -from unittest import TestCase import json -import jsonschema -from datetime import datetime, date +from unittest import TestCase + +from search.domain import DocMeta from search.process import transform -from search.domain import Document, DocMeta class TestTransformMetdata(TestCase): @@ -13,321 +12,313 @@ class TestTransformMetdata(TestCase): def test_id(self): """Field ``id`` is populated from ``paper_id``.""" - meta = DocMeta(**{'paper_id': '1234.56789'}) + meta = DocMeta(**{"paper_id": "1234.56789"}) doc = transform.to_search_document(meta) - self.assertEqual(doc['id'], '1234.56789v1') + self.assertEqual(doc["id"], "1234.56789v1") def test_abstract(self): """Field ``abstract`` is populated from ``abstract_utf8``.""" - meta = DocMeta(**{ - 'paper_id': '1234.56789', - 'abstract_utf8': 'abstract!' - }) + meta = DocMeta( + **{"paper_id": "1234.56789", "abstract_utf8": "abstract!"} + ) doc = transform.to_search_document(meta) - self.assertEqual(doc['abstract'], 'abstract!') + self.assertEqual(doc["abstract"], "abstract!") def test_authors(self): """Field ``authors`` is populated from ``authors_parsed``.""" - meta = DocMeta(**{ - 'paper_id': '1234.56789', - 'authors_parsed': [ - { - 'first_name': 'B. Ivan', - 'last_name': 'Dole' - } - ] - }) + meta = DocMeta( + **{ + "paper_id": "1234.56789", + "authors_parsed": [ + {"first_name": "B. Ivan", "last_name": "Dole"} + ], + } + ) doc = transform.to_search_document(meta) - self.assertEqual(doc['authors'][0]['first_name'], 'B. Ivan') - self.assertEqual(doc['authors'][0]['last_name'], 'Dole') - self.assertEqual(doc['authors'][0]['full_name'], 'B. Ivan Dole', - "full_name should be generated from first_name and" - " last_name") - self.assertEqual(doc['authors'][0]['initials'], "B I", - "initials should be generated from first name") + self.assertEqual(doc["authors"][0]["first_name"], "B. Ivan") + self.assertEqual(doc["authors"][0]["last_name"], "Dole") + self.assertEqual( + doc["authors"][0]["full_name"], + "B. Ivan Dole", + "full_name should be generated from first_name and" " last_name", + ) + self.assertEqual( + doc["authors"][0]["initials"], + "B I", + "initials should be generated from first name", + ) def test_authors_freeform(self): """Field ``authors_freeform`` is populated from ``authors_utf8``.""" - meta = DocMeta(**{ - 'paper_id': '1234.56789', - 'authors_utf8': 'authors!' - }) + meta = DocMeta( + **{"paper_id": "1234.56789", "authors_utf8": "authors!"} + ) doc = transform.to_search_document(meta) - self.assertEqual(doc['authors_freeform'], 'authors!') + self.assertEqual(doc["authors_freeform"], "authors!") def test_owners(self): """Field ``owners`` is populated from ``author_owners``.""" - meta = DocMeta(**{ - 'paper_id': '1234.56789', - 'author_owners': [ - { - 'first_name': 'B. Ivan', - 'last_name': 'Dole' - } - ] - }) + meta = DocMeta( + **{ + "paper_id": "1234.56789", + "author_owners": [ + {"first_name": "B. Ivan", "last_name": "Dole"} + ], + } + ) doc = transform.to_search_document(meta) - self.assertEqual(doc['owners'][0]['first_name'], 'B. Ivan') - self.assertEqual(doc['owners'][0]['last_name'], 'Dole') - self.assertEqual(doc['owners'][0]['full_name'], 'B. Ivan Dole', - "full_name should be generated from first_name and" - " last_name") - self.assertEqual(doc['owners'][0]['initials'], "B I", - "initials should be generated from first name") + self.assertEqual(doc["owners"][0]["first_name"], "B. Ivan") + self.assertEqual(doc["owners"][0]["last_name"], "Dole") + self.assertEqual( + doc["owners"][0]["full_name"], + "B. Ivan Dole", + "full_name should be generated from first_name and" " last_name", + ) + self.assertEqual( + doc["owners"][0]["initials"], + "B I", + "initials should be generated from first name", + ) def test_submitted_date(self): """Field ``submitted_date`` is populated from ``submitted_date``.""" - meta = DocMeta(**{ - 'paper_id': '1234.56789', - 'submitted_date': '2007-04-25T16:06:50-0400' - }) + meta = DocMeta( + **{ + "paper_id": "1234.56789", + "submitted_date": "2007-04-25T16:06:50-0400", + } + ) doc = transform.to_search_document(meta) - self.assertEqual(doc['submitted_date'], '2007-04-25T16:06:50-0400') + self.assertEqual(doc["submitted_date"], "2007-04-25T16:06:50-0400") def test_submitted_date_all(self): """``submitted_date_all`` is populated from ``submitted_date_all``.""" - meta = DocMeta(**{ - 'paper_id': '1234.56789', - "submitted_date_all": [ - "2007-04-25T15:58:28-0400", "2007-04-25T16:06:50-0400" - ], - 'is_current': True - }) + meta = DocMeta( + **{ + "paper_id": "1234.56789", + "submitted_date_all": [ + "2007-04-25T15:58:28-0400", + "2007-04-25T16:06:50-0400", + ], + "is_current": True, + } + ) doc = transform.to_search_document(meta) - self.assertEqual(doc['submitted_date_all'][0], '2007-04-25T15:58:28-0400') - self.assertEqual(doc['submitted_date_all'][1], '2007-04-25T16:06:50-0400') - self.assertEqual(doc['submitted_date_first'], '2007-04-25T15:58:28-0400', - "Should be populated from submitted_date_all") - self.assertEqual(doc['submitted_date_latest'], "2007-04-25T16:06:50-0400", - "Should be populated from submitted_date_all") + self.assertEqual( + doc["submitted_date_all"][0], "2007-04-25T15:58:28-0400" + ) + self.assertEqual( + doc["submitted_date_all"][1], "2007-04-25T16:06:50-0400" + ) + self.assertEqual( + doc["submitted_date_first"], + "2007-04-25T15:58:28-0400", + "Should be populated from submitted_date_all", + ) + self.assertEqual( + doc["submitted_date_latest"], + "2007-04-25T16:06:50-0400", + "Should be populated from submitted_date_all", + ) def test_modified_date(self): """Field ``modified_date`` is populated from ``modified_date``.""" - meta = DocMeta(**{ - 'paper_id': '1234.56789', - 'modified_date': '2007-04-25T16:06:50-0400' - }) + meta = DocMeta( + **{ + "paper_id": "1234.56789", + "modified_date": "2007-04-25T16:06:50-0400", + } + ) doc = transform.to_search_document(meta) - self.assertEqual(doc['modified_date'], '2007-04-25T16:06:50-0400') + self.assertEqual(doc["modified_date"], "2007-04-25T16:06:50-0400") def test_updated_date(self): """Field ``updated_date`` is populated from ``updated_date``.""" - meta = DocMeta(**{ - 'paper_id': '1234.56789', - 'updated_date': '2007-04-25T16:06:50-0400' - }) + meta = DocMeta( + **{ + "paper_id": "1234.56789", + "updated_date": "2007-04-25T16:06:50-0400", + } + ) doc = transform.to_search_document(meta) - self.assertEqual(doc['updated_date'], '2007-04-25T16:06:50-0400') + self.assertEqual(doc["updated_date"], "2007-04-25T16:06:50-0400") def test_announced_date_first(self): """``announced_date_first`` populated from ``announced_date_first``.""" - meta = DocMeta(**{ - 'paper_id': '1234.56789', - 'announced_date_first': '2007-04' - }) + meta = DocMeta( + **{"paper_id": "1234.56789", "announced_date_first": "2007-04"} + ) doc = transform.to_search_document(meta) - self.assertEqual(doc['announced_date_first'], '2007-04') + self.assertEqual(doc["announced_date_first"], "2007-04") def test_is_withdrawn(self): """Field ``is_withdrawn`` is populated from ``is_withdrawn``.""" - meta = DocMeta(**{ - 'paper_id': '1234.56789', - 'is_withdrawn': False - }) + meta = DocMeta(**{"paper_id": "1234.56789", "is_withdrawn": False}) doc = transform.to_search_document(meta) - self.assertFalse(doc['is_withdrawn']) + self.assertFalse(doc["is_withdrawn"]) def test_license(self): """Field ``license`` is populated from ``license``.""" _license = { "label": "arXiv.org perpetual, non-exclusive license to" - " distribute this article", - "uri": "http://arxiv.org/licenses/nonexclusive-distrib/1.0/" + " distribute this article", + "uri": "http://arxiv.org/licenses/nonexclusive-distrib/1.0/", } - meta = DocMeta(**{ - 'paper_id': '1234.56789', - 'license': _license - }) + meta = DocMeta(**{"paper_id": "1234.56789", "license": _license}) doc = transform.to_search_document(meta) - self.assertEqual(doc['license']['uri'], _license['uri']) - self.assertEqual(doc['license']['label'], _license['label']) + self.assertEqual(doc["license"]["uri"], _license["uri"]) + self.assertEqual(doc["license"]["label"], _license["label"]) - meta = DocMeta(**{ - 'paper_id': '1234.56789', - 'license': {'uri': None, 'label': None} - }) + meta = DocMeta( + **{ + "paper_id": "1234.56789", + "license": {"uri": None, "label": None}, + } + ) doc = transform.to_search_document(meta) - self.assertEqual(doc['license']['uri'], transform.DEFAULT_LICENSE['uri'], - "The default license should be used") - self.assertEqual(doc['license']['label'], - transform.DEFAULT_LICENSE['label'], - "The default license should be used") + self.assertEqual( + doc["license"]["uri"], + transform.DEFAULT_LICENSE["uri"], + "The default license should be used", + ) + self.assertEqual( + doc["license"]["label"], + transform.DEFAULT_LICENSE["label"], + "The default license should be used", + ) def test_paper_version(self): """Field ``paper_id_v`` is populated from ``paper_id``.""" - meta = DocMeta(**{'paper_id': '1234.56789', 'version': 4}) + meta = DocMeta(**{"paper_id": "1234.56789", "version": 4}) doc = transform.to_search_document(meta) - self.assertEqual(doc['paper_id_v'], '1234.56789v4') + self.assertEqual(doc["paper_id_v"], "1234.56789v4") def test_primary_classification(self): """``primary_classification`` set from ``primary_classification``.""" - meta = DocMeta(**{ - 'paper_id': '1234.56789', - 'primary_classification': { - "group": { - "name": "Physics", - "id": "physics" - }, - "archive": { - "name": "Astrophysics", - "id": "astro-ph" + meta = DocMeta( + **{ + "paper_id": "1234.56789", + "primary_classification": { + "group": {"name": "Physics", "id": "physics"}, + "archive": {"name": "Astrophysics", "id": "astro-ph"}, + "category": {"name": "Astrophysics", "id": "astro-ph"}, }, - "category": { - "name": "Astrophysics", - "id": "astro-ph" - } } - }) + ) doc = transform.to_search_document(meta) - self.assertEqual(doc['primary_classification'], - meta.primary_classification) + self.assertEqual( + doc["primary_classification"], meta.primary_classification + ) def test_secondary_classification(self): """``secondary_classification`` from ``secondary_classification``.""" - meta = DocMeta(**{ - 'paper_id': '1234.56789', - 'secondary_classification': [{ - "group": { - "name": "Physics", - "id": "physics" - }, - "archive": { - "name": "Astrophysics", - "id": "astro-ph" - }, - "category": { - "name": "Astrophysics", - "id": "astro-ph" - } - }] - }) + meta = DocMeta( + **{ + "paper_id": "1234.56789", + "secondary_classification": [ + { + "group": {"name": "Physics", "id": "physics"}, + "archive": {"name": "Astrophysics", "id": "astro-ph"}, + "category": {"name": "Astrophysics", "id": "astro-ph"}, + } + ], + } + ) doc = transform.to_search_document(meta) - self.assertEqual(doc['secondary_classification'], - meta.secondary_classification) + self.assertEqual( + doc["secondary_classification"], meta.secondary_classification + ) def test_title(self): """Field ``title`` is populated from ``title_utf8``.""" - meta = DocMeta(**{ - 'paper_id': '1234.56789', - 'title_utf8': 'foo title' - }) + meta = DocMeta(**{"paper_id": "1234.56789", "title_utf8": "foo title"}) doc = transform.to_search_document(meta) - self.assertEqual(doc['title'], 'foo title') + self.assertEqual(doc["title"], "foo title") def test_title_utf8(self): """Field ``title`` is populated from ``title_utf8``.""" - meta = DocMeta(**{ - 'paper_id': '1234.56789', - 'title_utf8': 'foö title' - }) + meta = DocMeta(**{"paper_id": "1234.56789", "title_utf8": "foö title"}) doc = transform.to_search_document(meta) - self.assertEqual(doc['title'], 'foö title') + self.assertEqual(doc["title"], "foö title") def test_source(self): """Field ``source`` is populated from ``source``.""" _source = {"flags": "1", "format": "pdf", "size_bytes": 1230119} - meta = DocMeta(**{ - 'paper_id': '1234.56789', - 'source': _source - }) + meta = DocMeta(**{"paper_id": "1234.56789", "source": _source}) doc = transform.to_search_document(meta) - self.assertEqual(doc['source'], _source) + self.assertEqual(doc["source"], _source) def test_version(self): """Field ``version`` is populated from ``version``.""" - meta = DocMeta(**{ - 'paper_id': '1234.56789', - 'version': 25 - }) + meta = DocMeta(**{"paper_id": "1234.56789", "version": 25}) doc = transform.to_search_document(meta) - self.assertEqual(doc['version'], 25) + self.assertEqual(doc["version"], 25) def test_submitter(self): """Field ``submitter`` is populated from ``submitter``.""" _submitter = { "email": "s.mitter@cornell.edu", "name": "Sub Mitter", - "name_utf8": "Süb Mitter" + "name_utf8": "Süb Mitter", } - meta = DocMeta(**{ - 'paper_id': '1234.56789', - 'submitter': _submitter - }) + meta = DocMeta(**{"paper_id": "1234.56789", "submitter": _submitter}) doc = transform.to_search_document(meta) - self.assertEqual(doc['submitter'], _submitter) + self.assertEqual(doc["submitter"], _submitter) def test_report_num(self): """Field ``report_num`` is populated from ``report_num``.""" - meta = DocMeta(**{ - 'paper_id': '1234.56789', - 'report_num': "Physica A, 245 (1997) 181" - }) + meta = DocMeta( + **{ + "paper_id": "1234.56789", + "report_num": "Physica A, 245 (1997) 181", + } + ) doc = transform.to_search_document(meta) - self.assertEqual(doc['report_num'], "Physica A, 245 (1997) 181") + self.assertEqual(doc["report_num"], "Physica A, 245 (1997) 181") def test_proxy(self): """Field ``proxy`` is populated from ``proxy``.""" - meta = DocMeta(**{ - 'paper_id': '1234.56789', - 'proxy': True - }) - doc = transform.to_search_document(meta) - self.assertTrue(doc['proxy']) - - def test_metadata_id(self): - """Field ``metadata_id`` is populated from ``metadata_id``.""" - meta = DocMeta(**{ - 'paper_id': '1234.56789', - 'metadata_id': '690776' - }) + meta = DocMeta(**{"paper_id": "1234.56789", "proxy": True}) doc = transform.to_search_document(meta) - self.assertEqual(doc['metadata_id'], '690776') + self.assertTrue(doc["proxy"]) def test_msc_class(self): """Field ``msc_class`` is populated from ``msc_class``.""" - meta = DocMeta(**{ - 'paper_id': '1234.56789', - 'msc_class': "03B70,68Q60" - }) + meta = DocMeta( + **{"paper_id": "1234.56789", "msc_class": "03B70,68Q60"} + ) doc = transform.to_search_document(meta) - self.assertEqual(doc['msc_class'], ["03B70", "68Q60"]) + self.assertEqual(doc["msc_class"], ["03B70", "68Q60"]) def test_acm_class(self): """Field ``acm_class`` is populated from ``acm_class``.""" - meta = DocMeta(**{ - 'paper_id': '1234.56789', - 'acm_class': "F.4.1; D.2.4" - }) + meta = DocMeta( + **{"paper_id": "1234.56789", "acm_class": "F.4.1; D.2.4"} + ) doc = transform.to_search_document(meta) - self.assertEqual(doc['acm_class'], ["F.4.1", "D.2.4"]) + self.assertEqual(doc["acm_class"], ["F.4.1", "D.2.4"]) def test_doi(self): """Field ``doi`` is populated from ``doi``.""" - meta = DocMeta(**{ - 'paper_id': '1234.56789', - 'doi': '10.1103/PhysRevD.76.104043' - }) + meta = DocMeta( + **{"paper_id": "1234.56789", "doi": "10.1103/PhysRevD.76.104043"} + ) + doc = transform.to_search_document(meta) + self.assertEqual(doc["doi"], ["10.1103/PhysRevD.76.104043"]) + + def test_metadata_id1(self): + """Field ``metadata_id`` is populated from ``metadata_id``.""" + meta = DocMeta(**{"paper_id": "1234.56789", "metadata_id": "690776"}) doc = transform.to_search_document(meta) - self.assertEqual(doc['doi'], ['10.1103/PhysRevD.76.104043']) + self.assertEqual(doc["metadata_id"], "690776") - def test_metadata_id(self): + def test_metadata_id2(self): """Field ``comments`` is populated from ``comments_utf8``.""" - meta = DocMeta(**{ - 'paper_id': '1234.56789', - 'comments_utf8': 'comments!' - }) + meta = DocMeta( + **{"paper_id": "1234.56789", "comments_utf8": "comments!"} + ) doc = transform.to_search_document(meta) - self.assertEqual(doc['comments'], 'comments!') + self.assertEqual(doc["comments"], "comments!") class TestTransformBulkDocmeta(TestCase): @@ -335,29 +326,29 @@ class TestTransformBulkDocmeta(TestCase): def test_transform(self): """All of the paper ID and version fields should be set correctly.""" - with open('tests/data/docmeta_bulk.json') as f: + with open("tests/data/docmeta_bulk.json") as f: data = json.load(f) docmeta = [DocMeta(**datum) for datum in data] documents = [transform.to_search_document(meta) for meta in docmeta] for doc in documents: - self.assertIsNotNone(doc['id']) - self.assertGreater(len(doc['id']), 0) - self.assertIsNotNone(doc['paper_id']) - self.assertGreater(len(doc['paper_id']), 0) - self.assertNotIn('v', doc['paper_id']) - self.assertIsNotNone(doc['paper_id_v']) - self.assertGreater(len(doc['paper_id_v']), 0) - self.assertIn('v', doc['paper_id_v']) - self.assertIsNotNone(doc['version']) - self.assertGreater(doc['version'], 0) - - if doc['version'] == 2: - self.assertEqual(doc['latest'], f"{doc['paper_id']}v2") - self.assertTrue(doc['is_current']) - self.assertEqual(doc['id'], doc['paper_id_v']) + self.assertIsNotNone(doc["id"]) + self.assertGreater(len(doc["id"]), 0) + self.assertIsNotNone(doc["paper_id"]) + self.assertGreater(len(doc["paper_id"]), 0) + self.assertNotIn("v", doc["paper_id"]) + self.assertIsNotNone(doc["paper_id_v"]) + self.assertGreater(len(doc["paper_id_v"]), 0) + self.assertIn("v", doc["paper_id_v"]) + self.assertIsNotNone(doc["version"]) + self.assertGreater(doc["version"], 0) + + if doc["version"] == 2: + self.assertEqual(doc["latest"], f"{doc['paper_id']}v2") + self.assertTrue(doc["is_current"]) + self.assertEqual(doc["id"], doc["paper_id_v"]) else: - self.assertFalse(doc['is_current']) - self.assertEqual(doc['id'], doc['paper_id_v']) - self.assertEqual(doc['latest_version'], 2) + self.assertFalse(doc["is_current"]) + self.assertEqual(doc["id"], doc["paper_id_v"]) + self.assertEqual(doc["latest_version"], 2) diff --git a/search/process/transform.py b/search/process/transform.py index c56ddc6e..e3a0fe50 100644 --- a/search/process/transform.py +++ b/search/process/transform.py @@ -6,49 +6,56 @@ from search.domain import Document, DocMeta, Fulltext DEFAULT_LICENSE = { - 'uri': 'http://arxiv.org/licenses/assumed-1991-2003/', - 'label': "Assumed arXiv.org perpetual, non-exclusive license to distribute" - " this article for submissions made before January 2004" + "uri": "http://arxiv.org/licenses/assumed-1991-2003/", + "label": "Assumed arXiv.org perpetual, non-exclusive license to distribute" + " this article for submissions made before January 2004", } -def _constructLicense(meta: DocMeta) -> dict: +def _constructLicense(meta: DocMeta) -> Dict[str, str]: """Get the document license, or set the default.""" - if not meta.license or not meta.license['uri']: + if not meta.license or not meta.license["uri"]: return DEFAULT_LICENSE return meta.license def _strip_punctuation(s: str) -> str: - return ''.join([c for c in s if c not in punctuation]) + return "".join([c for c in s if c not in punctuation]) def _constructPaperVersion(meta: DocMeta) -> str: """Generate a version-qualified paper ID.""" - return '%sv%i' % (meta.paper_id, meta.version) + return "%sv%i" % (meta.paper_id, meta.version) def _constructMSCClass(meta: DocMeta) -> Optional[list]: """Extract ``msc_class`` field as an array.""" if not meta.msc_class: return None - return [obj.strip() for obj in meta.msc_class.split(',')] + return [obj.strip() for obj in meta.msc_class.split(",")] def _constructACMClass(meta: DocMeta) -> Optional[list]: """Extract ``acm_class`` field as an array.""" if not meta.acm_class: return None - return [obj.strip() for obj in meta.acm_class.split(';')] + return [obj.strip() for obj in meta.acm_class.split(";")] -def _transformAuthor(author: dict) -> Optional[Dict]: - if (not author['last_name']) and (not author['first_name']): +def _transformAuthor(author: Dict[str, str]) -> Optional[Dict[str, str]]: + if (not author["last_name"]) and (not author["first_name"]): return None - author['full_name'] = re.sub(r'\s+', ' ', f"{author['first_name']} {author['last_name']}") - author['initials'] = " ".join([pt[0] for pt in author['first_name'].split() if pt]) - name_parts = author['first_name'].split() + author['last_name'].split() - author['full_name_initialized'] = ' '.join([part[0] for part in name_parts[:-1]] + [name_parts[-1]]) + author["full_name"] = re.sub( + r"\s+", " ", f"{author['first_name']} {author['last_name']}" + ) + author["initials"] = " ".join( + [pt[0] for pt in author["first_name"].split() if pt] + ) + name_parts = author["first_name"].split() + author["last_name"].split() + author["full_name_initialized"] = " ".join( + [part[0] for part in name_parts[:-1]] + [name_parts[-1]] + ) + return author @@ -96,8 +103,11 @@ def _constructDOI(meta: DocMeta) -> List[str]: ("authors_freeform", "authors_utf8", False), ("owners", _constructAuthorOwners, False), ("submitted_date", "submitted_date", True), - ("submitted_date_all", - lambda meta: meta.submitted_date_all if meta.is_current else None, True), + ( + "submitted_date_all", + lambda meta: meta.submitted_date_all if meta.is_current else None, + True, + ), ("submitted_date_first", _getFirstSubDate, True), ("submitted_date_latest", _getLastSubDate, True), ("modified_date", "modified_date", True), @@ -126,12 +136,13 @@ def _constructDOI(meta: DocMeta) -> List[str]: ("abs_categories", "abs_categories", False), ("formats", "formats", True), ("latest_version", "latest_version", True), - ("latest", "latest", True) + ("latest", "latest", True), ] -def to_search_document(metadata: DocMeta, - fulltext: Optional[Fulltext] = None) -> Document: +def to_search_document( + metadata: DocMeta, fulltext: Optional[Fulltext] = None +) -> Document: """ Transform metadata (and fulltext) into a valid search document. @@ -150,10 +161,11 @@ def to_search_document(metadata: DocMeta, """ data: Document = {} + for key, source, is_required in _transformations: if isinstance(source, str): value = getattr(metadata, source, None) - elif hasattr(source, '__call__'): + elif callable(source): value = source(metadata) if value is None and not is_required: continue @@ -161,4 +173,3 @@ def to_search_document(metadata: DocMeta, # if fulltext: # data['fulltext'] = fulltext.content return data - # See https://github.com/python/mypy/issues/3937 diff --git a/search/routes/api/__init__.py b/search/routes/api/__init__.py index 5fdc1158..1f929165 100644 --- a/search/routes/api/__init__.py +++ b/search/routes/api/__init__.py @@ -1,50 +1,44 @@ """Provides routing blueprint from the search API.""" -import json -from typing import Dict, Callable, Union, Any, Optional, List -from functools import wraps -from urllib.parse import urljoin, urlparse, parse_qs, urlencode, urlunparse - -from flask.json import jsonify -from flask import Blueprint, render_template, redirect, request, Response, \ - url_for -from werkzeug.urls import Href, url_encode, url_parse, url_unparse, url_encode -from werkzeug.datastructures import MultiDict, ImmutableMultiDict - -from arxiv import status -from arxiv.base import logging -from werkzeug.exceptions import InternalServerError -from search.controllers import api +__all__ = ["blueprint", "exceptions"] -from . import serialize, exceptions, classic +from flask import Blueprint, make_response, request, Response -from arxiv.users.auth.decorators import scoped +from arxiv.base import logging from arxiv.users.auth import scopes +from arxiv.users.auth.decorators import scoped +from search import serialize +from search.controllers import api +from search.routes.consts import JSON +from search.routes.api import exceptions logger = logging.getLogger(__name__) -blueprint = Blueprint('api', __name__, url_prefix='/') +blueprint = Blueprint("api", __name__, url_prefix="/") -ATOM_XML = "application/atom+xml" -JSON = "application/json" - -@blueprint.route('/', methods=['GET']) +@blueprint.route("/", methods=["GET"]) @scoped(required=scopes.READ_PUBLIC) def search() -> Response: """Main query endpoint.""" - logger.debug('Got query: %s', request.args) + logger.debug("Got query: %s", request.args) data, status_code, headers = api.search(request.args) # requested = request.accept_mimetypes.best_match([JSON, ATOM_XML]) # if requested == ATOM_XML: # return serialize.as_atom(data), status, headers - response_data = serialize.as_json(data['results'], query=data['query']) - return response_data, status_code, headers # type: ignore + response_data = serialize.as_json(data["results"], query=data["query"]) + + headers.update({"Content-type": JSON}) + response: Response = make_response(response_data, status_code, headers) + return response -@blueprint.route('v', methods=['GET']) +@blueprint.route("/v", methods=["GET"]) @scoped(required=scopes.READ_PUBLIC) def paper(paper_id: str, version: str) -> Response: """Document metadata endpoint.""" - data, status_code, headers = api.paper(f'{paper_id}v{version}') - return serialize.as_json(data['results']), status_code, headers # type: ignore + data, status_code, headers = api.paper(f"{paper_id}v{version}") + response_data = serialize.as_json(data["results"]) + headers.update({"Content-type": JSON}) + response: Response = make_response(response_data, status_code, headers) + return response diff --git a/search/routes/api/classic.py b/search/routes/api/classic.py deleted file mode 100644 index 69fccd4a..00000000 --- a/search/routes/api/classic.py +++ /dev/null @@ -1,45 +0,0 @@ -"""Provides the classic search API.""" - -from flask import Blueprint, render_template, redirect, request, Response, \ - url_for - -from arxiv.base import logging -from search.controllers import api - -from . import serialize, exceptions - -from arxiv.users.auth.decorators import scoped -from arxiv.users.auth import scopes - -logger = logging.getLogger(__name__) - -blueprint = Blueprint('api', __name__, url_prefix='/') - -ATOM_XML = "application/atom+xml" -JSON = "application/json" - - -@blueprint.route('/query', methods=['GET']) -@scoped(required=scopes.READ_PUBLIC) -def query() -> Response: - """Main query endpoint.""" - logger.debug('Got query: %s', request.args) - data, status_code, headers = api.classic_query(request.args) - # requested = request.accept_mimetypes.best_match([JSON, ATOM_XML]) - # if requested == ATOM_XML: - # return serialize.as_atom(data), status, headers - response = serialize.as_json(data['results'], query=data['query']) - response.status_code = status_code - response.headers.extend(headers) - return response - - -@blueprint.route('v', methods=['GET']) -@scoped(required=scopes.READ_PUBLIC) -def paper(paper_id: str, version: str) -> Response: - """Document metadata endpoint.""" - data, status_code, headers = api.paper(f'{paper_id}v{version}') - response = serialize.as_json(data['results']) - response.status_code = status_code - response.headers.extend(headers) - return response diff --git a/search/routes/api/exceptions.py b/search/routes/api/exceptions.py index e3cf397a..147afd54 100644 --- a/search/routes/api/exceptions.py +++ b/search/routes/api/exceptions.py @@ -6,14 +6,22 @@ """ from typing import Callable, List, Tuple - -from werkzeug.exceptions import NotFound, Forbidden, Unauthorized, \ - MethodNotAllowed, RequestEntityTooLarge, BadRequest, InternalServerError, \ - HTTPException +from http import HTTPStatus + +from werkzeug.exceptions import ( + NotFound, + Forbidden, + Unauthorized, + MethodNotAllowed, + RequestEntityTooLarge, + BadRequest, + InternalServerError, + HTTPException, +) from flask import make_response, Response, jsonify -from arxiv import status from arxiv.base import logging +from search.routes.consts import JSON logger = logging.getLogger(__name__) @@ -22,10 +30,12 @@ def handler(exception: type) -> Callable: """Generate a decorator to register a handler for an exception.""" + def deco(func: Callable) -> Callable: """Register a function as an exception handler.""" _handlers.append((exception, func)) return func + return deco @@ -37,73 +47,60 @@ def get_handlers() -> List[Tuple[type, Callable]]: ------- list List of (:class:`.HTTPException`, callable) tuples. + """ return _handlers +def respond(error: HTTPException, status: HTTPStatus) -> Response: + """Generate a JSON response.""" + return make_response( # type: ignore + jsonify({"code": error.code, "error": error.description}), + status, + {"Content-type": JSON}, + ) + + @handler(NotFound) def handle_not_found(error: NotFound) -> Response: """Render the base 404 error page.""" - rendered = jsonify({'code': error.code, 'error': error.description}) - response: Response = make_response(rendered) - response.status_code = status.HTTP_404_NOT_FOUND - return response + return respond(error, HTTPStatus.NOT_FOUND) @handler(Forbidden) def handle_forbidden(error: Forbidden) -> Response: """Render the base 403 error page.""" - rendered = jsonify({'code': error.code, 'error': error.description}) - response: Response = make_response(rendered) - response.status_code = status.HTTP_403_FORBIDDEN - return response + return respond(error, HTTPStatus.FORBIDDEN) @handler(Unauthorized) def handle_unauthorized(error: Unauthorized) -> Response: """Render the base 401 error page.""" - rendered = jsonify({'code': error.code, 'error': error.description}) - response: Response = make_response(rendered) - response.status_code = status.HTTP_401_UNAUTHORIZED - return response + return respond(error, HTTPStatus.UNAUTHORIZED) @handler(MethodNotAllowed) def handle_method_not_allowed(error: MethodNotAllowed) -> Response: """Render the base 405 error page.""" - rendered = jsonify({'code': error.code, 'error': error.description}) - response: Response = make_response(rendered) - response.status_code = status.HTTP_405_METHOD_NOT_ALLOWED - return response + return respond(error, HTTPStatus.METHOD_NOT_ALLOWED) @handler(RequestEntityTooLarge) def handle_request_entity_too_large(error: RequestEntityTooLarge) -> Response: """Render the base 413 error page.""" - rendered = jsonify({'code': error.code, 'error': error.description}) - response: Response = make_response(rendered) - response.status_code = status.HTTP_413_REQUEST_ENTITY_TOO_LARGE - return response + return respond(error, HTTPStatus.REQUEST_ENTITY_TOO_LARGE) @handler(BadRequest) def handle_bad_request(error: BadRequest) -> Response: """Render the base 400 error page.""" - rendered = jsonify({'code': error.code, 'error': error.description}) - response: Response = make_response(rendered) - response.status_code = status.HTTP_400_BAD_REQUEST - return response + return respond(error, HTTPStatus.BAD_REQUEST) @handler(InternalServerError) def handle_internal_server_error(error: InternalServerError) -> Response: """Render the base 500 error page.""" - if isinstance(error, HTTPException): - rendered = jsonify({'code': error.code, 'error': error.description}) - else: - logger.error('Caught unhandled exception: %s', error) - rendered = jsonify({'code': status.HTTP_500_INTERNAL_SERVER_ERROR, - 'error': 'Unexpected error'}) - response: Response = make_response(rendered) - response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR - return response + if not isinstance(error, HTTPException): + logger.error("Caught unhandled exception: %s", error) + error.code = HTTPStatus.INTERNAL_SERVER_ERROR + return respond(error, HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/search/routes/api/serialize.py b/search/routes/api/serialize.py deleted file mode 100644 index be845865..00000000 --- a/search/routes/api/serialize.py +++ /dev/null @@ -1,152 +0,0 @@ -"""Serializers for API responses.""" - -from typing import Union, Optional -from datetime import datetime -from xml.etree import ElementTree as etree -from flask import jsonify, url_for, Response - -from arxiv import status -from search.domain import DocumentSet, Document, Classification, Person, \ - APIQuery - - -class BaseSerializer(object): - """Base class for API serializers.""" - - -class JSONSerializer(BaseSerializer): - """Serializes a :class:`DocumentSet` as JSON.""" - - @classmethod - def _transform_classification(cls, clsn: Classification) -> Optional[dict]: - category = clsn.get('category') - if category is None: - return None - return {'group': clsn.get('group'), - 'archive': clsn.get('archive'), - 'category': category} - - @classmethod - def _transform_format(cls, fmt: str, paper_id: str, version: int) -> dict: - return {"format": fmt, - "href": url_for(fmt, paper_id=paper_id, version=version)} - - @classmethod - def _transform_latest(cls, document: Document) -> Optional[dict]: - latest = document.get('latest') - if latest is None: - return None - return { - "paper_id": latest, - "href": url_for("api.paper", paper_id=document['paper_id'], - version=document.get('latest_version'), - _external=True), - "canonical": url_for("abs", paper_id=document['paper_id'], - version=document.get('latest_version')), - "version": document.get('latest_version') - } - - @classmethod - def _transform_license(cls, license: dict) -> Optional[dict]: - uri = license.get('uri') - if uri is None: - return None - return {'label': license.get('label', ''), 'href': uri} - - @classmethod - def transform_document(cls, doc: Document, - query: Optional[APIQuery] = None) -> dict: - """Select a subset of :class:`Document` properties for public API.""" - # Only return fields that have been explicitly requested. - data = {key: value for key, value in doc.items() - if query is None or key in query.include_fields} - paper_id = doc['paper_id'] - version = doc['version'] - if 'submitted_date_first' in data: - data['submitted_date_first'] = \ - doc['submitted_date_first'].isoformat() - if 'announced_date_first' in data: - data['announced_date_first'] = \ - doc['announced_date_first'].isoformat() - if 'formats' in data: - data['formats'] = [cls._transform_format(fmt, paper_id, version) - for fmt in doc['formats']] - if 'license' in data: - data['license'] = cls._transform_license(doc['license']) - if 'latest' in data: - data['latest'] = cls._transform_latest(doc) - - data['href'] = url_for("api.paper", paper_id=paper_id, - version=version, _external=True) - data['canonical'] = url_for("abs", paper_id=paper_id, - version=version) - return data - - @classmethod - def serialize(cls, document_set: DocumentSet, - query: Optional[APIQuery] = None) -> Response: - """Generate JSON for a :class:`DocumentSet`.""" - serialized: Response = jsonify({ - 'results': [cls.transform_document(doc, query=query) - for doc in document_set['results']], - 'metadata': { - 'start': document_set['metadata'].get('start', ''), - 'end': document_set['metadata'].get('end', ''), - 'size': document_set['metadata'].get('size', ''), - 'total': document_set['metadata'].get('total', ''), - 'query': document_set['metadata'].get('query', []) - }, - }) - return serialized - - @classmethod - def serialize_document(cls, document: Document, - query: Optional[APIQuery] = None) -> Response: - """Generate JSON for a single :class:`Document`.""" - serialized: Response = jsonify( - cls.transform_document(document, query=query) - ) - return serialized - - -def as_json(document_or_set: Union[DocumentSet, Document], - query: Optional[APIQuery] = None) -> Response: - """Serialize a :class:`DocumentSet` as JSON.""" - if 'paper_id' in document_or_set: - return JSONSerializer.serialize_document(document_or_set, query=query) # type: ignore - return JSONSerializer.serialize(document_or_set, query=query) # type: ignore - - - -# TODO: implement me! -class AtomXMLSerializer(BaseSerializer): - """Atom XML serializer for paper metadata.""" - - ATOM = "http://www.w3.org/2005/Atom" - OPENSEARCH = "http://a9.com/-/spec/opensearch/1.1/" - ARXIV = "http://arxiv.org/schemas/atom" - NSMAP = { - None: ATOM, - "opensearch": OPENSEARCH, - "arxiv": ARXIV - } -# fields = { -# 'title': '{%s}title' % ATOM, -# 'id': '{%s}id' % ATOM, -# 'submitted_date': '{%s}published' % ATOM, -# 'modified_date': '{%s}updated' % ATOM, -# 'abstract': '{%s}summary' % ATOM, -# '' -# } -# -# def __init__(cls, *args, **kwargs) -> None: -# super(AtomXMLSerializer, cls).__init__(*args, **kwargs) -# cls._root = etree.Element('feed', nsmap=cls.NSMAP) -# -# def transform(cls): -# for document in cls.iter_documents(): -# -# -# -# def __repr__(cls) -> str: -# return etree.tostring(cls._root, pretty_print=True) diff --git a/search/routes/api/tests/test_api.py b/search/routes/api/tests/test_api.py index b62a8520..98338d5f 100644 --- a/search/routes/api/tests/test_api.py +++ b/search/routes/api/tests/test_api.py @@ -2,30 +2,30 @@ import os import json -from datetime import datetime +from http import HTTPStatus from unittest import TestCase, mock import jsonschema from arxiv.users import helpers, auth from arxiv.users.domain import Scope -from arxiv import status from search import factory -from search import domain +from search.tests import mocks +from search.domain.api import APIQuery, get_required_fields class TestAPISearchRequests(TestCase): """Requests against the main search API.""" - SCHEMA_PATH = os.path.abspath('schema/resources/DocumentSet.json') + SCHEMA_PATH = os.path.abspath("schema/resources/DocumentSet.json") def setUp(self): """Instantiate and configure an API app.""" - jwt_secret = 'foosecret' - os.environ['JWT_SECRET'] = jwt_secret + jwt_secret = "foosecret" + os.environ["JWT_SECRET"] = jwt_secret self.app = factory.create_api_web_app() - self.app.config['JWT_SECRET'] = jwt_secret + self.app.config["JWT_SECRET"] = jwt_secret self.client = self.app.test_client() with open(self.SCHEMA_PATH) as f: @@ -33,167 +33,109 @@ def setUp(self): def test_request_without_token(self): """No auth token is provided on the request.""" - response = self.client.get('/') - self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + response = self.client.get("/") + self.assertEqual(response.status_code, HTTPStatus.UNAUTHORIZED) def test_with_token_lacking_scope(self): """Client auth token lacks required public read scope.""" - token = helpers.generate_token('1234', 'foo@bar.com', 'foouser', - scope=[Scope('something', 'read')]) - response = self.client.get('/', headers={'Authorization': token}) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + token = helpers.generate_token( + "1234", + "foo@bar.com", + "foouser", + scope=[Scope("something", "read")], + ) + response = self.client.get("/", headers={"Authorization": token}) + self.assertEqual(response.status_code, HTTPStatus.FORBIDDEN) - @mock.patch(f'{factory.__name__}.api.api') + @mock.patch(f"{factory.__name__}.api.api") def test_with_valid_token(self, mock_controller): """Client auth token has required public read scope.""" - document = dict( - submitted_date=datetime.now(), - submitted_date_first=datetime.now(), - announced_date_first=datetime.now(), - id='1234.5678', - abstract='very abstract', - authors=[ - dict(full_name='F. Bar', orcid='1234-5678-9012-3456') - ], - submitter=dict(full_name='S. Ubmitter', author_id='su_1'), - modified_date=datetime.now(), - updated_date=datetime.now(), - is_current=True, - is_withdrawn=False, - license={ - 'uri': 'http://foo.license/1', - 'label': 'Notalicense 5.4' - }, - paper_id='1234.5678', - paper_id_v='1234.5678v6', - title='tiiiitle', - source={ - 'flags': 'A', - 'format': 'pdftotex', - 'size_bytes': 2 - }, - version=6, - latest='1234.5678v6', - latest_version=6, - report_num='somenum1', - msc_class=['c1'], - acm_class=['z2'], - journal_ref='somejournal (1991): 2-34', - doi='10.123456/7890', - comments='very science', - abs_categories='astro-ph.CO foo.BR', - formats=['pdf', 'other'], - primary_classification=dict( - group={'id': 'foo', 'name': 'Foo Group'}, - archive={'id': 'foo', 'name': 'Foo Archive'}, - category={'id': 'foo.BR', 'name': 'Foo Category'}, - ), - secondary_classification=[ - dict( - group={'id': 'foo', 'name': 'Foo Group'}, - archive={'id': 'foo', 'name': 'Foo Archive'}, - category={'id': 'foo.BZ', 'name': 'Baz Category'}, - ) - ] - ) - docs = dict( - results=[document], - metadata={'start': 0, 'end': 1, 'size': 50, 'total': 1} + document = mocks.document() + docs = { + "results": [document], + "metadata": {"start": 0, "end": 1, "size": 50, "total": 1}, + } + r_data = {"results": docs, "query": APIQuery()} + mock_controller.search.return_value = r_data, HTTPStatus.OK, {} + token = helpers.generate_token( + "1234", "foo@bar.com", "foouser", scope=[auth.scopes.READ_PUBLIC] ) - r_data = {'results': docs, 'query': domain.APIQuery()} - mock_controller.search.return_value = r_data, status.HTTP_200_OK, {} - token = helpers.generate_token('1234', 'foo@bar.com', 'foouser', - scope=[auth.scopes.READ_PUBLIC]) - response = self.client.get('/', headers={'Authorization': token}) - self.assertEqual(response.status_code, status.HTTP_200_OK) + response = self.client.get("/", headers={"Authorization": token}) + self.assertEqual(response.status_code, HTTPStatus.OK) data = json.loads(response.data) res = jsonschema.RefResolver( - 'file://%s/' % os.path.abspath(os.path.dirname(self.SCHEMA_PATH)), - None + "file://%s/" % os.path.abspath(os.path.dirname(self.SCHEMA_PATH)), + None, + ) + self.assertIsNone( + jsonschema.validate(data, self.schema, resolver=res), + "Response content is valid per schema", ) - self.assertIsNone(jsonschema.validate(data, self.schema, resolver=res), - 'Response content is valid per schema') - for field in domain.api.get_required_fields(): - self.assertIn(field, data['results'][0]) + for field in get_required_fields(): + self.assertIn(field, data["results"][0]) - @mock.patch(f'{factory.__name__}.api.api') + @mock.patch(f"{factory.__name__}.api.api") def test_with_valid_token_limit_fields(self, mock_controller): """Client auth token has required public read scope.""" - document = dict( - submitted_date=datetime.now(), - submitted_date_first=datetime.now(), - announced_date_first=datetime.now(), - id='1234.5678', - abstract='very abstract', - authors=[ - dict(full_name='F. Bar', orcid='1234-5678-9012-3456') - ], - submitter=dict(full_name='S. Ubmitter', author_id='su_1'), - modified_date=datetime.now(), - updated_date=datetime.now(), - is_current=True, - is_withdrawn=False, - license={ - 'uri': 'http://foo.license/1', - 'label': 'Notalicense 5.4' - }, - paper_id='1234.5678', - paper_id_v='1234.5678v6', - title='tiiiitle', - source={ - 'flags': 'A', - 'format': 'pdftotex', - 'size_bytes': 2 - }, - version=6, - latest='1234.5678v6', - latest_version=6, - report_num='somenum1', - msc_class=['c1'], - acm_class=['z2'], - journal_ref='somejournal (1991): 2-34', - doi='10.123456/7890', - comments='very science', - abs_categories='astro-ph.CO foo.BR', - formats=['pdf', 'other'], - primary_classification=dict( - group={'id': 'foo', 'name': 'Foo Group'}, - archive={'id': 'foo', 'name': 'Foo Archive'}, - category={'id': 'foo.BR', 'name': 'Foo Category'}, - ), - secondary_classification=[ - dict( - group={'id': 'foo', 'name': 'Foo Group'}, - archive={'id': 'foo', 'name': 'Foo Archive'}, - category={'id': 'foo.BZ', 'name': 'Baz Category'}, - ) - ] - ) - docs = dict( - results=[document], - metadata={'start': 0, 'end': 1, 'size': 50, 'total': 1} + document = mocks.document() + docs = { + "results": [document], + "metadata": {"start": 0, "end": 1, "size": 50, "total": 1}, + } + + query = APIQuery(include_fields=["abstract", "license"]) + r_data = {"results": docs, "query": query} + mock_controller.search.return_value = r_data, HTTPStatus.OK, {} + token = helpers.generate_token( + "1234", "foo@bar.com", "foouser", scope=[auth.scopes.READ_PUBLIC] ) - - query = domain.APIQuery(include_fields=['abstract', 'license']) - r_data = {'results': docs, 'query': query} - mock_controller.search.return_value = r_data, status.HTTP_200_OK, {} - token = helpers.generate_token('1234', 'foo@bar.com', 'foouser', - scope=[auth.scopes.READ_PUBLIC]) - response = self.client.get('/', headers={'Authorization': token}) - self.assertEqual(response.status_code, status.HTTP_200_OK) + response = self.client.get("/", headers={"Authorization": token}) + self.assertEqual(response.status_code, HTTPStatus.OK) data = json.loads(response.data) res = jsonschema.RefResolver( - 'file://%s/' % os.path.abspath(os.path.dirname(self.SCHEMA_PATH)), - None + "file://%s/" % os.path.abspath(os.path.dirname(self.SCHEMA_PATH)), + None, + ) + self.assertIsNone( + jsonschema.validate(data, self.schema, resolver=res), + "Response content is valid per schema", ) - self.assertIsNone(jsonschema.validate(data, self.schema, resolver=res), - 'Response content is valid per schema') # for field in domain.api.get_required_fields(): self.assertEqual( - set(data['results'][0].keys()), - set(query.include_fields) + set(data["results"][0].keys()), set(query.include_fields) + ) + + @mock.patch(f"{factory.__name__}.api.api") + def test_paper_retrieval(self, mock_controller): + """Test single-paper retrieval.""" + document = mocks.document() + docs = { + "results": [document], + "metadata": {"start": 0, "end": 1, "size": 50, "total": 1}, + } + r_data = {"results": docs, "query": APIQuery()} + mock_controller.paper.return_value = r_data, HTTPStatus.OK, {} + token = helpers.generate_token( + "1234", "foo@bar.com", "foouser", scope=[auth.scopes.READ_PUBLIC] ) + response = self.client.get( + "/1234.56789v6", headers={"Authorization": token} + ) + self.assertEqual(response.status_code, HTTPStatus.OK) + + data = json.loads(response.data) + res = jsonschema.RefResolver( + "file://%s/" % os.path.abspath(os.path.dirname(self.SCHEMA_PATH)), + None, + ) + self.assertIsNone( + jsonschema.validate(data, self.schema, resolver=res), + "Response content is valid per schema", + ) + + for field in get_required_fields(): + self.assertIn(field, data["results"][0]) diff --git a/search/routes/api/tests/test_serialize.py b/search/routes/api/tests/test_serialize.py deleted file mode 100644 index c79a6a5d..00000000 --- a/search/routes/api/tests/test_serialize.py +++ /dev/null @@ -1,159 +0,0 @@ -"""Tests for serializers.""" - -import os -from unittest import TestCase, mock -from datetime import datetime -import json -import jsonschema -from .... import domain, encode -from .. import serialize - - -def mock_jsonify(o): - return json.dumps(o, cls=encode.ISO8601JSONEncoder) - - -class TestSerializeJSONDocument(TestCase): - """Serialize a single :class:`domain.Document` as JSON.""" - - SCHEMA_PATH = os.path.abspath('schema/resources/Document.json') - - def setUp(self): - with open(self.SCHEMA_PATH) as f: - self.schema = json.load(f) - - @mock.patch(f'{serialize.__name__}.url_for', lambda *a, **k: 'http://f/12') - @mock.patch(f'{serialize.__name__}.jsonify', mock_jsonify) - def test_to_json(self): - """Just your run-of-the-mill arXiv document generates valid JSON.""" - document = dict( - submitted_date=datetime.now(), - submitted_date_first=datetime.now(), - announced_date_first=datetime.now(), - id='1234.5678', - abstract='very abstract', - authors=[dict(full_name='F. Bar', orcid='1234-5678-9012-3456')], - submitter=dict(full_name='S. Ubmitter', author_id='su_1'), - modified_date=datetime.now(), - updated_date=datetime.now(), - is_current=True, - is_withdrawn=False, - license={ - 'uri': 'http://foo.license/1', - 'label': 'Notalicense 5.4' - }, - paper_id='1234.5678', - paper_id_v='1234.5678v6', - title='tiiiitle', - source={ - 'flags': 'A', - 'format': 'pdftotex', - 'size_bytes': 2 - }, - version=6, - latest='1234.5678v6', - latest_version=6, - report_num='somenum1', - msc_class=['c1'], - acm_class=['z2'], - journal_ref='somejournal (1991): 2-34', - doi='10.123456/7890', - comments='very science', - abs_categories='astro-ph.CO foo.BR', - formats=['pdf', 'other'], - primary_classification=dict( - group={'id': 'foo', 'name': 'Foo Group'}, - archive={'id': 'foo', 'name': 'Foo Archive'}, - category={'id': 'foo.BR', 'name': 'Foo Category'}, - ), - secondary_classification=[ - dict( - group={'id': 'foo', 'name': 'Foo Group'}, - archive={'id': 'foo', 'name': 'Foo Archive'}, - category={'id': 'foo.BZ', 'name': 'Baz Category'}, - ) - ] - ) - srlzd = serialize.as_json(document) - res = jsonschema.RefResolver( - 'file://%s/' % os.path.abspath(os.path.dirname(self.SCHEMA_PATH)), - None - ) - self.assertIsNone( - jsonschema.validate(json.loads(srlzd), self.schema, resolver=res) - ) - - -class TestSerializeJSONDocumentSet(TestCase): - """Serialize a :class:`domain.DocumentSet` as JSON.""" - - SCHEMA_PATH = os.path.abspath('schema/resources/DocumentSet.json') - - def setUp(self): - with open(self.SCHEMA_PATH) as f: - self.schema = json.load(f) - - @mock.patch(f'{serialize.__name__}.url_for', lambda *a, **k: 'http://f/12') - @mock.patch(f'{serialize.__name__}.jsonify', mock_jsonify) - def test_to_json(self): - """Just your run-of-the-mill arXiv document generates valid JSON.""" - document = dict( - submitted_date=datetime.now(), - submitted_date_first=datetime.now(), - announced_date_first=datetime.now(), - id='1234.5678', - abstract='very abstract', - authors=[ - dict(full_name='F. Bar', orcid='1234-5678-9012-3456') - ], - submitter=dict(full_name='S. Ubmitter', author_id='su_1'), - modified_date=datetime.now(), - updated_date=datetime.now(), - is_current=True, - is_withdrawn=False, - license={ - 'uri': 'http://foo.license/1', - 'label': 'Notalicense 5.4' - }, - paper_id='1234.5678', - paper_id_v='1234.5678v6', - title='tiiiitle', - source={ - 'flags': 'A', - 'format': 'pdftotex', - 'size_bytes': 2 - }, - version=6, - latest='1234.5678v6', - latest_version=6, - report_num='somenum1', - msc_class=['c1'], - acm_class=['z2'], - journal_ref='somejournal (1991): 2-34', - doi='10.123456/7890', - comments='very science', - abs_categories='astro-ph.CO foo.BR', - formats=['pdf', 'other'], - primary_classification=dict( - group={'id': 'foo', 'name': 'Foo Group'}, - archive={'id': 'foo', 'name': 'Foo Archive'}, - category={'id': 'foo.BR', 'name': 'Foo Category'}, - ), - secondary_classification=[ - dict( - group={'id': 'foo', 'name': 'Foo Group'}, - archive={'id': 'foo', 'name': 'Foo Archive'}, - category={'id': 'foo.BZ', 'name': 'Baz Category'}, - ) - ] - ) - meta = {'start': 0, 'size': 50, 'end': 50, 'total': 500202} - document_set = dict(results=[document], metadata=meta) - srlzd = serialize.as_json(document_set) - res = jsonschema.RefResolver( - 'file://%s/' % os.path.abspath(os.path.dirname(self.SCHEMA_PATH)), - None - ) - self.assertIsNone( - jsonschema.validate(json.loads(srlzd), self.schema, resolver=res) - ) diff --git a/search/routes/classic_api/__init__.py b/search/routes/classic_api/__init__.py new file mode 100644 index 00000000..ddb6d269 --- /dev/null +++ b/search/routes/classic_api/__init__.py @@ -0,0 +1,42 @@ +"""Provides the classic search API.""" + +__all__ = ["blueprint", "exceptions"] + +from flask import Blueprint, make_response, request, Response + +from arxiv.base import logging +from arxiv.users.auth import scopes +from arxiv.users.auth.decorators import scoped +from search import serialize +from search.controllers import classic_api +from search.routes.consts import ATOM_XML +from search.routes.classic_api import exceptions + +logger = logging.getLogger(__name__) + +blueprint = Blueprint("classic_api", __name__, url_prefix="/classic_api") + + +@blueprint.route("/query", methods=["GET"]) +@scoped(required=scopes.READ_PUBLIC) +def query() -> Response: + """Main query endpoint.""" + logger.debug("Got query: %s", request.args) + data, status_code, headers = classic_api.query(request.args) + response_data = serialize.as_atom( # type: ignore + data.results, query=data.query + ) # type: ignore + headers.update({"Content-type": ATOM_XML}) + response: Response = make_response(response_data, status_code, headers) + return response + + +@blueprint.route("/v", methods=["GET"]) +@scoped(required=scopes.READ_PUBLIC) +def paper(paper_id: str, version: str) -> Response: + """Document metadata endpoint.""" + data, status_code, headers = classic_api.paper(f"{paper_id}v{version}") + response_data = serialize.as_atom(data.results) # type:ignore + headers.update({"Content-type": ATOM_XML}) + response: Response = make_response(response_data, status_code, headers) + return response diff --git a/search/routes/classic_api/exceptions.py b/search/routes/classic_api/exceptions.py new file mode 100644 index 00000000..602e58db --- /dev/null +++ b/search/routes/classic_api/exceptions.py @@ -0,0 +1,120 @@ +""" +Exception handlers for classic arXiv API endpoints. + +.. todo:: This module belongs in :mod:`arxiv.base`. + +""" +from http import HTTPStatus +from typing import Callable, List, Tuple +from werkzeug.exceptions import ( + NotFound, + Forbidden, + Unauthorized, + MethodNotAllowed, + RequestEntityTooLarge, + BadRequest, + InternalServerError, + HTTPException, +) +from flask import make_response, Response + +from arxiv.base import logging +from search.serialize import as_atom +from search.domain import Error +from search.routes.consts import ATOM_XML +from search.errors import ValidationError + + +logger = logging.getLogger(__name__) + +_handlers = [] + + +def handler(exception: type) -> Callable: + """Generate a decorator to register a handler for an exception.""" + + def deco(func: Callable) -> Callable: + """Register a function as an exception handler.""" + _handlers.append((exception, func)) + return func + + return deco + + +def get_handlers() -> List[Tuple[type, Callable]]: + """Get a list of registered exception handlers. + + Returns + ------- + list + List of (:class:`.HTTPException`, callable) tuples. + + """ + return _handlers + + +def respond( + error_msg: str, + link: str = "http://arxiv.org/api/errors", + status: HTTPStatus = HTTPStatus.INTERNAL_SERVER_ERROR, +) -> Response: + """Generate an Atom response.""" + return make_response( # type: ignore + as_atom(Error(id=link, error=error_msg, link=link)), + status, + {"Content-type": ATOM_XML}, + ) + + +@handler(NotFound) +def handle_not_found(error: NotFound) -> Response: + """Render the base 404 error page.""" + return respond(error.description, status=HTTPStatus.NOT_FOUND) + + +@handler(Forbidden) +def handle_forbidden(error: Forbidden) -> Response: + """Render the base 403 error page.""" + return respond(error.description, status=HTTPStatus.FORBIDDEN) + + +@handler(Unauthorized) +def handle_unauthorized(error: Unauthorized) -> Response: + """Render the base 401 error page.""" + return respond(error.description, status=HTTPStatus.UNAUTHORIZED) + + +@handler(MethodNotAllowed) +def handle_method_not_allowed(error: MethodNotAllowed) -> Response: + """Render the base 405 error page.""" + return respond(error.description, status=HTTPStatus.METHOD_NOT_ALLOWED) + + +@handler(RequestEntityTooLarge) +def handle_request_entity_too_large(error: RequestEntityTooLarge) -> Response: + """Render the base 413 error page.""" + return respond( + error.description, status=HTTPStatus.REQUEST_ENTITY_TOO_LARGE + ) + + +@handler(BadRequest) +def handle_bad_request(error: BadRequest) -> Response: + """Render the base 400 error page.""" + return respond(error.description, status=HTTPStatus.BAD_REQUEST) + + +@handler(InternalServerError) +def handle_internal_server_error(error: InternalServerError) -> Response: + """Render the base 500 error page.""" + if not isinstance(error, HTTPException): + logger.error("Caught unhandled exception: %s", error) + return respond(error.description, status=HTTPStatus.INTERNAL_SERVER_ERROR) + + +@handler(ValidationError) +def handle_validation_error(error: ValidationError) -> Response: + """Render the base 400 error page.""" + return respond( + error_msg=error.message, link=error.link, status=HTTPStatus.BAD_REQUEST + ) diff --git a/search/routes/classic_api/tests/__init__.py b/search/routes/classic_api/tests/__init__.py new file mode 100644 index 00000000..4c9d94f7 --- /dev/null +++ b/search/routes/classic_api/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for classic arXiv API routes.""" diff --git a/search/routes/classic_api/tests/test_classic.py b/search/routes/classic_api/tests/test_classic.py new file mode 100644 index 00000000..11c5f334 --- /dev/null +++ b/search/routes/classic_api/tests/test_classic.py @@ -0,0 +1,262 @@ +"""Tests for API routes.""" + +import os +from http import HTTPStatus +from xml.etree import ElementTree +from unittest import TestCase, mock + +from arxiv.users import helpers, auth +from arxiv.users.domain import Scope +from search import consts +from search import factory +from search import domain +from search.tests import mocks + + +class TestClassicAPISearchRequests(TestCase): + """Requests against the classic search API.""" + + def setUp(self): + """Instantiate and configure an API app.""" + jwt_secret = "foosecret" + os.environ["JWT_SECRET"] = jwt_secret + self.app = factory.create_classic_api_web_app() + self.app.config["JWT_SECRET"] = jwt_secret + self.client = self.app.test_client() + self.auth_header = { + "Authorization": helpers.generate_token( + "1234", + "foo@bar.com", + "foouser", + scope=[auth.scopes.READ_PUBLIC], + ) + } + + @staticmethod + def mock_classic_controller(controller, method="query", **kwargs): + docs: domain.DocumentSet = { + "results": [mocks.document()], + "metadata": {"start": 0, "end": 1, "size": 50, "total": 1}, + } + r_data = domain.ClassicSearchResponseData( + results=docs, + query=domain.ClassicAPIQuery( + **(kwargs or {"search_query": "all:electron"}) + ), + ) + getattr(controller, method).return_value = r_data, HTTPStatus.OK, {} + + def test_request_without_token(self): + """No auth token is provided on the request.""" + response = self.client.get( + "/classic_api/query?search_query=au:copernicus" + ) + self.assertEqual(response.status_code, HTTPStatus.UNAUTHORIZED) + + def test_with_token_lacking_scope(self): + """Client auth token lacks required public read scope.""" + token = helpers.generate_token( + "1234", + "foo@bar.com", + "foouser", + scope=[Scope("something", "read")], + ) + response = self.client.get( + "/classic_api/query?search_query=au:copernicus", + headers={"Authorization": token}, + ) + self.assertEqual(response.status_code, HTTPStatus.FORBIDDEN) + + @mock.patch(f"{factory.__name__}.classic_api.classic_api") + def test_with_valid_token(self, mock_controller): + """Client auth token has required public read scope.""" + self.mock_classic_controller(mock_controller, id_list=["1234.5678"]) + response = self.client.get( + "/classic_api/query?search_query=au:copernicus", + headers=self.auth_header, + ) + self.assertEqual(response.status_code, HTTPStatus.OK) + + @mock.patch(f"{factory.__name__}.classic_api.classic_api") + def test_paper_retrieval(self, mock_controller): + """Test single-paper retrieval.""" + self.mock_classic_controller(mock_controller, method="paper") + response = self.client.get( + "/classic_api/1234.56789v6", headers=self.auth_header + ) + self.assertEqual(response.status_code, HTTPStatus.OK) + + # Validation errors + def _fix_path(self, path): + return "/".join( + [ + "{{http://www.w3.org/2005/Atom}}{}".format(p) + for p in path.split("/") + ] + ) + + def _node(self, et: ElementTree, path: str): + """Return the node.""" + return et.find(self._fix_path(path)) + + def _text(self, et: ElementTree, path: str): + """Return the text content of the node.""" + return et.findtext(self._fix_path(path)) + + def check_validation_error(self, response, error, link): + et = ElementTree.fromstring(response.get_data(as_text=True)) + self.assertEqual(self._text(et, "entry/id"), link) + self.assertEqual(self._text(et, "entry/title"), "Error") + self.assertEqual(self._text(et, "entry/summary"), error) + link_attrib = self._node(et, "entry/link").attrib + self.assertEqual(link_attrib["href"], link) + + def test_start_not_a_number(self): + response = self.client.get( + "/classic_api/query?search_query=au:copernicus&start=non_number", + headers=self.auth_header, + ) + self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST) + self.check_validation_error( + response, + "start must be an integer", + "http://arxiv.org/api/errors#start_must_be_an_integer", + ) + + def test_start_negative(self): + response = self.client.get( + "/classic_api/query?search_query=au:copernicus&start=-1", + headers=self.auth_header, + ) + self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST) + self.check_validation_error( + response, + "start must be non-negative", + "http://arxiv.org/api/errors#start_must_be_non-negative", + ) + + def test_max_results_not_a_number(self): + response = self.client.get( + "/classic_api/query?search_query=au:copernicus&" + "max_results=non_number", + headers=self.auth_header, + ) + self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST) + self.check_validation_error( + response, + "max_results must be an integer", + "http://arxiv.org/api/errors#max_results_must_be_an_integer", + ) + + def test_max_results_negative(self): + response = self.client.get( + "/classic_api/query?search_query=au:copernicus&max_results=-1", + headers=self.auth_header, + ) + self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST) + self.check_validation_error( + response, + "max_results must be non-negative", + "http://arxiv.org/api/errors#max_results_must_be_non-negative", + ) + + @mock.patch(f"{factory.__name__}.classic_api.classic_api") + def test_sort_by_valid_values(self, mock_controller): + self.mock_classic_controller(mock_controller) + + for value in domain.SortBy: + response = self.client.get( + f"/classic_api/query?search_query=au:copernicus&" + f"sortBy={value}", + headers=self.auth_header, + ) + self.assertEqual(response.status_code, HTTPStatus.OK) + + def test_sort_by_invalid_values(self): + response = self.client.get( + "/classic_api/query?search_query=au:copernicus&sortBy=foo", + headers=self.auth_header, + ) + self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST) + self.check_validation_error( + response, + f"sortBy must be in: {', '.join(domain.SortBy)}", + "https://arxiv.org/help/api/user-manual#sort", + ) + + @mock.patch(f"{factory.__name__}.classic_api.classic_api") + def test_sort_direction_valid_values(self, mock_controller): + self.mock_classic_controller(mock_controller) + + for value in domain.SortDirection: + response = self.client.get( + f"/classic_api/query?search_query=au:copernicus&" + f"sortOrder={value}", + headers=self.auth_header, + ) + self.assertEqual(response.status_code, HTTPStatus.OK) + + def test_sort_direction_invalid_values(self): + response = self.client.get( + "/classic_api/query?search_query=au:copernicus&sortOrder=foo", + headers=self.auth_header, + ) + self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST) + self.check_validation_error( + response, + f"sortOrder must be in: {', '.join(domain.SortDirection)}", + "https://arxiv.org/help/api/user-manual#sort", + ) + + def test_sort_order(self): + # Default + sort_order = domain.SortOrder(by=None) + self.assertEqual(sort_order.to_es(), consts.DEFAULT_SORT_ORDER) + + # Relevance/Score + sort_order = domain.SortOrder(by=domain.SortBy.relevance) + self.assertEqual(sort_order.to_es(), [{"_score": {"order": "desc"}}]) + sort_order = domain.SortOrder( + by=domain.SortBy.relevance, + direction=domain.SortDirection.ascending, + ) + self.assertEqual(sort_order.to_es(), [{"_score": {"order": "asc"}}]) + + # Submitted date/Publication date + sort_order = domain.SortOrder(by=domain.SortBy.submitted_date) + self.assertEqual( + sort_order.to_es(), [{"submitted_date": {"order": "desc"}}] + ) + sort_order = domain.SortOrder( + by=domain.SortBy.submitted_date, + direction=domain.SortDirection.ascending, + ) + self.assertEqual( + sort_order.to_es(), [{"submitted_date": {"order": "asc"}}] + ) + + # Last update date/Update date + sort_order = domain.SortOrder(by=domain.SortBy.last_updated_date) + self.assertEqual( + sort_order.to_es(), [{"updated_date": {"order": "desc"}}] + ) + sort_order = domain.SortOrder( + by=domain.SortBy.last_updated_date, + direction=domain.SortDirection.ascending, + ) + self.assertEqual( + sort_order.to_es(), [{"updated_date": {"order": "asc"}}] + ) + + def test_invalid_arxiv_id(self): + response = self.client.get( + "/classic_api/query?id_list=cond—mat/0709123", + headers=self.auth_header, + ) + self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST) + self.check_validation_error( + response, + "incorrect id format for cond—mat/0709123", + "http://arxiv.org/api/errors#" + "incorrect_id_format_for_cond—mat/0709123", + ) diff --git a/search/routes/consts.py b/search/routes/consts.py new file mode 100644 index 00000000..92862409 --- /dev/null +++ b/search/routes/consts.py @@ -0,0 +1,4 @@ +"""Serialization MIME type and charset constants.""" + +ATOM_XML = "application/atom+xml; charset=utf-8" +JSON = "application/json; charset=utf-8" diff --git a/search/routes/context_processors.py b/search/routes/context_processors.py index 60130434..6b4fef12 100644 --- a/search/routes/context_processors.py +++ b/search/routes/context_processors.py @@ -1,44 +1,45 @@ """Context processors for use in :mod:`.routes.ui`.""" -from typing import Dict, Callable, Optional, Any, List -from urllib.parse import urljoin, urlparse, parse_qs, urlunparse, \ - urlencode, ParseResult +from typing import Dict, Callable, List +from urllib.parse import urlparse, urlunparse, urlencode, ParseResult from flask import request, url_for -from arxiv import taxonomy -from ..domain import Classification - def url_for_page_builder() -> Dict[str, Callable]: """Add a page URL builder function to the template context.""" + def url_for_page(page: int, size: int) -> str: """Build an URL to for a search result page.""" rule = request.url_rule - parts = urlparse(url_for(rule.endpoint)) # type: ignore + parts = urlparse(url_for(rule.endpoint)) # type: ignore args = request.args.copy() - args['start'] = (page - 1) * size + args["start"] = (page - 1) * size parts = parts._replace(query=urlencode(list(args.items(multi=True)))) url: str = urlunparse(parts) return url - return dict(url_for_page=url_for_page) + + return {"url_for_page": url_for_page} def current_url_params_builder() -> Dict[str, Callable]: """Add a function that gets the GET params from the current URL.""" + def current_url_params() -> str: """Get the GET params from the current URL.""" params: str = urlencode(list(request.args.items(multi=True))) return params - return dict(current_url_params=current_url_params) + + return {"current_url_params": current_url_params} def current_url_sans_parameters_builder() -> Dict[str, Callable]: """Add a function to strip GET parameters from the current URL.""" + def current_url_sans_parameters(*params_to_remove: str) -> str: """Get the current URL with ``param`` removed from GET parameters.""" if request.url_rule is None: - raise ValueError('No matching URL rule for this request (oddly)') + raise ValueError("No matching URL rule for this request (oddly)") rule = request.url_rule parts = urlparse(url_for(rule.endpoint)) args = request.args.copy() @@ -47,67 +48,75 @@ def current_url_sans_parameters(*params_to_remove: str) -> str: parts = parts._replace(query=urlencode(list(args.items(multi=True)))) url: str = urlunparse(parts) return url - return dict(current_url_sans_parameters=current_url_sans_parameters) + + return {"current_url_sans_parameters": current_url_sans_parameters} def url_for_author_search_builder() -> Dict[str, Callable]: """Inject a function to build author name query URLs.""" - search_url = urlparse(url_for('ui.search')) + search_url = urlparse(url_for("ui.search")) archives_urls: Dict[str, ParseResult] = {} def get_archives_url(archives: List[str]) -> ParseResult: key = ",".join(archives) if key not in archives_urls: - archives_urls[key] = urlparse(url_for('ui.search', - archives=archives)) + archives_urls[key] = urlparse( + url_for("ui.search", archives=archives) + ) return archives_urls[key] def url_for_author_search(forename: str, surname: str) -> str: # If we are in an archive-specific context, we want to preserve that # when generating URLs for author queries in search results. - archives = request.view_args.get('archives') + archives = request.view_args.get("archives") parts = get_archives_url(archives) if archives else search_url if forename: fparts = [part[0] for part in forename.split()] - forename_part = ' '.join(fparts) - name = f'{surname}, {forename_part}' + forename_part = " ".join(fparts) + name = f"{surname}, {forename_part}" else: name = surname - parts = parts._replace(query=urlencode({'searchtype': 'author', - 'query': name})) + parts = parts._replace( + query=urlencode({"searchtype": "author", "query": name}) + ) url: str = urlunparse(parts) return url - return dict(url_for_author_search=url_for_author_search) + + return {"url_for_author_search": url_for_author_search} def url_with_params_builder() -> Dict[str, Callable]: """Inject a URL builder that handles GET parameters.""" + def url_with_params(name: str, values: dict, params: dict) -> str: """Build a URL for ``name`` with path ``values`` and GET ``params``.""" parts = urlparse(url_for(name, **values)) parts = parts._replace(query=urlencode(params)) url: str = urlunparse(parts) return url - return dict(url_with_params=url_with_params) + + return {"url_with_params": url_with_params} def is_current_builder() -> Dict[str, Callable]: """Inject a function to evaluate whether or not a result is current.""" + def is_current(result: dict) -> bool: """Determine whether the result is the current version.""" - if result['submitted_date_all'] is None: - return bool(result['is_current']) + if result["submitted_date_all"] is None: + return bool(result["is_current"]) try: return bool( - result['is_current'] - and result['version'] == len(result['submitted_date_all']) + result["is_current"] + and result["version"] == len(result["submitted_date_all"]) ) except Exception: return True return False - return dict(is_current=is_current) + + return {"is_current": is_current} context_processors: List[Callable[[], Dict[str, Callable]]] = [ @@ -116,5 +125,5 @@ def is_current(result: dict) -> bool: current_url_sans_parameters_builder, url_for_author_search_builder, url_with_params_builder, - is_current_builder + is_current_builder, ] diff --git a/search/routes/ui.py b/search/routes/ui.py index cdc651b6..ea558991 100644 --- a/search/routes/ui.py +++ b/search/routes/ui.py @@ -1,32 +1,36 @@ """Provides the main search user interfaces.""" import json -from typing import Dict, Callable, Union, Any, Optional, List -from functools import wraps -from functools import lru_cache as memoize - - -from flask.json import jsonify -from flask import Blueprint, render_template, redirect, request, Response, \ - url_for -from werkzeug.urls import Href -from werkzeug.datastructures import MultiDict, ImmutableMultiDict +from http import HTTPStatus +from typing import Union, Optional, List + +from flask import ( + Blueprint, + render_template, + redirect, + request, + Response, + make_response, +) +from werkzeug.datastructures import MultiDict +from werkzeug.exceptions import InternalServerError +from werkzeug.wrappers import Response as WerkzeugResponse -from arxiv import status from arxiv.base import logging -from werkzeug.exceptions import InternalServerError +from search.routes import context_processors from search.controllers import simple, advanced, health_check -from . import context_processors + +_Response = Union[Response, WerkzeugResponse] logger = logging.getLogger(__name__) -blueprint = Blueprint('ui', __name__, url_prefix='/') +blueprint = Blueprint("ui", __name__, url_prefix="/") -PARAMS_TO_PERSIST = ['order', 'size', 'abstracts', 'date-date_type'] +PARAMS_TO_PERSIST = ["order", "size", "abstracts", "date-date_type"] """These parameters should be persisted in a cookie.""" -PARAMS_COOKIE_NAME = 'arxiv-search-parameters' +PARAMS_COOKIE_NAME = "arxiv-search-parameters" """The name of the cookie to use to persist search parameters.""" @@ -43,7 +47,7 @@ def get_parameters_from_cookie() -> None: return # We need the request args to be mutable. - request.args = MultiDict(request.args.items(multi=True)) # type: ignore + request.args = MultiDict(request.args.items(multi=True)) # type: ignore data = json.loads(request.cookies[PARAMS_COOKIE_NAME]) for param in PARAMS_TO_PERSIST: # Don't clobber the user's explicit request. @@ -55,9 +59,12 @@ def get_parameters_from_cookie() -> None: @blueprint.after_request def set_parameters_in_cookie(response: Response) -> Response: """Set request parameters in the cookie, to use as future defaults.""" - if response.status_code == status.HTTP_200_OK: - data = {param: request.args[param] for param in PARAMS_TO_PERSIST - if param in request.args} + if response.status_code == HTTPStatus.OK: + data = { + param: request.args[param] + for param in PARAMS_TO_PERSIST + if param in request.args + } response.set_cookie(PARAMS_COOKIE_NAME, json.dumps(data)) return response @@ -71,56 +78,71 @@ def apply_response_headers(response: Response) -> Response: return response -@blueprint.route('', methods=['GET']) -@blueprint.route('/', methods=['GET']) -def search(archives: Optional[List[str]] = None) -> Union[str, Response]: +@blueprint.route("", methods=["GET"]) +@blueprint.route("/", methods=["GET"]) +def search(archives: Optional[List[str]] = None) -> _Response: """Simple search interface.""" - response, code, headers = simple.search(request.args, archives) + data, code, hdrs = simple.search(request.args, archives) logger.debug(f"controller returned code: {code}") - if code == status.HTTP_200_OK: - return render_template("search/search.html", pagetitle="Search", # type: ignore - archives=archives, **response) - elif (code == status.HTTP_301_MOVED_PERMANENTLY - or code == status.HTTP_303_SEE_OTHER): - return redirect(headers['Location'], code=code) # type: ignore - raise InternalServerError('Unexpected error') - - -@blueprint.route('advanced', methods=['GET']) -def advanced_search() -> Union[str, Response]: + if code == HTTPStatus.OK: + content = render_template( + "search/search.html", pagetitle="Search", archives=archives, **data + ) + response: Response = make_response(content) + for key, value in hdrs.items(): + response.headers[key] = value + return response + elif code == HTTPStatus.MOVED_PERMANENTLY or code == HTTPStatus.SEE_OTHER: + return redirect(hdrs["Location"], code=code) + raise InternalServerError("Unexpected error") + + +@blueprint.route("advanced", methods=["GET"]) +def advanced_search() -> _Response: """Advanced search interface.""" - response, code, headers = advanced.search(request.args) - return render_template( # type: ignore - "search/advanced_search.html", - pagetitle="Advanced Search", - **response + data, code, hdrs = advanced.search(request.args) + content = render_template( + "search/advanced_search.html", pagetitle="Advanced Search", **data ) + response: Response = make_response(content) + response.status_code = code + for key, value in hdrs.items(): + response.headers[key] = value + return response -@blueprint.route('advanced/', methods=['GET']) -def group_search(groups_or_archives: str) -> Union[str, Response]: +@blueprint.route("advanced/", methods=["GET"]) +def group_search(groups_or_archives: str) -> _Response: """ Short-cut for advanced search with group or archive pre-selected. Note that this only supports options supported in the advanced search interface. Anything else will result in a 404. """ - response, code, _ = advanced.group_search(request.args, groups_or_archives) - return render_template( # type: ignore - "search/advanced_search.html", - pagetitle="Advanced Search", - **response + data, code, hdrs = advanced.group_search(request.args, groups_or_archives) + content = render_template( + "search/advanced_search.html", pagetitle="Advanced Search", **data ) + response: Response = make_response(content) + response.status_code = code + for key, value in hdrs.items(): + response.headers[key] = value + return response -@blueprint.route('status', methods=['GET', 'HEAD']) -def service_status() -> Union[str, Response]: +@blueprint.route("status", methods=["GET", "HEAD"]) +def service_status() -> _Response: """ Health check endpoint for search. Exercises the search index connection with a real query. """ - return health_check() # type: ignore + content, code, hdrs = health_check() + response: Response = make_response(content) + response.status_code = code + for key, value in hdrs.items(): + response.headers[key] = value + return response # Register context processors. diff --git a/search/serialize/__init__.py b/search/serialize/__init__.py new file mode 100644 index 00000000..95609a83 --- /dev/null +++ b/search/serialize/__init__.py @@ -0,0 +1,5 @@ +"""Provides serialization functions for API responses.""" +__all__ = ["JSONSerializer", "as_json", "AtomXMLSerializer", "as_atom"] + +from search.serialize.json import JSONSerializer, as_json +from search.serialize.atom import AtomXMLSerializer, as_atom diff --git a/search/serialize/atom.py b/search/serialize/atom.py new file mode 100644 index 00000000..6c61d77c --- /dev/null +++ b/search/serialize/atom.py @@ -0,0 +1,264 @@ +"""Atom serialization for classic arXiv API.""" + +import base64 +import hashlib +from datetime import datetime, timezone +import dateutil +from typing import Union, Optional, Dict, Any + +from flask import url_for +from feedgen.feed import FeedGenerator + +from search.domain import ( + Error, + DocumentSet, + Document, + ClassicAPIQuery, + document_set_from_documents, +) +from search.serialize.atom_extensions import ( + ArXivExtension, + ArXivEntryExtension, + OpenSearchExtension, + ARXIV_NS, +) +from search.domain.classic_api.query_parser import phrase_to_query_string +from search.serialize.base import BaseSerializer + + +class DateTime(datetime): + """DateTime is a hack wrapper around datetime. + + Feedgen doesn't have custom timestamp formatting. It uses isoformat, so + we use a custom class that overrides the isoformat class. + """ + + def isoformat(self, sep: str = "T", timespec: str = "auto") -> str: + """Return formatted datetime.""" + return self.strftime("%Y-%m-%dT%H:%M:%SZ") + + @property + def tzinfo(self) -> timezone: + """Return the objects timezone.""" + return timezone.utc + + +def to_utc(dt: Union[datetime, str]) -> Optional[datetime]: + """Localize datetime objects to UTC timezone.""" + if dt is None: + return None + if isinstance(dt, str): + try: + parsed_dt = dateutil.parser.parse(dt) + return DateTime.fromtimestamp(parsed_dt.timestamp()) + except Exception: + return None + return DateTime.fromtimestamp(dt.astimezone(timezone.utc).timestamp()) + + +class AtomXMLSerializer(BaseSerializer): + """Atom XML serializer for paper metadata.""" + + @classmethod + def transform_document( + cls, + fg: FeedGenerator, + doc: Document, + query: Optional[ClassicAPIQuery] = None, + ) -> None: + """Select a subset of :class:`Document` properties for public API.""" + entry = fg.add_entry() + entry.id( + url_for( + "abs", + paper_id=doc["paper_id"], + version=doc["version"], + _external=True, + ) + ) + entry.title(doc["title"]) + entry.summary(doc["abstract"]) + entry.published(to_utc(doc["submitted_date"])) + entry.updated(to_utc(doc["updated_date"])) + entry.link( + { + "href": url_for( + "abs", + paper_id=doc["paper_id"], + version=doc["version"], + _external=True, + ), + "type": "text/html", + } + ) + + entry.link( + { + "href": url_for( + "pdf", + paper_id=doc["paper_id"], + version=doc["version"], + _external=True, + ), + "type": "application/pdf", + "rel": "related", + "title": "pdf", + } + ) + + if doc.get("comments"): + entry.arxiv.comment(doc["comments"]) + + if doc.get("journal_ref"): + entry.arxiv.journal_ref(doc["journal_ref"]) + + if doc.get("doi"): + entry.arxiv.doi(doc["doi"]) + + if doc["primary_classification"]["category"] is not None: + entry.arxiv.primary_category( + doc["primary_classification"]["category"]["id"] + ) + entry.category( + term=doc["primary_classification"]["category"]["id"], + scheme=ARXIV_NS, + ) + + for category in doc["secondary_classification"]: + entry.category(term=category["category"]["id"], scheme=ARXIV_NS) + + for author in doc["authors"]: + author_data: Dict[str, Any] = {"name": author["full_name"]} + if author.get("affiliation"): + author_data["affiliation"] = author["affiliation"] + entry.arxiv.author(author_data) + + @staticmethod + def _get_feed(query: Optional[ClassicAPIQuery] = None) -> FeedGenerator: + fg = FeedGenerator() + fg.generator("") + fg.register_extension("opensearch", OpenSearchExtension) + fg.register_extension( + "arxiv", ArXivExtension, ArXivEntryExtension, rss=False + ) + + if query: + if query.phrase is not None: + query_string = phrase_to_query_string(query.phrase) + else: + query_string = "" + + if query.id_list: + id_list = ",".join(query.id_list) + else: + id_list = "" + + fg.title(f"arXiv Query: {query.to_query_string()}") + + # From perl documentation of the old site: + # search_id is calculated by taking SHA-1 digest of the query + # string. Digest is in bytes form and it's 20 bytes long. Then it's + # base64 encoded, but perls version returns only 27 characters - + # it omits the `=` sign at the end. + search_id = base64.b64encode( + hashlib.sha1(query.to_query_string().encode("utf-8")).digest() + ).decode("utf-8")[:-1] + fg.id( + url_for("classic_api.query").replace("/query", f"/{search_id}") + ) + + fg.link( + { + "href": url_for( + "classic_api.query", + search_query=query_string, + start=query.page_start, + max_results=query.size, + id_list=id_list, + ), + "type": "application/atom+xml", + } + ) + else: + # TODO: Discuss better defaults + fg.title("arXiv Search Results") + fg.id("https://arxiv.org/") + + fg.updated(to_utc(datetime.utcnow())) + return fg + + @classmethod + def serialize( + cls, document_set: DocumentSet, query: Optional[ClassicAPIQuery] = None + ) -> str: + """Generate Atom response for a :class:`DocumentSet`.""" + fg = cls._get_feed(query) + + # pylint struggles with the opensearch extensions, so we ignore + # no-member here. + # pylint: disable=no-member + fg.opensearch.totalResults( + document_set["metadata"].get("total_results") + ) + fg.opensearch.itemsPerPage(document_set["metadata"].get("size")) + fg.opensearch.startIndex(document_set["metadata"].get("start")) + + for doc in reversed(document_set["results"]): + cls.transform_document(fg, doc, query=query) + + return fg.atom_str(pretty=True) # type: ignore + + @classmethod + def serialize_error( + cls, error: Error, query: Optional[ClassicAPIQuery] = None + ) -> str: + """Generate Atom error response.""" + fg = cls._get_feed(query) + + # pylint struggles with the opensearch extensions, so we ignore + # no-member here. + # pylint: disable=no-member + fg.opensearch.totalResults(1) + fg.opensearch.itemsPerPage(1) + fg.opensearch.startIndex(0) + + entry = fg.add_entry() + entry.id(error.id) + entry.title("Error") + entry.summary(error.error) + entry.updated(to_utc(error.created)) + entry.link( + {"href": error.link, "rel": "alternate", "type": "text/html"} + ) + entry.arxiv.author({"name": error.author}) + + return fg.atom_str(pretty=True) # type: ignore + + @classmethod + def serialize_document( + cls, document: Document, query: Optional[ClassicAPIQuery] = None + ) -> str: + """Generate Atom feed for a single :class:`Document`.""" + # Wrap the single document in a DocumentSet wrapper. + document_set = document_set_from_documents([document]) + + return cls.serialize(document_set, query=query) + + +def as_atom( + document_or_set: Union[Error, DocumentSet, Document], + query: Optional[ClassicAPIQuery] = None, +) -> str: + """Serialize a :class:`DocumentSet` as Atom.""" + if isinstance(document_or_set, Error): + return AtomXMLSerializer.serialize_error( + document_or_set, query=query + ) # type: ignore + # type: ignore + elif "paper_id" in document_or_set: + return AtomXMLSerializer.serialize_document( # type: ignore + document_or_set, query=query + ) + return AtomXMLSerializer.serialize( # type: ignore + document_or_set, query=query + ) # type: ignore diff --git a/search/serialize/atom_extensions.py b/search/serialize/atom_extensions.py new file mode 100644 index 00000000..aa1b9bb4 --- /dev/null +++ b/search/serialize/atom_extensions.py @@ -0,0 +1,308 @@ +"""FeedGen extensions to implement serialization of the arXiv legacy API. + +Throughout module, pylint: disable=arguments-differ due to inconsistencies in +feedgen library. +""" +# pylint: disable=arguments-differ + +from typing import Any, Dict, List + +from lxml import etree +from feedgen.ext.base import BaseEntryExtension, BaseExtension +from feedgen.entry import FeedEntry +from feedgen.feed import FeedGenerator + + +ARXIV_NS = "http://arxiv.org/schemas/atom" +OPENSEARCH_NS = "http://a9.com/-/spec/opensearch/1.1/" + + +class OpenSearchExtension(BaseExtension): + """Extension of the Feedgen base class to put OpenSearch metadata.""" + + # pylint: disable=invalid-name + + def __init__(self: BaseExtension) -> None: + """Initialize extension parameters.""" + # __ syntax follows convention of :module:`feedgen.ext` + self.__opensearch_totalResults = None + self.__opensearch_startIndex = None + self.__opensearch_itemsPerPage = None + + def extend_atom( + self: BaseExtension, atom_feed: FeedGenerator + ) -> FeedGenerator: + """ + Assign the Atom feed generator to the extension. + + Parameters + ---------- + atom_feed : :class:`.FeedGenerator` + The FeedGenerator to use for Atom results. + + Returns + ------- + FeedGenerator + The provided feed generator. + + """ + if self.__opensearch_itemsPerPage is not None: + elt = etree.SubElement( + atom_feed, f"{{{OPENSEARCH_NS}}}itemsPerPage" + ) + elt.text = self.__opensearch_itemsPerPage + + if self.__opensearch_totalResults is not None: + elt = etree.SubElement( + atom_feed, f"{{{OPENSEARCH_NS}}}totalResults" + ) + elt.text = self.__opensearch_totalResults + + if self.__opensearch_startIndex is not None: + elt = etree.SubElement(atom_feed, f"{{{OPENSEARCH_NS}}}startIndex") + elt.text = self.__opensearch_startIndex + + return atom_feed + + @staticmethod + def extend_rss(rss_feed: FeedGenerator) -> FeedGenerator: + """ + Assign the RSS feed generator to the extension. + + Parameters + ---------- + rss_feed + The FeedGenerator to use for RSS results. + + Returns + ------- + FeedGenerator + The provided feed generator. + + """ + return rss_feed + + @staticmethod + def extend_ns() -> Dict[str, str]: + """ + Assign the feed's namespace string. + + Returns + ------- + str + The definition string for the "arxiv" namespace. + + """ + return {"opensearch": OPENSEARCH_NS} + + def totalResults(self: BaseExtension, text: str) -> None: + """Set the totalResults parameter.""" + self.__opensearch_totalResults = str(text) + + def startIndex(self: BaseExtension, text: str) -> None: + """Set the startIndex parameter.""" + self.__opensearch_startIndex = str(text) + + def itemsPerPage(self: BaseExtension, text: str) -> None: + """Set the itemsPerPage parameter.""" + self.__opensearch_itemsPerPage = str(text) + + +class ArXivExtension(BaseExtension): + """Extension of the Feedgen base class to allow us to define namespaces.""" + + def __init__(self: BaseExtension) -> None: + """Noop initialization.""" + + @staticmethod + def extend_atom(atom_feed: FeedGenerator) -> FeedGenerator: + """ + Assign the Atom feed generator to the extension. + + Parameters + ---------- + atom_feed + The FeedGenerator to use for Atom results. + + Returns + ------- + FeedGenerator + The provided feed generator. + + """ + return atom_feed + + @staticmethod + def extend_rss(rss_feed: FeedGenerator) -> FeedGenerator: + """ + Assign the RSS feed generator to the extension. + + Parameters + ---------- + rss_feed + The FeedGenerator to use for RSS results. + + Returns + ------- + FeedGenerator + The provided feed generator. + + """ + return rss_feed + + @staticmethod + def extend_ns() -> Dict[str, str]: + """ + Assign the feed's namespace string. + + Returns + ------- + str + The definition string for the "arxiv" namespace. + + """ + return {"arxiv": ARXIV_NS} + + +class ArXivEntryExtension(BaseEntryExtension): + """Extension of the Feedgen base class to allow us to add elements.""" + + def __init__(self: BaseEntryExtension): + """Initialize the member values to all be empty.""" + self.__arxiv_comment = None + self.__arxiv_primary_category = None + self.__arxiv_doi = None + self.__arxiv_journal_ref = None + self.__arxiv_authors: List[Dict] = [] + + def extend_atom(self: BaseEntryExtension, entry: FeedEntry) -> FeedEntry: + """ + Add this extension's new elements to the Atom feed entry. + + Parameters + ---------- + entry + The FeedEntry to modify. + + Returns + ------- + FeedEntry + The modified entry. + + """ + if self.__arxiv_comment: + comment_element = etree.SubElement(entry, f"{{{ARXIV_NS}}}comment") + comment_element.text = self.__arxiv_comment + + if self.__arxiv_primary_category: + primary_category_element = etree.SubElement( + entry, f"{{{ARXIV_NS}}}primary_category" + ) + primary_category_element.attrib[ + "term" + ] = self.__arxiv_primary_category + + if self.__arxiv_journal_ref: + journal_ref_element = etree.SubElement( + entry, f"{{{ARXIV_NS}}}journal_ref" + ) + journal_ref_element.text = self.__arxiv_journal_ref + + if self.__arxiv_authors: + for author in self.__arxiv_authors: + author_element = etree.SubElement(entry, "author") + name_element = etree.SubElement(author_element, "name") + name_element.text = author["name"] + for affiliation in author.get("affiliation", []): + affiliation_element = etree.SubElement( + author_element, "{%s}affiliation" % ARXIV_NS + ) + affiliation_element.text = affiliation + + if self.__arxiv_doi: + for doi in self.__arxiv_doi: + doi_element = etree.SubElement(entry, f"{{{ARXIV_NS}}}doi") + doi_element.text = doi + + doi_link_element = etree.SubElement(entry, "link") + doi_link_element.set("rel", "related") + doi_link_element.set("href", f"https://doi.org/{doi}") + + return entry + + @staticmethod + def extend_rss(entry: FeedEntry) -> FeedEntry: + """ + Add this extension's new elements to the RSS feed entry. + + Parameters + ---------- + entry + The FeedEntry to modify. + + Returns + ------- + FeedEntry + The modfied entry. + + """ + return entry + + def comment(self: BaseEntryExtension, text: str) -> None: + """ + Assign the comment value to this entry. + + Parameters + ---------- + text + The new comment text. + + """ + self.__arxiv_comment = text + + def primary_category(self: BaseEntryExtension, text: str) -> None: + """ + Assign the primary_category value to this entry. + + Parameters + ---------- + text + The new primary_category name. + + """ + self.__arxiv_primary_category = text + + def journal_ref(self: BaseEntryExtension, text: str) -> None: + """ + Assign the journal_ref value to this entry. + + Parameters + ---------- + text + The new journal_ref value. + + """ + self.__arxiv_journal_ref = text + + def doi(self: BaseEntryExtension, dois: Dict[str, str]) -> None: + """ + Assign the doi value to this entry. + + Parameters + ---------- + list + The new list of DOI assignments. + + """ + self.__arxiv_doi = dois + + def author(self: BaseEntryExtension, data: Dict[str, Any]) -> None: + """ + Add an author to this entry. + + Parameters + ---------- + data + A dictionary consisting of the author name and affiliation data. + """ + self.__arxiv_authors.append(data) diff --git a/search/serialize/base.py b/search/serialize/base.py new file mode 100644 index 00000000..cacfb913 --- /dev/null +++ b/search/serialize/base.py @@ -0,0 +1,5 @@ +"""Base class for API serializers.""" + + +class BaseSerializer(object): + """Base class for API serializers.""" diff --git a/search/serialize/json.py b/search/serialize/json.py new file mode 100644 index 00000000..74c188f8 --- /dev/null +++ b/search/serialize/json.py @@ -0,0 +1,152 @@ +"""Serializers for API responses.""" + +from typing import Union, Optional, Dict, Any +from flask import jsonify, url_for, Response + +from search.serialize.base import BaseSerializer +from search.domain import DocumentSet, Document, Classification, APIQuery + + +class JSONSerializer(BaseSerializer): + """Serializes a :class:`DocumentSet` as JSON.""" + + # FIXME: Return type. + @classmethod + def _transform_classification( + cls, clsn: Classification + ) -> Optional[Dict[Any, Any]]: + category = clsn.get("category") + if category is None: + return None + return { + "group": clsn.get("group"), + "archive": clsn.get("archive"), + "category": category, + } + + # FIXME: Return type. + @classmethod + def _transform_format( + cls, fmt: str, paper_id: str, version: int + ) -> Dict[Any, Any]: + return { + "format": fmt, + "href": url_for(fmt, paper_id=paper_id, version=version), + } + + # FIXME: Return type. + @classmethod + def _transform_latest(cls, document: Document) -> Optional[Dict[Any, Any]]: + latest = document.get("latest") + if latest is None: + return None + return { + "paper_id": latest, + "href": url_for( + "api.paper", + paper_id=document["paper_id"], + version=document.get("latest_version"), + _external=True, + ), + "canonical": url_for( + "abs", + paper_id=document["paper_id"], + version=document.get("latest_version"), + ), + "version": document.get("latest_version"), + } + + # FIXME: Types. + @classmethod + def _transform_license( + cls, license: Dict[Any, Any] + ) -> Optional[Dict[Any, Any]]: + uri = license.get("uri") + if uri is None: + return None + return {"label": license.get("label", ""), "href": uri} + + # FIXME: Return type. + @classmethod + def transform_document( + cls, doc: Document, query: Optional[APIQuery] = None + ) -> Dict[Any, Any]: + """Select a subset of :class:`Document` properties for public API.""" + # Only return fields that have been explicitly requested. + data = { + key: value + for key, value in doc.items() + if query is None or key in query.include_fields + } + paper_id = doc["paper_id"] + version = doc["version"] + if "submitted_date_first" in data: + data["submitted_date_first"] = doc[ + "submitted_date_first" + ].isoformat() + if "announced_date_first" in data: + data["announced_date_first"] = doc[ + "announced_date_first" + ].isoformat() + if "formats" in data: + data["formats"] = [ + cls._transform_format(fmt, paper_id, version) + for fmt in doc["formats"] + ] + if "license" in data: + data["license"] = cls._transform_license(doc["license"]) + if "latest" in data: + data["latest"] = cls._transform_latest(doc) + + data["href"] = url_for( + "api.paper", paper_id=paper_id, version=version, _external=True + ) + data["canonical"] = url_for("abs", paper_id=paper_id, version=version) + return data + + @classmethod + def serialize( + cls, document_set: DocumentSet, query: Optional[APIQuery] = None + ) -> Response: + """Generate JSON for a :class:`DocumentSet`.""" + total_results = int(document_set["metadata"].get("total_results", 0)) + serialized: Response = jsonify( + { + "results": [ + cls.transform_document(doc, query=query) + for doc in document_set["results"] + ], + "metadata": { + "start": document_set["metadata"].get("start", ""), + "end": document_set["metadata"].get("end", ""), + "size": document_set["metadata"].get("size", ""), + "total_results": total_results, + "query": document_set["metadata"].get("query", []), + }, + } + ) + return serialized + + @classmethod + def serialize_document( + cls, document: Document, query: Optional[APIQuery] = None, + ) -> Response: + """Generate JSON for a single :class:`Document`.""" + serialized: Response = jsonify( + cls.transform_document(document, query=query) + ) + return serialized + + +def as_json( + document_or_set: Union[DocumentSet, Document], + query: Optional[APIQuery] = None, +) -> Response: + """Serialize a :class:`DocumentSet` as JSON.""" + if "paper_id" in document_or_set: + return JSONSerializer.serialize_document( # type:ignore + document_or_set, query=query + ) # type: ignore + return JSONSerializer.serialize( # type:ignore + document_or_set, query=query + ) diff --git a/search/serialize/tests/__init__.py b/search/serialize/tests/__init__.py new file mode 100644 index 00000000..f708b96f --- /dev/null +++ b/search/serialize/tests/__init__.py @@ -0,0 +1 @@ +"""Serialization tests.""" diff --git a/search/serialize/tests/test_serialize.py b/search/serialize/tests/test_serialize.py new file mode 100644 index 00000000..33fd85a6 --- /dev/null +++ b/search/serialize/tests/test_serialize.py @@ -0,0 +1,83 @@ +"""Tests for serializers.""" + +import os +import json +from unittest import TestCase, mock + +import jsonschema + +from search import encode +from search import serialize +from search.tests import mocks + + +def mock_jsonify(o): + return json.dumps(o, cls=encode.ISO8601JSONEncoder) + + +class TestSerializeJSONDocument(TestCase): + """Serialize a single :class:`domain.Document` as JSON.""" + + SCHEMA_PATH = os.path.abspath("schema/resources/Document.json") + + def setUp(self): + with open(self.SCHEMA_PATH) as f: + self.schema = json.load(f) + + @mock.patch( + f"search.serialize.json.url_for", lambda *a, **k: "http://f/12" + ) + @mock.patch(f"search.serialize.json.jsonify", mock_jsonify) + def test_to_json(self): + """Just your run-of-the-mill arXiv document generates valid JSON.""" + document = mocks.document() + srlzd = serialize.as_json(document) + res = jsonschema.RefResolver( + "file://%s/" % os.path.abspath(os.path.dirname(self.SCHEMA_PATH)), + None, + ) + self.assertIsNone( + jsonschema.validate(json.loads(srlzd), self.schema, resolver=res) + ) + + +class TestSerializeJSONDocumentSet(TestCase): + """Serialize a :class:`domain.DocumentSet` as JSON.""" + + SCHEMA_PATH = os.path.abspath("schema/resources/DocumentSet.json") + + def setUp(self): + with open(self.SCHEMA_PATH) as f: + self.schema = json.load(f) + + @mock.patch( + f"search.serialize.json.url_for", lambda *a, **k: "http://f/12" + ) + @mock.patch(f"search.serialize.json.jsonify", mock_jsonify) + def test_to_json(self): + """Just your run-of-the-mill arXiv document generates valid JSON.""" + document = mocks.document() + meta = {"start": 0, "size": 50, "end": 50, "total": 500202} + document_set = {"results": [document], "metadata": meta} + srlzd = serialize.as_json(document_set) + res = jsonschema.RefResolver( + "file://%s/" % os.path.abspath(os.path.dirname(self.SCHEMA_PATH)), + None, + ) + self.assertIsNone( + jsonschema.validate(json.loads(srlzd), self.schema, resolver=res) + ) + + +class TestSerializeAtomDocument(TestCase): + """Serialize a single :class:`domain.Document` as Atom.""" + + @mock.patch( + f"search.serialize.atom.url_for", lambda *a, **k: "http://f/12" + ) + def test_to_atom(self): + """Just your run-of-the-mill arXiv document generates valid Atom.""" + document = mocks.document() + _ = serialize.as_atom(document) + + # TODO: Verify valid AtomXML diff --git a/search/services/__init__.py b/search/services/__init__.py index bf21553c..12c67095 100644 --- a/search/services/__init__.py +++ b/search/services/__init__.py @@ -1,3 +1,5 @@ """Provides service integration modules for use by controllers.""" -from .index import SearchSession +__all__ = ["SearchSession"] + +from search.services.index import SearchSession diff --git a/search/services/fulltext.py b/search/services/fulltext.py index a00bd9b2..fc4ef7a3 100644 --- a/search/services/fulltext.py +++ b/search/services/fulltext.py @@ -1,15 +1,14 @@ """Provides access to fulltext content for arXiv papers.""" +import json +from http import HTTPStatus from functools import wraps -import os from urllib.parse import urljoin -import json import requests -from arxiv import status -from search.context import get_application_config, get_application_global from search.domain import Fulltext +from search.context import get_application_config, get_application_global class FulltextSession(object): @@ -19,10 +18,10 @@ def __init__(self, endpoint: str) -> None: """Initialize an HTTP session.""" self._session = requests.Session() self._adapter = requests.adapters.HTTPAdapter(max_retries=2) - self._session.mount('https://', self._adapter) + self._session.mount("https://", self._adapter) - if not endpoint[-1] == '/': - endpoint += '/' + if not endpoint[-1] == "/": + endpoint += "/" self.endpoint = endpoint def retrieve(self, document_id: str) -> Fulltext: @@ -49,38 +48,43 @@ def retrieve(self, document_id: str) -> Fulltext: IOError Raised when unable to retrieve fulltext content. """ - if not document_id: # This could use further elaboration. - raise ValueError('Invalid value for document_id') + if not document_id: # This could use further elaboration. + raise ValueError("Invalid value for document_id") try: response = requests.get(urljoin(self.endpoint, document_id)) - except requests.exceptions.SSLError as e: - raise IOError('SSL failed: %s' % e) - - if response.status_code != status.HTTP_200_OK: - raise IOError('%s: could not retrieve fulltext: %i' % - (document_id, response.status_code)) + except requests.exceptions.SSLError as ex: + raise IOError("SSL failed: %s" % ex) + + if response.status_code != HTTPStatus.OK: + raise IOError( + "%s: could not retrieve fulltext: %i" + % (document_id, response.status_code) + ) try: data = response.json() - except json.decoder.JSONDecodeError as e: - raise IOError('%s: could not decode response: %s' % - (document_id, e)) from e - return Fulltext(**data) # type: ignore + except json.decoder.JSONDecodeError as ex: + raise IOError( + "%s: could not decode response: %s" % (document_id, ex) + ) from ex + return Fulltext(**data) # type: ignore # See https://github.com/python/mypy/issues/3937 def init_app(app: object = None) -> None: """Set default configuration parameters for an application instance.""" config = get_application_config(app) - config.setdefault('FULLTEXT_ENDPOINT', - 'https://fulltext.arxiv.org/fulltext/') + config.setdefault( + "FULLTEXT_ENDPOINT", "https://fulltext.arxiv.org/fulltext/" + ) def get_session(app: object = None) -> FulltextSession: """Get a new session with the fulltext endpoint.""" config = get_application_config(app) - endpoint = config.get('FULLTEXT_ENDPOINT', - 'https://fulltext.arxiv.org/fulltext/') + endpoint = config.get( + "FULLTEXT_ENDPOINT", "https://fulltext.arxiv.org/fulltext/" + ) return FulltextSession(endpoint) @@ -89,9 +93,9 @@ def current_session() -> FulltextSession: g = get_application_global() if not g: return get_session() - if 'fulltext' not in g: - g.fulltext = get_session() # type: ignore - return g.fulltext # type: ignore + if "fulltext" not in g: + g.fulltext = get_session() # type: ignore + return g.fulltext # type: ignore @wraps(FulltextSession.retrieve) diff --git a/search/services/index/__init__.py b/search/services/index/__init__.py index 53b6aa6c..5b9d8a72 100644 --- a/search/services/index/__init__.py +++ b/search/services/index/__init__.py @@ -11,48 +11,77 @@ :mod:`search.agent.consumer.MetadataRecordProcessor`). """ +__all__ = ["Q", "SearchSession"] + import json -import urllib3 +import warnings from contextlib import contextmanager -from typing import Any, Optional, Tuple, Union, List, Generator, Mapping -from functools import reduce, wraps -from operator import ior +from typing import Any, Optional, List, Generator, Dict +import urllib3 from flask import current_app -from elasticsearch import Elasticsearch, ElasticsearchException, \ - SerializationError, TransportError, NotFoundError, \ - helpers +from elasticsearch import ( + Elasticsearch, + ElasticsearchException, + SerializationError, + TransportError, + helpers, +) from elasticsearch.connection import Urllib3HttpConnection from elasticsearch.helpers import BulkIndexError - from elasticsearch_dsl import Search, Q from arxiv.base import logging from arxiv.integration.meta import MetaIntegration from search.context import get_application_config, get_application_global -from search.domain import Document, DocumentSet, Query, AdvancedQuery, \ - SimpleQuery, asdict, APIQuery - -from .exceptions import QueryError, IndexConnectionError, DocumentNotFound, \ - IndexingError, OutsideAllowedRange, MappingError -from .util import MAX_RESULTS -from .advanced import advanced_search -from .simple import simple_search -from .api import api_search -from . import highlighting -from . import results +from search.domain import ( + Document, + DocumentSet, + Query, + AdvancedQuery, + SimpleQuery, + APIQuery, + ClassicAPIQuery, +) + +from search.services.index.exceptions import ( + QueryError, + IndexConnectionError, + DocumentNotFound, + IndexingError, + OutsideAllowedRange, + MappingError, +) +from search.services.index.util import MAX_RESULTS +from search.services.index.advanced import advanced_search +from search.services.index.simple import simple_search +from search.services.index.api import api_search +from search.services.index.classic_api import classic_search +from search.services.index import highlighting +from search.services.index import results logger = logging.getLogger(__name__) # Disable the Elasticsearch logger. When enabled, the Elasticsearch logger # dumps entire Tracebacks prior to propagating exceptions. Thus we end up with # tracebacks in the logs even for handled exceptions. -logging.getLogger('elasticsearch').disabled = True +logging.getLogger("elasticsearch").disabled = True -ALL_SEARCH_FIELDS = ['author', 'title', 'abstract', 'comments', 'journal_ref', - 'acm_class', 'msc_class', 'report_num', 'paper_id', 'doi', - 'orcid', 'author_id'] +ALL_SEARCH_FIELDS = [ + "author", + "title", + "abstract", + "comments", + "journal_ref", + "acm_class", + "msc_class", + "report_num", + "paper_id", + "doi", + "orcid", + "author_id", +] @contextmanager @@ -60,45 +89,53 @@ def handle_es_exceptions() -> Generator: """Handle common ElasticSearch-related exceptions.""" try: yield - except TransportError as e: - if e.error == 'resource_already_exists_exception': - logger.debug('Index already exists; move along') + except TransportError as ex: + if ex.error == "resource_already_exists_exception": + logger.debug("Index already exists; move along") return - elif e.error == 'mapper_parsing_exception': - logger.error('ES mapper_parsing_exception: %s', e.info) - logger.debug(str(e.info)) - raise MappingError('Invalid mapping: %s' % str(e.info)) from e - elif e.error == 'index_not_found_exception': - logger.error('ES index_not_found_exception: %s', e.info) + elif ex.error == "mapper_parsing_exception": + logger.error("ES mapper_parsing_exception: %s", ex.info) + logger.debug(str(ex.info)) + raise MappingError("Invalid mapping: %s" % str(ex.info)) from ex + elif ex.error == "index_not_found_exception": + logger.error("ES index_not_found_exception: %s", ex.info) SearchSession.current_session().create_index() - elif e.error == 'parsing_exception': - logger.error('ES parsing_exception: %s', e.info) - raise QueryError(e.info) from e - elif e.status_code == 404: - logger.error('Caught NotFoundError: %s', e) - raise DocumentNotFound('No such document') - logger.error('Problem communicating with ES: %s' % e.error) + elif ex.error == "parsing_exception": + logger.error("ES parsing_exception: %s", ex.info) + raise QueryError(ex.info) from ex + elif ex.status_code == 404: + logger.error("Caught NotFoundError: %s", ex) + raise DocumentNotFound("No such document") + logger.error("Problem communicating with ES: %s" % ex.error) raise IndexConnectionError( - 'Problem communicating with ES: %s' % e.error - ) from e - except SerializationError as e: - logger.error("SerializationError: %s", e) - raise IndexingError('Problem serializing document: %s' % e) from e - except BulkIndexError as e: - logger.error("BulkIndexError: %s", e) - raise IndexingError('Problem with bulk indexing: %s' % e) from e - except Exception as e: - logger.error('Unhandled exception: %s') + "Problem communicating with ES: %s" % ex.error + ) from ex + except SerializationError as ex: + logger.error("SerializationError: %s", ex) + raise IndexingError("Problem serializing document: %s" % ex) from ex + except BulkIndexError as ex: + logger.error("BulkIndexError: %s", ex) + raise IndexingError("Problem with bulk indexing: %s" % ex) from ex + except Exception as ex: + logger.error("Unhandled exception: %s" % ex) raise class SearchSession(metaclass=MetaIntegration): """Encapsulates session with Elasticsearch host.""" - def __init__(self, host: str, index: str, port: int = 9200, - scheme: str = 'http', user: Optional[str] = None, - password: Optional[str] = None, mapping: Optional[str] = None, - verify: bool = True, **extra: Any) -> None: + def __init__( + self, + host: str, + index: str, + port: int = 9200, + scheme: str = "http", + user: Optional[str] = None, + password: Optional[str] = None, + mapping: Optional[str] = None, + verify: bool = True, + **extra: Any, + ) -> None: """ Initialize the connection to Elasticsearch. @@ -124,35 +161,46 @@ def __init__(self, host: str, index: str, port: int = 9200, """ self.index = index self.mapping = mapping - self.doc_type = 'document' - use_ssl = True if scheme == 'https' else False - http_auth = '%s:%s' % (user, password) if user else None - - self.conn_params = {'host': host, 'port': port, 'use_ssl': use_ssl, - 'http_auth': http_auth, 'verify_certs': verify} + self.doc_type = "document" + use_ssl = True if scheme == "https" else False + http_auth = "%s:%s" % (user, password) if user else None + + self.conn_params = { + "host": host, + "port": port, + "use_ssl": use_ssl, + "http_auth": http_auth, + "verify_certs": verify, + } self.conn_extra = extra + if not use_ssl: + warnings.warn(f"TLS is disabled, using port {port}") + if host == "localhost": + warnings.warn(f"Using ES at {host}:{port}; not OK for production") def new_connection(self) -> Elasticsearch: """Create a new :class:`.Elasticsearch` connection.""" - logger.debug('init ES session with %s', self.conn_params) + logger.debug("init ES session with %s", self.conn_params) try: es = Elasticsearch( [self.conn_params], connection_class=Urllib3HttpConnection, - **self.conn_extra) - except ElasticsearchException as e: - logger.error('ElasticsearchException: %s', e) + **self.conn_extra, + ) + except ElasticsearchException as ex: + logger.error("ElasticsearchException: %s", ex) raise IndexConnectionError( - 'Could not initialize ES session: %s' % e - ) from e + "Could not initialize ES session: %s" % ex + ) from ex return es def _base_search(self) -> Search: return Search(using=self.es, index=self.index) - def _load_mapping(self) -> dict: - if not self.mapping or type(self.mapping) is not str: - raise IndexingError('Mapping not set') + # FIXME: Return type. + def _load_mapping(self) -> Dict[Any, Any]: + if not self.mapping or not isinstance(self.mapping, str): + raise IndexingError("Mapping not set") with open(self.mapping) as f: mappings: dict = json.load(f) return mappings @@ -170,9 +218,9 @@ def es(self) -> Elasticsearch: connection. """ if current_app: - if 'elasticsearch' not in current_app.extensions: - current_app.extensions['elasticsearch'] = self.new_connection() - return current_app.extensions['elasticsearch'] + if "elasticsearch" not in current_app.extensions: + current_app.extensions["elasticsearch"] = self.new_connection() + return current_app.extensions["elasticsearch"] return self.new_connection() def cluster_available(self) -> bool: @@ -185,13 +233,13 @@ def cluster_available(self) -> bool: """ try: - self.es.cluster.health(wait_for_status='yellow', request_timeout=1) + self.es.cluster.health(wait_for_status="yellow", request_timeout=1) return True - except urllib3.exceptions.HTTPError as e: - logger.debug('Health check failed: %s', str(e)) + except urllib3.exceptions.HTTPError as ex: + logger.debug("Health check failed: %s", str(ex)) return False - except Exception as e: - logger.debug('Health check failed: %s', str(e)) + except Exception as ex: + logger.debug("Health check failed: %s", str(ex)) return False def create_index(self) -> None: @@ -226,8 +274,10 @@ def index_exists(self, index_name: str) -> bool: _exists: bool = self.es.indices.exists(index_name) return _exists - def reindex(self, old_index: str, new_index: str, - wait_for_completion: bool = False) -> dict: + # FIXME: Return type. + def reindex( + self, old_index: str, new_index: str, wait_for_completion: bool = False + ) -> Dict[Any, Any]: """ Create a new index and reindex with the current mappings. @@ -256,13 +306,14 @@ def reindex(self, old_index: str, new_index: str, with handle_es_exceptions(): self.es.indices.create(new_index, self._load_mapping()) - response: dict = self.es.reindex({ - "source": {"index": old_index}, - "dest": {"index": new_index} - }, wait_for_completion=wait_for_completion) + response: dict = self.es.reindex( + {"source": {"index": old_index}, "dest": {"index": new_index}}, + wait_for_completion=wait_for_completion, + ) return response - def get_task_status(self, task: str) -> dict: + # FIXME: Return type. + def get_task_status(self, task: str) -> Dict[Any, Any]: """ Get the status of a running task in ES (e.g. reindex). @@ -307,13 +358,18 @@ def add_document(self, document: Document) -> None: self.create_index() with handle_es_exceptions(): - ident = document['id'] if document['id'] else document['paper_id'] - logger.debug(f'{ident}: index document') - self.es.index(index=self.index, doc_type=self.doc_type, - id=ident, body=document) + ident = document["id"] if document["id"] else document["paper_id"] + logger.debug(f"{ident}: index document") + self.es.index( + index=self.index, + doc_type=self.doc_type, + id=ident, + body=document, + ) - def bulk_add_documents(self, documents: List[Document], - docs_per_chunk: int = 500) -> None: + def bulk_add_documents( + self, documents: List[Document], docs_per_chunk: int = 500 + ) -> None: """ Add documents to the search index using the bulk API. @@ -333,21 +389,25 @@ def bulk_add_documents(self, documents: List[Document], """ if not self.es.indices.exists(index=self.index): - logger.debug('index does not exist') + logger.debug("index does not exist") self.create_index() - logger.debug('created index') + logger.debug("created index") with handle_es_exceptions(): - actions = ({ - '_index': self.index, - '_type': self.doc_type, - '_id': document['id'], - '_source': document - } for document in documents) + actions = ( + { + "_index": self.index, + "_type": self.doc_type, + "_id": document["id"], + "_source": document, + } + for document in documents + ) - helpers.bulk(client=self.es, actions=actions, - chunk_size=docs_per_chunk) - logger.debug('added %i documents to index', len(documents)) + helpers.bulk( + client=self.es, actions=actions, chunk_size=docs_per_chunk + ) + logger.debug("added %i documents to index", len(documents)) def get_document(self, document_id: str) -> Document: """ @@ -370,13 +430,14 @@ def get_document(self, document_id: str) -> Document: """ with handle_es_exceptions(): - record = self.es.get(index=self.index, doc_type=self.doc_type, - id=document_id) + record = self.es.get( + index=self.index, doc_type=self.doc_type, id=document_id + ) if not record: logger.error("No such document: %s", document_id) - raise DocumentNotFound('No such document') - return results.to_document(record['_source'], highlight=False) + raise DocumentNotFound("No such document") + return results.to_document(record["_source"], highlight=False) # See https://github.com/python/mypy/issues/3937 def search(self, query: Query, highlight: bool = True) -> DocumentSet: @@ -400,14 +461,14 @@ def search(self, query: Query, highlight: bool = True) -> DocumentSet: """ # Make sure that the user is not requesting a nonexistant page. - max_pages = int(MAX_RESULTS/query.size) + max_pages = int(MAX_RESULTS / query.size) if query.page > max_pages: - _message = f'Requested page {query.page}, but max is {max_pages}' + _message = f"Requested page {query.page}, but max is {max_pages}" logger.error(_message) raise OutsideAllowedRange(_message) # Perform the search. - logger.debug('got current search request %s', str(query)) + logger.debug("got current search request %s", str(query)) current_search = self._base_search() try: if isinstance(query, AdvancedQuery): @@ -416,8 +477,10 @@ def search(self, query: Query, highlight: bool = True) -> DocumentSet: current_search = simple_search(current_search, query) elif isinstance(query, APIQuery): current_search = api_search(current_search, query) - except TypeError as e: - raise e + elif isinstance(query, ClassicAPIQuery): + current_search = classic_search(current_search, query) + except TypeError as ex: + raise ex # logger.error('Malformed query: %s', str(e)) # raise QueryError('Malformed query') from e @@ -428,12 +491,12 @@ def search(self, query: Query, highlight: bool = True) -> DocumentSet: if isinstance(query, APIQuery): current_search = current_search.extra( - _source={'include': query.include_fields} + _source={"include": query.include_fields} ) with handle_es_exceptions(): # Slicing the search adds pagination parameters to the request. - resp = current_search[query.page_start:query.page_end].execute() + resp = current_search[query.page_start : query.page_end].execute() # Perform post-processing on the search results. return results.to_documentset(query, resp, highlight=highlight) @@ -441,54 +504,58 @@ def search(self, query: Query, highlight: bool = True) -> DocumentSet: def exists(self, paper_id_v: str) -> bool: """Determine whether a paper exists in the index.""" with handle_es_exceptions(): - ex: bool = self.es.exists(self.index, self.doc_type, - paper_id_v) + ex: bool = self.es.exists(self.index, self.doc_type, paper_id_v) return ex @classmethod def init_app(cls, app: object = None) -> None: """Set default configuration parameters for an application instance.""" config = get_application_config(app) - config.setdefault('ELASTICSEARCH_HOST', 'localhost') - config.setdefault('ELASTICSEARCH_PORT', '9200') - config.setdefault('ELASTICSEARCH_INDEX', 'arxiv') - config.setdefault('ELASTICSEARCH_USER', None) - config.setdefault('ELASTICSEARCH_PASSWORD', None) - config.setdefault('ELASTICSEARCH_MAPPING', - 'mappings/DocumentMapping.json') - config.setdefault('ELASTICSEARCH_VERIFY', 'true') + config.setdefault("ELASTICSEARCH_SERVICE_HOST", "localhost") + config.setdefault("ELASTICSEARCH_SERVICE_PORT", "9200") + config.setdefault("ELASTICSEARCH_INDEX", "arxiv") + config.setdefault("ELASTICSEARCH_USER", None) + config.setdefault("ELASTICSEARCH_PASSWORD", None) + config.setdefault( + "ELASTICSEARCH_MAPPING", "mappings/DocumentMapping.json" + ) + config.setdefault("ELASTICSEARCH_VERIFY", "true") @classmethod - def get_session(cls, app: object = None) -> 'SearchSession': + def get_session(cls, app: object = None) -> "SearchSession": """Get a new session with the search index.""" config = get_application_config(app) - host = config.get('ELASTICSEARCH_HOST', 'localhost') - port = config.get('ELASTICSEARCH_PORT', '9200') - scheme = config.get('ELASTICSEARCH_SCHEME', 'http') - index = config.get('ELASTICSEARCH_INDEX', 'arxiv') - verify = config.get('ELASTICSEARCH_VERIFY', 'true') == 'true' - user = config.get('ELASTICSEARCH_USER', None) - password = config.get('ELASTICSEARCH_PASSWORD', None) - mapping = config.get('ELASTICSEARCH_MAPPING', - 'mappings/DocumentMapping.json') - return cls(host, index, port, scheme, user, password, mapping, - verify=verify) + host = config.get("ELASTICSEARCH_SERVICE_HOST", "localhost") + port = config.get("ELASTICSEARCH_SERVICE_PORT", "9200") + scheme = config.get( + "ELASTICSEARCH_SERVICE_PORT_%s_PROTO" % port, "http" + ) + index = config.get("ELASTICSEARCH_INDEX", "arxiv") + verify = config.get("ELASTICSEARCH_VERIFY", "true") == "true" + user = config.get("ELASTICSEARCH_USER", None) + password = config.get("ELASTICSEARCH_PASSWORD", None) + mapping = config.get( + "ELASTICSEARCH_MAPPING", "mappings/DocumentMapping.json" + ) + return cls( + host, index, port, scheme, user, password, mapping, verify=verify + ) @classmethod - def current_session(cls) -> 'SearchSession': + def current_session(cls) -> "SearchSession": """Get/create :class:`.SearchSession` for this context.""" g = get_application_global() if not g: return cls.get_session() - if 'search' not in g: - g.search = cls.get_session() # type: ignore - return g.search # type: ignore + if "search" not in g: + g.search = cls.get_session() # type: ignore + return g.search # type: ignore def ok() -> bool: """Health check.""" try: SearchSession.current_session().cluster_available() - except Exception: # TODO: be more specific. + except Exception: # TODO: be more specific. return False return True diff --git a/search/services/index/advanced.py b/search/services/index/advanced.py index d33fa3d1..b2015355 100644 --- a/search/services/index/advanced.py +++ b/search/services/index/advanced.py @@ -1,17 +1,16 @@ """Supports the advanced search feature.""" -from typing import Any, Union - -from functools import reduce, wraps -from operator import ior, iand +from typing import Any, Tuple from elasticsearch_dsl import Search, Q, SF -from elasticsearch_dsl.query import Range, Match, Bool - -from search.domain import AdvancedQuery, Classification +from elasticsearch_dsl.query import Range, Match -from .prepare import SEARCH_FIELDS, limit_by_classification -from .util import sort +from search.domain import AdvancedQuery +from search.services.index.util import sort +from search.services.index.prepare import ( + SEARCH_FIELDS, + limit_by_classification, +) def advanced_search(search: Search, query: AdvancedQuery) -> Search: @@ -38,20 +37,22 @@ def advanced_search(search: Search, query: AdvancedQuery) -> Search: search = search.filter("term", is_current=True) _q_clsn = limit_by_classification(query.classification) if query.include_cross_list: - _q_clsn |= limit_by_classification(query.classification, - "secondary_classification") - q = ( - _fielded_terms_to_q(query) - & _date_range(query) - & _q_clsn - ) - if query.order is None or query.order == 'relevance': + _q_clsn |= limit_by_classification( + query.classification, "secondary_classification" + ) + q = _fielded_terms_to_q(query) & _date_range(query) & _q_clsn + if query.order is None or query.order == "relevance": # Boost the current version heavily when sorting by relevance. - q = Q('function_score', query=q, boost=5, boost_mode="multiply", - score_mode="max", - functions=[ - SF({'weight': 5, 'filter': Q('term', is_current=True)}) - ]) + q = Q( + "function_score", + query=q, + boost=5, + boost_mode="multiply", + score_mode="max", + functions=[ + SF({"weight": 5, "filter": Q("term", is_current=True)}) + ], + ) search = sort(query, search) search = search.query(q) return search @@ -63,17 +64,18 @@ def _date_range(q: AdvancedQuery) -> Range: return Q() params = {} if q.date_range.date_type == q.date_range.ANNOUNCED: - fmt = '%Y-%m' + fmt = "%Y-%m" else: - fmt = '%Y-%m-%dT%H:%M:%S%z' + fmt = "%Y-%m-%dT%H:%M:%S%z" if q.date_range.start_date: params["gte"] = q.date_range.start_date.strftime(fmt) if q.date_range.end_date: params["lt"] = q.date_range.end_date.strftime(fmt) - return Q('range', **{q.date_range.date_type: params}) + return Q("range", **{q.date_range.date_type: params}) -def _grouped_terms_to_q(term_pair: tuple) -> Q: +# FIXME: Argument type. +def _grouped_terms_to_q(term_pair: Tuple[Any, Any, Any]) -> Q: """Generate a :class:`.Q` from grouped terms.""" term_a_raw, operator, term_b_raw = term_pair @@ -87,11 +89,11 @@ def _grouped_terms_to_q(term_pair: tuple) -> Q: else: term_b = SEARCH_FIELDS[term_b_raw.field](term_b_raw.term) - if operator == 'OR': + if operator == "OR": return term_a | term_b - elif operator == 'AND': + elif operator == "AND": return term_a & term_b - elif operator == 'NOT': + elif operator == "NOT": return term_a & ~term_b else: # TODO: Confirm proper exception. @@ -101,27 +103,28 @@ def _grouped_terms_to_q(term_pair: tuple) -> Q: def _get_operator(obj: Any) -> str: if type(obj) is tuple: return _get_operator(obj[0]) - return obj.operator # type: ignore + return obj.operator # type: ignore -def _group_terms(query: AdvancedQuery) -> tuple: +# FIXME: Return type. +def _group_terms(query: AdvancedQuery) -> Tuple[Any, ...]: """Group fielded search terms into a set of nested tuples.""" terms = query.terms[:] - for operator in ['NOT', 'AND', 'OR']: + for operator in ["NOT", "AND", "OR"]: i = 0 while i < len(terms) - 1: - if _get_operator(terms[i+1]) == operator: - terms[i] = (terms[i], operator, terms[i+1]) - terms.pop(i+1) + if _get_operator(terms[i + 1]) == operator: + terms[i] = (terms[i], operator, terms[i + 1]) + terms.pop(i + 1) i -= 1 i += 1 assert len(terms) == 1 - return terms[0] # type: ignore + return terms[0] # type: ignore def _fielded_terms_to_q(query: AdvancedQuery) -> Match: if len(query.terms) == 1: return SEARCH_FIELDS[query.terms[0].field](query.terms[0].term) elif len(query.terms) > 1: - return _grouped_terms_to_q(_group_terms(query)) - return Q('match_all') + return _grouped_terms_to_q(_group_terms(query)) # type:ignore + return Q("match_all") diff --git a/search/services/index/api.py b/search/services/index/api.py index 1d76c122..d16352df 100644 --- a/search/services/index/api.py +++ b/search/services/index/api.py @@ -1,17 +1,19 @@ """Supports the advanced search feature.""" -from typing import Any, Union - -from functools import reduce, wraps -from operator import ior, iand +from operator import ior +from functools import reduce +from typing import Any, Tuple from elasticsearch_dsl import Search, Q, SF -from elasticsearch_dsl.query import Range, Match, Bool - -from search.domain import Classification, APIQuery +from elasticsearch_dsl.query import Range, Match -from .prepare import SEARCH_FIELDS, query_primary_exact, query_secondary_exact -from .util import sort +from search.domain import APIQuery +from search.services.index.util import sort +from search.services.index.prepare import ( + SEARCH_FIELDS, + query_primary_exact, + query_secondary_exact, +) def api_search(search: Search, query: APIQuery) -> Search: @@ -39,24 +41,27 @@ def api_search(search: Search, query: APIQuery) -> Search: _q_clsn = Q() if query.primary_classification: - _q_clsn &= reduce(ior, map(query_primary_exact, - list(query.primary_classification))) + _q_clsn &= reduce( + ior, map(query_primary_exact, list(query.primary_classification)) + ) if query.secondary_classification: for classification in query.secondary_classification: - _q_clsn &= reduce(ior, map(query_secondary_exact, - list(classification))) - q = ( - _fielded_terms_to_q(query) - & _date_range(query) - & _q_clsn - ) - if query.order is None or query.order == 'relevance': + _q_clsn &= reduce( + ior, map(query_secondary_exact, list(classification)) + ) + q = _fielded_terms_to_q(query) & _date_range(query) & _q_clsn + if query.order is None or query.order == "relevance": # Boost the current version heavily when sorting by relevance. - q = Q('function_score', query=q, boost=5, boost_mode="multiply", - score_mode="max", - functions=[ - SF({'weight': 5, 'filter': Q('term', is_current=True)}) - ]) + q = Q( + "function_score", + query=q, + boost=5, + boost_mode="multiply", + score_mode="max", + functions=[ + SF({"weight": 5, "filter": Q("term", is_current=True)}) + ], + ) search = sort(query, search) search = search.query(q) return search @@ -68,17 +73,17 @@ def _date_range(q: APIQuery) -> Range: return Q() params = {} if q.date_range.date_type == q.date_range.ANNOUNCED: - fmt = '%Y-%m' + fmt = "%Y-%m" else: - fmt = '%Y-%m-%dT%H:%M:%S%z' + fmt = "%Y-%m-%dT%H:%M:%S%z" if q.date_range.start_date: params["gte"] = q.date_range.start_date.strftime(fmt) if q.date_range.end_date: params["lt"] = q.date_range.end_date.strftime(fmt) - return Q('range', **{q.date_range.date_type: params}) + return Q("range", **{q.date_range.date_type: params}) -def _grouped_terms_to_q(term_pair: tuple) -> Q: +def _grouped_terms_to_q(term_pair: Tuple[Any, Any, Any]) -> Q: """Generate a :class:`.Q` from grouped terms.""" term_a_raw, operator, term_b_raw = term_pair @@ -92,11 +97,11 @@ def _grouped_terms_to_q(term_pair: tuple) -> Q: else: term_b = SEARCH_FIELDS[term_b_raw.field](term_b_raw.term) - if operator == 'OR': + if operator == "OR": return term_a | term_b - elif operator == 'AND': + elif operator == "AND": return term_a & term_b - elif operator == 'NOT': + elif operator == "NOT": return term_a & ~term_b else: # TODO: Confirm proper exception. @@ -106,27 +111,28 @@ def _grouped_terms_to_q(term_pair: tuple) -> Q: def _get_operator(obj: Any) -> str: if type(obj) is tuple: return _get_operator(obj[0]) - return obj.operator # type: ignore + return obj.operator # type: ignore -def _group_terms(query: APIQuery) -> tuple: +# FIXME: Return type. +def _group_terms(query: APIQuery) -> Tuple[Any, ...]: """Group fielded search terms into a set of nested tuples.""" terms = query.terms[:] - for operator in ['NOT', 'AND', 'OR']: + for operator in ["NOT", "AND", "OR"]: i = 0 while i < len(terms) - 1: - if _get_operator(terms[i+1]) == operator: - terms[i] = (terms[i], operator, terms[i+1]) - terms.pop(i+1) + if _get_operator(terms[i + 1]) == operator: + terms[i] = (terms[i], operator, terms[i + 1]) + terms.pop(i + 1) i -= 1 i += 1 assert len(terms) == 1 - return terms[0] # type: ignore + return terms[0] # type: ignore def _fielded_terms_to_q(query: APIQuery) -> Match: if len(query.terms) == 1: return SEARCH_FIELDS[query.terms[0].field](query.terms[0].term) elif len(query.terms) > 1: - return _grouped_terms_to_q(_group_terms(query)) - return Q('match_all') + return _grouped_terms_to_q(_group_terms(query)) # type:ignore + return Q("match_all") diff --git a/search/services/index/authors.py b/search/services/index/authors.py index 3f6cef18..b6e76212 100644 --- a/search/services/index/authors.py +++ b/search/services/index/authors.py @@ -1,16 +1,13 @@ """Query-builders and helpers for searching by author name.""" -from typing import Tuple, Optional, List import re -from functools import reduce, wraps +from functools import reduce from operator import ior, iand -from elasticsearch_dsl import Search, Q, SF +from elasticsearch_dsl import Q from arxiv.base import logging - -from .util import wildcard_escape, escape, STRING_LITERAL, \ - remove_single_characters, has_wildcard +from search.services.index.util import escape, STRING_LITERAL, has_wildcard logger = logging.getLogger(__name__) logger.propagate = False @@ -25,9 +22,12 @@ def _remove_stopwords(term: str) -> str: """Remove common stopwords, except in literal queries.""" parts = re.split(STRING_LITERAL, term) for stopword in STOP: - parts = [re.sub(f"(^|\s+){stopword}(\s+|$)", " ", part) - if not part.startswith('"') and not part.startswith("'") - else part for part in parts] + parts = [ + re.sub(fr"(^|\s+){stopword}(\s+|$)", " ", part) + if not part.startswith('"') and not part.startswith("'") + else part + for part in parts + ] return "".join(parts) @@ -61,10 +61,10 @@ def part_query(term: str, path: str = "authors") -> Q: AUTHOR_QUERY_FIELDS = [ f"{path}.full_name", f"{path}.last_name", - f"{path}.full_name_initialized" + f"{path}.full_name_initialized", ] term = term.strip() - logger.debug(f'{path} part_query for {term}') + logger.debug(f"{path} part_query for {term}") # Commas are used to distinguish surname and forename. forename_is_individuated = "," in term @@ -77,38 +77,46 @@ def part_query(term: str, path: str = "authors") -> Q: forename = " ".join(name_parts[1:]).strip() # Doing a query string so that wildcards and literals are just handled. - q_surname = Q("query_string", fields=[f"{path}.last_name"], - query=escape(surname), - default_operator='AND', - allow_leading_wildcard=False) + q_surname = Q( + "query_string", + fields=[f"{path}.last_name"], + query=escape(surname), + default_operator="AND", + allow_leading_wildcard=False, + ) if forename: # If a wildcard is provided in the forename, we treat it as a # query string query. This has the disadvantage of losing term # order, but the advantage of handling wildcards as expected. - logger.debug(f'Forename: {forename}') + logger.debug(f"Forename: {forename}") if has_wildcard(forename): - q_forename = Q("query_string", fields=[f"{path}.first_name"], - query=escape(forename), - auto_generate_phrase_queries=True, - default_operator='AND', - allow_leading_wildcard=False) + q_forename = Q( + "query_string", + fields=[f"{path}.first_name"], + query=escape(forename), + auto_generate_phrase_queries=True, + default_operator="AND", + allow_leading_wildcard=False, + ) # Otherwise, we expect the forename to match as a phrase. The # _prefix bit means that the last word can match as a prefix of the # corresponding term. else: - q_forename = Q("match_phrase_prefix", - **{f"{path}__first_name": forename}) + q_forename = Q( + "match_phrase_prefix", **{f"{path}__first_name": forename} + ) # It may be the case that the forename consists of initials or some # other prefix/partial forename. For a match of this kind, each # part of the forename part must be a prefix of a term in the # forename. - if path == 'authors' and forename: - logger.debug('Consider initials: %s', forename) - q_forename |= Q("match_phrase_prefix", - **{f"{path}__initials": forename}) + if path == "authors" and forename: + logger.debug("Consider initials: %s", forename) + q_forename |= Q( + "match_phrase_prefix", **{f"{path}__initials": forename} + ) # We will treat this as a search for a single author; surname and # forename parts must match in the same (nested) author. @@ -119,22 +127,31 @@ def part_query(term: str, path: str = "authors") -> Q: # Match across all fields within a single author. We don't know which # bits of the query match which bits of the author name. This will # handle wildcards, literals, etc. - q = Q("query_string", - fields=AUTHOR_QUERY_FIELDS, default_operator='AND', - allow_leading_wildcard=False, - type="cross_fields", query=escape(term)) - return Q("nested", path=path, query=q, score_mode='sum') + q = Q( + "query_string", + fields=AUTHOR_QUERY_FIELDS, + default_operator="AND", + allow_leading_wildcard=False, + type="cross_fields", + query=escape(term), + ) + return Q("nested", path=path, query=q, score_mode="sum") -def string_query(term: str, path: str = 'authors', operator: str = 'AND') -> Q: +def string_query(term: str, path: str = "authors", operator: str = "AND") -> Q: """Build a query that handles query strings within a single author.""" - q = Q("query_string", fields=[f"{path}.full_name"], - default_operator=operator, allow_leading_wildcard=False, - type="cross_fields", query=escape(term)) - return Q('nested', path=path, query=q, score_mode='sum') + q = Q( + "query_string", + fields=[f"{path}.full_name"], + default_operator=operator, + allow_leading_wildcard=False, + type="cross_fields", + query=escape(term), + ) + return Q("nested", path=path, query=q, score_mode="sum") -def author_query(term: str, operator: str = 'and') -> Q: +def author_query(term: str, operator: str = "and") -> Q: """ Construct a query based on author (and owner) names. @@ -174,23 +191,33 @@ def author_query(term: str, operator: str = 'and') -> Q: logger.debug(f"Contains literal: {term}") # Apply literal parts of the query separately. - return reduce(iand if operator.upper() == 'AND' else ior, [ - (string_query(part, operator=operator) - | string_query(part, path="owners", operator=operator)) - for part in re.split(STRING_LITERAL, term) if part.strip() - ]) + return reduce( + iand if operator.upper() == "AND" else ior, + [ + ( + string_query(part, operator=operator) + | string_query(part, path="owners", operator=operator) + ) + for part in re.split(STRING_LITERAL, term) + if part.strip() + ], + ) - term = term.replace('"', '') # Just ignore unbalanced quotes. + term = term.replace('"', "") # Just ignore unbalanced quotes. - if ";" in term: # Authors are individuated. + if ";" in term: # Authors are individuated. logger.debug(f"Authors are individuated: {term}") logger.debug(f"Operator: {operator}") - return reduce(iand if operator.upper() == "AND" else ior, [ - (part_query(author_part) | part_query(author_part, "owners")) - for author_part in term.split(";") if author_part - ]) + return reduce( + iand if operator.upper() == "AND" else ior, + [ + (part_query(author_part) | part_query(author_part, "owners")) + for author_part in term.split(";") + if author_part + ], + ) - if "," in term: # Forename is individuated. + if "," in term: # Forename is individuated. logger.debug(f"Forename is individuated: {term}") return part_query(term) | part_query(term, "owners") @@ -201,51 +228,85 @@ def author_query(term: str, operator: str = 'and') -> Q: # # A query_string query on the combined field will yield matches among # authors. - q = Q('query_string', fields=['authors_combined'], - query=escape(term, quotes=True), - default_operator='and') + q = Q( + "query_string", + fields=["authors_combined"], + query=escape(term, quotes=True), + default_operator="and", + ) # A nested query_string query on full name will match within individual # authors. - q |= ( - Q('nested', path='authors', score_mode='sum', - query=Q("query_string", fields=['authors.full_name'], - default_operator=operator, allow_leading_wildcard=False, - query=escape(term, quotes=True))) - | Q('nested', path='owners', score_mode='sum', - query=Q("query_string", fields=['owners.full_name'], - default_operator=operator, allow_leading_wildcard=False, - query=escape(term, quotes=True))) + q |= Q( + "nested", + path="authors", + score_mode="sum", + query=Q( + "query_string", + fields=["authors.full_name"], + default_operator=operator, + allow_leading_wildcard=False, + query=escape(term, quotes=True), + ), + ) | Q( + "nested", + path="owners", + score_mode="sum", + query=Q( + "query_string", + fields=["owners.full_name"], + default_operator=operator, + allow_leading_wildcard=False, + query=escape(term, quotes=True), + ), ) return q -def author_id_query(term: str, operator: str = 'and') -> Q: +def author_id_query(term: str, operator: str = "and") -> Q: """Generate a query part for Author ID using the ES DSL.""" - term = term.lower() # Just in case. - if operator == 'or': - return ( - Q("nested", path="owners", - query=Q("terms", **{"owners__author_id": term.split()})) - | Q("terms", **{"submitter__author_id": term.split()}) - ) - return reduce(iand, [( - Q("nested", path="owners", - query=Q("term", **{"owners__author_id": part})) - | Q("term", **{"submitter__author_id": part}) - ) for part in term.split()]) + term = term.lower() # Just in case. + if operator == "or": + return Q( + "nested", + path="owners", + query=Q("terms", **{"owners__author_id": term.split()}), + ) | Q("terms", **{"submitter__author_id": term.split()}) + return reduce( + iand, + [ + ( + Q( + "nested", + path="owners", + query=Q("term", **{"owners__author_id": part}), + ) + | Q("term", **{"submitter__author_id": part}) + ) + for part in term.split() + ], + ) -def orcid_query(term: str, operator: str = 'and') -> Q: +def orcid_query(term: str, operator: str = "and") -> Q: """Generate a query part for ORCID ID using the ES DSL.""" - if operator == 'or': - return ( - Q("nested", path="owners", - query=Q("terms", **{"owners__orcid": term.split()})) - | Q("terms", **{"submitter__orcid": term.split()}) - ) - return reduce(iand, [( - Q("nested", path="owners", - query=Q("term", **{"owners__orcid": part})) - | Q("term", **{"submitter__orcid": part}) - ) for part in term.split()]) + if operator == "or": + return Q( + "nested", + path="owners", + query=Q("terms", **{"owners__orcid": term.split()}), + ) | Q("terms", **{"submitter__orcid": term.split()}) + return reduce( + iand, + [ + ( + Q( + "nested", + path="owners", + query=Q("term", **{"owners__orcid": part}), + ) + | Q("term", **{"submitter__orcid": part}) + ) + for part in term.split() + ], + ) diff --git a/search/services/index/classic_api/__init__.py b/search/services/index/classic_api/__init__.py new file mode 100644 index 00000000..82fa9b83 --- /dev/null +++ b/search/services/index/classic_api/__init__.py @@ -0,0 +1,3 @@ +__all__ = ["classic_search"] + +from search.services.index.classic_api.classic_search import classic_search diff --git a/search/services/index/classic_api/classic_search.py b/search/services/index/classic_api/classic_search.py new file mode 100644 index 00000000..a6a1d1f9 --- /dev/null +++ b/search/services/index/classic_api/classic_search.py @@ -0,0 +1,62 @@ +"""Translate classic API `Phrase` objects to Elasticsearch DSL.""" +import re + +from elasticsearch_dsl import Q, Search + +from search.domain import ClassicAPIQuery, SortOrder +from search.services.index.classic_api.query_builder import query_builder + +# FIXME: Use arxiv identifier parsing from arxiv.base when it's ready. +# Also this allows version to start with 0 to mimic the old API. +ENDS_WITH_VERSION = re.compile(r".*v\d+$") + + +def classic_search(search: Search, query: ClassicAPIQuery) -> Search: + """ + Prepare a :class:`.Search` from a :class:`.ClassicAPIQuery`. + + Parameters + ---------- + search : :class:`.Search` + An Elasticsearch search in preparation. + query : :class:`.ClassicAPIQuery` + An query originating from the Classic API. + + Returns + ------- + :class:`.Search` + The passed ES search object, updated with specific query parameters + that implement the advanced query. + + """ + # Initialize query. + if query.phrase: + dsl_query = query_builder(query.phrase) + else: + dsl_query = Q() + + # Filter id_list if necessary. + if query.id_list: + # Separate versioned and unversioned papers. + + paper_ids = [] + paper_ids_vs = [] + for paper_id in query.id_list: + if ENDS_WITH_VERSION.match(paper_id): + paper_ids_vs.append(paper_id) + else: + paper_ids.append(paper_id) + + # Filter by most recent unversioned paper or any versioned paper. + id_query = ( + Q("terms", paper_id=paper_ids) & Q("term", is_current=True) + ) | Q("terms", paper_id_v=paper_ids_vs) + + search = search.filter(id_query) + else: + # If no id_list, only display current results. + search = search.filter("term", is_current=True) + + if not isinstance(query, SortOrder): + return search.query(dsl_query) + return search.query(dsl_query).sort(*query.order.to_es()) # type: ignore diff --git a/search/services/index/classic_api/query_builder.py b/search/services/index/classic_api/query_builder.py new file mode 100644 index 00000000..ad715c31 --- /dev/null +++ b/search/services/index/classic_api/query_builder.py @@ -0,0 +1,55 @@ +from typing import Dict, Callable + +from elasticsearch_dsl import Q + +from search.domain import Phrase, Term, Field, Operator +from search.services.index.prepare import ( + SEARCH_FIELDS, + query_any_subject_exact_raw, +) + +FIELD_TERM_MAPPING: Dict[Field, Callable[[str], Q]] = { + Field.Author: SEARCH_FIELDS["author"], + Field.Comment: SEARCH_FIELDS["comments"], + Field.Identifier: SEARCH_FIELDS["paper_id"], + Field.JournalReference: SEARCH_FIELDS["journal_ref"], + Field.ReportNumber: SEARCH_FIELDS["report_num"], + # Expects to match on primary or secondary category. + Field.SubjectCategory: query_any_subject_exact_raw, + Field.Title: SEARCH_FIELDS["title"], + Field.All: SEARCH_FIELDS["all"], +} + + +def term_to_query(term: Term) -> Q: + """ + Parses a fielded term using transfromations from the current API. + + See Also + -------- + :module:`.api` + """ + + return Q() if term.is_empty else FIELD_TERM_MAPPING[term.field](term.value) + + +def query_builder(phrase: Phrase) -> Q: + """Parses a Phrase of a Classic API request into an ES Q object.""" + if isinstance(phrase, Term): + return term_to_query(phrase) + elif len(phrase) == 2: + # This is unary ANDNOT which is just NOT + return ~term_to_query(phrase[1]) + elif len(phrase) == 3: + binary_op, exp1, exp2 = phrase[:3] # type:ignore + q1 = query_builder(exp1) + q2 = query_builder(exp2) + if binary_op is Operator.AND: + return q1 & q2 + elif binary_op is Operator.OR: + return q1 | q2 + elif binary_op is Operator.ANDNOT: + return q1 & (~q2) + else: + # Error? + return Q() diff --git a/search/services/index/classic_api/tests/__init__.py b/search/services/index/classic_api/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/search/services/index/classic_api/tests/test_query_builder.py b/search/services/index/classic_api/tests/test_query_builder.py new file mode 100644 index 00000000..f4b5e373 --- /dev/null +++ b/search/services/index/classic_api/tests/test_query_builder.py @@ -0,0 +1,192 @@ +from typing import List +from unittest import TestCase +from dataclasses import dataclass + +from elasticsearch_dsl import Q + +from search.domain import Field, Operator, Phrase, Term +from search.services.index.classic_api.query_builder import ( + query_builder, + FIELD_TERM_MAPPING as FTM, +) + + +@dataclass +class Case: + message: str + phrase: Phrase + query: Q + + +TEST_CASES: List[Case] = [ + Case(message="Empty query", phrase=Term(Field.All, ""), query=Q()), + Case( + message="Empty query in conjunction.", + phrase=( + Operator.AND, + Term(Field.All, ""), + Term(Field.All, "electron"), + ), + query=FTM[Field.All]("electron"), + ), + Case( + message="Double empty query in conjunction.", + phrase=(Operator.AND, Term(Field.All, ""), Term(Field.All, ""),), + query=Q(), + ), + Case( + message="Simple query without grouping/nesting.", + phrase=Term(Field.Author, "copernicus"), + query=FTM[Field.Author]("copernicus"), + ), + Case( + message="Simple query with quotations.", + phrase=Term(Field.Title, "dark matter"), + query=FTM[Field.Title]("dark matter"), + ), + Case( + message="Simple conjunct AND query.", + phrase=( + Operator.AND, + Term(Field.Author, "del_maestro"), + Term(Field.Title, "checkerboard"), + ), + query=( + FTM[Field.Author]("del_maestro") & FTM[Field.Title]("checkerboard") + ), + ), + Case( + message="Simple conjunct OR query.", + phrase=( + Operator.OR, + Term(Field.Author, "del_maestro"), + Term(Field.Title, "checkerboard"), + ), + query=( + FTM[Field.Author]("del_maestro") | FTM[Field.Title]("checkerboard") + ), + ), + Case( + message="Simple conjunct ANDNOT query.", + phrase=( + Operator.ANDNOT, + Term(Field.Author, "del_maestro"), + Term(Field.Title, "checkerboard"), + ), + query=( + FTM[Field.Author]("del_maestro") + & (~FTM[Field.Title]("checkerboard")) + ), + ), + Case( + message="Simple conjunct query with quoted field.", + phrase=( + Operator.AND, + Term(Field.Author, "del_maestro"), + Term(Field.Title, "dark matter"), + ), + query=( + FTM[Field.Author]("del_maestro") & FTM[Field.Title]("dark matter") + ), + ), + Case( + message="Disjunct query with an unary not.", + phrase=( + Operator.OR, + Term(Field.Author, "del_maestro"), + (Operator.ANDNOT, Term(Field.Title, "checkerboard")), + ), + query=( + FTM[Field.Author]("del_maestro") + | (~FTM[Field.Title]("checkerboard")) + ), + ), + Case( + message="Conjunct query with nested disjunct query.", + phrase=( + Operator.ANDNOT, + Term(Field.Author, "del_maestro"), + ( + Operator.OR, + Term(Field.Title, "checkerboard"), + Term(Field.Title, "Pyrochlore"), + ), + ), + query=FTM[Field.Author]("del_maestro") + & ( + ~( + FTM[Field.Title]("checkerboard") + | FTM[Field.Title]("Pyrochlore") + ) + ), + ), + Case( + message="Conjunct query with nested disjunct query.", + phrase=( + Operator.ANDNOT, + ( + Operator.OR, + Term(Field.Author, "del_maestro"), + Term(Field.Author, "bob"), + ), + ( + Operator.OR, + Term(Field.Title, "checkerboard"), + Term(Field.Title, "Pyrochlore"), + ), + ), + query=( + (FTM[Field.Author]("del_maestro") | FTM[Field.Author]("bob")) + & ( + ~( + FTM[Field.Title]("checkerboard") + | FTM[Field.Title]("Pyrochlore") + ) + ) + ), + ), + Case( + message="Conjunct ANDNOT query with nested disjunct query.", + phrase=( + Operator.ANDNOT, + ( + Operator.OR, + Term(Field.Title, "checkerboard"), + Term(Field.Title, "Pyrochlore"), + ), + Term(Field.Author, "del_maestro"), + ), + query=( + (FTM[Field.Title]("checkerboard") | FTM[Field.Title]("Pyrochlore")) + & (~FTM[Field.Author]("del_maestro")) + ), + ), + Case( + message="Conjunct AND query with nested disjunct query.", + phrase=( + Operator.AND, + ( + Operator.OR, + Term(Field.Title, "checkerboard"), + Term(Field.Title, "Pyrochlore"), + ), + ( + Operator.OR, + Term(Field.Author, "del_maestro"), + Term(Field.Author, "hawking"), + ), + ), + query=( + (FTM[Field.Title]("checkerboard") | FTM[Field.Title]("Pyrochlore")) + & (FTM[Field.Author]("del_maestro") | FTM[Field.Author]("hawking")) + ), + ), +] + + +class TestQueryBuilder(TestCase): + def test_query_builder(self): + for case in TEST_CASES: + self.assertEqual( + query_builder(case.phrase), case.query, msg=case.message + ) diff --git a/search/services/index/exceptions.py b/search/services/index/exceptions.py index b437c381..776b47d9 100644 --- a/search/services/index/exceptions.py +++ b/search/services/index/exceptions.py @@ -1,7 +1,13 @@ """Exceptions raised by the search index service.""" -__all__ = ('MappingError', 'IndexConnectionError', 'IndexingError', - 'QueryError', 'DocumentNotFound', 'OutsideAllowedRange') +__all__ = ( + "MappingError", + "IndexConnectionError", + "IndexingError", + "QueryError", + "DocumentNotFound", + "OutsideAllowedRange", +) class MappingError(ValueError): diff --git a/search/services/index/highlighting.py b/search/services/index/highlighting.py index a295b2e7..6918c60f 100644 --- a/search/services/index/highlighting.py +++ b/search/services/index/highlighting.py @@ -11,19 +11,20 @@ import re from typing import Any, Union -from elasticsearch_dsl import Search, Q, SF -from elasticsearch_dsl.response import Response, Hit -import bleach from flask import escape -from arxiv.base import logging +from elasticsearch_dsl import Search +from elasticsearch_dsl.response import Response, Hit +from arxiv.base import logging from search.domain import Document -from .util import TEXISM +from search.services.index.util import TEXISM + logger = logging.getLogger(__name__) + HIGHLIGHT_TAG_OPEN = '' -HIGHLIGHT_TAG_CLOSE = '' +HIGHLIGHT_TAG_CLOSE = "" def highlight(search: Search) -> Search: @@ -43,38 +44,41 @@ def highlight(search: Search) -> Search: """ # Highlight class .search-hit defined in search.sass search = search.highlight_options( - pre_tags=[HIGHLIGHT_TAG_OPEN], - post_tags=[HIGHLIGHT_TAG_CLOSE] + pre_tags=[HIGHLIGHT_TAG_OPEN], post_tags=[HIGHLIGHT_TAG_CLOSE] ) - search = search.highlight('title', type='plain', number_of_fragments=0) - search = search.highlight('title.english', type='plain', - number_of_fragments=0) - search = search.highlight('title.tex', type='plain', - number_of_fragments=0) + search = search.highlight("title", type="plain", number_of_fragments=0) + search = search.highlight( + "title.english", type="plain", number_of_fragments=0 + ) + search = search.highlight("title.tex", type="plain", number_of_fragments=0) - search = search.highlight('comments', number_of_fragments=0) + search = search.highlight("comments", number_of_fragments=0) # Highlight any field the name of which begins with "author". - search = search.highlight('author*') - search = search.highlight('owner*') - search = search.highlight('announced_date_first') - search = search.highlight('submitter*') - search = search.highlight('journal_ref', type='plain') - search = search.highlight('acm_class', number_of_fragments=0) - search = search.highlight('msc_class', number_of_fragments=0) - search = search.highlight('doi', type='plain') - search = search.highlight('report_num', type='plain') + search = search.highlight("author*") + search = search.highlight("owner*") + search = search.highlight("announced_date_first") + search = search.highlight("submitter*") + search = search.highlight("journal_ref", type="plain") + search = search.highlight("acm_class", number_of_fragments=0) + search = search.highlight("msc_class", number_of_fragments=0) + search = search.highlight("doi", type="plain") + search = search.highlight("report_num", type="plain") # Setting number_of_fragments to 0 tells ES to highlight the entire field. - search = search.highlight('abstract', number_of_fragments=0) - search = search.highlight('abstract.tex', type='plain', - number_of_fragments=0) - search = search.highlight('abstract.english', number_of_fragments=0) + search = search.highlight("abstract", number_of_fragments=0) + search = search.highlight( + "abstract.tex", type="plain", number_of_fragments=0 + ) + search = search.highlight("abstract.english", number_of_fragments=0) return search -def preview(value: str, fragment_size: int = 400, - start_tag: str = HIGHLIGHT_TAG_OPEN, - end_tag: str = HIGHLIGHT_TAG_CLOSE) -> str: +def preview( + value: str, + fragment_size: int = 400, + start_tag: str = HIGHLIGHT_TAG_OPEN, + end_tag: str = HIGHLIGHT_TAG_CLOSE, +) -> str: """ Generate a snippet preview that doesn't breaking TeXisms or highlighting. @@ -109,13 +113,13 @@ def preview(value: str, fragment_size: int = 400, c = value[start - 1] s = start while start - s < start_frag_size and s > 0: - if c in '$>': # This may or may not be an actual HTML tag or TeX. - break # But it doesn't hurt to play it safe. + if c in "$>": # This may or may not be an actual HTML tag or TeX. + break # But it doesn't hurt to play it safe. s -= 1 c = value[s - 1] start = s # Move the start forward slightly, to find a word boundary. - while c not in '.,!? \t\n$<' and start > 0: + while c not in ".,!? \t\n$<" and start > 0: start += 1 c = value[start - 1] else: @@ -127,8 +131,9 @@ def preview(value: str, fragment_size: int = 400, # Jump the end forward until we consume (as much as possible of) the # rest of the target fragment size. remaining = max(0, fragment_size - (end - start)) - end += _end_safely(value[end:], remaining, start_tag=start_tag, - end_tag=end_tag) + end += _end_safely( + value[end:], remaining, start_tag=start_tag, end_tag=end_tag + ) snippet = value[start:end].strip() last_open = snippet.rfind(HIGHLIGHT_TAG_OPEN) last_close = snippet.rfind(HIGHLIGHT_TAG_CLOSE) @@ -136,9 +141,9 @@ def preview(value: str, fragment_size: int = 400, if last_open > last_close and last_open >= 0: snippet += HIGHLIGHT_TAG_CLOSE snippet = ( - ('…' if start > 0 else '') + ("…" if start > 0 else "") + snippet - + ('…' if end < len(value) else '') + + ("…" if end < len(value) else "") ) return snippet @@ -162,33 +167,32 @@ def add_highlighting(result: Document, raw: Union[Response, Hit]) -> Document: """ # There may or may not be highlighting in the result set. - highlighted_fields = getattr(raw.meta, 'highlight', None) + highlighted_fields = getattr(raw.meta, "highlight", None) # ``meta.matched_queries`` contains a list of query ``_name``s that # matched. This is nice for non-string fields. - matched_fields = getattr(raw.meta, 'matched_queries', []) + matched_fields = getattr(raw.meta, "matched_queries", []) # These are from hits within child documents, e.g. # secondary_classification. - inner_hits = getattr(raw.meta, 'inner_hits', None) + inner_hits = getattr(raw.meta, "inner_hits", None) # The values here will (almost) always be list-like. So we need to stitch # them together. Note that dir(None) won't return anything, so this block # is skipped if there are no highlights from ES. for field in dir(highlighted_fields): - if field.startswith('_'): + if field.startswith("_"): continue value = getattr(highlighted_fields, field) - if hasattr(value, '__iter__'): - value = '…'.join(value) + if hasattr(value, "__iter__"): + value = "…".join(value) # Non-TeX searches may hit inside of TeXisms. Highlighting those # fragments (i.e. inserting HTML) will break MathJax rendering. # To guard against this while preserving highlighting, we move # any highlighting tags from within TeXisms to encapsulate the # entire TeXism. - if field in ['title', 'title.english', - 'abstract', 'abstract.english']: + if field in ["title", "title.english", "abstract", "abstract.english"]: value = _highlight_whole_texism(value) value = _escape(value) @@ -197,46 +201,49 @@ def add_highlighting(result: Document, raw: Union[Response, Hit]) -> Document: # truncated. So instead of highlighting author names themselves, we # set a 'flag' that can get picked up in the template and highlight # the entire author field. - if field.startswith('author') or field.startswith('owner') \ - or field.startswith('submitter'): - result['match']['author'] = True + if ( + field.startswith("author") + or field.startswith("owner") + or field.startswith("submitter") + ): + result["match"]["author"] = True continue - result['highlight'][field] = value + result["highlight"][field] = value for field in matched_fields: - if field not in result['highlight']: - result['match'][field] = True + if field not in result["highlight"]: + result["match"][field] = True # We're using inner_hits to see which category in particular responded to # the query. - if hasattr(inner_hits, 'secondary_classification'): - result['match']['secondary_classification'] = [ + if hasattr(inner_hits, "secondary_classification"): + result["match"]["secondary_classification"] = [ ih.category.id for ih in inner_hits.secondary_classification ] # We just want to know whether there was a hit on the announcement date. - result['match']['announced_date_first'] = ( - bool('announced_date_first' in matched_fields) + result["match"]["announced_date_first"] = bool( + "announced_date_first" in matched_fields ) # If there is a hit in a TeX field, we prefer highlighting on that # field, since other tokenizers will clobber the TeX. - for field in ['abstract', 'title']: - if f'{field}.tex' in result['highlight']: - result['highlight'][field] = result['highlight'][f'{field}.tex'] - del result['highlight'][f'{field}.tex'] - - for field in ['abstract.tex', 'abstract.english', 'abstract']: - if field in result['highlight']: - value = result['highlight'][field] + for field in ["abstract", "title"]: + if f"{field}.tex" in result["highlight"]: + result["highlight"][field] = result["highlight"][f"{field}.tex"] + del result["highlight"][f"{field}.tex"] + + for field in ["abstract.tex", "abstract.english", "abstract"]: + if field in result["highlight"]: + value = result["highlight"][field] abstract_snippet = preview(value) - result['preview']['abstract'] = abstract_snippet - result['highlight']['abstract'] = value + result["preview"]["abstract"] = abstract_snippet + result["highlight"]["abstract"] = value break - for field in ['title.english', 'title']: - if field in result['highlight']: - result['highlight']['title'] = result['highlight'][field] + for field in ["title.english", "title"]: + if field in result["highlight"]: + result["highlight"]["title"] = result["highlight"][field] break return result @@ -250,7 +257,7 @@ def _strip_highlight_and_enclose(match: Any) -> str: value = value.replace(HIGHLIGHT_TAG_CLOSE, "") # If HTML was removed, we will assume that it was highlighting HTML. # if len(new_value) < len(value): - value = f'{HIGHLIGHT_TAG_OPEN}{value}{HIGHLIGHT_TAG_CLOSE}' + value = f"{HIGHLIGHT_TAG_OPEN}{value}{HIGHLIGHT_TAG_CLOSE}" return value @@ -281,70 +288,82 @@ def _escape(value: str) -> str: break if i_o is not None and i_c is not None: if i_o < i_c: - _sub = str(escape(value[i:i + i_o])) + tag_o + _sub = str(escape(value[i : i + i_o])) + tag_o i += i_o + len(tag_o) elif i_c < i_o: - _sub = str(escape(value[i:i + i_c])) + tag_c + _sub = str(escape(value[i : i + i_c])) + tag_c i += i_c + len(tag_c) elif i_o is not None and i_c is None: - _sub = str(escape(value[i:i + i_o])) + tag_o + _sub = str(escape(value[i : i + i_o])) + tag_o i += i_o + len(tag_o) elif i_c is not None and i_o is None: - _sub = str(escape(value[i:i + i_c])) + tag_c + _sub = str(escape(value[i : i + i_c])) + tag_c i += i_c + len(tag_c) _new += _sub return _new -def _start_safely(value: str, start: int, end: int, fragment_size: int, - tolerance: int = 0, start_tag: str = HIGHLIGHT_TAG_OPEN, - end_tag: str = HIGHLIGHT_TAG_CLOSE) -> int: +def _start_safely( + value: str, + start: int, + end: int, + fragment_size: int, + tolerance: int = 0, + start_tag: str = HIGHLIGHT_TAG_OPEN, + end_tag: str = HIGHLIGHT_TAG_CLOSE, +) -> int: # Try to maximize the length of the fragment up to the fragment_size, but # avoid starting in the middle of a tag or a TeXism. space_remaining = (fragment_size + tolerance) - (end - start) - remainder = value[start - fragment_size:start] - acceptable = value[start - fragment_size - tolerance:start] + remainder = value[start - fragment_size : start] + acceptable = value[start - fragment_size - tolerance : start] if end_tag in remainder: # Relative index of the first end tag. - first_end_tag = value[start - space_remaining:start].index(end_tag) - if start_tag in value[start - space_remaining:first_end_tag]: - target_area = value[start - space_remaining:first_end_tag] + first_end_tag = value[start - space_remaining : start].index(end_tag) + if start_tag in value[start - space_remaining : first_end_tag]: + target_area = value[start - space_remaining : first_end_tag] first_start_tag = target_area.index(start_tag) return (start - space_remaining) + first_start_tag - elif '$' in remainder: + elif "$" in remainder: m = TEXISM.search(acceptable) - if m is None: # Can't get to opening - return start - remainder[::-1].index('$') + 1 + if m is None: # Can't get to opening + return start - remainder[::-1].index("$") + 1 return (start - fragment_size - tolerance) + m.start() # Ideally, we hit the fragment size without entering a tag or TeXism. return start - fragment_size -def _end_safely(value: str, remaining: int, - start_tag: str = HIGHLIGHT_TAG_OPEN, - end_tag: str = HIGHLIGHT_TAG_CLOSE) -> int: +def _end_safely( + value: str, + remaining: int, + start_tag: str = HIGHLIGHT_TAG_OPEN, + end_tag: str = HIGHLIGHT_TAG_CLOSE, +) -> int: """Find a fragment end that doesn't break TeXisms or HTML.""" # Should match on either a TeXism or a TeXism enclosed in highlight tags. # TeXisms may be enclosed in pairs of $$ or $. - ptn = r'|'.join([ - r'([\$]{2}[^\$]+[\$]{2})', - r'([\$]{1}[^\$]+[\$]{1})', - r'(%s[\$]{2}[^\$]+[\$]{2}%s)' % (start_tag, end_tag), - r'(%s[\$]{1}[^\$]+[\$]{1}%s)' % (start_tag, end_tag), - r'(%s[^\$]+%s)' % (start_tag, end_tag) - ]) + ptn = r"|".join( + [ + r"([\$]{2}[^\$]+[\$]{2})", + r"([\$]{1}[^\$]+[\$]{1})", + r"(%s[\$]{2}[^\$]+[\$]{2}%s)" % (start_tag, end_tag), + r"(%s[\$]{1}[^\$]+[\$]{1}%s)" % (start_tag, end_tag), + r"(%s[^\$]+%s)" % (start_tag, end_tag), + ] + ) m = re.search(ptn, value) - if m is None: # Nothing to worry about; the coast is clear. + if m is None: # Nothing to worry about; the coast is clear. return remaining ptn_start = m.start() ptn_end = m.end() if remaining <= ptn_start: # The ideal end falls before the next TeX/tag. return remaining - elif ptn_end < remaining: # The ideal end falls after the next TeX/tag. - return ptn_end + _end_safely(value[ptn_end:], remaining - ptn_end, - start_tag, end_tag) + elif ptn_end < remaining: # The ideal end falls after the next TeX/tag. + return ptn_end + _end_safely( + value[ptn_end:], remaining - ptn_end, start_tag, end_tag + ) # We can't make it past the end of the next TeX/tag without exceeding the # target fragment size, so we will end at the beginning of the match. diff --git a/search/services/index/prepare.py b/search/services/index/prepare.py index cda9ff7a..d387b428 100644 --- a/search/services/index/prepare.py +++ b/search/services/index/prepare.py @@ -7,25 +7,34 @@ See :func:`._query_all_fields` for information on how results are scored. """ -from typing import Any, List, Tuple, Callable, Dict, Optional -from functools import reduce, wraps -from operator import ior, iand import re +from functools import reduce from datetime import datetime -from string import punctuation +from operator import ior, iand +from typing import List, Callable, Dict, Optional -from elasticsearch_dsl import Search, Q, SF +from elasticsearch_dsl import Q, SF from arxiv.base import logging -from search.domain import SimpleQuery, Query, AdvancedQuery, Classification, \ - ClassificationList -from .util import strip_tex, Q_, is_tex_query, is_literal_query, escape, \ - wildcard_escape, remove_single_characters, has_wildcard, is_old_papernum, \ - parse_date, parse_date_partial - -from .highlighting import HIGHLIGHT_TAG_OPEN, HIGHLIGHT_TAG_CLOSE -from .authors import author_query, author_id_query, orcid_query +from search.domain import Classification, ClassificationList +from search.services.index.util import ( + Q_, + is_tex_query, + is_literal_query, + escape, + wildcard_escape, + has_wildcard, + is_old_papernum, + parse_date, + parse_date_partial, +) + +from search.services.index.authors import ( + author_query, + author_id_query, + orcid_query, +) logger = logging.getLogger(__name__) @@ -33,63 +42,88 @@ END_YEAR = datetime.now().year -def _query_title(term: str, default_operator: str = 'AND') -> Q: +def _query_title(term: str, default_operator: str = "AND") -> Q: if is_tex_query(term): - return Q("match", **{f'title.tex': {'query': term}}) - fields = ['title.english'] + return Q("match", **{f"title.tex": {"query": term}}) + fields = ["title.english"] if is_literal_query(term): - fields += ['title'] - return Q("query_string", fields=fields, default_operator=default_operator, - allow_leading_wildcard=False, query=escape(term)) + fields += ["title"] + return Q( + "query_string", + fields=fields, + default_operator=default_operator, + allow_leading_wildcard=False, + query=escape(term), + ) -def _query_abstract(term: str, default_operator: str = 'AND') -> Q: +def _query_abstract(term: str, default_operator: str = "AND") -> Q: fields = ["abstract.english"] if is_literal_query(term): fields += ["abstract"] - return Q("query_string", fields=fields, default_operator=default_operator, - allow_leading_wildcard=False, query=escape(term), - _name="abstract") + return Q( + "query_string", + fields=fields, + default_operator=default_operator, + allow_leading_wildcard=False, + query=escape(term), + _name="abstract", + ) -def _query_comments(term: str, default_operator: str = 'AND') -> Q: - return Q("query_string", fields=["comments"], - default_operator=default_operator, - allow_leading_wildcard=False, query=escape(term)) +def _query_comments(term: str, default_operator: str = "AND") -> Q: + return Q( + "query_string", + fields=["comments"], + default_operator=default_operator, + allow_leading_wildcard=False, + query=escape(term), + ) -def _tex_query(field: str, term: str, operator: str = 'and') -> Q: - return Q("match", - **{f'{field}.tex': {'query': term, 'operator': operator}}) +def _tex_query(field: str, term: str, operator: str = "and") -> Q: + return Q( + "match", **{f"{field}.tex": {"query": term, "operator": operator}} + ) -def _query_journal_ref(term: str, boost: int = 1, operator: str = 'and') -> Q: - return Q("query_string", fields=["journal_ref"], default_operator=operator, - allow_leading_wildcard=False, query=escape(term)) +def _query_journal_ref(term: str, boost: int = 1, operator: str = "and") -> Q: + return Q( + "query_string", + fields=["journal_ref"], + default_operator=operator, + allow_leading_wildcard=False, + query=escape(term), + ) -def _query_report_num(term: str, boost: int = 1, operator: str = 'and') -> Q: - return Q("query_string", fields=["report_num"], default_operator=operator, - allow_leading_wildcard=False, query=escape(term)) +def _query_report_num(term: str, boost: int = 1, operator: str = "and") -> Q: + return Q( + "query_string", + fields=["report_num"], + default_operator=operator, + allow_leading_wildcard=False, + query=escape(term), + ) -def _query_acm_class(term: str, operator: str = 'and') -> Q: +def _query_acm_class(term: str, operator: str = "and") -> Q: if has_wildcard(term): return Q("wildcard", acm_class=term) return Q("match", acm_class={"query": term, "operator": operator}) -def _query_msc_class(term: str, operator: str = 'and') -> Q: +def _query_msc_class(term: str, operator: str = "and") -> Q: if has_wildcard(term): return Q("wildcard", msc_class=term) return Q("match", msc_class={"query": term, "operator": operator}) -def _query_doi(term: str, operator: str = 'and') -> Q: +def _query_doi(term: str, operator: str = "and") -> Q: value, wildcard = wildcard_escape(term) if wildcard: - return Q('wildcard', doi={'value': term.lower()}) - return Q('match', doi={'query': term, 'operator': operator}) + return Q("wildcard", doi={"value": term.lower()}) + return Q("match", doi={"query": term, "operator": operator}) def _query_announcement_date(term: str) -> Optional[Q]: @@ -99,81 +133,121 @@ def _query_announcement_date(term: str) -> Optional[Q]: If ``term`` looks like a year, will use a range search for all months in that year. If it looks like a year-month combo, will match. """ - year_match = re.match(r'^([0-9]{4})$', term) # Looks like a year. + year_match = re.match(r"^([0-9]{4})$", term) # Looks like a year. if year_match and END_YEAR >= int(year_match.group(1)) >= START_YEAR: - _range = {'gte': f'{term}-01', 'lte': f'{term}-12'} - return Q('range', announced_date_first=_range) + _range = {"gte": f"{term}-01", "lte": f"{term}-12"} + return Q("range", announced_date_first=_range) - month_match = re.match(r'^([0-9]{4})-([0-9]{2})$', term) # yyyy-MM. + month_match = re.match(r"^([0-9]{4})-([0-9]{2})$", term) # yyyy-MM. if month_match and END_YEAR >= int(month_match.group(1)) >= START_YEAR: - return Q('match', announced_date_first=term) + return Q("match", announced_date_first=term) return None -def _query_primary(term: str, operator: str = 'and') -> Q: +def _query_primary(term: str, operator: str = "and") -> Q: # This now uses the "primary_classification.combined" field, which is # isomorphic to the document-level "combined" field. So we get # straightforward hit highlighting and a more consistent behavior. - return Q("match", **{ - "primary_classification__combined": { - "query": term, - "operator": operator, - "_name": "primary_classification" - } - }) + return Q( + "match", + **{ + "primary_classification__combined": { + "query": term, + "operator": operator, + "_name": "primary_classification", + } + }, + ) def query_primary_exact(classification: Classification) -> Q: """Generate a :class:`Q` for primary classification by ID.""" - return reduce(iand, [ - Q("match", **{f"primary_classification__{field}__id": - getattr(classification, field)['id']}) - for field in ['group', 'archive', 'category'] - if getattr(classification, field, None) is not None - ]) + return reduce( + iand, + [ + Q( + "match", + **{ + f"primary_classification__{field}__id": getattr( + classification, field + )["id"] + }, + ) + for field in ["group", "archive", "category"] + if getattr(classification, field, None) is not None + ], + ) def query_secondary_exact(classification: Classification) -> Q: """Generate a :class:`Q` for secondary classification by ID.""" - return Q("nested", path="secondary_classification", - query=reduce(iand, [ - Q("match", **{f"secondary_classification__{field}__id": - getattr(classification, field)['id']}) - for field in ['group', 'archive', 'category'] - if getattr(classification, field, None) is not None - ])) + return Q( + "nested", + path="secondary_classification", + query=reduce( + iand, + [ + Q( + "match", + **{ + f"secondary_classification__{field}__id": getattr( + classification, field + )["id"] + }, + ) + for field in ["group", "archive", "category"] + if getattr(classification, field, None) is not None + ], + ), + ) + +def query_any_subject_exact_raw(term: str) -> Q: + """ + Generate a :class:`Q` for classification subject by ID with a raw value. + + This will match any e-print that has a primary or secondary classification + with category identifier equal to ``term``. + """ + return Q("match", primary_classification__category__id=term) | Q( + "nested", + path="secondary_classification", + query=Q("match", secondary_classification__category__id=term), + ) -def _query_secondary(term: str, operator: str = 'and') -> Q: + +def _query_secondary(term: str, operator: str = "and") -> Q: return Q( "nested", path="secondary_classification", query=Q( - "match", **{ + "match", + **{ "secondary_classification.combined": { "query": term, - "operator": operator + "operator": operator, } - } + }, ), _name="secondary_classification", - inner_hits={} # This gets us the specific category that matched. + inner_hits={}, # This gets us the specific category that matched. ) -def _query_paper_id(term: str, operator: str = 'and') -> Q: +def _query_paper_id(term: str, operator: str = "and") -> Q: operator = operator.lower() - logger.debug(f'query paper ID with: {term}') - q = (Q_('match', 'paper_id', escape(term), operator=operator) - | Q_('match', 'paper_id_v', escape(term), operator=operator)) + logger.debug(f"query paper ID with: {term}") + q = Q_("match", "paper_id", escape(term), operator=operator) | Q_( + "match", "paper_id_v", escape(term), operator=operator + ) if is_old_papernum(term): - q |= Q('wildcard', paper_id=f'*/{term}') + q |= Q("wildcard", paper_id=f"*/{term}") return q -def _license_query(term: str, operator: str = 'and') -> Q: +def _license_query(term: str, operator: str = "and") -> Q: """Search by license, using its URI (exact).""" - return Q('term', **{'license__uri': term}) + return Q("term", **{"license__uri": term}) def _query_combined(term: str) -> Q: @@ -181,8 +255,13 @@ def _query_combined(term: str) -> Q: wildcard_escaped, has_wildcard = wildcard_escape(term) query_term = (wildcard_escaped if has_wildcard else escape(term)).lower() # All terms must match in the combined field. - return Q("query_string", fields=['combined'], default_operator='AND', - allow_leading_wildcard=False, query=query_term) + return Q( + "query_string", + fields=["combined"], + default_operator="AND", + allow_leading_wildcard=False, + query=query_term, + ) def _query_all_fields(term: str) -> Q: @@ -229,47 +308,50 @@ def _query_all_fields(term: str) -> Q: """ # We only perform TeX queries on title and abstract. if is_tex_query(term): - return _tex_query('title', term) | _tex_query('abstract', term) + return _tex_query("title", term) | _tex_query("abstract", term) match_all_fields = _query_combined(term) # We include matches of any term in any field, so that we can highlight # and score appropriately. queries = [ - _query_paper_id(term, operator='or'), - author_query(term, operator='or'), - _query_title(term, default_operator='or'), - _query_abstract(term, default_operator='or'), - _query_comments(term, default_operator='or'), - orcid_query(term, operator='or'), - author_id_query(term, operator='or'), - _query_doi(term, operator='or'), - _query_journal_ref(term, operator='or'), - _query_report_num(term, operator='or'), - _query_acm_class(term, operator='or'), - _query_msc_class(term, operator='or'), - _query_primary(term, operator='or'), - _query_secondary(term, operator='or'), + _query_paper_id(term, operator="or"), + author_query(term, operator="or"), + _query_title(term, default_operator="or"), + _query_abstract(term, default_operator="or"), + _query_comments(term, default_operator="or"), + orcid_query(term, operator="or"), + author_id_query(term, operator="or"), + _query_doi(term, operator="or"), + _query_journal_ref(term, operator="or"), + _query_report_num(term, operator="or"), + _query_acm_class(term, operator="or"), + _query_msc_class(term, operator="or"), + _query_primary(term, operator="or"), + _query_secondary(term, operator="or"), ] # If the whole query matches on a specific field, we should consider that # responsive even if the query on the combined field does not respond. - match_individual_field = reduce(ior, [ - _query_paper_id(term, operator='AND'), - author_query(term, operator='AND'), - _query_title(term, default_operator='and'), - _query_abstract(term, default_operator='and'), - _query_comments(term, default_operator='and'), - orcid_query(term, operator='and'), - author_id_query(term, operator='and'), - _query_doi(term, operator='and'), - _query_journal_ref(term, operator='and'), - _query_report_num(term, operator='and'), - _query_acm_class(term, operator='and'), - _query_msc_class(term, operator='and'), - _query_primary(term, operator='and'), - _query_secondary(term, operator='and') - ]) + match_individual_field = reduce( + ior, + [ + _query_paper_id(term, operator="AND"), + author_query(term, operator="AND"), + _query_title(term, default_operator="and"), + _query_abstract(term, default_operator="and"), + _query_comments(term, default_operator="and"), + orcid_query(term, operator="and"), + author_id_query(term, operator="and"), + _query_doi(term, operator="and"), + _query_journal_ref(term, operator="and"), + _query_report_num(term, operator="and"), + _query_acm_class(term, operator="and"), + _query_msc_class(term, operator="and"), + _query_primary(term, operator="and"), + _query_secondary(term, operator="and"), + ], + ) # It is possible that the query includes a date-related term, which we # interpret as an announcement date of v1 of the paper. We currently @@ -290,16 +372,16 @@ def _query_all_fields(term: str) -> Q: pass if date_fragment: - logger.debug('date: %s; remainder: %s', date_fragment, remainder) + logger.debug("date: %s; remainder: %s", date_fragment, remainder) match_date: Optional[Q] = None match_date_partial: Optional[Q] = None match_date_announced: Optional[Q] = None match_dates: List[Q] = [] - logger.debug('date_fragment: %s', date_fragment) + logger.debug("date_fragment: %s", date_fragment) # Try to query using legacy yyMM date partial format. date_partial = parse_date_partial(date_fragment) - logger.debug('date_partial: %s', date_partial) + logger.debug("date_partial: %s", date_partial) if date_partial is not None: match_date_partial = Q("term", announced_date_first=date_partial) match_dates.append(match_date_partial) @@ -316,96 +398,125 @@ def _query_all_fields(term: str) -> Q: # the announcement date is to wrap this in a top-level query and # give it a ``_name``. This causes the ``_name`` to show up # in the ``.meta.matched_queries`` property on the search result. - match_date = Q("bool", should=match_dates, minimum_should_match=1, - _name="announced_date_first") - logger.debug('match date: %s', match_date) + match_date = Q( + "bool", + should=match_dates, + minimum_should_match=1, + _name="announced_date_first", + ) + logger.debug("match date: %s", match_date) queries.insert(0, match_date) # Now join the announcement date query with the all-fields queries. if match_date is not None: if remainder: match_remainder = _query_combined(remainder) - match_all_fields |= (match_remainder & match_date) - - match_sans_date = reduce(ior, [ - _query_paper_id(remainder, operator='AND'), - author_query(remainder, operator='AND'), - _query_title(remainder, default_operator='and'), - _query_abstract(remainder, default_operator='and'), - _query_comments(remainder, default_operator='and'), - orcid_query(remainder, operator='and'), - author_id_query(remainder, operator='and'), - _query_doi(remainder, operator='and'), - _query_journal_ref(remainder, operator='and'), - _query_report_num(remainder, operator='and'), - _query_acm_class(remainder, operator='and'), - _query_msc_class(remainder, operator='and'), - _query_primary(remainder, operator='and'), - _query_secondary(remainder, operator='and') - ]) - match_individual_field |= (match_sans_date & match_date) + match_all_fields |= match_remainder & match_date + + match_sans_date = reduce( + ior, + [ + _query_paper_id(remainder, operator="AND"), + author_query(remainder, operator="AND"), + _query_title(remainder, default_operator="and"), + _query_abstract(remainder, default_operator="and"), + _query_comments(remainder, default_operator="and"), + orcid_query(remainder, operator="and"), + author_id_query(remainder, operator="and"), + _query_doi(remainder, operator="and"), + _query_journal_ref(remainder, operator="and"), + _query_report_num(remainder, operator="and"), + _query_acm_class(remainder, operator="and"), + _query_msc_class(remainder, operator="and"), + _query_primary(remainder, operator="and"), + _query_secondary(remainder, operator="and"), + ], + ) + match_individual_field |= match_sans_date & match_date else: match_all_fields |= match_date - query = (match_all_fields | match_individual_field) + query = match_all_fields | match_individual_field query &= Q("bool", should=queries) # Partial matches across fields. - scores = [SF({'weight': i + 1, 'filter': q}) - for i, q in enumerate(queries[::-1])] - return Q('function_score', query=query, score_mode="sum", functions=scores, - boost_mode='multiply') + scores = [ + SF({"weight": i + 1, "filter": q}) for i, q in enumerate(queries[::-1]) + ] + return Q( + "function_score", + query=query, + score_mode="sum", + functions=scores, + boost_mode="multiply", + ) -def limit_by_classification(classifications: ClassificationList, - field: str = 'primary_classification') -> Q: +def limit_by_classification( + classifications: ClassificationList, field: str = "primary_classification" +) -> Q: """Generate a :class:`Q` to limit a query by by classification.""" if len(classifications) == 0: return Q() def _to_q(classification: Classification) -> Q: _parts = [] - if 'group' in classification and classification['group'] is not None: + if "group" in classification and classification["group"] is not None: _parts.append( - Q('match', **{ - f'{field}__group__id': classification['group']['id'] - }) + Q( + "match", + **{f"{field}__group__id": classification["group"]["id"]}, + ) ) - if 'archive' in classification \ - and classification['archive'] is not None: + if ( + "archive" in classification + and classification["archive"] is not None + ): _parts.append( - Q('match', **{ - f'{field}__archive__id': classification['archive']['id'] - }) + Q( + "match", + **{ + f"{field}__archive__id": classification["archive"][ + "id" + ] + }, + ) ) - if 'category' in classification \ - and classification['category'] is not None: + if ( + "category" in classification + and classification["category"] is not None + ): _parts.append( - Q('match', **{ - f'{field}__category__id': classification['category']['id'] - }) + Q( + "match", + **{ + f"{field}__category__id": classification["category"][ + "id" + ] + }, + ) ) return reduce(iand, _parts) _q = reduce(ior, map(_to_q, classifications)) - if field == 'secondary_classification': + if field == "secondary_classification": _q = Q("nested", path="secondary_classification", query=_q) return _q -SEARCH_FIELDS: Dict[str, Callable[[str], Q]] = dict([ - ('author', author_query), - ('title', _query_title), - ('abstract', _query_abstract), - ('comments', _query_comments), - ('journal_ref', _query_journal_ref), - ('report_num', _query_report_num), - ('acm_class', _query_acm_class), - ('msc_class', _query_msc_class), - ('cross_list_category', _query_secondary), - ('doi', _query_doi), - ('paper_id', _query_paper_id), - ('orcid', orcid_query), - ('author_id', author_id_query), - ('license', _license_query), - ('all', _query_all_fields) -]) +SEARCH_FIELDS: Dict[str, Callable[[str], Q]] = { + "author": author_query, + "title": _query_title, + "abstract": _query_abstract, + "comments": _query_comments, + "journal_ref": _query_journal_ref, + "report_num": _query_report_num, + "acm_class": _query_acm_class, + "msc_class": _query_msc_class, + "cross_list_category": _query_secondary, + "doi": _query_doi, + "paper_id": _query_paper_id, + "orcid": orcid_query, + "author_id": author_id_query, + "license": _license_query, + "all": _query_all_fields, +} diff --git a/search/services/index/results.py b/search/services/index/results.py index fd262074..1ea5fc9a 100644 --- a/search/services/index/results.py +++ b/search/services/index/results.py @@ -4,18 +4,16 @@ The primary public function in this module is :func:`.to_documentset`. """ -import re -from datetime import datetime from math import floor -from typing import Any, Dict, Union +from typing import Union +from datetime import datetime from elasticsearch_dsl.response import Response, Hit -from elasticsearch_dsl.utils import AttrList, AttrDict -from search.domain import Document, Query, DocumentSet, Classification, Person -from arxiv.base import logging -from .util import MAX_RESULTS, TEXISM -from .highlighting import add_highlighting, preview +from arxiv.base import logging +from search.domain import Document, Query, DocumentSet +from search.services.index.util import MAX_RESULTS +from search.services.index.highlighting import add_highlighting, preview logger = logging.getLogger(__name__) logger.propagate = False @@ -26,53 +24,55 @@ def to_document(raw: Union[Hit, dict], highlight: bool = True) -> Document: # typing: ignore result: Document = {} - result['match'] = {} # Hit on field, but no highlighting. - result['truncated'] = {} # Preview is truncated. + result["match"] = {} # Hit on field, but no highlighting. + result["truncated"] = {} # Preview is truncated. - result.update(raw.__dict__['_d_']) + result.update(raw.__dict__["_d_"]) # Parse dates to date/datetime objects. - if 'announced_date_first' in result: - result['announced_date_first'] = \ - datetime.strptime(raw['announced_date_first'], '%Y-%m').date() - for key in ['', '_first', '_latest']: - key = f'submitted_date{key}' + if "announced_date_first" in result: + result["announced_date_first"] = datetime.strptime( + raw["announced_date_first"], "%Y-%m" + ).date() + for key in ["", "_first", "_latest"]: + key = f"submitted_date{key}" if key not in result: continue try: - result[key] = datetime.strptime(raw[key], '%Y-%m-%dT%H:%M:%S%z') + result[key] = datetime.strptime(raw[key], "%Y-%m-%dT%H:%M:%S%z") except (ValueError, TypeError): - logger.warning(f'Could not parse {key} as datetime') + logger.warning(f"Could not parse {key} as datetime") pass - for key in ['acm_class', 'msc_class']: + for key in ["acm_class", "msc_class"]: if key in result and result[key]: - result[key] = '; '.join(result[key]) + result[key] = "; ".join(result[key]) try: - result['score'] = raw.meta.score # type: ignore + result["score"] = raw.meta.score # type: ignore except AttributeError: pass - if highlight: # type(result.get('abstract')) is str and - result['highlight'] = {} - logger.debug('%s: add highlighting to result', result['paper_id']) + if highlight: # type(result.get('abstract')) is str and + result["highlight"] = {} + logger.debug("%s: add highlighting to result", result["paper_id"]) - if 'preview' not in result: - result['preview'] = {} + if "preview" not in result: + result["preview"] = {} - if 'abstract' in result: - result['preview']['abstract'] = preview(result['abstract']) - if result['preview']['abstract'].endswith('…'): - result['truncated']['abstract'] = True + if "abstract" in result: + result["preview"]["abstract"] = preview(result["abstract"]) + if result["preview"]["abstract"].endswith("…"): + result["truncated"]["abstract"] = True result = add_highlighting(result, raw) return result -def to_documentset(query: Query, response: Response, highlight: bool = True) \ - -> DocumentSet: +def to_documentset( + query: Query, response: Response, highlight: bool = True +) -> DocumentSet: """ Transform a response from ES to a :class:`.DocumentSet`. @@ -90,21 +90,21 @@ def to_documentset(query: Query, response: Response, highlight: bool = True) \ page, along with pagination metadata. """ - max_pages = int(MAX_RESULTS/query.size) - N_pages_raw = response['hits']['total']/query.size - N_pages = int(floor(N_pages_raw)) + \ - int(N_pages_raw % query.size > 0) - logger.debug('got %i results', response['hits']['total']) + max_pages = int(MAX_RESULTS / query.size) + n_pages_raw = response["hits"]["total"] / query.size + n_pages = int(floor(n_pages_raw)) + int(n_pages_raw % query.size > 0) + logger.debug("got %i results", response["hits"]["total"]) return { - 'metadata': { - 'start': query.page_start, - 'end': min(query.page_start + query.size, - response['hits']['total']), - 'total': response['hits']['total'], - 'current_page': query.page, - 'total_pages': N_pages, - 'size': query.size, - 'max_pages': max_pages + "metadata": { + "start": query.page_start, + "end": min( + query.page_start + query.size, response["hits"]["total"] + ), + "total_results": response["hits"]["total"], + "current_page": query.page, + "total_pages": n_pages, + "size": query.size, + "max_pages": max_pages, }, - 'results': [to_document(raw, highlight=highlight) for raw in response] + "results": [to_document(raw, highlight=highlight) for raw in response], } diff --git a/search/services/index/simple.py b/search/services/index/simple.py index 1406f0c1..78cfc179 100644 --- a/search/services/index/simple.py +++ b/search/services/index/simple.py @@ -31,8 +31,9 @@ def simple_search(search: Search, query: SimpleQuery) -> Search: if query.classification: _q = limit_by_classification(query.classification) if query.include_cross_list: - _q |= limit_by_classification(query.classification, - "secondary_classification") + _q |= limit_by_classification( + query.classification, "secondary_classification" + ) q &= _q search = search.query(q) search = sort(query, search) diff --git a/search/services/index/tests/test_reindex.py b/search/services/index/tests/test_reindex.py index a5340c35..54c67646 100644 --- a/search/services/index/tests/test_reindex.py +++ b/search/services/index/tests/test_reindex.py @@ -7,62 +7,90 @@ def raise_index_exists(*args, **kwargs): """Raise a resource_already_exists_exception TransportError.""" - raise index.TransportError(400, 'resource_already_exists_exception', {}) + raise index.TransportError(400, "resource_already_exists_exception", {}) class TestReindexing(TestCase): """Tests for :func:`.index.reindex`.""" - @mock.patch('search.services.index.Elasticsearch') + @mock.patch("search.services.index.Elasticsearch") def test_reindex_from_scratch(self, mock_Elasticsearch): """Reindex to an index that does not exist.""" mock_es = mock.MagicMock() mock_Elasticsearch.return_value = mock_es - index.SearchSession.reindex('barindex', 'bazindex') - self.assertEqual(mock_es.indices.create.call_count, 1, - "Should attempt to create the new index") - self.assertEqual(mock_es.indices.create.call_args[0][0], "bazindex", - "Should attempt to create the new index") + index.SearchSession.reindex("barindex", "bazindex") + self.assertEqual( + mock_es.indices.create.call_count, + 1, + "Should attempt to create the new index", + ) + self.assertEqual( + mock_es.indices.create.call_args[0][0], + "bazindex", + "Should attempt to create the new index", + ) - self.assertEqual(mock_es.reindex.call_count, 1, - "Should proceed to request reindexing") - self.assertEqual(mock_es.reindex.call_args[0][0]['source']['index'], - 'barindex') - self.assertEqual(mock_es.reindex.call_args[0][0]['dest']['index'], - 'bazindex') + self.assertEqual( + mock_es.reindex.call_count, + 1, + "Should proceed to request reindexing", + ) + self.assertEqual( + mock_es.reindex.call_args[0][0]["source"]["index"], "barindex" + ) + self.assertEqual( + mock_es.reindex.call_args[0][0]["dest"]["index"], "bazindex" + ) - @mock.patch('search.services.index.Elasticsearch') + @mock.patch("search.services.index.Elasticsearch") def test_reindex_already_exists(self, mock_Elasticsearch): """Reindex to an index that already exists.""" mock_es = mock.MagicMock() mock_Elasticsearch.return_value = mock_es mock_es.indices.create.side_effect = raise_index_exists - index.SearchSession.reindex('barindex', 'bazindex') - self.assertEqual(mock_es.indices.create.call_count, 1, - "Should attempt to create the new index") - self.assertEqual(mock_es.indices.create.call_args[0][0], "bazindex", - "Should attempt to create the new index") + index.SearchSession.reindex("barindex", "bazindex") + self.assertEqual( + mock_es.indices.create.call_count, + 1, + "Should attempt to create the new index", + ) + self.assertEqual( + mock_es.indices.create.call_args[0][0], + "bazindex", + "Should attempt to create the new index", + ) - self.assertEqual(mock_es.reindex.call_count, 1, - "Should proceed to request reindexing") - self.assertEqual(mock_es.reindex.call_args[0][0]['source']['index'], - 'barindex') - self.assertEqual(mock_es.reindex.call_args[0][0]['dest']['index'], - 'bazindex') + self.assertEqual( + mock_es.reindex.call_count, + 1, + "Should proceed to request reindexing", + ) + self.assertEqual( + mock_es.reindex.call_args[0][0]["source"]["index"], "barindex" + ) + self.assertEqual( + mock_es.reindex.call_args[0][0]["dest"]["index"], "bazindex" + ) class TestTaskStatus(TestCase): """Tests for :func:`.index.get_task_status`.""" - @mock.patch('search.services.index.Elasticsearch') + @mock.patch("search.services.index.Elasticsearch") def test_get_task_status(self, mock_Elasticsearch): """Get task status via the ES API.""" mock_es = mock.MagicMock() mock_Elasticsearch.return_value = mock_es - task_id = 'foonode:bartask' + task_id = "foonode:bartask" index.SearchSession.get_task_status(task_id) - self.assertEqual(mock_es.tasks.get.call_count, 1, - "Should call the task status endpoint") - self.assertEqual(mock_es.tasks.get.call_args[0][0], task_id, - "Should call the task status endpoint with task ID") + self.assertEqual( + mock_es.tasks.get.call_count, + 1, + "Should call the task status endpoint", + ) + self.assertEqual( + mock_es.tasks.get.call_args[0][0], + task_id, + "Should call the task status endpoint with task ID", + ) diff --git a/search/services/index/tests/test_results.py b/search/services/index/tests/test_results.py index 3d21bc60..f67cf8fc 100644 --- a/search/services/index/tests/test_results.py +++ b/search/services/index/tests/test_results.py @@ -1,6 +1,6 @@ """Tests for :mod:`search.services.index`.""" -from unittest import TestCase, mock +from unittest import TestCase from search.services.index import highlighting @@ -19,35 +19,39 @@ def setUp(self): " $\\mathrm{Z}(\\mu\\mu)\\mathrm{H}$, and" " $\\mathrm{Z}(\\mathrm{e}\\mathrm{e})\\mathrm{H}$." ) - self.start_tag = '' - self.end_tag = '' + self.start_tag = "" + self.end_tag = "" def test_preview(self): """Generate a preview that is smaller than/equal to fragment size.""" - preview = highlighting.preview(self.value, fragment_size=350, - start_tag=self.start_tag, - end_tag=self.end_tag) + preview = highlighting.preview( + self.value, + fragment_size=350, + start_tag=self.start_tag, + end_tag=self.end_tag, + ) self.assertGreaterEqual(338, len(preview)) def test_preview_with_close_highlights(self): """Two highlights in the abstract are close together.""" value = ( "We investigate self-averaging properties in the transport of" - " particles through random media. We show" + ' particles through random media. We show' " rigorously that in the subdiffusive anomalous regime transport" " coefficients are not self--averaging quantities. These" " quantities are exactly calculated in the case of directed" - " " + ' ' "random walks. In the case of general symmetric random" + ' class="has-text-success has-text-weight-bold mathjax">random' " walks a perturbative analysis around the Effective Medium" " Approximation (EMA) is performed." ) - start_tag = "" + start_tag = ( + '' + ) end_tag = "" - preview = highlighting.preview(value, start_tag=start_tag, - end_tag=end_tag) + _ = highlighting.preview(value, start_tag=start_tag, end_tag=end_tag) class TestResultsEndSafely(TestCase): @@ -65,35 +69,35 @@ def setUp(self): " $\\mathrm{Z}(\\mu\\mu)\\mathrm{H}$, and" " $\\mathrm{Z}(\\mathrm{e}\\mathrm{e})\\mathrm{H}$." ) - self.start_tag = '' - self.end_tag = '' + self.start_tag = "" + self.end_tag = "" def test_end_safely_from_start(self): """No TeXisms/HTML are found within the desired fragment size.""" - end = highlighting._end_safely(self.value, 45, - start_tag=self.start_tag, - end_tag=self.end_tag) + end = highlighting._end_safely( + self.value, 45, start_tag=self.start_tag, end_tag=self.end_tag + ) self.assertEqual(end, 45, "Should end at the desired fragment length.") def test_end_safely_before_texism(self): """End before TeXism when desired fragment size would truncate.""" - end = highlighting._end_safely(self.value, 55, - start_tag=self.start_tag, - end_tag=self.end_tag) + end = highlighting._end_safely( + self.value, 55, start_tag=self.start_tag, end_tag=self.end_tag + ) # print(self.value[:end]) self.assertEqual(end, 50, "Should end before the start of the TeXism.") def test_end_safely_before_html(self): """End before HTML when desired fragment size would truncate.""" - end = highlighting._end_safely(self.value, 215, - start_tag=self.start_tag, - end_tag=self.end_tag) + end = highlighting._end_safely( + self.value, 215, start_tag=self.start_tag, end_tag=self.end_tag + ) # print(self.value[:end]) self.assertEqual(end, 213, "Should end before the start of the tag.") def test_end_safely_after_html_with_tolerance(self): """End before HTML when desired fragment size would truncate.""" - end = highlighting._end_safely(self.value, 275, - start_tag=self.start_tag, - end_tag=self.end_tag) + end = highlighting._end_safely( + self.value, 275, start_tag=self.start_tag, end_tag=self.end_tag + ) self.assertEqual(end, 275, "Should end after the closing tag.") diff --git a/search/services/index/tests/test_util.py b/search/services/index/tests/test_util.py index cb74a081..0cb46f73 100644 --- a/search/services/index/tests/test_util.py +++ b/search/services/index/tests/test_util.py @@ -10,34 +10,34 @@ class TestMatchDatePartial(TestCase): def test_date_partial_only(self): """Term includes only a four-digit date partial.""" - term, rmd = util.parse_date('1902') + term, rmd = util.parse_date("1902") ym = util.parse_date_partial(term) - self.assertEqual(ym, '2019-02') - self.assertEqual(rmd, '', "Should have no remainder") + self.assertEqual(ym, "2019-02") + self.assertEqual(rmd, "", "Should have no remainder") def test_in_word(self): """A false positive in a word.""" with self.assertRaises(ValueError): - term, rmd = util.parse_date('notasearch1902foradatepartial') + term, rmd = util.parse_date("notasearch1902foradatepartial") def test_near_words(self): """Term includes date partial plus other terms.""" - term, rmd = util.parse_date('foo 1902 bar') + term, rmd = util.parse_date("foo 1902 bar") ym = util.parse_date_partial(term) - self.assertEqual(ym, '2019-02') + self.assertEqual(ym, "2019-02") self.assertEqual(rmd, "foo bar", "Should have remainder") def test_out_of_range(self): """Term looks like a date partial, but is not a valid date.""" - term, rmd = util.parse_date('0699') + term, rmd = util.parse_date("0699") self.assertIsNone(util.parse_date_partial(term)) def test_last_millenium(self): """Term is for a pre-2000 paper.""" - term, rmd = util.parse_date('old paper 9505') + term, rmd = util.parse_date("old paper 9505") ym = util.parse_date_partial(term) - self.assertEqual(ym, '1995-05') - self.assertEqual(rmd, 'old paper', 'Should have a remainder') + self.assertEqual(ym, "1995-05") + self.assertEqual(rmd, "old paper", "Should have a remainder") class TestOldPapernumDetection(TestCase): @@ -45,9 +45,9 @@ class TestOldPapernumDetection(TestCase): def test_is_old_papernum(self): """User enters a 7-digit number that looks like an old papernum.""" - self.assertFalse(util.is_old_papernum('9106001')) - self.assertTrue(util.is_old_papernum('9107001')) - self.assertFalse(util.is_old_papernum('9200001')) - self.assertTrue(util.is_old_papernum('9201001')) - self.assertTrue(util.is_old_papernum('0703999')) - self.assertFalse(util.is_old_papernum('0704001')) + self.assertFalse(util.is_old_papernum("9106001")) + self.assertTrue(util.is_old_papernum("9107001")) + self.assertFalse(util.is_old_papernum("9200001")) + self.assertTrue(util.is_old_papernum("9201001")) + self.assertTrue(util.is_old_papernum("0703999")) + self.assertFalse(util.is_old_papernum("0704001")) diff --git a/search/services/index/tests/tests.py b/search/services/index/tests/tests.py index 480a158e..f3e23885 100644 --- a/search/services/index/tests/tests.py +++ b/search/services/index/tests/tests.py @@ -1,34 +1,113 @@ """Tests for :mod:`search.services.index`.""" from unittest import TestCase, mock -from datetime import date, datetime, timedelta -from pytz import timezone -from elasticsearch_dsl import Search, Q -from elasticsearch_dsl.query import Range, Match, Bool, Nested +from datetime import datetime, timedelta from search.services import index from search.services.index import advanced from search.services.index.util import wildcard_escape, Q_ -from search.domain import Query, FieldedSearchTerm, DateRange, Classification,\ - AdvancedQuery, FieldedSearchList, ClassificationList, SimpleQuery, \ - DocumentSet +from search.domain import ( + SortBy, + SortOrder, + FieldedSearchTerm, + DateRange, + Classification, + AdvancedQuery, + FieldedSearchList, + ClassificationList, + SimpleQuery, + Field, + Term, + ClassicAPIQuery, + Operator, +) + + +class TestClassicApiQuery(TestCase): + def test_classis_query_creation(self): + self.assertRaises(ValueError, lambda: ClassicAPIQuery()) + # There is no assert not raises + self.assertIsNotNone(ClassicAPIQuery(search_query="")) + self.assertIsNotNone(ClassicAPIQuery(id_list=[])) + + def test_to_query_string(self): + self.assertEqual( + ClassicAPIQuery(id_list=[]).to_query_string(), + "search_query=&id_list=&start=0&max_results=10", + ) + self.assertEqual( + ClassicAPIQuery( + search_query="all:electron", id_list=[] + ).to_query_string(), + "search_query=all:electron&id_list=&start=0&max_results=10", + ) + self.assertEqual( + ClassicAPIQuery( + search_query="all:electron", + id_list=["1705.09169v3", "1705.09129v3"], + ).to_query_string(), + "search_query=all:electron&id_list=1705.09169v3,1705.09129v3" + "&start=0&max_results=10", + ) + self.assertEqual( + ClassicAPIQuery( + search_query="all:electron", + id_list=["1705.09169v3", "1705.09129v3"], + page_start=3, + ).to_query_string(), + "search_query=all:electron&id_list=1705.09169v3,1705.09129v3" + "&start=3&max_results=10", + ) + self.assertEqual( + ClassicAPIQuery( + search_query="all:electron", + id_list=["1705.09169v3", "1705.09129v3"], + page_start=3, + size=50, + ).to_query_string(), + "search_query=all:electron&id_list=1705.09169v3,1705.09129v3" + "&start=3&max_results=50", + ) + self.assertEqual( + ClassicAPIQuery( + search_query="all:electron", page_start=3, size=50 + ).to_query_string(), + "search_query=all:electron&id_list=&start=3&max_results=50", + ) + self.assertEqual( + ClassicAPIQuery( + id_list=["1705.09169v3", "1705.09129v3"], page_start=3, size=50 + ).to_query_string(), + "search_query=&id_list=1705.09169v3,1705.09129v3" + "&start=3&max_results=50", + ) + self.assertEqual( + ClassicAPIQuery( + search_query="all:electron", size=50 + ).to_query_string(), + "search_query=all:electron&id_list=&start=0&max_results=50", + ) + -EASTERN = timezone('US/Eastern') +def mock_rdata(): + return { + "authors": [{"full_name": "N. Ame"}], + "owners": [{"full_name": "N. Ame"}], + "submitter": {"full_name": "N. Ame"}, + "paper_id": "1234.56789", + } class TestSearch(TestCase): """Tests for :func:`.index.search`.""" - @mock.patch('search.services.index.Search') - @mock.patch('search.services.index.Elasticsearch') + @mock.patch("search.services.index.Search") + @mock.patch("search.services.index.Elasticsearch") def test_advanced_query(self, mock_Elasticsearch, mock_Search): """:class:`.index.search` supports :class:`AdvancedQuery`.""" mock_results = mock.MagicMock() - mock_results.__getitem__.return_value = {'total': 53} - rdata = dict(authors=[{'full_name': 'N. Ame'}], - owners=[{'full_name': 'N. Ame'}], - submitter={'full_name': 'N. Ame'}, - paper_id='1234.56789') + mock_results.__getitem__.return_value = {"total": 53} + rdata = mock_rdata() mock_result = mock.MagicMock(_d_=rdata, **rdata) mock_result.meta.score = 1 mock_results.__iter__.return_value = [mock_result] @@ -44,59 +123,79 @@ def test_advanced_query(self, mock_Elasticsearch, mock_Search): mock_Search.__getitem__.return_value = mock_Search query = AdvancedQuery( - order='relevance', + order="relevance", size=10, date_range=DateRange( start_date=datetime.now() - timedelta(days=5), - end_date=datetime.now() + end_date=datetime.now(), + ), + classification=ClassificationList( + [ + Classification( + group={"id": "physics"}, + archive={"id": "physics"}, + category={"id": "hep-th"}, + ) + ] + ), + terms=FieldedSearchList( + [ + FieldedSearchTerm( + operator="AND", field="title", term="foo" + ), + FieldedSearchTerm( + operator="AND", field="author", term="joe" + ), + FieldedSearchTerm( + operator="OR", field="abstract", term="hmm" + ), + FieldedSearchTerm( + operator="NOT", field="comments", term="eh" + ), + FieldedSearchTerm( + operator="AND", + field="journal_ref", + term="jref (1999) 1:2-3", + ), + FieldedSearchTerm( + operator="AND", field="acm_class", term="abc123" + ), + FieldedSearchTerm( + operator="AND", field="msc_class", term="abc123" + ), + FieldedSearchTerm( + operator="OR", field="report_num", term="abc123" + ), + FieldedSearchTerm( + operator="OR", field="doi", term="10.01234/56789" + ), + FieldedSearchTerm( + operator="OR", + field="orcid", + term="0000-0000-0000-0000", + ), + FieldedSearchTerm( + operator="OR", field="author_id", term="Bloggs_J" + ), + ] ), - classification=ClassificationList([ - Classification( - group={'id': 'physics'}, - archive={'id': 'physics'}, - category={'id': 'hep-th'} - ) - ]), - terms=FieldedSearchList([ - FieldedSearchTerm(operator='AND', field='title', term='foo'), - FieldedSearchTerm(operator='AND', field='author', term='joe'), - FieldedSearchTerm(operator='OR', field='abstract', term='hmm'), - FieldedSearchTerm(operator='NOT', field='comments', term='eh'), - FieldedSearchTerm(operator='AND', field='journal_ref', - term='jref (1999) 1:2-3'), - FieldedSearchTerm(operator='AND', field='acm_class', - term='abc123'), - FieldedSearchTerm(operator='AND', field='msc_class', - term='abc123'), - FieldedSearchTerm(operator='OR', field='report_num', - term='abc123'), - FieldedSearchTerm(operator='OR', field='doi', - term='10.01234/56789'), - FieldedSearchTerm(operator='OR', field='orcid', - term='0000-0000-0000-0000'), - FieldedSearchTerm(operator='OR', field='author_id', - term='Bloggs_J'), - ]) ) document_set = index.SearchSession.search(query) # self.assertIsInstance(document_set, DocumentSet) - self.assertEqual(document_set['metadata']['start'], 0) - self.assertEqual(document_set['metadata']['total'], 53) - self.assertEqual(document_set['metadata']['current_page'], 1) - self.assertEqual(document_set['metadata']['total_pages'], 6) - self.assertEqual(document_set['metadata']['size'], 10) - self.assertEqual(len(document_set['results']), 1) - - @mock.patch('search.services.index.Search') - @mock.patch('search.services.index.Elasticsearch') + self.assertEqual(document_set["metadata"]["start"], 0) + self.assertEqual(document_set["metadata"]["total_results"], 53) + self.assertEqual(document_set["metadata"]["current_page"], 1) + self.assertEqual(document_set["metadata"]["total_pages"], 6) + self.assertEqual(document_set["metadata"]["size"], 10) + self.assertEqual(len(document_set["results"]), 1) + + @mock.patch("search.services.index.Search") + @mock.patch("search.services.index.Elasticsearch") def test_simple_query(self, mock_Elasticsearch, mock_Search): """:class:`.index.search` supports :class:`SimpleQuery`.""" mock_results = mock.MagicMock() - mock_results.__getitem__.return_value = {'total': 53} - rdata = dict(authors=[{'full_name': 'N. Ame'}], - owners=[{'full_name': 'N. Ame'}], - submitter={'full_name': 'N. Ame'}, - paper_id='1234.56789') + mock_results.__getitem__.return_value = {"total": 53} + rdata = mock_rdata() mock_result = mock.MagicMock(_d_=rdata, **rdata) mock_result.meta.score = 1 mock_results.__iter__.return_value = [mock_result] @@ -112,19 +211,168 @@ def test_simple_query(self, mock_Elasticsearch, mock_Search): mock_Search.__getitem__.return_value = mock_Search query = SimpleQuery( - order='relevance', + order="relevance", size=10, search_field="title", value="foo title" + ) + document_set = index.SearchSession.search(query) + # self.assertIsInstance(document_set, DocumentSet) + self.assertEqual(document_set["metadata"]["start"], 0) + self.assertEqual(document_set["metadata"]["total_results"], 53) + self.assertEqual(document_set["metadata"]["current_page"], 1) + self.assertEqual(document_set["metadata"]["total_pages"], 6) + self.assertEqual(document_set["metadata"]["size"], 10) + self.assertEqual(len(document_set["results"]), 1) + + @mock.patch("search.services.index.Search") + @mock.patch("search.services.index.Elasticsearch") + def test_classic_query(self, mock_Elasticsearch, mock_Search): + """:class:`.index.search` supports :class:`SimpleQuery`.""" + mock_results = mock.MagicMock() + mock_results.__getitem__.return_value = {"total": 53} + rdata = mock_rdata() + mock_result = mock.MagicMock(_d_=rdata, **rdata) + mock_result.meta.score = 1 + mock_results.__iter__.return_value = [mock_result] + mock_Search.execute.return_value = mock_results + + # Support the chaining API for py-ES. + mock_Search.return_value = mock_Search + mock_Search.filter.return_value = mock_Search + mock_Search.highlight.return_value = mock_Search + mock_Search.highlight_options.return_value = mock_Search + mock_Search.query.return_value = mock_Search + mock_Search.sort.return_value = mock_Search + mock_Search.__getitem__.return_value = mock_Search + + query = ClassicAPIQuery( + phrase=Term(Field.Author, "copernicus"), + order=SortOrder(by=SortBy.relevance), size=10, - search_field='title', - value='foo title' ) + document_set = index.SearchSession.search(query) # self.assertIsInstance(document_set, DocumentSet) - self.assertEqual(document_set['metadata']['start'], 0) - self.assertEqual(document_set['metadata']['total'], 53) - self.assertEqual(document_set['metadata']['current_page'], 1) - self.assertEqual(document_set['metadata']['total_pages'], 6) - self.assertEqual(document_set['metadata']['size'], 10) - self.assertEqual(len(document_set['results']), 1) + self.assertEqual(document_set["metadata"]["start"], 0) + self.assertEqual(document_set["metadata"]["total_results"], 53) + self.assertEqual(document_set["metadata"]["current_page"], 1) + self.assertEqual(document_set["metadata"]["total_pages"], 6) + self.assertEqual(document_set["metadata"]["size"], 10) + self.assertEqual(len(document_set["results"]), 1) + + @mock.patch("search.services.index.Search") + @mock.patch("search.services.index.Elasticsearch") + def test_classic_query_complex(self, mock_Elasticsearch, mock_Search): + """:class:`.index.search` supports :class:`SimpleQuery`.""" + mock_results = mock.MagicMock() + mock_results.__getitem__.return_value = {"total": 53} + rdata = mock_rdata() + mock_result = mock.MagicMock(_d_=rdata, **rdata) + mock_result.meta.score = 1 + mock_results.__iter__.return_value = [mock_result] + mock_Search.execute.return_value = mock_results + + # Support the chaining API for py-ES. + mock_Search.return_value = mock_Search + mock_Search.filter.return_value = mock_Search + mock_Search.highlight.return_value = mock_Search + mock_Search.highlight_options.return_value = mock_Search + mock_Search.query.return_value = mock_Search + mock_Search.sort.return_value = mock_Search + mock_Search.__getitem__.return_value = mock_Search + + query = ClassicAPIQuery( + phrase=( + Operator.OR, + Term(Field.Author, "copernicus"), + (Operator.ANDNOT, Term(Field.Title, "dark matter")), + ), + order=SortOrder(by=SortBy.relevance), + size=10, + ) + + document_set = index.SearchSession.search(query) + # self.assertIsInstance(document_set, DocumentSet) + self.assertEqual(document_set["metadata"]["start"], 0) + self.assertEqual(document_set["metadata"]["total_results"], 53) + self.assertEqual(document_set["metadata"]["current_page"], 1) + self.assertEqual(document_set["metadata"]["total_pages"], 6) + self.assertEqual(document_set["metadata"]["size"], 10) + self.assertEqual(len(document_set["results"]), 1) + + @mock.patch("search.services.index.Search") + @mock.patch("search.services.index.Elasticsearch") + def test_classic_query_id_list(self, mock_Elasticsearch, mock_Search): + """:class:`.index.search` supports :class:`SimpleQuery`.""" + mock_results = mock.MagicMock() + mock_results.__getitem__.return_value = {"total": 53} + rdata = mock_rdata() + mock_result = mock.MagicMock(_d_=rdata, **rdata) + mock_result.meta.score = 1 + mock_results.__iter__.return_value = [mock_result] + mock_Search.execute.return_value = mock_results + + # Support the chaining API for py-ES. + mock_Search.return_value = mock_Search + mock_Search.filter.return_value = mock_Search + mock_Search.highlight.return_value = mock_Search + mock_Search.highlight_options.return_value = mock_Search + mock_Search.query.return_value = mock_Search + mock_Search.sort.return_value = mock_Search + mock_Search.__getitem__.return_value = mock_Search + + query = ClassicAPIQuery( + id_list=["1234.56789"], + order=SortOrder(by=SortBy.relevance), + size=10, + ) + + document_set = index.SearchSession.search(query) + # self.assertIsInstance(document_set, DocumentSet) + self.assertEqual(document_set["metadata"]["start"], 0) + self.assertEqual(document_set["metadata"]["total_results"], 53) + self.assertEqual(document_set["metadata"]["current_page"], 1) + self.assertEqual(document_set["metadata"]["total_pages"], 6) + self.assertEqual(document_set["metadata"]["size"], 10) + self.assertEqual(len(document_set["results"]), 1) + + @mock.patch("search.services.index.Search") + @mock.patch("search.services.index.Elasticsearch") + def test_classic_query_phrases(self, mock_Elasticsearch, mock_Search): + """:class:`.index.search` supports :class:`SimpleQuery`.""" + mock_results = mock.MagicMock() + mock_results.__getitem__.return_value = {"total": 53} + rdata = mock_rdata() + mock_result = mock.MagicMock(_d_=rdata, **rdata) + mock_result.meta.score = 1 + mock_results.__iter__.return_value = [mock_result] + mock_Search.execute.return_value = mock_results + + # Support the chaining API for py-ES. + mock_Search.return_value = mock_Search + mock_Search.filter.return_value = mock_Search + mock_Search.highlight.return_value = mock_Search + mock_Search.highlight_options.return_value = mock_Search + mock_Search.query.return_value = mock_Search + mock_Search.sort.return_value = mock_Search + mock_Search.__getitem__.return_value = mock_Search + + query = ClassicAPIQuery( + phrase=( + Operator.AND, + Term(Field.Author, "copernicus"), + Term(Field.Title, "philosophy"), + ), + order=SortOrder(by=SortBy.relevance), + size=10, + ) + + document_set = index.SearchSession.search(query) + # self.assertIsInstance(document_set, DocumentSet) + self.assertEqual(document_set["metadata"]["start"], 0) + self.assertEqual(document_set["metadata"]["total_results"], 53) + self.assertEqual(document_set["metadata"]["current_page"], 1) + self.assertEqual(document_set["metadata"]["total_pages"], 6) + self.assertEqual(document_set["metadata"]["size"], 10) + self.assertEqual(len(document_set["results"]), 1) class TestWildcardSearch(TestCase): @@ -138,9 +386,9 @@ def test_match_any_wildcard_is_present(self): self.assertTrue(wildcard, "Wildcard should be detected") self.assertEqual(qs, qs_escaped, "The querystring should be unchanged") self.assertIsInstance( - Q_('match', 'title', qs), - type(index.Q('wildcard', title=qs)), - "Wildcard Q object should be generated" + Q_("match", "title", qs), + type(index.Q("wildcard", title=qs)), + "Wildcard Q object should be generated", ) def test_match_any_wildcard_in_literal(self): @@ -148,12 +396,14 @@ def test_match_any_wildcard_in_literal(self): qs = '"Foo t*"' qs_escaped, wildcard = wildcard_escape(qs) - self.assertEqual(qs_escaped, '"Foo t\*"', "Wildcard should be escaped") + self.assertEqual( + qs_escaped, r'"Foo t\*"', "Wildcard should be escaped" + ) self.assertFalse(wildcard, "Wildcard should not be detected") self.assertIsInstance( - Q_('match', 'title', qs), - type(index.Q('match', title='"Foo t\*"')), - "Wildcard Q object should not be generated" + Q_("match", "title", qs), + type(index.Q("match", title=r'"Foo t\*"')), + "Wildcard Q object should not be generated", ) def test_multiple_match_any_wildcard_in_literal(self): @@ -161,13 +411,14 @@ def test_multiple_match_any_wildcard_in_literal(self): qs = '"Fo*o t*"' qs_escaped, wildcard = wildcard_escape(qs) - self.assertEqual(qs_escaped, '"Fo\*o t\*"', - "Both wildcards should be escaped") + self.assertEqual( + qs_escaped, r'"Fo\*o t\*"', "Both wildcards should be escaped" + ) self.assertFalse(wildcard, "Wildcard should not be detected") self.assertIsInstance( - Q_('match', 'title', qs), - type(index.Q('match', title='"Fo\*o t\*"')), - "Wildcard Q object should not be generated" + Q_("match", "title", qs), + type(index.Q("match", title=r'"Fo\*o t\*"')), + "Wildcard Q object should not be generated", ) def test_mixed_wildcards_in_literal(self): @@ -175,13 +426,14 @@ def test_mixed_wildcards_in_literal(self): qs = '"Fo? t*"' qs_escaped, wildcard = wildcard_escape(qs) - self.assertEqual(qs_escaped, '"Fo\? t\*"', - "Both wildcards should be escaped") + self.assertEqual( + qs_escaped, r'"Fo\? t\*"', "Both wildcards should be escaped" + ) self.assertFalse(wildcard, "Wildcard should not be detected") self.assertIsInstance( - Q_('match', 'title', qs), - type(index.Q('match', title='"Fo\? t\*"')), - "Wildcard Q object should not be generated" + Q_("match", "title", qs), + type(index.Q("match", title=r'"Fo\? t\*"')), + "Wildcard Q object should not be generated", ) def test_wildcards_both_inside_and_outside_literal(self): @@ -189,13 +441,16 @@ def test_wildcards_both_inside_and_outside_literal(self): qs = '"Fo? t*" said the *' qs_escaped, wildcard = wildcard_escape(qs) - self.assertEqual(qs_escaped, '"Fo\? t\*" said the *', - "Wildcards in literal should be escaped") + self.assertEqual( + qs_escaped, + r'"Fo\? t\*" said the *', + "Wildcards in literal should be escaped", + ) self.assertTrue(wildcard, "Wildcard should be detected") self.assertIsInstance( - Q_('match', 'title', qs), - type(index.Q('wildcard', title='"Fo\? t\*" said the *')), - "Wildcard Q object should be generated" + Q_("match", "title", qs), + type(index.Q("wildcard", title=r'"Fo\? t\*" said the *')), + "Wildcard Q object should be generated", ) def test_wildcards_inside_outside_multiple_literals(self): @@ -203,14 +458,17 @@ def test_wildcards_inside_outside_multiple_literals(self): qs = '"Fo?" s* "yes*" o?' qs_escaped, wildcard = wildcard_escape(qs) - self.assertEqual(qs_escaped, '"Fo\?" s* "yes\*" o?', - "Wildcards in literal should be escaped") + self.assertEqual( + qs_escaped, + r'"Fo\?" s* "yes\*" o?', + "Wildcards in literal should be escaped", + ) self.assertTrue(wildcard, "Wildcard should be detected") self.assertIsInstance( - Q_('match', 'title', qs), - type(index.Q('wildcard', title='"Fo\?" s* "yes\*" o?')), - "Wildcard Q object should be generated" + Q_("match", "title", qs), + type(index.Q("wildcard", title=r'"Fo\?" s* "yes\*" o?')), + "Wildcard Q object should be generated", ) def test_wildcard_at_opening_of_string(self): @@ -219,7 +477,7 @@ def test_wildcard_at_opening_of_string(self): wildcard_escape("*nope") with self.assertRaises(index.QueryError): - Q_('match', 'title', '*nope') + Q_("match", "title", "*nope") class TestPrepare(TestCase): @@ -227,49 +485,75 @@ class TestPrepare(TestCase): def test_group_terms(self): """:meth:`._group_terms` groups terms using logical precedence.""" - query = AdvancedQuery(terms=FieldedSearchList([ - FieldedSearchTerm(operator=None, field='title', term='muon'), - FieldedSearchTerm(operator='OR', field='title', term='gluon'), - FieldedSearchTerm(operator='NOT', field='title', term='foo'), - FieldedSearchTerm(operator='AND', field='title', term='boson'), - ])) + query = AdvancedQuery( + terms=FieldedSearchList( + [ + FieldedSearchTerm( + operator=None, field="title", term="muon" + ), + FieldedSearchTerm( + operator="OR", field="title", term="gluon" + ), + FieldedSearchTerm( + operator="NOT", field="title", term="foo" + ), + FieldedSearchTerm( + operator="AND", field="title", term="boson" + ), + ] + ) + ) expected = ( - FieldedSearchTerm(operator=None, field='title', term='muon'), - 'OR', + FieldedSearchTerm(operator=None, field="title", term="muon"), + "OR", ( - ( - FieldedSearchTerm(operator='OR', field='title', term='gluon'), - 'NOT', - FieldedSearchTerm(operator='NOT', field='title', term='foo') - ), - 'AND', - FieldedSearchTerm(operator='AND', field='title', term='boson') - ) + ( + FieldedSearchTerm( + operator="OR", field="title", term="gluon" + ), + "NOT", + FieldedSearchTerm( + operator="NOT", field="title", term="foo" + ), + ), + "AND", + FieldedSearchTerm(operator="AND", field="title", term="boson"), + ), ) try: terms = advanced._group_terms(query) except AssertionError: - self.fail('Should result in a single group') + self.fail("Should result in a single group") self.assertEqual(expected, terms) def test_group_terms_all_and(self): """:meth:`._group_terms` groups terms using logical precedence.""" - query = AdvancedQuery(terms=FieldedSearchList([ - FieldedSearchTerm(operator=None, field='title', term='muon'), - FieldedSearchTerm(operator='AND', field='title', term='gluon'), - FieldedSearchTerm(operator='AND', field='title', term='foo'), - ])) + query = AdvancedQuery( + terms=FieldedSearchList( + [ + FieldedSearchTerm( + operator=None, field="title", term="muon" + ), + FieldedSearchTerm( + operator="AND", field="title", term="gluon" + ), + FieldedSearchTerm( + operator="AND", field="title", term="foo" + ), + ] + ) + ) expected = ( ( - FieldedSearchTerm(operator=None, field='title', term='muon'), - 'AND', - FieldedSearchTerm(operator='AND', field='title', term='gluon') + FieldedSearchTerm(operator=None, field="title", term="muon"), + "AND", + FieldedSearchTerm(operator="AND", field="title", term="gluon"), ), - 'AND', - FieldedSearchTerm(operator='AND', field='title', term='foo') + "AND", + FieldedSearchTerm(operator="AND", field="title", term="foo"), ) try: terms = advanced._group_terms(query) except AssertionError: - self.fail('Should result in a single group') + self.fail("Should result in a single group") self.assertEqual(expected, terms) diff --git a/search/services/index/util.py b/search/services/index/util.py index f9e845ef..2cb1d9ce 100644 --- a/search/services/index/util.py +++ b/search/services/index/util.py @@ -1,35 +1,56 @@ """Helpers for building ES queries.""" import re -from typing import Any, Optional, Tuple, Union, List from string import punctuation +from typing import Optional, Tuple -from elasticsearch_dsl import Search, Q, SF +from elasticsearch_dsl import Search, Q + +from search import consts from search.domain import Query -from .exceptions import QueryError +from search.services.index.exceptions import QueryError # We'll compile this ahead of time, since it gets called quite a lot. STRING_LITERAL = re.compile(r"([\"][^\"]*[\"])") """Pattern for string literals (quoted) in search queries.""" -TEXISM = re.compile(r'(([\$]{2}[^\$]+[\$]{2})|([\$]{1}[^\$]+[\$]{1}))') +TEXISM = re.compile(r"(([\$]{2}[^\$]+[\$]{2})|([\$]{1}[^\$]+[\$]{1}))") # TODO: make this configurable. MAX_RESULTS = 10_000 """This is the maximum result offset for pagination.""" -SPECIAL_CHARACTERS = ['+', '=', '&&', '||', '>', '<', '!', '(', ')', '{', - '}', '[', ']', '^', '~', ':', '\\', '/', '-'] -DEFAULT_SORT = ['-announced_date_first', '_doc'] +SPECIAL_CHARACTERS = [ + "+", + "=", + "&&", + "||", + ">", + "<", + "!", + "(", + ")", + "{", + "}", + "[", + "]", + "^", + "~", + ":", + "\\", + "/", + "-", +] DATE_PARTIAL = r"(?:^|[\s])(\d{2})((?:0[1-9]{1})|(?:1[0-2]{1}))(?:$|[\s])" """Used to match parts of paper IDs that encode the announcement date.""" -OLD_ID_NUMBER = \ - r'(910[7-9]|911[0-2]|9[2-9](0[1-9]|1[0-2])|0[0-6](0[1-9]|1[0-2])|070[1-3])'\ - r'(00[1-9]|0[1-9][0-9]|[1-9][0-9][0-9])' +OLD_ID_NUMBER = ( + r"(910[7-9]|911[0-2]|9[2-9](0[1-9]|1[0-2])|0[0-6](0[1-9]|1[0-2])|070[1-3])" + r"(00[1-9]|0[1-9][0-9]|[1-9][0-9][0-9])" +) """ The number part of the old arXiv identifier looks like YYMMNNN. @@ -56,26 +77,30 @@ def wildcard_escape(querystring: str) -> Tuple[str, bool]: """ # This should get caught by the controller (form validation), but just # in case we should check for it here. - if querystring.startswith('?') or querystring.startswith('*'): - raise QueryError('Query cannot start with a wildcard') + if querystring.startswith("?") or querystring.startswith("*"): + raise QueryError("Query cannot start with a wildcard") # Escape wildcard characters within string literals. # re.sub() can't handle the complexity, sadly... parts = re.split(STRING_LITERAL, querystring) - parts = [part.replace('*', r'\*').replace('?', r'\?') - if part.startswith('"') or part.startswith("'") else part - for part in parts] + parts = [ + part.replace("*", r"\*").replace("?", r"\?") + if part.startswith('"') or part.startswith("'") + else part + for part in parts + ] querystring = "".join(parts) # Only unescaped wildcard characters should remain. - wildcard = re.search(r'(? bool: """Determine whether or not ``term`` contains a wildcard.""" - return (('*' in term or '?' in term) and not - (term.startswith('*') or term.startswith('?'))) + return ("*" in term or "?" in term) and not ( + term.startswith("*") or term.startswith("?") + ) def is_literal_query(term: str) -> bool: @@ -96,15 +121,15 @@ def is_old_papernum(term: str) -> bool: def strip_tex(term: str) -> str: """Remove TeX-isms from a term.""" - return re.sub(TEXISM, '', term).strip() + return re.sub(TEXISM, "", term).strip() -def Q_(qtype: str, field: str, value: str, operator: str = 'or') -> Q: +def Q_(qtype: str, field: str, value: str, operator: str = "or") -> Q: """Construct a :class:`.Q`, but handle wildcards first.""" value, wildcard = wildcard_escape(value) if wildcard: - return Q('wildcard', **{field: {'value': value.lower()}}) - if 'match' in qtype: + return Q("wildcard", **{field: {"value": value.lower()}}) + if "match" in qtype: return Q(qtype, **{field: value}) return Q(qtype, **{field: value}, operator=operator) @@ -112,7 +137,7 @@ def Q_(qtype: str, field: str, value: str, operator: str = 'or') -> Q: def escape(term: str, quotes: bool = False) -> str: """Escape special characters.""" escaped = [] - for i, char in enumerate(term): + for char in term: if char in SPECIAL_CHARACTERS or quotes and char == '"': escaped.append("\\") escaped.append(char) @@ -121,22 +146,27 @@ def escape(term: str, quotes: bool = False) -> str: def strip_punctuation(s: str) -> str: """Remove all punctuation characters from a string.""" - return ''.join([c for c in s if c not in punctuation]) + return "".join([c for c in s if c not in punctuation]) def remove_single_characters(term: str) -> str: """Remove any single characters in the search string.""" - return ' '.join([part for part in term.split() - if len(strip_punctuation(part)) > 1]) + return " ".join( + [part for part in term.split() if len(strip_punctuation(part)) > 1] + ) def sort(query: Query, search: Search) -> Search: """Apply sorting to a :class:`.Search`.""" if not query.order: - sort_params = DEFAULT_SORT + sort_params = consts.DEFAULT_SORT_ORDER else: - direction = '-' if query.order.startswith('-') else '' - sort_params = [query.order, f'{direction}paper_id_v'] + direction = ( + "-" + if isinstance(query.order, str) and query.order.startswith("-") + else "" + ) + sort_params = [query.order, f"{direction}paper_id_v"] # type:ignore if sort_params is not None: search = search.sort(*sort_params) return search @@ -163,16 +193,16 @@ def parse_date(term: str) -> Tuple[str, str]: Raised if no date-related information is found in `term`. """ - match = re.search(r'(?:^|[\s]+)([0-9]{4}-[0-9]{2})(?:$|[\s]+)', term) + match = re.search(r"(?:^|[\s]+)([0-9]{4}-[0-9]{2})(?:$|[\s]+)", term) if match: - remainder = term[:match.start()] + " " + term[match.end():] + remainder = term[: match.start()] + " " + term[match.end() :] return match.group(1), remainder.strip() - match = re.search(r'(?:^|[\s]+)([0-9]{4})(?:$|[\s]+)', term) - if match: # Looks like a year: - remainder = term[:match.start()] + " " + term[match.end():] + match = re.search(r"(?:^|[\s]+)([0-9]{4})(?:$|[\s]+)", term) + if match: # Looks like a year: + remainder = term[: match.start()] + " " + term[match.end() :] return match.group(1), remainder.strip() - raise ValueError('No date info detected') + raise ValueError("No date info detected") def parse_date_partial(term: str) -> Optional[str]: @@ -197,6 +227,6 @@ def parse_date_partial(term: str) -> Optional[str]: year, month = match.groups() # This should be fine until 2091. century = 19 if int(year) >= 91 else 20 - date_partial = f"{century}{year}-{month}" # year_month format in ES. + date_partial = f"{century}{year}-{month}" # year_month format in ES. return date_partial return None diff --git a/search/services/metadata.py b/search/services/metadata.py index 4de6bfb9..276033a3 100644 --- a/search/services/metadata.py +++ b/search/services/metadata.py @@ -10,21 +10,20 @@ depending on the context of the request. """ -from typing import Dict, List - -import os -from urllib.parse import urljoin +import ast import json +from typing import List +from http import HTTPStatus from itertools import cycle from functools import wraps +from urllib.parse import urljoin import requests from requests.packages.urllib3.util.retry import Retry -from arxiv import status -from search.context import get_application_config, get_application_global from arxiv.base import logging from search.domain import DocMeta +from search.context import get_application_config, get_application_global logger = logging.getLogger(__name__) @@ -66,25 +65,21 @@ def __init__(self, *endpoints: str, verify_cert: bool = True) -> None: self._session = requests.Session() self._verify_cert = verify_cert self._retry = Retry( # type: ignore - total=10, - read=10, - connect=10, - status=10, - backoff_factor=0.5 + total=10, read=10, connect=10, status=10, backoff_factor=0.5 ) self._adapter = requests.adapters.HTTPAdapter(max_retries=self._retry) - self._session.mount('https://', self._adapter) + self._session.mount("https://", self._adapter) for endpoint in endpoints: - if not endpoint[-1] == '/': - endpoint += '/' - logger.debug(f'New DocMeta session with endpoints {endpoints}') + if not endpoint[-1] == "/": + endpoint += "/" + logger.debug(f"New DocMeta session with endpoints {endpoints}") self._endpoints = cycle(endpoints) @property def endpoint(self) -> str: """Get a metadata endpoint.""" - logger.debug('get next endpoint') + logger.debug("get next endpoint") return self._endpoints.__next__() def retrieve(self, document_id: str) -> DocMeta: @@ -104,45 +99,49 @@ def retrieve(self, document_id: str) -> DocMeta: IOError ValueError """ - if not document_id: # This could use further elaboration. - raise ValueError('Invalid value for document_id') + if not document_id: # This could use further elaboration. + raise ValueError("Invalid value for document_id") try: - target = urljoin(self.endpoint, '/docmeta/') + target = urljoin(self.endpoint, "/docmeta/") target = urljoin(target, document_id) logger.debug( - f'{document_id}: retrieve metadata from {target} with SSL' - f' verify {self._verify_cert}' + f"{document_id}: retrieve metadata from {target} with SSL" + f" verify {self._verify_cert}" + ) + response = requests.get( + target, + verify=self._verify_cert, + headers={"User-Agent": "arXiv/system"}, ) - response = requests.get(target, verify=self._verify_cert, - headers={'User-Agent': 'arXiv/system'}) - except requests.exceptions.SSLError as e: - logger.error('SSLError: %s', e) - raise SecurityException('SSL failed: %s' % e) from e - except requests.exceptions.ConnectionError as e: - logger.error('ConnectionError: %s', e) + except requests.exceptions.SSLError as ex: + logger.error("SSLError: %s", ex) + raise SecurityException("SSL failed: %s" % ex) from ex + except requests.exceptions.ConnectionError as ex: + logger.error("ConnectionError: %s", ex) raise ConnectionFailed( - 'Could not connect to metadata service: %s' % e - ) from e - - if response.status_code not in \ - [status.HTTP_200_OK, status.HTTP_206_PARTIAL_CONTENT]: - logger.error('Request failed: %s', response.content) + "Could not connect to metadata service: %s" % ex + ) from ex + + if response.status_code not in [ + HTTPStatus.OK, + HTTPStatus.PARTIAL_CONTENT, + ]: + logger.error("Request failed: %s", response.content) raise RequestFailed( - '%s: failed with %i: %s' % ( - document_id, response.status_code, response.content - ) + "%s: failed with %i: %s" + % (document_id, response.status_code, response.content) ) - logger.debug(f'{document_id}: response OK') + logger.debug(f"{document_id}: response OK") try: - data = DocMeta(**response.json()) # type: ignore + data = DocMeta(**response.json()) # type: ignore # See https://github.com/python/mypy/issues/3937 - except json.decoder.JSONDecodeError as e: - logger.error('JSONDecodeError: %s', e) + except json.decoder.JSONDecodeError as ex: + logger.error("JSONDecodeError: %s", ex) raise BadResponse( - '%s: could not decode response: %s' % (document_id, e) - ) from e - logger.debug(f'{document_id}: response decoded; done!') + "%s: could not decode response: %s" % (document_id, ex) + ) from ex + logger.debug(f"{document_id}: response decoded; done!") return data def bulk_retrieve(self, document_ids: List[str]) -> List[DocMeta]: @@ -162,65 +161,68 @@ def bulk_retrieve(self, document_ids: List[str]) -> List[DocMeta]: IOError ValueError """ - if not document_ids: # This could use further elaboration. - raise ValueError('Invalid value for document_ids') + if not document_ids: # This could use further elaboration. + raise ValueError("Invalid value for document_ids") - query_string = '/docmeta_bulk?' + '&'.join( - f'id={document_id}' for document_id in document_ids + query_string = "/docmeta_bulk?" + "&".join( + f"id={document_id}" for document_id in document_ids ) try: target = urljoin(self.endpoint, query_string) logger.debug( - f'{document_ids}: retrieve metadata from {target} with SSL' - f' verify {self._verify_cert}' + f"{document_ids}: retrieve metadata from {target} with SSL" + f" verify {self._verify_cert}" ) response = self._session.get(target, verify=self._verify_cert) - except requests.exceptions.SSLError as e: - logger.error('SSLError: %s', e) - raise SecurityException('SSL failed: %s' % e) from e - except requests.exceptions.ConnectionError as e: - logger.error('ConnectionError: %s', e) + except requests.exceptions.SSLError as ex: + logger.error("SSLError: %s", ex) + raise SecurityException("SSL failed: %s" % ex) from ex + except requests.exceptions.ConnectionError as ex: + logger.error("ConnectionError: %s", ex) raise ConnectionFailed( - 'Could not connect to metadata service: %s' % e - ) from e - - if response.status_code not in \ - [status.HTTP_200_OK, status.HTTP_206_PARTIAL_CONTENT]: - logger.error('Request failed: %s', response.content) + "Could not connect to metadata service: %s" % ex + ) from ex + + if response.status_code not in [ + HTTPStatus.OK, + HTTPStatus.PARTIAL_CONTENT, + ]: + logger.error("Request failed: %s", response.content) raise RequestFailed( - '%s: failed with %i: %s' % ( - document_ids, response.status_code, response.content - ) + "%s: failed with %i: %s" + % (document_ids, response.status_code, response.content) ) - logger.debug(f'{document_ids}: response OK') + logger.debug(f"{document_ids}: response OK") try: resp = response.json() # A list with metadata for each paper. data: List[DocMeta] - data = [DocMeta(**value) for value in resp] # type: ignore - except json.decoder.JSONDecodeError as e: - logger.error('JSONDecodeError: %s', e) + data = [DocMeta(**value) for value in resp] # type: ignore + except json.decoder.JSONDecodeError as ex: + logger.error("JSONDecodeError: %s", ex) raise BadResponse( - '%s: could not decode response: %s' % (document_ids, e) - ) from e - logger.debug(f'{document_ids}: response decoded; done!') + "%s: could not decode response: %s" % (document_ids, ex) + ) from ex + logger.debug(f"{document_ids}: response decoded; done!") return data def init_app(app: object = None) -> None: """Set default configuration parameters for an application instance.""" config = get_application_config(app) - config.setdefault('METADATA_ENDPOINT', 'https://arxiv.org/') - config.setdefault('METADATA_VERIFY_CERT', 'True') + config.setdefault("METADATA_ENDPOINT", "https://arxiv.org/") + config.setdefault("METADATA_VERIFY_CERT", "True") def get_session(app: object = None) -> DocMetaSession: """Get a new session with the docmeta endpoint.""" config = get_application_config(app) - endpoint = config.get('METADATA_ENDPOINT', 'https://arxiv.org/') - verify_cert = bool(eval(config.get('METADATA_VERIFY_CERT', 'True'))) - if ',' in endpoint: - return DocMetaSession(*(endpoint.split(',')), verify_cert=verify_cert) + endpoint = config.get("METADATA_ENDPOINT", "https://arxiv.org/") + verify_cert = bool( + ast.literal_eval(config.get("METADATA_VERIFY_CERT", "True")) + ) + if "," in endpoint: + return DocMetaSession(*(endpoint.split(",")), verify_cert=verify_cert) return DocMetaSession(endpoint, verify_cert=verify_cert) @@ -229,9 +231,9 @@ def current_session() -> DocMetaSession: g = get_application_global() if not g: return get_session() - elif 'docmeta' not in g: - g.docmeta = get_session() # type: ignore - return g.docmeta # type: ignore + elif "docmeta" not in g: + g.docmeta = get_session() # type: ignore + return g.docmeta # type: ignore @wraps(DocMetaSession.retrieve) diff --git a/search/services/tests/test_fulltext.py b/search/services/tests/test_fulltext.py index 93ba9711..384e1e79 100644 --- a/search/services/tests/test_fulltext.py +++ b/search/services/tests/test_fulltext.py @@ -8,16 +8,18 @@ class TestRetrieveExistantContent(unittest.TestCase): """Fulltext content is available for a paper.""" - @mock.patch('search.services.fulltext.requests.get') + @mock.patch("search.services.fulltext.requests.get") def test_calls_fulltext_endpoint(self, mock_get): """:func:`.fulltext.retrieve` calls passed endpoint with GET.""" - base = 'https://asdf.com/' + base = "https://asdf.com/" response = mock.MagicMock() - type(response).json = mock.MagicMock(return_value={ - 'content': 'The whole story', - 'version': 0.1, - 'created': '2017-08-30T08:24:58.525923' - }) + type(response).json = mock.MagicMock( + return_value={ + "content": "The whole story", + "version": 0.1, + "created": "2017-08-30T08:24:58.525923", + } + ) response.status_code = 200 mock_get.return_value = response @@ -25,9 +27,9 @@ def test_calls_fulltext_endpoint(self, mock_get): fulltext_session.endpoint = base try: - fulltext_session.retrieve('1234.5678v3') - except Exception as e: - self.fail('Choked on valid response: %s' % e) + fulltext_session.retrieve("1234.5678v3") + except Exception as ex: + self.fail("Choked on valid response: %s" % ex) args, _ = mock_get.call_args self.assertTrue(args[0].startswith(base)) @@ -35,7 +37,7 @@ def test_calls_fulltext_endpoint(self, mock_get): class TestRetrieveNonexistantRecord(unittest.TestCase): """Fulltext content is not available for a paper.""" - @mock.patch('search.services.fulltext.requests.get') + @mock.patch("search.services.fulltext.requests.get") def test_raise_ioerror_on_404(self, mock_get): """:func:`.fulltext.retrieve` raises IOError when text unvailable.""" response = mock.MagicMock() @@ -43,9 +45,9 @@ def test_raise_ioerror_on_404(self, mock_get): response.status_code = 404 mock_get.return_value = response with self.assertRaises(IOError): - fulltext.retrieve('1234.5678v3') + fulltext.retrieve("1234.5678v3") - @mock.patch('search.services.fulltext.requests.get') + @mock.patch("search.services.fulltext.requests.get") def test_raise_ioerror_on_503(self, mock_get): """:func:`.fulltext.retrieve` raises IOError when text unvailable.""" response = mock.MagicMock() @@ -53,38 +55,40 @@ def test_raise_ioerror_on_503(self, mock_get): response.status_code = 503 mock_get.return_value = response with self.assertRaises(IOError): - fulltext.retrieve('1234.5678v3') + fulltext.retrieve("1234.5678v3") - @mock.patch('search.services.fulltext.requests.get') + @mock.patch("search.services.fulltext.requests.get") def test_raise_ioerror_on_sslerror(self, mock_get): """:func:`.fulltext.retrieve` raises IOError when SSL fails.""" from requests.exceptions import SSLError + mock_get.side_effect = SSLError with self.assertRaises(IOError): try: - fulltext.retrieve('1234.5678v3') - except Exception as e: - if type(e) is SSLError: - self.fail('Should not return dependency exception') + fulltext.retrieve("1234.5678v3") + except Exception as ex: + if type(ex) is SSLError: + self.fail("Should not return dependency exception") raise class TestRetrieveMalformedRecord(unittest.TestCase): """Fulltext endpoint returns non-JSON response.""" - @mock.patch('search.services.fulltext.requests.get') + @mock.patch("search.services.fulltext.requests.get") def test_response_is_not_json(self, mock_get): """:func:`.fulltext.retrieve` raises IOError when not valid JSON.""" from json.decoder import JSONDecodeError + response = mock.MagicMock() # Ideally we would pass the exception itself as a side_effect, but it # doesn't have the expected signature. def raise_decodeerror(*args, **kwargs): - raise JSONDecodeError('Nope', 'Nope', 0) + raise JSONDecodeError("Nope", "Nope", 0) type(response).json = mock.MagicMock(side_effect=raise_decodeerror) response.status_code = 200 mock_get.return_value = response with self.assertRaises(IOError): - fulltext.retrieve('1234.5678v3') + fulltext.retrieve("1234.5678v3") diff --git a/search/services/tests/test_metadata.py b/search/services/tests/test_metadata.py index 2e56d874..562bc39e 100644 --- a/search/services/tests/test_metadata.py +++ b/search/services/tests/test_metadata.py @@ -1,10 +1,8 @@ """Tests for :mod:`search.services.metadata`.""" +import json import unittest from unittest import mock -import json -import os -from itertools import cycle from search.services import metadata from search.factory import create_ui_web_app @@ -13,16 +11,16 @@ class TestRetrieveExistantMetadata(unittest.TestCase): """Metadata is available for a paper.""" - @mock.patch('search.services.metadata.requests.get') + @mock.patch("search.services.metadata.requests.get") def test_calls_metadata_endpoint(self, mock_get): """:func:`.metadata.retrieve` calls passed endpoint with GET.""" - base = 'https://asdf.com/' + base = "https://asdf.com/" app = create_ui_web_app() - app.config['METADATA_ENDPOINT'] = base + app.config["METADATA_ENDPOINT"] = base response = mock.MagicMock() - with open('tests/data/docmeta.json') as f: + with open("tests/data/docmeta.json") as f: mock_content = json.load(f) type(response).json = mock.MagicMock(return_value=mock_content) @@ -33,26 +31,26 @@ def test_calls_metadata_endpoint(self, mock_get): docmeta_session = metadata.get_session() try: - docmeta_session.retrieve('1602.00123') - except Exception as e: - self.fail('Choked on valid response: %s' % e) + docmeta_session.retrieve("1602.00123") + except Exception as ex: + self.fail("Choked on valid response: %s" % ex) try: args, _ = mock_get.call_args - except Exception as e: - self.fail('Did not call requests.get as expected: %s' % e) + except Exception as ex: + self.fail("Did not call requests.get as expected: %s" % ex) self.assertTrue(args[0].startswith(base)) - @mock.patch('search.services.metadata.requests.get') + @mock.patch("search.services.metadata.requests.get") def test_calls_metadata_endpoint_roundrobin(self, mock_get): """:func:`.metadata.retrieve` calls passed endpoint with GET.""" - base = ['https://asdf.com/', 'https://asdf2.com/'] + base = ["https://asdf.com/", "https://asdf2.com/"] app = create_ui_web_app() - app.config['METADATA_ENDPOINT'] = ','.join(base) - app.config['METADATA_VERIFY_CERT'] = 'False' + app.config["METADATA_ENDPOINT"] = ",".join(base) + app.config["METADATA_VERIFY_CERT"] = "False" response = mock.MagicMock() - with open('tests/data/docmeta.json') as f: + with open("tests/data/docmeta.json") as f: mock_content = json.load(f) type(response).json = mock.MagicMock(return_value=mock_content) @@ -63,25 +61,25 @@ def test_calls_metadata_endpoint_roundrobin(self, mock_get): docmeta_session = metadata.get_session() try: - docmeta_session.retrieve('1602.00123') - except Exception as e: - self.fail('Choked on valid response: %s' % e) + docmeta_session.retrieve("1602.00123") + except Exception as ex: + self.fail("Choked on valid response: %s" % ex) try: args, _ = mock_get.call_args - except Exception as e: - self.fail('Did not call requests.get as expected: %s' % e) + except Exception as ex: + self.fail("Did not call requests.get as expected: %s" % ex) self.assertTrue( args[0].startswith(base[0]), "Expected call to %s" % base[0] ) try: - docmeta_session.retrieve('1602.00124') - except Exception as e: - self.fail('Choked on valid response: %s' % e) + docmeta_session.retrieve("1602.00124") + except Exception as ex: + self.fail("Choked on valid response: %s" % ex) try: args, _ = mock_get.call_args - except Exception as e: - self.fail('Did not call requests.get as expected: %s' % e) + except Exception as ex: + self.fail("Did not call requests.get as expected: %s" % ex) self.assertTrue( args[0].startswith(base[1]), "Expected call to %s" % base[1] ) @@ -90,7 +88,7 @@ def test_calls_metadata_endpoint_roundrobin(self, mock_get): class TestRetrieveNonexistantRecord(unittest.TestCase): """Metadata is not available for a paper.""" - @mock.patch('search.services.metadata.requests.get') + @mock.patch("search.services.metadata.requests.get") def test_raise_ioerror_on_404(self, mock_get): """:func:`.metadata.retrieve` raises IOError when unvailable.""" response = mock.MagicMock() @@ -98,9 +96,9 @@ def test_raise_ioerror_on_404(self, mock_get): response.status_code = 404 mock_get.return_value = response with self.assertRaises(IOError): - metadata.retrieve('1234.5678v3') + metadata.retrieve("1234.5678v3") - @mock.patch('search.services.metadata.requests.get') + @mock.patch("search.services.metadata.requests.get") def test_raise_ioerror_on_503(self, mock_get): """:func:`.metadata.retrieve` raises IOError when unvailable.""" response = mock.MagicMock() @@ -108,38 +106,40 @@ def test_raise_ioerror_on_503(self, mock_get): response.status_code = 503 mock_get.return_value = response with self.assertRaises(IOError): - metadata.retrieve('1234.5678v3') + metadata.retrieve("1234.5678v3") - @mock.patch('search.services.metadata.requests.get') + @mock.patch("search.services.metadata.requests.get") def test_raise_ioerror_on_sslerror(self, mock_get): """:func:`.metadata.retrieve` raises IOError when SSL fails.""" from requests.exceptions import SSLError + mock_get.side_effect = SSLError with self.assertRaises(IOError): try: - metadata.retrieve('1234.5678v3') - except Exception as e: - if type(e) is SSLError: - self.fail('Should not return dependency exception') + metadata.retrieve("1234.5678v3") + except Exception as ex: + if type(ex) is SSLError: + self.fail("Should not return dependency exception") raise class TestRetrieveMalformedRecord(unittest.TestCase): """Metadata endpoint returns non-JSON response.""" - @mock.patch('search.services.metadata.requests.get') + @mock.patch("search.services.metadata.requests.get") def test_response_is_not_json(self, mock_get): """:func:`.metadata.retrieve` raises IOError when not valid JSON.""" from json.decoder import JSONDecodeError + response = mock.MagicMock() # Ideally we would pass the exception itself as a side_effect, but it # doesn't have the expected signature. def raise_decodeerror(*args, **kwargs): - raise JSONDecodeError('Nope', 'Nope', 0) + raise JSONDecodeError("Nope", "Nope", 0) type(response).json = mock.MagicMock(side_effect=raise_decodeerror) response.status_code = 200 mock_get.return_value = response with self.assertRaises(IOError): - metadata.retrieve('1234.5678v3') + metadata.retrieve("1234.5678v3") diff --git a/search/templates/search/advanced_search.html b/search/templates/search/advanced_search.html index f3baac09..4f75f4ee 100644 --- a/search/templates/search/advanced_search.html +++ b/search/templates/search/advanced_search.html @@ -324,7 +324,7 @@ {% block title %} {% if not show_form and results %} - Showing {{ metadata.start + 1 }}–{{ metadata.end }} of {{ '{0:,}'.format(metadata.total) }} results + Showing {{ metadata.start + 1 }}–{{ metadata.end }} of {{ '{0:,}'.format(metadata.total_results) }} results {% elif show_form %} Advanced Search {% else %} diff --git a/search/templates/search/base.html b/search/templates/search/base.html index de259d43..0763a231 100644 --- a/search/templates/search/base.html +++ b/search/templates/search/base.html @@ -28,7 +28,7 @@ }, fieldValues: { "components": ["16000"], // Search component. - "versions": ["14157"], // Release search-0.5 + "versions": ["14260"], // Release search-0.5.6 "customfield_11401": window.location.href } }; diff --git a/search/templates/search/search.html b/search/templates/search/search.html index 5e75e67f..4d21f11b 100644 --- a/search/templates/search/search.html +++ b/search/templates/search/search.html @@ -4,7 +4,7 @@ {% block title %} {% if results %} - Showing {{ metadata.start + 1 }}–{{ metadata.end }} of {{ '{0:,}'.format(metadata.total) }} results for {{ query.search_field }}: {{ query.value }} + Showing {{ metadata.start + 1 }}–{{ metadata.end }} of {{ '{0:,}'.format(metadata.total_results) }} results for {{ query.search_field }}: {{ query.value }} {% else %} Search {% endif %} diff --git a/search/tests/mocks.py b/search/tests/mocks.py new file mode 100644 index 00000000..ba78a41c --- /dev/null +++ b/search/tests/mocks.py @@ -0,0 +1,46 @@ +from datetime import datetime + + +def document(): + """Return a mock document.""" + return { + "submitted_date": datetime.now(), + "submitted_date_first": datetime.now(), + "announced_date_first": datetime.now(), + "id": "1234.5678", + "abstract": "very abstract", + "authors": [{"full_name": "F. Bar", "orcid": "1234-5678-9012-3456"}], + "submitter": {"full_name": "S. Ubmitter", "author_id": "su_1"}, + "modified_date": datetime.now(), + "updated_date": datetime.now(), + "is_current": True, + "is_withdrawn": False, + "license": {"uri": "http://foo.license/1", "label": "Notalicense 5.4"}, + "paper_id": "1234.5678", + "paper_id_v": "1234.5678v6", + "title": "tiiiitle", + "source": {"flags": "A", "format": "pdftotex", "size_bytes": 2}, + "version": 6, + "latest": "1234.5678v6", + "latest_version": 6, + "report_num": "somenum1", + "msc_class": ["c1"], + "acm_class": ["z2"], + "journal_ref": "somejournal (1991): 2-34", + "doi": "10.123456/7890", + "comments": "very science", + "abs_categories": "astro-ph.CO foo.BR", + "formats": ["pdf", "other"], + "primary_classification": { + "group": {"id": "foo", "name": "Foo Group"}, + "archive": {"id": "foo", "name": "Foo Archive"}, + "category": {"id": "foo.BR", "name": "Foo Category"}, + }, + "secondary_classification": [ + { + "group": {"id": "foo", "name": "Foo Group"}, + "archive": {"id": "foo", "name": "Foo Archive"}, + "category": {"id": "foo.BZ", "name": "Baz Category"}, + } + ], + } diff --git a/search/tests/test_advanced_search.py b/search/tests/test_advanced_search.py index 72dba16a..8793dad4 100644 --- a/search/tests/test_advanced_search.py +++ b/search/tests/test_advanced_search.py @@ -1,6 +1,7 @@ -from unittest import TestCase, mock +from http import HTTPStatus +from unittest import TestCase -from arxiv import taxonomy, status +from arxiv import taxonomy from search.factory import create_ui_web_app @@ -15,12 +16,18 @@ def setUp(self): def test_archive_shortcut(self): """User requests a sub-path with classification archive.""" for archive in taxonomy.ARCHIVES.keys(): - response = self.client.get(f'/advanced/{archive}') - self.assertEqual(response.status_code, status.HTTP_200_OK, - "Should support shortcut for archive {archive}") + response = self.client.get(f"/advanced/{archive}") + self.assertEqual( + response.status_code, + HTTPStatus.OK, + "Should support shortcut for archive {archive}", + ) def test_nonexistant_archive_shortcut(self): """User requests a sub-path with non-existant archive.""" - response = self.client.get('/advanced/fooarchive') - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND, - "Should return a 404 error") + response = self.client.get("/advanced/fooarchive") + self.assertEqual( + response.status_code, + HTTPStatus.NOT_FOUND, + "Should return a 404 error", + ) diff --git a/search/tests/test_param_persistence.py b/search/tests/test_param_persistence.py index faae8170..285f3927 100644 --- a/search/tests/test_param_persistence.py +++ b/search/tests/test_param_persistence.py @@ -1,8 +1,10 @@ """Tests related to the persistence of search parameters in a cookie.""" -from unittest import TestCase, mock import json +from unittest import TestCase, mock + from search.factory import create_ui_web_app +from search.controllers.simple.forms import SimpleSearchForm from search.routes import ui @@ -16,45 +18,61 @@ def setUp(self): def test_request_includes_params(self): """A request is made with parameters indicated for persistence.""" - ui.PARAMS_TO_PERSIST = ['foo', 'baz'] - ui.PARAMS_COOKIE_NAME = 'foo-cookie' - response = self.client.get('/?foo=bar&baz=bat') + ui.PARAMS_TO_PERSIST = ["foo", "baz"] + ui.PARAMS_COOKIE_NAME = "foo-cookie" + response = self.client.get("/?foo=bar&baz=bat") - self.assertIn('Set-Cookie', response.headers, "Should set a cookie") - expected = 'foo-cookie="{\\"foo\\": \\"bar\\"\\054 \\"baz\\": \\"bat\\"}"; Path=/' - self.assertEqual(response.headers['Set-Cookie'], expected, - "Cookie should contain request params") + self.assertIn("Set-Cookie", response.headers, "Should set a cookie") + expected = ( + 'foo-cookie="{\\"foo\\": \\"bar\\"\\054 \\"baz\\": \\"bat\\"}"; ' + "Path=/" + ) + self.assertEqual( + response.headers["Set-Cookie"], + expected, + "Cookie should contain request params", + ) def test_request_does_not_include_params(self): """The request does not include persistable params.""" - ui.PARAMS_TO_PERSIST = ['foo', 'baz'] - ui.PARAMS_COOKIE_NAME = 'foo-cookie' - response = self.client.get('/?nope=nope') - self.assertIn('Set-Cookie', response.headers, "Should set a cookie") - self.assertEqual(response.headers['Set-Cookie'], - 'foo-cookie="{}"; Path=/', - "Cookie should not contain request params") - - @mock.patch('search.routes.ui.simple') + ui.PARAMS_TO_PERSIST = ["foo", "baz"] + ui.PARAMS_COOKIE_NAME = "foo-cookie" + response = self.client.get("/?nope=nope") + self.assertIn("Set-Cookie", response.headers, "Should set a cookie") + self.assertEqual( + response.headers["Set-Cookie"], + 'foo-cookie="{}"; Path=/', + "Cookie should not contain request params", + ) + + @mock.patch("search.routes.ui.simple") def test_request_includes_cookie(self, mock_simple): """The request includes the params cookie.""" - mock_simple.search.return_value = '', 200, {} - ui.PARAMS_TO_PERSIST = ['foo', 'baz'] - ui.PARAMS_COOKIE_NAME = 'foo-cookie' - self.client.set_cookie('', ui.PARAMS_COOKIE_NAME, - json.dumps({'foo': 'ack'})) - self.client.get('/') - self.assertEqual(mock_simple.search.call_args[0][0]['foo'], 'ack', - 'The value in the cookie should be used') - - @mock.patch('search.routes.ui.simple') + mock_simple.search.return_value = {"form": SimpleSearchForm()}, 200, {} + ui.PARAMS_TO_PERSIST = ["foo", "baz"] + ui.PARAMS_COOKIE_NAME = "foo-cookie" + self.client.set_cookie( + "", ui.PARAMS_COOKIE_NAME, json.dumps({"foo": "ack"}) + ) + self.client.get("/") + self.assertEqual( + mock_simple.search.call_args[0][0]["foo"], + "ack", + "The value in the cookie should be used", + ) + + @mock.patch("search.routes.ui.simple") def test_request_includes_cookie_but_also_explicit_val(self, mock_simple): """The request includes the cookie, but also an explicit value.""" - mock_simple.search.return_value = '', 200, {} - ui.PARAMS_TO_PERSIST = ['foo', 'baz'] - ui.PARAMS_COOKIE_NAME = 'foo-cookie' - self.client.set_cookie('', ui.PARAMS_COOKIE_NAME, - json.dumps({'foo': 'ack'})) - self.client.get('/?foo=oof') - self.assertEqual(mock_simple.search.call_args[0][0]['foo'], 'oof', - 'The explicit value should be used') + mock_simple.search.return_value = {"form": SimpleSearchForm()}, 200, {} + ui.PARAMS_TO_PERSIST = ["foo", "baz"] + ui.PARAMS_COOKIE_NAME = "foo-cookie" + self.client.set_cookie( + "", ui.PARAMS_COOKIE_NAME, json.dumps({"foo": "ack"}) + ) + self.client.get("/?foo=oof") + self.assertEqual( + mock_simple.search.call_args[0][0]["foo"], + "oof", + "The explicit value should be used", + ) diff --git a/setup.cfg b/setup.cfg index 2539e514..05677b17 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,3 +1,6 @@ +[flake8] +ignore = E203,W503 + [pydocstyle] convention = numpy -add-ignore = D401 +add-ignore = D100,D101,D102,D103,D104,D202,D401 diff --git a/tests/base_app_tests.py b/tests/base_app_tests.py new file mode 100644 index 00000000..18080cf3 --- /dev/null +++ b/tests/base_app_tests.py @@ -0,0 +1,14 @@ +""" +Run :mod:`arxiv.base.app_tests`. + +These are run separately from the rest of the tests in :mod:`search`. +""" + +import unittest +from search.factory import create_ui_web_app + +app = create_ui_web_app() +app.app_context().push() + +if __name__ == "__main__": + unittest.main() diff --git a/tests/stubs/docmeta.py b/tests/stubs/docmeta.py index 7832e7c7..dab16bf0 100644 --- a/tests/stubs/docmeta.py +++ b/tests/stubs/docmeta.py @@ -11,26 +11,26 @@ logger = logging.getLogger(__name__) -METADATA_DIR = os.environ.get('METADATA_DIR') +METADATA_DIR = os.environ.get("METADATA_DIR") -app = Flask('metadata') +app = Flask("metadata") Base(app) -app.url_map.converters['arxiv'] = ArXivConverter +app.url_map.converters["arxiv"] = ArXivConverter -@app.route('/docmeta/', methods=["GET"]) +@app.route("/docmeta/", methods=["GET"]) def docmeta(document_id): """Retrieve document metadata.""" - logger.debug(f'Get metadata for {document_id}') - logger.debug(f'Metadata base is {METADATA_DIR}') + logger.debug(f"Get metadata for {document_id}") + logger.debug(f"Metadata base is {METADATA_DIR}") if not METADATA_DIR: - raise InternalServerError('Metadata directory not set') + raise InternalServerError("Metadata directory not set") metadata_path = os.path.join(METADATA_DIR, f"{document_id}.json") - logger.debug(f'Metadata path is {metadata_path}') + logger.debug(f"Metadata path is {metadata_path}") if not os.path.exists(metadata_path): - raise NotFound('No such document') + raise NotFound("No such document") with open(metadata_path) as f: return jsonify(json.load(f)) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 828e9ba9..d1f0f491 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -1,11 +1,12 @@ """Tests exception handling in :mod:`arxiv.base.exceptions`.""" +from http import HTTPStatus from unittest import TestCase, mock -from flask import Flask -from arxiv import status -from search.factory import create_ui_web_app from werkzeug.exceptions import InternalServerError + +from search.controllers import simple +from search.factory import create_ui_web_app from search.services.index import IndexConnectionError, QueryError @@ -19,45 +20,44 @@ def setUp(self): def test_404(self): """A 404 response should be returned.""" - response = self.client.get('/foo') - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - self.assertIn('text/html', response.content_type) + response = self.client.get("/foo") + self.assertEqual(response.status_code, HTTPStatus.NOT_FOUND) + self.assertIn("text/html", response.content_type) def test_405(self): """A 405 response should be returned.""" - response = self.client.post('/') - self.assertEqual(response.status_code, - status.HTTP_405_METHOD_NOT_ALLOWED) - self.assertIn('text/html', response.content_type) + response = self.client.post("/") + self.assertEqual(response.status_code, HTTPStatus.METHOD_NOT_ALLOWED) + self.assertIn("text/html", response.content_type) - @mock.patch('search.controllers.simple.search') + @mock.patch("search.controllers.simple.search") def test_500(self, mock_search): """A 500 response should be returned.""" # Raise an internal server error from the search controller. mock_search.side_effect = InternalServerError - response = self.client.get('/') - self.assertEqual(response.status_code, - status.HTTP_500_INTERNAL_SERVER_ERROR) - self.assertIn('text/html', response.content_type) + response = self.client.get("/") + self.assertEqual( + response.status_code, HTTPStatus.INTERNAL_SERVER_ERROR + ) + self.assertIn("text/html", response.content_type) - @mock.patch('search.controllers.simple.index') - def test_index_connection_error(self, mock_index): + @mock.patch(f"{simple.__name__}.SearchSession.search") + def test_index_connection_error(self, mock_search): """When an IndexConnectionError occurs, an error page is displayed.""" - mock_index.IndexConnectionError = IndexConnectionError - mock_index.search.side_effect = IndexConnectionError - response = self.client.get('/?searchtype=title&query=foo') - self.assertEqual(response.status_code, - status.HTTP_500_INTERNAL_SERVER_ERROR) - self.assertIn('text/html', response.content_type) - - @mock.patch('search.controllers.simple.index') - def test_query_error(self, mock_index): + mock_search.side_effect = IndexConnectionError + response = self.client.get("/?searchtype=title&query=foo") + self.assertEqual( + response.status_code, HTTPStatus.INTERNAL_SERVER_ERROR + ) + self.assertIn("text/html", response.content_type) + + @mock.patch(f"{simple.__name__}.SearchSession.search") + def test_query_error(self, mock_search): """When a QueryError occurs, an error page is displayed.""" - mock_index.IndexConnectionError = IndexConnectionError - mock_index.QueryError = QueryError - mock_index.search.side_effect = QueryError - response = self.client.get('/?searchtype=title&query=foo') - self.assertEqual(response.status_code, - status.HTTP_500_INTERNAL_SERVER_ERROR) - self.assertIn('text/html', response.content_type) + mock_search.side_effect = QueryError + response = self.client.get("/?searchtype=title&query=foo") + self.assertEqual( + response.status_code, HTTPStatus.INTERNAL_SERVER_ERROR + ) + self.assertIn("text/html", response.content_type) diff --git a/upload_static_assets.py b/upload_static_assets.py index a3cf3fae..50fa3b18 100644 --- a/upload_static_assets.py +++ b/upload_static_assets.py @@ -6,4 +6,4 @@ app = create_ui_web_app() # TODO: need a better way to exclude sass directories at every level -flask_s3.create_all(app, filepath_filter_regex=r'(base|css|images|js)') +flask_s3.create_all(app, filepath_filter_regex=r"(base|css|images|js)") diff --git a/wsgi-api.py b/wsgi-api.py index 1d41efba..3c76e577 100644 --- a/wsgi-api.py +++ b/wsgi-api.py @@ -1,10 +1,10 @@ -"""Web Server Gateway Interface entry-point.""" +"""Web Server Gateway Interface entry-point for API.""" from search.factory import create_api_web_app import os -__flask_app__ = create_api_web_app() +__flask_app__ = None def application(environ, start_response): @@ -15,8 +15,12 @@ def application(environ, start_response): # be a container ID, which is not helpful for things like building # URLs. We want to keep ``SERVER_NAME`` explicitly configured, either # in config.py or via an os.environ var loaded by config.py. - if key == 'SERVER_NAME': + if key == "SERVER_NAME": continue - os.environ[key] = str(value) - __flask_app__.config[key] = str(value) + if type(value) is str: + os.environ[key] = value + global __flask_app__ + if __flask_app__ is None: + __flask_app__ = create_api_web_app() + return __flask_app__(environ, start_response) diff --git a/wsgi-classic-api.py b/wsgi-classic-api.py new file mode 100644 index 00000000..33e1ddf4 --- /dev/null +++ b/wsgi-classic-api.py @@ -0,0 +1,26 @@ +"""Web Server Gateway Interface entry-point for classic API.""" + +from search.factory import create_classic_api_web_app +import os + + +__flask_app__ = None + + +def application(environ, start_response): + """WSGI application factory.""" + for key, value in environ.items(): + # In some deployment scenarios (e.g. uWSGI on k8s), uWSGI will pass in + # the hostname as part of the request environ. This will usually just + # be a container ID, which is not helpful for things like building + # URLs. We want to keep ``SERVER_NAME`` explicitly configured, either + # in config.py or via an os.environ var loaded by config.py. + if key == "SERVER_NAME": + continue + if type(value) is str: + os.environ[key] = value + global __flask_app__ + if __flask_app__ is None: + __flask_app__ = create_classic_api_web_app() + + return __flask_app__(environ, start_response) diff --git a/wsgi.py b/wsgi.py index 600c2530..dea38631 100644 --- a/wsgi.py +++ b/wsgi.py @@ -1,8 +1,11 @@ -"""Web Server Gateway Interface entry-point.""" +"""Web Server Gateway Interface entry-point for UI.""" -from search.factory import create_ui_web_app import os +from arxiv.base import logging + +from search.factory import create_ui_web_app +logger = logging.getLogger(__name__) __flask_app__ = None @@ -10,14 +13,22 @@ def application(environ, start_response): """WSGI application factory.""" for key, value in environ.items(): + # Copy string WSGI environ to os.environ. This is to get apache + # SetEnv vars. It needs to be done before the call to + # create_web_app() due to how config is setup from os in + # search/config.py. + # # In some deployment scenarios (e.g. uWSGI on k8s), uWSGI will pass in # the hostname as part of the request environ. This will usually just # be a container ID, which is not helpful for things like building # URLs. We want to keep ``SERVER_NAME`` explicitly configured, either # in config.py or via an os.environ var loaded by config.py. - if key == 'SERVER_NAME': + if key == "SERVER_NAME": continue - os.environ[key] = str(value) + if type(value) is str: + os.environ[key] = value + # 'global' actually means module scope, and that is exactly what + # we want here. global __flask_app__ if __flask_app__ is None: __flask_app__ = create_ui_web_app()