diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..35cc5f6 --- /dev/null +++ b/.flake8 @@ -0,0 +1,8 @@ +[flake8] +# Exceptions below are ignored because their solution was not clear to the author of the PR, +# but they should be solved and taken out from the ignore. +ignore = E203, E501, W503, B950, F821, B007, E402, E722, F401, F811, B001, B008, C901, E731, E231, B009, B303, E731, B903, B011 +max-line-length = 100 +max-complexity = 18 +select = A,B,C,E,F,W,T4,B9 +exclude = .git,__pycache__,misc diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml new file mode 100644 index 0000000..745e965 --- /dev/null +++ b/.gitlab-ci.yml @@ -0,0 +1,30 @@ +stages: +- test + +linting: + stage: test + image: python:3.5-slim + before_script: + - apt-get update + - apt-get install -y git + - pip3 install -U pre-commit==1.21.0 + script: + - pre-commit run --all-files + +test: + stage: test + image: python:3.5 + variables: + PYTHONPATH: $PWD:$PYTHONPATH + script: + - ./setup.py test + +security: + stage: test + image: python:3.5-slim + allow_failure: true + before_script: + - pip install . + - pip install safety + script: + - safety check \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..bad7ef9 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,24 @@ +repos: + - repo: https://github.com/ambv/black + rev: 19.10b0 + hooks: + - id: black + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.4.0 + hooks: + - id: flake8 + exclude: misc + additional_dependencies: [ + 'flake8==3.7.9', + 'flake8-builtins==1.4.2', + 'flake8-bugbear==20.1.2', + ] + - repo: https://github.com/pycqa/bandit + rev: 1.6.2 + hooks: + - id: bandit + args: [-lll] + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.761 + hooks: + - id: mypy \ No newline at end of file diff --git a/.travis.yml b/.travis.yml index 3d3e47e..896d80f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,16 +2,34 @@ language: python python: - "3.8" -# command to install dependencies +install: "pip install Cython && pip install . pre-commit" + +jobs: + include: + + - python: 3.6 + stage: test + + - python: 3.7 + stage: test + + - python: 3.8 + stage: test -install: "pip install ." # command to run tests services: - xvfb before_script: # configure a headless display for testing plot generation - "export DISPLAY=:99.0" -script: nosetests . +script: + - pre-commit run --all-files + - nosetests . + +cache: + directories: + - $HOME/.cache/pre-commit + - $HOME/.cache/pip notifications: email: diff --git a/examples/example_export.py b/examples/example_export.py index e372ede..f668786 100644 --- a/examples/example_export.py +++ b/examples/example_export.py @@ -10,24 +10,21 @@ # get elementary bus events (connections) taking place within a given time interval: -all_events = networks.temporal_network(g, - start_time_ut=start_ut, - end_time_ut=end_ut - ) +all_events = networks.temporal_network(g, start_time_ut=start_ut, end_time_ut=end_ut) print("Number of elementary PT events during rush hour in Kuopio: ", len(all_events)) # get elementary bus events (connections) taking place within a given time interval: -tram_events = networks.temporal_network(g, - start_time_ut=start_ut, - end_time_ut=end_ut, - route_type=route_types.TRAM - ) -assert(len(tram_events) == 0) # there should be no trams in our example city (Kuopio, Finland) +tram_events = networks.temporal_network( + g, start_time_ut=start_ut, end_time_ut=end_ut, route_type=route_types.TRAM +) +assert len(tram_events) == 0 # there should be no trams in our example city (Kuopio, Finland) # construct a networkx graph print("\nConstructing a combined stop_to_stop_network") -graph = networks.combined_stop_to_stop_transit_network(g, start_time_ut=start_ut, end_time_ut=end_ut) +graph = networks.combined_stop_to_stop_transit_network( + g, start_time_ut=start_ut, end_time_ut=end_ut +) print("Number of edges: ", len(graph.edges())) print("Number of nodes: ", len(graph.nodes())) print("Example edge: ", list(graph.edges(data=True))[0]) @@ -37,4 +34,3 @@ ################################################# # See also other functions in gtfspy.networks ! # ################################################# - diff --git a/examples/example_filter.py b/examples/example_filter.py index 74a278f..017eb5a 100644 --- a/examples/example_filter.py +++ b/examples/example_filter.py @@ -16,16 +16,24 @@ # filter by time and 3 kilometers from the city center week_start = G.get_weekly_extract_start_date() week_end = week_start + datetime.timedelta(days=7) -fe = FilterExtract(G, filtered_database_path, start_date=week_start, end_date=week_end, - buffer_lat=62.8930796, buffer_lon=27.6671316, buffer_distance_km=3) +fe = FilterExtract( + G, + filtered_database_path, + start_date=week_start, + end_date=week_end, + buffer_lat=62.8930796, + buffer_lon=27.6671316, + buffer_distance_km=3, +) fe.create_filtered_copy() -assert (os.path.exists(filtered_database_path)) +assert os.path.exists(filtered_database_path) G = GTFS(filtered_database_path) # visualize the routes of the filtered database from gtfspy import mapviz from matplotlib import pyplot as plt + mapviz.plot_route_network_from_gtfs(G) -plt.show() \ No newline at end of file +plt.show() diff --git a/examples/example_import.py b/examples/example_import.py index 71559c3..f004d3e 100644 --- a/examples/example_import.py +++ b/examples/example_import.py @@ -7,17 +7,21 @@ def load_or_import_example_gtfs(verbose=False): imported_database_path = "test_db_kuopio.sqlite" - if not os.path.exists(imported_database_path): # reimport only if the imported database does not already exist + if not os.path.exists( + imported_database_path + ): # reimport only if the imported database does not already exist print("Importing gtfs zip file") - import_gtfs.import_gtfs(["data/gtfs_kuopio_finland.zip"], # input: list of GTFS zip files (or directories) - imported_database_path, # output: where to create the new sqlite3 database - print_progress=verbose, # whether to print progress when importing data - location_name="Kuopio") + import_gtfs.import_gtfs( + ["data/gtfs_kuopio_finland.zip"], # input: list of GTFS zip files (or directories) + imported_database_path, # output: where to create the new sqlite3 database + print_progress=verbose, # whether to print progress when importing data + location_name="Kuopio", + ) # Not this is an optional step, which is not necessary for many things. print("Computing walking paths using OSM") G = gtfs.GTFS(imported_database_path) - G.meta['download_date'] = "2017-03-15" + G.meta["download_date"] = "2017-03-15" osm_path = "data/kuopio_extract_mapzen_2017_03_15.osm.pbf" @@ -25,7 +29,9 @@ def load_or_import_example_gtfs(verbose=False): # this should raise a warning due to no nearby OSM nodes for one of the stops. osm_transfers.add_walk_distances_to_db_python(imported_database_path, osm_path) - print("Note: for large cities we have also a faster option for computing footpaths that uses Java.)") + print( + "Note: for large cities we have also a faster option for computing footpaths that uses Java.)" + ) dir_path = os.path.dirname(os.path.realpath(__file__)) java_path = os.path.join(dir_path, "../java_routing/") print("Please see the contents of " + java_path + " for more details.") @@ -35,7 +41,10 @@ def load_or_import_example_gtfs(verbose=False): if verbose: print("Location name:" + G.get_location_name()) # should print Kuopio - print("Time span of the data in unixtime: " + str(G.get_approximate_schedule_time_span_in_ut())) + print( + "Time span of the data in unixtime: " + + str(G.get_approximate_schedule_time_span_in_ut()) + ) # prints the time span in unix time return G diff --git a/examples/example_map_visualization.py b/examples/example_map_visualization.py index 0b93223..c680a99 100644 --- a/examples/example_map_visualization.py +++ b/examples/example_map_visualization.py @@ -14,4 +14,4 @@ # ax_thumbnail.figure.savefig("test_thumbnail.jpg") -plt.show() \ No newline at end of file +plt.show() diff --git a/examples/example_plot_trip_counts.py b/examples/example_plot_trip_counts.py index 87945c9..0a47ae6 100644 --- a/examples/example_plot_trip_counts.py +++ b/examples/example_plot_trip_counts.py @@ -1,8 +1,8 @@ -import functools import os from example_import import load_or_import_example_gtfs from matplotlib import pyplot as plt + from gtfspy.gtfs import GTFS G = load_or_import_example_gtfs() @@ -10,11 +10,11 @@ daily_trip_counts = G.get_trip_counts_per_day() f, ax = plt.subplots() -datetimes = [date.to_pydatetime() for date in daily_trip_counts['date']] -trip_counts = daily_trip_counts['trip_counts'] +datetimes = [date.to_pydatetime() for date in daily_trip_counts["date"]] +trip_counts = daily_trip_counts["trip_counts"] ax.bar(datetimes, trip_counts) -ax.axvline(G.meta['download_date'], color="red") +ax.axvline(G.meta["download_date"], color="red") threshold = 0.96 ax.axhline(trip_counts.max() * threshold, color="red") ax.axvline(G.get_weekly_extract_start_date(weekdays_at_least_of_max=threshold), color="yellow") @@ -24,18 +24,17 @@ G = GTFS(weekly_db_path) f, ax = plt.subplots() daily_trip_counts = G.get_trip_counts_per_day() - datetimes = [date.to_pydatetime() for date in daily_trip_counts['date']] - trip_counts = daily_trip_counts['trip_counts'] + datetimes = [date.to_pydatetime() for date in daily_trip_counts["date"]] + trip_counts = daily_trip_counts["trip_counts"] ax.bar(datetimes, trip_counts) - events = list(G.generate_routable_transit_events(0, G.get_approximate_schedule_time_span_in_ut()[0])) - min_ut = float('inf') + events = list( + G.generate_routable_transit_events(0, G.get_approximate_schedule_time_span_in_ut()[0]) + ) + min_ut = float("inf") for e in events: min_ut = min(e.dep_time_ut, min_ut) print(G.get_approximate_schedule_time_span_in_ut()) plt.show() - - - diff --git a/examples/example_temporal_distance_profile.py b/examples/example_temporal_distance_profile.py index 72f5ede..dcf40a0 100644 --- a/examples/example_temporal_distance_profile.py +++ b/examples/example_temporal_distance_profile.py @@ -3,7 +3,9 @@ import example_import from gtfspy.routing.helpers import get_transit_connections, get_walk_network -from gtfspy.routing.multi_objective_pseudo_connection_scan_profiler import MultiObjectivePseudoCSAProfiler +from gtfspy.routing.multi_objective_pseudo_connection_scan_profiler import ( + MultiObjectivePseudoCSAProfiler, +) from gtfspy.routing.node_profile_analyzer_time_and_veh_legs import NodeProfileAnalyzerTimeAndVehLegs G = example_import.load_or_import_example_gtfs() @@ -14,12 +16,12 @@ to_stop_I = None stop_dict = G.stops().to_dict("index") for stop_I, data in stop_dict.items(): - if data['name'] == from_stop_name: + if data["name"] == from_stop_name: from_stop_I = stop_I - if data['name'] == to_stop_name: + if data["name"] == to_stop_name: to_stop_I = stop_I -assert (from_stop_I is not None) -assert (to_stop_I is not None) +assert from_stop_I is not None +assert to_stop_I is not None # The start and end times between which PT operations (and footpaths) are scanned: ANALYSIS_START_TIME_UT = G.get_suitable_date_for_daily_extract(ut=True) + 10 * 3600 @@ -40,17 +42,18 @@ # gtfspy.osm_transfers.add_walk_distances_to_db_python(..., cutoff_distance_m=2000). - -mpCSA = MultiObjectivePseudoCSAProfiler(connections, - targets=[to_stop_I], - start_time_ut=CONNECTION_SCAN_START_TIME_UT, - end_time_ut=CONNECTION_SCAN_END_TIME_UT, - transfer_margin=120, # seconds - walk_network=walk_network, - walk_speed=1.5, # meters per second - verbose=True, - track_vehicle_legs=True, - track_time=True) +mpCSA = MultiObjectivePseudoCSAProfiler( + connections, + targets=[to_stop_I], + start_time_ut=CONNECTION_SCAN_START_TIME_UT, + end_time_ut=CONNECTION_SCAN_END_TIME_UT, + transfer_margin=120, # seconds + walk_network=walk_network, + walk_speed=1.5, # meters per second + verbose=True, + track_vehicle_legs=True, + track_time=True, +) mpCSA.run() profiles = mpCSA.stop_profiles @@ -60,19 +63,21 @@ direct_walk_duration = departure_stop_profile.get_walk_to_target_duration() # This equals inf, if walking distance between the departure_stop (from_stop_I) and target_stop (to_stop_I) # is longer than MAX_WALK_LENGTH -analyzer = NodeProfileAnalyzerTimeAndVehLegs(departure_stop_profile.get_final_optimal_labels(), - direct_walk_duration, - ANALYSIS_START_TIME_UT, - ANALYSIS_END_TIME_UT) +analyzer = NodeProfileAnalyzerTimeAndVehLegs( + departure_stop_profile.get_final_optimal_labels(), + direct_walk_duration, + ANALYSIS_START_TIME_UT, + ANALYSIS_END_TIME_UT, +) # Print out results: stop_dict = G.stops().to_dict("index") print("Origin: ", stop_dict[from_stop_I]) print("Destination: ", stop_dict[to_stop_I]) -print("Minimum temporal distance: ", analyzer.min_temporal_distance() / 60., " minutes") -print("Mean temporal distance: ", analyzer.mean_temporal_distance() / 60., " minutes") -print("Medan temporal distance: ", analyzer.median_temporal_distance() / 60., " minutes") -print("Maximum temporal distance: ", analyzer.max_temporal_distance() / 60., " minutes") +print("Minimum temporal distance: ", analyzer.min_temporal_distance() / 60.0, " minutes") +print("Mean temporal distance: ", analyzer.mean_temporal_distance() / 60.0, " minutes") +print("Medan temporal distance: ", analyzer.median_temporal_distance() / 60.0, " minutes") +print("Maximum temporal distance: ", analyzer.max_temporal_distance() / 60.0, " minutes") # Note that the mean and max temporal distances have the value of `direct_walk_duration`, # if there are no journey alternatives departing after (or at the same time as) `ANALYSIS_END_TIME_UT`. # Thus, if you obtain a float('inf') value for some of the temporal distance measures, it could probably be @@ -85,8 +90,9 @@ # use tex in plotting rc("text", usetex=True) -fig1 = analyzer.plot_new_transfer_temporal_distance_profile(timezone=timezone_pytz, - format_string="%H:%M") +fig1 = analyzer.plot_new_transfer_temporal_distance_profile( + timezone=timezone_pytz, format_string="%H:%M" +) fig2 = analyzer.plot_temporal_distance_pdf_horizontal(use_minutes=True) print("Showing...") plt.show() diff --git a/gtfspy/calc_transfers.py b/gtfspy/calc_transfers.py index af823d2..1ed3154 100644 --- a/gtfspy/calc_transfers.py +++ b/gtfspy/calc_transfers.py @@ -9,16 +9,17 @@ from gtfspy.gtfs import GTFS from gtfspy.util import wgs84_distance, wgs84_height, wgs84_width -create_stmt = ('CREATE TABLE IF NOT EXISTS main.stop_distances ' - '(from_stop_I INT, ' - ' to_stop_I INT, ' - ' d INT, ' - ' d_walk INT, ' - ' min_transfer_time INT, ' - ' timed_transfer INT, ' - 'UNIQUE (from_stop_I, to_stop_I)' - ')' - ) +create_stmt = ( + "CREATE TABLE IF NOT EXISTS main.stop_distances " + "(from_stop_I INT, " + " to_stop_I INT, " + " d INT, " + " d_walk INT, " + " min_transfer_time INT, " + " timed_transfer INT, " + "UNIQUE (from_stop_I, to_stop_I)" + ")" +) def bind_functions(conn): @@ -36,11 +37,14 @@ def _get_geo_hash_precision(search_radius_in_km): suggested_precision = precision break if suggested_precision is None: - raise RuntimeError("GeoHash cannot work with this large search radius (km): " + search_radius_in_km) + raise RuntimeError( + "GeoHash cannot work with this large search radius (km): " + search_radius_in_km + ) return suggested_precision + def calc_transfers(conn, threshold_meters=1000): - geohash_precision = _get_geo_hash_precision(threshold_meters / 1000.) + geohash_precision = _get_geo_hash_precision(threshold_meters / 1000.0) geo_index = GeoGridIndex(precision=geohash_precision) g = GTFS(conn) stops = g.get_table("stops") @@ -52,7 +56,9 @@ def calc_transfers(conn, threshold_meters=1000): geo_index.add_point(stop_geopoint) stop_geopoints.append(stop_geopoint) for stop_geopoint in stop_geopoints: - nearby_stop_geopoints = geo_index.get_nearest_points_dirty(stop_geopoint, threshold_meters / 1000.0, "km") + nearby_stop_geopoints = geo_index.get_nearest_points_dirty( + stop_geopoint, threshold_meters / 1000.0, "km" + ) from_stop_I = int(stop_geopoint.ref) from_lat = stop_geopoint.latitude from_lon = stop_geopoint.longitude @@ -71,32 +77,44 @@ def calc_transfers(conn, threshold_meters=1000): distances.append(distance) n_pairs = len(to_stop_Is) - from_stop_Is = [from_stop_I]*n_pairs - cursor.executemany('INSERT OR REPLACE INTO stop_distances VALUES (?, ?, ?, ?, ?, ?);', - zip(from_stop_Is, to_stop_Is, distances, [None]*n_pairs, [None]*n_pairs, [None]*n_pairs)) - cursor.execute('CREATE INDEX IF NOT EXISTS idx_sd_fsid ON stop_distances (from_stop_I);') + from_stop_Is = [from_stop_I] * n_pairs + cursor.executemany( + "INSERT OR REPLACE INTO stop_distances VALUES (?, ?, ?, ?, ?, ?);", + zip( + from_stop_Is, + to_stop_Is, + distances, + [None] * n_pairs, + [None] * n_pairs, + [None] * n_pairs, + ), + ) + cursor.execute("CREATE INDEX IF NOT EXISTS idx_sd_fsid ON stop_distances (from_stop_I);") def _export_transfers(conn, fname): conn = GTFS(conn).conn cur = conn.cursor() - cur.execute('SELECT S1.lat, S1.lon, S2.lat, S2.lon, SD.d ' - 'FROM stop_distances SD ' - ' LEFT JOIN stops S1 ON (SD.from_stop_I=S1.stop_I) ' - ' LEFT JOIN stops S2 ON (SD.to_stop_I =S2.stop_I)') - f = open(fname, 'w') + cur.execute( + "SELECT S1.lat, S1.lon, S2.lat, S2.lon, SD.d " + "FROM stop_distances SD " + " LEFT JOIN stops S1 ON (SD.from_stop_I=S1.stop_I) " + " LEFT JOIN stops S2 ON (SD.to_stop_I =S2.stop_I)" + ) + f = open(fname, "w") for row in cur: - print(' '.join(str(x) for x in row), file=f) + print(" ".join(str(x) for x in row), file=f) def main(): import sys + cmd = sys.argv[1] - if cmd == 'calc': + if cmd == "calc": dbname = sys.argv[2] conn = GTFS(dbname).conn calc_transfers(conn) - elif cmd == 'export': + elif cmd == "export": _export_transfers(sys.argv[2], sys.argv[3]) diff --git a/gtfspy/colormaps.py b/gtfspy/colormaps.py index c6270c7..7192fbc 100644 --- a/gtfspy/colormaps.py +++ b/gtfspy/colormaps.py @@ -1,8 +1,8 @@ -import matplotlib.colors import matplotlib.cm import matplotlib.colorbar +import matplotlib.colors import matplotlib.pyplot -import numpy + # colormaps: "viridis", "plasma_r","seismic" @@ -68,11 +68,9 @@ def get_list_of_colors(values, observable_name=None): colorvalues.append(colorvalue) return colorvalues, norm, cmap + def createcolorbar(cmap, norm): """Create a colourbar with limits of lwr and upr""" cax, kw = matplotlib.colorbar.make_axes(matplotlib.pyplot.gca()) c = matplotlib.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm) return c - - - diff --git a/gtfspy/exports.py b/gtfspy/exports.py index 0e27f94..c26b810 100644 --- a/gtfspy/exports.py +++ b/gtfspy/exports.py @@ -9,8 +9,11 @@ from gtfspy import route_types from gtfspy.gtfs import GTFS from gtfspy import util -from gtfspy.networks import stop_to_stop_networks_by_type, temporal_network, \ - combined_stop_to_stop_transit_network +from gtfspy.networks import ( + stop_to_stop_networks_by_type, + temporal_network, + combined_stop_to_stop_transit_network, +) from gtfspy.route_types import ROUTE_TYPE_TO_ZORDER @@ -22,9 +25,9 @@ def write_walk_transfer_edges(gtfs, output_file_name): output_file_name: str """ transfers = gtfs.get_table("stop_distances") - transfers.drop([u"min_transfer_time", u"timed_transfer"], 1, inplace=True) + transfers.drop(["min_transfer_time", "timed_transfer"], 1, inplace=True) with util.create_file(output_file_name, tmpdir=True, keepext=True) as tmpfile: - transfers.to_csv(tmpfile, encoding='utf-8', index=False) + transfers.to_csv(tmpfile, encoding="utf-8", index=False) def write_nodes(gtfs, output, fields=None): @@ -41,43 +44,33 @@ def write_nodes(gtfs, output, fields=None): if fields is not None: nodes = nodes[fields] with util.create_file(output, tmpdir=True, keepext=True) as tmpfile: - nodes.to_csv(tmpfile, encoding='utf-8', index=False, sep=";") + nodes.to_csv(tmpfile, encoding="utf-8", index=False, sep=";") def create_stops_geojson_dict(gtfs, fields=None): nodes = gtfs.get_table("stops") if fields is None: - fields = {'name': 'stop_name', 'stop_I': 'stop_I', 'lat': 'lat', 'lon': 'lon'} - assert (fields['lat'] == 'lat' and fields['lon'] == 'lon') + fields = {"name": "stop_name", "stop_I": "stop_I", "lat": "lat", "lon": "lon"} + assert fields["lat"] == "lat" and fields["lon"] == "lon" nodes = nodes[list(fields.keys())] nodes.replace(list(fields.keys()), [fields[key] for key in fields.keys()], inplace=True) - assert ('lat' in nodes.columns) - assert ('lon' in nodes.columns) + assert "lat" in nodes.columns + assert "lon" in nodes.columns features = [] for i, node_tuple in enumerate(nodes.itertuples()): - feature = {"type": "Feature", - "id": str(i), - "geometry": { - "type": "Point", - "coordinates": [ - node_tuple.lon, - node_tuple.lat - ] - }, - "properties": { - "stop_I": str(node_tuple.stop_I), - "name": node_tuple.name - } - } + feature = { + "type": "Feature", + "id": str(i), + "geometry": {"type": "Point", "coordinates": [node_tuple.lon, node_tuple.lat]}, + "properties": {"stop_I": str(node_tuple.stop_I), "name": node_tuple.name}, + } features.append(feature) - geojson = { - "type": "FeatureCollection", - "features": features - } + geojson = {"type": "FeatureCollection", "features": features} return geojson + def write_stops_geojson(gtfs, out_file, fields=None): """ Parameters @@ -94,7 +87,7 @@ def write_stops_geojson(gtfs, out_file, fields=None): out_file.write(json.dumps(geojson)) else: with util.create_file(out_file, tmpdir=True, keepext=True) as tmpfile_path: - tmpfile = open(tmpfile_path, 'w') + tmpfile = open(tmpfile_path, "w") tmpfile.write(json.dumps(geojson)) @@ -146,10 +139,12 @@ def write_temporal_networks_by_route_type(gtfs, extract_output_dir): """ util.makedirs(extract_output_dir) for route_type in route_types.TRANSIT_ROUTE_TYPES: - pandas_data_frame = temporal_network(gtfs, start_time_ut=None, end_time_ut=None, route_type=route_type) + pandas_data_frame = temporal_network( + gtfs, start_time_ut=None, end_time_ut=None, route_type=route_type + ) tag = route_types.ROUTE_TYPE_TO_LOWERCASE_TAG[route_type] out_file_name = os.path.join(extract_output_dir, tag + ".tnet") - pandas_data_frame.to_csv(out_file_name, encoding='utf-8', index=False) + pandas_data_frame.to_csv(out_file_name, encoding="utf-8", index=False) def write_temporal_network(gtfs, output_filename, start_time_ut=None, end_time_ut=None): @@ -166,7 +161,7 @@ def write_temporal_network(gtfs, output_filename, start_time_ut=None, end_time_u """ util.makedirs(os.path.dirname(os.path.abspath(output_filename))) pandas_data_frame = temporal_network(gtfs, start_time_ut=start_time_ut, end_time_ut=end_time_ut) - pandas_data_frame.to_csv(output_filename, encoding='utf-8', index=False) + pandas_data_frame.to_csv(output_filename, encoding="utf-8", index=False) def _write_stop_to_stop_network_edges(net, file_name, data=True, fmt=None): @@ -192,7 +187,7 @@ def _write_stop_to_stop_network_edges(net, file_name, data=True, fmt=None): else: networkx.write_edgelist(net, file_name) elif fmt == "csv": - with open(file_name, 'w') as f: + with open(file_name, "w") as f: # writing out the header edge_iter = net.edges_iter(data=True) _, _, edg_data = next(edge_iter) @@ -214,69 +209,75 @@ def _write_stop_to_stop_network_edges(net, file_name, data=True, fmt=None): def create_sections_geojson_dict(G, start_time_ut=None, end_time_ut=None): - multi_di_graph = combined_stop_to_stop_transit_network(G, start_time_ut=start_time_ut, end_time_ut=end_time_ut) + multi_di_graph = combined_stop_to_stop_transit_network( + G, start_time_ut=start_time_ut, end_time_ut=end_time_ut + ) stops = G.get_table("stops") stop_I_to_coords = {row.stop_I: [row.lon, row.lat] for row in stops.itertuples()} gjson = {"type": "FeatureCollection"} features = [] gjson["features"] = features data = list(multi_di_graph.edges(data=True)) - data.sort(key=lambda el: ROUTE_TYPE_TO_ZORDER[el[2]['route_type']]) + data.sort(key=lambda el: ROUTE_TYPE_TO_ZORDER[el[2]["route_type"]]) for from_stop_I, to_stop_I, data in data: feature = {"type": "Feature"} geometry = { "type": "LineString", - 'coordinates': [stop_I_to_coords[from_stop_I], stop_I_to_coords[to_stop_I]] + "coordinates": [stop_I_to_coords[from_stop_I], stop_I_to_coords[to_stop_I]], } - feature['geometry'] = geometry - route_I_counts = data['route_I_counts'] + feature["geometry"] = geometry + route_I_counts = data["route_I_counts"] route_I_counts = {str(key): int(value) for key, value in route_I_counts.items()} - data['route_I_counts'] = route_I_counts + data["route_I_counts"] = route_I_counts properties = data - properties['from_stop_I'] = int(from_stop_I) - properties['to_stop_I'] = int(to_stop_I) - feature['properties'] = data + properties["from_stop_I"] = int(from_stop_I) + properties["to_stop_I"] = int(to_stop_I) + feature["properties"] = data features.append(feature) return gjson + def write_sections_geojson(G, output_file, start_time_ut=None, end_time_ut=None): gjson = create_sections_geojson_dict(G, start_time_ut=start_time_ut, end_time_ut=end_time_ut) if hasattr(output_file, "write"): output_file.write(json.dumps(gjson)) else: - with open(output_file, 'w') as f: + with open(output_file, "w") as f: f.write(json.dumps(gjson)) + def create_routes_geojson_dict(G): - assert(isinstance(G, GTFS)) + assert isinstance(G, GTFS) gjson = {"type": "FeatureCollection"} features = [] for routeShape in G.get_all_route_shapes(use_shapes=False): feature = {"type": "Feature"} geometry = { "type": "LineString", - "coordinates": list(zip(routeShape['lons'], routeShape['lats'])) + "coordinates": list(zip(routeShape["lons"], routeShape["lats"])), + } + feature["geometry"] = geometry + properties = { + "route_type": int(routeShape["type"]), + "route_I": int(routeShape["route_I"]), + "route_name": str(routeShape["name"]), } - feature['geometry'] = geometry - properties = {"route_type": int(routeShape['type']), - "route_I": int(routeShape['route_I']), - "route_name": str(routeShape['name'])} - feature['properties'] = properties + feature["properties"] = properties features.append(feature) - gjson['features'] = features + gjson["features"] = features return gjson + def write_routes_geojson(G, output_file): gjson = create_routes_geojson_dict(G) if hasattr(output_file, "write"): output_file.write(json.dumps(gjson)) else: - with open(output_file, 'w') as f: + with open(output_file, "w") as f: f.write(json.dumps(gjson)) return None - def write_gtfs(gtfs, output): """ Write out the database according to the GTFS format. @@ -294,15 +295,15 @@ def write_gtfs(gtfs, output): """ output = os.path.abspath(output) uuid_str = "tmp_" + str(uuid.uuid1()) - if output[-4:] == '.zip': - zip = True + if output[-4:] == ".zip": + zip_file = True out_basepath = os.path.dirname(os.path.abspath(output)) if not os.path.exists(out_basepath): raise IOError(out_basepath + " does not exist, cannot write gtfs as a zip") tmp_dir = os.path.join(out_basepath, str(uuid_str)) # zip_file_na,e = ../out_basedir + ".zip else: - zip = False + zip_file = False out_basepath = output tmp_dir = os.path.join(out_basepath + "_" + str(uuid_str)) @@ -325,19 +326,18 @@ def write_gtfs(gtfs, output): } for table, writer in gtfs_table_to_writer.items(): - fname_to_write = os.path.join(tmp_dir, table + '.txt') + fname_to_write = os.path.join(tmp_dir, table + ".txt") print(fname_to_write) - writer(gtfs, open(os.path.join(tmp_dir, table + '.txt'), 'w')) + writer(gtfs, open(os.path.join(tmp_dir, table + ".txt"), "w")) - if zip: - shutil.make_archive(output[:-4], 'zip', tmp_dir) + if zip_file: + shutil.make_archive(output[:-4], "zip", tmp_dir) shutil.rmtree(tmp_dir) else: print("moving " + str(tmp_dir) + " to " + out_basepath) os.rename(tmp_dir, out_basepath) - def _remove_I_columns(df): """ Remove columns ending with I from a pandas.DataFrame @@ -355,13 +355,21 @@ def _remove_I_columns(df): del df[column] -def __replace_I_with_id(gtfs, current_table, from_table_name, old_column_current, old_column_from, new_column_in_from, - new_column_name=None): +def __replace_I_with_id( + gtfs, + current_table, + from_table_name, + old_column_current, + old_column_from, + new_column_in_from, + new_column_name=None, +): if new_column_name is None: new_column_name = new_column_in_from from_table = gtfs.get_table(from_table_name) - merged = pandas.merge(current_table, from_table, how="left", - left_on=old_column_current, right_on=old_column_from) + merged = pandas.merge( + current_table, from_table, how="left", left_on=old_column_current, right_on=old_column_from + ) series = pandas.Series(merged[new_column_in_from], index=current_table.index) current_table[new_column_name] = series @@ -369,12 +377,14 @@ def __replace_I_with_id(gtfs, current_table, from_table_name, old_column_current def _write_gtfs_agencies(gtfs, output_file): # remove agency_I agencies_table = gtfs.get_table("agencies") - assert (isinstance(agencies_table, pandas.DataFrame)) - columns_to_change = {'name': 'agency_name', - 'url': 'agency_url', - 'timezone': 'agency_timezone', - 'lang': 'agency_lang', - 'phone': 'agency_phone'} + assert isinstance(agencies_table, pandas.DataFrame) + columns_to_change = { + "name": "agency_name", + "url": "agency_url", + "timezone": "agency_timezone", + "lang": "agency_lang", + "phone": "agency_phone", + } agencies_table = agencies_table.rename(columns=columns_to_change) _remove_I_columns(agencies_table) agencies_table.to_csv(output_file, index=False) @@ -382,14 +392,15 @@ def _write_gtfs_agencies(gtfs, output_file): def _write_gtfs_stops(gtfs, output_file): stops_table = gtfs.get_table("stops") - assert (isinstance(stops_table, pandas.DataFrame)) - columns_to_change = {'name': 'stop_name', - 'url': 'stop_url', - 'lat': 'stop_lat', - 'lon': 'stop_lon', - 'code': 'stop_code', - 'desc': 'stop_desc' - } + assert isinstance(stops_table, pandas.DataFrame) + columns_to_change = { + "name": "stop_name", + "url": "stop_url", + "lat": "stop_lat", + "lon": "stop_lon", + "code": "stop_code", + "desc": "stop_desc", + } stops_table = stops_table.rename(columns=columns_to_change) # Remove stop_I @@ -401,27 +412,28 @@ def _write_gtfs_stops(gtfs, output_file): except KeyError: parent_station = "" parent_stations.append(parent_station) - stops_table['parent_station'] = pandas.Series(parent_stations, index=stops_table.index) + stops_table["parent_station"] = pandas.Series(parent_stations, index=stops_table.index) _remove_I_columns(stops_table) stops_table.to_csv(output_file, index=False) def _write_gtfs_routes(gtfs, output_file): routes_table = gtfs.get_table("routes") - columns_to_change = {'name': 'route_short_name', - 'long_name': 'route_long_name', - 'url': 'route_url', - 'type': 'route_type', - 'desc': 'route_desc', - 'color': 'route_color', - 'text_color': 'route_text_color' - } + columns_to_change = { + "name": "route_short_name", + "long_name": "route_long_name", + "url": "route_url", + "type": "route_type", + "desc": "route_desc", + "color": "route_color", + "text_color": "route_text_color", + } routes_table = routes_table.rename(columns=columns_to_change) # replace agency_I agencies_table = gtfs.get_table("agencies") - agency_ids = pandas.merge(routes_table, agencies_table, how="left", on='agency_I')['agency_id'] - routes_table['agency_id'] = pandas.Series(agency_ids, index=routes_table.index) + agency_ids = pandas.merge(routes_table, agencies_table, how="left", on="agency_I")["agency_id"] + routes_table["agency_id"] = pandas.Series(agency_ids, index=routes_table.index) _remove_I_columns(routes_table) routes_table.to_csv(output_file, index=False) @@ -429,104 +441,116 @@ def _write_gtfs_routes(gtfs, output_file): def _write_gtfs_trips(gtfs, output_file): trips_table = gtfs.get_table("trips") columns_to_change = { - 'headsign': 'trip_headsign', + "headsign": "trip_headsign", } trips_table = trips_table.rename(columns=columns_to_change) - __replace_I_with_id(gtfs, trips_table, 'routes', 'route_I', 'route_I', 'route_id') - __replace_I_with_id(gtfs, trips_table, 'calendar', 'service_I', 'service_I', 'service_id') + __replace_I_with_id(gtfs, trips_table, "routes", "route_I", "route_I", "route_id") + __replace_I_with_id(gtfs, trips_table, "calendar", "service_I", "service_I", "service_id") _remove_I_columns(trips_table) - del [trips_table['start_time_ds']] - del [trips_table['end_time_ds']] + del [trips_table["start_time_ds"]] + del [trips_table["end_time_ds"]] trips_table.to_csv(output_file, index=False) def _write_gtfs_stop_times(gtfs, output_file): - stop_times_table = gtfs.get_table('stop_times') + stop_times_table = gtfs.get_table("stop_times") columns_to_change = { - 'seq': 'stop_sequence', - 'arr_time': 'arrival_time', - 'dep_time': 'departure_time' + "seq": "stop_sequence", + "arr_time": "arrival_time", + "dep_time": "departure_time", } stop_times_table = stop_times_table.rename(columns=columns_to_change) # replace trip_I with trip_id - __replace_I_with_id(gtfs, stop_times_table, 'trips', 'trip_I', 'trip_I', 'trip_id') - __replace_I_with_id(gtfs, stop_times_table, 'stops', 'stop_I', 'stop_I', 'stop_id') + __replace_I_with_id(gtfs, stop_times_table, "trips", "trip_I", "trip_I", "trip_id") + __replace_I_with_id(gtfs, stop_times_table, "stops", "stop_I", "stop_I", "stop_id") # delete unneeded columns: - del [stop_times_table['arr_time_hour']] - del [stop_times_table['arr_time_ds']] - del [stop_times_table['dep_time_ds']] - del [stop_times_table['shape_break']] + del [stop_times_table["arr_time_hour"]] + del [stop_times_table["arr_time_ds"]] + del [stop_times_table["dep_time_ds"]] + del [stop_times_table["shape_break"]] _remove_I_columns(stop_times_table) stop_times_table.to_csv(output_file, index=False) def _write_gtfs_calendar(gtfs, output_file): - calendar_table = gtfs.get_table('calendar') + calendar_table = gtfs.get_table("calendar") columns_to_change = { - 'm': 'monday', - 't': 'tuesday', - 'w': 'wednesday', - 'th': 'thursday', - 'f': 'friday', - 's': 'saturday', - 'su': 'sunday' + "m": "monday", + "t": "tuesday", + "w": "wednesday", + "th": "thursday", + "f": "friday", + "s": "saturday", + "su": "sunday", } calendar_table = calendar_table.rename(columns=columns_to_change) - calendar_table['start_date'] = [date.replace("-", "") for date in calendar_table['start_date']] - calendar_table['end_date'] = [date.replace("-", "") for date in calendar_table['end_date']] + calendar_table["start_date"] = [date.replace("-", "") for date in calendar_table["start_date"]] + calendar_table["end_date"] = [date.replace("-", "") for date in calendar_table["end_date"]] _remove_I_columns(calendar_table) calendar_table.to_csv(output_file, index=False) def _write_gtfs_calendar_dates(gtfs, output_file): - calendar_dates_table = gtfs.get_table('calendar_dates') - __replace_I_with_id(gtfs, calendar_dates_table, 'calendar', 'service_I', 'service_I', 'service_id') + calendar_dates_table = gtfs.get_table("calendar_dates") + __replace_I_with_id( + gtfs, calendar_dates_table, "calendar", "service_I", "service_I", "service_id" + ) _remove_I_columns(calendar_dates_table) calendar_dates_table.to_csv(output_file, index=False) def _write_gtfs_shapes(gtfs, ouput_file): - shapes_table = gtfs.get_table('shapes') + shapes_table = gtfs.get_table("shapes") columns_to_change = { - 'lat': 'shape_pt_lat', - 'lon': 'shape_pt_lon', - 'seq': 'shape_pt_sequence', - 'd': 'shape_dist_traveled' + "lat": "shape_pt_lat", + "lon": "shape_pt_lon", + "seq": "shape_pt_sequence", + "d": "shape_dist_traveled", } shapes_table = shapes_table.rename(columns=columns_to_change) shapes_table.to_csv(ouput_file, index=False) def _write_gtfs_feed_info(gtfs, output_file): - gtfs.get_table('feed_info').to_csv(output_file, index=False) + gtfs.get_table("feed_info").to_csv(output_file, index=False) def _write_gtfs_frequencies(gtfs, output_file): - raise NotImplementedError("Frequencies should not be outputted from GTFS as they are included in other tables.") + raise NotImplementedError( + "Frequencies should not be outputted from GTFS as they are included in other tables." + ) def _write_gtfs_transfers(gtfs, output_file): - transfers_table = gtfs.get_table('transfers') - __replace_I_with_id(gtfs, transfers_table, 'stops', 'from_stop_I', 'stop_I', 'stop_id', 'from_stop_id') - __replace_I_with_id(gtfs, transfers_table, 'stops', 'to_stop_I', 'stop_I', 'stop_id', 'to_stop_id') + transfers_table = gtfs.get_table("transfers") + __replace_I_with_id( + gtfs, transfers_table, "stops", "from_stop_I", "stop_I", "stop_id", "from_stop_id" + ) + __replace_I_with_id( + gtfs, transfers_table, "stops", "to_stop_I", "stop_I", "stop_id", "to_stop_id" + ) _remove_I_columns(transfers_table) transfers_table.to_csv(output_file, index=False) def _write_gtfs_stop_distances(gtfs, output_file): - stop_distances = gtfs.get_table('stop_distances') - __replace_I_with_id(gtfs, stop_distances, 'stops', 'from_stop_I', 'stop_I', 'stop_id', 'from_stop_id') - __replace_I_with_id(gtfs, stop_distances, 'stops', 'to_stop_I', 'stop_I', 'stop_id', 'to_stop_id') + stop_distances = gtfs.get_table("stop_distances") + __replace_I_with_id( + gtfs, stop_distances, "stops", "from_stop_I", "stop_I", "stop_id", "from_stop_id" + ) + __replace_I_with_id( + gtfs, stop_distances, "stops", "to_stop_I", "stop_I", "stop_id", "to_stop_id" + ) _remove_I_columns(stop_distances) - del stop_distances['min_transfer_time'] - del stop_distances['timed_transfer'] + del stop_distances["min_transfer_time"] + del stop_distances["timed_transfer"] stop_distances.to_csv(output_file, index=False) @@ -539,28 +563,33 @@ def _write_gtfs_stop_distances(gtfs, output_file): # stop_times_table['departure_time'] = pandas.Series(departure_times, stop_times_table.index) - - def main(): import argparse - parser = argparse.ArgumentParser(description="Create network extracts from already imported GTFS files.") - subparsers = parser.add_subparsers(dest='cmd') + parser = argparse.ArgumentParser( + description="Create network extracts from already imported GTFS files." + ) + subparsers = parser.add_subparsers(dest="cmd") # parsing import - parser_routingnets = subparsers.add_parser('extract_temporal', help="Direct import GTFS->sqlite") - parser_routingnets.add_argument('gtfs', help='Input GTFS .sqlite (must end in .sqlite)') - parser_routingnets.add_argument('basename', help='Basename for the output files') # Parsing copy + parser_routingnets = subparsers.add_parser( + "extract_temporal", help="Direct import GTFS->sqlite" + ) + parser_routingnets.add_argument("gtfs", help="Input GTFS .sqlite (must end in .sqlite)") + parser_routingnets.add_argument( + "basename", help="Basename for the output files" + ) # Parsing copy args = parser.parse_args() # if the first argument is import, import a GTFS directory to a .sqlite database. # Both directory and - if args.cmd == 'extract_temporal': + if args.cmd == "extract_temporal": gtfs_fname = args.gtfs output_basename = args.basename from gtfspy.gtfs import GTFS + gtfs = GTFS(gtfs_fname) nodes_filename = output_basename + ".nodes.csv" @@ -582,5 +611,3 @@ def main(): if __name__ == "__main__": main() - - diff --git a/gtfspy/extended_route_types.py b/gtfspy/extended_route_types.py index 6c7933b..7c0b528 100644 --- a/gtfspy/extended_route_types.py +++ b/gtfspy/extended_route_types.py @@ -46,5 +46,5 @@ 1701: 5, 1702: 3, 1100: 99, - 1104: 99 + 1104: 99, } diff --git a/gtfspy/filter.py b/gtfspy/filter.py index 1d37932..92669ae 100644 --- a/gtfspy/filter.py +++ b/gtfspy/filter.py @@ -1,56 +1,81 @@ -import time +import datetime +import logging import os import shutil -import logging import sqlite3 -import datetime +import time import pandas import gtfspy +from gtfspy import gtfs +from gtfspy import stats from gtfspy import util from gtfspy.gtfs import GTFS from gtfspy.import_loaders.day_loader import recreate_days_table from gtfspy.import_loaders.day_trips_materializer import recreate_day_trips2_table from gtfspy.import_loaders.stop_times_loader import resequence_stop_times_seq_values from gtfspy.import_loaders.trip_loader import update_trip_travel_times_ds -from gtfspy.util import wgs84_distance, set_process_timezone -from gtfspy import stats -from gtfspy import gtfs +from gtfspy.util import wgs84_distance FILTERED = True NOT_FILTERED = False -DELETE_FREQUENCIES_NOT_REFERENCED_IN_TRIPS_SQL = "DELETE FROM frequencies WHERE trip_I NOT IN (SELECT DISTINCT trip_I FROM trips)" -DELETE_SHAPES_NOT_REFERENCED_IN_TRIPS_SQL = 'DELETE FROM shapes WHERE shape_id NOT IN (SELECT shape_id FROM trips)' -DELETE_ROUTES_NOT_PRESENT_IN_TRIPS_SQL = 'DELETE FROM routes WHERE route_I NOT IN (SELECT route_I FROM trips)' -DELETE_DAYS_ENTRIES_NOT_PRESENT_IN_TRIPS_SQL = "DELETE FROM days WHERE trip_I NOT IN (SELECT trip_I FROM trips)" -DELETE_DAY_TRIPS2_ENTRIES_NOT_PRESENT_IN_TRIPS_SQL = "DELETE FROM day_trips2 WHERE trip_I NOT IN (SELECT trip_I FROM trips)" -DELETE_FREQUENCIES_ENTRIES_NOT_PRESENT_IN_TRIPS = "DELETE FROM frequencies WHERE trip_I NOT IN (SELECT trip_I FROM trips)" -DELETE_CALENDAR_ENTRIES_FOR_NON_REFERENCE_SERVICE_IS_SQL = "DELETE FROM calendar WHERE service_I NOT IN (SELECT distinct(service_I) FROM trips)" -DELETE_CALENDAR_DATES_ENTRIES_FOR_NON_REFERENCE_SERVICE_IS_SQL = "DELETE FROM calendar_dates WHERE service_I NOT IN (SELECT distinct(service_I) FROM trips)" -DELETE_AGENCIES_NOT_REFERENCED_IN_ROUTES_SQL = "DELETE FROM agencies WHERE agency_I NOT IN (SELECT distinct(agency_I) FROM routes)" -DELETE_STOP_TIMES_NOT_REFERENCED_IN_TRIPS_SQL = 'DELETE FROM stop_times WHERE trip_I NOT IN (SELECT trip_I FROM trips)' -DELETE_STOP_DISTANCE_ENTRIES_WITH_NONEXISTENT_STOPS_SQL = "DELETE FROM stop_distances " \ - "WHERE from_stop_I NOT IN (SELECT stop_I FROM stops) " \ - " OR to_stop_I NOT IN (SELECT stop_I FROM stops)" -DELETE_TRIPS_NOT_IN_DAYS_SQL = 'DELETE FROM trips WHERE trip_I NOT IN (SELECT trip_I FROM days)' -DELETE_TRIPS_NOT_REFERENCED_IN_STOP_TIMES = 'DELETE FROM trips WHERE trip_I NOT IN (SELECT trip_I FROM stop_times)' +DELETE_FREQUENCIES_NOT_REFERENCED_IN_TRIPS_SQL = ( + "DELETE FROM frequencies WHERE trip_I NOT IN (SELECT DISTINCT trip_I FROM trips)" +) +DELETE_SHAPES_NOT_REFERENCED_IN_TRIPS_SQL = ( + "DELETE FROM shapes WHERE shape_id NOT IN (SELECT shape_id FROM trips)" +) +DELETE_ROUTES_NOT_PRESENT_IN_TRIPS_SQL = ( + "DELETE FROM routes WHERE route_I NOT IN (SELECT route_I FROM trips)" +) +DELETE_DAYS_ENTRIES_NOT_PRESENT_IN_TRIPS_SQL = ( + "DELETE FROM days WHERE trip_I NOT IN (SELECT trip_I FROM trips)" +) +DELETE_DAY_TRIPS2_ENTRIES_NOT_PRESENT_IN_TRIPS_SQL = ( + "DELETE FROM day_trips2 WHERE trip_I NOT IN (SELECT trip_I FROM trips)" +) +DELETE_FREQUENCIES_ENTRIES_NOT_PRESENT_IN_TRIPS = ( + "DELETE FROM frequencies WHERE trip_I NOT IN (SELECT trip_I FROM trips)" +) +DELETE_CALENDAR_ENTRIES_FOR_NON_REFERENCE_SERVICE_IS_SQL = ( + "DELETE FROM calendar WHERE service_I NOT IN (SELECT distinct(service_I) FROM trips)" +) +DELETE_CALENDAR_DATES_ENTRIES_FOR_NON_REFERENCE_SERVICE_IS_SQL = ( + "DELETE FROM calendar_dates WHERE service_I NOT IN (SELECT distinct(service_I) FROM trips)" +) +DELETE_AGENCIES_NOT_REFERENCED_IN_ROUTES_SQL = ( + "DELETE FROM agencies WHERE agency_I NOT IN (SELECT distinct(agency_I) FROM routes)" +) +DELETE_STOP_TIMES_NOT_REFERENCED_IN_TRIPS_SQL = ( + "DELETE FROM stop_times WHERE trip_I NOT IN (SELECT trip_I FROM trips)" +) +DELETE_STOP_DISTANCE_ENTRIES_WITH_NONEXISTENT_STOPS_SQL = ( + "DELETE FROM stop_distances " + "WHERE from_stop_I NOT IN (SELECT stop_I FROM stops) " + " OR to_stop_I NOT IN (SELECT stop_I FROM stops)" +) +DELETE_TRIPS_NOT_IN_DAYS_SQL = "DELETE FROM trips WHERE trip_I NOT IN (SELECT trip_I FROM days)" +DELETE_TRIPS_NOT_REFERENCED_IN_STOP_TIMES = ( + "DELETE FROM trips WHERE trip_I NOT IN (SELECT trip_I FROM stop_times)" +) class FilterExtract(object): - - def __init__(self, - G, - copy_db_path, - buffer_distance_km=None, - buffer_lat=None, - buffer_lon=None, - update_metadata=True, - start_date=None, - end_date=None, - agency_ids_to_preserve=None, - agency_distance=None): + def __init__( + self, + G, + copy_db_path, + buffer_distance_km=None, + buffer_lat=None, + buffer_lon=None, + update_metadata=True, + start_date=None, + end_date=None, + agency_ids_to_preserve=None, + agency_distance=None, + ): """ Copy a database, and then based on various filters. Only method `create_filtered_copy` is provided as we do not want to take the risk of @@ -122,9 +147,12 @@ def __init__(self, self.this_db_path = self.gtfs.get_main_database_path() assert os.path.exists(self.this_db_path), "Copying of in-memory databases is not supported" - assert os.path.exists(os.path.dirname(os.path.abspath(copy_db_path))), \ - "the directory where the copied database will reside should exist beforehand" - assert not os.path.exists(copy_db_path), "the resulting database exists already: %s" % copy_db_path + assert os.path.exists( + os.path.dirname(os.path.abspath(copy_db_path)) + ), "the directory where the copied database will reside should exist beforehand" + assert not os.path.exists(copy_db_path), ( + "the resulting database exists already: %s" % copy_db_path + ) def create_filtered_copy(self): # this with statement @@ -137,8 +165,8 @@ def create_filtered_copy(self): filtered = False filtered = self._delete_rows_by_start_and_end_date() or filtered - if self.copy_db_conn.execute('SELECT count(*) FROM days').fetchone() == (0,): - raise ValueError('No data left after filtering') + if self.copy_db_conn.execute("SELECT count(*) FROM days").fetchone() == (0,): + raise ValueError("No data left after filtering") filtered = self._filter_by_calendar() or filtered filtered = self._filter_by_agency() or filtered filtered = self._filter_spatially() or filtered @@ -162,36 +190,43 @@ def _delete_rows_by_start_and_end_date(self): if (self.start_date is not None) and (self.end_date is not None): start_date_ut = self.gtfs.get_day_start_ut(self.start_date) end_date_ut = self.gtfs.get_day_start_ut(self.end_date) - if self.copy_db_conn.execute("SELECT count(*) FROM day_trips2 WHERE start_time_ut IS null " - "OR end_time_ut IS null").fetchone() != (0,): - raise ValueError("Missing information in day_trips2 (start_time_ut and/or end_time_ut), " - "check trips.start_time_ds and trips.end_time_ds.") + if self.copy_db_conn.execute( + "SELECT count(*) FROM day_trips2 WHERE start_time_ut IS null " + "OR end_time_ut IS null" + ).fetchone() != (0,): + raise ValueError( + "Missing information in day_trips2 (start_time_ut and/or end_time_ut), " + "check trips.start_time_ds and trips.end_time_ds." + ) logging.info("Filtering based on start_time_ut and end_time_ut") table_to_preserve_map = { "calendar": "start_date < date({filter_end_ut}, 'unixepoch', 'localtime') " - "AND " - "end_date >= date({filter_start_ut}, 'unixepoch', 'localtime') ", + "AND " + "end_date >= date({filter_start_ut}, 'unixepoch', 'localtime') ", "calendar_dates": "date >= date({filter_start_ut}, 'unixepoch', 'localtime') " - "AND " - "date < date({filter_end_ut}, 'unixepoch', 'localtime') ", - "day_trips2": 'start_time_ut < {filter_end_ut} ' - 'AND ' - 'end_time_ut > {filter_start_ut} ', + "AND " + "date < date({filter_end_ut}, 'unixepoch', 'localtime') ", + "day_trips2": "start_time_ut < {filter_end_ut} " + "AND " + "end_time_ut > {filter_start_ut} ", "days": "day_start_ut >= {filter_start_ut} " - "AND " - "day_start_ut < {filter_end_ut} " - } - table_to_remove_map = {key: "WHERE NOT ( " + to_preserve + " );" - for key, to_preserve in table_to_preserve_map.items() } + "AND " + "day_start_ut < {filter_end_ut} ", + } + table_to_remove_map = { + key: "WHERE NOT ( " + to_preserve + " );" + for key, to_preserve in table_to_preserve_map.items() + } # Ensure that process timezone is correct as we rely on 'localtime' in the SQL statements. GTFS(self.copy_db_conn).set_current_process_time_zone() # remove the 'source' entries from tables for table, query_template in table_to_remove_map.items(): - param_dict = {"filter_start_ut": str(start_date_ut), - "filter_end_ut": str(end_date_ut)} - query = "DELETE FROM " + table + " " + \ - query_template.format(**param_dict) + param_dict = { + "filter_start_ut": str(start_date_ut), + "filter_end_ut": str(end_date_ut), + } + query = "DELETE FROM " + table + " " + query_template.format(**param_dict) self.copy_db_conn.execute(query) self.copy_db_conn.commit() @@ -242,15 +277,20 @@ def _filter_by_calendar(self): if (self.start_date is not None) and (self.end_date is not None): logging.info("Making date extract") - start_date_query = "UPDATE calendar " \ - "SET start_date='{start_date}' " \ - "WHERE start_date<'{start_date}' ".format(start_date=self.start_date) + start_date_query = ( + "UPDATE calendar " + "SET start_date='{start_date}' " + "WHERE start_date<'{start_date}' ".format(start_date=self.start_date) + ) self.copy_db_conn.execute(start_date_query) - end_date_query = "UPDATE calendar " \ - "SET end_date='{end_date_to_include}' " \ - "WHERE end_date>'{end_date_to_include}' " \ - .format(end_date_to_include=self.end_date_to_include_str) + end_date_query = ( + "UPDATE calendar " + "SET end_date='{end_date_to_include}' " + "WHERE end_date>'{end_date_to_include}' ".format( + end_date_to_include=self.end_date_to_include_str + ) + ) self.copy_db_conn.execute(end_date_query) # then recursively delete further data: @@ -279,29 +319,38 @@ def _filter_by_agency(self): agencies = pandas.read_sql("SELECT * FROM agencies", self.copy_db_conn) agencies_to_remove = [] for idx, row in agencies.iterrows(): - if row['agency_id'] not in agency_ids_to_preserve: - agencies_to_remove.append(row['agency_id']) + if row["agency_id"] not in agency_ids_to_preserve: + agencies_to_remove.append(row["agency_id"]) for agency_id in agencies_to_remove: - self.copy_db_conn.execute('DELETE FROM agencies WHERE agency_id=?', (agency_id,)) + self.copy_db_conn.execute("DELETE FROM agencies WHERE agency_id=?", (agency_id,)) # and remove recursively related to the agencies: - self.copy_db_conn.execute('DELETE FROM routes WHERE ' - 'agency_I NOT IN (SELECT agency_I FROM agencies)') - self.copy_db_conn.execute('DELETE FROM trips WHERE ' - 'route_I NOT IN (SELECT route_I FROM routes)') - self.copy_db_conn.execute('DELETE FROM calendar WHERE ' - 'service_I NOT IN (SELECT service_I FROM trips)') - self.copy_db_conn.execute('DELETE FROM calendar_dates WHERE ' - 'service_I NOT IN (SELECT service_I FROM trips)') - self.copy_db_conn.execute('DELETE FROM days WHERE ' - 'trip_I NOT IN (SELECT trip_I FROM trips)') - self.copy_db_conn.execute('DELETE FROM stop_times WHERE ' - 'trip_I NOT IN (SELECT trip_I FROM trips)') - self.copy_db_conn.execute('DELETE FROM stop_times WHERE ' - 'trip_I NOT IN (SELECT trip_I FROM trips)') - self.copy_db_conn.execute('DELETE FROM shapes WHERE ' - 'shape_id NOT IN (SELECT shape_id FROM trips)') - self.copy_db_conn.execute('DELETE FROM day_trips2 WHERE ' - 'trip_I NOT IN (SELECT trip_I FROM trips)') + self.copy_db_conn.execute( + "DELETE FROM routes WHERE " "agency_I NOT IN (SELECT agency_I FROM agencies)" + ) + self.copy_db_conn.execute( + "DELETE FROM trips WHERE " "route_I NOT IN (SELECT route_I FROM routes)" + ) + self.copy_db_conn.execute( + "DELETE FROM calendar WHERE " "service_I NOT IN (SELECT service_I FROM trips)" + ) + self.copy_db_conn.execute( + "DELETE FROM calendar_dates WHERE " "service_I NOT IN (SELECT service_I FROM trips)" + ) + self.copy_db_conn.execute( + "DELETE FROM days WHERE " "trip_I NOT IN (SELECT trip_I FROM trips)" + ) + self.copy_db_conn.execute( + "DELETE FROM stop_times WHERE " "trip_I NOT IN (SELECT trip_I FROM trips)" + ) + self.copy_db_conn.execute( + "DELETE FROM stop_times WHERE " "trip_I NOT IN (SELECT trip_I FROM trips)" + ) + self.copy_db_conn.execute( + "DELETE FROM shapes WHERE " "shape_id NOT IN (SELECT shape_id FROM trips)" + ) + self.copy_db_conn.execute( + "DELETE FROM day_trips2 WHERE " "trip_I NOT IN (SELECT trip_I FROM trips)" + ) self.copy_db_conn.commit() return FILTERED else: @@ -321,14 +370,21 @@ def _filter_spatially(self): if self.buffer_lat is None or self.buffer_lon is None or self.buffer_distance_km is None: return NOT_FILTERED - print("filtering with lat: " + str(self.buffer_lat) + - " lon: " + str(self.buffer_lon) + - " buffer distance: " + str(self.buffer_distance_km)) - remove_all_trips_fully_outside_buffer(self.copy_db_conn, - self.buffer_lat, - self.buffer_lon, - self.buffer_distance_km, - update_secondary_data=False) + print( + "filtering with lat: " + + str(self.buffer_lat) + + " lon: " + + str(self.buffer_lon) + + " buffer distance: " + + str(self.buffer_distance_km) + ) + remove_all_trips_fully_outside_buffer( + self.copy_db_conn, + self.buffer_lat, + self.buffer_lon, + self.buffer_distance_km, + update_secondary_data=False, + ) logging.info("Making spatial extract") find_distance_func_name = add_wgs84_distance_function_to_db(self.copy_db_conn) @@ -336,16 +392,18 @@ def _filter_spatially(self): # select all stops that are within the buffer and have some stop_times assigned. stop_distance_filter_sql_base = ( - "SELECT DISTINCT stops.stop_I FROM stops, stop_times" + - " WHERE CAST(find_distance(lat, lon, {buffer_lat}, {buffer_lon}) AS INT) < {buffer_distance_meters}" + - " AND stops.stop_I=stop_times.stop_I" + "SELECT DISTINCT stops.stop_I FROM stops, stop_times" + + " WHERE CAST(find_distance(lat, lon, {buffer_lat}, {buffer_lon}) AS INT) < {buffer_distance_meters}" + + " AND stops.stop_I=stop_times.stop_I" ) stops_within_buffer_sql = stop_distance_filter_sql_base.format( buffer_lat=float(self.buffer_lat), buffer_lon=float(self.buffer_lon), - buffer_distance_meters=int(self.buffer_distance_km * 1000) + buffer_distance_meters=int(self.buffer_distance_km * 1000), + ) + stops_within_buffer = set( + row[0] for row in self.copy_db_conn.execute(stops_within_buffer_sql) ) - stops_within_buffer = set(row[0] for row in self.copy_db_conn.execute(stops_within_buffer_sql)) # For each trip_I, find smallest (min_seq) and largest (max_seq) stop sequence numbers that # are within the soft buffer_distance from the buffer_lon and buffer_lat, and add them into the @@ -353,12 +411,14 @@ def _filter_spatially(self): # Note that if a trip is OUT-IN-OUT-IN-OUT, this process preserves (at least) the part IN-OUT-IN of the trip. # Repeat until no more stops are found. - stops_within_buffer_string = "(" +",".join(str(stop_I) for stop_I in stops_within_buffer) + ")" - trip_min_max_include_seq_sql = ( - 'SELECT trip_I, min(seq) AS min_seq, max(seq) AS max_seq FROM stop_times, stops ' - 'WHERE stop_times.stop_I = stops.stop_I ' - ' AND stops.stop_I IN {stop_I_list}' - ' GROUP BY trip_I' + stops_within_buffer_string = ( + "(" + ",".join(str(stop_I) for stop_I in stops_within_buffer) + ")" + ) + trip_min_max_include_seq_sql = ( + "SELECT trip_I, min(seq) AS min_seq, max(seq) AS max_seq FROM stop_times, stops " + "WHERE stop_times.stop_I = stops.stop_I " + " AND stops.stop_I IN {stop_I_list}" + " GROUP BY trip_I" ).format(stop_I_list=stops_within_buffer_string) trip_I_min_seq_max_seq_df = pandas.read_sql(trip_min_max_include_seq_sql, self.copy_db_conn) @@ -369,57 +429,71 @@ def _filter_spatially(self): # DELETE FROM STOP_TIMES if min_seq == max_seq: # Only one entry in stop_times to be left, remove whole trip. - self.copy_db_conn.execute("DELETE FROM stop_times WHERE trip_I={trip_I}".format(trip_I=trip_I)) - self.copy_db_conn.execute("DELETE FROM trips WHERE trip_i={trip_I}".format(trip_I=trip_I)) + self.copy_db_conn.execute( + "DELETE FROM stop_times WHERE trip_I={trip_I}".format(trip_I=trip_I) + ) + self.copy_db_conn.execute( + "DELETE FROM trips WHERE trip_i={trip_I}".format(trip_I=trip_I) + ) else: # DELETE STOP_TIME ENTRIES BEFORE ENTERING AND AFTER DEPARTING THE BUFFER AREA - DELETE_STOP_TIME_ENTRIES_SQL = \ - "DELETE FROM stop_times WHERE trip_I={trip_I} AND (seq<{min_seq} OR seq>{max_seq})"\ - .format(trip_I=trip_I, max_seq=max_seq, min_seq=min_seq) + DELETE_STOP_TIME_ENTRIES_SQL = "DELETE FROM stop_times WHERE trip_I={trip_I} AND (seq<{min_seq} OR seq>{max_seq})".format( + trip_I=trip_I, max_seq=max_seq, min_seq=min_seq + ) self.copy_db_conn.execute(DELETE_STOP_TIME_ENTRIES_SQL) - STOPS_NOT_WITHIN_BUFFER__FOR_TRIP_SQL = \ - "SELECT seq, stop_I IN {stops_within_hard_buffer} AS within FROM stop_times WHERE trip_I={trip_I} ORDER BY seq"\ - .format(stops_within_hard_buffer=stops_within_buffer_string, trip_I=trip_I) - stop_times_within_buffer_df = pandas.read_sql(STOPS_NOT_WITHIN_BUFFER__FOR_TRIP_SQL, self.copy_db_conn) - if stop_times_within_buffer_df['within'].all(): + STOPS_NOT_WITHIN_BUFFER__FOR_TRIP_SQL = "SELECT seq, stop_I IN {stops_within_hard_buffer} AS within FROM stop_times WHERE trip_I={trip_I} ORDER BY seq".format( + stops_within_hard_buffer=stops_within_buffer_string, trip_I=trip_I + ) + stop_times_within_buffer_df = pandas.read_sql( + STOPS_NOT_WITHIN_BUFFER__FOR_TRIP_SQL, self.copy_db_conn + ) + if stop_times_within_buffer_df["within"].all(): continue else: _split_trip(self.copy_db_conn, trip_I, stop_times_within_buffer_df) - # Delete all shapes that are not fully within the buffer to avoid shapes going outside # the buffer area in a some cases. # This could probably be done in some more sophisticated way though (per trip) - SHAPE_IDS_NOT_WITHIN_BUFFER_SQL = \ - "SELECT DISTINCT shape_id FROM SHAPES " \ - "WHERE CAST(find_distance(lat, lon, {buffer_lat}, {buffer_lon}) AS INT) > {buffer_distance_meters}" \ - .format(buffer_lat=self.buffer_lat, - buffer_lon=self.buffer_lon, - buffer_distance_meters=self.buffer_distance_km * 1000) - DELETE_ALL_SHAPE_IDS_NOT_WITHIN_BUFFER_SQL = "DELETE FROM shapes WHERE shape_id IN (" \ - + SHAPE_IDS_NOT_WITHIN_BUFFER_SQL + ")" + SHAPE_IDS_NOT_WITHIN_BUFFER_SQL = ( + "SELECT DISTINCT shape_id FROM SHAPES " + "WHERE CAST(find_distance(lat, lon, {buffer_lat}, {buffer_lon}) AS INT) > {buffer_distance_meters}".format( + buffer_lat=self.buffer_lat, + buffer_lon=self.buffer_lon, + buffer_distance_meters=self.buffer_distance_km * 1000, + ) + ) + DELETE_ALL_SHAPE_IDS_NOT_WITHIN_BUFFER_SQL = ( + "DELETE FROM shapes WHERE shape_id IN (" + SHAPE_IDS_NOT_WITHIN_BUFFER_SQL + ")" + ) self.copy_db_conn.execute(DELETE_ALL_SHAPE_IDS_NOT_WITHIN_BUFFER_SQL) - SET_SHAPE_ID_TO_NULL_FOR_HARD_BUFFER_FILTERED_SHAPE_IDS = \ - "UPDATE trips SET shape_id=NULL WHERE trips.shape_id IN (" + SHAPE_IDS_NOT_WITHIN_BUFFER_SQL + ")" + SET_SHAPE_ID_TO_NULL_FOR_HARD_BUFFER_FILTERED_SHAPE_IDS = ( + "UPDATE trips SET shape_id=NULL WHERE trips.shape_id IN (" + + SHAPE_IDS_NOT_WITHIN_BUFFER_SQL + + ")" + ) self.copy_db_conn.execute(SET_SHAPE_ID_TO_NULL_FOR_HARD_BUFFER_FILTERED_SHAPE_IDS) - # Delete trips with only one stop - self.copy_db_conn.execute('DELETE FROM stop_times WHERE ' - 'trip_I IN (SELECT trip_I FROM ' - '(SELECT trip_I, count(*) AS N_stops from stop_times ' - 'GROUP BY trip_I) q1 ' - 'WHERE N_stops = 1)') + self.copy_db_conn.execute( + "DELETE FROM stop_times WHERE " + "trip_I IN (SELECT trip_I FROM " + "(SELECT trip_I, count(*) AS N_stops from stop_times " + "GROUP BY trip_I) q1 " + "WHERE N_stops = 1)" + ) # Delete trips with only one stop but several instances in stop_times - self.copy_db_conn.execute('DELETE FROM stop_times WHERE ' - 'trip_I IN (SELECT q1.trip_I AS trip_I FROM ' - '(SELECT trip_I, stop_I, count(*) AS stops_per_stop FROM stop_times ' - 'GROUP BY trip_I, stop_I) q1, ' - '(SELECT trip_I, count(*) as n_stops FROM stop_times ' - 'GROUP BY trip_I) q2 ' - 'WHERE q1.trip_I = q2.trip_I AND n_stops = stops_per_stop)') + self.copy_db_conn.execute( + "DELETE FROM stop_times WHERE " + "trip_I IN (SELECT q1.trip_I AS trip_I FROM " + "(SELECT trip_I, stop_I, count(*) AS stops_per_stop FROM stop_times " + "GROUP BY trip_I, stop_I) q1, " + "(SELECT trip_I, count(*) as n_stops FROM stop_times " + "GROUP BY trip_I) q2 " + "WHERE q1.trip_I = q2.trip_I AND n_stops = stops_per_stop)" + ) # Delete all stop_times for uncovered stops delete_stops_not_in_stop_times_and_not_as_parent_stop(self.copy_db_conn) @@ -441,58 +515,69 @@ def _update_metadata(self): print("Updating metadata") logging.info("Updating metadata") G_copy = gtfs.GTFS(self.copy_db_conn) - G_copy.meta['copied_from'] = self.this_db_path - G_copy.meta['copy_time_ut'] = time.time() - G_copy.meta['copy_time'] = time.ctime() + G_copy.meta["copied_from"] = self.this_db_path + G_copy.meta["copy_time_ut"] = time.time() + G_copy.meta["copy_time"] = time.ctime() # Copy some keys directly. try: - for key in ['original_gtfs', - 'download_date', - 'location_name', - 'timezone', ]: + for key in [ + "original_gtfs", + "download_date", + "location_name", + "timezone", + ]: G_copy.meta[key] = G_orig.meta[key] # This part is for gtfs objects with multiple sources except: for k, v in G_copy.meta.items(): - if 'feed_' in k: + if "feed_" in k: G_copy.meta[k] = G_orig.meta[k] - for key in ['location_name', - 'timezone', ]: + for key in [ + "location_name", + "timezone", + ]: G_copy.meta[key] = G_orig.meta[key] # Update *all* original metadata under orig_ namespace. - G_copy.meta.update(('orig_' + k, v) for k, v in G_orig.meta.items()) + G_copy.meta.update(("orig_" + k, v) for k, v in G_orig.meta.items()) stats.update_stats(G_copy) # print "Vacuuming..." - self.copy_db_conn.execute('VACUUM;') + self.copy_db_conn.execute("VACUUM;") # print "Analyzing..." - self.copy_db_conn.execute('ANALYZE;') + self.copy_db_conn.execute("ANALYZE;") self.copy_db_conn.commit() return + def delete_stops_not_in_stop_times_and_not_as_parent_stop(conn): - _STOPS_REFERENCED_IN_STOP_TIMES_OR_AS_PARENT_STOP_I_SQL = \ - "SELECT DISTINCT stop_I FROM stop_times " \ - "UNION " \ + _STOPS_REFERENCED_IN_STOP_TIMES_OR_AS_PARENT_STOP_I_SQL = ( + "SELECT DISTINCT stop_I FROM stop_times " + "UNION " "SELECT DISTINCT parent_I as stop_I FROM stops WHERE parent_I IS NOT NULL" - DELETE_STOPS_NOT_REFERENCED_IN_STOP_TIMES_AND_NOT_PARENT_STOP_SQL = \ - "DELETE FROM stops WHERE stop_I NOT IN (" + \ - _STOPS_REFERENCED_IN_STOP_TIMES_OR_AS_PARENT_STOP_I_SQL + ")" + ) + DELETE_STOPS_NOT_REFERENCED_IN_STOP_TIMES_AND_NOT_PARENT_STOP_SQL = ( + "DELETE FROM stops WHERE stop_I NOT IN (" + + _STOPS_REFERENCED_IN_STOP_TIMES_OR_AS_PARENT_STOP_I_SQL + + ")" + ) # It is possible that there is some "parent_I" recursion going on, and thus we # execute the same SQL query three times. conn.execute(DELETE_STOPS_NOT_REFERENCED_IN_STOP_TIMES_AND_NOT_PARENT_STOP_SQL) conn.execute(DELETE_STOPS_NOT_REFERENCED_IN_STOP_TIMES_AND_NOT_PARENT_STOP_SQL) conn.execute(DELETE_STOPS_NOT_REFERENCED_IN_STOP_TIMES_AND_NOT_PARENT_STOP_SQL) + def add_wgs84_distance_function_to_db(conn): function_name = "find_distance" conn.create_function(function_name, 4, wgs84_distance) return function_name -def remove_all_trips_fully_outside_buffer(db_conn, center_lat, center_lon, buffer_km, update_secondary_data=True): +def remove_all_trips_fully_outside_buffer( + db_conn, center_lat, center_lon, buffer_km, update_secondary_data=True +): """ Not used in the regular filter process for the time being. @@ -505,15 +590,31 @@ def remove_all_trips_fully_outside_buffer(db_conn, center_lat, center_lon, buffe buffer_km: float """ distance_function_str = add_wgs84_distance_function_to_db(db_conn) - stops_within_buffer_query_sql = "SELECT stop_I FROM stops WHERE CAST(" + distance_function_str + \ - "(lat, lon, {lat} , {lon}) AS INT) < {d_m}"\ - .format(lat=float(center_lat), lon=float(center_lon), d_m=int(1000*buffer_km)) - select_all_trip_Is_where_stop_I_is_within_buffer_sql = "SELECT distinct(trip_I) FROM stop_times WHERE stop_I IN (" + stops_within_buffer_query_sql + ")" - trip_Is_to_remove_sql = "SELECT trip_I FROM trips WHERE trip_I NOT IN ( " + select_all_trip_Is_where_stop_I_is_within_buffer_sql + ")" + stops_within_buffer_query_sql = ( + "SELECT stop_I FROM stops WHERE CAST(" + + distance_function_str + + "(lat, lon, {lat} , {lon}) AS INT) < {d_m}".format( + lat=float(center_lat), lon=float(center_lon), d_m=int(1000 * buffer_km) + ) + ) + select_all_trip_Is_where_stop_I_is_within_buffer_sql = ( + "SELECT distinct(trip_I) FROM stop_times WHERE stop_I IN (" + + stops_within_buffer_query_sql + + ")" + ) + trip_Is_to_remove_sql = ( + "SELECT trip_I FROM trips WHERE trip_I NOT IN ( " + + select_all_trip_Is_where_stop_I_is_within_buffer_sql + + ")" + ) trip_Is_to_remove = pandas.read_sql(trip_Is_to_remove_sql, db_conn)["trip_I"].values trip_Is_to_remove_string = ",".join([str(trip_I) for trip_I in trip_Is_to_remove]) - remove_all_trips_fully_outside_buffer_sql = "DELETE FROM trips WHERE trip_I IN (" + trip_Is_to_remove_string + ")" - remove_all_stop_times_where_trip_I_fully_outside_buffer_sql = "DELETE FROM stop_times WHERE trip_I IN (" + trip_Is_to_remove_string + ")" + remove_all_trips_fully_outside_buffer_sql = ( + "DELETE FROM trips WHERE trip_I IN (" + trip_Is_to_remove_string + ")" + ) + remove_all_stop_times_where_trip_I_fully_outside_buffer_sql = ( + "DELETE FROM stop_times WHERE trip_I IN (" + trip_Is_to_remove_string + ")" + ) db_conn.execute(remove_all_trips_fully_outside_buffer_sql) db_conn.execute(remove_all_stop_times_where_trip_I_fully_outside_buffer_sql) delete_stops_not_in_stop_times_and_not_as_parent_stop(db_conn) @@ -539,25 +640,27 @@ def remove_dangling_shapes(db_conn): connection to the GTFS object """ db_conn.execute(DELETE_SHAPES_NOT_REFERENCED_IN_TRIPS_SQL) - SELECT_MIN_MAX_SHAPE_BREAKS_BY_TRIP_I_SQL = \ - "SELECT trips.trip_I, shape_id, min(shape_break) as min_shape_break, max(shape_break) as max_shape_break FROM trips, stop_times WHERE trips.trip_I=stop_times.trip_I GROUP BY trips.trip_I" - trip_min_max_shape_seqs= pandas.read_sql(SELECT_MIN_MAX_SHAPE_BREAKS_BY_TRIP_I_SQL, db_conn) + SELECT_MIN_MAX_SHAPE_BREAKS_BY_TRIP_I_SQL = "SELECT trips.trip_I, shape_id, min(shape_break) as min_shape_break, max(shape_break) as max_shape_break FROM trips, stop_times WHERE trips.trip_I=stop_times.trip_I GROUP BY trips.trip_I" + trip_min_max_shape_seqs = pandas.read_sql(SELECT_MIN_MAX_SHAPE_BREAKS_BY_TRIP_I_SQL, db_conn) rows = [] for row in trip_min_max_shape_seqs.itertuples(): - shape_id, min_shape_break, max_shape_break = row.shape_id, row.min_shape_break, row.max_shape_break + shape_id, min_shape_break, max_shape_break = ( + row.shape_id, + row.min_shape_break, + row.max_shape_break, + ) if min_shape_break is None or max_shape_break is None: - min_shape_break = float('-inf') - max_shape_break = float('-inf') - rows.append( (shape_id, min_shape_break, max_shape_break) ) + min_shape_break = float("-inf") + max_shape_break = float("-inf") + rows.append((shape_id, min_shape_break, max_shape_break)) DELETE_SQL_BASE = "DELETE FROM shapes WHERE shape_id=? AND (seq?)" db_conn.executemany(DELETE_SQL_BASE, rows) remove_dangling_shapes_references(db_conn) def remove_dangling_shapes_references(db_conn): - remove_danging_shapes_references_sql = \ - "UPDATE trips SET shape_id=NULL WHERE trips.shape_id NOT IN (SELECT DISTINCT shape_id FROM shapes)" + remove_danging_shapes_references_sql = "UPDATE trips SET shape_id=NULL WHERE trips.shape_id NOT IN (SELECT DISTINCT shape_id FROM shapes)" db_conn.execute(remove_danging_shapes_references_sql) @@ -573,38 +676,56 @@ def _split_trip(copy_db_conn, orig_trip_I, stop_times_within_buffer_df): next_block = [] if len(next_block) > 1: blocks.append(next_block) - orig_trip_df = pandas.read_sql("SELECT * FROM trips WHERE trip_I={trip_I}".format(trip_I=orig_trip_I), copy_db_conn) + orig_trip_df = pandas.read_sql( + "SELECT * FROM trips WHERE trip_I={trip_I}".format(trip_I=orig_trip_I), copy_db_conn + ) orig_trip_dict = orig_trip_df.to_dict(orient="records")[0] for i, seq_block in enumerate(blocks): # create new trip for each block, # with mostly same trip information as the original - trip_id_generated = orig_trip_dict['trip_id'] + "_splitted_part_" + str(i) - insert_generated_trip_sql = \ - "INSERT INTO trips (trip_id, route_I, service_I, direction_id, " \ - "shape_id, headsign, start_time_ds, end_time_ds) " \ + trip_id_generated = orig_trip_dict["trip_id"] + "_splitted_part_" + str(i) + insert_generated_trip_sql = ( + "INSERT INTO trips (trip_id, route_I, service_I, direction_id, " + "shape_id, headsign, start_time_ds, end_time_ds) " "VALUES (?, ?, ?, ?, ?, ?, ?, ?)" - values = [trip_id_generated, orig_trip_dict['route_I'], - orig_trip_dict['service_I'], orig_trip_dict['direction_id'], - None, orig_trip_dict['headsign'], None, None] + ) + values = [ + trip_id_generated, + orig_trip_dict["route_I"], + orig_trip_dict["service_I"], + orig_trip_dict["direction_id"], + None, + orig_trip_dict["headsign"], + None, + None, + ] copy_db_conn.execute(insert_generated_trip_sql, values) - block_trip_I = copy_db_conn.execute("SELECT trip_I from trips WHERE trips.trip_id=?", - [trip_id_generated]).fetchone()[0] + block_trip_I = copy_db_conn.execute( + "SELECT trip_I from trips WHERE trips.trip_id=?", [trip_id_generated] + ).fetchone()[0] # alter the trip_I values in the stop_times table for seq_values_to_update_str = "(" + ",".join(str(seq) for seq in seq_block) + ")" - stop_times_update_sql = \ - "UPDATE stop_times SET trip_I={trip_I_generated} WHERE trip_I={orig_trip_I} AND seq IN {seq_block}".format( - trip_I_generated=block_trip_I, - orig_trip_I=orig_trip_I, - seq_block=seq_values_to_update_str - ) + stop_times_update_sql = "UPDATE stop_times SET trip_I={trip_I_generated} WHERE trip_I={orig_trip_I} AND seq IN {seq_block}".format( + trip_I_generated=block_trip_I, + orig_trip_I=orig_trip_I, + seq_block=seq_values_to_update_str, + ) copy_db_conn.execute(stop_times_update_sql) - copy_db_conn.execute("DELETE FROM trips WHERE trip_I={orig_trip_I}".format(orig_trip_I=orig_trip_I)) - copy_db_conn.execute("DELETE from stop_times WHERE trip_I={orig_trip_I}".format(orig_trip_I=orig_trip_I)) - copy_db_conn.execute("DELETE FROM shapes WHERE shape_id IN " - " (SELECT DISTINCT shapes.shape_id FROM shapes, trips " - " WHERE trip_I={orig_trip_I} AND shapes.shape_id=trips.shape_id)" - .format(orig_trip_I=orig_trip_I)) + copy_db_conn.execute( + "DELETE FROM trips WHERE trip_I={orig_trip_I}".format(orig_trip_I=orig_trip_I) + ) + copy_db_conn.execute( + "DELETE from stop_times WHERE trip_I={orig_trip_I}".format(orig_trip_I=orig_trip_I) + ) + copy_db_conn.execute( + "DELETE FROM shapes WHERE shape_id IN " + " (SELECT DISTINCT shapes.shape_id FROM shapes, trips " + " WHERE trip_I={orig_trip_I} AND shapes.shape_id=trips.shape_id)".format( + orig_trip_I=orig_trip_I + ) + ) + def update_secondary_data_copies(db_conn): G = gtfspy.gtfs.GTFS(db_conn) diff --git a/gtfspy/geometry.py b/gtfspy/geometry.py index 9dadac4..7526f38 100644 --- a/gtfspy/geometry.py +++ b/gtfspy/geometry.py @@ -18,33 +18,38 @@ def get_convex_hull_coordinates(gtfs): lons, lats = _get_stop_lat_lons(gtfs) lon_lats = list(zip(lons, lats)) polygon = MultiPoint(lon_lats).convex_hull - hull_lons, hull_lats= polygon.exterior.coords.xy + hull_lons, hull_lats = polygon.exterior.coords.xy return hull_lats, hull_lons + def _get_stop_lat_lons(gtfs): stops = gtfs.stops() - lats = stops['lat'] - lons = stops['lon'] + lats = stops["lat"] + lons = stops["lon"] return lons, lats + def get_approximate_convex_hull_area_km2(gtfs): lons, lats = _get_stop_lat_lons(gtfs) return approximate_convex_hull_area(lons, lats) + def approximate_convex_hull_area(lons, lats): lon_meters, lat_meters = _get_lon_lat_meters(lons, lats) lon_lat_meters = list(zip(lon_meters, lat_meters)) return MultiPoint(lon_lat_meters).convex_hull.area / 1000 ** 2 + def _get_lon_lat_meters(lons, lats): lat_min = min(lats) lat_max = max(lats) - lat_mean = (lat_max + lat_min) / 2. + lat_mean = (lat_max + lat_min) / 2.0 lon_min = min(lons) lon_max = max(lons) - lon_mean = (lon_max + lon_min) / 2. + lon_mean = (lon_max + lon_min) / 2.0 from gtfspy.util import wgs84_distance + lat_span_meters = wgs84_distance(lat_min, lon_mean, lat_max, lon_mean) lon_span_meters = wgs84_distance(lat_mean, lon_min, lat_mean, lon_max) @@ -62,7 +67,7 @@ def get_buffered_area_of_stops(gtfs, buffer_meters, resolution): ---------- gtfs: gtfs.GTFS buffer_meters: meters around the stop to buffer. - resolution: increases the accuracy of the calculated area with computation time. Default = 16 + resolution: increases the accuracy of the calculated area with computation time. Default = 16 Returns ------- @@ -88,5 +93,8 @@ def compute_buffered_area_of_stops(lats, lons, buffer_meters, resolution=16): lon_meters = lons lat_meters = lats - return MultiPoint(points=list(zip(lon_meters, lat_meters))).buffer(buffer_meters, resolution=resolution).area - + return ( + MultiPoint(points=list(zip(lon_meters, lat_meters))) + .buffer(buffer_meters, resolution=resolution) + .area + ) diff --git a/gtfspy/gtfs.py b/gtfspy/gtfs.py index 084bbd6..5f885c6 100644 --- a/gtfspy/gtfs.py +++ b/gtfspy/gtfs.py @@ -21,7 +21,6 @@ class GTFS(object): - def __init__(self, fname_or_conn): """Open a GTFS object @@ -35,9 +34,9 @@ def __init__(self, fname_or_conn): self.conn = sqlite3.connect(fname_or_conn) self.fname = fname_or_conn # memory-mapped IO size, in bytes - self.conn.execute('PRAGMA mmap_size = 1000000000;') + self.conn.execute("PRAGMA mmap_size = 1000000000;") # page cache size, in negative KiB. - self.conn.execute('PRAGMA cache_size = -2000000;') + self.conn.execute("PRAGMA cache_size = -2000000;") else: raise FileNotFoundError("File " + fname_or_conn + " missing") elif isinstance(fname_or_conn, sqlite3.Connection): @@ -45,9 +44,15 @@ def __init__(self, fname_or_conn): self._dont_close = True else: raise NotImplementedError( - "Initiating GTFS using an object with type " + str(type(fname_or_conn)) + " is not supported") - - assert self.conn.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchone() is not None + "Initiating GTFS using an object with type " + + str(type(fname_or_conn)) + + " is not supported" + ) + + assert ( + self.conn.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchone() + is not None + ) self.meta = GTFSMetadata(self.conn) # Bind functions @@ -57,7 +62,7 @@ def __init__(self, fname_or_conn): self._timezone = pytz.timezone(self.get_timezone_name()) def __del__(self): - if not getattr(self, '_dont_close', False) and hasattr(self, "conn"): + if not getattr(self, "_dont_close", False) and hasattr(self, "conn"): self.conn.close() @classmethod @@ -72,11 +77,9 @@ def from_directory_as_inmemory_db(cls, gtfs_directory): """ # this import is here to avoid circular imports (which turned out to be a problem) from gtfspy.import_gtfs import import_gtfs + conn = sqlite3.connect(":memory:") - import_gtfs(gtfs_directory, - conn, - preserve_connection=True, - print_progress=False) + import_gtfs(gtfs_directory, conn, preserve_connection=True, print_progress=False) return cls(conn) def get_main_database_path(self): @@ -96,7 +99,7 @@ def get_main_database_path(self): return row[2] def get_location_name(self): - return self.meta.get('location_name', "location_unknown") + return self.meta.get("location_name", "location_unknown") def get_shape_distance_between_stops(self, trip_I, from_stop_seq, to_stop_seq): """ @@ -124,10 +127,14 @@ def get_shape_distance_between_stops(self, trip_I, from_stop_seq, to_stop_seq): for seq in stop_seqs: q = query_template.format(seq=seq, trip_I=trip_I) shape_breaks.append(self.conn.execute(q).fetchone()) - query_template = "SELECT max(d) - min(d) " \ - "FROM shapes JOIN trips ON(trips.shape_id=shapes.shape_id) " \ - "WHERE trip_I={trip_I} AND shapes.seq>={from_stop_seq} AND shapes.seq<={to_stop_seq};" - distance_query = query_template.format(trip_I=trip_I, from_stop_seq=from_stop_seq, to_stop_seq=to_stop_seq) + query_template = ( + "SELECT max(d) - min(d) " + "FROM shapes JOIN trips ON(trips.shape_id=shapes.shape_id) " + "WHERE trip_I={trip_I} AND shapes.seq>={from_stop_seq} AND shapes.seq<={to_stop_seq};" + ) + distance_query = query_template.format( + trip_I=trip_I, from_stop_seq=from_stop_seq, to_stop_seq=to_stop_seq + ) return self.conn.execute(distance_query).fetchone()[0] def get_stop_distance(self, from_stop_I, to_stop_I): @@ -140,8 +147,11 @@ def get_stop_distance(self, from_stop_I, to_stop_I): def get_stops_within_distance(self, stop, distance): query = """SELECT stops.* FROM stop_distances, stops - WHERE stop_distances.to_stop_I = stops.stop_I - AND d < %s AND from_stop_I = %s""" % (distance, stop) + WHERE stop_distances.to_stop_I = stops.stop_I + AND d < %s AND from_stop_I = %s""" % ( + distance, + stop, + ) return pd.read_sql_query(query, self.conn) def get_directly_accessible_stops_within_distance(self, stop, distance): @@ -152,15 +162,18 @@ def get_directly_accessible_stops_within_distance(self, stop, distance): :return: """ query = """SELECT stop.* FROM - (SELECT st2.* FROM + (SELECT st2.* FROM (SELECT * FROM stop_distances WHERE from_stop_I = %s) sd, (SELECT * FROM stop_times) st1, (SELECT * FROM stop_times) st2 - WHERE sd.d < %s AND sd.to_stop_I = st1.stop_I AND st1.trip_I = st2.trip_I + WHERE sd.d < %s AND sd.to_stop_I = st1.stop_I AND st1.trip_I = st2.trip_I GROUP BY st2.stop_I) sq, (SELECT * FROM stops) stop - WHERE sq.stop_I = stop.stop_I""" % (stop, distance) + WHERE sq.stop_I = stop.stop_I""" % ( + stop, + distance, + ) return pd.read_sql_query(query, self.conn) def get_cursor(self): @@ -198,7 +211,9 @@ def get_table_names(self): ------- table_names: list[str] """ - return list(pd.read_sql("SELECT * FROM main.sqlite_master WHERE type='table'", self.conn)["name"]) + return list( + pd.read_sql("SELECT * FROM main.sqlite_master WHERE type='table'", self.conn)["name"] + ) def set_current_process_time_zone(self): """ @@ -212,7 +227,7 @@ def set_current_process_time_zone(self): Alters os.environ['TZ'] """ - TZ = self.conn.execute('SELECT timezone FROM agencies LIMIT 1').fetchall()[0][0] + TZ = self.conn.execute("SELECT timezone FROM agencies LIMIT 1").fetchall()[0][0] # TODO!: This is dangerous (?). # In my opinion, we should get rid of this at some point (RK): return set_process_timezone(TZ) @@ -229,7 +244,7 @@ def get_timezone_name(self): timezone_name : str name of the time zone, e.g. "Europe/Helsinki" """ - tz_name = self.conn.execute('SELECT timezone FROM agencies LIMIT 1').fetchone() + tz_name = self.conn.execute("SELECT timezone FROM agencies LIMIT 1").fetchone() if tz_name is None: raise ValueError("This database does not have a timezone defined.") return tz_name[0] @@ -255,9 +270,9 @@ def get_timezone_string(self, dt=None): timezone_string : str """ if dt is None: - download_date = self.meta.get('download_date') + download_date = self.meta.get("download_date") if download_date: - dt = datetime.datetime.strptime(download_date, '%Y-%m-%d') + dt = datetime.datetime.strptime(download_date, "%Y-%m-%d") else: dt = datetime.datetime.today() loc_dt = self._timezone.localize(dt) @@ -312,7 +327,7 @@ def get_day_start_ut(self, date): start time of the day in unixtime """ if isinstance(date, string_types): - date = datetime.datetime.strptime(date, '%Y-%m-%d') + date = datetime.datetime.strptime(date, "%Y-%m-%d") date_noon = datetime.datetime(date.year, date.month, date.day, 12, 0, 0) ut_noon = self.unlocalized_datetime_to_ut_seconds(date_noon) @@ -358,8 +373,8 @@ def get_trip_trajectories_within_timespan(self, start, end, use_shapes=True, fil trip = {} name, route_type = self.get_route_name_and_type_of_tripI(trip_I) - trip['route_type'] = int(route_type) - trip['name'] = str(name) + trip["route_type"] = int(route_type) + trip["name"] = str(name) if filter_name and (name != filter_name): continue @@ -393,23 +408,25 @@ def get_trip_trajectories_within_timespan(self, start, end, use_shapes=True, fil shape_data = shape_cache[shape_id] # noinspection PyBroadException try: - trip['times'] = shapes.interpolate_shape_times(shape_data['d'], shape_breaks, stop_dep_times) - trip['lats'] = shape_data['lats'] - trip['lons'] = shape_data['lons'] + trip["times"] = shapes.interpolate_shape_times( + shape_data["d"], shape_breaks, stop_dep_times + ) + trip["lats"] = shape_data["lats"] + trip["lons"] = shape_data["lons"] start_break = shape_breaks[0] end_break = shape_breaks[-1] - trip['times'] = trip['times'][start_break:end_break + 1] - trip['lats'] = trip['lats'][start_break:end_break + 1] - trip['lons'] = trip['lons'][start_break:end_break + 1] + trip["times"] = trip["times"][start_break : end_break + 1] + trip["lats"] = trip["lats"][start_break : end_break + 1] + trip["lons"] = trip["lons"][start_break : end_break + 1] except: # In case interpolation fails: - trip['times'] = stop_dep_times - trip['lats'] = stop_lats - trip['lons'] = stop_lons + trip["times"] = stop_dep_times + trip["lats"] = stop_lats + trip["lons"] = stop_lons else: - trip['times'] = stop_dep_times - trip['lats'] = stop_lats - trip['lons'] = stop_lons + trip["times"] = stop_dep_times + trip["lats"] = stop_lats + trip["lons"] = stop_lons trips.append(trip) return {"trips": trips} @@ -442,7 +459,9 @@ def get_stop_count_data(self, start_ut, end_ut): # get stop_data and store it: stops_seq = self.get_trip_stop_time_data(row.trip_I, row.day_start_ut) for stop_time_row in stops_seq.itertuples(index=False): - if (stop_time_row.dep_time_ut >= start_ut) and (stop_time_row.dep_time_ut <= end_ut): + if (stop_time_row.dep_time_ut >= start_ut) and ( + stop_time_row.dep_time_ut <= end_ut + ): stop_counts[stop_time_row.stop_I] += 1 all_stop_data = self.stops() @@ -486,7 +505,9 @@ def get_segment_count_data(self, start, end, use_shapes=True): stops_df = self.get_trip_stop_time_data(row.trip_I, row.day_start_ut) for i in range(len(stops_df) - 1): (stop_I, dep_time_ut, s_lat, s_lon, s_seq, shape_break) = stops_df.iloc[i] - (stop_I_n, dep_time_ut_n, s_lat_n, s_lon_n, s_seq_n, shape_break_n) = stops_df.iloc[i + 1] + (stop_I_n, dep_time_ut_n, s_lat_n, s_lon_n, s_seq_n, shape_break_n) = stops_df.iloc[ + i + 1 + ] # test if _contained_ in the interval # overlap would read: # (dep_time_ut <= end) and (start <= dep_time_ut_n) @@ -495,12 +516,12 @@ def get_segment_count_data(self, start, end, use_shapes=True): segment_counts[seg] += 1 if seg not in seg_to_info: seg_to_info[seg] = { - u"trip_I": row.trip_I, - u"lats": [s_lat, s_lat_n], - u"lons": [s_lon, s_lon_n], - u"shape_id": row.shape_id, - u"stop_seqs": [s_seq, s_seq_n], - u"shape_breaks": [shape_break, shape_break_n] + "trip_I": row.trip_I, + "lats": [s_lat, s_lat_n], + "lons": [s_lon, s_lon_n], + "shape_id": row.shape_id, + "stop_seqs": [s_seq, s_seq_n], + "shape_breaks": [shape_break, shape_break_n], } tripI_to_seq[row.trip_I].append(seg) @@ -508,30 +529,27 @@ def get_segment_count_data(self, start, end, use_shapes=True): for (stop_I, stop_J) in segment_counts.keys(): for s in [stop_I, stop_J]: if s not in stop_names: - stop_names[s] = self.stop(s)[u'name'].values[0] + stop_names[s] = self.stop(s)["name"].values[0] seg_data = [] for seg, count in segment_counts.items(): segInfo = seg_to_info[seg] - shape_breaks = segInfo[u"shape_breaks"] + shape_breaks = segInfo["shape_breaks"] seg_el = {} if use_shapes and shape_breaks and shape_breaks[0] and shape_breaks[1]: shape = shapes.get_shape_between_stops( - cur, - segInfo[u'trip_I'], - shape_breaks=shape_breaks + cur, segInfo["trip_I"], shape_breaks=shape_breaks ) - seg_el[u'lats'] = segInfo[u'lats'][:1] + shape[u'lat'] + segInfo[u'lats'][1:] - seg_el[u'lons'] = segInfo[u'lons'][:1] + shape[u'lon'] + segInfo[u'lons'][1:] + seg_el["lats"] = segInfo["lats"][:1] + shape["lat"] + segInfo["lats"][1:] + seg_el["lons"] = segInfo["lons"][:1] + shape["lon"] + segInfo["lons"][1:] else: - seg_el[u'lats'] = segInfo[u'lats'] - seg_el[u'lons'] = segInfo[u'lons'] - seg_el[u'name'] = stop_names[seg[0]] + u"-" + stop_names[seg[1]] - seg_el[u'count'] = count + seg_el["lats"] = segInfo["lats"] + seg_el["lons"] = segInfo["lons"] + seg_el["name"] = stop_names[seg[0]] + "-" + stop_names[seg[1]] + seg_el["count"] = count seg_data.append(seg_el) return seg_data - def get_all_route_shapes(self, use_shapes=True): """ Get the shapes of all routes. @@ -556,31 +574,38 @@ def get_all_route_shapes(self, use_shapes=True): # FROM trips LEFT JOIN routes USING(route_I)" # data1 = pd.read_sql_query(query, self.conn) # one (arbitrary) shape_id per route_I ("one direction") -> less than half of the routes - query = "SELECT routes.name as name, shape_id, route_I, trip_I, routes.type, " \ - " agency_id, agencies.name as agency_name, max(end_time_ds-start_time_ds) as trip_duration " \ - "FROM trips " \ - "LEFT JOIN routes " \ - "USING(route_I) " \ - "LEFT JOIN agencies " \ - "USING(agency_I) " \ - "GROUP BY routes.route_I" + query = ( + "SELECT routes.name as name, shape_id, route_I, trip_I, routes.type, " + " agency_id, agencies.name as agency_name, max(end_time_ds-start_time_ds) as trip_duration " + "FROM trips " + "LEFT JOIN routes " + "USING(route_I) " + "LEFT JOIN agencies " + "USING(agency_I) " + "GROUP BY routes.route_I" + ) data = pd.read_sql_query(query, self.conn) routeShapes = [] for i, row in enumerate(data.itertuples()): - datum = {"name": str(row.name), "type": int(row.type), "route_I": row.route_I, "agency": str(row.agency_id), - "agency_name": str(row.agency_name)} + datum = { + "name": str(row.name), + "type": int(row.type), + "route_I": row.route_I, + "agency": str(row.agency_id), + "agency_name": str(row.agency_name), + } # this function should be made also non-shape friendly (at this point) if use_shapes and row.shape_id: shape = shapes.get_shape_points2(cur, row.shape_id) - lats = shape['lats'] - lons = shape['lons'] + lats = shape["lats"] + lons = shape["lons"] else: stop_shape = self.get_trip_stop_coordinates(row.trip_I) - lats = list(stop_shape['lat']) - lons = list(stop_shape['lon']) - datum['lats'] = [float(lat) for lat in lats] - datum['lons'] = [float(lon) for lon in lons] + lats = list(stop_shape["lat"]) + lons = list(stop_shape["lon"]) + datum["lats"] = [float(lat) for lat in lats] + datum["lons"] = [float(lon) for lon in lons] routeShapes.append(datum) return routeShapes @@ -603,10 +628,13 @@ def get_tripIs_active_in_range(self, start, end): trip_I, day_start_ut, start_time_ut, end_time_ut, shape_id """ to_select = "trip_I, day_start_ut, start_time_ut, end_time_ut, shape_id " - query = "SELECT " + to_select + \ - "FROM day_trips " \ - "WHERE " \ - "(end_time_ut > {start_ut} AND start_time_ut < {end_ut})".format(start_ut=start, end_ut=end) + query = ( + "SELECT " + to_select + "FROM day_trips " + "WHERE " + "(end_time_ut > {start_ut} AND start_time_ut < {end_ut})".format( + start_ut=start, end_ut=end + ) + ) return pd.read_sql_query(query, self.conn) def get_trip_counts_per_day(self): @@ -625,8 +653,8 @@ def get_trip_counts_per_day(self): # (necessary for some visualizations) max_day = trip_counts_per_day.index.max() min_day = trip_counts_per_day.index.min() - min_date = datetime.datetime.strptime(min_day, '%Y-%m-%d') - max_date = datetime.datetime.strptime(max_day, '%Y-%m-%d') + min_date = datetime.datetime.strptime(min_day, "%Y-%m-%d") + max_date = datetime.datetime.strptime(max_day, "%Y-%m-%d") num_days = (max_date - min_date).days dates = [min_date + datetime.timedelta(days=x) for x in range(num_days + 1)] trip_counts = [] @@ -635,7 +663,7 @@ def get_trip_counts_per_day(self): date_string = date.strftime("%Y-%m-%d") date_strings.append(date_string) try: - value = trip_counts_per_day.loc[date_string, 'number_of_trips'] + value = trip_counts_per_day.loc[date_string, "number_of_trips"] except KeyError: # set value to 0 if dsut is not present, i.e. when no trips # take place on that day @@ -663,12 +691,12 @@ def get_suitable_date_for_daily_extract(self, date=None, ut=False): If the download date is out of range, the process will look through the dates from first to last. """ daily_trips = self.get_trip_counts_per_day() - max_daily_trips = daily_trips[u'trip_counts'].max(axis=0) - if date in daily_trips[u'date_str']: - start_index = daily_trips[daily_trips[u'date_str'] == date].index.tolist()[0] - daily_trips[u'old_index'] = daily_trips.index - daily_trips[u'date_dist'] = abs(start_index - daily_trips.index) - daily_trips = daily_trips.sort_values(by=[u'date_dist', u'old_index']).reindex() + max_daily_trips = daily_trips["trip_counts"].max(axis=0) + if date in daily_trips["date_str"]: + start_index = daily_trips[daily_trips["date_str"] == date].index.tolist()[0] + daily_trips["old_index"] = daily_trips.index + daily_trips["date_dist"] = abs(start_index - daily_trips.index) + daily_trips = daily_trips.sort_values(by=["date_dist", "old_index"]).reindex() for row in daily_trips.itertuples(): if row.trip_counts >= 0.9 * max_daily_trips: if ut: @@ -676,8 +704,9 @@ def get_suitable_date_for_daily_extract(self, date=None, ut=False): else: return row.date_str - def get_weekly_extract_start_date(self, ut=False, weekdays_at_least_of_max=0.9, - verbose=False, download_date_override=None): + def get_weekly_extract_start_date( + self, ut=False, weekdays_at_least_of_max=0.9, verbose=False, download_date_override=None + ): """ Find a suitable weekly extract start date (monday). The goal is to obtain as 'usual' week as possible. @@ -709,65 +738,86 @@ def get_weekly_extract_start_date(self, ut=False, weekdays_at_least_of_max=0.9, search_start_date = download_date_override else: assert download_date_override is None - download_date_str = self.meta['download_date'] + download_date_str = self.meta["download_date"] if download_date_str == "": - warnings.warn("Download date is not speficied in the database. " - "Download date used in GTFS." + self.get_weekly_extract_start_date.__name__ + - "() defaults to the smallest date when any operations take place.") - search_start_date = daily_trip_counts['date'].min() + warnings.warn( + "Download date is not speficied in the database. " + "Download date used in GTFS." + + self.get_weekly_extract_start_date.__name__ + + "() defaults to the smallest date when any operations take place." + ) + search_start_date = daily_trip_counts["date"].min() else: search_start_date = datetime.datetime.strptime(download_date_str, "%Y-%m-%d") - feed_min_date = daily_trip_counts['date'].min() - feed_max_date = daily_trip_counts['date'].max() - assert (feed_max_date - feed_min_date >= datetime.timedelta(days=7)), \ - "Dataset is not long enough for providing week long extracts" + feed_min_date = daily_trip_counts["date"].min() + feed_max_date = daily_trip_counts["date"].max() + assert feed_max_date - feed_min_date >= datetime.timedelta( + days=7 + ), "Dataset is not long enough for providing week long extracts" # get first a valid monday where the search for the week can be started: - next_monday_from_search_start_date = search_start_date + timedelta(days=(7 - search_start_date.weekday())) + next_monday_from_search_start_date = search_start_date + timedelta( + days=(7 - search_start_date.weekday()) + ) if not (feed_min_date <= next_monday_from_search_start_date <= feed_max_date): - warnings.warn("The next monday after the (possibly user) specified download date is not present in the database." - "Resorting to first monday after the beginning of operations instead.") - next_monday_from_search_start_date = feed_min_date + timedelta(days=(7 - feed_min_date.weekday())) - - max_trip_count = daily_trip_counts['trip_counts'].quantile(0.95) + warnings.warn( + "The next monday after the (possibly user) specified download date is not present in the database." + "Resorting to first monday after the beginning of operations instead." + ) + next_monday_from_search_start_date = feed_min_date + timedelta( + days=(7 - feed_min_date.weekday()) + ) + + max_trip_count = daily_trip_counts["trip_counts"].quantile(0.95) # Take 95th percentile to omit special days, if any exist. threshold = weekdays_at_least_of_max * max_trip_count - threshold_fulfilling_days = daily_trip_counts['trip_counts'] > threshold + threshold_fulfilling_days = daily_trip_counts["trip_counts"] > threshold # look forward first # get the index of the trip: - search_start_monday_index = daily_trip_counts[daily_trip_counts['date'] == next_monday_from_search_start_date].index[0] + search_start_monday_index = daily_trip_counts[ + daily_trip_counts["date"] == next_monday_from_search_start_date + ].index[0] # get starting point while_loop_monday_index = search_start_monday_index while len(daily_trip_counts.index) >= while_loop_monday_index + 7: - if all(threshold_fulfilling_days[while_loop_monday_index:while_loop_monday_index + 5]): + if all( + threshold_fulfilling_days[while_loop_monday_index : while_loop_monday_index + 5] + ): row = daily_trip_counts.iloc[while_loop_monday_index] if ut: return self.get_day_start_ut(row.date_str) else: - return row['date'] + return row["date"] while_loop_monday_index += 7 while_loop_monday_index = search_start_monday_index - 7 # then backwards while while_loop_monday_index >= 0: - if all(threshold_fulfilling_days[while_loop_monday_index:while_loop_monday_index + 5]): + if all( + threshold_fulfilling_days[while_loop_monday_index : while_loop_monday_index + 5] + ): row = daily_trip_counts.iloc[while_loop_monday_index] if ut: return self.get_day_start_ut(row.date_str) else: - return row['date'] + return row["date"] while_loop_monday_index -= 7 raise RuntimeError("No suitable weekly extract start date could be determined!") - def get_spreading_trips(self, start_time_ut, lat, lon, - max_duration_ut=4 * 3600, - min_transfer_time=30, - use_shapes=False): + def get_spreading_trips( + self, + start_time_ut, + lat, + lon, + max_duration_ut=4 * 3600, + min_transfer_time=30, + use_shapes=False, + ): """ Starting from a specific point and time, get complete single source shortest path spreading dynamics as trips, or "events". @@ -799,7 +849,10 @@ def get_spreading_trips(self, start_time_ut, lat, lon, el['name'] : name of the route """ from gtfspy.spreading.spreader import Spreader - spreader = Spreader(self, start_time_ut, lat, lon, max_duration_ut, min_transfer_time, use_shapes) + + spreader = Spreader( + self, start_time_ut, lat, lon, max_duration_ut, min_transfer_time, use_shapes + ) return spreader.spread() def get_closest_stop(self, lat, lon): @@ -831,7 +884,9 @@ def get_closest_stop(self, lat, lon): def get_stop_coordinates(self, stop_I): cur = self.conn.cursor() - results = cur.execute("SELECT lat, lon FROM stops WHERE stop_I={stop_I}".format(stop_I=stop_I)) + results = cur.execute( + "SELECT lat, lon FROM stops WHERE stop_I={stop_I}".format(stop_I=stop_I) + ) lat, lon = results.fetchone() return lat, lon @@ -854,10 +909,12 @@ def get_bounding_box_by_stops(self, stop_Is, buffer_ratio=None): lat_diff = wgs84_height(distance) lon_diff = wgs84_width(distance, (max_lat - min_lat) / 2 + min_lat) - return {"lat_min": min_lat - lat_diff, - "lat_max": max_lat + lat_diff, - "lon_min": min_lon - lon_diff, - "lon_max": max_lon + lon_diff} + return { + "lat_min": min_lat - lat_diff, + "lat_max": max_lat + lat_diff, + "lon_min": min_lon - lon_diff, + "lon_max": max_lon + lon_diff, + } def get_route_name_and_type_of_tripI(self, trip_I): """ @@ -876,10 +933,13 @@ def get_route_name_and_type_of_tripI(self, trip_I): route_type according to the GTFS standard """ cur = self.conn.cursor() - results = cur.execute("SELECT name, type FROM routes JOIN trips USING(route_I) WHERE trip_I={trip_I}" - .format(trip_I=trip_I)) + results = cur.execute( + "SELECT name, type FROM routes JOIN trips USING(route_I) WHERE trip_I={trip_I}".format( + trip_I=trip_I + ) + ) name, rtype = results.fetchone() - return u"%s" % str(name), int(rtype) + return "%s" % str(name), int(rtype) def get_route_name_and_type(self, route_I): """ @@ -921,7 +981,9 @@ def get_trip_stop_coordinates(self, trip_I): JOIN stops USING(stop_I) WHERE trip_I={trip_I} - ORDER BY stop_times.seq""".format(trip_I=trip_I) + ORDER BY stop_times.seq""".format( + trip_I=trip_I + ) stop_coords = pd.read_sql(query, self.conn) return stop_coords @@ -946,16 +1008,23 @@ def get_trip_stop_time_data(self, trip_I, day_start_ut): df has the following columns 'departure_time_ut, lat, lon, seq, shape_break' """ - to_select = "stop_I, " + str(day_start_ut) + "+dep_time_ds AS dep_time_ut, lat, lon, seq, shape_break" - str_to_run = "SELECT " + to_select + """ + to_select = ( + "stop_I, " + + str(day_start_ut) + + "+dep_time_ds AS dep_time_ut, lat, lon, seq, shape_break" + ) + str_to_run = ( + "SELECT " + + to_select + + """ FROM stop_times JOIN stops USING(stop_I) WHERE (trip_I ={trip_I}) ORDER BY seq """ + ) str_to_run = str_to_run.format(trip_I=trip_I) return pd.read_sql_query(str_to_run, self.conn) - def get_events_by_tripI_and_dsut(self, trip_I, day_start_ut, - start_ut=None, end_ut=None): + def get_events_by_tripI_and_dsut(self, trip_I, day_start_ut, start_ut=None, end_ut=None): """ Get trip data as a list of events (i.e. dicts). @@ -995,8 +1064,7 @@ def get_events_by_tripI_and_dsut(self, trip_I, day_start_ut, WHERE (trip_I = ?) """ - params = [day_start_ut, day_start_ut, - trip_I] + params = [day_start_ut, day_start_ut, trip_I] if start_ut: query += "AND (dep_time_ds > ?-?)" params += [start_ut, day_start_ut] @@ -1012,7 +1080,7 @@ def get_events_by_tripI_and_dsut(self, trip_I, day_start_ut, "from_stop": stop_data[i][0], "to_stop": stop_data[i + 1][0], "dep_time_ut": stop_data[i][2], - "arr_time_ut": stop_data[i + 1][1] + "arr_time_ut": stop_data[i + 1][1], } events.append(event) return events @@ -1041,7 +1109,7 @@ def tripI_takes_place_on_dsut(self, trip_I, day_start_ut): if len(rows) == 0: return False else: - assert len(rows) == 1, 'On a day, a trip_I should be present at most once' + assert len(rows) == 1, "On a day, a trip_I should be present at most once" return True # unused and (untested) code: @@ -1124,9 +1192,10 @@ def increment_day_start_ut(self, day_start_ut, n_days=1): """ old_tz = self.set_current_process_time_zone() day0 = time.localtime(day_start_ut + 43200) # time of noon - dayN = time.mktime(day0[:2] + # YYYY, MM - (day0[2] + n_days,) + # DD - (12, 00, 0, 0, 0, -1)) - 43200 # HHMM, etc. Minus 12 hours. + dayN = ( + time.mktime(day0[:2] + (day0[2] + n_days,) + (12, 00, 0, 0, 0, -1)) # YYYY, MM # DD + - 43200 + ) # HHMM, etc. Minus 12 hours. set_process_timezone(old_tz) return dayN @@ -1201,9 +1270,7 @@ def _get_possible_day_starts(self, start_ut, end_ut, max_time_overnight=None): # Return three tuples which can be zip:ped together. return day_start_times_ut, start_times_ds, end_times_ds - def get_tripIs_within_range_by_dsut(self, - start_time_ut, - end_time_ut): + def get_tripIs_within_range_by_dsut(self, start_time_ut, end_time_ut): """ Obtain a list of trip_Is that take place during a time interval. The trip needs to be only partially overlapping with the given time interval. @@ -1224,13 +1291,11 @@ def get_tripIs_within_range_by_dsut(self, """ cur = self.conn.cursor() assert start_time_ut <= end_time_ut - dst_ut, st_ds, et_ds = \ - self._get_possible_day_starts(start_time_ut, end_time_ut, 7) + dst_ut, st_ds, et_ds = self._get_possible_day_starts(start_time_ut, end_time_ut, 7) # noinspection PyTypeChecker assert len(dst_ut) >= 0 trip_I_dict = {} - for day_start_ut, start_ds, end_ds in \ - zip(dst_ut, st_ds, et_ds): + for day_start_ut, start_ds, end_ds in zip(dst_ut, st_ds, et_ds): query = """ SELECT distinct(trip_I) FROM days @@ -1273,14 +1338,16 @@ def stop(self, stop_I): ------- stop: pandas.DataFrame """ - return pd.read_sql_query("SELECT * FROM stops WHERE stop_I={stop_I}".format(stop_I=stop_I), self.conn) + return pd.read_sql_query( + "SELECT * FROM stops WHERE stop_I={stop_I}".format(stop_I=stop_I), self.conn + ) - def add_coordinates_to_df(self, df, join_column='stop_I', lat_name="lat", lon_name="lon"): + def add_coordinates_to_df(self, df, join_column="stop_I", lat_name="lat", lon_name="lon"): assert join_column in df.columns stops_df = self.stops() coord_df = stops_df[["stop_I", "lat", "lon"]] - df_merged = pd.merge(coord_df, df, left_on='stop_I', right_on=join_column) + df_merged = pd.merge(coord_df, df, left_on="stop_I", right_on=join_column) df_merged.drop(["stop_I"], axis=1, inplace=True) df_merged3 = df_merged.rename(columns={"lat": lat_name, "lon": lon_name}) @@ -1290,7 +1357,9 @@ def get_n_stops(self): return pd.read_sql_query("SELECT count(*) from stops;", self.conn).values[0, 0] def get_modes(self): - modes = list(pd.read_sql_query("SELECT distinct(type) from routes;", self.conn).values.flatten()) + modes = list( + pd.read_sql_query("SELECT distinct(type) from routes;", self.conn).values.flatten() + ) return modes def get_stops_for_route_type(self, route_type): @@ -1307,16 +1376,22 @@ def get_stops_for_route_type(self, route_type): if route_type is WALK: return self.stops() else: - return pd.read_sql_query("SELECT DISTINCT stops.* " - "FROM stops JOIN stop_times ON stops.stop_I == stop_times.stop_I " - " JOIN trips ON stop_times.trip_I = trips.trip_I" - " JOIN routes ON trips.route_I == routes.route_I " - "WHERE routes.type=(?)", self.conn, params=(route_type,)) + return pd.read_sql_query( + "SELECT DISTINCT stops.* " + "FROM stops JOIN stop_times ON stops.stop_I == stop_times.stop_I " + " JOIN trips ON stop_times.trip_I = trips.trip_I" + " JOIN routes ON trips.route_I == routes.route_I " + "WHERE routes.type=(?)", + self.conn, + params=(route_type,), + ) def get_stops_connected_to_stop(self): pass - def generate_routable_transit_events(self, start_time_ut=None, end_time_ut=None, route_type=None): + def generate_routable_transit_events( + self, start_time_ut=None, end_time_ut=None, route_type=None + ): """ Generates events that take place during a time interval [start_time_ut, end_time_ut]. Each event needs to be only partially overlap the given time interval. @@ -1343,7 +1418,10 @@ def generate_routable_transit_events(self, start_time_ut=None, end_time_ut=None, seq: int """ from gtfspy.networks import temporal_network - df = temporal_network(self, start_time_ut=start_time_ut, end_time_ut=end_time_ut, route_type=route_type) + + df = temporal_network( + self, start_time_ut=start_time_ut, end_time_ut=end_time_ut, route_type=route_type + ) df.sort_values("dep_time_ut", ascending=False, inplace=True) for row in df.itertuples(): @@ -1381,27 +1459,35 @@ def get_transit_events(self, start_time_ut=None, end_time_ut=None, route_type=No get_transit_events_in_time_span : an older version of the same thing """ table_name = self._get_day_trips_table_name() - event_query = "SELECT stop_I, seq, trip_I, route_I, routes.route_id AS route_id, routes.type AS route_type, " \ - "shape_id, day_start_ut+dep_time_ds AS dep_time_ut, day_start_ut+arr_time_ds AS arr_time_ut " \ - "FROM " + table_name + " " \ - "JOIN trips USING(trip_I) " \ - "JOIN routes USING(route_I) " \ - "JOIN stop_times USING(trip_I)" + event_query = ( + "SELECT stop_I, seq, trip_I, route_I, routes.route_id AS route_id, routes.type AS route_type, " + "shape_id, day_start_ut+dep_time_ds AS dep_time_ut, day_start_ut+arr_time_ds AS arr_time_ut " + "FROM " + table_name + " " + "JOIN trips USING(trip_I) " + "JOIN routes USING(route_I) " + "JOIN stop_times USING(trip_I)" + ) where_clauses = [] if end_time_ut: - where_clauses.append(table_name + ".start_time_ut< {end_time_ut}".format(end_time_ut=end_time_ut)) + where_clauses.append( + table_name + ".start_time_ut< {end_time_ut}".format(end_time_ut=end_time_ut) + ) where_clauses.append("dep_time_ut <={end_time_ut}".format(end_time_ut=end_time_ut)) if start_time_ut: - where_clauses.append(table_name + ".end_time_ut > {start_time_ut}".format(start_time_ut=start_time_ut)) - where_clauses.append("arr_time_ut >={start_time_ut}".format(start_time_ut=start_time_ut)) + where_clauses.append( + table_name + ".end_time_ut > {start_time_ut}".format(start_time_ut=start_time_ut) + ) + where_clauses.append( + "arr_time_ut >={start_time_ut}".format(start_time_ut=start_time_ut) + ) if route_type is not None: assert route_type in ALL_ROUTE_TYPES where_clauses.append("routes.type={route_type}".format(route_type=route_type)) if len(where_clauses) > 0: event_query += " WHERE " for i, where_clause in enumerate(where_clauses): - if i is not 0: + if i != 0: event_query += " AND " event_query += where_clause # ordering is required for later stages @@ -1409,36 +1495,62 @@ def get_transit_events(self, start_time_ut=None, end_time_ut=None, route_type=No events_result = pd.read_sql_query(event_query, self.conn) # 'filter' results so that only real "events" are taken into account from_indices = numpy.nonzero( - (events_result['trip_I'][:-1].values == events_result['trip_I'][1:].values) * - (events_result['seq'][:-1].values < events_result['seq'][1:].values) + (events_result["trip_I"][:-1].values == events_result["trip_I"][1:].values) + * (events_result["seq"][:-1].values < events_result["seq"][1:].values) )[0] to_indices = from_indices + 1 # these should have same trip_ids - assert (events_result['trip_I'][from_indices].values == events_result['trip_I'][to_indices].values).all() - trip_Is = events_result['trip_I'][from_indices] - from_stops = events_result['stop_I'][from_indices] - to_stops = events_result['stop_I'][to_indices] - shape_ids = events_result['shape_id'][from_indices] - dep_times = events_result['dep_time_ut'][from_indices] - arr_times = events_result['arr_time_ut'][to_indices] - route_types = events_result['route_type'][from_indices] - route_ids = events_result['route_id'][from_indices] - route_Is = events_result['route_I'][from_indices] + assert ( + events_result["trip_I"][from_indices].values + == events_result["trip_I"][to_indices].values + ).all() + trip_Is = events_result["trip_I"][from_indices] + from_stops = events_result["stop_I"][from_indices] + to_stops = events_result["stop_I"][to_indices] + shape_ids = events_result["shape_id"][from_indices] + dep_times = events_result["dep_time_ut"][from_indices] + arr_times = events_result["arr_time_ut"][to_indices] + route_types = events_result["route_type"][from_indices] + route_ids = events_result["route_id"][from_indices] + route_Is = events_result["route_I"][from_indices] durations = arr_times.values - dep_times.values assert (durations >= 0).all() - from_seqs = events_result['seq'][from_indices] - to_seqs = events_result['seq'][to_indices] - data_tuples = zip(from_stops, to_stops, dep_times, arr_times, - shape_ids, route_types, route_ids, trip_Is, - durations, from_seqs, to_seqs, route_Is) - columns = ["from_stop_I", "to_stop_I", "dep_time_ut", "arr_time_ut", - "shape_id", "route_type", "route_id", "trip_I", - "duration", "from_seq", "to_seq", "route_I"] + from_seqs = events_result["seq"][from_indices] + to_seqs = events_result["seq"][to_indices] + data_tuples = zip( + from_stops, + to_stops, + dep_times, + arr_times, + shape_ids, + route_types, + route_ids, + trip_Is, + durations, + from_seqs, + to_seqs, + route_Is, + ) + columns = [ + "from_stop_I", + "to_stop_I", + "dep_time_ut", + "arr_time_ut", + "shape_id", + "route_type", + "route_id", + "trip_I", + "duration", + "from_seq", + "to_seq", + "route_I", + ] df = pd.DataFrame.from_records(data_tuples, columns=columns) return df - def get_route_difference_with_other_db(self, other_gtfs, start_time, end_time, uniqueness_threshold=None, - uniqueness_ratio=None): + def get_route_difference_with_other_db( + self, other_gtfs, start_time, end_time, uniqueness_threshold=None, uniqueness_ratio=None + ): """ Compares the routes based on stops in the schedule with the routes in another db and returns the ones without match. Uniqueness thresholds or ratio can be used to allow small differences @@ -1450,8 +1562,8 @@ def get_route_difference_with_other_db(self, other_gtfs, start_time, end_time, u this_df = frequencies_by_generated_route(self, start_time, end_time) other_df = frequencies_by_generated_route(other_gtfs, start_time, end_time) - this_routes = {x: set(x.split(',')) for x in this_df["route"]} - other_routes = {x: set(x.split(',')) for x in other_df["route"]} + this_routes = {x: set(x.split(",")) for x in this_df["route"]} + other_routes = {x: set(x.split(",")) for x in other_df["route"]} # this_df["route_set"] = this_df.apply(lambda x: set(x.route.split(',')), axis=1) # other_df["route_set"] = other_df.apply(lambda x: set(x.route.split(',')), axis=1) @@ -1463,7 +1575,7 @@ def get_route_difference_with_other_db(self, other_gtfs, start_time, end_time, u for j_key, j in other_routes.items(): union = i | j intersection = i & j - symmetric_difference = i ^ j + # symmetric_difference = i ^ j if uniqueness_ratio: if len(intersection) / len(union) >= uniqueness_ratio: try: @@ -1495,7 +1607,10 @@ def get_section_difference_with_other_db(self, other_conn, start_time, end_time) AND t1.trip_I = trips.trip_I AND trips.route_I = routes.route_I GROUP BY from_stop_I, to_stop_I, routes.route_I ORDER BY route_id) sq1 - GROUP BY from_stop_I, to_stop_I""" % (start_time, end_time) + GROUP BY from_stop_I, to_stop_I""" % ( + start_time, + end_time, + ) prev_df = None result = pd.DataFrame @@ -1503,13 +1618,17 @@ def get_section_difference_with_other_db(self, other_conn, start_time, end_time) df = conn.execute_custom_query_pandas(query) df.set_index(["from_stop_I", "to_stop_I"], inplace=True, drop=True) if prev_df is not None: - result = prev_df.merge(df, how="outer", left_index=True, right_index=True, suffixes=["_old", "_new"]) + result = prev_df.merge( + df, how="outer", left_index=True, right_index=True, suffixes=["_old", "_new"] + ) break prev_df = df for suffix in ["_new", "_old"]: result["all_routes" + suffix] = result["all_routes" + suffix].fillna(value="") - result["all_routes" + suffix] = result["all_routes" + suffix].apply(lambda x: x.split(",")) + result["all_routes" + suffix] = result["all_routes" + suffix].apply( + lambda x: x.split(",") + ) result.reset_index(inplace=True) result.fillna(value=0, inplace=True) for column in ["n_trips", "n_routes"]: @@ -1534,12 +1653,12 @@ def get_straight_line_transfer_distances(self, stop_I=None): d: float or int #distance in meters """ if stop_I is not None: - query = u""" SELECT from_stop_I, to_stop_I, d + query = """ SELECT from_stop_I, to_stop_I, d FROM stop_distances WHERE from_stop_I=? """ - params = (u"{stop_I}".format(stop_I=stop_I),) + params = ("{stop_I}".format(stop_I=stop_I),) else: query = """ SELECT from_stop_I, to_stop_I, d FROM stop_distances @@ -1550,7 +1669,7 @@ def get_straight_line_transfer_distances(self, stop_I=None): def update_stats(self, stats): self.meta.update(stats) - self.meta['stats_calc_at_ut'] = time.time() + self.meta["stats_calc_at_ut"] = time.time() def get_approximate_schedule_time_span_in_ut(self): """ @@ -1576,8 +1695,9 @@ def get_day_start_ut_span(self): last_day_start_ut: int """ cur = self.conn.cursor() - first_day_start_ut, last_day_start_ut = \ - cur.execute("SELECT min(day_start_ut), max(day_start_ut) FROM days;").fetchone() + first_day_start_ut, last_day_start_ut = cur.execute( + "SELECT min(day_start_ut), max(day_start_ut) FROM days;" + ).fetchone() return first_day_start_ut, last_day_start_ut def get_min_date(self): @@ -1597,6 +1717,7 @@ def print_validation_warnings(self): warnings_container: validator.TimetableValidationWarningsContainer """ from .timetable_validator import TimetableValidator + validator = TimetableValidator(self) return validator.validate_and_get_warnings() @@ -1608,10 +1729,13 @@ def execute_custom_query_pandas(self, query): def get_stats(self): from gtfspy import stats + return stats.get_stats(self) def _get_day_trips_table_name(self): - cur = self.conn.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='day_trips2'") + cur = self.conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='day_trips2'" + ) if len(cur.fetchall()) > 0: table_name = "day_trips2" else: @@ -1636,10 +1760,13 @@ def homogenize_stops_table_with_other_db(self, source): AND find_distance(t1.lon, t1.lat, t2.lon, t2.lat) <= 50""" df_inner_join = self.execute_custom_query_pandas(query_inner_join) print("number of common stops: ", len(df_inner_join.index)) - df_not_in_other = self.execute_custom_query_pandas("SELECT * FROM stops EXCEPT " + query_inner_join) + df_not_in_other = self.execute_custom_query_pandas( + "SELECT * FROM stops EXCEPT " + query_inner_join + ) print("number of stops missing in second feed: ", len(df_not_in_other.index)) - df_not_in_self = self.execute_custom_query_pandas("SELECT * FROM other.stops EXCEPT " + - query_inner_join.replace("t1.*", "t2.*")) + df_not_in_self = self.execute_custom_query_pandas( + "SELECT * FROM other.stops EXCEPT " + query_inner_join.replace("t1.*", "t2.*") + ) print("number of stops missing in first feed: ", len(df_not_in_self.index)) try: self.execute_custom_query("""ALTER TABLE stops ADD COLUMN stop_pair_I INT """) @@ -1661,14 +1788,20 @@ def homogenize_stops_table_with_other_db(self, source): for items in df_not_in_other.itertuples(index=False): rows_to_update_self.append((counter, items[1])) - rows_to_add_to_other.append((stop_id_stub + str(counter),) + tuple(items[x] for x in [2, 3, 4, 5, 6, 8, 9]) - + (counter,)) + rows_to_add_to_other.append( + (stop_id_stub + str(counter),) + + tuple(items[x] for x in [2, 3, 4, 5, 6, 8, 9]) + + (counter,) + ) counter += 1 for items in df_not_in_self.itertuples(index=False): rows_to_update_other.append((counter, items[1])) - rows_to_add_to_self.append((stop_id_stub + str(counter),) + tuple(items[x] for x in [2, 3, 4, 5, 6, 8, 9]) - + (counter,)) + rows_to_add_to_self.append( + (stop_id_stub + str(counter),) + + tuple(items[x] for x in [2, 3, 4, 5, 6, 8, 9]) + + (counter,) + ) counter += 1 query_add_row = """INSERT INTO stops( @@ -1680,7 +1813,9 @@ def homogenize_stops_table_with_other_db(self, source): lon, location_type, wheelchair_boarding, - stop_pair_I) VALUES (%s) """ % (", ".join(["?" for x in range(9)])) + stop_pair_I) VALUES (%s) """ % ( + ", ".join(["?" for x in range(9)]) + ) query_update_row = """UPDATE stops SET stop_pair_I=? WHERE stop_id=?""" print("adding rows to databases") @@ -1698,10 +1833,8 @@ def replace_stop_i_with_stop_pair_i(self): "(SELECT stops.stop_pair_I AS stop_I FROM stops WHERE stops.stop_I = stop_times.stop_I)", # Replace stop_distances "ALTER TABLE stop_distances RENAME TO stop_distances_old", - "CREATE TABLE stop_distances (from_stop_I INT, to_stop_I INT, d INT, d_walk INT, min_transfer_time INT, " "timed_transfer INT, UNIQUE (from_stop_I, to_stop_I))", - "INSERT INTO stop_distances(from_stop_I, to_stop_I, d, d_walk, min_transfer_time, timed_transfer) " "SELECT f_stop.stop_pair_I AS from_stop_I, t_stop.stop_pair_I AS to_stop_I, d, d_walk, min_transfer_time, " "timed_transfer " @@ -1715,25 +1848,20 @@ def replace_stop_i_with_stop_pair_i(self): " JOIN " "(SELECT stop_I, stop_pair_I FROM stops) t_stop " "ON sd_o.to_stop_I = t_stop.stop_I ;", - "DROP TABLE stop_distances_old", - # Replace stops table with other "ALTER TABLE stops RENAME TO stops_old", - "CREATE TABLE stops (stop_I INTEGER PRIMARY KEY, stop_id TEXT UNIQUE NOT NULL, code TEXT, name TEXT, " "desc TEXT, lat REAL, lon REAL, parent_I INT, location_type INT, wheelchair_boarding BOOL, " "self_or_parent_I INT, old_stop_I INT)", - "INSERT INTO stops(stop_I, stop_id, code, name, desc, lat, lon, parent_I, location_type, " "wheelchair_boarding, self_or_parent_I, old_stop_I) " "SELECT stop_pair_I AS stop_I, stop_id, code, name, desc, lat, lon, parent_I, location_type, " "wheelchair_boarding, self_or_parent_I, stop_I AS old_stop_I " "FROM stops_old;", - "DROP TABLE stops_old", - - "CREATE INDEX idx_stops_sid ON stops (stop_I)"] + "CREATE INDEX idx_stops_sid ON stops (stop_I)", + ] for query in queries: cur.execute(query) self.conn.commit() @@ -1741,28 +1869,32 @@ def replace_stop_i_with_stop_pair_i(self): def regenerate_parent_stop_I(self): raise NotImplementedError # get max stop_I - cur = self.conn.cursor() + # cur = self.conn.cursor() - query = "SELECT stop_I FROM stops ORDER BY stop_I DESC LIMIT 1" - max_stop_I = cur.execute(query).fetchall()[0] + # query = "SELECT stop_I FROM stops ORDER BY stop_I DESC LIMIT 1" + # max_stop_I = cur.execute(query).fetchall()[0] - query_update_row = """UPDATE stops SET parent_I=? WHERE parent_I=?""" + # query_update_row = """UPDATE stops SET parent_I=? WHERE parent_I=?""" def add_stops_from_csv(self, csv_dir): - stops_to_add = pd.read_csv(csv_dir, encoding='utf-8') - assert all([x in stops_to_add.columns for x in ["stop_id", "code", "name", "desc", "lat", "lon"]]) + stops_to_add = pd.read_csv(csv_dir, encoding="utf-8") + assert all( + [x in stops_to_add.columns for x in ["stop_id", "code", "name", "desc", "lat", "lon"]] + ) for s in stops_to_add.itertuples(): self.add_stop(s.stop_id, s.code, s.name, s.desc, s.lat, s.lon) def add_stop(self, stop_id, code, name, desc, lat, lon): cur = self.conn.cursor() - query_add_row = 'INSERT INTO stops( stop_id, code, name, desc, lat, lon) ' \ - 'VALUES (?, ?, ?, ?, ?, ?)' + query_add_row = ( + "INSERT INTO stops( stop_id, code, name, desc, lat, lon) " "VALUES (?, ?, ?, ?, ?, ?)" + ) cur.executemany(query_add_row, [[stop_id, code, name, desc, lat, lon]]) self.conn.commit() def recalculate_stop_distances(self, max_distance): from gtfspy.calc_transfers import calc_transfers + calc_transfers(self.conn, max_distance) def attach_gtfs_database(self, gtfs_dir): @@ -1779,7 +1911,9 @@ def update_stop_coordinates(self, stop_updates): """ cur = self.conn.cursor() - stop_values = [(values.lat, values.lon, values.stop_id) for values in stop_updates.itertuples()] + stop_values = [ + (values.lat, values.lon, values.stop_id) for values in stop_updates.itertuples() + ] cur.executemany("""UPDATE stops SET lat = ?, lon = ? WHERE stop_id = ?""", stop_values) self.conn.commit() @@ -1795,8 +1929,7 @@ def __init__(self, conn): self._conn = conn def __getitem__(self, key): - val = self._conn.execute('SELECT value FROM metadata WHERE key=?', - (key,)).fetchone() + val = self._conn.execute("SELECT value FROM metadata WHERE key=?", (key,)).fetchone() if not val: raise KeyError("This GTFS does not have metadata: %s" % key) return val[0] @@ -1804,50 +1937,47 @@ def __getitem__(self, key): def __setitem__(self, key, value): """Get metadata from the DB""" if isinstance(value, bytes): - value = value.decode('utf-8') - self._conn.execute('INSERT OR REPLACE INTO metadata ' - '(key, value) VALUES (?, ?)', - (key, value)).fetchone() + value = value.decode("utf-8") + self._conn.execute( + "INSERT OR REPLACE INTO metadata " "(key, value) VALUES (?, ?)", (key, value) + ).fetchone() self._conn.commit() def __delitem__(self, key): - self._conn.execute('DELETE FROM metadata WHERE key=?', - (key,)).fetchone() + self._conn.execute("DELETE FROM metadata WHERE key=?", (key,)).fetchone() self._conn.commit() def __iter__(self): - cur = self._conn.execute('SELECT key FROM metadata ORDER BY key') + cur = self._conn.execute("SELECT key FROM metadata ORDER BY key") return (x[0] for x in cur) def __contains__(self, key): - val = self._conn.execute('SELECT value FROM metadata WHERE key=?', - (key,)).fetchone() + val = self._conn.execute("SELECT value FROM metadata WHERE key=?", (key,)).fetchone() return val is not None def get(self, key, default=None): - val = self._conn.execute('SELECT value FROM metadata WHERE key=?', - (key,)).fetchone() + val = self._conn.execute("SELECT value FROM metadata WHERE key=?", (key,)).fetchone() if not val: return default return val[0] def items(self): - cur = self._conn.execute('SELECT key, value FROM metadata ORDER BY key') + cur = self._conn.execute("SELECT key, value FROM metadata ORDER BY key") return cur def keys(self): - cur = self._conn.execute('SELECT key FROM metadata ORDER BY key') + cur = self._conn.execute("SELECT key FROM metadata ORDER BY key") return cur def values(self): - cur = self._conn.execute('SELECT value FROM metadata ORDER BY key') + cur = self._conn.execute("SELECT value FROM metadata ORDER BY key") return cur def update(self, dict_): # Would be more efficient to do it in a new query here, but # preferring simplicity. metadata updates are probably # infrequent. - if hasattr(dict_, 'items'): + if hasattr(dict_, "items"): for key, value in dict_.items(): self[key] = value else: @@ -1857,8 +1987,9 @@ def update(self, dict_): def main(cmd, args): from gtfspy import filter + # noinspection PyPackageRequirements - if cmd == 'stats': + if cmd == "stats": print(args[0]) G = GTFS(args[0]) stats = G.get_stats() @@ -1868,30 +1999,30 @@ def main(cmd, args): elif cmd == "validate": G = GTFS(args[0]) G.print_validation_warnings() - elif cmd == 'metadata-list': + elif cmd == "metadata-list": # print args[0] # need to not print to be valid json on stdout G = GTFS(args[0]) # for row in G.meta.items(): # print row stats = dict(G.meta.items()) import json - print(json.dumps(stats, sort_keys=True, - indent=4, separators=(',', ': '))) - elif cmd == 'make-daily': + + print(json.dumps(stats, sort_keys=True, indent=4, separators=(",", ": "))) + elif cmd == "make-daily": from_db = args[0] g = GTFS(from_db) to_db = args[1] - download_date = g.meta['download_date'] - d = datetime.datetime.strptime(download_date, '%Y-%m-%d').date() + download_date = g.meta["download_date"] + d = datetime.datetime.strptime(download_date, "%Y-%m-%d").date() start_time = d + datetime.timedelta(7 - d.isoweekday() + 1) # inclusive end_time = d + datetime.timedelta(7 - d.isoweekday() + 1 + 1) # exclusive filter.filter_extract(g, to_db, start_date=start_time, end_date=end_time) - elif cmd == 'make-weekly': + elif cmd == "make-weekly": from_db = args[0] g = GTFS(from_db) to_db = args[1] - download_date = g.meta['download_date'] - d = datetime.datetime.strptime(download_date, '%Y-%m-%d').date() + download_date = g.meta["download_date"] + d = datetime.datetime.strptime(download_date, "%Y-%m-%d").date() start_time = d + datetime.timedelta(7 - d.isoweekday() + 1) # inclusive end_time = d + datetime.timedelta(7 - d.isoweekday() + 1 + 7) # exclusive print(start_time, end_time) @@ -1904,31 +2035,38 @@ def main(cmd, args): radius_in_km = float(args[3]) to_db = args[4] except Exception as e: - print("spatial-extract usage: python gtfs.py spatial-extract fromdb.sqlite center_lat center_lon " - "radius_in_km todb.sqlite") + print( + "spatial-extract usage: python gtfs.py spatial-extract fromdb.sqlite center_lat center_lon " + "radius_in_km todb.sqlite" + ) raise e logging.basicConfig(level=logging.INFO) logging.info("Loading initial database") g = GTFS(from_db) - filter.filter_extract(g, to_db, buffer_distance=radius_in_km * 1000, buffer_lat=lat, buffer_lon=lon) - elif cmd == 'interact': + filter.filter_extract( + g, to_db, buffer_distance=radius_in_km * 1000, buffer_lat=lat, buffer_lon=lon + ) + elif cmd == "interact": # noinspection PyUnusedLocal G = GTFS(args[0]) # noinspection PyPackageRequirements import IPython + IPython.embed() - elif 'export_shapefile' in cmd: + elif "export_shapefile" in cmd: from gtfspy.util import write_shapefile + from_db = args[ - 0] # '/m/cs/project/networks/jweckstr/transit/scratch/proc_latest/helsinki/2016-04-06/main.day.sqlite' + 0 + ] # '/m/cs/project/networks/jweckstr/transit/scratch/proc_latest/helsinki/2016-04-06/main.day.sqlite' shapefile_path = args[1] # '/m/cs/project/networks/jweckstr/TESTDATA/helsinki_routes.shp' g = GTFS(from_db) - if cmd == 'export_shapefile_routes': + if cmd == "export_shapefile_routes": data = g.get_all_route_shapes(use_shapes=True) - elif cmd == 'export_shapefile_segment_counts': + elif cmd == "export_shapefile_segment_counts": date = args[2] # '2016-04-06' - d = datetime.datetime.strptime(date, '%Y-%m-%d').date() + d = datetime.datetime.strptime(date, "%Y-%m-%d").date() day_start = g.get_day_start_ut(d + datetime.timedelta(7 - d.isoweekday() + 1)) start_time = day_start + 3600 * 7 end_time = day_start + 3600 * 8 @@ -1936,7 +2074,6 @@ def main(cmd, args): write_shapefile(data, shapefile_path) - else: print("Unrecognized command: %s" % cmd) exit(1) diff --git a/gtfspy/import_gtfs.py b/gtfspy/import_gtfs.py index 7c093ca..5c92fbb 100644 --- a/gtfspy/import_gtfs.py +++ b/gtfspy/import_gtfs.py @@ -4,9 +4,23 @@ from __future__ import print_function from __future__ import unicode_literals -from gtfspy.import_loaders import AgencyLoader, CalendarDatesLoader, CalendarLoader, DayLoader, \ - DayTripsMaterializer, FeedInfoLoader, FrequenciesLoader, TripLoader, MetadataLoader, RouteLoader, \ - ShapeLoader, StopDistancesLoader, StopLoader, StopTimesLoader, TransfersLoader +from gtfspy.import_loaders import ( + AgencyLoader, + CalendarDatesLoader, + CalendarLoader, + DayLoader, + DayTripsMaterializer, + FeedInfoLoader, + FrequenciesLoader, + TripLoader, + MetadataLoader, + RouteLoader, + ShapeLoader, + StopDistancesLoader, + StopLoader, + StopTimesLoader, + TransfersLoader, +) from gtfspy.import_loaders.table_loader import ignore_tables, decode_six """ @@ -25,29 +39,36 @@ from gtfspy.gtfs import GTFS -Loaders = [AgencyLoader, # deps: - - RouteLoader, # deps: Agency - MetadataLoader, # deps: - - CalendarLoader, # deps: - - CalendarDatesLoader, # deps: Calendar - ShapeLoader, # deps: - - FeedInfoLoader, # deps: - - StopLoader, # deps: - - TransfersLoader, # deps: Stop - StopDistancesLoader, # deps: (pi: Stop) - TripLoader, # deps: Route, Calendar, (Shape) | (pi2: StopTimes) - StopTimesLoader, # deps: Stop, Trip | |(v: Trip, Day) - FrequenciesLoader, # deps: Trip (pi: Trip, StopTimes) | - DayLoader, # deps: (pi: Calendar, CalendarDates, Trip) | - DayTripsMaterializer # deps: | (pi2: Day) - ] -postprocessors = [ - #validate_day_start_ut, +Loaders = [ + AgencyLoader, # deps: - + RouteLoader, # deps: Agency + MetadataLoader, # deps: - + CalendarLoader, # deps: - + CalendarDatesLoader, # deps: Calendar + ShapeLoader, # deps: - + FeedInfoLoader, # deps: - + StopLoader, # deps: - + TransfersLoader, # deps: Stop + StopDistancesLoader, # deps: (pi: Stop) + TripLoader, # deps: Route, Calendar, (Shape) | (pi2: StopTimes) + StopTimesLoader, # deps: Stop, Trip | |(v: Trip, Day) + FrequenciesLoader, # deps: Trip (pi: Trip, StopTimes) | + DayLoader, # deps: (pi: Calendar, CalendarDates, Trip) | + DayTripsMaterializer, # deps: | (pi2: Day) ] +# postprocessors = [validate_day_start_ut] +postprocessors = [] # type: ignore -def import_gtfs(gtfs_sources, output, preserve_connection=False, - print_progress=True, location_name=None, **kwargs): + +def import_gtfs( + gtfs_sources, + output, + preserve_connection=False, + print_progress=True, + location_name=None, + **kwargs +): """Import a GTFS database gtfs_sources: str, dict, list @@ -78,17 +99,17 @@ def import_gtfs(gtfs_sources, output, preserve_connection=False, # These are a bit unsafe, but make importing much faster, # especially on scratch. - cur.execute('PRAGMA page_size = 4096;') - cur.execute('PRAGMA mmap_size = 1073741824;') - cur.execute('PRAGMA cache_size = -2000000;') - cur.execute('PRAGMA temp_store=2;') + cur.execute("PRAGMA page_size = 4096;") + cur.execute("PRAGMA mmap_size = 1073741824;") + cur.execute("PRAGMA cache_size = -2000000;") + cur.execute("PRAGMA temp_store=2;") # Changes of isolation level are python3.6 workarounds - # eventually will probably be fixed and this can be removed. conn.isolation_level = None # change to autocommit mode (former default) - cur.execute('PRAGMA journal_mode = OFF;') - #cur.execute('PRAGMA journal_mode = WAL;') - cur.execute('PRAGMA synchronous = OFF;') - conn.isolation_level = '' # change back to python default. + cur.execute("PRAGMA journal_mode = OFF;") + # cur.execute('PRAGMA journal_mode = WAL;') + cur.execute("PRAGMA synchronous = OFF;") + conn.isolation_level = "" # change back to python default. # end python3.6 workaround # Do the actual importing. @@ -117,13 +138,14 @@ def import_gtfs(gtfs_sources, output, preserve_connection=False, # Set up same basic metadata. from gtfspy import gtfs as mod_gtfs + G = mod_gtfs.GTFS(output) - G.meta['gen_time_ut'] = time.time() - G.meta['gen_time'] = time.ctime() - G.meta['import_seconds'] = time.time() - time_import_start - G.meta['download_date'] = '' - G.meta['location_name'] = '' - G.meta['n_gtfs_sources'] = len(gtfs_sources) + G.meta["gen_time_ut"] = time.time() + G.meta["gen_time"] = time.ctime() + G.meta["import_seconds"] = time.time() - time_import_start + G.meta["download_date"] = "" + G.meta["location_name"] = "" + G.meta["n_gtfs_sources"] = len(gtfs_sources) # Extract things from GTFS download_date_strs = [] @@ -133,31 +155,31 @@ def import_gtfs(gtfs_sources, output, preserve_connection=False, else: prefix = "feed_" + str(i) + "_" if isinstance(source, string_types): - G.meta[prefix + 'original_gtfs'] = decode_six(source) if source else None + G.meta[prefix + "original_gtfs"] = decode_six(source) if source else None # Extract GTFS date. Last date pattern in filename. - filename_date_list = re.findall(r'\d{4}-\d{2}-\d{2}', source) + filename_date_list = re.findall(r"\d{4}-\d{2}-\d{2}", source) if filename_date_list: date_str = filename_date_list[-1] - G.meta[prefix + 'download_date'] = date_str + G.meta[prefix + "download_date"] = date_str download_date_strs.append(date_str) if location_name: - G.meta['location_name'] = location_name + G.meta["location_name"] = location_name else: - location_name_list = re.findall(r'/([^/]+)/\d{4}-\d{2}-\d{2}', source) + location_name_list = re.findall(r"/([^/]+)/\d{4}-\d{2}-\d{2}", source) if location_name_list: - G.meta[prefix + 'location_name'] = location_name_list[-1] + G.meta[prefix + "location_name"] = location_name_list[-1] else: try: - G.meta[prefix + 'location_name'] = source.split("/")[-4] + G.meta[prefix + "location_name"] = source.split("/")[-4] except: - G.meta[prefix + 'location_name'] = source + G.meta[prefix + "location_name"] = source - if G.meta['download_date'] == "": + if G.meta["download_date"] == "": unique_download_dates = list(set(download_date_strs)) if len(unique_download_dates) == 1: - G.meta['download_date'] = unique_download_dates[0] + G.meta["download_date"] = unique_download_dates[0] - G.meta['timezone'] = cur.execute('SELECT timezone FROM agencies LIMIT 1').fetchone()[0] + G.meta["timezone"] = cur.execute("SELECT timezone FROM agencies LIMIT 1").fetchone()[0] stats.update_stats(G) del G @@ -165,21 +187,22 @@ def import_gtfs(gtfs_sources, output, preserve_connection=False, print("Vacuuming...") # Next 3 lines are python 3.6 work-arounds again. conn.isolation_level = None # former default of autocommit mode - cur.execute('VACUUM;') - conn.isolation_level = '' # back to python default + cur.execute("VACUUM;") + conn.isolation_level = "" # back to python default # end python3.6 workaround if print_progress: print("Analyzing...") - cur.execute('ANALYZE') + cur.execute("ANALYZE") if not (preserve_connection is True): conn.close() + def validate_day_start_ut(conn): """This validates the day_start_ut of the days table.""" G = GTFS(conn) - cur = conn.execute('SELECT date, day_start_ut FROM days') + cur = conn.execute("SELECT date, day_start_ut FROM days") for date, day_start_ut in cur: - #print date, day_start_ut + # print date, day_start_ut assert day_start_ut == G.get_day_start_ut(date) @@ -192,10 +215,12 @@ def main_make_views(gtfs_fname): L(None).make_views(conn) conn.commit() + def main(): import argparse - parser = argparse.ArgumentParser(description=""" + parser = argparse.ArgumentParser( + description=""" Import GTFS files. Imports gtfs. There are two subcommands. The 'import' subcommand converts from a GTFS directory or zipfile to a sqlite database. Both must be specified on the command line. The @@ -203,62 +228,73 @@ def main(): automatically find databases and output files (in scratch/gtfs and scratch/db) based on the shortname given on the command line. This should probably not be used much anymore, instead automate that before - calling this program.""") - subparsers = parser.add_subparsers(dest='cmd') + calling this program.""" + ) + subparsers = parser.add_subparsers(dest="cmd") # parsing import - parser_import = subparsers.add_parser('import', help="Direct import GTFS->sqlite") - parser_import.add_argument('gtfs', help='Input GTFS filename (dir or .zip)') - parser_import.add_argument('output', help='Output .sqlite filename (must end in .sqlite)') - parser.add_argument('--fast', action='store_true', - help='Skip stop_times and shapes tables.') + parser_import = subparsers.add_parser("import", help="Direct import GTFS->sqlite") + parser_import.add_argument("gtfs", help="Input GTFS filename (dir or .zip)") + parser_import.add_argument("output", help="Output .sqlite filename (must end in .sqlite)") + parser.add_argument("--fast", action="store_true", help="Skip stop_times and shapes tables.") # parsing import-auto - parser_importauto = subparsers.add_parser('import-auto', help="Automatic GTFS import from files") - parser_importauto.add_argument('gtfsname', help='Input GTFS filename') + parser_importauto = subparsers.add_parser( + "import-auto", help="Automatic GTFS import from files" + ) + parser_importauto.add_argument("gtfsname", help="Input GTFS filename") # parsing import-multiple - parser_import_multiple = subparsers.add_parser('import-multiple', help="GTFS import from multiple zip-files") - parser_import_multiple.add_argument('zipfiles', metavar='zipfiles', type=str, nargs=argparse.ONE_OR_MORE, - help='zipfiles for the import') - parser_import_multiple.add_argument('output', help='Output .sqlite filename (must end in .sqlite)') + parser_import_multiple = subparsers.add_parser( + "import-multiple", help="GTFS import from multiple zip-files" + ) + parser_import_multiple.add_argument( + "zipfiles", + metavar="zipfiles", + type=str, + nargs=argparse.ONE_OR_MORE, + help="zipfiles for the import", + ) + parser_import_multiple.add_argument( + "output", help="Output .sqlite filename (must end in .sqlite)" + ) # parsing import-list # Parsing copy - parser_copy = subparsers.add_parser('copy', help="Copy database") - parser_copy.add_argument('source', help='Input GTFS .sqlite') - parser_copy.add_argument('dest', help='Output GTFS .sqlite') - parser_copy.add_argument('--start', help='Start copy time (inclusive)') - parser_copy.add_argument('--end', help='Start copy time (exclusive)') + parser_copy = subparsers.add_parser("copy", help="Copy database") + parser_copy.add_argument("source", help="Input GTFS .sqlite") + parser_copy.add_argument("dest", help="Output GTFS .sqlite") + parser_copy.add_argument("--start", help="Start copy time (inclusive)") + parser_copy.add_argument("--end", help="Start copy time (exclusive)") # Parsing copy - parser_copy = subparsers.add_parser('make-views', help="Re-create views") - parser_copy.add_argument('gtfs', help='Input GTFS .sqlite') + parser_copy = subparsers.add_parser("make-views", help="Re-create views") + parser_copy.add_argument("gtfs", help="Input GTFS .sqlite") # make-weekly-download - parser_copy = subparsers.add_parser('make-weekly') - parser_copy.add_argument('source', help='Input GTFS .sqlite') - parser_copy.add_argument('dest', help='Output GTFS .sqlite') + parser_copy = subparsers.add_parser("make-weekly") + parser_copy.add_argument("source", help="Input GTFS .sqlite") + parser_copy.add_argument("dest", help="Output GTFS .sqlite") - parser_copy = subparsers.add_parser('make-daily') - parser_copy.add_argument('source', help='Input GTFS .sqlite') - parser_copy.add_argument('dest', help='Output GTFS .sqlite') + parser_copy = subparsers.add_parser("make-daily") + parser_copy.add_argument("source", help="Input GTFS .sqlite") + parser_copy.add_argument("dest", help="Output GTFS .sqlite") # Export stop distances - parser_copy = subparsers.add_parser('export-stop-distances') - parser_copy.add_argument('gtfs', help='Input GTFS .sqlite') - parser_copy.add_argument('output', help='Output for .txt file') + parser_copy = subparsers.add_parser("export-stop-distances") + parser_copy.add_argument("gtfs", help="Input GTFS .sqlite") + parser_copy.add_argument("output", help="Output for .txt file") # Custom stuff - parser_copy = subparsers.add_parser('custom') - parser_copy.add_argument('gtfs', help='Input GTFS .sqlite') + parser_copy = subparsers.add_parser("custom") + parser_copy.add_argument("gtfs", help="Input GTFS .sqlite") args = parser.parse_args() if args.fast: - ignore_tables.update(('stop_times', 'shapes')) + ignore_tables.update(("stop_times", "shapes")) # if the first argument is import, import a GTFS directory to a .sqlite database. # Both directory and - if args.cmd == 'import': + if args.cmd == "import": gtfs = args.gtfs output = args.output # This context manager makes a tmpfile for import. If there @@ -272,7 +308,7 @@ def main(): print("loaders") with util.create_file(output, tmpdir=True, keepext=True) as tmpfile: import_gtfs(zipfiles, output=tmpfile) - elif args.cmd == 'make-views': + elif args.cmd == "make-views": main_make_views(args.gtfs) # This is now implemented in gtfs.py, please remove the commented code # if no one has touched this in a while.: @@ -291,11 +327,11 @@ def main(): # date_start = d + timedelta(7-d.isoweekday()+1) # inclusive # date_end = d + timedelta(7-d.isoweekday()+1 + 1) # exclusive # G.copy_and_filter(args.dest, start_date=date_start, end_date=date_end) - elif args.cmd == 'export-stop-distances': + elif args.cmd == "export-stop-distances": conn = sqlite3.connect(args.gtfs) L = StopDistancesLoader(conn) - L.export_stop_distances(conn, open(args.output, 'w')) - elif args.cmd == 'custom': + L.export_stop_distances(conn, open(args.output, "w")) + elif args.cmd == "custom": pass # This is designed for just testing things. This code should # always be commented out in the VCS. @@ -307,5 +343,6 @@ def main(): print("Unrecognized command: %s" % args.cmd) exit(1) + if __name__ == "__main__": main() diff --git a/gtfspy/import_loaders/__init__.py b/gtfspy/import_loaders/__init__.py index 6d0e6aa..f2060c4 100644 --- a/gtfspy/import_loaders/__init__.py +++ b/gtfspy/import_loaders/__init__.py @@ -12,4 +12,4 @@ from gtfspy.import_loaders.stop_distances_loader import StopDistancesLoader from gtfspy.import_loaders.stop_loader import StopLoader from gtfspy.import_loaders.stop_times_loader import StopTimesLoader -from gtfspy.import_loaders.transfer_loader import TransfersLoader \ No newline at end of file +from gtfspy.import_loaders.transfer_loader import TransfersLoader diff --git a/gtfspy/import_loaders/agency_loader.py b/gtfspy/import_loaders/agency_loader.py index a60d20a..3b23a89 100644 --- a/gtfspy/import_loaders/agency_loader.py +++ b/gtfspy/import_loaders/agency_loader.py @@ -1,5 +1,3 @@ -import os -import time from datetime import datetime from gtfspy.import_loaders.table_loader import TableLoader, decode_six @@ -7,10 +5,12 @@ class AgencyLoader(TableLoader): - fname = 'agency.txt' - table = 'agencies' - tabledef = ('(agency_I INTEGER PRIMARY KEY, agency_id TEXT UNIQUE NOT NULL, ' - 'name TEXT, url TEXT, timezone TEXT, lang TEXT, phone TEXT)') + fname = "agency.txt" + table = "agencies" + tabledef = ( + "(agency_I INTEGER PRIMARY KEY, agency_id TEXT UNIQUE NOT NULL, " + "name TEXT, url TEXT, timezone TEXT, lang TEXT, phone TEXT)" + ) # shape_id,shape_pt_lat,shape_pt_lon,shape_pt_sequence # 1001_20140811_1,60.167430,24.951684,1 @@ -19,29 +19,32 @@ def gen_rows(self, readers, prefixes): for reader, prefix in zip(readers, prefixes): for row in reader: yield dict( - agency_id =prefix + decode_six(row.get('agency_id', '1')), - name = decode_six(row['agency_name']), - timezone = decode_six(row['agency_timezone']), - url = decode_six(row['agency_url']), - lang = decode_six(row['agency_lang']) if 'agency_lang' in row else None, - phone = decode_six(row['agency_phone']) if 'agency_phone' in row else None, + agency_id=prefix + decode_six(row.get("agency_id", "1")), + name=decode_six(row["agency_name"]), + timezone=decode_six(row["agency_timezone"]), + url=decode_six(row["agency_url"]), + lang=decode_six(row["agency_lang"]) if "agency_lang" in row else None, + phone=decode_six(row["agency_phone"]) if "agency_phone" in row else None, ) def post_import(self, cur): - TZs = cur.execute('SELECT DISTINCT timezone FROM agencies').fetchall() + TZs = cur.execute("SELECT DISTINCT timezone FROM agencies").fetchall() if len(TZs) == 0: raise ValueError("Error: no timezones defined in sources: %s" % self.gtfs_sources) elif len(TZs) > 1: first_tz = TZs[0][0] import pytz + for tz in TZs[1:]: generic_date = datetime(2009, 9, 1) ftz = pytz.timezone(first_tz).utcoffset(generic_date, is_dst=True) ctz = pytz.timezone(tz[0]).utcoffset(generic_date, is_dst=True) if not str(ftz) == str(ctz): - raise ValueError("Error: multiple timezones defined in sources:: %s" % self.gtfs_sources) + raise ValueError( + "Error: multiple timezones defined in sources:: %s" % self.gtfs_sources + ) TZ = TZs[0][0] set_process_timezone(TZ) def index(self, cur): - pass \ No newline at end of file + pass diff --git a/gtfspy/import_loaders/calendar_dates_loader.py b/gtfspy/import_loaders/calendar_dates_loader.py index a16c578..a84a820 100644 --- a/gtfspy/import_loaders/calendar_dates_loader.py +++ b/gtfspy/import_loaders/calendar_dates_loader.py @@ -2,42 +2,46 @@ class CalendarDatesLoader(TableLoader): - fname = 'calendar_dates.txt' - table = 'calendar_dates' - tabledef = '(service_I INTEGER NOT NULL, date TEXT, exception_type INT)' - copy_where = ("WHERE date({start_ut}, 'unixepoch', 'localtime') <= date " - "AND date < date({end_ut}, 'unixepoch', 'localtime')") + fname = "calendar_dates.txt" + table = "calendar_dates" + tabledef = "(service_I INTEGER NOT NULL, date TEXT, exception_type INT)" + copy_where = ( + "WHERE date({start_ut}, 'unixepoch', 'localtime') <= date " + "AND date < date({end_ut}, 'unixepoch', 'localtime')" + ) def gen_rows(self, readers, prefixes): conn = self._conn cur = conn.cursor() for reader, prefix in zip(readers, prefixes): for row in reader: - date = row['date'] - date_str = '%s-%s-%s' % (date[:4], date[4:6], date[6:8]) - service_id = prefix+row['service_id'] + date = row["date"] + date_str = "%s-%s-%s" % (date[:4], date[4:6], date[6:8]) + service_id = prefix + row["service_id"] # We need to find the service_I of this. To do this we # need to check the calendar table, since that (and only # that) is the absolute list of service_ids. service_I = cur.execute( - 'SELECT service_I FROM calendar WHERE service_id=?', - (decode_six(service_id),)).fetchone() + "SELECT service_I FROM calendar WHERE service_id=?", (decode_six(service_id),) + ).fetchone() if service_I is None: # We have to add a new fake row in order to get a # service_I. calendar is *the* authoritative source # for service_I:s. - cur.execute('INSERT INTO calendar ' - '(service_id, m,t,w,th,f,s,su, start_date,end_date)' - 'VALUES (?, 0,0,0,0,0,0,0, ?,?)', - (decode_six(service_id), date_str, date_str) - ) + cur.execute( + "INSERT INTO calendar " + "(service_id, m,t,w,th,f,s,su, start_date,end_date)" + "VALUES (?, 0,0,0,0,0,0,0, ?,?)", + (decode_six(service_id), date_str, date_str), + ) service_I = cur.execute( - 'SELECT service_I FROM calendar WHERE service_id=?', - (decode_six(service_id),)).fetchone() + "SELECT service_I FROM calendar WHERE service_id=?", + (decode_six(service_id),), + ).fetchone() service_I = service_I[0] # row tuple -> int yield dict( - service_I = int(service_I), - date = date_str, - exception_type= int(row['exception_type']), - ) \ No newline at end of file + service_I=int(service_I), + date=date_str, + exception_type=int(row["exception_type"]), + ) diff --git a/gtfspy/import_loaders/calendar_loader.py b/gtfspy/import_loaders/calendar_loader.py index afa7968..e8c975e 100644 --- a/gtfspy/import_loaders/calendar_loader.py +++ b/gtfspy/import_loaders/calendar_loader.py @@ -2,11 +2,13 @@ class CalendarLoader(TableLoader): - fname = 'calendar.txt' - table = 'calendar' - tabledef = '(service_I INTEGER PRIMARY KEY, service_id TEXT UNIQUE NOT NULL, m INT, t INT, w INT, th INT, f INT, s INT, su INT, start_date TEXT, end_date TEXT)' - copy_where = ("WHERE date({start_ut}, 'unixepoch', 'localtime') < end_date " - "AND start_date < date({end_ut}, 'unixepoch', 'localtime')") + fname = "calendar.txt" + table = "calendar" + tabledef = "(service_I INTEGER PRIMARY KEY, service_id TEXT UNIQUE NOT NULL, m INT, t INT, w INT, th INT, f INT, s INT, su INT, start_date TEXT, end_date TEXT)" + copy_where = ( + "WHERE date({start_ut}, 'unixepoch', 'localtime') < end_date " + "AND start_date < date({end_ut}, 'unixepoch', 'localtime')" + ) # service_id,monday,tuesday,wednesday,thursday,friday,saturday,sunday,start_date,end_date # 1001_20150810_20151014_Ke,0,0,1,0,0,0,0,20150810,20151014 @@ -14,29 +16,31 @@ def gen_rows(self, readers, prefixes): for reader, prefix in zip(readers, prefixes): for row in reader: # print row - start = row['start_date'] - end = row['end_date'] + start = row["start_date"] + end = row["end_date"] yield dict( - service_id = prefix + decode_six(row['service_id']), - m = int(row['monday']), - t = int(row['tuesday']), - w = int(row['wednesday']), - th = int(row['thursday']), - f = int(row['friday']), - s = int(row['saturday']), - su = int(row['sunday']), - start_date = '%s-%s-%s' % (start[:4], start[4:6], start[6:8]), - end_date = '%s-%s-%s' % (end[:4], end[4:6], end[6:8]), + service_id=prefix + decode_six(row["service_id"]), + m=int(row["monday"]), + t=int(row["tuesday"]), + w=int(row["wednesday"]), + th=int(row["thursday"]), + f=int(row["friday"]), + s=int(row["saturday"]), + su=int(row["sunday"]), + start_date="%s-%s-%s" % (start[:4], start[4:6], start[6:8]), + end_date="%s-%s-%s" % (end[:4], end[4:6], end[6:8]), ) @classmethod def index(cls, cur): # cur.execute('CREATE INDEX IF NOT EXISTS idx_calendar_svid ON calendar (service_id)') - cur.execute('CREATE INDEX IF NOT EXISTS idx_calendar_s_e ON calendar (start_date, end_date)') - cur.execute('CREATE INDEX IF NOT EXISTS idx_calendar_m ON calendar (m)') - cur.execute('CREATE INDEX IF NOT EXISTS idx_calendar_t ON calendar (t)') - cur.execute('CREATE INDEX IF NOT EXISTS idx_calendar_w ON calendar (w)') - cur.execute('CREATE INDEX IF NOT EXISTS idx_calendar_th ON calendar (th)') - cur.execute('CREATE INDEX IF NOT EXISTS idx_calendar_f ON calendar (f)') - cur.execute('CREATE INDEX IF NOT EXISTS idx_calendar_s ON calendar (s)') - cur.execute('CREATE INDEX IF NOT EXISTS idx_calendar_su ON calendar (su)') \ No newline at end of file + cur.execute( + "CREATE INDEX IF NOT EXISTS idx_calendar_s_e ON calendar (start_date, end_date)" + ) + cur.execute("CREATE INDEX IF NOT EXISTS idx_calendar_m ON calendar (m)") + cur.execute("CREATE INDEX IF NOT EXISTS idx_calendar_t ON calendar (t)") + cur.execute("CREATE INDEX IF NOT EXISTS idx_calendar_w ON calendar (w)") + cur.execute("CREATE INDEX IF NOT EXISTS idx_calendar_th ON calendar (th)") + cur.execute("CREATE INDEX IF NOT EXISTS idx_calendar_f ON calendar (f)") + cur.execute("CREATE INDEX IF NOT EXISTS idx_calendar_s ON calendar (s)") + cur.execute("CREATE INDEX IF NOT EXISTS idx_calendar_su ON calendar (su)") diff --git a/gtfspy/import_loaders/day_loader.py b/gtfspy/import_loaders/day_loader.py index eef4e80..72ad96e 100644 --- a/gtfspy/import_loaders/day_loader.py +++ b/gtfspy/import_loaders/day_loader.py @@ -2,12 +2,13 @@ from gtfspy.import_loaders.table_loader import TableLoader + class DayLoader(TableLoader): # Note: calendar and calendar_dates should have been imported before # importing with DayLoader fname = None - table = 'days' - tabledef = '(date TEXT, day_start_ut INT, trip_I INT)' + table = "days" + tabledef = "(date TEXT, day_start_ut INT, trip_I INT)" copy_where = "WHERE {start_ut} <= day_start_ut AND day_start_ut < {end_ut}" def post_import(self, cur): @@ -18,8 +19,8 @@ def index(self, cur): def create_day_table_indices(cursor): - cursor.execute('CREATE INDEX IF NOT EXISTS idx_days_day ON days (date)') - cursor.execute('CREATE INDEX IF NOT EXISTS idx_days_dsut_tid ON days (day_start_ut, trip_I)') + cursor.execute("CREATE INDEX IF NOT EXISTS idx_days_day ON days (date)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_days_dsut_tid ON days (day_start_ut, trip_I)") def drop_day_table_indices(cursor): @@ -35,9 +36,9 @@ def insert_data_to_days(cur, conn): # future processing. So, create it here, delete it at the end # of the function. If this index was important, it could be # moved to CalendarDatesLoader. - cur.execute('CREATE INDEX IF NOT EXISTS idx_calendar_dates_sid ON calendar_dates (service_I)') + cur.execute("CREATE INDEX IF NOT EXISTS idx_calendar_dates_sid ON calendar_dates (service_I)") - cur.execute('SELECT * FROM calendar') + cur.execute("SELECT * FROM calendar") colnames = cur.description cur2 = conn.cursor() @@ -53,27 +54,28 @@ def iter_dates(start, end): yield date date += one_day - weekdays = ['m', 't', 'w', 'th', 'f', 's', 'su'] + weekdays = ["m", "t", "w", "th", "f", "s", "su"] # For every row in the calendar... for row in cur: row = make_dict(row) - service_I = int(row['service_I']) + service_I = int(row["service_I"]) # EXCEPTIONS (calendar_dates): Get a set of all # exceptional days. exception_type=2 means that service # removed on that day. Below, we will exclude all dates # that are in this set. - cur2.execute('SELECT date FROM calendar_dates ' - 'WHERE service_I=? and exception_type=?', - (service_I, 2)) + cur2.execute( + "SELECT date FROM calendar_dates " "WHERE service_I=? and exception_type=?", + (service_I, 2), + ) exception_dates = set(x[0] for x in cur2.fetchall()) # - start_date = datetime.strptime(row['start_date'], '%Y-%m-%d').date() - end_date = datetime.strptime(row['end_date'], '%Y-%m-%d').date() + start_date = datetime.strptime(row["start_date"], "%Y-%m-%d").date() + end_date = datetime.strptime(row["end_date"], "%Y-%m-%d").date() # For every date in that row's date range... for date in iter_dates(start_date, end_date): weekday = date.isoweekday() - 1 # -1 to match weekdays list above # Exclude dates with service exceptions - date_str = date.strftime('%Y-%m-%d') + date_str = date.strftime("%Y-%m-%d") if date_str in exception_dates: # print "calendar_dates.txt exception: removing %s from %s"%(service_I, date) continue @@ -83,24 +85,28 @@ def iter_dates(start, end): days.append((date, service_I)) # Store in database, day_start_ut is "noon minus 12 hours". - cur.executemany("""INSERT INTO days + cur.executemany( + """INSERT INTO days (date, day_start_ut, trip_I) SELECT ?, strftime('%s', ?, '12:00', 'utc')-43200, trip_I FROM trips WHERE service_I=? - """, ((date, date, service_I) - for date, service_I in days)) + """, + ((date, date, service_I) for date, service_I in days), + ) # EXCEPTIONS: Add in dates with exceptions. Find them and # store them directly in the database. - cur2.execute("INSERT INTO days " - "(date, day_start_ut, trip_I) " - "SELECT date, strftime('%s',date,'12:00','utc')-43200, trip_I " - "FROM trips " - "JOIN calendar_dates USING(service_I) " - "WHERE exception_type=?", - (1,)) + cur2.execute( + "INSERT INTO days " + "(date, day_start_ut, trip_I) " + "SELECT date, strftime('%s',date,'12:00','utc')-43200, trip_I " + "FROM trips " + "JOIN calendar_dates USING(service_I) " + "WHERE exception_type=?", + (1,), + ) conn.commit() - cur.execute('DROP INDEX IF EXISTS main.idx_calendar_dates_sid') + cur.execute("DROP INDEX IF EXISTS main.idx_calendar_dates_sid") def recreate_days_table(conn): diff --git a/gtfspy/import_loaders/day_trips_materializer.py b/gtfspy/import_loaders/day_trips_materializer.py index 494d1c1..a109cc2 100644 --- a/gtfspy/import_loaders/day_trips_materializer.py +++ b/gtfspy/import_loaders/day_trips_materializer.py @@ -12,23 +12,22 @@ class DayTripsMaterializer(TableLoader): day_trips: Replacement for the old day_trips view. day_trips2+trips day_stop_times: day_trips2+trips+stop_times """ + fname = None - table = 'day_trips2' - tabledef = ('(date TEXT, ' - 'trip_I INT, ' - 'start_time_ut INT, ' - 'end_time_ut INT, ' - 'day_start_ut INT)') - copy_where = 'WHERE {start_ut} < end_time_ut AND start_time_ut < {end_ut}' + table = "day_trips2" + tabledef = ( + "(date TEXT, " "trip_I INT, " "start_time_ut INT, " "end_time_ut INT, " "day_start_ut INT)" + ) + copy_where = "WHERE {start_ut} < end_time_ut AND start_time_ut < {end_ut}" @classmethod def post_import_round2(cls, conn): insert_data_to_day_trips2(conn) - def index(cls, cur): + @staticmethod + def index(cur): create_day_trips_indices(cur) - @classmethod def make_views(cls, conn): """Create day_trips and day_stop_times views. @@ -36,59 +35,67 @@ def make_views(cls, conn): day_trips: day_trips2 x trips = days x trips day_stop_times: day_trips2 x trips x stop_times = days x trips x stop_times """ - conn.execute('DROP VIEW IF EXISTS main.day_trips') - conn.execute('CREATE VIEW day_trips AS ' - 'SELECT day_trips2.*, trips.* ' - #'days.day_start_ut+trips.start_time_ds AS start_time_ut, ' - #'days.day_start_ut+trips.end_time_ds AS end_time_ut ' - 'FROM day_trips2 JOIN trips USING (trip_I);') + conn.execute("DROP VIEW IF EXISTS main.day_trips") + conn.execute( + "CREATE VIEW day_trips AS " + "SELECT day_trips2.*, trips.* " + # 'days.day_start_ut+trips.start_time_ds AS start_time_ut, ' + # 'days.day_start_ut+trips.end_time_ds AS end_time_ut ' + "FROM day_trips2 JOIN trips USING (trip_I);" + ) conn.commit() - conn.execute('DROP VIEW IF EXISTS main.day_stop_times') - conn.execute('CREATE VIEW day_stop_times AS ' - 'SELECT day_trips2.*, trips.*, stop_times.*, ' - #'days.day_start_ut+trips.start_time_ds AS start_time_ut, ' - #'days.day_start_ut+trips.end_time_ds AS end_time_ut, ' - 'day_trips2.day_start_ut+stop_times.arr_time_ds AS arr_time_ut, ' - 'day_trips2.day_start_ut+stop_times.dep_time_ds AS dep_time_ut ' - 'FROM day_trips2 ' - 'JOIN trips USING (trip_I) ' - 'JOIN stop_times USING (trip_I)') + conn.execute("DROP VIEW IF EXISTS main.day_stop_times") + conn.execute( + "CREATE VIEW day_stop_times AS " + "SELECT day_trips2.*, trips.*, stop_times.*, " + # 'days.day_start_ut+trips.start_time_ds AS start_time_ut, ' + # 'days.day_start_ut+trips.end_time_ds AS end_time_ut, ' + "day_trips2.day_start_ut+stop_times.arr_time_ds AS arr_time_ut, " + "day_trips2.day_start_ut+stop_times.dep_time_ds AS dep_time_ut " + "FROM day_trips2 " + "JOIN trips USING (trip_I) " + "JOIN stop_times USING (trip_I)" + ) conn.commit() def create_day_trips_indices(cur): - cur.execute('CREATE INDEX IF NOT EXISTS idx_day_trips2_tid ' - 'ON day_trips2 (trip_I)') - cur.execute('CREATE INDEX IF NOT EXISTS idx_day_trips2_d ' - 'ON day_trips2 (date)') - cur.execute('CREATE INDEX IF NOT EXISTS idx_day_trips2_stut_etut ' - 'ON day_trips2 (start_time_ut, end_time_ut)') + cur.execute("CREATE INDEX IF NOT EXISTS idx_day_trips2_tid " "ON day_trips2 (trip_I)") + cur.execute("CREATE INDEX IF NOT EXISTS idx_day_trips2_d " "ON day_trips2 (date)") + cur.execute( + "CREATE INDEX IF NOT EXISTS idx_day_trips2_stut_etut " + "ON day_trips2 (start_time_ut, end_time_ut)" + ) # This index may not be needed anymore. - cur.execute('CREATE INDEX IF NOT EXISTS idx_day_trips2_dsut ' - 'ON day_trips2 (day_start_ut)') + cur.execute("CREATE INDEX IF NOT EXISTS idx_day_trips2_dsut " "ON day_trips2 (day_start_ut)") + def drop_day_trip_indices(cur): - cur.execute('DROP INDEX IF EXISTS idx_day_trips2_tid') - cur.execute('DROP INDEX IF EXISTS idx_day_trips2_d') - cur.execute('DROP INDEX IF EXISTS idx_day_trips2_stut_etut') - cur.execute('DROP INDEX IF EXISTS idx_day_trips2_dsut') + cur.execute("DROP INDEX IF EXISTS idx_day_trips2_tid") + cur.execute("DROP INDEX IF EXISTS idx_day_trips2_d") + cur.execute("DROP INDEX IF EXISTS idx_day_trips2_stut_etut") + cur.execute("DROP INDEX IF EXISTS idx_day_trips2_dsut") + def insert_data_to_day_trips2(conn): cur = conn.cursor() - cur.execute('DELETE FROM day_trips2') - cur.execute('INSERT INTO day_trips2 ' - 'SELECT date, trip_I, ' - 'days.day_start_ut+trips.start_time_ds AS start_time_ut, ' - 'days.day_start_ut+trips.end_time_ds AS end_time_ut, ' - 'day_start_ut ' - 'FROM days JOIN trips USING (trip_I)') + cur.execute("DELETE FROM day_trips2") + cur.execute( + "INSERT INTO day_trips2 " + "SELECT date, trip_I, " + "days.day_start_ut+trips.start_time_ds AS start_time_ut, " + "days.day_start_ut+trips.end_time_ds AS end_time_ut, " + "day_start_ut " + "FROM days JOIN trips USING (trip_I)" + ) # Delete rows, where start_time_ut or end_time_ut IS NULL. # This could happen e.g. if stop_times are missing for some trip. cur.execute("DELETE FROM day_trips2 WHERE start_time_ut IS NULL or end_time_ut IS NULL") conn.commit() + def recreate_day_trips2_table(conn): drop_day_trip_indices(conn.cursor()) insert_data_to_day_trips2(conn) - create_day_trips_indices(conn.cursor()) \ No newline at end of file + create_day_trips_indices(conn.cursor()) diff --git a/gtfspy/import_loaders/feed_info_loader.py b/gtfspy/import_loaders/feed_info_loader.py index c38cec9..4e9eb2b 100644 --- a/gtfspy/import_loaders/feed_info_loader.py +++ b/gtfspy/import_loaders/feed_info_loader.py @@ -5,41 +5,52 @@ class FeedInfoLoader(TableLoader): """feed_info.txt: various feed metadata""" - fname = 'feed_info.txt' - table = 'feed_info' - tabledef = ('(feed_publisher_name TEXT, ' - 'feed_publisher_url TEXT, ' - 'feed_lang TEXT, ' - 'feed_start_date TEXT, ' - 'feed_end_date TEXT, ' - 'feed_version TEXT, ' - 'feed_id TEXT) ') + + fname = "feed_info.txt" + table = "feed_info" + tabledef = ( + "(feed_publisher_name TEXT, " + "feed_publisher_url TEXT, " + "feed_lang TEXT, " + "feed_start_date TEXT, " + "feed_end_date TEXT, " + "feed_version TEXT, " + "feed_id TEXT) " + ) def gen_rows(self, readers, prefixes): for reader, prefix in zip(readers, prefixes): for row in reader: - #print row - start = row['feed_start_date'] if 'feed_start_date' in row else None - end = row['feed_end_date'] if 'feed_end_date' in row else None + # print row + start = row["feed_start_date"] if "feed_start_date" in row else None + end = row["feed_end_date"] if "feed_end_date" in row else None yield dict( - feed_publisher_name = decode_six(row['feed_publisher_name']) if 'feed_publisher_name' in row else None, - feed_publisher_url = decode_six(row['feed_publisher_url']) if 'feed_publisher_url' in row else None, - feed_lang = decode_six(row['feed_lang']) if 'feed_lang' in row else None, - feed_start_date = '%s-%s-%s'%(start[:4], start[4:6], start[6:8]) if start else None, - feed_end_date = '%s-%s-%s'%(end[:4], end[4:6], end[6:8]) if end else None, - feed_version = decode_six(row['feed_version']) if 'feed_version' in row else None, - feed_id = prefix[:-1] if len(prefix) > 0 else prefix + feed_publisher_name=decode_six(row["feed_publisher_name"]) + if "feed_publisher_name" in row + else None, + feed_publisher_url=decode_six(row["feed_publisher_url"]) + if "feed_publisher_url" in row + else None, + feed_lang=decode_six(row["feed_lang"]) if "feed_lang" in row else None, + feed_start_date="%s-%s-%s" % (start[:4], start[4:6], start[6:8]) + if start + else None, + feed_end_date="%s-%s-%s" % (end[:4], end[4:6], end[6:8]) if end else None, + feed_version=decode_six(row["feed_version"]) if "feed_version" in row else None, + feed_id=prefix[:-1] if len(prefix) > 0 else prefix, ) def post_import2(self, conn): # TODO! Something whould be done with this! Multiple feeds are possible, currently only selects one row for all feeds G = GTFS(conn) - for name in ['feed_publisher_name', - 'feed_publisher_url', - 'feed_lang', - 'feed_start_date', - 'feed_end_date', - 'feed_version']: - value = conn.execute('SELECT %s FROM feed_info' % name).fetchone()[0] - if value: - G.meta['feed_info_' + name] = value \ No newline at end of file + for name in [ + "feed_publisher_name", + "feed_publisher_url", + "feed_lang", + "feed_start_date", + "feed_end_date", + "feed_version", + ]: + value = conn.execute("SELECT %s FROM feed_info" % name).fetchone()[0] + if value: + G.meta["feed_info_" + name] = value diff --git a/gtfspy/import_loaders/frequencies_loader.py b/gtfspy/import_loaders/frequencies_loader.py index 9516248..c7c61f1 100644 --- a/gtfspy/import_loaders/frequencies_loader.py +++ b/gtfspy/import_loaders/frequencies_loader.py @@ -6,35 +6,42 @@ class FrequenciesLoader(TableLoader): """Load the general frequency table.""" - fname = 'frequencies.txt' - table = 'frequencies' - - tabledef = (u'(trip_I INT, ' - u'start_time TEXT, ' - u'end_time TEXT, ' - u'headway_secs INT,' - u'exact_times INT, ' - u'start_time_ds INT, ' - u'end_time_ds INT' - u')') - extra_keys = [u'trip_I', - u'start_time_ds', - u'end_time_ds', - ] - extra_values = [u'(SELECT trip_I FROM trips WHERE trip_id=:_trip_id )', - '(substr(:start_time,-8,2)*3600 + substr(:start_time,-5,2)*60 + substr(:start_time,-2))', - '(substr(:end_time,-8,2)*3600 + substr(:end_time,-5,2)*60 + substr(:end_time,-2))', - ] + + fname = "frequencies.txt" + table = "frequencies" + + tabledef = ( + "(trip_I INT, " + "start_time TEXT, " + "end_time TEXT, " + "headway_secs INT," + "exact_times INT, " + "start_time_ds INT, " + "end_time_ds INT" + ")" + ) + extra_keys = [ + "trip_I", + "start_time_ds", + "end_time_ds", + ] + extra_values = [ + "(SELECT trip_I FROM trips WHERE trip_id=:_trip_id )", + "(substr(:start_time,-8,2)*3600 + substr(:start_time,-5,2)*60 + substr(:start_time,-2))", + "(substr(:end_time,-8,2)*3600 + substr(:end_time,-5,2)*60 + substr(:end_time,-2))", + ] def gen_rows(self, readers, prefixes): for reader, prefix in zip(readers, prefixes): for row in reader: yield dict( - _trip_id=prefix + decode_six(row['trip_id']), - start_time=row['start_time'], - end_time=row['end_time'], - headway_secs=int(row['headway_secs']), - exact_times=int(row['exact_times']) if 'exact_times' in row and row['exact_times'].isdigit() else 0 + _trip_id=prefix + decode_six(row["trip_id"]), + start_time=row["start_time"], + end_time=row["end_time"], + headway_secs=int(row["headway_secs"]), + exact_times=int(row["exact_times"]) + if "exact_times" in row and row["exact_times"].isdigit() + else 0, ) def post_import(self, cur): @@ -43,17 +50,24 @@ def post_import(self, cur): frequencies_df = pandas.read_sql("SELECT * FROM " + self.table, conn) for freq_tuple in frequencies_df.itertuples(): - trip_data = pandas.read_sql_query("SELECT * FROM trips WHERE trip_I= " + str(int(freq_tuple.trip_I)), conn) + trip_data = pandas.read_sql_query( + "SELECT * FROM trips WHERE trip_I= " + str(int(freq_tuple.trip_I)), conn + ) assert len(trip_data) == 1 trip_data = list(trip_data.itertuples())[0] freq_start_time_ds = freq_tuple.start_time_ds freq_end_time_ds = freq_tuple.end_time_ds - trip_duration = cur.execute("SELECT max(arr_time_ds) - min(dep_time_ds) " - "FROM stop_times " - "WHERE trip_I={trip_I}".format(trip_I=str(int(freq_tuple.trip_I))) - ).fetchone()[0] + trip_duration = cur.execute( + "SELECT max(arr_time_ds) - min(dep_time_ds) " + "FROM stop_times " + "WHERE trip_I={trip_I}".format(trip_I=str(int(freq_tuple.trip_I))) + ).fetchone()[0] if trip_duration is None: - raise ValueError("Stop times for frequency trip " + trip_data.trip_id + " are not properly defined") + raise ValueError( + "Stop times for frequency trip " + + trip_data.trip_id + + " are not properly defined" + ) headway = freq_tuple.headway_secs sql = "SELECT * FROM stop_times WHERE trip_I=" + str(trip_data.trip_I) + " ORDER BY seq" @@ -61,7 +75,7 @@ def post_import(self, cur): start_times_ds = range(freq_start_time_ds, freq_end_time_ds, headway) for i, start_time in enumerate(start_times_ds): - trip_id = trip_data.trip_id + u"_freq_" + str(start_time) + trip_id = trip_data.trip_id + "_freq_" + str(start_time) route_I = trip_data.route_I service_I = trip_data.service_I @@ -71,11 +85,22 @@ def post_import(self, cur): end_time_ds = start_time + trip_duration # insert these into trips - query = "INSERT INTO trips (trip_id, route_I, service_I, shape_id, direction_id, " \ - "headsign, start_time_ds, end_time_ds)" \ - " VALUES (?, ?, ?, ?, ?, ?, ?, ?)" + query = ( + "INSERT INTO trips (trip_id, route_I, service_I, shape_id, direction_id, " + "headsign, start_time_ds, end_time_ds)" + " VALUES (?, ?, ?, ?, ?, ?, ?, ?)" + ) - params = [trip_id, int(route_I), int(service_I), shape_id, direction_id, headsign, int(start_time), int(end_time_ds)] + params = [ + trip_id, + int(route_I), + int(service_I), + shape_id, + direction_id, + headsign, + int(start_time), + int(end_time_ds), + ] cur.execute(query, params) query = "SELECT trip_I FROM trips WHERE trip_id='{trip_id}'".format(trip_id=trip_id) @@ -83,12 +108,12 @@ def post_import(self, cur): # insert into stop_times # TODO! get the original data - dep_times_ds = stop_time_data['dep_time_ds'] + dep_times_ds = stop_time_data["dep_time_ds"] dep_times_ds = dep_times_ds - min(dep_times_ds) + start_time - arr_times_ds = stop_time_data['arr_time_ds'] + arr_times_ds = stop_time_data["arr_time_ds"] arr_times_ds = arr_times_ds - min(arr_times_ds) + start_time - shape_breaks_series = stop_time_data['shape_break'] - stop_Is = stop_time_data['stop_I'] + shape_breaks_series = stop_time_data["shape_break"] + stop_Is = stop_time_data["stop_I"] shape_breaks = [] for shape_break in shape_breaks_series: @@ -99,22 +124,37 @@ def post_import(self, cur): pass shape_breaks.append(value) - for seq, (dep_time_ds, arr_time_ds, shape_break, stop_I) in enumerate(zip(dep_times_ds, - arr_times_ds, - shape_breaks, - stop_Is)): + for seq, (dep_time_ds, arr_time_ds, shape_break, stop_I) in enumerate( + zip(dep_times_ds, arr_times_ds, shape_breaks, stop_Is) + ): arr_time_hour = int(arr_time_ds // 3600) - query = "INSERT INTO stop_times (trip_I, stop_I, arr_time, " \ - "dep_time, seq, arr_time_hour, shape_break, arr_time_ds, dep_time_ds) " \ - "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)" + query = ( + "INSERT INTO stop_times (trip_I, stop_I, arr_time, " + "dep_time, seq, arr_time_hour, shape_break, arr_time_ds, dep_time_ds) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)" + ) arr_time = util.day_seconds_to_str_time(arr_time_ds) dep_time = util.day_seconds_to_str_time(dep_time_ds) - cur.execute(query, (int(trip_I), int(stop_I), arr_time, dep_time, int(seq + 1), - int(arr_time_hour), shape_break, int(arr_time_ds), int(dep_time_ds))) - - trip_Is = frequencies_df['trip_I'].unique() + cur.execute( + query, + ( + int(trip_I), + int(stop_I), + arr_time, + dep_time, + int(seq + 1), + int(arr_time_hour), + shape_break, + int(arr_time_ds), + int(dep_time_ds), + ), + ) + + trip_Is = frequencies_df["trip_I"].unique() for trip_I in trip_Is: for table in ["trips", "stop_times"]: - cur.execute("DELETE FROM {table} WHERE trip_I={trip_I}".format(table=table, trip_I=trip_I)) - self._conn.commit() \ No newline at end of file + cur.execute( + "DELETE FROM {table} WHERE trip_I={trip_I}".format(table=table, trip_I=trip_I) + ) + self._conn.commit() diff --git a/gtfspy/import_loaders/metadata_loader.py b/gtfspy/import_loaders/metadata_loader.py index 2c20433..d577bcf 100644 --- a/gtfspy/import_loaders/metadata_loader.py +++ b/gtfspy/import_loaders/metadata_loader.py @@ -3,15 +3,15 @@ class MetadataLoader(TableLoader): """Table to be used for any type of metadata""" + fname = None - table = 'metadata' - tabledef = '(key TEXT UNIQUE NOT NULL, value BLOB, value2 BLOB)' + table = "metadata" + tabledef = "(key TEXT UNIQUE NOT NULL, value BLOB, value2 BLOB)" @classmethod def index(cls, cur): - cur.execute('CREATE INDEX IF NOT EXISTS idx_metadata_name ' - 'ON metadata (key)') + cur.execute("CREATE INDEX IF NOT EXISTS idx_metadata_name " "ON metadata (key)") @classmethod def copy(cls, conn, **where): - pass \ No newline at end of file + pass diff --git a/gtfspy/import_loaders/route_loader.py b/gtfspy/import_loaders/route_loader.py index 6716f4e..58cc140 100644 --- a/gtfspy/import_loaders/route_loader.py +++ b/gtfspy/import_loaders/route_loader.py @@ -2,43 +2,53 @@ class RouteLoader(TableLoader): - fname = 'routes.txt' - table = 'routes' - tabledef = '(route_I INTEGER PRIMARY KEY, ' \ - 'route_id TEXT UNIQUE NOT NULL, ' \ - 'agency_I INT, ' \ - 'name TEXT, ' \ - 'long_name TEXT, ' \ - 'desc TEXT, ' \ - 'type INT, ' \ - 'url TEXT, ' \ - 'color TEXT, ' \ - 'text_color TEXT' \ - ')' - extra_keys = ['agency_I', ] - extra_values = ['(SELECT agency_I FROM agencies WHERE agency_id=:_agency_id )', - ] + fname = "routes.txt" + table = "routes" + tabledef = ( + "(route_I INTEGER PRIMARY KEY, " + "route_id TEXT UNIQUE NOT NULL, " + "agency_I INT, " + "name TEXT, " + "long_name TEXT, " + "desc TEXT, " + "type INT, " + "url TEXT, " + "color TEXT, " + "text_color TEXT" + ")" + ) + extra_keys = [ + "agency_I", + ] + extra_values = [ + "(SELECT agency_I FROM agencies WHERE agency_id=:_agency_id )", + ] # route_id,agency_id,route_short_name,route_long_name,route_desc,route_type,route_url # 1001,HSL,1,Kauppatori - Kapyla,0,http://aikataulut.hsl.fi/linjat/fi/h1_1a.html def gen_rows(self, readers, prefixes): from gtfspy import extended_route_types + for reader, prefix in zip(readers, prefixes): for row in reader: - #print (row) + # print (row) yield dict( - route_id = prefix + decode_six(row['route_id']), - _agency_id = prefix + decode_six(row['agency_id']) if 'agency_id' in row else None, - name = decode_six(row['route_short_name']), - long_name = decode_six(row['route_long_name']), - desc = decode_six(row['route_desc']) if 'route_desc' in row else None, - type = extended_route_types.ROUTE_TYPE_CONVERSION[int(row['route_type'])], - url = decode_six(row['route_url']) if 'route_url' in row else None, - color = decode_six(row['route_color']) if 'route_color' in row else None, - text_color = decode_six(row['route_text_color']) if 'route_text_color' in row else None, + route_id=prefix + decode_six(row["route_id"]), + _agency_id=prefix + decode_six(row["agency_id"]) + if "agency_id" in row + else None, + name=decode_six(row["route_short_name"]), + long_name=decode_six(row["route_long_name"]), + desc=decode_six(row["route_desc"]) if "route_desc" in row else None, + type=extended_route_types.ROUTE_TYPE_CONVERSION[int(row["route_type"])], + url=decode_six(row["route_url"]) if "route_url" in row else None, + color=decode_six(row["route_color"]) if "route_color" in row else None, + text_color=decode_six(row["route_text_color"]) + if "route_text_color" in row + else None, ) @classmethod def index(cls, cur): # cur.execute('CREATE INDEX IF NOT EXISTS idx_rid ON route (route_id)') - cur.execute('CREATE INDEX IF NOT EXISTS idx_route_name ON routes (name)') \ No newline at end of file + cur.execute("CREATE INDEX IF NOT EXISTS idx_route_name ON routes (name)") diff --git a/gtfspy/import_loaders/shape_loader.py b/gtfspy/import_loaders/shape_loader.py index aa285f5..c928786 100644 --- a/gtfspy/import_loaders/shape_loader.py +++ b/gtfspy/import_loaders/shape_loader.py @@ -2,53 +2,52 @@ class ShapeLoader(TableLoader): - fname = 'shapes.txt' - table = 'shapes' - tabledef = '(shape_id TEXT, lat REAL, lon REAL, seq INT, d INT)' + fname = "shapes.txt" + table = "shapes" + tabledef = "(shape_id TEXT, lat REAL, lon REAL, seq INT, d INT)" # shape_id,shape_pt_lat,shape_pt_lon,shape_pt_sequence # 1001_20140811_1,60.167430,24.951684,1 def gen_rows(self, readers, prefixes): for reader, prefix in zip(readers, prefixes): for row in reader: - #print row + # print row yield dict( - shape_id = prefix + decode_six(row['shape_id']), - lat = float(row['shape_pt_lat']), - lon = float(row['shape_pt_lon']), - seq = int(row['shape_pt_sequence']) + shape_id=prefix + decode_six(row["shape_id"]), + lat=float(row["shape_pt_lat"]), + lon=float(row["shape_pt_lon"]), + seq=int(row["shape_pt_sequence"]), ) @classmethod def index(cls, cur): # cur.execute('CREATE INDEX IF NOT EXISTS idx_shapes_shid ON shapes (shape_id)') # cur.execute('CREATE INDEX IF NOT EXISTS idx_shapes_id_seq ON shapes (shape_I, seq)') - cur.execute('CREATE INDEX IF NOT EXISTS idx_shapes_id_seq ON shapes (shape_id, seq)') + cur.execute("CREATE INDEX IF NOT EXISTS idx_shapes_id_seq ON shapes (shape_id, seq)") @classmethod def post_import(cls, cur): from gtfspy import shapes - cur.execute('SELECT DISTINCT shape_id FROM shapes') + + cur.execute("SELECT DISTINCT shape_id FROM shapes") shape_ids = tuple(x[0] for x in cur) # print "Renumbering sequences to start from 0 and Calculating shape cumulative distances" for shape_id in shape_ids: - rows = cur.execute("SELECT shape_id, seq " - "FROM shapes " - "WHERE shape_id=? " - "ORDER BY seq", (shape_id,) + rows = cur.execute( + "SELECT shape_id, seq " "FROM shapes " "WHERE shape_id=? " "ORDER BY seq", + (shape_id,), ).fetchall() - cur.executemany("UPDATE shapes SET seq=? " - "WHERE shape_id=? AND seq=?", - ( (i, shape_id, seq) - for i, (shape_id, seq) in enumerate(rows)) - ) + cur.executemany( + "UPDATE shapes SET seq=? " "WHERE shape_id=? AND seq=?", + ((i, shape_id, seq) for i, (shape_id, seq) in enumerate(rows)), + ) for shape_id in shape_ids: shape_points = shapes.get_shape_points(cur, shape_id) shapes.gen_cumulative_distances(shape_points) - cur.executemany('UPDATE shapes SET d=? ' - 'WHERE shape_id=? AND seq=? ', - ((pt['d'], shape_id, pt['seq']) - for pt in shape_points)) \ No newline at end of file + cur.executemany( + "UPDATE shapes SET d=? " "WHERE shape_id=? AND seq=? ", + ((pt["d"], shape_id, pt["seq"]) for pt in shape_points), + ) diff --git a/gtfspy/import_loaders/stop_distances_loader.py b/gtfspy/import_loaders/stop_distances_loader.py index e214352..072f95d 100644 --- a/gtfspy/import_loaders/stop_distances_loader.py +++ b/gtfspy/import_loaders/stop_distances_loader.py @@ -12,10 +12,11 @@ class StopDistancesLoader(TableLoader): d_walk: distance, routed somehow min_transfer_time: from transfers.txt, in seconds. """ + # This loader is special. calc_transfers creates the table there, # too. We put a tabledef here so that copy() will work. fname = None - table = 'stop_distances' + table = "stop_distances" tabledef = calc_transfers.create_stmt threshold = 1000 @@ -35,86 +36,91 @@ def post_import(self, cur): # Add min transfer times (transfer_type=2). This just copies # min_transfer_time from `transfers` to `stop_distances`. - stmt = ('SELECT min_transfer_time, from_stop_I, to_stop_I ' - 'FROM transfers ' - 'WHERE transfer_type=2 ' - 'and from_stop_I!=to_stop_I') + stmt = ( + "SELECT min_transfer_time, from_stop_I, to_stop_I " + "FROM transfers " + "WHERE transfer_type=2 " + "and from_stop_I!=to_stop_I" + ) # First we have to run with INSERT OR IGNORE to add in any # rows that are missing. Unfortunately there is no INSERT OR # UPDATE, so we do this in two stages. First we insert any # missing rows (there is a unique constraint on (from_stop_I, # to_stop_I)) and then we update all rows. cur.execute(stmt) - cur2.executemany('INSERT OR IGNORE INTO stop_distances ' - '(min_transfer_time, from_stop_I, to_stop_I) ' - 'VALUES (?,?,?)', - cur) + cur2.executemany( + "INSERT OR IGNORE INTO stop_distances " + "(min_transfer_time, from_stop_I, to_stop_I) " + "VALUES (?,?,?)", + cur, + ) # Now, run again to do UPDATE any pre-existing rows. cur.execute(stmt) - cur2.executemany('UPDATE stop_distances ' - 'SET min_transfer_time=? ' - 'WHERE from_stop_I=? and to_stop_I=?', - cur) + cur2.executemany( + "UPDATE stop_distances " + "SET min_transfer_time=? " + "WHERE from_stop_I=? and to_stop_I=?", + cur, + ) conn.commit() # Add timed transfers (transfer_type=1). This is added with # timed_transfer=1 and min_transfer_time=0. Again, first we # add missing rows, and then we update the relevant rows. - stmt = ('SELECT from_stop_I, to_stop_I ' - 'FROM transfers ' - 'WHERE transfer_type=1 ' - 'and from_stop_I!=to_stop_I') + stmt = ( + "SELECT from_stop_I, to_stop_I " + "FROM transfers " + "WHERE transfer_type=1 " + "and from_stop_I!=to_stop_I" + ) cur.execute(stmt) - cur2.executemany('INSERT OR IGNORE INTO stop_distances ' - '(from_stop_I, to_stop_I) ' - 'VALUES (?,?)', - cur) + cur2.executemany( + "INSERT OR IGNORE INTO stop_distances " "(from_stop_I, to_stop_I) " "VALUES (?,?)", cur + ) cur.execute(stmt) - cur2.executemany('UPDATE stop_distances ' - 'SET timed_transfer=1, ' - ' min_transfer_time=0 ' - 'WHERE from_stop_I=? and to_stop_I=?', - cur) + cur2.executemany( + "UPDATE stop_distances " + "SET timed_transfer=1, " + " min_transfer_time=0 " + "WHERE from_stop_I=? and to_stop_I=?", + cur, + ) conn.commit() # Excluded transfers. Delete any transfer point with # transfer_type=3. cur = conn.cursor() cur2 = conn.cursor() - cur.execute('SELECT from_stop_I, to_stop_I ' - 'FROM transfers ' - 'WHERE transfer_type=3') - cur2.executemany('DELETE FROM stop_distances ' - 'WHERE from_stop_I=? and to_stop_I=?', - cur) + cur.execute("SELECT from_stop_I, to_stop_I " "FROM transfers " "WHERE transfer_type=3") + cur2.executemany("DELETE FROM stop_distances " "WHERE from_stop_I=? and to_stop_I=?", cur) conn.commit() # Calculate any `d`s missing because of inserted rows in the # previous two steps. - cur.execute('UPDATE stop_distances ' - 'SET d=CAST (find_distance(' - ' (SELECT lat FROM stops WHERE stop_I=from_stop_I), ' - ' (SELECT lon FROM stops WHERE stop_I=from_stop_I), ' - ' (SELECT lat FROM stops WHERE stop_I=to_stop_I), ' - ' (SELECT lon FROM stops WHERE stop_I=to_stop_I) ) ' - ' AS INT)' - 'WHERE d ISNULL' - ) + cur.execute( + "UPDATE stop_distances " + "SET d=CAST (find_distance(" + " (SELECT lat FROM stops WHERE stop_I=from_stop_I), " + " (SELECT lon FROM stops WHERE stop_I=from_stop_I), " + " (SELECT lat FROM stops WHERE stop_I=to_stop_I), " + " (SELECT lon FROM stops WHERE stop_I=to_stop_I) ) " + " AS INT)" + "WHERE d ISNULL" + ) conn.commit() def export_stop_distances(self, conn, f_out): cur = conn.cursor() - cur.execute('SELECT ' - 'from_stop_I, to_stop_I, ' - 'S1.lat, S1.lon, S2.lat, S2.lon, ' - 'd, ' - 'min_transfer_time ' - 'FROM stop_distances ' - 'LEFT JOIN stops S1 ON (from_stop_I=S1.stop_I)' - 'LEFT JOIN stops S2 ON (to_stop_I =S2.stop_I)' - ) - f_out.write('#from_stop_I,to_stop_I,' - 'lat1,lon1,lat2,lon2,' - 'd,min_transfer_time\n') + cur.execute( + "SELECT " + "from_stop_I, to_stop_I, " + "S1.lat, S1.lon, S2.lat, S2.lon, " + "d, " + "min_transfer_time " + "FROM stop_distances " + "LEFT JOIN stops S1 ON (from_stop_I=S1.stop_I)" + "LEFT JOIN stops S2 ON (to_stop_I =S2.stop_I)" + ) + f_out.write("#from_stop_I,to_stop_I," "lat1,lon1,lat2,lon2," "d,min_transfer_time\n") for row in cur: - f_out.write(','.join(str(x) for x in row) + '\n') + f_out.write(",".join(str(x) for x in row) + "\n") diff --git a/gtfspy/import_loaders/stop_loader.py b/gtfspy/import_loaders/stop_loader.py index 24c0364..067841c 100644 --- a/gtfspy/import_loaders/stop_loader.py +++ b/gtfspy/import_loaders/stop_loader.py @@ -5,9 +5,9 @@ class StopLoader(TableLoader): # This class is documented to explain what it does, others are not. # Metadata needed to create table. GTFS filename, table name, and # the CREATE TABLE syntax (last part only). - fname = 'stops.txt' - table = 'stops' - tabledef = '''(stop_I INTEGER PRIMARY KEY, stop_id TEXT UNIQUE NOT NULL, code TEXT, name TEXT, desc TEXT, lat REAL, lon REAL, parent_I INT, location_type INT, wheelchair_boarding BOOL, self_or_parent_I INT)''' + fname = "stops.txt" + table = "stops" + tabledef = """(stop_I INTEGER PRIMARY KEY, stop_id TEXT UNIQUE NOT NULL, code TEXT, name TEXT, desc TEXT, lat REAL, lon REAL, parent_I INT, location_type INT, wheelchair_boarding BOOL, self_or_parent_I INT)""" def gen_rows(self, readers, prefixes): for reader, prefix in zip(readers, prefixes): @@ -16,33 +16,38 @@ def gen_rows(self, readers, prefixes): # dictionary, which is yielded. There can be different # transformations here, as needed. yield dict( - stop_id = prefix + decode_six(row['stop_id']), - code = decode_six(row['stop_code']) if 'stop_code' in row else None, - name = decode_six(row['stop_name']), - desc = decode_six(row['stop_desc']) if 'stop_desc' in row else None, - lat = float(row['stop_lat']), - lon = float(row['stop_lon']), - _parent_id = prefix + decode_six(row['parent_station']) if row.get('parent_station','') else None, - location_type = int(row['location_type']) if row.get('location_type') else None, - wheelchair_boarding = int(row['wheelchair_boarding']) if row.get('wheelchair_boarding','') else None, + stop_id=prefix + decode_six(row["stop_id"]), + code=decode_six(row["stop_code"]) if "stop_code" in row else None, + name=decode_six(row["stop_name"]), + desc=decode_six(row["stop_desc"]) if "stop_desc" in row else None, + lat=float(row["stop_lat"]), + lon=float(row["stop_lon"]), + _parent_id=prefix + decode_six(row["parent_station"]) + if row.get("parent_station", "") + else None, + location_type=int(row["location_type"]) if row.get("location_type") else None, + wheelchair_boarding=int(row["wheelchair_boarding"]) + if row.get("wheelchair_boarding", "") + else None, ) def post_import(self, cur): # if parent_id, set also parent_I: # :_parent_id stands for a named parameter _parent_id # inputted through a dictionary in cur.executemany - stmt = ('UPDATE %s SET parent_I=CASE WHEN (:_parent_id IS NOT "") THEN ' - '(SELECT stop_I FROM %s WHERE stop_id=:_parent_id) END ' - 'WHERE stop_id=:stop_id') % (self.table, self.table) + stmt = ( + 'UPDATE %s SET parent_I=CASE WHEN (:_parent_id IS NOT "") THEN ' + "(SELECT stop_I FROM %s WHERE stop_id=:_parent_id) END " + "WHERE stop_id=:stop_id" + ) % (self.table, self.table) if self.exists(): cur.executemany(stmt, self.gen_rows0()) - stmt = 'UPDATE %s ' \ - 'SET self_or_parent_I=coalesce(parent_I, stop_I)' % self.table + stmt = "UPDATE %s " "SET self_or_parent_I=coalesce(parent_I, stop_I)" % self.table cur.execute(stmt) def index(self, cur): # Make indexes/ views as needed. - #cur.execute('CREATE INDEX IF NOT EXISTS idx_stop_sid ON stop (stop_id)') - #cur.execute('CREATE INDEX IF NOT EXISTS idx_stops_pid_sid ON stops (parent_id, stop_I)') - #conn.commit() - pass \ No newline at end of file + # cur.execute('CREATE INDEX IF NOT EXISTS idx_stop_sid ON stop (stop_id)') + # cur.execute('CREATE INDEX IF NOT EXISTS idx_stops_pid_sid ON stops (parent_id, stop_I)') + # conn.commit() + pass diff --git a/gtfspy/import_loaders/stop_times_loader.py b/gtfspy/import_loaders/stop_times_loader.py index 59e1a52..40c8a4b 100644 --- a/gtfspy/import_loaders/stop_times_loader.py +++ b/gtfspy/import_loaders/stop_times_loader.py @@ -2,39 +2,53 @@ class StopTimesLoader(TableLoader): - fname = 'stop_times.txt' - table = 'stop_times' - tabledef = ('(stop_I INT, trip_I INT, arr_time TEXT, dep_time TEXT, ' - 'seq INT, arr_time_hour INT, shape_break INT, ' - 'arr_time_ds INT, dep_time_ds INT)') - extra_keys = ['stop_I', - 'trip_I', - 'arr_time_ds', - 'dep_time_ds', - ] - extra_values = ['(SELECT stop_I FROM stops WHERE stop_id=:_stop_id )', - '(SELECT trip_I FROM trips WHERE trip_id=:_trip_id )', - '(substr(:arr_time,-8,2)*3600 + substr(:arr_time,-5,2)*60 + substr(:arr_time,-2))', - '(substr(:dep_time,-8,2)*3600 + substr(:dep_time,-5,2)*60 + substr(:dep_time,-2))', - ] + fname = "stop_times.txt" + table = "stop_times" + tabledef = ( + "(stop_I INT, trip_I INT, arr_time TEXT, dep_time TEXT, " + "seq INT, arr_time_hour INT, shape_break INT, " + "arr_time_ds INT, dep_time_ds INT)" + ) + extra_keys = [ + "stop_I", + "trip_I", + "arr_time_ds", + "dep_time_ds", + ] + extra_values = [ + "(SELECT stop_I FROM stops WHERE stop_id=:_stop_id )", + "(SELECT trip_I FROM trips WHERE trip_id=:_trip_id )", + "(substr(:arr_time,-8,2)*3600 + substr(:arr_time,-5,2)*60 + substr(:arr_time,-2))", + "(substr(:dep_time,-8,2)*3600 + substr(:dep_time,-5,2)*60 + substr(:dep_time,-2))", + ] # trip_id,arrival_time,departure_time,stop_id,stop_sequence,stop_headsign,pickup_type,drop_off_type,shape_dist_traveled # 1001_20150424_Ke_1_0953,09:53:00,09:53:00,1030423,1,,0,1,0.0000 def gen_rows(self, readers, prefixes): for reader, prefix in zip(readers, prefixes): for row in reader: - #print row - assert row['arrival_time'] != "", "Some stop_times entries is missing arrival time information." - assert row['departure_time'] != "", "Some stop_times entries is missing departure time information." - assert row['stop_sequence'] != "", "Some stop_times entries is missing seq information." - assert row['stop_id'] != "", "Some stop_times entries is missing stop_id information." - assert row['trip_id'] != "", "Some stop_times entries is missing trip_id information." + # print row + assert ( + row["arrival_time"] != "" + ), "Some stop_times entries is missing arrival time information." + assert ( + row["departure_time"] != "" + ), "Some stop_times entries is missing departure time information." + assert ( + row["stop_sequence"] != "" + ), "Some stop_times entries is missing seq information." + assert ( + row["stop_id"] != "" + ), "Some stop_times entries is missing stop_id information." + assert ( + row["trip_id"] != "" + ), "Some stop_times entries is missing trip_id information." yield dict( - _stop_id = prefix + decode_six(row['stop_id']), - _trip_id = prefix + decode_six(row['trip_id']), - arr_time = row['arrival_time'], - dep_time = row['departure_time'], - seq = int(row['stop_sequence']), + _stop_id=prefix + decode_six(row["stop_id"]), + _trip_id=prefix + decode_six(row["trip_id"]), + arr_time=row["arrival_time"], + dep_time=row["departure_time"], + seq=int(row["stop_sequence"]), ) def post_import(self, cur): @@ -42,27 +56,28 @@ def post_import(self, cur): # integer of the arrival time hour. Conversion to integer is # done in the sqlite engine, since the column affinity is # declared to be INT. - cur.execute('UPDATE stop_times SET arr_time_hour = substr(arr_time, -8, 2)') + cur.execute("UPDATE stop_times SET arr_time_hour = substr(arr_time, -8, 2)") calculate_trip_shape_breakpoints(self._conn) # Resequence seq value to increments of 1 starting from 1 resequence_stop_times_seq_values(self._conn) - @classmethod def index(cls, cur): - cur.execute('CREATE INDEX IF NOT EXISTS idx_stop_times_tid_seq ON stop_times (trip_I, seq)') + cur.execute("CREATE INDEX IF NOT EXISTS idx_stop_times_tid_seq ON stop_times (trip_I, seq)") # Do *not* use this index, use the one below - #cur.execute('CREATE INDEX idx_stop_times_tid_ath ON stop_times (trip_id, arr_time_hour)') + # cur.execute('CREATE INDEX idx_stop_times_tid_ath ON stop_times (trip_id, arr_time_hour)') # This is used for the stop frequencies analysis. - #cur.execute('CREATE INDEX idx_stop_times_tid_ath_sid ON stop_times (trip_I, arr_time_hour, stop_id)') - # ^-- much slower than the next index. - cur.execute('CREATE INDEX idx_stop_times_ath_tid_sid ON stop_times (arr_time_hour, trip_I, stop_I)') + # cur.execute('CREATE INDEX idx_stop_times_tid_ath_sid ON stop_times (trip_I, arr_time_hour, stop_id)') + # ^-- much slower than the next index. + cur.execute( + "CREATE INDEX idx_stop_times_ath_tid_sid ON stop_times (arr_time_hour, trip_I, stop_I)" + ) # This has now been moved to DayTripsMaterializer, but is left # here in case we someday want to make DayTripsMaterializer # optional. - #def make_views(self, conn): + # def make_views(self, conn): # conn.execute('DROP VIEW IF EXISTS main.day_stop_times') # conn.execute('CREATE VIEW day_stop_times AS ' # 'SELECT stop_times.*, trips.*, days.*, ' @@ -78,8 +93,10 @@ def index(cls, cur): def resequence_stop_times_seq_values(conn): cursor = conn.cursor() - rows = cursor.execute('SELECT ROWID, trip_I, seq FROM stop_times ORDER BY trip_I, seq').fetchall() - old_trip_I = '' + rows = cursor.execute( + "SELECT ROWID, trip_I, seq FROM stop_times ORDER BY trip_I, seq" + ).fetchall() + old_trip_I = "" correct_seq = 1 for row in rows: rowid = row[0] @@ -89,7 +106,7 @@ def resequence_stop_times_seq_values(conn): if old_trip_I != trip_I: correct_seq = 1 if seq != correct_seq: - cursor.execute('UPDATE stop_times SET seq = ? WHERE ROWID = ?', (correct_seq, rowid)) + cursor.execute("UPDATE stop_times SET seq = ? WHERE ROWID = ?", (correct_seq, rowid)) old_trip_I = trip_I correct_seq += 1 @@ -108,41 +125,44 @@ def calculate_trip_shape_breakpoints(conn): count_bad_shape_fit = 0 count_no_shape_fit = 0 - trip_Is = [x[0] for x in - cur.execute('SELECT DISTINCT trip_I FROM stop_times').fetchall()] + trip_Is = [x[0] for x in cur.execute("SELECT DISTINCT trip_I FROM stop_times").fetchall()] for trip_I in trip_Is: # Get the shape points - row = cur.execute('''SELECT shape_id - FROM trips WHERE trip_I=?''', (trip_I,)).fetchone() + row = cur.execute( + """SELECT shape_id + FROM trips WHERE trip_I=?""", + (trip_I,), + ).fetchone() if row is None: continue shape_id = row[0] - if shape_id is None or shape_id == '': + if shape_id is None or shape_id == "": continue # Get the stop points - cur.execute('''SELECT seq, lat, lon, stop_id + cur.execute( + """SELECT seq, lat, lon, stop_id FROM stop_times LEFT JOIN stops USING (stop_I) WHERE trip_I=? - ORDER BY seq''', - (trip_I,)) - #print '%20s, %s'%(run_code, datetime.fromtimestamp(run_sch_starttime)) - stop_points = [dict(seq=row[0], - lat=row[1], - lon=row[2], - stop_I=row[3]) - for row in cur if row[1] and row[2]] + ORDER BY seq""", + (trip_I,), + ) + # print '%20s, %s'%(run_code, datetime.fromtimestamp(run_sch_starttime)) + stop_points = [ + dict(seq=row[0], lat=row[1], lon=row[2], stop_I=row[3]) + for row in cur + if row[1] and row[2] + ] # Calculate a cache key for this sequence. # If both shape_id, and all stop_Is are same, then we can re-use existing breakpoints: - cache_key = (shape_id, tuple(x['stop_I'] for x in stop_points)) + cache_key = (shape_id, tuple(x["stop_I"] for x in stop_points)) if cache_key in breakpoints_cache: breakpoints = breakpoints_cache[cache_key] else: # Must re-calculate breakpoints: shape_points = shapes.get_shape_points(cur, shape_id) - breakpoints, badness \ - = shapes.find_segments(stop_points, shape_points) + breakpoints, badness = shapes.find_segments(stop_points, shape_points) if breakpoints != sorted(breakpoints): # route_name, route_id, route_I, trip_id, trip_I = \ # cur.execute('''SELECT name, route_id, route_I, trip_id, trip_I @@ -157,7 +177,7 @@ def calculate_trip_shape_breakpoints(conn): breakpoints_cache[cache_key] = breakpoints if badness > 30 * len(breakpoints): - #print "bad shape fit: %s (%s, %s, %s)" % (badness, trip_I, shape_id, len(breakpoints)) + # print "bad shape fit: %s (%s, %s, %s)" % (badness, trip_I, shape_id, len(breakpoints)) count_bad_shape_fit += 1 if breakpoints is None: @@ -165,20 +185,23 @@ def calculate_trip_shape_breakpoints(conn): if len(breakpoints) == 0: # No valid route could be identified. - #print "Ignoring: No shape identified for trip_I=%s, shape_id=%s" % (trip_I, shape_id) + # print "Ignoring: No shape identified for trip_I=%s, shape_id=%s" % (trip_I, shape_id) count_no_shape_fit += 1 continue # breakpoints is the corresponding points for each stop assert len(breakpoints) == len(stop_points) - cur.executemany('UPDATE stop_times SET shape_break=? ' - 'WHERE trip_I=? AND seq=? ', - ((int(bkpt), int(trip_I), int(stpt['seq'])) - for bkpt, stpt in zip(breakpoints, stop_points))) + cur.executemany( + "UPDATE stop_times SET shape_break=? " "WHERE trip_I=? AND seq=? ", + ( + (int(bkpt), int(trip_I), int(stpt["seq"])) + for bkpt, stpt in zip(breakpoints, stop_points) + ), + ) if count_bad_shape_fit > 0: print(" Shape trip breakpoints: %s bad fits" % count_bad_shape_fit) if count_bad_shape_ordering > 0: print(" Shape trip breakpoints: %s bad shape orderings" % count_bad_shape_ordering) if count_no_shape_fit > 0: print(" Shape trip breakpoints: %s no shape fits" % count_no_shape_fit) - conn.commit() \ No newline at end of file + conn.commit() diff --git a/gtfspy/import_loaders/table_loader.py b/gtfspy/import_loaders/table_loader.py index b09ef1e..875e8ef 100644 --- a/gtfspy/import_loaders/table_loader.py +++ b/gtfspy/import_loaders/table_loader.py @@ -20,7 +20,8 @@ class TableLoader(object): This class is just instantiated, and it does its stuff, and then it is destroyed. """ - mode = 'all' # None or 'import' or 'index'. "None" does everything. + + mode = "all" # None or 'import' or 'index'. "None" does everything. # The following properties need to be defined in a subclass. Examples here. # fname = 'routes.txt' @@ -35,8 +36,8 @@ class TableLoader(object): # Finally, a subclass needs to define these methods: # def gen_rows(self, reader): # def index(self): - extra_keys = [] - extra_values = [] + extra_keys = [] # type: ignore + extra_values = [] # type: ignore is_zipfile = False table = "" # e.g. stops for StopLoader @@ -88,19 +89,16 @@ def __init__(self, gtfssource=None, print_progress=True): if os.path.isdir(source): self.gtfs_sources.append(source) else: - z = zipfile.ZipFile(source, mode='r') + z = zipfile.ZipFile(source, mode="r") zip_commonprefix = os.path.commonprefix(z.namelist()) - zip_source_datum = { - "zipfile": source, - "zip_commonprefix": zip_commonprefix - } + zip_source_datum = {"zipfile": source, "zip_commonprefix": zip_commonprefix} self.gtfs_sources.append(zip_source_datum) # Methods that should be implemented by inheriting classes # when necessary. # - #def post_import(self, cur): - #def index(self, cur): + # def post_import(self, cur): + # def index(self, cur): # Methods for these classes: @@ -122,13 +120,13 @@ def exists_by_source(self): # Handle zipfiles specially if "zipfile" in source: try: - Z = zipfile.ZipFile(source['zipfile'], mode='r') - Z.getinfo(os.path.join(source['zip_commonprefix'], self.fname)) + Z = zipfile.ZipFile(source["zipfile"], mode="r") + Z.getinfo(os.path.join(source["zip_commonprefix"], self.fname)) exists_list.append(True) continue # File does not exist in the zip archive except KeyError: - print(self.fname, ' missing in ', source) + print(self.fname, " missing in ", source) exists_list.append(False) continue # Normal filename @@ -141,7 +139,13 @@ def exists_by_source(self): return exists_list def assert_exists_if_required(self): - REQUIRED_FILES_GTFS = ["agency.txt", "stops.txt", "routes.txt", "trips.txt", "stop_times.txt"] + REQUIRED_FILES_GTFS = [ + "agency.txt", + "stops.txt", + "routes.txt", + "trips.txt", + "stop_times.txt", + ] if self.fname in REQUIRED_FILES_GTFS: for gtfs_source, exists in zip(self.gtfs_sources, self.exists_by_source()): if not exists: @@ -189,17 +193,17 @@ def _iter_file(file_obj): f = data_obj elif "zipfile" in source: try: - Z = zipfile.ZipFile(source['zipfile'], mode='r') + Z = zipfile.ZipFile(source["zipfile"], mode="r") # print(Z.namelist()) - f = util.zip_open(Z, os.path.join(source['zip_commonprefix'], self.fname)) + f = util.zip_open(Z, os.path.join(source["zip_commonprefix"], self.fname)) except KeyError: pass elif isinstance(source, string_types): # now source is a directory try: f = open(os.path.join(source, self.fname)) - # except OSError as e: - except IOError as e: + # except OSError: + except IOError: f = [] fs.append(f) @@ -214,22 +218,26 @@ def _iter_file(file_obj): # `skipinitialspace` option, but let's make sure that we strip # it from both sides. # The following results in a generator, the complicated - csv_reader_stripped = (dict((k, (v.strip() if v is not None else None)) # v is not always a string - for k, v in row.items()) - for row in csv_reader) + csv_reader_stripped = ( + dict( + (k, (v.strip() if v is not None else None)) # v is not always a string + for k, v in row.items() + ) + for row in csv_reader + ) csv_reader_generators.append(csv_reader_stripped) except TypeError as e: if "NoneType" in str(e): print(self.fname + " missing from feed " + str(i)) csv_reader_generators.append(iter(())) - #raise e here will make every multifeed download with incompatible number of tables fail + # raise e here will make every multifeed download with incompatible number of tables fail else: raise e - prefixes = [u"feed_{i}_".format(i=i) for i in range(len(csv_reader_generators))] + prefixes = ["feed_{i}_".format(i=i) for i in range(len(csv_reader_generators))] if len(prefixes) == 1: # no prefix for a single source feed - prefixes = [u""] + prefixes = [""] return csv_reader_generators, prefixes def gen_rows(self, csv_readers, prefixes): @@ -243,15 +251,13 @@ def create_table(self, conn): # Drop table if it already exists, to be recreated. This # could in the future abort if table already exists, and not # recreate it from scratch. - #cur.execute('''DROP TABLE IF EXISTS %s'''%self.table) - #conn.commit() + # cur.execute('''DROP TABLE IF EXISTS %s'''%self.table) + # conn.commit() if self.tabledef is None: return - if not self.tabledef.startswith('CREATE'): + if not self.tabledef.startswith("CREATE"): # "normal" table creation. - cur.execute('CREATE TABLE IF NOT EXISTS %s %s' - % (self.table, self.tabledef) - ) + cur.execute("CREATE TABLE IF NOT EXISTS %s %s" % (self.table, self.tabledef)) else: # When tabledef contains the full CREATE statement (for # virtual tables). @@ -279,59 +285,60 @@ def insert_data(self, conn): # proceed. Since there is nothing to import, just continue the loop print("Not importing %s into %s for %s" % (self.fname, self.table, prefix)) continue - stmt = '''INSERT INTO %s (%s) VALUES (%s)''' % ( + stmt = """INSERT INTO %s (%s) VALUES (%s)""" % ( self.table, - (', '.join([x for x in fields if x[0] != '_'] + self.extra_keys)), - (', '.join([":" + x for x in fields if x[0] != '_'] + self.extra_values)) + (", ".join([x for x in fields if x[0] != "_"] + self.extra_keys)), + (", ".join([":" + x for x in fields if x[0] != "_"] + self.extra_values)), ) # This does the actual insertions. Passed the INSERT # statement and then an iterator over dictionaries. Each # dictionary is inserted. if self.print_progress: - print('Importing %s into %s for %s' % (self.fname, self.table, prefix)) + print("Importing %s into %s for %s" % (self.fname, self.table, prefix)) # the first row was consumed by fetching the fields # (this could be optimized) from itertools import chain + rows = chain([row], self.gen_rows([csv_reader], [prefix])) cur.executemany(stmt, rows) conn.commit() # This was used for debugging the missing service_I: # if self.__class__.__name__ == 'TripLoader': # and False: - # for i in self.gen_rows([new_csv_readers[i]], [prefix]): - # print(stmt) - # rows = cur.execute('SELECT agency_id, trips.service_id FROM agencies, routes, trips + # for i in self.gen_rows([new_csv_readers[i]], [prefix]): + # print(stmt) + # rows = cur.execute('SELECT agency_id, trips.service_id FROM agencies, routes, trips # LEFT JOIN calendar ON(calendar.service_id=trips.service_id) # WHERE trips.route_I = routes.route_I and routes.agency_I = agencies.agency_I and trips.service_I is NULL # GROUP BY trips.service_id, agency_id').fetchall() - # rows = cur.execute('SELECT distinct trips.service_id FROM trips + # rows = cur.execute('SELECT distinct trips.service_id FROM trips # LEFT JOIN calendar ON(calendar.service_id=trips.service_id) WHERE trips.service_I is NULL').fetchall() - # print('trips, etc', [description[0] for description in cur.description]) - # for i, row in enumerate(rows): - # print(row) - #if i == 100: - #exit(0) + # print('trips, etc', [description[0] for description in cur.description]) + # for i, row in enumerate(rows): + # print(row) + # if i == 100: + # exit(0) - # rows = cur.execute('SELECT distinct service_id FROM calendar').fetchall() - # print('calendar_columns',[description[0] for description in cur.description]) - # for row in rows: - # print(row) + # rows = cur.execute('SELECT distinct service_id FROM calendar').fetchall() + # print('calendar_columns',[description[0] for description in cur.description]) + # for row in rows: + # print(row) def run_post_import(self, conn): if self.print_progress: - print('Post-import %s into %s' % (self.fname, self.table)) + print("Post-import %s into %s" % (self.fname, self.table)) cur = conn.cursor() self.post_import(cur) conn.commit() def create_index(self, conn): - if not hasattr(self, 'index'): + if not hasattr(self, "index"): return cur = conn.cursor() if self.print_progress: - print('Indexing %s' % (self.table,)) + print("Indexing %s" % (self.table,)) self.index(cur) conn.commit() @@ -347,19 +354,24 @@ def import_(self, conn): after all tables are loaded. """ if self.print_progress: - print('Beginning', self.__class__.__name__) + print("Beginning", self.__class__.__name__) # what is this mystical self._conn ? self._conn = conn self.create_table(conn) # This does insertions - if self.mode in ('all', 'import') and self.fname and self.exists() and self.table not in ignore_tables: + if ( + self.mode in ("all", "import") + and self.fname + and self.exists() + and self.table not in ignore_tables + ): self.insert_data(conn) # This makes indexes in the DB. - if self.mode in ('all', 'index') and hasattr(self, 'index'): + if self.mode in ("all", "index") and hasattr(self, "index"): self.create_index(conn) # Any post-processing to be done after the full import. - if self.mode in ('all', 'import') and hasattr(self, 'post_import'): + if self.mode in ("all", "import") and hasattr(self, "post_import"): self.run_post_import(conn) # Commit it all conn.commit() @@ -369,7 +381,7 @@ def make_views(cls, conn): """The make views should be run after all tables imported.""" pass - copy_where = '' + copy_where = "" @classmethod def copy(cls, conn, **where): @@ -387,22 +399,23 @@ def copy(cls, conn, **where): copy_where = cls.copy_where.format(**where) # print(copy_where) else: - copy_where = '' - cur.execute('INSERT INTO %s ' - 'SELECT * FROM source.%s %s' % (cls.table, cls.table, copy_where)) + copy_where = "" + cur.execute( + "INSERT INTO %s " "SELECT * FROM source.%s %s" % (cls.table, cls.table, copy_where) + ) @classmethod def post_import_round2(cls, conn): pass -ignore_tables = set() +ignore_tables = set() # type: ignore def decode_six(string): version = sys.version_info[0] if version == 2: - return string.decode('utf-8') + return string.decode("utf-8") else: - assert(isinstance(string, str)) - return string \ No newline at end of file + assert isinstance(string, str) + return string diff --git a/gtfspy/import_loaders/transfer_loader.py b/gtfspy/import_loaders/transfer_loader.py index 2b31fe0..2adebe1 100644 --- a/gtfspy/import_loaders/transfer_loader.py +++ b/gtfspy/import_loaders/transfer_loader.py @@ -11,33 +11,33 @@ class TransfersLoader(TableLoader): 3: transfers not possible """ + # This loader is special. calc_transfers creates the table there, # too. We put a tabledef here so that copy() will work. - fname = 'transfers.txt' - table = 'transfers' + fname = "transfers.txt" + table = "transfers" # TODO: this is copy-pasted from calc_transfers. - tabledef = ('(from_stop_I INT, ' - 'to_stop_I INT, ' - 'transfer_type INT, ' - 'min_transfer_time INT' - ')') - extra_keys = ['from_stop_I', - 'to_stop_I', - ] - extra_values = ['(SELECT stop_I FROM stops WHERE stop_id=:_from_stop_id)', - '(SELECT stop_I FROM stops WHERE stop_id=:_to_stop_id)', - ] + tabledef = ( + "(from_stop_I INT, " "to_stop_I INT, " "transfer_type INT, " "min_transfer_time INT" ")" + ) + extra_keys = [ + "from_stop_I", + "to_stop_I", + ] + extra_values = [ + "(SELECT stop_I FROM stops WHERE stop_id=:_from_stop_id)", + "(SELECT stop_I FROM stops WHERE stop_id=:_to_stop_id)", + ] def gen_rows(self, readers, prefixes): for reader, prefix in zip(readers, prefixes): for row in reader: - #print row + # print row yield dict( - _from_stop_id = prefix + decode_six(row['from_stop_id']).strip(), - _to_stop_id = prefix + decode_six(row['to_stop_id']).strip(), - transfer_type = int(row['transfer_type']), - min_transfer_time = int(row['min_transfer_time']) - if ('min_transfer_time' in row - and (row.get('min_transfer_time').strip()) ) - else None - ) \ No newline at end of file + _from_stop_id=prefix + decode_six(row["from_stop_id"]).strip(), + _to_stop_id=prefix + decode_six(row["to_stop_id"]).strip(), + transfer_type=int(row["transfer_type"]), + min_transfer_time=int(row["min_transfer_time"]) + if ("min_transfer_time" in row and (row.get("min_transfer_time").strip())) + else None, + ) diff --git a/gtfspy/import_loaders/trip_loader.py b/gtfspy/import_loaders/trip_loader.py index a48ec46..826b69e 100644 --- a/gtfspy/import_loaders/trip_loader.py +++ b/gtfspy/import_loaders/trip_loader.py @@ -2,42 +2,49 @@ class TripLoader(TableLoader): - fname = 'trips.txt' - table = 'trips' + fname = "trips.txt" + table = "trips" # service_I INT NOT NULL - tabledef = ('(trip_I INTEGER PRIMARY KEY, trip_id TEXT UNIQUE NOT NULL, ' - 'route_I INT, service_I INT, direction_id TEXT, shape_id TEXT, ' - 'headsign TEXT, ' - 'start_time_ds INT, end_time_ds INT)') - extra_keys = ['route_I', 'service_I' ] #'shape_I'] - extra_values = ['(SELECT route_I FROM routes WHERE route_id=:_route_id )', - '(SELECT service_I FROM calendar WHERE service_id=:_service_id )', - #'(SELECT shape_I FROM shapes WHERE shape_id=:_shape_id )' - ] + tabledef = ( + "(trip_I INTEGER PRIMARY KEY, trip_id TEXT UNIQUE NOT NULL, " + "route_I INT, service_I INT, direction_id TEXT, shape_id TEXT, " + "headsign TEXT, " + "start_time_ds INT, end_time_ds INT)" + ) + extra_keys = ["route_I", "service_I"] # 'shape_I'] + extra_values = [ + "(SELECT route_I FROM routes WHERE route_id=:_route_id )", + "(SELECT service_I FROM calendar WHERE service_id=:_service_id )", + # '(SELECT shape_I FROM shapes WHERE shape_id=:_shape_id )' + ] # route_id,service_id,trip_id,trip_headsign,direction_id,shape_id,wheelchair_accessible,bikes_allowed # 1001,1001_20150424_20150426_Ke,1001_20150424_Ke_1_0953,"Kapyla",0,1001_20140811_1,1,2 def gen_rows(self, readers, prefixes): - #try: + # try: for reader, prefix in zip(readers, prefixes): for row in reader: - #print row - yield dict( - _route_id = prefix + decode_six(row['route_id']), - _service_id = prefix + decode_six(row['service_id']), - trip_id = prefix + decode_six(row['trip_id']), - direction_id = decode_six(row['direction_id']) if row.get('direction_id','') else None, - shape_id = prefix + decode_six(row['shape_id']) if row.get('shape_id','') else None, - headsign = decode_six(row['trip_headsign']) if 'trip_headsign' in row else None, - ) - #except: - #print(row) + # print row + yield dict( + _route_id=prefix + decode_six(row["route_id"]), + _service_id=prefix + decode_six(row["service_id"]), + trip_id=prefix + decode_six(row["trip_id"]), + direction_id=decode_six(row["direction_id"]) + if row.get("direction_id", "") + else None, + shape_id=prefix + decode_six(row["shape_id"]) + if row.get("shape_id", "") + else None, + headsign=decode_six(row["trip_headsign"]) if "trip_headsign" in row else None, + ) + # except: + # print(row) @classmethod def index(cls, cur): # cur.execute('CREATE INDEX IF NOT EXISTS idx_trips_tid ON trips (trip_id)') - cur.execute('CREATE INDEX IF NOT EXISTS idx_trips_svid ON trips (service_I)') - cur.execute('CREATE INDEX IF NOT EXISTS idx_trips_rid ON trips (route_I)') + cur.execute("CREATE INDEX IF NOT EXISTS idx_trips_svid ON trips (service_I)") + cur.execute("CREATE INDEX IF NOT EXISTS idx_trips_rid ON trips (route_I)") def post_import_round2(self, conn): update_trip_travel_times_ds(conn) @@ -45,8 +52,8 @@ def post_import_round2(self, conn): # This has now been moved to DayTripsMaterializer, but is left # here in case we someday want to make DayTripsMaterializer # optional. - #@classmethod - #def make_views(cls, conn): + # @classmethod + # def make_views(cls, conn): # conn.execute('DROP VIEW IF EXISTS main.day_trips') # conn.execute('CREATE VIEW day_trips AS ' # 'SELECT trips.*, days.*, ' @@ -59,26 +66,29 @@ def post_import_round2(self, conn): def update_trip_travel_times_ds(conn): cur0 = conn.cursor() cur = conn.cursor() - cur0.execute('''SELECT trip_I, min(dep_time), max(arr_time) + cur0.execute( + """SELECT trip_I, min(dep_time), max(arr_time) FROM trips JOIN stop_times USING (trip_I) - GROUP BY trip_I''') + GROUP BY trip_I""" + ) print("updating trips travel times") def iter_rows(cur0): for row in cur0: if row[1]: - st = row[1].split(':') + st = row[1].split(":") start_time_ds = int(st[0]) * 3600 + int(st[1]) * 60 + int(st[2]) else: start_time_ds = None if row[2]: - et = row[2].split(':') + et = row[2].split(":") end_time_ds = int(et[0]) * 3600 + int(et[1]) * 60 + int(et[2]) else: end_time_ds = None yield start_time_ds, end_time_ds, row[0] - cur.executemany('''UPDATE trips SET start_time_ds=?, end_time_ds=? WHERE trip_I=?''', - iter_rows(cur0)) + cur.executemany( + """UPDATE trips SET start_time_ds=?, end_time_ds=? WHERE trip_I=?""", iter_rows(cur0) + ) conn.commit() diff --git a/gtfspy/import_validator.py b/gtfspy/import_validator.py index 59bc844..952a63d 100644 --- a/gtfspy/import_validator.py +++ b/gtfspy/import_validator.py @@ -32,18 +32,18 @@ DANGLER_QUERIES = [ - 'SELECT count(*) FROM stops ' - 'LEFT JOIN stop_times ON(stop_times.stop_I=stops.stop_I) ' - 'LEFT JOIN stops as parents ON(stops.stop_I=parents.parent_I) ' - 'WHERE (stop_times.stop_I IS NULL AND parents.parent_I IS NULL)', - 'SELECT count(*) FROM stop_times LEFT JOIN stops ON(stop_times.stop_I=stops.stop_I) WHERE stops.stop_I IS NULL', - 'SELECT count(*) FROM stop_times LEFT JOIN trips ON(stop_times.trip_I=trips.trip_I) WHERE trips.trip_I IS NULL', - 'SELECT count(*) FROM trips LEFT JOIN stop_times ON(stop_times.trip_I=trips.trip_I) WHERE stop_times.trip_I IS NULL', - 'SELECT count(*) FROM trips LEFT JOIN days ON(days.trip_I=trips.trip_I) WHERE days.trip_I IS NULL', - 'SELECT count(*) FROM trips LEFT JOIN calendar ON(calendar.service_I=trips.service_I) WHERE trips.service_I IS NULL', - 'SELECT count(*) FROM trips LEFT JOIN routes ON(routes.route_I=trips.route_I) WHERE routes.route_I IS NULL', - 'SELECT count(*) FROM days LEFT JOIN trips ON(days.trip_I=trips.trip_I) WHERE trips.trip_I IS NULL', - 'SELECT count(*) FROM routes LEFT JOIN trips ON(routes.route_I=trips.route_I) WHERE trips.route_I IS NULL' + "SELECT count(*) FROM stops " + "LEFT JOIN stop_times ON(stop_times.stop_I=stops.stop_I) " + "LEFT JOIN stops as parents ON(stops.stop_I=parents.parent_I) " + "WHERE (stop_times.stop_I IS NULL AND parents.parent_I IS NULL)", + "SELECT count(*) FROM stop_times LEFT JOIN stops ON(stop_times.stop_I=stops.stop_I) WHERE stops.stop_I IS NULL", + "SELECT count(*) FROM stop_times LEFT JOIN trips ON(stop_times.trip_I=trips.trip_I) WHERE trips.trip_I IS NULL", + "SELECT count(*) FROM trips LEFT JOIN stop_times ON(stop_times.trip_I=trips.trip_I) WHERE stop_times.trip_I IS NULL", + "SELECT count(*) FROM trips LEFT JOIN days ON(days.trip_I=trips.trip_I) WHERE days.trip_I IS NULL", + "SELECT count(*) FROM trips LEFT JOIN calendar ON(calendar.service_I=trips.service_I) WHERE trips.service_I IS NULL", + "SELECT count(*) FROM trips LEFT JOIN routes ON(routes.route_I=trips.route_I) WHERE routes.route_I IS NULL", + "SELECT count(*) FROM days LEFT JOIN trips ON(days.trip_I=trips.trip_I) WHERE trips.trip_I IS NULL", + "SELECT count(*) FROM routes LEFT JOIN trips ON(routes.route_I=trips.route_I) WHERE trips.route_I IS NULL", ] DANGLER_WARNINGS = [ @@ -59,49 +59,65 @@ ] DB_TABLE_NAME_TO_FIELDS_WHERE_NULL_NOT_OK = { - 'agencies': ['agency_I', 'agency_id', "timezone"], - 'stops': ['stop_I', 'stop_id', 'lat', 'lon'], - 'routes': ['route_I', 'route_id', 'type'], - 'trips': ['trip_I', 'trip_id', 'service_I', "route_I"], - 'stop_times': ["trip_I", "stop_I", "arr_time_ds", "dep_time_ds"], - 'calendar': ['service_id', 'service_I', 'm', "t", "w", "th", "f", "s", "su", "start_date", "end_date"], - 'calendar_dates': ['service_I', 'date', 'exception_type'], - 'days': ["date","day_start_ut","trip_I"], - 'shapes': ["shape_id", "lat", "lon", "seq"], - 'stop_distances': ["from_stop_I", "to_stop_I", "d"] + "agencies": ["agency_I", "agency_id", "timezone"], + "stops": ["stop_I", "stop_id", "lat", "lon"], + "routes": ["route_I", "route_id", "type"], + "trips": ["trip_I", "trip_id", "service_I", "route_I"], + "stop_times": ["trip_I", "stop_I", "arr_time_ds", "dep_time_ds"], + "calendar": [ + "service_id", + "service_I", + "m", + "t", + "w", + "th", + "f", + "s", + "su", + "start_date", + "end_date", + ], + "calendar_dates": ["service_I", "date", "exception_type"], + "days": ["date", "day_start_ut", "trip_I"], + "shapes": ["shape_id", "lat", "lon", "seq"], + "stop_distances": ["from_stop_I", "to_stop_I", "d"], } DB_TABLE_NAME_TO_FIELDS_WHERE_NULL_OK_BUT_WARN = { - 'agencies': ['name', "url"], - 'stops': ['name'], - 'routes': ['name', 'long_name'], - 'trips': [], - 'calendar': [], - 'calendar_dates': [], - 'days': [], - 'shapes': [], - 'stop_times': [], - 'stop_distances': ["d_walk"] + "agencies": ["name", "url"], + "stops": ["name"], + "routes": ["name", "long_name"], + "trips": [], + "calendar": [], + "calendar_dates": [], + "days": [], + "shapes": [], + "stop_times": [], + "stop_distances": ["d_walk"], } DB_TABLE_NAMES = list(sorted(DB_TABLE_NAME_TO_FIELDS_WHERE_NULL_NOT_OK.keys())) DB_TABLE_NAME_TO_SOURCE_FILE = { - 'agencies': "agency", - 'routes': "routes", - 'trips': "trips", - 'calendar': "calendar", - 'calendar_dates': "calendar_dates", - 'stop_times': "stop_times", - 'stops': "stops", - "shapes": 'shapes' + "agencies": "agency", + "routes": "routes", + "trips": "trips", + "calendar": "calendar", + "calendar_dates": "calendar_dates", + "stop_times": "stop_times", + "stops": "stops", + "shapes": "shapes", } DB_TABLE_NAME_TO_ROWS_MISSING_WARNING = {} for _db_table_name in DB_TABLE_NAMES: - DB_TABLE_NAME_TO_ROWS_MISSING_WARNING[_db_table_name] = "Rows missing in {table}".format(table=_db_table_name) -DB_TABLE_NAME_TO_ROWS_MISSING_WARNING["calendar"] = "There are extra/missing rows in calendar that cannot be explained " \ - "by dummy entries required by the calendar_dates table." + DB_TABLE_NAME_TO_ROWS_MISSING_WARNING[_db_table_name] = "Rows missing in {table}".format( + table=_db_table_name + ) +DB_TABLE_NAME_TO_ROWS_MISSING_WARNING["calendar"] = ( + "There are extra/missing rows in calendar that cannot be explained " + "by dummy entries required by the calendar_dates table." +) for dictionary in [DB_TABLE_NAME_TO_SOURCE_FILE, DB_TABLE_NAME_TO_ROWS_MISSING_WARNING]: for key in dictionary.keys(): @@ -110,11 +126,10 @@ for key in DB_TABLE_NAME_TO_FIELDS_WHERE_NULL_OK_BUT_WARN.keys(): assert key in DB_TABLE_NAME_TO_FIELDS_WHERE_NULL_NOT_OK -#SOURCE_TABLE_NAMES = ['agency', 'routes', 'trips', 'calendar', 'calendar_dates', 'stop_times', 'stops', 'shapes'] +# SOURCE_TABLE_NAMES = ['agency', 'routes', 'trips', 'calendar', 'calendar_dates', 'stop_times', 'stops', 'shapes'] class ImportValidator(object): - def __init__(self, gtfssource, gtfs, verbose=True): """ Parameters @@ -130,7 +145,9 @@ def __init__(self, gtfssource, gtfs, verbose=True): else: assert isinstance(gtfssource, list) self.gtfs_sources = gtfssource - assert len(self.gtfs_sources) > 0, "There needs to be some source files for validating an import" + assert ( + len(self.gtfs_sources) > 0 + ), "There needs to be some source files for validating an import" if not isinstance(gtfs, GTFS): self.gtfs = GTFS(gtfs) @@ -139,7 +156,7 @@ def __init__(self, gtfssource, gtfs, verbose=True): self.location = self.gtfs.get_location_name() self.warnings_container = WarningsContainer() - self.verbose=verbose + self.verbose = verbose def validate_and_get_warnings(self): self.warnings_container.clear() @@ -163,13 +180,15 @@ def _validate_table_row_counts(self): # Row counts in source files: source_row_count = 0 for gtfs_source in self.gtfs_sources: - frequencies_in_source = source_csv_to_pandas(gtfs_source, 'frequencies.txt') + frequencies_in_source = source_csv_to_pandas(gtfs_source, "frequencies.txt") try: - if table_name_source_file == 'trips' and not frequencies_in_source.empty: + if table_name_source_file == "trips" and not frequencies_in_source.empty: source_row_count += self._frequency_generated_trips_rows(gtfs_source) - elif table_name_source_file == 'stop_times' and not frequencies_in_source.empty: - source_row_count += self._compute_number_of_frequency_generated_stop_times(gtfs_source) + elif table_name_source_file == "stop_times" and not frequencies_in_source.empty: + source_row_count += self._compute_number_of_frequency_generated_stop_times( + gtfs_source + ) else: df = source_csv_to_pandas(gtfs_source, table_name_source_file) @@ -180,51 +199,80 @@ def _validate_table_row_counts(self): else: raise e - if source_row_count == database_row_count and self.verbose: - print("Row counts match for " + table_name_source_file + " between the source and database (" - + str(database_row_count) + ")") + print( + "Row counts match for " + + table_name_source_file + + " between the source and database (" + + str(database_row_count) + + ")" + ) else: difference = database_row_count - source_row_count - ('Row counts do not match for ' + str(table_name_source_file) + ': (source=' + str(source_row_count) + - ', database=' + str(database_row_count) + ")") + ( + "Row counts do not match for " + + str(table_name_source_file) + + ": (source=" + + str(source_row_count) + + ", database=" + + str(database_row_count) + + ")" + ) if table_name_source_file == "calendar" and difference > 0: - query = "SELECT count(*) FROM (SELECT * FROM calendar ORDER BY service_I DESC LIMIT " \ - + str(int(difference)) + \ - ") WHERE start_date=end_date AND m=0 AND t=0 AND w=0 AND th=0 AND f=0 AND s=0 AND su=0" - number_of_entries_added_by_calendar_dates_loader = self.gtfs.execute_custom_query(query).fetchone()[ - 0] - if number_of_entries_added_by_calendar_dates_loader == difference and self.verbose: - print(" But don't worry, the extra entries seem to just dummy entries due to calendar_dates") + query = ( + "SELECT count(*) FROM (SELECT * FROM calendar ORDER BY service_I DESC LIMIT " + + str(int(difference)) + + ") WHERE start_date=end_date AND m=0 AND t=0 AND w=0 AND th=0 AND f=0 AND s=0 AND su=0" + ) + number_of_entries_added_by_calendar_dates_loader = self.gtfs.execute_custom_query( + query + ).fetchone()[ + 0 + ] + if ( + number_of_entries_added_by_calendar_dates_loader == difference + and self.verbose + ): + print( + " But don't worry, the extra entries seem to just dummy entries due to calendar_dates" + ) else: if self.verbose: print(" Reason for this is unknown.") - self.warnings_container.add_warning(row_warning_str, self.location, difference) + self.warnings_container.add_warning( + row_warning_str, self.location, difference + ) else: self.warnings_container.add_warning(row_warning_str, self.location, difference) - def _validate_no_null_values(self): """ Loads the tables from the gtfs object and counts the number of rows that have null values in fields that should not be null. Stores the number of null rows in warnings_container """ for table in DB_TABLE_NAMES: - null_not_ok_warning = "Null values in must-have columns in table {table}".format(table=table) - null_warn_warning = "Null values in good-to-have columns in table {table}".format(table=table) + null_not_ok_warning = "Null values in must-have columns in table {table}".format( + table=table + ) + null_warn_warning = "Null values in good-to-have columns in table {table}".format( + table=table + ) null_not_ok_fields = DB_TABLE_NAME_TO_FIELDS_WHERE_NULL_NOT_OK[table] null_warn_fields = DB_TABLE_NAME_TO_FIELDS_WHERE_NULL_OK_BUT_WARN[table] # CW, TODO: make this validation source by source df = self.gtfs.get_table(table) - for warning, fields in zip([null_not_ok_warning, null_warn_warning], [null_not_ok_fields, null_warn_fields]): + for warning, fields in zip( + [null_not_ok_warning, null_warn_warning], [null_not_ok_fields, null_warn_fields] + ): null_unwanted_df = df[fields] rows_having_null = null_unwanted_df.isnull().any(1) if sum(rows_having_null) > 0: rows_having_unwanted_null = df[rows_having_null.values] - self.warnings_container.add_warning(warning, rows_having_unwanted_null, len(rows_having_unwanted_null)) - + self.warnings_container.add_warning( + warning, rows_having_unwanted_null, len(rows_having_unwanted_null) + ) def _validate_danglers(self): """ @@ -252,13 +300,20 @@ def _frequency_generated_trips_rows(self, gtfs_soure_path, return_df_freq=False) param txt: txt file in question :return: sum of all trips """ - df_freq = source_csv_to_pandas(gtfs_soure_path, 'frequencies') + df_freq = source_csv_to_pandas(gtfs_soure_path, "frequencies") df_trips = source_csv_to_pandas(gtfs_soure_path, "trips") - df_freq['n_trips'] = df_freq.apply(lambda row: len(range(str_time_to_day_seconds(row['start_time']), - str_time_to_day_seconds(row['end_time']), - row['headway_secs'])), axis=1) - df_trips_freq = pd.merge(df_freq, df_trips, how='outer', on='trip_id') - n_freq_generated_trips = int(df_trips_freq['n_trips'].fillna(1).sum(axis=0)) + df_freq["n_trips"] = df_freq.apply( + lambda row: len( + range( + str_time_to_day_seconds(row["start_time"]), + str_time_to_day_seconds(row["end_time"]), + row["headway_secs"], + ) + ), + axis=1, + ) + df_trips_freq = pd.merge(df_freq, df_trips, how="outer", on="trip_id") + n_freq_generated_trips = int(df_trips_freq["n_trips"].fillna(1).sum(axis=0)) if return_df_freq: return df_trips_freq else: @@ -277,5 +332,5 @@ def _compute_number_of_frequency_generated_stop_times(self, gtfs_source_path): """ df_freq = self._frequency_generated_trips_rows(gtfs_source_path, return_df_freq=True) df_stop_times = source_csv_to_pandas(gtfs_source_path, "stop_times") - df_stop_freq = pd.merge(df_freq, df_stop_times, how='outer', on='trip_id') - return int(df_stop_freq['n_trips'].fillna(1).sum(axis=0)) + df_stop_freq = pd.merge(df_freq, df_stop_times, how="outer", on="trip_id") + return int(df_stop_freq["n_trips"].fillna(1).sum(axis=0)) diff --git a/gtfspy/mapviz.py b/gtfspy/mapviz.py index 28ddfa0..1deb9ea 100644 --- a/gtfspy/mapviz.py +++ b/gtfspy/mapviz.py @@ -1,16 +1,21 @@ +import math from urllib.error import URLError +import matplotlib as mpl +import matplotlib.pyplot as plt import numpy import smopy -import matplotlib.pyplot as plt from matplotlib import colors as mcolors -import math -from gtfspy.gtfs import GTFS -from gtfspy.stats import get_spatial_bounds, get_percentile_stop_bounds, get_median_lat_lon_of_stops -from gtfspy.route_types import ROUTE_TYPE_TO_COLOR, ROUTE_TYPE_TO_ZORDER, ROUTE_TYPE_TO_SHORT_DESCRIPTION -import matplotlib as mpl from matplotlib_scalebar.scalebar import ScaleBar + from gtfspy import util +from gtfspy.gtfs import GTFS +from gtfspy.route_types import ( + ROUTE_TYPE_TO_COLOR, + ROUTE_TYPE_TO_ZORDER, + ROUTE_TYPE_TO_SHORT_DESCRIPTION, +) +from gtfspy.stats import get_spatial_bounds, get_median_lat_lon_of_stops """ This module contains functions for plotting (static) visualizations of the public transport networks using matplotlib. @@ -27,7 +32,7 @@ "light_nolabels", "light_only_labels", "dark_nolabels", - "dark_only_labels" + "dark_only_labels", ] @@ -43,8 +48,16 @@ def _get_median_centered_plot_bounds(g): return plot_lon_min, plot_lon_max, plot_lat_min, plot_lat_max -def plot_route_network_from_gtfs(g, ax=None, spatial_bounds=None, map_alpha=0.8, scalebar=True, legend=True, - return_smopy_map=False, map_style=None): +def plot_route_network_from_gtfs( + g, + ax=None, + spatial_bounds=None, + map_alpha=0.8, + scalebar=True, + legend=True, + return_smopy_map=False, + map_style=None, +): """ Parameters ---------- @@ -62,7 +75,7 @@ def plot_route_network_from_gtfs(g, ax=None, spatial_bounds=None, map_alpha=0.8, ax: matplotlib.axes.Axes """ - assert(isinstance(g, GTFS)) + assert isinstance(g, GTFS) route_shapes = g.get_all_route_shapes() if spatial_bounds is None: @@ -71,18 +84,30 @@ def plot_route_network_from_gtfs(g, ax=None, spatial_bounds=None, map_alpha=0.8, bbox = ax.get_window_extent().transformed(ax.figure.dpi_scale_trans.inverted()) width, height = bbox.width, bbox.height spatial_bounds = _expand_spatial_bounds_to_fit_axes(spatial_bounds, width, height) - return plot_as_routes(route_shapes, - ax=ax, - spatial_bounds=spatial_bounds, - map_alpha=map_alpha, - plot_scalebar=scalebar, - legend=legend, - return_smopy_map=return_smopy_map, - map_style=map_style) - - -def plot_as_routes(route_shapes, ax=None, spatial_bounds=None, map_alpha=0.8, plot_scalebar=True, legend=True, - return_smopy_map=False, line_width_attribute=None, line_width_scale=1.0, map_style=None): + return plot_as_routes( + route_shapes, + ax=ax, + spatial_bounds=spatial_bounds, + map_alpha=map_alpha, + plot_scalebar=scalebar, + legend=legend, + return_smopy_map=return_smopy_map, + map_style=map_style, + ) + + +def plot_as_routes( + route_shapes, + ax=None, + spatial_bounds=None, + map_alpha=0.8, + plot_scalebar=True, + legend=True, + return_smopy_map=False, + line_width_attribute=None, + line_width_scale=1.0, + map_style=None, +): """ Parameters ---------- @@ -103,35 +128,45 @@ def plot_as_routes(route_shapes, ax=None, spatial_bounds=None, map_alpha=0.8, pl ------- ax: matplotlib.axes object """ - lon_min = spatial_bounds['lon_min'] - lon_max = spatial_bounds['lon_max'] - lat_min = spatial_bounds['lat_min'] - lat_max = spatial_bounds['lat_max'] + lon_min = spatial_bounds["lon_min"] + lon_max = spatial_bounds["lon_max"] + lat_min = spatial_bounds["lat_min"] + lat_max = spatial_bounds["lat_max"] if ax is None: fig = plt.figure() ax = fig.add_subplot(111) smopy_map = get_smopy_map(lon_min, lon_max, lat_min, lat_max, map_style=map_style) ax = smopy_map.show_mpl(figsize=None, ax=ax, alpha=map_alpha) - bound_pixel_xs, bound_pixel_ys = smopy_map.to_pixels(numpy.array([lat_min, lat_max]), - numpy.array([lon_min, lon_max])) + bound_pixel_xs, bound_pixel_ys = smopy_map.to_pixels( + numpy.array([lat_min, lat_max]), numpy.array([lon_min, lon_max]) + ) route_types_to_lines = {} for shape in route_shapes: - route_type = ROUTE_TYPE_CONVERSION[shape['type']] - lats = numpy.array(shape['lats']) - lons = numpy.array(shape['lons']) + route_type = ROUTE_TYPE_CONVERSION[shape["type"]] + lats = numpy.array(shape["lats"]) + lons = numpy.array(shape["lons"]) if line_width_attribute: line_width = line_width_scale * shape[line_width_attribute] else: line_width = 1 xs, ys = smopy_map.to_pixels(lats, lons) - line, = ax.plot(xs, ys, linewidth=line_width, color=ROUTE_TYPE_TO_COLOR[route_type], zorder=ROUTE_TYPE_TO_ZORDER[route_type]) + (line,) = ax.plot( + xs, + ys, + linewidth=line_width, + color=ROUTE_TYPE_TO_COLOR[route_type], + zorder=ROUTE_TYPE_TO_ZORDER[route_type], + ) route_types_to_lines[route_type] = line if legend: lines = list(route_types_to_lines.values()) - labels = [ROUTE_TYPE_TO_SHORT_DESCRIPTION[route_type] for route_type in route_types_to_lines.keys()] + labels = [ + ROUTE_TYPE_TO_SHORT_DESCRIPTION[route_type] + for route_type in route_types_to_lines.keys() + ] ax.legend(lines, labels, loc="upper left") if plot_scalebar: @@ -148,29 +183,37 @@ def plot_as_routes(route_shapes, ax=None, spatial_bounds=None, map_alpha=0.8, pl return ax -def plot_routes_as_stop_to_stop_network(from_lats, from_lons, to_lats, to_lons, attributes=None, color_attributes=None, - zorders=None, - line_labels=None, - ax=None, - spatial_bounds=None, - alpha=1, - map_alpha=0.8, - scalebar=True, - return_smopy_map=False, - c=None, linewidth=None, - linewidth_multiplier=1, - use_log_scale=False): +def plot_routes_as_stop_to_stop_network( + from_lats, + from_lons, + to_lats, + to_lons, + attributes=None, + color_attributes=None, + zorders=None, + line_labels=None, + ax=None, + spatial_bounds=None, + alpha=1, + map_alpha=0.8, + scalebar=True, + return_smopy_map=False, + c=None, + linewidth=None, + linewidth_multiplier=1, + use_log_scale=False, +): if attributes is None: - attributes = len(list(from_lats))*[None] + attributes = len(list(from_lats)) * [None] if not linewidth: linewidth = 1 if color_attributes is None: - color_attributes = len(list(from_lats))*[None] + color_attributes = len(list(from_lats)) * [None] assert c is not None if zorders is None: - zorders = len(list(from_lats))*[1] + zorders = len(list(from_lats)) * [1] if line_labels is None: - line_labels = len(list(from_lats))*[None] + line_labels = len(list(from_lats)) * [None] if spatial_bounds is None: lon_min = min(list(from_lons) + list(to_lons)) @@ -178,29 +221,23 @@ def plot_routes_as_stop_to_stop_network(from_lats, from_lons, to_lats, to_lons, lat_min = min(list(from_lats) + list(to_lats)) lat_max = max(list(from_lats) + list(to_lats)) else: - lon_min = spatial_bounds['lon_min'] - lon_max = spatial_bounds['lon_max'] - lat_min = spatial_bounds['lat_min'] - lat_max = spatial_bounds['lat_max'] + lon_min = spatial_bounds["lon_min"] + lon_max = spatial_bounds["lon_max"] + lat_min = spatial_bounds["lat_min"] + lat_max = spatial_bounds["lat_max"] if ax is None: fig = plt.figure() ax = fig.add_subplot(111) smopy_map = get_smopy_map(lon_min, lon_max, lat_min, lat_max) ax = smopy_map.show_mpl(figsize=None, ax=ax, alpha=map_alpha) - bound_pixel_xs, bound_pixel_ys = smopy_map.to_pixels(numpy.array([lat_min, lat_max]), - numpy.array([lon_min, lon_max])) - - for from_lat, from_lon, to_lat, to_lon, attribute, color_attribute, zorder, line_label in zip(from_lats, - from_lons, - to_lats, - to_lons, - attributes, - color_attributes, - zorders, - line_labels): - + bound_pixel_xs, bound_pixel_ys = smopy_map.to_pixels( + numpy.array([lat_min, lat_max]), numpy.array([lon_min, lon_max]) + ) + for from_lat, from_lon, to_lat, to_lon, attribute, color_attribute, zorder, line_label in zip( + from_lats, from_lons, to_lats, to_lons, attributes, color_attributes, zorders, line_labels + ): if color_attribute is None: color = c @@ -213,31 +250,50 @@ def plot_routes_as_stop_to_stop_network(from_lats, from_lons, to_lats, to_lons, if use_log_scale: attribute = math.log10(attribute) - xs, ys = smopy_map.to_pixels(numpy.array([from_lat, to_lat]), numpy.array([from_lon, to_lon])) - - ax.plot(xs, ys, color=color, linewidth=attribute*linewidth_multiplier, zorder=zorder, alpha=alpha) + xs, ys = smopy_map.to_pixels( + numpy.array([from_lat, to_lat]), numpy.array([from_lon, to_lon]) + ) + + ax.plot( + xs, + ys, + color=color, + linewidth=attribute * linewidth_multiplier, + zorder=zorder, + alpha=alpha, + ) if line_label: - ax.text(xs.mean(), ys.mean(), line_label, - # verticalalignment='bottom', horizontalalignment='right', - color='green', fontsize=15) + ax.text( + xs.mean(), + ys.mean(), + line_label, + # verticalalignment='bottom', horizontalalignment='right', + color="green", + fontsize=15, + ) legend = True if color_attributes[0] is not None else False import matplotlib.lines as mlines - + if legend: unique_types = set(color_attributes) lines = [] - + for i in unique_types: - line = mlines.Line2D([], [], color=ROUTE_TYPE_TO_COLOR[i], markersize=15, - label=ROUTE_TYPE_TO_SHORT_DESCRIPTION[i]) + line = mlines.Line2D( + [], + [], + color=ROUTE_TYPE_TO_COLOR[i], + markersize=15, + label=ROUTE_TYPE_TO_SHORT_DESCRIPTION[i], + ) lines.append(line) handles = lines - labels = [h.get_label() for h in handles] - + labels = [h.get_label() for h in handles] + ax.legend(handles=handles, labels=labels, loc=4) - + if scalebar: _add_scale_bar(ax, lat_max, lon_min, lon_max, bound_pixel_xs.max() - bound_pixel_xs.min()) @@ -272,40 +328,39 @@ def _expand_spatial_bounds_to_fit_axes(bounds, ax_width, ax_height): spatial_bounds """ b = bounds - height_meters = util.wgs84_distance(b['lat_min'], b['lon_min'], b['lat_max'], b['lon_min']) - width_meters = util.wgs84_distance(b['lat_min'], b['lon_min'], b['lat_min'], b['lon_max']) + height_meters = util.wgs84_distance(b["lat_min"], b["lon_min"], b["lat_max"], b["lon_min"]) + width_meters = util.wgs84_distance(b["lat_min"], b["lon_min"], b["lat_min"], b["lon_max"]) x_per_y_meters = width_meters / height_meters x_per_y_axes = ax_width / ax_height if x_per_y_axes > x_per_y_meters: # x-axis # axis x_axis has slack -> the spatial longitude bounds need to be extended - width_meters_new = (height_meters * x_per_y_axes) - d_lon_new = ((b['lon_max'] - b['lon_min']) / width_meters) * width_meters_new - mean_lon = (b['lon_min'] + b['lon_max'])/2. - lon_min = mean_lon - d_lon_new / 2. - lon_max = mean_lon + d_lon_new / 2. + width_meters_new = height_meters * x_per_y_axes + d_lon_new = ((b["lon_max"] - b["lon_min"]) / width_meters) * width_meters_new + mean_lon = (b["lon_min"] + b["lon_max"]) / 2.0 + lon_min = mean_lon - d_lon_new / 2.0 + lon_max = mean_lon + d_lon_new / 2.0 spatial_bounds = { "lon_min": lon_min, "lon_max": lon_max, - "lat_min": b['lat_min'], - "lat_max": b['lat_max'] + "lat_min": b["lat_min"], + "lat_max": b["lat_max"], } else: # axis y_axis has slack -> the spatial latitude bounds need to be extended - height_meters_new = (width_meters / x_per_y_axes) - d_lat_new = ((b['lat_max'] - b['lat_min']) / height_meters) * height_meters_new - mean_lat = (b['lat_min'] + b['lat_max']) / 2. - lat_min = mean_lat - d_lat_new / 2. - lat_max = mean_lat + d_lat_new / 2. + height_meters_new = width_meters / x_per_y_axes + d_lat_new = ((b["lat_max"] - b["lat_min"]) / height_meters) * height_meters_new + mean_lat = (b["lat_min"] + b["lat_max"]) / 2.0 + lat_min = mean_lat - d_lat_new / 2.0 + lat_max = mean_lat + d_lat_new / 2.0 spatial_bounds = { - "lon_min": b['lon_min'], - "lon_max": b['lon_max'], + "lon_min": b["lon_min"], + "lon_max": b["lon_max"], "lat_min": lat_min, - "lat_max": lat_max + "lat_max": lat_max, } return spatial_bounds - def plot_route_network_thumbnail(g, map_style=None): width = 512 # pixels height = 300 # pixels @@ -321,25 +376,38 @@ def plot_route_network_thumbnail(g, map_style=None): "lon_min": median_lon - dlon, "lon_max": median_lon + dlon, "lat_min": median_lat - dlat, - "lat_max": median_lat + dlat + "lat_max": median_lat + dlat, } - fig = plt.figure(figsize=(width/dpi, height/dpi)) + fig = plt.figure(figsize=(width / dpi, height / dpi)) ax = fig.add_subplot(111) plt.subplots_adjust(bottom=0.0, left=0.0, right=1.0, top=1.0) - return plot_route_network_from_gtfs(g, ax, spatial_bounds, map_alpha=1.0, scalebar=False, legend=False, map_style=map_style) - - -def plot_stops_with_categorical_attributes(lats_list, lons_list, attributes_list, s=0.5, spatial_bounds=None, colorbar=False, ax=None, cmap=None, norm=None, alpha=None): + return plot_route_network_from_gtfs( + g, ax, spatial_bounds, map_alpha=1.0, scalebar=False, legend=False, map_style=map_style + ) + + +def plot_stops_with_categorical_attributes( + lats_list, + lons_list, + attributes_list, + s=0.5, + spatial_bounds=None, + colorbar=False, + ax=None, + cmap=None, + norm=None, + alpha=None, +): if not spatial_bounds: lon_min = min([min(x) for x in lons_list]) lon_max = max([max(x) for x in lons_list]) lat_min = min([min(x) for x in lats_list]) lat_max = max([max(x) for x in lats_list]) else: - lon_min = spatial_bounds['lon_min'] - lon_max = spatial_bounds['lon_max'] - lat_min = spatial_bounds['lat_min'] - lat_max = spatial_bounds['lat_max'] + lon_min = spatial_bounds["lon_min"] + lon_max = spatial_bounds["lon_max"] + lat_min = spatial_bounds["lat_min"] + lat_max = spatial_bounds["lat_max"] smopy_map = get_smopy_map(lon_min, lon_max, lat_min, lat_max) if ax is None: fig = plt.figure() @@ -366,24 +434,37 @@ def plot_stops_with_categorical_attributes(lats_list, lons_list, attributes_list ax = smopy_map.show_mpl(figsize=None, ax=ax, alpha=0.8) axes = [] - for lats, lons, attributes, c in zip(lats_list, lons_list, attributes_list, mcolors.BASE_COLORS): + for lats, lons, attributes, c in zip( + lats_list, lons_list, attributes_list, mcolors.BASE_COLORS + ): x, y = zip(*[smopy_map.to_pixels(lat, lon) for lat, lon in zip(lats, lons)]) - ax = plt.scatter(x, y, s=s, c=c) #, marker=".") + ax = plt.scatter(x, y, s=s, c=c) # , marker=".") axes.append(ax) return axes -def plot_stops_with_attributes(lats, lons, attribute, s=0.5, spatial_bounds=None, colorbar=False, ax=None, cmap=None, norm=None, alpha=None): +def plot_stops_with_attributes( + lats, + lons, + attribute, + s=0.5, + spatial_bounds=None, + colorbar=False, + ax=None, + cmap=None, + norm=None, + alpha=None, +): if not spatial_bounds: lon_min = min(lons) lon_max = max(lons) lat_min = min(lats) lat_max = max(lats) else: - lon_min = spatial_bounds['lon_min'] - lon_max = spatial_bounds['lon_max'] - lat_min = spatial_bounds['lat_min'] - lat_max = spatial_bounds['lat_max'] + lon_min = spatial_bounds["lon_min"] + lon_max = spatial_bounds["lon_max"] + lat_min = spatial_bounds["lat_min"] + lat_max = spatial_bounds["lat_max"] smopy_map = get_smopy_map(lon_min, lon_max, lat_min, lat_max) if ax is None: fig = plt.figure() @@ -415,7 +496,7 @@ def plot_all_stops(g, ax=None, scalebar=False): ax: matplotlib.Axes """ - assert(isinstance(g, GTFS)) + assert isinstance(g, GTFS) lon_min, lon_max, lat_min, lat_max = get_spatial_bounds(g) smopy_map = get_smopy_map(lon_min, lon_max, lat_min, lat_max) if ax is None: @@ -424,8 +505,8 @@ def plot_all_stops(g, ax=None, scalebar=False): ax = smopy_map.show_mpl(figsize=None, ax=ax, alpha=0.8) stops = g.stops() - lats = numpy.array(stops['lat']) - lons = numpy.array(stops['lon']) + lats = numpy.array(stops["lat"]) + lons = numpy.array(stops["lon"]) xs, ys = smopy_map.to_pixels(lats, lons) ax.scatter(xs, ys, color="red", s=10) @@ -438,9 +519,11 @@ def plot_all_stops(g, ax=None, scalebar=False): def get_smopy_map(lon_min, lon_max, lat_min, lat_max, z=None, map_style=None): if map_style is not None: - assert map_style in MAP_STYLES, map_style + \ - " (map_style parameter) is not a valid CartoDB mapping style. Options are " + \ - str(MAP_STYLES) + assert map_style in MAP_STYLES, ( + map_style + + " (map_style parameter) is not a valid CartoDB mapping style. Options are " + + str(MAP_STYLES) + ) tileserver = "http://1.basemaps.cartocdn.com/" + map_style + "/{z}/{x}/{y}.png" else: tileserver = "http://1.basemaps.cartocdn.com/light_all/{z}/{x}/{y}.png" @@ -450,16 +533,20 @@ def get_smopy_map(lon_min, lon_max, lat_min, lat_max, z=None, map_style=None): kwargs = {} if z is not None: # this hack may not work smopy.Map.get_allowed_zoom = lambda self, z: z - kwargs['z'] = z + kwargs["z"] = z try: - get_smopy_map.maps[args] = smopy.Map((lat_min, lon_min, lat_max, lon_max), tileserver=tileserver, **kwargs) + get_smopy_map.maps[args] = smopy.Map( + (lat_min, lon_min, lat_max, lon_max), tileserver=tileserver, **kwargs + ) except URLError: - raise RuntimeError("\n Could not load background map from the tile server: " - + tileserver + - "\n Please check that the tile server exists and " - "that your are connected to the internet.") + raise RuntimeError( + "\n Could not load background map from the tile server: " + + tileserver + + "\n Please check that the tile server exists and " + "that your are connected to the internet." + ) return get_smopy_map.maps[args] -get_smopy_map.maps = {} +get_smopy_map.maps = {} # type: ignore diff --git a/gtfspy/networks.py b/gtfspy/networks.py index 03fae14..4ad9814 100644 --- a/gtfspy/networks.py +++ b/gtfspy/networks.py @@ -6,16 +6,20 @@ from warnings import warn ALL_STOP_TO_STOP_LINK_ATTRIBUTES = [ - "capacity_estimate", "duration_min", "duration_max", - "duration_median", "duration_avg", "n_vehicles", "route_types", - "d", "distance_shape", - "route_I_counts" + "capacity_estimate", + "duration_min", + "duration_max", + "duration_median", + "duration_avg", + "n_vehicles", + "route_types", + "d", + "distance_shape", + "route_I_counts", ] -DEFAULT_STOP_TO_STOP_LINK_ATTRIBUTES = [ - "n_vehicles", "duration_avg", - "d", "route_I_counts" -] +DEFAULT_STOP_TO_STOP_LINK_ATTRIBUTES = ["n_vehicles", "duration_avg", "d", "route_I_counts"] + def walk_transfer_stop_to_stop_network(gtfs, max_link_distance=None): """ @@ -46,8 +50,10 @@ def walk_transfer_stop_to_stop_network(gtfs, max_link_distance=None): stop_distances = gtfs.get_table("stop_distances") if stop_distances["d_walk"][0] is None: osm_distances_available = False - warn("Warning: OpenStreetMap-based walking distances have not been computed, using euclidean distances instead." - "Ignore this warning if running unit tests.") + warn( + "Warning: OpenStreetMap-based walking distances have not been computed, using euclidean distances instead." + "Ignore this warning if running unit tests." + ) else: osm_distances_available = True @@ -58,20 +64,18 @@ def walk_transfer_stop_to_stop_network(gtfs, max_link_distance=None): if osm_distances_available: if stop_distance_tuple.d_walk > max_link_distance or isnan(stop_distance_tuple.d_walk): continue - data = {'d': stop_distance_tuple.d, 'd_walk': stop_distance_tuple.d_walk} + data = {"d": stop_distance_tuple.d, "d_walk": stop_distance_tuple.d_walk} else: if stop_distance_tuple.d > max_link_distance: continue - data = {'d': stop_distance_tuple.d} + data = {"d": stop_distance_tuple.d} net.add_edge(from_node, to_node, data) return net -def stop_to_stop_network_for_route_type(gtfs, - route_type, - link_attributes=None, - start_time_ut=None, - end_time_ut=None): +def stop_to_stop_network_for_route_type( + gtfs, route_type, link_attributes=None, start_time_ut=None, end_time_ut=None +): """ Get a stop-to-stop network describing a single mode of travel. @@ -103,20 +107,20 @@ def stop_to_stop_network_for_route_type(gtfs, """ if link_attributes is None: link_attributes = DEFAULT_STOP_TO_STOP_LINK_ATTRIBUTES - assert(route_type in route_types.TRANSIT_ROUTE_TYPES) + assert route_type in route_types.TRANSIT_ROUTE_TYPES stops_dataframe = gtfs.get_stops_for_route_type(route_type) net = networkx.DiGraph() _add_stops_to_net(net, stops_dataframe) - events_df = gtfs.get_transit_events(start_time_ut=start_time_ut, - end_time_ut=end_time_ut, - route_type=route_type) + events_df = gtfs.get_transit_events( + start_time_ut=start_time_ut, end_time_ut=end_time_ut, route_type=route_type + ) if len(net.nodes()) < 2: assert events_df.shape[0] == 0 # group events by links, and loop over them (i.e. each link): - link_event_groups = events_df.groupby(['from_stop_I', 'to_stop_I'], sort=False) + link_event_groups = events_df.groupby(["from_stop_I", "to_stop_I"], sort=False) for key, link_events in link_event_groups: from_stop_I, to_stop_I = key assert isinstance(link_events, pd.DataFrame) @@ -126,26 +130,27 @@ def stop_to_stop_network_for_route_type(gtfs, else: link_data = {} if "duration_min" in link_attributes: - link_data['duration_min'] = float(link_events['duration'].min()) + link_data["duration_min"] = float(link_events["duration"].min()) if "duration_max" in link_attributes: - link_data['duration_max'] = float(link_events['duration'].max()) + link_data["duration_max"] = float(link_events["duration"].max()) if "duration_median" in link_attributes: - link_data['duration_median'] = float(link_events['duration'].median()) + link_data["duration_median"] = float(link_events["duration"].median()) if "duration_avg" in link_attributes: - link_data['duration_avg'] = float(link_events['duration'].mean()) + link_data["duration_avg"] = float(link_events["duration"].mean()) # statistics on numbers of vehicles: if "n_vehicles" in link_attributes: - link_data['n_vehicles'] = int(link_events.shape[0]) + link_data["n_vehicles"] = int(link_events.shape[0]) if "capacity_estimate" in link_attributes: - link_data['capacity_estimate'] = route_types.ROUTE_TYPE_TO_APPROXIMATE_CAPACITY[route_type] \ - * int(link_events.shape[0]) + link_data["capacity_estimate"] = route_types.ROUTE_TYPE_TO_APPROXIMATE_CAPACITY[ + route_type + ] * int(link_events.shape[0]) if "d" in link_attributes: - from_lat = net.node[from_stop_I]['lat'] - from_lon = net.node[from_stop_I]['lon'] - to_lat = net.node[to_stop_I]['lat'] - to_lon = net.node[to_stop_I]['lon'] + from_lat = net.node[from_stop_I]["lat"] + from_lon = net.node[from_stop_I]["lon"] + to_lat = net.node[to_stop_I]["lat"] + to_lon = net.node[to_stop_I]["lon"] distance = wgs84_distance(from_lat, from_lon, to_lat, to_lon) - link_data['d'] = int(distance) + link_data["d"] = int(distance) if "distance_shape" in link_attributes: assert "shape_id" in link_events.columns.values found = None @@ -158,11 +163,9 @@ def stop_to_stop_network_for_route_type(gtfs, else: link_event = link_events.iloc[found] distance = gtfs.get_shape_distance_between_stops( - link_event["trip_I"], - int(link_event["from_seq"]), - int(link_event["to_seq"]) + link_event["trip_I"], int(link_event["from_seq"]), int(link_event["to_seq"]) ) - link_data['distance_shape'] = distance + link_data["distance_shape"] = distance if "route_I_counts" in link_attributes: link_data["route_I_counts"] = link_events.groupby("route_I").size().to_dict() net.add_edge(from_stop_I, to_stop_I, attr_dict=link_data) @@ -192,6 +195,7 @@ def stop_to_stop_networks_by_type(gtfs): assert len(route_type_to_network) == len(route_types.ALL_ROUTE_TYPES) return route_type_to_network + def combined_stop_to_stop_transit_network(gtfs, start_time_ut=None, end_time_ut=None): """ Compute stop-to-stop networks for all travel modes and combine them into a single network. @@ -210,14 +214,16 @@ def combined_stop_to_stop_transit_network(gtfs, start_time_ut=None, end_time_ut= """ multi_di_graph = networkx.MultiDiGraph() for route_type in route_types.TRANSIT_ROUTE_TYPES: - graph = stop_to_stop_network_for_route_type(gtfs, route_type, - start_time_ut=start_time_ut, end_time_ut=end_time_ut) + graph = stop_to_stop_network_for_route_type( + gtfs, route_type, start_time_ut=start_time_ut, end_time_ut=end_time_ut + ) for from_node, to_node, data in graph.edges(data=True): - data['route_type'] = route_type + data["route_type"] = route_type multi_di_graph.add_edges_from(graph.edges(data=True)) multi_di_graph.add_nodes_from(graph.nodes(data=True)) return multi_di_graph + def _add_stops_to_net(net, stops): """ Add nodes to the network from the pandas dataframe describing (a part of the) stops table in the GTFS database. @@ -228,18 +234,11 @@ def _add_stops_to_net(net, stops): stops: pandas.DataFrame """ for stop in stops.itertuples(): - data = { - "lat": stop.lat, - "lon": stop.lon, - "name": stop.name - } + data = {"lat": stop.lat, "lon": stop.lon, "name": stop.name} net.add_node(stop.stop_I, data) -def temporal_network(gtfs, - start_time_ut=None, - end_time_ut=None, - route_type=None): +def temporal_network(gtfs, start_time_ut=None, end_time_ut=None, route_type=None): """ Compute the temporal network of the data, and return it as a pandas.DataFrame @@ -261,19 +260,14 @@ def temporal_network(gtfs, events_df: pandas.DataFrame Columns: departure_stop, arrival_stop, departure_time_ut, arrival_time_ut, route_type, route_I, trip_I """ - events_df = gtfs.get_transit_events(start_time_ut=start_time_ut, - end_time_ut=end_time_ut, - route_type=route_type) - events_df.drop('to_seq', 1, inplace=True) - events_df.drop('shape_id', 1, inplace=True) - events_df.drop('duration', 1, inplace=True) - events_df.drop('route_id', 1, inplace=True) - events_df.rename( - columns={ - 'from_seq': "seq" - }, - inplace=True + events_df = gtfs.get_transit_events( + start_time_ut=start_time_ut, end_time_ut=end_time_ut, route_type=route_type ) + events_df.drop("to_seq", 1, inplace=True) + events_df.drop("shape_id", 1, inplace=True) + events_df.drop("duration", 1, inplace=True) + events_df.drop("route_id", 1, inplace=True) + events_df.rename(columns={"from_seq": "seq"}, inplace=True) return events_df @@ -291,20 +285,26 @@ def route_to_route_network(gtfs, walking_threshold, start_time, end_time): routes = gtfs.get_table("routes") for i in routes.itertuples(): - graph.add_node(i.route_id, attr_dict={"type": i.type, "color": route_types.ROUTE_TYPE_TO_COLOR[i.type]}) - + graph.add_node( + i.route_id, attr_dict={"type": i.type, "color": route_types.ROUTE_TYPE_TO_COLOR[i.type]} + ) query = """SELECT stop1.route_id AS route_id1, stop1.type, stop2.route_id AS route_id2, stop2.type FROM (SELECT * FROM stop_distances WHERE d_walk < %s) sd, - (SELECT * FROM stop_times, trips, routes - WHERE stop_times.trip_I=trips.trip_I AND trips.route_I=routes.route_I + (SELECT * FROM stop_times, trips, routes + WHERE stop_times.trip_I=trips.trip_I AND trips.route_I=routes.route_I AND stop_times.dep_time_ds > %s AND stop_times.dep_time_ds < %s) stop1, - (SELECT * FROM stop_times, trips, routes - WHERE stop_times.trip_I=trips.trip_I AND trips.route_I=routes.route_I + (SELECT * FROM stop_times, trips, routes + WHERE stop_times.trip_I=trips.trip_I AND trips.route_I=routes.route_I AND stop_times.dep_time_ds > %s AND stop_times.dep_time_ds < %s) stop2 WHERE sd.from_stop_I = stop1.stop_I AND sd.to_stop_I = stop2.stop_I AND stop1.route_id != stop2.route_id - GROUP BY stop1.route_id, stop2.route_id""" % (walking_threshold, start_time, end_time, start_time, - end_time) + GROUP BY stop1.route_id, stop2.route_id""" % ( + walking_threshold, + start_time, + end_time, + start_time, + end_time, + ) df = gtfs.execute_custom_query_pandas(query) for items in df.itertuples(): @@ -313,8 +313,6 @@ def route_to_route_network(gtfs, walking_threshold, start_time, end_time): return graph - - # def cluster_network_stops(stop_to_stop_net, distance): # """ # Aggregate graph by grouping nodes that are within a specified distance. diff --git a/gtfspy/osm_transfers.py b/gtfspy/osm_transfers.py index 1aca8d5..f1aa8c9 100644 --- a/gtfspy/osm_transfers.py +++ b/gtfspy/osm_transfers.py @@ -1,16 +1,13 @@ import os +from warnings import warn import networkx -import pandas +from geoindex import GeoGridIndex, GeoPoint from osmread import parse_file, Way, Node from gtfspy.gtfs import GTFS from gtfspy.util import wgs84_distance -from warnings import warn - -from geoindex import GeoGridIndex, GeoPoint - def add_walk_distances_to_db_python(gtfs, osm_path, cutoff_distance_m=1000): """ @@ -36,11 +33,13 @@ def add_walk_distances_to_db_python(gtfs, osm_path, cutoff_distance_m=1000): """ if isinstance(gtfs, str): gtfs = GTFS(gtfs) - assert (isinstance(gtfs, GTFS)) + assert isinstance(gtfs, GTFS) print("Reading in walk network") walk_network = create_walk_network_from_osm(osm_path) print("Matching stops to the OSM network") - stop_I_to_nearest_osm_node, stop_I_to_nearest_osm_node_distance = match_stops_to_nodes(gtfs, walk_network) + stop_I_to_nearest_osm_node, stop_I_to_nearest_osm_node_distance = match_stops_to_nodes( + gtfs, walk_network + ) transfers = gtfs.get_straight_line_transfer_distances() @@ -54,22 +53,31 @@ def add_walk_distances_to_db_python(gtfs, osm_path, cutoff_distance_m=1000): for from_I, to_stop_Is in from_I_to_to_stop_Is.items(): from_node = stop_I_to_nearest_osm_node[from_I] from_dist = stop_I_to_nearest_osm_node_distance[from_I] - shortest_paths = networkx.single_source_dijkstra_path_length(walk_network, - from_node, - cutoff=cutoff_distance_m - from_dist, - weight="distance") + shortest_paths = networkx.single_source_dijkstra_path_length( + walk_network, from_node, cutoff=cutoff_distance_m - from_dist, weight="distance" + ) for to_I in to_stop_Is: to_distance = stop_I_to_nearest_osm_node_distance[to_I] to_node = stop_I_to_nearest_osm_node[to_I] - osm_distance = shortest_paths.get(to_node, float('inf')) + osm_distance = shortest_paths.get(to_node, float("inf")) total_distance = from_dist + osm_distance + to_distance - from_stop_I_transfers = transfers[transfers['from_stop_I'] == from_I] - straigth_distance = from_stop_I_transfers[from_stop_I_transfers["to_stop_I"] == to_I]["d"].values[0] - assert (straigth_distance < total_distance + 2) # allow for a maximum of 2 meters in calculations + from_stop_I_transfers = transfers[transfers["from_stop_I"] == from_I] + straigth_distance = from_stop_I_transfers[from_stop_I_transfers["to_stop_I"] == to_I][ + "d" + ].values[0] + assert ( + straigth_distance < total_distance + 2 + ) # allow for a maximum of 2 meters in calculations if total_distance <= cutoff_distance_m: - gtfs.conn.execute("UPDATE stop_distances " - "SET d_walk = " + str(int(total_distance)) + - " WHERE from_stop_I=" + str(from_I) + " AND to_stop_I=" + str(to_I)) + gtfs.conn.execute( + "UPDATE stop_distances " + "SET d_walk = " + + str(int(total_distance)) + + " WHERE from_stop_I=" + + str(from_I) + + " AND to_stop_I=" + + str(to_I) + ) gtfs.conn.commit() @@ -90,19 +98,19 @@ def match_stops_to_nodes(gtfs, walk_network): """ network_nodes = walk_network.nodes(data="true") - stop_Is = set(gtfs.get_straight_line_transfer_distances()['from_stop_I']) + stop_Is = set(gtfs.get_straight_line_transfer_distances()["from_stop_I"]) stops_df = gtfs.stops() geo_index = GeoGridIndex(precision=6) for net_node, data in network_nodes: - geo_index.add_point(GeoPoint(data['lat'], data['lon'], ref=net_node)) + geo_index.add_point(GeoPoint(data["lat"], data["lon"], ref=net_node)) stop_I_to_node = {} stop_I_to_dist = {} for stop_I in stop_Is: stop_lat = float(stops_df[stops_df.stop_I == stop_I].lat) stop_lon = float(stops_df[stops_df.stop_I == stop_I].lon) geo_point = GeoPoint(stop_lat, stop_lon) - min_dist = float('inf') + min_dist = float("inf") min_dist_node = None search_distances_m = [0.100, 0.500] for search_distance_m in search_distances_m: @@ -119,14 +127,29 @@ def match_stops_to_nodes(gtfs, walk_network): return stop_I_to_node, stop_I_to_dist -OSM_HIGHWAY_WALK_TAGS = {"trunk", "trunk_link", "primary", "primary_link", "secondary", "secondary_link", "tertiary", - "tertiary_link", "unclassified", "residential", "living_street", "road", "pedestrian", "path", - "cycleway", "footway"} +OSM_HIGHWAY_WALK_TAGS = { + "trunk", + "trunk_link", + "primary", + "primary_link", + "secondary", + "secondary_link", + "tertiary", + "tertiary_link", + "unclassified", + "residential", + "living_street", + "road", + "pedestrian", + "path", + "cycleway", + "footway", +} def create_walk_network_from_osm(osm_file): walk_network = networkx.Graph() - assert (os.path.exists(osm_file)) + assert os.path.exists(osm_file) ways = [] for i, entity in enumerate(parse_file(osm_file)): if isinstance(entity, Node): @@ -141,16 +164,15 @@ def create_walk_network_from_osm(osm_file): # Remove all singleton nodes (note that taking the giant component does not necessarily provide proper results. for node, degree in walk_network.degree().items(): - if degree is 0: + if degree == 0: walk_network.remove_node(node) - node_lats = networkx.get_node_attributes(walk_network, 'lat') - node_lons = networkx.get_node_attributes(walk_network, 'lon') + node_lats = networkx.get_node_attributes(walk_network, "lat") + node_lons = networkx.get_node_attributes(walk_network, "lon") for source, dest, data in walk_network.edges(data=True): - data["distance"] = wgs84_distance(node_lats[source], - node_lons[source], - node_lats[dest], - node_lons[dest]) + data["distance"] = wgs84_distance( + node_lats[source], node_lons[source], node_lats[dest], node_lons[dest] + ) return walk_network diff --git a/gtfspy/plots.py b/gtfspy/plots.py index 082888c..b4a9b44 100644 --- a/gtfspy/plots.py +++ b/gtfspy/plots.py @@ -6,7 +6,9 @@ """ -def plot_trip_counts_per_day(G, ax=None, highlight_dates=None, highlight_date_labels=None, show=False): +def plot_trip_counts_per_day( + G, ax=None, highlight_dates=None, highlight_date_labels=None, show=False +): """ Parameters ---------- @@ -27,15 +29,25 @@ def plot_trip_counts_per_day(G, ax=None, highlight_dates=None, highlight_date_la if ax is None: _fig, ax = plt.subplots() daily_trip_counts["datetime"] = pandas.to_datetime(daily_trip_counts["date_str"]) - daily_trip_counts.plot("datetime", "trip_counts", kind="line", ax=ax, marker="o", color="C0", ls=":", - label="Trip counts") + daily_trip_counts.plot( + "datetime", + "trip_counts", + kind="line", + ax=ax, + marker="o", + color="C0", + ls=":", + label="Trip counts", + ) ax.set_xlabel("Date") ax.set_ylabel("Trip counts per day") if highlight_dates is not None: assert isinstance(highlight_dates, list) if highlight_date_labels is not None: assert isinstance(highlight_date_labels, list) - assert len(highlight_dates) == len(highlight_date_labels), "Number of highlight date labels do not match" + assert len(highlight_dates) == len( + highlight_date_labels + ), "Number of highlight date labels do not match" else: highlight_date_labels = [None] * len(highlight_dates) for i, (highlight_date, label) in enumerate(zip(highlight_dates, highlight_date_labels)): diff --git a/gtfspy/route_types.py b/gtfspy/route_types.py index 479fad1..778b6be 100644 --- a/gtfspy/route_types.py +++ b/gtfspy/route_types.py @@ -25,7 +25,7 @@ CABLE_CAR: 8, GONDOLA: 9, FUNICULAR: 10, - AIRCRAFT: 11 + AIRCRAFT: 11, } ROUTE_TYPE_TO_DESCRIPTION = { @@ -35,11 +35,11 @@ BUS: "Bus. Used for short- and long-distance bus routes.", FERRY: "Ferry. Used for short- and long-distance boat service.", CABLE_CAR: "Cable car. Used for street-level cable cars " - "where the cable runs beneath the car.", + "where the cable runs beneath the car.", GONDOLA: "Gondola, Suspended cable car. " - "Typically used for aerial cable cars where " - "the car is suspended from the cable.", - FUNICULAR: "Funicular. Any rail system designed for steep inclines." + "Typically used for aerial cable cars where " + "the car is suspended from the cable.", + FUNICULAR: "Funicular. Any rail system designed for steep inclines.", } ROUTE_TYPE_TO_SHORT_DESCRIPTION = { @@ -52,7 +52,7 @@ CABLE_CAR: "Cable car", GONDOLA: "Gondola", FUNICULAR: "Funicular", - AIRCRAFT: "Aircraft" + AIRCRAFT: "Aircraft", } ROUTE_TYPE_TO_LOWERCASE_TAG = { @@ -64,7 +64,7 @@ FERRY: "ferry", CABLE_CAR: "cablecar", GONDOLA: "gondola", - FUNICULAR: "funicular" + FUNICULAR: "funicular", } # Use these on your own risk! @@ -77,26 +77,26 @@ FERRY: 200, CABLE_CAR: 40, GONDOLA: 20, - FUNICULAR: 20 + FUNICULAR: 20, } ROUTE_TYPE_TO_COLOR = { WALK: "black", - TRAM: '#33a02c', + TRAM: "#33a02c", SUBWAY: "#ff7f00", - RAIL: '#e31a1c', - BUS: '#1f78b4', + RAIL: "#e31a1c", + BUS: "#1f78b4", FERRY: "#ffff99", CABLE_CAR: "#6a3d9a", GONDOLA: "#b15928", FUNICULAR: "#fb9a99", - AIRCRAFT: "#fb9a99" + AIRCRAFT: "#fb9a99", } -def route_type_to_color_iterable(type): - return [ROUTE_TYPE_TO_COLOR[x] for x in type] +def route_type_to_color_iterable(route_type): + return [ROUTE_TYPE_TO_COLOR[x] for x in route_type] -def route_type_to_zorder(type): - return [ROUTE_TYPE_TO_ZORDER[x] for x in type] \ No newline at end of file +def route_type_to_zorder(route_type): + return [ROUTE_TYPE_TO_ZORDER[x] for x in route_type] diff --git a/gtfspy/routing/__init__.py b/gtfspy/routing/__init__.py index 6b1d17e..3a8caba 100644 --- a/gtfspy/routing/__init__.py +++ b/gtfspy/routing/__init__.py @@ -1,2 +1,3 @@ import pyximport -pyximport.install() \ No newline at end of file + +pyximport.install() diff --git a/gtfspy/routing/connection.py b/gtfspy/routing/connection.py index b6d6440..9c30cf3 100644 --- a/gtfspy/routing/connection.py +++ b/gtfspy/routing/connection.py @@ -6,8 +6,17 @@ class Connection: WALK_SEQ = -1 WALK_TRIP_ID = -1 - def __init__(self, departure_stop, arrival_stop, departure_time, arrival_time, trip_id, seq, - is_walk=False, arrival_stop_next_departure_time=float('inf')): + def __init__( + self, + departure_stop, + arrival_stop, + departure_time, + arrival_time, + trip_id, + seq, + is_walk=False, + arrival_stop_next_departure_time=float("inf"), + ): self.departure_stop = departure_stop self.arrival_stop = arrival_stop self.departure_time = departure_time @@ -27,9 +36,16 @@ def __eq__(self, other): return self.__dict__ == other.__dict__ def __repr__(self): - return '<%s:%s:%s:%s:%s:%s:%s:%s>' % ( - self.__class__.__name__, self.departure_stop, self.arrival_stop, - self.departure_time, self.arrival_time, self.trip_id, self.is_walk, self.arrival_stop_next_departure_time) + return "<%s:%s:%s:%s:%s:%s:%s:%s>" % ( + self.__class__.__name__, + self.departure_stop, + self.arrival_stop, + self.departure_time, + self.arrival_time, + self.trip_id, + self.is_walk, + self.arrival_stop_next_departure_time, + ) def __hash__(self): - return hash(self.__repr__()) \ No newline at end of file + return hash(self.__repr__()) diff --git a/gtfspy/routing/connection_scan.py b/gtfspy/routing/connection_scan.py index 91e5b86..3be8b95 100644 --- a/gtfspy/routing/connection_scan.py +++ b/gtfspy/routing/connection_scan.py @@ -17,8 +17,16 @@ class ConnectionScan(AbstractRoutingAlgorithm): http://i11www.iti.uni-karlsruhe.de/extra/publications/dpsw-isftr-13.pdf """ - def __init__(self, transit_events, seed_stop, start_time, - end_time, transfer_margin, walk_network, walk_speed): + def __init__( + self, + transit_events, + seed_stop, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + ): """ Parameters ---------- @@ -46,7 +54,7 @@ def __init__(self, transit_events, seed_stop, start_time, self._walk_speed = walk_speed # algorithm internals - self.__stop_labels = defaultdict(lambda: float('inf')) + self.__stop_labels = defaultdict(lambda: float("inf")) self.__stop_labels[seed_stop] = start_time # trip flags: @@ -101,5 +109,3 @@ def _scan_footpaths(self, stop_id, walk_departure_time): d_walk = data["d_walk"] arrival_time = walk_departure_time + d_walk / self._walk_speed self._update_stop_label(neighbor, arrival_time) - - diff --git a/gtfspy/routing/connection_scan_profile.py b/gtfspy/routing/connection_scan_profile.py index 1d94ee3..1fb3269 100644 --- a/gtfspy/routing/connection_scan_profile.py +++ b/gtfspy/routing/connection_scan_profile.py @@ -43,15 +43,17 @@ class ConnectionScanProfiler(AbstractRoutingAlgorithm): http://i11www.iti.uni-karlsruhe.de/extra/publications/dpsw-isftr-13.pdf """ - def __init__(self, - transit_events, - target_stop, - start_time=None, - end_time=None, - transfer_margin=0, - walk_network=None, - walk_speed=1.5, - verbose=False): + def __init__( + self, + transit_events, + target_stop, + start_time=None, + end_time=None, + transfer_margin=0, + walk_network=None, + walk_speed=1.5, + verbose=False, + ): """ Parameters ---------- @@ -113,8 +115,8 @@ def _run(self): # basic checking + printing progress: if self._verbose and i % 1000 == 0: print(i, "/", n_connections) - assert (isinstance(connection, Connection)) - assert (connection.departure_time <= previous_departure_time) + assert isinstance(connection, Connection) + assert connection.departure_time <= previous_departure_time previous_departure_time = connection.departure_time # get all different "accessible" / arrival times (Pareto-optimal sets) @@ -131,8 +133,9 @@ def _run(self): earliest_arrival_time_via_same_trip = self.__trip_min_arrival_time[connection.trip_id] # then, take the minimum (or the Pareto-optimal set) of these three alternatives. - min_arrival_time = min(earliest_arrival_time_via_same_trip, - earliest_arrival_time_via_transfer) + min_arrival_time = min( + earliest_arrival_time_via_same_trip, earliest_arrival_time_via_transfer + ) # If there are no 'labels' to progress, nothing needs to be done. if min_arrival_time == float("inf"): @@ -140,7 +143,9 @@ def _run(self): # Update information for the trip if earliest_arrival_time_via_same_trip > min_arrival_time: - self.__trip_min_arrival_time[connection.trip_id] = earliest_arrival_time_via_transfer + self.__trip_min_arrival_time[ + connection.trip_id + ] = earliest_arrival_time_via_transfer # Compute the new "best" pareto_tuple possible (later: merge the sets of pareto-optimal labels) pareto_tuple = LabelTimeSimple(connection.departure_time, min_arrival_time) @@ -150,17 +155,22 @@ def _run(self): updated_dep_stop = dep_stop_profile.update_pareto_optimal_tuples(pareto_tuple) # if the departure stop is updated, one also needs to scan the footpaths from the departure stop if updated_dep_stop: - self._scan_footpaths_to_departure_stop(connection.departure_stop, - connection.departure_time, - min_arrival_time) + self._scan_footpaths_to_departure_stop( + connection.departure_stop, connection.departure_time, min_arrival_time + ) - def _scan_footpaths_to_departure_stop(self, connection_dep_stop, connection_dep_time, arrival_time_target): + def _scan_footpaths_to_departure_stop( + self, connection_dep_stop, connection_dep_time, arrival_time_target + ): """ A helper method for scanning the footpaths. Updates self._stop_profiles accordingly""" - for _, neighbor, data in self._walk_network.edges_iter(nbunch=[connection_dep_stop], - data=True): - d_walk = data['d_walk'] + for _, neighbor, data in self._walk_network.edges_iter( + nbunch=[connection_dep_stop], data=True + ): + d_walk = data["d_walk"] neighbor_dep_time = connection_dep_time - d_walk / self._walk_speed - pt = LabelTimeSimple(departure_time=neighbor_dep_time, arrival_time_target=arrival_time_target) + pt = LabelTimeSimple( + departure_time=neighbor_dep_time, arrival_time_target=arrival_time_target + ) self._stop_profiles[neighbor].update_pareto_optimal_tuples(pt) @property diff --git a/gtfspy/routing/fastest_path_analyzer.py b/gtfspy/routing/fastest_path_analyzer.py index 0315eea..33bd498 100644 --- a/gtfspy/routing/fastest_path_analyzer.py +++ b/gtfspy/routing/fastest_path_analyzer.py @@ -1,14 +1,20 @@ -import copy - from gtfspy.routing.label import compute_pareto_front + from gtfspy.routing.node_profile_analyzer_time import NodeProfileAnalyzerTime -from gtfspy.routing.profile_block_analyzer import ProfileBlockAnalyzer from gtfspy.routing.profile_block import ProfileBlock +from gtfspy.routing.profile_block_analyzer import ProfileBlockAnalyzer class FastestPathAnalyzer: - - def __init__(self, labels, start_time_dep, end_time_dep, walk_duration=float('inf'), label_props_to_consider=None, **kwargs): + def __init__( + self, + labels, + start_time_dep, + end_time_dep, + walk_duration=float("inf"), + label_props_to_consider=None, + **kwargs + ): """ Parameters ---------- @@ -19,8 +25,8 @@ def __init__(self, labels, start_time_dep, end_time_dep, walk_duration=float('in label_props_to_consider: list """ for label in labels: - assert (hasattr(label, "departure_time")) - assert (hasattr(label, "arrival_time_target")) + assert hasattr(label, "departure_time") + assert hasattr(label, "arrival_time_target") self.start_time_dep = start_time_dep self.end_time_dep = end_time_dep self.walk_duration = walk_duration @@ -32,17 +38,23 @@ def __init__(self, labels, start_time_dep, end_time_dep, walk_duration=float('in # assert each label has the required properties for label in self._fastest_path_labels: for prop in self.label_props: - assert (hasattr(label, prop)) + assert hasattr(label, prop) self.kwargs = kwargs def _compute_fastest_path_labels(self, labels): - relevant_labels = [label.get_copy() for label in labels if (self.start_time_dep < label.departure_time <= self.end_time_dep)] - if len(relevant_labels) is 0 or relevant_labels[-1].departure_time < self.end_time_dep: + relevant_labels = [ + label.get_copy() + for label in labels + if (self.start_time_dep < label.departure_time <= self.end_time_dep) + ] + if len(relevant_labels) == 0 or relevant_labels[-1].departure_time < self.end_time_dep: # add an after label - smallest_arr_time_after_end_time = float('inf') + smallest_arr_time_after_end_time = float("inf") smallest_arr_time_label = None for label in labels: - if self.end_time_dep < label.departure_time and (label.arrival_time_target < smallest_arr_time_after_end_time): + if self.end_time_dep < label.departure_time and ( + label.arrival_time_target < smallest_arr_time_after_end_time + ): smallest_arr_time_after_end_time = label.arrival_time_target smallest_arr_time_label = label if smallest_arr_time_label is not None: @@ -56,8 +68,8 @@ def _compute_fastest_path_labels(self, labels): # assert ordered: for i in range(len(fp_labels) - 1): try: - assert (fp_labels[i].arrival_time_target <= fp_labels[i + 1].arrival_time_target) - assert (fp_labels[i].departure_time < fp_labels[i + 1].departure_time) + assert fp_labels[i].arrival_time_target <= fp_labels[i + 1].arrival_time_target + assert fp_labels[i].departure_time < fp_labels[i + 1].departure_time except AssertionError as e: for fp_label in fp_labels: print(fp_label) @@ -90,12 +102,13 @@ def get_fastest_path_temporal_distance_blocks(self): ------- blocks: list[ProfileBlock] """ + def _label_to_prop_dict(label): return {prop: getattr(label, prop) for prop in self.label_props} labels = self._fastest_path_labels for i in range(len(labels) - 1): - assert (labels[i].departure_time < labels[i + 1].departure_time) + assert labels[i].departure_time < labels[i + 1].departure_time previous_dep_time = self.start_time_dep blocks = [] @@ -103,26 +116,33 @@ def _label_to_prop_dict(label): if previous_dep_time >= self.end_time_dep: break end_time = min(label.departure_time, self.end_time_dep) - assert (end_time >= previous_dep_time) + assert end_time >= previous_dep_time temporal_distance_start = label.duration() + (label.departure_time - previous_dep_time) if temporal_distance_start > self.walk_duration: - split_point_x_computed = label.departure_time - (self.walk_duration - label.duration()) + split_point_x_computed = label.departure_time - ( + self.walk_duration - label.duration() + ) split_point_x = min(split_point_x_computed, end_time) if previous_dep_time < split_point_x: # add walk block, only if it is required - walk_block = ProfileBlock(previous_dep_time, - split_point_x, - self.walk_duration, - self.walk_duration, - **_label_to_prop_dict(label)) + walk_block = ProfileBlock( + previous_dep_time, + split_point_x, + self.walk_duration, + self.walk_duration, + **_label_to_prop_dict(label), + ) blocks.append(walk_block) if split_point_x < end_time: - trip_block = ProfileBlock(split_point_x, end_time, - label.duration() + (end_time - split_point_x), - label.duration(), - **_label_to_prop_dict(label)) + trip_block = ProfileBlock( + split_point_x, + end_time, + label.duration() + (end_time - split_point_x), + label.duration(), + **_label_to_prop_dict(label), + ) blocks.append(trip_block) else: journey_block = ProfileBlock( @@ -130,15 +150,15 @@ def _label_to_prop_dict(label): end_time, temporal_distance_start, temporal_distance_start - (end_time - previous_dep_time), - **_label_to_prop_dict(label)) + **_label_to_prop_dict(label), + ) blocks.append(journey_block) previous_dep_time = blocks[-1].end_time if previous_dep_time < self.end_time_dep: - last_block = ProfileBlock(previous_dep_time, - self.end_time_dep, - self.walk_duration, - self.walk_duration) + last_block = ProfileBlock( + previous_dep_time, self.end_time_dep, self.walk_duration, self.walk_duration + ) blocks.append(last_block) return blocks @@ -148,10 +168,9 @@ def get_time_analyzer(self): ------- NodeProfileAnalyzerTime """ - return NodeProfileAnalyzerTime(self._fastest_path_labels, - self.walk_duration, - self.start_time_dep, - self.end_time_dep) + return NodeProfileAnalyzerTime( + self._fastest_path_labels, self.walk_duration, self.start_time_dep, self.end_time_dep + ) def get_props(self): return list(self.label_props) @@ -174,13 +193,13 @@ def get_prop_analyzer_for_pre_journey_wait(self): prop_blocks.append(prop_block) return ProfileBlockAnalyzer(prop_blocks, **kwargs) - def get_prop_analyzer_flat(self, property, value_no_next_journey, value_cutoff): + def get_prop_analyzer_flat(self, property_name, value_no_next_journey, value_cutoff): """ Get a journey property analyzer, where each journey is weighted by the number of. Parameters ---------- - property: string + property_name: string Name of the property, needs to be one of label_props given on initialization. value_no_next_journey: Value of the profile, when there is no next journey available. @@ -196,15 +215,12 @@ def get_prop_analyzer_flat(self, property, value_no_next_journey, value_cutoff): prop_blocks = [] for b in fp_blocks: if b.is_flat(): - if b.distance_end == self.walk_duration and b.distance_end != float('inf'): + if b.distance_end == self.walk_duration and b.distance_end != float("inf"): prop_value = value_cutoff else: prop_value = value_no_next_journey else: - prop_value = b[property] + prop_value = b[property_name] prop_block = ProfileBlock(b.start_time, b.end_time, prop_value, prop_value) prop_blocks.append(prop_block) return ProfileBlockAnalyzer(prop_blocks, **kwargs) - - - diff --git a/gtfspy/routing/forwardjourney.py b/gtfspy/routing/forwardjourney.py index 9f16dfb..5e4a8b5 100644 --- a/gtfspy/routing/forwardjourney.py +++ b/gtfspy/routing/forwardjourney.py @@ -30,7 +30,7 @@ def add_leg(self, leg): ---------- leg: Connection """ - assert(isinstance(leg, Connection)) + assert isinstance(leg, Connection) if not self.legs: self.departure_time = leg.departure_time self.arrival_time = leg.arrival_time @@ -68,7 +68,11 @@ def get_transfer_stop_pairs(self): previous_arrival_stop = None current_trip_id = None for leg in self.legs: - if leg.trip_id is not None and leg.trip_id != current_trip_id and previous_arrival_stop is not None: + if ( + leg.trip_id is not None + and leg.trip_id != current_trip_id + and previous_arrival_stop is not None + ): transfer_stop_pair = (previous_arrival_stop, leg.departure_stop) transfer_stop_pairs.append(transfer_stop_pair) previous_arrival_stop = leg.arrival_stop @@ -96,7 +100,7 @@ def get_total_waiting_time(self): def get_invehicle_times(self): invehicle_times = [] for leg in self.legs: - assert(isinstance(leg, Connection)) + assert isinstance(leg, Connection) if leg.trip_id is not None: invehicle_times.append(leg.duration()) return invehicle_times @@ -116,15 +120,15 @@ def get_total_walking_time(self): def dominates(self, other, consider_time=True, consider_boardings=True): if consider_time: - dominates_time = (self.departure_time >= other.departure_time and - self.arrival_time <= other.arrival_time) + dominates_time = ( + self.departure_time >= other.departure_time + and self.arrival_time <= other.arrival_time + ) if not dominates_time: return False if consider_boardings: - dominates_boardings = (self.n_boardings <= other.n_boardings) + dominates_boardings = self.n_boardings <= other.n_boardings if not dominates_boardings: return False # dominates w.r.t all aspects: return True - - diff --git a/gtfspy/routing/helpers.py b/gtfspy/routing/helpers.py index 51f8bbd..70611c8 100644 --- a/gtfspy/routing/helpers.py +++ b/gtfspy/routing/helpers.py @@ -18,15 +18,21 @@ def get_transit_connections(gtfs, start_time_ut, end_time_ut): list[Connection] """ if start_time_ut + 20 * 3600 < end_time_ut: - warn("Note that it is possible that same trip_I's can take place during multiple days, " - "which could (potentially) affect the outcomes of the CSA routing!") - assert (isinstance(gtfs, GTFS)) + warn( + "Note that it is possible that same trip_I's can take place during multiple days, " + "which could (potentially) affect the outcomes of the CSA routing!" + ) + assert isinstance(gtfs, GTFS) events_df = temporal_network(gtfs, start_time_ut=start_time_ut, end_time_ut=end_time_ut) - assert (isinstance(events_df, pandas.DataFrame)) - return list(map(lambda e: Connection(e.from_stop_I, e.to_stop_I, e.dep_time_ut, e.arr_time_ut, e.trip_I, e.seq), - events_df.itertuples() - ) - ) + assert isinstance(events_df, pandas.DataFrame) + return list( + map( + lambda e: Connection( + e.from_stop_I, e.to_stop_I, e.dep_time_ut, e.arr_time_ut, e.trip_I, e.seq + ), + events_df.itertuples(), + ) + ) def get_walk_network(gtfs, max_link_distance_m=1000): @@ -39,5 +45,5 @@ def get_walk_network(gtfs, max_link_distance_m=1000): ------- walk_network: networkx.Graph: """ - assert (isinstance(gtfs, GTFS)) + assert isinstance(gtfs, GTFS) return walk_transfer_stop_to_stop_network(gtfs, max_link_distance=max_link_distance_m) diff --git a/gtfspy/routing/journey_data.py b/gtfspy/routing/journey_data.py index 9bf70f3..344f47a 100644 --- a/gtfspy/routing/journey_data.py +++ b/gtfspy/routing/journey_data.py @@ -4,8 +4,13 @@ from gtfspy.routing.connection import Connection from gtfspy.gtfs import GTFS -from gtfspy.routing.label import LabelTimeAndRoute, LabelTimeWithBoardingsCount, LabelTimeBoardingsAndRoute, \ - compute_pareto_front, LabelGeneric +from gtfspy.routing.label import ( + LabelTimeAndRoute, + LabelTimeWithBoardingsCount, + LabelTimeBoardingsAndRoute, + compute_pareto_front, + LabelGeneric, +) from gtfspy.routing.travel_impedance_data_store import TravelImpedanceDataStore from gtfspy.routing.fastest_path_analyzer import FastestPathAnalyzer from gtfspy.routing.node_profile_analyzer_time_and_veh_legs import NodeProfileAnalyzerTimeAndVehLegs @@ -19,12 +24,20 @@ def attach_database(conn, other_db_path, name="other"): print("other database attached:", cur.fetchall()) return conn + _T_WALK_STR = "t_walk" -class JourneyDataManager: - def __init__(self, gtfs_path, journey_db_path, routing_params=None, multitarget_routing=False, - track_vehicle_legs=True, track_route=False): +class JourneyDataManager: + def __init__( + self, + gtfs_path, + journey_db_path, + routing_params=None, + multitarget_routing=False, + track_vehicle_legs=True, + track_route=False, + ): """ :param gtfs: GTFS object :param list_of_stop_profiles: dict of NodeProfileMultiObjective @@ -59,15 +72,16 @@ def __init__(self, gtfs_path, journey_db_path, routing_params=None, multitarget_ self._assert_journey_computation_paramaters_match() self.journey_properties = {"journey_duration": (_T_WALK_STR, _T_WALK_STR)} - if routing_params.get('track_vehicle_legs', False) or \ - self.routing_parameters.get('track_vehicle_legs', False): + if routing_params.get("track_vehicle_legs", False) or self.routing_parameters.get( + "track_vehicle_legs", False + ): self.journey_properties["n_boardings"] = (float("inf"), 0) if self.track_route: additional_journey_parameters = { - "in_vehicle_duration": (float('inf'), 0), - "transfer_wait_duration": (float('inf'), 0), + "in_vehicle_duration": (float("inf"), 0), + "transfer_wait_duration": (float("inf"), 0), "walking_duration": (_T_WALK_STR, _T_WALK_STR), - "pre_journey_wait_fp": (float('inf'), 0) + "pre_journey_wait_fp": (float("inf"), 0), } self.journey_properties.update(additional_journey_parameters) self.travel_impedance_measure_names = list(self.journey_properties.keys()) @@ -79,7 +93,9 @@ def __del__(self): self.conn.close() @timeit - def import_journey_data_for_target_stop(self, target_stop_I, origin_stop_I_to_journey_labels, enforce_synchronous_writes=False): + def import_journey_data_for_target_stop( + self, target_stop_I, origin_stop_I_to_journey_labels, enforce_synchronous_writes=False + ): """ Parameters ---------- @@ -89,14 +105,18 @@ def import_journey_data_for_target_stop(self, target_stop_I, origin_stop_I_to_jo target_stop_I: int """ cur = self.conn.cursor() - self.conn.isolation_level = 'EXCLUSIVE' + self.conn.isolation_level = "EXCLUSIVE" # if not enforce_synchronous_writes: - cur.execute('PRAGMA synchronous = 0;') + cur.execute("PRAGMA synchronous = 0;") if self.track_route: - self._insert_journeys_with_route_into_db(origin_stop_I_to_journey_labels, target_stop=int(target_stop_I)) + self._insert_journeys_with_route_into_db( + origin_stop_I_to_journey_labels, target_stop=int(target_stop_I) + ) else: - self._insert_journeys_into_db_no_route(origin_stop_I_to_journey_labels, target_stop=int(target_stop_I)) + self._insert_journeys_into_db_no_route( + origin_stop_I_to_journey_labels, target_stop=int(target_stop_I) + ) print("Finished import process") self.conn.commit() @@ -121,40 +141,44 @@ def _insert_journeys_into_db_no_route(self, stop_profiles, target_stop=None): print("Collecting journey data") journey_id = 1 journey_list = [] - tot = len(stop_profiles) + # tot = len(stop_profiles) for i, (origin_stop, labels) in enumerate(stop_profiles.items(), start=1): - #print("\r Stop " + str(i) + " of " + str(tot), end='', flush=True) + # print("\r Stop " + str(i) + " of " + str(tot), end='', flush=True) for label in labels: - assert (isinstance(label, LabelTimeWithBoardingsCount)) + assert isinstance(label, LabelTimeWithBoardingsCount) if self.multitarget_routing: target_stop = None else: target_stop = int(target_stop) - values = [int(journey_id), - int(origin_stop), - target_stop, - int(label.departure_time), - int(label.arrival_time_target), - int(label.n_boardings)] + values = [ + int(journey_id), + int(origin_stop), + target_stop, + int(label.departure_time), + int(label.arrival_time_target), + int(label.n_boardings), + ] journey_list.append(values) journey_id += 1 print("Inserting journeys without route into database") - insert_journeys_stmt = '''INSERT INTO journeys( + insert_journeys_stmt = """INSERT INTO journeys( journey_id, from_stop_I, to_stop_I, departure_time, arrival_time_target, - n_boardings) VALUES (%s) ''' % (", ".join(["?" for x in range(6)])) - #self.conn.executemany(insert_journeys_stmt, journey_list) + n_boardings) VALUES (%s) """ % ( + ", ".join(["?" for x in range(6)]) + ) + # self.conn.executemany(insert_journeys_stmt, journey_list) self._executemany_exclusive(insert_journeys_stmt, journey_list) self.conn.commit() @timeit def _executemany_exclusive(self, statement, rows): - self.conn.execute('BEGIN EXCLUSIVE') + self.conn.execute("BEGIN EXCLUSIVE") last_id = self._get_largest_journey_id() rows = [[x[0] + last_id] + x[1:] for x in rows] self.conn.executemany(statement, rows) @@ -167,39 +191,47 @@ def _insert_journeys_with_route_into_db(self, stop_I_to_journey_labels, target_s label = None for i, (origin_stop, labels) in enumerate(stop_I_to_journey_labels.items(), start=1): # tot = len(stop_profiles) - #print("\r Stop " + str(i) + " of " + str(tot), end='', flush=True) + # print("\r Stop " + str(i) + " of " + str(tot), end='', flush=True) - assert (isinstance(stop_I_to_journey_labels[origin_stop], list)) + assert isinstance(stop_I_to_journey_labels[origin_stop], list) for label in labels: - assert (isinstance(label, LabelTimeAndRoute) or isinstance(label, LabelTimeBoardingsAndRoute)) + assert isinstance(label, LabelTimeAndRoute) or isinstance( + label, LabelTimeBoardingsAndRoute + ) # We need to "unpack" the journey to actually figure out where the trip went # (there can be several targets). if label.departure_time == label.arrival_time_target: print("Weird label:", label) continue - target_stop, new_connection_values, route_stops = self._collect_connection_data(journey_id, label) + target_stop, new_connection_values, route_stops = self._collect_connection_data( + journey_id, label + ) if origin_stop == target_stop: continue if isinstance(label, LabelTimeBoardingsAndRoute): - values = [int(journey_id), - int(origin_stop), - int(target_stop), - int(label.departure_time), - int(label.arrival_time_target), - label.n_boardings, - label.movement_duration, - route_stops] + values = [ + int(journey_id), + int(origin_stop), + int(target_stop), + int(label.departure_time), + int(label.arrival_time_target), + label.n_boardings, + label.movement_duration, + route_stops, + ] else: - values = [int(journey_id), - int(origin_stop), - int(target_stop), - int(label.departure_time), - int(label.arrival_time_target), - label.movement_duration, - route_stops] + values = [ + int(journey_id), + int(origin_stop), + int(target_stop), + int(label.departure_time), + int(label.arrival_time_target), + label.movement_duration, + route_stops, + ] journey_list.append(values) connection_list += new_connection_values @@ -208,7 +240,7 @@ def _insert_journeys_with_route_into_db(self, stop_I_to_journey_labels, target_s print("Inserting journeys into database") if label: if isinstance(label, LabelTimeBoardingsAndRoute): - insert_journeys_stmt = '''INSERT INTO journeys( + insert_journeys_stmt = """INSERT INTO journeys( journey_id, from_stop_I, to_stop_I, @@ -216,20 +248,24 @@ def _insert_journeys_with_route_into_db(self, stop_I_to_journey_labels, target_s arrival_time_target, n_boardings, movement_duration, - route) VALUES (%s) ''' % (", ".join(["?" for x in range(8)])) + route) VALUES (%s) """ % ( + ", ".join(["?" for x in range(8)]) + ) else: - insert_journeys_stmt = '''INSERT INTO journeys( + insert_journeys_stmt = """INSERT INTO journeys( journey_id, from_stop_I, to_stop_I, departure_time, arrival_time_target, movement_duration, - route) VALUES (%s) ''' % (", ".join(["?" for x in range(7)])) + route) VALUES (%s) """ % ( + ", ".join(["?" for x in range(7)]) + ) self.conn.executemany(insert_journeys_stmt, journey_list) print("Inserting legs into database") - insert_legs_stmt = '''INSERT INTO legs( + insert_legs_stmt = """INSERT INTO legs( journey_id, from_stop_I, to_stop_I, @@ -237,16 +273,19 @@ def _insert_journeys_with_route_into_db(self, stop_I_to_journey_labels, target_s arrival_time_target, trip_I, seq, - leg_stops) VALUES (%s) ''' % (", ".join(["?" for x in range(8)])) + leg_stops) VALUES (%s) """ % ( + ", ".join(["?" for x in range(8)]) + ) self.conn.executemany(insert_legs_stmt, connection_list) - self.routing_parameters["target_list"] += (str(target_stop) + ",") + self.routing_parameters["target_list"] += str(target_stop) + "," self.conn.commit() - def create_index_for_journeys_table(self): self.conn.execute("PRAGMA temp_store=2") self.conn.commit() - self.conn.execute("CREATE INDEX IF NOT EXISTS journeys_to_stop_I_idx ON journeys (to_stop_I)") + self.conn.execute( + "CREATE INDEX IF NOT EXISTS journeys_to_stop_I_idx ON journeys (to_stop_I)" + ) def _collect_connection_data(self, journey_id, label): target_stop = None @@ -283,8 +322,8 @@ def _collect_connection_data(self, journey_id, label): int(leg_arrival_time), int(prev_trip_id), int(seq), - ','.join([str(x) for x in leg_stops]) - ) + ",".join([str(x) for x in leg_stops]), + ) value_list.append(values) seq += 1 leg_stops = [] @@ -307,14 +346,14 @@ def _collect_connection_data(self, journey_id, label): int(leg_arrival_time), int(prev_trip_id), int(seq), - ','.join([str(x) for x in leg_stops]) + ",".join([str(x) for x in leg_stops]), ) value_list.append(values) break cur_label = cur_label.previous_label route_stops.append(target_stop) - route_stops = ','.join([str(x) for x in route_stops]) + route_stops = ",".join([str(x) for x in route_stops]) return target_stop, value_list, route_stops def populate_additional_journey_columns(self): @@ -326,27 +365,29 @@ def populate_additional_journey_columns(self): def get_od_pairs_having_journeys(self): cur = self.conn.cursor() if not self.od_pairs: - cur.execute('SELECT from_stop_I, to_stop_I FROM journeys GROUP BY from_stop_I, to_stop_I') + cur.execute( + "SELECT from_stop_I, to_stop_I FROM journeys GROUP BY from_stop_I, to_stop_I" + ) self.od_pairs = cur.fetchall() return self.od_pairs def get_targets_having_journeys(self): cur = self.conn.cursor() if not self._targets: - cur.execute('SELECT to_stop_I FROM journeys GROUP BY to_stop_I') + cur.execute("SELECT to_stop_I FROM journeys GROUP BY to_stop_I") self._targets = [target[0] for target in cur.fetchall()] return self._targets def get_origins_having_journeys(self): cur = self.conn.cursor() if not self._origins: - cur.execute('SELECT from_stop_I FROM journeys GROUP BY from_stop_I') + cur.execute("SELECT from_stop_I FROM journeys GROUP BY from_stop_I") self._origins = [origin[0] for origin in cur.fetchall()] return self._origins def get_table_with_coordinates(self, table_name, target=None): df = self.get_table_as_dataframe(table_name, target) - return self.gtfs.add_coordinates_to_df(df, join_column='from_stop_I') + return self.gtfs.add_coordinates_to_df(df, join_column="from_stop_I") def get_table_as_dataframe(self, table_name, to_stop_I_target=None): query = "SELECT * FROM " + table_name @@ -361,16 +402,28 @@ def add_fastest_path_column(self): for target in self.get_targets_having_journeys(): fastest_path_journey_ids = [] for origin in self.get_origins_having_journeys(): - cur.execute('SELECT departure_time, arrival_time_target, journey_id FROM journeys ' - 'WHERE from_stop_I = ? AND to_stop_I = ? ' - 'ORDER BY departure_time ASC', (origin, target)) + cur.execute( + "SELECT departure_time, arrival_time_target, journey_id FROM journeys " + "WHERE from_stop_I = ? AND to_stop_I = ? " + "ORDER BY departure_time ASC", + (origin, target), + ) all_trips = cur.fetchall() - all_labels = [LabelTimeAndRoute(x[0], x[1], x[2], False) for x in all_trips] #putting journey_id as movement_duration - all_fp_labels = compute_pareto_front(all_labels, finalization=False, ignore_n_boardings=True) + all_labels = [ + LabelTimeAndRoute(x[0], x[1], x[2], False) for x in all_trips + ] # putting journey_id as movement_duration + all_fp_labels = compute_pareto_front( + all_labels, finalization=False, ignore_n_boardings=True + ) fastest_path_journey_ids.append(all_fp_labels) - fastest_path_journey_ids = [(1, x.movement_duration) for sublist in fastest_path_journey_ids for x in sublist] - cur.executemany("UPDATE journeys SET fastest_path = ? WHERE journey_id = ?", fastest_path_journey_ids) + fastest_path_journey_ids = [ + (1, x.movement_duration) for sublist in fastest_path_journey_ids for x in sublist + ] + cur.executemany( + "UPDATE journeys SET fastest_path = ? WHERE journey_id = ?", + fastest_path_journey_ids, + ) self.conn.commit() @timeit @@ -379,9 +432,12 @@ def add_time_to_prev_journey_fp_column(self): cur = self.conn.cursor() for target in self.get_targets_having_journeys(): - cur.execute('SELECT journey_id, from_stop_I, to_stop_I, departure_time FROM journeys ' - 'WHERE fastest_path = 1 AND to_stop_I = ? ' - 'ORDER BY from_stop_I, to_stop_I, departure_time ', (target,)) + cur.execute( + "SELECT journey_id, from_stop_I, to_stop_I, departure_time FROM journeys " + "WHERE fastest_path = 1 AND to_stop_I = ? " + "ORDER BY from_stop_I, to_stop_I, departure_time ", + (target,), + ) all_trips = cur.fetchall() time_to_prev_journey = [] @@ -400,7 +456,10 @@ def add_time_to_prev_journey_fp_column(self): prev_origin = from_stop_I prev_destination = to_stop_I prev_departure_time = departure_time - cur.executemany("UPDATE journeys SET pre_journey_wait_fp = ? WHERE journey_id = ?", time_to_prev_journey) + cur.executemany( + "UPDATE journeys SET pre_journey_wait_fp = ? WHERE journey_id = ?", + time_to_prev_journey, + ) self.conn.commit() @timeit @@ -410,18 +469,24 @@ def compute_journey_time_components(self): cur.execute("UPDATE journeys SET journey_duration = arrival_time_target - departure_time") if self.track_route: - cur.execute("UPDATE journeys " - "SET " - "in_vehicle_duration = " - "(SELECT sum(arrival_time_target - departure_time) AS in_vehicle_duration FROM legs " - "WHERE journeys.journey_id = legs.journey_id AND trip_I != -1 GROUP BY journey_id)") - cur.execute("UPDATE journeys " - "SET " - "walking_duration = " - "(SELECT sum(arrival_time_target - departure_time) AS walking_duration FROM legs " - "WHERE journeys.journey_id = legs.journey_id AND trip_I < 0 GROUP BY journey_id)") - cur.execute("UPDATE journeys " - "SET transfer_wait_duration = journey_duration - in_vehicle_duration - walking_duration") + cur.execute( + "UPDATE journeys " + "SET " + "in_vehicle_duration = " + "(SELECT sum(arrival_time_target - departure_time) AS in_vehicle_duration FROM legs " + "WHERE journeys.journey_id = legs.journey_id AND trip_I != -1 GROUP BY journey_id)" + ) + cur.execute( + "UPDATE journeys " + "SET " + "walking_duration = " + "(SELECT sum(arrival_time_target - departure_time) AS walking_duration FROM legs " + "WHERE journeys.journey_id = legs.journey_id AND trip_I < 0 GROUP BY journey_id)" + ) + cur.execute( + "UPDATE journeys " + "SET transfer_wait_duration = journey_duration - in_vehicle_duration - walking_duration" + ) self.conn.commit() def _journey_label_generator(self, destination_stop_Is=None, origin_stop_Is=None): @@ -444,19 +509,28 @@ def _journey_label_generator(self, destination_stop_Is=None, origin_stop_Is=None for destination_stop_I in destination_stop_Is: if self.track_route: - label_features = "journey_id, from_stop_I, to_stop_I, n_boardings, movement_duration, " \ - "journey_duration, in_vehicle_duration, transfer_wait_duration, walking_duration, " \ - "departure_time, arrival_time_target""" + label_features = ( + "journey_id, from_stop_I, to_stop_I, n_boardings, movement_duration, " + "journey_duration, in_vehicle_duration, transfer_wait_duration, walking_duration, " + "departure_time, arrival_time_target" + "" + ) else: - label_features = "journey_id, from_stop_I, to_stop_I, n_boardings, departure_time, " \ - "arrival_time_target" - sql = "SELECT " + label_features + " FROM journeys WHERE to_stop_I = %s" % destination_stop_I + label_features = ( + "journey_id, from_stop_I, to_stop_I, n_boardings, departure_time, " + "arrival_time_target" + ) + sql = ( + "SELECT " + + label_features + + " FROM journeys WHERE to_stop_I = %s" % destination_stop_I + ) df = pd.read_sql_query(sql, self.conn) for origin_stop_I in origin_stop_Is: - selection = df.loc[df['from_stop_I'] == origin_stop_I] + selection = df.loc[df["from_stop_I"] == origin_stop_I] journey_labels = [] - for journey in selection.to_dict(orient='records'): + for journey in selection.to_dict(orient="records"): journey["pre_journey_wait_fp"] = -1 try: journey_labels.append(LabelGeneric(journey)) @@ -468,21 +542,31 @@ def _journey_label_generator(self, destination_stop_Is=None, origin_stop_Is=None def get_node_profile_time_analyzer(self, target, origin, start_time_dep, end_time_dep): sql = """SELECT journey_id, from_stop_I, to_stop_I, n_boardings, movement_duration, journey_duration, in_vehicle_duration, transfer_wait_duration, walking_duration, departure_time, arrival_time_target - FROM journeys WHERE to_stop_I = %s AND from_stop_I = %s""" % (target, origin) + FROM journeys WHERE to_stop_I = %s AND from_stop_I = %s""" % ( + target, + origin, + ) df = pd.read_sql_query(sql, self.conn) journey_labels = [] - for journey in df.to_dict(orient='records'): + for journey in df.to_dict(orient="records"): journey_labels.append(LabelGeneric(journey)) - fpa = FastestPathAnalyzer(journey_labels, - start_time_dep, - end_time_dep, - walk_duration=float('inf'), # walking time - label_props_to_consider=list(self.journey_properties.keys())) + fpa = FastestPathAnalyzer( + journey_labels, + start_time_dep, + end_time_dep, + walk_duration=float("inf"), # walking time + label_props_to_consider=list(self.journey_properties.keys()), + ) return fpa.get_time_analyzer() - def get_node_profile_analyzer_time_and_veh_legs(self, target, origin, start_time_dep, end_time_dep): - sql = """SELECT from_stop_I, to_stop_I, n_boardings, departure_time, arrival_time_target FROM journeys WHERE to_stop_I = %s AND from_stop_I = %s""" % (target, origin) + def get_node_profile_analyzer_time_and_veh_legs( + self, target, origin, start_time_dep, end_time_dep + ): + sql = ( + """SELECT from_stop_I, to_stop_I, n_boardings, departure_time, arrival_time_target FROM journeys WHERE to_stop_I = %s AND from_stop_I = %s""" + % (target, origin) + ) df = pd.read_sql_query(sql, self.conn) journey_labels = [] @@ -490,30 +574,30 @@ def get_node_profile_analyzer_time_and_veh_legs(self, target, origin, start_time departure_time = journey.departure_time arrival_time_target = journey.arrival_time_target n_boardings = journey.n_boardings - journey_labels.append(LabelTimeWithBoardingsCount(departure_time, - arrival_time_target, - n_boardings, - first_leg_is_walk=float('nan'))) + journey_labels.append( + LabelTimeWithBoardingsCount( + departure_time, arrival_time_target, n_boardings, first_leg_is_walk=float("nan") + ) + ) # This ought to be optimized... - query = """SELECT d, d_walk FROM stop_distances WHERE to_stop_I = %s AND from_stop_I = %s""" % (target, origin) + query = ( + """SELECT d, d_walk FROM stop_distances WHERE to_stop_I = %s AND from_stop_I = %s""" + % (target, origin) + ) df = self.gtfs.execute_custom_query_pandas(query) if len(df) > 0: - walk_duration = float(df['d_walk']) / self.routing_params_input['walk_speed'] + walk_duration = float(df["d_walk"]) / self.routing_params_input["walk_speed"] else: - walk_duration = float('inf') - analyzer = NodeProfileAnalyzerTimeAndVehLegs(journey_labels, - walk_duration, # walking time - start_time_dep, - end_time_dep) + walk_duration = float("inf") + analyzer = NodeProfileAnalyzerTimeAndVehLegs( + journey_labels, walk_duration, start_time_dep, end_time_dep # walking time + ) return analyzer - def __compute_travel_impedance_measure_dict(self, - origin, - target, - journey_labels, - analysis_start_time, - analysis_end_time): + def __compute_travel_impedance_measure_dict( + self, origin, target, journey_labels, analysis_start_time, analysis_end_time + ): measure_summaries = {} kwargs = {"from_stop_I": origin, "to_stop_I": target} walking_distance = self.gtfs.get_stop_distance(origin, target) @@ -522,52 +606,59 @@ def __compute_travel_impedance_measure_dict(self, walking_duration = walking_distance / self.routing_params_input["walk_speed"] else: walking_duration = float("inf") - fpa = FastestPathAnalyzer(journey_labels, - analysis_start_time, - analysis_end_time, - walk_duration=walking_duration, # walking time - label_props_to_consider=list(self.journey_properties.keys()), - **kwargs) + fpa = FastestPathAnalyzer( + journey_labels, + analysis_start_time, + analysis_end_time, + walk_duration=walking_duration, # walking time + label_props_to_consider=list(self.journey_properties.keys()), + **kwargs, + ) temporal_distance_analyzer = fpa.get_temporal_distance_analyzer() # Note: the summary_as_dict automatically includes also the from_stop_I and to_stop_I -fields. measure_summaries["temporal_distance"] = temporal_distance_analyzer.summary_as_dict() fpa.calculate_pre_journey_waiting_times_ignoring_direct_walk() for key, (value_no_next_journey, value_cutoff) in self.journey_properties.items(): value_cutoff = walking_duration if value_cutoff == _T_WALK_STR else value_cutoff - value_no_next_journey = walking_duration if value_no_next_journey == _T_WALK_STR else value_no_next_journey + value_no_next_journey = ( + walking_duration if value_no_next_journey == _T_WALK_STR else value_no_next_journey + ) if key == "pre_journey_wait_fp": property_analyzer = fpa.get_prop_analyzer_for_pre_journey_wait() else: - property_analyzer = fpa.get_prop_analyzer_flat(key, value_no_next_journey, value_cutoff) + property_analyzer = fpa.get_prop_analyzer_flat( + key, value_no_next_journey, value_cutoff + ) measure_summaries[key] = property_analyzer.summary_as_dict() return measure_summaries - def compute_travel_impedance_measures_for_target(self, - analysis_start_time, - analysis_end_time, - target, origins=None): + def compute_travel_impedance_measures_for_target( + self, analysis_start_time, analysis_end_time, target, origins=None + ): if origins is None: origins = self.get_origins_having_journeys() measure_to_measure_summary_dicts = {} for measure in ["temporal_distance"] + list(self.journey_properties): measure_to_measure_summary_dicts[measure] = [] for origin, target, journey_labels in self._journey_label_generator([target], origins): - measure_summary_dicts_for_pair = \ - self.__compute_travel_impedance_measure_dict( - origin, target, journey_labels, - analysis_start_time, analysis_end_time + measure_summary_dicts_for_pair = self.__compute_travel_impedance_measure_dict( + origin, target, journey_labels, analysis_start_time, analysis_end_time ) for measure in measure_summary_dicts_for_pair: - measure_to_measure_summary_dicts[measure].append(measure_summary_dicts_for_pair[measure]) + measure_to_measure_summary_dicts[measure].append( + measure_summary_dicts_for_pair[measure] + ) return measure_to_measure_summary_dicts @timeit - def compute_and_store_travel_impedance_measures(self, - analysis_start_time, - analysis_end_time, - travel_impedance_store_fname, - origins=None, - targets=None): + def compute_and_store_travel_impedance_measures( + self, + analysis_start_time, + analysis_end_time, + travel_impedance_store_fname, + origins=None, + targets=None, + ): data_store = TravelImpedanceDataStore(travel_impedance_store_fname) @@ -575,7 +666,7 @@ def compute_and_store_travel_impedance_measures(self, for travel_impedance_measure in self.travel_impedance_measure_names: data_store.create_table(travel_impedance_measure) - print("Computing total number of origins and targets..", end='', flush=True) + print("Computing total number of origins and targets..", end="", flush=True) if targets is None: targets = self.get_targets_having_journeys() if origins is None: @@ -583,9 +674,11 @@ def compute_and_store_travel_impedance_measures(self, print("\rComputed total number of origins and targets") n_pairs_tot = len(origins) * len(targets) - print("TODO!, compute_and_store_travel_impedance_measures, " - "should be adjusted to use " - "travel_impedance_measure_data_store.py instead") + print( + "TODO!, compute_and_store_travel_impedance_measures, " + "should be adjusted to use " + "travel_impedance_measure_data_store.py instead" + ) def _flush_data_to_db(results): for travel_impedance_measure, data in results.items(): @@ -596,18 +689,32 @@ def _flush_data_to_db(results): # This initializes the meaasure_to_measure_summary_dict properly _flush_data_to_db(measure_to_measure_summary_dicts) - for i, (origin, target, journey_labels) in enumerate(self._journey_label_generator(targets, origins)): + for i, (origin, target, journey_labels) in enumerate( + self._journey_label_generator(targets, origins) + ): print("i", origin, target, journey_labels) if len(journey_labels) == 0: continue - measure_summary_dicts_for_pair = self.__compute_travel_impedance_measure_dict(origin, target, journey_labels, - analysis_start_time, analysis_end_time) + measure_summary_dicts_for_pair = self.__compute_travel_impedance_measure_dict( + origin, target, journey_labels, analysis_start_time, analysis_end_time + ) for measure in measure_summary_dicts_for_pair: - measure_to_measure_summary_dicts[measure].append(measure_summary_dicts_for_pair[measure]) + measure_to_measure_summary_dicts[measure].append( + measure_summary_dicts_for_pair[measure] + ) - if i % 1000 == 0: # update in batches of 1000 - print("\r", i, "/", n_pairs_tot, " : ", "%.2f" % round(float(i) / n_pairs_tot, 3), end='', flush=True) + if i % 1000 == 0: # update in batches of 1000 + print( + "\r", + i, + "/", + n_pairs_tot, + " : ", + "%.2f" % round(float(i) / n_pairs_tot, 3), + end="", + flush=True, + ) _flush_data_to_db(measure_to_measure_summary_dicts) # flush everything that remains @@ -619,10 +726,12 @@ def calculate_pre_journey_waiting_times_ignoring_direct_walk(self): for origin, target, journey_labels in self._journey_label_generator(): if not journey_labels: continue - fpa = FastestPathAnalyzer(journey_labels, - self.routing_parameters["routing_start_time_dep"], - self.routing_parameters["routing_end_time_dep"], - walk_duration=float('inf')) + fpa = FastestPathAnalyzer( + journey_labels, + self.routing_parameters["routing_start_time_dep"], + self.routing_parameters["routing_end_time_dep"], + walk_duration=float("inf"), + ) fpa.calculate_pre_journey_waiting_times_ignoring_direct_walk() all_fp_labels += fpa.get_fastest_path_labels() self.update_journey_from_labels(all_fp_labels, "pre_journey_wait_fp") @@ -647,41 +756,72 @@ def _insert_travel_impedance_data_to_db(self, travel_impedance_measure_name, dat "from_stop_I", "to_stop_I", "min", "max", "median" and "mean" """ f = float - data_tuple = [(x["from_stop_I"], x["to_stop_I"], f(x["min"]), f(x["max"]), f(x["median"]), f(x["mean"])) for x in data] - insert_stmt = '''INSERT OR REPLACE INTO ''' + travel_impedance_measure_name + ''' ( + data_tuple = [ + ( + x["from_stop_I"], + x["to_stop_I"], + f(x["min"]), + f(x["max"]), + f(x["median"]), + f(x["mean"]), + ) + for x in data + ] + insert_stmt = ( + """INSERT OR REPLACE INTO """ + + travel_impedance_measure_name + + """ ( from_stop_I, to_stop_I, min, max, median, - mean) VALUES (?, ?, ?, ?, ?, ?) ''' + mean) VALUES (?, ?, ?, ?, ?, ?) """ + ) self.conn.executemany(insert_stmt, data_tuple) self.conn.commit() def _create_index_for_journeys_table(self): - self.conn.execute("CREATE INDEX IF NOT EXISTS journeys_to_stop_I_idx ON journeys (to_stop_I)") + self.conn.execute( + "CREATE INDEX IF NOT EXISTS journeys_to_stop_I_idx ON journeys (to_stop_I)" + ) @timeit def initialize_comparison_tables(self, diff_db_path, before_db_tuple, after_db_tuple): self.diff_conn = sqlite3.connect(diff_db_path) - self.diff_conn = attach_database(self.diff_conn, before_db_tuple[0], name=before_db_tuple[1]) + self.diff_conn = attach_database( + self.diff_conn, before_db_tuple[0], name=before_db_tuple[1] + ) self.diff_conn = attach_database(self.diff_conn, after_db_tuple[0], name=after_db_tuple[1]) for table in self.travel_impedance_measure_names: - self.diff_conn.execute("CREATE TABLE IF NOT EXISTS diff_" + table + - " (from_stop_I, to_stop_I, diff_min, diff_max, diff_median, diff_mean)") - insert_stmt = "INSERT OR REPLACE INTO diff_" + table + \ - "(from_stop_I, to_stop_I, diff_min, diff_max, diff_median, diff_mean) " \ - "SELECT t1.from_stop_I, t1.to_stop_I, " \ - "t1.min - t2.min AS diff_min, " \ - "t1.max - t2.max AS diff_max, " \ - "t1.median - t2.median AS diff_median, " \ - "t1.mean - t2.mean AS diff_mean " \ - "FROM " + before_db_tuple[1] + "." + table + " AS t1, " \ - + before_db_tuple[1] + "." + table + " AS t2 " \ - "WHERE t1.from_stop_I = t2.from_stop_I " \ - "AND t1.to_stop_I = t2.to_stop_I " + self.diff_conn.execute( + "CREATE TABLE IF NOT EXISTS diff_" + + table + + " (from_stop_I, to_stop_I, diff_min, diff_max, diff_median, diff_mean)" + ) + insert_stmt = ( + "INSERT OR REPLACE INTO diff_" + + table + + "(from_stop_I, to_stop_I, diff_min, diff_max, diff_median, diff_mean) " + "SELECT t1.from_stop_I, t1.to_stop_I, " + "t1.min - t2.min AS diff_min, " + "t1.max - t2.max AS diff_max, " + "t1.median - t2.median AS diff_median, " + "t1.mean - t2.mean AS diff_mean " + "FROM " + + before_db_tuple[1] + + "." + + table + + " AS t1, " + + before_db_tuple[1] + + "." + + table + + " AS t2 " + "WHERE t1.from_stop_I = t2.from_stop_I " + "AND t1.to_stop_I = t2.to_stop_I " + ) self.diff_conn.execute(insert_stmt) self.diff_conn.commit() @@ -691,11 +831,14 @@ def initialize_database(self): print("Database initialized!") def _set_up_database(self): - self.conn.execute('''CREATE TABLE IF NOT EXISTS parameters( + self.conn.execute( + """CREATE TABLE IF NOT EXISTS parameters( key TEXT UNIQUE, - value BLOB)''') + value BLOB)""" + ) if self.track_route: - self.conn.execute('''CREATE TABLE IF NOT EXISTS journeys( + self.conn.execute( + """CREATE TABLE IF NOT EXISTS journeys( journey_id INTEGER PRIMARY KEY, from_stop_I INT, to_stop_I INT, @@ -709,9 +852,11 @@ def _set_up_database(self): in_vehicle_duration INT, transfer_wait_duration INT, walking_duration INT, - fastest_path INT)''') + fastest_path INT)""" + ) - self.conn.execute('''CREATE TABLE IF NOT EXISTS legs( + self.conn.execute( + """CREATE TABLE IF NOT EXISTS legs( journey_id INT, from_stop_I INT, to_stop_I INT, @@ -719,7 +864,8 @@ def _set_up_database(self): arrival_time_target INT, trip_I INT, seq INT, - leg_stops TEXT)''') + leg_stops TEXT)""" + ) """ self.conn.execute('''CREATE TABLE IF NOT EXISTS nodes( stop_I INT, @@ -759,7 +905,8 @@ def _set_up_database(self): n_trips INT)''') """ else: - self.conn.execute('''CREATE TABLE IF NOT EXISTS journeys( + self.conn.execute( + """CREATE TABLE IF NOT EXISTS journeys( journey_id INTEGER PRIMARY KEY, from_stop_I INT, to_stop_I INT, @@ -768,7 +915,8 @@ def _set_up_database(self): n_boardings INT, journey_duration INT, time_to_prev_journey_fp INT, - fastest_path INT)''') + fastest_path INT)""" + ) self.conn.commit() @@ -778,13 +926,15 @@ def _initialize_parameter_table(self): parameters["multiple_targets"] = self.multitarget_routing parameters["gtfs_dir"] = self.gtfs_path - for param in ["location_name", - "lat_median", - "lon_median", - "start_time_ut", - "end_time_ut", - "start_date", - "end_date"]: + for param in [ + "location_name", + "lat_median", + "lon_median", + "start_time_ut", + "end_time_ut", + "start_date", + "end_date", + ]: parameters[param] = self.gtfs_meta[param] parameters["target_list"] = "," for key, value in self.routing_params_input.items(): @@ -797,23 +947,23 @@ def create_indices(self): # Next 3 lines are python 3.6 work-arounds again. self.conn.isolation_level = None # former default of autocommit mode cur = self.conn.cursor() - cur.execute('VACUUM;') - self.conn.isolation_level = '' # back to python default + cur.execute("VACUUM;") + self.conn.isolation_level = "" # back to python default # end python3.6 workaround print("Analyzing...") - cur.execute('ANALYZE') + cur.execute("ANALYZE") print("Indexing") cur = self.conn.cursor() - cur.execute('CREATE INDEX IF NOT EXISTS idx_journeys_jid ON journeys (journey_id)') - cur.execute('CREATE INDEX IF NOT EXISTS idx_journeys_fid ON journeys (from_stop_I)') - cur.execute('CREATE INDEX IF NOT EXISTS idx_journeys_tid ON journeys (to_stop_I)') + cur.execute("CREATE INDEX IF NOT EXISTS idx_journeys_jid ON journeys (journey_id)") + cur.execute("CREATE INDEX IF NOT EXISTS idx_journeys_fid ON journeys (from_stop_I)") + cur.execute("CREATE INDEX IF NOT EXISTS idx_journeys_tid ON journeys (to_stop_I)") if self.track_route: - cur.execute('CREATE INDEX IF NOT EXISTS idx_legs_jid ON legs (journey_id)') - cur.execute('CREATE INDEX IF NOT EXISTS idx_legs_trid ON legs (trip_I)') - cur.execute('CREATE INDEX IF NOT EXISTS idx_legs_fid ON legs (from_stop_I)') - cur.execute('CREATE INDEX IF NOT EXISTS idx_legs_tid ON legs (to_stop_I)') - cur.execute('CREATE INDEX IF NOT EXISTS idx_journeys_route ON journeys (route)') + cur.execute("CREATE INDEX IF NOT EXISTS idx_legs_jid ON legs (journey_id)") + cur.execute("CREATE INDEX IF NOT EXISTS idx_legs_trid ON legs (trip_I)") + cur.execute("CREATE INDEX IF NOT EXISTS idx_legs_fid ON legs (from_stop_I)") + cur.execute("CREATE INDEX IF NOT EXISTS idx_legs_tid ON legs (to_stop_I)") + cur.execute("CREATE INDEX IF NOT EXISTS idx_journeys_route ON journeys (route)") self.conn.commit() @@ -831,27 +981,37 @@ def initialize_journey_comparison_tables(self, tables, before_db_tuple, after_db self.conn = self.attach_database(after_db_path, name=after_db_name) for table in tables: - self.conn.execute("CREATE TABLE IF NOT EXISTS diff_" + table + - "(from_stop_I INT, to_stop_I INT, " - "diff_min INT, diff_max INT, diff_median INT, diff_mean INT, " - "rel_diff_min REAL, rel_diff_max REAL, rel_diff_median REAL, rel_diff_mean REAL)") - insert_stmt = "INSERT OR REPLACE INTO diff_" + table + \ - " (from_stop_I, to_stop_I, diff_min, diff_max, diff_median, diff_mean, " \ - "rel_diff_min, rel_diff_max, rel_diff_median, rel_diff_mean) " \ - "SELECT " \ - "t1.from_stop_I, " \ - "t1.to_stop_I, " \ - "t1.min - t2.min AS diff_min, " \ - "t1.max - t2.max AS diff_max, " \ - "t1.median - t2.median AS diff_median, " \ - "t1.mean - t2.mean AS diff_mean, " \ - "(t1.min - t2.min)*1.0/t2.min AS rel_diff_min, " \ - "(t1.max - t2.max)*1.0/t2.max AS rel_diff_max, " \ - "(t1.median - t2.median)*1.0/t2.median AS rel_diff_median, " \ - "(t1.mean - t2.mean)*1.0/t2.mean AS rel_diff_mean " \ - "FROM " + after_db_name + "." + table + " AS t1, "\ - + before_db_name + "." + table + \ - " AS t2 WHERE t1.from_stop_I = t2.from_stop_I AND t1.to_stop_I = t2.to_stop_I " + self.conn.execute( + "CREATE TABLE IF NOT EXISTS diff_" + table + "(from_stop_I INT, to_stop_I INT, " + "diff_min INT, diff_max INT, diff_median INT, diff_mean INT, " + "rel_diff_min REAL, rel_diff_max REAL, rel_diff_median REAL, rel_diff_mean REAL)" + ) + insert_stmt = ( + "INSERT OR REPLACE INTO diff_" + + table + + " (from_stop_I, to_stop_I, diff_min, diff_max, diff_median, diff_mean, " + "rel_diff_min, rel_diff_max, rel_diff_median, rel_diff_mean) " + "SELECT " + "t1.from_stop_I, " + "t1.to_stop_I, " + "t1.min - t2.min AS diff_min, " + "t1.max - t2.max AS diff_max, " + "t1.median - t2.median AS diff_median, " + "t1.mean - t2.mean AS diff_mean, " + "(t1.min - t2.min)*1.0/t2.min AS rel_diff_min, " + "(t1.max - t2.max)*1.0/t2.max AS rel_diff_max, " + "(t1.median - t2.median)*1.0/t2.median AS rel_diff_median, " + "(t1.mean - t2.mean)*1.0/t2.mean AS rel_diff_mean " + "FROM " + + after_db_name + + "." + + table + + " AS t1, " + + before_db_name + + "." + + table + + " AS t2 WHERE t1.from_stop_I = t2.from_stop_I AND t1.to_stop_I = t2.to_stop_I " + ) self.conn.execute(insert_stmt) self.conn.commit() @@ -864,14 +1024,19 @@ def attach_database(self, other_db_path, name="other"): def get_table_with_coordinates(self, gtfs, table_name, target=None, use_relative=False): df = self.get_table_as_dataframe(table_name, use_relative, target) - return gtfs.add_coordinates_to_df(df, join_column='from_stop_I') + return gtfs.add_coordinates_to_df(df, join_column="from_stop_I") def get_table_as_dataframe(self, table_name, use_relative, target=None): if use_relative: - query = "SELECT from_stop_I, to_stop_I, rel_diff_min, rel_diff_max, rel_diff_median, rel_diff_mean FROM "\ - + table_name + query = ( + "SELECT from_stop_I, to_stop_I, rel_diff_min, rel_diff_max, rel_diff_median, rel_diff_mean FROM " + + table_name + ) else: - query = "SELECT from_stop_I, to_stop_I, diff_min, diff_max, diff_median, diff_mean FROM " + table_name + query = ( + "SELECT from_stop_I, to_stop_I, diff_min, diff_max, diff_median, diff_mean FROM " + + table_name + ) if target: query += " WHERE to_stop_I = %s" % target return pd.read_sql_query(query, self.conn) @@ -879,38 +1044,51 @@ def get_table_as_dataframe(self, table_name, use_relative, target=None): def get_temporal_distance_change_o_d_pairs(self, target, threshold): cur = self.conn.cursor() query = """SELECT from_stop_I FROM diff_temporal_distance - WHERE to_stop_I = %s AND abs(diff_mean) >= %s""" % (target, threshold) + WHERE to_stop_I = %s AND abs(diff_mean) >= %s""" % ( + target, + threshold, + ) rows = [x[0] for x in cur.execute(query).fetchall()] return rows def get_largest_component(self, target, threshold=180): - query = """SELECT diff_pre_journey_wait_fp.from_stop_I AS stop_I, - diff_pre_journey_wait_fp.diff_mean AS pre_journey_wait, + query = """SELECT diff_pre_journey_wait_fp.from_stop_I AS stop_I, + diff_pre_journey_wait_fp.diff_mean AS pre_journey_wait, diff_in_vehicle_duration.diff_mean AS in_vehicle_duration, - diff_transfer_wait_duration.diff_mean AS transfer_wait, + diff_transfer_wait_duration.diff_mean AS transfer_wait, diff_walking_duration.diff_mean AS walking_duration, diff_temporal_distance.diff_mean AS temporal_distance - FROM diff_pre_journey_wait_fp, diff_in_vehicle_duration, + FROM diff_pre_journey_wait_fp, diff_in_vehicle_duration, diff_transfer_wait_duration, diff_walking_duration, diff_temporal_distance WHERE diff_pre_journey_wait_fp.rowid = diff_in_vehicle_duration.rowid AND diff_pre_journey_wait_fp.rowid = diff_transfer_wait_duration.rowid AND diff_pre_journey_wait_fp.rowid = diff_walking_duration.rowid AND diff_pre_journey_wait_fp.rowid = diff_temporal_distance.rowid - AND diff_pre_journey_wait_fp.to_stop_I = %s""" % (target,) + AND diff_pre_journey_wait_fp.to_stop_I = %s""" % ( + target, + ) df = pd.read_sql_query(query, self.conn) - df['max_component'] = df[["pre_journey_wait", "in_vehicle_duration", "transfer_wait", "walking_duration"]].idxmax(axis=1) - df['max_value'] = df[["pre_journey_wait", "in_vehicle_duration", "transfer_wait", "walking_duration"]].max(axis=1) + df["max_component"] = df[ + ["pre_journey_wait", "in_vehicle_duration", "transfer_wait", "walking_duration"] + ].idxmax(axis=1) + df["max_value"] = df[ + ["pre_journey_wait", "in_vehicle_duration", "transfer_wait", "walking_duration"] + ].max(axis=1) - mask = (df['max_value'] < threshold) + mask = df["max_value"] < threshold - df.loc[mask, 'max_component'] = "no_change_within_threshold" + df.loc[mask, "max_component"] = "no_change_within_threshold" - df['min_component'] = df[["pre_journey_wait", "in_vehicle_duration", "transfer_wait", "walking_duration"]].idxmin(axis=1) - df['min_value'] = df[["pre_journey_wait", "in_vehicle_duration", "transfer_wait", "walking_duration"]].min(axis=1) + df["min_component"] = df[ + ["pre_journey_wait", "in_vehicle_duration", "transfer_wait", "walking_duration"] + ].idxmin(axis=1) + df["min_value"] = df[ + ["pre_journey_wait", "in_vehicle_duration", "transfer_wait", "walking_duration"] + ].min(axis=1) - mask = (df['min_value'] > -1 * threshold) + mask = df["min_value"] > -1 * threshold - df.loc[mask, 'min_component'] = "no_change_within_threshold" + df.loc[mask, "min_component"] = "no_change_within_threshold" return df @@ -925,7 +1103,9 @@ def __init__(self, conn): self._conn.execute("CREATE TABLE IF NOT EXISTS parameters (key, value)") def __setitem__(self, key, value): - self._conn.execute("INSERT OR REPLACE INTO parameters('key', 'value') VALUES (?, ?)", (key, value)) + self._conn.execute( + "INSERT OR REPLACE INTO parameters('key', 'value') VALUES (?, ?)", (key, value) + ) self._conn.commit() def __getitem__(self, key): @@ -941,29 +1121,27 @@ def __delitem__(self, key): self._conn.commit() def __iter__(self): - cur = self._conn.execute('SELECT key FROM parameters ORDER BY key') + cur = self._conn.execute("SELECT key FROM parameters ORDER BY key") return (x[0] for x in cur) def __contains__(self, key): - val = self._conn.execute('SELECT value FROM parameters WHERE key=?', - (key,)).fetchone() + val = self._conn.execute("SELECT value FROM parameters WHERE key=?", (key,)).fetchone() return val is not None def get(self, key, default=None): - val = self._conn.execute('SELECT value FROM parameters WHERE key=?', - (key,)).fetchone() + val = self._conn.execute("SELECT value FROM parameters WHERE key=?", (key,)).fetchone() if not val: return default return val[0] def items(self): - cur = self._conn.execute('SELECT key, value FROM parameters ORDER BY key') + cur = self._conn.execute("SELECT key, value FROM parameters ORDER BY key") return cur def keys(self): - cur = self._conn.execute('SELECT key FROM parameters ORDER BY key') + cur = self._conn.execute("SELECT key FROM parameters ORDER BY key") return cur def values(self): - cur = self._conn.execute('SELECT value FROM parameters ORDER BY key') + cur = self._conn.execute("SELECT value FROM parameters ORDER BY key") return cur diff --git a/gtfspy/routing/journey_data_analyzer.py b/gtfspy/routing/journey_data_analyzer.py index e2b9587..b511790 100644 --- a/gtfspy/routing/journey_data_analyzer.py +++ b/gtfspy/routing/journey_data_analyzer.py @@ -23,8 +23,16 @@ def __init__(self, journey_db_path, gtfs_path): def __del__(self): self.conn.close() - def get_journey_legs_to_target(self, target, fastest_path=True, min_boardings=False, all_leg_sections=True, - ignore_walk=False, diff_threshold=None, diff_path=None): + def get_journey_legs_to_target( + self, + target, + fastest_path=True, + min_boardings=False, + all_leg_sections=True, + ignore_walk=False, + diff_threshold=None, + diff_path=None, + ): """ Returns a dataframe of aggregated sections from source nodes to target. The returned sections are either transfer point to transfer point or stop to stop. In a before after setting, the results can be filtered based @@ -53,9 +61,11 @@ def get_journey_legs_to_target(self, target, fastest_path=True, min_boardings=Fa if diff_path and diff_threshold: self.conn = attach_database(self.conn, diff_path, name="diff") add_diff = ", diff.diff_temporal_distance" - added_constraints += " AND abs(diff_temporal_distance.diff_mean) >= %s " \ - "AND diff_temporal_distance.from_stop_I = journeys.from_stop_I " \ - "AND diff_temporal_distance.to_stop_I = journeys.to_stop_I" % (diff_threshold,) + added_constraints += ( + " AND abs(diff_temporal_distance.diff_mean) >= %s " + "AND diff_temporal_distance.from_stop_I = journeys.from_stop_I " + "AND diff_temporal_distance.to_stop_I = journeys.to_stop_I" % (diff_threshold,) + ) if all_leg_sections: df = self._get_journey_legs_to_target_with_all_sections(target, added_constraints) @@ -67,7 +77,11 @@ def get_journey_legs_to_target(self, target, fastest_path=True, min_boardings=Fa WHERE journeys.journey_id = legs.journey_id AND journeys.to_stop_I = %s %s) q1 LEFT JOIN (SELECT * FROM other.trips, other.routes WHERE trips.route_I = routes.route_I) q2 ON q1.trip_I = q2.trip_I - GROUP BY from_stop_I, to_stop_I, type""" % (add_diff, str(target), added_constraints) + GROUP BY from_stop_I, to_stop_I, type""" % ( + add_diff, + str(target), + added_constraints, + ) df = read_sql_query(query, self.conn) return df @@ -88,25 +102,39 @@ def gen_pairs(stop_lists): WHERE journeys.journey_id = legs.journey_id AND journeys.to_stop_I = %s %s) q1 LEFT JOIN (SELECT * FROM other.trips, other.routes WHERE trips.route_I = routes.route_I) q2 ON q1.trip_I = q2.trip_I - GROUP BY leg_stops, type""" % (str(target), added_constraint) + GROUP BY leg_stops, type""" % ( + str(target), + added_constraint, + ) orig_df = read_sql_query(query, self.conn) - df = DataFrame([x for x in gen_pairs(orig_df.leg_stops.str.split(',').tolist())], - index=[orig_df.type, orig_df.n_trips]).stack() + df = DataFrame( + [x for x in gen_pairs(orig_df.leg_stops.str.split(",").tolist())], + index=[orig_df.type, orig_df.n_trips], + ).stack() df = df.reset_index() df = df.rename(columns={0: "stop_tuple"}) - df[['from_stop_I', 'to_stop_I']] = df['stop_tuple'].apply(Series) + df[["from_stop_I", "to_stop_I"]] = df["stop_tuple"].apply(Series) - df = df.groupby(['from_stop_I', 'to_stop_I', 'type']).agg({'n_trips': [np.sum]}) + df = df.groupby(["from_stop_I", "to_stop_I", "type"]).agg({"n_trips": [np.sum]}) df = df.reset_index() df.columns = df.columns.droplevel(1) - df_to_return = df[['from_stop_I', 'to_stop_I', 'type', 'n_trips']] + df_to_return = df[["from_stop_I", "to_stop_I", "type", "n_trips"]] return df_to_return - def get_origin_target_journey_legs(self, origin, target, start_time=None, end_time=None, fastest_path=True, min_boardings=False, - ignore_walk=False, add_coordinates=True): + def get_origin_target_journey_legs( + self, + origin, + target, + start_time=None, + end_time=None, + fastest_path=True, + min_boardings=False, + ignore_walk=False, + add_coordinates=True, + ): assert not (fastest_path and min_boardings) if min_boardings: @@ -126,41 +154,64 @@ def get_origin_target_journey_legs(self, origin, target, start_time=None, end_ti count(*) AS n_trips, group_concat(dep_time) AS dep_times FROM (SELECT legs.*, journeys.departure_time as dep_time FROM legs, journeys - WHERE journeys.journey_id = legs.journey_id AND journeys.from_stop_I = %s + WHERE journeys.journey_id = legs.journey_id AND journeys.from_stop_I = %s AND journeys.to_stop_I = %s %s ORDER BY dep_time ) q1 LEFT JOIN (SELECT * FROM other.trips, other.routes WHERE trips.route_I = routes.route_I) q2 ON q1.trip_I = q2.trip_I - GROUP BY from_stop_I, to_stop_I, type""" % (str(origin), str(target), added_constraints) + GROUP BY from_stop_I, to_stop_I, type""" % ( + str(origin), + str(target), + added_constraints, + ) df = read_sql_query(query, self.conn) if add_coordinates: - df = self.g.add_coordinates_to_df(df, join_column="from_stop_I", lat_name="from_lat", lon_name="from_lon") - df = self.g.add_coordinates_to_df(df, join_column="to_stop_I", lat_name="to_lat", lon_name="to_lon") + df = self.g.add_coordinates_to_df( + df, join_column="from_stop_I", lat_name="from_lat", lon_name="from_lon" + ) + df = self.g.add_coordinates_to_df( + df, join_column="to_stop_I", lat_name="to_lat", lon_name="to_lon" + ) return df - def get_journey_routes_not_in_other_db(self, target, other_journey_conn, fastest_path=True, min_boardings=False, all_leg_sections=True, - ignore_walk=False, diff_threshold=None, diff_path=None): - name = "ojdb" + def get_journey_routes_not_in_other_db( + self, + target, + other_journey_conn, + fastest_path=True, + min_boardings=False, + all_leg_sections=True, + ignore_walk=False, + diff_threshold=None, + diff_path=None, + ): added_constraints = "" if fastest_path: added_constraints += " AND journeys.pre_journey_wait_fp>=0" if ignore_walk: added_constraints += " AND legs.trip_I >= 0" - query = """SELECT from_stop_I, to_stop_I, coalesce(type, -1) AS type, route FROM + query = """SELECT from_stop_I, to_stop_I, coalesce(type, -1) AS type, route FROM (SELECT legs.*, route FROM journeys, legs WHERE legs.journey_id=journeys.journey_id AND journeys.to_stop_I = %s %s) q1 - LEFT JOIN + LEFT JOIN (SELECT * FROM other.trips, other.routes WHERE trips.route_I = routes.route_I) q2 ON q1.trip_I = q2.trip_I - """ % (str(target), added_constraints) + """ % ( + str(target), + added_constraints, + ) df = read_sql_query(query, self.conn) - routes = other_journey_conn.execute("SELECT DISTINCT route FROM journeys WHERE to_stop_I = %s" % (str(target),)).fetchall() + routes = other_journey_conn.execute( + "SELECT DISTINCT route FROM journeys WHERE to_stop_I = %s" % (str(target),) + ).fetchall() routes = [x[0] for x in routes] other_set = set(routes) - these_routes = self.conn.execute("SELECT DISTINCT route FROM journeys WHERE to_stop_I = %s" % (str(target),)).fetchall() + these_routes = self.conn.execute( + "SELECT DISTINCT route FROM journeys WHERE to_stop_I = %s" % (str(target),) + ).fetchall() these_routes = [x[0] for x in these_routes] this_set = set(these_routes) print("n unique routes for this db: ", len(this_set)) @@ -171,19 +222,25 @@ def get_journey_routes_not_in_other_db(self, target, other_journey_conn, fastest print("n unique routes", len(union)) print("n common routes", len(intersection)) - df = df.loc[~df['route'].isin(routes)] + df = df.loc[~df["route"].isin(routes)] df = df[["from_stop_I", "to_stop_I", "type"]] - df = DataFrame({"n_trips": df.groupby(["from_stop_I", "to_stop_I", "type"]).size()}).reset_index() + df = DataFrame( + {"n_trips": df.groupby(["from_stop_I", "to_stop_I", "type"]).size()} + ).reset_index() return df def journey_alternatives_per_stop_pair(self, target, start_time, end_time): query = """SELECT from_stop_I, to_stop_I, ifnull(1.0*sum(n_sq)/(sum(n_trips)*(sum(n_trips)-1)), 1) AS simpson, - sum(n_trips) AS n_trips, count(*) AS n_routes FROM - (SELECT from_stop_I, to_stop_I, count(*) AS n_trips, count(*)*(count(*)-1) AS n_sq + sum(n_trips) AS n_trips, count(*) AS n_routes FROM + (SELECT from_stop_I, to_stop_I, count(*) AS n_trips, count(*)*(count(*)-1) AS n_sq FROM journeys WHERE pre_journey_wait_fp > 0 AND to_stop_I = %s AND departure_time >= %s AND departure_time <= %s GROUP BY route) sq1 - GROUP BY from_stop_I, to_stop_I""" % (target, start_time, end_time) + GROUP BY from_stop_I, to_stop_I""" % ( + target, + start_time, + end_time, + ) df = read_sql_query(query, self.conn) df = self.g.add_coordinates_to_df(df, join_column="from_stop_I") @@ -191,31 +248,44 @@ def journey_alternatives_per_stop_pair(self, target, start_time, end_time): def journey_alternative_data_time_weighted(self, target, start_time, end_time): query = """SELECT sum(p*p) AS simpson, sum(n_trips) AS n_trips, count(*) AS n_routes, from_stop_I, to_stop_I FROM - (SELECT 1.0*sum(pre_journey_wait_fp)/total_time AS p, count(*) AS n_trips, route, + (SELECT 1.0*sum(pre_journey_wait_fp)/total_time AS p, count(*) AS n_trips, route, journeys.from_stop_I, journeys.to_stop_I FROM journeys, (SELECT sum(pre_journey_wait_fp) AS total_time, from_stop_I, to_stop_I FROM journeys WHERE departure_time >= %s AND departure_time <= %s GROUP BY from_stop_I, to_stop_I) sq1 - WHERE pre_journey_wait_fp > 0 AND sq1.to_stop_I=journeys.to_stop_I AND departure_time >= %s - AND departure_time <= %s AND journeys.to_stop_I = %s AND sq1.from_stop_I = journeys.from_stop_I + WHERE pre_journey_wait_fp > 0 AND sq1.to_stop_I=journeys.to_stop_I AND departure_time >= %s + AND departure_time <= %s AND journeys.to_stop_I = %s AND sq1.from_stop_I = journeys.from_stop_I GROUP BY route) sq2 - GROUP BY from_stop_I, to_stop_I""" % (start_time, end_time, start_time, end_time, target) + GROUP BY from_stop_I, to_stop_I""" % ( + start_time, + end_time, + start_time, + end_time, + target, + ) df = read_sql_query(query, self.conn) df = self.g.add_coordinates_to_df(df, join_column="from_stop_I") return df def _add_to_from_coordinates_to_df(self, df): - df = self.g.add_coordinates_to_df(df, join_column="from_stop_I", lat_name="from_lat", lon_name="from_lon") - df = self.g.add_coordinates_to_df(df, join_column="to_stop_I", lat_name="to_lat", lon_name="to_lon") + df = self.g.add_coordinates_to_df( + df, join_column="from_stop_I", lat_name="from_lat", lon_name="from_lon" + ) + df = self.g.add_coordinates_to_df( + df, join_column="to_stop_I", lat_name="to_lat", lon_name="to_lon" + ) return df def get_upstream_stops(self, target, stop): - query = """SELECT stops.* FROM other.stops, - (SELECT journeys.from_stop_I AS stop_I FROM journeys, legs + query = """SELECT stops.* FROM other.stops, + (SELECT journeys.from_stop_I AS stop_I FROM journeys, legs WHERE journeys.journey_id=legs.journey_id AND legs.from_stop_I = %s AND journeys.to_stop_I = %s AND pre_journey_wait_fp >= 0 GROUP BY journeys.from_stop_I) q1 - WHERE stops.stop_I = q1.stop_I""" % (stop, target) + WHERE stops.stop_I = q1.stop_I""" % ( + stop, + target, + ) df = read_sql_query(query, self.conn) return df @@ -229,16 +299,18 @@ def get_upstream_stops_ratio(self, target, trough_stops, ratio): """ if isinstance(trough_stops, list): trough_stops = ",".join(trough_stops) - query = """SELECT stops.* FROM other.stops, - (SELECT q2.from_stop_I AS stop_I FROM + query = """SELECT stops.* FROM other.stops, + (SELECT q2.from_stop_I AS stop_I FROM (SELECT journeys.from_stop_I, count(*) AS n_total FROM journeys - WHERE journeys.to_stop_I = {target} + WHERE journeys.to_stop_I = {target} GROUP BY from_stop_I) q1, - (SELECT journeys.from_stop_I, count(*) AS n_trough FROM journeys, legs + (SELECT journeys.from_stop_I, count(*) AS n_trough FROM journeys, legs WHERE journeys.journey_id=legs.journey_id AND legs.from_stop_I IN ({trough_stops}) AND journeys.to_stop_I = {target} GROUP BY journeys.from_stop_I) q2 WHERE q1.from_stop_I = q2.from_stop_I AND n_trough/(n_total*1.0) >= {ratio}) q1 - WHERE stops.stop_I = q1.stop_I""".format(target=target, trough_stops=trough_stops, ratio=ratio) + WHERE stops.stop_I = q1.stop_I""".format( + target=target, trough_stops=trough_stops, ratio=ratio + ) df = read_sql_query(query, self.conn) return df @@ -248,6 +320,7 @@ def passing_journeys_per_stop(self): :return: """ pass + @timeit def journeys_per_section(self, fastest_path=False, time_weighted=False): """ @@ -303,4 +376,3 @@ def get_journey_time_per_mode(self, modes=None): def get_walking_time(self): pass - diff --git a/gtfspy/routing/multi_objective_pseudo_connection_scan_profiler.py b/gtfspy/routing/multi_objective_pseudo_connection_scan_profiler.py index 8931c9d..29655a6 100644 --- a/gtfspy/routing/multi_objective_pseudo_connection_scan_profiler.py +++ b/gtfspy/routing/multi_objective_pseudo_connection_scan_profiler.py @@ -6,8 +6,15 @@ from gtfspy.routing.connection import Connection from gtfspy.routing.abstract_routing_algorithm import AbstractRoutingAlgorithm from gtfspy.routing.node_profile_multiobjective import NodeProfileMultiObjective -from gtfspy.routing.label import merge_pareto_frontiers, LabelTimeWithBoardingsCount, LabelTime, compute_pareto_front, \ - LabelVehLegCount, LabelTimeBoardingsAndRoute, LabelTimeAndRoute +from gtfspy.routing.label import ( + merge_pareto_frontiers, + LabelTimeWithBoardingsCount, + LabelTime, + compute_pareto_front, + LabelVehLegCount, + LabelTimeBoardingsAndRoute, + LabelTimeAndRoute, +) from gtfspy.util import timeit @@ -18,18 +25,20 @@ class MultiObjectivePseudoCSAProfiler(AbstractRoutingAlgorithm): http://i11www.iti.uni-karlsruhe.de/extra/publications/dpsw-isftr-13.pdf """ - def __init__(self, - transit_events, - targets, - start_time_ut=None, - end_time_ut=None, - transfer_margin=0, - walk_network=None, - walk_speed=1.5, - verbose=False, - track_vehicle_legs=True, - track_time=True, - track_route=False): + def __init__( + self, + transit_events, + targets, + start_time_ut=None, + end_time_ut=None, + transfer_margin=0, + walk_network=None, + walk_speed=1.5, + verbose=False, + track_vehicle_legs=True, + track_time=True, + track_route=False, + ): """ Parameters ---------- @@ -55,7 +64,7 @@ def __init__(self, whether to consider time in the set of pareto_optimal """ AbstractRoutingAlgorithm.__init__(self) - assert (len(transit_events) == len(set(transit_events))), "Duplicate transit events spotted!" + assert len(transit_events) == len(set(transit_events)), "Duplicate transit events spotted!" self._transit_connections = transit_events if start_time_ut is None: start_time_ut = transit_events[-1].departure_time @@ -76,7 +85,7 @@ def __init__(self, self._count_vehicle_legs = track_vehicle_legs self._consider_time = track_time - assert(track_time or track_vehicle_legs) + assert track_time or track_vehicle_legs if track_vehicle_legs: if track_time: if track_route: @@ -92,15 +101,22 @@ def __init__(self, self._label_class = LabelTime print("using label:", str(self._label_class)) - self._stop_departure_times, self._stop_arrival_times = self.__compute_stop_dep_and_arrival_times() - self._all_nodes = set.union(set(self._stop_departure_times.keys()), - set(self._stop_arrival_times.keys()), - set(self._walk_network.nodes())) + ( + self._stop_departure_times, + self._stop_arrival_times, + ) = self.__compute_stop_dep_and_arrival_times() + self._all_nodes = set.union( + set(self._stop_departure_times.keys()), + set(self._stop_arrival_times.keys()), + set(self._walk_network.nodes()), + ) self._pseudo_connections = self.__compute_pseudo_connections() self._add_pseudo_connection_departures_to_stop_departure_times() self._all_connections = self._pseudo_connections + self._transit_connections - self._all_connections.sort(key=lambda connection: (-connection.departure_time, -connection.seq)) + self._all_connections.sort( + key=lambda connection: (-connection.departure_time, -connection.seq) + ) self._augment_all_connections_with_arrival_stop_next_dep_time() if isinstance(targets, list): self._targets = targets @@ -117,18 +133,20 @@ def _add_pseudo_connection_departures_to_stop_departure_times(self): for key, value in self._stop_departure_times_with_pseudo_connections.items(): self._stop_departure_times_with_pseudo_connections[key] = list(value) for pseudo_connection in self._pseudo_connections: - assert(isinstance(pseudo_connection, Connection)) - self._stop_departure_times_with_pseudo_connections[pseudo_connection.departure_stop]\ - .append(pseudo_connection.departure_time) + assert isinstance(pseudo_connection, Connection) + self._stop_departure_times_with_pseudo_connections[ + pseudo_connection.departure_stop + ].append(pseudo_connection.departure_time) for stop, dep_times in self._stop_departure_times_with_pseudo_connections.items(): - self._stop_departure_times_with_pseudo_connections[stop] = numpy.array(list(sorted(set(dep_times)))) - + self._stop_departure_times_with_pseudo_connections[stop] = numpy.array( + list(sorted(set(dep_times))) + ) @timeit def __initialize_node_profiles(self): self._stop_profiles = dict() for node in self._all_nodes: - walk_duration_to_target = float('inf') + walk_duration_to_target = float("inf") closest_target = None if node in self._targets: walk_duration_to_target = 0 @@ -142,12 +160,15 @@ def __initialize_node_profiles(self): walk_duration_to_target = walk_duration closest_target = target - self._stop_profiles[node] = NodeProfileMultiObjective(dep_times=self._stop_departure_times_with_pseudo_connections[node], - label_class=self._label_class, - walk_to_target_duration=walk_duration_to_target, - transit_connection_dep_times=self._stop_departure_times[node], - closest_target=closest_target, - node_id=node) + self._stop_profiles[node] = NodeProfileMultiObjective( + dep_times=self._stop_departure_times_with_pseudo_connections[node], + label_class=self._label_class, + walk_to_target_duration=walk_duration_to_target, + transit_connection_dep_times=self._stop_departure_times[node], + closest_target=closest_target, + node_id=node, + ) + @timeit def __compute_stop_dep_and_arrival_times(self): stop_departure_times = defaultdict(lambda: list()) @@ -161,14 +182,15 @@ def __compute_stop_dep_and_arrival_times(self): stop_arrival_times[stop] = numpy.array(sorted(list(set(stop_arrival_times[stop])))) return stop_departure_times, stop_arrival_times - @timeit def __compute_pseudo_connections(self): print("Started computing pseudoconnections") pseudo_connections = [] # DiGraph makes things iterate both ways (!) for u, v, data in networkx.DiGraph(self._walk_network).edges(data=True): - walk_duration = int(data["d_walk"] / float(self._walk_speed)) # round to one second accuracy + walk_duration = int( + data["d_walk"] / float(self._walk_speed) + ) # round to one second accuracy total_walk_time_with_transfer = walk_duration + self._transfer_margin in_times = self._stop_arrival_times[u] out_times = self._stop_departure_times[v] @@ -183,18 +205,26 @@ def __compute_pseudo_connections(self): j += 1 # connection j cannot be reached -> need to check next j -> increase out_time else: # if next element still satisfies the wanted condition, go on and increase i! - while i + 1 < n_in_times and in_times[i + 1] + total_walk_time_with_transfer < out_times[j]: + while ( + i + 1 < n_in_times + and in_times[i + 1] + total_walk_time_with_transfer < out_times[j] + ): i += 1 dep_time = in_times[i] arr_time = out_times[j] from_stop = u to_stop = v waiting_time = arr_time - dep_time - total_walk_time_with_transfer - assert(waiting_time >= 0) - pseudo = Connection(from_stop, to_stop, arr_time - walk_duration, arr_time, - Connection.WALK_TRIP_ID, - Connection.WALK_SEQ, - is_walk=True) + assert waiting_time >= 0 + pseudo = Connection( + from_stop, + to_stop, + arr_time - walk_duration, + arr_time, + Connection.WALK_TRIP_ID, + Connection.WALK_SEQ, + is_walk=True, + ) pseudo_connections.append(pseudo) i += 1 print("Computed pseudoconnections") @@ -203,39 +233,43 @@ def __compute_pseudo_connections(self): @timeit def _augment_all_connections_with_arrival_stop_next_dep_time(self): for connection in self._all_connections: - assert(isinstance(connection, Connection)) + assert isinstance(connection, Connection) to_stop = connection.arrival_stop arr_stop_dep_times = self._stop_departure_times_with_pseudo_connections[to_stop] - arr_stop_next_dep_time = float('inf') + arr_stop_next_dep_time = float("inf") if len(arr_stop_dep_times) > 0: if connection.is_walk: index = numpy.searchsorted(arr_stop_dep_times, connection.arrival_time) else: - index = numpy.searchsorted(arr_stop_dep_times, connection.arrival_time + self._transfer_margin) + index = numpy.searchsorted( + arr_stop_dep_times, connection.arrival_time + self._transfer_margin + ) if 0 <= index < len(arr_stop_dep_times): arr_stop_next_dep_time = arr_stop_dep_times[index] - if connection.is_walk and not (arr_stop_next_dep_time < float('inf')): - assert (arr_stop_next_dep_time < float('inf')) + if connection.is_walk and not (arr_stop_next_dep_time < float("inf")): + assert arr_stop_next_dep_time < float("inf") connection.arrival_stop_next_departure_time = arr_stop_next_dep_time def _get_modified_arrival_node_labels(self, connection): # get all different "accessible" / arrival times (Pareto-optimal sets) arrival_profile = self._stop_profiles[connection.arrival_stop] # NodeProfileMultiObjective - assert (isinstance(arrival_profile, NodeProfileMultiObjective)) + assert isinstance(arrival_profile, NodeProfileMultiObjective) - arrival_node_labels_orig = arrival_profile.evaluate(connection.arrival_stop_next_departure_time, - first_leg_can_be_walk=not connection.is_walk, - connection_arrival_time=connection.arrival_time) + arrival_node_labels_orig = arrival_profile.evaluate( + connection.arrival_stop_next_departure_time, + first_leg_can_be_walk=not connection.is_walk, + connection_arrival_time=connection.arrival_time, + ) - increment_vehicle_count = (self._count_vehicle_legs and not connection.is_walk) + increment_vehicle_count = self._count_vehicle_legs and not connection.is_walk # TODO: (?) this copying / modification logic should be moved to the Label / ForwardJourney class ? arrival_node_labels_modified = self._copy_and_modify_labels( arrival_node_labels_orig, connection, increment_vehicle_count=increment_vehicle_count, - first_leg_is_walk=connection.is_walk + first_leg_is_walk=connection.is_walk, ) if connection.is_walk: connection.is_walk = True @@ -245,10 +279,12 @@ def _get_modified_arrival_node_labels(self, connection): def _get_trip_labels(self, connection): # best labels from this current trip if not connection.is_walk: - trip_labels = self._copy_and_modify_labels(self.__trip_labels[connection.trip_id], - connection, - increment_vehicle_count=False, - first_leg_is_walk=False) + trip_labels = self._copy_and_modify_labels( + self.__trip_labels[connection.trip_id], + connection, + increment_vehicle_count=False, + first_leg_is_walk=False, + ) else: trip_labels = list() return trip_labels @@ -260,9 +296,18 @@ def _run(self): for i, connection in enumerate(self._all_connections): # basic checking + printing progress: if self._verbose and i % 1000 == 0: - print("\r", i, "/", n_connections_tot, " : ", "%.2f" % round(float(i) / n_connections_tot, 3), end='', flush=True) - assert (isinstance(connection, Connection)) - assert (connection.departure_time <= previous_departure_time) + print( + "\r", + i, + "/", + n_connections_tot, + " : ", + "%.2f" % round(float(i) / n_connections_tot, 3), + end="", + flush=True, + ) + assert isinstance(connection, Connection) + assert connection.departure_time <= previous_departure_time previous_departure_time = connection.departure_time # Get labels from the stop (possibly subject to buffer time) @@ -278,9 +323,9 @@ def _run(self): self.__trip_labels[connection.trip_id] = all_pareto_optimal_labels # Update labels for the departure stop profile (later: with the sets of pareto-optimal labels) - self._stop_profiles[connection.departure_stop].update(all_pareto_optimal_labels, - connection.departure_time) - + self._stop_profiles[connection.departure_stop].update( + all_pareto_optimal_labels, connection.departure_time + ) print("finalizing profiles!") self._finalize_profiles() @@ -290,7 +335,7 @@ def _finalize_profiles(self): Deal with the first walks by joining profiles to other stops within walking distance. """ for stop, stop_profile in self._stop_profiles.items(): - assert (isinstance(stop_profile, NodeProfileMultiObjective)) + assert isinstance(stop_profile, NodeProfileMultiObjective) neighbor_label_bags = [] walk_durations_to_neighbors = [] departure_arrival_stop_pairs = [] @@ -298,13 +343,21 @@ def _finalize_profiles(self): neighbors = networkx.all_neighbors(self._walk_network, stop) for neighbor in neighbors: neighbor_profile = self._stop_profiles[neighbor] - assert (isinstance(neighbor_profile, NodeProfileMultiObjective)) - neighbor_real_connection_labels = neighbor_profile.get_labels_for_real_connections() + assert isinstance(neighbor_profile, NodeProfileMultiObjective) + neighbor_real_connection_labels = ( + neighbor_profile.get_labels_for_real_connections() + ) neighbor_label_bags.append(neighbor_real_connection_labels) - walk_durations_to_neighbors.append(int(self._walk_network.get_edge_data(stop, neighbor)["d_walk"] / - self._walk_speed)) + walk_durations_to_neighbors.append( + int( + self._walk_network.get_edge_data(stop, neighbor)["d_walk"] + / self._walk_speed + ) + ) departure_arrival_stop_pairs.append((stop, neighbor)) - stop_profile.finalize(neighbor_label_bags, walk_durations_to_neighbors, departure_arrival_stop_pairs) + stop_profile.finalize( + neighbor_label_bags, walk_durations_to_neighbors, departure_arrival_stop_pairs + ) @property def stop_profiles(self): @@ -317,15 +370,23 @@ def stop_profiles(self): assert self._has_run return self._stop_profiles - def _copy_and_modify_labels(self, labels, connection, increment_vehicle_count=False, first_leg_is_walk=False): - if self._label_class == LabelTimeBoardingsAndRoute or self._label_class == LabelTimeAndRoute: + def _copy_and_modify_labels( + self, labels, connection, increment_vehicle_count=False, first_leg_is_walk=False + ): + if ( + self._label_class == LabelTimeBoardingsAndRoute + or self._label_class == LabelTimeAndRoute + ): labels_copy = [label.get_label_with_connection_added(connection) for label in labels] else: labels_copy = [label.get_copy() for label in labels] for label in labels_copy: label.departure_time = connection.departure_time - if self._label_class == LabelTimeAndRoute or self._label_class == LabelTimeBoardingsAndRoute: + if ( + self._label_class == LabelTimeAndRoute + or self._label_class == LabelTimeBoardingsAndRoute + ): label.movement_duration += connection.duration() if increment_vehicle_count: label.n_boardings += 1 @@ -339,7 +400,7 @@ def reset(self, targets): else: self._targets = [targets] for target in targets: - assert(target in self._all_nodes) + assert target in self._all_nodes self.__initialize_node_profiles() self.__trip_labels = defaultdict(lambda: list()) self._has_run = False diff --git a/gtfspy/routing/node_profile_analyzer_time.py b/gtfspy/routing/node_profile_analyzer_time.py index 0a770d7..f4dbaf5 100644 --- a/gtfspy/routing/node_profile_analyzer_time.py +++ b/gtfspy/routing/node_profile_analyzer_time.py @@ -17,20 +17,21 @@ def wrapper(self): if self.trip_durations: return func(self) else: - return float('inf') + return float("inf") return wrapper class NodeProfileAnalyzerTime: - @classmethod def from_profile(cls, node_profile, start_time_dep, end_time_dep): assert isinstance(node_profile, NodeProfileSimple), type(node_profile) - return NodeProfileAnalyzerTime(node_profile.get_final_optimal_labels(), - node_profile.get_walk_to_target_duration(), - start_time_dep, - end_time_dep) + return NodeProfileAnalyzerTime( + node_profile.get_final_optimal_labels(), + node_profile.get_walk_to_target_duration(), + start_time_dep, + end_time_dep, + ) def __init__(self, labels, walk_time_to_target, start_time_dep, end_time_dep): """ @@ -44,23 +45,32 @@ def __init__(self, labels, walk_time_to_target, start_time_dep, end_time_dep): self.start_time_dep = start_time_dep self.end_time_dep = end_time_dep # used for computing temporal distances: - all_pareto_optimal_tuples = [pt for pt in labels if - (start_time_dep < pt.departure_time < end_time_dep)] + all_pareto_optimal_tuples = [ + pt for pt in labels if (start_time_dep < pt.departure_time < end_time_dep) + ] - labels_after_dep_time = [label for label in labels if label.departure_time >= self.end_time_dep] + labels_after_dep_time = [ + label for label in labels if label.departure_time >= self.end_time_dep + ] if labels_after_dep_time: - next_label_after_end_time = min(labels_after_dep_time, key=lambda el: el.arrival_time_target) + next_label_after_end_time = min( + labels_after_dep_time, key=lambda el: el.arrival_time_target + ) all_pareto_optimal_tuples.append(next_label_after_end_time) - all_pareto_optimal_tuples = sorted(all_pareto_optimal_tuples, key=lambda ptuple: ptuple.departure_time) + all_pareto_optimal_tuples = sorted( + all_pareto_optimal_tuples, key=lambda ptuple: ptuple.departure_time + ) arrival_time_target_at_end_time = end_time_dep + walk_time_to_target previous_trip = None for trip_tuple in all_pareto_optimal_tuples: if previous_trip: - assert(trip_tuple.arrival_time_target > previous_trip.arrival_time_target) - if trip_tuple.departure_time > self.end_time_dep \ - and trip_tuple.arrival_time_target < arrival_time_target_at_end_time: + assert trip_tuple.arrival_time_target > previous_trip.arrival_time_target + if ( + trip_tuple.departure_time > self.end_time_dep + and trip_tuple.arrival_time_target < arrival_time_target_at_end_time + ): arrival_time_target_at_end_time = trip_tuple.arrival_time_target previous_trip = trip_tuple @@ -74,23 +84,29 @@ def __init__(self, labels, walk_time_to_target, start_time_dep, end_time_dep): continue if self._walk_time_to_target <= trip_pareto_tuple.duration(): print(self._walk_time_to_target, trip_pareto_tuple.duration()) - assert(self._walk_time_to_target > trip_pareto_tuple.duration()) + assert self._walk_time_to_target > trip_pareto_tuple.duration() effective_trip_previous_departure_time = max( previous_departure_time, - trip_pareto_tuple.departure_time - (self._walk_time_to_target - trip_pareto_tuple.duration()) + trip_pareto_tuple.departure_time + - (self._walk_time_to_target - trip_pareto_tuple.duration()), ) if effective_trip_previous_departure_time > previous_departure_time: - walk_block = ProfileBlock(start_time=previous_departure_time, - end_time=effective_trip_previous_departure_time, - distance_start=self._walk_time_to_target, - distance_end=self._walk_time_to_target - ) + walk_block = ProfileBlock( + start_time=previous_departure_time, + end_time=effective_trip_previous_departure_time, + distance_start=self._walk_time_to_target, + distance_end=self._walk_time_to_target, + ) self._profile_blocks.append(walk_block) - trip_waiting_time = trip_pareto_tuple.departure_time - effective_trip_previous_departure_time - trip_block = ProfileBlock(end_time=trip_pareto_tuple.departure_time, - start_time=effective_trip_previous_departure_time, - distance_start=trip_pareto_tuple.duration() + trip_waiting_time, - distance_end=trip_pareto_tuple.duration()) + trip_waiting_time = ( + trip_pareto_tuple.departure_time - effective_trip_previous_departure_time + ) + trip_block = ProfileBlock( + end_time=trip_pareto_tuple.departure_time, + start_time=effective_trip_previous_departure_time, + distance_start=trip_pareto_tuple.duration() + trip_waiting_time, + distance_end=trip_pareto_tuple.duration(), + ) self.trip_durations.append(trip_pareto_tuple.duration()) self.trip_departure_times.append(trip_pareto_tuple.departure_time) self._profile_blocks.append(trip_block) @@ -105,32 +121,37 @@ def __init__(self, labels, walk_time_to_target, start_time_dep, end_time_dep): waiting_time = end_time_dep - dep_previous distance_end_trip = arrival_time_target_at_end_time - end_time_dep - walking_wait_time = min(end_time_dep - dep_previous, - waiting_time - (self._walk_time_to_target - distance_end_trip)) + walking_wait_time = min( + end_time_dep - dep_previous, + waiting_time - (self._walk_time_to_target - distance_end_trip), + ) walking_wait_time = max(0, walking_wait_time) if walking_wait_time > 0: - walk_block = ProfileBlock(start_time=dep_previous, - end_time=dep_previous + walking_wait_time, - distance_start=self._walk_time_to_target, - distance_end=self._walk_time_to_target - ) - assert (walk_block.start_time <= walk_block.end_time) - assert (walk_block.distance_end <= walk_block.distance_start) + walk_block = ProfileBlock( + start_time=dep_previous, + end_time=dep_previous + walking_wait_time, + distance_start=self._walk_time_to_target, + distance_end=self._walk_time_to_target, + ) + assert walk_block.start_time <= walk_block.end_time + assert walk_block.distance_end <= walk_block.distance_start self._profile_blocks.append(walk_block) trip_waiting_time = waiting_time - walking_wait_time if trip_waiting_time > 0: try: - trip_block = ProfileBlock(start_time=dep_previous + walking_wait_time, - end_time=dep_previous + walking_wait_time + trip_waiting_time, - distance_start=distance_end_trip + trip_waiting_time, - distance_end=distance_end_trip) - assert (trip_block.start_time <= trip_block.end_time) - assert (trip_block.distance_end <= trip_block.distance_start) + trip_block = ProfileBlock( + start_time=dep_previous + walking_wait_time, + end_time=dep_previous + walking_wait_time + trip_waiting_time, + distance_start=distance_end_trip + trip_waiting_time, + distance_end=distance_end_trip, + ) + assert trip_block.start_time <= trip_block.end_time + assert trip_block.distance_end <= trip_block.distance_start self._profile_blocks.append(trip_block) - except AssertionError as e: + except AssertionError: # the error was due to a very small waiting timesmall numbers - assert(trip_waiting_time < 10**-5) + assert trip_waiting_time < 10 ** -5 # TODO? Refactor to use the cutoff_distance feature in ProfileBlockAnalyzer? self.profile_block_analyzer = ProfileBlockAnalyzer(profile_blocks=self._profile_blocks) @@ -270,8 +291,13 @@ def plot_temporal_distance_pdf(self, use_minutes=True, color="green", ax=None): fig: matplotlib.Figure """ from matplotlib import pyplot as plt - plt.rc('text', usetex=True) - temporal_distance_split_points_ordered, densities, delta_peaks = self._temporal_distance_pdf() + + plt.rc("text", usetex=True) + ( + temporal_distance_split_points_ordered, + densities, + delta_peaks, + ) = self._temporal_distance_pdf() xs = [] for i, x in enumerate(temporal_distance_split_points_ordered): xs.append(x) @@ -309,30 +335,40 @@ def plot_temporal_distance_pdf(self, use_minutes=True, color="green", ax=None): for loc, mass in delta_peaks.items(): ax.plot([loc, loc], [0, peak_height], color="green", lw=5) - ax.text(loc + text_x_offset, peak_height * 0.99, "$P(\\mathrm{walk}) = %.2f$" % (mass), color="green") + ax.text( + loc + text_x_offset, + peak_height * 0.99, + "$P(\\mathrm{walk}) = %.2f$" % (mass), + color="green", + ) ax.set_xlim(now_min_x, now_max_x) tot_delta_peak_mass = sum(delta_peaks.values()) transit_text_x = (min_x + max_x) / 2 - transit_text_y = min(ys[ys > 0]) / 2. - ax.text(transit_text_x, - transit_text_y, - "$P(mathrm{PT}) = %.2f$" % (1 - tot_delta_peak_mass), - color="green", - va="center", - ha="center") + transit_text_y = min(ys[ys > 0]) / 2.0 + ax.text( + transit_text_x, + transit_text_y, + "$P(mathrm{PT}) = %.2f$" % (1 - tot_delta_peak_mass), + color="green", + va="center", + ha="center", + ) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_ylim(bottom=0) return ax.figure - def plot_temporal_distance_pdf_horizontal(self, use_minutes=True, - color="green", - ax=None, - duration_divider=60.0, - legend_font_size=None, - legend_loc=None): + def plot_temporal_distance_pdf_horizontal( + self, + use_minutes=True, + color="green", + ax=None, + duration_divider=60.0, + legend_font_size=None, + legend_loc=None, + ): """ Plot the temporal distance probability density function. @@ -341,13 +377,18 @@ def plot_temporal_distance_pdf_horizontal(self, use_minutes=True, fig: matplotlib.Figure """ from matplotlib import pyplot as plt - plt.rc('text', usetex=True) + + plt.rc("text", usetex=True) if ax is None: fig = plt.figure() ax = fig.add_subplot(111) - temporal_distance_split_points_ordered, densities, delta_peaks = self._temporal_distance_pdf() + ( + temporal_distance_split_points_ordered, + densities, + delta_peaks, + ) = self._temporal_distance_pdf() xs = [] for i, x in enumerate(temporal_distance_split_points_ordered): xs.append(x) @@ -370,12 +411,12 @@ def plot_temporal_distance_pdf_horizontal(self, use_minutes=True, if delta_peaks: peak_height = max(ys) * 1.4 - max_x = max(xs) - min_x = min(xs) - now_max_x = max(xs) + 0.3 * (max_x - min_x) - now_min_x = min_x - 0.1 * (max_x - min_x) - - text_x_offset = 0.1 * (now_max_x - max_x) + # max_x = max(xs) + # min_x = min(xs) + # now_max_x = max(xs) + 0.3 * (max_x - min_x) + # now_min_x = min_x - 0.1 * (max_x - min_x) + # + # text_x_offset = 0.1 * (now_max_x - max_x) for loc, mass in delta_peaks.items(): text = "$P(\\mathrm{walk}) = " + ("%.2f$" % (mass)) @@ -384,7 +425,7 @@ def plot_temporal_distance_pdf_horizontal(self, use_minutes=True, ax.plot(ys, xs, "k-") if delta_peaks: tot_delta_peak_mass = sum(delta_peaks.values()) - fill_label = "$P(\\mathrm{PT}) = %.2f$" % (1-tot_delta_peak_mass) + fill_label = "$P(\\mathrm{PT}) = %.2f$" % (1 - tot_delta_peak_mass) else: fill_label = None ax.fill_betweenx(xs, ys, color=color, alpha=0.2, label=fill_label) @@ -397,37 +438,42 @@ def plot_temporal_distance_pdf_horizontal(self, use_minutes=True, legend_font_size = 12 if legend_loc is None: legend_loc = "best" - ax.legend(loc=legend_loc, prop={'size': legend_font_size}) - + ax.legend(loc=legend_loc, prop={"size": legend_font_size}) if True: line_tyles = ["-.", "--", "-"][::-1] - to_plot_funcs = [self.max_temporal_distance, self.mean_temporal_distance, self.min_temporal_distance] + to_plot_funcs = [ + self.max_temporal_distance, + self.mean_temporal_distance, + self.min_temporal_distance, + ] xmin, xmax = ax.get_xlim() for to_plot_func, ls in zip(to_plot_funcs, line_tyles): y = to_plot_func() / duration_divider - assert y < float('inf') + assert y < float("inf") # factor of 10 just to be safe that the lines cover the whole region. - ax.plot([xmin, xmax*10], [y, y], color="black", ls=ls, lw=1) + ax.plot([xmin, xmax * 10], [y, y], color="black", ls=ls, lw=1) return ax.figure - def plot_temporal_distance_profile(self, - timezone=None, - color="black", - alpha=0.15, - ax=None, - lw=2, - label="", - plot_tdist_stats=False, - plot_trip_stats=False, - format_string="%Y-%m-%d %H:%M:%S", - plot_journeys=False, - duration_divider=60.0, - fill_color="green", - journey_letters=None, - return_letters=False): + def plot_temporal_distance_profile( + self, + timezone=None, + color="black", + alpha=0.15, + ax=None, + lw=2, + label="", + plot_tdist_stats=False, + plot_trip_stats=False, + format_string="%Y-%m-%d %H:%M:%S", + plot_journeys=False, + duration_divider=60.0, + fill_color="green", + journey_letters=None, + return_letters=False, + ): """ Parameters ---------- @@ -457,28 +503,43 @@ def _ut_to_unloc_datetime(ut): _ut_to_unloc_datetime = lambda x: x ax.set_xlim( - _ut_to_unloc_datetime(self.start_time_dep), - _ut_to_unloc_datetime(self.end_time_dep) + _ut_to_unloc_datetime(self.start_time_dep), _ut_to_unloc_datetime(self.end_time_dep) ) if plot_tdist_stats: line_tyles = ["-.", "--", "-"][::-1] # to_plot_labels = ["maximum temporal distance", "mean temporal distance", "minimum temporal distance"] - to_plot_labels = ["$\\tau_\\mathrm{max} \\;$ = ", "$\\tau_\\mathrm{mean}$ = ", "$\\tau_\\mathrm{min} \\:\\:$ = "] - to_plot_funcs = [self.max_temporal_distance, self.mean_temporal_distance, self.min_temporal_distance] + to_plot_labels = [ + "$\\tau_\\mathrm{max} \\;$ = ", + "$\\tau_\\mathrm{mean}$ = ", + "$\\tau_\\mathrm{min} \\:\\:$ = ", + ] + to_plot_funcs = [ + self.max_temporal_distance, + self.mean_temporal_distance, + self.min_temporal_distance, + ] xmin, xmax = ax.get_xlim() for to_plot_label, to_plot_func, ls in zip(to_plot_labels, to_plot_funcs, line_tyles): y = to_plot_func() / duration_divider - assert y < float('inf'), to_plot_label + assert y < float("inf"), to_plot_label to_plot_label = to_plot_label + "%.1f min" % (y) ax.plot([xmin, xmax], [y, y], color="black", ls=ls, lw=1, label=to_plot_label) if plot_trip_stats: - assert (not plot_tdist_stats) + assert not plot_tdist_stats line_tyles = ["-", "-.", "--"] - to_plot_labels = ["min journey duration", "max journey duration", "mean journey duration"] - to_plot_funcs = [self.min_trip_duration, self.max_trip_duration, self.mean_trip_duration] + to_plot_labels = [ + "min journey duration", + "max journey duration", + "mean journey duration", + ] + to_plot_funcs = [ + self.min_trip_duration, + self.max_trip_duration, + self.mean_trip_duration, + ] xmin, xmax = ax.get_xlim() for to_plot_label, to_plot_func, ls in zip(to_plot_labels, to_plot_funcs, line_tyles): @@ -486,7 +547,9 @@ def _ut_to_unloc_datetime(ut): if not numpy.math.isnan(y): ax.plot([xmin, xmax], [y, y], color="red", ls=ls, lw=2) txt = to_plot_label + "\n = %.1f min" % y - ax.text(xmax + 0.01 * (xmax - xmin), y, txt, color="red", va="center", ha="left") + ax.text( + xmax + 0.01 * (xmax - xmin), y, txt, color="red", va="center", ha="left" + ) old_xmax = xmax xmax += (xmax - xmin) * 0.3 @@ -497,23 +560,27 @@ def _ut_to_unloc_datetime(ut): # plot the actual profile vertical_lines, slopes = self.profile_block_analyzer.get_vlines_and_slopes_for_plotting() for i, line in enumerate(slopes): - xs = [_ut_to_unloc_datetime(x) for x in line['x']] - if i is 0: - label = u"profile" + xs = [_ut_to_unloc_datetime(x) for x in line["x"]] + if i == 0: + label = "profile" else: label = None - ax.plot(xs, numpy.array(line['y']) / duration_divider, "-", color=color, lw=lw, label=label) + ax.plot( + xs, numpy.array(line["y"]) / duration_divider, "-", color=color, lw=lw, label=label + ) for line in vertical_lines: - xs = [_ut_to_unloc_datetime(x) for x in line['x']] - ax.plot(xs, numpy.array(line['y']) / duration_divider, ":", color=color) # , lw=lw) + xs = [_ut_to_unloc_datetime(x) for x in line["x"]] + ax.plot(xs, numpy.array(line["y"]) / duration_divider, ":", color=color) # , lw=lw) - assert (isinstance(ax, plt.Axes)) + assert isinstance(ax, plt.Axes) if plot_journeys: xs = [_ut_to_unloc_datetime(x) for x in self.trip_departure_times] ys = self.trip_durations - ax.plot(xs, numpy.array(ys) / duration_divider, "o", color="black", ms=8, label="journeys") + ax.plot( + xs, numpy.array(ys) / duration_divider, "o", color="black", ms=8, label="journeys" + ) if journey_letters is None: journey_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" @@ -528,26 +595,41 @@ def cycle_journey_letters(journey_letters): for element in saved: yield element + str(count) count += 1 + journey_letters_iterator = cycle_journey_letters(journey_letters) - time_letters = {int(time): letter for letter, time in zip(journey_letters_iterator, self.trip_departure_times)} + time_letters = { + int(time): letter + for letter, time in zip(journey_letters_iterator, self.trip_departure_times) + } for x, y, letter in zip(xs, ys, journey_letters_iterator): - walking = - self._walk_time_to_target / 30 if numpy.isfinite(self._walk_time_to_target) else 0 - ax.text(x + datetime.timedelta(seconds=(self.end_time_dep - self.start_time_dep) / 60), - (y + walking) / duration_divider, letter, va="top", ha="left") + walking = ( + -self._walk_time_to_target / 30 + if numpy.isfinite(self._walk_time_to_target) + else 0 + ) + ax.text( + x + datetime.timedelta(seconds=(self.end_time_dep - self.start_time_dep) / 60), + (y + walking) / duration_divider, + letter, + va="top", + ha="left", + ) fill_between_x = [] fill_between_y = [] for line in slopes: - xs = [_ut_to_unloc_datetime(x) for x in line['x']] + xs = [_ut_to_unloc_datetime(x) for x in line["x"]] fill_between_x.extend(xs) fill_between_y.extend(numpy.array(line["y"]) / duration_divider) - ax.fill_between(fill_between_x, y1=fill_between_y, color=fill_color, alpha=alpha, label=label) + ax.fill_between( + fill_between_x, y1=fill_between_y, color=fill_color, alpha=alpha, label=label + ) ax.set_ylim(bottom=0) ax.set_ylim(ax.get_ylim()[0], ax.get_ylim()[1] * 1.05) - if rcParams['text.usetex']: + if rcParams["text.usetex"]: ax.set_xlabel(r"Departure time $t_{\mathrm{dep}}$") else: ax.set_xlabel("Departure time") @@ -569,7 +651,10 @@ def _temporal_distance_pdf(self): len(density) == len(temporal_distance_split_points_ordered) -1 delta_peak_loc_to_probability_mass : dict """ - temporal_distance_split_points_ordered, norm_cdf = self.profile_block_analyzer._temporal_distance_cdf() + ( + temporal_distance_split_points_ordered, + norm_cdf, + ) = self.profile_block_analyzer._temporal_distance_cdf() delta_peak_loc_to_probability_mass = {} non_delta_peak_split_points = [temporal_distance_split_points_ordered[0]] @@ -584,18 +669,16 @@ def _temporal_distance_pdf(self): else: non_delta_peak_split_points.append(right) non_delta_peak_densities.append(prob_mass / float(width)) - assert (len(non_delta_peak_densities) == len(non_delta_peak_split_points) - 1) - return numpy.array(non_delta_peak_split_points), \ - numpy.array(non_delta_peak_densities), \ - delta_peak_loc_to_probability_mass + assert len(non_delta_peak_densities) == len(non_delta_peak_split_points) - 1 + return ( + numpy.array(non_delta_peak_split_points), + numpy.array(non_delta_peak_densities), + delta_peak_loc_to_probability_mass, + ) def get_temporal_distance_at(self, dep_time): return self.profile_block_analyzer.interpolate(dep_time) - def get_temporal_distance_at(self, dep_time): - return self.profile_block_analyzer - - @staticmethod def all_measures_and_names_as_lists(): NPA = NodeProfileAnalyzerTime @@ -608,7 +691,7 @@ def all_measures_and_names_as_lists(): NPA.mean_temporal_distance, NPA.median_temporal_distance, NPA.min_temporal_distance, - NPA.n_pareto_optimal_trips + NPA.n_pareto_optimal_trips, ] profile_observable_names = [ "max_trip_duration", @@ -619,7 +702,7 @@ def all_measures_and_names_as_lists(): "mean_temporal_distance", "median_temporal_distance", "min_temporal_distance", - "n_pareto_optimal_trips" + "n_pareto_optimal_trips", ] - assert (len(profile_summary_methods) == len(profile_observable_names)) + assert len(profile_summary_methods) == len(profile_observable_names) return profile_summary_methods, profile_observable_names diff --git a/gtfspy/routing/node_profile_analyzer_time_and_veh_legs.py b/gtfspy/routing/node_profile_analyzer_time_and_veh_legs.py index e474941..68f9307 100644 --- a/gtfspy/routing/node_profile_analyzer_time_and_veh_legs.py +++ b/gtfspy/routing/node_profile_analyzer_time_and_veh_legs.py @@ -1,34 +1,33 @@ from __future__ import print_function +import datetime import warnings from collections import defaultdict -import datetime import matplotlib -import numpy import matplotlib.pyplot as plt +import numpy import pytz -from matplotlib import lines - +from gtfspy.routing.label import LabelTimeWithBoardingsCount, compute_pareto_front, LabelTimeSimple from matplotlib import dates as md +from matplotlib import lines from matplotlib.colors import ListedColormap, LinearSegmentedColormap from gtfspy.routing.fastest_path_analyzer import FastestPathAnalyzer -from gtfspy.routing.node_profile_multiobjective import NodeProfileMultiObjective -from gtfspy.routing.label import LabelTimeWithBoardingsCount, compute_pareto_front, LabelTimeSimple from gtfspy.routing.node_profile_analyzer_time import NodeProfileAnalyzerTime +from gtfspy.routing.node_profile_multiobjective import NodeProfileMultiObjective from gtfspy.routing.node_profile_simple import NodeProfileSimple from gtfspy.routing.profile_block_analyzer import ProfileBlockAnalyzer def _check_for_no_labels_for_n_veh_counts(func): def wrapper(self): - assert (isinstance(self, NodeProfileAnalyzerTimeAndVehLegs)) + assert isinstance(self, NodeProfileAnalyzerTimeAndVehLegs) if len(self._labels_within_time_frame) == 0: if self._walk_to_target_duration is None: return 0 else: - return float('nan') + return float("nan") else: return func(self) @@ -40,7 +39,8 @@ def wrapper(self): if self._labels_within_time_frame: return func(self) else: - return float('inf') + return float("inf") + return wrapper @@ -50,21 +50,22 @@ def _truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100): Code originall from http://stackoverflow.com/questions/18926031/how-to-extract-a-subset-of-a-colormap-as-a-new-colormap-in-matplotlib """ new_cmap = LinearSegmentedColormap.from_list( - 'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval), - cmap(numpy.linspace(minval, maxval, n)) + "trunc({n},{a:.2f},{b:.2f})".format(n=cmap.name, a=minval, b=maxval), + cmap(numpy.linspace(minval, maxval, n)), ) return new_cmap class NodeProfileAnalyzerTimeAndVehLegs: - @classmethod def from_profile(cls, node_profile, start_time_dep, end_time_dep): - assert (node_profile.label_class == LabelTimeWithBoardingsCount) - return NodeProfileAnalyzerTimeAndVehLegs(node_profile.get_final_optimal_labels(), - node_profile.get_walk_to_target_duration(), - start_time_dep, - end_time_dep) + assert node_profile.label_class == LabelTimeWithBoardingsCount + return NodeProfileAnalyzerTimeAndVehLegs( + node_profile.get_final_optimal_labels(), + node_profile.get_walk_to_target_duration(), + start_time_dep, + end_time_dep, + ) def __init__(self, labels, walk_to_target_duration, start_time_dep, end_time_dep): """ @@ -77,12 +78,18 @@ def __init__(self, labels, walk_to_target_duration, start_time_dep, end_time_dep self._node_profile_final_labels = labels self.start_time_dep = start_time_dep self.end_time_dep = end_time_dep - self.all_labels = [label for label in self._node_profile_final_labels if - (start_time_dep <= label.departure_time <= end_time_dep)] - after_label_candidates = [label for label in self._node_profile_final_labels if - (label.departure_time > self.end_time_dep)] + self.all_labels = [ + label + for label in self._node_profile_final_labels + if (start_time_dep <= label.departure_time <= end_time_dep) + ] + after_label_candidates = [ + label + for label in self._node_profile_final_labels + if (label.departure_time > self.end_time_dep) + ] after_label_candidates.sort(key=lambda el: (el.arrival_time_target, el.n_boardings)) - min_n_boardings_observed = float('inf') + min_n_boardings_observed = float("inf") after_labels = [] for candidate_after_label in after_label_candidates: if candidate_after_label.n_boardings < min_n_boardings_observed: @@ -90,36 +97,38 @@ def __init__(self, labels, walk_to_target_duration, start_time_dep, end_time_dep min_n_boardings_observed = candidate_after_label.n_boardings self.all_labels.extend(after_labels) - if len(after_labels) is 0: + if len(after_labels) == 0: self._labels_within_time_frame = self.all_labels else: - self._labels_within_time_frame = self.all_labels[:-len(after_labels)] + self._labels_within_time_frame = self.all_labels[: -len(after_labels)] self._walk_to_target_duration = walk_to_target_duration self._n_boardings_to_simple_time_analyzers = {} self._transfers_on_fastest_paths_analyzer = self._get_transfers_on_fastest_path_analyzer() def __get_fastest_path_analyzer(self): - return FastestPathAnalyzer(self.all_labels, - self.start_time_dep, - self.end_time_dep, - walk_duration=self._walk_to_target_duration, - label_props_to_consider=["n_boardings"]) + return FastestPathAnalyzer( + self.all_labels, + self.start_time_dep, + self.end_time_dep, + walk_duration=self._walk_to_target_duration, + label_props_to_consider=["n_boardings"], + ) def _get_transfers_on_fastest_path_analyzer(self): fp_analyzer = self.__get_fastest_path_analyzer() - if self._walk_to_target_duration < float('inf'): + if self._walk_to_target_duration < float("inf"): cutoff_value = 0 else: - cutoff_value = float('inf') - return fp_analyzer.get_prop_analyzer_flat("n_boardings", float('inf'), cutoff_value) + cutoff_value = float("inf") + return fp_analyzer.get_prop_analyzer_flat("n_boardings", float("inf"), cutoff_value) def min_n_boardings(self): - if self._walk_to_target_duration < float('inf'): + if self._walk_to_target_duration < float("inf"): return 0 else: - if len(self.all_labels) is 0: - return float('inf') + if len(self.all_labels) == 0: + return float("inf") else: return min([label.n_boardings for label in self.all_labels]) @@ -134,6 +143,7 @@ def max_finite_n_boardings_on_fastest_paths(self): def mean_n_boardings_on_shortest_paths(self): import math + mean = self._transfers_on_fastest_paths_analyzer.mean() if math.isnan(mean): mean = self._transfers_on_fastest_paths_analyzer.mean() @@ -156,20 +166,26 @@ def get_time_profile_analyzer(self, max_n_boardings=None): if max_n_boardings is None: max_n_boardings = self.max_trip_n_boardings() # compute only if not yet computed - if not max_n_boardings in self._n_boardings_to_simple_time_analyzers: + if max_n_boardings not in self._n_boardings_to_simple_time_analyzers: if max_n_boardings == 0: valids = [] else: - candidate_labels = [LabelTimeSimple(label.departure_time, label.arrival_time_target) - for label in self._node_profile_final_labels if - ((self.start_time_dep <= label.departure_time) - and label.n_boardings <= max_n_boardings)] + candidate_labels = [ + LabelTimeSimple(label.departure_time, label.arrival_time_target) + for label in self._node_profile_final_labels + if ( + (self.start_time_dep <= label.departure_time) + and label.n_boardings <= max_n_boardings + ) + ] valids = compute_pareto_front(candidate_labels) valids.sort(key=lambda label: -label.departure_time) profile = NodeProfileSimple(self._walk_to_target_duration) for valid in valids: profile.update_pareto_optimal_tuples(valid) - npat = NodeProfileAnalyzerTime.from_profile(profile, self.start_time_dep, self.end_time_dep) + npat = NodeProfileAnalyzerTime.from_profile( + profile, self.start_time_dep, self.end_time_dep + ) self._n_boardings_to_simple_time_analyzers[max_n_boardings] = npat return self._n_boardings_to_simple_time_analyzers[max_n_boardings] @@ -190,15 +206,18 @@ def mean_trip_n_boardings(self): @_check_for_no_labels_for_n_veh_counts def median_trip_n_boardings(self): return numpy.median([label.n_boardings for label in self._labels_within_time_frame]) - + @_check_for_no_labels_for_n_veh_counts def n_boardings_on_fastest_trip(self): - return min(self._labels_within_time_frame, key=lambda label: label.arrival_time_target - label.departure_time).n_boardings + return min( + self._labels_within_time_frame, + key=lambda label: label.arrival_time_target - label.departure_time, + ).n_boardings @_if_no_labels_return_inf def min_temporal_distance(self): result = self.get_time_profile_analyzer().min_temporal_distance() - assert (result >= 0), result + assert result >= 0, result return result @_if_no_labels_return_inf @@ -255,7 +274,9 @@ def median_temporal_distances(self, min_n_boardings=None, max_n_boardings=None): if max_n_boardings is None: max_n_boardings = 0 - median_temporal_distances = [float('inf') for _ in range(min_n_boardings, max_n_boardings + 1)] + median_temporal_distances = [ + float("inf") for _ in range(min_n_boardings, max_n_boardings + 1) + ] for n_boardings in range(min_n_boardings, max_n_boardings + 1): simple_analyzer = self.get_time_profile_analyzer(n_boardings) median_temporal_distances[n_boardings] = simple_analyzer.median_temporal_distance() @@ -264,19 +285,23 @@ def median_temporal_distances(self, min_n_boardings=None, max_n_boardings=None): @classmethod def _get_colors_for_boardings(cls, min_n_boardings, max_n_boardings): cmap = NodeProfileAnalyzerTimeAndVehLegs.get_colormap_for_boardings(max_n_boardings) - colors = [cmap(float(n_boardings) / max_n_boardings) for n_boardings in range(int(max_n_boardings) + 1)] - return colors[min_n_boardings:max_n_boardings + 1] + colors = [ + cmap(float(n_boardings) / max_n_boardings) + for n_boardings in range(int(max_n_boardings) + 1) + ] + return colors[min_n_boardings : max_n_boardings + 1] @classmethod def get_colormap_for_boardings(cls, max_n_boardings=None): n_default = 5 - if max_n_boardings in [float('nan'), None]: + if max_n_boardings in [float("nan"), None]: max_n_boardings = n_default from matplotlib import cm + cmap = cm.get_cmap("cubehelix_r") start = 0.1 end = 0.9 - if max_n_boardings is 0: + if max_n_boardings == 0: step = 0 else: divider = max(n_default, max_n_boardings) @@ -311,7 +336,9 @@ def _get_fill_and_line_colors(self, min_n, max_n): line_saturation_multiplier = 1 / max_saturation for n, color_tuple in nboardings_to_color.items(): - c = NodeProfileAnalyzerTimeAndVehLegs._multiply_color_saturation(color_tuple, line_saturation_multiplier) + c = NodeProfileAnalyzerTimeAndVehLegs._multiply_color_saturation( + color_tuple, line_saturation_multiplier + ) c = NodeProfileAnalyzerTimeAndVehLegs._multiply_color_brightness(c, 1) n_boardings_to_line_color[n] = c @@ -321,30 +348,32 @@ def _get_fill_and_line_colors(self, min_n, max_n): return n_boardings_to_fill_color, n_boardings_to_line_color @classmethod - def n_boardings_to_label(self, n): - if n is 0: + def n_boardings_to_label(cls, n): + if n == 0: return "walk" - elif n is 1: + elif n == 1: return "1 boarding" else: return str(n) + " boardings" - def plot_new_transfer_temporal_distance_profile(self, - timezone=None, - format_string="%Y-%m-%d %H:%M:%S", - duration_divider=60.0, - ax=None, - plot_journeys=False, - highlight_fastest_path=True, - default_lw=5, - ncol_legend=1, - fastest_path_lw=3, - legend_alpha=0.9, - journey_letters=None, - legend_font_size=None): + def plot_new_transfer_temporal_distance_profile( + self, + timezone=None, + format_string="%Y-%m-%d %H:%M:%S", + duration_divider=60.0, + ax=None, + plot_journeys=False, + highlight_fastest_path=True, + default_lw=5, + ncol_legend=1, + fastest_path_lw=3, + legend_alpha=0.9, + journey_letters=None, + legend_font_size=None, + ): max_n = self.max_finite_n_boardings_on_fastest_paths() min_n = self.min_n_boardings() - if self._walk_to_target_duration < float('inf'): + if self._walk_to_target_duration < float("inf"): min_n = 0 if max_n is None: return None @@ -352,7 +381,7 @@ def plot_new_transfer_temporal_distance_profile(self, fig = plt.figure() ax = fig.add_subplot(111) fig = ax.figure - assert (isinstance(ax, matplotlib.axes.Axes)) + assert isinstance(ax, matplotlib.axes.Axes) if timezone is None: warnings.warn("Warning: No timezone specified, defaulting to UTC") @@ -369,18 +398,18 @@ def _ut_to_unloc_datetime(ut): _ut_to_unloc_datetime = lambda x: x ax.set_xlim( - _ut_to_unloc_datetime(self.start_time_dep), - _ut_to_unloc_datetime(self.end_time_dep) + _ut_to_unloc_datetime(self.start_time_dep), _ut_to_unloc_datetime(self.end_time_dep) ) n_boardings_range = range(min_n, max_n + 1) n_boardings_to_lw = {n: default_lw for i, n in enumerate(n_boardings_range)} - n_boardings_to_fill_color, n_boardings_to_line_color = self._get_fill_and_line_colors(min_n, max_n) + n_boardings_to_fill_color, n_boardings_to_line_color = self._get_fill_and_line_colors( + min_n, max_n + ) # get all trips ordered by departure time - deptime_ordered_labels = sorted(list(self.all_labels), key=lambda x: x.departure_time) n_boardings_to_labels = defaultdict(list) @@ -388,9 +417,14 @@ def _ut_to_unloc_datetime(ut): n_boardings_to_labels[journey_label.n_boardings].append(journey_label) walk_duration = self._walk_to_target_duration / duration_divider - if walk_duration < float('inf'): + if walk_duration < float("inf"): xs = [_ut_to_unloc_datetime(x) for x in [self.start_time_dep, self.end_time_dep]] - ax.plot(xs, [walk_duration, walk_duration], lw=n_boardings_to_lw[0], color=n_boardings_to_line_color[0]) + ax.plot( + xs, + [walk_duration, walk_duration], + lw=n_boardings_to_lw[0], + color=n_boardings_to_line_color[0], + ) ax.fill_between(xs, 0, walk_duration, color=n_boardings_to_fill_color[0]) max_tdist = walk_duration else: @@ -403,7 +437,7 @@ def _ut_to_unloc_datetime(ut): continue prev_analyzer = self.get_time_profile_analyzer(n_boardings - 1) prev_profile_block_analyzer = prev_analyzer.profile_block_analyzer - assert (isinstance(prev_profile_block_analyzer, ProfileBlockAnalyzer)) + assert isinstance(prev_profile_block_analyzer, ProfileBlockAnalyzer) labels = n_boardings_to_labels[n_boardings] prev_was_larger_already = False @@ -413,28 +447,35 @@ def _ut_to_unloc_datetime(ut): continue prev_dep_time = self.start_time_dep - if i is not 0: + if i != 0: prev_dep_time = labels[i - 1].departure_time # this could perhaps be made a while loop of some sort # to not loop over things multiple times for block in prev_profile_block_analyzer._profile_blocks: if block.distance_end != block.distance_start: if block.distance_end <= journey_label.duration() + ( - journey_label.departure_time - block.end_time): + journey_label.departure_time - block.end_time + ): prev_dep_time = max(prev_dep_time, block.end_time) elif block.distance_end == block.distance_start: # look for the time when - waiting_time = (block.distance_end - journey_label.duration()) - prev_dep_time = max(prev_dep_time, journey_label.departure_time - waiting_time) + waiting_time = block.distance_end - journey_label.duration() + prev_dep_time = max( + prev_dep_time, journey_label.departure_time - waiting_time + ) # prev dep time is now known waiting_time = journey_label.departure_time - prev_dep_time lw = n_boardings_to_lw[n_boardings] - xs = [_ut_to_unloc_datetime(prev_dep_time), _ut_to_unloc_datetime(journey_label.departure_time)] - ys = numpy.array([journey_label.duration() + waiting_time, journey_label.duration()]) / duration_divider + xs = [ + _ut_to_unloc_datetime(prev_dep_time), + _ut_to_unloc_datetime(journey_label.departure_time), + ] + ys = ( + numpy.array([journey_label.duration() + waiting_time, journey_label.duration()]) + / duration_divider + ) max_tdist = max(ys[0], max_tdist) - ax.plot(xs, ys, - color=n_boardings_to_line_color[n_boardings], # "k", - lw=lw) + ax.plot(xs, ys, color=n_boardings_to_line_color[n_boardings], lw=lw) # "k", ax.fill_between(xs, 0, ys, color=n_boardings_to_fill_color[n_boardings]) if plot_journeys: journeys.append((xs[1], ys[1])) @@ -442,22 +483,41 @@ def _ut_to_unloc_datetime(ut): legend_patches = [] for n_boardings in n_boardings_range: text = self.n_boardings_to_label(n_boardings) - p = lines.Line2D([0, 1], [0, 1], lw=n_boardings_to_lw[n_boardings], - color=n_boardings_to_line_color[n_boardings], - label=text) + p = lines.Line2D( + [0, 1], + [0, 1], + lw=n_boardings_to_lw[n_boardings], + color=n_boardings_to_line_color[n_boardings], + label=text, + ) legend_patches.append(p) if highlight_fastest_path: fastest_path_time_analyzer = self.get_time_profile_analyzer() - vlines, slopes = fastest_path_time_analyzer.profile_block_analyzer.get_vlines_and_slopes_for_plotting() + ( + vlines, + slopes, + ) = ( + fastest_path_time_analyzer.profile_block_analyzer.get_vlines_and_slopes_for_plotting() + ) lw = fastest_path_lw ls = "--" for vline in vlines: - ax.plot([_ut_to_unloc_datetime(x) for x in vline['x']], numpy.array(vline['y']) / duration_divider, - ls=":", lw=lw / 2, color="k") + ax.plot( + [_ut_to_unloc_datetime(x) for x in vline["x"]], + numpy.array(vline["y"]) / duration_divider, + ls=":", + lw=lw / 2, + color="k", + ) for slope in slopes: - ax.plot([_ut_to_unloc_datetime(x) for x in slope['x']], numpy.array(slope['y']) / duration_divider, - ls=ls, color="k", lw=lw) + ax.plot( + [_ut_to_unloc_datetime(x) for x in slope["x"]], + numpy.array(slope["y"]) / duration_divider, + ls=ls, + color="k", + lw=lw, + ) p = lines.Line2D([0, 1], [0, 1], ls=ls, lw=lw, color="k", label="fastest path profile") legend_patches.append(p) @@ -468,22 +528,42 @@ def _ut_to_unloc_datetime(ut): for (x, y), letter in zip(journeys, journey_letters): if x < _ut_to_unloc_datetime(self.end_time_dep): ax.plot(x, y, "o", ms=8, color="k") - ax.text(x + datetime.timedelta(seconds=(self.end_time_dep - self.start_time_dep) / 40.), - y, letter, va="center", ha="left") + ax.text( + x + + datetime.timedelta( + seconds=(self.end_time_dep - self.start_time_dep) / 40.0 + ), + y, + letter, + va="center", + ha="left", + ) p = lines.Line2D([0, 0], [1, 1], ls="", marker="o", ms=8, color="k", label="journeys") legend_patches.append(p) if legend_font_size is None: legend_font_size = 12 - leg = ax.legend(handles=legend_patches, loc="best", ncol=ncol_legend, fancybox=True, prop={"size": legend_font_size}) + leg = ax.legend( + handles=legend_patches, + loc="best", + ncol=ncol_legend, + fancybox=True, + prop={"size": legend_font_size}, + ) leg.get_frame().set_alpha(legend_alpha) ax.set_ylim(0, 1.1 * max_tdist) ax.set_xlabel("Departure time $t_\\mathrm{dep}$") ax.set_ylabel("Temporal distance $\\tau$") return fig - def plot_temporal_distance_pdf_horizontal(self, use_minutes=True, ax=None, duration_divider=60.0, - legend_font_size=None, legend_loc=None): + def plot_temporal_distance_pdf_horizontal( + self, + use_minutes=True, + ax=None, + duration_divider=60.0, + legend_font_size=None, + legend_loc=None, + ): if ax is None: fig = plt.figure() ax = fig.add_subplot(111) @@ -494,7 +574,7 @@ def plot_temporal_distance_pdf_horizontal(self, use_minutes=True, ax=None, durat non_walk_blocks = [] tdist_split_points = set() for block in blocks: - if block.is_flat(): # test for walk + if block.is_flat(): # test for walk walking_is_fastest_time += block.width() else: non_walk_blocks.append(block) @@ -502,19 +582,20 @@ def plot_temporal_distance_pdf_horizontal(self, use_minutes=True, ax=None, durat tdist_split_points.add(block.distance_start) distance_split_points_ordered = numpy.array(sorted(list(tdist_split_points))) - temporal_distance_split_widths = distance_split_points_ordered[1:] - distance_split_points_ordered[:-1] + temporal_distance_split_widths = ( + distance_split_points_ordered[1:] - distance_split_points_ordered[:-1] + ) non_walk_blocks_total_time = sum((block.width() for block in non_walk_blocks)) - assert ( - numpy.isclose(walking_is_fastest_time + non_walk_blocks_total_time, - self.end_time_dep - self.start_time_dep) + assert numpy.isclose( + walking_is_fastest_time + non_walk_blocks_total_time, + self.end_time_dep - self.start_time_dep, ) min_n_boardings = int(self.min_n_boardings()) max_n_boardings = int(self.max_finite_n_boardings_on_fastest_paths()) fill_colors, line_colors = self._get_fill_and_line_colors(min_n_boardings, max_n_boardings) - temporal_distance_values_to_plot = [] for x in distance_split_points_ordered: temporal_distance_values_to_plot.append(x) @@ -530,8 +611,12 @@ def plot_temporal_distance_pdf_horizontal(self, use_minutes=True, ax=None, durat journey_counts = numpy.zeros(len(temporal_distance_split_widths)) for block_now in blocks_now: - first_index = numpy.searchsorted(distance_split_points_ordered, block_now.distance_end) - last_index = numpy.searchsorted(distance_split_points_ordered, block_now.distance_start) + first_index = numpy.searchsorted( + distance_split_points_ordered, block_now.distance_end + ) + last_index = numpy.searchsorted( + distance_split_points_ordered, block_now.distance_start + ) journey_counts[first_index:last_index] += 1 part_pdf = journey_counts / (self.end_time_dep - self.start_time_dep) @@ -545,27 +630,45 @@ def plot_temporal_distance_pdf_horizontal(self, use_minutes=True, ax=None, durat pdf_values_to_plot_by_n_boardings[n_boardings] = numpy.array(pdf_values_to_plot) if walking_is_fastest_time > 0: - text = "$P(\\mathrm{walk}) = " + ("%.2f$" % (walking_is_fastest_time/(self.end_time_dep - self.start_time_dep))) - ax.plot([0, 10], [self._walk_to_target_duration / duration_divider, self._walk_to_target_duration / duration_divider], color=line_colors[0], - lw=5, label=text, zorder=10) + text = "$P(\\mathrm{walk}) = " + ( + "%.2f$" % (walking_is_fastest_time / (self.end_time_dep - self.start_time_dep)) + ) + ax.plot( + [0, 10], + [ + self._walk_to_target_duration / duration_divider, + self._walk_to_target_duration / duration_divider, + ], + color=line_colors[0], + lw=5, + label=text, + zorder=10, + ) for n_boardings in range(max(1, min_n_boardings), max_n_boardings + 1): if n_boardings is max_n_boardings: prob = pdf_areas[n_boardings] else: - prob = pdf_areas[n_boardings] - pdf_areas[n_boardings+1] + prob = pdf_areas[n_boardings] - pdf_areas[n_boardings + 1] pdf_values_to_plot = pdf_values_to_plot_by_n_boardings[n_boardings] - if n_boardings is 0: + if n_boardings == 0: label = "$P(\\mathrm{walk})= %.2f $" % (prob) else: label = "$P(b=" + str(n_boardings) + ")= %.2f $" % (prob) # "$P_\\mathrm{" + self.n_boardings_to_label(n_boardings).replace(" ", "\\,") + "} = %.2f $" % (prob) - ax.fill_betweenx(temporal_distance_values_to_plot / duration_divider, - pdf_values_to_plot * duration_divider, - label=label, - color=fill_colors[n_boardings], zorder=n_boardings) - ax.plot(pdf_values_to_plot * duration_divider, temporal_distance_values_to_plot / duration_divider, - color=line_colors[n_boardings], zorder=n_boardings) + ax.fill_betweenx( + temporal_distance_values_to_plot / duration_divider, + pdf_values_to_plot * duration_divider, + label=label, + color=fill_colors[n_boardings], + zorder=n_boardings, + ) + ax.plot( + pdf_values_to_plot * duration_divider, + temporal_distance_values_to_plot / duration_divider, + color=line_colors[n_boardings], + zorder=n_boardings, + ) # get counts for each plot if legend_font_size is None: @@ -583,8 +686,7 @@ def plot_fastest_temporal_distance_profile(self, timezone=None, **kwargs): ax = fig.add_subplot(111) kwargs["ax"] = ax npat = self.get_time_profile_analyzer(max_n) - fig = npat.plot_temporal_distance_profile(timezone=timezone, - **kwargs) + fig = npat.plot_temporal_distance_profile(timezone=timezone, **kwargs) return fig def n_pareto_optimal_trips(self): @@ -620,7 +722,7 @@ def all_measures_and_names_as_lists(): NPA.max_n_boardings_on_shortest_paths, NPA.median_n_boardings_on_shortest_paths, NPA.mean_temporal_distance_with_min_n_boardings, - NPA.min_temporal_distance_with_min_n_boardings + NPA.min_temporal_distance_with_min_n_boardings, ] profile_observable_names = [ "max_trip_duration", @@ -642,13 +744,15 @@ def all_measures_and_names_as_lists(): "max_n_boardings_on_shortest_paths", "median_n_boardings_on_shortest_paths", "mean_temporal_distance_with_min_n_boardings", - "min_temporal_distance_with_min_n_boardings" + "min_temporal_distance_with_min_n_boardings", ] - assert (len(profile_summary_methods) == len(profile_observable_names)) + assert len(profile_summary_methods) == len(profile_observable_names) return profile_summary_methods, profile_observable_names def get_node_profile_measures_as_dict(self): profile_summary_methods, profile_observable_names = self.all_measures_and_names_as_lists() - profile_measure_dict = {key: value(self) for key, value in zip(profile_observable_names, profile_summary_methods)} + profile_measure_dict = { + key: value(self) + for key, value in zip(profile_observable_names, profile_summary_methods) + } return profile_measure_dict - diff --git a/gtfspy/routing/node_profile_c.py b/gtfspy/routing/node_profile_c.py index bc25f08..5e6a5f5 100644 --- a/gtfspy/routing/node_profile_c.py +++ b/gtfspy/routing/node_profile_c.py @@ -1,5 +1,3 @@ -import copy - from gtfspy.routing.label import LabelTime, compute_pareto_front_naive @@ -9,7 +7,7 @@ class NodeProfileC: that stores information on the Pareto-Optimal (departure_time_this_node, arrival_time_target_node) tuples. """ - def __init__(self, walk_to_target_duration=float('inf')): + def __init__(self, walk_to_target_duration=float("inf")): self._labels = [] # list[LabelTime] # always ordered by decreasing departure_time self._walk_to_target_duration = walk_to_target_duration @@ -26,22 +24,26 @@ def update_pareto_optimal_tuples(self, new_label): ------- updated: bool """ - assert (isinstance(new_label, LabelTime)) + assert isinstance(new_label, LabelTime) if self._labels: - assert (new_label.departure_time <= self._labels[-1].departure_time) + assert new_label.departure_time <= self._labels[-1].departure_time best_later_departing_arrival_time = self._labels[-1].arrival_time_target else: - best_later_departing_arrival_time = float('inf') + best_later_departing_arrival_time = float("inf") walk_to_target_arrival_time = new_label.departure_time + self._walk_to_target_duration - best_arrival_time = min(walk_to_target_arrival_time, - best_later_departing_arrival_time, - new_label.arrival_time_target) + best_arrival_time = min( + walk_to_target_arrival_time, + best_later_departing_arrival_time, + new_label.arrival_time_target, + ) # this should be changed to get constant time insertions / additions # (with time-indexing) - if (new_label.arrival_time_target < walk_to_target_arrival_time and - new_label.arrival_time_target < best_later_departing_arrival_time): + if ( + new_label.arrival_time_target < walk_to_target_arrival_time + and new_label.arrival_time_target < best_later_departing_arrival_time + ): self._labels.append(LabelTime(new_label.departure_time, best_arrival_time)) return True else: diff --git a/gtfspy/routing/node_profile_multiobjective.py b/gtfspy/routing/node_profile_multiobjective.py index 6279fa9..00fdd55 100644 --- a/gtfspy/routing/node_profile_multiobjective.py +++ b/gtfspy/routing/node_profile_multiobjective.py @@ -1,7 +1,14 @@ import numpy -from gtfspy.routing.label import LabelTimeWithBoardingsCount, merge_pareto_frontiers, compute_pareto_front, \ - LabelVehLegCount, LabelTime, LabelTimeBoardingsAndRoute, LabelTimeAndRoute +from gtfspy.routing.label import ( + LabelTimeWithBoardingsCount, + merge_pareto_frontiers, + compute_pareto_front, + LabelVehLegCount, + LabelTime, + LabelTimeBoardingsAndRoute, + LabelTimeAndRoute, +) from gtfspy.routing.connection import Connection @@ -11,13 +18,15 @@ class NodeProfileMultiObjective: each stop has a profile entry containing all Pareto-optimal entries. """ - def __init__(self, - dep_times=None, - walk_to_target_duration=float('inf'), - label_class=LabelTimeWithBoardingsCount, - transit_connection_dep_times=None, - closest_target=None, - node_id=None): + def __init__( + self, + dep_times=None, + walk_to_target_duration=float("inf"), + label_class=LabelTimeWithBoardingsCount, + transit_connection_dep_times=None, + closest_target=None, + node_id=None, + ): """ Parameters ---------- @@ -35,20 +44,25 @@ def __init__(self, n_dep_times = len(dep_times) assert n_dep_times == len(set(dep_times)), "There should be no duplicate departure times" self._departure_times = list(reversed(sorted(dep_times))) - self.dep_times_to_index = dict(zip(self._departure_times, range(len(self._departure_times)))) + self.dep_times_to_index = dict( + zip(self._departure_times, range(len(self._departure_times))) + ) self._label_bags = [[]] * len(self._departure_times) self._walk_to_target_duration = walk_to_target_duration - self._min_dep_time = float('inf') + self._min_dep_time = float("inf") self.label_class = label_class self.closest_target = closest_target - if self.label_class == LabelTimeBoardingsAndRoute and self._walk_to_target_duration < float('inf'): - assert (self.closest_target is not None) + if ( + self.label_class == LabelTimeBoardingsAndRoute + and self._walk_to_target_duration < float("inf") + ): + assert self.closest_target is not None if transit_connection_dep_times is not None: self._connection_dep_times = transit_connection_dep_times else: self._connection_dep_times = dep_times - assert (isinstance(self._connection_dep_times, (list, numpy.ndarray))) + assert isinstance(self._connection_dep_times, (list, numpy.ndarray)) self._closed = False self._finalized = False self._final_pareto_optimal_labels = None @@ -68,14 +82,19 @@ def _check_dep_time_is_valid(self, dep_time): ------- None """ - assert dep_time <= self._min_dep_time, "Labels should be entered in decreasing order of departure time." + assert ( + dep_time <= self._min_dep_time + ), "Labels should be entered in decreasing order of departure time." dep_time_index = self.dep_times_to_index[dep_time] - if self._min_dep_time < float('inf'): + if self._min_dep_time < float("inf"): min_dep_index = self.dep_times_to_index[self._min_dep_time] - assert min_dep_index == dep_time_index or (min_dep_index == dep_time_index - 1), \ - "dep times should be ordered sequentially" + assert min_dep_index == dep_time_index or ( + min_dep_index == dep_time_index - 1 + ), "dep times should be ordered sequentially" else: - assert dep_time_index is 0, "first dep_time index should be zero (ensuring that all connections are properly handled)" + assert ( + dep_time_index == 0 + ), "first dep_time index should be zero (ensuring that all connections are properly handled)" self._min_dep_time = dep_time def get_walk_to_target_duration(self): @@ -111,13 +130,15 @@ def update(self, new_labels, departure_time_backup=None): self._check_dep_time_is_valid(departure_time) for new_label in new_labels: - assert (new_label.departure_time == departure_time) + assert new_label.departure_time == departure_time dep_time_index = self.dep_times_to_index[departure_time] if dep_time_index > 0: # Departure time is modified in order to not pass on labels which are not Pareto-optimal when departure time is ignored. - mod_prev_labels = [label.get_copy_with_specified_departure_time(departure_time) for label - in self._label_bags[dep_time_index - 1]] + mod_prev_labels = [ + label.get_copy_with_specified_departure_time(departure_time) + for label in self._label_bags[dep_time_index - 1] + ] else: mod_prev_labels = list() mod_prev_labels += self._label_bags[dep_time_index] @@ -154,7 +175,7 @@ def evaluate(self, dep_time, first_leg_can_be_walk=True, connection_arrival_time """ walk_labels = list() # walk label towards target - if first_leg_can_be_walk and self._walk_to_target_duration != float('inf'): + if first_leg_can_be_walk and self._walk_to_target_duration != float("inf"): # add walk_label if connection_arrival_time is not None: walk_labels.append(self._get_label_to_target(connection_arrival_time)) @@ -163,7 +184,7 @@ def evaluate(self, dep_time, first_leg_can_be_walk=True, connection_arrival_time # if dep time is larger than the largest dep time -> only walk labels are possible if dep_time in self.dep_times_to_index: - assert (dep_time != float('inf')) + assert dep_time != float("inf") index = self.dep_times_to_index[dep_time] labels = self._label_bags[index] pareto_optimal_labels = merge_pareto_frontiers(labels, walk_labels) @@ -171,45 +192,57 @@ def evaluate(self, dep_time, first_leg_can_be_walk=True, connection_arrival_time pareto_optimal_labels = walk_labels if not first_leg_can_be_walk: - pareto_optimal_labels = [label for label in pareto_optimal_labels if not label.first_leg_is_walk] + pareto_optimal_labels = [ + label for label in pareto_optimal_labels if not label.first_leg_is_walk + ] return pareto_optimal_labels def _get_label_to_target(self, departure_time): - if departure_time != float('inf') and self._walk_to_target_duration != float('inf'): + if departure_time != float("inf") and self._walk_to_target_duration != float("inf"): if self._walk_to_target_duration == 0: first_leg_is_walk = False else: first_leg_is_walk = True - if self.label_class == LabelTimeBoardingsAndRoute or self.label_class == LabelTimeAndRoute: + if ( + self.label_class == LabelTimeBoardingsAndRoute + or self.label_class == LabelTimeAndRoute + ): if self._walk_to_target_duration > 0: - walk_connection = Connection(self.node_id, - self.closest_target, - departure_time, - departure_time + self._walk_to_target_duration, - Connection.WALK_TRIP_ID, - Connection.WALK_SEQ, - is_walk=True - ) + walk_connection = Connection( + self.node_id, + self.closest_target, + departure_time, + departure_time + self._walk_to_target_duration, + Connection.WALK_TRIP_ID, + Connection.WALK_SEQ, + is_walk=True, + ) else: walk_connection = None if self.label_class == LabelTimeAndRoute: - label = self.label_class(departure_time=float(departure_time), - arrival_time_target=float(departure_time + self._walk_to_target_duration), - movement_duration=self._walk_to_target_duration, - first_leg_is_walk=first_leg_is_walk, - connection=walk_connection) + label = self.label_class( + departure_time=float(departure_time), + arrival_time_target=float(departure_time + self._walk_to_target_duration), + movement_duration=self._walk_to_target_duration, + first_leg_is_walk=first_leg_is_walk, + connection=walk_connection, + ) else: - label = self.label_class(departure_time=float(departure_time), - arrival_time_target=float(departure_time + self._walk_to_target_duration), - movement_duration=self._walk_to_target_duration, - n_boardings=0, - first_leg_is_walk=first_leg_is_walk, - connection=walk_connection) + label = self.label_class( + departure_time=float(departure_time), + arrival_time_target=float(departure_time + self._walk_to_target_duration), + movement_duration=self._walk_to_target_duration, + n_boardings=0, + first_leg_is_walk=first_leg_is_walk, + connection=walk_connection, + ) else: - label = self.label_class(departure_time=float(departure_time), - arrival_time_target=float(departure_time + self._walk_to_target_duration), - n_boardings=0, - first_leg_is_walk=first_leg_is_walk) + label = self.label_class( + departure_time=float(departure_time), + arrival_time_target=float(departure_time + self._walk_to_target_duration), + n_boardings=0, + first_leg_is_walk=first_leg_is_walk, + ) return label else: @@ -231,7 +264,9 @@ def get_final_optimal_labels(self): assert self._finalized, "finalize() first!" return self._final_pareto_optimal_labels - def finalize(self, neighbor_label_bags=None, walk_durations=None, departure_arrival_stop_pairs=None): + def finalize( + self, neighbor_label_bags=None, walk_durations=None, departure_arrival_stop_pairs=None + ): """ Parameters ---------- @@ -244,14 +279,14 @@ def finalize(self, neighbor_label_bags=None, walk_durations=None, departure_arri ------- None """ - assert (not self._finalized) + assert not self._finalized if self._final_pareto_optimal_labels is None: self._compute_real_connection_labels() if neighbor_label_bags is not None: - assert (len(walk_durations) == len(neighbor_label_bags)) - self._compute_final_pareto_optimal_labels(neighbor_label_bags, - walk_durations, - departure_arrival_stop_pairs) + assert len(walk_durations) == len(neighbor_label_bags) + self._compute_final_pareto_optimal_labels( + neighbor_label_bags, walk_durations, departure_arrival_stop_pairs + ) else: self._final_pareto_optimal_labels = self._real_connection_labels self._finalized = True @@ -262,43 +297,62 @@ def _compute_real_connection_labels(self): # do not take those bags with first event is a pseudo-connection for dep_time in self._connection_dep_times: index = self.dep_times_to_index[dep_time] - pareto_optimal_labels.extend([label for label in self._label_bags[index] if not label.first_leg_is_walk]) - if self.label_class == LabelTimeWithBoardingsCount or self.label_class == LabelTime \ - or self.label_class == LabelTimeBoardingsAndRoute: - pareto_optimal_labels = [label for label in pareto_optimal_labels - if label.duration() < self._walk_to_target_duration] - - if self.label_class == LabelVehLegCount and self._walk_to_target_duration < float('inf'): + pareto_optimal_labels.extend( + [label for label in self._label_bags[index] if not label.first_leg_is_walk] + ) + if ( + self.label_class == LabelTimeWithBoardingsCount + or self.label_class == LabelTime + or self.label_class == LabelTimeBoardingsAndRoute + ): + pareto_optimal_labels = [ + label + for label in pareto_optimal_labels + if label.duration() < self._walk_to_target_duration + ] + + if self.label_class == LabelVehLegCount and self._walk_to_target_duration < float("inf"): pareto_optimal_labels.append(LabelVehLegCount(0)) - self._real_connection_labels = [label.get_copy() for label in compute_pareto_front(pareto_optimal_labels, - finalization=True)] - - def _compute_final_pareto_optimal_labels(self, neighbor_label_bags, walk_durations, departure_arrival_stops): + self._real_connection_labels = [ + label.get_copy() + for label in compute_pareto_front(pareto_optimal_labels, finalization=True) + ] + + def _compute_final_pareto_optimal_labels( + self, neighbor_label_bags, walk_durations, departure_arrival_stops + ): labels_from_neighbors = [] - for i, (label_bag, walk_duration)in enumerate(zip(neighbor_label_bags, walk_durations)): + for i, (label_bag, walk_duration) in enumerate(zip(neighbor_label_bags, walk_durations)): for label in label_bag: - if self.label_class == LabelTimeBoardingsAndRoute or self.label_class == LabelTimeAndRoute: + if ( + self.label_class == LabelTimeBoardingsAndRoute + or self.label_class == LabelTimeAndRoute + ): departure_arrival_tuple = departure_arrival_stops[i] departure_time = label.departure_time - walk_duration arrival_time = label.departure_time - connection = Connection(departure_arrival_tuple[0], - departure_arrival_tuple[1], - departure_time, - arrival_time, - Connection.WALK_TRIP_ID, - Connection.WALK_SEQ, - is_walk=True) - labels_from_neighbors.append(label.get_copy_with_walk_added(walk_duration, connection)) + connection = Connection( + departure_arrival_tuple[0], + departure_arrival_tuple[1], + departure_time, + arrival_time, + Connection.WALK_TRIP_ID, + Connection.WALK_SEQ, + is_walk=True, + ) + labels_from_neighbors.append( + label.get_copy_with_walk_added(walk_duration, connection) + ) else: labels_from_neighbors.append(label.get_copy_with_walk_added(walk_duration)) - pareto_front = compute_pareto_front(self._real_connection_labels + - labels_from_neighbors, - finalization=True) + pareto_front = compute_pareto_front( + self._real_connection_labels + labels_from_neighbors, finalization=True + ) if pareto_front and hasattr(pareto_front[0], "duration"): - self._final_pareto_optimal_labels = list(filter(lambda label: label.duration() < self._walk_to_target_duration, pareto_front)) + self._final_pareto_optimal_labels = list( + filter(lambda label: label.duration() < self._walk_to_target_duration, pareto_front) + ) else: self._final_pareto_optimal_labels = pareto_front - - diff --git a/gtfspy/routing/node_profile_simple.py b/gtfspy/routing/node_profile_simple.py index dff39b9..6545a42 100644 --- a/gtfspy/routing/node_profile_simple.py +++ b/gtfspy/routing/node_profile_simple.py @@ -7,7 +7,7 @@ class NodeProfileSimple: that stores information on the Pareto-Optimal (departure_time_this_node, arrival_time_target_node) tuples. """ - def __init__(self, walk_to_target_duration=float('inf'), label_class=LabelTimeSimple): + def __init__(self, walk_to_target_duration=float("inf"), label_class=LabelTimeSimple): self._labels = [] # list[LabelTimeSimple] # always ordered by decreasing departure_time self._walk_to_target_duration = walk_to_target_duration self._label_class = label_class @@ -29,11 +29,14 @@ def update_pareto_optimal_tuples(self, new_pareto_tuple): whether new_pareto_tuple was added to the set of pareto-optimal tuples """ if new_pareto_tuple.duration() > self._walk_to_target_duration: - direct_walk_label = self._label_class.direct_walk_label(new_pareto_tuple.departure_time, - self._walk_to_target_duration) + direct_walk_label = self._label_class.direct_walk_label( + new_pareto_tuple.departure_time, self._walk_to_target_duration + ) if not direct_walk_label.dominates(new_pareto_tuple): raise - direct_walk_label = self._label_class.direct_walk_label(new_pareto_tuple.departure_time, self._walk_to_target_duration) + direct_walk_label = self._label_class.direct_walk_label( + new_pareto_tuple.departure_time, self._walk_to_target_duration + ) if direct_walk_label.dominates(new_pareto_tuple): return False @@ -92,10 +95,12 @@ def evaluate_earliest_arrival_time_at_target(self, dep_time, transfer_margin): minimum = dep_time + self._walk_to_target_duration dep_time_plus_transfer_margin = dep_time + transfer_margin for label in self._labels: - if label.departure_time >= dep_time_plus_transfer_margin and label.arrival_time_target < minimum: + if ( + label.departure_time >= dep_time_plus_transfer_margin + and label.arrival_time_target < minimum + ): minimum = label.arrival_time_target return float(minimum) def get_final_optimal_labels(self): return [label.get_copy() for label in self._labels] - diff --git a/gtfspy/routing/profile_block.py b/gtfspy/routing/profile_block.py index 08105d7..f04ca2e 100644 --- a/gtfspy/routing/profile_block.py +++ b/gtfspy/routing/profile_block.py @@ -1,11 +1,8 @@ - - class ProfileBlock: - def __init__(self, start_time, end_time, distance_start, distance_end, **extra_properties): self.start_time = start_time self.end_time = end_time - assert(self.start_time < self.end_time) + assert self.start_time < self.end_time self.distance_start = distance_start self.distance_end = distance_end self.extra_properties = extra_properties @@ -19,10 +16,10 @@ def mean(self): def width(self): return self.end_time - self.start_time - def max(self): + def max(self): # noqa: A003 return max(self.distance_start, self.distance_end) - def min(self): + def min(self): # noqa: A003 return min(self.distance_start, self.distance_end) def is_flat(self): @@ -30,7 +27,7 @@ def is_flat(self): def interpolate(self, time): p = (time - self.start_time) / (self.end_time - self.start_time) - return (1-p) * self.distance_start + p * self.distance_end + return (1 - p) * self.distance_start + p * self.distance_end def __getitem__(self, extra_property_name): return self.extra_properties[extra_property_name] @@ -42,4 +39,4 @@ def __str__(self): parts.append(self.distance_start) parts.append(self.distance_end) parts.append(self.extra_properties) - return str(parts) \ No newline at end of file + return str(parts) diff --git a/gtfspy/routing/profile_block_analyzer.py b/gtfspy/routing/profile_block_analyzer.py index ab151f4..f4a082b 100644 --- a/gtfspy/routing/profile_block_analyzer.py +++ b/gtfspy/routing/profile_block_analyzer.py @@ -6,7 +6,6 @@ class ProfileBlockAnalyzer: - def __init__(self, profile_blocks, cutoff_distance=None, **kwargs): """ Parameters @@ -40,31 +39,42 @@ def _apply_cutoff(self, cutoff_distance): if block_max > cutoff_distance: print("applying cutoff") blocks = [] - if block.distance_start == block.distance_end or \ - (block.distance_start > cutoff_distance and block.distance_end > cutoff_distance): + if block.distance_start == block.distance_end or ( + block.distance_start > cutoff_distance and block.distance_end > cutoff_distance + ): blocks.append( - ProfileBlock(distance_end=cutoff_distance, - distance_start=cutoff_distance, - start_time=block.start_time, - end_time=block.end_time) + ProfileBlock( + distance_end=cutoff_distance, + distance_start=cutoff_distance, + start_time=block.start_time, + end_time=block.end_time, + ) ) else: - if (block.distance_end >= cutoff_distance): - assert (block.distance_end < cutoff_distance) - split_point_x = block.start_time + (block.distance_start - cutoff_distance) / ( - block.distance_start - block.distance_end) * block.width() + if block.distance_end >= cutoff_distance: + assert block.distance_end < cutoff_distance + split_point_x = ( + block.start_time + + (block.distance_start - cutoff_distance) + / (block.distance_start - block.distance_end) + * block.width() + ) if block.distance_start > block.distance_end: start_distance = cutoff_distance end_distance = block.distance_end else: start_distance = block.distance_start end_distance = cutoff_distance - first_block = ProfileBlock(block.start_time, split_point_x, start_distance, cutoff_distance) - second_block = ProfileBlock(split_point_x, block.end_time, cutoff_distance, end_distance) + first_block = ProfileBlock( + block.start_time, split_point_x, start_distance, cutoff_distance + ) + second_block = ProfileBlock( + split_point_x, block.end_time, cutoff_distance, end_distance + ) blocks.append(first_block) blocks.append(second_block) index = self._profile_blocks.index(block) - self._profile_blocks[index:index + 1] = blocks + self._profile_blocks[index : index + 1] = blocks def mean(self): total_width = self._profile_blocks[-1].end_time - self._profile_blocks[0].start_time @@ -74,32 +84,40 @@ def mean(self): def median(self): try: distance_split_points_ordered, norm_cdf = self._temporal_distance_cdf() - except RuntimeError as e: - return float('inf') + except RuntimeError: + return float("inf") if len(distance_split_points_ordered) == 0: - return float('inf') + return float("inf") left = numpy.searchsorted(norm_cdf, 0.5, side="left") right = numpy.searchsorted(norm_cdf, 0.5, side="right") if left == len(norm_cdf): - return float('inf') + return float("inf") elif left == right: left_cdf_val = norm_cdf[right - 1] right_cdf_val = norm_cdf[right] delta_y = right_cdf_val - left_cdf_val - assert (delta_y > 0) - delta_x = (distance_split_points_ordered[right] - distance_split_points_ordered[right - 1]) - median = (0.5 - left_cdf_val) / delta_y * delta_x + distance_split_points_ordered[right - 1] + assert delta_y > 0 + delta_x = ( + distance_split_points_ordered[right] - distance_split_points_ordered[right - 1] + ) + median = (0.5 - left_cdf_val) / delta_y * delta_x + distance_split_points_ordered[ + right - 1 + ] return median else: return distance_split_points_ordered[left] - def min(self): - return min([min(block.distance_end, block.distance_start) for block in self._profile_blocks]) + def min(self): # noqa: A003 + return min( + [min(block.distance_end, block.distance_start) for block in self._profile_blocks] + ) - def max(self): - return max([max(block.distance_end, block.distance_start) for block in self._profile_blocks]) + def max(self): # noqa: A003 + return max( + [max(block.distance_end, block.distance_start) for block in self._profile_blocks] + ) def largest_finite_distance(self): """ @@ -109,10 +127,16 @@ def largest_finite_distance(self): ------- max_temporal_distance : float """ - block_start_distances = [block.distance_start for block in self._profile_blocks if - block.distance_start < float('inf')] - block_end_distances = [block.distance_end for block in self._profile_blocks if - block.distance_end < float('inf')] + block_start_distances = [ + block.distance_start + for block in self._profile_blocks + if block.distance_start < float("inf") + ] + block_end_distances = [ + block.distance_end + for block in self._profile_blocks + if block.distance_end < float("inf") + ] distances = block_start_distances + block_end_distances if len(distances) > 0: return max(distances) @@ -120,14 +144,16 @@ def largest_finite_distance(self): return None def summary_as_dict(self): - summary = {"max": self.max(), - "min": self.min(), - "mean": self.mean(), - "median": self.median()} + summary = { + "max": self.max(), + "min": self.min(), + "mean": self.mean(), + "median": self.median(), + } if hasattr(self, "from_stop_I"): - summary['from_stop_I'] = self.from_stop_I + summary["from_stop_I"] = self.from_stop_I if hasattr(self, "to_stop_I"): - summary['to_stop_I'] = self.to_stop_I + summary["to_stop_I"] = self.to_stop_I return summary def _temporal_distance_cdf(self): @@ -143,12 +169,14 @@ def _temporal_distance_cdf(self): """ distance_split_points = set() for block in self._profile_blocks: - if block.distance_start != float('inf'): + if block.distance_start != float("inf"): distance_split_points.add(block.distance_end) distance_split_points.add(block.distance_start) distance_split_points_ordered = numpy.array(sorted(list(distance_split_points))) - temporal_distance_split_widths = distance_split_points_ordered[1:] - distance_split_points_ordered[:-1] + temporal_distance_split_widths = ( + distance_split_points_ordered[1:] - distance_split_points_ordered[:-1] + ) trip_counts = numpy.zeros(len(temporal_distance_split_widths)) delta_peaks = defaultdict(lambda: 0) @@ -160,26 +188,36 @@ def _temporal_distance_cdf(self): end_index = numpy.searchsorted(distance_split_points_ordered, block.distance_start) trip_counts[start_index:end_index] += 1 - unnormalized_cdf = numpy.array([0] + list(numpy.cumsum(temporal_distance_split_widths * trip_counts))) - if not (numpy.isclose( + unnormalized_cdf = numpy.array( + [0] + list(numpy.cumsum(temporal_distance_split_widths * trip_counts)) + ) + if not ( + numpy.isclose( [unnormalized_cdf[-1]], - [self._end_time - self._start_time - sum(delta_peaks.values())], atol=1E-4 - ).all()): - print(unnormalized_cdf[-1], self._end_time - self._start_time - sum(delta_peaks.values())) + [self._end_time - self._start_time - sum(delta_peaks.values())], + atol=1e-4, + ).all() + ): + print( + unnormalized_cdf[-1], self._end_time - self._start_time - sum(delta_peaks.values()) + ) raise RuntimeError("Something went wrong with cdf computation!") if len(delta_peaks) > 0: for peak in delta_peaks.keys(): - if peak == float('inf'): + if peak == float("inf"): continue index = numpy.nonzero(distance_split_points_ordered == peak)[0][0] unnormalized_cdf = numpy.insert(unnormalized_cdf, index, unnormalized_cdf[index]) - distance_split_points_ordered = numpy.insert(distance_split_points_ordered, index, - distance_split_points_ordered[index]) + distance_split_points_ordered = numpy.insert( + distance_split_points_ordered, index, distance_split_points_ordered[index] + ) # walk_waiting_time_fraction = walk_total_time / (self.end_time_dep - self.start_time_dep) - unnormalized_cdf[(index + 1):] = unnormalized_cdf[(index + 1):] + delta_peaks[peak] + unnormalized_cdf[(index + 1) :] = ( + unnormalized_cdf[(index + 1) :] + delta_peaks[peak] + ) - norm_cdf = unnormalized_cdf / (unnormalized_cdf[-1] + delta_peaks[float('inf')]) + norm_cdf = unnormalized_cdf / (unnormalized_cdf[-1] + delta_peaks[float("inf")]) return distance_split_points_ordered, norm_cdf def _temporal_distance_pdf(self): @@ -208,9 +246,12 @@ def _temporal_distance_pdf(self): else: non_delta_peak_split_points.append(right) non_delta_peak_densities.append(prob_mass / float(width)) - assert (len(non_delta_peak_densities) == len(non_delta_peak_split_points) - 1) - return numpy.array(non_delta_peak_split_points), \ - numpy.array(non_delta_peak_densities), delta_peak_loc_to_probability_mass + assert len(non_delta_peak_densities) == len(non_delta_peak_split_points) - 1 + return ( + numpy.array(non_delta_peak_split_points), + numpy.array(non_delta_peak_densities), + delta_peak_loc_to_probability_mass, + ) def get_vlines_and_slopes_for_plotting(self): vertical_lines = [] @@ -220,24 +261,29 @@ def get_vlines_and_slopes_for_plotting(self): distance_end_minutes = block.distance_end distance_start_minutes = block.distance_start - slope = dict(x=[block.start_time, block.end_time], - y=[distance_start_minutes, distance_end_minutes]) + slope = dict( + x=[block.start_time, block.end_time], + y=[distance_start_minutes, distance_end_minutes], + ) slopes.append(slope) if i != 0: # no vertical line for the first observation previous_duration_minutes = self._profile_blocks[i - 1].distance_end - vertical_lines.append(dict(x=[block.start_time, block.start_time], - y=[previous_duration_minutes, distance_start_minutes])) + vertical_lines.append( + dict( + x=[block.start_time, block.start_time], + y=[previous_duration_minutes, distance_start_minutes], + ) + ) return vertical_lines, slopes def get_blocks(self): return self._profile_blocks def interpolate(self, time): - assert(self._start_time <= time <= self._end_time) + assert self._start_time <= time <= self._end_time for profile_block in self._profile_blocks: # find the first block whose end time is larger than or equal to that of the queried time if profile_block.end_time >= time: return profile_block.interpolate(time) - diff --git a/gtfspy/routing/pseudo_connection_scan_profiler.py b/gtfspy/routing/pseudo_connection_scan_profiler.py index 2cb029e..370f29d 100644 --- a/gtfspy/routing/pseudo_connection_scan_profiler.py +++ b/gtfspy/routing/pseudo_connection_scan_profiler.py @@ -29,13 +29,13 @@ from collections import defaultdict import networkx +from gtfspy.routing.label import LabelTime +from gtfspy.routing.abstract_routing_algorithm import AbstractRoutingAlgorithm from gtfspy.routing.connection import Connection -from gtfspy.routing.label import LabelTime +from gtfspy.routing.node_profile_c import NodeProfileC from gtfspy.routing.node_profile_simple import NodeProfileSimple -from gtfspy.routing.abstract_routing_algorithm import AbstractRoutingAlgorithm from gtfspy.routing.pseudo_connections import compute_pseudo_connections -from gtfspy.routing.node_profile_c import NodeProfileC class PseudoConnectionScanProfiler(AbstractRoutingAlgorithm): @@ -45,15 +45,17 @@ class PseudoConnectionScanProfiler(AbstractRoutingAlgorithm): http://i11www.iti.uni-karlsruhe.de/extra/publications/dpsw-isftr-13.pdf """ - def __init__(self, - transit_events, - target_stop, - start_time=None, - end_time=None, - transfer_margin=0, - walk_network=None, - walk_speed=1.5, - verbose=False): + def __init__( + self, + transit_events, + target_stop, + start_time=None, + end_time=None, + transfer_margin=0, + walk_network=None, + walk_speed=1.5, + verbose=False, + ): """ Parameters ---------- @@ -105,9 +107,14 @@ def __init__(self, edge_data = walk_network.get_edge_data(target_neighbor, target_stop) walk_duration = edge_data["d_walk"] / self._walk_speed self._stop_profiles[target_neighbor] = NodeProfileC(walk_duration) - pseudo_connection_set = compute_pseudo_connections(transit_events, self._start_time, self._end_time, - self._transfer_margin, self._walk_network, - self._walk_speed) + pseudo_connection_set = compute_pseudo_connections( + transit_events, + self._start_time, + self._end_time, + self._transfer_margin, + self._walk_network, + self._walk_speed, + ) self._pseudo_connections = list(pseudo_connection_set) self._all_connections = self._pseudo_connections + self._transit_connections self._all_connections.sort(key=lambda connection: -connection.departure_time) @@ -115,14 +122,14 @@ def __init__(self, def _run(self): # if source node in s1: previous_departure_time = float("inf") - connections = self._all_connections # list[Connection] + connections = self._all_connections # list[Connection] n_connections_tot = len(connections) for i, connection in enumerate(connections): # basic checking + printing progress: if self._verbose and i % 1000 == 0: print(i, "/", n_connections_tot) - assert (isinstance(connection, Connection)) - assert (connection.departure_time <= previous_departure_time) + assert isinstance(connection, Connection) + assert connection.departure_time <= previous_departure_time previous_departure_time = connection.departure_time # get all different "accessible" / arrival times (Pareto-optimal sets) @@ -139,22 +146,29 @@ def _run(self): earliest_arrival_time_via_same_trip = self.__trip_min_arrival_time[connection.trip_id] # then, take the minimum (or the Pareto-optimal set) of these three alternatives. - min_arrival_time = min(earliest_arrival_time_via_same_trip, - earliest_arrival_time_via_transfer) + min_arrival_time = min( + earliest_arrival_time_via_same_trip, earliest_arrival_time_via_transfer + ) # If there are no 'labels' to progress, nothing needs to be done. if min_arrival_time == float("inf"): continue # Update information for the trip - if (not connection.is_walk) and (earliest_arrival_time_via_same_trip > min_arrival_time): - self.__trip_min_arrival_time[connection.trip_id] = earliest_arrival_time_via_transfer + if (not connection.is_walk) and ( + earliest_arrival_time_via_same_trip > min_arrival_time + ): + self.__trip_min_arrival_time[ + connection.trip_id + ] = earliest_arrival_time_via_transfer # Compute the new "best" pareto_tuple possible (later: merge the sets of pareto-optimal labels) pareto_tuple = LabelTime(connection.departure_time, min_arrival_time) # update departure stop profile (later: with the sets of pareto-optimal labels) - self._stop_profiles[connection.departure_stop].update_pareto_optimal_tuples(pareto_tuple) + self._stop_profiles[connection.departure_stop].update_pareto_optimal_tuples( + pareto_tuple + ) @property def stop_profiles(self): diff --git a/gtfspy/routing/pseudo_connections.py b/gtfspy/routing/pseudo_connections.py index 2b2ea65..23ef186 100644 --- a/gtfspy/routing/pseudo_connections.py +++ b/gtfspy/routing/pseudo_connections.py @@ -1,9 +1,9 @@ from gtfspy.routing.connection import Connection -def compute_pseudo_connections(transit_connections, start_time_dep, - end_time_dep, transfer_margin, - walk_network, walk_speed): +def compute_pseudo_connections( + transit_connections, start_time_dep, end_time_dep, transfer_margin, walk_network, walk_speed +): """ Given a set of transit events and the static walk network, "transform" the static walking network into a set of "pseudo-connections". @@ -36,18 +36,17 @@ def compute_pseudo_connections(transit_connections, start_time_dep, walk_arr_stop = c.departure_stop walk_arr_time = c.departure_time - transfer_margin for _, walk_dep_stop, data in walk_network.edges(nbunch=[walk_arr_stop], data=True): - walk_dep_time = walk_arr_time - data['d_walk'] / float(walk_speed) + walk_dep_time = walk_arr_time - data["d_walk"] / float(walk_speed) if walk_dep_time > end_time_dep or walk_dep_time < start_time_dep: continue - pseudo_connection = Connection(walk_dep_stop, - walk_arr_stop, - walk_dep_time, - walk_arr_time, - Connection.WALK_TRIP_ID, - Connection.WALK_SEQ, - is_walk=True) + pseudo_connection = Connection( + walk_dep_stop, + walk_arr_stop, + walk_dep_time, + walk_arr_time, + Connection.WALK_TRIP_ID, + Connection.WALK_SEQ, + is_walk=True, + ) pseudo_connection_set.add(pseudo_connection) return pseudo_connection_set - - - diff --git a/gtfspy/routing/test/test_connection_scan.py b/gtfspy/routing/test/test_connection_scan.py index 59444a2..1924c19 100644 --- a/gtfspy/routing/test/test_connection_scan.py +++ b/gtfspy/routing/test/test_connection_scan.py @@ -7,14 +7,13 @@ class ConnectionScanTest(unittest.TestCase): - def setUp(self): event_list_raw_data = [ (1, 2, 0, 10, "trip_1", 1), (1, 3, 1, 10, "trip_2", 1), (2, 3, 10, 11, "trip_1", 2), (3, 4, 11, 13, "trip_1", 3), - (3, 6, 12, 14, "trip_3", 1) + (3, 6, 12, 14, "trip_3", 1), ] self.transit_connections = map(lambda el: Connection(*el), event_list_raw_data) self.walk_network = networkx.Graph() @@ -32,9 +31,15 @@ def test_basics(self): 2. Stop labels are respected. 3. Walk network is used properly. """ - csa = ConnectionScan(self.transit_connections, self.source_stop, - self.start_time, self.end_time, - self.transfer_margin, self.walk_network, self.walk_speed) + csa = ConnectionScan( + self.transit_connections, + self.source_stop, + self.start_time, + self.end_time, + self.transfer_margin, + self.walk_network, + self.walk_speed, + ) csa.run() arrival_times = csa.get_arrival_times() self.assertEqual(arrival_times[1], self.start_time) @@ -43,29 +48,41 @@ def test_basics(self): self.assertEqual(arrival_times[4], 13) self.assertEqual(arrival_times[5], 13 + 100) self.assertEqual(arrival_times[6], 14) - self.assertEqual(arrival_times[7], float('inf')) + self.assertEqual(arrival_times[7], float("inf")) self.assertGreater(csa.get_run_time(), 0) def test_change_starttime(self): start_time = 1 - self.transfer_margin - csa = ConnectionScan(self.transit_connections, self.source_stop, - start_time, self.end_time, - self.transfer_margin, self.walk_network, self.walk_speed) + csa = ConnectionScan( + self.transit_connections, + self.source_stop, + start_time, + self.end_time, + self.transfer_margin, + self.walk_network, + self.walk_speed, + ) csa.run() arrival_times = csa.get_arrival_times() self.assertEqual(arrival_times[1], start_time) - self.assertEqual(arrival_times[2], float('inf')) + self.assertEqual(arrival_times[2], float("inf")) self.assertEqual(arrival_times[3], 10) - self.assertEqual(arrival_times[4], float('inf')) - self.assertEqual(arrival_times[5], float('inf')) + self.assertEqual(arrival_times[4], float("inf")) + self.assertEqual(arrival_times[5], float("inf")) self.assertEqual(arrival_times[6], 14) - self.assertEqual(arrival_times[7], float('inf')) + self.assertEqual(arrival_times[7], float("inf")) def test_change_endtime(self): end_time = 11 - csa = ConnectionScan(self.transit_connections, self.source_stop, - self.start_time, end_time, - self.transfer_margin, self.walk_network, self.walk_speed) + csa = ConnectionScan( + self.transit_connections, + self.source_stop, + self.start_time, + end_time, + self.transfer_margin, + self.walk_network, + self.walk_speed, + ) csa.run() arrival_times = csa.get_arrival_times() self.assertEqual(arrival_times[1], self.start_time) @@ -73,14 +90,12 @@ def test_change_endtime(self): self.assertEqual(arrival_times[3], 10) self.assertEqual(arrival_times[4], 13) self.assertEqual(arrival_times[5], 13 + 100) - self.assertEqual(arrival_times[6], float('inf')) - self.assertEqual(arrival_times[7], float('inf')) + self.assertEqual(arrival_times[6], float("inf")) + self.assertEqual(arrival_times[7], float("inf")) def test_starts_with_walk(self): end_time = 11 - event_list_raw_data = [ - (1, 2, 0, 10, "trip_1", 1) - ] + event_list_raw_data = [(1, 2, 0, 10, "trip_1", 1)] transit_connections = map(lambda el: Connection(*el), event_list_raw_data) walk_network = networkx.Graph() walk_network.add_edge(1, 2, {"d_walk": 10}) @@ -88,19 +103,16 @@ def test_starts_with_walk(self): source_stop = 1 start_time = 0 transfer_margin = 0 - csa = ConnectionScan(transit_connections, source_stop, - start_time, end_time, - transfer_margin, walk_network, walk_speed) + csa = ConnectionScan( + transit_connections, + source_stop, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + ) csa.run() arrival_times = csa.get_arrival_times() self.assertEqual(arrival_times[1], start_time) self.assertEqual(arrival_times[2], 1) - - - - - - - - - diff --git a/gtfspy/routing/test/test_connection_scan_profile.py b/gtfspy/routing/test/test_connection_scan_profile.py index 4cf837b..fd5d24f 100644 --- a/gtfspy/routing/test/test_connection_scan_profile.py +++ b/gtfspy/routing/test/test_connection_scan_profile.py @@ -9,8 +9,8 @@ # noinspection PyAttributeOutsideInit -class ConnectionScanProfileTest(unittest.TestCase): +class ConnectionScanProfileTest(unittest.TestCase): def setUp(self): event_list_raw_data = [ (2, 4, 40, 50, "trip_6", 1), @@ -18,7 +18,7 @@ def setUp(self): (3, 4, 32, 35, "trip_4", 1), (2, 3, 25, 30, "trip_3", 1), (1, 2, 10, 20, "trip_2", 1), - (0, 1, 0, 10, "trip_1", 1) + (0, 1, 0, 10, "trip_1", 1), ] self.transit_connections = list(map(lambda el: Connection(*el), event_list_raw_data)) self.walk_network = networkx.Graph() @@ -32,9 +32,15 @@ def setUp(self): self.end_time = 50 def test_basics(self): - csa_profile = ConnectionScanProfiler(self.transit_connections, self.target_stop, - self.start_time, self.end_time, self.transfer_margin, - self.walk_network, self.walk_speed) + csa_profile = ConnectionScanProfiler( + self.transit_connections, + self.target_stop, + self.start_time, + self.end_time, + self.transfer_margin, + self.walk_network, + self.walk_speed, + ) csa_profile.run() stop_3_pareto_tuples = csa_profile.stop_profiles[3].get_final_optimal_labels() @@ -54,10 +60,7 @@ def test_basics(self): pareto_tuples.append(LabelTimeSimple(departure_time=20, arrival_time_target=50)) pareto_tuples.append(LabelTimeSimple(departure_time=32, arrival_time_target=55)) - self._assert_pareto_labels_equal( - pareto_tuples, - source_stop_pareto_optimal_tuples - ) + self._assert_pareto_labels_equal(pareto_tuples, source_stop_pareto_optimal_tuples) def test_wrong_event_data_ordering(self): event_list_wrong_ordering = [ @@ -66,11 +69,17 @@ def test_wrong_event_data_ordering(self): (2, 3, 25, 30, "trip_3", 1), (3, 4, 32, 35, "trip_4", 1), (1, 3, 32, 40, "trip_5", 1), - (2, 4, 40, 50, "trip_5", 1) + (2, 4, 40, 50, "trip_5", 1), ] - csa_profile = ConnectionScanProfiler(event_list_wrong_ordering, self.target_stop, - self.start_time, self.end_time, self.transfer_margin, - self.walk_network, self.walk_speed) + csa_profile = ConnectionScanProfiler( + event_list_wrong_ordering, + self.target_stop, + self.start_time, + self.end_time, + self.transfer_margin, + self.walk_network, + self.walk_speed, + ) self.assertRaises(AssertionError, csa_profile.run) def test_simple(self): @@ -91,22 +100,23 @@ def test_simple(self): pareto_tuples = list() pareto_tuples.append(LabelTimeSimple(departure_time=20, arrival_time_target=50)) - csa_profile = ConnectionScanProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - walk_network, walk_speed) + csa_profile = ConnectionScanProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + ) csa_profile.run() source_stop_profile = csa_profile.stop_profiles[source_stop] source_stop_pareto_tuples = source_stop_profile.get_final_optimal_labels() - self._assert_pareto_labels_equal( - pareto_tuples, - source_stop_pareto_tuples - ) + self._assert_pareto_labels_equal(pareto_tuples, source_stop_pareto_tuples) def test_last_leg_is_walk(self): - event_list_raw_data = [ - (0, 1, 0, 10, "trip_1", 1) - ] + event_list_raw_data = [(0, 1, 0, 10, "trip_1", 1)] transit_connections = list(map(lambda el: Connection(*el), event_list_raw_data)) walk_network = networkx.Graph() walk_network.add_edge(1, 2, {"d_walk": 20}) @@ -120,17 +130,21 @@ def test_last_leg_is_walk(self): pareto_tuples = list() pareto_tuples.append(LabelTimeSimple(departure_time=0, arrival_time_target=30)) - csa_profile = ConnectionScanProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - walk_network, walk_speed) + csa_profile = ConnectionScanProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + ) csa_profile.run() found_tuples = csa_profile.stop_profiles[source_stop].get_final_optimal_labels() self._assert_pareto_labels_equal(found_tuples, pareto_tuples) def test_walk_is_faster_than_by_trip(self): - event_list_raw_data = [ - (0, 1, 0, 10, "trip_1", 1) - ] + event_list_raw_data = [(0, 1, 0, 10, "trip_1", 1)] transit_connections = list(map(lambda el: Connection(*el), event_list_raw_data)) walk_speed = 2 source_stop = 0 @@ -141,9 +155,15 @@ def test_walk_is_faster_than_by_trip(self): walk_network = networkx.Graph() walk_network.add_edge(0, 1, {"d_walk": 1}) - csa_profile = ConnectionScanProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - walk_network, walk_speed) + csa_profile = ConnectionScanProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + ) csa_profile.run() source_profile = csa_profile.stop_profiles[source_stop] self.assertEqual(source_profile.evaluate_earliest_arrival_time_at_target(0, 0), 0.5) @@ -151,9 +171,7 @@ def test_walk_is_faster_than_by_trip(self): self.assertEqual(len(found_tuples), 0) def test_target_node_not_in_walk_network(self): - event_list_raw_data = [ - (0, 1, 0, 10, "trip_1", 1) - ] + event_list_raw_data = [(0, 1, 0, 10, "trip_1", 1)] transit_connections = list(map(lambda el: Connection(*el), event_list_raw_data)) walk_speed = 2 source_stop = 0 @@ -163,16 +181,21 @@ def test_target_node_not_in_walk_network(self): end_time = 50 walk_network = networkx.Graph() - csa_profile = ConnectionScanProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - walk_network, walk_speed) + csa_profile = ConnectionScanProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + ) csa_profile.run() source_profile = csa_profile.stop_profiles[source_stop] self.assertEqual(source_profile.evaluate_earliest_arrival_time_at_target(0, 0), 10) found_tuples = source_profile.get_final_optimal_labels() self.assertEqual(len(found_tuples), 1) - def _assert_pareto_labels_equal(self, found_tuples, should_be_tuples): for found_tuple in found_tuples: self.assertIn(found_tuple, should_be_tuples) diff --git a/gtfspy/routing/test/test_forward_journey.py b/gtfspy/routing/test/test_forward_journey.py index 079b876..929ed63 100644 --- a/gtfspy/routing/test/test_forward_journey.py +++ b/gtfspy/routing/test/test_forward_journey.py @@ -5,19 +5,32 @@ class ForwardJourneyTest(unittest.TestCase): - def test_add_leg(self): journey = ForwardJourney() - leg1 = Connection(departure_stop=0, arrival_stop=1, departure_time=0, arrival_time=1, - trip_id="tripI", seq=1, is_walk=False) + leg1 = Connection( + departure_stop=0, + arrival_stop=1, + departure_time=0, + arrival_time=1, + trip_id="tripI", + seq=1, + is_walk=False, + ) journey.add_leg(leg1) self.assertEqual(len(journey.legs), 1) self.assertEqual(journey.departure_time, leg1.departure_time) self.assertEqual(journey.arrival_time, leg1.arrival_time) self.assertEqual(journey.n_boardings, 1) - leg2 = Connection(departure_stop=1, arrival_stop=2, departure_time=1, arrival_time=2, - trip_id="tripI", seq=1, is_walk=False) + leg2 = Connection( + departure_stop=1, + arrival_stop=2, + departure_time=1, + arrival_time=2, + trip_id="tripI", + seq=1, + is_walk=False, + ) journey.add_leg(leg2) self.assertEqual(len(journey.legs), 2) self.assertEqual(journey.departure_time, leg1.departure_time) @@ -25,12 +38,33 @@ def test_add_leg(self): self.assertEqual(journey.n_boardings, 1) def test_dominates(self): - leg1 = Connection(departure_stop=0, arrival_stop=1, departure_time=0, arrival_time=1, - trip_id="tripI", seq=1, is_walk=False) - leg2 = Connection(departure_stop=1, arrival_stop=2, departure_time=1, arrival_time=2, - trip_id="tripI", seq=1, is_walk=False) - leg3 = Connection(departure_stop=1, arrival_stop=2, departure_time=1, arrival_time=3, - trip_id="tripI", seq=1, is_walk=False) + leg1 = Connection( + departure_stop=0, + arrival_stop=1, + departure_time=0, + arrival_time=1, + trip_id="tripI", + seq=1, + is_walk=False, + ) + leg2 = Connection( + departure_stop=1, + arrival_stop=2, + departure_time=1, + arrival_time=2, + trip_id="tripI", + seq=1, + is_walk=False, + ) + leg3 = Connection( + departure_stop=1, + arrival_stop=2, + departure_time=1, + arrival_time=3, + trip_id="tripI", + seq=1, + is_walk=False, + ) journey1 = ForwardJourney(legs=[leg1]) journey2 = ForwardJourney(legs=[leg2]) journey12 = ForwardJourney(legs=[leg1, leg2]) @@ -45,7 +79,7 @@ def test_basics(self): (1, 100, 32, 36, "trip_5", 1), (100, 3, 36, 40, "trip_5", 2), (3, 4, 40, 41, "trip_4", 1), - (4, 2, 44, 50, None, 1) + (4, 2, 44, 50, None, 1), ] legs = list(map(lambda el: Connection(*el), event_list_raw_data)) test_journey = ForwardJourney(legs) @@ -67,7 +101,7 @@ def test_transfer_stop_pairs(self): (100, 3, 36, 40, "trip_5", 2), (3, 4, 40, 41, "trip_4", 1), (4, 2, 44, 50, None, 1), - (10, 11, 52, 55, "trip_6", 1) + (10, 11, 52, 55, "trip_6", 1), ] legs = list(map(lambda el: Connection(*el), event_list_raw_data)) test_journey = ForwardJourney(legs) @@ -81,4 +115,3 @@ def test_transfer_stop_pairs(self): self.assertEqual(transfer_stop_pairs[1][1], 3) self.assertEqual(transfer_stop_pairs[2][0], 2) self.assertEqual(transfer_stop_pairs[2][1], 10) - diff --git a/gtfspy/routing/test/test_journey_data.py b/gtfspy/routing/test/test_journey_data.py index d7548a6..a5632a1 100644 --- a/gtfspy/routing/test/test_journey_data.py +++ b/gtfspy/routing/test/test_journey_data.py @@ -16,7 +16,10 @@ class TestJourneyData(TestCase): # noinspection PyAttributeOutsideInit def _import_sample_gtfs_db(self): - import_gtfs([os.path.join(os.path.dirname(__file__), "../../test/test_data/test_gtfs.zip")], self.gtfs_path) + import_gtfs( + [os.path.join(os.path.dirname(__file__), "../../test/test_data/test_gtfs.zip")], + self.gtfs_path, + ) def _remove_routing_test_data_directory_if_exists(self): try: @@ -31,14 +34,18 @@ def _create_routing_test_data_directory(self): def setUp(self): self.routing_tmp_test_data_dir = "./tmp_routing_test_data/" self.gtfs_path = os.path.join(self.routing_tmp_test_data_dir, "test_gtfs.sqlite") - self.data_store_path = os.path.join(self.routing_tmp_test_data_dir, "test_data_store.sqlite") + self.data_store_path = os.path.join( + self.routing_tmp_test_data_dir, "test_data_store.sqlite" + ) self._remove_routing_test_data_directory_if_exists() self._create_routing_test_data_directory() self._import_sample_gtfs_db() - self.jdm = JourneyDataManager(self.gtfs_path, - os.path.join(self.routing_tmp_test_data_dir, "test_journeys.sqlite"), - routing_params={"track_vehicle_legs": True}) + self.jdm = JourneyDataManager( + self.gtfs_path, + os.path.join(self.routing_tmp_test_data_dir, "test_journeys.sqlite"), + routing_params={"track_vehicle_legs": True}, + ) def tearDown(self): self._remove_routing_test_data_directory_if_exists() @@ -47,16 +54,20 @@ def test_boardings_computations_based_on_journeys(self): # input some journeys destination_stop = 1 origin_stop = 2 - self.jdm.import_journey_data_for_target_stop(destination_stop, - {origin_stop: - [LabelTimeWithBoardingsCount(1, 2, 1, True), - LabelTimeWithBoardingsCount(2, 3, 2, True)]}, - enforce_synchronous_writes=True - ) + self.jdm.import_journey_data_for_target_stop( + destination_stop, + { + origin_stop: [ + LabelTimeWithBoardingsCount(1, 2, 1, True), + LabelTimeWithBoardingsCount(2, 3, 2, True), + ] + }, + enforce_synchronous_writes=True, + ) self.jdm.compute_and_store_travel_impedance_measures(0, 2, self.data_store_path) store = TravelImpedanceDataStore(self.data_store_path) df = store.read_data_as_dataframe("temporal_distance") self.assertAlmostEqual(df.iloc[0]["min"], 1) self.assertAlmostEqual(df.iloc[0]["mean"], 1.5) self.assertAlmostEqual(df.iloc[0]["max"], 2.0) - self.assertIn(df.iloc[0]["median"],[1, 2, 1.0, 1.5, 2.0]) \ No newline at end of file + self.assertIn(df.iloc[0]["median"], [1, 2, 1.0, 1.5, 2.0]) diff --git a/gtfspy/routing/test/test_label.py b/gtfspy/routing/test/test_label.py index 08c22a8..0bf5495 100644 --- a/gtfspy/routing/test/test_label.py +++ b/gtfspy/routing/test/test_label.py @@ -5,12 +5,19 @@ from unittest import TestCase -from gtfspy.routing.label import LabelTime, LabelTimeWithBoardingsCount, merge_pareto_frontiers, \ - LabelVehLegCount, compute_pareto_front, compute_pareto_front_naive, LabelTimeAndRoute, LabelTimeBoardingsAndRoute +from gtfspy.routing.label import ( + LabelTime, + LabelTimeWithBoardingsCount, + merge_pareto_frontiers, + LabelVehLegCount, + compute_pareto_front, + compute_pareto_front_naive, + LabelTimeAndRoute, + LabelTimeBoardingsAndRoute, +) class TestLabelTime(TestCase): - def test_dominates(self): label1 = LabelTime(departure_time=0, arrival_time_target=20) label2 = LabelTime(departure_time=1, arrival_time_target=10) @@ -61,52 +68,76 @@ def test_large_numbers_do_not_overflow(self): label = LabelTime( departure_time=float(departure_time), arrival_time_target=float(arrival_time), - first_leg_is_walk=False + first_leg_is_walk=False, ) self.assertEqual(departure_time, label.departure_time) self.assertEqual(arrival_time, label.arrival_time_target) class TestLabelTimeAndVehLegCount(TestCase): - def test_dominates_simple(self): - label1 = LabelTimeWithBoardingsCount(departure_time=0, arrival_time_target=20, n_boardings=0, first_leg_is_walk=False) - label2 = LabelTimeWithBoardingsCount(departure_time=1, arrival_time_target=10, n_boardings=0, first_leg_is_walk=False) + label1 = LabelTimeWithBoardingsCount( + departure_time=0, arrival_time_target=20, n_boardings=0, first_leg_is_walk=False + ) + label2 = LabelTimeWithBoardingsCount( + departure_time=1, arrival_time_target=10, n_boardings=0, first_leg_is_walk=False + ) self.assertTrue(label2.dominates(label1)) def test_does_not_dominate_same(self): - label2 = LabelTimeWithBoardingsCount(departure_time=1, arrival_time_target=10, n_boardings=0, first_leg_is_walk=False) - label3 = LabelTimeWithBoardingsCount(departure_time=1, arrival_time_target=10, n_boardings=0, first_leg_is_walk=False) + label2 = LabelTimeWithBoardingsCount( + departure_time=1, arrival_time_target=10, n_boardings=0, first_leg_is_walk=False + ) + label3 = LabelTimeWithBoardingsCount( + departure_time=1, arrival_time_target=10, n_boardings=0, first_leg_is_walk=False + ) self.assertTrue(label2.dominates(label3)) def test_dominates_later_arrival_time(self): - label2 = LabelTimeWithBoardingsCount(departure_time=1, arrival_time_target=10, n_boardings=0, first_leg_is_walk=False) - label4 = LabelTimeWithBoardingsCount(departure_time=1, arrival_time_target=11, n_boardings=0, first_leg_is_walk=False) + label2 = LabelTimeWithBoardingsCount( + departure_time=1, arrival_time_target=10, n_boardings=0, first_leg_is_walk=False + ) + label4 = LabelTimeWithBoardingsCount( + departure_time=1, arrival_time_target=11, n_boardings=0, first_leg_is_walk=False + ) self.assertTrue(label2.dominates(label4)) def test_dominates_earlier_departure_time(self): - label2 = LabelTimeWithBoardingsCount(departure_time=1, arrival_time_target=10, n_boardings=0, first_leg_is_walk=False) - label5 = LabelTimeWithBoardingsCount(departure_time=0, arrival_time_target=10, n_boardings=0, first_leg_is_walk=False) + label2 = LabelTimeWithBoardingsCount( + departure_time=1, arrival_time_target=10, n_boardings=0, first_leg_is_walk=False + ) + label5 = LabelTimeWithBoardingsCount( + departure_time=0, arrival_time_target=10, n_boardings=0, first_leg_is_walk=False + ) self.assertTrue(label2.dominates(label5)) def test_dominates_less_transfers(self): - labela = LabelTimeWithBoardingsCount(departure_time=1, arrival_time_target=10, n_boardings=1, first_leg_is_walk=False) - labelb = LabelTimeWithBoardingsCount(departure_time=1, arrival_time_target=10, n_boardings=0, first_leg_is_walk=False) + labela = LabelTimeWithBoardingsCount( + departure_time=1, arrival_time_target=10, n_boardings=1, first_leg_is_walk=False + ) + labelb = LabelTimeWithBoardingsCount( + departure_time=1, arrival_time_target=10, n_boardings=0, first_leg_is_walk=False + ) self.assertTrue(labelb.dominates(labela)) def test_dominates_less_transfers_different_travel_time(self): - labela = LabelTimeWithBoardingsCount(departure_time=1, arrival_time_target=9, n_boardings=1, first_leg_is_walk=False) - labelb = LabelTimeWithBoardingsCount(departure_time=1, arrival_time_target=10, n_boardings=0, first_leg_is_walk=False) + labela = LabelTimeWithBoardingsCount( + departure_time=1, arrival_time_target=9, n_boardings=1, first_leg_is_walk=False + ) + labelb = LabelTimeWithBoardingsCount( + departure_time=1, arrival_time_target=10, n_boardings=0, first_leg_is_walk=False + ) self.assertFalse(labelb.dominates(labela)) self.assertFalse(labela.dominates(labelb)) - def test_duration(self): label1 = LabelTime(departure_time=0, arrival_time_target=20, last_leg_is_walk=False) self.assertEqual(20, label1.duration()) def test_sort(self): - l1 = LabelTimeWithBoardingsCount(departure_time=1, arrival_time_target=1, n_boardings=3, first_leg_is_walk=False) + l1 = LabelTimeWithBoardingsCount( + departure_time=1, arrival_time_target=1, n_boardings=3, first_leg_is_walk=False + ) l2 = LabelTimeWithBoardingsCount(0, 0, 0, False) self.assertTrue(l1 > l2) self.assertTrue(l1 >= l2) @@ -132,13 +163,11 @@ def test_sort(self): self.assertTrue(sorted([l1, l2])[0] == l2) - l1 = LabelTimeWithBoardingsCount(1, 1, 10, True) l2 = LabelTimeWithBoardingsCount(1, 1, 10, False) self.assertTrue(l1 < l2) self.assertFalse(l1 > l2) - def test_large_numbers_do_not_overflow(self): departure_time = 1475530980 arrival_time = 1475530980 @@ -146,21 +175,18 @@ def test_large_numbers_do_not_overflow(self): departure_time=float(departure_time), arrival_time_target=float(arrival_time), n_boardings=0, - first_leg_is_walk=False + first_leg_is_walk=False, ) self.assertEqual(departure_time, label.departure_time) self.assertEqual(arrival_time, label.arrival_time_target) - class TestLabelVehLegCount(TestCase): - def test_dominates_simple(self): label1 = LabelVehLegCount(n_boardings=1) label2 = LabelVehLegCount(n_boardings=0) self.assertTrue(label2.dominates(label1)) - def test_sort(self): l1 = LabelVehLegCount(departure_time=1, n_boardings=3) l2 = LabelVehLegCount(departure_time=0, n_boardings=0) @@ -194,79 +220,114 @@ def test_pareto_frontier(self): class TestLabelTimeAndRoute(TestCase): - def test_dominates_simple(self): - label1 = LabelTimeAndRoute(departure_time=0, arrival_time_target=20, movement_duration=0, first_leg_is_walk=False) - label2 = LabelTimeAndRoute(departure_time=1, arrival_time_target=10, movement_duration=0, first_leg_is_walk=False) + label1 = LabelTimeAndRoute( + departure_time=0, arrival_time_target=20, movement_duration=0, first_leg_is_walk=False + ) + label2 = LabelTimeAndRoute( + departure_time=1, arrival_time_target=10, movement_duration=0, first_leg_is_walk=False + ) self.assertTrue(label2.dominates(label1)) self.assertFalse(label1.dominates(label2)) def test_does_not_dominate_same(self): - label2 = LabelTimeAndRoute(departure_time=1, arrival_time_target=10, movement_duration=0, first_leg_is_walk=False) - label3 = LabelTimeAndRoute(departure_time=1, arrival_time_target=10, movement_duration=0, first_leg_is_walk=False) + label2 = LabelTimeAndRoute( + departure_time=1, arrival_time_target=10, movement_duration=0, first_leg_is_walk=False + ) + label3 = LabelTimeAndRoute( + departure_time=1, arrival_time_target=10, movement_duration=0, first_leg_is_walk=False + ) self.assertTrue(label2.dominates(label3)) self.assertTrue(label3.dominates(label2)) def test_dominates_later_arrival_time(self): - label2 = LabelTimeAndRoute(departure_time=1, arrival_time_target=10, movement_duration=0, first_leg_is_walk=False) - label4 = LabelTimeAndRoute(departure_time=1, arrival_time_target=11, movement_duration=0, first_leg_is_walk=False) + label2 = LabelTimeAndRoute( + departure_time=1, arrival_time_target=10, movement_duration=0, first_leg_is_walk=False + ) + label4 = LabelTimeAndRoute( + departure_time=1, arrival_time_target=11, movement_duration=0, first_leg_is_walk=False + ) self.assertTrue(label2.dominates(label4)) self.assertFalse(label4.dominates(label2)) def test_dominates_earlier_departure_time(self): - label2 = LabelTimeAndRoute(departure_time=1, arrival_time_target=10, movement_duration=0, first_leg_is_walk=False) - label5 = LabelTimeAndRoute(departure_time=0, arrival_time_target=10, movement_duration=0, first_leg_is_walk=False) + label2 = LabelTimeAndRoute( + departure_time=1, arrival_time_target=10, movement_duration=0, first_leg_is_walk=False + ) + label5 = LabelTimeAndRoute( + departure_time=0, arrival_time_target=10, movement_duration=0, first_leg_is_walk=False + ) self.assertTrue(label2.dominates(label5)) self.assertFalse(label5.dominates(label2)) def test_dominates_less_movement_duration(self): - labela = LabelTimeAndRoute(departure_time=1, arrival_time_target=10, movement_duration=0, first_leg_is_walk=False) - labelb = LabelTimeAndRoute(departure_time=1, arrival_time_target=10, movement_duration=1, first_leg_is_walk=False) + labela = LabelTimeAndRoute( + departure_time=1, arrival_time_target=10, movement_duration=0, first_leg_is_walk=False + ) + labelb = LabelTimeAndRoute( + departure_time=1, arrival_time_target=10, movement_duration=1, first_leg_is_walk=False + ) self.assertFalse(labelb.dominates(labela)) self.assertTrue(labela.dominates(labelb)) def test_dominates_less_movement_duration_when_arrival_time_not_the_same(self): # a should dominate b as the travel time is shorter - labela = LabelTimeAndRoute(departure_time=1, arrival_time_target=9, movement_duration=1, first_leg_is_walk=False) - labelb = LabelTimeAndRoute(departure_time=1, arrival_time_target=10, movement_duration=0, first_leg_is_walk=False) + labela = LabelTimeAndRoute( + departure_time=1, arrival_time_target=9, movement_duration=1, first_leg_is_walk=False + ) + labelb = LabelTimeAndRoute( + departure_time=1, arrival_time_target=10, movement_duration=0, first_leg_is_walk=False + ) self.assertFalse(labelb.dominates(labela)) self.assertTrue(labela.dominates(labelb)) def test_dominates_less_movement_duration_when_departure_time_not_the_same(self): - labela = LabelTimeAndRoute(departure_time=4, arrival_time_target=10, movement_duration=1, first_leg_is_walk=False) - labelb = LabelTimeAndRoute(departure_time=1, arrival_time_target=10, movement_duration=0, first_leg_is_walk=False) + labela = LabelTimeAndRoute( + departure_time=4, arrival_time_target=10, movement_duration=1, first_leg_is_walk=False + ) + labelb = LabelTimeAndRoute( + departure_time=1, arrival_time_target=10, movement_duration=0, first_leg_is_walk=False + ) self.assertFalse(labelb.dominates(labela)) self.assertTrue(labela.dominates(labelb)) def test_dominates_ignoring_dep_time_finalization_less_movement_duration(self): - labela = LabelTimeAndRoute(departure_time=1, arrival_time_target=10, movement_duration=0, - first_leg_is_walk=False) - labelb = LabelTimeAndRoute(departure_time=1, arrival_time_target=10, movement_duration=1, - first_leg_is_walk=False) + labela = LabelTimeAndRoute( + departure_time=1, arrival_time_target=10, movement_duration=0, first_leg_is_walk=False + ) + labelb = LabelTimeAndRoute( + departure_time=1, arrival_time_target=10, movement_duration=1, first_leg_is_walk=False + ) self.assertTrue(labelb.dominates_ignoring_dep_time_finalization(labela)) self.assertTrue(labela.dominates_ignoring_dep_time_finalization(labelb)) def test_dominates_ignoring_dep_time_finalization_arrival_time(self): - labela = LabelTimeAndRoute(departure_time=1, arrival_time_target=9, movement_duration=1, - first_leg_is_walk=False) - labelb = LabelTimeAndRoute(departure_time=1, arrival_time_target=10, movement_duration=1, - first_leg_is_walk=False) + labela = LabelTimeAndRoute( + departure_time=1, arrival_time_target=9, movement_duration=1, first_leg_is_walk=False + ) + labelb = LabelTimeAndRoute( + departure_time=1, arrival_time_target=10, movement_duration=1, first_leg_is_walk=False + ) self.assertFalse(labelb.dominates_ignoring_dep_time_finalization(labela)) self.assertTrue(labela.dominates_ignoring_dep_time_finalization(labelb)) def test_dominates_ignoring_dep_time_less_movement_duration(self): - labela = LabelTimeAndRoute(departure_time=1, arrival_time_target=10, movement_duration=0, - first_leg_is_walk=False) - labelb = LabelTimeAndRoute(departure_time=1, arrival_time_target=10, movement_duration=1, - first_leg_is_walk=False) + labela = LabelTimeAndRoute( + departure_time=1, arrival_time_target=10, movement_duration=0, first_leg_is_walk=False + ) + labelb = LabelTimeAndRoute( + departure_time=1, arrival_time_target=10, movement_duration=1, first_leg_is_walk=False + ) self.assertTrue(labelb.dominates_ignoring_dep_time(labela)) self.assertTrue(labela.dominates_ignoring_dep_time(labelb)) def test_dominates_ignoring_dep_time_arrival_time(self): - labela = LabelTimeAndRoute(departure_time=1, arrival_time_target=9, movement_duration=1, - first_leg_is_walk=False) - labelb = LabelTimeAndRoute(departure_time=1, arrival_time_target=10, movement_duration=1, - first_leg_is_walk=False) + labela = LabelTimeAndRoute( + departure_time=1, arrival_time_target=9, movement_duration=1, first_leg_is_walk=False + ) + labelb = LabelTimeAndRoute( + departure_time=1, arrival_time_target=10, movement_duration=1, first_leg_is_walk=False + ) self.assertFalse(labelb.dominates_ignoring_dep_time(labela)) self.assertTrue(labela.dominates_ignoring_dep_time(labelb)) @@ -279,12 +340,17 @@ def test_dominates_ignoring_dep_time_finalization_equal(self): self.assertFalse(labelb.dominates_ignoring_dep_time_finalization(labela)) self.assertFalse(labela.dominates_ignoring_dep_time_finalization(labelb)) """ + def test_duration(self): - label1 = LabelTimeAndRoute(departure_time=0, arrival_time_target=20, movement_duration=1, first_leg_is_walk=False) + label1 = LabelTimeAndRoute( + departure_time=0, arrival_time_target=20, movement_duration=1, first_leg_is_walk=False + ) self.assertEqual(20, label1.duration()) def test_sort(self): - l1 = LabelTimeAndRoute(departure_time=1, arrival_time_target=1, movement_duration=3, first_leg_is_walk=False) + l1 = LabelTimeAndRoute( + departure_time=1, arrival_time_target=1, movement_duration=3, first_leg_is_walk=False + ) l2 = LabelTimeAndRoute(0, 0, 0, False) self.assertTrue(l1 > l2) self.assertTrue(l1 >= l2) @@ -322,102 +388,245 @@ def test_large_numbers_do_not_overflow(self): departure_time=float(departure_time), arrival_time_target=float(arrival_time), movement_duration=0, - first_leg_is_walk=False + first_leg_is_walk=False, ) self.assertEqual(departure_time, label.departure_time) self.assertEqual(arrival_time, label.arrival_time_target) class TestLabelTimeBoardingsAndRoute(TestCase): - def test_dominates_simple(self): - label1 = LabelTimeBoardingsAndRoute(departure_time=0, arrival_time_target=20, movement_duration=0, n_boardings=1, first_leg_is_walk=False) - label2 = LabelTimeBoardingsAndRoute(departure_time=1, arrival_time_target=10, movement_duration=0, n_boardings=1, first_leg_is_walk=False) + label1 = LabelTimeBoardingsAndRoute( + departure_time=0, + arrival_time_target=20, + movement_duration=0, + n_boardings=1, + first_leg_is_walk=False, + ) + label2 = LabelTimeBoardingsAndRoute( + departure_time=1, + arrival_time_target=10, + movement_duration=0, + n_boardings=1, + first_leg_is_walk=False, + ) self.assertTrue(label2.dominates(label1)) self.assertFalse(label1.dominates(label2)) def test_does_not_dominate_same(self): - label2 = LabelTimeBoardingsAndRoute(departure_time=1, arrival_time_target=10, movement_duration=0, n_boardings=1, first_leg_is_walk=False) - label3 = LabelTimeBoardingsAndRoute(departure_time=1, arrival_time_target=10, movement_duration=0, n_boardings=1, first_leg_is_walk=False) + label2 = LabelTimeBoardingsAndRoute( + departure_time=1, + arrival_time_target=10, + movement_duration=0, + n_boardings=1, + first_leg_is_walk=False, + ) + label3 = LabelTimeBoardingsAndRoute( + departure_time=1, + arrival_time_target=10, + movement_duration=0, + n_boardings=1, + first_leg_is_walk=False, + ) self.assertTrue(label2.dominates(label3)) self.assertTrue(label3.dominates(label2)) def test_dominates_later_arrival_time(self): - label2 = LabelTimeBoardingsAndRoute(departure_time=1, arrival_time_target=10, movement_duration=0, n_boardings=1, first_leg_is_walk=False) - label4 = LabelTimeBoardingsAndRoute(departure_time=1, arrival_time_target=11, movement_duration=0, n_boardings=1, first_leg_is_walk=False) + label2 = LabelTimeBoardingsAndRoute( + departure_time=1, + arrival_time_target=10, + movement_duration=0, + n_boardings=1, + first_leg_is_walk=False, + ) + label4 = LabelTimeBoardingsAndRoute( + departure_time=1, + arrival_time_target=11, + movement_duration=0, + n_boardings=1, + first_leg_is_walk=False, + ) self.assertTrue(label2.dominates(label4)) self.assertFalse(label4.dominates(label2)) def test_dominates_earlier_departure_time(self): - label2 = LabelTimeBoardingsAndRoute(departure_time=1, arrival_time_target=10, movement_duration=0, n_boardings=1, first_leg_is_walk=False) - label5 = LabelTimeBoardingsAndRoute(departure_time=0, arrival_time_target=10, movement_duration=0, n_boardings=1, first_leg_is_walk=False) + label2 = LabelTimeBoardingsAndRoute( + departure_time=1, + arrival_time_target=10, + movement_duration=0, + n_boardings=1, + first_leg_is_walk=False, + ) + label5 = LabelTimeBoardingsAndRoute( + departure_time=0, + arrival_time_target=10, + movement_duration=0, + n_boardings=1, + first_leg_is_walk=False, + ) self.assertTrue(label2.dominates(label5)) self.assertFalse(label5.dominates(label2)) def test_dominates_less_movement_duration(self): - labela = LabelTimeBoardingsAndRoute(departure_time=1, arrival_time_target=10, movement_duration=0, n_boardings=1, first_leg_is_walk=False) - labelb = LabelTimeBoardingsAndRoute(departure_time=1, arrival_time_target=10, movement_duration=1, n_boardings=1, first_leg_is_walk=False) + labela = LabelTimeBoardingsAndRoute( + departure_time=1, + arrival_time_target=10, + movement_duration=0, + n_boardings=1, + first_leg_is_walk=False, + ) + labelb = LabelTimeBoardingsAndRoute( + departure_time=1, + arrival_time_target=10, + movement_duration=1, + n_boardings=1, + first_leg_is_walk=False, + ) self.assertFalse(labelb.dominates(labela)) self.assertTrue(labela.dominates(labelb)) def test_dominates_less_movement_duration_when_arrival_time_not_the_same(self): # a should dominate b as the travel time is shorter - labela = LabelTimeBoardingsAndRoute(departure_time=1, arrival_time_target=9, movement_duration=1, n_boardings=1, first_leg_is_walk=False) - labelb = LabelTimeBoardingsAndRoute(departure_time=1, arrival_time_target=10, movement_duration=0, n_boardings=1, first_leg_is_walk=False) + labela = LabelTimeBoardingsAndRoute( + departure_time=1, + arrival_time_target=9, + movement_duration=1, + n_boardings=1, + first_leg_is_walk=False, + ) + labelb = LabelTimeBoardingsAndRoute( + departure_time=1, + arrival_time_target=10, + movement_duration=0, + n_boardings=1, + first_leg_is_walk=False, + ) self.assertFalse(labelb.dominates(labela)) self.assertTrue(labela.dominates(labelb)) def test_dominates_less_movement_duration_when_departure_time_not_the_same(self): - labela = LabelTimeBoardingsAndRoute(departure_time=4, arrival_time_target=10, movement_duration=1, n_boardings=1, first_leg_is_walk=False) - labelb = LabelTimeBoardingsAndRoute(departure_time=1, arrival_time_target=10, movement_duration=0, n_boardings=1, first_leg_is_walk=False) + labela = LabelTimeBoardingsAndRoute( + departure_time=4, + arrival_time_target=10, + movement_duration=1, + n_boardings=1, + first_leg_is_walk=False, + ) + labelb = LabelTimeBoardingsAndRoute( + departure_time=1, + arrival_time_target=10, + movement_duration=0, + n_boardings=1, + first_leg_is_walk=False, + ) self.assertFalse(labelb.dominates(labela)) self.assertTrue(labela.dominates(labelb)) def test_dominates_ignoring_dep_time_finalization_less_movement_duration(self): - labela = LabelTimeBoardingsAndRoute(departure_time=1, arrival_time_target=10, movement_duration=0, n_boardings=1, - first_leg_is_walk=False) - labelb = LabelTimeBoardingsAndRoute(departure_time=1, arrival_time_target=10, movement_duration=1, n_boardings=1, - first_leg_is_walk=False) + labela = LabelTimeBoardingsAndRoute( + departure_time=1, + arrival_time_target=10, + movement_duration=0, + n_boardings=1, + first_leg_is_walk=False, + ) + labelb = LabelTimeBoardingsAndRoute( + departure_time=1, + arrival_time_target=10, + movement_duration=1, + n_boardings=1, + first_leg_is_walk=False, + ) self.assertFalse(labelb.dominates_ignoring_dep_time_finalization(labela)) self.assertTrue(labela.dominates_ignoring_dep_time_finalization(labelb)) def test_dominates_ignoring_dep_time_finalization_arrival_time(self): - labela = LabelTimeBoardingsAndRoute(departure_time=1, arrival_time_target=9, movement_duration=1, n_boardings=1, - first_leg_is_walk=False) - labelb = LabelTimeBoardingsAndRoute(departure_time=1, arrival_time_target=10, movement_duration=0, n_boardings=1, - first_leg_is_walk=False) + labela = LabelTimeBoardingsAndRoute( + departure_time=1, + arrival_time_target=9, + movement_duration=1, + n_boardings=1, + first_leg_is_walk=False, + ) + labelb = LabelTimeBoardingsAndRoute( + departure_time=1, + arrival_time_target=10, + movement_duration=0, + n_boardings=1, + first_leg_is_walk=False, + ) self.assertFalse(labelb.dominates_ignoring_dep_time_finalization(labela)) self.assertTrue(labela.dominates_ignoring_dep_time_finalization(labelb)) def test_dominates_ignoring_dep_time_finalization_both_pareto(self): - labela = LabelTimeBoardingsAndRoute(departure_time=1, arrival_time_target=9, movement_duration=1, n_boardings=1, - first_leg_is_walk=False) - labelb = LabelTimeBoardingsAndRoute(departure_time=1, arrival_time_target=10, movement_duration=0, n_boardings=0, - first_leg_is_walk=False) + labela = LabelTimeBoardingsAndRoute( + departure_time=1, + arrival_time_target=9, + movement_duration=1, + n_boardings=1, + first_leg_is_walk=False, + ) + labelb = LabelTimeBoardingsAndRoute( + departure_time=1, + arrival_time_target=10, + movement_duration=0, + n_boardings=0, + first_leg_is_walk=False, + ) self.assertFalse(labelb.dominates_ignoring_dep_time_finalization(labela)) self.assertFalse(labela.dominates_ignoring_dep_time_finalization(labelb)) def test_dominates_ignoring_dep_time_less_movement_duration(self): - labela = LabelTimeBoardingsAndRoute(departure_time=1, arrival_time_target=10, movement_duration=0, n_boardings=1, - first_leg_is_walk=False) - labelb = LabelTimeBoardingsAndRoute(departure_time=1, arrival_time_target=10, movement_duration=1, n_boardings=1, - first_leg_is_walk=False) + labela = LabelTimeBoardingsAndRoute( + departure_time=1, + arrival_time_target=10, + movement_duration=0, + n_boardings=1, + first_leg_is_walk=False, + ) + labelb = LabelTimeBoardingsAndRoute( + departure_time=1, + arrival_time_target=10, + movement_duration=1, + n_boardings=1, + first_leg_is_walk=False, + ) self.assertFalse(labelb.dominates_ignoring_dep_time(labela)) self.assertTrue(labela.dominates_ignoring_dep_time(labelb)) def test_dominates_ignoring_dep_time_arrival_time(self): - labela = LabelTimeBoardingsAndRoute(departure_time=1, arrival_time_target=9, movement_duration=1, n_boardings=1, - first_leg_is_walk=False) - labelb = LabelTimeBoardingsAndRoute(departure_time=1, arrival_time_target=10, movement_duration=1, n_boardings=1, - first_leg_is_walk=False) + labela = LabelTimeBoardingsAndRoute( + departure_time=1, + arrival_time_target=9, + movement_duration=1, + n_boardings=1, + first_leg_is_walk=False, + ) + labelb = LabelTimeBoardingsAndRoute( + departure_time=1, + arrival_time_target=10, + movement_duration=1, + n_boardings=1, + first_leg_is_walk=False, + ) self.assertFalse(labelb.dominates_ignoring_dep_time(labela)) self.assertTrue(labela.dominates_ignoring_dep_time(labelb)) def test_various_dominates(self): - labela = LabelTimeBoardingsAndRoute(departure_time=1481520618, arrival_time_target=1481521300, n_boardings=1, - movement_duration=681, first_leg_is_walk=True) - labelb = LabelTimeBoardingsAndRoute(departure_time=1481520618, arrival_time_target=1481521215, n_boardings=1, - movement_duration=597, first_leg_is_walk=True) + labela = LabelTimeBoardingsAndRoute( + departure_time=1481520618, + arrival_time_target=1481521300, + n_boardings=1, + movement_duration=681, + first_leg_is_walk=True, + ) + labelb = LabelTimeBoardingsAndRoute( + departure_time=1481520618, + arrival_time_target=1481521215, + n_boardings=1, + movement_duration=597, + first_leg_is_walk=True, + ) self.assertTrue(labelb.dominates(labela)) self.assertFalse(labela.dominates(labelb)) @@ -430,22 +639,33 @@ def test_dominates_ignoring_dep_time_finalization_equal(self): self.assertFalse(labelb.dominates_ignoring_dep_time_finalization(labela)) self.assertFalse(labela.dominates_ignoring_dep_time_finalization(labelb)) """ + def test_duration(self): - label1 = LabelTimeBoardingsAndRoute(departure_time=0, arrival_time_target=20, movement_duration=1, - n_boardings=1, first_leg_is_walk=False) + label1 = LabelTimeBoardingsAndRoute( + departure_time=0, + arrival_time_target=20, + movement_duration=1, + n_boardings=1, + first_leg_is_walk=False, + ) self.assertEqual(20, label1.duration()) def test_sort(self): - l1 = LabelTimeBoardingsAndRoute(departure_time=1, arrival_time_target=1, movement_duration=3, - n_boardings=1, first_leg_is_walk=False) + l1 = LabelTimeBoardingsAndRoute( + departure_time=1, + arrival_time_target=1, + movement_duration=3, + n_boardings=1, + first_leg_is_walk=False, + ) l2 = LabelTimeBoardingsAndRoute(0, 0, 0, 1, False) self.assertTrue(l1 > l2) self.assertTrue(l1 >= l2) self.assertFalse(l1 < l2) self.assertFalse(l1 <= l2) - l1 = LabelTimeBoardingsAndRoute(0, 0, 0, 1, False) - l2 = LabelTimeBoardingsAndRoute(0, 0, 0, 1, False) + l1 = LabelTimeBoardingsAndRoute(0, 0, 0, 1, False) + l2 = LabelTimeBoardingsAndRoute(0, 0, 0, 1, False) self.assertTrue(l1 == l2) self.assertTrue(l1 >= l2) self.assertTrue(l1 <= l2) @@ -476,30 +696,42 @@ def test_large_numbers_do_not_overflow(self): arrival_time_target=float(arrival_time), movement_duration=0, n_boardings=1, - first_leg_is_walk=False + first_leg_is_walk=False, ) self.assertEqual(departure_time, label.departure_time) self.assertEqual(arrival_time, label.arrival_time_target) - - - class TestParetoFrontier(TestCase): - def test_compute_pareto_front_all_include(self): - label_a = LabelTimeWithBoardingsCount(departure_time=1, arrival_time_target=2, n_boardings=0, first_leg_is_walk=False) - label_b = LabelTimeWithBoardingsCount(departure_time=2, arrival_time_target=3, n_boardings=0, first_leg_is_walk=False) - label_c = LabelTimeWithBoardingsCount(departure_time=3, arrival_time_target=4, n_boardings=0, first_leg_is_walk=False) - label_d = LabelTimeWithBoardingsCount(departure_time=4, arrival_time_target=5, n_boardings=0, first_leg_is_walk=False) + label_a = LabelTimeWithBoardingsCount( + departure_time=1, arrival_time_target=2, n_boardings=0, first_leg_is_walk=False + ) + label_b = LabelTimeWithBoardingsCount( + departure_time=2, arrival_time_target=3, n_boardings=0, first_leg_is_walk=False + ) + label_c = LabelTimeWithBoardingsCount( + departure_time=3, arrival_time_target=4, n_boardings=0, first_leg_is_walk=False + ) + label_d = LabelTimeWithBoardingsCount( + departure_time=4, arrival_time_target=5, n_boardings=0, first_leg_is_walk=False + ) labels = [label_a, label_b, label_c, label_d] self.assertEqual(4, len(compute_pareto_front(labels))) def test_one_dominates_all(self): - label_a = LabelTimeWithBoardingsCount(departure_time=1, arrival_time_target=12, n_boardings=0, first_leg_is_walk=False) - label_b = LabelTimeWithBoardingsCount(departure_time=2, arrival_time_target=13, n_boardings=0, first_leg_is_walk=False) - label_c = LabelTimeWithBoardingsCount(departure_time=3, arrival_time_target=14, n_boardings=0, first_leg_is_walk=False) - label_d = LabelTimeWithBoardingsCount(departure_time=4, arrival_time_target=5, n_boardings=0, first_leg_is_walk=False) + label_a = LabelTimeWithBoardingsCount( + departure_time=1, arrival_time_target=12, n_boardings=0, first_leg_is_walk=False + ) + label_b = LabelTimeWithBoardingsCount( + departure_time=2, arrival_time_target=13, n_boardings=0, first_leg_is_walk=False + ) + label_c = LabelTimeWithBoardingsCount( + departure_time=3, arrival_time_target=14, n_boardings=0, first_leg_is_walk=False + ) + label_d = LabelTimeWithBoardingsCount( + departure_time=4, arrival_time_target=5, n_boardings=0, first_leg_is_walk=False + ) labels = [label_a, label_b, label_c, label_d] pareto_front = compute_pareto_front(labels) self.assertEqual(1, len(pareto_front)) @@ -510,10 +742,18 @@ def test_empty(self): self.assertEqual(0, len(compute_pareto_front(labels))) def test_some_are_optimal_some_are_not(self): - label_a = LabelTimeWithBoardingsCount(departure_time=1, arrival_time_target=2, n_boardings=1, first_leg_is_walk=False) # optimal - label_b = LabelTimeWithBoardingsCount(departure_time=1, arrival_time_target=5, n_boardings=0, first_leg_is_walk=False) # label_d dominates - label_c = LabelTimeWithBoardingsCount(departure_time=3, arrival_time_target=4, n_boardings=1, first_leg_is_walk=False) # optimal - label_d = LabelTimeWithBoardingsCount(departure_time=4, arrival_time_target=5, n_boardings=0, first_leg_is_walk=False) # optimal + label_a = LabelTimeWithBoardingsCount( + departure_time=1, arrival_time_target=2, n_boardings=1, first_leg_is_walk=False + ) # optimal + label_b = LabelTimeWithBoardingsCount( + departure_time=1, arrival_time_target=5, n_boardings=0, first_leg_is_walk=False + ) # label_d dominates + label_c = LabelTimeWithBoardingsCount( + departure_time=3, arrival_time_target=4, n_boardings=1, first_leg_is_walk=False + ) # optimal + label_d = LabelTimeWithBoardingsCount( + departure_time=4, arrival_time_target=5, n_boardings=0, first_leg_is_walk=False + ) # optimal labels = [label_a, label_b, label_c, label_d] pareto_front = compute_pareto_front(labels) @@ -521,10 +761,18 @@ def test_some_are_optimal_some_are_not(self): self.assertNotIn(label_b, pareto_front) def test_merge_pareto_frontiers(self): - label_a = LabelTimeWithBoardingsCount(departure_time=1, arrival_time_target=2, n_boardings=1, first_leg_is_walk=False) # optimal - label_b = LabelTimeWithBoardingsCount(departure_time=1, arrival_time_target=5, n_boardings=0, first_leg_is_walk=False) # d dominates - label_c = LabelTimeWithBoardingsCount(departure_time=3, arrival_time_target=4, n_boardings=1, first_leg_is_walk=False) # optimal - label_d = LabelTimeWithBoardingsCount(departure_time=4, arrival_time_target=5, n_boardings=0, first_leg_is_walk=False) # optimal + label_a = LabelTimeWithBoardingsCount( + departure_time=1, arrival_time_target=2, n_boardings=1, first_leg_is_walk=False + ) # optimal + label_b = LabelTimeWithBoardingsCount( + departure_time=1, arrival_time_target=5, n_boardings=0, first_leg_is_walk=False + ) # d dominates + label_c = LabelTimeWithBoardingsCount( + departure_time=3, arrival_time_target=4, n_boardings=1, first_leg_is_walk=False + ) # optimal + label_d = LabelTimeWithBoardingsCount( + departure_time=4, arrival_time_target=5, n_boardings=0, first_leg_is_walk=False + ) # optimal front_1 = [label_a, label_b] front_2 = [label_c, label_d] @@ -538,13 +786,16 @@ def test_merge_pareto_frontiers_empty(self): def test_compute_pareto_front_smart(self): labels = [] - for n in [1, 2, 10]: #, 500]: + for n in [1, 2, 10]: # , 500]: for dep_time in range(0, n): for n_veh_legs in range(2): - for arr_time in range(dep_time, dep_time + 10): - label = LabelTimeWithBoardingsCount(dep_time, arr_time - n_veh_legs, n_veh_legs, False) + for arr_time in range(dep_time, dep_time + 10): + label = LabelTimeWithBoardingsCount( + dep_time, arr_time - n_veh_legs, n_veh_legs, False + ) labels.append(label) import random + random.shuffle(labels) labels_copy = copy.deepcopy(labels) pareto_optimal_labels = compute_pareto_front_naive(labels) @@ -554,10 +805,14 @@ def test_compute_pareto_front_smart(self): def test_compute_pareto_front_smart_randomized(self): import random + for i in range(10): - labels = [LabelTimeWithBoardingsCount(random.randint(0, 1000), random.randint(0, 1000), random.randint(0, 10), 0) - for _ in range(1000)] + labels = [ + LabelTimeWithBoardingsCount( + random.randint(0, 1000), random.randint(0, 1000), random.randint(0, 10), 0 + ) + for _ in range(1000) + ] pareto_optimal_labels_old = compute_pareto_front_naive(labels) pareto_optimal_labels_smart = compute_pareto_front(labels) self.assertEqual(len(pareto_optimal_labels_old), len(pareto_optimal_labels_smart)) - diff --git a/gtfspy/routing/test/test_multi_objective_pseudo_connection_scan_profiler.py b/gtfspy/routing/test/test_multi_objective_pseudo_connection_scan_profiler.py index d2e810d..daba4d5 100644 --- a/gtfspy/routing/test/test_multi_objective_pseudo_connection_scan_profiler.py +++ b/gtfspy/routing/test/test_multi_objective_pseudo_connection_scan_profiler.py @@ -5,12 +5,16 @@ from gtfspy.routing.connection import Connection from gtfspy.routing.label import min_arrival_time_target, LabelTimeWithBoardingsCount, LabelTime -from gtfspy.routing.multi_objective_pseudo_connection_scan_profiler import MultiObjectivePseudoCSAProfiler +from gtfspy.routing.multi_objective_pseudo_connection_scan_profiler import ( + MultiObjectivePseudoCSAProfiler, +) from gtfspy.routing.node_profile_multiobjective import NodeProfileMultiObjective import pyximport + pyximport.install() + class TestMultiObjectivePseudoCSAProfiler(TestCase): # noinspection PyAttributeOutsideInit @@ -21,7 +25,7 @@ def setUp(self): (3, 4, 32, 35, "trip_4", 1), (2, 3, 25, 30, "trip_3", 1), (1, 2, 10, 20, "trip_2", 1), - (0, 1, 0, 10, "trip_1", 1) + (0, 1, 0, 10, "trip_1", 1), ] self.transit_connections = list(map(lambda el: Connection(*el), event_list_raw_data)) self.walk_network = networkx.Graph() @@ -34,10 +38,7 @@ def setUp(self): self.end_time = 50 def test_pseudo_connections(self): - event_list_raw_data = [ - (0, 1, 10, 20, "trip_6", 1), - (2, 3, 42, 50, "trip_5", 1) - ] + event_list_raw_data = [(0, 1, 10, 20, "trip_6", 1), (2, 3, 42, 50, "trip_5", 1)] transit_connections = list(map(lambda el: Connection(*el), event_list_raw_data)) walk_network = networkx.Graph() walk_network.add_edge(1, 2, {"d_walk": 20}) @@ -46,9 +47,15 @@ def test_pseudo_connections(self): transfer_margin = 0 start_time = 0 end_time = 50 - csa_profile = MultiObjectivePseudoCSAProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - walk_network, walk_speed) + csa_profile = MultiObjectivePseudoCSAProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + ) self.assertEqual(len(csa_profile._all_connections), 3) pseudo_connection = csa_profile._all_connections[1] self.assertTrue(pseudo_connection.is_walk) @@ -77,15 +84,14 @@ def test_pseudo_connections(self): self.assertIsInstance(arrival_stop_profile, NodeProfileMultiObjective) self.assertIsInstance(departure_stop_profile, NodeProfileMultiObjective) self.assertIn(connection.departure_time, departure_stop_profile.dep_times_to_index) - if connection.arrival_stop_next_departure_time != float('inf'): - self.assertIn(connection.arrival_stop_next_departure_time, arrival_stop_profile.dep_times_to_index) - + if connection.arrival_stop_next_departure_time != float("inf"): + self.assertIn( + connection.arrival_stop_next_departure_time, + arrival_stop_profile.dep_times_to_index, + ) def test_pseudo_connections_with_transfer_margin(self): - event_list_raw_data = [ - (0, 1, 10, 20, "trip_6", 1), - (2, 3, 42, 50, "trip_5", 1) - ] + event_list_raw_data = [(0, 1, 10, 20, "trip_6", 1), (2, 3, 42, 50, "trip_5", 1)] transit_connections = list(map(lambda el: Connection(*el), event_list_raw_data)) walk_network = networkx.Graph() walk_network.add_edge(1, 2, {"d_walk": 10}) @@ -94,9 +100,15 @@ def test_pseudo_connections_with_transfer_margin(self): transfer_margin = 5 start_time = 0 end_time = 50 - csa_profile = MultiObjectivePseudoCSAProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - walk_network, walk_speed) + csa_profile = MultiObjectivePseudoCSAProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + ) transfer_connection = csa_profile._all_connections[1] self.assertEqual(transfer_connection.arrival_stop, 2) self.assertEqual(transfer_connection.arrival_stop_next_departure_time, 42) @@ -106,30 +118,59 @@ def test_pseudo_connections_with_transfer_margin(self): self.assertEqual(transfer_connection.arrival_time, 42) def test_basics(self): - csa_profile = MultiObjectivePseudoCSAProfiler(self.transit_connections, self.target_stop, - self.start_time, self.end_time, self.transfer_margin, - self.walk_network, self.walk_speed) + csa_profile = MultiObjectivePseudoCSAProfiler( + self.transit_connections, + self.target_stop, + self.start_time, + self.end_time, + self.transfer_margin, + self.walk_network, + self.walk_speed, + ) csa_profile.run() stop_3_labels = csa_profile.stop_profiles[3].get_final_optimal_labels() self.assertEqual(len(stop_3_labels), 1) - self.assertIn(LabelTimeWithBoardingsCount(32, 35, n_boardings=1, first_leg_is_walk=False), stop_3_labels) + self.assertIn( + LabelTimeWithBoardingsCount(32, 35, n_boardings=1, first_leg_is_walk=False), + stop_3_labels, + ) stop_2_labels = csa_profile.stop_profiles[2].get_final_optimal_labels() self.assertEqual(len(stop_2_labels), 3) - self.assertIn(LabelTimeWithBoardingsCount(40, 50, n_boardings=1, first_leg_is_walk=False), stop_2_labels) - self.assertIn(LabelTimeWithBoardingsCount(25, 35, n_boardings=2, first_leg_is_walk=False), stop_2_labels) - self.assertIn(LabelTimeWithBoardingsCount(25, 45, n_boardings=1, first_leg_is_walk=False), stop_2_labels) - + self.assertIn( + LabelTimeWithBoardingsCount(40, 50, n_boardings=1, first_leg_is_walk=False), + stop_2_labels, + ) + self.assertIn( + LabelTimeWithBoardingsCount(25, 35, n_boardings=2, first_leg_is_walk=False), + stop_2_labels, + ) + self.assertIn( + LabelTimeWithBoardingsCount(25, 45, n_boardings=1, first_leg_is_walk=False), + stop_2_labels, + ) - stop_one_profile = csa_profile.stop_profiles[1] - stop_one_pareto_labels = stop_one_profile.get_final_optimal_labels() + # stop_one_profile = csa_profile.stop_profiles[1] + # stop_one_pareto_labels = stop_one_profile.get_final_optimal_labels() labels = list() # these should exist at least: - labels.append(LabelTimeWithBoardingsCount(departure_time=10, arrival_time_target=35, n_boardings=3, first_leg_is_walk=False)) - labels.append(LabelTimeWithBoardingsCount(departure_time=20, arrival_time_target=50, n_boardings=1, first_leg_is_walk=False)) - labels.append(LabelTimeWithBoardingsCount(departure_time=32, arrival_time_target=55, n_boardings=1, first_leg_is_walk=False)) + labels.append( + LabelTimeWithBoardingsCount( + departure_time=10, arrival_time_target=35, n_boardings=3, first_leg_is_walk=False + ) + ) + labels.append( + LabelTimeWithBoardingsCount( + departure_time=20, arrival_time_target=50, n_boardings=1, first_leg_is_walk=False + ) + ) + labels.append( + LabelTimeWithBoardingsCount( + departure_time=32, arrival_time_target=55, n_boardings=1, first_leg_is_walk=False + ) + ) def test_multiple_targets(self): event_list_raw_data = [ @@ -145,9 +186,15 @@ def test_multiple_targets(self): start_time = 0 end_time = 60 - csa_profile = MultiObjectivePseudoCSAProfiler(transit_connections, targets, - start_time, end_time, transfer_margin, - walk_network, walk_speed) + csa_profile = MultiObjectivePseudoCSAProfiler( + transit_connections, + targets, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + ) csa_profile.run() source_stop_profile = csa_profile.stop_profiles[source_stop] final_labels = source_stop_profile.get_final_optimal_labels() @@ -168,9 +215,15 @@ def test_simple(self): start_time = 0 end_time = 50 - csa_profile = MultiObjectivePseudoCSAProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - walk_network, walk_speed) + csa_profile = MultiObjectivePseudoCSAProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + ) csa_profile.run() source_stop_profile = csa_profile.stop_profiles[source_stop] self.assertTrue(source_stop_profile._finalized) @@ -179,20 +232,16 @@ def test_simple(self): source_stop_labels = source_stop_profile.get_final_optimal_labels() labels = list() - labels.append(LabelTimeWithBoardingsCount(departure_time=20, - arrival_time_target=50, - n_boardings=1, - first_leg_is_walk=True)) - - self._assert_label_sets_equal( - labels, - source_stop_labels + labels.append( + LabelTimeWithBoardingsCount( + departure_time=20, arrival_time_target=50, n_boardings=1, first_leg_is_walk=True + ) ) + self._assert_label_sets_equal(labels, source_stop_labels) + def test_last_leg_is_walk(self): - event_list_raw_data = [ - (0, 1, 0, 10, "trip_1", 1) - ] + event_list_raw_data = [(0, 1, 0, 10, "trip_1", 1)] transit_connections = list(map(lambda el: Connection(*el), event_list_raw_data)) walk_network = networkx.Graph() walk_network.add_edge(1, 2, {"d_walk": 20}) @@ -204,19 +253,27 @@ def test_last_leg_is_walk(self): start_time = 0 end_time = 50 labels = list() - labels.append(LabelTimeWithBoardingsCount(departure_time=0, arrival_time_target=30, n_boardings=1, first_leg_is_walk=False)) + labels.append( + LabelTimeWithBoardingsCount( + departure_time=0, arrival_time_target=30, n_boardings=1, first_leg_is_walk=False + ) + ) - csa_profile = MultiObjectivePseudoCSAProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - walk_network, walk_speed) + csa_profile = MultiObjectivePseudoCSAProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + ) csa_profile.run() found_tuples = csa_profile.stop_profiles[source_stop].get_final_optimal_labels() self._assert_label_sets_equal(found_tuples, labels) def test_walk_is_faster_than_by_trip(self): - event_list_raw_data = [ - (0, 1, 0, 10, "trip_1", 1) - ] + event_list_raw_data = [(0, 1, 0, 10, "trip_1", 1)] transit_connections = list(map(lambda el: Connection(*el), event_list_raw_data)) walk_speed = 0.5 source_stop = 0 @@ -227,12 +284,20 @@ def test_walk_is_faster_than_by_trip(self): walk_network = networkx.Graph() walk_network.add_edge(0, 1, {"d_walk": 1}) - csa_profile = MultiObjectivePseudoCSAProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - walk_network, walk_speed) + csa_profile = MultiObjectivePseudoCSAProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + ) csa_profile.run() source_profile = csa_profile.stop_profiles[source_stop] - self.assertEqual(min_arrival_time_target(source_profile.evaluate(0, first_leg_can_be_walk=True)), 2) + self.assertEqual( + min_arrival_time_target(source_profile.evaluate(0, first_leg_can_be_walk=True)), 2 + ) found_tuples = source_profile.get_final_optimal_labels() self.assertEqual(len(found_tuples), 0) @@ -247,7 +312,7 @@ def test_no_multiple_walks(self): (1, 2, 5, 6, "trip_7", 1), (2, 1, 5, 6, "trip_8", 1), (1, 2, 2, 3, "trip_7", 2), - (2, 1, 2, 3, "trip_8", 2) + (2, 1, 2, 3, "trip_8", 2), ] transit_connections = list(map(lambda el: Connection(*el), event_list_raw_data)) walk_network = networkx.Graph() @@ -258,9 +323,9 @@ def test_no_multiple_walks(self): start_time = 0 end_time = 50 - csa_profile = MultiObjectivePseudoCSAProfiler(transit_connections, 2, - start_time, end_time, transfer_margin, - walk_network, walk_speed) + csa_profile = MultiObjectivePseudoCSAProfiler( + transit_connections, 2, start_time, end_time, transfer_margin, walk_network, walk_speed + ) csa_profile.run() source_profile = csa_profile.stop_profiles[0] print(source_profile.get_final_optimal_labels()) @@ -268,9 +333,7 @@ def test_no_multiple_walks(self): self.assertGreater(label.n_boardings, 0) def test_target_node_not_in_walk_network(self): - event_list_raw_data = [ - (0, 1, 0, 10, "trip_1", 1) - ] + event_list_raw_data = [(0, 1, 0, 10, "trip_1", 1)] transit_connections = list(map(lambda el: Connection(*el), event_list_raw_data)) walk_speed = 2 source_stop = 0 @@ -280,9 +343,15 @@ def test_target_node_not_in_walk_network(self): end_time = 50 walk_network = networkx.Graph() - csa_profile = MultiObjectivePseudoCSAProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - walk_network, walk_speed) + csa_profile = MultiObjectivePseudoCSAProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + ) csa_profile.run() source_profile = csa_profile.stop_profiles[source_stop] self.assertEqual(min_arrival_time_target(source_profile.evaluate(0, 0)), 10) @@ -293,7 +362,7 @@ def test_pareto_optimality(self): event_list_raw_data = [ (0, 2, 0, 10, "trip_1", 1), (0, 1, 2, 5, "trip_2", 1), - (1, 2, 5, 8, "trip_3", 1) + (1, 2, 5, 8, "trip_3", 1), ] transit_connections = list(map(lambda el: Connection(*el), event_list_raw_data)) walk_speed = 2 @@ -303,16 +372,26 @@ def test_pareto_optimality(self): start_time = 0 end_time = 20 walk_network = networkx.Graph() - csa_profile = MultiObjectivePseudoCSAProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - walk_network, walk_speed) + csa_profile = MultiObjectivePseudoCSAProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + ) csa_profile.run() source_profile = csa_profile.stop_profiles[source_stop] self.assertEqual(min_arrival_time_target(source_profile.evaluate(0, 0)), 8) found_labels = source_profile.get_final_optimal_labels() labels_should_be = list() - labels_should_be.append(LabelTimeWithBoardingsCount(0, 10, n_boardings=1, first_leg_is_walk=False)) - labels_should_be.append(LabelTimeWithBoardingsCount(2, 8, n_boardings=2, first_leg_is_walk=False)) + labels_should_be.append( + LabelTimeWithBoardingsCount(0, 10, n_boardings=1, first_leg_is_walk=False) + ) + labels_should_be.append( + LabelTimeWithBoardingsCount(2, 8, n_boardings=2, first_leg_is_walk=False) + ) self._assert_label_sets_equal(found_labels, labels_should_be) def test_transfer_margin(self): @@ -327,9 +406,15 @@ def test_transfer_margin(self): ] # case without any transfer margin transfer_margin = 0 - csa_profile = MultiObjectivePseudoCSAProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - networkx.Graph(), walk_speed) + csa_profile = MultiObjectivePseudoCSAProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + networkx.Graph(), + walk_speed, + ) csa_profile.run() stop_profile_1 = csa_profile.stop_profiles[1] stop_profile_3 = csa_profile.stop_profiles[3] @@ -338,9 +423,15 @@ def test_transfer_margin(self): # case with transfer margin transfer_margin = 1 - csa_profile = MultiObjectivePseudoCSAProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - networkx.Graph(), walk_speed) + csa_profile = MultiObjectivePseudoCSAProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + networkx.Graph(), + walk_speed, + ) csa_profile.run() stop_profile_3 = csa_profile.stop_profiles[3] stop_profile_1 = csa_profile.stop_profiles[1] @@ -356,13 +447,19 @@ def test_possible_transfer_margin_bug_with_multiple_arrivals(self): transit_connections = [ Connection(0, 1, 100, 101, "trip_0", 1), Connection(4, 1, 102, 104, "trip_1", 1), - Connection(2, 3, 106, 108, "trip_2", 1) + Connection(2, 3, 106, 108, "trip_2", 1), ] walk_network = networkx.Graph() walk_network.add_edge(1, 2, {"d_walk": 1}) - csa_profile = MultiObjectivePseudoCSAProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - walk_network, walk_speed) + csa_profile = MultiObjectivePseudoCSAProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + ) csa_profile.run() profile = csa_profile.stop_profiles[4] self.assertEqual(len(profile.get_final_optimal_labels()), 0) @@ -391,31 +488,48 @@ def test_transfer_margin_with_walk(self): journey_dep_times = [1030, 1020, 1010, 1000, 1030] for transfer_margin, dep_time in zip(transfer_margins, journey_dep_times): - csa_profile = MultiObjectivePseudoCSAProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - walk_network, walk_speed) + csa_profile = MultiObjectivePseudoCSAProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + ) csa_profile.run() profile = csa_profile.stop_profiles[0] - self.assertEqual(len(profile.get_final_optimal_labels()), 1, "transfer_margin=" + str(transfer_margin)) + self.assertEqual( + len(profile.get_final_optimal_labels()), + 1, + "transfer_margin=" + str(transfer_margin), + ) label = profile.get_final_optimal_labels()[0] - self.assertEqual(label.departure_time, dep_time, "transfer_margin=" + str(transfer_margin)) + self.assertEqual( + label.departure_time, dep_time, "transfer_margin=" + str(transfer_margin) + ) def test_basics_no_transfer_tracking(self): csa_profile = MultiObjectivePseudoCSAProfiler( - self.transit_connections, self.target_stop, - self.start_time, self.end_time, self.transfer_margin, - self.walk_network, self.walk_speed, track_vehicle_legs=False + self.transit_connections, + self.target_stop, + self.start_time, + self.end_time, + self.transfer_margin, + self.walk_network, + self.walk_speed, + track_vehicle_legs=False, ) csa_profile.run() stop_3_pareto_tuples = csa_profile.stop_profiles[3].get_final_optimal_labels() self.assertEqual(len(stop_3_pareto_tuples), 1) - self.assertIn(LabelTime(32., 35.), stop_3_pareto_tuples) + self.assertIn(LabelTime(32.0, 35.0), stop_3_pareto_tuples) stop_2_pareto_tuples = csa_profile.stop_profiles[2].get_final_optimal_labels() self.assertEqual(len(stop_2_pareto_tuples), 2) - self.assertIn(LabelTime(40., 50.), stop_2_pareto_tuples) - self.assertIn(LabelTime(25., 35.), stop_2_pareto_tuples) + self.assertIn(LabelTime(40.0, 50.0), stop_2_pareto_tuples) + self.assertIn(LabelTime(25.0, 35.0), stop_2_pareto_tuples) source_stop_profile = csa_profile.stop_profiles[1] source_stop_pareto_optimal_tuples = source_stop_profile.get_final_optimal_labels() @@ -425,10 +539,7 @@ def test_basics_no_transfer_tracking(self): pareto_tuples.append(LabelTime(departure_time=20, arrival_time_target=50)) pareto_tuples.append(LabelTime(departure_time=32, arrival_time_target=55)) - self._assert_label_sets_equal( - pareto_tuples, - source_stop_pareto_optimal_tuples - ) + self._assert_label_sets_equal(pareto_tuples, source_stop_pareto_optimal_tuples) def test_transfers_only(self): event_list_raw_data = [ @@ -445,23 +556,25 @@ def test_transfers_only(self): start_time = 0 end_time = 50 - csa_profile = MultiObjectivePseudoCSAProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - walk_network, walk_speed, track_time=False) + csa_profile = MultiObjectivePseudoCSAProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + track_time=False, + ) csa_profile.run() - stop_to_n_boardings = { - 2: 1, - 7: 2, - 3: 0 - } + stop_to_n_boardings = {2: 1, 7: 2, 3: 0} for stop, n_veh_legs in stop_to_n_boardings.items(): labels = csa_profile.stop_profiles[stop].get_final_optimal_labels() self.assertEqual(len(labels), 1) self.assertEqual(labels[0].n_boardings, n_veh_legs) - def test_reset(self): walk_speed = 1 target_stop = 2 @@ -471,11 +584,17 @@ def test_reset(self): transit_connections = [ Connection(0, 1, 40, 50, "trip_1", 1), Connection(1, 2, 55, 60, "trip_1", 1), - Connection(3, 1, 40, 60, "trip_2", 1) + Connection(3, 1, 40, 60, "trip_2", 1), ] - csa_profile = MultiObjectivePseudoCSAProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - networkx.Graph(), walk_speed) + csa_profile = MultiObjectivePseudoCSAProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + networkx.Graph(), + walk_speed, + ) csa_profile.run() nodes = [0, 1, 2, 3] label_counts = [1, 1, 0, 0] @@ -499,21 +618,28 @@ def test_550_problem(self): # There used to be a problem when working with real unixtimes (c-side floating point number problems), # this test is one check for that event_data = StringIO( - "from_stop_I,to_stop_I,dep_time_ut,arr_time_ut,route_type,route_id,trip_I,seq\n" + - "2198,2247,1475530740,1475530860,3,2550,158249,36\n" + - "2247,2177,1475530860,1475530980,3,2550,158249,37\n") + "from_stop_I,to_stop_I,dep_time_ut,arr_time_ut,route_type,route_id,trip_I,seq\n" + + "2198,2247,1475530740,1475530860,3,2550,158249,36\n" + + "2247,2177,1475530860,1475530980,3,2550,158249,37\n" + ) import pandas as pd + events = pd.read_csv(event_data) events.sort_values("dep_time_ut", ascending=False, inplace=True) connections = [ - Connection(int(e.from_stop_I), int(e.to_stop_I), int(e.dep_time_ut), int(e.arr_time_ut), - int(e.trip_I), - int(e.seq)) + Connection( + int(e.from_stop_I), + int(e.to_stop_I), + int(e.dep_time_ut), + int(e.arr_time_ut), + int(e.trip_I), + int(e.seq), + ) for e in events.itertuples() ] - csa_profiler = MultiObjectivePseudoCSAProfiler(connections, 2177, - 0, 1475530860*10, 0, - networkx.Graph(), 0) + csa_profiler = MultiObjectivePseudoCSAProfiler( + connections, 2177, 0, 1475530860 * 10, 0, networkx.Graph(), 0 + ) csa_profiler.run() @@ -536,15 +662,21 @@ def test_transfer_on_same_stop_with_multiple_departures(self): Connection(4, 1, 50, 60, "trip_2", 1), Connection(4, 2, 50, 60, "trip_3", 1), Connection(4, 3, 50, 60, "trip_4", 1), - Connection(4, target_stop, 70, 100, "trip_5", 1) + Connection(4, target_stop, 70, 100, "trip_5", 1), ] - csa_profiler = MultiObjectivePseudoCSAProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - networkx.Graph(), walk_speed) + csa_profiler = MultiObjectivePseudoCSAProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + networkx.Graph(), + walk_speed, + ) csa_profiler.run() profiles = csa_profiler.stop_profiles - assert(profiles[0].get_final_optimal_labels()[0]) - assert(len(profiles[0].get_final_optimal_labels()) > 0) + assert profiles[0].get_final_optimal_labels()[0] + assert len(profiles[0].get_final_optimal_labels()) > 0 def test_transfer_connections_do_not_affect_transfers(self): walk_speed = 1000 @@ -557,20 +689,25 @@ def test_transfer_connections_do_not_affect_transfers(self): Connection(3, 4, 45, 50, "trip_2", 1), Connection(4, 3, 45, 50, "trip_3", 1), Connection(5, 3, 45, 50, "trip_4", 1), - Connection(1, target_stop, 70, 100, "trip_5", 1) + Connection(1, target_stop, 70, 100, "trip_5", 1), ] walk_network = networkx.Graph() walk_network.add_edge(1, 3, {"d_walk": 1}) walk_network.add_edge(1, 4, {"d_walk": 1}) walk_network.add_edge(1, 5, {"d_walk": 1}) - csa_profiler = MultiObjectivePseudoCSAProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - walk_network, walk_speed) + csa_profiler = MultiObjectivePseudoCSAProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + ) csa_profiler.run() profiles = csa_profiler.stop_profiles - assert(profiles[0].get_final_optimal_labels()[0]) - assert(len(profiles[0].get_final_optimal_labels()) > 0) - + assert profiles[0].get_final_optimal_labels()[0] + assert len(profiles[0].get_final_optimal_labels()) > 0 def test_transfer_connections_do_not_affect_transfers2(self): walk_speed = 1 @@ -581,17 +718,23 @@ def test_transfer_connections_do_not_affect_transfers2(self): transit_connections = [ Connection(3, 0, 10, 11, "trip_1", 1), Connection(2, 1, 5, 6, "trip_2", 1), - Connection(4, 3, 0, 1, "trip_3", 1) + Connection(4, 3, 0, 1, "trip_3", 1), ] walk_network = networkx.Graph() walk_network.add_edge(2, 3, {"d_walk": 1}) walk_network.add_edge(1, 0, {"d_walk": 1}) - csa_profiler = MultiObjectivePseudoCSAProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - walk_network, walk_speed) + csa_profiler = MultiObjectivePseudoCSAProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + ) csa_profiler.run() profiles = csa_profiler.stop_profiles - assert(len(profiles[4].get_final_optimal_labels()) == 1) + assert len(profiles[4].get_final_optimal_labels()) == 1 optimal_label = profiles[4].get_final_optimal_labels()[0] self.assertEqual(optimal_label.departure_time, 0) self.assertEqual(optimal_label.arrival_time_target, 7) @@ -608,21 +751,27 @@ def test_transfer_connections_do_not_affect_transfers3(self): Connection(2, 1, 5, 6, "t2", 1), Connection(7, 2, 3, 4, "tX", 1), Connection(5, 6, 2, 3, "--", 1), - Connection(4, 3, 0, 1, "t3", 1) + Connection(4, 3, 0, 1, "t3", 1), ] walk_network = networkx.Graph() walk_network.add_edge(7, 3, {"d_walk": 1}) walk_network.add_edge(1, 0, {"d_walk": 1}) walk_network.add_edge(5, 3, {"d_walk": 1}) - csa_profiler = MultiObjectivePseudoCSAProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - walk_network, walk_speed) + csa_profiler = MultiObjectivePseudoCSAProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + ) csa_profiler.run() profiles = csa_profiler.stop_profiles print(profiles[4].get_final_optimal_labels()[0]) optimal_labels = profiles[4].get_final_optimal_labels() - assert(len(optimal_labels) == 2) + assert len(optimal_labels) == 2 boardings_to_arr_time = {} for label in optimal_labels: boardings_to_arr_time[label.n_boardings] = label.arrival_time_target @@ -641,9 +790,16 @@ def test_stored_route(self): # - test with multiple targets # - test with continuing route # - test that timestamps for label and the connection objects match - csa_profile = MultiObjectivePseudoCSAProfiler(self.transit_connections, self.target_stop, - self.start_time, self.end_time, self.transfer_margin, - self.walk_network, self.walk_speed, track_route=True) + csa_profile = MultiObjectivePseudoCSAProfiler( + self.transit_connections, + self.target_stop, + self.start_time, + self.end_time, + self.transfer_margin, + self.walk_network, + self.walk_speed, + track_route=True, + ) csa_profile.run() for stop, profile in csa_profile.stop_profiles.items(): for bag in profile._label_bags: @@ -674,9 +830,7 @@ def test_stored_route(self): prev_arr_node = arr_node def test_target_self_loops(self): - event_list_raw_data = [ - (3, 1, 30, 40, "trip_3", 1) - ] + event_list_raw_data = [(3, 1, 30, 40, "trip_3", 1)] transit_connections = list(map(lambda el: Connection(*el), event_list_raw_data)) walk_network = networkx.Graph() walk_network.add_edge(1, 3, {"d_walk": 11}) @@ -687,10 +841,18 @@ def test_target_self_loops(self): end_time = 50 print(walk_network.edges()) print(transit_connections) - csa_profile = MultiObjectivePseudoCSAProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - walk_network, walk_speed, track_vehicle_legs=True, - track_time=True, track_route=True) + csa_profile = MultiObjectivePseudoCSAProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + track_vehicle_legs=True, + track_time=True, + track_route=True, + ) csa_profile.run() for stop, profile in csa_profile.stop_profiles.items(): if stop == target_stop: @@ -717,7 +879,6 @@ def unpack_route_from_labels(cur_label): (1, 2, 0, 10, "trip_1", 1), (2, 3, 10, 20, "trip_1", 1), (4, 5, 30, 40, "trip_2", 1), - ] transit_connections = list(map(lambda el: Connection(*el), event_list_raw_data)) walk_network = networkx.Graph() @@ -729,22 +890,44 @@ def unpack_route_from_labels(cur_label): start_time = 0 end_time = 50 - csa_profile = MultiObjectivePseudoCSAProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - walk_network, walk_speed, track_vehicle_legs=False, - track_time=True, track_route=True) + csa_profile = MultiObjectivePseudoCSAProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + track_vehicle_legs=False, + track_time=True, + track_route=True, + ) csa_profile.run() for stop, profile in csa_profile.stop_profiles.items(): for label_bag in profile._label_bags: for label in label_bag: - print('origin:', stop, 'n_boardings/movement_duration:', label.movement_duration, 'route:', unpack_route_from_labels(label)) - print('optimal labels:') + print( + "origin:", + stop, + "n_boardings/movement_duration:", + label.movement_duration, + "route:", + unpack_route_from_labels(label), + ) + print("optimal labels:") for stop, profile in csa_profile.stop_profiles.items(): for label in profile.get_final_optimal_labels(): - print('origin:', stop, 'n_boardings/movement_duration:', label.movement_duration, 'route:', unpack_route_from_labels(label)) - #if stop == 1: - #assert 3 not in unpack_route_from_labels(label) + print( + "origin:", + stop, + "n_boardings/movement_duration:", + label.movement_duration, + "route:", + unpack_route_from_labels(label), + ) + # if stop == 1: + # assert 3 not in unpack_route_from_labels(label) # print('origin:', stop, 'n_boardings:', label.n_boardings, 'route:', unpack_route_from_labels(label)) def test_journeys_using_movement_duration_last_stop_walk(self): @@ -770,7 +953,6 @@ def unpack_route_from_labels(cur_label): (1, 2, 0, 10, "trip_1", 1), (2, 3, 10, 20, "trip_2", 1), (4, 5, 30, 40, "trip_3", 1), - ] transit_connections = list(map(lambda el: Connection(*el), event_list_raw_data)) walk_network = networkx.Graph() @@ -783,34 +965,48 @@ def unpack_route_from_labels(cur_label): start_time = 0 end_time = 50 - csa_profile = MultiObjectivePseudoCSAProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - walk_network, walk_speed, track_vehicle_legs=False, - track_time=True, track_route=True) + csa_profile = MultiObjectivePseudoCSAProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + track_vehicle_legs=False, + track_time=True, + track_route=True, + ) csa_profile.run() for stop, profile in csa_profile.stop_profiles.items(): for label_bag in profile._label_bags: for label in label_bag: - print('origin:', stop, - 'n_boardings/movement_duration:', label.movement_duration, - 'route:', unpack_route_from_labels(label)) - print('optimal labels:') + print( + "origin:", + stop, + "n_boardings/movement_duration:", + label.movement_duration, + "route:", + unpack_route_from_labels(label), + ) + print("optimal labels:") for stop, profile in csa_profile.stop_profiles.items(): for label in profile.get_final_optimal_labels(): - print('origin:', stop, - 'n_boardings/movement_duration:', label.movement_duration, - 'route:', unpack_route_from_labels(label)) - #if stop == 1: - #assert 3 not in unpack_route_from_labels(label) + print( + "origin:", + stop, + "n_boardings/movement_duration:", + label.movement_duration, + "route:", + unpack_route_from_labels(label), + ) + # if stop == 1: + # assert 3 not in unpack_route_from_labels(label) # print('origin:', stop, 'n_boardings:', label.n_boardings, 'route:', unpack_route_from_labels(label)) - def test_zero_length_journeys_potential_bug_1(self): - event_list_raw_data = [ - (0, 1, 0, 0, "trip_1", 0), - (1, 2, 0, 0, "trip_1", 1) - ] + event_list_raw_data = [(0, 1, 0, 0, "trip_1", 0), (1, 2, 0, 0, "trip_1", 1)] transit_connections = list(map(lambda el: Connection(*el), event_list_raw_data)) walk_network = networkx.Graph() walk_network.add_edge(10, 1, {"d_walk": 20}) @@ -820,19 +1016,29 @@ def test_zero_length_journeys_potential_bug_1(self): transfer_margin = 0 start_time = 0 end_time = 50 - csa_profiler = MultiObjectivePseudoCSAProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - walk_network, walk_speed, - track_vehicle_legs=True, - track_time=True, - track_route=True) + csa_profiler = MultiObjectivePseudoCSAProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + track_vehicle_legs=True, + track_time=True, + track_route=True, + ) csa_profiler.run() stop_profile_1 = csa_profiler._stop_profiles[1] - all_labels_stop_profile_1 = [label for label_bag in stop_profile_1._label_bags for label in label_bag] + all_labels_stop_profile_1 = [ + label for label_bag in stop_profile_1._label_bags for label in label_bag + ] for label in all_labels_stop_profile_1: - self.assertLess(label.n_boardings, 1, "There should at most a walking label when going from 11 to 1 at any " - "point in time, now one label has " + str(label.n_boardings) + - " boardings" + self.assertLess( + label.n_boardings, + 1, + "There should at most a walking label when going from 11 to 1 at any " + "point in time, now one label has " + str(label.n_boardings) + " boardings", ) def test_zero_length_journeys_potential_bug(self): @@ -844,7 +1050,7 @@ def test_zero_length_journeys_potential_bug(self): event_list_raw_data = [ (s, a, 0, 0, "trip_1", 1), (a, b, 0, 0, "trip_1", 2), - (b, t, 1, 2, "trip_2", 0) + (b, t, 1, 2, "trip_2", 0), ] transit_connections = list(map(lambda el: Connection(*el), event_list_raw_data)) walk_network = networkx.Graph() @@ -853,15 +1059,20 @@ def test_zero_length_journeys_potential_bug(self): transfer_margin = 0 start_time = 0 end_time = 50 - csa_profiler = MultiObjectivePseudoCSAProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - walk_network, walk_speed, - track_vehicle_legs=True, - track_time=True, - track_route=True) + csa_profiler = MultiObjectivePseudoCSAProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + track_vehicle_legs=True, + track_time=True, + track_route=True, + ) csa_profiler.run() stop_profile_a_labels = csa_profiler.stop_profiles[a].get_final_optimal_labels() stop_profile_s_labels = csa_profiler.stop_profiles[s].get_final_optimal_labels() self.assertEqual(len(stop_profile_a_labels), 1) self.assertEqual(len(stop_profile_s_labels), 1) - diff --git a/gtfspy/routing/test/test_node_profile_analyzer_time.py b/gtfspy/routing/test/test_node_profile_analyzer_time.py index 381a6dd..a96dc02 100644 --- a/gtfspy/routing/test/test_node_profile_analyzer_time.py +++ b/gtfspy/routing/test/test_node_profile_analyzer_time.py @@ -10,26 +10,21 @@ class TestNodeProfileAnalyzerTime(TestCase): - def test_trip_duration_statistics_empty_profile(self): profile = NodeProfileSimple() analyzer = NodeProfileAnalyzerTime.from_profile(profile, 0, 10) - self.assertEqual(float('inf'), analyzer.max_trip_duration()) - self.assertEqual(float('inf'), analyzer.min_trip_duration()) - self.assertEqual(float('inf'), analyzer.mean_trip_duration()) - self.assertEqual(float('inf'), analyzer.median_trip_duration()) + self.assertEqual(float("inf"), analyzer.max_trip_duration()) + self.assertEqual(float("inf"), analyzer.min_trip_duration()) + self.assertEqual(float("inf"), analyzer.mean_trip_duration()) + self.assertEqual(float("inf"), analyzer.median_trip_duration()) - self.assertEqual(float('inf'), analyzer.max_temporal_distance()) - self.assertEqual(float('inf'), analyzer.min_temporal_distance()) - self.assertEqual(float('inf'), analyzer.mean_temporal_distance()) - self.assertEqual(float('inf'), analyzer.median_temporal_distance()) + self.assertEqual(float("inf"), analyzer.max_temporal_distance()) + self.assertEqual(float("inf"), analyzer.min_temporal_distance()) + self.assertEqual(float("inf"), analyzer.mean_temporal_distance()) + self.assertEqual(float("inf"), analyzer.median_temporal_distance()) def test_trip_duration_statistics_simple(self): - pairs = [ - LabelTimeSimple(1.0, 2.0), - LabelTimeSimple(2.0, 4.0), - LabelTimeSimple(4.0, 5.0) - ] + pairs = [LabelTimeSimple(1.0, 2.0), LabelTimeSimple(2.0, 4.0), LabelTimeSimple(4.0, 5.0)] profile = NodeProfileSimple() for pair in pairs: profile.update_pareto_optimal_tuples(pair) @@ -40,11 +35,7 @@ def test_trip_duration_statistics_simple(self): self.assertAlmostEqual(1, analyzer.median_trip_duration()) def test_temporal_distance_statistics(self): - pairs = [ - LabelTimeSimple(1, 2), - LabelTimeSimple(2, 4), - LabelTimeSimple(4, 5) - ] + pairs = [LabelTimeSimple(1, 2), LabelTimeSimple(2, 4), LabelTimeSimple(4, 5)] profile = NodeProfileSimple() for pair in pairs: profile.update_pareto_optimal_tuples(pair) @@ -52,7 +43,9 @@ def test_temporal_distance_statistics(self): analyzer = NodeProfileAnalyzerTime.from_profile(profile, 0, 3) self.assertAlmostEqual(4 - 1, analyzer.max_temporal_distance()) # 1 -wait-> 2 -travel->4 self.assertAlmostEqual(1, analyzer.min_temporal_distance()) - self.assertAlmostEqual((1.5 * 1 + 2.5 * 1 + 2.5 * 1) / 3., analyzer.mean_temporal_distance()) + self.assertAlmostEqual( + (1.5 * 1 + 2.5 * 1 + 2.5 * 1) / 3.0, analyzer.mean_temporal_distance() + ) self.assertAlmostEqual(2.25, analyzer.median_temporal_distance()) def test_temporal_distances_no_transit_trips_within_range(self): @@ -65,14 +58,16 @@ def test_temporal_distances_no_transit_trips_within_range(self): analyzer = NodeProfileAnalyzerTime.from_profile(profile, 0, 10) self.assertAlmostEqual(5, analyzer.max_temporal_distance()) self.assertAlmostEqual(2, analyzer.min_temporal_distance()) - self.assertAlmostEqual((7 * 5 + 3 * (5 + 2) / 2.) / 10.0, analyzer.mean_temporal_distance()) + self.assertAlmostEqual( + (7 * 5 + 3 * (5 + 2) / 2.0) / 10.0, analyzer.mean_temporal_distance() + ) self.assertAlmostEqual(5, analyzer.median_temporal_distance()) def test_temporal_distances_no_transit_trips_within_range_and_no_walk(self): pairs = [ LabelTimeSimple(departure_time=11, arrival_time_target=12), ] - profile = NodeProfileSimple(walk_to_target_duration=float('inf')) + profile = NodeProfileSimple(walk_to_target_duration=float("inf")) for pair in pairs: profile.update_pareto_optimal_tuples(pair) analyzer = NodeProfileAnalyzerTime.from_profile(profile, 0, 10) @@ -87,13 +82,13 @@ def test_time_offset(self): labels = [ LabelTimeSimple(departure_time=7248 + offset, arrival_time_target=14160 + offset), ] - profile = NodeProfileSimple(walk_to_target_duration=float('inf')) + profile = NodeProfileSimple(walk_to_target_duration=float("inf")) for label in labels: profile.update_pareto_optimal_tuples(label) analyzer = NodeProfileAnalyzerTime.from_profile(profile, 0 + offset, 7200 + offset) max_distances.append(analyzer.max_temporal_distance()) max_distances = numpy.array(max_distances) - assert((max_distances == max_distances[0]).all()) + assert (max_distances == max_distances[0]).all() # self.assertAlmostEqual(12, analyzer.max_temporal_distance()) # self.assertAlmostEqual(2, analyzer.min_temporal_distance()) # self.assertAlmostEqual((12 + 2) / 2.0, analyzer.mean_temporal_distance()) @@ -101,7 +96,9 @@ def test_time_offset(self): def test_temporal_distance_statistics_with_walk(self): pt1 = LabelTimeSimple(departure_time=1, arrival_time_target=2) - pt2 = LabelTimeSimple(departure_time=4, arrival_time_target=5) # not taken into account by the analyzer + pt2 = LabelTimeSimple( + departure_time=4, arrival_time_target=5 + ) # not taken into account by the analyzer profile = NodeProfileSimple(1.5) assert isinstance(pt1, LabelTimeSimple), type(pt1) profile.update_pareto_optimal_tuples(pt1) @@ -109,7 +106,7 @@ def test_temporal_distance_statistics_with_walk(self): analyzer = NodeProfileAnalyzerTime.from_profile(profile, 0, 3) self.assertAlmostEqual(1.5, analyzer.max_temporal_distance()) # 1 -wait-> 2 -travel->4 self.assertAlmostEqual(1, analyzer.min_temporal_distance()) - self.assertAlmostEqual((2.5 * 1.5 + 0.5 * 1.25) / 3., analyzer.mean_temporal_distance()) + self.assertAlmostEqual((2.5 * 1.5 + 0.5 * 1.25) / 3.0, analyzer.mean_temporal_distance()) self.assertAlmostEqual(1.5, analyzer.median_temporal_distance()) def test_temporal_distance_statistics_with_walk2(self): @@ -133,7 +130,11 @@ def test_temporal_distance_pdf_with_walk(self): self.assertEqual(len(analyzer.profile_block_analyzer._temporal_distance_pdf()), 3) - split_points, densities, delta_peaks = analyzer.profile_block_analyzer._temporal_distance_pdf() + ( + split_points, + densities, + delta_peaks, + ) = analyzer.profile_block_analyzer._temporal_distance_pdf() self.assertEqual(len(split_points), 2) self.assertEqual(split_points[0], 20) self.assertEqual(split_points[1], 25) @@ -144,7 +145,6 @@ def test_temporal_distance_pdf_with_walk(self): self.assertIn(25, delta_peaks) self.assertEqual(delta_peaks[25], 0.5) - @unittest.skip("Skipping plotting test") def test_all_plots(self): profile = NodeProfileSimple(25) @@ -157,24 +157,36 @@ def test_all_plots(self): plt.show() profile = NodeProfileSimple() - profile.update_pareto_optimal_tuples(LabelTimeSimple(departure_time=2 * 60, arrival_time_target=11 * 60)) - profile.update_pareto_optimal_tuples(LabelTimeSimple(departure_time=20 * 60, arrival_time_target=25 * 60)) - profile.update_pareto_optimal_tuples(LabelTimeSimple(departure_time=40 * 60, arrival_time_target=45 * 60)) + profile.update_pareto_optimal_tuples( + LabelTimeSimple(departure_time=2 * 60, arrival_time_target=11 * 60) + ) + profile.update_pareto_optimal_tuples( + LabelTimeSimple(departure_time=20 * 60, arrival_time_target=25 * 60) + ) + profile.update_pareto_optimal_tuples( + LabelTimeSimple(departure_time=40 * 60, arrival_time_target=45 * 60) + ) analyzer = NodeProfileAnalyzerTime.from_profile(profile, 0, 60 * 60) analyzer.plot_temporal_distance_profile() analyzer.plot_temporal_distance_cdf() analyzer.plot_temporal_distance_pdf() profile = NodeProfileSimple() - profile.update_pareto_optimal_tuples(LabelTimeSimple(departure_time=2 * 60, arrival_time_target=3 * 60)) - profile.update_pareto_optimal_tuples(LabelTimeSimple(departure_time=4 * 60, arrival_time_target=25 * 60)) + profile.update_pareto_optimal_tuples( + LabelTimeSimple(departure_time=2 * 60, arrival_time_target=3 * 60) + ) + profile.update_pareto_optimal_tuples( + LabelTimeSimple(departure_time=4 * 60, arrival_time_target=25 * 60) + ) analyzer = NodeProfileAnalyzerTime.from_profile(profile, 0, 5 * 60) analyzer.plot_temporal_distance_profile() analyzer.plot_temporal_distance_cdf() analyzer.plot_temporal_distance_pdf() pt1 = LabelTimeSimple(departure_time=1, arrival_time_target=2) - pt2 = LabelTimeSimple(departure_time=4, arrival_time_target=5) # not taken into account by the analyzer + pt2 = LabelTimeSimple( + departure_time=4, arrival_time_target=5 + ) # not taken into account by the analyzer profile = NodeProfileSimple(1.5) profile.update_pareto_optimal_tuples(pt1) profile.update_pareto_optimal_tuples(pt2) @@ -183,4 +195,3 @@ def test_all_plots(self): analyzer.plot_temporal_distance_cdf() plt.show() - diff --git a/gtfspy/routing/test/test_node_profile_analyzer_time_and_veh_legs.py b/gtfspy/routing/test/test_node_profile_analyzer_time_and_veh_legs.py index fe3b633..87a67ae 100644 --- a/gtfspy/routing/test/test_node_profile_analyzer_time_and_veh_legs.py +++ b/gtfspy/routing/test/test_node_profile_analyzer_time_and_veh_legs.py @@ -8,15 +8,16 @@ class TestNodeProfileAnalyzerTimeAndVehLegs(TestCase): - def setUp(self): self.label_class = LabelTimeWithBoardingsCount - def _get_analyzer(self, labels, start_time, end_time, walk_to_target_duration=float('inf')): + def _get_analyzer(self, labels, start_time, end_time, walk_to_target_duration=float("inf")): dep_times = list(set(map(lambda el: el.departure_time, labels))) - p = NodeProfileMultiObjective(dep_times=dep_times, - walk_to_target_duration=walk_to_target_duration, - label_class=LabelTimeWithBoardingsCount) + p = NodeProfileMultiObjective( + dep_times=dep_times, + walk_to_target_duration=walk_to_target_duration, + label_class=LabelTimeWithBoardingsCount, + ) for label in labels: p.update([label]) p.finalize() @@ -33,24 +34,33 @@ def test_trip_duration_statistics_empty_profile(self): def test_temporal_distances_by_n_vehicles(self): labels = [ - LabelTimeWithBoardingsCount(departure_time=10, arrival_time_target=12, n_boardings=4, first_leg_is_walk=False), - LabelTimeWithBoardingsCount(departure_time=10, arrival_time_target=15, n_boardings=2, first_leg_is_walk=False), - LabelTimeWithBoardingsCount(departure_time=10, arrival_time_target=17, n_boardings=1, first_leg_is_walk=False) + LabelTimeWithBoardingsCount( + departure_time=10, arrival_time_target=12, n_boardings=4, first_leg_is_walk=False + ), + LabelTimeWithBoardingsCount( + departure_time=10, arrival_time_target=15, n_boardings=2, first_leg_is_walk=False + ), + LabelTimeWithBoardingsCount( + departure_time=10, arrival_time_target=17, n_boardings=1, first_leg_is_walk=False + ), ] analyzer = self._get_analyzer(labels, 0, 10, walk_to_target_duration=10) median_temporal_distances = analyzer.median_temporal_distances() self.assertEqual(len(median_temporal_distances), 4 + 1) for i in range(len(median_temporal_distances) - 1): - assert(median_temporal_distances[i] >= median_temporal_distances[i + 1]) + assert median_temporal_distances[i] >= median_temporal_distances[i + 1] def test_n_boardings_on_shortest_paths(self): labels = [ - LabelTimeWithBoardingsCount(departure_time=10, arrival_time_target=12, n_boardings=4, - first_leg_is_walk=False), - LabelTimeWithBoardingsCount(departure_time=5, arrival_time_target=10, n_boardings=2, - first_leg_is_walk=False), - LabelTimeWithBoardingsCount(departure_time=5, arrival_time_target=12, n_boardings=0, - first_leg_is_walk=False) + LabelTimeWithBoardingsCount( + departure_time=10, arrival_time_target=12, n_boardings=4, first_leg_is_walk=False + ), + LabelTimeWithBoardingsCount( + departure_time=5, arrival_time_target=10, n_boardings=2, first_leg_is_walk=False + ), + LabelTimeWithBoardingsCount( + departure_time=5, arrival_time_target=12, n_boardings=0, first_leg_is_walk=False + ), ] analyzer = self._get_analyzer(labels, 0, 10, walk_to_target_duration=10) self.assertEqual(analyzer.mean_n_boardings_on_shortest_paths(), 3) @@ -59,93 +69,123 @@ def test_n_boardings_on_shortest_paths(self): def test_min_n_boardings(self): labels = [ - LabelTimeWithBoardingsCount(departure_time=10, arrival_time_target=12, n_boardings=4, - first_leg_is_walk=False), - LabelTimeWithBoardingsCount(departure_time=5, arrival_time_target=10, n_boardings=2, - first_leg_is_walk=False), - LabelTimeWithBoardingsCount(departure_time=5, arrival_time_target=12, n_boardings=1, - first_leg_is_walk=False) + LabelTimeWithBoardingsCount( + departure_time=10, arrival_time_target=12, n_boardings=4, first_leg_is_walk=False + ), + LabelTimeWithBoardingsCount( + departure_time=5, arrival_time_target=10, n_boardings=2, first_leg_is_walk=False + ), + LabelTimeWithBoardingsCount( + departure_time=5, arrival_time_target=12, n_boardings=1, first_leg_is_walk=False + ), ] analyzer = self._get_analyzer(labels, 0, 10, walk_to_target_duration=10) self.assertEqual(analyzer.min_n_boardings(), 0) - analyzer2 = self._get_analyzer(labels, 0, 10, walk_to_target_duration=float('inf')) + analyzer2 = self._get_analyzer(labels, 0, 10, walk_to_target_duration=float("inf")) self.assertEqual(analyzer2.min_n_boardings(), 1) def test_min_n_boardings_after_departure_time_2(self): # (This tests the bug experienced with the Jollas region) labels = [ - LabelTimeWithBoardingsCount(departure_time=12, arrival_time_target=14, n_boardings=2, - first_leg_is_walk=False), - LabelTimeWithBoardingsCount(departure_time=11, arrival_time_target=12, n_boardings=3, - first_leg_is_walk=False), - LabelTimeWithBoardingsCount(departure_time=5, arrival_time_target=10, n_boardings=4, - first_leg_is_walk=False), + LabelTimeWithBoardingsCount( + departure_time=12, arrival_time_target=14, n_boardings=2, first_leg_is_walk=False + ), + LabelTimeWithBoardingsCount( + departure_time=11, arrival_time_target=12, n_boardings=3, first_leg_is_walk=False + ), + LabelTimeWithBoardingsCount( + departure_time=5, arrival_time_target=10, n_boardings=4, first_leg_is_walk=False + ), ] - analyzer = self._get_analyzer(labels, 0, 10, walk_to_target_duration=float('inf')) + analyzer = self._get_analyzer(labels, 0, 10, walk_to_target_duration=float("inf")) self.assertEqual(analyzer.min_n_boardings(), 2) - - def test_min_n_boardings_on_fastest_paths(self): labels = [ - LabelTimeWithBoardingsCount(departure_time=5, arrival_time_target=12, n_boardings=4, - first_leg_is_walk=False), + LabelTimeWithBoardingsCount( + departure_time=5, arrival_time_target=12, n_boardings=4, first_leg_is_walk=False + ), ] analyzer = self._get_analyzer(labels, 0, 10, walk_to_target_duration=10) self.assertEqual(analyzer.min_n_boardings_on_shortest_paths(), 0) def test_mean_n_boardings_on_fastest_paths(self): labels = [ - LabelTimeWithBoardingsCount(departure_time=5, arrival_time_target=6, n_boardings=1, - first_leg_is_walk=False), + LabelTimeWithBoardingsCount( + departure_time=5, arrival_time_target=6, n_boardings=1, first_leg_is_walk=False + ), ] analyzer = self._get_analyzer(labels, 0, 10, walk_to_target_duration=10) self.assertEqual(analyzer.mean_n_boardings_on_shortest_paths(), 0.5) labels = [ - LabelTimeWithBoardingsCount(departure_time=10, arrival_time_target=15, - n_boardings=1, first_leg_is_walk=False), + LabelTimeWithBoardingsCount( + departure_time=10, arrival_time_target=15, n_boardings=1, first_leg_is_walk=False + ), ] analyzer = self._get_analyzer(labels, 0, 10, walk_to_target_duration=10) self.assertEqual(analyzer.mean_n_boardings_on_shortest_paths(), 0.5) def test_mean_temporal_distance_with_min_n_boardings(self): labels = [ - LabelTimeWithBoardingsCount(departure_time=10, arrival_time_target=22, n_boardings=4, - first_leg_is_walk=False), - LabelTimeWithBoardingsCount(departure_time=10, arrival_time_target=24, n_boardings=2, - first_leg_is_walk=False), - LabelTimeWithBoardingsCount(departure_time=10, arrival_time_target=26, n_boardings=1, - first_leg_is_walk=False) + LabelTimeWithBoardingsCount( + departure_time=10, arrival_time_target=22, n_boardings=4, first_leg_is_walk=False + ), + LabelTimeWithBoardingsCount( + departure_time=10, arrival_time_target=24, n_boardings=2, first_leg_is_walk=False + ), + LabelTimeWithBoardingsCount( + departure_time=10, arrival_time_target=26, n_boardings=1, first_leg_is_walk=False + ), ] - analyzer = self._get_analyzer(labels, 0, 5, walk_to_target_duration=float('inf')) + analyzer = self._get_analyzer(labels, 0, 5, walk_to_target_duration=float("inf")) self.assertEqual(analyzer.mean_temporal_distance_with_min_n_boardings(), 2.5 + 5 + 26 - 10) def test_n_boardings_on_fastest_trip(self): labels = [ - LabelTimeWithBoardingsCount(departure_time=10, arrival_time_target=22, n_boardings=4, - first_leg_is_walk=False), - LabelTimeWithBoardingsCount(departure_time=10, arrival_time_target=24, n_boardings=2, - first_leg_is_walk=False), - LabelTimeWithBoardingsCount(departure_time=10, arrival_time_target=26, n_boardings=1, - first_leg_is_walk=False) + LabelTimeWithBoardingsCount( + departure_time=10, arrival_time_target=22, n_boardings=4, first_leg_is_walk=False + ), + LabelTimeWithBoardingsCount( + departure_time=10, arrival_time_target=24, n_boardings=2, first_leg_is_walk=False + ), + LabelTimeWithBoardingsCount( + departure_time=10, arrival_time_target=26, n_boardings=1, first_leg_is_walk=False + ), ] - analyzer = self._get_analyzer(labels, start_time=0, end_time=11, walk_to_target_duration=float('inf')) + analyzer = self._get_analyzer( + labels, start_time=0, end_time=11, walk_to_target_duration=float("inf") + ) self.assertEqual(analyzer.n_boardings_on_fastest_trip(), 4) - @unittest.skip + @unittest.skip # type: ignore def test_plot(self): labels = [ - LabelTimeWithBoardingsCount(departure_time=20, arrival_time_target=22, n_boardings=5, first_leg_is_walk=False), - LabelTimeWithBoardingsCount(departure_time=15, arrival_time_target=20, n_boardings=6, first_leg_is_walk=False), - LabelTimeWithBoardingsCount(departure_time=14, arrival_time_target=21, n_boardings=5, first_leg_is_walk=False), - LabelTimeWithBoardingsCount(departure_time=13, arrival_time_target=22, n_boardings=4, first_leg_is_walk=False), + LabelTimeWithBoardingsCount( + departure_time=20, arrival_time_target=22, n_boardings=5, first_leg_is_walk=False + ), + LabelTimeWithBoardingsCount( + departure_time=15, arrival_time_target=20, n_boardings=6, first_leg_is_walk=False + ), + LabelTimeWithBoardingsCount( + departure_time=14, arrival_time_target=21, n_boardings=5, first_leg_is_walk=False + ), + LabelTimeWithBoardingsCount( + departure_time=13, arrival_time_target=22, n_boardings=4, first_leg_is_walk=False + ), # LabelTimeWithBoardingsCount(departure_time=12, arrival_time_target=23, n_vehicle_legs=3), - LabelTimeWithBoardingsCount(departure_time=11, arrival_time_target=24, n_boardings=2, first_leg_is_walk=True), - LabelTimeWithBoardingsCount(departure_time=10, arrival_time_target=25, n_boardings=1, first_leg_is_walk=True), - LabelTimeWithBoardingsCount(departure_time=5, arrival_time_target=10, n_boardings=1, first_leg_is_walk=True) + LabelTimeWithBoardingsCount( + departure_time=11, arrival_time_target=24, n_boardings=2, first_leg_is_walk=True + ), + LabelTimeWithBoardingsCount( + departure_time=10, arrival_time_target=25, n_boardings=1, first_leg_is_walk=True + ), + LabelTimeWithBoardingsCount( + departure_time=5, arrival_time_target=10, n_boardings=1, first_leg_is_walk=True + ), ] - analyzer = self._get_analyzer(labels, 0, 20, 35) + self._get_analyzer(labels, 0, 20, 35) print(fig) import matplotlib.pyplot as plt + plt.show() diff --git a/gtfspy/routing/test/test_node_profile_c.py b/gtfspy/routing/test/test_node_profile_c.py index 0d14056..bf80418 100644 --- a/gtfspy/routing/test/test_node_profile_c.py +++ b/gtfspy/routing/test/test_node_profile_c.py @@ -5,15 +5,18 @@ class TestNodeProfileC(TestCase): - def test_earliest_arrival_time(self): node_profile = NodeProfileC() self.assertEquals(float("inf"), node_profile.evaluate_earliest_arrival_time_at_target(0, 0)) - node_profile.update_pareto_optimal_tuples(LabelTime(departure_time=3, arrival_time_target=4)) + node_profile.update_pareto_optimal_tuples( + LabelTime(departure_time=3, arrival_time_target=4) + ) self.assertEquals(4, node_profile.evaluate_earliest_arrival_time_at_target(2, 0)) - node_profile.update_pareto_optimal_tuples(LabelTime(departure_time=1, arrival_time_target=1)) + node_profile.update_pareto_optimal_tuples( + LabelTime(departure_time=1, arrival_time_target=1) + ) self.assertEquals(1, node_profile.evaluate_earliest_arrival_time_at_target(0, 0)) def test_pareto_optimality(self): diff --git a/gtfspy/routing/test/test_node_profile_multi_objective.py b/gtfspy/routing/test/test_node_profile_multi_objective.py index f35d1e3..dba5a23 100644 --- a/gtfspy/routing/test/test_node_profile_multi_objective.py +++ b/gtfspy/routing/test/test_node_profile_multi_objective.py @@ -1,14 +1,19 @@ import pyximport + pyximport.install() from unittest import TestCase from gtfspy.routing.node_profile_multiobjective import NodeProfileMultiObjective -from gtfspy.routing.label import LabelTime, min_arrival_time_target, LabelTimeWithBoardingsCount, LabelVehLegCount +from gtfspy.routing.label import ( + LabelTime, + min_arrival_time_target, + LabelTimeWithBoardingsCount, + LabelVehLegCount, +) class TestNodeProfileMultiObjective(TestCase): - def test_evaluate(self): node_profile = NodeProfileMultiObjective(dep_times=[3, 1], label_class=LabelTime) @@ -29,10 +34,14 @@ def test_pareto_optimality2(self): def test_identity_profile(self): identity_profile = NodeProfileMultiObjective(dep_times=[10]) identity_profile.update([LabelTimeWithBoardingsCount(10, 10, 0, True)]) - self.assertEqual(10, min_arrival_time_target(identity_profile.evaluate(10, first_leg_can_be_walk=True))) + self.assertEqual( + 10, min_arrival_time_target(identity_profile.evaluate(10, first_leg_can_be_walk=True)) + ) def test_walk_duration(self): - node_profile = NodeProfileMultiObjective(dep_times=[10, 5], walk_to_target_duration=27, label_class=LabelTime) + node_profile = NodeProfileMultiObjective( + dep_times=[10, 5], walk_to_target_duration=27, label_class=LabelTime + ) self.assertEqual(27, node_profile.get_walk_to_target_duration()) pt2 = LabelTime(departure_time=10, arrival_time_target=35) pt1 = LabelTime(departure_time=5, arrival_time_target=35) @@ -41,9 +50,15 @@ def test_walk_duration(self): def test_pareto_optimality_with_transfers_and_time(self): node_profile = NodeProfileMultiObjective(dep_times=[5, 6, 7]) - pt3 = LabelTimeWithBoardingsCount(departure_time=5, arrival_time_target=45, n_boardings=0, first_leg_is_walk=False) - pt2 = LabelTimeWithBoardingsCount(departure_time=6, arrival_time_target=40, n_boardings=1, first_leg_is_walk=False) - pt1 = LabelTimeWithBoardingsCount(departure_time=7, arrival_time_target=35, n_boardings=2, first_leg_is_walk=False) + pt3 = LabelTimeWithBoardingsCount( + departure_time=5, arrival_time_target=45, n_boardings=0, first_leg_is_walk=False + ) + pt2 = LabelTimeWithBoardingsCount( + departure_time=6, arrival_time_target=40, n_boardings=1, first_leg_is_walk=False + ) + pt1 = LabelTimeWithBoardingsCount( + departure_time=7, arrival_time_target=35, n_boardings=2, first_leg_is_walk=False + ) self.assertTrue(node_profile.update([pt1])) self.assertTrue(node_profile.update([pt2])) self.assertTrue(node_profile.update([pt3])) @@ -62,30 +77,44 @@ def test_pareto_optimality_with_transfers_only(self): self.assertEqual(1, len(node_profile.get_final_optimal_labels())) def test_finalize(self): - node_profile = NodeProfileMultiObjective(label_class=LabelTimeWithBoardingsCount, dep_times=[10]) - own_label = LabelTimeWithBoardingsCount(departure_time=10, arrival_time_target=20, n_boardings=0, first_leg_is_walk=False) + node_profile = NodeProfileMultiObjective( + label_class=LabelTimeWithBoardingsCount, dep_times=[10] + ) + own_label = LabelTimeWithBoardingsCount( + departure_time=10, arrival_time_target=20, n_boardings=0, first_leg_is_walk=False + ) self.assertTrue(node_profile.update([own_label])) - neighbor_label = LabelTimeWithBoardingsCount(departure_time=15, arrival_time_target=18, n_boardings=2, first_leg_is_walk=False) - assert(len(node_profile.get_labels_for_real_connections()) == 1) + neighbor_label = LabelTimeWithBoardingsCount( + departure_time=15, arrival_time_target=18, n_boardings=2, first_leg_is_walk=False + ) + assert len(node_profile.get_labels_for_real_connections()) == 1 node_profile.finalize([[neighbor_label]], [3]) - assert (len(node_profile.get_final_optimal_labels()) == 2) - self.assertTrue(any(map(lambda el: el.departure_time == 12, node_profile.get_final_optimal_labels()))) + assert len(node_profile.get_final_optimal_labels()) == 2 + self.assertTrue( + any(map(lambda el: el.departure_time == 12, node_profile.get_final_optimal_labels())) + ) def test_same_dep_times_fail_in_init(self): with self.assertRaises(AssertionError): - node_profile = NodeProfileMultiObjective(label_class=LabelTimeWithBoardingsCount, dep_times=[10, 10, 20, 20]) - + NodeProfileMultiObjective( + label_class=LabelTimeWithBoardingsCount, dep_times=[10, 10, 20, 20] + ) def test_dep_time_skipped_in_update(self): - label3 = LabelTimeWithBoardingsCount(departure_time=30, arrival_time_target=20, n_boardings=0, - first_leg_is_walk=False) - label2 = LabelTimeWithBoardingsCount(departure_time=20, arrival_time_target=20, n_boardings=0, - first_leg_is_walk=False) - label1 = LabelTimeWithBoardingsCount(departure_time=10, arrival_time_target=20, n_boardings=0, - first_leg_is_walk=False) + label3 = LabelTimeWithBoardingsCount( + departure_time=30, arrival_time_target=20, n_boardings=0, first_leg_is_walk=False + ) + label2 = LabelTimeWithBoardingsCount( + departure_time=20, arrival_time_target=20, n_boardings=0, first_leg_is_walk=False + ) + label1 = LabelTimeWithBoardingsCount( + departure_time=10, arrival_time_target=20, n_boardings=0, first_leg_is_walk=False + ) # This should work ok - node_profile = NodeProfileMultiObjective(label_class=LabelTimeWithBoardingsCount, dep_times=[10, 20, 30]) + node_profile = NodeProfileMultiObjective( + label_class=LabelTimeWithBoardingsCount, dep_times=[10, 20, 30] + ) node_profile.update([label3]) node_profile.update([label2]) node_profile.update([label2]) @@ -93,12 +122,15 @@ def test_dep_time_skipped_in_update(self): # This should fail due to dep time 20 missing in between with self.assertRaises(AssertionError): - node_profile = NodeProfileMultiObjective(label_class=LabelTimeWithBoardingsCount, dep_times=[10, 20, 30]) + node_profile = NodeProfileMultiObjective( + label_class=LabelTimeWithBoardingsCount, dep_times=[10, 20, 30] + ) node_profile.update([label3]) node_profile.update([label1]) # This should fail due to dep time 30 not being the first to deal with with self.assertRaises(AssertionError): - node_profile = NodeProfileMultiObjective(label_class=LabelTimeWithBoardingsCount, dep_times=[10, 20, 30]) + node_profile = NodeProfileMultiObjective( + label_class=LabelTimeWithBoardingsCount, dep_times=[10, 20, 30] + ) node_profile.update([label2]) - diff --git a/gtfspy/routing/test/test_node_profile_simple.py b/gtfspy/routing/test/test_node_profile_simple.py index 0b3ed46..d3003e3 100644 --- a/gtfspy/routing/test/test_node_profile_simple.py +++ b/gtfspy/routing/test/test_node_profile_simple.py @@ -1,4 +1,5 @@ from pyximport import install + install() from unittest import TestCase @@ -8,15 +9,18 @@ class TestNodeProfileSimple(TestCase): - def test_earliest_arrival_time(self): node_profile = NodeProfileSimple() self.assertEquals(float("inf"), node_profile.evaluate_earliest_arrival_time_at_target(0, 0)) - node_profile.update_pareto_optimal_tuples(LabelTimeSimple(departure_time=1, arrival_time_target=1)) + node_profile.update_pareto_optimal_tuples( + LabelTimeSimple(departure_time=1, arrival_time_target=1) + ) self.assertEquals(1, node_profile.evaluate_earliest_arrival_time_at_target(0, 0)) - node_profile.update_pareto_optimal_tuples(LabelTimeSimple(departure_time=3, arrival_time_target=4)) + node_profile.update_pareto_optimal_tuples( + LabelTimeSimple(departure_time=3, arrival_time_target=4) + ) self.assertEquals(4, node_profile.evaluate_earliest_arrival_time_at_target(2, 0)) def test_pareto_optimality(self): @@ -32,7 +36,9 @@ def test_pareto_optimality(self): pair3 = LabelTimeSimple(departure_time=1, arrival_time_target=1) self.assertTrue(node_profile.update_pareto_optimal_tuples(pair3)) - self.assertEquals(2, len(node_profile._labels), msg=str(node_profile.get_final_optimal_labels())) + self.assertEquals( + 2, len(node_profile._labels), msg=str(node_profile.get_final_optimal_labels()) + ) pair4 = LabelTimeSimple(departure_time=1, arrival_time_target=2) self.assertFalse(node_profile.update_pareto_optimal_tuples(pair4)) @@ -60,11 +66,16 @@ def test_walk_duration(self): def test_pareto_optimality_with_transfers(self): node_profile = NodeProfileSimple(label_class=LabelTimeWithBoardingsCount) - pt3 = LabelTimeWithBoardingsCount(departure_time=5, arrival_time_target=35, n_boardings=0, first_leg_is_walk=True) - pt2 = LabelTimeWithBoardingsCount(departure_time=5, arrival_time_target=35, n_boardings=1, first_leg_is_walk=True) - pt1 = LabelTimeWithBoardingsCount(departure_time=5, arrival_time_target=35, n_boardings=2, first_leg_is_walk=True) + pt3 = LabelTimeWithBoardingsCount( + departure_time=5, arrival_time_target=35, n_boardings=0, first_leg_is_walk=True + ) + pt2 = LabelTimeWithBoardingsCount( + departure_time=5, arrival_time_target=35, n_boardings=1, first_leg_is_walk=True + ) + pt1 = LabelTimeWithBoardingsCount( + departure_time=5, arrival_time_target=35, n_boardings=2, first_leg_is_walk=True + ) self.assertTrue(node_profile.update_pareto_optimal_tuples(pt1)) self.assertTrue(node_profile.update_pareto_optimal_tuples(pt2)) self.assertTrue(node_profile.update_pareto_optimal_tuples(pt3)) self.assertEqual(1, len(node_profile.get_final_optimal_labels())) - diff --git a/gtfspy/routing/test/test_profile_block_analyzer.py b/gtfspy/routing/test/test_profile_block_analyzer.py index 45a5415..2f04356 100644 --- a/gtfspy/routing/test/test_profile_block_analyzer.py +++ b/gtfspy/routing/test/test_profile_block_analyzer.py @@ -3,16 +3,14 @@ from gtfspy.routing.profile_block_analyzer import ProfileBlockAnalyzer from gtfspy.routing.profile_block import ProfileBlock -class TestProfileBlockAnalyzer(TestCase): +class TestProfileBlockAnalyzer(TestCase): def test_interpolate(self): blocks = [ProfileBlock(0, 1, 2, 1), ProfileBlock(1, 2, 2, 2)] - analyzer = ProfileBlockAnalyzer(blocks, cutoff_distance=3.0) + analyzer = ProfileBlockAnalyzer(blocks, cutoff_distance=3.0) self.assertAlmostEqual(analyzer.interpolate(0.2), 1.8) - self.assertAlmostEqual(analyzer.interpolate(1-10**-9), 1.) + self.assertAlmostEqual(analyzer.interpolate(1 - 10 ** -9), 1.0) self.assertAlmostEqual(analyzer.interpolate(1), 1) - self.assertAlmostEqual(analyzer.interpolate(1.+10**-9), 2) + self.assertAlmostEqual(analyzer.interpolate(1.0 + 10 ** -9), 2) self.assertAlmostEqual(analyzer.interpolate(1.23), 2) self.assertAlmostEqual(analyzer.interpolate(2), 2) - - diff --git a/gtfspy/routing/test/test_pseudo_connection_scan_profiler.py b/gtfspy/routing/test/test_pseudo_connection_scan_profiler.py index e91ca8f..814e1c3 100644 --- a/gtfspy/routing/test/test_pseudo_connection_scan_profiler.py +++ b/gtfspy/routing/test/test_pseudo_connection_scan_profiler.py @@ -17,7 +17,7 @@ def setUp(self): (3, 4, 32, 35, "trip_4", 1), (2, 3, 25, 30, "trip_3", 1), (1, 2, 10, 20, "trip_2", 1), - (0, 1, 0, 10, "trip_1", 1) + (0, 1, 0, 10, "trip_1", 1), ] self.transit_connections = list(map(lambda el: Connection(*el), event_list_raw_data)) self.walk_network = networkx.Graph() @@ -31,9 +31,15 @@ def setUp(self): self.end_time = 50 def test_basics(self): - csa_profile = PseudoConnectionScanProfiler(self.transit_connections, self.target_stop, - self.start_time, self.end_time, self.transfer_margin, - self.walk_network, self.walk_speed) + csa_profile = PseudoConnectionScanProfiler( + self.transit_connections, + self.target_stop, + self.start_time, + self.end_time, + self.transfer_margin, + self.walk_network, + self.walk_speed, + ) csa_profile.run() stop_3_labels = csa_profile.stop_profiles[3].get_final_optimal_labels() @@ -53,10 +59,7 @@ def test_basics(self): labels.append(LabelTime(departure_time=20, arrival_time_target=50)) labels.append(LabelTime(departure_time=32, arrival_time_target=55)) - self._assert_pareto_tuple_sets_equal( - labels, - source_stop_pareto_optimal_tuples - ) + self._assert_pareto_tuple_sets_equal(labels, source_stop_pareto_optimal_tuples) def test_simple(self): event_list_raw_data = [ @@ -76,22 +79,23 @@ def test_simple(self): labels = [] labels.append(LabelTime(departure_time=20, arrival_time_target=50)) - csa_profile = PseudoConnectionScanProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - walk_network, walk_speed) + csa_profile = PseudoConnectionScanProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + ) csa_profile.run() source_stop_profile = csa_profile.stop_profiles[source_stop] source_stop_labels = source_stop_profile.get_final_optimal_labels() - self._assert_pareto_tuple_sets_equal( - labels, - source_stop_labels - ) + self._assert_pareto_tuple_sets_equal(labels, source_stop_labels) def test_last_leg_is_walk(self): - event_list_raw_data = [ - (0, 1, 0, 10, "trip_1", 1) - ] + event_list_raw_data = [(0, 1, 0, 10, "trip_1", 1)] transit_connections = list(map(lambda el: Connection(*el), event_list_raw_data)) walk_network = networkx.Graph() walk_network.add_edge(1, 2, {"d_walk": 20}) @@ -105,17 +109,21 @@ def test_last_leg_is_walk(self): labels = [] labels.append(LabelTime(departure_time=0, arrival_time_target=30)) - csa_profile = PseudoConnectionScanProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - walk_network, walk_speed) + csa_profile = PseudoConnectionScanProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + ) csa_profile.run() found_tuples = csa_profile.stop_profiles[source_stop].get_final_optimal_labels() self._assert_pareto_tuple_sets_equal(found_tuples, labels) def test_walk_is_faster_than_by_trip(self): - event_list_raw_data = [ - (0, 1, 0, 10, "trip_1", 1) - ] + event_list_raw_data = [(0, 1, 0, 10, "trip_1", 1)] transit_connections = list(map(lambda el: Connection(*el), event_list_raw_data)) walk_speed = 2 source_stop = 0 @@ -126,9 +134,15 @@ def test_walk_is_faster_than_by_trip(self): walk_network = networkx.Graph() walk_network.add_edge(0, 1, {"d_walk": 1}) - csa_profile = PseudoConnectionScanProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - walk_network, walk_speed) + csa_profile = PseudoConnectionScanProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + ) csa_profile.run() source_profile = csa_profile.stop_profiles[source_stop] self.assertEqual(source_profile.evaluate_earliest_arrival_time_at_target(0, 0), 0.5) @@ -136,9 +150,7 @@ def test_walk_is_faster_than_by_trip(self): self.assertEqual(len(found_tuples), 0) def test_target_node_not_in_walk_network(self): - event_list_raw_data = [ - (0, 1, 0, 10, "trip_1", 1) - ] + event_list_raw_data = [(0, 1, 0, 10, "trip_1", 1)] transit_connections = list(map(lambda el: Connection(*el), event_list_raw_data)) walk_speed = 2 source_stop = 0 @@ -148,9 +160,15 @@ def test_target_node_not_in_walk_network(self): end_time = 50 walk_network = networkx.Graph() - csa_profile = PseudoConnectionScanProfiler(transit_connections, target_stop, - start_time, end_time, transfer_margin, - walk_network, walk_speed) + csa_profile = PseudoConnectionScanProfiler( + transit_connections, + target_stop, + start_time, + end_time, + transfer_margin, + walk_network, + walk_speed, + ) csa_profile.run() source_profile = csa_profile.stop_profiles[source_stop] self.assertEqual(source_profile.evaluate_earliest_arrival_time_at_target(0, 0), 10) diff --git a/gtfspy/routing/travel_impedance_data_store.py b/gtfspy/routing/travel_impedance_data_store.py index a0ad15c..df3053f 100644 --- a/gtfspy/routing/travel_impedance_data_store.py +++ b/gtfspy/routing/travel_impedance_data_store.py @@ -3,17 +3,14 @@ class TravelImpedanceDataStore: - def __init__(self, db_fname, timeout=100): self.db_fname = db_fname self.timeout = timeout self.conn = sqlite3.connect(self.db_fname, timeout) - def read_data_as_dataframe(self, - travel_impedance_measure, - from_stop_I=None, - to_stop_I=None, - statistic=None): + def read_data_as_dataframe( + self, travel_impedance_measure, from_stop_I=None, to_stop_I=None, statistic=None + ): """ Recover pre-computed travel_impedance between od-pairs from the database. @@ -39,18 +36,22 @@ def read_data_as_dataframe(self, to_select_clause = ",".join(to_select) if not to_select_clause: to_select_clause = "*" - sql = "SELECT " + to_select_clause + " FROM " + travel_impedance_measure + where_clause + ";" + sql = ( + "SELECT " + to_select_clause + " FROM " + travel_impedance_measure + where_clause + ";" + ) df = pd.read_sql(sql, self.conn) return df def create_table(self, travel_impedance_measure, ensure_uniqueness=True): print("Creating table: ", travel_impedance_measure) - sql = "CREATE TABLE IF NOT EXISTS " + travel_impedance_measure + " (from_stop_I INT, " \ - "to_stop_I INT, " \ - "min REAL, " \ - "max REAL, " \ - "median REAL, " \ - "mean REAL" + sql = ( + "CREATE TABLE IF NOT EXISTS " + travel_impedance_measure + " (from_stop_I INT, " + "to_stop_I INT, " + "min REAL, " + "max REAL, " + "median REAL, " + "mean REAL" + ) if ensure_uniqueness: sql = sql + ", UNIQUE (from_stop_I, to_stop_I) )" else: @@ -69,8 +70,16 @@ def create_indices_for_all_tables(self, use_memory_as_temp_store=False): def create_indices(self, travel_impedance_measure_name): table = travel_impedance_measure_name - sql_from_to = "CREATE UNIQUE INDEX IF NOT EXISTS " + table + "_from_stop_I_to_stop_I ON " + table + " (from_stop_I, to_stop_I)" - sql_from = "CREATE INDEX IF NOT EXISTS " + table + "_from_stop_I ON " + table + " (from_stop_I)" + sql_from_to = ( + "CREATE UNIQUE INDEX IF NOT EXISTS " + + table + + "_from_stop_I_to_stop_I ON " + + table + + " (from_stop_I, to_stop_I)" + ) + sql_from = ( + "CREATE INDEX IF NOT EXISTS " + table + "_from_stop_I ON " + table + " (from_stop_I)" + ) sql_to = "CREATE INDEX IF NOT EXISTS " + table + "_to_stop_I ON " + table + " (to_stop_I)" print("Executing: " + sql_from_to) self.conn.execute(sql_from_to) @@ -90,20 +99,30 @@ def insert_data(self, travel_impedance_measure_name, data): "from_stop_I", "to_stop_I", "min", "max", "median" and "mean" """ f = float - data_tuple = [(int(x["from_stop_I"]), int(x["to_stop_I"]), f(x["min"]), f(x["max"]), f(x["median"]), f(x["mean"])) for - x in data] - insert_stmt = '''INSERT OR REPLACE INTO ''' + travel_impedance_measure_name + ''' ( + data_tuple = [ + ( + int(x["from_stop_I"]), + int(x["to_stop_I"]), + f(x["min"]), + f(x["max"]), + f(x["median"]), + f(x["mean"]), + ) + for x in data + ] + insert_stmt = ( + """INSERT OR REPLACE INTO """ + + travel_impedance_measure_name + + """ ( from_stop_I, to_stop_I, min, max, median, - mean) VALUES (?, ?, ?, ?, ?, ?) ''' + mean) VALUES (?, ?, ?, ?, ?, ?) """ + ) self.conn.executemany(insert_stmt, data_tuple) self.conn.commit() def apply_insertion_speedups(self): self.conn.execute("PRAGMA SYNCHRONOUS = OFF") - - - diff --git a/gtfspy/routing/util.py b/gtfspy/routing/util.py index 5f6dc47..94d8c37 100644 --- a/gtfspy/routing/util.py +++ b/gtfspy/routing/util.py @@ -1,5 +1,6 @@ import time + def timeit(method): """ A Python decorator for printing out the execution time for a function. @@ -7,11 +8,15 @@ def timeit(method): Adapted from: www.andreas-jung.com/contents/a-python-decorator-for-measuring-the-execution-time-of-methods """ + def timed(*args, **kw): time_start = time.time() result = method(*args, **kw) time_end = time.time() - print('timeit: %r %2.2f sec (%r, %r) ' % (method.__name__, time_end-time_start, str(args)[:20], kw)) + print( + "timeit: %r %2.2f sec (%r, %r) " + % (method.__name__, time_end - time_start, str(args)[:20], kw) + ) return result - return timed \ No newline at end of file + return timed diff --git a/gtfspy/segments.py b/gtfspy/segments.py index a7d2030..dfdcbb6 100644 --- a/gtfspy/segments.py +++ b/gtfspy/segments.py @@ -1,7 +1,4 @@ - - class Segments(object): - def __init__(self, gtfs): self._gtfs @@ -16,7 +13,8 @@ def get_segments(self): cur = self._gtfs.get_cursor() # Find our IDs that are relevant. - cur.execute('''SELECT trip_I, cnt, seq1, seq2, S1.code, S2.code, + cur.execute( + """SELECT trip_I, cnt, seq1, seq2, S1.code, S2.code, S1.name AS name1, S2.name AS name2, S1.lat, S1.lon, S2.lat, S2.lon @@ -32,13 +30,14 @@ def get_segments(self): LEFT JOIN stops S1 ON (sid1=S1.stop_I) LEFT JOIN stops S2 ON (sid2=S2.stop_I) --ORDER BY cnt DESC LIMIT 10 ; - ''') - + """ + ) class Segment(object): - - def __init__(self, from_node, to_node, distance, time, vehicle_count, capacity_per_hour, lines, modes): + def __init__( + self, from_node, to_node, distance, time, vehicle_count, capacity_per_hour, lines, modes + ): self.from_node = from_node self.to_node = to_node self.distance = distance @@ -47,4 +46,3 @@ def __init__(self, from_node, to_node, distance, time, vehicle_count, capacity_p self.capacity_per_hour = capacity_per_hour self.lines = lines self.modes = modes - diff --git a/gtfspy/shapes.py b/gtfspy/shapes.py index 4f5e96a..133da21 100644 --- a/gtfspy/shapes.py +++ b/gtfspy/shapes.py @@ -34,16 +34,16 @@ from .util import wgs84_distance -def print_coords(rows, prefix=''): +def print_coords(rows, prefix=""): """Print coordinates within a sequence. This is only used for debugging. Printed in a form that can be pasted into Python for visualization.""" - lat = [row['lat'] for row in rows] - lon = [row['lon'] for row in rows] - print('COORDS'+'-' * 5) + lat = [row["lat"] for row in rows] + lon = [row["lon"] for row in rows] + print("COORDS" + "-" * 5) print("%slat, %slon = %r, %r" % (prefix, prefix, lat, lon)) - print('-'*5) + print("-" * 5) def find_segments(stops, shape): @@ -75,19 +75,19 @@ def find_segments(stops, shape): last_i = 0 cumul_d = 0 badness = 0 - d_last_stop = float('inf') + d_last_stop = float("inf") lstlat, lstlon = None, None break_shape_points = [] for stop in stops: - stlat, stlon = stop['lat'], stop['lon'] - best_d = float('inf') + stlat, stlon = stop["lat"], stop["lon"] + best_d = float("inf") # print stop if badness > 500 and badness > 30 * len(break_points): return [], badness for i in range(last_i, len(shape)): - d = wgs84_distance(stlat, stlon, shape[i]['lat'], shape[i]['lon']) + d = wgs84_distance(stlat, stlon, shape[i]["lat"], shape[i]["lon"]) if lstlat: - d_last_stop = wgs84_distance(lstlat, lstlon, shape[i]['lat'], shape[i]['lon']) + d_last_stop = wgs84_distance(lstlat, lstlon, shape[i]["lat"], shape[i]["lon"]) # If we are getting closer to next stop, record this as # the best stop so far.continue if d < best_d: @@ -98,7 +98,7 @@ def find_segments(stops, shape): # We have to be very careful about our stop condition. # This is trial and error, basically. if (d_last_stop < d) or (d > 500) or (i < best_i + 100): - continue + continue # We have decided our best stop, stop looking and continue # the outer loop. else: @@ -123,8 +123,7 @@ def find_segments(stops, shape): return break_points, badness -def find_best_segments(cur, stops, shape_ids, route_id=None, - breakpoints_cache=None): +def find_best_segments(cur, stops, shape_ids, route_id=None, breakpoints_cache=None): """Finds the best shape_id for a stop-sequence. This is used in cases like when you have GPS data with a route @@ -147,18 +146,20 @@ def find_best_segments(cur, stops, shape_ids, route_id=None, # Calculate a cache key for this sequence. If shape_id and # all stop_Is are the same, then we assume that it is the same # route and re-use existing breakpoints. - cache_key = (route_id, tuple(x['stop_I'] for x in stops)) + cache_key = (route_id, tuple(x["stop_I"] for x in stops)) if cache_key in breakpoints_cache: - print('found in cache') + print("found in cache") return breakpoints_cache[cache_key] if route_id is not None: - cur.execute('''SELECT DISTINCT shape_id + cur.execute( + """SELECT DISTINCT shape_id FROM routes LEFT JOIN trips USING (route_I) - WHERE route_id=?''', - (route_id,)) + WHERE route_id=?""", + (route_id,), + ) data = cur.fetchall() # If not data, then route_id didn't match anything, or there # were no shapes defined. We have to exit in this case. @@ -173,7 +174,7 @@ def find_best_segments(cur, stops, shape_ids, route_id=None, shape = get_shape_points(cur, shape_id) breakpoints, badness = find_segments(stops, shape) results.append([badness, breakpoints, shape, shape_id]) - if len(stops) > 5 and badness < 5*(len(stops)): + if len(stops) > 5 and badness < 5 * (len(stops)): break best = np.argmin(zip(*results)[0]) @@ -201,12 +202,12 @@ def return_segments(shape, break_points): # print break_points # assert len(stops) == len(break_points) segs = [] - bp = 0 # not used + bp = 0 # not used bp2 = 0 - for i in range(len(break_points)-1): + for i in range(len(break_points) - 1): bp = break_points[i] if break_points[i] is not None else bp2 - bp2 = break_points[i+1] if break_points[i+1] is not None else bp - segs.append(shape[bp:bp2+1]) + bp2 = break_points[i + 1] if break_points[i + 1] is not None else bp + segs.append(shape[bp : bp2 + 1]) segs.append([]) return segs @@ -228,14 +229,13 @@ def gen_cumulative_distances(stops): and the function adds the 'd' key ('d' stands for distance) to the dictionaries """ - stops[0]['d'] = 0.0 + stops[0]["d"] = 0.0 for i in range(1, len(stops)): - stops[i]['d'] = stops[i-1]['d'] + wgs84_distance( - stops[i-1]['lat'], stops[i-1]['lon'], - stops[i]['lat'], stops[i]['lon'], - ) + stops[i]["d"] = stops[i - 1]["d"] + wgs84_distance( + stops[i - 1]["lat"], stops[i - 1]["lon"], stops[i]["lat"], stops[i]["lon"], + ) for stop in stops: - stop['d'] = int(stop['d']) + stop["d"] = int(stop["d"]) # stop['d'] = round(stop['d'], 1) @@ -255,10 +255,12 @@ def get_shape_points(cur, shape_id): shape_points: list elements are dictionaries containing the 'seq', 'lat', and 'lon' of the shape """ - cur.execute('''SELECT seq, lat, lon, d FROM shapes where shape_id=? - ORDER BY seq''', (shape_id,)) - shape_points = [dict(seq=row[0], lat=row[1], lon=row[2], d=row[3]) - for row in cur] + cur.execute( + """SELECT seq, lat, lon, d FROM shapes where shape_id=? + ORDER BY seq""", + (shape_id,), + ) + shape_points = [dict(seq=row[0], lat=row[1], lon=row[2], d=row[3]) for row in cur] return shape_points @@ -279,14 +281,17 @@ def get_shape_points2(cur, shape_id): shape_points: dict of lists dict contains keys 'seq', 'lat', 'lon', and 'd'(istance) of the shape """ - cur.execute('''SELECT seq, lat, lon, d FROM shapes where shape_id=? - ORDER BY seq''', (shape_id,)) - shape_points = {'seqs': [], 'lats': [], 'lons': [], 'd': []} + cur.execute( + """SELECT seq, lat, lon, d FROM shapes where shape_id=? + ORDER BY seq""", + (shape_id,), + ) + shape_points = {"seqs": [], "lats": [], "lons": [], "d": []} for row in cur: - shape_points['seqs'].append(row[0]) - shape_points['lats'].append(row[1]) - shape_points['lons'].append(row[2]) - shape_points['d'].append(row[3]) + shape_points["seqs"].append(row[0]) + shape_points["lats"].append(row[1]) + shape_points["lons"].append(row[2]) + shape_points["d"].append(row[3]) return shape_points @@ -306,7 +311,8 @@ def get_route_shape_segments(cur, route_id): shape_points: list elements are dictionaries containing the 'seq', 'lat', and 'lon' of the shape """ - cur.execute('''SELECT seq, lat, lon + cur.execute( + """SELECT seq, lat, lon FROM ( SELECT shape_id FROM route @@ -316,7 +322,9 @@ def get_route_shape_segments(cur, route_id): ) JOIN shapes USING (shape_id) - ORDER BY seq''', (route_id,)) + ORDER BY seq""", + (route_id,), + ) shape_points = [dict(seq=row[0], lat=row[1], lon=row[2]) for row in cur] return shape_points @@ -363,7 +371,10 @@ def get_shape_between_stops(cur, trip_I, seq_stop1=None, seq_stop2=None, shape_b for seq_stop in [seq_stop1, seq_stop2]: query = """SELECT shape_break FROM stop_times WHERE trip_I=%d AND seq=%d - """ % (trip_I, seq_stop) + """ % ( + trip_I, + seq_stop, + ) for row in cur.execute(query): shape_breaks.append(row[0]) assert len(shape_breaks) == 2 @@ -372,16 +383,20 @@ def get_shape_between_stops(cur, trip_I, seq_stop1=None, seq_stop2=None, shape_b FROM (SELECT shape_id FROM trips WHERE trip_I=%d) JOIN shapes USING (shape_id) WHERE seq>=%d AND seq <= %d; - """ % (trip_I, shape_breaks[0], shape_breaks[1]) - shapedict = {'lat': [], 'lon': [], 'seq': []} + """ % ( + trip_I, + shape_breaks[0], + shape_breaks[1], + ) + shapedict = {"lat": [], "lon": [], "seq": []} for row in cur.execute(query): - shapedict['seq'].append(row[0]) - shapedict['lat'].append(row[1]) - shapedict['lon'].append(row[2]) + shapedict["seq"].append(row[0]) + shapedict["lat"].append(row[1]) + shapedict["lon"].append(row[2]) return shapedict -def get_trip_points(cur, route_id, offset=0, tripid_glob=''): +def get_trip_points(cur, route_id, offset=0, tripid_glob=""): """Get all scheduled stops on a particular route_id. Given a route_id, return the trip-stop-list with @@ -407,16 +422,19 @@ def get_trip_points(cur, route_id, offset=0, tripid_glob=''): stop-list List of stops in stop-seq format. """ - extra_where = '' + extra_where = "" if tripid_glob: extra_where = "AND trip_id GLOB '%s'" % tripid_glob - cur.execute('SELECT seq, lat, lon ' - 'FROM (select trip_I from route ' - ' LEFT JOIN trips USING (route_I) ' - ' WHERE route_id=? %s limit 1 offset ? ) ' - 'JOIN stop_times USING (trip_I) ' - 'LEFT JOIN stop USING (stop_id) ' - 'ORDER BY seq' % extra_where, (route_id, offset)) + cur.execute( + "SELECT seq, lat, lon " + "FROM (select trip_I from route " + " LEFT JOIN trips USING (route_I) " + " WHERE route_id=? %s limit 1 offset ? ) " + "JOIN stop_times USING (trip_I) " + "LEFT JOIN stop USING (stop_id) " + "ORDER BY seq" % extra_where, + (route_id, offset), + ) stop_points = [dict(seq=row[0], lat=row[1], lon=row[2]) for row in cur] return stop_points @@ -444,20 +462,21 @@ def interpolate_shape_times(shape_distances, shape_breaks, stop_times): given the value of the last shape point. """ shape_times = np.zeros(len(shape_distances)) - shape_times[:shape_breaks[0]] = stop_times[0] - for i in range(len(shape_breaks)-1): + shape_times[: shape_breaks[0]] = stop_times[0] + for i in range(len(shape_breaks) - 1): cur_break = shape_breaks[i] cur_time = stop_times[i] - next_break = shape_breaks[i+1] - next_time = stop_times[i+1] + next_break = shape_breaks[i + 1] + next_time = stop_times[i + 1] if cur_break == next_break: shape_times[cur_break] = stop_times[i] else: - cur_distances = shape_distances[cur_break:next_break+1] - norm_distances = ((np.array(cur_distances)-float(cur_distances[0])) / - float(cur_distances[-1] - cur_distances[0])) - times = (1.-norm_distances)*cur_time+norm_distances*next_time + cur_distances = shape_distances[cur_break : next_break + 1] + norm_distances = (np.array(cur_distances) - float(cur_distances[0])) / float( + cur_distances[-1] - cur_distances[0] + ) + times = (1.0 - norm_distances) * cur_time + norm_distances * next_time shape_times[cur_break:next_break] = times[:-1] # deal final ones separately: - shape_times[shape_breaks[-1]:] = stop_times[-1] + shape_times[shape_breaks[-1] :] = stop_times[-1] return list(shape_times) diff --git a/gtfspy/spreading/event.py b/gtfspy/spreading/event.py index c39c917..86dad18 100644 --- a/gtfspy/spreading/event.py +++ b/gtfspy/spreading/event.py @@ -1,3 +1,3 @@ from collections import namedtuple -Event = namedtuple('Event', ['arr_time_ut', 'dep_time_ut', 'from_stop_I', 'to_stop_I', 'trip_I']) +Event = namedtuple("Event", ["arr_time_ut", "dep_time_ut", "from_stop_I", "to_stop_I", "trip_I"]) diff --git a/gtfspy/spreading/heap.py b/gtfspy/spreading/heap.py index f438c2e..44ff48b 100644 --- a/gtfspy/spreading/heap.py +++ b/gtfspy/spreading/heap.py @@ -5,6 +5,7 @@ from gtfspy.route_types import WALK from .event import Event + class EventHeap: """ EventHeap represents a container for the event @@ -19,7 +20,7 @@ def __init__(self, pd_df=None): Initial list of """ self.heap = [] - keys = ['arr_time_ut', 'dep_time_ut', 'from_stop_I', 'to_stop_I', 'trip_I'] + keys = ["arr_time_ut", "dep_time_ut", "from_stop_I", "to_stop_I", "trip_I"] # pd_df.iterrows() is slow as it creates new Series objects! n = len(pd_df) @@ -55,7 +56,9 @@ def size(self): """ return len(self.heap) - def add_walk_events_to_heap(self, transfer_distances, e, start_time_ut, walk_speed, uninfected_stops, max_duration_ut): + def add_walk_events_to_heap( + self, transfer_distances, e, start_time_ut, walk_speed, uninfected_stops, max_duration_ut + ): """ Parameters ---------- @@ -68,14 +71,14 @@ def add_walk_events_to_heap(self, transfer_distances, e, start_time_ut, walk_spe """ n = len(transfer_distances) dists_values = transfer_distances.values - to_stop_I_index = np.nonzero(transfer_distances.columns == 'to_stop_I')[0][0] - d_index = np.nonzero(transfer_distances.columns == 'd')[0][0] + to_stop_I_index = np.nonzero(transfer_distances.columns == "to_stop_I")[0][0] + d_index = np.nonzero(transfer_distances.columns == "d")[0][0] for i in range(n): transfer_to_stop_I = dists_values[i, to_stop_I_index] if transfer_to_stop_I in uninfected_stops: d = dists_values[i, d_index] - transfer_arr_time = e.arr_time_ut + int(d/float(walk_speed)) - if transfer_arr_time > start_time_ut+max_duration_ut: + transfer_arr_time = e.arr_time_ut + int(d / float(walk_speed)) + if transfer_arr_time > start_time_ut + max_duration_ut: continue te = Event(transfer_arr_time, e.arr_time_ut, e.to_stop_I, transfer_to_stop_I, WALK) self.add_event(te) diff --git a/gtfspy/spreading/spreader.py b/gtfspy/spreading/spreader.py index 7fd1089..8454f39 100644 --- a/gtfspy/spreading/spreader.py +++ b/gtfspy/spreading/spreader.py @@ -15,8 +15,17 @@ class Spreader(object): shortest path spreading dynamics as trips, or "events". """ - def __init__(self, gtfs, start_time_ut, lat, lon, max_duration_ut, min_transfer_time=30, - shapes=True, walk_speed=0.5): + def __init__( + self, + gtfs, + start_time_ut, + lat, + lon, + max_duration_ut, + min_transfer_time=30, + shapes=True, + walk_speed=0.5, + ): """ Parameters ---------- @@ -56,8 +65,10 @@ def spread(self): def _initialize(self): if self._initialized: - raise RuntimeError("This spreader instance has already been initialized: " - "create a new Spreader object for a new run.") + raise RuntimeError( + "This spreader instance has already been initialized: " + "create a new Spreader object for a new run." + ) # events are sorted by arrival time, so in order to use the # heapq, we need to have events coded as # (arrival_time, (from_stop, to_stop)) @@ -66,7 +77,7 @@ def _initialize(self): print("Computing/fetching events") events_df = self.gtfs.get_transit_events(self.start_time_ut, end_time_ut) - all_stops = set(self.gtfs.stops()['stop_I']) + all_stops = set(self.gtfs.stops()["stop_I"]) self._uninfected_stops = all_stops.copy() self._uninfected_stops.remove(start_stop_I) @@ -74,9 +85,7 @@ def _initialize(self): # match stop_I to a more advanced stop object seed_stop = SpreadingStop(start_stop_I, self.min_transfer_time) - self._stop_I_to_spreading_stop = { - start_stop_I: seed_stop - } + self._stop_I_to_spreading_stop = {start_stop_I: seed_stop} for stop in self._uninfected_stops: self._stop_I_to_spreading_stop[stop] = SpreadingStop(stop, self.min_transfer_time) @@ -84,11 +93,9 @@ def _initialize(self): print("intializing heap") self.event_heap = EventHeap(events_df) - start_event = Event(self.start_time_ut - 1, - self.start_time_ut - 1, - start_stop_I, - start_stop_I, - -1) + start_event = Event( + self.start_time_ut - 1, self.start_time_ut - 1, start_stop_I, start_stop_I, -1 + ) seed_stop.visit(start_event) assert len(seed_stop.visit_events) > 0 @@ -100,7 +107,7 @@ def _initialize(self): self.start_time_ut, self.walk_speed, self._uninfected_stops, - self.max_duration_ut + self.max_duration_ut, ) self._initialized = True @@ -109,8 +116,10 @@ def _run(self): Run the actual simulation. """ if self._has_run: - raise RuntimeError("This spreader instance has already been run: " - "create a new Spreader object for a new run.") + raise RuntimeError( + "This spreader instance has already been run: " + "create a new Spreader object for a new run." + ) i = 1 while self.event_heap.size() > 0 and len(self._uninfected_stops) > 0: event = self.event_heap.pop_next_event() @@ -128,10 +137,17 @@ def _run(self): if not already_visited: self._uninfected_stops.remove(event.to_stop_I) print(i, self.event_heap.size()) - transfer_distances = self.gtfs.get_straight_line_transfer_distances(event.to_stop_I) - self.event_heap.add_walk_events_to_heap(transfer_distances, event, self.start_time_ut, - self.walk_speed, self._uninfected_stops, - self.max_duration_ut) + transfer_distances = self.gtfs.get_straight_line_transfer_distances( + event.to_stop_I + ) + self.event_heap.add_walk_events_to_heap( + transfer_distances, + event, + self.start_time_ut, + self.walk_speed, + self._uninfected_stops, + self.max_duration_ut, + ) i += 1 self._has_run = True @@ -151,13 +167,17 @@ def _get_shortest_path_trips(self): if not self._has_run: raise RuntimeError("This spreader object has not run yet. Can not return any trips.") # create new transfer events and add them to the heap (=queue) - inf_times = [[stop_I, el.get_min_visit_time() - self.start_time_ut] - for stop_I, el in self._stop_I_to_spreading_stop.items()] + inf_times = [ + [stop_I, el.get_min_visit_time() - self.start_time_ut] + for stop_I, el in self._stop_I_to_spreading_stop.items() + ] inf_times = numpy.array(inf_times) inf_time_data = pd.DataFrame(inf_times, columns=["stop_I", "inf_time_ut"]) stop_data = self.gtfs.stops() - combined = inf_time_data.merge(stop_data, how='inner', on='stop_I', suffixes=('_infs', '_stops'), copy=True) + combined = inf_time_data.merge( + stop_data, how="inner", on="stop_I", suffixes=("_infs", "_stops"), copy=True + ) trips = [] for stop_I, dest_stop_obj in self._stop_I_to_spreading_stop.items(): @@ -165,11 +185,11 @@ def _get_shortest_path_trips(self): if inf_event is None: continue dep_stop_I = inf_event.from_stop_I - dep_lat = float(combined[combined['stop_I'] == dep_stop_I]['lat'].values) - dep_lon = float(combined[combined['stop_I'] == dep_stop_I]['lon'].values) + dep_lat = float(combined[combined["stop_I"] == dep_stop_I]["lat"].values) + dep_lon = float(combined[combined["stop_I"] == dep_stop_I]["lon"].values) - dest_lat = float(combined[combined['stop_I'] == stop_I]['lat'].values) - dest_lon = float(combined[combined['stop_I'] == stop_I]['lon'].values) + dest_lat = float(combined[combined["stop_I"] == stop_I]["lat"].values) + dest_lon = float(combined[combined["stop_I"] == stop_I]["lon"].values) if inf_event.trip_I == -1: name = "walk" @@ -178,11 +198,11 @@ def _get_shortest_path_trips(self): name, rtype = self.gtfs.get_route_name_and_type_of_tripI(inf_event.trip_I) trip = { - "lats" : [dep_lat, dest_lat], - "lons" : [dep_lon, dest_lon], - "times" : [inf_event.dep_time_ut, inf_event.arr_time_ut], - "name" : name, - "route_type": rtype + "lats": [dep_lat, dest_lat], + "lons": [dep_lon, dest_lon], + "times": [inf_event.dep_time_ut, inf_event.arr_time_ut], + "name": name, + "route_type": rtype, } trips.append(trip) return {"trips": trips} diff --git a/gtfspy/spreading/spreading_stop.py b/gtfspy/spreading/spreading_stop.py index 46050e8..f0ad0eb 100644 --- a/gtfspy/spreading/spreading_stop.py +++ b/gtfspy/spreading/spreading_stop.py @@ -1,5 +1,4 @@ class SpreadingStop: - def __init__(self, stop_I, min_transfer_time): self.stop_I = stop_I self.min_transfer_time = min_transfer_time @@ -10,7 +9,7 @@ def get_min_visit_time(self): Get the earliest visit time of the stop. """ if not self.visit_events: - return float('inf') + return float("inf") else: return min(self.visit_events, key=lambda event: event.arr_time_ut).arr_time_ut @@ -36,7 +35,7 @@ def visit(self, event): if visit is stored, returns True, otherwise False """ to_visit = False - if event.arr_time_ut <= self.min_transfer_time+self.get_min_visit_time(): + if event.arr_time_ut <= self.min_transfer_time + self.get_min_visit_time(): to_visit = True else: for ve in self.visit_events: @@ -47,7 +46,9 @@ def visit(self, event): self.visit_events.append(event) min_time = self.get_min_visit_time() # remove any visits that are 'too old' - self.visit_events = [v for v in self.visit_events if v.arr_time_ut <= min_time+self.min_transfer_time] + self.visit_events = [ + v for v in self.visit_events if v.arr_time_ut <= min_time + self.min_transfer_time + ] return to_visit def has_been_visited(self): @@ -63,7 +64,7 @@ def can_infect(self, event): if not self.has_been_visited(): return False else: - time_sep = event.dep_time_ut-self.get_min_visit_time() + time_sep = event.dep_time_ut - self.get_min_visit_time() # if the gap between the earliest visit_time and current time is # smaller than the min. transfer time, the stop can pass the spreading # forward diff --git a/gtfspy/stats.py b/gtfspy/stats.py index c89a715..b1f9577 100644 --- a/gtfspy/stats.py +++ b/gtfspy/stats.py @@ -1,11 +1,11 @@ from __future__ import unicode_literals import csv -import pandas as pd +import os +import sys import numpy -import sys -import os +import pandas as pd from gtfspy.gtfs import GTFS from gtfspy.util import wgs84_distance @@ -25,12 +25,12 @@ def get_spatial_bounds(gtfs, as_dict=False): max_lat: float """ stats = get_stats(gtfs) - lon_min = stats['lon_min'] - lon_max = stats['lon_max'] - lat_min = stats['lat_min'] - lat_max = stats['lat_max'] + lon_min = stats["lon_min"] + lon_max = stats["lon_max"] + lat_min = stats["lat_min"] + lat_max = stats["lat_max"] if as_dict: - return {'lon_min': lon_min, 'lon_max': lon_max, 'lat_min': lat_min, 'lat_max': lat_max} + return {"lon_min": lon_min, "lon_max": lon_max, "lat_min": lat_min, "lat_max": lat_max} else: return lon_min, lon_max, lat_min, lat_max @@ -38,10 +38,10 @@ def get_spatial_bounds(gtfs, as_dict=False): def get_percentile_stop_bounds(gtfs, percentile): stops = gtfs.get_table("stops") percentile = min(percentile, 100 - percentile) - lat_min = numpy.percentile(stops['lat'].values, percentile) - lat_max = numpy.percentile(stops['lat'].values, 100 - percentile) - lon_min = numpy.percentile(stops['lon'].values, percentile) - lon_max = numpy.percentile(stops['lon'].values, 100 - percentile) + lat_min = numpy.percentile(stops["lat"].values, percentile) + lat_max = numpy.percentile(stops["lat"].values, 100 - percentile) + lon_min = numpy.percentile(stops["lon"].values, percentile) + lon_max = numpy.percentile(stops["lon"].values, 100 - percentile) return lon_min, lon_max, lat_min, lat_max @@ -59,10 +59,11 @@ def get_median_lat_lon_of_stops(gtfs): median_lon : float """ stops = gtfs.get_table("stops") - median_lat = numpy.percentile(stops['lat'].values, 50) - median_lon = numpy.percentile(stops['lon'].values, 50) + median_lat = numpy.percentile(stops["lat"].values, 50) + median_lon = numpy.percentile(stops["lon"].values, 50) return median_lat, median_lon + def get_centroid_of_stops(gtfs): """ Get mean latitude AND longitude of stops @@ -77,8 +78,8 @@ def get_centroid_of_stops(gtfs): mean_lon : float """ stops = gtfs.get_table("stops") - mean_lat = numpy.mean(stops['lat'].values) - mean_lon = numpy.mean(stops['lon'].values) + mean_lat = numpy.mean(stops["lat"].values) + mean_lon = numpy.mean(stops["lon"].values) return mean_lat, mean_lon @@ -97,25 +98,25 @@ def write_stats_as_csv(gtfs, path_to_csv, re_write=False): stats_dict = get_stats(gtfs) # check if file exist if re_write: - os.remove(path_to_csv) - - #if not os.path.isfile(path_to_csv): - # is_new = True - #else: - # is_new = False - + os.remove(path_to_csv) + + # if not os.path.isfile(path_to_csv): + # is_new = True + # else: + # is_new = False + is_new = True - mode = 'r' if os.path.exists(path_to_csv) else 'w+' + mode = "r" if os.path.exists(path_to_csv) else "w+" with open(path_to_csv, mode) as csvfile: for line in csvfile: - if line: - is_new = False - else: - is_new = True - - with open(path_to_csv, 'a') as csvfile: - if (sys.version_info > (3, 0)): - delimiter = u"," + if line: + is_new = False + else: + is_new = True + + with open(path_to_csv, "a") as csvfile: + if sys.version_info > (3, 0): + delimiter = "," else: delimiter = b"," statswriter = csv.writer(csvfile, delimiter=delimiter) @@ -147,18 +148,31 @@ def get_stats(gtfs): """ stats = {} # Basic table counts - for table in ['agencies', 'routes', 'stops', 'stop_times', 'trips', 'calendar', 'shapes', 'calendar_dates', - 'days', 'stop_distances', 'frequencies', 'feed_info', 'transfers']: + for table in [ + "agencies", + "routes", + "stops", + "stop_times", + "trips", + "calendar", + "shapes", + "calendar_dates", + "days", + "stop_distances", + "frequencies", + "feed_info", + "transfers", + ]: stats["n_" + table] = gtfs.get_row_count(table) # Agency names agencies = gtfs.get_table("agencies") - stats["agencies"] = "_".join(agencies['name'].values) + stats["agencies"] = "_".join(agencies["name"].values) # Stop lat/lon range stops = gtfs.get_table("stops") - lats = stops['lat'].values - lons = stops['lon'].values + lats = stops["lat"].values + lons = stops["lon"].values percentiles = [0, 10, 50, 90, 100] try: @@ -184,8 +198,8 @@ def get_stats(gtfs): stats["lon_max"] = lon_max if len(lats) > 0: - stats["height_km"] = wgs84_distance(lat_min, lon_median, lat_max, lon_median) / 1000. - stats["width_km"] = wgs84_distance(lon_min, lat_median, lon_max, lat_median) / 1000. + stats["height_km"] = wgs84_distance(lat_min, lon_median, lat_max, lon_median) / 1000.0 + stats["width_km"] = wgs84_distance(lon_min, lat_median, lon_max, lat_median) / 1000.0 else: stats["height_km"] = None stats["width_km"] = None @@ -203,17 +217,24 @@ def get_stats(gtfs): # Maximum activity day max_activity_date = gtfs.execute_custom_query( - 'SELECT count(*), date ' - 'FROM days ' - 'GROUP BY date ' - 'ORDER BY count(*) DESC, date ' - 'LIMIT 1;').fetchone() + "SELECT count(*), date " + "FROM days " + "GROUP BY date " + "ORDER BY count(*) DESC, date " + "LIMIT 1;" + ).fetchone() if max_activity_date: stats["max_activity_date"] = max_activity_date[1] - max_activity_hour = gtfs.get_cursor().execute( - 'SELECT count(*), arr_time_hour FROM day_stop_times ' - 'WHERE date=? GROUP BY arr_time_hour ' - 'ORDER BY count(*) DESC;', (stats["max_activity_date"],)).fetchone() + max_activity_hour = ( + gtfs.get_cursor() + .execute( + "SELECT count(*), arr_time_hour FROM day_stop_times " + "WHERE date=? GROUP BY arr_time_hour " + "ORDER BY count(*) DESC;", + (stats["max_activity_date"],), + ) + .fetchone() + ) if max_activity_hour: stats["max_activity_hour"] = max_activity_hour[1] else: @@ -221,21 +242,25 @@ def get_stats(gtfs): # Fleet size estimate: considering each line separately if max_activity_date and max_activity_hour: - fleet_size_estimates = _fleet_size_estimate(gtfs, stats['max_activity_hour'], stats['max_activity_date']) + fleet_size_estimates = _fleet_size_estimate( + gtfs, stats["max_activity_hour"], stats["max_activity_date"] + ) stats.update(fleet_size_estimates) # Compute simple distributions of various columns that have a finite range of values. # Commented lines refer to values that are not imported yet, ? - stats['routes__type__dist'] = _distribution(gtfs, 'routes', 'type') + stats["routes__type__dist"] = _distribution(gtfs, "routes", "type") # stats['stop_times__pickup_type__dist'] = _distribution(gtfs, 'stop_times', 'pickup_type') # stats['stop_times__drop_off_type__dist'] = _distribution(gtfs, 'stop_times', 'drop_off_type') # stats['stop_times__timepoint__dist'] = _distribution(gtfs, 'stop_times', 'timepoint') - stats['calendar_dates__exception_type__dist'] = _distribution(gtfs, 'calendar_dates', 'exception_type') - stats['frequencies__exact_times__dist'] = _distribution(gtfs, 'frequencies', 'exact_times') - stats['transfers__transfer_type__dist'] = _distribution(gtfs, 'transfers', 'transfer_type') - stats['agencies__lang__dist'] = _distribution(gtfs, 'agencies', 'lang') - stats['stops__location_type__dist'] = _distribution(gtfs, 'stops', 'location_type') + stats["calendar_dates__exception_type__dist"] = _distribution( + gtfs, "calendar_dates", "exception_type" + ) + stats["frequencies__exact_times__dist"] = _distribution(gtfs, "frequencies", "exact_times") + stats["transfers__transfer_type__dist"] = _distribution(gtfs, "transfers", "transfer_type") + stats["agencies__lang__dist"] = _distribution(gtfs, "agencies", "lang") + stats["stops__location_type__dist"] = _distribution(gtfs, "stops", "location_type") # stats['stops__wheelchair_boarding__dist'] = _distribution(gtfs, 'stops', 'wheelchair_boarding') # stats['trips__wheelchair_accessible__dist'] = _distribution(gtfs, 'trips', 'wheelchair_accessible') # stats['trips__bikes_allowed__dist'] = _distribution(gtfs, 'trips', 'bikes_allowed') @@ -250,10 +275,12 @@ def _distribution(gtfs, table, column): Example return value: '1:5 2:15'""" cur = gtfs.conn.cursor() - cur.execute('SELECT {column}, count(*) ' - 'FROM {table} GROUP BY {column} ' - 'ORDER BY {column}'.format(column=column, table=table)) - return ' '.join('%s:%s' % (t, c) for t, c in cur) + cur.execute( + "SELECT {column}, count(*) " + "FROM {table} GROUP BY {column} " + "ORDER BY {column}".format(column=column, table=table) + ) + return " ".join("%s:%s" % (t, c) for t, c in cur) def _fleet_size_estimate(gtfs, hour, date): @@ -281,44 +308,46 @@ def _fleet_size_estimate(gtfs, hour, date): fleet_size_list = [] cur = gtfs.conn.cursor() rows = cur.execute( - 'SELECT type, max(vehicles) ' - 'FROM (' - 'SELECT type, direction_id, sum(vehicles) as vehicles ' - 'FROM ' - '(' - 'SELECT trips.route_I, trips.direction_id, routes.route_id, name, type, count(*) as vehicles, cycle_time_min ' - 'FROM trips, routes, days, ' - '(' - 'SELECT first_trip.route_I, first_trip.direction_id, first_trip_start_time, first_trip_end_time, ' - 'MIN(start_time_ds) as return_trip_start_time, end_time_ds as return_trip_end_time, ' - '(end_time_ds - first_trip_start_time)/60 as cycle_time_min ' - 'FROM ' - 'trips, ' - '(SELECT route_I, direction_id, MIN(start_time_ds) as first_trip_start_time, ' - 'end_time_ds as first_trip_end_time ' - 'FROM trips, days ' - 'WHERE trips.trip_I=days.trip_I AND start_time_ds >= ? * 3600 ' - 'AND start_time_ds <= (? + 1) * 3600 AND date = ? ' - 'GROUP BY route_I, direction_id) first_trip ' - 'WHERE first_trip.route_I = trips.route_I ' - 'AND first_trip.direction_id != trips.direction_id ' - 'AND start_time_ds >= first_trip_end_time ' - 'GROUP BY trips.route_I, trips.direction_id' - ') return_trip ' - 'WHERE trips.trip_I=days.trip_I AND trips.route_I= routes.route_I ' - 'AND date = ? AND trips.route_I = return_trip.route_I ' - 'AND trips.direction_id = return_trip.direction_id ' - 'AND start_time_ds >= first_trip_start_time ' - 'AND start_time_ds < return_trip_end_time ' - 'GROUP BY trips.route_I, trips.direction_id ' - 'ORDER BY type, name, vehicles desc' - ') cycle_times ' - 'GROUP BY direction_id, type' - ') vehicles_type ' - 'GROUP BY type;', (hour, hour, date, date)) + "SELECT type, max(vehicles) " + "FROM (" + "SELECT type, direction_id, sum(vehicles) as vehicles " + "FROM " + "(" + "SELECT trips.route_I, trips.direction_id, routes.route_id, name, type, count(*) as vehicles, cycle_time_min " + "FROM trips, routes, days, " + "(" + "SELECT first_trip.route_I, first_trip.direction_id, first_trip_start_time, first_trip_end_time, " + "MIN(start_time_ds) as return_trip_start_time, end_time_ds as return_trip_end_time, " + "(end_time_ds - first_trip_start_time)/60 as cycle_time_min " + "FROM " + "trips, " + "(SELECT route_I, direction_id, MIN(start_time_ds) as first_trip_start_time, " + "end_time_ds as first_trip_end_time " + "FROM trips, days " + "WHERE trips.trip_I=days.trip_I AND start_time_ds >= ? * 3600 " + "AND start_time_ds <= (? + 1) * 3600 AND date = ? " + "GROUP BY route_I, direction_id) first_trip " + "WHERE first_trip.route_I = trips.route_I " + "AND first_trip.direction_id != trips.direction_id " + "AND start_time_ds >= first_trip_end_time " + "GROUP BY trips.route_I, trips.direction_id" + ") return_trip " + "WHERE trips.trip_I=days.trip_I AND trips.route_I= routes.route_I " + "AND date = ? AND trips.route_I = return_trip.route_I " + "AND trips.direction_id = return_trip.direction_id " + "AND start_time_ds >= first_trip_start_time " + "AND start_time_ds < return_trip_end_time " + "GROUP BY trips.route_I, trips.direction_id " + "ORDER BY type, name, vehicles desc" + ") cycle_times " + "GROUP BY direction_id, type" + ") vehicles_type " + "GROUP BY type;", + (hour, hour, date, date), + ) for row in rows: - fleet_size_list.append(str(row[0]) + ':' + str(row[1])) - results['fleet_size_route_based'] = " ".join(fleet_size_list) + fleet_size_list.append(str(row[0]) + ":" + str(row[1])) + results["fleet_size_route_based"] = " ".join(fleet_size_list) # Fleet size estimate: maximum number of vehicles in movement fleet_size_list = [] @@ -326,29 +355,31 @@ def _fleet_size_estimate(gtfs, hour, date): if hour: for minute in range(hour * 3600, (hour + 1) * 3600, 60): rows = gtfs.conn.cursor().execute( - 'SELECT type, count(*) ' - 'FROM trips, routes, days ' - 'WHERE trips.route_I = routes.route_I ' - 'AND trips.trip_I=days.trip_I ' - 'AND start_time_ds <= ? ' - 'AND end_time_ds > ? + 60 ' - 'AND date = ? ' - 'GROUP BY type;', - (minute, minute, date)) + "SELECT type, count(*) " + "FROM trips, routes, days " + "WHERE trips.route_I = routes.route_I " + "AND trips.trip_I=days.trip_I " + "AND start_time_ds <= ? " + "AND end_time_ds > ? + 60 " + "AND date = ? " + "GROUP BY type;", + (minute, minute, date), + ) for row in rows: if fleet_size_dict.get(row[0], 0) < row[1]: fleet_size_dict[row[0]] = row[1] for key in fleet_size_dict.keys(): - fleet_size_list.append(str(key) + ':' + str(fleet_size_dict[key])) - results["fleet_size_max_movement"] = ' '.join(fleet_size_list) + fleet_size_list.append(str(key) + ":" + str(fleet_size_dict[key])) + results["fleet_size_max_movement"] = " ".join(fleet_size_list) return results def _n_gtfs_sources(gtfs): n_gtfs_sources = gtfs.execute_custom_query( - "SELECT value FROM metadata WHERE key = 'n_gtfs_sources';").fetchone() + "SELECT value FROM metadata WHERE key = 'n_gtfs_sources';" + ).fetchone() if not n_gtfs_sources: n_gtfs_sources = [1] return n_gtfs_sources @@ -376,9 +407,15 @@ def _feed_calendar_span(gtfs, stats): feed_key = "feed_" + str(i) + "_" start_key = feed_key + "calendar_start" end_key = feed_key + "calendar_end" - calendar_span = gtfs.conn.cursor().execute( - 'SELECT min(date), max(date) FROM trips, days ' - 'WHERE trips.trip_I = days.trip_I AND trip_id LIKE ?;', (feed_key + '%',)).fetchone() + calendar_span = ( + gtfs.conn.cursor() + .execute( + "SELECT min(date), max(date) FROM trips, days " + "WHERE trips.trip_I = days.trip_I AND trip_id LIKE ?;", + (feed_key + "%",), + ) + .fetchone() + ) stats[start_key] = calendar_span[0] stats[end_key] = calendar_span[1] @@ -431,37 +468,38 @@ def trip_stats(gtfs, results_by_mode=False): conn = gtfs.conn conn.create_function("find_distance", 4, wgs84_distance) - cur = conn.cursor() # this query calculates the distance and travel time for each complete trip # stop_data_df = pd.read_sql_query(query, self.conn, params=params) - query = 'SELECT ' \ - 'startstop.trip_I AS trip_I, ' \ - 'type, ' \ - 'sum(CAST(find_distance(startstop.lat, startstop.lon, endstop.lat, endstop.lon) AS INT)) as total_distance, ' \ - 'sum(endstop.arr_time_ds - startstop.arr_time_ds) as total_traveltime ' \ - 'FROM ' \ - '(SELECT * FROM stop_times, stops WHERE stop_times.stop_I = stops.stop_I) startstop, ' \ - '(SELECT * FROM stop_times, stops WHERE stop_times.stop_I = stops.stop_I) endstop, ' \ - 'trips, ' \ - 'routes ' \ - 'WHERE ' \ - 'startstop.trip_I = endstop.trip_I ' \ - 'AND startstop.seq + 1 = endstop.seq ' \ - 'AND startstop.trip_I = trips.trip_I ' \ - 'AND trips.route_I = routes.route_I ' \ - 'GROUP BY startstop.trip_I' + query = ( + "SELECT " + "startstop.trip_I AS trip_I, " + "type, " + "sum(CAST(find_distance(startstop.lat, startstop.lon, endstop.lat, endstop.lon) AS INT)) as total_distance, " + "sum(endstop.arr_time_ds - startstop.arr_time_ds) as total_traveltime " + "FROM " + "(SELECT * FROM stop_times, stops WHERE stop_times.stop_I = stops.stop_I) startstop, " + "(SELECT * FROM stop_times, stops WHERE stop_times.stop_I = stops.stop_I) endstop, " + "trips, " + "routes " + "WHERE " + "startstop.trip_I = endstop.trip_I " + "AND startstop.seq + 1 = endstop.seq " + "AND startstop.trip_I = trips.trip_I " + "AND trips.route_I = routes.route_I " + "GROUP BY startstop.trip_I" + ) q_result = pd.read_sql_query(query, conn) - q_result['avg_speed_kmh'] = 3.6 * q_result['total_distance'] / q_result['total_traveltime'] - q_result['total_distance'] = q_result['total_distance'] / 1000 - q_result['total_traveltime'] = q_result['total_traveltime'] / 60 - q_result = q_result.loc[q_result['avg_speed_kmh'] != float("inf")] + q_result["avg_speed_kmh"] = 3.6 * q_result["total_distance"] / q_result["total_traveltime"] + q_result["total_distance"] = q_result["total_distance"] / 1000 + q_result["total_traveltime"] = q_result["total_traveltime"] / 60 + q_result = q_result.loc[q_result["avg_speed_kmh"] != float("inf")] if results_by_mode: q_results = {} - for type in q_result['type'].unique().tolist(): - q_results[type] = q_result.loc[q_result['type'] == type] + for type in q_result["type"].unique().tolist(): # noqa: A003, A001 + q_results[type] = q_result.loc[q_result["type"] == type] # noqa: A003 return q_results else: return q_result @@ -471,33 +509,34 @@ def get_section_stats(gtfs, results_by_mode=False): conn = gtfs.conn conn.create_function("find_distance", 4, wgs84_distance) - cur = conn.cursor() # this query calculates the distance and travel time for each stop to stop section for each trip # stop_data_df = pd.read_sql_query(query, self.conn, params=params) - query = 'SELECT type, from_stop_I, to_stop_I, distance, min(travel_time) AS min_time, max(travel_time) AS max_time, avg(travel_time) AS mean_time ' \ - 'FROM ' \ - '(SELECT q1.trip_I, type, q1.stop_I as from_stop_I, q2.stop_I as to_stop_I, ' \ - 'CAST(find_distance(q1.lat, q1.lon, q2.lat, q2.lon) AS INT) as distance, ' \ - 'q2.arr_time_ds - q1.arr_time_ds as travel_time, ' \ - 'q1.lat AS from_lat, q1.lon AS from_lon, q2.lat AS to_lat, q2.lon AS to_lon ' \ - 'FROM ' \ - '(SELECT * FROM stop_times, stops WHERE stop_times.stop_I = stops.stop_I) q1, ' \ - '(SELECT * FROM stop_times, stops WHERE stop_times.stop_I = stops.stop_I) q2, ' \ - 'trips, ' \ - 'routes ' \ - 'WHERE q1.trip_I = q2.trip_I ' \ - 'AND q1.seq + 1 = q2.seq ' \ - 'AND q1.trip_I = trips.trip_I ' \ - 'AND trips.route_I = routes.route_I) sq1 ' \ - 'GROUP BY to_stop_I, from_stop_I, type ' + query = ( + "SELECT type, from_stop_I, to_stop_I, distance, min(travel_time) AS min_time, max(travel_time) AS max_time, avg(travel_time) AS mean_time " + "FROM " + "(SELECT q1.trip_I, type, q1.stop_I as from_stop_I, q2.stop_I as to_stop_I, " + "CAST(find_distance(q1.lat, q1.lon, q2.lat, q2.lon) AS INT) as distance, " + "q2.arr_time_ds - q1.arr_time_ds as travel_time, " + "q1.lat AS from_lat, q1.lon AS from_lon, q2.lat AS to_lat, q2.lon AS to_lon " + "FROM " + "(SELECT * FROM stop_times, stops WHERE stop_times.stop_I = stops.stop_I) q1, " + "(SELECT * FROM stop_times, stops WHERE stop_times.stop_I = stops.stop_I) q2, " + "trips, " + "routes " + "WHERE q1.trip_I = q2.trip_I " + "AND q1.seq + 1 = q2.seq " + "AND q1.trip_I = trips.trip_I " + "AND trips.route_I = routes.route_I) sq1 " + "GROUP BY to_stop_I, from_stop_I, type " + ) q_result = pd.read_sql_query(query, conn) if results_by_mode: q_results = {} - for type in q_result['type'].unique().tolist(): - q_results[type] = q_result.loc[q_result['type'] == type] + for type in q_result["type"].unique().tolist(): # noqa: A003, A001 + q_results[type] = q_result.loc[q_result["type"] == type] # noqa: A003 return q_results else: return q_result @@ -528,8 +567,9 @@ def route_frequencies(gtfs, results_by_mode=False): " GROUP by route_I, trip_I)" " GROUP BY route_I) as f" " ON f.route_I = r.route_I" - " ORDER BY frequency DESC".format(day=day)) - + " ORDER BY frequency DESC".format(day=day) + ) + return pd.DataFrame(gtfs.execute_custom_query_pandas(query)) @@ -550,55 +590,65 @@ def hourly_frequencies(gtfs, st, et, route_type): ------- numeric pandas.DataFrame with columns stop_I, lat, lon, frequency - """ - timeframe = et-st - hours = timeframe/ 3600 + """ + timeframe = et - st + hours = timeframe / 3600 day = gtfs.get_suitable_date_for_daily_extract() stops = gtfs.get_stops_for_route_type(route_type).T.drop_duplicates().T - query = ("SELECT * FROM stops as x" - " JOIN" - " (SELECT * , COUNT(*)/{h} as frequency" - " FROM stop_times, days" - " WHERE stop_times.trip_I = days.trip_I" - " AND dep_time_ds > {st}" - " AND dep_time_ds < {et}" - " AND date = '{day}'" - " GROUP BY stop_I) as y" - " ON y.stop_I = x.stop_I".format(h=hours, st=st, et=et, day=day)) + query = ( + "SELECT * FROM stops as x" + " JOIN" + " (SELECT * , COUNT(*)/{h} as frequency" + " FROM stop_times, days" + " WHERE stop_times.trip_I = days.trip_I" + " AND dep_time_ds > {st}" + " AND dep_time_ds < {et}" + " AND date = '{day}'" + " GROUP BY stop_I) as y" + " ON y.stop_I = x.stop_I".format(h=hours, st=st, et=et, day=day) + ) try: trips_frequency = gtfs.execute_custom_query_pandas(query).T.drop_duplicates().T - df = pd.merge(stops[['stop_I', 'lat', 'lon']], trips_frequency[['stop_I', 'frequency']], - on='stop_I', how='inner') + df = pd.merge( + stops[["stop_I", "lat", "lon"]], + trips_frequency[["stop_I", "frequency"]], + on="stop_I", + how="inner", + ) return df.apply(pd.to_numeric) except: raise ValueError("Maybe too short time frame!") def frequencies_by_generated_route(gtfs, st, et, day=None): - timeframe = et-st - hours = timeframe/3600 + timeframe = et - st + hours = timeframe / 3600 if not day: day = gtfs.get_suitable_date_for_daily_extract() - query = """SELECT count(*)/{h} AS frequency, count(*) AS n_trips, route, type FROM + query = """SELECT count(*)/{h} AS frequency, count(*) AS n_trips, route, type FROM (SELECT trip_I, group_concat(stop_I) AS route, name, type FROM (SELECT * FROM stop_times, days, trips, routes - WHERE stop_times.trip_I = days.trip_I AND stop_times.trip_I = trips.trip_I AND routes.route_I = trips.route_I AND + WHERE stop_times.trip_I = days.trip_I AND stop_times.trip_I = trips.trip_I AND routes.route_I = trips.route_I AND days.date = '{day}' AND start_time_ds >= {st} AND start_time_ds < {et} ORDER BY trip_I, seq) q1 GROUP BY trip_I) q2 - GROUP BY route""".format(h=hours, st=st, et=et, day=day) + GROUP BY route""".format( + h=hours, st=st, et=et, day=day + ) df = gtfs.execute_custom_query_pandas(query) return df def departure_stops(gtfs, st, et): day = gtfs.get_suitable_date_for_daily_extract() - query = """select stop_I, count(*) as n_departures from + query = """select stop_I, count(*) as n_departures from (select min(seq), * from stop_times, days, trips where stop_times.trip_I = days.trip_I and stop_times.trip_I = trips.trip_I and days.date = '{day}' and start_time_ds >= {st} and start_time_ds < {et} group by stop_times.trip_I) q1 - group by stop_I""".format(st=st, et=et, day=day) + group by stop_I""".format( + st=st, et=et, day=day + ) df = gtfs.execute_custom_query_pandas(query) df = gtfs.add_coordinates_to_df(df) return df @@ -610,16 +660,18 @@ def get_vehicle_hours_by_type(gtfs, route_type): """ day = gtfs.get_suitable_date_for_daily_extract() - query = (" SELECT * , SUM(end_time_ds - start_time_ds)/3600 as vehicle_hours_type" - " FROM" - " (SELECT * FROM day_trips as q1" - " INNER JOIN" - " (SELECT route_I, type FROM routes) as q2" - " ON q1.route_I = q2.route_I" - " WHERE type = {route_type}" - " AND date = '{day}')".format(day=day, route_type=route_type)) + query = ( + " SELECT * , SUM(end_time_ds - start_time_ds)/3600 as vehicle_hours_type" + " FROM" + " (SELECT * FROM day_trips as q1" + " INNER JOIN" + " (SELECT route_I, type FROM routes) as q2" + " ON q1.route_I = q2.route_I" + " WHERE type = {route_type}" + " AND date = '{day}')".format(day=day, route_type=route_type) + ) df = gtfs.execute_custom_query_pandas(query) - return df['vehicle_hours_type'].item() + return df["vehicle_hours_type"].item() def trips_frequencies(gtfs): @@ -631,8 +683,10 @@ def trips_frequencies(gtfs): " (SELECT * FROM stop_times) q1," " (SELECT * FROM stop_times) q2" " WHERE q1.seq+1=q2.seq AND q1.trip_I=q2.trip_I" - " GROUP BY from_stop_I, to_stop_I") - return(gtfs.execute_custom_query_pandas(query)) + " GROUP BY from_stop_I, to_stop_I" + ) + return gtfs.execute_custom_query_pandas(query) + # def route_circuity(): # pass diff --git a/gtfspy/test/test_exports.py b/gtfspy/test/test_exports.py index 155fcc1..33d1d40 100644 --- a/gtfspy/test/test_exports.py +++ b/gtfspy/test/test_exports.py @@ -17,7 +17,6 @@ class ExportsTest(unittest.TestCase): - @classmethod def setUpClass(cls): """ This method is run once before executing any tests""" @@ -29,7 +28,9 @@ def setUp(self): """This method is run once before _each_ test method is executed""" self.gtfs_source_dir = self.__class__.gtfs_source_dir self.gtfs = self.__class__.G - self.extract_output_dir = os.path.join(self.gtfs_source_dir, "../", "test_gtfspy_extracts_8211231/") + self.extract_output_dir = os.path.join( + self.gtfs_source_dir, "../", "test_gtfspy_extracts_8211231/" + ) if not os.path.exists(self.extract_output_dir): makedirs(self.extract_output_dir) @@ -45,10 +46,12 @@ def test_walk_network(self): self.assertIn("d", data_dict) self.assertGreaterEqual(data_dict["d"], 0) threshold = 670 - walk_net = networks.walk_transfer_stop_to_stop_network(self.gtfs, max_link_distance=threshold) + walk_net = networks.walk_transfer_stop_to_stop_network( + self.gtfs, max_link_distance=threshold + ) self.assertEqual(len(walk_net.edges()), 2) for form_node, to_node, data_dict in walk_net.edges(data=True): - self.assertLess(data_dict['d'], threshold) + self.assertLess(data_dict["d"], threshold) def test_write_stop_to_stop_networks(self): exports.write_static_networks(self.gtfs, self.extract_output_dir) @@ -63,9 +66,9 @@ def test_write_combined_stop_to_stop_networks(self): def test_stop_to_stop_network_by_route_type(self): # test that distance works - nxGraph = networks.stop_to_stop_network_for_route_type(self.gtfs, - BUS, - link_attributes=ALL_STOP_TO_STOP_LINK_ATTRIBUTES) + nxGraph = networks.stop_to_stop_network_for_route_type( + self.gtfs, BUS, link_attributes=ALL_STOP_TO_STOP_LINK_ATTRIBUTES + ) self.assertTrue(isinstance(nxGraph, networkx.DiGraph), type(nxGraph)) nodes = nxGraph.nodes(data=True) self.assertGreater(len(nodes), 0) @@ -81,18 +84,26 @@ def test_stop_to_stop_network_by_route_type(self): at_least_one_shape_distance = False for from_I, to_I, linkData in edges: - ds = linkData['distance_shape'] - self.assertTrue(isinstance(ds, int) or (ds is None), - "distance_shape should be either int or None (in case shapes are not available)") + ds = linkData["distance_shape"] + self.assertTrue( + isinstance(ds, int) or (ds is None), + "distance_shape should be either int or None (in case shapes are not available)", + ) if isinstance(ds, int): at_least_one_shape_distance = True - self.assertLessEqual(linkData['duration_min'], linkData["duration_avg"]) - self.assertLessEqual(linkData['duration_avg'], linkData["duration_max"]) - self.assertLessEqual(linkData['duration_median'], linkData["duration_max"]) - self.assertGreaterEqual(linkData['duration_median'], linkData["duration_min"]) - self.assertTrue(isinstance(linkData['d'], int), "straight line distance should always exist and be an int") - self.assertGreaterEqual(linkData['d'], 0, - "straight line distance should be always greater than or equal to 0 (?)") + self.assertLessEqual(linkData["duration_min"], linkData["duration_avg"]) + self.assertLessEqual(linkData["duration_avg"], linkData["duration_max"]) + self.assertLessEqual(linkData["duration_median"], linkData["duration_max"]) + self.assertGreaterEqual(linkData["duration_median"], linkData["duration_min"]) + self.assertTrue( + isinstance(linkData["d"], int), + "straight line distance should always exist and be an int", + ) + self.assertGreaterEqual( + linkData["d"], + 0, + "straight line distance should be always greater than or equal to 0 (?)", + ) n_veh = linkData["n_vehicles"] route_ids = linkData["route_I_counts"] route_ids_sum = sum([count for route_type, count in route_ids.items()]) @@ -115,8 +126,14 @@ def test_write_temporal_network(self): exports.write_temporal_network(self.gtfs, path, None, None) self.assertTrue(os.path.exists(path)) df = pandas.read_csv(path) - columns_should_exist = ["dep_time_ut", "arr_time_ut", "from_stop_I", "to_stop_I", - "route_type", "trip_I"] + columns_should_exist = [ + "dep_time_ut", + "arr_time_ut", + "from_stop_I", + "to_stop_I", + "route_type", + "trip_I", + ] for col in columns_should_exist: self.assertIn(col, df.columns.values) @@ -125,68 +142,105 @@ def test_write_temporal_networks_by_route_type(self): self.assertTrue(os.path.exists(os.path.join(self.extract_output_dir + "bus.tnet"))) def test_write_gtfs_agencies(self): - required_columns = 'agency_id,agency_name,agency_url,agency_timezone,agency_phone,agency_lang'.split(",") - optional_columns = ['agency_lang', 'agency_phone', 'agency_fare_url', 'agency_email'] - self.__test_write_gtfs_table(exports._write_gtfs_agencies, required_columns, optional_columns) + required_columns = "agency_id,agency_name,agency_url,agency_timezone,agency_phone,agency_lang".split( + "," + ) + optional_columns = ["agency_lang", "agency_phone", "agency_fare_url", "agency_email"] + self.__test_write_gtfs_table( + exports._write_gtfs_agencies, required_columns, optional_columns + ) def test_write_gtfs_stops(self): - required_columns = 'stop_id,stop_name,stop_desc,stop_lat,stop_lon'.split(",") - optional_columns = ['stop_code', 'stop_desc', 'zone_id', 'stop_url', 'location_type', 'parent_station', - 'stop_timezone', 'wheelchair_boarding'] + required_columns = "stop_id,stop_name,stop_desc,stop_lat,stop_lon".split(",") + optional_columns = [ + "stop_code", + "stop_desc", + "zone_id", + "stop_url", + "location_type", + "parent_station", + "stop_timezone", + "wheelchair_boarding", + ] self.__test_write_gtfs_table(exports._write_gtfs_stops, required_columns, optional_columns) def test_write_gtfs_routes(self): - required_columns = 'route_id,agency_id,route_short_name,route_long_name,route_desc,route_type'.split(",") - optional_columns = ['route_desc', 'route_url', 'route_color', 'route_text_color'] + required_columns = "route_id,agency_id,route_short_name,route_long_name,route_desc,route_type".split( + "," + ) + optional_columns = ["route_desc", "route_url", "route_color", "route_text_color"] self.__test_write_gtfs_table(exports._write_gtfs_routes, required_columns, optional_columns) def test_write_gtfs_trips(self): - required_columns = 'route_id,service_id,trip_id'.split(",") - optional_columns = ['trip_headsign', 'trip_short_name', 'direction_id', 'block_id', - 'shape_id', 'wheelchair_accessible', 'bikes_allowed'] + required_columns = "route_id,service_id,trip_id".split(",") + optional_columns = [ + "trip_headsign", + "trip_short_name", + "direction_id", + "block_id", + "shape_id", + "wheelchair_accessible", + "bikes_allowed", + ] self.__test_write_gtfs_table(exports._write_gtfs_trips, required_columns, optional_columns) def test_write_gtfs_stop_times(self): - required_columns = 'trip_id,arrival_time,departure_time,stop_id,stop_sequence'.split(",") - optional_columns = ['stop_headsign', 'pickup_type', 'drop_off_type', 'shape_dist_traveled', 'timepoint'] - self.__test_write_gtfs_table(exports._write_gtfs_stop_times, required_columns, optional_columns) + required_columns = "trip_id,arrival_time,departure_time,stop_id,stop_sequence".split(",") + optional_columns = [ + "stop_headsign", + "pickup_type", + "drop_off_type", + "shape_dist_traveled", + "timepoint", + ] + self.__test_write_gtfs_table( + exports._write_gtfs_stop_times, required_columns, optional_columns + ) def test_write_gtfs_calendar(self): - required_columns = 'service_id,monday,tuesday,wednesday,thursday,friday,saturday,sunday,' \ - 'start_date,end_date'.split(",") + required_columns = ( + "service_id,monday,tuesday,wednesday,thursday,friday,saturday,sunday," + "start_date,end_date".split(",") + ) self.__test_write_gtfs_table(exports._write_gtfs_calendar, required_columns, []) in_memory_file = io.StringIO() exports._write_gtfs_calendar(self.gtfs, in_memory_file) in_memory_file.seek(0) df = pandas.read_csv(in_memory_file) - self.assertTrue("-" not in str(df['start_date'][0])) - self.assertTrue("-" not in str(df['end_date'][0])) - self.assertTrue(len(str(df['start_date'][0])) == 8) - self.assertTrue(len(str(df['start_date'][0])) == 8) + self.assertTrue("-" not in str(df["start_date"][0])) + self.assertTrue("-" not in str(df["end_date"][0])) + self.assertTrue(len(str(df["start_date"][0])) == 8) + self.assertTrue(len(str(df["start_date"][0])) == 8) def test_write_gtfs_calendar_dates(self): - required_columns = 'service_id,date,exception_type'.split(",") + required_columns = "service_id,date,exception_type".split(",") self.__test_write_gtfs_table(exports._write_gtfs_calendar_dates, required_columns, []) def test_write_gtfs_shapes(self): - required_columns = 'shape_id,shape_pt_lat,shape_pt_lon,shape_pt_sequence'.split(",") - optional_columns = ['shape_dist_traveled'] + required_columns = "shape_id,shape_pt_lat,shape_pt_lon,shape_pt_sequence".split(",") + optional_columns = ["shape_dist_traveled"] self.__test_write_gtfs_table(exports._write_gtfs_shapes, required_columns, optional_columns) def test_write_gtfs_transfers(self): - required_columns = 'from_stop_id,to_stop_id,transfer_type'.split(",") - optional_columns = ['min_transfer_time'] - self.__test_write_gtfs_table(exports._write_gtfs_transfers, required_columns, optional_columns) + required_columns = "from_stop_id,to_stop_id,transfer_type".split(",") + optional_columns = ["min_transfer_time"] + self.__test_write_gtfs_table( + exports._write_gtfs_transfers, required_columns, optional_columns + ) def test_write_gtfs_stop_distances(self): - required_columns = 'from_stop_id,to_stop_id,d,d_walk'.split(",") + required_columns = "from_stop_id,to_stop_id,d,d_walk".split(",") optional_columns = [] - self.__test_write_gtfs_table(exports._write_gtfs_stop_distances, required_columns, optional_columns) + self.__test_write_gtfs_table( + exports._write_gtfs_stop_distances, required_columns, optional_columns + ) def test_write_feed_info(self): - required_columns = 'feed_publisher_name,feed_publisher_url,feed_lang'.split(",") - columns_not_present = ['feed_start_date', 'feed_end_date', 'feed_version', 'feed_id'] - self.__test_write_gtfs_table(exports._write_gtfs_feed_info, required_columns, columns_not_present) + required_columns = "feed_publisher_name,feed_publisher_url,feed_lang".split(",") + columns_not_present = ["feed_start_date", "feed_end_date", "feed_version", "feed_id"] + self.__test_write_gtfs_table( + exports._write_gtfs_feed_info, required_columns, columns_not_present + ) def __test_write_gtfs_table(self, table_write_func, required_columns, optional_columns): """ @@ -214,6 +268,7 @@ def test_write_gtfs(self): # A simple import-output-import test" for ending in ["", ".zip"]: from gtfspy.import_gtfs import import_gtfs + UUID = "36167f3012fe11e793ae92361f002671" sqlite_fname = "test_" + UUID + ".sqlite" test_output_dir = "./test_output_dir_" + UUID @@ -226,7 +281,7 @@ def test_write_gtfs(self): exports.write_gtfs(self.gtfs, test_output_dir + ending) self.assertTrue(os.path.exists(test_output_dir + ending)) try: - G = import_gtfs(test_output_dir + ending, os.path.join(sqlite_fname)) + import_gtfs(test_output_dir + ending, os.path.join(sqlite_fname)) self.assertTrue(os.path.exists(sqlite_fname)) finally: os.remove(sqlite_fname) @@ -245,7 +300,7 @@ def test_write_stops_geojson(self): self.assertTrue(gjson.is_valid) - gjson_properties = gjson['features'][0]['properties'] + gjson_properties = gjson["features"][0]["properties"] self.assertIn("name", gjson_properties.keys()) self.assertIn("stop_I", gjson_properties.keys()) @@ -259,7 +314,7 @@ def test_write_sections_geojson(self): self.assertTrue(gjson.is_valid) - gjson_properties = gjson['features'][0]['properties'] + gjson_properties = gjson["features"][0]["properties"] self.assertIn("from_stop_I", gjson_properties.keys()) self.assertIn("to_stop_I", gjson_properties.keys()) @@ -277,15 +332,14 @@ def test_write_routes_geojson(self): self.assertTrue(gjson.is_valid) - gjson_properties = gjson['features'][0]['properties'] + gjson_properties = gjson["features"][0]["properties"] self.assertIn("route_type", gjson_properties.keys()) self.assertIn("route_I", gjson_properties.keys()) self.assertIn("route_name", gjson_properties.keys()) - - # def test_clustered_stops_network(self): + # orig_net = networks.undirected_stop_to_stop_network_with_route_information(self.gtfs) # aggregate_net = networks.aggregate_route_network(self.gtfs, 1000) # self.assertGreater(len(orig_net.nodes()), len(aggregate_net.nodes())) diff --git a/gtfspy/test/test_filter.py b/gtfspy/test/test_filter.py index 5cfdd40..d55ceb3 100644 --- a/gtfspy/test/test_filter.py +++ b/gtfspy/test/test_filter.py @@ -13,7 +13,6 @@ class TestGTFSFilter(unittest.TestCase): - def setUp(self): self.gtfs_source_dir = os.path.join(os.path.dirname(__file__), "test_data") self.gtfs_source_dir_filter_test = os.path.join(self.gtfs_source_dir, "filter_test_feed/") @@ -31,12 +30,17 @@ def setUp(self): conn = sqlite3.connect(self.fname) import_gtfs(self.gtfs_source_dir, conn, preserve_connection=True, print_progress=False) conn_filter = sqlite3.connect(self.fname_filter) - import_gtfs(self.gtfs_source_dir_filter_test, conn_filter, preserve_connection=True, print_progress=False) + import_gtfs( + self.gtfs_source_dir_filter_test, + conn_filter, + preserve_connection=True, + print_progress=False, + ) self.G = GTFS(conn) self.G_filter_test = GTFS(conn_filter) - self.hash_orig = hashlib.md5(open(self.fname, 'rb').read()).hexdigest() + self.hash_orig = hashlib.md5(open(self.fname, "rb").read()).hexdigest() def _remove_temporary_files(self): for fn in [self.fname, self.fname_copy, self.fname_filter]: @@ -51,7 +55,7 @@ def test_copy(self): FilterExtract(self.G, self.fname_copy, update_metadata=False).create_filtered_copy() # check that the copying has been properly performed: - hash_copy = hashlib.md5(open(self.fname_copy, 'rb').read()).hexdigest() + hash_copy = hashlib.md5(open(self.fname_copy, "rb").read()).hexdigest() self.assertTrue(os.path.exists(self.fname_copy)) self.assertEqual(self.hash_orig, hash_copy) @@ -59,39 +63,50 @@ def test_filter_change_metadata(self): # A simple test that changing update_metadata to True, does update some stuff: FilterExtract(self.G, self.fname_copy, update_metadata=True).create_filtered_copy() # check that the copying has been properly performed: - hash_orig = hashlib.md5(open(self.fname, 'rb').read()).hexdigest() - hash_copy = hashlib.md5(open(self.fname_copy, 'rb').read()).hexdigest() + hash_orig = hashlib.md5(open(self.fname, "rb").read()).hexdigest() + hash_copy = hashlib.md5(open(self.fname_copy, "rb").read()).hexdigest() self.assertTrue(os.path.exists(self.fname_copy)) self.assertNotEqual(hash_orig, hash_copy) os.remove(self.fname_copy) def test_filter_by_agency(self): - FilterExtract(self.G, self.fname_copy, agency_ids_to_preserve=['DTA']).create_filtered_copy() - hash_copy = hashlib.md5(open(self.fname_copy, 'rb').read()).hexdigest() + FilterExtract( + self.G, self.fname_copy, agency_ids_to_preserve=["DTA"] + ).create_filtered_copy() + hash_copy = hashlib.md5(open(self.fname_copy, "rb").read()).hexdigest() self.assertNotEqual(self.hash_orig, hash_copy) G_copy = GTFS(self.fname_copy) agency_table = G_copy.get_table("agencies") - assert "EXA" not in agency_table['agency_id'].values, "EXA agency should not be preserved" - assert "DTA" in agency_table['agency_id'].values, "DTA agency should be preserved" + assert "EXA" not in agency_table["agency_id"].values, "EXA agency should not be preserved" + assert "DTA" in agency_table["agency_id"].values, "DTA agency should be preserved" routes_table = G_copy.get_table("routes") - assert "EXR1" not in routes_table['route_id'].values, "EXR1 route_id should not be preserved" - assert "AB" in routes_table['route_id'].values, "AB route_id should be preserved" + assert ( + "EXR1" not in routes_table["route_id"].values + ), "EXR1 route_id should not be preserved" + assert "AB" in routes_table["route_id"].values, "AB route_id should be preserved" trips_table = G_copy.get_table("trips") - assert "EXT1" not in trips_table['trip_id'].values, "EXR1 route_id should not be preserved" - assert "AB1" in trips_table['trip_id'].values, "AB1 route_id should be preserved" + assert "EXT1" not in trips_table["trip_id"].values, "EXR1 route_id should not be preserved" + assert "AB1" in trips_table["trip_id"].values, "AB1 route_id should be preserved" calendar_table = G_copy.get_table("calendar") - assert "FULLW" in calendar_table['service_id'].values, "FULLW service_id should be preserved" + assert ( + "FULLW" in calendar_table["service_id"].values + ), "FULLW service_id should be preserved" # stop_times stop_times_table = G_copy.get_table("stop_times") # 01:23:45 corresponds to 3600 + (32 * 60) + 45 [in day seconds] - assert 3600 + (32 * 60) + 45 not in stop_times_table['arr_time'] + assert 3600 + (32 * 60) + 45 not in stop_times_table["arr_time"] os.remove(self.fname_copy) - def test_filter_by_start_and_end_full_range(self): # untested tables with filtering: stops, shapes # test filtering by start and end time, copy full range - FilterExtract(self.G, self.fname_copy, start_date=u"2007-01-01", end_date=u"2011-01-01", update_metadata=False).create_filtered_copy() + FilterExtract( + self.G, + self.fname_copy, + start_date="2007-01-01", + end_date="2011-01-01", + update_metadata=False, + ).create_filtered_copy() G_copy = GTFS(self.fname_copy) dsut_end = G_copy.get_day_start_ut("2010-12-31") dsut_to_trip_I = G_copy.get_tripIs_within_range_by_dsut(dsut_end, dsut_end + 24 * 3600) @@ -100,9 +115,11 @@ def test_filter_by_start_and_end_full_range(self): def test_filter_end_date_not_included(self): # the end date should not be included: - FilterExtract(self.G, self.fname_copy, start_date="2007-01-02", end_date="2010-12-31").create_filtered_copy() + FilterExtract( + self.G, self.fname_copy, start_date="2007-01-02", end_date="2010-12-31" + ).create_filtered_copy() - hash_copy = hashlib.md5(open(self.fname_copy, 'rb').read()).hexdigest() + hash_copy = hashlib.md5(open(self.fname_copy, "rb").read()).hexdigest() self.assertNotEqual(self.hash_orig, hash_copy) G_copy = GTFS(self.fname_copy) dsut_end = G_copy.get_day_start_ut("2010-12-31") @@ -110,40 +127,58 @@ def test_filter_end_date_not_included(self): self.assertEqual(len(dsut_to_trip_I), 0) calendar_copy = G_copy.get_table("calendar") - max_date_calendar = max([datetime.datetime.strptime(el, "%Y-%m-%d") - for el in calendar_copy["end_date"].values]) - min_date_calendar = max([datetime.datetime.strptime(el, "%Y-%m-%d") - for el in calendar_copy["start_date"].values]) + max_date_calendar = max( + [datetime.datetime.strptime(el, "%Y-%m-%d") for el in calendar_copy["end_date"].values] + ) + min_date_calendar = max( + [ + datetime.datetime.strptime(el, "%Y-%m-%d") + for el in calendar_copy["start_date"].values + ] + ) end_date_not_included = datetime.datetime.strptime("2010-12-31", "%Y-%m-%d") start_date_not_included = datetime.datetime.strptime("2007-01-01", "%Y-%m-%d") - self.assertLess(max_date_calendar, end_date_not_included, msg="the last date should not be included in calendar") + self.assertLess( + max_date_calendar, + end_date_not_included, + msg="the last date should not be included in calendar", + ) self.assertLess(start_date_not_included, min_date_calendar) os.remove(self.fname_copy) def test_filter_spatially(self): # test that the db is split by a given spatial boundary - FilterExtract(self.G, self.fname_copy, buffer_lat=36.914893, buffer_lon=-116.76821, buffer_distance_km=50).create_filtered_copy() + FilterExtract( + self.G, + self.fname_copy, + buffer_lat=36.914893, + buffer_lon=-116.76821, + buffer_distance_km=50, + ).create_filtered_copy() G_copy = GTFS(self.fname_copy) stops_table = G_copy.get_table("stops") - self.assertNotIn("FUR_CREEK_RES", stops_table['stop_id'].values) - self.assertIn("AMV", stops_table['stop_id'].values) - self.assertEqual(len(stops_table['stop_id'].values), 8) + self.assertNotIn("FUR_CREEK_RES", stops_table["stop_id"].values) + self.assertIn("AMV", stops_table["stop_id"].values) + self.assertEqual(len(stops_table["stop_id"].values), 8) conn_copy = sqlite3.connect(self.fname_copy) - stop_ids_df = pandas.read_sql('SELECT stop_id from stop_times ' - 'left join stops ' - 'on stops.stop_I = stop_times.stop_I', conn_copy) + stop_ids_df = pandas.read_sql( + "SELECT stop_id from stop_times " + "left join stops " + "on stops.stop_I = stop_times.stop_I", + conn_copy, + ) stop_ids = stop_ids_df["stop_id"].values self.assertNotIn("FUR_CREEK_RES", stop_ids) self.assertIn("AMV", stop_ids) trips_table = G_copy.get_table("trips") - self.assertNotIn("BFC1", trips_table['trip_id'].values) + self.assertNotIn("BFC1", trips_table["trip_id"].values) routes_table = G_copy.get_table("routes") - self.assertNotIn("BFC", routes_table['route_id'].values) + self.assertNotIn("BFC", routes_table["route_id"].values) # cases: # whole trip excluded # whole route excluded @@ -159,21 +194,15 @@ def test_filter_spatially(self): # -> stop C preserved def test_filter_spatially_2(self): - n_rows_before = { - "routes": 4, - "stop_times": 14, - "trips": 4, - "stops": 6, - "shapes": 4 - } - n_rows_after_1000 = { # within "soft buffer" in the feed data + n_rows_before = {"routes": 4, "stop_times": 14, "trips": 4, "stops": 6, "shapes": 4} + n_rows_after_1000 = { # within "soft buffer" in the feed data "routes": 1, "stop_times": 2, "trips": 1, "stops": 2, - "shapes": 0 + "shapes": 0, } - n_rows_after_3000 = { # within "hard buffer" in the feed data + n_rows_after_3000 = { # within "hard buffer" in the feed data "routes": len(["t1", "t3", "t4"]), "stop_times": 11, "trips": 4, @@ -183,44 +212,65 @@ def test_filter_spatially_2(self): paris_lat = 48.832781 paris_lon = 2.360734 - SELECT_MIN_MAX_SHAPE_BREAKS_BY_TRIP_I_SQL = \ - "SELECT trips.trip_I, shape_id, min(shape_break) as min_shape_break, max(shape_break) as max_shape_break FROM trips, stop_times WHERE trips.trip_I=stop_times.trip_I GROUP BY trips.trip_I" - trip_min_max_shape_seqs = pandas.read_sql(SELECT_MIN_MAX_SHAPE_BREAKS_BY_TRIP_I_SQL, self.G_filter_test.conn) + # SELECT_MIN_MAX_SHAPE_BREAKS_BY_TRIP_I_SQL = "SELECT trips.trip_I, shape_id, min(shape_break) as min_shape_break, max(shape_break) as max_shape_break FROM trips, stop_times WHERE trips.trip_I=stop_times.trip_I GROUP BY trips.trip_I" + # trip_min_max_shape_seqs = pandas.read_sql( + # SELECT_MIN_MAX_SHAPE_BREAKS_BY_TRIP_I_SQL, self.G_filter_test.conn + # ) for distance_km, n_rows_after in zip([1000, 3000], [n_rows_after_1000, n_rows_after_3000]): try: os.remove(self.fname_copy) except FileNotFoundError: pass - FilterExtract(self.G_filter_test, - self.fname_copy, - buffer_lat=paris_lat, - buffer_lon=paris_lon, - buffer_distance_km=distance_km).create_filtered_copy() + FilterExtract( + self.G_filter_test, + self.fname_copy, + buffer_lat=paris_lat, + buffer_lon=paris_lon, + buffer_distance_km=distance_km, + ).create_filtered_copy() for table_name, n_rows in n_rows_before.items(): - self.assertEqual(len(self.G_filter_test.get_table(table_name)), n_rows, "Row counts before differ in " + table_name + ", distance: " + str(distance_km)) + self.assertEqual( + len(self.G_filter_test.get_table(table_name)), + n_rows, + "Row counts before differ in " + table_name + ", distance: " + str(distance_km), + ) G_copy = GTFS(self.fname_copy) for table_name, n_rows in n_rows_after.items(): table = G_copy.get_table(table_name) - self.assertEqual(len(table), n_rows, "Row counts after differ in " + table_name + ", distance: " + str(distance_km) + "\n" + str(table)) + self.assertEqual( + len(table), + n_rows, + "Row counts after differ in " + + table_name + + ", distance: " + + str(distance_km) + + "\n" + + str(table), + ) # assert that stop_times are resequenced starting from one - counts = pandas.read_sql("SELECT count(*) FROM stop_times GROUP BY trip_I ORDER BY trip_I", G_copy.conn) - max_values = pandas.read_sql("SELECT max(seq) FROM stop_times GROUP BY trip_I ORDER BY trip_I", G_copy.conn) + counts = pandas.read_sql( + "SELECT count(*) FROM stop_times GROUP BY trip_I ORDER BY trip_I", G_copy.conn + ) + max_values = pandas.read_sql( + "SELECT max(seq) FROM stop_times GROUP BY trip_I ORDER BY trip_I", G_copy.conn + ) self.assertTrue((counts.values == max_values.values).all()) def test_remove_all_trips_fully_outside_buffer(self): stops = self.G.stops() - stop_1 = stops[stops['stop_I'] == 1] + stop_1 = stops[stops["stop_I"] == 1] n_trips_before = len(self.G.get_table("trips")) - remove_all_trips_fully_outside_buffer(self.G.conn, float(stop_1.lat), float(stop_1.lon), 100000) + remove_all_trips_fully_outside_buffer( + self.G.conn, float(stop_1.lat), float(stop_1.lon), 100000 + ) self.assertEqual(len(self.G.get_table("trips")), n_trips_before) # 0.002 (=max 2 meters from the stop), rounding errors can take place... - remove_all_trips_fully_outside_buffer(self.G.conn, float(stop_1.lat), float(stop_1.lon), 0.002) + remove_all_trips_fully_outside_buffer( + self.G.conn, float(stop_1.lat), float(stop_1.lon), 0.002 + ) self.assertEqual(len(self.G.get_table("trips")), 2) # value "2" comes from the data - - - diff --git a/gtfspy/test/test_geometry.py b/gtfspy/test/test_geometry.py index 0498db9..1fca652 100644 --- a/gtfspy/test/test_geometry.py +++ b/gtfspy/test/test_geometry.py @@ -3,10 +3,15 @@ import numpy as np from gtfspy.gtfs import GTFS -from gtfspy.geometry import get_convex_hull_coordinates, get_approximate_convex_hull_area_km2, approximate_convex_hull_area, compute_buffered_area_of_stops +from gtfspy.geometry import ( + get_convex_hull_coordinates, + get_approximate_convex_hull_area_km2, + approximate_convex_hull_area, + compute_buffered_area_of_stops, +) -class GeometryTest(unittest.TestCase): +class GeometryTest(unittest.TestCase): @classmethod def setUpClass(cls): """ This method is run once before executing any tests""" @@ -38,52 +43,42 @@ def test_approximate_convex_hull_area(self): main_railway_station_coords = 60.171545, 24.940734 # lat, lon lats, lons = list(zip(leppavaara_coords, pasila_coords, main_railway_station_coords)) - approximate_reference = 9.91 # computed using https://asiointi.maanmittauslaitos.fi/karttapaikka/ - computed = approximate_convex_hull_area(lons, lats) - self.assertTrue(approximate_reference * 0.9 < computed < approximate_reference * 1.1) - - def test_approximate_convex_hull_area(self): - # helsinki railway station, Helsinki - # leppavaara station, Helsinki - # pasila railway station, Helsinki - leppavaara_coords = 60.219163, 24.813390 - pasila_coords = 60.199136, 24.934090 - main_railway_station_coords = 60.171545, 24.940734 - # lat, lon - lats, lons = list(zip(leppavaara_coords, pasila_coords, main_railway_station_coords)) - approximate_reference = 9.91 # computed using https://asiointi.maanmittauslaitos.fi/karttapaikka/ + approximate_reference = ( + 9.91 # computed using https://asiointi.maanmittauslaitos.fi/karttapaikka/ + ) computed = approximate_convex_hull_area(lons, lats) - self.assertTrue(approximate_reference * 0.9 < computed < approximate_reference * 1.1) + self.assertTrue(approximate_reference * 0.9 < computed < approximate_reference * 1.1) def test_get_buffered_area_of_stops(self): # stop1 is far from stop2, theres no overlap # stop1 and stop3 are close and could have overlap # The area has an accuracy between 95%-99% of the real value. - stop1_coords = 61.129094,24.027896 - stop2_coords = 61.747408,23.924279 - stop3_coords = 61.129621,24.027363 - #lat, lon + stop1_coords = 61.129094, 24.027896 + stop2_coords = 61.747408, 23.924279 + stop3_coords = 61.129621, 24.027363 + # lat, lon lats_1, lons_1 = list(zip(stop1_coords)) lats_1_2, lons_1_2 = list(zip(stop1_coords, stop2_coords)) lats_1_3, lons_1_3 = list(zip(stop1_coords, stop3_coords)) - - #One point buffer - buffer_onepoint = 100 #100 meters of radius - true_area = 10000 * np.pi #area = pi * square radius - area_1 = compute_buffered_area_of_stops(lats_1, lons_1, buffer_onepoint) + + # One point buffer + buffer_onepoint = 100 # 100 meters of radius + true_area = 10000 * np.pi # area = pi * square radius + area_1 = compute_buffered_area_of_stops(lats_1, lons_1, buffer_onepoint) confidence = true_area * 0.95 self.assertTrue(confidence < area_1 < true_area) - - + # Two points buffer non-overlap # Note: the points are "far away" to avoid overlap, but since they are points in the same city # a "really big buffer" could cause overlap and the test is going fail. - buffer_nonoverlap = 100 #100 meters of radius - two_points_nonoverlap_true_area = 2 * buffer_nonoverlap ** 2 * np.pi #area = pi * square radius - area_1_2 = compute_buffered_area_of_stops(lats_1_2, lons_1_2, buffer_nonoverlap) + buffer_nonoverlap = 100 # 100 meters of radius + two_points_nonoverlap_true_area = ( + 2 * buffer_nonoverlap ** 2 * np.pi + ) # area = pi * square radius + area_1_2 = compute_buffered_area_of_stops(lats_1_2, lons_1_2, buffer_nonoverlap) confidence_2 = two_points_nonoverlap_true_area * 0.95 - self.assertTrue(confidence_2 < area_1_2 and area_1_2 < two_points_nonoverlap_true_area) - + self.assertTrue(confidence_2 < area_1_2 and area_1_2 < two_points_nonoverlap_true_area) + # Two points buffer with overlap # Points so close that will overlap with a radius of 100 meters buffer_overlap = 100 # 100 meters of radius @@ -93,20 +88,21 @@ def test_get_buffered_area_of_stops(self): # 'Half-overlap' from gtfspy.util import wgs84_distance + lat1, lat3 = lats_1_3 lon1, lon3 = lons_1_3 distance = wgs84_distance(lat1, lon1, lat3, lon3) # just a little overlap - buffer = distance / 2. + 1 + buffer = distance / 2.0 + 1 area_1_3b = compute_buffered_area_of_stops(lats_1_3, lons_1_3, buffer, resolution=100) - one_point_true_area = np.pi * buffer**2 + one_point_true_area = np.pi * buffer ** 2 self.assertLess(one_point_true_area * 1.5, area_1_3b) self.assertLess(area_1_3b, 2 * one_point_true_area) # no overlap - buffer = distance / 2. - 1 + buffer = distance / 2.0 - 1 area_1_3b = compute_buffered_area_of_stops(lats_1_3, lons_1_3, buffer, resolution=100) two_points_nonoverlap_true_area = 2 * buffer ** 2 * np.pi self.assertGreater(area_1_3b, two_points_nonoverlap_true_area * 0.95) - self.assertLess(area_1_3b , two_points_nonoverlap_true_area) + self.assertLess(area_1_3b, two_points_nonoverlap_true_area) diff --git a/gtfspy/test/test_gtfs.py b/gtfspy/test/test_gtfs.py index f56e53d..b395452 100644 --- a/gtfspy/test/test_gtfs.py +++ b/gtfspy/test/test_gtfs.py @@ -52,9 +52,12 @@ def test_get_day_start_ut(self): self.assertEquals(day_start_ut_should_be, day_start_ut_is) def test_get_main_database_path(self): - self.assertEqual(self.gtfs.get_main_database_path(), "", "path of an in-memory database should equal ''") + self.assertEqual( + self.gtfs.get_main_database_path(), "", "path of an in-memory database should equal ''" + ) from gtfspy.import_gtfs import import_gtfs + try: fname = self.gtfs_source_dir + "/test_gtfs.sqlite" if os.path.exists(fname) and os.path.isfile(fname): @@ -63,37 +66,37 @@ def test_get_main_database_path(self): import_gtfs(self.gtfs_source_dir, conn, preserve_connection=True, print_progress=False) G = GTFS(conn) self.assertTrue(os.path.exists(G.get_main_database_path())) - self.assertIn(u"/test_gtfs.sqlite", G.get_main_database_path(), "path should be correct") + self.assertIn("/test_gtfs.sqlite", G.get_main_database_path(), "path should be correct") finally: if os.path.exists(fname) and os.path.isfile(fname): os.remove(fname) def test_get_table(self): - df = self.gtfs.get_table(u"agencies") + df = self.gtfs.get_table("agencies") self.assertTrue(isinstance(df, pandas.DataFrame)) def test_get_table_names(self): tables = self.gtfs.get_table_names() self.assertTrue(isinstance(tables, list)) - self.assertGreater(len(tables), 11, u"quite many tables should be available") - self.assertIn(u"routes", tables) + self.assertGreater(len(tables), 11, "quite many tables should be available") + self.assertIn("routes", tables) def test_get_all_route_shapes(self): res = self.gtfs.get_all_route_shapes() self.assertTrue(isinstance(res, list)) el = res[0] - keys = u"name type agency lats lons".split() + keys = "name type agency lats lons".split() for key in keys: self.assertTrue(key in el) for el in res: - self.assertTrue(isinstance(el[u"name"], string_types), type(el[u"name"])) - self.assertTrue(isinstance(el[u"type"], (int, numpy.int_)), type(el[u'type'])) - self.assertTrue(isinstance(el[u"agency"], string_types)) - self.assertTrue(isinstance(el[u"lats"], list), type(el[u'lats'])) - self.assertTrue(isinstance(el[u"lons"], list)) - self.assertTrue(isinstance(el[u'lats'][0], float)) - self.assertTrue(isinstance(el[u'lons'][0], float)) + self.assertTrue(isinstance(el["name"], string_types), type(el["name"])) + self.assertTrue(isinstance(el["type"], (int, numpy.int_)), type(el["type"])) + self.assertTrue(isinstance(el["agency"], string_types)) + self.assertTrue(isinstance(el["lats"], list), type(el["lats"])) + self.assertTrue(isinstance(el["lons"], list)) + self.assertTrue(isinstance(el["lats"][0], float)) + self.assertTrue(isinstance(el["lons"][0], float)) def test_get_shape_distance_between_stops(self): # tested as a part of test_to_directed_graph, although this could be made a separate test as well @@ -115,8 +118,8 @@ def test_get_timezone_string(self): self.assertIn(tz_string[0], "+-") for i in range(1, 5): self.assertIn(tz_string[i], "0123456789") - dt = datetime.datetime(1970, 1, 1) - tz_string_epoch = self.gtfs.get_timezone_string(dt) + # dt = datetime.datetime(1970, 1, 1) + # tz_string_epoch = self.gtfs.get_timezone_string(dt) # self.assertEqual(tz_string, tz_string_epoch) def test_timezone_conversions(self): @@ -155,7 +158,7 @@ def test_get_stop_count_data(self): self.assertTrue(isinstance(el, float)) if c in ["name"]: self.assertTrue(isinstance(el, string_types), type(el)) - self.assertTrue((df['count'].values > 0).any()) + self.assertTrue((df["count"].values > 0).any()) def test_get_segment_count_data(self): dt_start_query = datetime.datetime(2007, 1, 1, 7, 59, 59) @@ -181,8 +184,10 @@ def test_get_tripIs_active_in_range(self): df = self.gtfs.get_tripIs_active_in_range(start_query, end_query) self.assertGreater(len(df), 0) for row in df.itertuples(): - self.assertTrue((row.start_time_ut <= end_query) and \ - (row.end_time_ut >= start_query), "some trip does not overlap!") + self.assertTrue( + (row.start_time_ut <= end_query) and (row.end_time_ut >= start_query), + "some trip does not overlap!", + ) if row.start_time_ut == start_real and row.end_time_ut == end_real: found = True # check that overlaps @@ -214,11 +219,11 @@ def test_get_closest_stop(self): # First row in test_data: # FUR_CREEK_RES, Furnace Creek Resort (Demo),, 36.425288, -117.133162,, lat_s, lon_s = 36.425288, -117.133162 - lat, lon = lat_s + 10**-5, lon_s + 10**-5 + lat, lon = lat_s + 10 ** -5, lon_s + 10 ** -5 stop_I = self.gtfs.get_closest_stop(lat, lon) self.assertTrue(isinstance(stop_I, int)) df = self.gtfs.stop(stop_I) - name = df['name'][0] + name = df["name"][0] # print name # check that correct stop has been found: self.assertTrue(name == "Furnace Creek Resort (Demo)") @@ -238,23 +243,23 @@ def test_get_trip_stop_time_data(self): dsut, trip_Is = list(dsut_dict.items())[0] df = self.gtfs.get_trip_stop_time_data(trip_Is[0], dsut) self.assertTrue(isinstance(df, pandas.DataFrame)) - columns = u"dep_time_ut lat lon seq shape_break".split(" ") + columns = "dep_time_ut lat lon seq shape_break".split(" ") el = df.iloc[0] for c in columns: self.assertTrue(c in df.columns) - if c in u"dep_time_ut lat lon".split(" "): + if c in "dep_time_ut lat lon".split(" "): self.assertTrue(isinstance(el[c], float)) - if c in u"seq".split(" "): + if c in "seq".split(" "): self.assertTrue(isinstance(el[c], (int, numpy.int_)), type(el[c])) def test_get_straight_line_transfer_distances(self): data = self.gtfs.get_straight_line_transfer_distances() a_stop_I = None for index, row in data.iterrows(): - self.assertTrue(row[u'from_stop_I'] is not None) - a_stop_I = row[u'from_stop_I'] - self.assertTrue(row[u'to_stop_I'] is not None) - self.assertTrue(row[u'd'] is not None) + self.assertTrue(row["from_stop_I"] is not None) + a_stop_I = row["from_stop_I"] + self.assertTrue(row["to_stop_I"] is not None) + self.assertTrue(row["d"] is not None) data = self.gtfs.get_straight_line_transfer_distances(a_stop_I) self.assertGreater(len(data), 0) @@ -291,11 +296,14 @@ def test_homogenize_stops_table_with_other_db(self): def test_get_weekly_extract_start_date(self): trip_counts_per_day = self.G.get_trip_counts_per_day() - first_day = trip_counts_per_day['date'].min() - last_day = trip_counts_per_day['date'].max() # a monday not in reach + first_day = trip_counts_per_day["date"].min() + last_day = trip_counts_per_day["date"].max() # a monday not in reach # print(first_day, last_day) first_monday = self.G.get_weekly_extract_start_date(download_date_override=first_day) early_monday = self.G.get_weekly_extract_start_date( - download_date_override=first_day + datetime.timedelta(days=10)) - end_monday = self.G.get_weekly_extract_start_date(download_date_override=last_day - datetime.timedelta(days=5)) + download_date_override=first_day + datetime.timedelta(days=10) + ) + end_monday = self.G.get_weekly_extract_start_date( + download_date_override=last_day - datetime.timedelta(days=5) + ) assert first_monday < early_monday < end_monday diff --git a/gtfspy/test/test_import_gtfs.py b/gtfspy/test/test_import_gtfs.py index b7aeb06..5161a5a 100644 --- a/gtfspy/test/test_import_gtfs.py +++ b/gtfspy/test/test_import_gtfs.py @@ -10,7 +10,6 @@ # noinspection PyTypeChecker class TestImport(unittest.TestCase): - @classmethod def setup_class(cls): """This method is run once for each class before any tests are run""" @@ -27,84 +26,95 @@ def tearDown(self): def setUp(self): """This method is run once before _each_ test method is executed""" - self.conn = sqlite3.connect(':memory:') - self.agencyText = \ - 'agency_id, agency_name, agency_timezone, agency_url' \ - '\n ag1, CompNet, Europe/Zurich, www.example.com' - self.stopsText = \ - 'stop_id, stop_name, stop_lat, stop_lon, parent_station' \ - '\nSID1, "Parent-Stop-Name", 1.0, 2.0, ' \ - '\nSID2, Boring Stop Name, 1.1, 2.1, SID1' \ - '\n3, Boring Stop Name1, 1.2, 2.2, ' \ - '\n4, Boring Stop Name2, 1.3, 2.3, 3' \ - '\n5, StopCloseToFancyStop, 1.0001, 2.0001, ' \ - '\nT1, t1, 1.0001, 2.2, ' \ - '\nT2, t2, 1.0002, 2.2, ' \ - '\nT3, t3, 1.00015, 2.2, ' \ - '\nT4, t4, 1.0001, 2.2, ' - self.calendarText = \ - 'service_id, monday, tuesday, wednesday, thursday, friday, saturday, sunday,' \ - 'start_date, end_date' \ - '\nservice1, 1, 1, 1, 1, 1, 1, 1, 20160321, 20160327' \ - '\nservice2, 0, 0, 0, 0, 0, 0, 0, 20160321, 20160327' \ - '\nfreq_service, 1, 1, 1, 1, 1, 1, 1, 20160329, 20160329' - self.calendarDatesText = \ - 'service_id, date, exception_type' \ - '\nservice1, 20160322, 2' \ - '\nextra_service, 20160321, 1' \ - '\nservice2, 20160322, 1' \ - '\nphantom_service, 20160320, 2' + self.conn = sqlite3.connect(":memory:") + self.agencyText = ( + "agency_id, agency_name, agency_timezone, agency_url" + "\n ag1, CompNet, Europe/Zurich, www.example.com" + ) + self.stopsText = ( + "stop_id, stop_name, stop_lat, stop_lon, parent_station" + '\nSID1, "Parent-Stop-Name", 1.0, 2.0, ' + "\nSID2, Boring Stop Name, 1.1, 2.1, SID1" + "\n3, Boring Stop Name1, 1.2, 2.2, " + "\n4, Boring Stop Name2, 1.3, 2.3, 3" + "\n5, StopCloseToFancyStop, 1.0001, 2.0001, " + "\nT1, t1, 1.0001, 2.2, " + "\nT2, t2, 1.0002, 2.2, " + "\nT3, t3, 1.00015, 2.2, " + "\nT4, t4, 1.0001, 2.2, " + ) + self.calendarText = ( + "service_id, monday, tuesday, wednesday, thursday, friday, saturday, sunday," + "start_date, end_date" + "\nservice1, 1, 1, 1, 1, 1, 1, 1, 20160321, 20160327" + "\nservice2, 0, 0, 0, 0, 0, 0, 0, 20160321, 20160327" + "\nfreq_service, 1, 1, 1, 1, 1, 1, 1, 20160329, 20160329" + ) + self.calendarDatesText = ( + "service_id, date, exception_type" + "\nservice1, 20160322, 2" + "\nextra_service, 20160321, 1" + "\nservice2, 20160322, 1" + "\nphantom_service, 20160320, 2" + ) # 1 -> service added # 2 -> service removed # note some same service IDs as in self.calendarText - self.tripText = \ - "route_id, service_id, trip_id, trip_headsign, trip_short_name, shape_id" \ - "\nservice1_route, service1, service1_trip1, going north, trip_s1t1, shape_s1t1" \ - "\nservice2_route, service2, service2_trip1, going north, trip_s2t1, shape_s2t1" \ - "\nes_route, extra_service, es_trip1, going north, trip_es1, shape_es1" \ + self.tripText = ( + "route_id, service_id, trip_id, trip_headsign, trip_short_name, shape_id" + "\nservice1_route, service1, service1_trip1, going north, trip_s1t1, shape_s1t1" + "\nservice2_route, service2, service2_trip1, going north, trip_s2t1, shape_s2t1" + "\nes_route, extra_service, es_trip1, going north, trip_es1, shape_es1" "\nfrequency_route, freq_service, freq_trip_scheduled, going north, freq_name, shape_es1" - self.routesText = \ - "route_id, agency_id, route_short_name, route_long_name, route_type" \ - "\nservice1_route, ag1, r1, route1, 0" \ - "\nservice2_route, ag1, r2, route2, 1" \ + ) + self.routesText = ( + "route_id, agency_id, route_short_name, route_long_name, route_type" + "\nservice1_route, ag1, r1, route1, 0" + "\nservice2_route, ag1, r2, route2, 1" "\nfrequency_route, ag1, freq_route, schedule frequency route, 2" - self.shapeText = \ - "shape_id, shape_pt_lat, shape_pt_lon, shape_pt_sequence" \ - "\n shape_s1t1,1.0,2.0,0" \ - "\n shape_s1t1,1.001,2.0,1" \ - "\n shape_s1t1,1.001,2.001,10" \ + ) + self.shapeText = ( + "shape_id, shape_pt_lat, shape_pt_lon, shape_pt_sequence" + "\n shape_s1t1,1.0,2.0,0" + "\n shape_s1t1,1.001,2.0,1" + "\n shape_s1t1,1.001,2.001,10" "\n shape_s1t1,1.10001,2.10001,100" - self.stopTimesText = \ - "trip_id, arrival_time, departure_time, stop_sequence, stop_id" \ - "\nservice1_trip1,0:06:10,0:06:10,0,SID1" \ - "\nservice1_trip1,0:06:15,0:06:16,1,SID2" \ - "\nfreq_trip_scheduled,0:00:00,0:00:00,1,SID1" \ + ) + self.stopTimesText = ( + "trip_id, arrival_time, departure_time, stop_sequence, stop_id" + "\nservice1_trip1,0:06:10,0:06:10,0,SID1" + "\nservice1_trip1,0:06:15,0:06:16,1,SID2" + "\nfreq_trip_scheduled,0:00:00,0:00:00,1,SID1" "\nfreq_trip_scheduled,0:02:00,0:02:00,1,SID2" - self.frequenciesText = \ - "trip_id, start_time, end_time, headway_secs, exact_times" \ + ) + self.frequenciesText = ( + "trip_id, start_time, end_time, headway_secs, exact_times" "\nfreq_trip_scheduled, 14:00:00, 16:00:00, 600, 1" - self.transfersText = \ - "from_stop_id, to_stop_id, transfer_type, min_transfer_time" \ - "\nT1, T2, 0, " \ - "\nT2, T3, 1, " \ - "\nT3, T1, 2, 120" \ + ) + self.transfersText = ( + "from_stop_id, to_stop_id, transfer_type, min_transfer_time" + "\nT1, T2, 0, " + "\nT2, T3, 1, " + "\nT3, T1, 2, 120" "\nT1, T4, 3, " - self.feedInfoText = \ - "feed_publisher_name, feed_publisher_url, feed_lang, feed_start_date, feed_end_date, feed_version" \ + ) + self.feedInfoText = ( + "feed_publisher_name, feed_publisher_url, feed_lang, feed_start_date, feed_end_date, feed_version" "\nThePublisher, www.example.com, en, 20160321, 20160327, 1.0" + ) self.fdict = { - 'agency.txt': self.agencyText, - 'stops.txt': self.stopsText, - 'calendar.txt': self.calendarText, - 'calendar_dates.txt': self.calendarDatesText, - 'trips.txt': self.tripText, - 'routes.txt': self.routesText, - 'shapes.txt': self.shapeText, - 'stop_times.txt': self.stopTimesText, - 'frequencies.txt': self.frequenciesText, - 'transfers.txt': self.transfersText, - 'feed_info.txt': self.feedInfoText + "agency.txt": self.agencyText, + "stops.txt": self.stopsText, + "calendar.txt": self.calendarText, + "calendar_dates.txt": self.calendarDatesText, + "trips.txt": self.tripText, + "routes.txt": self.routesText, + "shapes.txt": self.shapeText, + "stop_times.txt": self.stopTimesText, + "frequencies.txt": self.frequenciesText, + "transfers.txt": self.transfersText, + "feed_info.txt": self.feedInfoText, } self.orig_row_factory = self.conn.row_factory @@ -141,13 +151,13 @@ def printTable(self, table_name): cur = self.conn.execute("SELECT * FROM %s" % table_name) names = [d[0] for d in cur.description] for name in names: - print(name + ', ', end="") + print(name + ", ", end="") print("") for row in cur: print(row) self.conn.row_factory = prev_row_factory - def tearDown(self): + def tearDown(self): # type: ignore """This method is run once after _each_ test method is executed""" pass @@ -162,24 +172,24 @@ def test_stopLoader(self): # sqlite returns now list of dicts rows = self.conn.execute("SELECT * FROM stops").fetchall() assert len(rows) > 4 # some data should be imported - assert rows[0]['stop_I'] == 1 + assert rows[0]["stop_I"] == 1 # Store quotes in names: parent_index = None for i, row in enumerate(rows): - if row['name'] == '"Parent-Stop-Name"': + if row["name"] == '"Parent-Stop-Name"': parent_index = i break assert parent_index is not None - parent_stop_I = rows[parent_index]['stop_I'] + parent_stop_I = rows[parent_index]["stop_I"] boring_index = None for i, row in enumerate(rows): - if row['name'] == "Boring Stop Name": + if row["name"] == "Boring Stop Name": boring_index = i break assert boring_index is not None - assert rows[boring_index]['parent_I'] == parent_stop_I - assert rows[boring_index]['self_or_parent_I'] == parent_stop_I - assert rows[3]['self_or_parent_I'] == 3 + assert rows[boring_index]["parent_I"] == parent_stop_I + assert rows[boring_index]["self_or_parent_I"] == parent_stop_I + assert rows[3]["self_or_parent_I"] == 3 def test_agencyLoader(self): import_gtfs(self.fdict, self.conn, preserve_connection=True) @@ -187,12 +197,13 @@ def test_agencyLoader(self): cursor = self.conn.cursor() rows = cursor.execute("SELECT agency_id FROM agencies").fetchall() assert len(rows) == 1 - assert rows[0][0] == u'ag1', rows[0][0] + assert rows[0][0] == "ag1", rows[0][0] def test_agencyLoaderTwoTimeZonesFail(self): - newagencytext = \ + newagencytext = ( self.agencyText + "\n123, AgencyFromDifferentTZ, Europe/Helsinki, www.buahaha.com" - self.fdict['agency.txt'] = newagencytext + ) + self.fdict["agency.txt"] = newagencytext with self.assertRaises(ValueError): import_gtfs(self.fdict, self.conn, preserve_connection=True) @@ -205,8 +216,8 @@ def test_calendarLoader(self): self.setDictConn() rows = self.conn.execute("SELECT * FROM calendar").fetchall() assert len(rows[0]) == 11 - for key in 'm t w th f s su start_date end_date service_id service_I'.split(): - assert key in rows[0], 'no key ' + key + for key in "m t w th f s su start_date end_date service_id service_I".split(): + assert key in rows[0], "no key " + key def test_calendarDatesLoader(self): import_gtfs(self.fdict, self.conn, preserve_connection=True) @@ -214,11 +225,11 @@ def test_calendarDatesLoader(self): self.setDictConn() rows = self.conn.execute("SELECT * FROM calendar_dates").fetchall() for row in rows: - assert isinstance(row['service_I'], int) + assert isinstance(row["service_I"], int) # calendar table should be increased by two dummy row - rows = self.conn.execute("SELECT * " - "FROM calendar " - "WHERE service_id='phantom_service'").fetchall() + rows = self.conn.execute( + "SELECT * " "FROM calendar " "WHERE service_id='phantom_service'" + ).fetchall() # Whether this should be the case is negotiable, though self.assertEqual(len(rows), 1, "phantom service should be present in the calendar") @@ -232,63 +243,84 @@ def test_dayLoader(self): # Now, there should be # a regular trip according to calendar dates without any exceptions: self.setDictConn() - query1 = "SELECT trip_I " \ - "FROM days " \ - "JOIN trips " \ - "USING(trip_I) " \ - "JOIN calendar " \ - "USING(service_I) " \ - "WHERE date='2016-03-21'" \ - "AND service_id='service1'" + query1 = ( + "SELECT trip_I " + "FROM days " + "JOIN trips " + "USING(trip_I) " + "JOIN calendar " + "USING(service_I) " + "WHERE date='2016-03-21'" + "AND service_id='service1'" + ) res = self.conn.execute(query1).fetchall() assert len(res) == 1 - trip_I_service_1 = res[0]['trip_I'] + trip_I_service_1 = res[0]["trip_I"] print(trip_I_service_1) query2 = "SELECT * FROM days WHERE trip_I=%s" % trip_I_service_1 - self.assertEqual(len(self.conn.execute(query2).fetchall()), 6, - "There should be 6 days with the trip_I " - "corresponding to service_id service1") - query3 = "SELECT * " \ - "FROM days " \ - "JOIN trips " \ - "USING(trip_I) " \ - "JOIN calendar " \ - "USING(service_I) " \ - "WHERE date='2016-03-22'" \ - "AND service_id='service1'" - self.assertEqual(len(self.conn.execute(query3).fetchall()), 0, - "There should be no trip on date 2016-03-22" - "for service1 due to calendar_dates") - query4 = "SELECT date " \ - "FROM days " \ - "JOIN trips " \ - "USING(trip_I) " \ - "JOIN calendar " \ - "USING(service_I) " \ - "WHERE service_id='service2'" - self.assertEqual(len(self.conn.execute(query4).fetchall()), 1, - "There should be only one trip for service 2") - self.assertEqual(self.conn.execute(query4).fetchone()['date'], "2016-03-22", - "and the date should be 2016-03-22") - query6 = "SELECT * " \ - "FROM days " \ - "JOIN trips " \ - "USING(trip_I) " \ - "JOIN calendar " \ - "USING(service_I) " \ - "WHERE service_id='phantom_service'" + self.assertEqual( + len(self.conn.execute(query2).fetchall()), + 6, + "There should be 6 days with the trip_I " "corresponding to service_id service1", + ) + query3 = ( + "SELECT * " + "FROM days " + "JOIN trips " + "USING(trip_I) " + "JOIN calendar " + "USING(service_I) " + "WHERE date='2016-03-22'" + "AND service_id='service1'" + ) + self.assertEqual( + len(self.conn.execute(query3).fetchall()), + 0, + "There should be no trip on date 2016-03-22" "for service1 due to calendar_dates", + ) + query4 = ( + "SELECT date " + "FROM days " + "JOIN trips " + "USING(trip_I) " + "JOIN calendar " + "USING(service_I) " + "WHERE service_id='service2'" + ) + self.assertEqual( + len(self.conn.execute(query4).fetchall()), + 1, + "There should be only one trip for service 2", + ) + self.assertEqual( + self.conn.execute(query4).fetchone()["date"], + "2016-03-22", + "and the date should be 2016-03-22", + ) + query6 = ( + "SELECT * " + "FROM days " + "JOIN trips " + "USING(trip_I) " + "JOIN calendar " + "USING(service_I) " + "WHERE service_id='phantom_service'" + ) res = self.conn.execute(query6).fetchall() - self.assertEqual(len(res), 0, "there should be no phantom trips due to phantom service" - "even though phantom service is in calendar" - ) + self.assertEqual( + len(res), + 0, + "there should be no phantom trips due to phantom service" + "even though phantom service is in calendar", + ) def test_shapeLoader(self): import_gtfs(self.fdict, self.conn, preserve_connection=True) self.setDictConn() - keys = ['shape_id', 'lat', 'lon', 'seq', 'd'] + keys = ["shape_id", "lat", "lon", "seq", "d"] table = self.conn.execute("SELECT * FROM shapes").fetchall() - assert table[1]['d'] > 0, "distance traveled should be > 0" + assert table[1]["d"] > 0, "distance traveled should be > 0" for key in keys: assert key in table[0], "key " + key + " not in shapes table" @@ -296,13 +328,21 @@ def test_stopTimesLoader(self): import_gtfs(self.fdict, self.conn, preserve_connection=True) self.setDictConn() stoptimes = self.conn.execute("SELECT * FROM stop_times").fetchall() - keys = ['stop_I', 'shape_break', 'trip_I', 'arr_time', - 'dep_time', 'seq', 'arr_time_ds', 'dep_time_ds'] + keys = [ + "stop_I", + "shape_break", + "trip_I", + "arr_time", + "dep_time", + "seq", + "arr_time_ds", + "dep_time_ds", + ] for key in keys: assert key in stoptimes[0] - assert stoptimes[0]['dep_time_ds'] == 370 - assert stoptimes[0]['shape_break'] == 0 - assert stoptimes[1]['shape_break'] == 3 + assert stoptimes[0]["dep_time_ds"] == 370 + assert stoptimes[0]["shape_break"] == 0 + assert stoptimes[1]["shape_break"] == 3 def test_stopDistancesLoader(self): import_gtfs(self.fdict, self.conn, preserve_connection=True) @@ -313,7 +353,9 @@ def test_stopDistancesLoader(self): assert len(rows) > 0 for row in rows: print(row) - assert row['d'] >= 0, "distance should be defined for all pairs in the stop_distances table" + assert ( + row["d"] >= 0 + ), "distance should be defined for all pairs in the stop_distances table" def test_metaDataLoader(self): import_gtfs(self.fdict, self.conn, preserve_connection=True) @@ -326,7 +368,15 @@ def test_metaDataLoader(self): def test_frequencyLoader(self): import_gtfs(self.fdict, self.conn, preserve_connection=True) # "\nfrequency_route, freq_service, freq_trip, going north, freq_name, shape_es1" \ - keys = ["trip_I", "start_time", "end_time", "headway_secs", "exact_times", "start_time_ds", "end_time_ds"] + keys = [ + "trip_I", + "start_time", + "end_time", + "headway_secs", + "exact_times", + "start_time_ds", + "end_time_ds", + ] self.setDictConn() rows = self.conn.execute("SELECT * FROM frequencies").fetchall() for key in keys: @@ -336,15 +386,21 @@ def test_frequencyLoader(self): if row["start_time_ds"] == 14 * 3600: self.assertEqual(row["exact_times"], 1) # there should be twelve trips with service_I freq - count = self.conn.execute("SELECT count(*) AS count FROM trips JOIN calendar " - "USING(service_I) WHERE service_id='freq_service'").fetchone()['count'] + count = self.conn.execute( + "SELECT count(*) AS count FROM trips JOIN calendar " + "USING(service_I) WHERE service_id='freq_service'" + ).fetchone()["count"] assert count == 12, count - rows = self.conn.execute("SELECT trip_I FROM trips JOIN calendar " - "USING(service_I) WHERE service_id='freq_service'").fetchall() + rows = self.conn.execute( + "SELECT trip_I FROM trips JOIN calendar " + "USING(service_I) WHERE service_id='freq_service'" + ).fetchall() for row in rows: - trip_I = row['trip_I'] - res = self.conn.execute("SELECT * FROM stop_times WHERE trip_I={trip_I}".format(trip_I=trip_I)).fetchall() + trip_I = row["trip_I"] + res = self.conn.execute( + "SELECT * FROM stop_times WHERE trip_I={trip_I}".format(trip_I=trip_I) + ).fetchall() assert len(res) > 1, res self.setRowConn() g = GTFS(self.conn) @@ -376,8 +432,8 @@ def test_transfersLoader(self): for transfer in transfers: transfer_type = transfer["transfer_type"] - from_stop_I = transfer['from_stop_I'] - to_stop_I = transfer['to_stop_I'] + from_stop_I = transfer["from_stop_I"] + to_stop_I = transfer["to_stop_I"] min_transfer_time = transfer["min_transfer_time"] assert isinstance(from_stop_I, int) assert isinstance(to_stop_I, int) @@ -396,20 +452,30 @@ def test_transfersLoader(self): base_query = "SELECT * FROM stop_distances WHERE from_stop_I=? and to_stop_I=?" # no_transfer - no_transfer_rows = self.conn.execute(base_query, (from_stop_I_no_transfer, to_stop_I_no_transfer)).fetchall() + no_transfer_rows = self.conn.execute( + base_query, (from_stop_I_no_transfer, to_stop_I_no_transfer) + ).fetchall() assert len(no_transfer_rows) == 0 - timed_transfer_rows = \ - self.conn.execute(base_query, (from_stop_I_timed_transfer, to_stop_I_timed_transfer)).fetchall() + timed_transfer_rows = self.conn.execute( + base_query, (from_stop_I_timed_transfer, to_stop_I_timed_transfer) + ).fetchall() assert len(timed_transfer_rows) == 1 - assert timed_transfer_rows[0]['min_transfer_time'] == 0 - min_transfer_rows = \ - self.conn.execute(base_query, (from_stop_I_min_transfer, to_stop_I_min_transfer)).fetchall() + assert timed_transfer_rows[0]["min_transfer_time"] == 0 + min_transfer_rows = self.conn.execute( + base_query, (from_stop_I_min_transfer, to_stop_I_min_transfer) + ).fetchall() assert len(min_transfer_rows) == 1 - assert min_transfer_rows[0]['min_transfer_time'] == min_transfer_time_min_transfer + assert min_transfer_rows[0]["min_transfer_time"] == min_transfer_time_min_transfer def test_feedInfoLoader(self): import_gtfs(self.fdict, self.conn, preserve_connection=True) - keys = ["feed_publisher_name", "feed_publisher_url", "feed_lang", "feed_start_date", "feed_end_date"] + keys = [ + "feed_publisher_name", + "feed_publisher_url", + "feed_lang", + "feed_start_date", + "feed_end_date", + ] self.setDictConn() rows = self.conn.execute("SELECT * FROM feed_info").fetchall() for key in keys: @@ -435,12 +501,10 @@ def test_importMultiple(self): error_raised = True assert error_raised, "different timezones in multiple feeds should raise an error" - - #mod_agencyText = \ + # mod_agencyText = \ # 'agency_id, agency_name, agency_timezone, agency_url' \ - # '\nag1, CompNet, America/Los_Angeles, www.example.com' - #self.fdict['agency.txt'] = mod_agencyText - + # '\nag1, CompNet, America/Los_Angeles, www.example.com' + # self.fdict['agency.txt'] = mod_agencyText # test that if trip_id:s (or stop_id:s etc. ) are the same in two feeds, # they get different trip_Is in the database created @@ -457,12 +521,12 @@ def test_importMultiple(self): gtfs_sources = [self.fdict, self.fdict] import_gtfs(gtfs_sources, self.conn, preserve_connection=True) n_rows_double = self.conn.execute("SELECT count(*) FROM trips").fetchone()[0] - self.assertEqual(n_rows_double, 2*n_rows_ref) + self.assertEqual(n_rows_double, 2 * n_rows_ref) # check for duplicate trip_I's rows = self.conn.execute("SELECT count(*) FROM trips GROUP BY trip_I").fetchall() for row in rows: - self.assertIs(row[0],1) + self.assertIs(row[0], 1) # check for duplicate service_I's in calendar rows = self.conn.execute("SELECT count(*) FROM calendar GROUP BY service_I").fetchall() @@ -470,7 +534,9 @@ def test_importMultiple(self): self.assertIs(row[0], 1) # check for duplicate service_I's in calendar_dates - rows = self.conn.execute("SELECT count(*) FROM calendar_dates GROUP BY service_I").fetchall() + rows = self.conn.execute( + "SELECT count(*) FROM calendar_dates GROUP BY service_I" + ).fetchall() for row in rows: self.assertIs(row[0], 1) @@ -502,15 +568,15 @@ def test_sources_required_multiple(self): def test_resequencing_stop_times(self): gtfs_source = self.fdict.copy() - gtfs_source.pop('stop_times.txt') - - gtfs_source['stop_times.txt'] = \ - self.stopTimesText = \ - "trip_id, arrival_time, departure_time, stop_sequence, stop_id" \ - "\nservice1_trip1,0:06:10,0:06:10,0,SID1" \ - "\nservice1_trip1,0:06:15,0:06:16,10,SID2" \ - "\nfreq_trip_scheduled,0:00:00,0:00:00,1,SID1" \ + gtfs_source.pop("stop_times.txt") + + gtfs_source["stop_times.txt"] = self.stopTimesText = ( + "trip_id, arrival_time, departure_time, stop_sequence, stop_id" + "\nservice1_trip1,0:06:10,0:06:10,0,SID1" + "\nservice1_trip1,0:06:15,0:06:16,10,SID2" + "\nfreq_trip_scheduled,0:00:00,0:00:00,1,SID1" "\nfreq_trip_scheduled,0:02:00,0:02:00,123,SID2" + ) import_gtfs(gtfs_source, self.conn, preserve_connection=True) rows = self.conn.execute("SELECT seq FROM stop_times ORDER BY trip_I, seq").fetchall() diff --git a/gtfspy/test/test_import_validator.py b/gtfspy/test/test_import_validator.py index 48e1be5..51a4abb 100644 --- a/gtfspy/test/test_import_validator.py +++ b/gtfspy/test/test_import_validator.py @@ -5,14 +5,15 @@ class TestImportValidator(unittest.TestCase): - def setUp(self): # create validator object using textfiles test_feed_dir = os.path.join(os.path.dirname(__file__), "test_data/") test_feed_b_dir = os.path.join(test_feed_dir, "feed_b") self.gtfs_source_dir = os.path.join(os.path.dirname(__file__), "test_data") self.G_txt = GTFS.from_directory_as_inmemory_db([test_feed_dir, test_feed_b_dir]) - self.import_validator = ImportValidator([test_feed_dir, test_feed_b_dir], self.G_txt, verbose=False) + self.import_validator = ImportValidator( + [test_feed_dir, test_feed_b_dir], self.G_txt, verbose=False + ) def test_source_gtfsobj_comparison(self): self.import_validator._validate_table_row_counts() @@ -27,7 +28,3 @@ def test_validate(self): if "stop_distances" in warning: stop_dist_warning_exists = True assert stop_dist_warning_exists - - - - diff --git a/gtfspy/test/test_mapviz.py b/gtfspy/test/test_mapviz.py index 4fc8cae..4c004df 100644 --- a/gtfspy/test/test_mapviz.py +++ b/gtfspy/test/test_mapviz.py @@ -8,7 +8,6 @@ class TestMapviz(unittest.TestCase): - def setUp(self): self.gtfs_source_dir = os.path.join(os.path.dirname(__file__), "test_data/filter_test_feed") self.fname = self.gtfs_source_dir + "/test_gtfs.sqlite" @@ -27,10 +26,10 @@ def tearDown(self): def test_plot_trip_counts_per_day(self): # simple "it compiles" tests: - ax = plot_route_network_from_gtfs(self.G) - ax = plot_route_network_from_gtfs(self.G, map_style="light_all") - ax = plot_route_network_from_gtfs(self.G, map_style="dark_all") - ax = plot_route_network_from_gtfs(self.G, map_style="rastertiles/voyager") + plot_route_network_from_gtfs(self.G) + plot_route_network_from_gtfs(self.G, map_style="light_all") + plot_route_network_from_gtfs(self.G, map_style="dark_all") + plot_route_network_from_gtfs(self.G, map_style="rastertiles/voyager") # for interactive testing # from matplotlib import pyplot as plt # plt.show() diff --git a/gtfspy/test/test_plots.py b/gtfspy/test/test_plots.py index fd01b06..2b294e1 100644 --- a/gtfspy/test/test_plots.py +++ b/gtfspy/test/test_plots.py @@ -28,13 +28,11 @@ def tearDown(self): def test_plot_trip_counts_per_day(self): # simple test - ax = plot_trip_counts_per_day(self.G, - highlight_dates=["2009-01-01"], - highlight_date_labels=["test_date"]) + ax = plot_trip_counts_per_day( + self.G, highlight_dates=["2009-01-01"], highlight_date_labels=["test_date"] + ) # test with multiple dates and datetime dates = [datetime(2009, month=10, day=1), datetime(2010, month=10, day=1)] labels = ["test_date_1", "test_date_2"] - ax = plot_trip_counts_per_day(self.G, - highlight_dates=dates, - highlight_date_labels=labels) + ax = plot_trip_counts_per_day(self.G, highlight_dates=dates, highlight_date_labels=labels) assert isinstance(ax, Axes) diff --git a/gtfspy/test/test_shapes.py b/gtfspy/test/test_shapes.py index c05d875..03a57d8 100644 --- a/gtfspy/test/test_shapes.py +++ b/gtfspy/test/test_shapes.py @@ -8,55 +8,53 @@ from gtfspy import shapes - class ShapesTest(unittest.TestCase): - def test_shape_break_order(self): for trip_I in [ - # These trip IDs require the hsl-2015-07-12 DB. - 73775, # Route 18 Eira -> Munkkivuori - 172258, # Route 94A in Helsinki, direction 0. - 84380, # 36 in Helsinki. Has lots of dead ends. - 83734, # route 1032 - 84044, # route 1034 - 240709, # 143K - 194350, # 802 - 194530, # 802K - 270813, # P20 - 270849, # P21 - ]: + # These trip IDs require the hsl-2015-07-12 DB. + 73775, # Route 18 Eira -> Munkkivuori + 172258, # Route 94A in Helsinki, direction 0. + 84380, # 36 in Helsinki. Has lots of dead ends. + 83734, # route 1032 + 84044, # route 1034 + 240709, # 143K + 194350, # 802 + 194530, # 802K + 270813, # P20 + 270849, # P21 + ]: yield self.test_shape_break_order_1, trip_I pass - #yield test_shape_break_order_1, 83734 + # yield test_shape_break_order_1, 83734 @unittest.skip("skipping test_shape_break_order_1") def test_shape_break_order_1(self, trip_I=73775): """This is to a bug related to shape alignment.""" - conn = GTFS('../scratch/db/hsl-2015-07-12.sqlite').conn + conn = GTFS("../scratch/db/hsl-2015-07-12.sqlite").conn cur = conn.cursor() - cur.execute('''SELECT seq, lat, lon + cur.execute( + """SELECT seq, lat, lon FROM stop_times LEFT JOIN stops USING (stop_I) WHERE trip_I=? - ORDER BY seq''', - (trip_I,)) - #print '%20s, %s'%(run_code, datetime.fromtimestamp(run_sch_starttime)) - stop_points = [ dict(seq=row[0], - lat=row[1], - lon=row[2]) - for row in cur] + ORDER BY seq""", + (trip_I,), + ) + # print '%20s, %s'%(run_code, datetime.fromtimestamp(run_sch_starttime)) + stop_points = [dict(seq=row[0], lat=row[1], lon=row[2]) for row in cur] # Get the shape points - shape_id = cur.execute('''SELECT shape_id - FROM trips WHERE trip_I=?''', (trip_I,)).fetchone()[0] + shape_id = cur.execute( + """SELECT shape_id + FROM trips WHERE trip_I=?""", + (trip_I,), + ).fetchone()[0] shape_points = shapes.get_shape_points(cur, shape_id) - breakpoints, badness \ - = shapes.find_segments(stop_points, shape_points) + breakpoints, badness = shapes.find_segments(stop_points, shape_points) print(badness) if badness > 30: print("bad shape fit: %s (%s, %s)" % (badness, trip_I, shape_id)) - for b1, b2 in zip(breakpoints, sorted(breakpoints)): self.assertEqual(b1, b2) @@ -65,7 +63,7 @@ def test_interpolate_shape_times(): shape_distances = [0, 2, 5, 10, 20, 100] shape_breaks = [0, 2, 5] stop_times = [0, 1, 20] - result_should_be = [0, 0.4, 1, 1 + (19 * 5 / 95.), 1 + (19 * 15 / 95.), 20] + result_should_be = [0, 0.4, 1, 1 + (19 * 5 / 95.0), 1 + (19 * 15 / 95.0), 20] result = shapes.interpolate_shape_times(shape_distances, shape_breaks, stop_times) assert len(result) == len(result_should_be) diff --git a/gtfspy/test/test_spreading.py b/gtfspy/test/test_spreading.py index 3b5d6de..fdace8b 100644 --- a/gtfspy/test/test_spreading.py +++ b/gtfspy/test/test_spreading.py @@ -9,13 +9,12 @@ class SpreadingTest(unittest.TestCase): - @staticmethod def test_get_min_visit_time(): stop_I = 1 min_transfer_time = 60 ss = SpreadingStop(stop_I, min_transfer_time) - assert ss.get_min_visit_time() == float('inf') + assert ss.get_min_visit_time() == float("inf") ss.visit_events = [Event(10, 0, stop_I, stop_I, -1)] assert ss.get_min_visit_time() == 10 ss.visit_events.append(Event(5, 0, stop_I, stop_I, -1)) @@ -74,5 +73,3 @@ def test_get_trips(): keys = "lats lons times route_type name".split() for key in keys: assert key in el, el - - diff --git a/gtfspy/test/test_stats.py b/gtfspy/test/test_stats.py index 42c2301..fe03588 100644 --- a/gtfspy/test/test_stats.py +++ b/gtfspy/test/test_stats.py @@ -9,7 +9,6 @@ class StatsTest(unittest.TestCase): - @classmethod def setUpClass(cls): """ This method is run once before executing any tests""" @@ -21,11 +20,11 @@ def setUp(self): self.gtfs = GTFS.from_directory_as_inmemory_db(self.gtfs_source_dir) def test_write_stats_as_csv(self): - testfile = temp.NamedTemporaryFile(mode='w+b') + testfile = temp.NamedTemporaryFile(mode="w+b") stats.write_stats_as_csv(self.gtfs, testfile.name) df = pd.read_csv(testfile.name) - print('len is ' + str(len(df))) + print("len is " + str(len(df))) self.assertEqual(len(df), 1) stats.write_stats_as_csv(self.gtfs, testfile.name) @@ -38,20 +37,25 @@ def test_get_stats(self): self.assertTrue(isinstance(d, dict)) def test_calc_and_store_stats(self): - self.gtfs.meta['stats_calc_at_ut'] = None + self.gtfs.meta["stats_calc_at_ut"] = None stats.update_stats(self.gtfs) self.assertTrue(isinstance(stats.get_stats(self.gtfs), dict)) - self.assertTrue(self.G.meta['stats_calc_at_ut'] is not None) + self.assertTrue(self.G.meta["stats_calc_at_ut"] is not None) def test_get_median_lat_lon_of_stops(self): lat, lon = stats.get_median_lat_lon_of_stops(self.gtfs) - self.assertTrue(lat != lon, "probably median lat and median lon should not be equal for any real data set") + self.assertTrue( + lat != lon, + "probably median lat and median lon should not be equal for any real data set", + ) self.assertTrue(isinstance(lat, float)) self.assertTrue(isinstance(lon, float)) def test_get_centroid_of_stops(self): lat, lon = stats.get_centroid_of_stops(self.gtfs) - self.assertTrue(lat != lon, "probably centroid lat and lon should not be equal for any real data set") + self.assertTrue( + lat != lon, "probably centroid lat and lon should not be equal for any real data set" + ) self.assertTrue(isinstance(lat, float)) self.assertTrue(isinstance(lon, float)) @@ -72,5 +76,3 @@ def test_hourly_frequencies(self): self.assertTrue(isinstance(df, pd.DataFrame)) self.assertTrue(isinstance(df, object)) self.assertTrue(len(df.columns), 4) - - diff --git a/gtfspy/test/test_timetable_validator.py b/gtfspy/test/test_timetable_validator.py index 10d2406..7e4edc1 100644 --- a/gtfspy/test/test_timetable_validator.py +++ b/gtfspy/test/test_timetable_validator.py @@ -4,8 +4,8 @@ from gtfspy.gtfs import GTFS from gtfspy.timetable_validator import TimetableValidator -class TestGTFSValidator(unittest.TestCase): +class TestGTFSValidator(unittest.TestCase): def setUp(self): self.gtfs_source_dir = os.path.join(os.path.dirname(__file__), "test_data") self.G = GTFS.from_directory_as_inmemory_db(self.gtfs_source_dir) @@ -15,5 +15,3 @@ def test_compiles(self): warnings = validator.validate_and_get_warnings() warning_counts = warnings.get_warning_counter() assert len(warning_counts) > 0 - - diff --git a/gtfspy/test/test_util.py b/gtfspy/test/test_util.py index 966f8d0..7c397bb 100644 --- a/gtfspy/test/test_util.py +++ b/gtfspy/test/test_util.py @@ -6,7 +6,6 @@ class TestUtil(unittest.TestCase): - @staticmethod def _approximately_equal(a, b): return abs(a - b) / float(abs(a + b)) < 1e-2 @@ -37,7 +36,16 @@ def test_day_seconds_to_str_time(self): def test_txt_to_pandas(self): source_dir = os.path.join(os.path.dirname(__file__), "test_data") - txtnames = ['agency', 'routes', 'trips', 'calendar', 'calendar_dates', 'stop_times', 'stops', 'shapes'] + txtnames = [ + "agency", + "routes", + "trips", + "calendar", + "calendar_dates", + "stop_times", + "stops", + "shapes", + ] df = util.source_csv_to_pandas(source_dir, txtnames[3]) self.assertIsInstance(df, pd.DataFrame) source_zip = os.path.join(os.path.dirname(__file__), "test_data/test_gtfs.zip") @@ -50,4 +58,4 @@ def test_difference_of_pandas_dfs(self): df1 = pd.DataFrame(dict1) df2 = pd.DataFrame(dict2) df = util.difference_of_pandas_dfs(df1, df2, ["lat", "lon"]) - self.assertEqual(len(df.index), 2) \ No newline at end of file + self.assertEqual(len(df.index), 2) diff --git a/gtfspy/test/test_warnings_container.py b/gtfspy/test/test_warnings_container.py index a2b495b..c114cc8 100644 --- a/gtfspy/test/test_warnings_container.py +++ b/gtfspy/test/test_warnings_container.py @@ -5,7 +5,6 @@ class TestWarningsContainer(TestCase): - def test_summary_print(self): wc = WarningsContainer() wc.add_warning("DUMMY_WARNING", ["dummy1", "dummy2"], 2) @@ -22,4 +21,4 @@ def test_details_print(self): f = io.StringIO("") wc.write_details(output_stream=f) f.seek(0) - assert len(f.readlines()) > len(wc.get_warning_counter().keys()) + 1 \ No newline at end of file + assert len(f.readlines()) > len(wc.get_warning_counter().keys()) + 1 diff --git a/gtfspy/timetable_validator.py b/gtfspy/timetable_validator.py index 37fb358..d9ee87e 100644 --- a/gtfspy/timetable_validator.py +++ b/gtfspy/timetable_validator.py @@ -5,9 +5,9 @@ # (i.e. using the if __name__ == "__main__": part at the end of this file) from gtfspy.warnings_container import WarningsContainer -if __name__ == '__main__' and __package__ is None: +if __name__ == "__main__" and __package__ is None: # import gtfspy - __package__ = 'gtfspy' + __package__ = "gtfspy" # noqa: A001 from gtfspy import route_types @@ -15,17 +15,31 @@ from gtfspy.util import wgs84_distance -WARNING_5_OR_MORE_CONSECUTIVE_STOPS_WITH_SAME_TIME = "trip--arr_time -combinations with five or more consecutive stops having same stop time" +WARNING_5_OR_MORE_CONSECUTIVE_STOPS_WITH_SAME_TIME = ( + "trip--arr_time -combinations with five or more consecutive stops having same stop time" +) WARNING_LONG_TRIP_TIME = "Trip time longer than {MAX_TRIP_TIME} seconds" -WARNING_TRIP_UNREALISTIC_AVERAGE_SPEED = "trips whose average speed is unrealistic relative to travel mode" +WARNING_TRIP_UNREALISTIC_AVERAGE_SPEED = ( + "trips whose average speed is unrealistic relative to travel mode" +) MAX_ALLOWED_DISTANCE_BETWEEN_CONSECUTIVE_STOPS = 20000 # meters -WARNING_LONG_STOP_SPACING = "distance between consecutive stops longer than " + str(MAX_ALLOWED_DISTANCE_BETWEEN_CONSECUTIVE_STOPS) + " meters" +WARNING_LONG_STOP_SPACING = ( + "distance between consecutive stops longer than " + + str(MAX_ALLOWED_DISTANCE_BETWEEN_CONSECUTIVE_STOPS) + + " meters" +) MAX_TIME_BETWEEN_STOPS = 1800 # seconds -WARNING_LONG_TRAVEL_TIME_BETWEEN_STOPS = "trip--stop_times-combinations with travel time between consecutive stops longer than " + str(MAX_TIME_BETWEEN_STOPS / 60) + " minutes" +WARNING_LONG_TRAVEL_TIME_BETWEEN_STOPS = ( + "trip--stop_times-combinations with travel time between consecutive stops longer than " + + str(MAX_TIME_BETWEEN_STOPS / 60) + + " minutes" +) WARNING_STOP_SEQUENCE_ORDER_ERROR = "stop sequence is not in right order" -WARNING_STOP_SEQUENCE_NOT_INCREMENTAL = "stop sequences are not increasing always by one in stop_times" +WARNING_STOP_SEQUENCE_NOT_INCREMENTAL = ( + "stop sequences are not increasing always by one in stop_times" +) WARNING_STOP_FAR_AWAY_FROM_FILTER_BOUNDARY = "stop far away from spatial filter boundary" ALL_WARNINGS = { @@ -35,7 +49,7 @@ WARNING_TRIP_UNREALISTIC_AVERAGE_SPEED, WARNING_LONG_TRAVEL_TIME_BETWEEN_STOPS, WARNING_STOP_SEQUENCE_ORDER_ERROR, - WARNING_STOP_SEQUENCE_NOT_INCREMENTAL + WARNING_STOP_SEQUENCE_NOT_INCREMENTAL, } GTFS_TYPE_TO_MAX_SPEED = { @@ -47,12 +61,12 @@ route_types.CABLE_CAR: 50, route_types.GONDOLA: 50, route_types.FUNICULAR: 50, - route_types.AIRCRAFT: 1000 + route_types.AIRCRAFT: 1000, } MAX_TRIP_TIME = 7200 # seconds -class TimetableValidator(object): +class TimetableValidator(object): def __init__(self, gtfs, buffer_params=None): """ Parameters @@ -88,49 +102,57 @@ def validate_and_get_warnings(self): def _validate_misplaced_stops(self): if self.buffer_params: p = self.buffer_params - center_lat = p['lat'] - center_lon = p['lon'] - buffer_distance = p['buffer_distance'] * 1000 * 1.002 # some error margin for rounding + center_lat = p["lat"] + center_lon = p["lon"] + buffer_distance = p["buffer_distance"] * 1000 * 1.002 # some error margin for rounding for stop_row in self.gtfs.stops().itertuples(): - if buffer_distance < wgs84_distance(center_lat, center_lon, stop_row.lat, stop_row.lon): - self.warnings_container.add_warning(WARNING_STOP_FAR_AWAY_FROM_FILTER_BOUNDARY, stop_row) + if buffer_distance < wgs84_distance( + center_lat, center_lon, stop_row.lat, stop_row.lon + ): + self.warnings_container.add_warning( + WARNING_STOP_FAR_AWAY_FROM_FILTER_BOUNDARY, stop_row + ) print(WARNING_STOP_FAR_AWAY_FROM_FILTER_BOUNDARY, stop_row) def _validate_stops_with_same_stop_time(self): n_stops_with_same_time = 5 # this query returns the trips where there are N or more stops with the same stop time rows = self.gtfs.get_cursor().execute( - 'SELECT ' - 'trip_I, ' - 'arr_time, ' - 'N ' - 'FROM ' - '(SELECT trip_I, arr_time, count(*) as N FROM stop_times GROUP BY trip_I, arr_time) q1 ' - 'WHERE N >= ?', (n_stops_with_same_time,) + "SELECT " + "trip_I, " + "arr_time, " + "N " + "FROM " + "(SELECT trip_I, arr_time, count(*) as N FROM stop_times GROUP BY trip_I, arr_time) q1 " + "WHERE N >= ?", + (n_stops_with_same_time,), ) for row in rows: - self.warnings_container.add_warning(WARNING_5_OR_MORE_CONSECUTIVE_STOPS_WITH_SAME_TIME, row) + self.warnings_container.add_warning( + WARNING_5_OR_MORE_CONSECUTIVE_STOPS_WITH_SAME_TIME, row + ) def _validate_stop_spacings(self): self.gtfs.conn.create_function("find_distance", 4, wgs84_distance) # this query calculates distance and travel time between consecutive stops rows = self.gtfs.execute_custom_query( - 'SELECT ' - 'q1.trip_I, ' - 'type, ' - 'q1.stop_I as stop_1, ' - 'q2.stop_I as stop_2, ' - 'CAST(find_distance(q1.lat, q1.lon, q2.lat, q2.lon) AS INT) as distance, ' - 'q2.arr_time_ds - q1.arr_time_ds as traveltime ' - 'FROM ' - '(SELECT * FROM stop_times, stops WHERE stop_times.stop_I = stops.stop_I) q1, ' - '(SELECT * FROM stop_times, stops WHERE stop_times.stop_I = stops.stop_I) q2, ' - 'trips, ' - 'routes ' - 'WHERE q1.trip_I = q2.trip_I ' - 'AND q1.seq + 1 = q2.seq ' - 'AND q1.trip_I = trips.trip_I ' - 'AND trips.route_I = routes.route_I ').fetchall() + "SELECT " + "q1.trip_I, " + "type, " + "q1.stop_I as stop_1, " + "q2.stop_I as stop_2, " + "CAST(find_distance(q1.lat, q1.lon, q2.lat, q2.lon) AS INT) as distance, " + "q2.arr_time_ds - q1.arr_time_ds as traveltime " + "FROM " + "(SELECT * FROM stop_times, stops WHERE stop_times.stop_I = stops.stop_I) q1, " + "(SELECT * FROM stop_times, stops WHERE stop_times.stop_I = stops.stop_I) q2, " + "trips, " + "routes " + "WHERE q1.trip_I = q2.trip_I " + "AND q1.seq + 1 = q2.seq " + "AND q1.trip_I = trips.trip_I " + "AND trips.route_I = routes.route_I " + ).fetchall() for row in rows: if row[4] > MAX_ALLOWED_DISTANCE_BETWEEN_CONSECUTIVE_STOPS: self.warnings_container.add_warning(WARNING_LONG_STOP_SPACING, row) @@ -144,35 +166,42 @@ def _validate_speeds_and_trip_times(self): # this query returns the total distance and travel time for each trip calculated for each stop spacing separately rows = pandas.read_sql( - 'SELECT ' - 'q1.trip_I, ' - 'type, ' - 'sum(CAST(find_distance(q1.lat, q1.lon, q2.lat, q2.lon) AS INT)) AS total_distance, ' # sum used for getting total - 'sum(q2.arr_time_ds - q1.arr_time_ds) AS total_traveltime, ' # sum used for getting total - 'count(*)' # for getting the total number of stops - 'FROM ' - ' (SELECT * FROM stop_times, stops WHERE stop_times.stop_I = stops.stop_I) q1, ' - ' (SELECT * FROM stop_times, stops WHERE stop_times.stop_I = stops.stop_I) q2, ' - ' trips, ' - ' routes ' - 'WHERE q1.trip_I = q2.trip_I AND q1.seq + 1 = q2.seq AND q1.trip_I = trips.trip_I ' - 'AND trips.route_I = routes.route_I GROUP BY q1.trip_I', self.gtfs.conn) + "SELECT " + "q1.trip_I, " + "type, " + "sum(CAST(find_distance(q1.lat, q1.lon, q2.lat, q2.lon) AS INT)) AS total_distance, " # sum used for getting total + "sum(q2.arr_time_ds - q1.arr_time_ds) AS total_traveltime, " # sum used for getting total + "count(*)" # for getting the total number of stops + "FROM " + " (SELECT * FROM stop_times, stops WHERE stop_times.stop_I = stops.stop_I) q1, " + " (SELECT * FROM stop_times, stops WHERE stop_times.stop_I = stops.stop_I) q2, " + " trips, " + " routes " + "WHERE q1.trip_I = q2.trip_I AND q1.seq + 1 = q2.seq AND q1.trip_I = trips.trip_I " + "AND trips.route_I = routes.route_I GROUP BY q1.trip_I", + self.gtfs.conn, + ) for row in rows.itertuples(): avg_velocity_km_per_h = row.total_distance / max(row.total_traveltime, 1) * 3.6 if avg_velocity_km_per_h > GTFS_TYPE_TO_MAX_SPEED[row.type]: - self.warnings_container.add_warning(WARNING_TRIP_UNREALISTIC_AVERAGE_SPEED + " (route_type=" + str(row.type) + ")", - row + self.warnings_container.add_warning( + WARNING_TRIP_UNREALISTIC_AVERAGE_SPEED + " (route_type=" + str(row.type) + ")", + row, ) if row.total_traveltime > MAX_TRIP_TIME: - self.warnings_container.add_warning(WARNING_LONG_TRIP_TIME.format(MAX_TRIP_TIME=MAX_TRIP_TIME), row, 1) + self.warnings_container.add_warning( + WARNING_LONG_TRIP_TIME.format(MAX_TRIP_TIME=MAX_TRIP_TIME), row, 1 + ) def _validate_stop_sequence(self): # This function checks if the seq values in stop_times are increasing with departure_time, # and that seq always increases by one. - rows = self.gtfs.execute_custom_query('SELECT trip_I, dep_time_ds, seq ' - 'FROM stop_times ' - 'ORDER BY trip_I, dep_time_ds, seq').fetchall() + rows = self.gtfs.execute_custom_query( + "SELECT trip_I, dep_time_ds, seq " + "FROM stop_times " + "ORDER BY trip_I, dep_time_ds, seq" + ).fetchall() old_trip_id = None old_seq = None for row in rows: @@ -195,6 +224,6 @@ def main(): warningscontainer = validator.validate_and_get_warnings() warningscontainer.write_summary() + if __name__ == "__main__": main() - diff --git a/gtfspy/util.py b/gtfspy/util.py index 0ae0219..a6cbe72 100644 --- a/gtfspy/util.py +++ b/gtfspy/util.py @@ -4,11 +4,11 @@ import io import math import os -import zipfile import shutil import sys import tempfile import time +import zipfile from math import cos import networkx @@ -25,8 +25,8 @@ current_umask = os.umask(0) os.umask(current_umask) -TORADIANS = 3.141592653589793 / 180. -EARTH_RADIUS = 6378137. +TORADIANS = 3.141592653589793 / 180.0 +EARTH_RADIUS = 6378137.0 def set_process_timezone(TZ): @@ -36,12 +36,12 @@ def set_process_timezone(TZ): TZ: string """ try: - prev_timezone = os.environ['TZ'] + prev_timezone = os.environ["TZ"] except KeyError: prev_timezone = None - os.environ['TZ'] = TZ + os.environ["TZ"] = TZ - if sys.platform == 'win32': # tzset() does not work on Windows + if sys.platform == "win32": # tzset() does not work on Windows system_time = SystemTime() lpSystemTime = ctypes.pointer(system_time) ctypes.windll.kernel32.GetLocalTime(lpSystemTime) @@ -53,23 +53,24 @@ def set_process_timezone(TZ): class SystemTime(ctypes.Structure): _fields_ = [ - ('wYear', ctypes.c_int16), - ('wMonth', ctypes.c_int16), - ('wDayOfWeek', ctypes.c_int16), - ('wDay', ctypes.c_int16), - ('wHour', ctypes.c_int16), - ('wMinute', ctypes.c_int16), - ('wSecond', ctypes.c_int16), - ('wMilliseconds', ctypes.c_int16)] + ("wYear", ctypes.c_int16), + ("wMonth", ctypes.c_int16), + ("wDayOfWeek", ctypes.c_int16), + ("wDay", ctypes.c_int16), + ("wHour", ctypes.c_int16), + ("wMinute", ctypes.c_int16), + ("wSecond", ctypes.c_int16), + ("wMilliseconds", ctypes.c_int16), + ] def wgs84_distance(lat1, lon1, lat2, lon2): """Distance (in meters) between two points in WGS84 coord system.""" dLat = math.radians(lat2 - lat1) dLon = math.radians(lon2 - lon1) - a = (math.sin(dLat / 2) * math.sin(dLat / 2) + - math.cos(math.radians(lat1)) * math.cos(math.radians(lat2)) * - math.sin(dLon / 2) * math.sin(dLon / 2)) + a = math.sin(dLat / 2) * math.sin(dLat / 2) + math.cos(math.radians(lat1)) * math.cos( + math.radians(lat2) + ) * math.sin(dLon / 2) * math.sin(dLon / 2) c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a)) d = EARTH_RADIUS * c return d @@ -86,19 +87,15 @@ def wgs84_width(meters, lat): # cython implementation of this. It is called _often_. try: - from gtfspy.cutil import wgs84_distance + from gtfspy.cutil import wgs84_distance # type: ignore except ImportError: pass -possible_tmpdirs = [ - '/tmp', - '' -] +possible_tmpdirs = ["/tmp", ""] @contextlib.contextmanager -def create_file(fname=None, fname_tmp=None, tmpdir=None, - save_tmpfile=False, keepext=False): +def create_file(fname=None, fname_tmp=None, tmpdir=None, save_tmpfile=False, keepext=False): """Context manager for making files with possibility of failure. If you are creating a file, it is possible that the code will fail @@ -135,7 +132,7 @@ def create_file(fname=None, fname_tmp=None, tmpdir=None, Re-raises any except occuring during the context block. """ # Do nothing if requesting sqlite memory DB. - if fname == ':memory:': + if fname == ":memory:": yield fname return if fname_tmp is None: @@ -147,7 +144,7 @@ def create_file(fname=None, fname_tmp=None, tmpdir=None, # automatic things itself. if not keepext: root = root + ext - ext = '' + ext = "" if tmpdir: # we should use a different temporary directory if tmpdir is True: @@ -161,11 +158,12 @@ def create_file(fname=None, fname_tmp=None, tmpdir=None, # extension. Set it to not delete automatically, since on # success we will move it to elsewhere. tmpfile = tempfile.NamedTemporaryFile( - prefix='tmp-' + root + '-', suffix=ext, dir=dir_, delete=False) + prefix="tmp-" + root + "-", suffix=ext, dir=dir_, delete=False + ) fname_tmp = tmpfile.name try: yield fname_tmp - except Exception as e: + except Exception: if save_tmpfile: print("Temporary file is '%s'" % fname_tmp) else: @@ -181,10 +179,11 @@ def create_file(fname=None, fname_tmp=None, tmpdir=None, # filesystems. So, we have to fallback to moving it. But, we # want to move it using tmpfiles also, so that the final file # appearing is atomic. We use... tmpfiles. - except OSError as e: + except OSError: # New temporary file in same directory tmpfile2 = tempfile.NamedTemporaryFile( - prefix='tmp-' + root + '-', suffix=ext, dir=this_dir, delete=False) + prefix="tmp-" + root + "-", suffix=ext, dir=this_dir, delete=False + ) # Copy contents over shutil.copy(fname_tmp, tmpfile2.name) # Rename new tmpfile, unlink old one on other filesystem. @@ -204,7 +203,7 @@ def execute(cur, *args): """ stmt = args[0] if len(args) > 1: - stmt = stmt.replace('%', '%%').replace('?', '%r') + stmt = stmt.replace("%", "%%").replace("?", "%r") print(stmt % (args[1])) return cur.execute(*args) @@ -212,7 +211,7 @@ def execute(cur, *args): def to_date_string(date): if isinstance(date, numpy.int64) or isinstance(date, int): date = str(date) - date = '%s-%s-%s' % (date[:4], date[4:6], date[6:8]) + date = "%s-%s-%s" % (date[:4], date[4:6], date[6:8]) return date @@ -227,7 +226,7 @@ def str_time_to_day_seconds(time): :param time: %H:%M:%S string :return: integer seconds """ - t = str(time).split(':') + t = str(time).split(":") seconds = int(t[0]) * 3600 + int(t[1]) * 60 + int(t[2]) return seconds @@ -263,7 +262,7 @@ def timed(*args, **kw): time_start = time.time() result = method(*args, **kw) time_end = time.time() - print('timeit: %r %2.2f sec ' % (method.__name__, time_end - time_start)) + print("timeit: %r %2.2f sec " % (method.__name__, time_end - time_start)) return result return timed @@ -271,8 +270,9 @@ def timed(*args, **kw): def corrupted_zip(zip_path): import zipfile + try: - zip_to_test = zipfile.ZipFile(zip_path) + zipfile.ZipFile(zip_path) # warning = zip_to_test.testzip() # if warning is not None: # return str(warning) @@ -297,8 +297,8 @@ def source_csv_to_pandas(path, table, read_csv_args=None): ------- df: pandas:DataFrame """ - if '.txt' not in table: - table += '.txt' + if ".txt" not in table: + table += ".txt" if isinstance(path, dict): data_obj = path[table] @@ -316,7 +316,7 @@ def source_csv_to_pandas(path, table, read_csv_args=None): break try: f = zip_open(z, table) - except KeyError as e: + except KeyError: return pd.DataFrame() if read_csv_args: @@ -328,6 +328,7 @@ def source_csv_to_pandas(path, table, read_csv_args=None): def write_shapefile(data, shapefile_path): from numpy import int64 + """ :param data: list of dicts where dictionary contains the keys lons and lats :param shapefile_path: path where shapefile is saved @@ -337,7 +338,6 @@ def write_shapefile(data, shapefile_path): w = shp.Writer(shp.POLYLINE) # shapeType=3) fields = [] - encode_strings = [] # This makes sure every geom has all the attributes w.autoBalance = 1 @@ -345,28 +345,27 @@ def write_shapefile(data, shapefile_path): # datastoring phase. Encode_strings stores .encode methods as strings for all fields that are strings if not fields: for key, value in data[0].items(): - if key != u'lats' and key != u'lons': + if key != "lats" and key != "lons": fields.append(key) if type(value) == float: - w.field(key.encode('ascii'), fieldType='N', size=11, decimal=3) + w.field(key.encode("ascii"), fieldType="N", size=11, decimal=3) print("float", type(value)) elif type(value) == int or type(value) == int64: print("int", type(value)) # encode_strings.append(".encode('ascii')") - w.field(key.encode('ascii'), fieldType='N', size=6, decimal=0) + w.field(key.encode("ascii"), fieldType="N", size=6, decimal=0) else: print("other type", type(value)) - w.field(key.encode('ascii')) + w.field(key.encode("ascii")) for dict_item in data: line = [] lineparts = [] - records = [] - records_string = '' - for lat, lon in zip(dict_item[u'lats'], dict_item[u'lons']): + records_string = "" + for lat, lon in zip(dict_item["lats"], dict_item["lons"]): line.append([float(lon), float(lat)]) lineparts.append(line) w.line(parts=lineparts) @@ -388,9 +387,9 @@ def write_shapefile(data, shapefile_path): # Opening files with Universal newlines is done differently in py3 def zip_open(z, filename): if sys.version_info[0] == 2: - return z.open(filename, 'rU') + return z.open(filename, "rU") else: - return io.TextIOWrapper(z.open(filename, 'r'), "utf-8") + return io.TextIOWrapper(z.open(filename, "r"), "utf-8") def draw_net_using_node_coords(net): @@ -405,10 +404,11 @@ def draw_net_using_node_coords(net): the figure object where the network is plotted """ import matplotlib.pyplot as plt + fig = plt.figure() node_coords = {} for node, data in net.nodes(data=True): - node_coords[node] = (data['lon'], data['lat']) + node_coords[node] = (data["lon"], data["lat"]) ax = fig.add_subplot(111) networkx.draw(net, pos=node_coords, ax=ax, node_size=50) return fig @@ -417,12 +417,14 @@ def draw_net_using_node_coords(net): def make_sure_path_exists(path): import os import errno + try: os.makedirs(path) except OSError as exception: if exception.errno != errno.EEXIST: raise + def difference_of_pandas_dfs(df_self, df_other, col_names=None): """ Returns a dataframe with all of df_other that are not in df_self, when considering the columns specified in col_names diff --git a/gtfspy/warnings_container.py b/gtfspy/warnings_container.py index f1eb190..57438a8 100644 --- a/gtfspy/warnings_container.py +++ b/gtfspy/warnings_container.py @@ -3,7 +3,6 @@ class WarningsContainer(object): - def __init__(self): self._warnings_counter = Counter() # key: "warning type" string, value: "number of errors" int @@ -22,14 +21,14 @@ def add_warning(self, warning, reason, count=None): def write_summary(self, output_stream=None): if output_stream is None: output_stream = sys.stdout - output_stream.write('The feed produced the following warnings:\n') + output_stream.write("The feed produced the following warnings:\n") for warning, count in self._warnings_counter.most_common(): output_stream.write(warning + ": " + str(count) + "\n") def write_details(self, output_stream=None): if output_stream is None: output_stream = sys.stdout - output_stream.write('The feed produced the following warnings (with details):\n') + output_stream.write("The feed produced the following warnings (with details):\n") for warning, count in self._warnings_counter.most_common(): output_stream.write(warning + ": " + str(count) + "\n") for reason in self._warnings_records[warning]: @@ -54,4 +53,4 @@ def get_warnings_by_query_rows(self): def clear(self): self._warnings_counter.clear() - self._warnings_records.clear() \ No newline at end of file + self._warnings_records.clear() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..91e1728 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,29 @@ +# Example configuration for Black. + +# NOTE: you have to use single-quoted strings in TOML for regular expressions. +# It's the equivalent of r-strings in Python. Multiline strings are treated as +# verbose regular expressions by Black. Use [ ] to denote a significant space +# character. + +[tool.black] +line-length = 100 +target-version = ["py35"] +include = '\.pyi?$' +exclude = ''' +/( + \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | venv + | _build + | buck-out + | build + | dist + # The following are specific to Black, you probably don't want those. + | blib2to3 + | tests/data + | profiling +)/ +''' \ No newline at end of file diff --git a/setup.py b/setup.py index 90132c5..a1c70c3 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import setup, Extension, find_packages -version="0.0.4" +version = "0.0.4" setup( name="gtfspy", @@ -10,48 +10,47 @@ packages=find_packages(exclude=["java_routing", "examples"]), author="Rainer Kujala", author_email="Rainer.Kujala@gmail.com", - license='MIT', + license="MIT", classifiers=[ # How mature is this project? Common values are # 3 - Alpha # 4 - Beta # 5 - Production/Stable - 'Development Status :: 3 - Alpha', - + "Development Status :: 3 - Alpha", # Indicate who your project is intended for - 'Intended Audience :: Developers', - 'Intended Audience :: Science/Research', - 'Topic :: Scientific/Engineering :: GIS', - + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering :: GIS", # Pick your license as you wish (should match "license" above) - 'License :: OSI Approved :: MIT License', - + "License :: OSI Approved :: MIT License", # Specify the Python versions you support here. In particular, ensure # that you indicate whether you support Python 2, Python 3 or both. - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.5' - ], - install_requires = [ - "setuptools>=18.0", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + ], + setup_requires=["setuptools>=18.0", "cython"], + install_requires=[ "pandas", "networkx==1.11", "pyshp", "smopy", "nose", - "Cython", "six", "geoindex", "osmread==0.2", "shapely", "geojson>=2.0.0", "pyproj", - "matplotlib-scalebar==0.6.1" - ], - ext_modules=[ - Extension( - 'gtfspy.routing.label', - sources=["gtfspy/routing/label.pyx"], - ), + "matplotlib-scalebar==0.6.1", ], - keywords = ['transit', 'routing' 'gtfs', 'public transport', 'analysis', 'visualization'], # arbitrary keywords + ext_modules=[Extension("gtfspy.routing.label", sources=["gtfspy/routing/label.pyx"])], + keywords=[ + "transit", + "routing" "gtfs", + "public transport", + "analysis", + "visualization", + ], # arbitrary keywords )