diff --git a/CHANGELOG.md b/CHANGELOG.md index 12b865a0b6..475648f22c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). - FlowAuth now makes version information available at `/version` and displays it in the web ui. [#835](https://github.com/Flowminder/FlowKit/issues/835) - FlowETL now comes with a deployment example (in `flowetl/deployment_example/`). [#1126](https://github.com/Flowminder/FlowKit/issues/1126) - FlowETL now allows to run supplementary post-ETL queries. [#989](https://github.com/Flowminder/FlowKit/issues/989) +- Random sampling is now exposed via the API, for all non-aggregated query kinds. [#1007](https://github.com/Flowminder/FlowKit/issues/1007) ### Changed - FlowDB is now based on PostgreSQL 11.5 and PostGIS 2.5.3 diff --git a/flowclient/flowclient/__init__.py b/flowclient/flowclient/__init__.py index 8fbbeef693..28d2e7c407 100644 --- a/flowclient/flowclient/__init__.py +++ b/flowclient/flowclient/__init__.py @@ -44,6 +44,7 @@ pareto_interactions, nocturnal_events, handset, + random_sample, ) __all__ = [ @@ -80,4 +81,5 @@ "pareto_interactions", "nocturnal_events", "handset", + "random_sample", ] diff --git a/flowclient/flowclient/client.py b/flowclient/flowclient/client.py index e3f0927555..39dd53709d 100644 --- a/flowclient/flowclient/client.py +++ b/flowclient/flowclient/client.py @@ -1100,7 +1100,7 @@ def location_introversion( Unit of aggregation, e.g. "admin3" direction : {"in", "out", "both"}, default "both" Optionally, include only ingoing or outbound calls/texts can be one of "in", "out" or "both" -> + Returns ------- dict @@ -1131,6 +1131,7 @@ def total_network_objects( Unit of aggregation, e.g. "admin3" total_by : {"second", "minute", "hour", "day", "month", "year"} Time period to bucket by one of "second", "minute", "hour", "day", "month" or "year" + Returns ------- dict @@ -1276,6 +1277,7 @@ def unique_location_counts( Subset of subscribers to include in event counts. Must be None (= all subscribers) or a dictionary with the specification of a subset query. + Returns ------- dict @@ -1349,6 +1351,7 @@ def subscriber_degree( Subset of subscribers to include in event counts. Must be None (= all subscribers) or a dictionary with the specification of a subset query. + Returns ------- dict @@ -1385,6 +1388,7 @@ def topup_amount( Subset of subscribers to include in event counts. Must be None (= all subscribers) or a dictionary with the specification of a subset query. + Returns ------- dict @@ -1425,6 +1429,7 @@ def event_count( Subset of subscribers to include in event counts. Must be None (= all subscribers) or a dictionary with the specification of a subset query. + Returns ------- dict @@ -1465,6 +1470,7 @@ def displacement( Subset of subscribers to include in event counts. Must be None (= all subscribers) or a dictionary with the specification of a subset query. + Returns ------- dict @@ -1502,6 +1508,7 @@ def pareto_interactions( Subset of subscribers to include in result. Must be None (= all subscribers) or a dictionary with the specification of a subset query. + Returns ------- dict @@ -1520,7 +1527,7 @@ def nocturnal_events( *, start: str, stop: str, - hours: tuple((int, int)), + hours: Tuple[int, int], subscriber_subset: Union[dict, None] = None, ) -> dict: """ @@ -1538,6 +1545,7 @@ def nocturnal_events( Subset of subscribers to include in event counts. Must be None (= all subscribers) or a dictionary with the specification of a subset query. + Returns ------- dict @@ -1557,14 +1565,8 @@ def handset( *, start_date: str, end_date: str, - characteristic: str = [ - "hnd_type", - "brand", - "model", - "software_os_name", - "software_os_vendor", - ], - method: str = ["last", "most-common"], + characteristic: str = "hnd_type", + method: str = "last", subscriber_subset: Union[dict, None] = None, ) -> dict: """ @@ -1576,14 +1578,15 @@ def handset( ISO format date of the first day for which to count handsets, e.g. "2016-01-01" stop : str ISO format date of the day _after_ the final date for which to count handsets, e.g. "2016-01-08" - characteristic: ["hnd_type", "brand", "model", "software_os_name", "software_os_vendor"], default "hnd_type" + characteristic: {"hnd_type", "brand", "model", "software_os_name", "software_os_vendor"}, default "hnd_type" The required handset characteristic. - method: ["last", "most-common"], default "last" + method: {"last", "most-common"}, default "last" Method for choosing a handset to associate with subscriber. subscriber_subset : dict or None, default None Subset of subscribers to include in event counts. Must be None (= all subscribers) or a dictionary with the specification of a subset query. + Returns ------- dict @@ -1597,3 +1600,69 @@ def handset( "method": method, "subscriber_subset": subscriber_subset, } + + +def random_sample( + *, + query: Dict[str, Union[str, dict]], + sampling_method: str = "system_rows", + size: Union[int, None] = None, + fraction: Union[float, None] = None, + estimate_count: bool = True, + seed: Union[float, None] = None, +) -> dict: + """ + Return spec for a random sample from a query result. + + Parameters + ---------- + query : dict + Specification of the query to be sampled. + sampling_method : {'system_rows', 'system', 'bernoulli', 'random_ids'}, default 'system_rows' + Specifies the method used to select the random sample. + 'system_rows': performs block-level sampling by randomly sampling + each physical storage page of the underlying relation. This + sampling method is guaranteed to provide a sample of the specified + size + 'system': performs block-level sampling by randomly sampling each + physical storage page for the underlying relation. This + sampling method is not guaranteed to generate a sample of the + specified size, but an approximation. This method may not + produce a sample at all, so it might be worth running it again + if it returns an empty dataframe. + 'bernoulli': samples directly on each row of the underlying + relation. This sampling method is slower and is not guaranteed to + generate a sample of the specified size, but an approximation + 'random_ids': samples rows by randomly sampling the row number. + size : int, optional + The number of rows to draw. + Exactly one of the 'size' or 'fraction' arguments must be provided. + fraction : float, optional + Fraction of rows to draw. + Exactly one of the 'size' or 'fraction' arguments must be provided. + estimate_count : bool, default True + Whether to estimate the number of rows in the table using + information contained in the `pg_class` or whether to perform an + actual count in the number of rows. + seed : float, optional + Optionally provide a seed for repeatable random samples. + If using random_ids method, seed must be between -/+1. + Not available in combination with the system_rows method. + + Returns + ------- + dict + Dict which functions as the query specification. + """ + sampled_query = dict(query) + sampling = { + "sampling_method": sampling_method, + "size": size, + "fraction": fraction, + "estimate_count": estimate_count, + } + if seed is not None: + # 'system_rows' method doesn't accept a seed parameter, so if seed is None we don't include it in the spec + sampling["seed"] = seed + sampled_query["sampling"] = sampling + return sampled_query diff --git a/flowdb/sql/functions_001_utilities.sql b/flowdb/sql/functions_001_utilities.sql index bf7b27d03f..d3d8c142cf 100644 --- a/flowdb/sql/functions_001_utilities.sql +++ b/flowdb/sql/functions_001_utilities.sql @@ -197,18 +197,11 @@ CREATE OR REPLACE FUNCTION random_ints (seed DOUBLE PRECISION, n_samples INT, ma RETURNS TABLE (id INT) AS $$ DECLARE new_seed NUMERIC; -DECLARE samples double precision[] := array[]::double precision[]; BEGIN new_seed = random(); PERFORM setseed(seed); - FOR i in 1..n_samples LOOP - samples := array_append(samples, random()); - END LOOP; + RETURN QUERY SELECT generate_series AS id FROM generate_series(1, max_val) ORDER BY random() LIMIT n_samples; PERFORM setseed(new_seed); - RETURN QUERY SELECT - round(samples[generate_series] * max_val)::integer as id - FROM generate_series(1, n_samples) - GROUP BY id; END; $$ LANGUAGE plpgsql diff --git a/flowdb/tests/test_utility_functions.py b/flowdb/tests/test_utility_functions.py index 62b6e305e4..3558313e6b 100644 --- a/flowdb/tests/test_utility_functions.py +++ b/flowdb/tests/test_utility_functions.py @@ -93,8 +93,18 @@ def test_seeded_random_ints(cursor): """Seeded random integers should return some predictable outputs.""" sql = "SELECT * from random_ints(0, 5, 10)" cursor.execute(sql) + first_vals = [x["id"] for x in cursor.fetchall()] + cursor.execute(sql) + second_vals = [x["id"] for x in cursor.fetchall()] + assert first_vals == second_vals + + +def test_random_ints_n_samples(cursor): + """random_ints should return the requested number of random integers.""" + sql = "SELECT * from random_ints(0, 5, 10)" + cursor.execute(sql) vals = [x["id"] for x in cursor.fetchall()] - assert [9, 4, 8] == vals + assert len(vals) == 5 def test_seeded_random_ints_seed_reset(cursor): diff --git a/flowmachine/flowmachine/core/cache.py b/flowmachine/flowmachine/core/cache.py index 7774ab4d11..732ad2c3cd 100644 --- a/flowmachine/flowmachine/core/cache.py +++ b/flowmachine/flowmachine/core/cache.py @@ -166,7 +166,7 @@ def write_cache_metadata( try: self_storage = pickle.dumps(query) except Exception as e: - logger.debug("Can't pickle ({e}), attempting to cache anyway.") + logger.debug(f"Can't pickle ({e}), attempting to cache anyway.") pass try: diff --git a/flowmachine/flowmachine/core/query.py b/flowmachine/flowmachine/core/query.py index a2463fa409..2486f176a7 100644 --- a/flowmachine/flowmachine/core/query.py +++ b/flowmachine/flowmachine/core/query.py @@ -706,63 +706,6 @@ def store(self): store_future = self.to_sql(name, schema=schema) return store_future - def _db_store_cache_metadata(self, compute_time=None): - """ - Helper function for store, updates flowmachine metadata table to - log that this query is stored, but does not actually store - the query. - """ - - from ..__init__ import __version__ - - con = self.connection.engine - - self_storage = b"" - try: - self_storage = pickle.dumps(self) - except: - logger.debug("Can't pickle, attempting to cache anyway.") - pass - - try: - in_cache = bool( - self.connection.fetch( - f"SELECT * FROM cache.cached WHERE query_id='{self.md5}'" - ) - ) - - with con.begin(): - cache_record_insert = """ - INSERT INTO cache.cached - (query_id, version, query, created, access_count, last_accessed, compute_time, - cache_score_multiplier, class, schema, tablename, obj) - VALUES (%s, %s, %s, NOW(), 0, NOW(), %s, 0, %s, %s, %s, %s) - ON CONFLICT (query_id) DO UPDATE SET last_accessed = NOW();""" - con.execute( - cache_record_insert, - ( - self.md5, - __version__, - self._make_query(), - compute_time, - self.__class__.__name__, - *self.fully_qualified_table_name.split("."), - psycopg2.Binary(self_storage), - ), - ) - con.execute("SELECT touch_cache(%s);", self.md5) - logger.debug( - "{} added to cache.".format(self.fully_qualified_table_name) - ) - if not in_cache: - for dep in self._get_stored_dependencies(exclude_self=True): - con.execute( - "INSERT INTO cache.dependencies values (%s, %s) ON CONFLICT DO NOTHING", - (self.md5, dep.md5), - ) - except NotImplementedError: - logger.debug("Table has no standard name.") - @property def dependencies(self): """ @@ -1026,24 +969,13 @@ def get_stored(cls): objs = Query.connection.fetch(qry) return (pickle.loads(obj[0]) for obj in objs) - def random_sample( - self, - size=None, - fraction=None, - method="system_rows", - estimate_count=True, - seed=None, - ): + def random_sample(self, sampling_method="system_rows", **params): """ Draws a random sample from this query. Parameters ---------- - size : int - Number of rows to draw - fraction : float - Fraction of total rows to draw - method : {'system', 'system_rows', 'bernoulli', 'random_ids'}, default 'system_rows' + sampling_method : {'system', 'system_rows', 'bernoulli', 'random_ids'}, default 'system_rows' Specifies the method used to select the random sample. 'system_rows': performs block-level sampling by randomly sampling each physical storage page of the underlying relation. This @@ -1058,15 +990,20 @@ def random_sample( 'bernoulli': samples directly on each row of the underlying relation. This sampling method is slower and is not guaranteed to generate a sample of the specified size, but an approximation - 'random_ids': Assumes that the table contains a column named 'id' - with random numbers from 1 to the total number of rows in the - table. This method samples the ids from this table. + 'random_ids': samples rows by randomly sampling the row number. + size : int, optional + The number of rows to draw. + Exactly one of the 'size' or 'fraction' arguments must be provided. + fraction : float, optional + Fraction of rows to draw. + Exactly one of the 'size' or 'fraction' arguments must be provided. estimate_count : bool, default True Whether to estimate the number of rows in the table using information contained in the `pg_class` or whether to perform an actual count in the number of rows. seed : float, optional - Optionally provide a seed for repeatable random samples, which should be between -/+1. + Optionally provide a seed for repeatable random samples. + If using random_ids method, seed must be between -/+1. Not available in combination with the system_rows method. Returns @@ -1076,28 +1013,13 @@ def random_sample( See Also -------- - flowmachine.utils.random.random_factory + flowmachine.core.random.random_factory Notes ----- - Random samples may only be stored if a seed is supplied. - """ - if seed is not None: - if seed > 1 or seed < -1: - raise ValueError("Seed must be between -1 and 1.") - if method == "system_rows": - raise ValueError("Seed is not supported with system_rows method.") - from .random import random_factory - random_class = random_factory(self.__class__) - return random_class( - query=self, - size=size, - fraction=fraction, - method=method, - estimate_count=estimate_count, - seed=seed, - ) + random_class = random_factory(self.__class__, sampling_method=sampling_method) + return random_class(query=self, **params) diff --git a/flowmachine/flowmachine/core/random.py b/flowmachine/flowmachine/core/random.py index 23069490e5..a245852e62 100644 --- a/flowmachine/flowmachine/core/random.py +++ b/flowmachine/flowmachine/core/random.py @@ -3,293 +3,520 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. """ -Classes to select random samples from the database. The random samples could be -random events, msisdns, sites or a given geographical boundary. All the samples -can be subsetted by time. +Classes to select random samples from queries or tables. """ import random -from typing import List +from typing import List, Optional, Dict, Any, Union, Type, Tuple +from abc import ABCMeta, abstractmethod from .query import Query from .table import Table -class RandomBase: +class _RandomGetter: """ - Base class containing method to construct a query used to obtain a random sample from a table. + Helper class for pickling/unpickling of dynamic random classes. + (see https://stackoverflow.com/questions/1947904/how-can-i-pickle-a-nested-class-in-python/11493777#11493777) """ - def _inheritance_check(self): + def __call__(self, query: Query, sampling_method: str, params: Dict[str, Any]): + return query.random_sample(sampling_method, **params) + + +class RandomBase(metaclass=ABCMeta): + """ + Base class for queries used to obtain a random sample from a table. + """ + + def __init__( + self, + query: Query, + *, + size: Optional[int] = None, + fraction: Optional[float] = None, + estimate_count: bool = True, + ): + if size is None and fraction is None: + raise ValueError( + f"{self.__class__.__name__}() missing 1 required argument: 'size' or 'fraction'" + ) + if size is not None and fraction is not None: + raise ValueError( + f"{self.__class__.__name__}() expects only 1 argument to be defined: either 'size' or 'fraction'" + ) + + if fraction is not None and (fraction < 0 or fraction > 1): + raise ValueError( + f"{self.__class__.__name__}() expects fraction between 0 and 1." + ) + + self.query = query + self.size = size + self.fraction = fraction + self.estimate_count = estimate_count + + @property + def _sample_params(self) -> Dict[str, Any]: + """ + Parameters passed when initialising this query. """ - Raise a value error if the query is a table, and has children, and the method relies - on it not having children. + return { + "size": self.size, + "fraction": self.fraction, + "estimate_count": self.estimate_count, + } + + def _count_table_rows(self) -> int: + """ + Return a count of the number of rows in self.query table, either using + information contained in the `pg_class` (if self.estimate_rowcount) or + by performing an actual count in the number of rows. """ - if (self.method != "system_rows") or (not isinstance(self.query, Table)): - return - if self.query.has_children(): + rowcount = 0 + + if self.estimate_count: + table = self.query.get_table() + rowcount = table.estimated_rowcount() + + if not self.estimate_count or rowcount == 0: + rowcount = len(self.query) + + return rowcount + + @abstractmethod + def _make_query(self): + raise NotImplementedError( + f"Class {self.__class__.__name__} does not implement the _make_query method." + ) + + # Overwrite the table_name method so that it cannot + # be stored by accident. + @property + def table_name(self): + raise NotImplementedError("Unseeded random samples cannot be stored.") + + # Overwrite to call on parent instead + @property + def column_names(self) -> List[str]: + return self.query.column_names + + +class RandomSystemRows(RandomBase): + """ + Gets a random sample from the result of a query, using a PostgreSQL TABLESAMPLE + clause with the 'system_rows' method. + This method performs block-level sampling by randomly sampling + each physical storage page of the underlying relation. This + sampling method is guaranteed to provide a sample of the specified + size. + + Parameters + ---------- + query : str + A query specifying a table from which a random sample will be drawn. + size : int, optional + The number of rows to be selected from the table. + Exactly one of the 'size' or 'fraction' arguments must be provided. + fraction : float, optional + The fraction of rows to be selected from the table. + Exactly one of the 'size' or 'fraction' arguments must be provided. + estimate_count : bool, default True + Whether to estimate the number of rows in the table using + information contained in the `pg_class` or whether to perform an + actual count in the number of rows. + + See Also + -------- + flowmachine.core.random.random_factory + + Notes + ----- + The 'system_rows' sampling method does not support parent tables which have + child inheritance. + The 'system_rows' sampling method does not support supplying a seed for + reproducible samples, so random samples cannot be stored. + """ + + def __init__( + self, + query: Query, + *, + size: Optional[int] = None, + fraction: Optional[float] = None, + estimate_count: bool = True, + ): + # Raise a value error if the query is a table, and has children, as the + # method relies on it not having children. + if isinstance(query, Table) and query.has_children(): raise ValueError( "It is not possible to use the 'system_rows' method in tables with inheritance " + "as it selects a random sample for each child table and not for the set as a whole." ) - def _make_query(self): + super().__init__( + query=query, size=size, fraction=fraction, estimate_count=estimate_count + ) - # def _make_query(self, columns, table, size=None, - # fraction=None, method='system_rows', estimate_count=True): + def _make_query(self) -> str: + # TABLESAMPLE only works on tables, so silently store this query + self.query.store().result() + columns = ", ".join( + [ + "{}.{}".format(self.query.fully_qualified_table_name, c) + for c in self.query.column_names + ] + ) + + if self.size is None: + rowcount = self._count_table_rows() + size = int(self.fraction * float(rowcount)) + else: + size = self.size + + sampled_query = f""" + SELECT {columns} FROM {self.query.fully_qualified_table_name} TABLESAMPLE SYSTEM_ROWS({size}) + """ + + return sampled_query + + +class SeedableRandom(RandomBase, metaclass=ABCMeta): + """ + Base class for random samples that accept a seed parameter for reproducibility. + """ + + def __init__( + self, + query: Query, + *, + size: Optional[int] = None, + fraction: Optional[float] = None, + estimate_count: bool = True, + seed: Optional[float] = None, + ): + self._seed = seed + super().__init__( + query=query, size=size, fraction=fraction, estimate_count=estimate_count + ) + + # Make seed a property to avoid inadvertently changing it. + @property + def seed(self) -> Optional[float]: + return self._seed + + @property + def _sample_params(self) -> Dict[str, Any]: + """ + Parameters passed when initialising this query. + """ + return dict(seed=self.seed, **super()._sample_params) + + # Overwrite the table_name method so that it cannot + # be stored by accident. + @property + def table_name(self) -> str: + if self.seed is None: + raise NotImplementedError("Unseeded random samples cannot be stored.") + return f"x{self.md5}" + + +class RandomTablesample(SeedableRandom): + """ + Gets a random sample from the result of a query, using a PostgreSQL TABLESAMPLE + clause with one of the following sampling methods: + 'system': performs block-level sampling by randomly sampling each + physical storage page for the underlying relation. This + sampling method is not guaranteed to generate a sample of the + specified size, but an approximation. This method may not + produce a sample at all, so it might be worth running it again + if it returns an empty dataframe. + 'bernoulli': samples directly on each row of the underlying + relation. This sampling method is slower and is not guaranteed to + generate a sample of the specified size, but an approximation. + The choice of method is determined from the _sampling_method attribute. + + Parameters + ---------- + query : str + A query specifying a table from which a random sample will be drawn. + size : int, optional + The number of rows to be selected from the table. + Exactly one of the 'size' or 'fraction' arguments must be provided. + fraction : float, optional + The fraction of rows to be selected from the table. + Exactly one of the 'size' or 'fraction' arguments must be provided. + estimate_count : bool, default True + Whether to estimate the number of rows in the table using + information contained in the `pg_class` or whether to perform an + actual count in the number of rows. + seed : float, optional + Optionally provide a seed for repeatable random samples. + + See Also + -------- + flowmachine.core.random.random_factory + + Notes + ----- + Random samples may only be stored if a seed is supplied. + """ + + _sampling_method = None + + def __init__( + self, + query: Query, + *, + size: Optional[int] = None, + fraction: Optional[float] = None, + estimate_count: bool = True, + seed: Optional[float] = None, + ): + valid_methods = ["system", "bernoulli"] + if self._sampling_method not in valid_methods: + raise ValueError( + "RandomTablesample() expects a valid sampling method from any of these: " + + ", ".join(valid_methods) + ) + + super().__init__( + query=query, + size=size, + fraction=fraction, + estimate_count=estimate_count, + seed=seed, + ) + + def _make_query(self) -> str: # TABLESAMPLE only works on tables, so silently store this query self.query.store().result() - table = Table(self.table) columns = ", ".join( - ["{}.{}".format(self.table, c) for c in self.query.column_names] + [ + "{}.{}".format(self.query.fully_qualified_table_name, c) + for c in self.query.column_names + ] ) - size = self.size - fraction = self.fraction + if self.fraction is None: + rowcount = self._count_table_rows() + percent = 100 * self.size / float(rowcount) + else: + percent = self.fraction * 100 + + repeatable_statement = ( + f"REPEATABLE({self.seed})" if self.seed is not None else "" + ) + + if self.size is not None: + percent_buffer = min(percent + 10, 100) + sampled_query = f""" + SELECT {columns} FROM {self.query.fully_qualified_table_name} + TABLESAMPLE {self._sampling_method.upper()}({percent_buffer}) {repeatable_statement} LIMIT {self.size} + """ + else: + sampled_query = f""" + SELECT {columns} FROM {self.query.fully_qualified_table_name} + TABLESAMPLE {self._sampling_method.upper()}({percent}) {repeatable_statement} + """ + + return sampled_query - if (size and self.method != "system_rows") or ( - fraction and self.method in ["system_rows", "random_ids"] - ): - ct = 0 +class RandomIDs(SeedableRandom): + """ + Gets a random sample from the result of a query, using the 'random_ids' sampling method. + This method samples rows by randomly sampling the row number. + + Parameters + ---------- + query : str + A query specifying a table from which a random sample will be drawn. + size : int, optional + The number of rows to be selected from the table. + Exactly one of the 'size' or 'fraction' arguments must be provided. + fraction : float, optional + The fraction of rows to be selected from the table. + Exactly one of the 'size' or 'fraction' arguments must be provided. + estimate_count : bool, default True + Whether to estimate the number of rows in the table using + information contained in the `pg_class` or whether to perform an + actual count in the number of rows. + seed : float, optional + Optionally provide a seed for repeatable random samples. + For the 'random_ids' method, seed must be between -/+1. + + See Also + -------- + flowmachine.core.random.random_factory + + Notes + ----- + Random samples may only be stored if a seed is supplied. + """ + + def __init__( + self, + query: Query, + *, + size: Optional[int] = None, + fraction: Optional[float] = None, + estimate_count: bool = True, + seed: Optional[float] = None, + ): + if seed is not None and (seed > 1 or seed < -1): + raise ValueError("Seed must be between -1 and 1 for random_ids method.") + + super().__init__( + query=query, + size=size, + fraction=fraction, + estimate_count=estimate_count, + seed=seed, + ) - if self.estimate_count: - ct = table.estimated_rowcount() + def _make_query(self) -> str: + # TABLESAMPLE only works on tables, so silently store this query + # Note: The "random_ids" method doesn't use TABLESAMPLE, but we still + # store the query before sampling for consistency with the other + # sampling methods. + self.query.store().result() - if not self.estimate_count or ct == 0: - ct = len(self.query) + columns = ",".join(self.query.column_names) - if size is not None: - fraction = size / float(ct) - elif fraction is not None: - size = int(fraction * float(ct)) + rowcount = self._count_table_rows() - if fraction is not None: - fraction *= 100 + if self.size is None: + size = int(self.fraction * float(rowcount)) + else: + size = self.size # Set the seed used to a random one if none is provided seed = self.seed if self.seed is not None else random.random() - if self.method == "random_ids": - - query = """ - SELECT {cn} FROM ( - (SELECT * FROM random_ints({seed}, {ct}, {size_buffer})) r - LEFT JOIN - (SELECT *, row_number() OVER () as rid FROM {sc}.{tn}) s - ON r.id=s.rid - ) o - LIMIT {size} - """.format( - cn=",".join(self.query.column_names), - sc=table.schema, - tn=table.name, - ct=ct, - size_buffer=int(size * 1.1), - size=size, - seed=seed, - ) - else: - if size is not None: - if self.method == "system_rows": - fraction_buffer = size - else: - fraction_buffer = min(fraction + 10, 100) - query = """ - SELECT {cn} FROM {sc}.{tn} TABLESAMPLE {method}({fraction_buffer}) REPEATABLE({seed}) LIMIT {size} - """.format( - cn=columns, - sc=table.schema, - tn=table.name, - fraction_buffer=fraction_buffer, - size=size, - method=self.method.upper(), - seed=seed, - ) - if self.method == "system_rows": - query = """ - SELECT {cn} FROM {sc}.{tn} TABLESAMPLE {method}({fraction_buffer}) LIMIT {size} - """.format( - cn=columns, - sc=table.schema, - tn=table.name, - fraction_buffer=fraction_buffer, - size=size, - method=self.method.upper(), - ) - else: - query = """ - SELECT {cn} FROM {sc}.{tn} TABLESAMPLE {method}({fraction}) REPEATABLE({seed}) - """.format( - cn=columns, - sc=table.schema, - tn=table.name, - fraction=fraction, - method=self.method.upper(), - seed=seed, - ) - if self.method == "system_rows": - query = """ - SELECT {cn} FROM {sc}.{tn} TABLESAMPLE {method}({fraction}) - """.format( - cn=columns, - sc=table.schema, - tn=table.name, - fraction=fraction, - method=self.method.upper(), - ) - - return query - - -def random_factory(parent_class): + sampled_query = f""" + SELECT {columns} FROM ( + (SELECT * FROM random_ints({seed}, {size}, {rowcount})) r + LEFT JOIN + (SELECT *, row_number() OVER () as rid FROM {self.query.fully_qualified_table_name}) s + ON r.id=s.rid + ) o + """ + + return sampled_query + + +def random_factory(parent_class: Type[Query], sampling_method: str = "system_rows"): """ Dynamically creates a random class as a descendant of parent_class. The resulting object will query the underlying object for attributes, and methods. + + Parameters + ---------- + parent_class : class derived from flowmachine.core.Query + Class from which to derive random class + sampling_method : str, default 'system_rows' + One of 'system_rows', 'system', 'bernoulli', 'random_ids'. + Specifies the method used to select the random sample. + 'system_rows': performs block-level sampling by randomly sampling + each physical storage page of the underlying relation. This + sampling method is guaranteed to provide a sample of the specified + size. This method does not support parent tables which have child + inheritance, and is not reproducible. + 'system': performs block-level sampling by randomly sampling each + physical storage page for the underlying relation. This + sampling method is not guaranteed to generate a sample of the + specified size, but an approximation. This method may not + produce a sample at all, so it might be worth running it again + if it returns an empty dataframe. + 'bernoulli': samples directly on each row of the underlying + relation. This sampling method is slower and is not guaranteed to + generate a sample of the specified size, but an approximation. + 'random_ids': samples rows by randomly sampling the row number. + + Returns + ------- + class + A class which gets a random sample from the result of a query. + + Examples + -------- + >>> query = UniqueSubscribers("2016-01-01", "2016-01-31") + >>> Random = random_factory(query.__class__) + >>> Random(query=query, size=10).get_dataframe() + + msisdn + 0 AgvE8pa3Bvqezmo6 + 1 3XKdxqvyNxO2vLD1 + 2 5Kgwy8Gp6DlN3Eq9 + 3 L4V537alj321eWz6 + 4 GJP3DWdGyb4QBnyo + 5 DAlqeZENbeOn2vBw + 6 By4j6PKdB4NGMpxr + 7 mkqQ4NPBPQLapbeg + 8 YNv2EgDJxxAoy0Gr + 9 2vmOlAENnxpPM1xX + + >>> query = VersionedInfrastructure("2016-01-01") + >>> Random = random_factory(query.__class__) + >>> Random(query=query, size=10).get_dataframe() + + id version + 0 o9yyxY 0 + 1 B8OaG5 0 + 2 DbWg4K 0 + 3 0xqNDj 0 + 4 pqg7ZE 0 + 5 nWM8R3 0 + 6 LVnDQL 0 + 7 pdVVV4 0 + 8 wzrXjw 0 + 9 RZgwVz 0 + + # The default method 'system_rows' does not support parent tables which have child inheritance + # as is the case with 'events.calls', so we choose another method here. + >>> Random = random_factory(flowmachine.core.Query, sampling_method='bernoulli') + >>> Random(query=Table('events.calls', columns=['id', 'duration']), size=10).get_dataframe() + id duration + 0 mQjOy-5eVrm-Ll5eE-P4V27 422.0 + 1 mQjOy-5eVrm-Ll5eE-P4V27 422.0 + 2 0r4KG-Rb4Lm-VK1bB-LZQxg 762.0 + 3 BDXMV-yb8Kl-zkmav-AZEJ2 318.0 + 4 vm9gW-4QbYm-OrKbz-qM5Yx 1407.0 + 5 WYxk8-mepk9-W3pdM-yJNjQ 1062.0 + 6 mQjOy-5eVn3-wK5eE-P4V27 1033.0 + 7 M7Vl4-zbqom-oPDep-rOZqE 879.0 + 8 58DKg-l9av9-NE8eG-1vzAp 3129.0 + 9 m9gW4-QbY62-WLYdz-qM5Yx 1117.0 """ + random_classes = { + "system_rows": RandomSystemRows, + "system": RandomTablesample, + "bernoulli": RandomTablesample, + "random_ids": RandomIDs, + } + + try: + random_class = random_classes[sampling_method] + except KeyError: + raise ValueError( + "random_factory expects a valid sampling method from any of these: " + + ", ".join(random_classes.keys()) + ) - class Random(RandomBase, parent_class): - """ - Gets a random sample from the database according to the specification. - - Parameters - ---------- - variable : str - Either 'msisdn' or 'sites'. The class will select a random sample - of msisdn or sites. If this argument is set, it is not possible to - set 'columns' and/or 'query'. The argument 'table' then refers to - any table in the 'events' schema and only has implications when - 'variable' is equal to 'msisdn'. - columns : str or list - The columns from the table to be selected. If this argument is set, - it is not possible to set 'variable' and/or 'query'. - table : str - Schema qualified name of the table which the analysis is based - upon. If 'ALL' it will use all tables that contain location data, - specified in flowmachine.yml. If this argument is set, it is not possible - to set 'query'. If 'variable' is set, then 'table' should refer to - a table in the 'events' schema. - query : str - A query specifying a table from which a random sample will be drawn - from. If this argument is set, it is not possible to set - 'variable' and/or 'table'. - size : int - The size of the random sample. - fraction : int - The fraction of rows to be selected from the table. - method : str, default 'system_rows' - Either 'system_rows', 'system', 'bernouilli', 'random_ids'. - Specifies the method used to select the random sample. - 'system_rows': performs block-level sampling by randomly sampling - each physical storage page of the underlying relation. This - sampling method is guaranteed to provide a sample of the specified - size - 'system': performs block-level sampling by randomly sampling each - physical storage page for the underlying relation. This - sampling method is not guaranteed to generate a sample of the - specified size, but an approximation. This method may not - produce a sample at all, so it might be worth running it again - if it returns an empty dataframe. - 'bernoulli': samples directly on each row of the underlying - relation. This sampling method is slower and is not guaranteed to - generate a sample of the specified size, but an approximation - 'random_ids': Assumes that the table contains a column named 'id' - with random numbers from 1 to the total number of rows in the - table. This method samples the ids from this table. - estimate_count : bool, default True - Whether to estimate the number of rows in the table using - information contained in the `pg_class` or whether to perform an - actual count in the number of rows. - - Examples - -------- - >>> Random(query=UniqueSubscribers("2016-01-01", "2016-01-31"), size=10).get_dataframe() - - msisdn - 0 AgvE8pa3Bvqezmo6 - 1 3XKdxqvyNxO2vLD1 - 2 5Kgwy8Gp6DlN3Eq9 - 3 L4V537alj321eWz6 - 4 GJP3DWdGyb4QBnyo - 5 DAlqeZENbeOn2vBw - 6 By4j6PKdB4NGMpxr - 7 mkqQ4NPBPQLapbeg - 8 YNv2EgDJxxAoy0Gr - 9 2vmOlAENnxpPM1xX - - >>> Random(VersionedInfrastructure("2016-01-01"), size=10).get_dataframe() - - id version - 0 o9yyxY 0 - 1 B8OaG5 0 - 2 DbWg4K 0 - 3 0xqNDj 0 - 4 pqg7ZE 0 - 5 nWM8R3 0 - 6 LVnDQL 0 - 7 pdVVV4 0 - 8 wzrXjw 0 - 9 RZgwVz 0 - - # The default method 'system_rows' does not support parent tables which have child inheritance - # as is the case with 'events.calls', so we choose another method here. - >>> Random(query=Table('events.calls', columns=['id', 'duration'), size=10, method='bernoulli').get_dataframe() - id duration - 0 mQjOy-5eVrm-Ll5eE-P4V27 422.0 - 1 mQjOy-5eVrm-Ll5eE-P4V27 422.0 - 2 0r4KG-Rb4Lm-VK1bB-LZQxg 762.0 - 3 BDXMV-yb8Kl-zkmav-AZEJ2 318.0 - 4 vm9gW-4QbYm-OrKbz-qM5Yx 1407.0 - 5 WYxk8-mepk9-W3pdM-yJNjQ 1062.0 - 6 mQjOy-5eVn3-wK5eE-P4V27 1033.0 - 7 M7Vl4-zbqom-oPDep-rOZqE 879.0 - 8 58DKg-l9av9-NE8eG-1vzAp 3129.0 - 9 m9gW4-QbY62-WLYdz-qM5Yx 1117.0 - """ + class Random(random_class, parent_class): + __doc__ = random_class.__doc__ - def __init__( - self, - query=None, - size=None, - fraction=None, - method="system_rows", - estimate_count=True, - seed=None, - ): - """ + # We define _sampling_method here so that __reduce__ can pass this to + # _RandomGetter for pickling/unpickling, and also so that it can be + # used within the RandomTablesample class without having to pass + # sampling_method as an init parameter + _sampling_method = sampling_method - """ - - self.query = query - self.table = self.query.fully_qualified_table_name - self.size = size - self.fraction = fraction - self.method = method - self.estimate_count = estimate_count - self.seed = seed - - if self.size is None and self.fraction is None: - raise ValueError( - "Random() missing 1 required argument: 'size' or 'fraction'" - ) - if self.size is not None and self.fraction is not None: - raise ValueError( - "Random() expects only 1 argument to be defined: either 'size' or 'fraction'" - ) - - valid_methods = ["system_rows", "system", "bernoulli", "random_ids"] - - if self.method not in valid_methods: - raise ValueError( - "Random() expects a valid method from any of those: " - + ", ".join(valid_methods) - ) - - if self.fraction and self.fraction > 1: - raise ValueError("Random() expects fraction between 0 and 1.") - self._inheritance_check() + def __init__(self, query: Query, **params): + super().__init__(query=query, **params) Query.__init__(self) # This voodoo incantation means that if we look for an attibute @@ -302,18 +529,15 @@ def __getattr__(self, name): raise AttributeError return self.query.__getattribute__(name) - # Overwrite the table_name method so that it cannot - # be stored by accident. - @property - def table_name(self): - if self.seed is None or self.method == "system_rows": - raise NotImplementedError - else: - return f"x{self.md5}" - - # Overwrite to call on parent instead - @property - def column_names(self) -> List[str]: - return self.query.column_names + def __reduce__(self) -> Tuple[_RandomGetter, Tuple[Query, str, Dict[str, Any]]]: + """ + Returns + ------- + A special object which recreates random samples. + """ + return ( + _RandomGetter(), + (self.query, self._sampling_method, self._sample_params), + ) return Random diff --git a/flowmachine/flowmachine/core/server/query_schemas/base_query_with_sampling.py b/flowmachine/flowmachine/core/server/query_schemas/base_query_with_sampling.py new file mode 100644 index 0000000000..26a509035f --- /dev/null +++ b/flowmachine/flowmachine/core/server/query_schemas/base_query_with_sampling.py @@ -0,0 +1,44 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +from abc import ABCMeta, abstractmethod +from marshmallow import Schema, fields + +from .base_exposed_query import BaseExposedQuery +from .random_sample import RandomSampleSchema + + +class BaseQueryWithSamplingSchema(Schema): + sampling = fields.Nested(RandomSampleSchema, allow_none=True) + + +class BaseExposedQueryWithSampling(BaseExposedQuery, metaclass=ABCMeta): + @property + @abstractmethod + def _unsampled_query_obj(self): + """ + Return the flowmachine query object to be sampled. + + Returns + ------- + Query + """ + raise NotImplementedError( + f"Class {self.__class__.__name__} does not have the _unsampled_query_obj property set." + ) + + @property + def _flowmachine_query_obj(self): + """ + Return the underlying flowmachine query object which this class exposes. + + Returns + ------- + Query + """ + query = self._unsampled_query_obj + if self.sampling is None: + return query + else: + return self.sampling.make_random_sample_object(query) diff --git a/flowmachine/flowmachine/core/server/query_schemas/daily_location.py b/flowmachine/flowmachine/core/server/query_schemas/daily_location.py index 335d1aafe2..581253754b 100644 --- a/flowmachine/flowmachine/core/server/query_schemas/daily_location.py +++ b/flowmachine/flowmachine/core/server/query_schemas/daily_location.py @@ -2,18 +2,21 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from marshmallow import Schema, fields, post_load +from marshmallow import fields, post_load from marshmallow.validate import OneOf from flowmachine.features import daily_location -from .base_exposed_query import BaseExposedQuery from .custom_fields import SubscriberSubset from .aggregation_unit import AggregationUnit, get_spatial_unit_obj +from .base_query_with_sampling import ( + BaseQueryWithSamplingSchema, + BaseExposedQueryWithSampling, +) __all__ = ["DailyLocationSchema", "DailyLocationExposed"] -class DailyLocationSchema(Schema): +class DailyLocationSchema(BaseQueryWithSamplingSchema): # query_kind parameter is required here for claims validation query_kind = fields.String(validate=OneOf(["daily_location"])) date = fields.Date(required=True) @@ -26,17 +29,20 @@ def make_query_object(self, params, **kwargs): return DailyLocationExposed(**params) -class DailyLocationExposed(BaseExposedQuery): - def __init__(self, date, *, method, aggregation_unit, subscriber_subset=None): +class DailyLocationExposed(BaseExposedQueryWithSampling): + def __init__( + self, date, *, method, aggregation_unit, subscriber_subset=None, sampling=None + ): # Note: all input parameters need to be defined as attributes on `self` # so that marshmallow can serialise the object correctly. self.date = date self.method = method self.aggregation_unit = aggregation_unit self.subscriber_subset = subscriber_subset + self.sampling = sampling @property - def _flowmachine_query_obj(self): + def _unsampled_query_obj(self): """ Return the underlying flowmachine daily_location object. diff --git a/flowmachine/flowmachine/core/server/query_schemas/displacement.py b/flowmachine/flowmachine/core/server/query_schemas/displacement.py index 90605a1314..f766d4ef41 100644 --- a/flowmachine/flowmachine/core/server/query_schemas/displacement.py +++ b/flowmachine/flowmachine/core/server/query_schemas/displacement.py @@ -2,15 +2,18 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from marshmallow import Schema, fields, post_load +from marshmallow import fields, post_load from marshmallow.validate import OneOf, Length from marshmallow_oneofschema import OneOfSchema from flowmachine.features import Displacement -from .base_exposed_query import BaseExposedQuery from .custom_fields import SubscriberSubset, Statistic from .daily_location import DailyLocationSchema, DailyLocationExposed from .modal_location import ModalLocationSchema, ModalLocationExposed +from .base_query_with_sampling import ( + BaseQueryWithSamplingSchema, + BaseExposedQueryWithSampling, +) __all__ = ["DisplacementSchema", "DisplacementExposed"] @@ -24,7 +27,7 @@ class InputToDisplacementSchema(OneOfSchema): } -class DisplacementSchema(Schema): +class DisplacementSchema(BaseQueryWithSamplingSchema): query_kind = fields.String(validate=OneOf(["displacement"])) start = fields.Date(required=True) stop = fields.Date(required=True) @@ -37,9 +40,16 @@ def make_query_object(self, params, **kwargs): return DisplacementExposed(**params) -class DisplacementExposed(BaseExposedQuery): +class DisplacementExposed(BaseExposedQueryWithSampling): def __init__( - self, *, start, stop, statistic, reference_location, subscriber_subset=None + self, + *, + start, + stop, + statistic, + reference_location, + subscriber_subset=None, + sampling=None ): # Note: all input parameters need to be defined as attributes on `self` # so that marshmallow can serialise the object correctly. @@ -48,9 +58,10 @@ def __init__( self.statistic = statistic self.reference_location = reference_location self.subscriber_subset = subscriber_subset + self.sampling = sampling @property - def _flowmachine_query_obj(self): + def _unsampled_query_obj(self): """ Return the underlying flowmachine displacement object. diff --git a/flowmachine/flowmachine/core/server/query_schemas/event_count.py b/flowmachine/flowmachine/core/server/query_schemas/event_count.py index 6f3f3814fd..ccc40b663b 100644 --- a/flowmachine/flowmachine/core/server/query_schemas/event_count.py +++ b/flowmachine/flowmachine/core/server/query_schemas/event_count.py @@ -2,17 +2,20 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from marshmallow import Schema, fields, post_load +from marshmallow import fields, post_load from marshmallow.validate import OneOf, Length from flowmachine.features import EventCount -from .base_exposed_query import BaseExposedQuery from .custom_fields import EventTypes, SubscriberSubset +from .base_query_with_sampling import ( + BaseQueryWithSamplingSchema, + BaseExposedQueryWithSampling, +) __all__ = ["EventCountSchema", "EventCountExposed"] -class EventCountSchema(Schema): +class EventCountSchema(BaseQueryWithSamplingSchema): query_kind = fields.String(validate=OneOf(["event_count"])) start = fields.Date(required=True) stop = fields.Date(required=True) @@ -27,8 +30,17 @@ def make_query_object(self, params, **kwargs): return EventCountExposed(**params) -class EventCountExposed(BaseExposedQuery): - def __init__(self, *, start, stop, direction, event_types, subscriber_subset=None): +class EventCountExposed(BaseExposedQueryWithSampling): + def __init__( + self, + *, + start, + stop, + direction, + event_types, + subscriber_subset=None, + sampling=None + ): # Note: all input parameters need to be defined as attributes on `self` # so that marshmallow can serialise the object correctly. self.start = start @@ -36,9 +48,10 @@ def __init__(self, *, start, stop, direction, event_types, subscriber_subset=Non self.direction = direction self.event_types = event_types self.subscriber_subset = subscriber_subset + self.sampling = sampling @property - def _flowmachine_query_obj(self): + def _unsampled_query_obj(self): """ Return the underlying flowmachine event_count object. diff --git a/flowmachine/flowmachine/core/server/query_schemas/handset.py b/flowmachine/flowmachine/core/server/query_schemas/handset.py index a9ba275c6d..d79036f803 100644 --- a/flowmachine/flowmachine/core/server/query_schemas/handset.py +++ b/flowmachine/flowmachine/core/server/query_schemas/handset.py @@ -2,18 +2,21 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from marshmallow import Schema, fields, post_load +from marshmallow import fields, post_load from marshmallow.validate import OneOf from flowmachine.features import SubscriberHandsetCharacteristic from flowmachine.core import CustomQuery -from .base_exposed_query import BaseExposedQuery from .custom_fields import SubscriberSubset +from .base_query_with_sampling import ( + BaseQueryWithSamplingSchema, + BaseExposedQueryWithSampling, +) __all__ = ["HandsetSchema", "HandsetExposed"] -class HandsetSchema(Schema): +class HandsetSchema(BaseQueryWithSamplingSchema): # query_kind parameter is required here for claims validation query_kind = fields.String(validate=OneOf(["handset"])) start_date = fields.Date(required=True) @@ -31,9 +34,16 @@ def make_query_object(self, params, **kwargs): return HandsetExposed(**params) -class HandsetExposed(BaseExposedQuery): +class HandsetExposed(BaseExposedQueryWithSampling): def __init__( - self, *, start_date, end_date, method, characteristic, subscriber_subset=None + self, + *, + start_date, + end_date, + method, + characteristic, + subscriber_subset=None, + sampling=None ): # Note: all input parameters need to be defined as attributes on `self` # so that marshmallow can serialise the object correctly. @@ -42,9 +52,10 @@ def __init__( self.method = method self.characteristic = characteristic self.subscriber_subset = subscriber_subset + self.sampling = sampling @property - def _flowmachine_query_obj(self): + def _unsampled_query_obj(self): """ Return the underlying flowmachine handset object. diff --git a/flowmachine/flowmachine/core/server/query_schemas/modal_location.py b/flowmachine/flowmachine/core/server/query_schemas/modal_location.py index 20551c73f2..3635775bc7 100644 --- a/flowmachine/flowmachine/core/server/query_schemas/modal_location.py +++ b/flowmachine/flowmachine/core/server/query_schemas/modal_location.py @@ -2,14 +2,17 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from marshmallow import Schema, fields, post_load +from marshmallow import fields, post_load from marshmallow.validate import OneOf, Length from marshmallow_oneofschema import OneOfSchema -from .base_exposed_query import BaseExposedQuery from .custom_fields import SubscriberSubset from .aggregation_unit import AggregationUnit from .daily_location import DailyLocationSchema, DailyLocationExposed +from .base_query_with_sampling import ( + BaseQueryWithSamplingSchema, + BaseExposedQueryWithSampling, +) class InputToModalLocationSchema(OneOfSchema): @@ -17,7 +20,7 @@ class InputToModalLocationSchema(OneOfSchema): type_schemas = {"daily_location": DailyLocationSchema} -class ModalLocationSchema(Schema): +class ModalLocationSchema(BaseQueryWithSamplingSchema): # query_kind parameter is required here for claims validation query_kind = fields.String(validate=OneOf(["modal_location"])) locations = fields.Nested( @@ -31,16 +34,19 @@ def make_query_object(self, data, **kwargs): return ModalLocationExposed(**data) -class ModalLocationExposed(BaseExposedQuery): - def __init__(self, locations, *, aggregation_unit, subscriber_subset=None): +class ModalLocationExposed(BaseExposedQueryWithSampling): + def __init__( + self, locations, *, aggregation_unit, subscriber_subset=None, sampling=None + ): # Note: all input parameters need to be defined as attributes on `self` # so that marshmallow can serialise the object correctly. self.locations = locations self.aggregation_unit = aggregation_unit self.subscriber_subset = subscriber_subset + self.sampling = sampling @property - def _flowmachine_query_obj(self): + def _unsampled_query_obj(self): """ Return the underlying flowmachine ModalLocation object. diff --git a/flowmachine/flowmachine/core/server/query_schemas/nocturnal_events.py b/flowmachine/flowmachine/core/server/query_schemas/nocturnal_events.py index 620ab1f343..dc46beb78b 100644 --- a/flowmachine/flowmachine/core/server/query_schemas/nocturnal_events.py +++ b/flowmachine/flowmachine/core/server/query_schemas/nocturnal_events.py @@ -2,17 +2,20 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from marshmallow import Schema, fields, post_load +from marshmallow import fields, post_load from marshmallow.validate import OneOf, Length from flowmachine.features import NocturnalEvents -from .base_exposed_query import BaseExposedQuery from .custom_fields import SubscriberSubset +from .base_query_with_sampling import ( + BaseQueryWithSamplingSchema, + BaseExposedQueryWithSampling, +) __all__ = ["NocturnalEventsSchema", "NocturnalEventsExposed"] -class NocturnalEventsSchema(Schema): +class NocturnalEventsSchema(BaseQueryWithSamplingSchema): query_kind = fields.String(validate=OneOf(["nocturnal_events"])) start = fields.Date(required=True) stop = fields.Date(required=True) @@ -24,17 +27,18 @@ def make_query_object(self, params, **kwargs): return NocturnalEventsExposed(**params) -class NocturnalEventsExposed(BaseExposedQuery): - def __init__(self, *, start, stop, hours, subscriber_subset=None): +class NocturnalEventsExposed(BaseExposedQueryWithSampling): + def __init__(self, *, start, stop, hours, subscriber_subset=None, sampling=None): # Note: all input parameters need to be defined as attributes on `self` # so that marshmallow can serialise the object correctly. self.start = start self.stop = stop self.hours = hours self.subscriber_subset = subscriber_subset + self.sampling = sampling @property - def _flowmachine_query_obj(self): + def _unsampled_query_obj(self): """ Return the underlying flowmachine nocturnal_events object. diff --git a/flowmachine/flowmachine/core/server/query_schemas/pareto_interactions.py b/flowmachine/flowmachine/core/server/query_schemas/pareto_interactions.py index 19016e04b1..1fbe84570f 100644 --- a/flowmachine/flowmachine/core/server/query_schemas/pareto_interactions.py +++ b/flowmachine/flowmachine/core/server/query_schemas/pareto_interactions.py @@ -2,17 +2,20 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from marshmallow import Schema, fields, post_load +from marshmallow import fields, post_load from marshmallow.validate import OneOf, Length, Range from flowmachine.features import ParetoInteractions -from .base_exposed_query import BaseExposedQuery from .custom_fields import SubscriberSubset +from .base_query_with_sampling import ( + BaseQueryWithSamplingSchema, + BaseExposedQueryWithSampling, +) __all__ = ["ParetoInteractionsSchema", "ParetoInteractionsExposed"] -class ParetoInteractionsSchema(Schema): +class ParetoInteractionsSchema(BaseQueryWithSamplingSchema): query_kind = fields.String(validate=OneOf(["pareto_interactions"])) start = fields.Date(required=True) stop = fields.Date(required=True) @@ -24,17 +27,20 @@ def make_query_object(self, params, **kwargs): return ParetoInteractionsExposed(**params) -class ParetoInteractionsExposed(BaseExposedQuery): - def __init__(self, *, start, stop, proportion, subscriber_subset=None): +class ParetoInteractionsExposed(BaseExposedQueryWithSampling): + def __init__( + self, *, start, stop, proportion, subscriber_subset=None, sampling=None + ): # Note: all input parameters need to be defined as attributes on `self` # so that marshmallow can serialise the object correctly. self.start = start self.stop = stop self.proportion = proportion self.subscriber_subset = subscriber_subset + self.sampling = sampling @property - def _flowmachine_query_obj(self): + def _unsampled_query_obj(self): """ Return the underlying flowmachine pareto_interactions object. diff --git a/flowmachine/flowmachine/core/server/query_schemas/radius_of_gyration.py b/flowmachine/flowmachine/core/server/query_schemas/radius_of_gyration.py index 6f227df885..7a4f136f81 100644 --- a/flowmachine/flowmachine/core/server/query_schemas/radius_of_gyration.py +++ b/flowmachine/flowmachine/core/server/query_schemas/radius_of_gyration.py @@ -2,17 +2,20 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from marshmallow import Schema, fields, post_load +from marshmallow import fields, post_load from marshmallow.validate import OneOf from flowmachine.features import RadiusOfGyration -from .base_exposed_query import BaseExposedQuery from .custom_fields import SubscriberSubset +from .base_query_with_sampling import ( + BaseQueryWithSamplingSchema, + BaseExposedQueryWithSampling, +) __all__ = ["RadiusOfGyrationSchema", "RadiusOfGyrationExposed"] -class RadiusOfGyrationSchema(Schema): +class RadiusOfGyrationSchema(BaseQueryWithSamplingSchema): # query_kind parameter is required here for claims validation query_kind = fields.String(validate=OneOf(["radius_of_gyration"])) start_date = fields.Date(required=True) @@ -24,16 +27,17 @@ def make_query_object(self, params, **kwargs): return RadiusOfGyrationExposed(**params) -class RadiusOfGyrationExposed(BaseExposedQuery): - def __init__(self, *, start_date, end_date, subscriber_subset=None): +class RadiusOfGyrationExposed(BaseExposedQueryWithSampling): + def __init__(self, *, start_date, end_date, subscriber_subset=None, sampling=None): # Note: all input parameters need to be defined as attributes on `self` # so that marshmallow can serialise the object correctly. self.start_date = start_date self.end_date = end_date self.subscriber_subset = subscriber_subset + self.sampling = sampling @property - def _flowmachine_query_obj(self): + def _unsampled_query_obj(self): """ Return the underlying flowmachine radius_of_gyration object. diff --git a/flowmachine/flowmachine/core/server/query_schemas/random_sample.py b/flowmachine/flowmachine/core/server/query_schemas/random_sample.py new file mode 100644 index 0000000000..2f885b3d52 --- /dev/null +++ b/flowmachine/flowmachine/core/server/query_schemas/random_sample.py @@ -0,0 +1,125 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +from marshmallow import Schema, fields, validates_schema, ValidationError, post_load +from marshmallow.validate import OneOf, Range +from marshmallow_oneofschema import OneOfSchema + +__all__ = ["RandomSampleSchema", "RandomSampler"] + + +class BaseRandomSampleSchema(Schema): + size = fields.Integer(validate=Range(min=1), allow_none=True) + fraction = fields.Float( + validate=Range(0.0, 1.0, min_inclusive=False, max_inclusive=False), + allow_none=True, + ) + estimate_count = fields.Boolean(missing=True) + + @validates_schema + def validate_size_or_fraction(self, data, **kwargs): + if ("size" in data and data["size"] is not None) == ( + "fraction" in data and data["fraction"] is not None + ): + raise ValidationError( + "Must provide exactly one of 'size' or 'fraction' for a random sample." + ) + + +class SystemRowsRandomSampleSchema(BaseRandomSampleSchema): + # We must define the sampling_method field here for it to appear in the API spec. + # This field is removed by RandomSampleSchema before passing on to this schema, + # so the sampling_method parameter is never received here and is not included in the + # params passed to make_random_sampler. + sampling_method = fields.String(validate=OneOf(["system_rows"])) + + @post_load + def make_random_sampler(self, params, **kwargs): + return RandomSampler(sampling_method="system_rows", **params) + + +class SystemRandomSampleSchema(BaseRandomSampleSchema): + # We must define the sampling_method field here for it to appear in the API spec. + # This field is removed by RandomSampleSchema before passing on to this schema, + # so the sampling_method parameter is never received here and is not included in the + # params passed to make_random_sampler. + sampling_method = fields.String(validate=OneOf(["system"])) + seed = fields.Float() + + @post_load + def make_random_sampler(self, params, **kwargs): + return RandomSampler(sampling_method="system", **params) + + +class BernoulliRandomSampleSchema(BaseRandomSampleSchema): + # We must define the sampling_method field here for it to appear in the API spec. + # This field is removed by RandomSampleSchema before passing on to this schema, + # so the sampling_method parameter is never received here and is not included in the + # params passed to make_random_sampler. + sampling_method = fields.String(validate=OneOf(["bernoulli"])) + seed = fields.Float() + + @post_load + def make_random_sampler(self, params, **kwargs): + return RandomSampler(sampling_method="bernoulli", **params) + + +class RandomIDsRandomSampleSchema(BaseRandomSampleSchema): + # We must define the sampling_method field here for it to appear in the API spec. + # This field is removed by RandomSampleSchema before passing on to this schema, + # so the sampling_method parameter is never received here and is not included in the + # params passed to make_random_sampler. + sampling_method = fields.String(validate=OneOf(["random_ids"])) + seed = fields.Float(validate=Range(-1.0, 1.0)) + + @post_load + def make_random_sampler(self, params, **kwargs): + return RandomSampler(sampling_method="random_ids", **params) + + +class RandomSampler: + def __init__( + self, *, sampling_method, size=None, fraction=None, estimate_count, seed=None + ): + # Note: all input parameters need to be defined as attributes on `self` + # so that marshmallow can serialise the object correctly. + self.sampling_method = sampling_method + self.size = size + self.fraction = fraction + self.estimate_count = estimate_count + if sampling_method != "system_rows": + self.seed = seed + + def make_random_sample_object(self, query): + """ + Apply this random sample to a FlowMachine Query object + + Parameters + ---------- + query : Query + FlowMachine Query object to be sampled + + Returns + ------- + Random + """ + sample_params = { + "sampling_method": self.sampling_method, + "size": self.size, + "fraction": self.fraction, + "estimate_count": self.estimate_count, + } + if self.sampling_method != "system_rows": + sample_params["seed"] = self.seed + return query.random_sample(**sample_params) + + +class RandomSampleSchema(OneOfSchema): + type_field = "sampling_method" + type_schemas = { + "system_rows": SystemRowsRandomSampleSchema, + "system": SystemRandomSampleSchema, + "bernoulli": BernoulliRandomSampleSchema, + "random_ids": RandomIDsRandomSampleSchema, + } diff --git a/flowmachine/flowmachine/core/server/query_schemas/subscriber_degree.py b/flowmachine/flowmachine/core/server/query_schemas/subscriber_degree.py index ec2f259b5b..d63ff1c73b 100644 --- a/flowmachine/flowmachine/core/server/query_schemas/subscriber_degree.py +++ b/flowmachine/flowmachine/core/server/query_schemas/subscriber_degree.py @@ -2,17 +2,20 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from marshmallow import Schema, fields, post_load +from marshmallow import fields, post_load from marshmallow.validate import OneOf, Length from flowmachine.features import SubscriberDegree -from .base_exposed_query import BaseExposedQuery from .custom_fields import SubscriberSubset +from .base_query_with_sampling import ( + BaseQueryWithSamplingSchema, + BaseExposedQueryWithSampling, +) __all__ = ["SubscriberDegreeSchema", "SubscriberDegreeExposed"] -class SubscriberDegreeSchema(Schema): +class SubscriberDegreeSchema(BaseQueryWithSamplingSchema): query_kind = fields.String(validate=OneOf(["subscriber_degree"])) start = fields.Date(required=True) stop = fields.Date(required=True) @@ -26,17 +29,20 @@ def make_query_object(self, params, **kwargs): return SubscriberDegreeExposed(**params) -class SubscriberDegreeExposed(BaseExposedQuery): - def __init__(self, *, start, stop, direction, subscriber_subset=None): +class SubscriberDegreeExposed(BaseExposedQueryWithSampling): + def __init__( + self, *, start, stop, direction, subscriber_subset=None, sampling=None + ): # Note: all input parameters need to be defined as attributes on `self` # so that marshmallow can serialise the object correctly. self.start = start self.stop = stop self.direction = direction self.subscriber_subset = subscriber_subset + self.sampling = sampling @property - def _flowmachine_query_obj(self): + def _unsampled_query_obj(self): """ Return the underlying flowmachine subscriber_degree object. diff --git a/flowmachine/flowmachine/core/server/query_schemas/topup_amount.py b/flowmachine/flowmachine/core/server/query_schemas/topup_amount.py index 0247ebfd0d..7d5a8305b6 100644 --- a/flowmachine/flowmachine/core/server/query_schemas/topup_amount.py +++ b/flowmachine/flowmachine/core/server/query_schemas/topup_amount.py @@ -2,17 +2,20 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from marshmallow import Schema, fields, post_load +from marshmallow import fields, post_load from marshmallow.validate import OneOf, Length from flowmachine.features import TopUpAmount -from .base_exposed_query import BaseExposedQuery from .custom_fields import SubscriberSubset, Statistic +from .base_query_with_sampling import ( + BaseQueryWithSamplingSchema, + BaseExposedQueryWithSampling, +) __all__ = ["TopUpAmountSchema", "TopUpAmountExposed"] -class TopUpAmountSchema(Schema): +class TopUpAmountSchema(BaseQueryWithSamplingSchema): query_kind = fields.String(validate=OneOf(["topup_amount"])) start = fields.Date(required=True) stop = fields.Date(required=True) @@ -24,17 +27,20 @@ def make_query_object(self, params, **kwargs): return TopUpAmountExposed(**params) -class TopUpAmountExposed(BaseExposedQuery): - def __init__(self, *, start, stop, statistic, subscriber_subset=None): +class TopUpAmountExposed(BaseExposedQueryWithSampling): + def __init__( + self, *, start, stop, statistic, subscriber_subset=None, sampling=None + ): # Note: all input parameters need to be defined as attributes on `self` # so that marshmallow can serialise the object correctly. self.start = start self.stop = stop self.statistic = statistic self.subscriber_subset = subscriber_subset + self.sampling = sampling @property - def _flowmachine_query_obj(self): + def _unsampled_query_obj(self): """ Return the underlying flowmachine topup_amount object. diff --git a/flowmachine/flowmachine/core/server/query_schemas/topup_balance.py b/flowmachine/flowmachine/core/server/query_schemas/topup_balance.py index 0d8078a429..3a441b9183 100644 --- a/flowmachine/flowmachine/core/server/query_schemas/topup_balance.py +++ b/flowmachine/flowmachine/core/server/query_schemas/topup_balance.py @@ -2,17 +2,20 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from marshmallow import Schema, fields, post_load +from marshmallow import fields, post_load from marshmallow.validate import OneOf, Length from flowmachine.features import TopUpBalance -from .base_exposed_query import BaseExposedQuery from .custom_fields import Statistic, SubscriberSubset +from .base_query_with_sampling import ( + BaseQueryWithSamplingSchema, + BaseExposedQueryWithSampling, +) __all__ = ["TopUpBalanceSchema", "TopUpBalanceExposed"] -class TopUpBalanceSchema(Schema): +class TopUpBalanceSchema(BaseQueryWithSamplingSchema): query_kind = fields.String(validate=OneOf(["topup_balance"])) start_date = fields.Date(required=True) end_date = fields.Date(required=True) @@ -24,9 +27,15 @@ def make_query_object(self, params, **kwargs): return TopUpBalanceExposed(**params) -class TopUpBalanceExposed(BaseExposedQuery): +class TopUpBalanceExposed(BaseExposedQueryWithSampling): def __init__( - self, *, start_date, end_date, statistic="avg", subscriber_subset=None + self, + *, + start_date, + end_date, + statistic="avg", + subscriber_subset=None, + sampling=None ): # Note: all input parameters need to be defined as attributes on `self` # so that marshmallow can serialise the object correctly. @@ -34,9 +43,10 @@ def __init__( self.end_date = end_date self.statistic = statistic self.subscriber_subset = subscriber_subset + self.sampling = sampling @property - def _flowmachine_query_obj(self): + def _unsampled_query_obj(self): """ Return the underlying flowmachine TopUpBalance object. diff --git a/flowmachine/flowmachine/core/server/query_schemas/unique_location_counts.py b/flowmachine/flowmachine/core/server/query_schemas/unique_location_counts.py index 848f2cee05..793850b547 100644 --- a/flowmachine/flowmachine/core/server/query_schemas/unique_location_counts.py +++ b/flowmachine/flowmachine/core/server/query_schemas/unique_location_counts.py @@ -2,18 +2,21 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -from marshmallow import Schema, fields, post_load +from marshmallow import fields, post_load from marshmallow.validate import OneOf, Length from flowmachine.features import UniqueLocationCounts -from .base_exposed_query import BaseExposedQuery from .custom_fields import SubscriberSubset from .aggregation_unit import AggregationUnit, get_spatial_unit_obj +from .base_query_with_sampling import ( + BaseQueryWithSamplingSchema, + BaseExposedQueryWithSampling, +) __all__ = ["UniqueLocationCountsSchema", "UniqueLocationCountsExposed"] -class UniqueLocationCountsSchema(Schema): +class UniqueLocationCountsSchema(BaseQueryWithSamplingSchema): query_kind = fields.String(validate=OneOf(["unique_location_counts"])) start_date = fields.Date(required=True) end_date = fields.Date(required=True) @@ -25,9 +28,15 @@ def make_query_object(self, params, **kwargs): return UniqueLocationCountsExposed(**params) -class UniqueLocationCountsExposed(BaseExposedQuery): +class UniqueLocationCountsExposed(BaseExposedQueryWithSampling): def __init__( - self, *, start_date, end_date, aggregation_unit, subscriber_subset=None + self, + *, + start_date, + end_date, + aggregation_unit, + subscriber_subset=None, + sampling=None ): # Note: all input parameters need to be defined as attributes on `self` # so that marshmallow can serialise the object correctly. @@ -35,9 +44,10 @@ def __init__( self.end_date = end_date self.aggregation_unit = aggregation_unit self.subscriber_subset = subscriber_subset + self.sampling = sampling @property - def _flowmachine_query_obj(self): + def _unsampled_query_obj(self): """ Return the underlying flowmachine unique_location_counts object. diff --git a/flowmachine/flowmachine/core/table.py b/flowmachine/flowmachine/core/table.py index 4b197032d2..8b2be171c6 100644 --- a/flowmachine/flowmachine/core/table.py +++ b/flowmachine/flowmachine/core/table.py @@ -14,6 +14,7 @@ from .errors import NotConnectedError from .query import Query from .subset import subset_factory +from .cache import write_cache_metadata import structlog @@ -121,7 +122,7 @@ def __init__(self, name=None, schema=None, columns=None): q_state_machine = QueryStateMachine(self.redis, self.md5) q_state_machine.enqueue() q_state_machine.execute() - self._db_store_cache_metadata(compute_time=0) + write_cache_metadata(self.connection, self, compute_time=0) q_state_machine.finish() def __format__(self, fmt): @@ -238,19 +239,60 @@ def invalidate_db_cache(self, name=None, schema=None, cascade=True, drop=False): name=name, schema=schema, cascade=cascade, drop=drop ) - def random_sample( - self, size=None, fraction=None, method="system_rows", estimate_count=True - ): + def random_sample(self, sampling_method="system_rows", **params): + """ + Draws a random sample from this table. + + Parameters + ---------- + sampling_method : {'system', 'system_rows', 'bernoulli', 'random_ids'}, default 'system_rows' + Specifies the method used to select the random sample. + 'system_rows': performs block-level sampling by randomly sampling + each physical storage page of the underlying relation. This + sampling method is guaranteed to provide a sample of the specified + size + 'system': performs block-level sampling by randomly sampling each + physical storage page for the underlying relation. This + sampling method is not guaranteed to generate a sample of the + specified size, but an approximation. This method may not + produce a sample at all, so it might be worth running it again + if it returns an empty dataframe. + 'bernoulli': samples directly on each row of the underlying + relation. This sampling method is slower and is not guaranteed to + generate a sample of the specified size, but an approximation + 'random_ids': samples rows by randomly sampling the row number. + size : int, optional + The number of rows to draw. + Exactly one of the 'size' or 'fraction' arguments must be provided. + fraction : float, optional + Fraction of rows to draw. + Exactly one of the 'size' or 'fraction' arguments must be provided. + estimate_count : bool, default True + Whether to estimate the number of rows in the table using + information contained in the `pg_class` or whether to perform an + actual count in the number of rows. + seed : float, optional + Optionally provide a seed for repeatable random samples. + If using random_ids method, seed must be between -/+1. + Not available in combination with the system_rows method. + + Returns + ------- + Random + A special query object which contains a random sample from this table + + See Also + -------- + flowmachine.core.random.random_factory + + Notes + ----- + Random samples may only be stored if a seed is supplied. + """ from .random import random_factory - random_class = random_factory(Query) - return random_class( - query=self, - size=size, - fraction=fraction, - method=method, - estimate_count=estimate_count, - ) + random_class = random_factory(Query, sampling_method=sampling_method) + return random_class(query=self, **params) def subset(self, col, subset): """ diff --git a/flowmachine/tests/test_cache.py b/flowmachine/tests/test_cache.py index 62af454bf9..ebebbb36ba 100644 --- a/flowmachine/tests/test_cache.py +++ b/flowmachine/tests/test_cache.py @@ -8,7 +8,7 @@ import pytest -from flowmachine.core.cache import cache_table_exists +from flowmachine.core.cache import cache_table_exists, write_cache_metadata from flowmachine.core.query import Query from flowmachine.features import daily_location, ModalLocation, Flows @@ -31,7 +31,7 @@ def test_do_cache_simple(flowmachine_connect): """ dl1 = daily_location("2016-01-01") - dl1._db_store_cache_metadata() + write_cache_metadata(flowmachine_connect, dl1) assert cache_table_exists(flowmachine_connect, dl1.md5) @@ -42,7 +42,7 @@ def test_do_cache_multi(flowmachine_connect): """ hl1 = ModalLocation(daily_location("2016-01-01"), daily_location("2016-01-02")) - hl1._db_store_cache_metadata() + write_cache_metadata(flowmachine_connect, hl1) assert cache_table_exists(flowmachine_connect, hl1.md5) @@ -55,7 +55,7 @@ def test_do_cache_nested(flowmachine_connect): hl1 = ModalLocation(daily_location("2016-01-01"), daily_location("2016-01-02")) hl2 = ModalLocation(daily_location("2016-01-03"), daily_location("2016-01-04")) flow = Flows(hl1, hl2) - flow._db_store_cache_metadata() + write_cache_metadata(flowmachine_connect, flow) assert cache_table_exists(flowmachine_connect, flow.md5) diff --git a/flowmachine/tests/test_query.py b/flowmachine/tests/test_query.py index 293168e2f9..d1f009fd52 100644 --- a/flowmachine/tests/test_query.py +++ b/flowmachine/tests/test_query.py @@ -126,7 +126,7 @@ def test_iteration(): def test_limited_head(): """Test that we can call head on a query with a limit clause.""" dl = daily_location("2016-01-01") - dl.random_sample(2).head() + dl.random_sample(size=2, sampling_method="bernoulli").head() def test_make_sql_no_overwrite(): diff --git a/flowmachine/tests/test_query_object_construction.py b/flowmachine/tests/test_query_object_construction.py index b44853b614..98ce70f09c 100644 --- a/flowmachine/tests/test_query_object_construction.py +++ b/flowmachine/tests/test_query_object_construction.py @@ -25,6 +25,33 @@ def test_construct_query(diff_reporter): "subscriber_subset": None, }, }, + { + "query_kind": "spatial_aggregate", + "locations": { + "query_kind": "daily_location", + "date": "2016-01-01", + "aggregation_unit": "admin3", + "method": "last", + "subscriber_subset": None, + "sampling": None, + }, + }, + { + "query_kind": "spatial_aggregate", + "locations": { + "query_kind": "daily_location", + "date": "2016-01-01", + "aggregation_unit": "admin3", + "method": "last", + "subscriber_subset": None, + "sampling": { + "sampling_method": "system_rows", + "size": 10, + "fraction": None, + "estimate_count": False, + }, + }, + }, { "query_kind": "location_event_counts", "start_date": "2016-01-01", @@ -264,3 +291,43 @@ def test_wrong_geography_aggregation_unit_raises_error(): _ = FlowmachineQuerySchema().load( {"query_kind": "geography", "aggregation_unit": "DUMMY_AGGREGATION_UNIT"} ) + + +@pytest.mark.parametrize( + "sampling, message", + [ + ( + {"sampling_method": "system_rows", "size": 10, "fraction": 0.2}, + "Must provide exactly one of 'size' or 'fraction' for a random sample", + ), + ( + {"sampling_method": "system_rows"}, + "Must provide exactly one of 'size' or 'fraction' for a random sample", + ), + ( + {"sampling_method": "system_rows", "fraction": 1.2}, + "Must be greater than 0.0 and less than 1.0.", + ), + ( + {"sampling_method": "system_rows", "size": -1}, + "Must be greater or equal to 1.", + ), + ( + {"sampling_method": "random_ids", "size": 10, "seed": 185}, + "Must be greater or equal to -1.0 and less or equal to 1.0.", + ), + ], +) +def test_invalid_sampling_params_raises_error(sampling, message): + query_spec = { + "query_kind": "spatial_aggregate", + "locations": { + "query_kind": "daily_location", + "date": "2016-01-01", + "aggregation_unit": "admin3", + "method": "last", + "sampling": sampling, + }, + } + with pytest.raises(ValidationError, match=message): + _ = FlowmachineQuerySchema().load(query_spec) diff --git a/flowmachine/tests/test_query_object_construction.test_construct_query.approved.txt b/flowmachine/tests/test_query_object_construction.test_construct_query.approved.txt index 857f5ed6d6..17da803b95 100644 --- a/flowmachine/tests/test_query_object_construction.test_construct_query.approved.txt +++ b/flowmachine/tests/test_query_object_construction.test_construct_query.approved.txt @@ -6,7 +6,24 @@ "date": "2016-01-01", "aggregation_unit": "admin3", "method": "last", - "subscriber_subset": null + "subscriber_subset": null, + "sampling": null + } + }, + "d1b69e1ca8b44f717d3335aaf93eb97c": { + "query_kind": "spatial_aggregate", + "locations": { + "query_kind": "daily_location", + "date": "2016-01-01", + "aggregation_unit": "admin3", + "method": "last", + "subscriber_subset": null, + "sampling": { + "sampling_method": "system_rows", + "size": 10, + "fraction": null, + "estimate_count": false + } } }, "1a1e4e159d05f2ec1f081ac9c2bfd6d5": { diff --git a/flowmachine/tests/test_random.py b/flowmachine/tests/test_random.py index dad1f0f4ae..ad4cb21aa9 100644 --- a/flowmachine/tests/test_random.py +++ b/flowmachine/tests/test_random.py @@ -9,6 +9,7 @@ import pytest +import pickle from flowmachine.core.mixins import GraphMixin from flowmachine.features import daily_location, Flows @@ -22,7 +23,7 @@ def test_random_msisdn(get_dataframe): Tests whether class selects a random sample of msisdn without failing. """ df = get_dataframe( - UniqueSubscribers(start="2016-01-01", stop="2016-01-04").random_sample(10) + UniqueSubscribers(start="2016-01-01", stop="2016-01-04").random_sample(size=10) ) assert list(df.columns) == ["subscriber"] assert len(df) == 10 @@ -36,12 +37,12 @@ def test_seeded_random(sample_method, get_dataframe): df = get_dataframe( UniqueSubscribers(start="2016-01-01", stop="2016-01-04").random_sample( - 10, method=sample_method, seed=0.1 + size=10, sampling_method=sample_method, seed=0.1 ) ) df2 = get_dataframe( UniqueSubscribers(start="2016-01-01", stop="2016-01-04").random_sample( - 10, method=sample_method, seed=0.1 + size=10, sampling_method=sample_method, seed=0.1 ) ) assert df.values.tolist() == df2.values.tolist() @@ -54,7 +55,7 @@ def test_bad_method_errors(): with pytest.raises(ValueError): UniqueSubscribers(start="2016-01-01", stop="2016-01-04").random_sample( - 10, method="BAD_METHOD_TYPE", seed=-50 + size=10, sampling_method="BAD_METHOD_TYPE", seed=-50 ) @@ -65,7 +66,7 @@ def test_bad_must_provide_sample_size_or_fraction(): with pytest.raises(ValueError): UniqueSubscribers(start="2016-01-01", stop="2016-01-04").random_sample( - None, fraction=None + size=None, fraction=None ) @@ -76,7 +77,7 @@ def test_bad_must_provide_either_sample_size_or_fraction(): with pytest.raises(ValueError): UniqueSubscribers(start="2016-01-01", stop="2016-01-04").random_sample( - 10, fraction=0.5 + size=10, fraction=0.5 ) @@ -87,7 +88,7 @@ def test_seeded_random_oob(): with pytest.raises(ValueError): UniqueSubscribers(start="2016-01-01", stop="2016-01-04").random_sample( - 10, method="bernoulli", seed=-50 + size=10, sampling_method="random_ids", seed=-50 ) @@ -98,7 +99,7 @@ def test_seeded_random_zero(sample_method): """ sample = UniqueSubscribers(start="2016-01-01", stop="2016-01-04").random_sample( - 10, method=sample_method, seed=0 + size=10, sampling_method=sample_method, seed=0 ) assert sample.get_query() == sample.get_query() @@ -108,9 +109,9 @@ def test_seeded_random_badmethod(): Tests whether seeds don't work with system_rows. """ - with pytest.raises(ValueError): + with pytest.raises(TypeError, match="got an unexpected keyword argument 'seed'"): UniqueSubscribers(start="2016-01-01", stop="2016-01-04").random_sample( - 10, method="system_rows", seed=-0.5 + size=10, sampling_method="system_rows", seed=-0.5 ) @@ -155,13 +156,13 @@ def test_system_rows(get_dataframe): """ df = get_dataframe( UniqueSubscribers(start="2016-01-01", stop="2016-01-04").random_sample( - size=10, method="system_rows" + size=10, sampling_method="system_rows" ) ) assert len(df) == 10 df = get_dataframe( UniqueSubscribers(start="2016-01-01", stop="2016-01-04").random_sample( - fraction=0.1, method="system_rows" + fraction=0.1, sampling_method="system_rows" ) ) assert len(df) == 50 @@ -177,7 +178,7 @@ def test_system(get_dataframe): while len(df) == 0: df = get_dataframe( UniqueSubscribers(start="2016-01-01", stop="2016-01-04").random_sample( - size=20, method="system" + size=20, sampling_method="system" ) ) assert list(df.columns) == ["subscriber"] @@ -189,7 +190,7 @@ def test_system(get_dataframe): while len(df) == 0: df = get_dataframe( UniqueSubscribers(start="2016-01-01", stop="2016-01-04").random_sample( - fraction=0.25, method="system" + fraction=0.25, sampling_method="system" ) ) assert list(df.columns) == ["subscriber"] @@ -201,7 +202,7 @@ def test_bernoulli(get_dataframe): """ df = get_dataframe( UniqueSubscribers(start="2016-01-01", stop="2016-01-04").random_sample( - size=10, method="bernoulli" + size=10, sampling_method="bernoulli" ) ) assert list(df.columns) == ["subscriber"] @@ -209,7 +210,7 @@ def test_bernoulli(get_dataframe): df = get_dataframe( UniqueSubscribers(start="2016-01-01", stop="2016-01-04").random_sample( - fraction=0.1, method="bernoulli" + fraction=0.1, sampling_method="bernoulli" ) ) assert list(df.columns) == ["subscriber"] @@ -221,7 +222,7 @@ def test_not_estimate_count(get_dataframe): """ df = get_dataframe( UniqueSubscribers(start="2016-01-01", stop="2016-01-04").random_sample( - size=10, method="bernoulli", estimate_count=False + size=10, sampling_method="bernoulli", estimate_count=False ) ) assert list(df.columns) == ["subscriber"] @@ -233,7 +234,9 @@ def test_system_rows_fail_with_inheritance(): Test whether the system row method fails if the subscriber queries for random rows on a parent table. """ with pytest.raises(ValueError): - df = Table(name="events.calls").random_sample(size=8) + df = Table(name="events.calls").random_sample( + size=8, sampling_method="system_rows" + ) def test_random_sample(get_dataframe): @@ -248,29 +251,48 @@ def test_random_sample(get_dataframe): assert len(df) == 6 -def is_subclass(): +def test_is_subclass(): """ Test that a random sample is an instance of the sampled thing. """ qur = UniqueSubscribers(start="2016-01-01", stop="2016-01-04") - sample = qur.random_sample(size=10, method="bernoulli", estimate_count=False) + sample = qur.random_sample( + size=10, sampling_method="bernoulli", estimate_count=False + ) assert isinstance(sample, UniqueSubscribers) -def gets_parent_attributes(): +def test_gets_parent_attributes(): """ Test that a random sample is an instance of the sampled thing. """ - qur = UniqueSubscribers(start="2016-01-01", stop="2016-01-04", level="admin3") - sample = qur.random_sample(size=10, method="bernoulli", estimate_count=False) - assert sample.level == "admin3" + qur = UniqueSubscribers(start="2016-01-01", stop="2016-01-04", hours=(4, 17)) + sample = qur.random_sample( + size=10, sampling_method="bernoulli", estimate_count=False + ) + assert sample.hours == (4, 17) -def gets_mixins(): +def test_gets_mixins(): """ Test that a random sample gets applicable mixins. """ dl1 = daily_location("2016-01-01") dl2 = daily_location("2016-01-02") flow = Flows(dl1, dl2) - assert isinstance(flow.random_sample(10), GraphMixin) + assert isinstance(flow.random_sample(size=10), GraphMixin) + + +def test_pickling(): + """ + Test that we can pickle and unpickle random classes. + """ + ss1 = UniqueSubscribers(start="2016-01-01", stop="2016-01-04").random_sample( + size=10 + ) + ss2 = Table("events.calls").random_sample( + size=10, sampling_method="bernoulli", seed=0.73 + ) + for ss in [ss1, ss2]: + assert ss.get_query() == pickle.loads(pickle.dumps(ss)).get_query() + assert ss.md5 == pickle.loads(pickle.dumps(ss)).md5 diff --git a/flowmachine/tests/test_subscriber_subsetting.py b/flowmachine/tests/test_subscriber_subsetting.py index 663df52e29..3ba6e1f75b 100644 --- a/flowmachine/tests/test_subscriber_subsetting.py +++ b/flowmachine/tests/test_subscriber_subsetting.py @@ -165,7 +165,7 @@ def test_cdrs_can_be_subset_by_list(get_dataframe, subscriber_list): def test_can_subset_by_sampler(get_dataframe): """Test that we can use the output of another query to subset by.""" unique_subs_sample = UniqueSubscribers("2016-01-01", "2016-01-07").random_sample( - size=10, method="system", seed=0.1 + size=10, sampling_method="system", seed=0.1 ) su = EventTableSubset( start="2016-01-01", stop="2016-01-03", subscriber_subset=unique_subs_sample diff --git a/integration_tests/tests/flowapi_tests/test_api_spec.test_generated_openapi_json_spec.approved.txt b/integration_tests/tests/flowapi_tests/test_api_spec.test_generated_openapi_json_spec.approved.txt index ce134c2fd3..a70dffbd1d 100644 --- a/integration_tests/tests/flowapi_tests/test_api_spec.test_generated_openapi_json_spec.approved.txt +++ b/integration_tests/tests/flowapi_tests/test_api_spec.test_generated_openapi_json_spec.approved.txt @@ -44,6 +44,38 @@ ], "type": "object" }, + "BernoulliRandomSample": { + "properties": { + "estimate_count": { + "default": true, + "type": "boolean" + }, + "fraction": { + "format": "float", + "maximum": 1.0, + "minimum": 0.0, + "nullable": true, + "type": "number" + }, + "sampling_method": { + "enum": [ + "bernoulli" + ], + "type": "string" + }, + "seed": { + "format": "float", + "type": "number" + }, + "size": { + "format": "int32", + "minimum": 1, + "nullable": true, + "type": "integer" + } + }, + "type": "object" + }, "DFSTotalMetricAmount": { "properties": { "aggregation_unit": { @@ -118,6 +150,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "subscriber_subset": { "enum": [ null @@ -145,6 +185,14 @@ "reference_location": { "$ref": "#/components/schemas/InputToDisplacement" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start": { "format": "date", "type": "string" @@ -229,6 +277,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start": { "format": "date", "type": "string" @@ -404,6 +460,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start_date": { "format": "date", "type": "string" @@ -987,6 +1051,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "subscriber_subset": { "enum": [ null @@ -1012,6 +1084,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start": { "format": "date", "type": "string" @@ -1049,6 +1129,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start": { "format": "date", "type": "string" @@ -1085,6 +1173,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start_date": { "format": "date", "type": "string" @@ -1104,6 +1200,65 @@ ], "type": "object" }, + "RandomIDsRandomSample": { + "properties": { + "estimate_count": { + "default": true, + "type": "boolean" + }, + "fraction": { + "format": "float", + "maximum": 1.0, + "minimum": 0.0, + "nullable": true, + "type": "number" + }, + "sampling_method": { + "enum": [ + "random_ids" + ], + "type": "string" + }, + "seed": { + "format": "float", + "maximum": 1.0, + "minimum": -1.0, + "type": "number" + }, + "size": { + "format": "int32", + "minimum": 1, + "nullable": true, + "type": "integer" + } + }, + "type": "object" + }, + "RandomSample": { + "discriminator": { + "mapping": { + "bernoulli": "#/components/schemas/BernoulliRandomSample", + "random_ids": "#/components/schemas/RandomIDsRandomSample", + "system": "#/components/schemas/SystemRandomSample", + "system_rows": "#/components/schemas/SystemRowsRandomSample" + }, + "propertyName": "sampling_method" + }, + "oneOf": [ + { + "$ref": "#/components/schemas/BernoulliRandomSample" + }, + { + "$ref": "#/components/schemas/RandomIDsRandomSample" + }, + { + "$ref": "#/components/schemas/SystemRandomSample" + }, + { + "$ref": "#/components/schemas/SystemRowsRandomSample" + } + ] + }, "SpatialAggregate": { "properties": { "locations": { @@ -1138,6 +1293,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start": { "format": "date", "type": "string" @@ -1161,6 +1324,66 @@ ], "type": "object" }, + "SystemRandomSample": { + "properties": { + "estimate_count": { + "default": true, + "type": "boolean" + }, + "fraction": { + "format": "float", + "maximum": 1.0, + "minimum": 0.0, + "nullable": true, + "type": "number" + }, + "sampling_method": { + "enum": [ + "system" + ], + "type": "string" + }, + "seed": { + "format": "float", + "type": "number" + }, + "size": { + "format": "int32", + "minimum": 1, + "nullable": true, + "type": "integer" + } + }, + "type": "object" + }, + "SystemRowsRandomSample": { + "properties": { + "estimate_count": { + "default": true, + "type": "boolean" + }, + "fraction": { + "format": "float", + "maximum": 1.0, + "minimum": 0.0, + "nullable": true, + "type": "number" + }, + "sampling_method": { + "enum": [ + "system_rows" + ], + "type": "string" + }, + "size": { + "format": "int32", + "minimum": 1, + "nullable": true, + "type": "integer" + } + }, + "type": "object" + }, "TopUpAmount": { "properties": { "query_kind": { @@ -1169,6 +1392,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start": { "format": "date", "type": "string" @@ -1217,6 +1448,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start_date": { "format": "date", "type": "string" @@ -1318,6 +1557,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start_date": { "format": "date", "type": "string" diff --git a/integration_tests/tests/flowapi_tests/test_api_spec.test_generated_openapi_redoc_spec.approved.txt b/integration_tests/tests/flowapi_tests/test_api_spec.test_generated_openapi_redoc_spec.approved.txt index dd5f37e8d4..680d5da1b0 100644 --- a/integration_tests/tests/flowapi_tests/test_api_spec.test_generated_openapi_redoc_spec.approved.txt +++ b/integration_tests/tests/flowapi_tests/test_api_spec.test_generated_openapi_redoc_spec.approved.txt @@ -44,6 +44,38 @@ ], "type": "object" }, + "BernoulliRandomSample": { + "properties": { + "estimate_count": { + "default": true, + "type": "boolean" + }, + "fraction": { + "format": "float", + "maximum": 1.0, + "minimum": 0.0, + "nullable": true, + "type": "number" + }, + "sampling_method": { + "enum": [ + "bernoulli" + ], + "type": "string" + }, + "seed": { + "format": "float", + "type": "number" + }, + "size": { + "format": "int32", + "minimum": 1, + "nullable": true, + "type": "integer" + } + }, + "type": "object" + }, "DFSTotalMetricAmount": { "properties": { "aggregation_unit": { @@ -118,6 +150,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "subscriber_subset": { "enum": [ null @@ -145,6 +185,14 @@ "reference_location": { "$ref": "#/components/schemas/InputToDisplacement" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start": { "format": "date", "type": "string" @@ -229,6 +277,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start": { "format": "date", "type": "string" @@ -385,6 +441,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start_date": { "format": "date", "type": "string" @@ -920,6 +984,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "subscriber_subset": { "enum": [ null @@ -945,6 +1017,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start": { "format": "date", "type": "string" @@ -982,6 +1062,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start": { "format": "date", "type": "string" @@ -1018,6 +1106,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start_date": { "format": "date", "type": "string" @@ -1037,6 +1133,56 @@ ], "type": "object" }, + "RandomIDsRandomSample": { + "properties": { + "estimate_count": { + "default": true, + "type": "boolean" + }, + "fraction": { + "format": "float", + "maximum": 1.0, + "minimum": 0.0, + "nullable": true, + "type": "number" + }, + "sampling_method": { + "enum": [ + "random_ids" + ], + "type": "string" + }, + "seed": { + "format": "float", + "maximum": 1.0, + "minimum": -1.0, + "type": "number" + }, + "size": { + "format": "int32", + "minimum": 1, + "nullable": true, + "type": "integer" + } + }, + "type": "object" + }, + "RandomSample": { + "oneOf": [ + { + "$ref": "#/components/schemas/BernoulliRandomSample" + }, + { + "$ref": "#/components/schemas/RandomIDsRandomSample" + }, + { + "$ref": "#/components/schemas/SystemRandomSample" + }, + { + "$ref": "#/components/schemas/SystemRowsRandomSample" + } + ] + }, "SpatialAggregate": { "properties": { "locations": { @@ -1071,6 +1217,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start": { "format": "date", "type": "string" @@ -1094,6 +1248,66 @@ ], "type": "object" }, + "SystemRandomSample": { + "properties": { + "estimate_count": { + "default": true, + "type": "boolean" + }, + "fraction": { + "format": "float", + "maximum": 1.0, + "minimum": 0.0, + "nullable": true, + "type": "number" + }, + "sampling_method": { + "enum": [ + "system" + ], + "type": "string" + }, + "seed": { + "format": "float", + "type": "number" + }, + "size": { + "format": "int32", + "minimum": 1, + "nullable": true, + "type": "integer" + } + }, + "type": "object" + }, + "SystemRowsRandomSample": { + "properties": { + "estimate_count": { + "default": true, + "type": "boolean" + }, + "fraction": { + "format": "float", + "maximum": 1.0, + "minimum": 0.0, + "nullable": true, + "type": "number" + }, + "sampling_method": { + "enum": [ + "system_rows" + ], + "type": "string" + }, + "size": { + "format": "int32", + "minimum": 1, + "nullable": true, + "type": "integer" + } + }, + "type": "object" + }, "TopUpAmount": { "properties": { "query_kind": { @@ -1102,6 +1316,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start": { "format": "date", "type": "string" @@ -1150,6 +1372,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start_date": { "format": "date", "type": "string" @@ -1251,6 +1481,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start_date": { "format": "date", "type": "string" diff --git a/integration_tests/tests/flowapi_tests/test_api_spec.test_generated_openapi_yaml_spec.approved.txt b/integration_tests/tests/flowapi_tests/test_api_spec.test_generated_openapi_yaml_spec.approved.txt index ce134c2fd3..a70dffbd1d 100644 --- a/integration_tests/tests/flowapi_tests/test_api_spec.test_generated_openapi_yaml_spec.approved.txt +++ b/integration_tests/tests/flowapi_tests/test_api_spec.test_generated_openapi_yaml_spec.approved.txt @@ -44,6 +44,38 @@ ], "type": "object" }, + "BernoulliRandomSample": { + "properties": { + "estimate_count": { + "default": true, + "type": "boolean" + }, + "fraction": { + "format": "float", + "maximum": 1.0, + "minimum": 0.0, + "nullable": true, + "type": "number" + }, + "sampling_method": { + "enum": [ + "bernoulli" + ], + "type": "string" + }, + "seed": { + "format": "float", + "type": "number" + }, + "size": { + "format": "int32", + "minimum": 1, + "nullable": true, + "type": "integer" + } + }, + "type": "object" + }, "DFSTotalMetricAmount": { "properties": { "aggregation_unit": { @@ -118,6 +150,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "subscriber_subset": { "enum": [ null @@ -145,6 +185,14 @@ "reference_location": { "$ref": "#/components/schemas/InputToDisplacement" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start": { "format": "date", "type": "string" @@ -229,6 +277,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start": { "format": "date", "type": "string" @@ -404,6 +460,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start_date": { "format": "date", "type": "string" @@ -987,6 +1051,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "subscriber_subset": { "enum": [ null @@ -1012,6 +1084,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start": { "format": "date", "type": "string" @@ -1049,6 +1129,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start": { "format": "date", "type": "string" @@ -1085,6 +1173,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start_date": { "format": "date", "type": "string" @@ -1104,6 +1200,65 @@ ], "type": "object" }, + "RandomIDsRandomSample": { + "properties": { + "estimate_count": { + "default": true, + "type": "boolean" + }, + "fraction": { + "format": "float", + "maximum": 1.0, + "minimum": 0.0, + "nullable": true, + "type": "number" + }, + "sampling_method": { + "enum": [ + "random_ids" + ], + "type": "string" + }, + "seed": { + "format": "float", + "maximum": 1.0, + "minimum": -1.0, + "type": "number" + }, + "size": { + "format": "int32", + "minimum": 1, + "nullable": true, + "type": "integer" + } + }, + "type": "object" + }, + "RandomSample": { + "discriminator": { + "mapping": { + "bernoulli": "#/components/schemas/BernoulliRandomSample", + "random_ids": "#/components/schemas/RandomIDsRandomSample", + "system": "#/components/schemas/SystemRandomSample", + "system_rows": "#/components/schemas/SystemRowsRandomSample" + }, + "propertyName": "sampling_method" + }, + "oneOf": [ + { + "$ref": "#/components/schemas/BernoulliRandomSample" + }, + { + "$ref": "#/components/schemas/RandomIDsRandomSample" + }, + { + "$ref": "#/components/schemas/SystemRandomSample" + }, + { + "$ref": "#/components/schemas/SystemRowsRandomSample" + } + ] + }, "SpatialAggregate": { "properties": { "locations": { @@ -1138,6 +1293,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start": { "format": "date", "type": "string" @@ -1161,6 +1324,66 @@ ], "type": "object" }, + "SystemRandomSample": { + "properties": { + "estimate_count": { + "default": true, + "type": "boolean" + }, + "fraction": { + "format": "float", + "maximum": 1.0, + "minimum": 0.0, + "nullable": true, + "type": "number" + }, + "sampling_method": { + "enum": [ + "system" + ], + "type": "string" + }, + "seed": { + "format": "float", + "type": "number" + }, + "size": { + "format": "int32", + "minimum": 1, + "nullable": true, + "type": "integer" + } + }, + "type": "object" + }, + "SystemRowsRandomSample": { + "properties": { + "estimate_count": { + "default": true, + "type": "boolean" + }, + "fraction": { + "format": "float", + "maximum": 1.0, + "minimum": 0.0, + "nullable": true, + "type": "number" + }, + "sampling_method": { + "enum": [ + "system_rows" + ], + "type": "string" + }, + "size": { + "format": "int32", + "minimum": 1, + "nullable": true, + "type": "integer" + } + }, + "type": "object" + }, "TopUpAmount": { "properties": { "query_kind": { @@ -1169,6 +1392,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start": { "format": "date", "type": "string" @@ -1217,6 +1448,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start_date": { "format": "date", "type": "string" @@ -1318,6 +1557,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start_date": { "format": "date", "type": "string" diff --git a/integration_tests/tests/flowmachine_server_tests/test_server.test_api_spec_of_flowmachine_query_schemas.approved.txt b/integration_tests/tests/flowmachine_server_tests/test_server.test_api_spec_of_flowmachine_query_schemas.approved.txt index b83b244173..550341169c 100644 --- a/integration_tests/tests/flowmachine_server_tests/test_server.test_api_spec_of_flowmachine_query_schemas.approved.txt +++ b/integration_tests/tests/flowmachine_server_tests/test_server.test_api_spec_of_flowmachine_query_schemas.approved.txt @@ -41,6 +41,38 @@ ], "type": "object" }, + "BernoulliRandomSample": { + "properties": { + "estimate_count": { + "default": true, + "type": "boolean" + }, + "fraction": { + "format": "float", + "maximum": 1.0, + "minimum": 0.0, + "nullable": true, + "type": "number" + }, + "sampling_method": { + "enum": [ + "bernoulli" + ], + "type": "string" + }, + "seed": { + "format": "float", + "type": "number" + }, + "size": { + "format": "int32", + "minimum": 1, + "nullable": true, + "type": "integer" + } + }, + "type": "object" + }, "DFSTotalMetricAmount": { "properties": { "aggregation_unit": { @@ -114,6 +146,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "subscriber_subset": { "enum": [ null @@ -140,6 +180,14 @@ "reference_location": { "$ref": "#/components/schemas/InputToDisplacement" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start": { "format": "date", "type": "string" @@ -222,6 +270,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start": { "format": "date", "type": "string" @@ -394,6 +450,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start_date": { "format": "date", "type": "string" @@ -970,6 +1034,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "subscriber_subset": { "enum": [ null @@ -994,6 +1066,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start": { "format": "date", "type": "string" @@ -1030,6 +1110,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start": { "format": "date", "type": "string" @@ -1065,6 +1153,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start_date": { "format": "date", "type": "string" @@ -1083,6 +1179,65 @@ ], "type": "object" }, + "RandomIDsRandomSample": { + "properties": { + "estimate_count": { + "default": true, + "type": "boolean" + }, + "fraction": { + "format": "float", + "maximum": 1.0, + "minimum": 0.0, + "nullable": true, + "type": "number" + }, + "sampling_method": { + "enum": [ + "random_ids" + ], + "type": "string" + }, + "seed": { + "format": "float", + "maximum": 1.0, + "minimum": -1.0, + "type": "number" + }, + "size": { + "format": "int32", + "minimum": 1, + "nullable": true, + "type": "integer" + } + }, + "type": "object" + }, + "RandomSample": { + "discriminator": { + "mapping": { + "bernoulli": "#/components/schemas/BernoulliRandomSample", + "random_ids": "#/components/schemas/RandomIDsRandomSample", + "system": "#/components/schemas/SystemRandomSample", + "system_rows": "#/components/schemas/SystemRowsRandomSample" + }, + "propertyName": "sampling_method" + }, + "oneOf": [ + { + "$ref": "#/components/schemas/BernoulliRandomSample" + }, + { + "$ref": "#/components/schemas/RandomIDsRandomSample" + }, + { + "$ref": "#/components/schemas/SystemRandomSample" + }, + { + "$ref": "#/components/schemas/SystemRowsRandomSample" + } + ] + }, "SpatialAggregate": { "properties": { "locations": { @@ -1116,6 +1271,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start": { "format": "date", "type": "string" @@ -1138,6 +1301,66 @@ ], "type": "object" }, + "SystemRandomSample": { + "properties": { + "estimate_count": { + "default": true, + "type": "boolean" + }, + "fraction": { + "format": "float", + "maximum": 1.0, + "minimum": 0.0, + "nullable": true, + "type": "number" + }, + "sampling_method": { + "enum": [ + "system" + ], + "type": "string" + }, + "seed": { + "format": "float", + "type": "number" + }, + "size": { + "format": "int32", + "minimum": 1, + "nullable": true, + "type": "integer" + } + }, + "type": "object" + }, + "SystemRowsRandomSample": { + "properties": { + "estimate_count": { + "default": true, + "type": "boolean" + }, + "fraction": { + "format": "float", + "maximum": 1.0, + "minimum": 0.0, + "nullable": true, + "type": "number" + }, + "sampling_method": { + "enum": [ + "system_rows" + ], + "type": "string" + }, + "size": { + "format": "int32", + "minimum": 1, + "nullable": true, + "type": "integer" + } + }, + "type": "object" + }, "TopUpAmount": { "properties": { "query_kind": { @@ -1146,6 +1369,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start": { "format": "date", "type": "string" @@ -1193,6 +1424,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start_date": { "format": "date", "type": "string" @@ -1292,6 +1531,14 @@ ], "type": "string" }, + "sampling": { + "allOf": [ + { + "$ref": "#/components/schemas/RandomSample" + } + ], + "nullable": true + }, "start_date": { "format": "date", "type": "string" diff --git a/integration_tests/tests/query_tests/test_queries.py b/integration_tests/tests/query_tests/test_queries.py index 4dad9dbaf6..cfa17a8333 100644 --- a/integration_tests/tests/query_tests/test_queries.py +++ b/integration_tests/tests/query_tests/test_queries.py @@ -551,6 +551,35 @@ "method": "distr", }, ), + ( + "spatial_aggregate", + { + "locations": flowclient.random_sample( + query=flowclient.daily_location( + date="2016-01-01", + aggregation_unit="admin3", + method="most-common", + ), + size=10, + ) + }, + ), + ( + "spatial_aggregate", + { + "locations": flowclient.random_sample( + query=flowclient.daily_location( + date="2016-01-01", + aggregation_unit="admin3", + method="most-common", + ), + sampling_method="bernoulli", + fraction=0.5, + estimate_count=False, + seed=0.2, + ) + }, + ), ], ) def test_run_query(query_kind, params, universal_access_token, flowapi_url):