diff --git a/flowclient/flowclient/__init__.py b/flowclient/flowclient/__init__.py index 55b61c57bf..369cbfb120 100644 --- a/flowclient/flowclient/__init__.py +++ b/flowclient/flowclient/__init__.py @@ -47,6 +47,7 @@ unique_locations_spec, most_frequent_location_spec, total_active_periods_spec, + per_subscriber_aggregate_spec, ) from . import aggregates from .aggregates import ( diff --git a/flowclient/flowclient/query_specs.py b/flowclient/flowclient/query_specs.py index bfcb47ba95..c825afbf9a 100644 --- a/flowclient/flowclient/query_specs.py +++ b/flowclient/flowclient/query_specs.py @@ -840,3 +840,23 @@ def random_sample_spec( ) sampled_query["sampling"] = sampling return sampled_query + + +def per_subscriber_aggregate_spec(*, subscriber_query: Dict, agg_method: str): + """ + Query that performs per-subscriber aggregation of a table. Returns a column + 'subscriber' containing unique subscribers and a column 'value' containing the + aggregration. + + Parameters + ---------- + subscriber_query: SubscriberFeature + A query with a `subscriber` column + agg_method: {"count", "sum", "avg", "max", "min", "median", "stddev", "variance"} default "avg" + The method of aggregation to perform + """ + return { + "query_kind": "per_subscriber_aggregate", + "subscriber_query": subscriber_query, + "agg_method": agg_method, + } diff --git a/flowmachine/flowmachine/core/server/query_schemas/flowmachine_query.py b/flowmachine/flowmachine/core/server/query_schemas/flowmachine_query.py index 57bd9a1b01..a101ddd3c9 100644 --- a/flowmachine/flowmachine/core/server/query_schemas/flowmachine_query.py +++ b/flowmachine/flowmachine/core/server/query_schemas/flowmachine_query.py @@ -30,6 +30,7 @@ from .geography import GeographySchema from .location_event_counts import LocationEventCountsSchema from .most_frequent_location import MostFrequentLocationSchema +from .per_subscriber_aggregate import PerSubscriberAggregateSchema from .trips_od_matrix import TripsODMatrixSchema from .unique_subscriber_counts import UniqueSubscriberCountsSchema from .location_introversion import LocationIntroversionSchema @@ -66,6 +67,7 @@ class FlowmachineQuerySchema(OneOfSchema): "unmoving_counts": UnmovingCountsSchema, "unmoving_at_reference_location_counts": UnmovingAtReferenceLocationCountsSchema, "trips_od_matrix": TripsODMatrixSchema, + "per_subscriber_aggregate": PerSubscriberAggregateSchema, } diff --git a/flowmachine/flowmachine/core/server/query_schemas/histogram_aggregate.py b/flowmachine/flowmachine/core/server/query_schemas/histogram_aggregate.py index 29383781a3..e80df522fc 100644 --- a/flowmachine/flowmachine/core/server/query_schemas/histogram_aggregate.py +++ b/flowmachine/flowmachine/core/server/query_schemas/histogram_aggregate.py @@ -4,26 +4,8 @@ from marshmallow import Schema, fields, post_load, validates_schema, ValidationError from marshmallow.validate import OneOf -from marshmallow_oneofschema import OneOfSchema from flowmachine.core.server.query_schemas.custom_fields import Bounds -from flowmachine.core.server.query_schemas.radius_of_gyration import ( - RadiusOfGyrationSchema, -) -from flowmachine.core.server.query_schemas.subscriber_degree import ( - SubscriberDegreeSchema, -) -from flowmachine.core.server.query_schemas.topup_amount import TopUpAmountSchema -from flowmachine.core.server.query_schemas.event_count import EventCountSchema -from flowmachine.core.server.query_schemas.nocturnal_events import NocturnalEventsSchema -from flowmachine.core.server.query_schemas.unique_location_counts import ( - UniqueLocationCountsSchema, -) -from flowmachine.core.server.query_schemas.displacement import DisplacementSchema -from flowmachine.core.server.query_schemas.pareto_interactions import ( - ParetoInteractionsSchema, -) -from flowmachine.core.server.query_schemas.topup_balance import TopUpBalanceSchema from flowmachine.features import HistogramAggregation from .base_exposed_query import BaseExposedQuery @@ -32,23 +14,7 @@ __all__ = ["HistogramAggregateSchema", "HistogramAggregateExposed"] from .base_schema import BaseSchema -from .total_active_periods import TotalActivePeriodsSchema - - -class HistogrammableMetrics(OneOfSchema): - type_field = "query_kind" - type_schemas = { - "radius_of_gyration": RadiusOfGyrationSchema, - "unique_location_counts": UniqueLocationCountsSchema, - "topup_balance": TopUpBalanceSchema, - "subscriber_degree": SubscriberDegreeSchema, - "topup_amount": TopUpAmountSchema, - "event_count": EventCountSchema, - "pareto_interactions": ParetoInteractionsSchema, - "nocturnal_events": NocturnalEventsSchema, - "displacement": DisplacementSchema, - "total_active_periods": TotalActivePeriodsSchema, - } +from .numeric_subscriber_metrics import NumericSubscriberMetricsSchema class HistogramBins(Schema): @@ -99,7 +65,7 @@ def _flowmachine_query_obj(self): class HistogramAggregateSchema(BaseSchema): # query_kind parameter is required here for claims validation query_kind = fields.String(validate=OneOf(["histogram_aggregate"])) - metric = fields.Nested(HistogrammableMetrics, required=True) + metric = fields.Nested(NumericSubscriberMetricsSchema, required=True) range = fields.Nested(Bounds) bins = fields.Nested(HistogramBins) diff --git a/flowmachine/flowmachine/core/server/query_schemas/numeric_subscriber_metrics.py b/flowmachine/flowmachine/core/server/query_schemas/numeric_subscriber_metrics.py new file mode 100644 index 0000000000..7d09f2ae44 --- /dev/null +++ b/flowmachine/flowmachine/core/server/query_schemas/numeric_subscriber_metrics.py @@ -0,0 +1,38 @@ +from marshmallow_oneofschema import OneOfSchema + +from flowmachine.core.server.query_schemas.displacement import DisplacementSchema +from flowmachine.core.server.query_schemas.event_count import EventCountSchema +from flowmachine.core.server.query_schemas.nocturnal_events import NocturnalEventsSchema +from flowmachine.core.server.query_schemas.pareto_interactions import ( + ParetoInteractionsSchema, +) +from flowmachine.core.server.query_schemas.radius_of_gyration import ( + RadiusOfGyrationSchema, +) +from flowmachine.core.server.query_schemas.subscriber_degree import ( + SubscriberDegreeSchema, +) +from flowmachine.core.server.query_schemas.topup_amount import TopUpAmountSchema +from flowmachine.core.server.query_schemas.topup_balance import TopUpBalanceSchema +from flowmachine.core.server.query_schemas.total_active_periods import ( + TotalActivePeriodsSchema, +) +from flowmachine.core.server.query_schemas.unique_location_counts import ( + UniqueLocationCountsSchema, +) + + +class NumericSubscriberMetricsSchema(OneOfSchema): + type_field = "query_kind" + type_schemas = { + "radius_of_gyration": RadiusOfGyrationSchema, + "unique_location_counts": UniqueLocationCountsSchema, + "topup_balance": TopUpBalanceSchema, + "subscriber_degree": SubscriberDegreeSchema, + "topup_amount": TopUpAmountSchema, + "event_count": EventCountSchema, + "pareto_interactions": ParetoInteractionsSchema, + "nocturnal_events": NocturnalEventsSchema, + "displacement": DisplacementSchema, + "total_active_periods": TotalActivePeriodsSchema, + } diff --git a/flowmachine/flowmachine/core/server/query_schemas/per_subscriber_aggregate.py b/flowmachine/flowmachine/core/server/query_schemas/per_subscriber_aggregate.py new file mode 100644 index 0000000000..00ed67e732 --- /dev/null +++ b/flowmachine/flowmachine/core/server/query_schemas/per_subscriber_aggregate.py @@ -0,0 +1,46 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +from functools import reduce + +from marshmallow import fields +from marshmallow.validate import OneOf, Length + +from flowmachine.core.server.query_schemas import BaseExposedQuery +from flowmachine.core.server.query_schemas.numeric_subscriber_metrics import ( + NumericSubscriberMetricsSchema, +) +from flowmachine.core.server.query_schemas.base_schema import BaseSchema +from flowmachine.features.subscriber.per_subscriber_aggregate import ( + PerSubscriberAggregate, + agg_methods, +) + + +class PerSubscriberAggregateExposed(BaseExposedQuery): + def __init__(self, subscriber_queries, agg_method): + self.subscriber_queries = subscriber_queries + self.agg_method = agg_method + + @property + def _flowmachine_query_obj(self): + subscriber_query = reduce( + # TODO: Replace with Jono's new list input to union + lambda x, y: x._flowmachine_query_obj.union(y._flowmachine_query_obj), + self.subscriber_queries, + ) + return PerSubscriberAggregate( + subscriber_query=subscriber_query, + agg_column="value", + agg_method=self.agg_method, + ) + + +class PerSubscriberAggregateSchema(BaseSchema): + query_kind = fields.String(validate=OneOf(["per_subscriber_aggregate"])) + subscriber_queries = fields.List( + fields.Nested(NumericSubscriberMetricsSchema), validate=Length(min=1) + ) + agg_method = fields.String(validate=OneOf(agg_methods)) + + __model__ = PerSubscriberAggregateExposed diff --git a/flowmachine/tests/test_query_object_construction.py b/flowmachine/tests/test_query_object_construction.py index 73e9db9e51..8818d82f94 100644 --- a/flowmachine/tests/test_query_object_construction.py +++ b/flowmachine/tests/test_query_object_construction.py @@ -269,6 +269,24 @@ def test_construct_query(diff_reporter): "event_types": ["calls", "sms"], "subscriber_subset": None, }, + { + "query_kind": "per_subscriber_aggregate", + "subscriber_queries": [ + { + "query_kind": "total_active_periods", + "start_date": "2016-01-01", + "total_periods": 1, + "event_types": ["calls", "sms"], + }, + { + "query_kind": "total_active_periods", + "start_date": "2016-01-02", + "total_periods": 1, + "event_types": ["calls", "sms"], + }, + ], + "agg_method": "min", + }, ] def get_query_id_for_query_spec(query_spec): diff --git a/flowmachine/tests/test_query_object_construction.test_construct_query.approved.txt b/flowmachine/tests/test_query_object_construction.test_construct_query.approved.txt index ce33d069b2..e32a23daec 100644 --- a/flowmachine/tests/test_query_object_construction.test_construct_query.approved.txt +++ b/flowmachine/tests/test_query_object_construction.test_construct_query.approved.txt @@ -354,5 +354,29 @@ "sms" ], "subscriber_subset": null + }, + "ce4e263fd59363ef7c4632f87b924d02": { + "query_kind": "per_subscriber_aggregate", + "subscriber_queries": [ + { + "query_kind": "total_active_periods", + "start_date": "2016-01-01", + "total_periods": 1, + "event_types": [ + "calls", + "sms" + ] + }, + { + "query_kind": "total_active_periods", + "start_date": "2016-01-02", + "total_periods": 1, + "event_types": [ + "calls", + "sms" + ] + } + ], + "agg_method": "min" } } diff --git a/integration_tests/tests/query_tests/test_queries.py b/integration_tests/tests/query_tests/test_queries.py index c35b1b5fcb..f0ab81d612 100644 --- a/integration_tests/tests/query_tests/test_queries.py +++ b/integration_tests/tests/query_tests/test_queries.py @@ -706,6 +706,22 @@ aggregation_unit="admin3", event_types=["calls", "sms"], ), + partial( + flowclient.per_subscriber_aggregate_spec, + subscriber_queries=[ + flowclient.total_active_periods_spec( + start_date="2016-01-01", + total_periods=1, + event_types=["calls", "sms"], + ), + flowclient.total_active_periods_spec( + start_date="2016-01-02", + total_periods=1, + event_types=["calls", "sms"], + ), + ], + agg_method="min", + ), ]