Skip to content

Commit

Permalink
chore: add index name to regionizers (#182)
Browse files Browse the repository at this point in the history
chore: add index name to regionizers, and extract index names to constant
  • Loading branch information
piotrgramacki authored Feb 22, 2023
1 parent ee45614 commit a82bc96
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 13 deletions.
14 changes: 8 additions & 6 deletions srai/joiners/intersection_joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import geopandas as gpd
import pandas as pd

from srai.utils.constants import FEATURES_INDEX, REGIONS_INDEX


class IntersectionJoiner:
"""
Expand Down Expand Up @@ -68,11 +70,11 @@ def _join_with_geom(
"""
joined_parts = [
gpd.overlay(
single[["geometry"]].reset_index(names="feature_id"),
regions[["geometry"]].reset_index(names="region_id"),
single[["geometry"]].reset_index(names=FEATURES_INDEX),
regions[["geometry"]].reset_index(names=REGIONS_INDEX),
how="intersection",
keep_geom_type=False,
).set_index(["region_id", "feature_id"])
).set_index([REGIONS_INDEX, FEATURES_INDEX])
for _, single in features.groupby(features["geometry"].geom_type)
]

Expand All @@ -95,12 +97,12 @@ def _join_without_geom(
"""
joint = (
gpd.sjoin(
regions.reset_index(names="region_id"),
features.reset_index(names="feature_id"),
regions.reset_index(names=REGIONS_INDEX),
features.reset_index(names=FEATURES_INDEX),
how="inner",
predicate="intersects",
)
.set_index(["region_id", "feature_id"])
.set_index([REGIONS_INDEX, FEATURES_INDEX])
.drop(columns=["index_right", "geometry"])
)
return joint
6 changes: 3 additions & 3 deletions srai/regionizers/administrative_boundary_regionizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from tqdm import tqdm

from srai.utils._optional import import_optional_dependencies
from srai.utils.constants import WGS84_CRS
from srai.utils.constants import REGIONS_INDEX, WGS84_CRS

from .base import BaseRegionizer

Expand Down Expand Up @@ -134,7 +134,7 @@ def transform(self, gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:

regions_dicts = self._generate_regions_from_all_geometries(gdf_wgs84)

regions_gdf = gpd.GeoDataFrame(data=regions_dicts, crs=WGS84_CRS).set_index("region_id")
regions_gdf = gpd.GeoDataFrame(data=regions_dicts, crs=WGS84_CRS).set_index(REGIONS_INDEX)
regions_gdf = self._toposimplify_gdf(regions_gdf)

if self.clip_regions:
Expand Down Expand Up @@ -270,7 +270,7 @@ def _toposimplify_gdf(self, regions_gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
prevent_oversimplify=True,
)
regions_gdf = topo.to_gdf(winding_order="CW_CCW", crs=WGS84_CRS, validate=True)
regions_gdf.index.rename("region_id", inplace=True)
regions_gdf.index.rename(REGIONS_INDEX, inplace=True)
regions_gdf.geometry = regions_gdf.geometry.apply(make_valid)
for idx, r in regions_gdf.iterrows():
if not isinstance(r.geometry, (Polygon, MultiPolygon)):
Expand Down
4 changes: 3 additions & 1 deletion srai/regionizers/h3_regionizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from functional import seq
from shapely import geometry

from srai.utils.constants import WGS84_CRS
from srai.utils.constants import REGIONS_INDEX, WGS84_CRS

from .base import BaseRegionizer

Expand Down Expand Up @@ -91,6 +91,8 @@ def transform(self, gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
else gdf_h3
)

gdf_h3_clipped.index.name = REGIONS_INDEX

return gdf_h3_clipped.to_crs(gdf.crs)

def _polygon_shapely_to_h3(self, polygon: geometry.Polygon) -> h3.Polygon:
Expand Down
4 changes: 3 additions & 1 deletion srai/regionizers/s2_regionizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from s2 import s2
from shapely.geometry import Polygon

from srai.utils.constants import WGS84_CRS
from srai.utils.constants import REGIONS_INDEX, WGS84_CRS

from .base import BaseRegionizer

Expand Down Expand Up @@ -76,6 +76,8 @@ def transform(self, gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:

res = res[~res.index.duplicated(keep="first")]

res.index.name = REGIONS_INDEX

return res

def _fill_with_s2_cells(self, gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
Expand Down
4 changes: 2 additions & 2 deletions srai/regionizers/voronoi_regionizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from shapely.geometry import box

from srai.utils._optional import import_optional_dependencies
from srai.utils.constants import WGS84_CRS
from srai.utils.constants import REGIONS_INDEX, WGS84_CRS

from .base import BaseRegionizer

Expand Down Expand Up @@ -110,7 +110,7 @@ def transform(self, gdf: Optional[gpd.GeoDataFrame] = None) -> gpd.GeoDataFrame:
regions_gdf = gpd.GeoDataFrame(
data={"geometry": generated_regions}, index=self.region_ids, crs=WGS84_CRS
)
regions_gdf.index.rename("region_id", inplace=True)
regions_gdf.index.rename(REGIONS_INDEX, inplace=True)
clipped_regions_gdf = regions_gdf.clip(mask=gdf_wgs84, keep_geom_type=False)
return clipped_regions_gdf

Expand Down
3 changes: 3 additions & 0 deletions srai/utils/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""Constants used across the project."""

WGS84_CRS = "EPSG:4326"

REGIONS_INDEX = "region_id"
FEATURES_INDEX = "feature_id"

0 comments on commit a82bc96

Please sign in to comment.