Skip to content

Commit

Permalink
feat: add duplicated seeds ids to the error message (#199)
Browse files Browse the repository at this point in the history
feat: add region ids to the error message
  • Loading branch information
RaczeQ authored Mar 10, 2023
1 parent 0ad9954 commit 070f7e3
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions srai/regionizers/voronoi_regionizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
This module contains voronoi regionizer implementation.
"""

from typing import Optional
from typing import Hashable, List, Optional

import geopandas as gpd
from shapely.geometry import box
from shapely.geometry import Point, box

from srai.regionizers import Regionizer
from srai.utils._optional import import_optional_dependencies
Expand Down Expand Up @@ -64,8 +64,8 @@ def __init__(
dependency_group="voronoi", modules=["haversine", "pymap3d", "spherical_geometry"]
)
seeds_wgs84 = seeds.to_crs(crs=WGS84_CRS)
self.region_ids = []
self.seeds = []
self.region_ids: List[Hashable] = []
self.seeds: List[Point] = []
self.max_meters_between_points = max_meters_between_points
self.num_of_multiprocessing_workers = num_of_multiprocessing_workers
self.multiprocessing_activation_threshold = multiprocessing_activation_threshold
Expand All @@ -75,8 +75,9 @@ def __init__(
self.region_ids.append(index)
self.seeds.append(candidate_point)

if self._check_duplicate_points():
raise ValueError("Duplicate seeds present.")
duplicated_seeds_ids = self._get_duplicated_seeds_ids()
if duplicated_seeds_ids:
raise ValueError(f"Duplicate seeds present: {duplicated_seeds_ids}.")

if len(self.seeds) < 4:
raise ValueError("Minimum 4 seeds are required.")
Expand Down Expand Up @@ -123,7 +124,9 @@ def transform(self, gdf: Optional[gpd.GeoDataFrame] = None) -> gpd.GeoDataFrame:
clipped_regions_gdf = regions_gdf.clip(mask=gdf_wgs84, keep_geom_type=False)
return clipped_regions_gdf

def _check_duplicate_points(self) -> bool:
"""Check if any point overlaps with another using quick sjoin operation."""
gdf = gpd.GeoDataFrame(data=[{"geometry": s} for s in self.seeds])
return len(gdf.sjoin(gdf).index) != len(self.seeds)
def _get_duplicated_seeds_ids(self) -> List[Hashable]:
"""Return all seeds ids that overlap with another using quick sjoin operation."""
gdf = gpd.GeoDataFrame(data={"geometry": self.seeds}, index=self.region_ids, crs=WGS84_CRS)
duplicated_seeds = gdf.sjoin(gdf).index.value_counts().loc[lambda x: x > 1]
duplicated_seeds_ids: List[Hashable] = duplicated_seeds.index.to_list()
return duplicated_seeds_ids

0 comments on commit 070f7e3

Please sign in to comment.