Skip to content

Commit

Permalink
chore: make h3neighbourhood return only available indexes, add tests …
Browse files Browse the repository at this point in the history
…for get_neighbours
  • Loading branch information
Szymon Woźniak committed Mar 9, 2023
1 parent 4da1270 commit 8ec7988
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 4 deletions.
29 changes: 26 additions & 3 deletions srai/neighbourhoods/h3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
This module contains the H3Neighbourhood class, that allows to get the neighbours of an H3 region.
"""
from typing import Set
from typing import Optional, Set

import geopandas as gpd
import h3

from .neighbourhood import Neighbourhood
Expand All @@ -17,6 +18,23 @@ class H3Neighbourhood(Neighbourhood[str]):
This class allows to get the neighbours of an H3 region.
"""

def __init__(self, regions_gdf: Optional[gpd.GeoDataFrame] = None) -> None:
"""
Initializes the H3Neighbourhood.
If a regions GeoDataFrame is provided, only the neighbours
that are in the regions GeoDataFrame will be returned by the methods of this instance.
Args:
regions_gdf (Optional[gpd.GeoDataFrame], optional): The regions that are being analyzed.
The H3Neighbourhood will only look for neighbours among these regions.
Defaults to None.
"""
super().__init__()
self._available_indices: Optional[Set[str]] = None
if regions_gdf is not None:
self._available_indices = set(regions_gdf.index)

def get_neighbours(self, index: str) -> Set[str]:
"""
Get the direct neighbours of an H3 region using its index.
Expand Down Expand Up @@ -45,7 +63,7 @@ def get_neighbours_up_to_distance(self, index: str, distance: int) -> Set[str]:

neighbours: Set[str] = h3.grid_disk(index, distance)
neighbours.discard(index)
return neighbours
return self._select_available(neighbours)

def get_neighbours_at_distance(self, index: str, distance: int) -> Set[str]:
"""
Expand All @@ -63,7 +81,12 @@ def get_neighbours_at_distance(self, index: str, distance: int) -> Set[str]:

neighbours: Set[str] = h3.grid_ring(index, distance)
neighbours.discard(index)
return neighbours
return self._select_available(neighbours)

def _select_available(self, indices: Set[str]) -> Set[str]:
if self._available_indices is None:
return indices
return indices.intersection(self._available_indices)

def _distance_incorrect(self, distance: int) -> bool:
return distance <= 0
65 changes: 64 additions & 1 deletion tests/neighbourhoods/test_h3_neighbourhood.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,73 @@
from typing import Set
from typing import Any, Set

import geopandas as gpd
import pytest
from shapely.geometry import Polygon

from srai.neighbourhoods import H3Neighbourhood


@pytest.fixture # type: ignore
def empty_gdf() -> gpd.GeoDataFrame:
"""Fixture for an empty GeoDataFrame."""
return gpd.GeoDataFrame()


@pytest.fixture # type: ignore
def single_hex_gdf() -> gpd.GeoDataFrame:
"""Fixture for a GeoDataFrame with a single hexagon."""
return gpd.GeoDataFrame(
{"index": ["811e3ffffffffff"], "geometry": [Polygon([(0, 0), (0, 1), (1, 0), (1, 1)])]}
)


@pytest.fixture # type: ignore
def hex_without_one_neighbour_gdf() -> gpd.GeoDataFrame:
"""Fixture for a GeoDataFrame with a single hexagon."""
return gpd.GeoDataFrame(
geometry=gpd.points_from_xy([0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]),
index=[
"811e3ffffffffff",
"811f3ffffffffff",
"811fbffffffffff",
"811ebffffffffff",
"811efffffffffff",
"811e7ffffffffff",
],
)


@pytest.mark.parametrize( # type: ignore
"regions_gdf_fixture,expected",
[
(
"empty_gdf",
set(),
),
(
"single_hex_gdf",
set(),
),
(
"hex_without_one_neighbour_gdf",
{
"811f3ffffffffff",
"811fbffffffffff",
"811ebffffffffff",
"811efffffffffff",
"811e7ffffffffff",
},
),
],
)
def test_get_neighbours_with_regions_gdf(
regions_gdf_fixture: str, expected: Set[str], request: Any
) -> None:
"""Test get_neighbours of H3Neighbourhood with a specified regions GeoDataFrame."""
regions_gdf = request.getfixturevalue(regions_gdf_fixture)
assert H3Neighbourhood(regions_gdf).get_neighbours("811e3ffffffffff") == expected


@pytest.mark.parametrize( # type: ignore
"index,expected",
[
Expand Down

0 comments on commit 8ec7988

Please sign in to comment.