Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add index name to regionizers #182

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions srai/joiners/intersection_joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import geopandas as gpd
import pandas as pd

from srai.utils.constants import FEATURES_INDEX, REGIONS_INDEX


class IntersectionJoiner:
"""
Expand Down Expand Up @@ -68,11 +70,11 @@ def _join_with_geom(
"""
joined_parts = [
gpd.overlay(
single[["geometry"]].reset_index(names="feature_id"),
regions[["geometry"]].reset_index(names="region_id"),
single[["geometry"]].reset_index(names=FEATURES_INDEX),
regions[["geometry"]].reset_index(names=REGIONS_INDEX),
how="intersection",
keep_geom_type=False,
).set_index(["region_id", "feature_id"])
).set_index([REGIONS_INDEX, FEATURES_INDEX])
for _, single in features.groupby(features["geometry"].geom_type)
]

Expand All @@ -95,12 +97,12 @@ def _join_without_geom(
"""
joint = (
gpd.sjoin(
regions.reset_index(names="region_id"),
features.reset_index(names="feature_id"),
regions.reset_index(names=REGIONS_INDEX),
features.reset_index(names=FEATURES_INDEX),
how="inner",
predicate="intersects",
)
.set_index(["region_id", "feature_id"])
.set_index([REGIONS_INDEX, FEATURES_INDEX])
.drop(columns=["index_right", "geometry"])
)
return joint
6 changes: 3 additions & 3 deletions srai/regionizers/administrative_boundary_regionizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from tqdm import tqdm

from srai.utils._optional import import_optional_dependencies
from srai.utils.constants import WGS84_CRS
from srai.utils.constants import REGIONS_INDEX, WGS84_CRS

from .base import BaseRegionizer

Expand Down Expand Up @@ -134,7 +134,7 @@ def transform(self, gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:

regions_dicts = self._generate_regions_from_all_geometries(gdf_wgs84)

regions_gdf = gpd.GeoDataFrame(data=regions_dicts, crs=WGS84_CRS).set_index("region_id")
regions_gdf = gpd.GeoDataFrame(data=regions_dicts, crs=WGS84_CRS).set_index(REGIONS_INDEX)
regions_gdf = self._toposimplify_gdf(regions_gdf)

if self.clip_regions:
Expand Down Expand Up @@ -270,7 +270,7 @@ def _toposimplify_gdf(self, regions_gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
prevent_oversimplify=True,
)
regions_gdf = topo.to_gdf(winding_order="CW_CCW", crs=WGS84_CRS, validate=True)
regions_gdf.index.rename("region_id", inplace=True)
regions_gdf.index.rename(REGIONS_INDEX, inplace=True)
regions_gdf.geometry = regions_gdf.geometry.apply(make_valid)
for idx, r in regions_gdf.iterrows():
if not isinstance(r.geometry, (Polygon, MultiPolygon)):
Expand Down
4 changes: 3 additions & 1 deletion srai/regionizers/h3_regionizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from functional import seq
from shapely import geometry

from srai.utils.constants import WGS84_CRS
from srai.utils.constants import REGIONS_INDEX, WGS84_CRS

from .base import BaseRegionizer

Expand Down Expand Up @@ -91,6 +91,8 @@ def transform(self, gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
else gdf_h3
)

gdf_h3_clipped.index.name = REGIONS_INDEX

return gdf_h3_clipped.to_crs(gdf.crs)

def _polygon_shapely_to_h3(self, polygon: geometry.Polygon) -> h3.Polygon:
Expand Down
4 changes: 3 additions & 1 deletion srai/regionizers/s2_regionizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from s2 import s2
from shapely.geometry import Polygon

from srai.utils.constants import WGS84_CRS
from srai.utils.constants import REGIONS_INDEX, WGS84_CRS

from .base import BaseRegionizer

Expand Down Expand Up @@ -76,6 +76,8 @@ def transform(self, gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:

res = res[~res.index.duplicated(keep="first")]

res.index.name = REGIONS_INDEX

return res

def _fill_with_s2_cells(self, gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
Expand Down
4 changes: 2 additions & 2 deletions srai/regionizers/voronoi_regionizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from shapely.geometry import box

from srai.utils._optional import import_optional_dependencies
from srai.utils.constants import WGS84_CRS
from srai.utils.constants import REGIONS_INDEX, WGS84_CRS

from .base import BaseRegionizer

Expand Down Expand Up @@ -110,7 +110,7 @@ def transform(self, gdf: Optional[gpd.GeoDataFrame] = None) -> gpd.GeoDataFrame:
regions_gdf = gpd.GeoDataFrame(
data={"geometry": generated_regions}, index=self.region_ids, crs=WGS84_CRS
)
regions_gdf.index.rename("region_id", inplace=True)
regions_gdf.index.rename(REGIONS_INDEX, inplace=True)
clipped_regions_gdf = regions_gdf.clip(mask=gdf_wgs84, keep_geom_type=False)
return clipped_regions_gdf

Expand Down
3 changes: 3 additions & 0 deletions srai/utils/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""Constants used across the project."""

WGS84_CRS = "EPSG:4326"

REGIONS_INDEX = "region_id"
FEATURES_INDEX = "feature_id"