diff --git a/.travis.yml b/.travis.yml index 017ca611fd..83e9b1409b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,7 +9,7 @@ python: before_install: - cd flowclient install: - - pip install . + - pip install .[test] # command to run tests script: - pytest \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 778508a0f3..0bc472d228 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). - Added new Flows type query to FlowAPI `unique_locations`, which produces the paired regional connectivity [COVID-19 indicator](https://github.com/Flowminder/COVID-19/blob/master/od_matrix_undirected_all_pairs.md) - Added FlowClient function `unique_locations_spec`, which can be used on either side of a `flows` query - Added FlowClient functions: `unique_visitor_counts`, `active_at_reference_location_counts`, `unmoving_counts`, `unmoving_at_reference_location_counts`, `trips_od_matrix`, and `consecutive_trips_od_matrix`. [#2333](https://github.com/Flowminder/FlowKit/issues/2333) +- FlowClient now has an asyncio API. Use `connect_async` instead of `connect` to create an `ASyncConnection`, and `await` methods on `APIQuery` objects. [#2199](https://github.com/Flowminder/FlowKit/issues/2199) ### Fixed - Fixed FlowMachine server becoming deadlocked under load. [#2390](https://github.com/Flowminder/FlowKit/issues/2390) diff --git a/flowclient/Pipfile b/flowclient/Pipfile index 2625cd5dfc..a7219c5653 100644 --- a/flowclient/Pipfile +++ b/flowclient/Pipfile @@ -12,7 +12,9 @@ tqdm = "*" [dev-packages] pytest = "*" -"pytest-cov" = "*" +pytest-asyncio = "*" +pytest-cov = "*" +asynctest = "*" versioneer = "*" black = "==19.10b0" ipython = "*" diff --git a/flowclient/Pipfile.lock b/flowclient/Pipfile.lock index 7604f39a3f..7b01bd3d7f 100644 --- a/flowclient/Pipfile.lock +++ b/flowclient/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "8941bdfdbeb2ccdd876f86dc0913c5214e241d69ffb3aa269b6517817d766e68" + "sha256": "9f0c5efc4b89aca5df70f7b758e5197add4b0b69da95d8139ce425ac92655c30" }, "pipfile-spec": 6, "requires": { @@ -162,6 +162,14 @@ "markers": "sys_platform == 'darwin'", "version": "==0.1.0" }, + "asynctest": { + "hashes": [ + "sha256:5da6118a7e6d6b54d83a8f7197769d046922a44d2a99c21382f0a6e4fadae676", + "sha256:c27862842d15d83e6a34eb0b2866c323880eb3a75e4485b079ea11748fd77fac" + ], + "index": "pypi", + "version": "==0.13.0" + }, "attrs": { "hashes": [ "sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c", @@ -364,6 +372,14 @@ "index": "pypi", "version": "==5.4.1" }, + "pytest-asyncio": { + "hashes": [ + "sha256:9fac5100fd716cbecf6ef89233e8590a4ad61d729d1732e0a96b84182df1daaf", + "sha256:d734718e25cfc32d2bf78d346e99d33724deeba774cc4afdf491530c6184b63b" + ], + "index": "pypi", + "version": "==0.10.0" + }, "pytest-cov": { "hashes": [ "sha256:cc6742d8bac45070217169f5f72ceee1e0e55b0221f54bcf24845972d3a47f2b", diff --git a/flowclient/flowclient/__init__.py b/flowclient/flowclient/__init__.py index a8e59ea713..4be6141d47 100644 --- a/flowclient/flowclient/__init__.py +++ b/flowclient/flowclient/__init__.py @@ -10,12 +10,16 @@ __version__ = get_versions()["version"] del get_versions +from flowclient.api_query import APIQuery +from .connection import Connection +from flowclient.client import connect + +from flowclient.async_api_query import ASyncAPIQuery +from .async_connection import ASyncConnection +from flowclient.async_client import connect_async + + from .client import ( - Connection, - connect, - daily_location_spec, - modal_location_spec, - modal_location_from_dates_spec, get_geography, get_result, get_result_by_query_id, @@ -24,11 +28,16 @@ query_is_ready, run_query, get_available_dates, +) +from .query_specs import ( + daily_location_spec, + modal_location_spec, + modal_location_from_dates_spec, radius_of_gyration_spec, unique_location_counts_spec, + topup_balance_spec, subscriber_degree_spec, topup_amount_spec, - topup_balance_spec, event_count_spec, displacement_spec, pareto_interactions_spec, @@ -37,7 +46,6 @@ random_sample_spec, unique_locations_spec, ) -from .api_query import APIQuery from . import aggregates from .aggregates import ( location_event_counts, @@ -61,11 +69,8 @@ __all__ = [ "aggregates", - "Connection", + "connect_async", "connect", - "daily_location_spec", - "modal_location_spec", - "modal_location_from_dates_spec", "get_geography", "get_result", "get_result_by_query_id", @@ -74,18 +79,8 @@ "query_is_ready", "run_query", "get_available_dates", - "radius_of_gyration_spec", - "unique_location_counts_spec", - "subscriber_degree_spec", - "topup_amount_spec", - "topup_balance_spec", - "event_count_spec", - "displacement_spec", - "pareto_interactions_spec", - "nocturnal_events_spec", - "handset_spec", - "random_sample_spec", "APIQuery", + "ASyncAPIQuery", "location_event_counts", "meaningful_locations_aggregate", "meaningful_locations_between_label_od_matrix", diff --git a/flowclient/flowclient/aggregates.py b/flowclient/flowclient/aggregates.py index fd5d148308..5875ddf6a7 100644 --- a/flowclient/flowclient/aggregates.py +++ b/flowclient/flowclient/aggregates.py @@ -7,7 +7,7 @@ from merge_args import merge_args -from flowclient.client import Connection +from flowclient import Connection from flowclient.api_query import APIQuery @@ -93,9 +93,7 @@ def location_event_counts(*, connection: Connection, **kwargs) -> APIQuery: APIQuery Location event counts query """ - return APIQuery( - connection=connection, parameters=location_event_counts_spec(**kwargs) - ) + return connection.make_api_query(parameters=location_event_counts_spec(**kwargs)) def meaningful_locations_aggregate_spec( @@ -266,8 +264,8 @@ def meaningful_locations_aggregate(*, connection: Connection, **kwargs) -> APIQu ----- Does not return any value below 15. """ - return APIQuery( - connection=connection, parameters=meaningful_locations_aggregate_spec(**kwargs) + return connection.make_api_query( + parameters=meaningful_locations_aggregate_spec(**kwargs) ) @@ -443,8 +441,7 @@ def meaningful_locations_between_label_od_matrix( .. [1] S. Isaacman et al., "Identifying Important Places in People's Lives from Cellular Network Data", International Conference on Pervasive Computing (2011), pp 133-151. .. [2] Zagatti, Guilherme Augusto, et al. "A trip to work: Estimation of origin and destination of commuting patterns in the main metropolitan regions of Haiti using CDR." Development Engineering 3 (2018): 133-165. """ - return APIQuery( - connection=connection, + return connection.make_api_query( parameters=meaningful_locations_between_label_od_matrix_spec(**kwargs), ) @@ -633,8 +630,7 @@ def meaningful_locations_between_dates_od_matrix( .. [1] S. Isaacman et al., "Identifying Important Places in People's Lives from Cellular Network Data", International Conference on Pervasive Computing (2011), pp 133-151. .. [2] Zagatti, Guilherme Augusto, et al. "A trip to work: Estimation of origin and destination of commuting patterns in the main metropolitan regions of Haiti using CDR." Development Engineering 3 (2018): 133-165. """ - return APIQuery( - connection=connection, + return connection.make_api_query( parameters=meaningful_locations_between_dates_od_matrix_spec(**kwargs), ) @@ -687,7 +683,7 @@ def flows(*, connection: Connection, **kwargs) -> APIQuery: Flows query """ - return APIQuery(connection=connection, parameters=flows_spec(**kwargs)) + return connection.make_api_query(parameters=flows_spec(**kwargs)) def unique_subscriber_counts_spec( @@ -739,9 +735,7 @@ def unique_subscriber_counts(*, connection: Connection, **kwargs) -> APIQuery: APIQuery Unique subscriber counts query """ - return APIQuery( - connection=connection, parameters=unique_subscriber_counts_spec(**kwargs) - ) + return connection.make_api_query(parameters=unique_subscriber_counts_spec(**kwargs)) def location_introversion_spec( @@ -798,9 +792,7 @@ def location_introversion(*, connection: Connection, **kwargs) -> APIQuery: APIQuery Location introversion query """ - return APIQuery( - connection=connection, parameters=location_introversion_spec(**kwargs) - ) + return connection.make_api_query(parameters=location_introversion_spec(**kwargs)) def total_network_objects_spec( @@ -857,9 +849,7 @@ def total_network_objects(*, connection: Connection, **kwargs) -> APIQuery: APIQuery Total network objects query """ - return APIQuery( - connection=connection, parameters=total_network_objects_spec(**kwargs) - ) + return connection.make_api_query(parameters=total_network_objects_spec(**kwargs)) def aggregate_network_objects_spec( @@ -913,8 +903,8 @@ def aggregate_network_objects(*, connection: Connection, **kwargs) -> APIQuery: APIQuery Aggregate network objects query """ - return APIQuery( - connection=connection, parameters=aggregate_network_objects_spec(**kwargs) + return connection.make_api_query( + parameters=aggregate_network_objects_spec(**kwargs) ) @@ -952,7 +942,7 @@ def spatial_aggregate(*, connection: Connection, **kwargs) -> APIQuery: APIQuery Spatial aggregate query """ - return APIQuery(connection=connection, parameters=spatial_aggregate_spec(**kwargs)) + return connection.make_api_query(parameters=spatial_aggregate_spec(**kwargs)) def consecutive_trips_od_matrix_spec( @@ -1014,8 +1004,8 @@ def consecutive_trips_od_matrix(*, connection: Connection, **kwargs) -> APIQuery APIQuery consecutive_trips_od_matrix query """ - return APIQuery( - connection=connection, parameters=consecutive_trips_od_matrix_spec(**kwargs), + return connection.make_api_query( + parameters=consecutive_trips_od_matrix_spec(**kwargs), ) @@ -1078,7 +1068,7 @@ def trips_od_matrix(*, connection: Connection, **kwargs) -> APIQuery: APIQuery trips_od_matrix query """ - return APIQuery(connection=connection, parameters=trips_od_matrix_spec(**kwargs),) + return connection.make_api_query(parameters=trips_od_matrix_spec(**kwargs),) def unmoving_counts_spec( @@ -1118,7 +1108,7 @@ def unmoving_counts(*, connection: Connection, **kwargs) -> APIQuery: APIQuery unmoving_counts query """ - return APIQuery(connection=connection, parameters=unmoving_counts_spec(**kwargs),) + return connection.make_api_query(parameters=unmoving_counts_spec(**kwargs),) def unmoving_at_reference_location_counts_spec( @@ -1170,8 +1160,7 @@ def unmoving_at_reference_location_counts( APIQuery unmoving_at_reference_location_counts query """ - return APIQuery( - connection=connection, + return connection.make_api_query( parameters=unmoving_at_reference_location_counts_spec(**kwargs), ) @@ -1225,8 +1214,7 @@ def active_at_reference_location_counts( APIQuery active_at_reference_location_counts query """ - return APIQuery( - connection=connection, + return connection.make_api_query( parameters=active_at_reference_location_counts_spec(**kwargs), ) @@ -1284,9 +1272,7 @@ def joined_spatial_aggregate(*, connection: Connection, **kwargs) -> APIQuery: APIQuery Joined spatial aggregate query """ - return APIQuery( - connection=connection, parameters=joined_spatial_aggregate_spec(**kwargs) - ) + return connection.make_api_query(parameters=joined_spatial_aggregate_spec(**kwargs)) def histogram_aggregate_spec( @@ -1344,6 +1330,4 @@ def histogram_aggregate(*, connection: Connection, **kwargs) -> APIQuery: APIQuery Histogram aggregate query """ - return APIQuery( - connection=connection, parameters=histogram_aggregate_spec(**kwargs) - ) + return connection.make_api_query(parameters=histogram_aggregate_spec(**kwargs)) diff --git a/flowclient/flowclient/api_query.py b/flowclient/flowclient/api_query.py index a85a788f08..c65721c17d 100644 --- a/flowclient/flowclient/api_query.py +++ b/flowclient/flowclient/api_query.py @@ -6,14 +6,13 @@ from typing import Union, Optional from flowclient.client import ( - FlowclientConnectionError, - Connection, run_query, get_status, get_result_by_query_id, get_geojson_result_by_query_id, wait_for_query_to_be_ready, ) +from flowclient.connection import Connection class APIQuery: diff --git a/flowclient/flowclient/async_api_query.py b/flowclient/flowclient/async_api_query.py new file mode 100644 index 0000000000..cfa027a0c8 --- /dev/null +++ b/flowclient/flowclient/async_api_query.py @@ -0,0 +1,163 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + +from typing import Union, Optional + +from flowclient.api_query import APIQuery +from flowclient.async_connection import ASyncConnection +from flowclient.async_client import ( + run_query, + get_status, + get_result_by_query_id, + get_geojson_result_by_query_id, + wait_for_query_to_be_ready, +) + + +class ASyncAPIQuery(APIQuery): + """ + Representation of a FlowKit query. + + Parameters + ---------- + connection : ASyncConnection + Connection to FlowKit server on which to run this query + parameters : dict + Parameters that specify the query + + Attributes + ---------- + parameters + connection + status + """ + + def __init__(self, *, connection: ASyncConnection, parameters: dict): + self._connection = connection + self.parameters = dict(parameters) + + async def run(self) -> None: + """ + Set this query running in FlowKit + + Raises + ------ + FlowclientConnectionError + if the query cannot be set running + """ + self._query_id = await run_query( + connection=self.connection, query_spec=self.parameters + ) + # TODO: Return a future? + + @property + def connection(self) -> ASyncConnection: + """ + Connection that is used for running this query. + + Returns + ------- + ASyncConnection + Connection to FlowKit API + """ + return self._connection + + @property + async def status(self) -> str: + """ + Status of this query. + + Returns + ------- + str + One of: + - "not_running" + - "queued" + - "executing" + - "completed" + """ + if not hasattr(self, "_query_id"): + return "not_running" + return await get_status(connection=self.connection, query_id=self._query_id) + + async def get_result( + self, + format: str = "pandas", + poll_interval: int = 1, + disable_progress: Optional[bool] = None, + ) -> Union["pandas.DataFrame", dict]: + """ + Get the result of this query, as a pandas DataFrame or GeoJSON dict. + + Parameters + ---------- + format : str, default 'pandas' + Result format. One of {'pandas', 'geojson'} + poll_interval : int, default 1 + Number of seconds to wait between checks for the query being ready + disable_progress : bool, default None + Set to True to disable progress bar display entirely, None to disable on + non-TTY, or False to always enable + + Returns + ------- + pandas.DataFrame or dict + Query result + """ + if format == "pandas": + result_getter = get_result_by_query_id + elif format == "geojson": + result_getter = get_geojson_result_by_query_id + else: + raise ValueError( + f"Invalid format: '{format}'. Expected one of {{'pandas', 'geojson'}}." + ) + + # TODO: Cache result internally? + try: + return await result_getter( + connection=self.connection, + query_id=self._query_id, + poll_interval=poll_interval, + disable_progress=disable_progress, + ) + except (AttributeError, FileNotFoundError): + # TODO: Warn before running? + await self.run() + return await result_getter( + connection=self.connection, + query_id=self._query_id, + poll_interval=poll_interval, + disable_progress=disable_progress, + ) + + async def wait_until_ready( + self, poll_interval: int = 1, disable_progress: Optional[bool] = None + ) -> None: + """ + Wait until this query has finished running. + + Parameters + ---------- + poll_interval : int, default 1 + Number of seconds to wait between checks for the query being ready + disable_progress : bool, default None + Set to True to disable progress bar display entirely, None to disable on + non-TTY, or False to always enable + + + Raises + ------ + FlowclientConnectionError + if query is not running or has errored + """ + if not hasattr(self, "_query_id"): + raise FileNotFoundError("Query is not running.") + await wait_for_query_to_be_ready( + connection=self.connection, + query_id=self._query_id, + poll_interval=poll_interval, + disable_progress=disable_progress, + ) diff --git a/flowclient/flowclient/async_client.py b/flowclient/flowclient/async_client.py new file mode 100644 index 0000000000..22047fb1c8 --- /dev/null +++ b/flowclient/flowclient/async_client.py @@ -0,0 +1,540 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +import logging +import re +from asyncio import sleep + +import pandas as pd +import requests +from typing import Tuple, Union, List, Optional +from tqdm.auto import tqdm + + +import flowclient.errors +from flowclient.async_connection import ASyncConnection + +logger = logging.getLogger(__name__) + + +async def connect_async( + *, + url: str, + token: str, + api_version: int = 0, + ssl_certificate: Union[str, None] = None, +) -> ASyncConnection: + """ + Connect to a FlowKit API server and return the resulting Connection object. + + Parameters + ---------- + url : str + URL of the API server, e.g. "https://localhost:9090" + token : str + JSON Web Token for this API server + api_version : int, default 0 + Version of the API to connect to + ssl_certificate: str or None + Provide a path to an ssl certificate to use, or None to use + default root certificates. + + Returns + ------- + ASyncConnection + """ + return ASyncConnection( + url=url, token=token, api_version=api_version, ssl_certificate=ssl_certificate + ) + + +async def query_is_ready( + *, connection: ASyncConnection, query_id: str +) -> Tuple[bool, requests.Response]: + """ + Check if a query id has results available. + + Parameters + ---------- + connection : ASyncConnection + API connection to use + query_id : str + Identifier of the query to retrieve + + Returns + ------- + Tuple[bool, requests.Response] + True if the query result is available + + Raises + ------ + FlowclientConnectionError + if query has errored + """ + logger.info( + f"Polling server on {connection.url}/api/{connection.api_version}/poll/{query_id}" + ) + reply = await connection.get_url(route=f"poll/{query_id}") + + if reply.status_code == 303: + logger.info( + f"{connection.url}/api/{connection.api_version}/poll/{query_id} ready." + ) + return True, reply # Query is ready, so exit the loop + elif reply.status_code == 202: + logger.info( + "{eligible} parts to run, {queued} in queue and {running} running.".format( + **reply.json()["progress"] + ) + ) + return False, reply + else: + raise flowclient.errors.FlowclientConnectionError( + f"Something went wrong: {reply}. API returned with status code: {reply.status_code}" + ) + + +async def get_status(*, connection: ASyncConnection, query_id: str) -> str: + """ + Check the status of a query. + + Parameters + ---------- + connection : ASyncConnection + API connection to use + query_id : str + Identifier of the query to retrieve + + Returns + ------- + str + Query status + + Raises + ------ + FlowclientConnectionError + if response does not contain the query status + """ + try: + ready, reply = await query_is_ready(connection=connection, query_id=query_id) + except FileNotFoundError: + # Can't distinguish 'known', 'cancelled', 'resetting' and 'awol' from the error, + # so return generic 'not_running' status. + return "not_running" + + if ready: + return "completed" + else: + try: + return reply.json()["status"] + except (KeyError, TypeError): + raise flowclient.errors.FlowclientConnectionError(f"No status reported.") + + +async def wait_for_query_to_be_ready( + *, + connection: ASyncConnection, + query_id: str, + poll_interval: int = 1, + disable_progress: Optional[bool] = None, +) -> requests.Response: + """ + Wait until a query id has finished running, and if it finished successfully + return the reply from flowapi. + + Parameters + ---------- + connection : ASyncConnection + API connection to use + query_id : str + Identifier of the query to retrieve + poll_interval : int + Number of seconds to wait between checks for the query being ready + disable_progress : bool, default None + Set to True to disable progress bar display entirely, None to disable on + non-TTY, or False to always enable + + Returns + ------- + requests.Response + Response object containing the reply to flowapi + + Raises + ------ + FlowclientConnectionError + If the query has finished running unsuccessfully + """ + query_ready, reply = await query_is_ready( + connection=connection, query_id=query_id + ) # Poll the server + + if not query_ready: + progress = reply.json()["progress"] + total_eligible = progress["eligible"] + completed = 0 + with tqdm( + desc="Parts run", disable=disable_progress, unit="q", total=total_eligible + ) as total_bar: + while not query_ready: + logger.info("Waiting before polling again.") + await sleep( + poll_interval + ) # Wait a second, then check if the query is ready again + query_ready, reply = await query_is_ready( + connection=connection, query_id=query_id + ) # Poll the server + if query_ready: + break + else: + progress = reply.json()["progress"] + completion_change = ( + total_eligible - progress["eligible"] + ) - completed + completed += completion_change + + total_bar.update(completion_change) + + total_bar.update(total_eligible - completed) # Finish the bar + + return reply + + +async def get_result_location_from_id_when_ready( + *, + connection: ASyncConnection, + query_id: str, + poll_interval: int = 1, + disable_progress: Optional[bool] = None, +) -> str: + """ + Return, once ready, the location at which results of a query will be obtainable. + + Parameters + ---------- + connection : ASyncConnection + API connection to use + query_id : str + Identifier of the query to retrieve + poll_interval : int + Number of seconds to wait between checks for the query being ready + disable_progress : bool, async default None + Set to True to disable progress bar display entirely, None to disable on + non-TTY, or False to always enable + + Returns + ------- + str + Endpoint to retrieve results from + + """ + reply = await wait_for_query_to_be_ready( + connection=connection, + query_id=query_id, + poll_interval=poll_interval, + disable_progress=disable_progress, + ) + + result_location = reply.headers[ + "Location" + ] # Need to strip off the /api// + return re.sub( + "^/api/[0-9]+/", "", result_location + ) # strip off the /api// + + +async def get_json_dataframe( + *, connection: ASyncConnection, location: str +) -> pd.DataFrame: + """ + Get a dataframe from a json source. + + Parameters + ---------- + connection : ASyncConnection + API connection to use + location : str + API enpoint to retrieve json from + + Returns + ------- + pandas.DataFrame + Dataframe containing the result + + """ + + response = await connection.get_url(route=location) + if response.status_code != 200: + try: + msg = response.json()["msg"] + more_info = f" Reason: {msg}" + except KeyError: + more_info = "" + raise flowclient.errors.FlowclientConnectionError( + f"Could not get result. API returned with status code: {response.status_code}.{more_info}" + ) + result = response.json() + logger.info(f"Got {connection.url}/api/{connection.api_version}/{location}") + return pd.DataFrame.from_records(result["query_result"]) + + +async def get_geojson_result_by_query_id( + *, + connection: ASyncConnection, + query_id: str, + poll_interval: int = 1, + disable_progress: Optional[bool] = None, +) -> dict: + """ + Get a query by id, and return it as a geojson dict + + Parameters + ---------- + connection : ASyncConnection + API connection to use + query_id : str + Identifier of the query to retrieve + poll_interval : int + Number of seconds to wait between checks for the query being ready + disable_progress : bool, async default None + Set to True to disable progress bar display entirely, None to disable on + non-TTY, or False to always enable + + Returns + ------- + dict + geojson + + """ + result_endpoint = await get_result_location_from_id_when_ready( + connection=connection, + query_id=query_id, + poll_interval=poll_interval, + disable_progress=disable_progress, + ) + response = await connection.get_url(route=f"{result_endpoint}.geojson") + if response.status_code != 200: + try: + msg = response.json()["msg"] + more_info = f" Reason: {msg}" + except KeyError: + more_info = "" + raise flowclient.errors.FlowclientConnectionError( + f"Could not get result. API returned with status code: {response.status_code}.{more_info}" + ) + return response.json() + + +async def get_result_by_query_id( + *, + connection: ASyncConnection, + query_id: str, + poll_interval: int = 1, + disable_progress: Optional[bool] = None, +) -> pd.DataFrame: + """ + Get a query by id, and return it as a dataframe + + Parameters + ---------- + connection : ASyncConnection + API connection to use + query_id : str + Identifier of the query to retrieve + poll_interval : int + Number of seconds to wait between checks for the query being ready + disable_progress : bool, async default None + Set to True to disable progress bar display entirely, None to disable on + non-TTY, or False to always enable + + Returns + ------- + pandas.DataFrame + Dataframe containing the result + + """ + result_endpoint = await get_result_location_from_id_when_ready( + connection=connection, + query_id=query_id, + poll_interval=poll_interval, + disable_progress=disable_progress, + ) + return await get_json_dataframe(connection=connection, location=result_endpoint) + + +async def get_geojson_result( + *, + connection: ASyncConnection, + query_spec: dict, + disable_progress: Optional[bool] = None, +) -> dict: + """ + Run and retrieve a query of a specified kind with parameters. + + Parameters + ---------- + connection : ASyncConnection + API connection to use + query_spec : dict + A query specification to run, e.g. `{'kind':'daily_location', 'params':{'date':'2016-01-01'}}` + disable_progress : bool, async default None + Set to True to disable progress bar display entirely, None to disable on + non-TTY, or False to always enable + + Returns + ------- + dict + Geojson + + """ + return await get_geojson_result_by_query_id( + connection=connection, + query_id=await run_query(connection=connection, query_spec=query_spec), + disable_progress=disable_progress, + ) + + +async def get_result( + *, + connection: ASyncConnection, + query_spec: dict, + disable_progress: Optional[bool] = None, +) -> pd.DataFrame: + """ + Run and retrieve a query of a specified kind with parameters. + + Parameters + ---------- + connection : ASyncConnection + API connection to use + query_spec : dict + A query specification to run, e.g. `{'kind':'daily_location', 'date':'2016-01-01'}` + disable_progress : bool, async default None + Set to True to disable progress bar display entirely, None to disable on + non-TTY, or False to always enable + + Returns + ------- + pd.DataFrame + Pandas dataframe containing the results + + """ + return await get_result_by_query_id( + connection=connection, + query_id=await run_query(connection=connection, query_spec=query_spec), + disable_progress=disable_progress, + ) + + +async def get_geography(*, connection: ASyncConnection, aggregation_unit: str) -> dict: + """ + Get geography data from the database. + + Parameters + ---------- + connection : ASyncConnection + API connection to use + aggregation_unit : str + aggregation unit, e.g. 'admin3' + + Returns + ------- + dict + geography data as a GeoJSON FeatureCollection + + """ + logger.info( + f"Getting {connection.url}/api/{connection.api_version}/geography/{aggregation_unit}" + ) + response = await connection.get_url(route=f"geography/{aggregation_unit}") + if response.status_code != 200: + try: + msg = response.json()["msg"] + more_info = f" Reason: {msg}" + except KeyError: + more_info = "" + raise flowclient.errors.FlowclientConnectionError( + f"Could not get result. API returned with status code: {response.status_code}.{more_info}" + ) + result = response.json() + logger.info( + f"Got {connection.url}/api/{connection.api_version}/geography/{aggregation_unit}" + ) + return result + + +async def get_available_dates( + *, connection: ASyncConnection, event_types: Union[None, List[str]] = None +) -> dict: + """ + Get available dates for different event types from the database. + + Parameters + ---------- + connection : ASyncConnection + API connection to use + event_types : list of str, optional + The event types for which to return available dates (for example: ["calls", "sms"]). + If None, return available dates for all available event types. + + Returns + ------- + dict + Available dates in the format {event_type: [list of dates]} + + """ + logger.info( + f"Getting {connection.url}/api/{connection.api_version}/available_dates" + ) + response = await connection.get_url(route=f"available_dates") + if response.status_code != 200: + try: + msg = response.json()["msg"] + more_info = f" Reason: {msg}" + except KeyError: + more_info = "" + raise flowclient.errors.FlowclientConnectionError( + f"Could not get available dates. API returned with status code: {response.status_code}.{more_info}" + ) + result = response.json()["available_dates"] + logger.info(f"Got {connection.url}/api/{connection.api_version}/available_dates") + if event_types is None: + return result + else: + return {k: v for k, v in result.items() if k in event_types} + + +async def run_query(*, connection: ASyncConnection, query_spec: dict) -> str: + """ + Run a query of a specified kind with parameters and get the identifier for it. + + Parameters + ---------- + connection : ASyncConnection + API connection to use + query_spec : dict + Query specification to run + + Returns + ------- + str + Identifier of the query + """ + logger.info( + f"Requesting run of {query_spec} at {connection.url}/api/{connection.api_version}" + ) + r = await connection.post_json(route="run", data=query_spec) + if r.status_code == 202: + query_id = r.headers["Location"].split("/").pop() + logger.info( + f"Accepted {query_spec} at {connection.url}/api/{connection.api_version} with id {query_id}" + ) + return query_id + else: + try: + error = r.json()["msg"] + except (ValueError, KeyError): + error = "Unknown error" + raise flowclient.errors.FlowclientConnectionError( + f"Error running the query: {error}. Status code: {r.status_code}." + ) diff --git a/flowclient/flowclient/async_connection.py b/flowclient/flowclient/async_connection.py new file mode 100644 index 0000000000..8fe91477f3 --- /dev/null +++ b/flowclient/flowclient/async_connection.py @@ -0,0 +1,89 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +from typing import Union + +import requests + +import flowclient.connection + + +class ASyncConnection(flowclient.connection.Connection): + """ + A connection to a FlowKit API server. + + Attributes + ---------- + url : str + URL of the API server + token : str + JSON Web Token for this API server + api_version : int + Version of the API to connect to + user : str + Username of token + + Parameters + ---------- + url : str + URL of the API server, e.g. "https://localhost:9090" + token : str + JSON Web Token for this API server + api_version : int, default 0 + Version of the API to connect to + ssl_certificate: str or None + Provide a path to an ssl certificate to use, or None to use + default root certificates. + """ + + async def get_url( + self, *, route: str, data: Union[None, dict] = None + ) -> requests.Response: + """ + Attempt to get something from the API, and return the raw + response object if an error response wasn't received. + If an error response was received, raises an error. + + Parameters + ---------- + route : str + Path relative to API host to get + + data : dict, optional + JSON data to send in the request body (optional) + + Returns + ------- + requests.Response + + """ + return super().get_url(route=route, data=data) + + async def post_json(self, *, route: str, data: dict) -> requests.Response: + """ + Attempt to post json to the API, and return the raw + response object if an error response wasn't received. + If an error response was received, raises an error. + + Parameters + ---------- + route : str + Path relative to API host to post_json to + data: dict + Dictionary of json-encodeable data to post_json + + Returns + ------- + requests.Response + + """ + return super().post_json(route=route, data=data) + + def make_api_query(self, parameters: dict) -> "ASyncAPIQuery": + from flowclient.async_api_query import ASyncAPIQuery + + return ASyncAPIQuery(connection=self, parameters=parameters) + + def __repr__(self): + return f"{super().__repr__()} (async)" diff --git a/flowclient/flowclient/client.py b/flowclient/flowclient/client.py index 0ca7c31378..42a763aa4b 100644 --- a/flowclient/flowclient/client.py +++ b/flowclient/flowclient/client.py @@ -3,214 +3,18 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. import logging -import warnings import re -from functools import partial -import jwt import pandas as pd import requests import time -from requests import ConnectionError -from typing import Tuple, Union, Dict, List, Optional +from typing import Tuple, Union, List, Optional from tqdm.auto import tqdm -logger = logging.getLogger(__name__) - - -class FlowclientConnectionError(Exception): - """ - Custom exception to indicate an error when connecting to a FlowKit API. - """ - - -class Connection: - """ - A connection to a FlowKit API server. - - Attributes - ---------- - url : str - URL of the API server - token : str - JSON Web Token for this API server - api_version : int - Version of the API to connect to - user : str - Username of token - - Parameters - ---------- - url : str - URL of the API server, e.g. "https://localhost:9090" - token : str - JSON Web Token for this API server - api_version : int, default 0 - Version of the API to connect to - ssl_certificate: str or None - Provide a path to an ssl certificate to use, or None to use - default root certificates. - """ - - url: str - token: str - user: str - api_version: int - - def __init__( - self, - *, - url: str, - token: str, - api_version: int = 0, - ssl_certificate: Union[str, None] = None, - ) -> None: - if not url.lower().startswith("https://"): - warnings.warn( - "Communications with this server are NOT SECURE.", stacklevel=2 - ) - self.url = url - self.api_version = api_version - self.session = requests.Session() - if ssl_certificate is not None: - self.session.verify = ssl_certificate - self.update_token(token=token) - - def update_token(self, token: str) -> None: - """ - Replace this connection's API token with a new one. - - Parameters - ---------- - token : str - JSON Web Token for this API server - """ - try: - self.user = jwt.decode(token, verify=False)["identity"] - except jwt.DecodeError: - raise FlowclientConnectionError(f"Unable to decode token: '{token}'") - except KeyError: - raise FlowclientConnectionError(f"Token does not contain user identity.") - self.token = token - self.session.headers["Authorization"] = f"Bearer {self.token}" - - def get_url( - self, *, route: str, data: Union[None, dict] = None - ) -> requests.Response: - """ - Attempt to get something from the API, and return the raw - response object if an error response wasn't received. - If an error response was received, raises an error. - - Parameters - ---------- - route : str - Path relative to API host to get - - data : dict, optional - JSON data to send in the request body (optional) - - Returns - ------- - requests.Response - - """ - logger.debug(f"Getting {self.url}/api/{self.api_version}/{route}") - try: - response = self.session.get( - f"{self.url}/api/{self.api_version}/{route}", - allow_redirects=False, - json=data, - ) - except ConnectionError as e: - error_msg = f"Unable to connect to FlowKit API at {self.url}: {e}" - logger.info(error_msg) - raise FlowclientConnectionError(error_msg) - if response.status_code in {202, 200, 303}: - return response - elif response.status_code == 404: - raise FileNotFoundError( - f"{self.url}/api/{self.api_version}/{route} not found." - ) - elif response.status_code in {401, 403}: - try: - error = response.json()["msg"] - except (ValueError, KeyError): - error = "Unknown access denied error" - raise FlowclientConnectionError(error) - else: - try: - error = response.json()["msg"] - except (ValueError, KeyError): - error = "Unknown error" - try: - status = response.json()["status"] - except (ValueError, KeyError): - status = "Unknown status" - raise FlowclientConnectionError( - f"Something went wrong: {error}. API returned with status code: {response.status_code} and status '{status}'" - ) +from flowclient.connection import Connection +from flowclient.errors import FlowclientConnectionError - def post_json(self, *, route: str, data: dict) -> requests.Response: - """ - Attempt to post json to the API, and return the raw - response object if an error response wasn't received. - If an error response was received, raises an error. - - Parameters - ---------- - route : str - Path relative to API host to post_json to - data: dict - Dictionary of json-encodeable data to post_json - - Returns - ------- - requests.Response - - """ - logger.debug(f"Posting {data} to {self.url}/api/{self.api_version}/{route}") - try: - response = self.session.post( - f"{self.url}/api/{self.api_version}/{route}", json=data - ) - except ConnectionError as e: - error_msg = f"Unable to connect to FlowKit API at {self.url}: {e}" - logger.info(error_msg) - raise FlowclientConnectionError(error_msg) - if response.status_code == 202: - return response - elif response.status_code == 404: - raise FileNotFoundError( - f"{self.url}/api/{self.api_version}/{route} not found." - ) - elif response.status_code in {401, 403}: - try: - error_msg = response.json()["msg"] - except ValueError: - error_msg = "Unknown access denied error" - raise FlowclientConnectionError(error_msg) - else: - try: - error_msg = response.json()["msg"] - try: - returned_payload = response.json()["payload"] - payload_info = ( - "" if not returned_payload else f" Payload: {returned_payload}" - ) - except KeyError: - payload_info = "" - except ValueError: - # Happens if the response body does not contain valid JSON - # (see http://docs.python-requests.org/en/master/api/#requests.Response.json) - error_msg = f"the response did not contain valid JSON" - payload_info = "" - raise FlowclientConnectionError( - f"Something went wrong. API returned with status code {response.status_code}. Error message: '{error_msg}'.{payload_info}" - ) - - def __repr__(self) -> str: - return f"{self.user}@{self.url} v{self.api_version}" +logger = logging.getLogger(__name__) def connect( @@ -731,591 +535,3 @@ def run_query(*, connection: Connection, query_spec: dict) -> str: raise FlowclientConnectionError( f"Error running the query: {error}. Status code: {r.status_code}." ) - - -def unique_locations_spec( - *, - start_date: str, - end_date: str, - aggregation_unit: str, - subscriber_subset: Union[dict, None] = None, -) -> dict: - """ - Subscriber level query which retrieves the unique set of locations visited by each subscriber - in the time period. - - Parameters - ---------- - start_date, end_date : str - ISO format dates between which to get unique locations, e.g. "2016-01-01" - aggregation_unit : str - Unit of aggregation, e.g. "admin3" - subscriber_subset : dict or None - Subset of subscribers to retrieve daily locations for. Must be None - (= all subscribers) or a dictionary with the specification of a - subset query. - - Returns - ------- - dict - Unique locations query specification. - - """ - return dict( - query_kind="unique_locations", - start_date=start_date, - end_date=end_date, - aggregation_unit=aggregation_unit, - subscriber_subset=subscriber_subset, - ) - - -def daily_location_spec( - *, - date: str, - aggregation_unit: str, - method: str, - subscriber_subset: Union[dict, None] = None, -) -> dict: - """ - Return query spec for a daily location query for a date and unit of aggregation. - Must be passed to `spatial_aggregate` to retrieve a result from the aggregates API. - - Parameters - ---------- - date : str - ISO format date to get the daily location for, e.g. "2016-01-01" - aggregation_unit : str - Unit of aggregation, e.g. "admin3" - method : str - Method to use for daily location, one of 'last' or 'most-common' - subscriber_subset : dict or None - Subset of subscribers to retrieve daily locations for. Must be None - (= all subscribers) or a dictionary with the specification of a - subset query. - - Returns - ------- - dict - Dict which functions as the query specification - - """ - return { - "query_kind": "daily_location", - "date": date, - "aggregation_unit": aggregation_unit, - "method": method, - "subscriber_subset": subscriber_subset, - } - - -def modal_location_spec( - *, locations: List[Dict[str, Union[str, Dict[str, str]]]] -) -> dict: - """ - Return query spec for a modal location query for a list of locations. - Must be passed to `spatial_aggregate` to retrieve a result from the aggregates API. - - Parameters - ---------- - locations : list of dicts - List of location query specifications - - - Returns - ------- - dict - Dict which functions as the query specification for the modal location - - """ - return { - "query_kind": "modal_location", - "locations": locations, - } - - -def modal_location_from_dates_spec( - *, - start_date: str, - end_date: str, - aggregation_unit: str, - method: str, - subscriber_subset: Union[dict, None] = None, -) -> dict: - """ - Return query spec for a modal location query for a date range and unit of aggregation. - Must be passed to `spatial_aggregate` to retrieve a result from the aggregates API. - - Parameters - ---------- - start_date : str - ISO format date that begins the period, e.g. "2016-01-01" - end_date : str - ISO format date for the day _after_ the final date of the period, e.g. "2016-01-08" - aggregation_unit : str - Unit of aggregation, e.g. "admin3" - method : str - Method to use for daily locations, one of 'last' or 'most-common' - subscriber_subset : dict or None - Subset of subscribers to retrieve modal locations for. Must be None - (= all subscribers) or a dictionary with the specification of a - subset query. - - Returns - ------- - dict - Dict which functions as the query specification - - """ - dates = [ - d.strftime("%Y-%m-%d") - for d in pd.date_range(start_date, end_date, freq="D", closed="left") - ] - daily_locations = [ - daily_location_spec( - date=date, - aggregation_unit=aggregation_unit, - method=method, - subscriber_subset=subscriber_subset, - ) - for date in dates - ] - return modal_location_spec(locations=daily_locations) - - -def radius_of_gyration_spec( - *, start_date: str, end_date: str, subscriber_subset: Union[dict, None] = None -) -> dict: - """ - Return query spec for radius of gyration - - Parameters - ---------- - start_date : str - ISO format date of the first day of the count, e.g. "2016-01-01" - end_date : str - ISO format date of the day _after_ the final date of the count, e.g. "2016-01-08" - subscriber_subset : dict or None, default None - Subset of subscribers to include in event counts. Must be None - (= all subscribers) or a dictionary with the specification of a - subset query. - - Returns - ------- - dict - Dict which functions as the query specification - """ - return { - "query_kind": "radius_of_gyration", - "start_date": start_date, - "end_date": end_date, - "subscriber_subset": subscriber_subset, - } - - -def unique_location_counts_spec( - *, - start_date: str, - end_date: str, - aggregation_unit: str, - subscriber_subset: Union[dict, None] = None, -) -> dict: - """ - Return query spec for unique location count - - Parameters - ---------- - start_date : str - ISO format date of the first day of the count, e.g. "2016-01-01" - end_date : str - ISO format date of the day _after_ the final date of the count, e.g. "2016-01-08" - aggregation_unit : str - Unit of aggregation, e.g. "admin3" - subscriber_subset : dict or None, default None - Subset of subscribers to include in event counts. Must be None - (= all subscribers) or a dictionary with the specification of a - subset query. - - Returns - ------- - dict - Dict which functions as the query specification - """ - return { - "query_kind": "unique_location_counts", - "start_date": start_date, - "end_date": end_date, - "aggregation_unit": aggregation_unit, - "subscriber_subset": subscriber_subset, - } - - -def topup_balance_spec( - *, - start_date: str, - end_date: str, - statistic: str, - subscriber_subset: Union[dict, None] = None, -) -> dict: - """ - Return query spec for top-up balance. - - Parameters - ---------- - start_date : str - ISO format date of the first day of the count, e.g. "2016-01-01" - end_date : str - ISO format date of the day _after_ the final date of the count, e.g. "2016-01-08" - statistic : {"avg", "max", "min", "median", "mode", "stddev", "variance"} - Statistic type one of "avg", "max", "min", "median", "mode", "stddev" or "variance". - subscriber_subset : dict or None, default None - Subset of subscribers to include in event counts. Must be None - (= all subscribers) or a dictionary with the specification of a - subset query. - - Returns - ------- - dict - Dict which functions as the query specification - """ - return { - "query_kind": "topup_balance", - "start_date": start_date, - "end_date": end_date, - "statistic": statistic, - "subscriber_subset": subscriber_subset, - } - - -def subscriber_degree_spec( - *, - start: str, - stop: str, - direction: str = "both", - subscriber_subset: Union[dict, None] = None, -) -> dict: - """ - Return query spec for subscriber degree - - Parameters - ---------- - start : str - ISO format date of the first day of the count, e.g. "2016-01-01" - stop : str - ISO format date of the day _after_ the final date of the count, e.g. "2016-01-08" - direction : {"in", "out", "both"}, default "both" - Optionally, include only ingoing or outbound calls/texts. Can be one of "in", "out" or "both". - subscriber_subset : dict or None, default None - Subset of subscribers to include in event counts. Must be None - (= all subscribers) or a dictionary with the specification of a - subset query. - - Returns - ------- - dict - Dict which functions as the query specification - """ - return { - "query_kind": "subscriber_degree", - "start": start, - "stop": stop, - "direction": direction, - "subscriber_subset": subscriber_subset, - } - - -def topup_amount_spec( - *, - start: str, - stop: str, - statistic: str, - subscriber_subset: Union[dict, None] = None, -) -> dict: - """ - Return query spec for topup amount - - Parameters - ---------- - start : str - ISO format date of the first day of the count, e.g. "2016-01-01" - stop : str - ISO format date of the day _after_ the final date of the count, e.g. "2016-01-08" - statistic : {"avg", "max", "min", "median", "mode", "stddev", "variance"} - Statistic type one of "avg", "max", "min", "median", "mode", "stddev" or "variance". - subscriber_subset : dict or None, default None - Subset of subscribers to include in event counts. Must be None - (= all subscribers) or a dictionary with the specification of a - subset query. - - Returns - ------- - dict - Dict which functions as the query specification - """ - return { - "query_kind": "topup_amount", - "start": start, - "stop": stop, - "statistic": statistic, - "subscriber_subset": subscriber_subset, - } - - -def event_count_spec( - *, - start: str, - stop: str, - direction: str = "both", - event_types: Optional[List[str]] = None, - subscriber_subset: Union[dict, None] = None, -) -> dict: - """ - Return query spec for event count - - Parameters - ---------- - start : str - ISO format date of the first day of the count, e.g. "2016-01-01" - stop : str - ISO format date of the day _after_ the final date of the count, e.g. "2016-01-08" - direction : {"in", "out", "both"}, default "both" - Optionally, include only ingoing or outbound calls/texts. Can be one of "in", "out" or "both". - event_types : list of str, optional - The event types to include in the count (for example: ["calls", "sms"]). - If None, include all event types in the count. - subscriber_subset : dict or None, default None - Subset of subscribers to include in event counts. Must be None - (= all subscribers) or a dictionary with the specification of a - subset query. - - Returns - ------- - dict - Dict which functions as the query specification - """ - return { - "query_kind": "event_count", - "start": start, - "stop": stop, - "direction": direction, - "event_types": event_types, - "subscriber_subset": subscriber_subset, - } - - -def displacement_spec( - *, - start: str, - stop: str, - statistic: str, - reference_location: Dict[str, str], - subscriber_subset: Union[dict, None] = None, -) -> dict: - """ - Return query spec for displacement - - Parameters - ---------- - start : str - ISO format date of the first day of the count, e.g. "2016-01-01" - stop : str - ISO format date of the day _after_ the final date of the count, e.g. "2016-01-08" - statistic : {"avg", "max", "min", "median", "mode", "stddev", "variance"} - Statistic type one of "avg", "max", "min", "median", "mode", "stddev" or "variance". - reference_location: - - subscriber_subset : dict or None, default None - Subset of subscribers to include in event counts. Must be None - (= all subscribers) or a dictionary with the specification of a - subset query. - - Returns - ------- - dict - Dict which functions as the query specification - """ - return { - "query_kind": "displacement", - "start": start, - "stop": stop, - "statistic": statistic, - "reference_location": reference_location, - "subscriber_subset": subscriber_subset, - } - - -def pareto_interactions_spec( - *, - start: str, - stop: str, - proportion: float, - subscriber_subset: Union[dict, None] = None, -) -> dict: - """ - Return query spec for pareto interactions - - Parameters - ---------- - start : str - ISO format date of the first day of the time interval to be considered, e.g. "2016-01-01" - stop : str - ISO format date of the day _after_ the final date of the time interval to be considered, e.g. "2016-01-08" - proportion : float - proportion to track below - subscriber_subset : dict or None, default None - Subset of subscribers to include in result. Must be None - (= all subscribers) or a dictionary with the specification of a - subset query. - - Returns - ------- - dict - Dict which functions as the query specification - """ - return { - "query_kind": "pareto_interactions", - "start": start, - "stop": stop, - "proportion": proportion, - "subscriber_subset": subscriber_subset, - } - - -def nocturnal_events_spec( - *, - start: str, - stop: str, - hours: Tuple[int, int], - subscriber_subset: Union[dict, None] = None, -) -> dict: - """ - Return query spec for nocturnal events - - Parameters - ---------- - start : str - ISO format date of the first day for which to count nocturnal events, e.g. "2016-01-01" - stop : str - ISO format date of the day _after_ the final date for which to count nocturnal events, e.g. "2016-01-08" - hours: tuple(int,int) - Tuple defining beginning and end of night - - subscriber_subset : dict or None, default None - Subset of subscribers to include in event counts. Must be None - (= all subscribers) or a dictionary with the specification of a - subset query. - - Returns - ------- - dict - Dict which functions as the query specification - """ - - return { - "query_kind": "nocturnal_events", - "start": start, - "stop": stop, - "night_start_hour": hours[0], - "night_end_hour": hours[1], - "subscriber_subset": subscriber_subset, - } - - -def handset_spec( - *, - start_date: str, - end_date: str, - characteristic: str = "hnd_type", - method: str = "last", - subscriber_subset: Union[dict, None] = None, -) -> dict: - """ - Return query spec for handset - - Parameters - ---------- - start : str - ISO format date of the first day for which to count handsets, e.g. "2016-01-01" - stop : str - ISO format date of the day _after_ the final date for which to count handsets, e.g. "2016-01-08" - characteristic: {"hnd_type", "brand", "model", "software_os_name", "software_os_vendor"}, default "hnd_type" - The required handset characteristic. - method: {"last", "most-common"}, default "last" - Method for choosing a handset to associate with subscriber. - subscriber_subset : dict or None, default None - Subset of subscribers to include in event counts. Must be None - (= all subscribers) or a dictionary with the specification of a - subset query. - - Returns - ------- - dict - Dict which functions as the query specification - """ - return { - "query_kind": "handset", - "start_date": start_date, - "end_date": end_date, - "characteristic": characteristic, - "method": method, - "subscriber_subset": subscriber_subset, - } - - -def random_sample_spec( - *, - query: Dict[str, Union[str, dict]], - seed: float, - sampling_method: str = "random_ids", - size: Union[int, None] = None, - fraction: Union[float, None] = None, - estimate_count: bool = True, -) -> dict: - """ - Return spec for a random sample from a query result. - - Parameters - ---------- - query : dict - Specification of the query to be sampled. - sampling_method : {'system', 'bernoulli', 'random_ids'}, default 'random_ids' - Specifies the method used to select the random sample. - 'system': performs block-level sampling by randomly sampling each - physical storage page for the underlying relation. This - sampling method is not guaranteed to generate a sample of the - specified size, but an approximation. This method may not - produce a sample at all, so it might be worth running it again - if it returns an empty dataframe. - 'bernoulli': samples directly on each row of the underlying - relation. This sampling method is slower and is not guaranteed to - generate a sample of the specified size, but an approximation - 'random_ids': samples rows by randomly sampling the row number. - size : int, optional - The number of rows to draw. - Exactly one of the 'size' or 'fraction' arguments must be provided. - fraction : float, optional - Fraction of rows to draw. - Exactly one of the 'size' or 'fraction' arguments must be provided. - estimate_count : bool, default True - Whether to estimate the number of rows in the table using - information contained in the `pg_class` or whether to perform an - actual count in the number of rows. - seed : float - A seed for repeatable random samples. - If using random_ids method, seed must be between -/+1. - - Returns - ------- - dict - Dict which functions as the query specification. - """ - sampled_query = dict(query) - sampling = dict( - seed=seed, - sampling_method=sampling_method, - size=size, - fraction=fraction, - estimate_count=estimate_count, - ) - sampled_query["sampling"] = sampling - return sampled_query diff --git a/flowclient/flowclient/connection.py b/flowclient/flowclient/connection.py new file mode 100644 index 0000000000..c83e0af1af --- /dev/null +++ b/flowclient/flowclient/connection.py @@ -0,0 +1,207 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +import logging +import warnings +from typing import Union + +import jwt +import requests +from requests import ConnectionError +from flowclient.errors import FlowclientConnectionError + +logger = logging.getLogger(__name__) + + +class Connection: + """ + A connection to a FlowKit API server. + + Attributes + ---------- + url : str + URL of the API server + token : str + JSON Web Token for this API server + api_version : int + Version of the API to connect to + user : str + Username of token + + Parameters + ---------- + url : str + URL of the API server, e.g. "https://localhost:9090" + token : str + JSON Web Token for this API server + api_version : int, default 0 + Version of the API to connect to + ssl_certificate: str or None + Provide a path to an ssl certificate to use, or None to use + default root certificates. + """ + + url: str + token: str + user: str + api_version: int + + def __init__( + self, + *, + url: str, + token: str, + api_version: int = 0, + ssl_certificate: Union[str, None] = None, + ) -> None: + if not url.lower().startswith("https://"): + warnings.warn( + "Communications with this server are NOT SECURE.", stacklevel=2 + ) + self.url = url + self.api_version = api_version + self.session = requests.Session() + if ssl_certificate is not None: + self.session.verify = ssl_certificate + self.update_token(token=token) + + def update_token(self, token: str) -> None: + """ + Replace this connection's API token with a new one. + + Parameters + ---------- + token : str + JSON Web Token for this API server + """ + try: + self.user = jwt.decode(token, verify=False)["identity"] + except jwt.DecodeError: + raise FlowclientConnectionError(f"Unable to decode token: '{token}'") + except KeyError: + raise FlowclientConnectionError(f"Token does not contain user identity.") + self.token = token + self.session.headers["Authorization"] = f"Bearer {self.token}" + + def get_url( + self, *, route: str, data: Union[None, dict] = None + ) -> requests.Response: + """ + Attempt to get something from the API, and return the raw + response object if an error response wasn't received. + If an error response was received, raises an error. + + Parameters + ---------- + route : str + Path relative to API host to get + + data : dict, optional + JSON data to send in the request body (optional) + + Returns + ------- + requests.Response + + """ + logger.debug(f"Getting {self.url}/api/{self.api_version}/{route}") + try: + response = self.session.get( + f"{self.url}/api/{self.api_version}/{route}", + allow_redirects=False, + json=data, + ) + except ConnectionError as e: + error_msg = f"Unable to connect to FlowKit API at {self.url}: {e}" + logger.info(error_msg) + raise FlowclientConnectionError(error_msg) + if response.status_code in {202, 200, 303}: + return response + elif response.status_code == 404: + raise FileNotFoundError( + f"{self.url}/api/{self.api_version}/{route} not found." + ) + elif response.status_code in {401, 403}: + try: + error = response.json()["msg"] + except (ValueError, KeyError): + error = "Unknown access denied error" + raise FlowclientConnectionError(error) + else: + try: + error = response.json()["msg"] + except (ValueError, KeyError): + error = "Unknown error" + try: + status = response.json()["status"] + except (ValueError, KeyError): + status = "Unknown status" + raise FlowclientConnectionError( + f"Something went wrong: {error}. API returned with status code: {response.status_code} and status '{status}'" + ) + + def post_json(self, *, route: str, data: dict) -> requests.Response: + """ + Attempt to post json to the API, and return the raw + response object if an error response wasn't received. + If an error response was received, raises an error. + + Parameters + ---------- + route : str + Path relative to API host to post_json to + data: dict + Dictionary of json-encodeable data to post_json + + Returns + ------- + requests.Response + + """ + logger.debug(f"Posting {data} to {self.url}/api/{self.api_version}/{route}") + try: + response = self.session.post( + f"{self.url}/api/{self.api_version}/{route}", json=data + ) + except ConnectionError as e: + error_msg = f"Unable to connect to FlowKit API at {self.url}: {e}" + logger.info(error_msg) + raise FlowclientConnectionError(error_msg) + if response.status_code == 202: + return response + elif response.status_code == 404: + raise FileNotFoundError( + f"{self.url}/api/{self.api_version}/{route} not found." + ) + elif response.status_code in {401, 403}: + try: + error_msg = response.json()["msg"] + except ValueError: + error_msg = "Unknown access denied error" + raise FlowclientConnectionError(error_msg) + else: + try: + error_msg = response.json()["msg"] + try: + returned_payload = response.json()["payload"] + payload_info = ( + "" if not returned_payload else f" Payload: {returned_payload}" + ) + except KeyError: + payload_info = "" + except ValueError: + # Happens if the response body does not contain valid JSON + # (see http://docs.python-requests.org/en/master/api/#requests.Response.json) + error_msg = f"the response did not contain valid JSON" + payload_info = "" + raise FlowclientConnectionError( + f"Something went wrong. API returned with status code {response.status_code}. Error message: '{error_msg}'.{payload_info}" + ) + + def make_api_query(self, parameters: dict) -> "APIQuery": + from flowclient.api_query import APIQuery + + return APIQuery(connection=self, parameters=parameters) + + def __repr__(self) -> str: + return f"{self.user}@{self.url} v{self.api_version}" diff --git a/flowclient/flowclient/errors.py b/flowclient/flowclient/errors.py new file mode 100644 index 0000000000..e6cf72be32 --- /dev/null +++ b/flowclient/flowclient/errors.py @@ -0,0 +1,9 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + +class FlowclientConnectionError(Exception): + """ + Custom exception to indicate an error when connecting to a FlowKit API. + """ diff --git a/flowclient/flowclient/query_specs.py b/flowclient/flowclient/query_specs.py new file mode 100644 index 0000000000..d9d2f866c1 --- /dev/null +++ b/flowclient/flowclient/query_specs.py @@ -0,0 +1,595 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +from typing import Union, List, Dict, Optional, Tuple + +import pandas as pd + + +def unique_locations_spec( + *, + start_date: str, + end_date: str, + aggregation_unit: str, + subscriber_subset: Union[dict, None] = None, +) -> dict: + """ + Subscriber level query which retrieves the unique set of locations visited by each subscriber + in the time period. + + Parameters + ---------- + start_date, end_date : str + ISO format dates between which to get unique locations, e.g. "2016-01-01" + aggregation_unit : str + Unit of aggregation, e.g. "admin3" + subscriber_subset : dict or None + Subset of subscribers to retrieve daily locations for. Must be None + (= all subscribers) or a dictionary with the specification of a + subset query. + + Returns + ------- + dict + Unique locations query specification. + + """ + return dict( + query_kind="unique_locations", + start_date=start_date, + end_date=end_date, + aggregation_unit=aggregation_unit, + subscriber_subset=subscriber_subset, + ) + + +def daily_location_spec( + *, + date: str, + aggregation_unit: str, + method: str, + subscriber_subset: Union[dict, None] = None, +) -> dict: + """ + Return query spec for a daily location query for a date and unit of aggregation. + Must be passed to `spatial_aggregate` to retrieve a result from the aggregates API. + + Parameters + ---------- + date : str + ISO format date to get the daily location for, e.g. "2016-01-01" + aggregation_unit : str + Unit of aggregation, e.g. "admin3" + method : str + Method to use for daily location, one of 'last' or 'most-common' + subscriber_subset : dict or None + Subset of subscribers to retrieve daily locations for. Must be None + (= all subscribers) or a dictionary with the specification of a + subset query. + + Returns + ------- + dict + Dict which functions as the query specification + + """ + return { + "query_kind": "daily_location", + "date": date, + "aggregation_unit": aggregation_unit, + "method": method, + "subscriber_subset": subscriber_subset, + } + + +def modal_location_spec( + *, locations: List[Dict[str, Union[str, Dict[str, str]]]] +) -> dict: + """ + Return query spec for a modal location query for a list of locations. + Must be passed to `spatial_aggregate` to retrieve a result from the aggregates API. + + Parameters + ---------- + locations : list of dicts + List of location query specifications + + + Returns + ------- + dict + Dict which functions as the query specification for the modal location + + """ + return { + "query_kind": "modal_location", + "locations": locations, + } + + +def modal_location_from_dates_spec( + *, + start_date: str, + end_date: str, + aggregation_unit: str, + method: str, + subscriber_subset: Union[dict, None] = None, +) -> dict: + """ + Return query spec for a modal location query for a date range and unit of aggregation. + Must be passed to `spatial_aggregate` to retrieve a result from the aggregates API. + + Parameters + ---------- + start_date : str + ISO format date that begins the period, e.g. "2016-01-01" + end_date : str + ISO format date for the day _after_ the final date of the period, e.g. "2016-01-08" + aggregation_unit : str + Unit of aggregation, e.g. "admin3" + method : str + Method to use for daily locations, one of 'last' or 'most-common' + subscriber_subset : dict or None + Subset of subscribers to retrieve modal locations for. Must be None + (= all subscribers) or a dictionary with the specification of a + subset query. + + Returns + ------- + dict + Dict which functions as the query specification + + """ + dates = [ + d.strftime("%Y-%m-%d") + for d in pd.date_range(start_date, end_date, freq="D", closed="left") + ] + daily_locations = [ + daily_location_spec( + date=date, + aggregation_unit=aggregation_unit, + method=method, + subscriber_subset=subscriber_subset, + ) + for date in dates + ] + return modal_location_spec(locations=daily_locations) + + +def radius_of_gyration_spec( + *, start_date: str, end_date: str, subscriber_subset: Union[dict, None] = None +) -> dict: + """ + Return query spec for radius of gyration + + Parameters + ---------- + start_date : str + ISO format date of the first day of the count, e.g. "2016-01-01" + end_date : str + ISO format date of the day _after_ the final date of the count, e.g. "2016-01-08" + subscriber_subset : dict or None, default None + Subset of subscribers to include in event counts. Must be None + (= all subscribers) or a dictionary with the specification of a + subset query. + + Returns + ------- + dict + Dict which functions as the query specification + """ + return { + "query_kind": "radius_of_gyration", + "start_date": start_date, + "end_date": end_date, + "subscriber_subset": subscriber_subset, + } + + +def unique_location_counts_spec( + *, + start_date: str, + end_date: str, + aggregation_unit: str, + subscriber_subset: Union[dict, None] = None, +) -> dict: + """ + Return query spec for unique location count + + Parameters + ---------- + start_date : str + ISO format date of the first day of the count, e.g. "2016-01-01" + end_date : str + ISO format date of the day _after_ the final date of the count, e.g. "2016-01-08" + aggregation_unit : str + Unit of aggregation, e.g. "admin3" + subscriber_subset : dict or None, default None + Subset of subscribers to include in event counts. Must be None + (= all subscribers) or a dictionary with the specification of a + subset query. + + Returns + ------- + dict + Dict which functions as the query specification + """ + return { + "query_kind": "unique_location_counts", + "start_date": start_date, + "end_date": end_date, + "aggregation_unit": aggregation_unit, + "subscriber_subset": subscriber_subset, + } + + +def topup_balance_spec( + *, + start_date: str, + end_date: str, + statistic: str, + subscriber_subset: Union[dict, None] = None, +) -> dict: + """ + Return query spec for top-up balance. + + Parameters + ---------- + start_date : str + ISO format date of the first day of the count, e.g. "2016-01-01" + end_date : str + ISO format date of the day _after_ the final date of the count, e.g. "2016-01-08" + statistic : {"avg", "max", "min", "median", "mode", "stddev", "variance"} + Statistic type one of "avg", "max", "min", "median", "mode", "stddev" or "variance". + subscriber_subset : dict or None, default None + Subset of subscribers to include in event counts. Must be None + (= all subscribers) or a dictionary with the specification of a + subset query. + + Returns + ------- + dict + Dict which functions as the query specification + """ + return { + "query_kind": "topup_balance", + "start_date": start_date, + "end_date": end_date, + "statistic": statistic, + "subscriber_subset": subscriber_subset, + } + + +def subscriber_degree_spec( + *, + start: str, + stop: str, + direction: str = "both", + subscriber_subset: Union[dict, None] = None, +) -> dict: + """ + Return query spec for subscriber degree + + Parameters + ---------- + start : str + ISO format date of the first day of the count, e.g. "2016-01-01" + stop : str + ISO format date of the day _after_ the final date of the count, e.g. "2016-01-08" + direction : {"in", "out", "both"}, default "both" + Optionally, include only ingoing or outbound calls/texts. Can be one of "in", "out" or "both". + subscriber_subset : dict or None, default None + Subset of subscribers to include in event counts. Must be None + (= all subscribers) or a dictionary with the specification of a + subset query. + + Returns + ------- + dict + Dict which functions as the query specification + """ + return { + "query_kind": "subscriber_degree", + "start": start, + "stop": stop, + "direction": direction, + "subscriber_subset": subscriber_subset, + } + + +def topup_amount_spec( + *, + start: str, + stop: str, + statistic: str, + subscriber_subset: Union[dict, None] = None, +) -> dict: + """ + Return query spec for topup amount + + Parameters + ---------- + start : str + ISO format date of the first day of the count, e.g. "2016-01-01" + stop : str + ISO format date of the day _after_ the final date of the count, e.g. "2016-01-08" + statistic : {"avg", "max", "min", "median", "mode", "stddev", "variance"} + Statistic type one of "avg", "max", "min", "median", "mode", "stddev" or "variance". + subscriber_subset : dict or None, default None + Subset of subscribers to include in event counts. Must be None + (= all subscribers) or a dictionary with the specification of a + subset query. + + Returns + ------- + dict + Dict which functions as the query specification + """ + return { + "query_kind": "topup_amount", + "start": start, + "stop": stop, + "statistic": statistic, + "subscriber_subset": subscriber_subset, + } + + +def event_count_spec( + *, + start: str, + stop: str, + direction: str = "both", + event_types: Optional[List[str]] = None, + subscriber_subset: Union[dict, None] = None, +) -> dict: + """ + Return query spec for event count + + Parameters + ---------- + start : str + ISO format date of the first day of the count, e.g. "2016-01-01" + stop : str + ISO format date of the day _after_ the final date of the count, e.g. "2016-01-08" + direction : {"in", "out", "both"}, default "both" + Optionally, include only ingoing or outbound calls/texts. Can be one of "in", "out" or "both". + event_types : list of str, optional + The event types to include in the count (for example: ["calls", "sms"]). + If None, include all event types in the count. + subscriber_subset : dict or None, default None + Subset of subscribers to include in event counts. Must be None + (= all subscribers) or a dictionary with the specification of a + subset query. + + Returns + ------- + dict + Dict which functions as the query specification + """ + return { + "query_kind": "event_count", + "start": start, + "stop": stop, + "direction": direction, + "event_types": event_types, + "subscriber_subset": subscriber_subset, + } + + +def displacement_spec( + *, + start: str, + stop: str, + statistic: str, + reference_location: Dict[str, str], + subscriber_subset: Union[dict, None] = None, +) -> dict: + """ + Return query spec for displacement + + Parameters + ---------- + start : str + ISO format date of the first day of the count, e.g. "2016-01-01" + stop : str + ISO format date of the day _after_ the final date of the count, e.g. "2016-01-08" + statistic : {"avg", "max", "min", "median", "mode", "stddev", "variance"} + Statistic type one of "avg", "max", "min", "median", "mode", "stddev" or "variance". + reference_location: + + subscriber_subset : dict or None, default None + Subset of subscribers to include in event counts. Must be None + (= all subscribers) or a dictionary with the specification of a + subset query. + + Returns + ------- + dict + Dict which functions as the query specification + """ + return { + "query_kind": "displacement", + "start": start, + "stop": stop, + "statistic": statistic, + "reference_location": reference_location, + "subscriber_subset": subscriber_subset, + } + + +def pareto_interactions_spec( + *, + start: str, + stop: str, + proportion: float, + subscriber_subset: Union[dict, None] = None, +) -> dict: + """ + Return query spec for pareto interactions + + Parameters + ---------- + start : str + ISO format date of the first day of the time interval to be considered, e.g. "2016-01-01" + stop : str + ISO format date of the day _after_ the final date of the time interval to be considered, e.g. "2016-01-08" + proportion : float + proportion to track below + subscriber_subset : dict or None, default None + Subset of subscribers to include in result. Must be None + (= all subscribers) or a dictionary with the specification of a + subset query. + + Returns + ------- + dict + Dict which functions as the query specification + """ + return { + "query_kind": "pareto_interactions", + "start": start, + "stop": stop, + "proportion": proportion, + "subscriber_subset": subscriber_subset, + } + + +def nocturnal_events_spec( + *, + start: str, + stop: str, + hours: Tuple[int, int], + subscriber_subset: Union[dict, None] = None, +) -> dict: + """ + Return query spec for nocturnal events + + Parameters + ---------- + start : str + ISO format date of the first day for which to count nocturnal events, e.g. "2016-01-01" + stop : str + ISO format date of the day _after_ the final date for which to count nocturnal events, e.g. "2016-01-08" + hours: tuple(int,int) + Tuple defining beginning and end of night + + subscriber_subset : dict or None, default None + Subset of subscribers to include in event counts. Must be None + (= all subscribers) or a dictionary with the specification of a + subset query. + + Returns + ------- + dict + Dict which functions as the query specification + """ + + return { + "query_kind": "nocturnal_events", + "start": start, + "stop": stop, + "night_start_hour": hours[0], + "night_end_hour": hours[1], + "subscriber_subset": subscriber_subset, + } + + +def handset_spec( + *, + start_date: str, + end_date: str, + characteristic: str = "hnd_type", + method: str = "last", + subscriber_subset: Union[dict, None] = None, +) -> dict: + """ + Return query spec for handset + + Parameters + ---------- + start : str + ISO format date of the first day for which to count handsets, e.g. "2016-01-01" + stop : str + ISO format date of the day _after_ the final date for which to count handsets, e.g. "2016-01-08" + characteristic: {"hnd_type", "brand", "model", "software_os_name", "software_os_vendor"}, default "hnd_type" + The required handset characteristic. + method: {"last", "most-common"}, default "last" + Method for choosing a handset to associate with subscriber. + subscriber_subset : dict or None, default None + Subset of subscribers to include in event counts. Must be None + (= all subscribers) or a dictionary with the specification of a + subset query. + + Returns + ------- + dict + Dict which functions as the query specification + """ + return { + "query_kind": "handset", + "start_date": start_date, + "end_date": end_date, + "characteristic": characteristic, + "method": method, + "subscriber_subset": subscriber_subset, + } + + +def random_sample_spec( + *, + query: Dict[str, Union[str, dict]], + seed: float, + sampling_method: str = "random_ids", + size: Union[int, None] = None, + fraction: Union[float, None] = None, + estimate_count: bool = True, +) -> dict: + """ + Return spec for a random sample from a query result. + + Parameters + ---------- + query : dict + Specification of the query to be sampled. + sampling_method : {'system', 'bernoulli', 'random_ids'}, default 'random_ids' + Specifies the method used to select the random sample. + 'system': performs block-level sampling by randomly sampling each + physical storage page for the underlying relation. This + sampling method is not guaranteed to generate a sample of the + specified size, but an approximation. This method may not + produce a sample at all, so it might be worth running it again + if it returns an empty dataframe. + 'bernoulli': samples directly on each row of the underlying + relation. This sampling method is slower and is not guaranteed to + generate a sample of the specified size, but an approximation + 'random_ids': samples rows by randomly sampling the row number. + size : int, optional + The number of rows to draw. + Exactly one of the 'size' or 'fraction' arguments must be provided. + fraction : float, optional + Fraction of rows to draw. + Exactly one of the 'size' or 'fraction' arguments must be provided. + estimate_count : bool, default True + Whether to estimate the number of rows in the table using + information contained in the `pg_class` or whether to perform an + actual count in the number of rows. + seed : float + A seed for repeatable random samples. + If using random_ids method, seed must be between -/+1. + + Returns + ------- + dict + Dict which functions as the query specification. + """ + sampled_query = dict(query) + sampling = dict( + seed=seed, + sampling_method=sampling_method, + size=size, + fraction=fraction, + estimate_count=estimate_count, + ) + sampled_query["sampling"] = sampling + return sampled_query diff --git a/flowclient/setup.py b/flowclient/setup.py index 4ee4c44347..cc1440d03d 100644 --- a/flowclient/setup.py +++ b/flowclient/setup.py @@ -23,7 +23,7 @@ with open("README.md", "r") as fh: long_description = fh.read() -test_requirements = ["pytest", "pytest-cov"] +test_requirements = ["pytest>=5.4.0", "pytest-cov", "asynctest", "pytest-asyncio"] setup( name="flowclient", diff --git a/flowclient/test_results/pytest/results.xml b/flowclient/test_results/pytest/results.xml new file mode 100644 index 0000000000..00bbcdcdbd --- /dev/null +++ b/flowclient/test_results/pytest/results.xml @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/flowclient/tests/unit/test_async_api_query.py b/flowclient/tests/unit/test_async_api_query.py new file mode 100644 index 0000000000..cca8c3b59a --- /dev/null +++ b/flowclient/tests/unit/test_async_api_query.py @@ -0,0 +1,191 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +import pytest +from unittest.mock import Mock +from asynctest import Mock as AMock, CoroutineMock + +from flowclient.async_api_query import ASyncAPIQuery + + +@pytest.mark.asyncio +async def test_query_run(): + """ + Test that the 'run' method runs the query and records the query ID internally. + """ + connection_mock = AMock() + connection_mock.post_json = CoroutineMock( + return_value=Mock( + status_code=202, headers={"Location": "DUMMY_LOCATION/DUMMY_ID"} + ) + ) + query_spec = {"query_kind": "dummy_query"} + query = ASyncAPIQuery(connection=connection_mock, parameters=query_spec) + assert not hasattr(query, "_query_id") + await query.run() + connection_mock.post_json.assert_called_once_with(route="run", data=query_spec) + assert query._query_id == "DUMMY_ID" + + +def test_can_get_query_connection(): + """ + Test that 'connection' property returns the internal connection object + (e.g. so that token can be updated). + """ + connection_mock = Mock() + query = ASyncAPIQuery( + connection=connection_mock, parameters={"query_kind": "dummy_query"} + ) + assert query.connection is connection_mock + + +def test_cannot_replace_query_connection(): + """ + Test that 'connection' property does not allow setting a new connection + (which could invalidate internal state) + """ + query = ASyncAPIQuery(connection=Mock(), parameters={"query_kind": "dummy_query"}) + with pytest.raises(AttributeError, match="can't set attribute"): + query.connection = "NEW_CONNECTION" + + +@pytest.mark.asyncio +async def test_query_status(): + """ + Test that the 'status' property returns the status reported by the API. + """ + connection_mock = AMock() + connection_mock.post_json = CoroutineMock( + return_value=Mock( + status_code=202, headers={"Location": "DUMMY_LOCATION/DUMMY_ID"} + ) + ) + connection_mock.get_url = CoroutineMock( + return_value=Mock( + status_code=202, + json=Mock( + return_value={ + "status": "executing", + "progress": {"eligible": 0, "queued": 0, "running": 0}, + } + ), + ) + ) + query = ASyncAPIQuery( + connection=connection_mock, parameters={"query_kind": "dummy_query"} + ) + await query.run() + assert await query.status == "executing" + + +@pytest.mark.asyncio +async def test_query_status_not_running(): + """ + Test that the 'status' property returns 'not_running' if the query has not been set running. + """ + query = ASyncAPIQuery(connection=Mock(), parameters={"query_kind": "dummy_query"}) + assert await query.status == "not_running" + + +@pytest.mark.asyncio +async def test_wait_until_ready(monkeypatch): + """ + Test that wait_until_ready polls until query_is_ready returns True + """ + reply_mock = Mock( + json=Mock( + return_value={ + "status": "executing", + "progress": {"eligible": 0, "queued": 0, "running": 0}, + } + ) + ) + ready_mock = CoroutineMock(side_effect=[(False, reply_mock,), (True, reply_mock),]) + monkeypatch.setattr("flowclient.async_client.query_is_ready", ready_mock) + connection_mock = AMock() + connection_mock.post_json = CoroutineMock( + return_value=Mock( + status_code=202, headers={"Location": "DUMMY_LOCATION/DUMMY_ID"} + ) + ) + query = ASyncAPIQuery( + connection=connection_mock, parameters={"query_kind": "dummy_query"} + ) + await query.run() + await query.wait_until_ready() + + assert 2 == ready_mock.call_count + + +@pytest.mark.asyncio +async def test_wait_until_ready_raises(): + """ + Test that 'wait_until_ready' raises an error if the query has not been set running. + """ + query = ASyncAPIQuery(connection=Mock(), parameters={"query_kind": "dummy_query"}) + with pytest.raises(FileNotFoundError): + await query.wait_until_ready() + + +@pytest.mark.parametrize( + "format,function", + [ + ("pandas", "get_result_by_query_id"), + ("geojson", "get_geojson_result_by_query_id"), + ], +) +@pytest.mark.asyncio +async def test_query_get_result_pandas(monkeypatch, format, function): + get_result_mock = CoroutineMock(return_value="DUMMY_RESULT") + monkeypatch.setattr(f"flowclient.async_api_query.{function}", get_result_mock) + connection_mock = AMock() + connection_mock.post_json = CoroutineMock( + return_value=Mock( + status_code=202, headers={"Location": "DUMMY_LOCATION/DUMMY_ID"} + ) + ) + query = ASyncAPIQuery( + connection=connection_mock, parameters={"query_kind": "dummy_query"} + ) + await query.run() + assert "DUMMY_RESULT" == await query.get_result(format=format, poll_interval=2) + get_result_mock.assert_called_once_with( + connection=connection_mock, + disable_progress=None, + query_id="DUMMY_ID", + poll_interval=2, + ) + + +@pytest.mark.asyncio +async def test_query_get_result_runs(monkeypatch): + """ + Test that get_result runs the query if it's not already running. + """ + get_result_mock = CoroutineMock(return_value="DUMMY_RESULT") + monkeypatch.setattr( + f"flowclient.async_api_query.get_result_by_query_id", get_result_mock + ) + connection_mock = AMock() + query_spec = {"query_kind": "dummy_query"} + connection_mock.post_json = CoroutineMock( + return_value=Mock( + status_code=202, headers={"Location": "DUMMY_LOCATION/DUMMY_ID"} + ) + ) + query = ASyncAPIQuery(connection=connection_mock, parameters=query_spec) + await query.get_result() + connection_mock.post_json.assert_called_once_with(route="run", data=query_spec) + + +@pytest.mark.asyncio +async def test_query_get_result_invalid_format(): + """ + Test that get_result raises an error for format other than 'pandas' or 'geojson'. + """ + query = ASyncAPIQuery( + connection="DUMMY_CONNECTION", parameters={"query_kind": "dummy_query"} + ) + with pytest.raises(ValueError): + await query.get_result(format="INVALID_FORMAT") diff --git a/flowclient/tests/unit/test_async_client.py b/flowclient/tests/unit/test_async_client.py new file mode 100644 index 0000000000..d4fc4b7d3d --- /dev/null +++ b/flowclient/tests/unit/test_async_client.py @@ -0,0 +1,373 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from unittest.mock import Mock + +from asynctest import Mock as AMock, CoroutineMock + +import pytest + +from flowclient.async_client import ( + query_is_ready, + get_geography, + get_status, + get_json_dataframe, + get_geojson_result_by_query_id, + run_query, + get_available_dates, + get_result_location_from_id_when_ready, + get_result, + get_geojson_result, +) +from flowclient.errors import FlowclientConnectionError + + +@pytest.mark.asyncio +async def test_query_ready_reports_false(): + """ Test that status code 202 is interpreted as query running. """ + con_mock = AMock() + con_mock.get_url = CoroutineMock( + return_value=AMock( + status_code=202, + json=Mock( + return_value={ + "status": "completed", + "progress": {"eligible": 0, "queued": 0, "running": 0}, + } + ), + ) + ) + is_ready, reply = await query_is_ready(connection=con_mock, query_id="foo") + assert not is_ready + + +@pytest.mark.asyncio +async def test_query_ready_reports_true(): + """ Test that status code 303 is interpreted as query ready. """ + con_mock = AMock() + con_mock.get_url = CoroutineMock(return_value=AMock(status_code=303)) + is_ready, reply = await query_is_ready(connection=con_mock, query_id="foo") + assert is_ready + + +@pytest.mark.asyncio +async def test_query_ready_raises(): + """ Test that status codes other than 202, 303, 401, and 404 raise a generic error. """ + con_mock = AMock() + con_mock.get_url = CoroutineMock(return_value=AMock(status_code=999)) + with pytest.raises(FlowclientConnectionError): + await query_is_ready(connection=con_mock, query_id="foo") + + +@pytest.mark.asyncio +async def test_run_query_raises(): + con_mock = AMock() + con_mock.post_json = CoroutineMock( + return_value=Mock( + status_code=500, json=Mock(return_value=dict(msg="DUMMY_ERROR")) + ) + ) + with pytest.raises( + FlowclientConnectionError, + match="Error running the query: DUMMY_ERROR. Status code: 500.", + ): + await run_query(connection=con_mock, query_spec="foo") + + +@pytest.mark.asyncio +async def test_run_query_raises_with_default_error(): + con_mock = AMock() + con_mock.post_json = CoroutineMock( + return_value=Mock(status_code=500, json=Mock(return_value=dict())) + ) + with pytest.raises( + FlowclientConnectionError, + match="Error running the query: Unknown error. Status code: 500.", + ): + await run_query(connection=con_mock, query_spec="foo") + + +@pytest.mark.parametrize("http_code", [401, 500]) +@pytest.mark.asyncio +async def test_available_dates_error(http_code): + """ + Any unexpected http code should raise an exception. + """ + connection_mock = AMock() + connection_mock.get_url = CoroutineMock( + return_value=Mock( + status_code=http_code, json=Mock(return_value=dict(msg="MESSAGE")) + ) + ) + with pytest.raises( + FlowclientConnectionError, + match=f"Could not get available dates. API returned with status code: {http_code}. Reason: MESSAGE", + ): + await get_available_dates(connection=connection_mock, event_types=["FOOBAR"]) + + +@pytest.mark.parametrize( + "arg, expected", [(None, dict(DUMMY=1, DUMMY_2=1)), (["DUMMY"], dict(DUMMY=1))] +) +@pytest.mark.asyncio +async def test_available_dates(arg, expected): + """ + Dates should be returned and filtered. + """ + connection_mock = AMock() + connection_mock.get_url = CoroutineMock( + return_value=Mock( + status_code=200, + json=Mock(return_value=dict(available_dates=dict(DUMMY=1, DUMMY_2=1))), + ) + ) + assert ( + await get_available_dates(connection=connection_mock, event_types=arg) + == expected + ) + + +@pytest.mark.asyncio +async def test_get_result_location_from_id_when_ready(): + """ + Any unexpected http code should raise an exception. + """ + connection_mock = AMock() + connection_mock.get_url = CoroutineMock( + return_value=Mock( + status_code=303, headers=dict(Location="/api/0/DUMMY_LOCATION") + ) + ) + assert ( + await get_result_location_from_id_when_ready( + connection=connection_mock, query_id="DUMMY_ID" + ) + == "DUMMY_LOCATION" + ) + + +@pytest.mark.asyncio +async def test_available_dates_error_with_no_info(): + """ + Any unexpected http code should raise an exception. + """ + connection_mock = AMock() + connection_mock.get_url = CoroutineMock( + return_value=Mock(status_code=401, json=Mock(return_value=dict(msg="MESSAGE"))) + ) + with pytest.raises( + FlowclientConnectionError, + match=f"Could not get available dates. API returned with status code: 401.", + ): + await get_available_dates(connection=connection_mock, event_types=["FOOBAR"]) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("running_status", ["queued", "executing"]) +async def test_get_status_reports_running(running_status): + """ Test that status code 202 is interpreted as query running or queued. """ + con_mock = AMock() + con_mock.get_url = CoroutineMock( + return_value=Mock( + status_code=202, + json=Mock( + return_value={ + "status": running_status, + "progress": {"eligible": 0, "queued": 0, "running": 0}, + } + ), + ) + ) + status = await get_status(connection=con_mock, query_id="foo") + assert status == running_status + + +@pytest.mark.asyncio +async def test_get_status_reports_finished(): + """ Test that status code 303 is interpreted as query finished. """ + con_mock = AMock() + con_mock.get_url = CoroutineMock(return_value=Mock(status_code=303)) + status = await get_status(connection=con_mock, query_id="foo") + assert status == "completed" + + +@pytest.mark.asyncio +async def test_get_status_404(): + """ Test that get_status reports that a query is not running. """ + con_mock = AMock() + con_mock.get_url = CoroutineMock(side_effect=FileNotFoundError("DUMMY_404")) + status_returned = await get_status(connection=con_mock, query_id="foo") + assert status_returned == "not_running" + + +@pytest.mark.asyncio +async def test_get_status_raises(): + """ Test that get_status raises an error for a status code other than 202, 303 or 404. """ + con_mock = AMock() + con_mock.get_url = CoroutineMock(return_value=Mock(status_code=500)) + with pytest.raises(FlowclientConnectionError): + await get_status(connection=con_mock, query_id="foo") + + +@pytest.mark.asyncio +async def test_get_json_dataframe(): + """ Test that get_json_dataframe returns results. """ + con_mock = AMock() + con_mock.get_url = CoroutineMock( + return_value=Mock( + status_code=200, json=Mock(return_value=dict(query_result=[{"0": 1}])) + ) + ) + assert ( + await get_json_dataframe(connection=con_mock, location="foo") + ).values.tolist() == [[1]] + + +@pytest.mark.asyncio +async def test_get_json_dataframe_raises(): + """ Test that get_json_dataframe raises an error. """ + con_mock = AMock() + con_mock.get_url = CoroutineMock( + return_value=Mock( + status_code=500, json=Mock(return_value=dict(msg="DUMMY_ERROR")) + ) + ) + with pytest.raises(FlowclientConnectionError, match=r".*Reason: DUMMY_ERROR"): + await get_json_dataframe(connection=con_mock, location="foo") + + +@pytest.mark.asyncio +async def test_get_geojson_result_by_query_id_raises(monkeypatch): + """ Test that get_geojson_result_by_query_id raises an error. """ + con_mock = AMock() + con_mock.get_url = CoroutineMock( + return_value=Mock( + status_code=500, json=Mock(return_value=dict(msg="DUMMY_ERROR")) + ) + ) + monkeypatch.setattr( + "flowclient.async_client.get_result_location_from_id_when_ready", + CoroutineMock(return_value="DUMMY"), + ) + with pytest.raises(FlowclientConnectionError, match=r".*Reason: DUMMY_ERROR"): + await get_geojson_result_by_query_id(connection=con_mock, query_id="foo") + + +@pytest.mark.asyncio +async def test_get_status_raises_without_status(): + """ Test that get_status raises an error if the status field is absent. """ + con_mock = AMock() + con_mock.get_url = CoroutineMock( + return_value=Mock( + status_code=202, + json=Mock( + return_value=dict(progress=dict(queued=0, running=0, eligible=0)) + ), + ) + ) + with pytest.raises(FlowclientConnectionError, match="No status reported"): + await get_status(connection=con_mock, query_id="foo") + + +@pytest.mark.asyncio +async def test_get_geography(token): + """ + Test that getting geography returns the returned dict + """ + connection_mock = AMock() + connection_mock.get_url = CoroutineMock( + return_value=Mock(status_code=200, json=Mock(return_value={"some": "json"})) + ) + gj = await get_geography( + connection=connection_mock, aggregation_unit="DUMMY_AGGREGATION" + ) + assert {"some": "json"} == gj + + +@pytest.mark.parametrize("http_code", [401, 404, 418, 400]) +@pytest.mark.asyncio +async def test_get_geography_error(http_code, token): + """ + Any unexpected http code should raise an exception. + """ + connection_mock = AMock() + connection_mock.get_url = CoroutineMock( + return_value=Mock( + status_code=http_code, json=Mock(return_value={"msg": "MESSAGE"}) + ) + ) + with pytest.raises( + FlowclientConnectionError, + match=f"Could not get result. API returned with status code: {http_code}. Reason: MESSAGE", + ): + await get_geography( + connection=connection_mock, aggregation_unit="DUMMY_AGGREGATION" + ) + + +@pytest.mark.asyncio +async def test_get_geography_no_msg_error(token): + """ + A response with an unexpected http code and no "msg" should raise a FlowclientConnectionError. + """ + connection_mock = AMock() + connection_mock.get_url = CoroutineMock( + return_value=Mock(status_code=404, json=Mock(return_value={})) + ) + with pytest.raises( + FlowclientConnectionError, + match=f"Could not get result. API returned with status code: 404.", + ): + await get_geography( + connection=connection_mock, aggregation_unit="DUMMY_AGGREGATION" + ) + + +@pytest.mark.asyncio +async def test_get_result(): + con_mock = AMock() + con_mock.post_json = CoroutineMock( + return_value=Mock( + status_code=202, + json=Mock(return_value=dict(msg="DUMMY_ERROR"),), + headers=dict(Location="DUMMY"), + ), + ) + con_mock.get_url = CoroutineMock( + side_effect=[ + Mock(status_code=303, headers=dict(Location="DUMMY"),), + Mock( + status_code=200, + headers=dict(Location="DUMMY"), + json=Mock(return_value=dict(query_result=[{"0": 1}])), + ), + ] + ) + assert ( + await get_result(connection=con_mock, query_spec="foo") + ).values.tolist() == [[1]] + + +@pytest.mark.asyncio +async def test_get_geojson_result(): + con_mock = AMock() + con_mock.post_json = CoroutineMock( + return_value=Mock( + status_code=202, + json=Mock(return_value=dict(msg="DUMMY_ERROR"),), + headers=dict(Location="DUMMY"), + ), + ) + con_mock.get_url = CoroutineMock( + side_effect=[ + Mock(status_code=303, headers=dict(Location="DUMMY"),), + Mock( + status_code=200, + headers=dict(Location="DUMMY"), + json=Mock(return_value=dict(query_result=[{"0": 1}])), + ), + ] + ) + assert (await get_geojson_result(connection=con_mock, query_spec="foo")) == { + "query_result": [{"0": 1}] + } diff --git a/flowclient/tests/unit/test_async_connection.py b/flowclient/tests/unit/test_async_connection.py new file mode 100644 index 0000000000..78ff8564fc --- /dev/null +++ b/flowclient/tests/unit/test_async_connection.py @@ -0,0 +1,34 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +from unittest.mock import Mock + +import pytest + +from flowclient import ASyncConnection + + +@pytest.mark.asyncio +async def test_get_url(monkeypatch, token): + monkeypatch.setattr( + "flowclient.connection.Connection.get_url", Mock(return_value="DUMMY_RETURN") + ) + con = ASyncConnection(url="DUMMY_URL", token=token) + assert await con.get_url(route="DUMMY_ROUTE", data="DUMMY_DATA") + + +@pytest.mark.asyncio +async def test_post_json(monkeypatch, token): + monkeypatch.setattr( + "flowclient.connection.Connection.post_json", Mock(return_value="DUMMY_RETURN") + ) + con = ASyncConnection(url="DUMMY_URL", token=token) + assert await con.post_json(route="DUMMY_ROUTE", data="DUMMY_DATA") + + +def test_make_query_object(monkeypatch, token): + con = ASyncConnection(url="DUMMY_URL", token=token) + dummy_params = dict(dummy_params="DUMMY_PARAMS") + assert con.make_api_query(dummy_params)._connection == con + assert con.make_api_query(dummy_params).parameters == dummy_params diff --git a/flowclient/tests/unit/test_connection.py b/flowclient/tests/unit/test_connection.py index 97a0d15049..ba94663558 100644 --- a/flowclient/tests/unit/test_connection.py +++ b/flowclient/tests/unit/test_connection.py @@ -6,7 +6,7 @@ import pytest import flowclient -from flowclient.client import FlowclientConnectionError +from flowclient.errors import FlowclientConnectionError def test_update_token(session_mock, token): diff --git a/flowclient/tests/unit/test_get_available_dates.py b/flowclient/tests/unit/test_get_available_dates.py index abdddb6bda..6236eebb54 100644 --- a/flowclient/tests/unit/test_get_available_dates.py +++ b/flowclient/tests/unit/test_get_available_dates.py @@ -5,11 +5,12 @@ from unittest.mock import Mock import pytest -from flowclient.client import FlowclientConnectionError, get_available_dates +from flowclient.client import get_available_dates +from flowclient.errors import FlowclientConnectionError @pytest.mark.parametrize("http_code", [401, 500]) -def test_get_geography_error(http_code): +def test_available_dates_error(http_code): """ Any unexpected http code should raise an exception. """ @@ -21,3 +22,17 @@ def test_get_geography_error(http_code): match=f"Could not get available dates. API returned with status code: {http_code}. Reason: MESSAGE", ): get_available_dates(connection=connection_mock, event_types=["FOOBAR"]) + + +def test_available_dates_error_with_no_info(): + """ + Any unexpected http code should raise an exception. + """ + connection_mock = Mock() + connection_mock.get_url.return_value.status_code = 401 + connection_mock.get_url.return_value.json.return_value = {} + with pytest.raises( + FlowclientConnectionError, + match=f"Could not get available dates. API returned with status code: 401.", + ): + get_available_dates(connection=connection_mock, event_types=["FOOBAR"]) diff --git a/flowclient/tests/unit/test_get_geography.py b/flowclient/tests/unit/test_get_geography.py index ba6d613f22..b9c0f209e6 100644 --- a/flowclient/tests/unit/test_get_geography.py +++ b/flowclient/tests/unit/test_get_geography.py @@ -6,7 +6,8 @@ import pytest import flowclient -from flowclient.client import FlowclientConnectionError, get_geography +from flowclient.client import get_geography +from flowclient.errors import FlowclientConnectionError def test_get_geography(token): diff --git a/flowclient/tests/unit/test_get_status.py b/flowclient/tests/unit/test_get_status.py index 817ab0bfe2..e341d799e1 100644 --- a/flowclient/tests/unit/test_get_status.py +++ b/flowclient/tests/unit/test_get_status.py @@ -6,7 +6,8 @@ import pytest -from flowclient.client import get_status, FlowclientConnectionError +from flowclient.client import get_status +from flowclient.errors import FlowclientConnectionError @pytest.mark.parametrize("running_status", ["queued", "executing"]) diff --git a/flowclient/tests/unit/test_get_url.py b/flowclient/tests/unit/test_get_url.py index 1cbf410ac1..9641b39712 100644 --- a/flowclient/tests/unit/test_get_url.py +++ b/flowclient/tests/unit/test_get_url.py @@ -4,7 +4,7 @@ import pytest from requests import ConnectionError import flowclient -from flowclient.client import FlowclientConnectionError +from flowclient.errors import FlowclientConnectionError @pytest.mark.parametrize("status_code", [200, 202, 303]) diff --git a/flowclient/tests/unit/test_login.py b/flowclient/tests/unit/test_login.py index 8f03392529..368205e14f 100644 --- a/flowclient/tests/unit/test_login.py +++ b/flowclient/tests/unit/test_login.py @@ -6,8 +6,8 @@ import jwt import pytest -from flowclient.client import Connection, FlowclientConnectionError - +from flowclient.errors import FlowclientConnectionError +from flowclient import Connection pytestmark = pytest.mark.usefixtures("session_mock") diff --git a/flowclient/tests/unit/test_post_json.py b/flowclient/tests/unit/test_post_json.py index 7f7f937696..7e5dc993af 100644 --- a/flowclient/tests/unit/test_post_json.py +++ b/flowclient/tests/unit/test_post_json.py @@ -6,7 +6,7 @@ from requests import ConnectionError import flowclient -from flowclient.client import FlowclientConnectionError +from flowclient.errors import FlowclientConnectionError from .zmq_helpers import ZMQReply diff --git a/flowclient/tests/unit/test_queries.py b/flowclient/tests/unit/test_queries.py index 7c1bf7cfed..99e4412bdf 100644 --- a/flowclient/tests/unit/test_queries.py +++ b/flowclient/tests/unit/test_queries.py @@ -8,12 +8,12 @@ import flowclient import flowclient.client from flowclient.client import ( - Connection, - FlowclientConnectionError, get_result_by_query_id, get_result, query_is_ready, ) +from flowclient.errors import FlowclientConnectionError +from flowclient import Connection def test_get_result_by_params(monkeypatch, token): diff --git a/flowclient/tests/unit/test_query_ready.py b/flowclient/tests/unit/test_query_ready.py index 7148e9458a..c234ade0ea 100644 --- a/flowclient/tests/unit/test_query_ready.py +++ b/flowclient/tests/unit/test_query_ready.py @@ -7,7 +7,7 @@ import pytest from flowclient.client import query_is_ready -from flowclient.client import FlowclientConnectionError +from flowclient.errors import FlowclientConnectionError def test_query_ready_reports_false(): diff --git a/integration_tests/tests/query_tests/test_queries.py b/integration_tests/tests/query_tests/test_queries.py index fb198bca5e..0f334abc69 100644 --- a/integration_tests/tests/query_tests/test_queries.py +++ b/integration_tests/tests/query_tests/test_queries.py @@ -5,11 +5,14 @@ import geojson import flowclient -from flowclient.client import get_result import pytest +@pytest.mark.asyncio +@pytest.mark.parametrize( + "connection", [flowclient.Connection, flowclient.ASyncConnection] +) @pytest.mark.parametrize( "query", [ @@ -544,20 +547,27 @@ ], ids=lambda val: val.func.__name__, ) -def test_run_query(query, universal_access_token, flowapi_url): +async def test_run_query(connection, query, universal_access_token, flowapi_url): """ Test that queries can be run, and return a QueryResult object. """ - con = flowclient.Connection(url=flowapi_url, token=universal_access_token) + con = connection(url=flowapi_url, token=universal_access_token) - query(connection=con).get_result() + try: + await query(connection=con).get_result() + except TypeError: + query(connection=con).get_result() # Ideally we'd check the contents, but several queries will be totally redacted and therefore empty # so we can only check it runs without erroring -def test_geo_result(universal_access_token, flowapi_url): +@pytest.mark.asyncio +@pytest.mark.parametrize( + "connection", [flowclient.Connection, flowclient.ASyncConnection] +) +async def test_geo_result(connection, universal_access_token, flowapi_url): query = flowclient.joined_spatial_aggregate( - connection=flowclient.Connection(url=flowapi_url, token=universal_access_token), + connection=connection(url=flowapi_url, token=universal_access_token), **{ "locations": flowclient.daily_location_spec( date="2016-01-01", aggregation_unit="admin3", method="last" @@ -572,75 +582,83 @@ def test_geo_result(universal_access_token, flowapi_url): } ) - result = query.get_result(format="geojson") + try: + result = await query.get_result(format="geojson") + except TypeError: + result = query.get_result(format="geojson") assert geojson.GeoJSON(result).is_valid +@pytest.mark.asyncio +@pytest.mark.parametrize( + "connection", [flowclient.Connection, flowclient.ASyncConnection] +) @pytest.mark.parametrize( - "query_kind, params", + "query", [ - ( + partial( flowclient.joined_spatial_aggregate, - { - "locations": { - "query_kind": "daily_location", - "date": "2016-01-01", - "aggregation_unit": "admin3", - "method": "last", - }, - "metric": { - "query_kind": "topup_balance", - "start_date": "2016-01-01", - "end_date": "2016-01-02", - "statistic": "avg", - }, - "method": "distr", - }, + locations=flowclient.daily_location_spec( + date="2016-01-01", aggregation_unit="admin3", method="last", + ), + metric=flowclient.topup_balance_spec( + start_date="2016-01-01", end_date="2016-01-02", statistic="avg", + ), + method="distr", ), - ( + partial( flowclient.joined_spatial_aggregate, - { - "locations": { - "query_kind": "daily_location", - "date": "2016-01-01", - "aggregation_unit": "admin3", - "method": "last", - }, - "metric": { - "query_kind": "handset", - "start_date": "2016-01-01", - "end_date": "2016-01-02", - "characteristic": "hnd_type", - "method": "last", - }, - "method": "avg", - }, + locations=flowclient.daily_location_spec( + date="2016-01-01", aggregation_unit="admin3", method="last", + ), + metric=flowclient.handset_spec( + start_date="2016-01-01", + end_date="2016-01-02", + characteristic="hnd_type", + method="last", + ), + method="avg", ), ], ) -def test_fail_query_incorrect_parameters( - query_kind, params, universal_access_token, flowapi_url +async def test_fail_query_incorrect_parameters( + connection, query, universal_access_token, flowapi_url ): """ Test that queries fail with incorrect parameters. """ - con = flowclient.Connection(url=flowapi_url, token=universal_access_token) - query = query_kind(connection=con, **params) + con = connection(url=flowapi_url, token=universal_access_token) with pytest.raises( flowclient.client.FlowclientConnectionError, match="Must be one of:" ): - result_dataframe = query.get_result() + try: + await query(connection=con).get_result() + except TypeError: + query(connection=con).get_result() -def test_get_geography(access_token_builder, flowapi_url): +@pytest.mark.asyncio +@pytest.mark.parametrize( + "connection, module", + [ + (flowclient.Connection, flowclient), + (flowclient.ASyncConnection, flowclient.async_client), + ], +) +async def test_get_geography(connection, module, access_token_builder, flowapi_url): """ Test that queries can be run, and return a GeoJSON dict. """ - con = flowclient.Connection( + con = connection( url=flowapi_url, token=access_token_builder(["get_result&geography.aggregation_unit.admin3"]), ) - result_geojson = flowclient.get_geography(connection=con, aggregation_unit="admin3") + try: + result_geojson = await module.get_geography( + connection=con, aggregation_unit="admin3" + ) + except TypeError: + result_geojson = module.get_geography(connection=con, aggregation_unit="admin3") assert "FeatureCollection" == result_geojson["type"] assert 0 < len(result_geojson["features"]) feature0 = result_geojson["features"][0] @@ -651,6 +669,14 @@ def test_get_geography(access_token_builder, flowapi_url): assert 0 < len(feature0["geometry"]["coordinates"]) +@pytest.mark.asyncio +@pytest.mark.parametrize( + "connection, module", + [ + (flowclient.Connection, flowclient), + (flowclient.ASyncConnection, flowclient.async_client), + ], +) @pytest.mark.parametrize( "event_types, expected_result", [ @@ -711,14 +737,19 @@ def test_get_geography(access_token_builder, flowapi_url): ), ], ) -def test_get_available_dates( - event_types, expected_result, access_token_builder, flowapi_url +async def test_get_available_dates( + connection, module, event_types, expected_result, access_token_builder, flowapi_url ): """ Test that queries can be run, and return the expected JSON result. """ - con = flowclient.Connection( + con = connection( url=flowapi_url, token=access_token_builder(["get_result&available_dates"]), ) - result = flowclient.get_available_dates(connection=con, event_types=event_types) + try: + result = await module.get_available_dates( + connection=con, event_types=event_types + ) + except TypeError: + result = module.get_available_dates(connection=con, event_types=event_types) assert expected_result == result