From b49b512291a2b8ea3a070568792599e4822ce275 Mon Sep 17 00:00:00 2001 From: Piotr Date: Mon, 24 Apr 2023 20:08:08 +0200 Subject: [PATCH] fix: use common interface in joiners and loaders --- srai/joiners/_base.py | 4 +++- srai/joiners/intersection_joiner.py | 3 ++- srai/loaders/_base.py | 9 ++++----- srai/loaders/geoparquet_loader.py | 3 ++- srai/loaders/gtfs_loader.py | 3 ++- srai/loaders/osm_loaders/osm_online_loader.py | 3 ++- srai/loaders/osm_loaders/osm_pbf_loader.py | 3 ++- srai/loaders/osm_way_loader/osm_way_loader.py | 3 ++- 8 files changed, 19 insertions(+), 12 deletions(-) diff --git a/srai/joiners/_base.py b/srai/joiners/_base.py index 00e35f32b..098abac03 100644 --- a/srai/joiners/_base.py +++ b/srai/joiners/_base.py @@ -10,7 +10,9 @@ class Joiner(abc.ABC): @abc.abstractmethod def transform( - self, regions: gpd.GeoDataFrame, features: gpd.GeoDataFrame + self, + regions: gpd.GeoDataFrame, + features: gpd.GeoDataFrame, ) -> gpd.GeoDataFrame: # pragma: no cover """ Join features to regions. diff --git a/srai/joiners/intersection_joiner.py b/srai/joiners/intersection_joiner.py index 930fdf1c2..ebea9e2b3 100644 --- a/srai/joiners/intersection_joiner.py +++ b/srai/joiners/intersection_joiner.py @@ -8,9 +8,10 @@ import pandas as pd from srai.constants import FEATURES_INDEX, GEOMETRY_COLUMN, REGIONS_INDEX +from srai.joiners import Joiner -class IntersectionJoiner: +class IntersectionJoiner(Joiner): """ Intersection Joiner. diff --git a/srai/loaders/_base.py b/srai/loaders/_base.py index c38abb7d7..b708da892 100644 --- a/srai/loaders/_base.py +++ b/srai/loaders/_base.py @@ -1,8 +1,7 @@ """Base class for loaders.""" import abc -from pathlib import Path -from typing import Union +from typing import Any import geopandas as gpd @@ -11,13 +10,13 @@ class Loader(abc.ABC): """Abstract class for loaders.""" @abc.abstractmethod - def load(self, area: Union[gpd.GeoDataFrame, Path]) -> gpd.GeoDataFrame: # pragma: no cover + def load(self, *args: Any, **kwargs: Any) -> gpd.GeoDataFrame: # pragma: no cover """ Load data for a given area. Args: - area (gdf.GeoDataFrame | Path): GeoDataFrame with the area of interest or a path - to a file with a geometry. + *args: Positional arguments dependating on a specific loader. + **kwargs: Keyword arguments dependating on a specific loader. Returns: GeoDataFrame with the downloaded data. diff --git a/srai/loaders/geoparquet_loader.py b/srai/loaders/geoparquet_loader.py index 0661a6965..a7a6e343b 100644 --- a/srai/loaders/geoparquet_loader.py +++ b/srai/loaders/geoparquet_loader.py @@ -10,9 +10,10 @@ import geopandas as gpd from srai.constants import GEOMETRY_COLUMN, WGS84_CRS +from srai.loaders import Loader -class GeoparquetLoader: +class GeoparquetLoader(Loader): """ GeoparquetLoader. diff --git a/srai/loaders/gtfs_loader.py b/srai/loaders/gtfs_loader.py index 1b669def1..106009abf 100644 --- a/srai/loaders/gtfs_loader.py +++ b/srai/loaders/gtfs_loader.py @@ -18,6 +18,7 @@ from shapely.geometry import Point from srai.constants import GEOMETRY_COLUMN, WGS84_CRS +from srai.loaders import Loader from srai.utils._optional import import_optional_dependencies if TYPE_CHECKING: # pragma: no cover @@ -27,7 +28,7 @@ GTFS2VEC_TRIPS_PREFIX = "trips_at_" -class GTFSLoader: +class GTFSLoader(Loader): """ GTFSLoader. diff --git a/srai/loaders/osm_loaders/osm_online_loader.py b/srai/loaders/osm_loaders/osm_online_loader.py index 24460fded..b06ea2995 100644 --- a/srai/loaders/osm_loaders/osm_online_loader.py +++ b/srai/loaders/osm_loaders/osm_online_loader.py @@ -12,11 +12,12 @@ from tqdm import tqdm from srai.constants import FEATURES_INDEX, GEOMETRY_COLUMN, WGS84_CRS +from srai.loaders import Loader from srai.loaders.osm_loaders.filters.osm_tags_type import osm_tags_type from srai.utils._optional import import_optional_dependencies -class OSMOnlineLoader: +class OSMOnlineLoader(Loader): """ OSMOnlineLoader. diff --git a/srai/loaders/osm_loaders/osm_pbf_loader.py b/srai/loaders/osm_loaders/osm_pbf_loader.py index f81717bc5..91cfd799f 100644 --- a/srai/loaders/osm_loaders/osm_pbf_loader.py +++ b/srai/loaders/osm_loaders/osm_pbf_loader.py @@ -10,11 +10,12 @@ import pandas as pd from srai.constants import FEATURES_INDEX, WGS84_CRS +from srai.loaders import Loader from srai.loaders.osm_loaders.filters.osm_tags_type import osm_tags_type from srai.utils._optional import import_optional_dependencies -class OSMPbfLoader: +class OSMPbfLoader(Loader): """ OSMPbfLoader. diff --git a/srai/loaders/osm_way_loader/osm_way_loader.py b/srai/loaders/osm_way_loader/osm_way_loader.py index d65b24c5c..911bee299 100644 --- a/srai/loaders/osm_way_loader/osm_way_loader.py +++ b/srai/loaders/osm_way_loader/osm_way_loader.py @@ -16,6 +16,7 @@ from srai.constants import FEATURES_INDEX, GEOMETRY_COLUMN, WGS84_CRS from srai.exceptions import LoadedDataIsEmptyException +from srai.loaders import Loader from srai.utils._optional import import_optional_dependencies from . import constants @@ -41,7 +42,7 @@ class NetworkType(str, Enum): WALK = "walk" -class OSMWayLoader: +class OSMWayLoader(Loader): """ OSMWayLoader downloads road infrastructure from OSM.