Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add osm data loader based on pbf files #205

Merged
merged 24 commits into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
9b36354
feat: add first version of PbfTagLoader
RaczeQ Mar 16, 2023
c6645f2
fix: apply refurb suggestion
RaczeQ Mar 16, 2023
971cdf2
feat: improve buffering and pbf downloading
RaczeQ Mar 17, 2023
a6f1b83
chore: modify pdm file
RaczeQ Mar 17, 2023
f4c992a
ci: added new nvidia libraries licenses to ignore
RaczeQ Mar 17, 2023
03aee70
chore: refactor code and add docstrings
RaczeQ Mar 17, 2023
2e9bdeb
feat(OSMPbfLoader): added example notebook
RaczeQ Mar 17, 2023
cb475dc
chore: modified modules docstrings
RaczeQ Mar 17, 2023
cf3eaaf
feat: modify pbf loader examples
RaczeQ Mar 18, 2023
841bf1c
feat: modify OSMPbfLoader geometries parsing
RaczeQ Mar 18, 2023
5aefb34
chore: changed example plots
RaczeQ Mar 18, 2023
66239f3
chore: change protomaps progress bars
RaczeQ Mar 18, 2023
efa5c73
chore: add prettymaps contribution
RaczeQ Mar 18, 2023
11ba806
chore: change directories names
RaczeQ Mar 19, 2023
0775eef
chore: add tests for OSMPbfLoader
RaczeQ Mar 19, 2023
9b3204f
fix: changed geometry hash calculation
RaczeQ Mar 19, 2023
b61f412
feat: added user agent header with library info
RaczeQ Mar 19, 2023
9b18518
chore: formatted notes in osm loaders
RaczeQ Mar 20, 2023
3d3098e
chore: merge remote-tracking branch 'origin/main' into 52-add-osmload…
RaczeQ Mar 20, 2023
97ef594
chore: refactor utils module
RaczeQ Mar 20, 2023
ab558fd
chore: update srai/loaders/osm_loaders/filters/osm_tags_type.py
RaczeQ Mar 20, 2023
f7cce35
chore: update srai/loaders/osm_loaders/osm_pbf_loader.py
RaczeQ Mar 20, 2023
ce9b032
chore: applied CR suggestions
RaczeQ Mar 21, 2023
a099b56
chore: applied CR suggestions
RaczeQ Mar 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,7 @@ requirements.txt
cache/

# pytorch lightning
lightning_logs/
lightning_logs/

# files_cache
files/
2 changes: 1 addition & 1 deletion examples/loaders/gtfs_loader.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"import numpy as np\n",
"from shapely.geometry import Point\n",
"from srai.utils.constants import WGS84_CRS\n",
"from utils import download"
"from srai.utils.download import download"
]
},
{
Expand Down
4 changes: 3 additions & 1 deletion srai/loaders/osm_tag_loader/filters/hex2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
References:
1. https://dl.acm.org/doi/10.1145/3486635.3491076
"""
HEX2VEC_FILTER = {
from srai.loaders.osm_tag_loader.filters.osm_tags_type import osm_tags_type

HEX2VEC_FILTER: osm_tags_type = {
"aeroway": [
"aerodrome",
"apron",
Expand Down
4 changes: 4 additions & 0 deletions srai/loaders/osm_tag_loader/filters/osm_tags_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""DOCSTRING TODO."""
from typing import Dict, List, Union

osm_tags_type = Dict[str, Union[List[str], str, bool]]
8 changes: 5 additions & 3 deletions srai/loaders/osm_tag_loader/filters/popular.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
import requests
from functional import seq

from srai.loaders.osm_tag_loader.filters.osm_tags_type import osm_tags_type

_TAGINFO_API_ADDRESS = "https://taginfo.openstreetmap.org"
_TAGINFO_API_TAGS = _TAGINFO_API_ADDRESS + "/api/4/tags/popular"


def get_popular_tags(
in_wiki_only: bool = False, min_count: int = 0, min_fraction: float = 0.0
) -> Dict[str, List[str]]:
) -> osm_tags_type:
"""
Download the OSM's most popular tags from taginfo api.

Expand Down Expand Up @@ -47,15 +49,15 @@ def get_popular_tags(

def _parse_taginfo_response(
taginfo_data: List[Dict[str, Any]], in_wiki_only: bool, min_count: int, min_fraction: float
) -> Dict[str, List[str]]:
) -> osm_tags_type:
result_tags = (
seq(taginfo_data)
.filter(lambda t: t["count_all"] >= min_count)
.filter(lambda t: t["count_all_fraction"] >= min_fraction)
)
if in_wiki_only:
result_tags = result_tags.filter(lambda t: t["in_wiki"])
taginfo_grouped: Dict[str, List[str]] = (
taginfo_grouped: osm_tags_type = (
result_tags.map(lambda t: (t["key"], t["value"])).group_by_key().to_dict()
)
return taginfo_grouped
11 changes: 5 additions & 6 deletions srai/loaders/osm_tag_loader/osm_tag_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
This module contains loader capable of loading OpenStreetMap tags.
"""
from itertools import product
from typing import Dict, List, Tuple, Union
from typing import List, Tuple, Union

import geopandas as gpd
import pandas as pd
from functional import seq
from tqdm import tqdm

from srai.loaders.osm_tag_loader.filters.osm_tags_type import osm_tags_type
from srai.utils._optional import import_optional_dependencies
from srai.utils.constants import FEATURES_INDEX, WGS84_CRS

Expand Down Expand Up @@ -42,7 +43,7 @@ def __init__(self) -> None:
def load(
self,
area: gpd.GeoDataFrame,
tags: Dict[str, Union[List[str], str, bool]],
tags: osm_tags_type,
) -> gpd.GeoDataFrame:
"""
Download OSM objects with specified tags for a given area.
Expand All @@ -55,7 +56,7 @@ def load(

Args:
area (gpd.GeoDataFrame): Area for which to download objects.
tags (Dict[str, Union[List[str], str, bool]]): A dictionary
tags (osm_tags_type): A dictionary
specifying which tags to download.
The keys should be OSM tags (e.g. `building`, `amenity`).
The values should either be `True` for retrieving all objects with the tag,
Expand Down Expand Up @@ -93,9 +94,7 @@ def load(

return self._flatten_index(result_gdf)

def _flatten_tags(
self, tags: Dict[str, Union[List[str], str, bool]]
) -> List[Tuple[str, Union[str, bool]]]:
def _flatten_tags(self, tags: osm_tags_type) -> List[Tuple[str, Union[str, bool]]]:
tags_flat: List[Tuple[str, Union[str, bool]]] = (
seq(tags.items())
.starmap(lambda k, v: product([k], v if isinstance(v, list) else [v]))
Expand Down
178 changes: 178 additions & 0 deletions srai/loaders/osm_tag_loader/pbf_file_downloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""DOCSTRING TODO."""
import hashlib
from pathlib import Path
from time import sleep, time
from typing import Any, Dict, Sequence

import geopandas as gpd
import requests
import shapely.wkt as wktlib
import topojson as tp
from shapely.geometry import Polygon, mapping
from shapely.geometry.base import BaseGeometry
from shapely.validation import make_valid
from tqdm import tqdm

from srai.utils.constants import WGS84_CRS
from srai.utils.download import download
from srai.utils.geometry import flatten_geometry, remove_interiors


class PbfFileDownloader:
"""DOCSTRING TODO."""

PROTOMAPS_API_START_URL = "https://app.protomaps.com/downloads/osm"
PROTOMAPS_API_DOWNLOAD_URL = "https://app.protomaps.com/downloads/{}/download"

_PBAR_FORMAT = "Downloading pbf file ({})"

SIMPLIFICATION_TOLERANCE_VALUES = [
1e-07,
2e-07,
5e-07,
1e-06,
2e-06,
5e-06,
1e-05,
2e-05,
5e-05,
0.0001,
0.0002,
0.0005,
0.001,
0.002,
0.005,
0.01,
0.02,
0.05,
]

def download_pbf_files_for_region_gdf(
self, region_gdf: gpd.GeoDataFrame
) -> Dict[str, Sequence[Path]]:
"""DOCSTRING TODO."""
regions_mapping: Dict[str, Sequence[Path]] = {}

for region_id, row in region_gdf.iterrows():
polygons = flatten_geometry(row.geometry)
regions_mapping[region_id] = [
self.download_pbf_file_for_polygon(polygon) for polygon in polygons
]

return regions_mapping

def download_pbf_file_for_polygon(self, polygon: Polygon) -> Path:
"""DOCSTRING TODO."""
closed_polygon = remove_interiors(polygon)
simplified_polygon = self._simplify_polygon(closed_polygon)
geometry_hash = self._get_geometry_hash(simplified_polygon)
pbf_file_path = Path().resolve() / "files" / f"{geometry_hash}.pbf"

if not pbf_file_path.exists():
geometry_geojson = mapping(simplified_polygon)

s = requests.Session()

req = s.get(url=self.PROTOMAPS_API_START_URL)

csrf_token = req.cookies["csrftoken"]
headers = {
"Referer": self.PROTOMAPS_API_START_URL,
"Cookie": f"csrftoken={csrf_token}",
"X-CSRFToken": csrf_token,
"Content-Type": "application/json; charset=utf-8",
}
request_payload = {
"region": {"type": "geojson", "data": geometry_geojson},
"name": geometry_hash,
}

start_extract_request = s.post(
url=self.PROTOMAPS_API_START_URL,
json=request_payload,
headers=headers,
cookies=dict(csrftoken=csrf_token),
)
start_extract_request.raise_for_status()

start_extract_result = start_extract_request.json()
extraction_uuid = start_extract_result["uuid"]
status_check_url = start_extract_result["url"]

with tqdm() as pbar:
status_response: Dict[str, Any] = {}
cells_total = 0
nodes_total = 0
elems_total = 0
while not status_response.get("Complete", False):
sleep(0.5)
status_response = s.get(url=status_check_url).json()
cells_total = max(cells_total, status_response.get("CellsTotal", 0))
nodes_total = max(nodes_total, status_response.get("NodesTotal", 0))
elems_total = max(elems_total, status_response.get("ElemsTotal", 0))

cells_prog = status_response.get("CellsProg", None)
nodes_prog = status_response.get("NodesProg", None)
elems_prog = status_response.get("ElemsProg", None)

if cells_total > 0 and cells_prog is not None and cells_prog < cells_total:
pbar.set_description(self._PBAR_FORMAT.format("Cells"))
pbar.total = cells_total
pbar.n = cells_prog
pbar.last_print_n = cells_prog
elif nodes_total > 0 and nodes_prog is not None and nodes_prog < nodes_total:
pbar.set_description(self._PBAR_FORMAT.format("Nodes"))
pbar.total = nodes_total
pbar.n = nodes_prog
pbar.last_print_n = nodes_prog
elif elems_total > 0 and elems_prog is not None and elems_prog < elems_total:
pbar.set_description(self._PBAR_FORMAT.format("Elements"))
pbar.total = elems_total
pbar.n = elems_prog
pbar.last_print_n = elems_prog
else:
pbar.total = elems_total
pbar.n = elems_total
pbar.last_print_n = elems_total

pbar.start_t = time()
pbar.last_print_t = time()
pbar.refresh()

download(
url=self.PROTOMAPS_API_DOWNLOAD_URL.format(extraction_uuid),
fname=pbf_file_path.as_posix(),
)

return pbf_file_path

def _simplify_polygon(self, polygon: Polygon) -> Polygon:
simplified_polygon = polygon

for simplify_tolerance in self.SIMPLIFICATION_TOLERANCE_VALUES:
simplified_polygon = (
tp.Topology(
polygon,
toposimplify=simplify_tolerance,
prevent_oversimplify=True,
)
.to_gdf(winding_order="CW_CCW", crs=WGS84_CRS, validate=True)
.geometry[0]
)
simplified_polygon = make_valid(simplified_polygon)
if len(simplified_polygon.exterior.coords) < 1000:
break

if len(simplified_polygon.exterior.coords) > 1000:
simplified_polygon = polygon.convex_hull

if len(simplified_polygon.exterior.coords) > 1000:
simplified_polygon = polygon.minimum_rotated_rectangle

return simplified_polygon

def _get_geometry_hash(self, geometry: BaseGeometry) -> str:
wkt_string = wktlib.dumps(geometry)
h = hashlib.new("sha256")
h.update(wkt_string.encode())
return h.hexdigest()
Loading