From b49b512291a2b8ea3a070568792599e4822ce275 Mon Sep 17 00:00:00 2001
From: Piotr
Date: Mon, 24 Apr 2023 20:08:08 +0200
Subject: [PATCH] fix: use common interface in joiners and loaders
---
srai/joiners/_base.py | 4 +++-
srai/joiners/intersection_joiner.py | 3 ++-
srai/loaders/_base.py | 9 ++++-----
srai/loaders/geoparquet_loader.py | 3 ++-
srai/loaders/gtfs_loader.py | 3 ++-
srai/loaders/osm_loaders/osm_online_loader.py | 3 ++-
srai/loaders/osm_loaders/osm_pbf_loader.py | 3 ++-
srai/loaders/osm_way_loader/osm_way_loader.py | 3 ++-
8 files changed, 19 insertions(+), 12 deletions(-)
diff --git a/srai/joiners/_base.py b/srai/joiners/_base.py
index 00e35f32b..098abac03 100644
--- a/srai/joiners/_base.py
+++ b/srai/joiners/_base.py
@@ -10,7 +10,9 @@ class Joiner(abc.ABC):
@abc.abstractmethod
def transform(
- self, regions: gpd.GeoDataFrame, features: gpd.GeoDataFrame
+ self,
+ regions: gpd.GeoDataFrame,
+ features: gpd.GeoDataFrame,
) -> gpd.GeoDataFrame: # pragma: no cover
"""
Join features to regions.
diff --git a/srai/joiners/intersection_joiner.py b/srai/joiners/intersection_joiner.py
index 930fdf1c2..ebea9e2b3 100644
--- a/srai/joiners/intersection_joiner.py
+++ b/srai/joiners/intersection_joiner.py
@@ -8,9 +8,10 @@
import pandas as pd
from srai.constants import FEATURES_INDEX, GEOMETRY_COLUMN, REGIONS_INDEX
+from srai.joiners import Joiner
-class IntersectionJoiner:
+class IntersectionJoiner(Joiner):
"""
Intersection Joiner.
diff --git a/srai/loaders/_base.py b/srai/loaders/_base.py
index c38abb7d7..b708da892 100644
--- a/srai/loaders/_base.py
+++ b/srai/loaders/_base.py
@@ -1,8 +1,7 @@
"""Base class for loaders."""
import abc
-from pathlib import Path
-from typing import Union
+from typing import Any
import geopandas as gpd
@@ -11,13 +10,13 @@ class Loader(abc.ABC):
"""Abstract class for loaders."""
@abc.abstractmethod
- def load(self, area: Union[gpd.GeoDataFrame, Path]) -> gpd.GeoDataFrame: # pragma: no cover
+ def load(self, *args: Any, **kwargs: Any) -> gpd.GeoDataFrame: # pragma: no cover
"""
Load data for a given area.
Args:
- area (gdf.GeoDataFrame | Path): GeoDataFrame with the area of interest or a path
- to a file with a geometry.
+ *args: Positional arguments dependating on a specific loader.
+ **kwargs: Keyword arguments dependating on a specific loader.
Returns:
GeoDataFrame with the downloaded data.
diff --git a/srai/loaders/geoparquet_loader.py b/srai/loaders/geoparquet_loader.py
index 0661a6965..a7a6e343b 100644
--- a/srai/loaders/geoparquet_loader.py
+++ b/srai/loaders/geoparquet_loader.py
@@ -10,9 +10,10 @@
import geopandas as gpd
from srai.constants import GEOMETRY_COLUMN, WGS84_CRS
+from srai.loaders import Loader
-class GeoparquetLoader:
+class GeoparquetLoader(Loader):
"""
GeoparquetLoader.
diff --git a/srai/loaders/gtfs_loader.py b/srai/loaders/gtfs_loader.py
index 1b669def1..106009abf 100644
--- a/srai/loaders/gtfs_loader.py
+++ b/srai/loaders/gtfs_loader.py
@@ -18,6 +18,7 @@
from shapely.geometry import Point
from srai.constants import GEOMETRY_COLUMN, WGS84_CRS
+from srai.loaders import Loader
from srai.utils._optional import import_optional_dependencies
if TYPE_CHECKING: # pragma: no cover
@@ -27,7 +28,7 @@
GTFS2VEC_TRIPS_PREFIX = "trips_at_"
-class GTFSLoader:
+class GTFSLoader(Loader):
"""
GTFSLoader.
diff --git a/srai/loaders/osm_loaders/osm_online_loader.py b/srai/loaders/osm_loaders/osm_online_loader.py
index 24460fded..b06ea2995 100644
--- a/srai/loaders/osm_loaders/osm_online_loader.py
+++ b/srai/loaders/osm_loaders/osm_online_loader.py
@@ -12,11 +12,12 @@
from tqdm import tqdm
from srai.constants import FEATURES_INDEX, GEOMETRY_COLUMN, WGS84_CRS
+from srai.loaders import Loader
from srai.loaders.osm_loaders.filters.osm_tags_type import osm_tags_type
from srai.utils._optional import import_optional_dependencies
-class OSMOnlineLoader:
+class OSMOnlineLoader(Loader):
"""
OSMOnlineLoader.
diff --git a/srai/loaders/osm_loaders/osm_pbf_loader.py b/srai/loaders/osm_loaders/osm_pbf_loader.py
index f81717bc5..91cfd799f 100644
--- a/srai/loaders/osm_loaders/osm_pbf_loader.py
+++ b/srai/loaders/osm_loaders/osm_pbf_loader.py
@@ -10,11 +10,12 @@
import pandas as pd
from srai.constants import FEATURES_INDEX, WGS84_CRS
+from srai.loaders import Loader
from srai.loaders.osm_loaders.filters.osm_tags_type import osm_tags_type
from srai.utils._optional import import_optional_dependencies
-class OSMPbfLoader:
+class OSMPbfLoader(Loader):
"""
OSMPbfLoader.
diff --git a/srai/loaders/osm_way_loader/osm_way_loader.py b/srai/loaders/osm_way_loader/osm_way_loader.py
index d65b24c5c..911bee299 100644
--- a/srai/loaders/osm_way_loader/osm_way_loader.py
+++ b/srai/loaders/osm_way_loader/osm_way_loader.py
@@ -16,6 +16,7 @@
from srai.constants import FEATURES_INDEX, GEOMETRY_COLUMN, WGS84_CRS
from srai.exceptions import LoadedDataIsEmptyException
+from srai.loaders import Loader
from srai.utils._optional import import_optional_dependencies
from . import constants
@@ -41,7 +42,7 @@ class NetworkType(str, Enum):
WALK = "walk"
-class OSMWayLoader:
+class OSMWayLoader(Loader):
"""
OSMWayLoader downloads road infrastructure from OSM.