diff --git a/tests/analysis/test_label.py b/tests/analysis/test_label.py index 50912064..fce26156 100644 --- a/tests/analysis/test_label.py +++ b/tests/analysis/test_label.py @@ -1,11 +1,9 @@ import os -import numpy as np import pandas as pd import pytest import trackintel as ti -from trackintel.analysis.labelling import _check_categories class TestCreate_activity_flag: @@ -100,14 +98,3 @@ def test_simple_coarse_identification_projected(self): assert tpls_transport_mode_3.iloc[0]["mode"] == "slow_mobility" assert tpls_transport_mode_3.iloc[1]["mode"] == "motorized_mobility" assert tpls_transport_mode_3.iloc[2]["mode"] == "fast_mobility" - - def test_check_categories(self): - """Asserts the correct identification of valid category dictionaries.""" - tpls_file = os.path.join("tests", "data", "triplegs_transport_mode_identification.csv") - tpls = ti.read_triplegs_csv(tpls_file, sep=";", index_col="id") - correct_dict = {2: "cat1", 7: "cat2", np.inf: "cat3"} - - assert _check_categories(correct_dict) - with pytest.raises(ValueError): - incorrect_dict = {10: "cat1", 5: "cat2", np.inf: "cat3"} - tpls.as_triplegs.predict_transport_mode(method="simple-coarse", categories=incorrect_dict) diff --git a/trackintel/analysis/labelling.py b/trackintel/analysis/labelling.py index 031efdb5..cd330603 100644 --- a/trackintel/analysis/labelling.py +++ b/trackintel/analysis/labelling.py @@ -1,6 +1,7 @@ import datetime import numpy as np +import pandas as pd from trackintel.geogr import get_speed_triplegs @@ -81,90 +82,39 @@ def predict_transport_mode(triplegs, method="simple-coarse", **kwargs): categories = kwargs.pop( "categories", {15 / 3.6: "slow_mobility", 100 / 3.6: "motorized_mobility", np.inf: "fast_mobility"} ) - - return _predict_transport_mode_simple_coarse(triplegs, categories) + triplegs = triplegs.copy() + triplegs["mode"] = _predict_transport_mode_simple_coarse(triplegs, categories) + return triplegs else: raise AttributeError(f"Method {method} not known for predicting tripleg transport modes.") -def _predict_transport_mode_simple_coarse(triplegs_in, categories): +def _predict_transport_mode_simple_coarse(triplegs, categories): """ - Predict a transport mode out of three coarse classes. + Predict a transport mode based on provided categories. Implements a simple speed based heuristic (over the whole tripleg). As such, it is very fast, but also very simple and coarse. Parameters ---------- - triplegs_in : Triplegs + triplegs : Triplegs The triplegs for the transport mode prediction. categories : dict, optional - The categories for the speed classification {upper_boundary:'category_name'}. + The categories for the speed classification {upper_boundary: 'category_name'}. The unit for the upper boundary is m/s. - The default is {15/3.6: 'slow_mobility', 100/3.6: 'motorized_mobility', np.inf: 'fast_mobility'}. - - Raises - ------ - ValueError - In case the boundaries of the categories are not in ascending order. Returns ------- - triplegs : trackintel triplegs GeoDataFrame - the triplegs with added column mode, containing the predicted transport modes. + cuts : pd.Series + Column containing the predicted transport modes. For additional documentation, see :func:`trackintel.analysis.transport_mode_identification.predict_transport_mode`. - - """ - if not (_check_categories(categories)): - raise ValueError("the categories must be in increasing order") - - triplegs = triplegs_in.copy() - - def category_by_speed(speed): - """ - Identify the mode based on the (overall) tripleg speed. - - Parameters - ---------- - speed : float - the speed of one tripleg - - Returns - ------- - str - the identified mode. - """ - for bound in categories: - if speed < bound: - return categories[bound] - - triplegs_speed = get_speed_triplegs(triplegs) - - triplegs["mode"] = triplegs_speed["speed"].apply(category_by_speed) - return triplegs - - -def _check_categories(cat): - """ - Check if the keys of a dictionary are in ascending order. - - Parameters - ---------- - cat : disct - the dictionary to be checked. - - Returns - ------- - correct : bool - True if dict keys are in ascending order False otherwise. - """ - correct = True - bounds = list(cat.keys()) - for i in range(len(bounds) - 1): - if bounds[i] >= bounds[i + 1]: - correct = False - return correct + categories = dict(sorted(categories.items(), key=lambda item: item[0])) + intervals = pd.IntervalIndex.from_breaks([-np.inf] + list(categories.keys()), closed="left") + speed = get_speed_triplegs(triplegs)["speed"] + cuts = pd.cut(speed, intervals) + return cuts.cat.rename_categories(categories.values())